ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +52 -51
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +191 -161
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +4 -6
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +2 -2
  93. ultralytics/solutions/instance_segmentation.py +7 -4
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -11
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +189 -79
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +45 -29
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
  143. ultralytics-8.3.145.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
+ from typing import Optional
4
5
 
5
6
  from ultralytics.models.yolo.detect import DetectionTrainer
6
7
  from ultralytics.nn.tasks import RTDETRDetectionModel
@@ -18,12 +19,17 @@ class RTDETRTrainer(DetectionTrainer):
18
19
  speed.
19
20
 
20
21
  Attributes:
21
- loss_names (Tuple[str]): Names of the loss components used for training.
22
+ loss_names (tuple): Names of the loss components used for training.
22
23
  data (dict): Dataset configuration containing class count and other parameters.
23
24
  args (dict): Training arguments and hyperparameters.
24
25
  save_dir (Path): Directory to save training results.
25
26
  test_loader (DataLoader): DataLoader for validation/testing data.
26
27
 
28
+ Methods:
29
+ get_model: Initialize and return an RT-DETR model for object detection tasks.
30
+ build_dataset: Build and return an RT-DETR dataset for training or validation.
31
+ get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
32
+
27
33
  Notes:
28
34
  - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
29
35
  - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
@@ -35,7 +41,7 @@ class RTDETRTrainer(DetectionTrainer):
35
41
  >>> trainer.train()
36
42
  """
37
43
 
38
- def get_model(self, cfg=None, weights=None, verbose=True):
44
+ def get_model(self, cfg: Optional[dict] = None, weights: Optional[str] = None, verbose: bool = True):
39
45
  """
40
46
  Initialize and return an RT-DETR model for object detection tasks.
41
47
 
@@ -52,7 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
52
58
  model.load(weights)
53
59
  return model
54
60
 
55
- def build_dataset(self, img_path, mode="val", batch=None):
61
+ def build_dataset(self, img_path: str, mode: str = "val", batch: Optional[int] = None):
56
62
  """
57
63
  Build and return an RT-DETR dataset for training or validation.
58
64
 
@@ -80,6 +86,6 @@ class RTDETRTrainer(DetectionTrainer):
80
86
  )
81
87
 
82
88
  def get_validator(self):
83
- """Returns a DetectionValidator suitable for RT-DETR model validation."""
89
+ """Return a DetectionValidator suitable for RT-DETR model validation."""
84
90
  self.loss_names = "giou_loss", "cls_loss", "l1_loss"
85
91
  return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
@@ -16,6 +16,22 @@ class RTDETRDataset(YOLODataset):
16
16
 
17
17
  This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
18
18
  real-time detection and tracking tasks.
19
+
20
+ Attributes:
21
+ augment (bool): Whether to apply data augmentation.
22
+ rect (bool): Whether to use rectangular training.
23
+ use_segments (bool): Whether to use segmentation masks.
24
+ use_keypoints (bool): Whether to use keypoint annotations.
25
+ imgsz (int): Target image size for training.
26
+
27
+ Methods:
28
+ load_image: Load one image from dataset index.
29
+ build_transforms: Build transformation pipeline for the dataset.
30
+
31
+ Examples:
32
+ Initialize an RT-DETR dataset
33
+ >>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
34
+ >>> image, hw = dataset.load_image(0)
19
35
  """
20
36
 
21
37
  def __init__(self, *args, data=None, **kwargs):
@@ -27,7 +43,7 @@ class RTDETRDataset(YOLODataset):
27
43
 
28
44
  Args:
29
45
  *args (Any): Variable length argument list passed to the parent YOLODataset class.
30
- data (Dict | None): Dictionary containing dataset information. If None, default values will be used.
46
+ data (dict | None): Dictionary containing dataset information. If None, default values will be used.
31
47
  **kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
32
48
  """
33
49
  super().__init__(*args, data=data, **kwargs)
@@ -41,11 +57,12 @@ class RTDETRDataset(YOLODataset):
41
57
  rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
42
58
 
43
59
  Returns:
44
- im (numpy.ndarray): The loaded image.
60
+ im (torch.Tensor): The loaded image.
45
61
  resized_hw (tuple): Height and width of the resized image with shape (2,).
46
62
 
47
63
  Examples:
48
- >>> dataset = RTDETRDataset(...)
64
+ Load an image from the dataset
65
+ >>> dataset = RTDETRDataset(img_path="path/to/images")
49
66
  >>> image, hw = dataset.load_image(0)
50
67
  """
51
68
  return super().load_image(i=i, rect_mode=rect_mode)
@@ -90,13 +107,22 @@ class RTDETRValidator(DetectionValidator):
90
107
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
91
108
  post-processing, and updates evaluation metrics accordingly.
