ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -22,13 +22,20 @@ class RTDETRDataset(YOLODataset):
22
22
  """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
23
23
  super().__init__(*args, data=data, **kwargs)
24
24
 
25
- # NOTE: add stretch version load_image for RTDETR mosaic
26
25
  def load_image(self, i, rect_mode=False):
27
26
  """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
28
27
  return super().load_image(i=i, rect_mode=rect_mode)
29
28
 
30
29
  def build_transforms(self, hyp=None):
31
- """Temporary, only for evaluation."""
30
+ """
31
+ Build transformation pipeline for the dataset.
32
+
33
+ Args:
34
+ hyp (Dict, optional): Hyperparameters for transformations.
35
+
36
+ Returns:
37
+ (Compose): Composition of transformation functions.
38
+ """
32
39
  if self.augment:
33
40
  hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
34
41
  hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
@@ -58,14 +65,11 @@ class RTDETRValidator(DetectionValidator):
58
65
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
59
66
  post-processing, and updates evaluation metrics accordingly.
60
67
 
61
- Example:
62
- ```python
63
- from ultralytics.models.rtdetr import RTDETRValidator
64
-
65
- args = dict(model="rtdetr-l.pt", data="coco8.yaml")
66
- validator = RTDETRValidator(args=args)
67
- validator()
68
- ```
68
+ Examples:
69
+ >>> from ultralytics.models.rtdetr import RTDETRValidator
70
+ >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
71
+ >>> validator = RTDETRValidator(args=args)
72
+ >>> validator()
69
73
 
70
74
  Note:
71
75
  For further details on the attributes and methods, refer to the parent DetectionValidator class.
@@ -78,7 +82,10 @@ class RTDETRValidator(DetectionValidator):
78
82
  Args:
79
83
  img_path (str): Path to the folder containing images.
