ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +527 -67
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +44 -37
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +84 -56
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.28.dist-info/METADATA +0 -373
- ultralytics-8.1.28.dist-info/RECORD +0 -197
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/models/sam/amg.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import math
|
4
4
|
from itertools import product
|
@@ -11,7 +11,7 @@ import torch
|
|
11
11
|
def is_box_near_crop_edge(
|
12
12
|
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
13
13
|
) -> torch.Tensor:
|
14
|
-
"""
|
14
|
+
"""Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
|
15
15
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
16
16
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
17
17
|
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
@@ -22,7 +22,7 @@ def is_box_near_crop_edge(
|
|
22
22
|
|
23
23
|
|
24
24
|
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
25
|
-
"""
|
25
|
+
"""Yields batches of data from input arguments with specified batch size for efficient processing."""
|
26
26
|
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
27
27
|
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
28
28
|
for b in range(n_batches):
|
@@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|
33
33
|
"""
|
34
34
|
Computes the stability score for a batch of masks.
|
35
35
|
|
36
|
-
The stability score is the IoU between
|
37
|
-
and low values.
|
36
|
+
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
37
|
+
high and low values.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
masks (torch.Tensor): Batch of predicted mask logits.
|
41
|
+
mask_threshold (float): Threshold value for creating binary masks.
|
42
|
+
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
(torch.Tensor): Stability scores for each mask in the batch.
|
38
46
|
|
39
47
|
Notes:
|
40
48
|
- One mask is always contained inside the other.
|
41
|
-
-
|
49
|
+
- Memory is saved by preventing unnecessary cast to torch.int64.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
53
|
+
>>> mask_threshold = 0.5
|
54
|
+
>>> threshold_offset = 0.1
|
55
|
+
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
42
56
|
"""
|
43
57
|
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
44
58
|
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
@@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|
46
60
|
|
47
61
|
|
48
62
|
def build_point_grid(n_per_side: int) -> np.ndarray:
|
49
|
-
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
|
63
|
+
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
|
50
64
|
offset = 1 / (2 * n_per_side)
|
51
65
|
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
52
66
|
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
@@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
55
69
|
|
56
70
|
|
57
71
|
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
58
|
-
"""
|
72
|
+
"""Generates point grids for multiple crop layers with varying scales and densities."""
|
59
73
|
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
60
74
|
|
61
75
|
|
62
76
|
def generate_crop_boxes(
|
63
77
|
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
64
78
|
) -> Tuple[List[List[int]], List[int]]:
|
65
|
-
"""
|
66
|
-
Generates a list of crop boxes of different sizes.
|
67
|
-
|
68
|
-
Each layer has (2**i)**2 boxes for the ith layer.
|
69
|
-
"""
|
79
|
+
"""Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
|
70
80
|
crop_boxes, layer_idxs = [], []
|
71
81
|
im_h, im_w = im_size
|
72
82
|
short_side = min(im_h, im_w)
|
@@ -99,7 +109,7 @@ def generate_crop_boxes(
|
|
99
109
|
|
100
110
|
|
101
111
|
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
102
|
-
"""Uncrop bounding boxes by adding the crop box offset."""
|
112
|
+
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
103
113
|
x0, y0, _, _ = crop_box
|
104
114
|
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
105
115
|
# Check if boxes has a channel dimension
|
@@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
109
119
|
|
110
120
|
|
111
121
|
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
112
|
-
"""Uncrop points by adding the crop box offset."""
|
122
|
+
"""Uncrop points by adding the crop box offset to their coordinates."""
|
113
123
|
x0, y0, _, _ = crop_box
|
114
124
|
offset = torch.tensor([[x0, y0]], device=points.device)
|
115
125
|
# Check if points has a channel dimension
|
@@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|
119
129
|
|
120
130
|
|
121
131
|
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
122
|
-
"""Uncrop masks by padding them to the original image size."""
|
132
|
+
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
123
133
|
x0, y0, x1, y1 = crop_box
|
124
134
|
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
125
135
|
return masks
|
@@ -130,10 +140,10 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|
130
140
|
|
131
141
|
|
132
142
|
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
133
|
-
"""
|
143
|
+
"""Removes small disconnected regions or holes in a mask based on area threshold and mode."""
|
134
144
|
import cv2 # type: ignore
|
135
145
|
|
136
|
-
assert mode in {"holes", "islands"}
|
146
|
+
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
137
147
|
correct_holes = mode == "holes"
|
138
148
|
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
139
149
|
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
@@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
150
160
|
|
151
161
|
|
152
162
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
153
|
-
"""
|
154
|
-
Calculates boxes in XYXY format around masks.
|
155
|
-
|
156
|
-
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
157
|
-
"""
|
163
|
+
"""Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
|
158
164
|
# torch.max below raises an error on empty inputs, just skip in this case
|
159
165
|
if torch.numel(masks) == 0:
|
160
166
|
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
ultralytics/models/sam/build.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
4
4
|
# All rights reserved.
|
@@ -11,15 +11,17 @@ from functools import partial
|
|
11
11
|
import torch
|
12
12
|
|
13
13
|
from ultralytics.utils.downloads import attempt_download_asset
|
14
|
+
|
14
15
|
from .modules.decoders import MaskDecoder
|
15
|
-
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
16
|
-
from .modules.
|
16
|
+
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
17
|
+
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
18
|
+
from .modules.sam import SAM2Model, SAMModel
|
17
19
|
from .modules.tiny_encoder import TinyViT
|
18
20
|
from .modules.transformer import TwoWayTransformer
|
19
21
|
|
20
22
|
|
21
23
|
def build_sam_vit_h(checkpoint=None):
|
22
|
-
"""
|
24
|
+
"""Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
23
25
|
return _build_sam(
|
24
26
|
encoder_embed_dim=1280,
|
25
27
|
encoder_depth=32,
|
@@ -30,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
|
|
30
32
|
|
31
33
|
|
32
34
|
def build_sam_vit_l(checkpoint=None):
|
33
|
-
"""
|
35
|
+
"""Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
34
36
|
return _build_sam(
|
35
37
|
encoder_embed_dim=1024,
|
36
38
|
encoder_depth=24,
|
@@ -41,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
|
|
41
43
|
|
42
44
|
|
43
45
|
def build_sam_vit_b(checkpoint=None):
|
44
|
-
"""
|
46
|
+
"""Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
|
45
47
|
return _build_sam(
|
46
48
|
encoder_embed_dim=768,
|
47
49
|
encoder_depth=12,
|
@@ -52,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
|
|
52
54
|
|
53
55
|
|
54
56
|
def build_mobile_sam(checkpoint=None):
|
55
|
-
"""
|
57
|
+
"""Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
56
58
|
return _build_sam(
|
57
59
|
encoder_embed_dim=[64, 128, 160, 320],
|
58
60
|
encoder_depth=[2, 2, 6, 2],
|
@@ -63,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
|
|
63
65
|
)
|
64
66
|
|
65
67
|
|
68
|
+
def build_sam2_t(checkpoint=None):
|
69
|
+
"""Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
70
|
+
return _build_sam2(
|
71
|
+
encoder_embed_dim=96,
|
72
|
+
encoder_stages=[1, 2, 7, 2],
|
73
|
+
encoder_num_heads=1,
|
74
|
+
encoder_global_att_blocks=[5, 7, 9],
|
75
|
+
encoder_window_spec=[8, 4, 14, 7],
|
76
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
77
|
+
checkpoint=checkpoint,
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
def build_sam2_s(checkpoint=None):
|
82
|
+
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
83
|
+
return _build_sam2(
|
84
|
+
encoder_embed_dim=96,
|
85
|
+
encoder_stages=[1, 2, 11, 2],
|
86
|
+
encoder_num_heads=1,
|
87
|
+
encoder_global_att_blocks=[7, 10, 13],
|
88
|
+
encoder_window_spec=[8, 4, 14, 7],
|
89
|
+
encoder_backbone_channel_list=[768, 384, 192, 96],
|
90
|
+
checkpoint=checkpoint,
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def build_sam2_b(checkpoint=None):
|
95
|
+
"""Builds and returns a SAM2 base-size model with specified architecture parameters."""
|
96
|
+
return _build_sam2(
|
97
|
+
encoder_embed_dim=112,
|
98
|
+
encoder_stages=[2, 3, 16, 3],
|
99
|
+
encoder_num_heads=2,
|
100
|
+
encoder_global_att_blocks=[12, 16, 20],
|
101
|
+
encoder_window_spec=[8, 4, 14, 7],
|
102
|
+
encoder_window_spatial_size=[14, 14],
|
103
|
+
encoder_backbone_channel_list=[896, 448, 224, 112],
|
104
|
+
checkpoint=checkpoint,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def build_sam2_l(checkpoint=None):
|
109
|
+
"""Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
110
|
+
return _build_sam2(
|
111
|
+
encoder_embed_dim=144,
|
112
|
+
encoder_stages=[2, 6, 36, 4],
|
113
|
+
encoder_num_heads=2,
|
114
|
+
encoder_global_att_blocks=[23, 33, 43],
|
115
|
+
encoder_window_spec=[8, 4, 16, 8],
|
116
|
+
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
117
|
+
checkpoint=checkpoint,
|
118
|
+
)
|
119
|
+
|
120
|
+
|
66
121
|
def _build_sam(
|
67
|
-
encoder_embed_dim,
|
122
|
+
encoder_embed_dim,
|
123
|
+
encoder_depth,
|
124
|
+
encoder_num_heads,
|
125
|
+
encoder_global_attn_indexes,
|
126
|
+
checkpoint=None,
|
127
|
+
mobile_sam=False,
|
68
128
|
):
|
69
|
-
"""
|
129
|
+
"""
|
130
|
+
Builds a Segment Anything Model (SAM) with specified encoder parameters.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
|
134
|
+
encoder_depth (int | List[int]): Depth of the encoder.
|
135
|
+
encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
|
136
|
+
encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
|
137
|
+
checkpoint (str | None): Path to the model checkpoint file.
|
138
|
+
mobile_sam (bool): Whether to build a Mobile-SAM model.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
142
|
+
|
143
|
+
Examples:
|
144
|
+
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
|
145
|
+
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
|
146
|
+
"""
|
70
147
|
prompt_embed_dim = 256
|
71
148
|
image_size = 1024
|
72
149
|
vit_patch_size = 16
|
@@ -104,7 +181,7 @@ def _build_sam(
|
|
104
181
|
out_chans=prompt_embed_dim,
|
105
182
|
)
|
106
183
|
)
|
107
|
-
sam =
|
184
|
+
sam = SAMModel(
|
108
185
|
image_encoder=image_encoder,
|
109
186
|
prompt_encoder=PromptEncoder(
|
110
187
|
embed_dim=prompt_embed_dim,
|
@@ -133,21 +210,142 @@ def _build_sam(
|
|
133
210
|
state_dict = torch.load(f)
|
134
211
|
sam.load_state_dict(state_dict)
|
135
212
|
sam.eval()
|
136
|
-
# sam.load_state_dict(torch.load(checkpoint), strict=True)
|
137
|
-
# sam.eval()
|
138
213
|
return sam
|
139
214
|
|
140
215
|
|
216
|
+
def _build_sam2(
|
217
|
+
encoder_embed_dim=1280,
|
218
|
+
encoder_stages=[2, 6, 36, 4],
|
219
|
+
encoder_num_heads=2,
|
220
|
+
encoder_global_att_blocks=[7, 15, 23, 31],
|
221
|
+
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
222
|
+
encoder_window_spatial_size=[7, 7],
|
223
|
+
encoder_window_spec=[8, 4, 16, 8],
|
224
|
+
checkpoint=None,
|
225
|
+
):
|
226
|
+
"""
|
227
|
+
Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
encoder_embed_dim (int): Embedding dimension for the encoder.
|
231
|
+
encoder_stages (List[int]): Number of blocks in each stage of the encoder.
|
232
|
+
encoder_num_heads (int): Number of attention heads in the encoder.
|
233
|
+
encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
|
234
|
+
encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
|
235
|
+
encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
|
236
|
+
encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
|
237
|
+
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
(SAM2Model): A configured and initialized SAM2 model.
|
241
|
+
|
242
|
+
Examples:
|
243
|
+
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
|
244
|
+
>>> sam2_model.eval()
|
245
|
+
"""
|
246
|
+
image_encoder = ImageEncoder(
|
247
|
+
trunk=Hiera(
|
248
|
+
embed_dim=encoder_embed_dim,
|
249
|
+
num_heads=encoder_num_heads,
|
250
|
+
stages=encoder_stages,
|
251
|
+
global_att_blocks=encoder_global_att_blocks,
|
252
|
+
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
253
|
+
window_spec=encoder_window_spec,
|
254
|
+
),
|
255
|
+
neck=FpnNeck(
|
256
|
+
d_model=256,
|
257
|
+
backbone_channel_list=encoder_backbone_channel_list,
|
258
|
+
fpn_top_down_levels=[2, 3],
|
259
|
+
fpn_interp_model="nearest",
|
260
|
+
),
|
261
|
+
scalp=1,
|
262
|
+
)
|
263
|
+
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
264
|
+
memory_encoder = MemoryEncoder(out_dim=64)
|
265
|
+
|
266
|
+
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
|
267
|
+
sam2 = SAM2Model(
|
268
|
+
image_encoder=image_encoder,
|
269
|
+
memory_attention=memory_attention,
|
270
|
+
memory_encoder=memory_encoder,
|
271
|
+
num_maskmem=7,
|
272
|
+
image_size=1024,
|
273
|
+
sigmoid_scale_for_mem_enc=20.0,
|
274
|
+
sigmoid_bias_for_mem_enc=-10.0,
|
275
|
+
use_mask_input_as_output_without_sam=True,
|
276
|
+
directly_add_no_mem_embed=True,
|
277
|
+
use_high_res_features_in_sam=True,
|
278
|
+
multimask_output_in_sam=True,
|
279
|
+
iou_prediction_use_sigmoid=True,
|
280
|
+
use_obj_ptrs_in_encoder=True,
|
281
|
+
add_tpos_enc_to_obj_ptrs=True,
|
282
|
+
only_obj_ptrs_in_the_past_for_eval=True,
|
283
|
+
pred_obj_scores=True,
|
284
|
+
pred_obj_scores_mlp=True,
|
285
|
+
fixed_no_obj_ptr=True,
|
286
|
+
multimask_output_for_tracking=True,
|
287
|
+
use_multimask_token_for_obj_ptr=True,
|
288
|
+
multimask_min_pt_num=0,
|
289
|
+
multimask_max_pt_num=1,
|
290
|
+
use_mlp_for_obj_ptr_proj=True,
|
291
|
+
compile_image_encoder=False,
|
292
|
+
no_obj_embed_spatial=is_sam2_1,
|
293
|
+
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
|
294
|
+
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
|
295
|
+
sam_mask_decoder_extra_args=dict(
|
296
|
+
dynamic_multimask_via_stability=True,
|
297
|
+
dynamic_multimask_stability_delta=0.05,
|
298
|
+
dynamic_multimask_stability_thresh=0.98,
|
299
|
+
),
|
300
|
+
)
|
301
|
+
|
302
|
+
if checkpoint is not None:
|
303
|
+
checkpoint = attempt_download_asset(checkpoint)
|
304
|
+
with open(checkpoint, "rb") as f:
|
305
|
+
state_dict = torch.load(f)["model"]
|
306
|
+
sam2.load_state_dict(state_dict)
|
307
|
+
sam2.eval()
|
308
|
+
return sam2
|
309
|
+
|
310
|
+
|
141
311
|
sam_model_map = {
|
142
312
|
"sam_h.pt": build_sam_vit_h,
|
143
313
|
"sam_l.pt": build_sam_vit_l,
|
144
314
|
"sam_b.pt": build_sam_vit_b,
|
145
315
|
"mobile_sam.pt": build_mobile_sam,
|
316
|
+
"sam2_t.pt": build_sam2_t,
|
317
|
+
"sam2_s.pt": build_sam2_s,
|
318
|
+
"sam2_b.pt": build_sam2_b,
|
319
|
+
"sam2_l.pt": build_sam2_l,
|
320
|
+
"sam2.1_t.pt": build_sam2_t,
|
321
|
+
"sam2.1_s.pt": build_sam2_s,
|
322
|
+
"sam2.1_b.pt": build_sam2_b,
|
323
|
+
"sam2.1_l.pt": build_sam2_l,
|
146
324
|
}
|
147
325
|
|
148
326
|
|
149
327
|
def build_sam(ckpt="sam_b.pt"):
|
150
|
-
"""
|
328
|
+
"""
|
329
|
+
Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
336
|
+
|
337
|
+
Raises:
|
338
|
+
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
|
339
|
+
|
340
|
+
Examples:
|
341
|
+
>>> sam_model = build_sam("sam_b.pt")
|
342
|
+
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
|
343
|
+
|
344
|
+
Notes:
|
345
|
+
Supported pre-defined models include:
|
346
|
+
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
|
347
|
+
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
|
348
|
+
"""
|
151
349
|
model_builder = None
|
152
350
|
ckpt = str(ckpt) # to allow Path ckpt types
|
153
351
|
for k in sam_model_map.keys():
|
ultralytics/models/sam/model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
3
|
SAM model interface.
|
4
4
|
|
@@ -18,40 +18,68 @@ from pathlib import Path
|
|
18
18
|
|
19
19
|
from ultralytics.engine.model import Model
|
20
20
|
from ultralytics.utils.torch_utils import model_info
|
21
|
+
|
21
22
|
from .build import build_sam
|
22
|
-
from .predict import Predictor
|
23
|
+
from .predict import Predictor, SAM2Predictor
|
23
24
|
|
24
25
|
|
25
26
|
class SAM(Model):
|
26
27
|
"""
|
27
|
-
SAM (Segment Anything Model) interface class.
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
28
|
+
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
29
|
+
|
30
|
+
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
|
31
|
+
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
|
32
|
+
boxes, points, or labels, and features zero-shot performance capabilities.
|
33
|
+
|
34
|
+
Attributes:
|
35
|
+
model (torch.nn.Module): The loaded SAM model.
|
36
|
+
is_sam2 (bool): Indicates whether the model is SAM2 variant.
|
37
|
+
task (str): The task type, set to "segment" for SAM models.
|
38
|
+
|
39
|
+
Methods:
|
40
|
+
predict: Performs segmentation prediction on the given image or video source.
|
41
|
+
info: Logs information about the SAM model.
|
42
|
+
|
43
|
+
Examples:
|
44
|
+
>>> sam = SAM("sam_b.pt")
|
45
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
46
|
+
>>> for r in results:
|
47
|
+
>>> print(f"Detected {len(r.masks)} masks")
|
32
48
|
"""
|
33
49
|
|
34
50
|
def __init__(self, model="sam_b.pt") -> None:
|
35
51
|
"""
|
36
|
-
Initializes the SAM
|
52
|
+
Initializes the SAM (Segment Anything Model) instance.
|
37
53
|
|
38
54
|
Args:
|
39
55
|
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
40
56
|
|
41
57
|
Raises:
|
42
58
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
59
|
+
|
60
|
+
Examples:
|
61
|
+
>>> sam = SAM("sam_b.pt")
|
62
|
+
>>> print(sam.is_sam2)
|
43
63
|
"""
|
44
|
-
if model and Path(model).suffix not in
|
64
|
+
if model and Path(model).suffix not in {".pt", ".pth"}:
|
45
65
|
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
66
|
+
self.is_sam2 = "sam2" in Path(model).stem
|
46
67
|
super().__init__(model=model, task="segment")
|
47
68
|
|
48
69
|
def _load(self, weights: str, task=None):
|
49
70
|
"""
|
50
71
|
Loads the specified weights into the SAM model.
|
51
72
|
|
73
|
+
This method initializes the SAM model with the provided weights file, setting up the model architecture
|
74
|
+
and loading the pre-trained parameters.
|
75
|
+
|
52
76
|
Args:
|
53
|
-
weights (str): Path to the weights file.
|
54
|
-
task (str
|
77
|
+
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
78
|
+
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
79
|
+
|
80
|
+
Examples:
|
81
|
+
>>> sam = SAM("sam_b.pt")
|
82
|
+
>>> sam._load("path/to/custom_weights.pt")
|
55
83
|
"""
|
56
84
|
self.model = build_sam(weights)
|
57
85
|
|
@@ -60,33 +88,51 @@ class SAM(Model):
|
|
60
88
|
Performs segmentation prediction on the given image or video source.
|
61
89
|
|
62
90
|
Args:
|
63
|
-
source (str): Path to the image or video file, or a PIL.Image object, or
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
91
|
+
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
|
92
|
+
a numpy.ndarray object.
|
93
|
+
stream (bool): If True, enables real-time streaming.
|
94
|
+
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
95
|
+
points (List[List[float]] | None): List of points for prompted segmentation.
|
96
|
+
labels (List[int] | None): List of labels for prompted segmentation.
|
97
|
+
**kwargs (Any): Additional keyword arguments for prediction.
|
68
98
|
|
69
99
|
Returns:
|
70
|
-
(
|
100
|
+
(List): The model predictions.
|
101
|
+
|
102
|
+
Examples:
|
103
|
+
>>> sam = SAM("sam_b.pt")
|
104
|
+
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
105
|
+
>>> for r in results:
|
106
|
+
... print(f"Detected {len(r.masks)} masks")
|
71
107
|
"""
|
72
108
|
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
73
|
-
kwargs
|
109
|
+
kwargs = {**overrides, **kwargs}
|
74
110
|
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
75
111
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
76
112
|
|
77
113
|
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
78
114
|
"""
|
79
|
-
|
115
|
+
Performs segmentation prediction on the given image or video source.
|
116
|
+
|
117
|
+
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
118
|
+
for segmentation tasks.
|
80
119
|
|
81
120
|
Args:
|
82
|
-
source (str): Path to the image or video file, or a PIL.Image
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
121
|
+
source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
|
122
|
+
object, or a numpy.ndarray object.
|
123
|
+
stream (bool): If True, enables real-time streaming.
|
124
|
+
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
125
|
+
points (List[List[float]] | None): List of points for prompted segmentation.
|
126
|
+
labels (List[int] | None): List of labels for prompted segmentation.
|
127
|
+
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
|
87
128
|
|
88
129
|
Returns:
|
89
|
-
(
|
130
|
+
(List): The model predictions, typically containing segmentation masks and other relevant information.
|
131
|
+
|
132
|
+
Examples:
|
133
|
+
>>> sam = SAM("sam_b.pt")
|
134
|
+
>>> results = sam("image.jpg", points=[[500, 375]])
|
135
|
+
>>> print(f"Detected {len(results[0].masks)} masks")
|
90
136
|
"""
|
91
137
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
92
138
|
|
@@ -94,12 +140,20 @@ class SAM(Model):
|
|
94
140
|
"""
|
95
141
|
Logs information about the SAM model.
|
96
142
|
|
143
|
+
This method provides details about the Segment Anything Model (SAM), including its architecture,
|
144
|
+
parameters, and computational requirements.
|
145
|
+
|
97
146
|
Args:
|
98
|
-
detailed (bool
|
99
|
-
verbose (bool
|
147
|
+
detailed (bool): If True, displays detailed information about the model layers and operations.
|
148
|
+
verbose (bool): If True, prints the information to the console.
|
100
149
|
|
101
150
|
Returns:
|
102
|
-
(tuple): A tuple containing the model's information.
|
151
|
+
(tuple): A tuple containing the model's information (string representations of the model).
|
152
|
+
|
153
|
+
Examples:
|
154
|
+
>>> sam = SAM("sam_b.pt")
|
155
|
+
>>> info = sam.info()
|
156
|
+
>>> print(info[0]) # Print summary information
|
103
157
|
"""
|
104
158
|
return model_info(self.model, detailed=detailed, verbose=verbose)
|
105
159
|
|
@@ -109,6 +163,13 @@ class SAM(Model):
|
|
109
163
|
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
110
164
|
|
111
165
|
Returns:
|
112
|
-
(
|
166
|
+
(Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
|
167
|
+
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
168
|
+
|
169
|
+
Examples:
|
170
|
+
>>> sam = SAM("sam_b.pt")
|
171
|
+
>>> task_map = sam.task_map
|
172
|
+
>>> print(task_map)
|
173
|
+
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
113
174
|
"""
|
114
|
-
return {"segment": {"predictor": Predictor}}
|
175
|
+
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
@@ -1 +1 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|