92
109
 
110
+ Attributes:
111
+ args (Namespace): Configuration arguments for validation.
112
+ data (dict): Dataset configuration dictionary.
113
+
114
+ Methods:
115
+ build_dataset: Build an RTDETR Dataset for validation.
116
+ postprocess: Apply Non-maximum suppression to prediction outputs.
117
+
93
118
  Examples:
119
+ Initialize and run RT-DETR validation
94
120
  >>> from ultralytics.models.rtdetr import RTDETRValidator
95
121
  >>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
96
122
  >>> validator = RTDETRValidator(args=args)
97
123
  >>> validator()
98
124
 
99
- Note:
125
+ Notes:
100
126
  For further details on the attributes and methods, refer to the parent DetectionValidator class.
101
127
  """
102
128
 
@@ -106,7 +132,8 @@ class RTDETRValidator(DetectionValidator):
106
132
 
107
133
  Args:
108
134
  img_path (str): Path to the folder containing images.
109
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
135
+ mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
136
+ each mode.
110
137
  batch (int, optional): Size of batches, this is for `rect`.
111
138
 
112
139
  Returns:
@@ -129,10 +156,10 @@ class RTDETRValidator(DetectionValidator):
129
156
  Apply Non-maximum suppression to prediction outputs.
130
157
 
131
158
  Args:
132
- preds (List | Tuple | torch.Tensor): Raw predictions from the model.
159
+ preds (list | tuple | torch.Tensor): Raw predictions from the model.
133
160
 
134
161
  Returns:
135
- (List[torch.Tensor]): List of processed predictions for each image in batch.
162
+ (list[torch.Tensor]): List of processed predictions for each image in batch.
136
163
  """
137
164
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
138
165
  preds = [preds, None]
@@ -153,7 +180,7 @@ class RTDETRValidator(DetectionValidator):
153
180
 
154
181
  def _prepare_batch(self, si, batch):
155
182
  """
156
- Prepares a batch for validation by applying necessary transformations.
183
+ Prepare a batch for validation by applying necessary transformations.
157
184
 
158
185
  Args:
159
186
  si (int): Batch index.
@@ -176,7 +203,7 @@ class RTDETRValidator(DetectionValidator):
176
203
 
177
204
  def _prepare_pred(self, pred, pbatch):
178
205
  """
179
- Prepares predictions by scaling bounding boxes to original image dimensions.
206
+ Prepare predictions by scaling bounding boxes to original image dimensions.
180
207
 
181
208
  Args:
182
209
  pred (torch.Tensor): Raw predictions.
@@ -11,7 +11,24 @@ import torch
11
11
  def is_box_near_crop_edge(
12
12
  boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
13
13
  ) -> torch.Tensor:
