dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,377 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ import torch.nn as nn
6
+
7
+ from ultralytics.nn.modules.transformer import MLP
8
+ from ultralytics.utils.patches import torch_load
9
+
10
+ from .modules.blocks import PositionEmbeddingSine, RoPEAttention
11
+ from .modules.encoders import MemoryEncoder
12
+ from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
13
+ from .modules.sam import SAM3Model
14
+ from .sam3.decoder import TransformerDecoder, TransformerDecoderLayer
15
+ from .sam3.encoder import TransformerEncoderFusion, TransformerEncoderLayer
16
+ from .sam3.geometry_encoders import SequenceGeometryEncoder
17
+ from .sam3.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
18
+ from .sam3.model_misc import DotProductScoring, TransformerWrapper
19
+ from .sam3.necks import Sam3DualViTDetNeck
20
+ from .sam3.sam3_image import SAM3SemanticModel
21
+ from .sam3.text_encoder_ve import VETextEncoder
22
+ from .sam3.vitdet import ViT
23
+ from .sam3.vl_combiner import SAM3VLBackbone
24
+
25
+
26
+ def _create_vision_backbone(compile_mode=None, enable_inst_interactivity=True) -> Sam3DualViTDetNeck:
27
+ """Create SAM3 visual backbone with ViT and neck."""
28
+ # Position encoding
29
+ position_encoding = PositionEmbeddingSine(
30
+ num_pos_feats=256,
31
+ normalize=True,
32
+ scale=None,
33
+ temperature=10000,
34
+ )
35
+
36
+ # ViT backbone
37
+ vit_backbone = ViT(
38
+ img_size=1008,
39
+ pretrain_img_size=336,
40
+ patch_size=14,
41
+ embed_dim=1024,
42
+ depth=32,
43
+ num_heads=16,
44
+ mlp_ratio=4.625,
45
+ norm_layer="LayerNorm",
46
+ drop_path_rate=0.1,
47
+ qkv_bias=True,
48
+ use_abs_pos=True,
49
+ tile_abs_pos=True,
50
+ global_att_blocks=(7, 15, 23, 31),
51
+ rel_pos_blocks=(),
52
+ use_rope=True,
53
+ use_interp_rope=True,
54
+ window_size=24,
55
+ pretrain_use_cls_token=True,
56
+ retain_cls_token=False,
57
+ ln_pre=True,
58
+ ln_post=False,
59
+ return_interm_layers=False,
60
+ bias_patch_embed=False,
61
+ compile_mode=compile_mode,
62
+ )
63
+ return Sam3DualViTDetNeck(
64
+ position_encoding=position_encoding,
65
+ d_model=256,
66
+ scale_factors=[4.0, 2.0, 1.0, 0.5],
67
+ trunk=vit_backbone,
68
+ add_sam2_neck=enable_inst_interactivity,
69
+ )
70
+
71
+
72
+ def _create_sam3_transformer() -> TransformerWrapper:
73
+ """Create SAM3 detector encoder and decoder."""
74
+ encoder: TransformerEncoderFusion = TransformerEncoderFusion(
75
+ layer=TransformerEncoderLayer(
76
+ d_model=256,
77
+ dim_feedforward=2048,
78
+ dropout=0.1,
79
+ pos_enc_at_attn=True,
80
+ pos_enc_at_cross_attn_keys=False,
81
+ pos_enc_at_cross_attn_queries=False,
82
+ pre_norm=True,
83
+ self_attention=nn.MultiheadAttention(
84
+ num_heads=8,
85
+ dropout=0.1,
86
+ embed_dim=256,
87
+ batch_first=True,
88
+ ),
89
+ cross_attention=nn.MultiheadAttention(
90
+ num_heads=8,
91
+ dropout=0.1,
92
+ embed_dim=256,
93
+ batch_first=True,
94
+ ),
95
+ ),
96
+ num_layers=6,
97
+ d_model=256,
98
+ num_feature_levels=1,
99
+ frozen=False,
100
+ use_act_checkpoint=True,
101
+ add_pooled_text_to_img_feat=False,
102
+ pool_text_with_mask=True,
103
+ )
104
+ decoder: TransformerDecoder = TransformerDecoder(
105
+ layer=TransformerDecoderLayer(
106
+ d_model=256,
107
+ dim_feedforward=2048,
108
+ dropout=0.1,
109
+ cross_attention=nn.MultiheadAttention(
110
+ num_heads=8,
111
+ dropout=0.1,
112
+ embed_dim=256,
113
+ ),
114
+ n_heads=8,
115
+ use_text_cross_attention=True,
116
+ ),
117
+ num_layers=6,
118
+ num_queries=200,
119
+ return_intermediate=True,
120
+ box_refine=True,
121
+ num_o2m_queries=0,
122
+ dac=True,
123
+ boxRPB="log",
124
+ d_model=256,
125
+ frozen=False,
126
+ interaction_layer=None,
127
+ dac_use_selfatt_ln=True,
128
+ use_act_checkpoint=True,
129
+ presence_token=True,
130
+ )
131
+
132
+ return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
133
+
134
+
135
+ def build_sam3_image_model(checkpoint_path: str, enable_segmentation: bool = True, compile: bool = False):
136
+ """Build SAM3 image model.
137
+
138
+ Args:
139
+ checkpoint_path: Optional path to model checkpoint
140
+ enable_segmentation: Whether to enable segmentation head
141
+ compile: To enable compilation, set to "default"
142
+
143
+ Returns:
144
+ A SAM3 image model
145
+ """
146
+ try:
147
+ import clip
148
+ except ImportError:
149
+ from ultralytics.utils.checks import check_requirements
150
+
151
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
152
+ import clip
153
+ # Create visual components
154
+ compile_mode = "default" if compile else None
155
+ vision_encoder = _create_vision_backbone(compile_mode=compile_mode, enable_inst_interactivity=True)
156
+
157
+ # Create text components
158
+ text_encoder = VETextEncoder(
159
+ tokenizer=clip.simple_tokenizer.SimpleTokenizer(),
160
+ d_model=256,
161
+ width=1024,
162
+ heads=16,
163
+ layers=24,
164
+ )
165
+
166
+ # Create visual-language backbone
167
+ backbone = SAM3VLBackbone(visual=vision_encoder, text=text_encoder, scalp=1)
168
+
169
+ # Create transformer components
170
+ transformer = _create_sam3_transformer()
171
+
172
+ # Create dot product scoring
173
+ dot_prod_scoring = DotProductScoring(
174
+ d_model=256,
175
+ d_proj=256,
176
+ prompt_mlp=MLP(
177
+ input_dim=256,
178
+ hidden_dim=2048,
179
+ output_dim=256,
180
+ num_layers=2,
181
+ residual=True,
182
+ out_norm=nn.LayerNorm(256),
183
+ ),
184
+ )
185
+
186
+ # Create segmentation head if enabled
187
+ segmentation_head = (
188
+ UniversalSegmentationHead(
189
+ hidden_dim=256,
190
+ upsampling_stages=3,
191
+ aux_masks=False,
192
+ presence_head=False,
193
+ dot_product_scorer=None,
194
+ act_ckpt=True,
195
+ cross_attend_prompt=nn.MultiheadAttention(
196
+ num_heads=8,
197
+ dropout=0,
198
+ embed_dim=256,
199
+ ),
200
+ pixel_decoder=PixelDecoder(
201
+ num_upsampling_stages=3,
202
+ interpolation_mode="nearest",
203
+ hidden_dim=256,
204
+ compile_mode=compile_mode,
205
+ ),
206
+ )
207
+ if enable_segmentation
208
+ else None
209
+ )
210
+
211
+ # Create geometry encoder
212
+ input_geometry_encoder = SequenceGeometryEncoder(
213
+ pos_enc=PositionEmbeddingSine(
214
+ num_pos_feats=256,
215
+ normalize=True,
216
+ scale=None,
217
+ temperature=10000,
218
+ ),
219
+ encode_boxes_as_points=False,
220
+ boxes_direct_project=True,
221
+ boxes_pool=True,
222
+ boxes_pos_enc=True,
223
+ d_model=256,
224
+ num_layers=3,
225
+ layer=TransformerEncoderLayer(
226
+ d_model=256,
227
+ dim_feedforward=2048,
228
+ dropout=0.1,
229
+ pos_enc_at_attn=False,
230
+ pre_norm=True,
231
+ pos_enc_at_cross_attn_queries=False,
232
+ pos_enc_at_cross_attn_keys=True,
233
+ ),
234
+ use_act_ckpt=True,
235
+ add_cls=True,
236
+ add_post_encode_proj=True,
237
+ )
238
+
239
+ # Create the SAM3SemanticModel model
240
+ model = SAM3SemanticModel(
241
+ backbone=backbone,
242
+ transformer=transformer,
243
+ input_geometry_encoder=input_geometry_encoder,
244
+ segmentation_head=segmentation_head,
245
+ num_feature_levels=1,
246
+ o2m_mask_predict=True,
247
+ dot_prod_scoring=dot_prod_scoring,
248
+ use_instance_query=False,
249
+ multimask_output=True,
250
+ )
251
+
252
+ # Load checkpoint
253
+ model = _load_checkpoint(model, checkpoint_path)
254
+ model.eval()
255
+ return model
256
+
257
+
258
+ def build_interactive_sam3(checkpoint_path: str, compile=None, with_backbone=True) -> SAM3Model:
259
+ """Build the SAM3 Tracker module for video tracking.
260
+
261
+ Returns:
262
+ Sam3TrackerPredictor: Wrapped SAM3 Tracker module
263
+ """
264
+ # Create model components
265
+ memory_encoder = MemoryEncoder(out_dim=64, interpol_size=[1152, 1152])
266
+ memory_attention = MemoryAttention(
267
+ batch_first=True,
268
+ d_model=256,
269
+ pos_enc_at_input=True,
270
+ layer=MemoryAttentionLayer(
271
+ dim_feedforward=2048,
272
+ dropout=0.1,
273
+ pos_enc_at_attn=False,
274
+ pos_enc_at_cross_attn_keys=True,
275
+ pos_enc_at_cross_attn_queries=False,
276
+ self_attn=RoPEAttention(
277
+ embedding_dim=256,
278
+ num_heads=1,
279
+ downsample_rate=1,
280
+ rope_theta=10000.0,
281
+ feat_sizes=[72, 72],
282
+ ),
283
+ d_model=256,
284
+ cross_attn=RoPEAttention(
285
+ embedding_dim=256,
286
+ num_heads=1,
287
+ downsample_rate=1,
288
+ kv_in_dim=64,
289
+ rope_theta=10000.0,
290
+ feat_sizes=[72, 72],
291
+ rope_k_repeat=True,
292
+ ),
293
+ ),
294
+ num_layers=4,
295
+ )
296
+
297
+ backbone = (
298
+ SAM3VLBackbone(scalp=1, visual=_create_vision_backbone(compile_mode=compile), text=None)
299
+ if with_backbone
300
+ else None
301
+ )
302
+ model = SAM3Model(
303
+ image_size=1008,
304
+ image_encoder=backbone,
305
+ memory_attention=memory_attention,
306
+ memory_encoder=memory_encoder,
307
+ backbone_stride=14,
308
+ num_maskmem=7,
309
+ sigmoid_scale_for_mem_enc=20.0,
310
+ sigmoid_bias_for_mem_enc=-10.0,
311
+ use_mask_input_as_output_without_sam=True,
312
+ directly_add_no_mem_embed=True,
313
+ use_high_res_features_in_sam=True,
314
+ multimask_output_in_sam=True,
315
+ iou_prediction_use_sigmoid=True,
316
+ use_obj_ptrs_in_encoder=True,
317
+ add_tpos_enc_to_obj_ptrs=True,
318
+ only_obj_ptrs_in_the_past_for_eval=True,
319
+ pred_obj_scores=True,
320
+ pred_obj_scores_mlp=True,
321
+ fixed_no_obj_ptr=True,
322
+ multimask_output_for_tracking=True,
323
+ use_multimask_token_for_obj_ptr=True,
324
+ multimask_min_pt_num=0,
325
+ multimask_max_pt_num=1,
326
+ use_mlp_for_obj_ptr_proj=True,
327
+ compile_image_encoder=False,
328
+ no_obj_embed_spatial=True,
329
+ proj_tpos_enc_in_obj_ptrs=True,
330
+ use_signed_tpos_enc_to_obj_ptrs=True,
331
+ sam_mask_decoder_extra_args=dict(
332
+ dynamic_multimask_via_stability=True,
333
+ dynamic_multimask_stability_delta=0.05,
334
+ dynamic_multimask_stability_thresh=0.98,
335
+ ),
336
+ )
337
+
338
+ # Load checkpoint if provided
339
+ model = _load_checkpoint(model, checkpoint_path, interactive=True)
340
+
341
+ # Setup device and mode
342
+ model.eval()
343
+ return model
344
+
345
+
346
+ def _load_checkpoint(model, checkpoint, interactive=False):
347
+ """Load SAM3 model checkpoint from file."""
348
+ with open(checkpoint, "rb") as f:
349
+ ckpt = torch_load(f)
350
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
351
+ ckpt = ckpt["model"]
352
+ sam3_image_ckpt = {k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k}
353
+ if interactive:
354
+ sam3_image_ckpt.update(
355
+ {
356
+ k.replace("backbone.vision_backbone", "image_encoder.vision_backbone"): v
357
+ for k, v in sam3_image_ckpt.items()
358
+ if "backbone.vision_backbone" in k
359
+ }
360
+ )
361
+ sam3_image_ckpt.update(
362
+ {
363
+ k.replace("tracker.transformer.encoder", "memory_attention"): v
364
+ for k, v in ckpt.items()
365
+ if "tracker.transformer" in k
366
+ }
367
+ )
368
+ sam3_image_ckpt.update(
369
+ {
370
+ k.replace("tracker.maskmem_backbone", "memory_encoder"): v
371
+ for k, v in ckpt.items()
372
+ if "tracker.maskmem_backbone" in k
373
+ }
374
+ )
375
+ sam3_image_ckpt.update({k.replace("tracker.", ""): v for k, v in ckpt.items() if "tracker." in k})
376
+ model.load_state_dict(sam3_image_ckpt, strict=False)
377
+ return model
@@ -21,16 +21,15 @@ from pathlib import Path
21
21
  from ultralytics.engine.model import Model
