dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
3
8
|
import torch
|
|
4
9
|
|
|
5
10
|
from ultralytics.data import YOLODataset
|
|
@@ -11,48 +16,61 @@ __all__ = ("RTDETRValidator",) # tuple or list
|
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
class RTDETRDataset(YOLODataset):
|
|
14
|
-
"""
|
|
15
|
-
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
|
19
|
+
"""Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
|
16
20
|
|
|
17
21
|
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
|
18
22
|
real-time detection and tracking tasks.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
augment (bool): Whether to apply data augmentation.
|
|
26
|
+
rect (bool): Whether to use rectangular training.
|
|
27
|
+
use_segments (bool): Whether to use segmentation masks.
|
|
28
|
+
use_keypoints (bool): Whether to use keypoint annotations.
|
|
29
|
+
imgsz (int): Target image size for training.
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
load_image: Load one image from dataset index.
|
|
33
|
+
build_transforms: Build transformation pipeline for the dataset.
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
Initialize an RT-DETR dataset
|
|
37
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
|
|
38
|
+
>>> image, hw = dataset.load_image(0)
|
|
19
39
|
"""
|
|
20
40
|
|
|
21
41
|
def __init__(self, *args, data=None, **kwargs):
|
|
22
|
-
"""
|
|
23
|
-
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
|
42
|
+
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
|
24
43
|
|
|
25
44
|
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
|
26
45
|
model, building upon the base YOLODataset functionality.
|
|
27
46
|
|
|
28
47
|
Args:
|
|
29
48
|
*args (Any): Variable length argument list passed to the parent YOLODataset class.
|
|
30
|
-
data (
|
|
49
|
+
data (dict | None): Dictionary containing dataset information. If None, default values will be used.
|
|
31
50
|
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
|
|
32
51
|
"""
|
|
33
52
|
super().__init__(*args, data=data, **kwargs)
|
|
34
53
|
|
|
35
54
|
def load_image(self, i, rect_mode=False):
|
|
36
|
-
"""
|
|
37
|
-
Load one image from dataset index 'i'.
|
|
55
|
+
"""Load one image from dataset index 'i'.
|
|
38
56
|
|
|
39
57
|
Args:
|
|
40
58
|
i (int): Index of the image to load.
|
|
41
59
|
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
|
42
60
|
|
|
43
61
|
Returns:
|
|
44
|
-
im (
|
|
62
|
+
im (torch.Tensor): The loaded image.
|
|
45
63
|
resized_hw (tuple): Height and width of the resized image with shape (2,).
|
|
46
64
|
|
|
47
65
|
Examples:
|
|
48
|
-
|
|
66
|
+
Load an image from the dataset
|
|
67
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images")
|
|
49
68
|
>>> image, hw = dataset.load_image(0)
|
|
50
69
|
"""
|
|
51
70
|
return super().load_image(i=i, rect_mode=rect_mode)
|
|
52
71
|
|
|
53
72
|
def build_transforms(self, hyp=None):
|
|
54
|
-
"""
|
|
55
|
-
Build transformation pipeline for the dataset.
|
|
73
|
+
"""Build transformation pipeline for the dataset.
|
|
56
74
|
|
|
57
75
|
Args:
|
|
58
76
|
hyp (dict, optional): Hyperparameters for transformations.
|
|
@@ -67,7 +85,7 @@ class RTDETRDataset(YOLODataset):
|
|
|
67
85
|
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
|
68
86
|
else:
|
|
69
87
|
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
|
|
70
|
-
transforms = Compose([])
|
|
88
|
+
transforms = Compose([lambda x: {**x, **{"ratio_pad": [x["ratio_pad"], [0, 0]]}}])
|
|
71
89
|
transforms.append(
|
|
72
90
|
Format(
|
|
73
91
|
bbox_format="xywh",
|
|
@@ -83,30 +101,38 @@ class RTDETRDataset(YOLODataset):
|
|
|
83
101
|
|
|
84
102
|
|
|
85
103
|
class RTDETRValidator(DetectionValidator):
|
|
86
|
-
"""
|
|
87
|
-
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
|
104
|
+
"""RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
|
88
105
|
the RT-DETR (Real-Time DETR) object detection model.
|
|
89
106
|
|
|
90
107
|
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
|
91
108
|
post-processing, and updates evaluation metrics accordingly.
|
|
92
109
|
|
|
110
|
+
Attributes:
|
|
111
|
+
args (Namespace): Configuration arguments for validation.
|
|
112
|
+
data (dict): Dataset configuration dictionary.
|
|
113
|
+
|
|
114
|
+
Methods:
|
|
115
|
+
build_dataset: Build an RTDETR Dataset for validation.
|
|
116
|
+
postprocess: Apply Non-maximum suppression to prediction outputs.
|
|
117
|
+
|
|
93
118
|
Examples:
|
|
119
|
+
Initialize and run RT-DETR validation
|
|
94
120
|
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
|
95
121
|
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
|
96
122
|
>>> validator = RTDETRValidator(args=args)
|
|
97
123
|
>>> validator()
|
|
98
124
|
|
|
99
|
-
|
|
125
|
+
Notes:
|
|
100
126
|
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
|
101
127
|
"""
|
|
102
128
|
|
|
103
129
|
def build_dataset(self, img_path, mode="val", batch=None):
|
|
104
|
-
"""
|
|
105
|
-
Build an RTDETR Dataset.
|
|
130
|
+
"""Build an RTDETR Dataset.
|
|
106
131
|
|
|
107
132
|
Args:
|
|
108
133
|
img_path (str): Path to the folder containing images.
|
|
109
|
-
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for
|
|
134
|
+
mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
|
|
135
|
+
each mode.
|
|
110
136
|
batch (int, optional): Size of batches, this is for `rect`.
|
|
111
137
|
|
|
112
138
|
Returns:
|
|
@@ -124,15 +150,21 @@ class RTDETRValidator(DetectionValidator):
|
|
|
124
150
|
data=self.data,
|
|
125
151
|
)
|
|
126
152
|
|
|
127
|
-
def postprocess(
|
|
128
|
-
|
|
129
|
-
|
|
153
|
+
def postprocess(
|
|
154
|
+
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
|
155
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
156
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
130
157
|
|
|
131
158
|
Args:
|
|
132
|
-
preds (
|
|
159
|
+
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
|
|
160
|
+
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
|
|
161
|
+
class scores.
|
|
133
162
|
|
|
134
163
|
Returns:
|
|
135
|
-
(
|
|
164
|
+
(list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
|
|
165
|
+
- 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
|
|
166
|
+
- 'conf': Tensor of shape (N,) with confidence scores
|
|
167
|
+
- 'cls': Tensor of shape (N,) with class indices
|
|
136
168
|
"""
|
|
137
169
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
|
138
170
|
preds = [preds, None]
|
|
@@ -149,43 +181,31 @@ class RTDETRValidator(DetectionValidator):
|
|
|
149
181
|
pred = pred[score.argsort(descending=True)]
|
|
150
182
|
outputs[i] = pred[score > self.args.conf]
|
|
151
183
|
|
|
152
|
-
return outputs
|
|
153
|
-
|
|
154
|
-
def _prepare_batch(self, si, batch):
|
|
155
|
-
"""
|
|
156
|
-
Prepares a batch for validation by applying necessary transformations.
|
|
157
|
-
|
|
158
|
-
Args:
|
|
159
|
-
si (int): Batch index.
|
|
160
|
-
batch (dict): Batch data containing images and annotations.
|
|
184
|
+
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
|
161
185
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
"""
|
|
165
|
-
idx = batch["batch_idx"] == si
|
|
166
|
-
cls = batch["cls"][idx].squeeze(-1)
|
|
167
|
-
bbox = batch["bboxes"][idx]
|
|
168
|
-
ori_shape = batch["ori_shape"][si]
|
|
169
|
-
imgsz = batch["img"].shape[2:]
|
|
170
|
-
ratio_pad = batch["ratio_pad"][si]
|
|
171
|
-
if len(cls):
|
|
172
|
-
bbox = ops.xywh2xyxy(bbox) # target boxes
|
|
173
|
-
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
|
174
|
-
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
|
|
175
|
-
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
|
176
|
-
|
|
177
|
-
def _prepare_pred(self, pred, pbatch):
|
|
178
|
-
"""
|
|
179
|
-
Prepares predictions by scaling bounding boxes to original image dimensions.
|
|
186
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
187
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
180
188
|
|
|
181
189
|
Args:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
Returns:
|
|
186
|
-
(torch.Tensor): Predictions scaled to original image dimensions.
|
|
190
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
191
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
192
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
187
193
|
"""
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
194
|
+
path = Path(pbatch["im_file"])
|
|
195
|
+
stem = path.stem
|
|
196
|
+
image_id = int(stem) if stem.isnumeric() else stem
|
|
197
|
+
box = predn["bboxes"].clone()
|
|
198
|
+
box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
|
199
|
+
box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
|
200
|
+
box = ops.xyxy2xywh(box) # xywh
|
|
201
|
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
|
202
|
+
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
|
203
|
+
self.jdict.append(
|
|
204
|
+
{
|
|
205
|
+
"image_id": image_id,
|
|
206
|
+
"file_name": path.name,
|
|
207
|
+
"category_id": self.class_map[int(c)],
|
|
208
|
+
"bbox": [round(x, 3) for x in b],
|
|
209
|
+
"score": round(s, 5),
|
|
210
|
+
}
|
|
211
|
+
)
|
|
@@ -1,6 +1,12 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from .model import SAM
|
|
4
|
-
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
|
4
|
+
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
|
|
5
5
|
|
|
6
|
-
__all__ =
|
|
6
|
+
__all__ = (
|
|
7
|
+
"SAM",
|
|
8
|
+
"Predictor",
|
|
9
|
+
"SAM2DynamicInteractivePredictor",
|
|
10
|
+
"SAM2Predictor",
|
|
11
|
+
"SAM2VideoPredictor",
|
|
12
|
+
) # tuple or list of exportable items
|
ultralytics/models/sam/amg.py
CHANGED
|
@@ -1,17 +1,36 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import math
|
|
6
|
+
from collections.abc import Generator
|
|
4
7
|
from itertools import product
|
|
5
|
-
from typing import Any
|
|
8
|
+
from typing import Any
|
|
6
9
|
|
|
7
10
|
import numpy as np
|
|
8
11
|
import torch
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
def is_box_near_crop_edge(
|
|
12
|
-
boxes: torch.Tensor, crop_box:
|
|
15
|
+
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
|
13
16
|
) -> torch.Tensor:
|
|
14
|
-
"""
|
|
17
|
+
"""Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
|
21
|
+
crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
|
|
22
|
+
orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
|
|
23
|
+
atol (float, optional): Absolute tolerance for edge proximity detection.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
(torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
|
|
30
|
+
>>> crop_box = [0, 0, 200, 200]
|
|
31
|
+
>>> orig_box = [0, 0, 300, 300]
|
|
32
|
+
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
|
|
33
|
+
"""
|
|
15
34
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
|
16
35
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
|
17
36
|
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
|
@@ -21,9 +40,8 @@ def is_box_near_crop_edge(
|
|
|
21
40
|
return torch.any(near_crop_edge, dim=1)
|
|
22
41
|
|
|
23
42
|
|
|
24
|
-
def batch_iterator(batch_size: int, *args) -> Generator[
|
|
25
|
-
"""
|
|
26
|
-
Yield batches of data from input arguments with specified batch size for efficient processing.
|
|
43
|
+
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
|
44
|
+
"""Yield batches of data from input arguments with specified batch size for efficient processing.
|
|
27
45
|
|
|
28
46
|
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
|
29
47
|
iterables. All input iterables must have the same length.
|
|
@@ -33,7 +51,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
|
33
51
|
*args (Any): Variable length input iterables to batch. All iterables must have the same length.
|
|
34
52
|
|
|
35
53
|
Yields:
|
|
36
|
-
(
|
|
54
|
+
(list[Any]): A list of batched elements from each input iterable.
|
|
37
55
|
|
|
38
56
|
Examples:
|
|
39
57
|
>>> data = [1, 2, 3, 4, 5]
|
|
@@ -51,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
|
51
69
|
|
|
52
70
|
|
|
53
71
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
|
54
|
-
"""
|
|
55
|
-
Computes the stability score for a batch of masks.
|
|
72
|
+
"""Compute the stability score for a batch of masks.
|
|
56
73
|
|
|
57
|
-
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
|
58
|
-
|
|
74
|
+
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
|
|
75
|
+
low values.
|
|
59
76
|
|
|
60
77
|
Args:
|
|
61
78
|
masks (torch.Tensor): Batch of predicted mask logits.
|
|
@@ -65,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|
|
65
82
|
Returns:
|
|
66
83
|
(torch.Tensor): Stability scores for each mask in the batch.
|
|
67
84
|
|
|
68
|
-
Notes:
|
|
69
|
-
- One mask is always contained inside the other.
|
|
70
|
-
- Memory is saved by preventing unnecessary cast to torch.int64.
|
|
71
|
-
|
|
72
85
|
Examples:
|
|
73
86
|
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
|
74
87
|
>>> mask_threshold = 0.5
|
|
75
88
|
>>> threshold_offset = 0.1
|
|
76
89
|
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
|
90
|
+
|
|
91
|
+
Notes:
|
|
92
|
+
- One mask is always contained inside the other.
|
|
93
|
+
- Memory is saved by preventing unnecessary cast to torch.int64.
|
|
77
94
|
"""
|
|
78
95
|
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
79
96
|
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
@@ -89,25 +106,24 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
|
89
106
|
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
|
90
107
|
|
|
91
108
|
|
|
92
|
-
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) ->
|
|
93
|
-
"""
|
|
109
|
+
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
|
|
110
|
+
"""Generate point grids for multiple crop layers with varying scales and densities."""
|
|
94
111
|
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
|
95
112
|
|
|
96
113
|
|
|
97
114
|
def generate_crop_boxes(
|
|
98
|
-
im_size:
|
|
99
|
-
) ->
|
|
100
|
-
"""
|
|
101
|
-
Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
|
115
|
+
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
|
116
|
+
) -> tuple[list[list[int]], list[int]]:
|
|
117
|
+
"""Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
|
102
118
|
|
|
103
119
|
Args:
|
|
104
|
-
im_size (
|
|
120
|
+
im_size (tuple[int, ...]): Height and width of the input image.
|
|
105
121
|
n_layers (int): Number of layers to generate crop boxes for.
|
|
106
122
|
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
|
107
123
|
|
|
108
124
|
Returns:
|
|
109
|
-
(
|
|
110
|
-
(
|
|
125
|
+
crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
|
126
|
+
layer_idxs (list[int]): List of layer indices corresponding to each crop box.
|
|
111
127
|
|
|
112
128
|
Examples:
|
|
113
129
|
>>> im_size = (800, 1200) # Height, width
|
|
@@ -124,8 +140,8 @@ def generate_crop_boxes(
|
|
|
124
140
|
layer_idxs.append(0)
|
|
125
141
|
|
|
126
142
|
def crop_len(orig_len, n_crops, overlap):
|
|
127
|
-
"""
|
|
128
|
-
return
|
|
143
|
+
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
|
|
144
|
+
return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
|
|
129
145
|
|
|
130
146
|
for i_layer in range(n_layers):
|
|
131
147
|
n_crops_per_side = 2 ** (i_layer + 1)
|
|
@@ -146,7 +162,7 @@ def generate_crop_boxes(
|
|
|
146
162
|
return crop_boxes, layer_idxs
|
|
147
163
|
|
|
148
164
|
|
|
149
|
-
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box:
|
|
165
|
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
|
150
166
|
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
|
151
167
|
x0, y0, _, _ = crop_box
|
|
152
168
|
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
|
@@ -156,7 +172,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
|
156
172
|
return boxes + offset
|
|
157
173
|
|
|
158
174
|
|
|
159
|
-
def uncrop_points(points: torch.Tensor, crop_box:
|
|
175
|
+
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
|
160
176
|
"""Uncrop points by adding the crop box offset to their coordinates."""
|
|
161
177
|
x0, y0, _, _ = crop_box
|
|
162
178
|
offset = torch.tensor([[x0, y0]], device=points.device)
|
|
@@ -166,7 +182,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
|
166
182
|
return points + offset
|
|
167
183
|
|
|
168
184
|
|
|
169
|
-
def uncrop_masks(masks: torch.Tensor, crop_box:
|
|
185
|
+
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
|
170
186
|
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
|
171
187
|
x0, y0, x1, y1 = crop_box
|
|
172
188
|
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
|
@@ -177,18 +193,18 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|
|
177
193
|
return torch.nn.functional.pad(masks, pad, value=0)
|
|
178
194
|
|
|
179
195
|
|
|
180
|
-
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) ->
|
|
181
|
-
"""
|
|
182
|
-
Removes small disconnected regions or holes in a mask based on area threshold and mode.
|
|
196
|
+
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
|
197
|
+
"""Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
|
183
198
|
|
|
184
199
|
Args:
|
|
185
200
|
mask (np.ndarray): Binary mask to process.
|
|
186
201
|
area_thresh (float): Area threshold below which regions will be removed.
|
|
187
|
-
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
|
202
|
+
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
|
203
|
+
regions.
|
|
188
204
|
|
|
189
205
|
Returns:
|
|
190
|
-
(np.ndarray): Processed binary mask with small regions removed.
|
|
191
|
-
(bool): Whether any regions were modified.
|
|
206
|
+
processed_mask (np.ndarray): Processed binary mask with small regions removed.
|
|
207
|
+
modified (bool): Whether any regions were modified.
|
|
192
208
|
|
|
193
209
|
Examples:
|
|
194
210
|
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
|
@@ -206,7 +222,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
|
206
222
|
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
207
223
|
if not small_regions:
|
|
208
224
|
return mask, False
|
|
209
|
-
fill_labels = [0
|
|
225
|
+
fill_labels = [0, *small_regions]
|
|
210
226
|
if not correct_holes:
|
|
211
227
|
# If every region is below threshold, keep largest
|
|
212
228
|
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
|
@@ -215,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
|
215
231
|
|
|
216
232
|
|
|
217
233
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
|
218
|
-
"""
|
|
219
|
-
Calculates bounding boxes in XYXY format around binary masks.
|
|
234
|
+
"""Calculate bounding boxes in XYXY format around binary masks.
|
|
220
235
|
|
|
221
236
|
Args:
|
|
222
237
|
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
ultralytics/models/sam/build.py
CHANGED
|
@@ -11,6 +11,7 @@ from functools import partial
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
14
|
+
from ultralytics.utils.torch_utils import TORCH_1_13
|
|
14
15
|
|
|
15
16
|
from .modules.decoders import MaskDecoder
|
|
16
17
|
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
|
@@ -21,7 +22,7 @@ from .modules.transformer import TwoWayTransformer
|
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def build_sam_vit_h(checkpoint=None):
|
|
24
|
-
"""
|
|
25
|
+
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
|
25
26
|
return _build_sam(
|
|
26
27
|
encoder_embed_dim=1280,
|
|
27
28
|
encoder_depth=32,
|
|
@@ -32,7 +33,7 @@ def build_sam_vit_h(checkpoint=None):
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def build_sam_vit_l(checkpoint=None):
|
|
35
|
-
"""
|
|
36
|
+
"""Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
|
36
37
|
return _build_sam(
|
|
37
38
|
encoder_embed_dim=1024,
|
|
38
39
|
encoder_depth=24,
|
|
@@ -43,7 +44,7 @@ def build_sam_vit_l(checkpoint=None):
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
def build_sam_vit_b(checkpoint=None):
|
|
46
|
-
"""
|
|
47
|
+
"""Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
|
|
47
48
|
return _build_sam(
|
|
48
49
|
encoder_embed_dim=768,
|
|
49
50
|
encoder_depth=12,
|
|
@@ -54,7 +55,7 @@ def build_sam_vit_b(checkpoint=None):
|
|
|
54
55
|
|
|
55
56
|
|
|
56
57
|
def build_mobile_sam(checkpoint=None):
|
|
57
|
-
"""
|
|
58
|
+
"""Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
|
58
59
|
return _build_sam(
|
|
59
60
|
encoder_embed_dim=[64, 128, 160, 320],
|
|
60
61
|
encoder_depth=[2, 2, 6, 2],
|
|
@@ -66,7 +67,7 @@ def build_mobile_sam(checkpoint=None):
|
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
def build_sam2_t(checkpoint=None):
|
|
69
|
-
"""
|
|
70
|
+
"""Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
|
70
71
|
return _build_sam2(
|
|
71
72
|
encoder_embed_dim=96,
|
|
72
73
|
encoder_stages=[1, 2, 7, 2],
|
|
@@ -79,7 +80,7 @@ def build_sam2_t(checkpoint=None):
|
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
def build_sam2_s(checkpoint=None):
|
|
82
|
-
"""
|
|
83
|
+
"""Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
|
83
84
|
return _build_sam2(
|
|
84
85
|
encoder_embed_dim=96,
|
|
85
86
|
encoder_stages=[1, 2, 11, 2],
|
|
@@ -92,7 +93,7 @@ def build_sam2_s(checkpoint=None):
|
|
|
92
93
|
|
|
93
94
|
|
|
94
95
|
def build_sam2_b(checkpoint=None):
|
|
95
|
-
"""
|
|
96
|
+
"""Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
|
|
96
97
|
return _build_sam2(
|
|
97
98
|
encoder_embed_dim=112,
|
|
98
99
|
encoder_stages=[2, 3, 16, 3],
|
|
@@ -106,7 +107,7 @@ def build_sam2_b(checkpoint=None):
|
|
|
106
107
|
|
|
107
108
|
|
|
108
109
|
def build_sam2_l(checkpoint=None):
|
|
109
|
-
"""
|
|
110
|
+
"""Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
|
110
111
|
return _build_sam2(
|
|
111
112
|
encoder_embed_dim=144,
|
|
112
113
|
encoder_stages=[2, 6, 36, 4],
|
|
@@ -126,16 +127,15 @@ def _build_sam(
|
|
|
126
127
|
checkpoint=None,
|
|
127
128
|
mobile_sam=False,
|
|
128
129
|
):
|
|
129
|
-
"""
|
|
130
|
-
Builds a Segment Anything Model (SAM) with specified encoder parameters.
|
|
130
|
+
"""Build a Segment Anything Model (SAM) with specified encoder parameters.
|
|
131
131
|
|
|
132
132
|
Args:
|
|
133
|
-
encoder_embed_dim (int |
|
|
134
|
-
encoder_depth (int |
|
|
135
|
-
encoder_num_heads (int |
|
|
136
|
-
encoder_global_attn_indexes (
|
|
137
|
-
checkpoint (str | None): Path to the model checkpoint file.
|
|
138
|
-
mobile_sam (bool): Whether to build a Mobile-SAM model.
|
|
133
|
+
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
|
|
134
|
+
encoder_depth (int | list[int]): Depth of the encoder.
|
|
135
|
+
encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
|
|
136
|
+
encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
|
|
137
|
+
checkpoint (str | None, optional): Path to the model checkpoint file.
|
|
138
|
+
mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
|
|
139
139
|
|
|
140
140
|
Returns:
|
|
141
141
|
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
|
@@ -207,7 +207,7 @@ def _build_sam(
|
|
|
207
207
|
if checkpoint is not None:
|
|
208
208
|
checkpoint = attempt_download_asset(checkpoint)
|
|
209
209
|
with open(checkpoint, "rb") as f:
|
|
210
|
-
state_dict = torch.load(f)
|
|
210
|
+
state_dict = torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f)
|
|
211
211
|
sam.load_state_dict(state_dict)
|
|
212
212
|
sam.eval()
|
|
213
213
|
return sam
|
|
@@ -223,18 +223,17 @@ def _build_sam2(
|
|
|
223
223
|
encoder_window_spec=[8, 4, 16, 8],
|
|
224
224
|
checkpoint=None,
|
|
225
225
|
):
|
|
226
|
-
"""
|
|
227
|
-
Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
|
226
|
+
"""Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
|
228
227
|
|
|
229
228
|
Args:
|
|
230
|
-
encoder_embed_dim (int): Embedding dimension for the encoder.
|
|
231
|
-
encoder_stages (
|
|
232
|
-
encoder_num_heads (int): Number of attention heads in the encoder.
|
|
233
|
-
encoder_global_att_blocks (
|
|
234
|
-
encoder_backbone_channel_list (
|
|
235
|
-
encoder_window_spatial_size (
|
|
236
|
-
encoder_window_spec (
|
|
237
|
-
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
|
|
229
|
+
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
|
230
|
+
encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
|
|
231
|
+
encoder_num_heads (int, optional): Number of attention heads in the encoder.
|
|
232
|
+
encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
|
|
233
|
+
encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
|
|
234
|
+
encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
|
|
235
|
+
encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
|
|
236
|
+
checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
|
|
238
237
|
|
|
239
238
|
Returns:
|
|
240
239
|
(SAM2Model): A configured and initialized SAM2 model.
|
|
@@ -302,7 +301,7 @@ def _build_sam2(
|
|
|
302
301
|
if checkpoint is not None:
|
|
303
302
|
checkpoint = attempt_download_asset(checkpoint)
|
|
304
303
|
with open(checkpoint, "rb") as f:
|
|
305
|
-
state_dict = torch.load(f)["model"]
|
|
304
|
+
state_dict = (torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f))["model"]
|
|
306
305
|
sam2.load_state_dict(state_dict)
|
|
307
306
|
sam2.eval()
|
|
308
307
|
return sam2
|
|
@@ -325,11 +324,10 @@ sam_model_map = {
|
|
|
325
324
|
|
|
326
325
|
|
|
327
326
|
def build_sam(ckpt="sam_b.pt"):
|
|
328
|
-
"""
|
|
329
|
-
Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
|
|
327
|
+
"""Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
|
330
328
|
|
|
331
329
|
Args:
|
|
332
|
-
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
|
|
330
|
+
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|
|
333
331
|
|
|
334
332
|
Returns:
|
|
335
333
|
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|