80
84
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
81
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
85
+ batch (int, optional): Size of batches, this is for `rect`.
86
+
87
+ Returns:
88
+ (RTDETRDataset): Dataset configured for RT-DETR validation.
82
89
  """
83
90
  return RTDETRDataset(
84
91
  img_path=img_path,
@@ -93,7 +100,15 @@ class RTDETRValidator(DetectionValidator):
93
100
  )
94
101
 
95
102
  def postprocess(self, preds):
96
- """Apply Non-maximum suppression to prediction outputs."""
103
+ """
104
+ Apply Non-maximum suppression to prediction outputs.
105
+
106
+ Args:
107
+ preds (List | Tuple | torch.Tensor): Raw predictions from the model.
108
+
109
+ Returns:
110
+ (List[torch.Tensor]): List of processed predictions for each image in batch.
111
+ """
97
112
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
98
113
  preds = [preds, None]
99
114
 
@@ -114,7 +129,16 @@ class RTDETRValidator(DetectionValidator):
114
129
  return outputs
115
130
 
116
131
  def _prepare_batch(self, si, batch):
117
- """Prepares a batch for training or inference by applying transformations."""
132
+ """
133
+ Prepares a batch for validation by applying necessary transformations.
134
+
135
+ Args:
136
+ si (int): Batch index.
137
+ batch (Dict): Batch data containing images and annotations.
138
+
139
+ Returns:
140
+ (Dict): Prepared batch with transformed annotations.
141
+ """
118
142
  idx = batch["batch_idx"] == si
119
143
  cls = batch["cls"][idx].squeeze(-1)
120
144
  bbox = batch["bboxes"][idx]
@@ -128,7 +152,16 @@ class RTDETRValidator(DetectionValidator):
128
152
  return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
129
153
 
130
154
  def _prepare_pred(self, pred, pbatch):
131
- """Prepares and returns a batch with transformed bounding boxes and class labels."""
155
+ """
156
+ Prepares predictions by scaling bounding boxes to original image dimensions.
157
+
158
+ Args:
159
+ pred (torch.Tensor): Raw predictions.
160
+ pbatch (Dict): Prepared batch information.
161
+
162
+ Returns:
163
+ (torch.Tensor): Predictions scaled to original image dimensions.
164
+ """
132
165
  predn = pred.clone()
133
166
  predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
134
167
  predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
@@ -3,4 +3,4 @@
3
3
  from .model import SAM
4
4
  from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
5
5
 
6
- __all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
6
+ __all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list of exportable items
@@ -76,7 +76,24 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
76
76
  def generate_crop_boxes(
77
77
  im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
78
78
  ) -> Tuple[List[List[int]], List[int]]:
79
- """Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions."""
79
+ """
80
+ Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
81
+
82
+ Args:
83
+ im_size (Tuple[int, ...]): Height and width of the input image.
84
+ n_layers (int): Number of layers to generate crop boxes for.
85
+ overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
86
+
87
+ Returns:
88
+ (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
89
+ (List[int]): List of layer indices corresponding to each crop box.
90
+
91
+ Examples:
92
+ >>> im_size = (800, 1200) # Height, width
93
+ >>> n_layers = 3
94
+ >>> overlap_ratio = 0.25
95
+ >>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
96
+ """
80
97
  crop_boxes, layer_idxs = [], []
81
98
  im_h, im_w = im_size
82
99
  short_side = min(im_h, im_w)
@@ -86,7 +103,7 @@ def generate_crop_boxes(
86
103
  layer_idxs.append(0)
87
104
 
88
105
  def crop_len(orig_len, n_crops, overlap):
89
- """Crops bounding boxes to the size of the input image."""
106
+ """Calculates the length of each crop given the original length, number of crops, and overlap."""
90
107
  return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
91
108
 
92
109
  for i_layer in range(n_layers):
@@ -140,7 +157,24 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
140
157
 
141
158
 
142
159
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
143
- """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
160
+ """
161
+ Removes small disconnected regions or holes in a mask based on area threshold and mode.
162
+
163
+ Args:
164
+ mask (np.ndarray): Binary mask to process.
165
+ area_thresh (float): Area threshold below which regions will be removed.
166
+ mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.
167
+
168
+ Returns:
169
+ (np.ndarray): Processed binary mask with small regions removed.
170
+ (bool): Whether any regions were modified.
171
+
172
+ Examples:
173
+ >>> mask = np.zeros((100, 100), dtype=np.bool_)
174
+ >>> mask[40:60, 40:60] = True # Create a square
175
+ >>> mask[45:55, 45:55] = False # Create a hole
176
+ >>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
177
+ """
144
178
  import cv2 # type: ignore
145
179
 
146
180
  assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
@@ -160,7 +194,19 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
160
194
 
161
195
 
162
196
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
163
- """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
197
+ """
198
+ Calculates bounding boxes in XYXY format around binary masks.
199
+
200
+ Args:
201
+ masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
202
+
203
+ Returns:
204
+ (torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
205
+
206
+ Notes:
207
+ - Handles empty masks by returning zero boxes.
208
+ - Preserves input tensor dimensions in the output.
209
+ """
164
210
  # torch.max below raises an error on empty inputs, just skip in this case
165
211
  if torch.numel(masks) == 0:
166
212
  return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
@@ -49,7 +49,7 @@ class SAM(Model):
49
49
 
50
50
  def __init__(self, model="sam_b.pt") -> None:
51
51
  """
52
- Initializes the SAM (Segment Anything Model) instance.
52
+ Initialize the SAM (Segment Anything Model) instance.
53
53
 
54
54
  Args:
55
55
  model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
@@ -68,10 +68,7 @@ class SAM(Model):
68
68
 
69
69
  def _load(self, weights: str, task=None):
70
70
  """
71
- Loads the specified weights into the SAM model.
72
-
73
- This method initializes the SAM model with the provided weights file, setting up the model architecture
74
- and loading the pre-trained parameters.
71
+ Load the specified weights into the SAM model.
75
72
 
76
73
  Args:
77
74
  weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
@@ -85,7 +82,7 @@ class SAM(Model):
85
82
 
86
83
  def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
87
84
  """
88
- Performs segmentation prediction on the given image or video source.
85
+ Perform segmentation prediction on the given image or video source.
89
86
 
90
87
  Args:
91
88
  source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
@@ -112,7 +109,7 @@ class SAM(Model):
112
109
 
113
110
  def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
114
111
  """
115
- Performs segmentation prediction on the given image or video source.
112
+ Perform segmentation prediction on the given image or video source.
116
113
 
117
114
  This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
118
115
  for segmentation tasks.
@@ -138,10 +135,7 @@ class SAM(Model):
138
135
 
139
136
  def info(self, detailed=False, verbose=True):
140
137
  """
141
- Logs information about the SAM model.
142
-
143
- This method provides details about the Segment Anything Model (SAM), including its architecture,
144
- parameters, and computational requirements.
138
+ Log information about the SAM model.
145
139
 
146
140
  Args:
147
141
  detailed (bool): If True, displays detailed information about the model layers and operations.
@@ -160,16 +154,16 @@ class SAM(Model):
160
154
  @property
161
155
  def task_map(self):
162
156
  """
163
- Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
157
+ Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
164
158
 
165
159
  Returns:
166
- (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
160
+ (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding Predictor
167
161
  class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
168
162
 
169
163
  Examples:
170
164
  >>> sam = SAM("sam_b.pt")
171
165
  >>> task_map = sam.task_map
172
166
  >>> print(task_map)
173
- {'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
167
+ {'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
174
168
  """
175
169
  return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
@@ -48,7 +48,7 @@ class MaskDecoder(nn.Module):
48
48
  iou_head_hidden_dim: int = 256,
49
49
  ) -> None:
50
50
  """
51
- Initializes the MaskDecoder module for generating masks and their quality scores.
51
+ Initialize the MaskDecoder module for generating masks and their associated quality scores.
52
52
 
53
53
  Args:
54
54
  transformer_dim (int): Channel dimension for the transformer module.
@@ -95,7 +95,7 @@ class MaskDecoder(nn.Module):
95
95
  multimask_output: bool,
96
96
  ) -> Tuple[torch.Tensor, torch.Tensor]:
97
97
  """
98
- Predicts masks given image and prompt embeddings.
98
+ Predict masks given image and prompt embeddings.
99
99
 
100
100
  Args:
101
101
  image_embeddings (torch.Tensor): Embeddings from the image encoder.
@@ -105,9 +105,8 @@ class MaskDecoder(nn.Module):
105
105
  multimask_output (bool): Whether to return multiple masks or a single mask.
106
106
 
107
107
  Returns:
108
- (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
109
- - masks (torch.Tensor): Batched predicted masks.
110
- - iou_pred (torch.Tensor): Batched predictions of mask quality.
108
+ masks (torch.Tensor): Batched predicted masks.
109
+ iou_pred (torch.Tensor): Batched predictions of mask quality.
111
110
 
112
111
  Examples:
113
112
  >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
@@ -140,7 +139,7 @@ class MaskDecoder(nn.Module):
140
139
  sparse_prompt_embeddings: torch.Tensor,
141
140
  dense_prompt_embeddings: torch.Tensor,
142
141
  ) -> Tuple[torch.Tensor, torch.Tensor]:
143
- """Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
142
+ """Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
144
143
  # Concatenate output tokens
145
144
  output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
146
145
  output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
@@ -236,7 +235,7 @@ class SAM2MaskDecoder(nn.Module):
236
235
  use_multimask_token_for_obj_ptr: bool = False,
237
236
  ) -> None:
