ultralytics-opencv-headless 8.3.246__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1026 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1578 -0
- ultralytics/engine/model.py +1124 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +974 -0
- ultralytics/engine/tuner.py +448 -0
- ultralytics/engine/validator.py +384 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +958 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +313 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1006 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +315 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +501 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1045 -0
- ultralytics/utils/tal.py +403 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +440 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +160 -0
- ultralytics_opencv_headless-8.3.246.dist-info/METADATA +374 -0
- ultralytics_opencv_headless-8.3.246.dist-info/RECORD +298 -0
- ultralytics_opencv_headless-8.3.246.dist-info/WHEEL +5 -0
- ultralytics_opencv_headless-8.3.246.dist-info/entry_points.txt +3 -0
- ultralytics_opencv_headless-8.3.246.dist-info/licenses/LICENSE +661 -0
- ultralytics_opencv_headless-8.3.246.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
4
|
+
# All rights reserved.
|
|
5
|
+
|
|
6
|
+
# This source code is licensed under the license found in the
|
|
7
|
+
# LICENSE file in the root directory of this source tree.
|
|
8
|
+
|
|
9
|
+
from functools import partial
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
|
14
|
+
from ultralytics.utils.patches import torch_load
|
|
15
|
+
|
|
16
|
+
from .modules.decoders import MaskDecoder
|
|
17
|
+
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
|
18
|
+
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
|
19
|
+
from .modules.sam import SAM2Model, SAMModel
|
|
20
|
+
from .modules.tiny_encoder import TinyViT
|
|
21
|
+
from .modules.transformer import TwoWayTransformer
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _load_checkpoint(model, checkpoint):
|
|
25
|
+
"""Load checkpoint into model from file path."""
|
|
26
|
+
if checkpoint is None:
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
checkpoint = attempt_download_asset(checkpoint)
|
|
30
|
+
with open(checkpoint, "rb") as f:
|
|
31
|
+
state_dict = torch_load(f)
|
|
32
|
+
# Handle nested "model" key
|
|
33
|
+
if "model" in state_dict and isinstance(state_dict["model"], dict):
|
|
34
|
+
state_dict = state_dict["model"]
|
|
35
|
+
model.load_state_dict(state_dict)
|
|
36
|
+
return model
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_sam_vit_h(checkpoint=None):
|
|
40
|
+
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
|
41
|
+
return _build_sam(
|
|
42
|
+
encoder_embed_dim=1280,
|
|
43
|
+
encoder_depth=32,
|
|
44
|
+
encoder_num_heads=16,
|
|
45
|
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
|
46
|
+
checkpoint=checkpoint,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_sam_vit_l(checkpoint=None):
|
|
51
|
+
"""Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
|
52
|
+
return _build_sam(
|
|
53
|
+
encoder_embed_dim=1024,
|
|
54
|
+
encoder_depth=24,
|
|
55
|
+
encoder_num_heads=16,
|
|
56
|
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
|
57
|
+
checkpoint=checkpoint,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def build_sam_vit_b(checkpoint=None):
|
|
62
|
+
"""Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
|
|
63
|
+
return _build_sam(
|
|
64
|
+
encoder_embed_dim=768,
|
|
65
|
+
encoder_depth=12,
|
|
66
|
+
encoder_num_heads=12,
|
|
67
|
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
|
68
|
+
checkpoint=checkpoint,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def build_mobile_sam(checkpoint=None):
|
|
73
|
+
"""Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
|
74
|
+
return _build_sam(
|
|
75
|
+
encoder_embed_dim=[64, 128, 160, 320],
|
|
76
|
+
encoder_depth=[2, 2, 6, 2],
|
|
77
|
+
encoder_num_heads=[2, 4, 5, 10],
|
|
78
|
+
encoder_global_attn_indexes=None,
|
|
79
|
+
mobile_sam=True,
|
|
80
|
+
checkpoint=checkpoint,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def build_sam2_t(checkpoint=None):
|
|
85
|
+
"""Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
|
86
|
+
return _build_sam2(
|
|
87
|
+
encoder_embed_dim=96,
|
|
88
|
+
encoder_stages=[1, 2, 7, 2],
|
|
89
|
+
encoder_num_heads=1,
|
|
90
|
+
encoder_global_att_blocks=[5, 7, 9],
|
|
91
|
+
encoder_window_spec=[8, 4, 14, 7],
|
|
92
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
|
93
|
+
checkpoint=checkpoint,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def build_sam2_s(checkpoint=None):
|
|
98
|
+
"""Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
|
99
|
+
return _build_sam2(
|
|
100
|
+
encoder_embed_dim=96,
|
|
101
|
+
encoder_stages=[1, 2, 11, 2],
|
|
102
|
+
encoder_num_heads=1,
|
|
103
|
+
encoder_global_att_blocks=[7, 10, 13],
|
|
104
|
+
encoder_window_spec=[8, 4, 14, 7],
|
|
105
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
|
106
|
+
checkpoint=checkpoint,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def build_sam2_b(checkpoint=None):
|
|
111
|
+
"""Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
|
|
112
|
+
return _build_sam2(
|
|
113
|
+
encoder_embed_dim=112,
|
|
114
|
+
encoder_stages=[2, 3, 16, 3],
|
|
115
|
+
encoder_num_heads=2,
|
|
116
|
+
encoder_global_att_blocks=[12, 16, 20],
|
|
117
|
+
encoder_window_spec=[8, 4, 14, 7],
|
|
118
|
+
encoder_window_spatial_size=[14, 14],
|
|
119
|
+
encoder_backbone_channel_list=[896, 448, 224, 112],
|
|
120
|
+
checkpoint=checkpoint,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def build_sam2_l(checkpoint=None):
|
|
125
|
+
"""Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
|
126
|
+
return _build_sam2(
|
|
127
|
+
encoder_embed_dim=144,
|
|
128
|
+
encoder_stages=[2, 6, 36, 4],
|
|
129
|
+
encoder_num_heads=2,
|
|
130
|
+
encoder_global_att_blocks=[23, 33, 43],
|
|
131
|
+
encoder_window_spec=[8, 4, 16, 8],
|
|
132
|
+
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
|
133
|
+
checkpoint=checkpoint,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _build_sam(
|
|
138
|
+
encoder_embed_dim,
|
|
139
|
+
encoder_depth,
|
|
140
|
+
encoder_num_heads,
|
|
141
|
+
encoder_global_attn_indexes,
|
|
142
|
+
checkpoint=None,
|
|
143
|
+
mobile_sam=False,
|
|
144
|
+
):
|
|
145
|
+
"""Build a Segment Anything Model (SAM) with specified encoder parameters.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
|
|
149
|
+
encoder_depth (int | list[int]): Depth of the encoder.
|
|
150
|
+
encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
|
|
151
|
+
encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
|
|
152
|
+
checkpoint (str | None, optional): Path to the model checkpoint file.
|
|
153
|
+
mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
|
|
160
|
+
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
|
|
161
|
+
"""
|
|
162
|
+
prompt_embed_dim = 256
|
|
163
|
+
image_size = 1024
|
|
164
|
+
vit_patch_size = 16
|
|
165
|
+
image_embedding_size = image_size // vit_patch_size
|
|
166
|
+
image_encoder = (
|
|
167
|
+
TinyViT(
|
|
168
|
+
img_size=1024,
|
|
169
|
+
in_chans=3,
|
|
170
|
+
num_classes=1000,
|
|
171
|
+
embed_dims=encoder_embed_dim,
|
|
172
|
+
depths=encoder_depth,
|
|
173
|
+
num_heads=encoder_num_heads,
|
|
174
|
+
window_sizes=[7, 7, 14, 7],
|
|
175
|
+
mlp_ratio=4.0,
|
|
176
|
+
drop_rate=0.0,
|
|
177
|
+
drop_path_rate=0.0,
|
|
178
|
+
use_checkpoint=False,
|
|
179
|
+
mbconv_expand_ratio=4.0,
|
|
180
|
+
local_conv_size=3,
|
|
181
|
+
layer_lr_decay=0.8,
|
|
182
|
+
)
|
|
183
|
+
if mobile_sam
|
|
184
|
+
else ImageEncoderViT(
|
|
185
|
+
depth=encoder_depth,
|
|
186
|
+
embed_dim=encoder_embed_dim,
|
|
187
|
+
img_size=image_size,
|
|
188
|
+
mlp_ratio=4,
|
|
189
|
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
|
190
|
+
num_heads=encoder_num_heads,
|
|
191
|
+
patch_size=vit_patch_size,
|
|
192
|
+
qkv_bias=True,
|
|
193
|
+
use_rel_pos=True,
|
|
194
|
+
global_attn_indexes=encoder_global_attn_indexes,
|
|
195
|
+
window_size=14,
|
|
196
|
+
out_chans=prompt_embed_dim,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
sam = SAMModel(
|
|
200
|
+
image_encoder=image_encoder,
|
|
201
|
+
prompt_encoder=PromptEncoder(
|
|
202
|
+
embed_dim=prompt_embed_dim,
|
|
203
|
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
|
204
|
+
input_image_size=(image_size, image_size),
|
|
205
|
+
mask_in_chans=16,
|
|
206
|
+
),
|
|
207
|
+
mask_decoder=MaskDecoder(
|
|
208
|
+
num_multimask_outputs=3,
|
|
209
|
+
transformer=TwoWayTransformer(
|
|
210
|
+
depth=2,
|
|
211
|
+
embedding_dim=prompt_embed_dim,
|
|
212
|
+
mlp_dim=2048,
|
|
213
|
+
num_heads=8,
|
|
214
|
+
),
|
|
215
|
+
transformer_dim=prompt_embed_dim,
|
|
216
|
+
iou_head_depth=3,
|
|
217
|
+
iou_head_hidden_dim=256,
|
|
218
|
+
),
|
|
219
|
+
pixel_mean=[123.675, 116.28, 103.53],
|
|
220
|
+
pixel_std=[58.395, 57.12, 57.375],
|
|
221
|
+
)
|
|
222
|
+
if checkpoint is not None:
|
|
223
|
+
sam = _load_checkpoint(sam, checkpoint)
|
|
224
|
+
sam.eval()
|
|
225
|
+
return sam
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def _build_sam2(
|
|
229
|
+
encoder_embed_dim=1280,
|
|
230
|
+
encoder_stages=(2, 6, 36, 4),
|
|
231
|
+
encoder_num_heads=2,
|
|
232
|
+
encoder_global_att_blocks=(7, 15, 23, 31),
|
|
233
|
+
encoder_backbone_channel_list=(1152, 576, 288, 144),
|
|
234
|
+
encoder_window_spatial_size=(7, 7),
|
|
235
|
+
encoder_window_spec=(8, 4, 16, 8),
|
|
236
|
+
checkpoint=None,
|
|
237
|
+
):
|
|
238
|
+
"""Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
|
242
|
+
encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
|
|
243
|
+
encoder_num_heads (int, optional): Number of attention heads in the encoder.
|
|
244
|
+
encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
|
|
245
|
+
encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
|
|
246
|
+
encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
|
|
247
|
+
encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
|
|
248
|
+
checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
(SAM2Model): A configured and initialized SAM2 model.
|
|
252
|
+
|
|
253
|
+
Examples:
|
|
254
|
+
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
|
|
255
|
+
>>> sam2_model.eval()
|
|
256
|
+
"""
|
|
257
|
+
image_encoder = ImageEncoder(
|
|
258
|
+
trunk=Hiera(
|
|
259
|
+
embed_dim=encoder_embed_dim,
|
|
260
|
+
num_heads=encoder_num_heads,
|
|
261
|
+
stages=encoder_stages,
|
|
262
|
+
global_att_blocks=encoder_global_att_blocks,
|
|
263
|
+
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
|
264
|
+
window_spec=encoder_window_spec,
|
|
265
|
+
),
|
|
266
|
+
neck=FpnNeck(
|
|
267
|
+
d_model=256,
|
|
268
|
+
backbone_channel_list=encoder_backbone_channel_list,
|
|
269
|
+
fpn_top_down_levels=[2, 3],
|
|
270
|
+
fpn_interp_model="nearest",
|
|
271
|
+
),
|
|
272
|
+
scalp=1,
|
|
273
|
+
)
|
|
274
|
+
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
|
275
|
+
memory_encoder = MemoryEncoder(out_dim=64)
|
|
276
|
+
|
|
277
|
+
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
|
|
278
|
+
sam2 = SAM2Model(
|
|
279
|
+
image_encoder=image_encoder,
|
|
280
|
+
memory_attention=memory_attention,
|
|
281
|
+
memory_encoder=memory_encoder,
|
|
282
|
+
num_maskmem=7,
|
|
283
|
+
image_size=1024,
|
|
284
|
+
sigmoid_scale_for_mem_enc=20.0,
|
|
285
|
+
sigmoid_bias_for_mem_enc=-10.0,
|
|
286
|
+
use_mask_input_as_output_without_sam=True,
|
|
287
|
+
directly_add_no_mem_embed=True,
|
|
288
|
+
use_high_res_features_in_sam=True,
|
|
289
|
+
multimask_output_in_sam=True,
|
|
290
|
+
iou_prediction_use_sigmoid=True,
|
|
291
|
+
use_obj_ptrs_in_encoder=True,
|
|
292
|
+
add_tpos_enc_to_obj_ptrs=True,
|
|
293
|
+
only_obj_ptrs_in_the_past_for_eval=True,
|
|
294
|
+
pred_obj_scores=True,
|
|
295
|
+
pred_obj_scores_mlp=True,
|
|
296
|
+
fixed_no_obj_ptr=True,
|
|
297
|
+
multimask_output_for_tracking=True,
|
|
298
|
+
use_multimask_token_for_obj_ptr=True,
|
|
299
|
+
multimask_min_pt_num=0,
|
|
300
|
+
multimask_max_pt_num=1,
|
|
301
|
+
use_mlp_for_obj_ptr_proj=True,
|
|
302
|
+
compile_image_encoder=False,
|
|
303
|
+
no_obj_embed_spatial=is_sam2_1,
|
|
304
|
+
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
|
|
305
|
+
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
|
|
306
|
+
sam_mask_decoder_extra_args=dict(
|
|
307
|
+
dynamic_multimask_via_stability=True,
|
|
308
|
+
dynamic_multimask_stability_delta=0.05,
|
|
309
|
+
dynamic_multimask_stability_thresh=0.98,
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if checkpoint is not None:
|
|
314
|
+
sam2 = _load_checkpoint(sam2, checkpoint)
|
|
315
|
+
sam2.eval()
|
|
316
|
+
return sam2
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
sam_model_map = {
|
|
320
|
+
"sam_h.pt": build_sam_vit_h,
|
|
321
|
+
"sam_l.pt": build_sam_vit_l,
|
|
322
|
+
"sam_b.pt": build_sam_vit_b,
|
|
323
|
+
"mobile_sam.pt": build_mobile_sam,
|
|
324
|
+
"sam2_t.pt": build_sam2_t,
|
|
325
|
+
"sam2_s.pt": build_sam2_s,
|
|
326
|
+
"sam2_b.pt": build_sam2_b,
|
|
327
|
+
"sam2_l.pt": build_sam2_l,
|
|
328
|
+
"sam2.1_t.pt": build_sam2_t,
|
|
329
|
+
"sam2.1_s.pt": build_sam2_s,
|
|
330
|
+
"sam2.1_b.pt": build_sam2_b,
|
|
331
|
+
"sam2.1_l.pt": build_sam2_l,
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def build_sam(ckpt="sam_b.pt"):
|
|
336
|
+
"""Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
|
343
|
+
|
|
344
|
+
Raises:
|
|
345
|
+
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
|
|
346
|
+
|
|
347
|
+
Examples:
|
|
348
|
+
>>> sam_model = build_sam("sam_b.pt")
|
|
349
|
+
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
|
|
350
|
+
|
|
351
|
+
Notes:
|
|
352
|
+
Supported pre-defined models include:
|
|
353
|
+
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
|
|
354
|
+
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
|
|
355
|
+
"""
|
|
356
|
+
model_builder = None
|
|
357
|
+
ckpt = str(ckpt) # to allow Path ckpt types
|
|
358
|
+
for k in sam_model_map.keys():
|
|
359
|
+
if ckpt.endswith(k):
|
|
360
|
+
model_builder = sam_model_map.get(k)
|
|
361
|
+
|
|
362
|
+
if not model_builder:
|
|
363
|
+
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
|
364
|
+
|
|
365
|
+
return model_builder(ckpt)
|