dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RTDETRDataset(YOLODataset):
|
|
19
|
-
"""
|
|
20
|
-
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
|
19
|
+
"""Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
|
21
20
|
|
|
22
21
|
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
|
23
22
|
real-time detection and tracking tasks.
|
|
@@ -36,12 +35,11 @@ class RTDETRDataset(YOLODataset):
|
|
|
36
35
|
Examples:
|
|
37
36
|
Initialize an RT-DETR dataset
|
|
38
37
|
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
|
|
39
|
-
>>> image, hw = dataset.load_image(0)
|
|
38
|
+
>>> image, hw0, hw = dataset.load_image(0)
|
|
40
39
|
"""
|
|
41
40
|
|
|
42
41
|
def __init__(self, *args, data=None, **kwargs):
|
|
43
|
-
"""
|
|
44
|
-
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
|
42
|
+
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
|
45
43
|
|
|
46
44
|
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
|
47
45
|
model, building upon the base YOLODataset functionality.
|
|
@@ -54,27 +52,26 @@ class RTDETRDataset(YOLODataset):
|
|
|
54
52
|
super().__init__(*args, data=data, **kwargs)
|
|
55
53
|
|
|
56
54
|
def load_image(self, i, rect_mode=False):
|
|
57
|
-
"""
|
|
58
|
-
Load one image from dataset index 'i'.
|
|
55
|
+
"""Load one image from dataset index 'i'.
|
|
59
56
|
|
|
60
57
|
Args:
|
|
61
58
|
i (int): Index of the image to load.
|
|
62
59
|
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
|
63
60
|
|
|
64
61
|
Returns:
|
|
65
|
-
im (
|
|
66
|
-
|
|
62
|
+
im (np.ndarray): Loaded image as a NumPy array.
|
|
63
|
+
hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
|
|
64
|
+
hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
|
|
67
65
|
|
|
68
66
|
Examples:
|
|
69
67
|
Load an image from the dataset
|
|
70
68
|
>>> dataset = RTDETRDataset(img_path="path/to/images")
|
|
71
|
-
>>> image, hw = dataset.load_image(0)
|
|
69
|
+
>>> image, hw0, hw = dataset.load_image(0)
|
|
72
70
|
"""
|
|
73
71
|
return super().load_image(i=i, rect_mode=rect_mode)
|
|
74
72
|
|
|
75
73
|
def build_transforms(self, hyp=None):
|
|
76
|
-
"""
|
|
77
|
-
Build transformation pipeline for the dataset.
|
|
74
|
+
"""Build transformation pipeline for the dataset.
|
|
78
75
|
|
|
79
76
|
Args:
|
|
80
77
|
hyp (dict, optional): Hyperparameters for transformations.
|
|
@@ -105,8 +102,7 @@ class RTDETRDataset(YOLODataset):
|
|
|
105
102
|
|
|
106
103
|
|
|
107
104
|
class RTDETRValidator(DetectionValidator):
|
|
108
|
-
"""
|
|
109
|
-
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
|
105
|
+
"""RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
|
110
106
|
the RT-DETR (Real-Time DETR) object detection model.
|
|
111
107
|
|
|
112
108
|
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
|
@@ -132,8 +128,7 @@ class RTDETRValidator(DetectionValidator):
|
|
|
132
128
|
"""
|
|
133
129
|
|
|
134
130
|
def build_dataset(self, img_path, mode="val", batch=None):
|
|
135
|
-
"""
|
|
136
|
-
Build an RTDETR Dataset.
|
|
131
|
+
"""Build an RTDETR Dataset.
|
|
137
132
|
|
|
138
133
|
Args:
|
|
139
134
|
img_path (str): Path to the folder containing images.
|
|
@@ -156,15 +151,19 @@ class RTDETRValidator(DetectionValidator):
|
|
|
156
151
|
data=self.data,
|
|
157
152
|
)
|
|
158
153
|
|
|
154
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
155
|
+
"""Scales predictions to the original image size."""
|
|
156
|
+
return predn
|
|
157
|
+
|
|
159
158
|
def postprocess(
|
|
160
159
|
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
|
161
160
|
) -> list[dict[str, torch.Tensor]]:
|
|
162
|
-
"""
|
|
163
|
-
Apply Non-maximum suppression to prediction outputs.
|
|
161
|
+
"""Apply Non-maximum suppression to prediction outputs.
|
|
164
162
|
|
|
165
163
|
Args:
|
|
166
164
|
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
|
|
167
|
-
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
|
|
165
|
+
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
|
|
166
|
+
class scores.
|
|
168
167
|
|
|
169
168
|
Returns:
|
|
170
169
|
(list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
|
|
@@ -190,12 +189,11 @@ class RTDETRValidator(DetectionValidator):
|
|
|
190
189
|
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
|
191
190
|
|
|
192
191
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
193
|
-
"""
|
|
194
|
-
Serialize YOLO predictions to COCO json format.
|
|
192
|
+
"""Serialize YOLO predictions to COCO json format.
|
|
195
193
|
|
|
196
194
|
Args:
|
|
197
|
-
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
198
|
-
|
|
195
|
+
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
196
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
199
197
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
200
198
|
"""
|
|
201
199
|
path = Path(pbatch["im_file"])
|
|
@@ -1,12 +1,25 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from .model import SAM
|
|
4
|
-
from .predict import
|
|
4
|
+
from .predict import (
|
|
5
|
+
Predictor,
|
|
6
|
+
SAM2DynamicInteractivePredictor,
|
|
7
|
+
SAM2Predictor,
|
|
8
|
+
SAM2VideoPredictor,
|
|
9
|
+
SAM3Predictor,
|
|
10
|
+
SAM3SemanticPredictor,
|
|
11
|
+
SAM3VideoPredictor,
|
|
12
|
+
SAM3VideoSemanticPredictor,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
15
|
__all__ = (
|
|
7
16
|
"SAM",
|
|
8
17
|
"Predictor",
|
|
18
|
+
"SAM2DynamicInteractivePredictor",
|
|
9
19
|
"SAM2Predictor",
|
|
10
20
|
"SAM2VideoPredictor",
|
|
11
|
-
"
|
|
21
|
+
"SAM3Predictor",
|
|
22
|
+
"SAM3SemanticPredictor",
|
|
23
|
+
"SAM3VideoPredictor",
|
|
24
|
+
"SAM3VideoSemanticPredictor",
|
|
12
25
|
) # tuple or list of exportable items
|
ultralytics/models/sam/amg.py
CHANGED
|
@@ -14,8 +14,7 @@ import torch
|
|
|
14
14
|
def is_box_near_crop_edge(
|
|
15
15
|
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
|
16
16
|
) -> torch.Tensor:
|
|
17
|
-
"""
|
|
18
|
-
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
|
17
|
+
"""Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
|
19
18
|
|
|
20
19
|
Args:
|
|
21
20
|
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
|
@@ -42,8 +41,7 @@ def is_box_near_crop_edge(
|
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
|
45
|
-
"""
|
|
46
|
-
Yield batches of data from input arguments with specified batch size for efficient processing.
|
|
44
|
+
"""Yield batches of data from input arguments with specified batch size for efficient processing.
|
|
47
45
|
|
|
48
46
|
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
|
49
47
|
iterables. All input iterables must have the same length.
|
|
@@ -71,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
|
|
71
69
|
|
|
72
70
|
|
|
73
71
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
|
74
|
-
"""
|
|
75
|
-
Compute the stability score for a batch of masks.
|
|
72
|
+
"""Compute the stability score for a batch of masks.
|
|
76
73
|
|
|
77
|
-
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
|
78
|
-
|
|
74
|
+
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
|
|
75
|
+
low values.
|
|
79
76
|
|
|
80
77
|
Args:
|
|
81
78
|
masks (torch.Tensor): Batch of predicted mask logits.
|
|
@@ -85,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|
|
85
82
|
Returns:
|
|
86
83
|
(torch.Tensor): Stability scores for each mask in the batch.
|
|
87
84
|
|
|
88
|
-
Notes:
|
|
89
|
-
- One mask is always contained inside the other.
|
|
90
|
-
- Memory is saved by preventing unnecessary cast to torch.int64.
|
|
91
|
-
|
|
92
85
|
Examples:
|
|
93
86
|
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
|
94
87
|
>>> mask_threshold = 0.5
|
|
95
88
|
>>> threshold_offset = 0.1
|
|
96
89
|
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
|
90
|
+
|
|
91
|
+
Notes:
|
|
92
|
+
- One mask is always contained inside the other.
|
|
93
|
+
- Memory is saved by preventing unnecessary cast to torch.int64.
|
|
97
94
|
"""
|
|
98
95
|
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
99
96
|
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
@@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|
|
117
114
|
def generate_crop_boxes(
|
|
118
115
|
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
|
119
116
|
) -> tuple[list[list[int]], list[int]]:
|
|
120
|
-
"""
|
|
121
|
-
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
|
117
|
+
"""Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
|
122
118
|
|
|
123
119
|
Args:
|
|
124
120
|
im_size (tuple[int, ...]): Height and width of the input image.
|
|
@@ -145,7 +141,7 @@ def generate_crop_boxes(
|
|
|
145
141
|
|
|
146
142
|
def crop_len(orig_len, n_crops, overlap):
|
|
147
143
|
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
|
|
148
|
-
return
|
|
144
|
+
return math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)
|
|
149
145
|
|
|
150
146
|
for i_layer in range(n_layers):
|
|
151
147
|
n_crops_per_side = 2 ** (i_layer + 1)
|
|
@@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
|
|
|
198
194
|
|
|
199
195
|
|
|
200
196
|
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
|
201
|
-
"""
|
|
202
|
-
Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
|
197
|
+
"""Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
|
203
198
|
|
|
204
199
|
Args:
|
|
205
200
|
mask (np.ndarray): Binary mask to process.
|
|
@@ -227,7 +222,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
|
|
|
227
222
|
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
228
223
|
if not small_regions:
|
|
229
224
|
return mask, False
|
|
230
|
-
fill_labels = [0
|
|
225
|
+
fill_labels = [0, *small_regions]
|
|
231
226
|
if not correct_holes:
|
|
232
227
|
# If every region is below threshold, keep largest
|
|
233
228
|
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
|
@@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
|
|
|
236
231
|
|
|
237
232
|
|
|
238
233
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
|
239
|
-
"""
|
|
240
|
-
Calculate bounding boxes in XYXY format around binary masks.
|
|
234
|
+
"""Calculate bounding boxes in XYXY format around binary masks.
|
|
241
235
|
|
|
242
236
|
Args:
|
|
243
237
|
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
ultralytics/models/sam/build.py
CHANGED
|
@@ -11,6 +11,7 @@ from functools import partial
|
|
|
11
11
|
import torch
|
|
12
12
|
|
|
13
13
|
from ultralytics.utils.downloads import attempt_download_asset
|
|
14
|
+
from ultralytics.utils.patches import torch_load
|
|
14
15
|
|
|
15
16
|
from .modules.decoders import MaskDecoder
|
|
16
17
|
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
|
@@ -20,6 +21,21 @@ from .modules.tiny_encoder import TinyViT
|
|
|
20
21
|
from .modules.transformer import TwoWayTransformer
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
def _load_checkpoint(model, checkpoint):
|
|
25
|
+
"""Load checkpoint into model from file path."""
|
|
26
|
+
if checkpoint is None:
|
|
27
|
+
return model
|
|
28
|
+
|
|
29
|
+
checkpoint = attempt_download_asset(checkpoint)
|
|
30
|
+
with open(checkpoint, "rb") as f:
|
|
31
|
+
state_dict = torch_load(f)
|
|
32
|
+
# Handle nested "model" key
|
|
33
|
+
if "model" in state_dict and isinstance(state_dict["model"], dict):
|
|
34
|
+
state_dict = state_dict["model"]
|
|
35
|
+
model.load_state_dict(state_dict)
|
|
36
|
+
return model
|
|
37
|
+
|
|
38
|
+
|
|
23
39
|
def build_sam_vit_h(checkpoint=None):
|
|
24
40
|
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
|
25
41
|
return _build_sam(
|
|
@@ -126,8 +142,7 @@ def _build_sam(
|
|
|
126
142
|
checkpoint=None,
|
|
127
143
|
mobile_sam=False,
|
|
128
144
|
):
|
|
129
|
-
"""
|
|
130
|
-
Build a Segment Anything Model (SAM) with specified encoder parameters.
|
|
145
|
+
"""Build a Segment Anything Model (SAM) with specified encoder parameters.
|
|
131
146
|
|
|
132
147
|
Args:
|
|
133
148
|
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
|
|
@@ -205,26 +220,22 @@ def _build_sam(
|
|
|
205
220
|
pixel_std=[58.395, 57.12, 57.375],
|
|
206
221
|
)
|
|
207
222
|
if checkpoint is not None:
|
|
208
|
-
|
|
209
|
-
with open(checkpoint, "rb") as f:
|
|
210
|
-
state_dict = torch.load(f)
|
|
211
|
-
sam.load_state_dict(state_dict)
|
|
223
|
+
sam = _load_checkpoint(sam, checkpoint)
|
|
212
224
|
sam.eval()
|
|
213
225
|
return sam
|
|
214
226
|
|
|
215
227
|
|
|
216
228
|
def _build_sam2(
|
|
217
229
|
encoder_embed_dim=1280,
|
|
218
|
-
encoder_stages=
|
|
230
|
+
encoder_stages=(2, 6, 36, 4),
|
|
219
231
|
encoder_num_heads=2,
|
|
220
|
-
encoder_global_att_blocks=
|
|
221
|
-
encoder_backbone_channel_list=
|
|
222
|
-
encoder_window_spatial_size=
|
|
223
|
-
encoder_window_spec=
|
|
232
|
+
encoder_global_att_blocks=(7, 15, 23, 31),
|
|
233
|
+
encoder_backbone_channel_list=(1152, 576, 288, 144),
|
|
234
|
+
encoder_window_spatial_size=(7, 7),
|
|
235
|
+
encoder_window_spec=(8, 4, 16, 8),
|
|
224
236
|
checkpoint=None,
|
|
225
237
|
):
|
|
226
|
-
"""
|
|
227
|
-
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
|
238
|
+
"""Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
|
228
239
|
|
|
229
240
|
Args:
|
|
230
241
|
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
|
@@ -300,10 +311,7 @@ def _build_sam2(
|
|
|
300
311
|
)
|
|
301
312
|
|
|
302
313
|
if checkpoint is not None:
|
|
303
|
-
|
|
304
|
-
with open(checkpoint, "rb") as f:
|
|
305
|
-
state_dict = torch.load(f)["model"]
|
|
306
|
-
sam2.load_state_dict(state_dict)
|
|
314
|
+
sam2 = _load_checkpoint(sam2, checkpoint)
|
|
307
315
|
sam2.eval()
|
|
308
316
|
return sam2
|
|
309
317
|
|
|
@@ -325,8 +333,7 @@ sam_model_map = {
|
|
|
325
333
|
|
|
326
334
|
|
|
327
335
|
def build_sam(ckpt="sam_b.pt"):
|
|
328
|
-
"""
|
|
329
|
-
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
|
336
|
+
"""Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
|
330
337
|
|
|
331
338
|
Args:
|
|
332
339
|
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|