238
237
  """
239
- Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
238
+ Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
240
239
 
241
240
  This decoder extends the functionality of MaskDecoder, incorporating additional features such as
242
241
  high-resolution feature processing, dynamic multimask output, and object score prediction.
@@ -320,9 +319,9 @@ class SAM2MaskDecoder(nn.Module):
320
319
  multimask_output: bool,
321
320
  repeat_image: bool,
322
321
  high_res_features: Optional[List[torch.Tensor]] = None,
323
- ) -> Tuple[torch.Tensor, torch.Tensor]:
322
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
324
323
  """
325
- Predicts masks given image and prompt embeddings.
324
+ Predict masks given image and prompt embeddings.
326
325
 
327
326
  Args:
328
327
  image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
@@ -334,11 +333,10 @@ class SAM2MaskDecoder(nn.Module):
334
333
  high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
335
334
 
336
335
  Returns:
337
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
338
- - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
339
- - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
340
- - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
341
- - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
336
+ masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
337
+ iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
338
+ sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
339
+ object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
342
340
 
343
341
  Examples:
344
342
  >>> image_embeddings = torch.rand(1, 256, 64, 64)
@@ -390,8 +388,8 @@ class SAM2MaskDecoder(nn.Module):
390
388
  dense_prompt_embeddings: torch.Tensor,
