ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -47,26 +47,26 @@ class Predictor(BasePredictor):
47
47
  device (torch.device): The device (CPU or GPU) on which the model is loaded.
48
48
  im (torch.Tensor): The preprocessed input image.
49
49
  features (torch.Tensor): Extracted image features.
50
- prompts (dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
50
+ prompts (Dict[str, Any]): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
51
51
  segment_all (bool): Flag to indicate if full image segmentation should be performed.
52
52
  mean (torch.Tensor): Mean values for image normalization.
53
53
  std (torch.Tensor): Standard deviation values for image normalization.
54
54
 
55
55
  Methods:
56
- preprocess: Prepares input images for model inference.
57
- pre_transform: Performs initial transformations on the input image.
58
- inference: Performs segmentation inference based on input prompts.
56
+ preprocess: Prepare input images for model inference.
57
+ pre_transform: Perform initial transformations on the input image.
58
+ inference: Perform segmentation inference based on input prompts.
59
59
  prompt_inference: Internal function for prompt-based segmentation inference.
60
- generate: Generates segmentation masks for an entire image.
61
- setup_model: Initializes the SAM model for inference.
62
- get_model: Builds and returns a SAM model.
63
- postprocess: Post-processes model outputs to generate final results.
64
- setup_source: Sets up the data source for inference.
65
- set_image: Sets and preprocesses a single image for inference.
66
- get_im_features: Extracts image features using the SAM image encoder.
67
- set_prompts: Sets prompts for subsequent inference.
68
- reset_image: Resets the current image and its features.
69
- remove_small_regions: Removes small disconnected regions and holes from masks.
60
+ generate: Generate segmentation masks for an entire image.
61
+ setup_model: Initialize the SAM model for inference.
62
+ get_model: Build and return a SAM model.
63
+ postprocess: Post-process model outputs to generate final results.
64
+ setup_source: Set up the data source for inference.
65
+ set_image: Set and preprocess a single image for inference.
66
+ get_im_features: Extract image features using the SAM image encoder.
67
+ set_prompts: Set prompts for subsequent inference.
68
+ reset_image: Reset the current image and its features.
69
+ remove_small_regions: Remove small disconnected regions and holes from masks.
70
70
 
71
71
  Examples:
72
72
  >>> predictor = Predictor()
@@ -86,8 +86,8 @@ class Predictor(BasePredictor):
86
86
 
87
87
  Args:
88
88
  cfg (dict): Configuration dictionary containing default settings.
89
- overrides (Dict | None): Dictionary of values to override default configuration.
90
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
89
+ overrides (dict | None): Dictionary of values to override default configuration.
90
+ _callbacks (dict | None): Dictionary of callback functions to customize behavior.
91
91
 
92
92
  Examples:
93
93
  >>> predictor_example = Predictor(cfg=DEFAULT_CFG)
@@ -115,7 +115,7 @@ class Predictor(BasePredictor):
115
115
  im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
116
116
 
117
117
  Returns:
118
- im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
118
+ (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
119
119
 
120
120
  Examples:
121
121
  >>> predictor = Predictor()
@@ -182,9 +182,9 @@ class Predictor(BasePredictor):
182
182
  **kwargs (Any): Additional keyword arguments.
183
183
 
184
184
  Returns:
185
- (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
186
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
187
- (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
185
+ pred_masks (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
186
+ pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
187
+ pred_logits (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
188
188
 
189
189
  Examples:
190
190
  >>> predictor = Predictor()
@@ -205,7 +205,7 @@ class Predictor(BasePredictor):
205
205
 
206
206
  def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
207
207
  """
208
- Performs image segmentation inference based on input cues using SAM's specialized architecture.
208
+ Perform image segmentation inference based on input cues using SAM's specialized architecture.
209
209
 
210
210
  This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
211
211
  It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
@@ -218,12 +218,9 @@ class Predictor(BasePredictor):
218
218
  masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
219
219
  multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
220
220
 
221
- Raises:
222
- AssertionError: If the number of points don't match the number of labels, in case labels were passed.
223
-
224
221
  Returns:
225
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
226
- (np.ndarray): Quality scores predicted by the model for each mask, with length C.
222
+ pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
223
+ pred_scores (np.ndarray): Quality scores predicted by the model for each mask, with length C.
227
224
 
228
225
  Examples:
229
226
  >>> predictor = Predictor()
@@ -253,20 +250,23 @@ class Predictor(BasePredictor):
253
250
 
254
251
  def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
255
252
  """
256
- Prepares and transforms the input prompts for processing based on the destination shape.
253
+ Prepare and transform the input prompts for processing based on the destination shape.
257
254
 
258
255
  Args:
259
256
  dst_shape (tuple): The target shape (height, width) for the prompts.
260
257
  bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
261
258
  points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
262
259
  labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background.
263
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
260
+ masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
261
+
262
+ Returns:
263
+ bboxes (torch.Tensor | None): Transformed bounding boxes.
264
+ points (torch.Tensor | None): Transformed points.
265
+ labels (torch.Tensor | None): Transformed labels.
266
+ masks (torch.Tensor | None): Transformed masks.
264
267
 
265
268
  Raises:
266
269
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
267
-
268
- Returns:
269
- (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
270
270
  """
271
271
  src_shape = self.batch[1][0].shape[:2]
272
272
  r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
@@ -407,7 +407,7 @@ class Predictor(BasePredictor):
407
407
 
408
408
  def setup_model(self, model=None, verbose=True):
409
409
  """
410
- Initializes the Segment Anything Model (SAM) for inference.
410
+ Initialize the Segment Anything Model (SAM) for inference.
411
411
 
412
412
  This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
413
413
  parameters for image normalization and other Ultralytics compatibility settings.
@@ -437,20 +437,20 @@ class Predictor(BasePredictor):
437
437
  self.done_warmup = True
438
438
 
439
439
  def get_model(self):
440
- """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
440
+ """Retrieve or build the Segment Anything Model (SAM) for image segmentation tasks."""
441
441
  from .build import build_sam # slow import
442
442
 
443
443
  return build_sam(self.args.model)
444
444
 
445
445
  def postprocess(self, preds, img, orig_imgs):
446
446
  """
447
- Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
447
+ Post-process SAM's inference outputs to generate object detection masks and bounding boxes.
448
448
 
449
449
  This method scales masks and boxes to the original image size and applies a threshold to the mask
450
450
  predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
451
451
 
452
452
  Args:
453
- preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
453
+ preds (tuple): The output from SAM model inference, containing:
454
454
  - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
455
455
  - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
456
456
  - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
@@ -458,7 +458,7 @@ class Predictor(BasePredictor):
458
458
  orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
459
459
 
460
460
  Returns:
461
- results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
461
+ (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
462
462
  metadata for each processed image.
463
463
 
464
464
  Examples:
@@ -495,7 +495,7 @@ class Predictor(BasePredictor):
495
495
 
496
496
  def setup_source(self, source):
497
497
  """
498
- Sets up the data source for inference.
498
+ Set up the data source for inference.
499
499
 
500
500
  This method configures the data source from which images will be fetched for inference. It supports
501
501
  various input types such as image files, directories, video files, and other compatible data sources.
@@ -520,7 +520,7 @@ class Predictor(BasePredictor):
520
520
 
521
521
  def set_image(self, image):
522
522
  """
523
- Preprocesses and sets a single image for inference.
523
+ Preprocess and set a single image for inference.
524
524
 
525
525
  This method prepares the model for inference on a single image by setting up the model if not already
526
526
  initialized, configuring the data source, and preprocessing the image for feature extraction. It
@@ -530,14 +530,14 @@ class Predictor(BasePredictor):
530
530
  image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
531
531
  an image read by cv2.
532
532
 
533
- Raises:
534
- AssertionError: If more than one image is attempted to be set.
535
-
536
533
  Examples:
537
534
  >>> predictor = Predictor()
538
535
  >>> predictor.set_image("path/to/image.jpg")
539
536
  >>> predictor.set_image(cv2.imread("path/to/image.jpg"))
540
537
 
538
+ Raises:
539
+ AssertionError: If more than one image is attempted to be set.
540
+
541
541
  Notes:
542
542
  - This method should be called before performing inference on a new image.
543
543
  - The extracted features are stored in the `self.features` attribute for later use.
@@ -552,7 +552,7 @@ class Predictor(BasePredictor):
552
552
  break
553
553
 
554
554
  def get_im_features(self, im):
555
- """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
555
+ """Extract image features using the SAM model's image encoder for subsequent mask prediction."""
556
556
  assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
557
557
  f"SAM models only support square image size, but got {self.imgsz}."
558
558
  )
@@ -560,11 +560,11 @@ class Predictor(BasePredictor):
560
560
  return self.model.image_encoder(im)
561
561
 
562
562
  def set_prompts(self, prompts):
563
- """Sets prompts for subsequent inference operations."""
563
+ """Set prompts for subsequent inference operations."""
564
564
  self.prompts = prompts
565
565
 
566
566
  def reset_image(self):
567
- """Resets the current image and its features, clearing them for subsequent inference."""
567
+ """Reset the current image and its features, clearing them for subsequent inference."""
568
568
  self.im = None
569
569
  self.features = None
570
570
 
@@ -630,18 +630,18 @@ class SAM2Predictor(Predictor):
630
630
  prompt-based inference.
631
631
 
632
632
  Attributes:
633
- _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
633
+ _bb_feat_sizes (List[tuple]): Feature sizes for different backbone levels.
634
634
  model (torch.nn.Module): The loaded SAM2 model.
635
635
  device (torch.device): The device (CPU or GPU) on which the model is loaded.
636
- features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
636
+ features (dict): Cached image features for efficient inference.
637
637
  segment_all (bool): Flag to indicate if all segments should be predicted.
638
- prompts (dict): Dictionary to store various types of prompts for inference.
638
+ prompts (Dict[str, Any]): Dictionary to store various types of prompts for inference.
639
639
 
640
640
  Methods:
641
- get_model: Retrieves and initializes the SAM2 model.
642
- prompt_inference: Performs image segmentation inference based on various prompts.
643
- set_image: Preprocesses and sets a single image for inference.
644
- get_im_features: Extracts and processes image features using SAM2's image encoder.
641
+ get_model: Retrieve and initialize the SAM2 model.
642
+ prompt_inference: Perform image segmentation inference based on various prompts.
643
+ set_image: Preprocess and set a single image for inference.
644
+ get_im_features: Extract and process image features using SAM2's image encoder.
645
645
 
646
646
  Examples:
647
647
  >>> predictor = SAM2Predictor(cfg)
@@ -658,7 +658,7 @@ class SAM2Predictor(Predictor):
658
658
  ]
659
659
 
660
660
  def get_model(self):
661
- """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
661
+ """Retrieve and initialize the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
662
662
  from .build import build_sam # slow import
663
663
 
664
664
  return build_sam(self.args.model)
@@ -674,7 +674,7 @@ class SAM2Predictor(Predictor):
674
674
  img_idx=-1,
675
675
  ):
676
676
  """
677
- Performs image segmentation inference based on various prompts using SAM2 architecture.
677
+ Perform image segmentation inference based on various prompts using SAM2 architecture.
678
678
 
679
679
  This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
680
680
  based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
@@ -690,8 +690,8 @@ class SAM2Predictor(Predictor):
690
690
  img_idx (int): Index of the image in the batch to process.
691
691
 
692
692
  Returns:
693
- (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
694
- (np.ndarray): Quality scores for each mask, with length C.
693
+ pred_masks (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
694
+ pred_scores (np.ndarray): Quality scores for each mask, with length C.
695
695
 
696
696
  Examples:
697
697
  >>> predictor = SAM2Predictor(cfg)
@@ -733,20 +733,22 @@ class SAM2Predictor(Predictor):
733
733
 
734
734
  def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
735
735
  """
736
- Prepares and transforms the input prompts for processing based on the destination shape.
736
+ Prepare and transform the input prompts for processing based on the destination shape.
737
737
 
738
738
  Args:
739
739
  dst_shape (tuple): The target shape (height, width) for the prompts.
740
740
  bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
741
741
  points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
742
742
  labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
743
- masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
743
+ masks (List | np.ndarray | None): Masks for the objects, where each mask is a 2D array.
744
+
745
+ Returns:
746
+ points (torch.Tensor | None): Transformed points.
747
+ labels (torch.Tensor | None): Transformed labels.
748
+ masks (torch.Tensor | None): Transformed masks.
744
749
 
745
750
  Raises:
746
751
  AssertionError: If the number of points don't match the number of labels, in case labels were passed.
747
-
748
- Returns:
749
- (tuple): A tuple containing transformed points, labels, and masks.
750
752
  """
751
753
  bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
752
754
  if bboxes is not None:
@@ -763,7 +765,7 @@ class SAM2Predictor(Predictor):
763
765
 
764
766
  def set_image(self, image):
765
767
  """
766
- Preprocesses and sets a single image for inference using the SAM2 model.
768
+ Preprocess and set a single image for inference using the SAM2 model.
767
769
 
768
770
  This method initializes the model if not already done, configures the data source to the specified image,
769
771
  and preprocesses the image for feature extraction. It supports setting only one image at a time.
@@ -771,14 +773,14 @@ class SAM2Predictor(Predictor):
771
773
  Args:
772
774
  image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
773
775
 
774
- Raises:
775
- AssertionError: If more than one image is attempted to be set.
776
-
777
776
  Examples:
778
777
  >>> predictor = SAM2Predictor()
779
778
  >>> predictor.set_image("path/to/image.jpg")
780
779
  >>> predictor.set_image(np.array([...])) # Using a numpy array
781
780
 
781
+ Raises:
782
+ AssertionError: If more than one image is attempted to be set.
783
+
782
784
  Notes:
783
785
  - This method must be called before performing any inference on a new image.
784
786
  - The method caches the extracted features for efficient subsequent inferences on the same image.
@@ -794,7 +796,7 @@ class SAM2Predictor(Predictor):
794
796
  break
795
797
 
796
798
  def get_im_features(self, im):
797
- """Extracts image features from the SAM image encoder for subsequent processing."""
799
+ """Extract image features from the SAM image encoder for subsequent processing."""
798
800
  assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], (
799
801
  f"SAM 2 models only support square image size, but got {self.imgsz}."
800
802
  )
@@ -827,10 +829,20 @@ class SAM2VideoPredictor(SAM2Predictor):
827
829
  clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
828
830
  callbacks (dict): A dictionary of callbacks for various prediction lifecycle events.
829
831
 
830
- Args:
831
- cfg (dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
832
- overrides (dict, Optional): Additional configuration overrides. Defaults to None.
833
- _callbacks (list, Optional): Custom callbacks to be added. Defaults to None.
832
+ Methods:
833
+ get_model: Retrieve and configure the model with binarization enabled.
834
+ inference: Perform image segmentation inference based on the given input cues.
835
+ postprocess: Post-process the predictions to apply non-overlapping constraints if required.
836
+ add_new_prompts: Add new points or masks to a specific frame for a given object ID.
837
+ propagate_in_video_preflight: Prepare inference_state and consolidate temporary outputs before tracking.
838
+ init_state: Initialize an inference state for the predictor.
839
+ get_im_features: Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
840
+
841
+ Examples:
842
+ >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
843
+ >>> predictor.set_image("path/to/video_frame.jpg")
844
+ >>> bboxes = [[100, 100, 200, 200]]
845
+ >>> results = predictor(bboxes=bboxes)
834
846
 
835
847
  Note:
836
848
  The `fill_hole_area` attribute is defined but not used in the current implementation.
@@ -848,8 +860,8 @@ class SAM2VideoPredictor(SAM2Predictor):
848
860
 
849
861
  Args:
850
862
  cfg (dict): Configuration dictionary containing default settings.
851
- overrides (Dict | None): Dictionary of values to override default configuration.
852
- _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
863
+ overrides (dict | None): Dictionary of values to override default configuration.
864
+ _callbacks (dict | None): Dictionary of callback functions to customize behavior.
853
865
 
854
866
  Examples:
855
867
  >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
@@ -865,7 +877,7 @@ class SAM2VideoPredictor(SAM2Predictor):
865
877
 
866
878
  def get_model(self):
867
879
  """
868
- Retrieves and configures the model with binarization enabled.
880
+ Retrieve and configure the model with binarization enabled.
869
881
 
870
882
  Note:
871
883
  This method overrides the base class implementation to set the binarize flag to True.
@@ -888,8 +900,8 @@ class SAM2VideoPredictor(SAM2Predictor):
888
900
  masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
889
901
 
890
902
  Returns:
891
- (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
892
- (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
903
+ pred_masks (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
904
+ pred_scores (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
893
905
  """
894
906
  # Override prompts if any stored in self.prompts
895
907
  bboxes = self.prompts.pop("bboxes", bboxes)
@@ -947,19 +959,19 @@ class SAM2VideoPredictor(SAM2Predictor):
947
959
 
948
960
  def postprocess(self, preds, img, orig_imgs):
949
961
  """
950
- Post-processes the predictions to apply non-overlapping constraints if required.
962
+ Post-process the predictions to apply non-overlapping constraints if required.
951
963
 
952
964
  This method extends the post-processing functionality by applying non-overlapping constraints
953
965
  to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
954
966
  the masks do not overlap, which can be useful for certain applications.
955
967
 
956
968
  Args:
957
- preds (Tuple[torch.Tensor]): The predictions from the model.
969
+ preds (tuple): The predictions from the model.
958
970
  img (torch.Tensor): The processed image tensor.
959
971
  orig_imgs (List[np.ndarray]): The original images before processing.
960
972
 
961
973
  Returns:
962
- results (list): The post-processed predictions.
974
+ (list): The post-processed predictions.
963
975
 
964
976
  Note:
965
977
  If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
@@ -982,7 +994,7 @@ class SAM2VideoPredictor(SAM2Predictor):
982
994
  frame_idx=0,
983
995
  ):
984
996
  """
985
- Adds new points or masks to a specific frame for a given object ID.
997
+ Add new points or masks to a specific frame for a given object ID.
986
998
 
987
999
  This method updates the inference state with new prompts (points or masks) for a specified
988
1000
  object and frame index. It ensures that the prompts are either points or masks, but not both,
@@ -991,13 +1003,14 @@ class SAM2VideoPredictor(SAM2Predictor):
991
1003
 
992
1004
  Args:
993
1005
  obj_id (int): The ID of the object to which the prompts are associated.
994
- points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
995
- labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
996
- masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
997
- frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
1006
+ points (torch.Tensor, optional): The coordinates of the points of interest.
1007
+ labels (torch.Tensor, optional): The labels corresponding to the points.
1008
+ masks (torch.Tensor, optional): Binary masks for the object.
1009
+ frame_idx (int, optional): The index of the frame to which the prompts are applied.
998
1010
 
999
1011
  Returns:
1000
- (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
1012
+ pred_masks (torch.Tensor): The flattened predicted masks.
1013
+ pred_scores (torch.Tensor): A tensor of ones indicating the number of objects.
1001
1014
 
1002
1015
  Raises:
1003
1016
  AssertionError: If both `masks` and `points` are provided, or neither is provided.
@@ -1194,16 +1207,16 @@ class SAM2VideoPredictor(SAM2Predictor):
1194
1207
 
1195
1208
  def get_im_features(self, im, batch=1):
1196
1209
  """
1197
- Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
1210
+ Extract and process image features using SAM2's image encoder for subsequent segmentation tasks.
1198
1211
 
1199
1212
  Args:
1200
1213
  im (torch.Tensor): The input image tensor.
1201
- batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
1214
+ batch (int, optional): The batch size for expanding features if there are multiple prompts.
1202
1215
 
1203
1216
  Returns:
1204
1217
  vis_feats (torch.Tensor): The visual features extracted from the image.
1205
1218
  vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
1206
- feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
1219
+ feat_sizes (List[tuple]): A list containing the sizes of the extracted features.
1207
1220
 
1208
1221
  Note:
1209
1222
  - If `batch` is greater than 1, the features are expanded to fit the batch size.
@@ -1227,7 +1240,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1227
1240
  obj_id (int): The unique identifier of the object provided by the client side.
1228
1241
 
1229
1242
  Returns:
1230
- obj_idx (int): The index of the object on the model side.
1243
+ (int): The index of the object on the model side.
1231
1244
 
1232
1245
  Raises:
1233
1246
  RuntimeError: If an attempt is made to add a new object after tracking has started.
@@ -1291,14 +1304,14 @@ class SAM2VideoPredictor(SAM2Predictor):
1291
1304
  frame_idx (int): The index of the current frame.
1292
1305
  batch_size (int): The batch size for processing the frame.
1293
1306
  is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
1294
- point_inputs (dict, Optional): Input points and their labels. Defaults to None.
1295
- mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
1307
+ point_inputs (dict | None): Input points and their labels.
1308
+ mask_inputs (torch.Tensor | None): Input binary masks.
1296
1309
  reverse (bool): Indicates if the tracking should be performed in reverse order.
1297
1310
  run_mem_encoder (bool): Indicates if the memory encoder should be executed.
1298
- prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
1311
+ prev_sam_mask_logits (torch.Tensor | None): Previous mask logits for the current object.
1299
1312
 
1300
1313
  Returns:
1301
- current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
1314
+ (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
1302
1315
 
1303
1316
  Raises:
1304
1317
  AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
@@ -1348,7 +1361,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1348
1361
 
1349
1362
  def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
1350
1363
  """
1351
- Caches and manages the positional encoding for mask memory across frames and objects.
1364
+ Cache and manage the positional encoding for mask memory across frames and objects.
1352
1365
 
1353
1366
  This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
1354
1367
  mask memory, which is constant across frames and objects, thus reducing the amount of
@@ -1358,11 +1371,11 @@ class SAM2VideoPredictor(SAM2Predictor):
1358
1371
  the current batch size.
1359
1372
 
1360
1373
  Args:
1361
- out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
1374
+ out_maskmem_pos_enc (List[torch.Tensor] | None): The positional encoding for mask memory.
1362
1375
  Should be a list of tensors or None.
1363
1376
 
1364
1377
  Returns:
1365
- out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
1378
+ (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
1366
1379
 
1367
1380
  Note:
1368
1381
  - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
@@ -1393,7 +1406,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1393
1406
  run_mem_encoder=False,
1394
1407
  ):
1395
1408
  """
1396
- Consolidates per-object temporary outputs into a single output for all objects.
1409
+ Consolidate per-object temporary outputs into a single output for all objects.
1397
1410
 
1398
1411
  This method combines the temporary outputs for each object on a given frame into a unified
1399
1412
  output. It fills in any missing objects either from the main output dictionary or leaves
@@ -1402,13 +1415,12 @@ class SAM2VideoPredictor(SAM2Predictor):
1402
1415
 
1403
1416
  Args:
1404
1417
  frame_idx (int): The index of the frame for which to consolidate outputs.
1405
- is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
1406
- Defaults to False.
1407
- run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
1408
- consolidating the outputs. Defaults to False.
1418
+ is_cond (bool, optional): Indicates if the frame is considered a conditioning frame.
1419
+ run_mem_encoder (bool, optional): Specifies whether to run the memory encoder after
1420
+ consolidating the outputs.
1409
1421
 
1410
1422
  Returns:
1411
- consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
1423
+ (dict): A consolidated output dictionary containing the combined results for all objects.
1412
1424
 
1413
1425
  Note:
1414
1426
  - The method initializes the consolidated output with placeholder values for missing objects.
@@ -1538,7 +1550,8 @@ class SAM2VideoPredictor(SAM2Predictor):
1538
1550
  is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
1539
1551
 
1540
1552
  Returns:
1541
- (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
1553
+ maskmem_features (torch.Tensor): The encoded mask features.
1554
+ maskmem_pos_enc (torch.Tensor): The positional encoding.
1542
1555
  """
1543
1556
  # Retrieve correct image features
1544
1557
  current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)