22
22
  from ultralytics.utils.torch_utils import model_info
23
23
 
24
- from .predict import Predictor, SAM2Predictor
24
+ from .predict import Predictor, SAM2Predictor, SAM3Predictor
25
25
 
26
26
 
27
27
  class SAM(Model):
28
- """
29
- SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
28
+ """SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
30
29
 
31
- This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
32
- promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
33
- boxes, points, or labels, and features zero-shot performance capabilities.
30
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
31
+ segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
32
+ labels, and features zero-shot performance capabilities.
34
33
 
35
34
  Attributes:
36
35
  model (torch.nn.Module): The loaded SAM model.
@@ -45,31 +44,26 @@ class SAM(Model):
45
44
  >>> sam = SAM("sam_b.pt")
46
45
  >>> results = sam.predict("image.jpg", points=[[500, 375]])
47
46
  >>> for r in results:
48
- >>> print(f"Detected {len(r.masks)} masks")
47
+ ... print(f"Detected {len(r.masks)} masks")
49
48
  """
50
49
 
51
50
  def __init__(self, model: str = "sam_b.pt") -> None:
52
- """
53
- Initialize the SAM (Segment Anything Model) instance.
51
+ """Initialize the SAM (Segment Anything Model) instance.
54
52
 
55
53
  Args:
56
54
  model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
57
55
 
58
56
  Raises:
59
57
  NotImplementedError: If the model file extension is not .pt or .pth.
60
-
61
- Examples:
62
- >>> sam = SAM("sam_b.pt")
63
- >>> print(sam.is_sam2)
64
58
  """