391
389
  repeat_image: bool,
392
390
  high_res_features: Optional[List[torch.Tensor]] = None,
393
- ) -> Tuple[torch.Tensor, torch.Tensor]:
394
- """Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
391
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
392
+ """Predict instance segmentation masks from image and prompt embeddings using a transformer."""
395
393
  # Concatenate output tokens
396
394
  s = 0
397
395
  if self.pred_obj_scores:
@@ -454,7 +452,7 @@ class SAM2MaskDecoder(nn.Module):
454
452
  return masks, iou_pred, mask_tokens_out, object_score_logits
455
453
 
456
454
  def _get_stability_scores(self, mask_logits):
457
- """Computes mask stability scores based on IoU between upper and lower thresholds."""
455
+ """Compute mask stability scores based on IoU between upper and lower thresholds."""
458
456
  mask_logits = mask_logits.flatten(-2)
459
457
  stability_delta = self.dynamic_multimask_stability_delta
460
458
  area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
@@ -463,7 +461,7 @@ class SAM2MaskDecoder(nn.Module):
463
461
 
464
462
  def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
465
463
  """
466
- Dynamically selects the most stable mask output based on stability scores and IoU predictions.
464
+ Dynamically select the most stable mask output based on stability scores and IoU predictions.
467
465
 
468
466
  This method is used when outputting a single mask. If the stability score from the current single-mask
469
467
  output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
@@ -476,9 +474,8 @@ class SAM2MaskDecoder(nn.Module):
476
474
  all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
477
475
 
478
476
  Returns:
479
- (Tuple[torch.Tensor, torch.Tensor]):
480
- - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
481
- - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
477
+ mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
478
+ iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
482
479
 
483
480
  Examples:
484
481
  >>> decoder = SAM2MaskDecoder(...)
@@ -65,7 +65,7 @@ class ImageEncoderViT(nn.Module):
65
65
  global_attn_indexes: Tuple[int, ...] = (),
66
66
  ) -> None:
67
67
  """
68
- Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
68
+ Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
69
69
 
70
70
  Args:
71
71
  img_size (int): Input image size, assumed to be square.
@@ -85,13 +85,6 @@ class ImageEncoderViT(nn.Module):
85
85
  window_size (int): Size of attention window for windowed attention blocks.
86
86
  global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
87
87
 
88
- Attributes:
89
- img_size (int): Dimension of input images.
90
- patch_embed (PatchEmbed): Module for patch embedding.
91
- pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
92
- blocks (nn.ModuleList): List of transformer blocks.
93
- neck (nn.Sequential): Neck module for final processing.
94
-
95
88
  Examples:
96
89
  >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
97
90
  >>> input_image = torch.randn(1, 3, 224, 224)
@@ -148,7 +141,7 @@ class ImageEncoderViT(nn.Module):
148
141
  )
149
142
 
150
143
  def forward(self, x: torch.Tensor) -> torch.Tensor:
151
- """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
144
+ """Process input through patch embedding, positional embedding, transformer blocks, and neck module."""
152
145
  x = self.patch_embed(x)
153
146
  if self.pos_embed is not None:
154
147
  pos_embed = (
@@ -201,10 +194,7 @@ class PromptEncoder(nn.Module):
201
194
  activation: Type[nn.Module] = nn.GELU,
202
195
  ) -> None:
203
196
  """
204
- Initializes the PromptEncoder module for encoding various types of prompts.
205
-
206
- This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
207
- producing both sparse and dense embeddings.
197
+ Initialize the PromptEncoder module for encoding various types of prompts.
208
198
 
209
199
  Args:
210
200
  embed_dim (int): The dimension of the embeddings.
@@ -213,17 +203,6 @@ class PromptEncoder(nn.Module):
213
203
  mask_in_chans (int): The number of hidden channels used for encoding input masks.
214
204
  activation (Type[nn.Module]): The activation function to use when encoding input masks.
215
205
 
216
- Attributes:
217
- embed_dim (int): Dimension of the embeddings.
218
- input_image_size (Tuple[int, int]): Size of the input image as (H, W).
219
- image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
220
- pe_layer (PositionEmbeddingRandom): Module for random position embedding.
221
- num_point_embeddings (int): Number of point embeddings for different types of points.
222
- point_embeddings (nn.ModuleList): List of point embeddings.
223
- not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
224
- mask_input_size (Tuple[int, int]): Size of the input mask.
225
- mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
226
-
227
206
  Examples:
228
207
  >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
229
208
  >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
@@ -258,9 +237,9 @@ class PromptEncoder(nn.Module):
258
237
 
259
238
  def get_dense_pe(self) -> torch.Tensor:
260
239
  """