14
- """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
14
+ """
15
+ Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
16
+
17
+ Args:
18
+ boxes (torch.Tensor): Bounding boxes in XYXY format.
19
+ crop_box (List[int]): Crop box coordinates in [x0, y0, x1, y1] format.
20
+ orig_box (List[int]): Original image box coordinates in [x0, y0, x1, y1] format.
21
+ atol (float, optional): Absolute tolerance for edge proximity detection.
22
+
23
+ Returns:
24
+ (torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
25
+
26
+ Examples:
27
+ >>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
28
+ >>> crop_box = [0, 0, 200, 200]
29
+ >>> orig_box = [0, 0, 300, 300]
30
+ >>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
31
+ """
15
32
  crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
16
33
  orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
17
34
  boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@@ -52,7 +69,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
52
69
 
53
70
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
54
71
  """
55
- Computes the stability score for a batch of masks.
72
+ Compute the stability score for a batch of masks.
56
73
 
57
74
  The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
58
75
  high and low values.
@@ -90,7 +107,7 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
90
107
 
91
108
 
92
109
  def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
93
- """Generates point grids for multiple crop layers with varying scales and densities."""
110
+ """Generate point grids for multiple crop layers with varying scales and densities."""
94
111
  return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
95
112
 
96
113
 
@@ -98,7 +115,7 @@ def generate_crop_boxes(
98
115
  im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
99
116
  ) -> Tuple[List[List[int]], List[int]]:
100
117
  """
101
- Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
118
+ Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
102
119
 
103
120
  Args:
104
121
  im_size (Tuple[int, ...]): Height and width of the input image.
@@ -106,8 +123,8 @@ def generate_crop_boxes(
106
123
  overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
107
124
 
108
125
  Returns:
109
- (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
110
- (List[int]): List of layer indices corresponding to each crop box.
126
+ crop_boxes (List[List[int]]): List of crop boxes in [x0, y0, x1, y1] format.
127
+ layer_idxs (List[int]): List of layer indices corresponding to each crop box.
111
128
 
112
129
  Examples:
113
130
  >>> im_size = (800, 1200) # Height, width
@@ -124,7 +141,7 @@ def generate_crop_boxes(
124
141
  layer_idxs.append(0)
125
142
 
126
143
  def crop_len(orig_len, n_crops, overlap):
127
- """Calculates the length of each crop given the original length, number of crops, and overlap."""
144
+ """Calculate the length of each crop given the original length, number of crops, and overlap."""
128
145
  return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
129
146
 
130
147
  for i_layer in range(n_layers):
@@ -179,16 +196,17 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
179
196
 
180
197
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
181
198
  """
182
- Removes small disconnected regions or holes in a mask based on area threshold and mode.
199
+ Remove small disconnected regions or holes in a mask based on area threshold and mode.
183
200
 
184
201
  Args:
185
202
  mask (np.ndarray): Binary mask to process.
186
203
  area_thresh (float): Area threshold below which regions will be removed.
187
- mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected regions.
204
+ mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
205
+ regions.
188
206
 
189
207
  Returns:
190
- (np.ndarray): Processed binary mask with small regions removed.
191
- (bool): Whether any regions were modified.
208
+ processed_mask (np.ndarray): Processed binary mask with small regions removed.
209
+ modified (bool): Whether any regions were modified.
192
210
 
193
211
  Examples:
194
212
  >>> mask = np.zeros((100, 100), dtype=np.bool_)
@@ -216,7 +234,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
216
234
 
217
235
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
218
236
  """
219
- Calculates bounding boxes in XYXY format around binary masks.
237
+ Calculate bounding boxes in XYXY format around binary masks.
220
238
 
221
239
  Args:
222
240
  masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
@@ -21,7 +21,7 @@ from .modules.transformer import TwoWayTransformer
21
21
 
22
22
 
23
23
  def build_sam_vit_h(checkpoint=None):
24
- """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
24
+ """Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
25
25
  return _build_sam(
26
26
  encoder_embed_dim=1280,
27
27
  encoder_depth=32,
@@ -32,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
32
32
 
33
33
 
34
34
  def build_sam_vit_l(checkpoint=None):
35
- """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
35
+ """Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
36
36
  return _build_sam(
37
37
  encoder_embed_dim=1024,
38
38
  encoder_depth=24,
@@ -43,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
43
43
 
44
44
 
45
45
  def build_sam_vit_b(checkpoint=None):
46
- """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
46
+ """Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
47
47
  return _build_sam(
48
48
  encoder_embed_dim=768,
49
49
  encoder_depth=12,
@@ -54,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
54
54
 
55
55
 
56
56
  def build_mobile_sam(checkpoint=None):
57
- """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
57
+ """Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
58
58
  return _build_sam(
59
59
  encoder_embed_dim=[64, 128, 160, 320],
60
60
  encoder_depth=[2, 2, 6, 2],
@@ -66,7 +66,7 @@ def build_mobile_sam(checkpoint=None):
66
66
 
67
67
 
68
68
  def build_sam2_t(checkpoint=None):
69
- """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
69
+ """Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
70
70
  return _build_sam2(
71
71
  encoder_embed_dim=96,
72
72
  encoder_stages=[1, 2, 7, 2],
@@ -79,7 +79,7 @@ def build_sam2_t(checkpoint=None):
79
79
 
80
80
 
81
81
  def build_sam2_s(checkpoint=None):
82
- """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
82
+ """Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
83
83
  return _build_sam2(
84
84
  encoder_embed_dim=96,
85
85
  encoder_stages=[1, 2, 11, 2],
@@ -92,7 +92,7 @@ def build_sam2_s(checkpoint=None):
92
92
 
93
93
 
94
94
  def build_sam2_b(checkpoint=None):
95
- """Builds and returns a SAM2 base-size model with specified architecture parameters."""
95
+ """Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
96
96
  return _build_sam2(
97
97
  encoder_embed_dim=112,
98
98
  encoder_stages=[2, 3, 16, 3],
@@ -106,7 +106,7 @@ def build_sam2_b(checkpoint=None):
106
106
 
107
107
 
108
108
  def build_sam2_l(checkpoint=None):
109
- """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
109
+ """Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
110
110
  return _build_sam2(
111
111
  encoder_embed_dim=144,
112
112
  encoder_stages=[2, 6, 36, 4],
@@ -127,15 +127,15 @@ def _build_sam(
127
127
  mobile_sam=False,
128
128
  ):
129
129
  """
130
- Builds a Segment Anything Model (SAM) with specified encoder parameters.
130
+ Build a Segment Anything Model (SAM) with specified encoder parameters.
131
131
 
132
132
  Args:
133
133
  encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
134
134
  encoder_depth (int | List[int]): Depth of the encoder.
135
135
  encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
136
136
  encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
137
- checkpoint (str | None): Path to the model checkpoint file.
138
- mobile_sam (bool): Whether to build a Mobile-SAM model.
137
+ checkpoint (str | None, optional): Path to the model checkpoint file.
138
+ mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
139
139
 
140
140
  Returns:
141
141
  (SAMModel): A Segment Anything Model instance with the specified architecture.
@@ -224,17 +224,17 @@ def _build_sam2(
224
224
  checkpoint=None,
225
225
  ):
226
226
  """
227
- Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
227
+ Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
228
228
 
229
229
  Args:
230
- encoder_embed_dim (int): Embedding dimension for the encoder.
231
- encoder_stages (List[int]): Number of blocks in each stage of the encoder.
232
- encoder_num_heads (int): Number of attention heads in the encoder.
233
- encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
234
- encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
235
- encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
236
- encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
237
- checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
230
+ encoder_embed_dim (int, optional): Embedding dimension for the encoder.
231
+ encoder_stages (List[int], optional): Number of blocks in each stage of the encoder.
232
+ encoder_num_heads (int, optional): Number of attention heads in the encoder.
233
+ encoder_global_att_blocks (List[int], optional): Indices of global attention blocks in the encoder.
234
+ encoder_backbone_channel_list (List[int], optional): Channel dimensions for each level of the encoder backbone.
235
+ encoder_window_spatial_size (List[int], optional): Spatial size of the window for position embeddings.
236
+ encoder_window_spec (List[int], optional): Window specifications for each stage of the encoder.
237
+ checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
238
238
 
239
239
  Returns:
240
240
  (SAM2Model): A configured and initialized SAM2 model.
@@ -326,10 +326,10 @@ sam_model_map = {
326
326
 
327
327
  def build_sam(ckpt="sam_b.pt"):
328
328
  """
329
- Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
329
+ Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
330
330
 
331
331
  Args:
332
- ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
332
+ ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
333
333
 
334
334
  Returns:
335
335
  (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
@@ -15,6 +15,7 @@ Key Features:
15
15
  """
16
16
 
17
17
  from pathlib import Path
18
+ from typing import Dict, Type
18
19
 
19
20
  from ultralytics.engine.model import Model
20
21
  from ultralytics.utils.torch_utils import model_info
@@ -36,8 +37,8 @@ class SAM(Model):
36
37
  task (str): The task type, set to "segment" for SAM models.
37
38
 
38
39
  Methods:
39
- predict: Performs segmentation prediction on the given image or video source.
40
- info: Logs information about the SAM model.
40
+ predict: Perform segmentation prediction on the given image or video source.
41
+ info: Log information about the SAM model.
41
42
 
42
43
  Examples:
43
44
  >>> sam = SAM("sam_b.pt")
@@ -46,7 +47,7 @@ class SAM(Model):
46
47
  >>> print(f"Detected {len(r.masks)} masks")
47
48
  """
48
49
 
49
- def __init__(self, model="sam_b.pt") -> None:
50
+ def __init__(self, model: str = "sam_b.pt") -> None:
50
51
  """
51
52
  Initialize the SAM (Segment Anything Model) instance.
52
53
 
@@ -81,7 +82,7 @@ class SAM(Model):
81
82
 
82
83
  self.model = build_sam(weights)
83
84
 
84
- def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
85
+ def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
85
86
  """
86
87
  Perform segmentation prediction on the given image or video source.
87
88
 
@@ -108,7 +109,7 @@ class SAM(Model):
108
109
  prompts = dict(bboxes=bboxes, points=points, labels=labels)
109
110
  return super().predict(source, stream, prompts=prompts, **kwargs)
110
111
 
111
- def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
112
+ def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
112
113
  """
113
114
  Perform segmentation prediction on the given image or video source.
114
115
 
@@ -134,7 +135,7 @@ class SAM(Model):
134
135
  """
135
136
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
136
137
 
137
- def info(self, detailed=False, verbose=True):
138
+ def info(self, detailed: bool = False, verbose: bool = True):
138
139
  """
139
140
  Log information about the SAM model.
140
141
 
@@ -153,13 +154,13 @@ class SAM(Model):
153
154
  return model_info(self.model, detailed=detailed, verbose=verbose)
154
155
 
155
156
  @property
156
- def task_map(self):
157
+ def task_map(self) -> Dict[str, Dict[str, Type[Predictor]]]:
157
158
  """
158
159
  Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
159
160
 
160
161
  Returns:
161
- (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding Predictor
162
- class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
162
+ (Dict[str, Dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
163
+ Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
163
164
 
164
165
  Examples:
165
166
  >>> sam = SAM("sam_b.pt")