dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,358 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4
|
+
# All rights reserved.
|
5
|
+
|
6
|
+
# This source code is licensed under the license found in the
|
7
|
+
# LICENSE file in the root directory of this source tree.
|
8
|
+
|
9
|
+
from functools import partial
|
10
|
+
|
11
|
+
import torch
|
12
|
+
|
13
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
14
|
+
|
15
|
+
from .modules.decoders import MaskDecoder
|
16
|
+
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
17
|
+
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
18
|
+
from .modules.sam import SAM2Model, SAMModel
|
19
|
+
from .modules.tiny_encoder import TinyViT
|
20
|
+
from .modules.transformer import TwoWayTransformer
|
21
|
+
|
22
|
+
|
23
|
+
def build_sam_vit_h(checkpoint=None):
|
24
|
+
"""Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
25
|
+
return _build_sam(
|
26
|
+
encoder_embed_dim=1280,
|
27
|
+
encoder_depth=32,
|
28
|
+
encoder_num_heads=16,
|
29
|
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
30
|
+
checkpoint=checkpoint,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
def build_sam_vit_l(checkpoint=None):
|
35
|
+
"""Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
36
|
+
return _build_sam(
|
37
|
+
encoder_embed_dim=1024,
|
38
|
+
encoder_depth=24,
|
39
|
+
encoder_num_heads=16,
|
40
|
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
41
|
+
checkpoint=checkpoint,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def build_sam_vit_b(checkpoint=None):
|
46
|
+
"""Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
|
47
|
+
return _build_sam(
|
48
|
+
encoder_embed_dim=768,
|
49
|
+
encoder_depth=12,
|
50
|
+
encoder_num_heads=12,
|
51
|
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
52
|
+
checkpoint=checkpoint,
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def build_mobile_sam(checkpoint=None):
|
57
|
+
"""Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
58
|
+
return _build_sam(
|
59
|
+
encoder_embed_dim=[64, 128, 160, 320],
|
60
|
+
encoder_depth=[2, 2, 6, 2],
|
61
|
+
encoder_num_heads=[2, 4, 5, 10],
|
62
|
+
encoder_global_attn_indexes=None,
|
63
|
+
mobile_sam=True,
|
64
|
+
checkpoint=checkpoint,
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
def build_sam2_t(checkpoint=None):
|
69
|
+
"""Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
70
|
+
return _build_sam2(
|
71
|
+
encoder_embed_dim=96,
|
72
|
+
encoder_stages=[1, 2, 7, 2],
|
73
|
+
encoder_num_heads=1,
|
74
|
+
encoder_global_att_blocks=[5, 7, 9],
|
75
|
+
encoder_window_spec=[8, 4, 14, 7],
|
76
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
77
|
+
checkpoint=checkpoint,
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
def build_sam2_s(checkpoint=None):
|
82
|
+
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
83
|
+
return _build_sam2(
|
84
|
+
encoder_embed_dim=96,
|
85
|
+
encoder_stages=[1, 2, 11, 2],
|
86
|
+
encoder_num_heads=1,
|
87
|
+
encoder_global_att_blocks=[7, 10, 13],
|
88
|
+
encoder_window_spec=[8, 4, 14, 7],
|
89
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
90
|
+
checkpoint=checkpoint,
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def build_sam2_b(checkpoint=None):
|
95
|
+
"""Builds and returns a SAM2 base-size model with specified architecture parameters."""
|
96
|
+
return _build_sam2(
|
97
|
+
encoder_embed_dim=112,
|
98
|
+
encoder_stages=[2, 3, 16, 3],
|
99
|
+
encoder_num_heads=2,
|
100
|
+
encoder_global_att_blocks=[12, 16, 20],
|
101
|
+
encoder_window_spec=[8, 4, 14, 7],
|
102
|
+
encoder_window_spatial_size=[14, 14],
|
103
|
+
encoder_backbone_channel_list=[896, 448, 224, 112],
|
104
|
+
checkpoint=checkpoint,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def build_sam2_l(checkpoint=None):
|
109
|
+
"""Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
110
|
+
return _build_sam2(
|
111
|
+
encoder_embed_dim=144,
|
112
|
+
encoder_stages=[2, 6, 36, 4],
|
113
|
+
encoder_num_heads=2,
|
114
|
+
encoder_global_att_blocks=[23, 33, 43],
|
115
|
+
encoder_window_spec=[8, 4, 16, 8],
|
116
|
+
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
117
|
+
checkpoint=checkpoint,
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
def _build_sam(
|
122
|
+
encoder_embed_dim,
|
123
|
+
encoder_depth,
|
124
|
+
encoder_num_heads,
|
125
|
+
encoder_global_attn_indexes,
|
126
|
+
checkpoint=None,
|
127
|
+
mobile_sam=False,
|
128
|
+
):
|
129
|
+
"""
|
130
|
+
Builds a Segment Anything Model (SAM) with specified encoder parameters.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
|
134
|
+
encoder_depth (int | List[int]): Depth of the encoder.
|
135
|
+
encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
|
136
|
+
encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
|
137
|
+
checkpoint (str | None): Path to the model checkpoint file.
|
138
|
+
mobile_sam (bool): Whether to build a Mobile-SAM model.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
142
|
+
|
143
|
+
Examples:
|
144
|
+
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
|
145
|
+
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
|
146
|
+
"""
|
147
|
+
prompt_embed_dim = 256
|
148
|
+
image_size = 1024
|
149
|
+
vit_patch_size = 16
|
150
|
+
image_embedding_size = image_size // vit_patch_size
|
151
|
+
image_encoder = (
|
152
|
+
TinyViT(
|
153
|
+
img_size=1024,
|
154
|
+
in_chans=3,
|
155
|
+
num_classes=1000,
|
156
|
+
embed_dims=encoder_embed_dim,
|
157
|
+
depths=encoder_depth,
|
158
|
+
num_heads=encoder_num_heads,
|
159
|
+
window_sizes=[7, 7, 14, 7],
|
160
|
+
mlp_ratio=4.0,
|
161
|
+
drop_rate=0.0,
|
162
|
+
drop_path_rate=0.0,
|
163
|
+
use_checkpoint=False,
|
164
|
+
mbconv_expand_ratio=4.0,
|
165
|
+
local_conv_size=3,
|
166
|
+
layer_lr_decay=0.8,
|
167
|
+
)
|
168
|
+
if mobile_sam
|
169
|
+
else ImageEncoderViT(
|
170
|
+
depth=encoder_depth,
|
171
|
+
embed_dim=encoder_embed_dim,
|
172
|
+
img_size=image_size,
|
173
|
+
mlp_ratio=4,
|
174
|
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
175
|
+
num_heads=encoder_num_heads,
|
176
|
+
patch_size=vit_patch_size,
|
177
|
+
qkv_bias=True,
|
178
|
+
use_rel_pos=True,
|
179
|
+
global_attn_indexes=encoder_global_attn_indexes,
|
180
|
+
window_size=14,
|
181
|
+
out_chans=prompt_embed_dim,
|
182
|
+
)
|
183
|
+
)
|
184
|
+
sam = SAMModel(
|
185
|
+
image_encoder=image_encoder,
|
186
|
+
prompt_encoder=PromptEncoder(
|
187
|
+
embed_dim=prompt_embed_dim,
|
188
|
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
189
|
+
input_image_size=(image_size, image_size),
|
190
|
+
mask_in_chans=16,
|
191
|
+
),
|
192
|
+
mask_decoder=MaskDecoder(
|
193
|
+
num_multimask_outputs=3,
|
194
|
+
transformer=TwoWayTransformer(
|
195
|
+
depth=2,
|
196
|
+
embedding_dim=prompt_embed_dim,
|
197
|
+
mlp_dim=2048,
|
198
|
+
num_heads=8,
|
199
|
+
),
|
200
|
+
transformer_dim=prompt_embed_dim,
|
201
|
+
iou_head_depth=3,
|
202
|
+
iou_head_hidden_dim=256,
|
203
|
+
),
|
204
|
+
pixel_mean=[123.675, 116.28, 103.53],
|
205
|
+
pixel_std=[58.395, 57.12, 57.375],
|
206
|
+
)
|
207
|
+
if checkpoint is not None:
|
208
|
+
checkpoint = attempt_download_asset(checkpoint)
|
209
|
+
with open(checkpoint, "rb") as f:
|
210
|
+
state_dict = torch.load(f)
|
211
|
+
sam.load_state_dict(state_dict)
|
212
|
+
sam.eval()
|
213
|
+
return sam
|
214
|
+
|
215
|
+
|
216
|
+
def _build_sam2(
|
217
|
+
encoder_embed_dim=1280,
|
218
|
+
encoder_stages=[2, 6, 36, 4],
|
219
|
+
encoder_num_heads=2,
|
220
|
+
encoder_global_att_blocks=[7, 15, 23, 31],
|
221
|
+
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
222
|
+
encoder_window_spatial_size=[7, 7],
|
223
|
+
encoder_window_spec=[8, 4, 16, 8],
|
224
|
+
checkpoint=None,
|
225
|
+
):
|
226
|
+
"""
|
227
|
+
Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
encoder_embed_dim (int): Embedding dimension for the encoder.
|
231
|
+
encoder_stages (List[int]): Number of blocks in each stage of the encoder.
|
232
|
+
encoder_num_heads (int): Number of attention heads in the encoder.
|
233
|
+
encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
|
234
|
+
encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
|
235
|
+
encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
|
236
|
+
encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
|
237
|
+
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
(SAM2Model): A configured and initialized SAM2 model.
|
241
|
+
|
242
|
+
Examples:
|
243
|
+
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
|
244
|
+
>>> sam2_model.eval()
|
245
|
+
"""
|
246
|
+
image_encoder = ImageEncoder(
|
247
|
+
trunk=Hiera(
|
248
|
+
embed_dim=encoder_embed_dim,
|
249
|
+
num_heads=encoder_num_heads,
|
250
|
+
stages=encoder_stages,
|
251
|
+
global_att_blocks=encoder_global_att_blocks,
|
252
|
+
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
253
|
+
window_spec=encoder_window_spec,
|
254
|
+
),
|
255
|
+
neck=FpnNeck(
|
256
|
+
d_model=256,
|
257
|
+
backbone_channel_list=encoder_backbone_channel_list,
|
258
|
+
fpn_top_down_levels=[2, 3],
|
259
|
+
fpn_interp_model="nearest",
|
260
|
+
),
|
261
|
+
scalp=1,
|
262
|
+
)
|
263
|
+
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
264
|
+
memory_encoder = MemoryEncoder(out_dim=64)
|
265
|
+
|
266
|
+
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
|
267
|
+
sam2 = SAM2Model(
|
268
|
+
image_encoder=image_encoder,
|
269
|
+
memory_attention=memory_attention,
|
270
|
+
memory_encoder=memory_encoder,
|
271
|
+
num_maskmem=7,
|
272
|
+
image_size=1024,
|
273
|
+
sigmoid_scale_for_mem_enc=20.0,
|
274
|
+
sigmoid_bias_for_mem_enc=-10.0,
|
275
|
+
use_mask_input_as_output_without_sam=True,
|
276
|
+
directly_add_no_mem_embed=True,
|
277
|
+
use_high_res_features_in_sam=True,
|
278
|
+
multimask_output_in_sam=True,
|
279
|
+
iou_prediction_use_sigmoid=True,
|
280
|
+
use_obj_ptrs_in_encoder=True,
|
281
|
+
add_tpos_enc_to_obj_ptrs=True,
|
282
|
+
only_obj_ptrs_in_the_past_for_eval=True,
|
283
|
+
pred_obj_scores=True,
|
284
|
+
pred_obj_scores_mlp=True,
|
285
|
+
fixed_no_obj_ptr=True,
|
286
|
+
multimask_output_for_tracking=True,
|
287
|
+
use_multimask_token_for_obj_ptr=True,
|
288
|
+
multimask_min_pt_num=0,
|
289
|
+
multimask_max_pt_num=1,
|
290
|
+
use_mlp_for_obj_ptr_proj=True,
|
291
|
+
compile_image_encoder=False,
|
292
|
+
no_obj_embed_spatial=is_sam2_1,
|
293
|
+
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
|
294
|
+
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
|
295
|
+
sam_mask_decoder_extra_args=dict(
|
296
|
+
dynamic_multimask_via_stability=True,
|
297
|
+
dynamic_multimask_stability_delta=0.05,
|
298
|
+
dynamic_multimask_stability_thresh=0.98,
|
299
|
+
),
|
300
|
+
)
|
301
|
+
|
302
|
+
if checkpoint is not None:
|
303
|
+
checkpoint = attempt_download_asset(checkpoint)
|
304
|
+
with open(checkpoint, "rb") as f:
|
305
|
+
state_dict = torch.load(f)["model"]
|
306
|
+
sam2.load_state_dict(state_dict)
|
307
|
+
sam2.eval()
|
308
|
+
return sam2
|
309
|
+
|
310
|
+
|
311
|
+
sam_model_map = {
|
312
|
+
"sam_h.pt": build_sam_vit_h,
|
313
|
+
"sam_l.pt": build_sam_vit_l,
|
314
|
+
"sam_b.pt": build_sam_vit_b,
|
315
|
+
"mobile_sam.pt": build_mobile_sam,
|
316
|
+
"sam2_t.pt": build_sam2_t,
|
317
|
+
"sam2_s.pt": build_sam2_s,
|
318
|
+
"sam2_b.pt": build_sam2_b,
|
319
|
+
"sam2_l.pt": build_sam2_l,
|
320
|
+
"sam2.1_t.pt": build_sam2_t,
|
321
|
+
"sam2.1_s.pt": build_sam2_s,
|
322
|
+
"sam2.1_b.pt": build_sam2_b,
|
323
|
+
"sam2.1_l.pt": build_sam2_l,
|
324
|
+
}
|
325
|
+
|
326
|
+
|
327
|
+
def build_sam(ckpt="sam_b.pt"):
|
328
|
+
"""
|
329
|
+
Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
336
|
+
|
337
|
+
Raises:
|
338
|
+
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
|
339
|
+
|
340
|
+
Examples:
|
341
|
+
>>> sam_model = build_sam("sam_b.pt")
|
342
|
+
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
|
343
|
+
|
344
|
+
Notes:
|
345
|
+
Supported pre-defined models include:
|
346
|
+
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
|
347
|
+
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
|
348
|
+
"""
|
349
|
+
model_builder = None
|
350
|
+
ckpt = str(ckpt) # to allow Path ckpt types
|
351
|
+
for k in sam_model_map.keys():
|
352
|
+
if ckpt.endswith(k):
|
353
|
+
model_builder = sam_model_map.get(k)
|
354
|
+
|
355
|
+
if not model_builder:
|
356
|
+
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
357
|
+
|
358
|
+
return model_builder(ckpt)
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
SAM model interface.
|
4
|
+
|
5
|
+
This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
|
6
|
+
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
|
7
|
+
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
8
|
+
image distributions and tasks without prior knowledge.
|
9
|
+
|
10
|
+
Key Features:
|
11
|
+
- Promptable segmentation
|
12
|
+
- Real-time performance
|
13
|
+
- Zero-shot transfer capabilities
|
14
|
+
- Trained on SA-1B dataset
|
15
|
+
"""
|
16
|
+
|
17
|
+
from pathlib import Path
|
18
|
+
|
19
|
+
from ultralytics.engine.model import Model
|
20
|
+
from ultralytics.utils.torch_utils import model_info
|
21
|
+
|
22
|
+
from .predict import Predictor, SAM2Predictor
|
23
|
+
|
24
|
+
|
25
|
+
class SAM(Model):
|
26
|
+
"""
|
27
|
+
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
28
|
+
|
29
|
+
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
|
30
|
+
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
|
31
|
+
boxes, points, or labels, and features zero-shot performance capabilities.
|
32
|
+
|
33
|
+
Attributes:
|
34
|
+
model (torch.nn.Module): The loaded SAM model.
|
35
|
+
is_sam2 (bool): Indicates whether the model is SAM2 variant.
|
36
|
+
task (str): The task type, set to "segment" for SAM models.
|
37
|
+
|
38
|
+
Methods:
|
39
|
+
predict: Performs segmentation prediction on the given image or video source.
|
40
|
+
info: Logs information about the SAM model.
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
>>> sam = SAM("sam_b.pt")
|
44
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
45
|
+
>>> for r in results:
|
46
|
+
>>> print(f"Detected {len(r.masks)} masks")
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, model="sam_b.pt") -> None:
|
50
|
+
"""
|
51
|
+
Initialize the SAM (Segment Anything Model) instance.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
55
|
+
|
56
|
+
Raises:
|
57
|
+
NotImplementedError: If the model file extension is not .pt or .pth.
|
58
|
+
|
59
|
+
Examples:
|
60
|
+
>>> sam = SAM("sam_b.pt")
|
61
|
+
>>> print(sam.is_sam2)
|
62
|
+
"""
|
63
|
+
if model and Path(model).suffix not in {".pt", ".pth"}:
|
64
|
+
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
65
|
+
self.is_sam2 = "sam2" in Path(model).stem
|
66
|
+
super().__init__(model=model, task="segment")
|
67
|
+
|
68
|
+
def _load(self, weights: str, task=None):
|
69
|
+
"""
|
70
|
+
Load the specified weights into the SAM model.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
74
|
+
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
75
|
+
|
76
|
+
Examples:
|
77
|
+
>>> sam = SAM("sam_b.pt")
|
78
|
+
>>> sam._load("path/to/custom_weights.pt")
|
79
|
+
"""
|
80
|
+
from .build import build_sam # slow import
|
81
|
+
|
82
|
+
self.model = build_sam(weights)
|
83
|
+
|
84
|
+
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
85
|
+
"""
|
86
|
+
Perform segmentation prediction on the given image or video source.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
|
90
|
+
a numpy.ndarray object.
|
91
|
+
stream (bool): If True, enables real-time streaming.
|
92
|
+
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
93
|
+
points (List[List[float]] | None): List of points for prompted segmentation.
|
94
|
+
labels (List[int] | None): List of labels for prompted segmentation.
|
95
|
+
**kwargs (Any): Additional keyword arguments for prediction.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
(list): The model predictions.
|
99
|
+
|
100
|
+
Examples:
|
101
|
+
>>> sam = SAM("sam_b.pt")
|
102
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
103
|
+
>>> for r in results:
|
104
|
+
... print(f"Detected {len(r.masks)} masks")
|
105
|
+
"""
|
106
|
+
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
107
|
+
kwargs = {**overrides, **kwargs}
|
108
|
+
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
109
|
+
return super().predict(source, stream, prompts=prompts, **kwargs)
|
110
|
+
|
111
|
+
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
112
|
+
"""
|
113
|
+
Perform segmentation prediction on the given image or video source.
|
114
|
+
|
115
|
+
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
116
|
+
for segmentation tasks.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
|
120
|
+
object, or a numpy.ndarray object.
|
121
|
+
stream (bool): If True, enables real-time streaming.
|
122
|
+
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
123
|
+
points (List[List[float]] | None): List of points for prompted segmentation.
|
124
|
+
labels (List[int] | None): List of labels for prompted segmentation.
|
125
|
+
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
(list): The model predictions, typically containing segmentation masks and other relevant information.
|
129
|
+
|
130
|
+
Examples:
|
131
|
+
>>> sam = SAM("sam_b.pt")
|
132
|
+
>>> results = sam("image.jpg", points=[[500, 375]])
|
133
|
+
>>> print(f"Detected {len(results[0].masks)} masks")
|
134
|
+
"""
|
135
|
+
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
136
|
+
|
137
|
+
def info(self, detailed=False, verbose=True):
|
138
|
+
"""
|
139
|
+
Log information about the SAM model.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
detailed (bool): If True, displays detailed information about the model layers and operations.
|
143
|
+
verbose (bool): If True, prints the information to the console.
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
(tuple): A tuple containing the model's information (string representations of the model).
|
147
|
+
|
148
|
+
Examples:
|
149
|
+
>>> sam = SAM("sam_b.pt")
|
150
|
+
>>> info = sam.info()
|
151
|
+
>>> print(info[0]) # Print summary information
|
152
|
+
"""
|
153
|
+
return model_info(self.model, detailed=detailed, verbose=verbose)
|
154
|
+
|
155
|
+
@property
|
156
|
+
def task_map(self):
|
157
|
+
"""
|
158
|
+
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
(Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding Predictor
|
162
|
+
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
163
|
+
|
164
|
+
Examples:
|
165
|
+
>>> sam = SAM("sam_b.pt")
|
166
|
+
>>> task_map = sam.task_map
|
167
|
+
>>> print(task_map)
|
168
|
+
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
169
|
+
"""
|
170
|
+
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
@@ -0,0 +1 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|