261
- Returns the dense positional encoding used for encoding point prompts.
240
+ Return the dense positional encoding used for encoding point prompts.
262
241
 
263
- This method generates a positional encoding for a dense set of points matching the shape of the image
242
+ Generate a positional encoding for a dense set of points matching the shape of the image
264
243
  encoding. The encoding is used to provide spatial information to the model when processing point prompts.
265
244
 
266
245
  Returns:
@@ -276,7 +255,7 @@ class PromptEncoder(nn.Module):
276
255
  return self.pe_layer(self.image_embedding_size).unsqueeze(0)
277
256
 
278
257
  def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
279
- """Embeds point prompts by applying positional encoding and label-specific embeddings."""
258
+ """Embed point prompts by applying positional encoding and label-specific embeddings."""
280
259
  points = points + 0.5 # Shift to center of pixel
281
260
  if pad:
282
261
  padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
@@ -293,7 +272,7 @@ class PromptEncoder(nn.Module):
293
272
  return point_embedding
294
273
 
295
274
  def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
296
- """Embeds box prompts by applying positional encoding and adding corner embeddings."""
275
+ """Embed box prompts by applying positional encoding and adding corner embeddings."""
297
276
  boxes = boxes + 0.5 # Shift to center of pixel
298
277
  coords = boxes.reshape(-1, 2, 2)
299
278
  corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
@@ -302,7 +281,7 @@ class PromptEncoder(nn.Module):
302
281
  return corner_embedding
303
282
 
304
283
  def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
305
- """Embeds mask inputs by downscaling and processing through convolutional layers."""
284
+ """Embed mask inputs by downscaling and processing through convolutional layers."""
306
285
  return self.mask_downscaling(masks)
307
286
 
308
287
  @staticmethod
@@ -311,7 +290,7 @@ class PromptEncoder(nn.Module):
311
290
  boxes: Optional[torch.Tensor],
312
291
  masks: Optional[torch.Tensor],
313
292
  ) -> int:
314
- """Gets the batch size of the output given the batch size of the input prompts."""
293
+ """Get the batch size of the output given the batch size of the input prompts."""
315
294
  if points is not None:
316
295
  return points[0].shape[0]
317
296
  elif boxes is not None:
@@ -322,7 +301,7 @@ class PromptEncoder(nn.Module):
322
301
  return 1
323
302
 
324
303
  def _get_device(self) -> torch.device:
325
- """Returns the device of the first point embedding's weight tensor."""
304
+ """Return the device of the first point embedding's weight tensor."""
326
305
  return self.point_embeddings[0].weight.device
327
306
 
328
307
  def forward(
@@ -332,7 +311,7 @@ class PromptEncoder(nn.Module):
332
311
  masks: Optional[torch.Tensor],
333
312
  ) -> Tuple[torch.Tensor, torch.Tensor]:
334
313
  """
335
- Embeds different types of prompts, returning both sparse and dense embeddings.
314
+ Embed different types of prompts, returning both sparse and dense embeddings.
336
315
 
337
316
  Args:
338
317
  points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
@@ -377,7 +356,7 @@ class PromptEncoder(nn.Module):
377
356
 
378
357
  class MemoryEncoder(nn.Module):
379
358
  """
380
- Encodes pixel features and masks into a memory representation for efficient image segmentation.
359
+ Encode pixel features and masks into a memory representation for efficient image segmentation.
381
360
 
382
361
  This class processes pixel-level features and masks, fusing them to generate encoded memory representations
383
362
  suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
@@ -390,7 +369,7 @@ class MemoryEncoder(nn.Module):
390
369
  out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
391
370
 
392
371
  Methods:
393
- forward: Processes input pixel features and masks to generate encoded memory representations.
372
+ forward: Process input pixel features and masks to generate encoded memory representations.
394
373
 
395
374
  Examples:
396
375
  >>> import torch
