ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +125 -39
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +34 -33
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +33 -47
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +69 -90
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +31 -38
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +21 -26
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +23 -17
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +29 -24
- ultralytics/models/nas/predict.py +14 -11
- ultralytics/models/nas/val.py +11 -13
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +21 -21
- ultralytics/models/rtdetr/train.py +25 -24
- ultralytics/models/rtdetr/val.py +47 -14
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +30 -12
- ultralytics/models/yolo/classify/train.py +83 -19
- ultralytics/models/yolo/classify/val.py +45 -23
- ultralytics/models/yolo/detect/predict.py +29 -19
- ultralytics/models/yolo/detect/train.py +90 -23
- ultralytics/models/yolo/detect/val.py +150 -29
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +18 -13
- ultralytics/models/yolo/obb/train.py +12 -8
- ultralytics/models/yolo/obb/val.py +35 -22
- ultralytics/models/yolo/pose/predict.py +28 -15
- ultralytics/models/yolo/pose/train.py +21 -8
- ultralytics/models/yolo/pose/val.py +51 -31
- ultralytics/models/yolo/segment/predict.py +27 -16
- ultralytics/models/yolo/segment/train.py +11 -8
- ultralytics/models/yolo/segment/val.py +110 -29
- ultralytics/models/yolo/world/train.py +43 -16
- ultralytics/models/yolo/world/train_world.py +61 -36
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +12 -12
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +226 -79
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +37 -35
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +139 -68
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +37 -56
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +117 -52
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +65 -61
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +72 -59
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +202 -64
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +13 -25
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.88.dist-info/RECORD +0 -250
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/models/rtdetr/val.py
CHANGED
@@ -22,13 +22,20 @@ class RTDETRDataset(YOLODataset):
|
|
22
22
|
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
|
23
23
|
super().__init__(*args, data=data, **kwargs)
|
24
24
|
|
25
|
-
# NOTE: add stretch version load_image for RTDETR mosaic
|
26
25
|
def load_image(self, i, rect_mode=False):
|
27
26
|
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
28
27
|
return super().load_image(i=i, rect_mode=rect_mode)
|
29
28
|
|
30
29
|
def build_transforms(self, hyp=None):
|
31
|
-
"""
|
30
|
+
"""
|
31
|
+
Build transformation pipeline for the dataset.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
hyp (Dict, optional): Hyperparameters for transformations.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
(Compose): Composition of transformation functions.
|
38
|
+
"""
|
32
39
|
if self.augment:
|
33
40
|
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
34
41
|
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
@@ -58,14 +65,11 @@ class RTDETRValidator(DetectionValidator):
|
|
58
65
|
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
59
66
|
post-processing, and updates evaluation metrics accordingly.
|
60
67
|
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
validator = RTDETRValidator(args=args)
|
67
|
-
validator()
|
68
|
-
```
|
68
|
+
Examples:
|
69
|
+
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
70
|
+
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
71
|
+
>>> validator = RTDETRValidator(args=args)
|
72
|
+
>>> validator()
|
69
73
|
|
70
74
|
Note:
|
71
75
|
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
@@ -78,7 +82,10 @@ class RTDETRValidator(DetectionValidator):
|
|
78
82
|
Args:
|
79
83
|
img_path (str): Path to the folder containing images.
|
80
84
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
81
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
85
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
(RTDETRDataset): Dataset configured for RT-DETR validation.
|
82
89
|
"""
|
83
90
|
return RTDETRDataset(
|
84
91
|
img_path=img_path,
|
@@ -93,7 +100,15 @@ class RTDETRValidator(DetectionValidator):
|
|
93
100
|
)
|
94
101
|
|
95
102
|
def postprocess(self, preds):
|
96
|
-
"""
|
103
|
+
"""
|
104
|
+
Apply Non-maximum suppression to prediction outputs.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
preds (List | Tuple | torch.Tensor): Raw predictions from the model.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
(List[torch.Tensor]): List of processed predictions for each image in batch.
|
111
|
+
"""
|
97
112
|
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
98
113
|
preds = [preds, None]
|
99
114
|
|
@@ -114,7 +129,16 @@ class RTDETRValidator(DetectionValidator):
|
|
114
129
|
return outputs
|
115
130
|
|
116
131
|
def _prepare_batch(self, si, batch):
|
117
|
-
"""
|
132
|
+
"""
|
133
|
+
Prepares a batch for validation by applying necessary transformations.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
si (int): Batch index.
|
137
|
+
batch (Dict): Batch data containing images and annotations.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
(Dict): Prepared batch with transformed annotations.
|
141
|
+
"""
|
118
142
|
idx = batch["batch_idx"] == si
|
119
143
|
cls = batch["cls"][idx].squeeze(-1)
|
120
144
|
bbox = batch["bboxes"][idx]
|
@@ -128,7 +152,16 @@ class RTDETRValidator(DetectionValidator):
|
|
128
152
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
129
153
|
|
130
154
|
def _prepare_pred(self, pred, pbatch):
|
131
|
-
"""
|
155
|
+
"""
|
156
|
+
Prepares predictions by scaling bounding boxes to original image dimensions.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
pred (torch.Tensor): Raw predictions.
|
160
|
+
pbatch (Dict): Prepared batch information.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
(torch.Tensor): Predictions scaled to original image dimensions.
|
164
|
+
"""
|
132
165
|
predn = pred.clone()
|
133
166
|
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
134
167
|
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
@@ -3,4 +3,4 @@
|
|
3
3
|
from .model import SAM
|
4
4
|
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
5
5
|
|
6
|
-
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
|
6
|
+
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items
|
ultralytics/models/sam/amg.py
CHANGED
@@ -76,7 +76,24 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|
76
76
|
def generate_crop_boxes(
|
77
77
|
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
78
78
|
) -> Tuple[List[List[int]], List[int]]:
|
79
|
-
"""
|
79
|
+
"""
|
80
|
+
Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
im_size (Tuple[int, ...]): Height and width of the input image.
|
84
|
+
n_layers (int): Number of layers to generate crop boxes for.
|
85
|
+
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
(List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
89
|
+
(List[int]): List of layer indices corresponding to each crop box.
|
90
|
+
|
91
|
+
Examples:
|
92
|
+
>>> im_size = (800, 1200) # Height, width
|
93
|
+
>>> n_layers = 3
|
94
|
+
>>> overlap_ratio = 0.25
|
95
|
+
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
|
96
|
+
"""
|
80
97
|
crop_boxes, layer_idxs = [], []
|
81
98
|
im_h, im_w = im_size
|
82
99
|
short_side = min(im_h, im_w)
|
@@ -86,7 +103,7 @@ def generate_crop_boxes(
|
|
86
103
|
layer_idxs.append(0)
|
87
104
|
|
88
105
|
def crop_len(orig_len, n_crops, overlap):
|
89
|
-
"""
|
106
|
+
"""Calculates the length of each crop given the original length, number of crops, and overlap."""
|
90
107
|
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
91
108
|
|
92
109
|
for i_layer in range(n_layers):
|
@@ -140,7 +157,24 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|
140
157
|
|
141
158
|
|
142
159
|
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
143
|
-
"""
|
160
|
+
"""
|
161
|
+
Removes small disconnected regions or holes in a mask based on area threshold and mode.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
mask (np.ndarray): Binary mask to process.
|
165
|
+
area_thresh (float): Area threshold below which regions will be removed.
|
166
|
+
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
(np.ndarray): Processed binary mask with small regions removed.
|
170
|
+
(bool): Whether any regions were modified.
|
171
|
+
|
172
|
+
Examples:
|
173
|
+
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
174
|
+
>>> mask[40:60, 40:60] = True # Create a square
|
175
|
+
>>> mask[45:55, 45:55] = False # Create a hole
|
176
|
+
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
|
177
|
+
"""
|
144
178
|
import cv2 # type: ignore
|
145
179
|
|
146
180
|
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
@@ -160,7 +194,19 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
160
194
|
|
161
195
|
|
162
196
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
163
|
-
"""
|
197
|
+
"""
|
198
|
+
Calculates bounding boxes in XYXY format around binary masks.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
(torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
|
205
|
+
|
206
|
+
Notes:
|
207
|
+
- Handles empty masks by returning zero boxes.
|
208
|
+
- Preserves input tensor dimensions in the output.
|
209
|
+
"""
|
164
210
|
# torch.max below raises an error on empty inputs, just skip in this case
|
165
211
|
if torch.numel(masks) == 0:
|
166
212
|
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
ultralytics/models/sam/model.py
CHANGED
@@ -49,7 +49,7 @@ class SAM(Model):
|
|
49
49
|
|
50
50
|
def __init__(self, model="sam_b.pt") -> None:
|
51
51
|
"""
|
52
|
-
|
52
|
+
Initialize the SAM (Segment Anything Model) instance.
|
53
53
|
|
54
54
|
Args:
|
55
55
|
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
@@ -68,10 +68,7 @@ class SAM(Model):
|
|
68
68
|
|
69
69
|
def _load(self, weights: str, task=None):
|
70
70
|
"""
|
71
|
-
|
72
|
-
|
73
|
-
This method initializes the SAM model with the provided weights file, setting up the model architecture
|
74
|
-
and loading the pre-trained parameters.
|
71
|
+
Load the specified weights into the SAM model.
|
75
72
|
|
76
73
|
Args:
|
77
74
|
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
@@ -85,7 +82,7 @@ class SAM(Model):
|
|
85
82
|
|
86
83
|
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
87
84
|
"""
|
88
|
-
|
85
|
+
Perform segmentation prediction on the given image or video source.
|
89
86
|
|
90
87
|
Args:
|
91
88
|
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
|
@@ -112,7 +109,7 @@ class SAM(Model):
|
|
112
109
|
|
113
110
|
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
114
111
|
"""
|
115
|
-
|
112
|
+
Perform segmentation prediction on the given image or video source.
|
116
113
|
|
117
114
|
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
118
115
|
for segmentation tasks.
|
@@ -138,10 +135,7 @@ class SAM(Model):
|
|
138
135
|
|
139
136
|
def info(self, detailed=False, verbose=True):
|
140
137
|
"""
|
141
|
-
|
142
|
-
|
143
|
-
This method provides details about the Segment Anything Model (SAM), including its architecture,
|
144
|
-
parameters, and computational requirements.
|
138
|
+
Log information about the SAM model.
|
145
139
|
|
146
140
|
Args:
|
147
141
|
detailed (bool): If True, displays detailed information about the model layers and operations.
|
@@ -160,16 +154,16 @@ class SAM(Model):
|
|
160
154
|
@property
|
161
155
|
def task_map(self):
|
162
156
|
"""
|
163
|
-
|
157
|
+
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
164
158
|
|
165
159
|
Returns:
|
166
|
-
(Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
|
160
|
+
(Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding Predictor
|
167
161
|
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
168
162
|
|
169
163
|
Examples:
|
170
164
|
>>> sam = SAM("sam_b.pt")
|
171
165
|
>>> task_map = sam.task_map
|
172
166
|
>>> print(task_map)
|
173
|
-
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
167
|
+
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
174
168
|
"""
|
175
169
|
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
@@ -48,7 +48,7 @@ class MaskDecoder(nn.Module):
|
|
48
48
|
iou_head_hidden_dim: int = 256,
|
49
49
|
) -> None:
|
50
50
|
"""
|
51
|
-
|
51
|
+
Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
52
52
|
|
53
53
|
Args:
|
54
54
|
transformer_dim (int): Channel dimension for the transformer module.
|
@@ -95,7 +95,7 @@ class MaskDecoder(nn.Module):
|
|
95
95
|
multimask_output: bool,
|
96
96
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
97
97
|
"""
|
98
|
-
|
98
|
+
Predict masks given image and prompt embeddings.
|
99
99
|
|
100
100
|
Args:
|
101
101
|
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
@@ -105,9 +105,8 @@ class MaskDecoder(nn.Module):
|
|
105
105
|
multimask_output (bool): Whether to return multiple masks or a single mask.
|
106
106
|
|
107
107
|
Returns:
|
108
|
-
(
|
109
|
-
|
110
|
-
- iou_pred (torch.Tensor): Batched predictions of mask quality.
|
108
|
+
masks (torch.Tensor): Batched predicted masks.
|
109
|
+
iou_pred (torch.Tensor): Batched predictions of mask quality.
|
111
110
|
|
112
111
|
Examples:
|
113
112
|
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
@@ -140,7 +139,7 @@ class MaskDecoder(nn.Module):
|
|
140
139
|
sparse_prompt_embeddings: torch.Tensor,
|
141
140
|
dense_prompt_embeddings: torch.Tensor,
|
142
141
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
143
|
-
"""
|
142
|
+
"""Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
|
144
143
|
# Concatenate output tokens
|
145
144
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
146
145
|
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
@@ -236,7 +235,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
236
235
|
use_multimask_token_for_obj_ptr: bool = False,
|
237
236
|
) -> None:
|
238
237
|
"""
|
239
|
-
|
238
|
+
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
240
239
|
|
241
240
|
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
|
242
241
|
high-resolution feature processing, dynamic multimask output, and object score prediction.
|
@@ -320,9 +319,9 @@ class SAM2MaskDecoder(nn.Module):
|
|
320
319
|
multimask_output: bool,
|
321
320
|
repeat_image: bool,
|
322
321
|
high_res_features: Optional[List[torch.Tensor]] = None,
|
323
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
322
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
324
323
|
"""
|
325
|
-
|
324
|
+
Predict masks given image and prompt embeddings.
|
326
325
|
|
327
326
|
Args:
|
328
327
|
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
@@ -334,11 +333,10 @@ class SAM2MaskDecoder(nn.Module):
|
|
334
333
|
high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
|
335
334
|
|
336
335
|
Returns:
|
337
|
-
(
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
- object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
|
336
|
+
masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
|
337
|
+
iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
|
338
|
+
sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
|
339
|
+
object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
|
342
340
|
|
343
341
|
Examples:
|
344
342
|
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
@@ -390,8 +388,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
390
388
|
dense_prompt_embeddings: torch.Tensor,
|
391
389
|
repeat_image: bool,
|
392
390
|
high_res_features: Optional[List[torch.Tensor]] = None,
|
393
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
394
|
-
"""
|
391
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
392
|
+
"""Predict instance segmentation masks from image and prompt embeddings using a transformer."""
|
395
393
|
# Concatenate output tokens
|
396
394
|
s = 0
|
397
395
|
if self.pred_obj_scores:
|
@@ -454,7 +452,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
454
452
|
return masks, iou_pred, mask_tokens_out, object_score_logits
|
455
453
|
|
456
454
|
def _get_stability_scores(self, mask_logits):
|
457
|
-
"""
|
455
|
+
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
|
458
456
|
mask_logits = mask_logits.flatten(-2)
|
459
457
|
stability_delta = self.dynamic_multimask_stability_delta
|
460
458
|
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
@@ -463,7 +461,7 @@ class SAM2MaskDecoder(nn.Module):
|
|
463
461
|
|
464
462
|
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
465
463
|
"""
|
466
|
-
Dynamically
|
464
|
+
Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
467
465
|
|
468
466
|
This method is used when outputting a single mask. If the stability score from the current single-mask
|
469
467
|
output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
|
@@ -476,9 +474,8 @@ class SAM2MaskDecoder(nn.Module):
|
|
476
474
|
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
477
475
|
|
478
476
|
Returns:
|
479
|
-
(
|
480
|
-
|
481
|
-
- iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
|
477
|
+
mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
|
478
|
+
iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
|
482
479
|
|
483
480
|
Examples:
|
484
481
|
>>> decoder = SAM2MaskDecoder(...)
|
@@ -65,7 +65,7 @@ class ImageEncoderViT(nn.Module):
|
|
65
65
|
global_attn_indexes: Tuple[int, ...] = (),
|
66
66
|
) -> None:
|
67
67
|
"""
|
68
|
-
|
68
|
+
Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
|
69
69
|
|
70
70
|
Args:
|
71
71
|
img_size (int): Input image size, assumed to be square.
|
@@ -85,13 +85,6 @@ class ImageEncoderViT(nn.Module):
|
|
85
85
|
window_size (int): Size of attention window for windowed attention blocks.
|
86
86
|
global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
|
87
87
|
|
88
|
-
Attributes:
|
89
|
-
img_size (int): Dimension of input images.
|
90
|
-
patch_embed (PatchEmbed): Module for patch embedding.
|
91
|
-
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
|
92
|
-
blocks (nn.ModuleList): List of transformer blocks.
|
93
|
-
neck (nn.Sequential): Neck module for final processing.
|
94
|
-
|
95
88
|
Examples:
|
96
89
|
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
97
90
|
>>> input_image = torch.randn(1, 3, 224, 224)
|
@@ -148,7 +141,7 @@ class ImageEncoderViT(nn.Module):
|
|
148
141
|
)
|
149
142
|
|
150
143
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151
|
-
"""
|
144
|
+
"""Process input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
152
145
|
x = self.patch_embed(x)
|
153
146
|
if self.pos_embed is not None:
|
154
147
|
pos_embed = (
|
@@ -201,10 +194,7 @@ class PromptEncoder(nn.Module):
|
|
201
194
|
activation: Type[nn.Module] = nn.GELU,
|
202
195
|
) -> None:
|
203
196
|
"""
|
204
|
-
|
205
|
-
|
206
|
-
This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
|
207
|
-
producing both sparse and dense embeddings.
|
197
|
+
Initialize the PromptEncoder module for encoding various types of prompts.
|
208
198
|
|
209
199
|
Args:
|
210
200
|
embed_dim (int): The dimension of the embeddings.
|
@@ -213,17 +203,6 @@ class PromptEncoder(nn.Module):
|
|
213
203
|
mask_in_chans (int): The number of hidden channels used for encoding input masks.
|
214
204
|
activation (Type[nn.Module]): The activation function to use when encoding input masks.
|
215
205
|
|
216
|
-
Attributes:
|
217
|
-
embed_dim (int): Dimension of the embeddings.
|
218
|
-
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
|
219
|
-
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
|
220
|
-
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
221
|
-
num_point_embeddings (int): Number of point embeddings for different types of points.
|
222
|
-
point_embeddings (nn.ModuleList): List of point embeddings.
|
223
|
-
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
|
224
|
-
mask_input_size (Tuple[int, int]): Size of the input mask.
|
225
|
-
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
226
|
-
|
227
206
|
Examples:
|
228
207
|
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
229
208
|
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
@@ -258,9 +237,9 @@ class PromptEncoder(nn.Module):
|
|
258
237
|
|
259
238
|
def get_dense_pe(self) -> torch.Tensor:
|
260
239
|
"""
|
261
|
-
|
240
|
+
Return the dense positional encoding used for encoding point prompts.
|
262
241
|
|
263
|
-
|
242
|
+
Generate a positional encoding for a dense set of points matching the shape of the image
|
264
243
|
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
|
265
244
|
|
266
245
|
Returns:
|
@@ -276,7 +255,7 @@ class PromptEncoder(nn.Module):
|
|
276
255
|
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
277
256
|
|
278
257
|
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
279
|
-
"""
|
258
|
+
"""Embed point prompts by applying positional encoding and label-specific embeddings."""
|
280
259
|
points = points + 0.5 # Shift to center of pixel
|
281
260
|
if pad:
|
282
261
|
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
@@ -293,7 +272,7 @@ class PromptEncoder(nn.Module):
|
|
293
272
|
return point_embedding
|
294
273
|
|
295
274
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
296
|
-
"""
|
275
|
+
"""Embed box prompts by applying positional encoding and adding corner embeddings."""
|
297
276
|
boxes = boxes + 0.5 # Shift to center of pixel
|
298
277
|
coords = boxes.reshape(-1, 2, 2)
|
299
278
|
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
@@ -302,7 +281,7 @@ class PromptEncoder(nn.Module):
|
|
302
281
|
return corner_embedding
|
303
282
|
|
304
283
|
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
305
|
-
"""
|
284
|
+
"""Embed mask inputs by downscaling and processing through convolutional layers."""
|
306
285
|
return self.mask_downscaling(masks)
|
307
286
|
|
308
287
|
@staticmethod
|
@@ -311,7 +290,7 @@ class PromptEncoder(nn.Module):
|
|
311
290
|
boxes: Optional[torch.Tensor],
|
312
291
|
masks: Optional[torch.Tensor],
|
313
292
|
) -> int:
|
314
|
-
"""
|
293
|
+
"""Get the batch size of the output given the batch size of the input prompts."""
|
315
294
|
if points is not None:
|
316
295
|
return points[0].shape[0]
|
317
296
|
elif boxes is not None:
|
@@ -322,7 +301,7 @@ class PromptEncoder(nn.Module):
|
|
322
301
|
return 1
|
323
302
|
|
324
303
|
def _get_device(self) -> torch.device:
|
325
|
-
"""
|
304
|
+
"""Return the device of the first point embedding's weight tensor."""
|
326
305
|
return self.point_embeddings[0].weight.device
|
327
306
|
|
328
307
|
def forward(
|
@@ -332,7 +311,7 @@ class PromptEncoder(nn.Module):
|
|
332
311
|
masks: Optional[torch.Tensor],
|
333
312
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
334
313
|
"""
|
335
|
-
|
314
|
+
Embed different types of prompts, returning both sparse and dense embeddings.
|
336
315
|
|
337
316
|
Args:
|
338
317
|
points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
|
@@ -377,7 +356,7 @@ class PromptEncoder(nn.Module):
|
|
377
356
|
|
378
357
|
class MemoryEncoder(nn.Module):
|
379
358
|
"""
|
380
|
-
|
359
|
+
Encode pixel features and masks into a memory representation for efficient image segmentation.
|
381
360
|
|
382
361
|
This class processes pixel-level features and masks, fusing them to generate encoded memory representations
|
383
362
|
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
@@ -390,7 +369,7 @@ class MemoryEncoder(nn.Module):
|
|
390
369
|
out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
|
391
370
|
|
392
371
|
Methods:
|
393
|
-
forward:
|
372
|
+
forward: Process input pixel features and masks to generate encoded memory representations.
|
394
373
|
|
395
374
|
Examples:
|
396
375
|
>>> import torch
|
@@ -407,7 +386,7 @@ class MemoryEncoder(nn.Module):
|
|
407
386
|
out_dim,
|
408
387
|
in_dim=256, # in_dim of pix_feats
|
409
388
|
):
|
410
|
-
"""
|
389
|
+
"""Initialize the MemoryEncoder for encoding pixel features and masks into memory representations."""
|
411
390
|
super().__init__()
|
412
391
|
|
413
392
|
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
@@ -425,7 +404,7 @@ class MemoryEncoder(nn.Module):
|
|
425
404
|
masks: torch.Tensor,
|
426
405
|
skip_mask_sigmoid: bool = False,
|
427
406
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
428
|
-
"""
|
407
|
+
"""Process pixel features and masks to generate encoded memory representations for segmentation."""
|
429
408
|
if not skip_mask_sigmoid:
|
430
409
|
masks = F.sigmoid(masks)
|
431
410
|
masks = self.mask_downsampler(masks)
|
@@ -445,7 +424,7 @@ class MemoryEncoder(nn.Module):
|
|
445
424
|
|
446
425
|
class ImageEncoder(nn.Module):
|
447
426
|
"""
|
448
|
-
|
427
|
+
Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
|
449
428
|
|
450
429
|
This class combines a trunk network for feature extraction with a neck network for feature refinement
|
451
430
|
and positional encoding generation. It can optionally discard the lowest resolution features.
|
@@ -456,7 +435,7 @@ class ImageEncoder(nn.Module):
|
|
456
435
|
scalp (int): Number of lowest resolution feature levels to discard.
|
457
436
|
|
458
437
|
Methods:
|
459
|
-
forward:
|
438
|
+
forward: Process the input image through the trunk and neck networks.
|
460
439
|
|
461
440
|
Examples:
|
462
441
|
>>> trunk = SomeTrunkNetwork()
|
@@ -474,7 +453,7 @@ class ImageEncoder(nn.Module):
|
|
474
453
|
neck: nn.Module,
|
475
454
|
scalp: int = 0,
|
476
455
|
):
|
477
|
-
"""
|
456
|
+
"""Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
|
478
457
|
super().__init__()
|
479
458
|
self.trunk = trunk
|
480
459
|
self.neck = neck
|
@@ -484,7 +463,7 @@ class ImageEncoder(nn.Module):
|
|
484
463
|
)
|
485
464
|
|
486
465
|
def forward(self, sample: torch.Tensor):
|
487
|
-
"""
|
466
|
+
"""Encode input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
488
467
|
features, pos = self.neck(self.trunk(sample))
|
489
468
|
if self.scalp > 0:
|
490
469
|
# Discard the lowest resolution features
|
@@ -514,7 +493,7 @@ class FpnNeck(nn.Module):
|
|
514
493
|
fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
|
515
494
|
|
516
495
|
Methods:
|
517
|
-
forward:
|
496
|
+
forward: Perform forward pass through the FPN neck.
|
518
497
|
|
519
498
|
Examples:
|
520
499
|
>>> backbone_channels = [64, 128, 256, 512]
|
@@ -665,8 +644,8 @@ class Hiera(nn.Module):
|
|
665
644
|
channel_list (List[int]): List of output channel dimensions for each stage.
|
666
645
|
|
667
646
|
Methods:
|
668
|
-
_get_pos_embed:
|
669
|
-
forward:
|
647
|
+
_get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.
|
648
|
+
forward: Perform the forward pass through the Hiera model.
|
670
649
|
|
671
650
|
Examples:
|
672
651
|
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
@@ -702,7 +681,7 @@ class Hiera(nn.Module):
|
|
702
681
|
),
|
703
682
|
return_interm_layers=True, # return feats from every stage
|
704
683
|
):
|
705
|
-
"""
|
684
|
+
"""Initialize the Hiera model, configuring its hierarchical vision transformer architecture."""
|
706
685
|
super().__init__()
|
707
686
|
|
708
687
|
assert len(stages) == len(window_spec)
|
@@ -768,7 +747,7 @@ class Hiera(nn.Module):
|
|
768
747
|
)
|
769
748
|
|
770
749
|
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
771
|
-
"""
|
750
|
+
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
|
772
751
|
h, w = hw
|
773
752
|
window_embed = self.pos_embed_window
|
774
753
|
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
@@ -777,7 +756,7 @@ class Hiera(nn.Module):
|
|
777
756
|
return pos_embed
|
778
757
|
|
779
758
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
780
|
-
"""
|
759
|
+
"""Perform forward pass through Hiera model, extracting multiscale features from input images."""
|
781
760
|
x = self.patch_embed(x)
|
782
761
|
# x: (B, H, W, C)
|
783
762
|
|