ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +16 -8
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.142.dist-info/RECORD +0 -272
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -47,26 +47,26 @@ class Predictor(BasePredictor):
|
|
47
47
|
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
48
48
|
im (torch.Tensor): The preprocessed input image.
|
49
49
|
features (torch.Tensor): Extracted image features.
|
50
|
-
prompts (
|
50
|
+
prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
|
51
51
|
segment_all (bool): Flag to indicate if full image segmentation should be performed.
|
52
52
|
mean (torch.Tensor): Mean values for image normalization.
|
53
53
|
std (torch.Tensor): Standard deviation values for image normalization.
|
54
54
|
|
55
55
|
Methods:
|
56
|
-
preprocess:
|
57
|
-
pre_transform:
|
58
|
-
inference:
|
56
|
+
preprocess: Prepare input images for model inference.
|
57
|
+
pre_transform: Perform initial transformations on the input image.
|
58
|
+
inference: Perform segmentation inference based on input prompts.
|
59
59
|
prompt_inference: Internal function for prompt-based segmentation inference.
|
60
|
-
generate:
|
61
|
-
setup_model:
|
62
|
-
get_model:
|
63
|
-
postprocess: Post-
|
64
|
-
setup_source:
|
65
|
-
set_image:
|
66
|
-
get_im_features:
|
67
|
-
set_prompts:
|
68
|
-
reset_image:
|
69
|
-
remove_small_regions:
|
60
|
+
generate: Generate segmentation masks for an entire image.
|
61
|
+
setup_model: Initialize the SAM model for inference.
|
62
|
+
get_model: Build and return a SAM model.
|
63
|
+
postprocess: Post-process model outputs to generate final results.
|
64
|
+
setup_source: Set up the data source for inference.
|
65
|
+
set_image: Set and preprocess a single image for inference.
|
66
|
+
get_im_features: Extract image features using the SAM image encoder.
|
67
|
+
set_prompts: Set prompts for subsequent inference.
|
68
|
+
reset_image: Reset the current image and its features.
|
69
|
+
remove_small_regions: Remove small disconnected regions and holes from masks.
|
70
70
|
|
71
71
|
Examples:
|
72
72
|
>>> predictor = Predictor()
|
@@ -86,8 +86,8 @@ class Predictor(BasePredictor):
|
|
86
86
|
|
87
87
|
Args:
|
88
88
|
cfg (dict): Configuration dictionary containing default settings.
|
89
|
-
overrides (
|
90
|
-
_callbacks (
|
89
|
+
overrides (dict | None): Dictionary of values to override default configuration.
|
90
|
+
_callbacks (dict | None): Dictionary of callback functions to customize behavior.
|
91
91
|
|
92
92
|
Examples:
|
93
93
|
>>> predictor_example = Predictor(cfg=DEFAULT_CFG)
|
@@ -115,7 +115,7 @@ class Predictor(BasePredictor):
|
|
115
115
|
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
116
116
|
|
117
117
|
Returns:
|
118
|
-
|
118
|
+
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
119
119
|
|
120
120
|
Examples:
|
121
121
|
>>> predictor = Predictor()
|
@@ -182,9 +182,9 @@ class Predictor(BasePredictor):
|
|
182
182
|
**kwargs (Any): Additional keyword arguments.
|
183
183
|
|
184
184
|
Returns:
|
185
|
-
(np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
|
186
|
-
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
187
|
-
(np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
185
|
+
pred_masks (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
|
186
|
+
pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
187
|
+
pred_logits (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
188
188
|
|
189
189
|
Examples:
|
190
190
|
>>> predictor = Predictor()
|
@@ -205,7 +205,7 @@ class Predictor(BasePredictor):
|
|
205
205
|
|
206
206
|
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
207
207
|
"""
|
208
|
-
|
208
|
+
Perform image segmentation inference based on input cues using SAM's specialized architecture.
|
209
209
|
|
210
210
|
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
|
211
211
|
It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
|
@@ -218,12 +218,9 @@ class Predictor(BasePredictor):
|
|
218
218
|
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
219
219
|
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
220
220
|
|
221
|
-
Raises:
|
222
|
-
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
223
|
-
|
224
221
|
Returns:
|
225
|
-
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
226
|
-
(np.ndarray): Quality scores predicted by the model for each mask, with length C.
|
222
|
+
pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
223
|
+
pred_scores (np.ndarray): Quality scores predicted by the model for each mask, with length C.
|
227
224
|
|
228
225
|
Examples:
|
229
226
|
>>> predictor = Predictor()
|
@@ -253,20 +250,23 @@ class Predictor(BasePredictor):
|
|
253
250
|
|
254
251
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
255
252
|
"""
|
256
|
-
|
253
|
+
Prepare and transform the input prompts for processing based on the destination shape.
|
257
254
|
|
258
255
|
Args:
|
259
256
|
dst_shape (tuple): The target shape (height, width) for the prompts.
|
260
257
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
261
258
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
262
259
|
labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
|
263
|
-
masks (List | np.ndarray
|
260
|
+
masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
bboxes (torch.Tensor | None): Transformed bounding boxes.
|
264
|
+
points (torch.Tensor | None): Transformed points.
|
265
|
+
labels (torch.Tensor | None): Transformed labels.
|
266
|
+
masks (torch.Tensor | None): Transformed masks.
|
264
267
|
|
265
268
|
Raises:
|
266
269
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
267
|
-
|
268
|
-
Returns:
|
269
|
-
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
270
270
|
"""
|
271
271
|
src_shape = self.batch[1][0].shape[:2]
|
272
272
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
@@ -407,7 +407,7 @@ class Predictor(BasePredictor):
|
|
407
407
|
|
408
408
|
def setup_model(self, model=None, verbose=True):
|
409
409
|
"""
|
410
|
-
|
410
|
+
Initialize the Segment Anything Model (SAM) for inference.
|
411
411
|
|
412
412
|
This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
|
413
413
|
parameters for image normalization and other Ultralytics compatibility settings.
|
@@ -437,20 +437,20 @@ class Predictor(BasePredictor):
|
|
437
437
|
self.done_warmup = True
|
438
438
|
|
439
439
|
def get_model(self):
|
440
|
-
"""
|
440
|
+
"""Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
|
441
441
|
from .build import build_sam # slow import
|
442
442
|
|
443
443
|
return build_sam(self.args.model)
|
444
444
|
|
445
445
|
def postprocess(self, preds, img, orig_imgs):
|
446
446
|
"""
|
447
|
-
Post-
|
447
|
+
Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
|
448
448
|
|
449
449
|
This method scales masks and boxes to the original image size and applies a threshold to the mask
|
450
450
|
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
|
451
451
|
|
452
452
|
Args:
|
453
|
-
preds (
|
453
|
+
preds (tuple): The output from SAM model inference, containing:
|
454
454
|
- pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
|
455
455
|
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
|
456
456
|
- pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
|
@@ -458,7 +458,7 @@ class Predictor(BasePredictor):
|
|
458
458
|
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
459
459
|
|
460
460
|
Returns:
|
461
|
-
|
461
|
+
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
462
462
|
metadata for each processed image.
|
463
463
|
|
464
464
|
Examples:
|
@@ -495,7 +495,7 @@ class Predictor(BasePredictor):
|
|
495
495
|
|
496
496
|
def setup_source(self, source):
|
497
497
|
"""
|
498
|
-
|
498
|
+
Set up the data source for inference.
|
499
499
|
|
500
500
|
This method configures the data source from which images will be fetched for inference. It supports
|
501
501
|
various input types such as image files, directories, video files, and other compatible data sources.
|
@@ -520,7 +520,7 @@ class Predictor(BasePredictor):
|
|
520
520
|
|
521
521
|
def set_image(self, image):
|
522
522
|
"""
|
523
|
-
|
523
|
+
Preprocess and set a single image for inference.
|
524
524
|
|
525
525
|
This method prepares the model for inference on a single image by setting up the model if not already
|
526
526
|
initialized, configuring the data source, and preprocessing the image for feature extraction. It
|
@@ -530,14 +530,14 @@ class Predictor(BasePredictor):
|
|
530
530
|
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
|
531
531
|
an image read by cv2.
|
532
532
|
|
533
|
-
Raises:
|
534
|
-
AssertionError: If more than one image is attempted to be set.
|
535
|
-
|
536
533
|
Examples:
|
537
534
|
>>> predictor = Predictor()
|
538
535
|
>>> predictor.set_image("path/to/image.jpg")
|
539
536
|
>>> predictor.set_image(cv2.imread("path/to/image.jpg"))
|
540
537
|
|
538
|
+
Raises:
|
539
|
+
AssertionError: If more than one image is attempted to be set.
|
540
|
+
|
541
541
|
Notes:
|
542
542
|
- This method should be called before performing inference on a new image.
|
543
543
|
- The extracted features are stored in the `self.features` attribute for later use.
|
@@ -552,7 +552,7 @@ class Predictor(BasePredictor):
|
|
552
552
|
break
|
553
553
|
|
554
554
|
def get_im_features(self, im):
|
555
|
-
"""
|
555
|
+
"""Extract image features using the SAM model's image encoder for subsequent mask prediction."""
|
556
556
|
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
557
557
|
f"SAM models only support square image size, but got {self.imgsz}."
|
558
558
|
)
|
@@ -560,11 +560,11 @@ class Predictor(BasePredictor):
|
|
560
560
|
return self.model.image_encoder(im)
|
561
561
|
|
562
562
|
def set_prompts(self, prompts):
|
563
|
-
"""
|
563
|
+
"""Set prompts for subsequent inference operations."""
|
564
564
|
self.prompts = prompts
|
565
565
|
|
566
566
|
def reset_image(self):
|
567
|
-
"""
|
567
|
+
"""Reset the current image and its features, clearing them for subsequent inference."""
|
568
568
|
self.im = None
|
569
569
|
self.features = None
|
570
570
|
|
@@ -630,18 +630,18 @@ class SAM2Predictor(Predictor):
|
|
630
630
|
prompt-based inference.
|
631
631
|
|
632
632
|
Attributes:
|
633
|
-
_bb_feat_sizes (List[
|
633
|
+
_bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.
|
634
634
|
model (torch.nn.Module): The loaded SAM2 model.
|
635
635
|
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
636
|
-
features (
|
636
|
+
features (dict): Cached image features for efficient inference.
|
637
637
|
segment_all (bool): Flag to indicate if all segments should be predicted.
|
638
|
-
prompts (
|
638
|
+
prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.
|
639
639
|
|
640
640
|
Methods:
|
641
|
-
get_model:
|
642
|
-
prompt_inference:
|
643
|
-
set_image:
|
644
|
-
get_im_features:
|
641
|
+
get_model: Retrieve and initialize the SAM2 model.
|
642
|
+
prompt_inference: Perform image segmentation inference based on various prompts.
|
643
|
+
set_image: Preprocess and set a single image for inference.
|
644
|
+
get_im_features: Extract and process image features using SAM2's image encoder.
|
645
645
|
|
646
646
|
Examples:
|
647
647
|
>>> predictor = SAM2Predictor(cfg)
|
@@ -658,7 +658,7 @@ class SAM2Predictor(Predictor):
|
|
658
658
|
]
|
659
659
|
|
660
660
|
def get_model(self):
|
661
|
-
"""
|
661
|
+
"""Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
|
662
662
|
from .build import build_sam # slow import
|
663
663
|
|
664
664
|
return build_sam(self.args.model)
|
@@ -674,7 +674,7 @@ class SAM2Predictor(Predictor):
|
|
674
674
|
img_idx=-1,
|
675
675
|
):
|
676
676
|
"""
|
677
|
-
|
677
|
+
Perform image segmentation inference based on various prompts using SAM2 architecture.
|
678
678
|
|
679
679
|
This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
|
680
680
|
based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
|
@@ -690,8 +690,8 @@ class SAM2Predictor(Predictor):
|
|
690
690
|
img_idx (int): Index of the image in the batch to process.
|
691
691
|
|
692
692
|
Returns:
|
693
|
-
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
694
|
-
(np.ndarray): Quality scores for each mask, with length C.
|
693
|
+
pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
694
|
+
pred_scores (np.ndarray): Quality scores for each mask, with length C.
|
695
695
|
|
696
696
|
Examples:
|
697
697
|
>>> predictor = SAM2Predictor(cfg)
|
@@ -733,20 +733,22 @@ class SAM2Predictor(Predictor):
|
|
733
733
|
|
734
734
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
735
735
|
"""
|
736
|
-
|
736
|
+
Prepare and transform the input prompts for processing based on the destination shape.
|
737
737
|
|
738
738
|
Args:
|
739
739
|
dst_shape (tuple): The target shape (height, width) for the prompts.
|
740
740
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
741
741
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
742
742
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
743
|
-
masks (List | np.ndarray
|
743
|
+
masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
|
744
|
+
|
745
|
+
Returns:
|
746
|
+
points (torch.Tensor | None): Transformed points.
|
747
|
+
labels (torch.Tensor | None): Transformed labels.
|
748
|
+
masks (torch.Tensor | None): Transformed masks.
|
744
749
|
|
745
750
|
Raises:
|
746
751
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
747
|
-
|
748
|
-
Returns:
|
749
|
-
(tuple): A tuple containing transformed points, labels, and masks.
|
750
752
|
"""
|
751
753
|
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
752
754
|
if bboxes is not None:
|
@@ -763,7 +765,7 @@ class SAM2Predictor(Predictor):
|
|
763
765
|
|
764
766
|
def set_image(self, image):
|
765
767
|
"""
|
766
|
-
|
768
|
+
Preprocess and set a single image for inference using the SAM2 model.
|
767
769
|
|
768
770
|
This method initializes the model if not already done, configures the data source to the specified image,
|
769
771
|
and preprocesses the image for feature extraction. It supports setting only one image at a time.
|
@@ -771,14 +773,14 @@ class SAM2Predictor(Predictor):
|
|
771
773
|
Args:
|
772
774
|
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
|
773
775
|
|
774
|
-
Raises:
|
775
|
-
AssertionError: If more than one image is attempted to be set.
|
776
|
-
|
777
776
|
Examples:
|
778
777
|
>>> predictor = SAM2Predictor()
|
779
778
|
>>> predictor.set_image("path/to/image.jpg")
|
780
779
|
>>> predictor.set_image(np.array([...])) # Using a numpy array
|
781
780
|
|
781
|
+
Raises:
|
782
|
+
AssertionError: If more than one image is attempted to be set.
|
783
|
+
|
782
784
|
Notes:
|
783
785
|
- This method must be called before performing any inference on a new image.
|
784
786
|
- The method caches the extracted features for efficient subsequent inferences on the same image.
|
@@ -794,7 +796,7 @@ class SAM2Predictor(Predictor):
|
|
794
796
|
break
|
795
797
|
|
796
798
|
def get_im_features(self, im):
|
797
|
-
"""
|
799
|
+
"""Extract image features from the SAM image encoder for subsequent processing."""
|
798
800
|
assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
|
799
801
|
f"SAM 2 models only support square image size, but got {self.imgsz}."
|
800
802
|
)
|
@@ -827,10 +829,20 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
827
829
|
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
|
828
830
|
callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.
|
829
831
|
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
832
|
+
Methods:
|
833
|
+
get_model: Retrieve and configure the model with binarization enabled.
|
834
|
+
inference: Perform image segmentation inference based on the given input cues.
|
835
|
+
postprocess: Post-process the predictions to apply non-overlapping constraints if required.
|
836
|
+
add_new_prompts: Add new points or masks to a specific frame for a given object ID.
|
837
|
+
propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.
|
838
|
+
init_state: Initialize an inference state for the predictor.
|
839
|
+
get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
|
840
|
+
|
841
|
+
Examples:
|
842
|
+
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
843
|
+
>>> predictor.set_image("path/to/video_frame.jpg")
|
844
|
+
>>> bboxes = [[100, 100, 200, 200]]
|
845
|
+
>>> results = predictor(bboxes=bboxes)
|
834
846
|
|
835
847
|
Note:
|
836
848
|
The `fill_hole_area` attribute is defined but not used in the current implementation.
|
@@ -848,8 +860,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
848
860
|
|
849
861
|
Args:
|
850
862
|
cfg (dict): Configuration dictionary containing default settings.
|
851
|
-
overrides (
|
852
|
-
_callbacks (
|
863
|
+
overrides (dict | None): Dictionary of values to override default configuration.
|
864
|
+
_callbacks (dict | None): Dictionary of callback functions to customize behavior.
|
853
865
|
|
854
866
|
Examples:
|
855
867
|
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
@@ -865,7 +877,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
865
877
|
|
866
878
|
def get_model(self):
|
867
879
|
"""
|
868
|
-
|
880
|
+
Retrieve and configure the model with binarization enabled.
|
869
881
|
|
870
882
|
Note:
|
871
883
|
This method overrides the base class implementation to set the binarize flag to True.
|
@@ -888,8 +900,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
888
900
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
889
901
|
|
890
902
|
Returns:
|
891
|
-
(np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
|
892
|
-
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
903
|
+
pred_masks (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
|
904
|
+
pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
893
905
|
"""
|
894
906
|
# Override prompts if any stored in self.prompts
|
895
907
|
bboxes = self.prompts.pop("bboxes", bboxes)
|
@@ -947,19 +959,19 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
947
959
|
|
948
960
|
def postprocess(self, preds, img, orig_imgs):
|
949
961
|
"""
|
950
|
-
Post-
|
962
|
+
Post-process the predictions to apply non-overlapping constraints if required.
|
951
963
|
|
952
964
|
This method extends the post-processing functionality by applying non-overlapping constraints
|
953
965
|
to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
|
954
966
|
the masks do not overlap, which can be useful for certain applications.
|
955
967
|
|
956
968
|
Args:
|
957
|
-
preds (
|
969
|
+
preds (tuple): The predictions from the model.
|
958
970
|
img (torch.Tensor): The processed image tensor.
|
959
971
|
orig_imgs (List[np.ndarray]): The original images before processing.
|
960
972
|
|
961
973
|
Returns:
|
962
|
-
|
974
|
+
(list): The post-processed predictions.
|
963
975
|
|
964
976
|
Note:
|
965
977
|
If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
|
@@ -982,7 +994,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
982
994
|
frame_idx=0,
|
983
995
|
):
|
984
996
|
"""
|
985
|
-
|
997
|
+
Add new points or masks to a specific frame for a given object ID.
|
986
998
|
|
987
999
|
This method updates the inference state with new prompts (points or masks) for a specified
|
988
1000
|
object and frame index. It ensures that the prompts are either points or masks, but not both,
|
@@ -991,13 +1003,14 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
991
1003
|
|
992
1004
|
Args:
|
993
1005
|
obj_id (int): The ID of the object to which the prompts are associated.
|
994
|
-
points (torch.Tensor,
|
995
|
-
labels (torch.Tensor,
|
996
|
-
masks (torch.Tensor, optional): Binary masks for the object.
|
997
|
-
frame_idx (int, optional): The index of the frame to which the prompts are applied.
|
1006
|
+
points (torch.Tensor, optional): The coordinates of the points of interest.
|
1007
|
+
labels (torch.Tensor, optional): The labels corresponding to the points.
|
1008
|
+
masks (torch.Tensor, optional): Binary masks for the object.
|
1009
|
+
frame_idx (int, optional): The index of the frame to which the prompts are applied.
|
998
1010
|
|
999
1011
|
Returns:
|
1000
|
-
(
|
1012
|
+
pred_masks (torch.Tensor): The flattened predicted masks.
|
1013
|
+
pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.
|
1001
1014
|
|
1002
1015
|
Raises:
|
1003
1016
|
AssertionError: If both `masks` and `points` are provided, or neither is provided.
|
@@ -1194,16 +1207,16 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1194
1207
|
|
1195
1208
|
def get_im_features(self, im, batch=1):
|
1196
1209
|
"""
|
1197
|
-
|
1210
|
+
Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
|
1198
1211
|
|
1199
1212
|
Args:
|
1200
1213
|
im (torch.Tensor): The input image tensor.
|
1201
|
-
batch (int, optional): The batch size for expanding features if there are multiple prompts.
|
1214
|
+
batch (int, optional): The batch size for expanding features if there are multiple prompts.
|
1202
1215
|
|
1203
1216
|
Returns:
|
1204
1217
|
vis_feats (torch.Tensor): The visual features extracted from the image.
|
1205
1218
|
vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
|
1206
|
-
feat_sizes (List
|
1219
|
+
feat_sizes (List[tuple]): A list containing the sizes of the extracted features.
|
1207
1220
|
|
1208
1221
|
Note:
|
1209
1222
|
- If `batch` is greater than 1, the features are expanded to fit the batch size.
|
@@ -1227,7 +1240,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1227
1240
|
obj_id (int): The unique identifier of the object provided by the client side.
|
1228
1241
|
|
1229
1242
|
Returns:
|
1230
|
-
|
1243
|
+
(int): The index of the object on the model side.
|
1231
1244
|
|
1232
1245
|
Raises:
|
1233
1246
|
RuntimeError: If an attempt is made to add a new object after tracking has started.
|
@@ -1291,14 +1304,14 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1291
1304
|
frame_idx (int): The index of the current frame.
|
1292
1305
|
batch_size (int): The batch size for processing the frame.
|
1293
1306
|
is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
|
1294
|
-
point_inputs (dict
|
1295
|
-
mask_inputs (torch.Tensor
|
1307
|
+
point_inputs (dict | None): Input points and their labels.
|
1308
|
+
mask_inputs (torch.Tensor | None): Input binary masks.
|
1296
1309
|
reverse (bool): Indicates if the tracking should be performed in reverse order.
|
1297
1310
|
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
|
1298
|
-
prev_sam_mask_logits (torch.Tensor
|
1311
|
+
prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.
|
1299
1312
|
|
1300
1313
|
Returns:
|
1301
|
-
|
1314
|
+
(dict): A dictionary containing the output of the tracking step, including updated features and predictions.
|
1302
1315
|
|
1303
1316
|
Raises:
|
1304
1317
|
AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
|
@@ -1348,7 +1361,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1348
1361
|
|
1349
1362
|
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
1350
1363
|
"""
|
1351
|
-
|
1364
|
+
Cache and manage the positional encoding for mask memory across frames and objects.
|
1352
1365
|
|
1353
1366
|
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
|
1354
1367
|
mask memory, which is constant across frames and objects, thus reducing the amount of
|
@@ -1358,11 +1371,11 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1358
1371
|
the current batch size.
|
1359
1372
|
|
1360
1373
|
Args:
|
1361
|
-
out_maskmem_pos_enc (List[torch.Tensor]
|
1374
|
+
out_maskmem_pos_enc (List[torch.Tensor] | None): The positional encoding for mask memory.
|
1362
1375
|
Should be a list of tensors or None.
|
1363
1376
|
|
1364
1377
|
Returns:
|
1365
|
-
|
1378
|
+
(List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
1366
1379
|
|
1367
1380
|
Note:
|
1368
1381
|
- The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
|
@@ -1393,7 +1406,7 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1393
1406
|
run_mem_encoder=False,
|
1394
1407
|
):
|
1395
1408
|
"""
|
1396
|
-
|
1409
|
+
Consolidate per-object temporary outputs into a single output for all objects.
|
1397
1410
|
|
1398
1411
|
This method combines the temporary outputs for each object on a given frame into a unified
|
1399
1412
|
output. It fills in any missing objects either from the main output dictionary or leaves
|
@@ -1402,13 +1415,12 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1402
1415
|
|
1403
1416
|
Args:
|
1404
1417
|
frame_idx (int): The index of the frame for which to consolidate outputs.
|
1405
|
-
is_cond (bool,
|
1406
|
-
|
1407
|
-
|
1408
|
-
consolidating the outputs. Defaults to False.
|
1418
|
+
is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
|
1419
|
+
run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after
|
1420
|
+
consolidating the outputs.
|
1409
1421
|
|
1410
1422
|
Returns:
|
1411
|
-
|
1423
|
+
(dict): A consolidated output dictionary containing the combined results for all objects.
|
1412
1424
|
|
1413
1425
|
Note:
|
1414
1426
|
- The method initializes the consolidated output with placeholder values for missing objects.
|
@@ -1538,7 +1550,8 @@ class SAM2VideoPredictor(SAM2Predictor):
|
|
1538
1550
|
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
|
1539
1551
|
|
1540
1552
|
Returns:
|
1541
|
-
(
|
1553
|
+
maskmem_features (torch.Tensor): The encoded mask features.
|
1554
|
+
maskmem_pos_enc (torch.Tensor): The positional encoding.
|
1542
1555
|
"""
|
1543
1556
|
# Retrieve correct image features
|
1544
1557
|
current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
|