@@ -407,7 +386,7 @@ class MemoryEncoder(nn.Module):
407
386
  out_dim,
408
387
  in_dim=256, # in_dim of pix_feats
409
388
  ):
410
- """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
389
+ """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations."""
411
390
  super().__init__()
412
391
 
413
392
  self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
@@ -425,7 +404,7 @@ class MemoryEncoder(nn.Module):
425
404
  masks: torch.Tensor,
426
405
  skip_mask_sigmoid: bool = False,
427
406
  ) -> Tuple[torch.Tensor, torch.Tensor]:
428
- """Processes pixel features and masks to generate encoded memory representations for segmentation."""
407
+ """Process pixel features and masks to generate encoded memory representations for segmentation."""
429
408
  if not skip_mask_sigmoid:
430
409
  masks = F.sigmoid(masks)
431
410
  masks = self.mask_downsampler(masks)
@@ -445,7 +424,7 @@ class MemoryEncoder(nn.Module):
445
424
 
446
425
  class ImageEncoder(nn.Module):
447
426
  """
448
- Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
427
+ Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
449
428
 
450
429
  This class combines a trunk network for feature extraction with a neck network for feature refinement
451
430
  and positional encoding generation. It can optionally discard the lowest resolution features.
@@ -456,7 +435,7 @@ class ImageEncoder(nn.Module):
456
435
  scalp (int): Number of lowest resolution feature levels to discard.
457
436
 
458
437
  Methods:
459
- forward: Processes the input image through the trunk and neck networks.
438
+ forward: Process the input image through the trunk and neck networks.
460
439
 
461
440
  Examples:
462
441
  >>> trunk = SomeTrunkNetwork()
@@ -474,7 +453,7 @@ class ImageEncoder(nn.Module):
474
453
  neck: nn.Module,
475
454
  scalp: int = 0,
476
455
  ):
477
- """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
456
+ """Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
478
457
  super().__init__()
479
458
  self.trunk = trunk
480
459
  self.neck = neck
@@ -484,7 +463,7 @@ class ImageEncoder(nn.Module):
484
463
  )
485
464
 
486
465
  def forward(self, sample: torch.Tensor):
487
- """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
466
+ """Encode input through patch embedding, positional embedding, transformer blocks, and neck module."""
488
467
  features, pos = self.neck(self.trunk(sample))
489
468
  if self.scalp > 0:
490
469
  # Discard the lowest resolution features
@@ -514,7 +493,7 @@ class FpnNeck(nn.Module):
514
493
  fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
515
494
 
516
495
  Methods:
517
- forward: Performs forward pass through the FPN neck.
496
+ forward: Perform forward pass through the FPN neck.
518
497
 
519
498
  Examples:
520
499
  >>> backbone_channels = [64, 128, 256, 512]
@@ -665,8 +644,8 @@ class Hiera(nn.Module):
665
644
  channel_list (List[int]): List of output channel dimensions for each stage.
666
645
 
667
646
  Methods:
668
- _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
669
- forward: Performs the forward pass through the Hiera model.
647
+ _get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.
648
+ forward: Perform the forward pass through the Hiera model.
670
649
 
671
650
  Examples:
672
651
  >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
@@ -702,7 +681,7 @@ class Hiera(nn.Module):
702
681
  ),
703
682
  return_interm_layers=True, # return feats from every stage
704
683
  ):
705
- """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
684
+ """Initialize the Hiera model, configuring its hierarchical vision transformer architecture."""
706
685
  super().__init__()
707
686
 
708
687
  assert len(stages) == len(window_spec)
@@ -768,7 +747,7 @@ class Hiera(nn.Module):
768
747
  )
769
748
 
770
749
  def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
771
- """Generates positional embeddings by interpolating and combining window and background embeddings."""
750
+ """Generate positional embeddings by interpolating and combining window and background embeddings."""
772
751
  h, w = hw
773
752
  window_embed = self.pos_embed_window
774
753
  pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
@@ -777,7 +756,7 @@ class Hiera(nn.Module):
777
756
  return pos_embed
778
757
 
779
758
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
780
- """Performs forward pass through Hiera model, extracting multiscale features from input images."""
759
+ """Perform forward pass through Hiera model, extracting multiscale features from input images."""
781
760
  x = self.patch_embed(x)
782
761
  # x: (B, H, W, C)
783
762