dgenerate-ultralytics-headless 8.3.141__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/METADATA +1 -1
- dgenerate_ultralytics_headless-8.3.144.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +22 -19
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -158
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +13 -11
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +18 -12
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +84 -42
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.141.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.141.dist-info → dgenerate_ultralytics_headless-8.3.144.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import copy
|
4
|
+
from typing import Optional
|
4
5
|
|
5
6
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
6
7
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
@@ -18,12 +19,17 @@ class RTDETRTrainer(DetectionTrainer):
|
|
18
19
|
speed.
|
19
20
|
|
20
21
|
Attributes:
|
21
|
-
loss_names (
|
22
|
+
loss_names (tuple): Names of the loss components used for training.
|
22
23
|
data (dict): Dataset configuration containing class count and other parameters.
|
23
24
|
args (dict): Training arguments and hyperparameters.
|
24
25
|
save_dir (Path): Directory to save training results.
|
25
26
|
test_loader (DataLoader): DataLoader for validation/testing data.
|
26
27
|
|
28
|
+
Methods:
|
29
|
+
get_model: Initialize and return an RT-DETR model for object detection tasks.
|
30
|
+
build_dataset: Build and return an RT-DETR dataset for training or validation.
|
31
|
+
get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
|
32
|
+
|
27
33
|
Notes:
|
28
34
|
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
29
35
|
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
@@ -35,7 +41,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
35
41
|
>>> trainer.train()
|
36
42
|
"""
|
37
43
|
|
38
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
44
|
+
def get_model(self, cfg: Optional[dict] = None, weights: Optional[str] = None, verbose: bool = True):
|
39
45
|
"""
|
40
46
|
Initialize and return an RT-DETR model for object detection tasks.
|
41
47
|
|
@@ -52,7 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
52
58
|
model.load(weights)
|
53
59
|
return model
|
54
60
|
|
55
|
-
def build_dataset(self, img_path, mode="val", batch=None):
|
61
|
+
def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
|
56
62
|
"""
|
57
63
|
Build and return an RT-DETR dataset for training or validation.
|
58
64
|
|
@@ -80,6 +86,6 @@ class RTDETRTrainer(DetectionTrainer):
|
|
80
86
|
)
|
81
87
|
|
82
88
|
def get_validator(self):
|
83
|
-
"""
|
89
|
+
"""Return a DetectionValidator suitable for RT-DETR model validation."""
|
84
90
|
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
85
91
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
ultralytics/models/rtdetr/val.py
CHANGED
@@ -16,6 +16,22 @@ class RTDETRDataset(YOLODataset):
|
|
16
16
|
|
17
17
|
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
18
18
|
real-time detection and tracking tasks.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
augment (bool): Whether to apply data augmentation.
|
22
|
+
rect (bool): Whether to use rectangular training.
|
23
|
+
use_segments (bool): Whether to use segmentation masks.
|
24
|
+
use_keypoints (bool): Whether to use keypoint annotations.
|
25
|
+
imgsz (int): Target image size for training.
|
26
|
+
|
27
|
+
Methods:
|
28
|
+
load_image: Load one image from dataset index.
|
29
|
+
build_transforms: Build transformation pipeline for the dataset.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
Initialize an RT-DETR dataset
|
33
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
|
34
|
+
>>> image, hw = dataset.load_image(0)
|
19
35
|
"""
|
20
36
|
|
21
37
|
def __init__(self, *args, data=None, **kwargs):
|
@@ -27,7 +43,7 @@ class RTDETRDataset(YOLODataset):
|
|
27
43
|
|
28
44
|
Args:
|
29
45
|
*args (Any): Variable length argument list passed to the parent YOLODataset class.
|
30
|
-
data (
|
46
|
+
data (dict | None): Dictionary containing dataset information. If None, default values will be used.
|
31
47
|
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
|
32
48
|
"""
|
33
49
|
super().__init__(*args, data=data, **kwargs)
|
@@ -41,11 +57,12 @@ class RTDETRDataset(YOLODataset):
|
|
41
57
|
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
42
58
|
|
43
59
|
Returns:
|
44
|
-
im (
|
60
|
+
im (torch.Tensor): The loaded image.
|
45
61
|
resized_hw (tuple): Height and width of the resized image with shape (2,).
|
46
62
|
|
47
63
|
Examples:
|
48
|
-
|
64
|
+
Load an image from the dataset
|
65
|
+
>>> dataset = RTDETRDataset(img_path="path/to/images")
|
49
66
|
>>> image, hw = dataset.load_image(0)
|
50
67
|
"""
|
51
68
|
return super().load_image(i=i, rect_mode=rect_mode)
|
@@ -90,13 +107,22 @@ class RTDETRValidator(DetectionValidator):
|
|
90
107
|
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
91
108
|
post-processing, and updates evaluation metrics accordingly.
|
92
109
|
|
110
|
+
Attributes:
|
111
|
+
args (Namespace): Configuration arguments for validation.
|
112
|
+
data (dict): Dataset configuration dictionary.
|
113
|
+
|
114
|
+
Methods:
|
115
|
+
build_dataset: Build an RTDETR Dataset for validation.
|
116
|
+
postprocess: Apply Non-maximum suppression to prediction outputs.
|
117
|
+
|
93
118
|
Examples:
|
119
|
+
Initialize and run RT-DETR validation
|
94
120
|
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
95
121
|
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
96
122
|
>>> validator = RTDETRValidator(args=args)
|
97
123
|
>>> validator()
|
98
124
|
|
99
|
-
|
125
|
+
Notes:
|
100
126
|
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
101
127
|
"""
|
102
128
|
|
@@ -106,7 +132,8 @@ class RTDETRValidator(DetectionValidator):
|
|
106
132
|
|
107
133
|
Args:
|
108
134
|
img_path (str): Path to the folder containing images.
|
109
|
-
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for
|
135
|
+
mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
|
136
|
+
each mode.
|
110
137
|
batch (int, optional): Size of batches, this is for `rect`.
|
111
138
|
|
112
139
|
Returns:
|
@@ -129,10 +156,10 @@ class RTDETRValidator(DetectionValidator):
|
|
129
156
|
Apply Non-maximum suppression to prediction outputs.
|
130
157
|
|
131
158
|
Args:
|
132
|
-
preds (
|
159
|
+
preds (list | tuple | torch.Tensor): Raw predictions from the model.
|
133
160
|
|
134
161
|
Returns:
|
135
|
-
(
|
162
|
+
(list[torch.Tensor]): List of processed predictions for each image in batch.
|
136
163
|
"""
|
137
164
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
138
165
|
preds = [preds, None]
|
@@ -153,7 +180,7 @@ class RTDETRValidator(DetectionValidator):
|
|
153
180
|
|
154
181
|
def _prepare_batch(self, si, batch):
|
155
182
|
"""
|
156
|
-
|
183
|
+
Prepare a batch for validation by applying necessary transformations.
|
157
184
|
|
158
185
|
Args:
|
159
186
|
si (int): Batch index.
|
@@ -176,7 +203,7 @@ class RTDETRValidator(DetectionValidator):
|
|
176
203
|
|
177
204
|
def _prepare_pred(self, pred, pbatch):
|
178
205
|
"""
|
179
|
-
|
206
|
+
Prepare predictions by scaling bounding boxes to original image dimensions.
|
180
207
|
|
181
208
|
Args:
|
182
209
|
pred (torch.Tensor): Raw predictions.
|
ultralytics/models/sam/amg.py
CHANGED
@@ -11,7 +11,24 @@ import torch
|
|
11
11
|
def is_box_near_crop_edge(
|
12
12
|
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
13
13
|
) -> torch.Tensor:
|
14
|
-
"""
|
14
|
+
"""
|
15
|
+
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
19
|
+
crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.
|
20
|
+
orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.
|
21
|
+
atol (float, optional): Absolute tolerance for edge proximity detection.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
(torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
|
28
|
+
>>> crop_box = [0, 0, 200, 200]
|
29
|
+
>>> orig_box = [0, 0, 300, 300]
|
30
|
+
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
|
31
|
+
"""
|
15
32
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
16
33
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
17
34
|
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
@@ -52,7 +69,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
52
69
|
|
53
70
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
54
71
|
"""
|
55
|
-
|
72
|
+
Compute the stability score for a batch of masks.
|
56
73
|
|
57
74
|
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
58
75
|
high and low values.
|
@@ -90,7 +107,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
90
107
|
|
91
108
|
|
92
109
|
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
93
|
-
"""
|
110
|
+
"""Generate point grids for multiple crop layers with varying scales and densities."""
|
94
111
|
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
95
112
|
|
96
113
|
|
@@ -98,7 +115,7 @@ def generate_crop_boxes(
|
|
98
115
|
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
99
116
|
) -> Tuple[List[List[int]], List[int]]:
|
100
117
|
"""
|
101
|
-
|
118
|
+
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
102
119
|
|
103
120
|
Args:
|
104
121
|
im_size (Tuple[int, ...]): Height and width of the input image.
|
@@ -106,8 +123,8 @@ def generate_crop_boxes(
|
|
106
123
|
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
107
124
|
|
108
125
|
Returns:
|
109
|
-
(List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
110
|
-
(List[int]): List of layer indices corresponding to each crop box.
|
126
|
+
crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
127
|
+
layer_idxs (List[int]): List of layer indices corresponding to each crop box.
|
111
128
|
|
112
129
|
Examples:
|
113
130
|
>>> im_size = (800, 1200) # Height, width
|
@@ -124,7 +141,7 @@ def generate_crop_boxes(
|
|
124
141
|
layer_idxs.append(0)
|
125
142
|
|
126
143
|
def crop_len(orig_len, n_crops, overlap):
|
127
|
-
"""
|
144
|
+
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
|
128
145
|
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
129
146
|
|
130
147
|
for i_layer in range(n_layers):
|
@@ -179,16 +196,17 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|
179
196
|
|
180
197
|
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
181
198
|
"""
|
182
|
-
|
199
|
+
Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
183
200
|
|
184
201
|
Args:
|
185
202
|
mask (np.ndarray): Binary mask to process.
|
186
203
|
area_thresh (float): Area threshold below which regions will be removed.
|
187
|
-
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
204
|
+
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
205
|
+
regions.
|
188
206
|
|
189
207
|
Returns:
|
190
|
-
(np.ndarray): Processed binary mask with small regions removed.
|
191
|
-
(bool): Whether any regions were modified.
|
208
|
+
processed_mask (np.ndarray): Processed binary mask with small regions removed.
|
209
|
+
modified (bool): Whether any regions were modified.
|
192
210
|
|
193
211
|
Examples:
|
194
212
|
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
@@ -216,7 +234,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
216
234
|
|
217
235
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
218
236
|
"""
|
219
|
-
|
237
|
+
Calculate bounding boxes in XYXY format around binary masks.
|
220
238
|
|
221
239
|
Args:
|
222
240
|
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
ultralytics/models/sam/build.py
CHANGED
@@ -21,7 +21,7 @@ from .modules.transformer import TwoWayTransformer
|
|
21
21
|
|
22
22
|
|
23
23
|
def build_sam_vit_h(checkpoint=None):
|
24
|
-
"""
|
24
|
+
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
25
25
|
return _build_sam(
|
26
26
|
encoder_embed_dim=1280,
|
27
27
|
encoder_depth=32,
|
@@ -32,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
|
|
32
32
|
|
33
33
|
|
34
34
|
def build_sam_vit_l(checkpoint=None):
|
35
|
-
"""
|
35
|
+
"""Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
36
36
|
return _build_sam(
|
37
37
|
encoder_embed_dim=1024,
|
38
38
|
encoder_depth=24,
|
@@ -43,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
|
|
43
43
|
|
44
44
|
|
45
45
|
def build_sam_vit_b(checkpoint=None):
|
46
|
-
"""
|
46
|
+
"""Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
|
47
47
|
return _build_sam(
|
48
48
|
encoder_embed_dim=768,
|
49
49
|
encoder_depth=12,
|
@@ -54,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
|
|
54
54
|
|
55
55
|
|
56
56
|
def build_mobile_sam(checkpoint=None):
|
57
|
-
"""
|
57
|
+
"""Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
58
58
|
return _build_sam(
|
59
59
|
encoder_embed_dim=[64, 128, 160, 320],
|
60
60
|
encoder_depth=[2, 2, 6, 2],
|
@@ -66,7 +66,7 @@ def build_mobile_sam(checkpoint=None):
|
|
66
66
|
|
67
67
|
|
68
68
|
def build_sam2_t(checkpoint=None):
|
69
|
-
"""
|
69
|
+
"""Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
70
70
|
return _build_sam2(
|
71
71
|
encoder_embed_dim=96,
|
72
72
|
encoder_stages=[1, 2, 7, 2],
|
@@ -79,7 +79,7 @@ def build_sam2_t(checkpoint=None):
|
|
79
79
|
|
80
80
|
|
81
81
|
def build_sam2_s(checkpoint=None):
|
82
|
-
"""
|
82
|
+
"""Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
83
83
|
return _build_sam2(
|
84
84
|
encoder_embed_dim=96,
|
85
85
|
encoder_stages=[1, 2, 11, 2],
|
@@ -92,7 +92,7 @@ def build_sam2_s(checkpoint=None):
|
|
92
92
|
|
93
93
|
|
94
94
|
def build_sam2_b(checkpoint=None):
|
95
|
-
"""
|
95
|
+
"""Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
|
96
96
|
return _build_sam2(
|
97
97
|
encoder_embed_dim=112,
|
98
98
|
encoder_stages=[2, 3, 16, 3],
|
@@ -106,7 +106,7 @@ def build_sam2_b(checkpoint=None):
|
|
106
106
|
|
107
107
|
|
108
108
|
def build_sam2_l(checkpoint=None):
|
109
|
-
"""
|
109
|
+
"""Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
110
110
|
return _build_sam2(
|
111
111
|
encoder_embed_dim=144,
|
112
112
|
encoder_stages=[2, 6, 36, 4],
|
@@ -127,15 +127,15 @@ def _build_sam(
|
|
127
127
|
mobile_sam=False,
|
128
128
|
):
|
129
129
|
"""
|
130
|
-
|
130
|
+
Build a Segment Anything Model (SAM) with specified encoder parameters.
|
131
131
|
|
132
132
|
Args:
|
133
133
|
encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
|
134
134
|
encoder_depth (int | List[int]): Depth of the encoder.
|
135
135
|
encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
|
136
136
|
encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
|
137
|
-
checkpoint (str | None): Path to the model checkpoint file.
|
138
|
-
mobile_sam (bool): Whether to build a Mobile-SAM model.
|
137
|
+
checkpoint (str | None, optional): Path to the model checkpoint file.
|
138
|
+
mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
|
139
139
|
|
140
140
|
Returns:
|
141
141
|
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
@@ -224,17 +224,17 @@ def _build_sam2(
|
|
224
224
|
checkpoint=None,
|
225
225
|
):
|
226
226
|
"""
|
227
|
-
|
227
|
+
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
228
228
|
|
229
229
|
Args:
|
230
|
-
encoder_embed_dim (int): Embedding dimension for the encoder.
|
231
|
-
encoder_stages (List[int]): Number of blocks in each stage of the encoder.
|
232
|
-
encoder_num_heads (int): Number of attention heads in the encoder.
|
233
|
-
encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
|
234
|
-
encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
|
235
|
-
encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
|
236
|
-
encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
|
237
|
-
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
|
230
|
+
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
231
|
+
encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.
|
232
|
+
encoder_num_heads (int, optional): Number of attention heads in the encoder.
|
233
|
+
encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.
|
234
|
+
encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.
|
235
|
+
encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.
|
236
|
+
encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.
|
237
|
+
checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
|
238
238
|
|
239
239
|
Returns:
|
240
240
|
(SAM2Model): A configured and initialized SAM2 model.
|
@@ -326,10 +326,10 @@ sam_model_map = {
|
|
326
326
|
|
327
327
|
def build_sam(ckpt="sam_b.pt"):
|
328
328
|
"""
|
329
|
-
|
329
|
+
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
330
330
|
|
331
331
|
Args:
|
332
|
-
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
|
332
|
+
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|
333
333
|
|
334
334
|
Returns:
|
335
335
|
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
ultralytics/models/sam/model.py
CHANGED
@@ -15,6 +15,7 @@ Key Features:
|
|
15
15
|
"""
|
16
16
|
|
17
17
|
from pathlib import Path
|
18
|
+
from typing import Dict, Type
|
18
19
|
|
19
20
|
from ultralytics.engine.model import Model
|
20
21
|
from ultralytics.utils.torch_utils import model_info
|
@@ -36,8 +37,8 @@ class SAM(Model):
|
|
36
37
|
task (str): The task type, set to "segment" for SAM models.
|
37
38
|
|
38
39
|
Methods:
|
39
|
-
predict:
|
40
|
-
info:
|
40
|
+
predict: Perform segmentation prediction on the given image or video source.
|
41
|
+
info: Log information about the SAM model.
|
41
42
|
|
42
43
|
Examples:
|
43
44
|
>>> sam = SAM("sam_b.pt")
|
@@ -46,7 +47,7 @@ class SAM(Model):
|
|
46
47
|
>>> print(f"Detected {len(r.masks)} masks")
|
47
48
|
"""
|
48
49
|
|
49
|
-
def __init__(self, model="sam_b.pt") -> None:
|
50
|
+
def __init__(self, model: str = "sam_b.pt") -> None:
|
50
51
|
"""
|
51
52
|
Initialize the SAM (Segment Anything Model) instance.
|
52
53
|
|
@@ -81,7 +82,7 @@ class SAM(Model):
|
|
81
82
|
|
82
83
|
self.model = build_sam(weights)
|
83
84
|
|
84
|
-
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
85
|
+
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
85
86
|
"""
|
86
87
|
Perform segmentation prediction on the given image or video source.
|
87
88
|
|
@@ -108,7 +109,7 @@ class SAM(Model):
|
|
108
109
|
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
109
110
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
110
111
|
|
111
|
-
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
112
|
+
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
112
113
|
"""
|
113
114
|
Perform segmentation prediction on the given image or video source.
|
114
115
|
|
@@ -134,7 +135,7 @@ class SAM(Model):
|
|
134
135
|
"""
|
135
136
|
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
136
137
|
|
137
|
-
def info(self, detailed=False, verbose=True):
|
138
|
+
def info(self, detailed: bool = False, verbose: bool = True):
|
138
139
|
"""
|
139
140
|
Log information about the SAM model.
|
140
141
|
|
@@ -153,13 +154,13 @@ class SAM(Model):
|
|
153
154
|
return model_info(self.model, detailed=detailed, verbose=verbose)
|
154
155
|
|
155
156
|
@property
|
156
|
-
def task_map(self):
|
157
|
+
def task_map(self) -> Dict[str, Dict[str, Type[Predictor]]]:
|
157
158
|
"""
|
158
159
|
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
159
160
|
|
160
161
|
Returns:
|
161
|
-
(Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
|
162
|
-
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
162
|
+
(Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
|
163
|
+
Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
163
164
|
|
164
165
|
Examples:
|
165
166
|
>>> sam = SAM("sam_b.pt")
|