65
59
  if model and Path(model).suffix not in {".pt", ".pth"}:
66
60
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
67
61
  self.is_sam2 = "sam2" in Path(model).stem
62
+ self.is_sam3 = "sam3" in Path(model).stem
68
63
  super().__init__(model=model, task="segment")
69
64
 
70
65
  def _load(self, weights: str, task=None):
71
- """
72
- Load the specified weights into the SAM model.
66
+ """Load the specified weights into the SAM model.
73
67
 
74
68
  Args:
75
69
  weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
@@ -79,17 +73,21 @@ class SAM(Model):
79
73
  >>> sam = SAM("sam_b.pt")
80
74
  >>> sam._load("path/to/custom_weights.pt")
81
75
  """
82
- from .build import build_sam # slow import
76
+ if self.is_sam3:
77
+ from .build_sam3 import build_interactive_sam3
83
78
 
84
- self.model = build_sam(weights)
79
+ self.model = build_interactive_sam3(weights)
80
+ else:
81
+ from .build import build_sam # slow import
82
+
83
+ self.model = build_sam(weights)
85
84
 
86
85
  def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
87
- """
88
- Perform segmentation prediction on the given image or video source.
86
+ """Perform segmentation prediction on the given image or video source.
89
87
 
90
88
  Args:
91
- source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
92
- a np.ndarray object.
89
+ source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
90
+ np.ndarray object.
93
91
  stream (bool): If True, enables real-time streaming.
94
92
  bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
95
93
  points (list[list[float]] | None): List of points for prompted segmentation.
@@ -111,15 +109,14 @@ class SAM(Model):
111
109
  return super().predict(source, stream, prompts=prompts, **kwargs)
112
110
 
113
111
  def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
114
- """
115
- Perform segmentation prediction on the given image or video source.
112
+ """Perform segmentation prediction on the given image or video source.
116
113
 
117
- This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
118
- for segmentation tasks.
114
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
115
+ segmentation tasks.
119
116
 
120
117
  Args:
121
- source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
122
- object, or a np.ndarray object.
118
+ source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
119
+ np.ndarray object.
123
120
  stream (bool): If True, enables real-time streaming.
124
121
  bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
125
122
  points (list[list[float]] | None): List of points for prompted segmentation.
@@ -137,8 +134,7 @@ class SAM(Model):
137
134
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
138
135
 
139
136
  def info(self, detailed: bool = False, verbose: bool = True):
140
- """
141
- Log information about the SAM model.
137
+ """Log information about the SAM model.
142
138
 
143
139
  Args:
144
140
  detailed (bool): If True, displays detailed information about the model layers and operations.
@@ -156,8 +152,7 @@ class SAM(Model):
156
152
 
157
153
  @property
158
154
  def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
159
- """
160
- Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
155
+ """Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
161
156
 
162
157
  Returns:
163
158
  (dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
@@ -169,4 +164,6 @@ class SAM(Model):
169
164
  >>> print(task_map)
170
165
  {'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
171
166
  """
172
- return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
167
+ return {
168
+ "segment": {"predictor": SAM2Predictor if self.is_sam2 else SAM3Predictor if self.is_sam3 else Predictor}
169
+ }