dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,10 @@ from .val import NASValidator
18
18
 
19
19
 
20
20
  class NAS(Model):
21
- """
22
- YOLO-NAS model for object detection.
21
+ """YOLO-NAS model for object detection.
23
22
 
24
- This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
25
- It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
23
+ This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
24
+ is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
26
25
 
27
26
  Attributes:
28
27
  model (torch.nn.Module): The loaded YOLO-NAS model.
@@ -48,8 +47,7 @@ class NAS(Model):
48
47
  super().__init__(model, task="detect")
49
48
 
50
49
  def _load(self, weights: str, task=None) -> None:
51
- """
52
- Load an existing NAS model weights or create a new NAS model with pretrained weights.
50
+ """Load an existing NAS model weights or create a new NAS model with pretrained weights.
53
51
 
54
52
  Args:
55
53
  weights (str): Path to the model weights file or model name.
@@ -83,8 +81,7 @@ class NAS(Model):
83
81
  self.model.eval()
84
82
 
85
83
  def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
86
- """
87
- Log model information.
84
+ """Log model information.
88
85
 
89
86
  Args:
90
87
  detailed (bool): Show detailed information about model.
@@ -7,12 +7,11 @@ from ultralytics.utils import ops
7
7
 
8
8
 
9
9
  class NASPredictor(DetectionPredictor):
10
- """
11
- Ultralytics YOLO NAS Predictor for object detection.
10
+ """Ultralytics YOLO NAS Predictor for object detection.
12
11
 
13
- This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the
14
- raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
15
- scaling the bounding boxes to fit the original image dimensions.
12
+ This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
13
+ predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
14
+ bounding boxes to fit the original image dimensions.
16
15
 
17
16
  Attributes:
18
17
  args (Namespace): Namespace containing various configurations for post-processing including confidence
@@ -33,12 +32,11 @@ class NASPredictor(DetectionPredictor):
33
32
  """
34
33
 
35
34
  def postprocess(self, preds_in, img, orig_imgs):
36
- """
37
- Postprocess NAS model predictions to generate final detection results.
35
+ """Postprocess NAS model predictions to generate final detection results.
38
36
 
39
37
  This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
40
- post-processing operations to generate the final detection results compatible with Ultralytics
41
- result visualization and analysis tools.
38
+ post-processing operations to generate the final detection results compatible with Ultralytics result
39
+ visualization and analysis tools.
42
40
 
43
41
  Args:
44
42
  preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
@@ -9,8 +9,7 @@ __all__ = ["NASValidator"]
9
9
 
10
10
 
11
11
  class NASValidator(DetectionValidator):
12
- """
13
- Ultralytics YOLO NAS Validator for object detection.
12
+ """Ultralytics YOLO NAS Validator for object detection.
14
13
 
15
14
  Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
16
15
  generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
@@ -19,11 +19,10 @@ from .val import RTDETRValidator
19
19
 
20
20
 
21
21
  class RTDETR(Model):
22
- """
23
- Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
22
+ """Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
24
23
 
25
- This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
26
- query selection, and adaptable inference speed.
24
+ This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
25
+ selection, and adaptable inference speed.
27
26
 
28
27
  Attributes:
29
28
  model (str): Path to the pre-trained model.
@@ -39,8 +38,7 @@ class RTDETR(Model):
39
38
  """
40
39
 
41
40
  def __init__(self, model: str = "rtdetr-l.pt") -> None:
42
- """
43
- Initialize the RT-DETR model with the given pre-trained model file.
41
+ """Initialize the RT-DETR model with the given pre-trained model file.
44
42
 
45
43
  Args:
46
44
  model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
@@ -50,8 +48,7 @@ class RTDETR(Model):
50
48
 
51
49
  @property
52
50
  def task_map(self) -> dict:
53
- """
54
- Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
51
+ """Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
55
52
 
56
53
  Returns:
57
54
  (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
@@ -9,11 +9,10 @@ from ultralytics.utils import ops
9
9
 
10
10
 
11
11
  class RTDETRPredictor(BasePredictor):
12
- """
13
- RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
12
+ """RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
14
13
 
15
- This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
16
- It supports key features like efficient hybrid encoding and IoU-aware query selection.
14
+ This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
15
+ supports key features like efficient hybrid encoding and IoU-aware query selection.
17
16
 
18
17
  Attributes:
19
18
  imgsz (int): Image size for inference (must be square and scale-filled).
@@ -34,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
34
33
  """
35
34
 
36
35
  def postprocess(self, preds, img, orig_imgs):
37
- """
38
- Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
36
+ """Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
39
37
 
40
- The method filters detections based on confidence and class if specified in `self.args`. It converts
41
- model predictions to Results objects containing properly scaled bounding boxes.
38
+ The method filters detections based on confidence and class if specified in `self.args`. It converts model
39
+ predictions to Results objects containing properly scaled bounding boxes.
42
40
 
43
41
  Args:
44
- preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
45
- bounding boxes and scores.
42
+ preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
43
+ and scores.
46
44
  img (torch.Tensor): Processed input images with shape (N, 3, H, W).
47
45
  orig_imgs (list | torch.Tensor): Original, unprocessed images.
48
46
 
49
47
  Returns:
50
- results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
51
- confidence scores, and class labels.
48
+ results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
49
+ scores, and class labels.
52
50
  """
53
51
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
54
52
  preds = [preds, None]
@@ -75,15 +73,14 @@ class RTDETRPredictor(BasePredictor):
75
73
  return results
76
74
 
77
75
  def pre_transform(self, im):
78
- """
79
- Pre-transform input images before feeding them into the model for inference.
76
+ """Pre-transform input images before feeding them into the model for inference.
80
77
 
81
- The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
82
- (640) and scale_filled.
78
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
79
+ and scale_filled.
83
80
 
84
81
  Args:
85
- im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
86
- [(H, W, 3) x N] for list.
82
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
83
+ list.
87
84
 
88
85
  Returns:
89
86
  (list): List of pre-transformed images ready for model inference.
@@ -12,12 +12,11 @@ from .val import RTDETRDataset, RTDETRValidator
12
12
 
13
13
 
14
14
  class RTDETRTrainer(DetectionTrainer):
15
- """
16
- Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
15
+ """Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
17
16
 
18
- This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
19
- The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
20
- speed.
17
+ This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
18
+ RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
19
+ inference speed.
21
20
 
22
21
  Attributes:
23
22
  loss_names (tuple): Names of the loss components used for training.
@@ -31,20 +30,19 @@ class RTDETRTrainer(DetectionTrainer):
31
30
  build_dataset: Build and return an RT-DETR dataset for training or validation.
32
31
  get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
33
32
 
34
- Notes:
35
- - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
36
- - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
37
-
38
33
  Examples:
39
34
  >>> from ultralytics.models.rtdetr.train import RTDETRTrainer
40
35
  >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
41
36
  >>> trainer = RTDETRTrainer(overrides=args)
42
37
  >>> trainer.train()
38
+
39
+ Notes:
40
+ - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
41
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
43
42
  """
44
43
 
45
44
  def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
46
- """
47
- Initialize and return an RT-DETR model for object detection tasks.
45
+ """Initialize and return an RT-DETR model for object detection tasks.
48
46
 
49
47
  Args:
50
48
  cfg (dict, optional): Model configuration.
@@ -60,8 +58,7 @@ class RTDETRTrainer(DetectionTrainer):
60
58
  return model
61
59
 
62
60
  def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
63
- """
64
- Build and return an RT-DETR dataset for training or validation.
61
+ """Build and return an RT-DETR dataset for training or validation.
65
62
 
66
63
  Args:
67
64
  img_path (str): Path to the folder containing images.
@@ -16,8 +16,7 @@ __all__ = ("RTDETRValidator",) # tuple or list
16
16
 
17
17
 
18
18
  class RTDETRDataset(YOLODataset):
19
- """
20
- Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
19
+ """Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
21
20
 
22
21
  This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
23
22
  real-time detection and tracking tasks.
@@ -40,8 +39,7 @@ class RTDETRDataset(YOLODataset):
40
39
  """
41
40
 
42
41
  def __init__(self, *args, data=None, **kwargs):
43
- """
44
- Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
42
+ """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
45
43
 
46
44
  This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
47
45
  model, building upon the base YOLODataset functionality.
@@ -54,8 +52,7 @@ class RTDETRDataset(YOLODataset):
54
52
  super().__init__(*args, data=data, **kwargs)
55
53
 
56
54
  def load_image(self, i, rect_mode=False):
57
- """
58
- Load one image from dataset index 'i'.
55
+ """Load one image from dataset index 'i'.
59
56
 
60
57
  Args:
61
58
  i (int): Index of the image to load.
@@ -73,8 +70,7 @@ class RTDETRDataset(YOLODataset):
73
70
  return super().load_image(i=i, rect_mode=rect_mode)
74
71
 
75
72
  def build_transforms(self, hyp=None):
76
- """
77
- Build transformation pipeline for the dataset.
73
+ """Build transformation pipeline for the dataset.
78
74
 
79
75
  Args:
80
76
  hyp (dict, optional): Hyperparameters for transformations.
@@ -89,7 +85,7 @@ class RTDETRDataset(YOLODataset):
89
85
  transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
90
86
  else:
91
87
  # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
92
- transforms = Compose([])
88
+ transforms = Compose([lambda x: {**x, **{"ratio_pad": [x["ratio_pad"], [0, 0]]}}])
93
89
  transforms.append(
94
90
  Format(
95
91
  bbox_format="xywh",
@@ -105,8 +101,7 @@ class RTDETRDataset(YOLODataset):
105
101
 
106
102
 
107
103
  class RTDETRValidator(DetectionValidator):
108
- """
109
- RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
104
+ """RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
110
105
  the RT-DETR (Real-Time DETR) object detection model.
111
106
 
112
107
  The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
@@ -132,8 +127,7 @@ class RTDETRValidator(DetectionValidator):
132
127
  """
133
128
 
134
129
  def build_dataset(self, img_path, mode="val", batch=None):
135
- """
136
- Build an RTDETR Dataset.
130
+ """Build an RTDETR Dataset.
137
131
 
138
132
  Args:
139
133
  img_path (str): Path to the folder containing images.
@@ -159,12 +153,12 @@ class RTDETRValidator(DetectionValidator):
159
153
  def postprocess(
160
154
  self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
161
155
  ) -> list[dict[str, torch.Tensor]]:
162
- """
163
- Apply Non-maximum suppression to prediction outputs.
156
+ """Apply Non-maximum suppression to prediction outputs.
164
157
 
165
158
  Args:
166
159
  preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
167
- (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
160
+ (batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and
161
+ class scores.
168
162
 
169
163
  Returns:
170
164
  (list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
@@ -190,12 +184,11 @@ class RTDETRValidator(DetectionValidator):
190
184
  return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
191
185
 
192
186
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
193
- """
194
- Serialize YOLO predictions to COCO json format.
187
+ """Serialize YOLO predictions to COCO json format.
195
188
 
196
189
  Args:
197
- predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
198
- with bounding box coordinates, confidence scores, and class predictions.
190
+ predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys with
191
+ bounding box coordinates, confidence scores, and class predictions.
199
192
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
200
193
  """
201
194
  path = Path(pbatch["im_file"])
@@ -14,8 +14,7 @@ import torch
14
14
  def is_box_near_crop_edge(
15
15
  boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
16
16
  ) -> torch.Tensor:
17
- """
18
- Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
17
+ """Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
19
18
 
20
19
  Args:
21
20
  boxes (torch.Tensor): Bounding boxes in XYXY format.
@@ -42,8 +41,7 @@ def is_box_near_crop_edge(
42
41
 
43
42
 
44
43
  def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
45
- """
46
- Yield batches of data from input arguments with specified batch size for efficient processing.
44
+ """Yield batches of data from input arguments with specified batch size for efficient processing.
47
45
 
48
46
  This function takes a batch size and any number of iterables, then yields batches of elements from those
49
47
  iterables. All input iterables must have the same length.
@@ -71,11 +69,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
71
69
 
72
70
 
73
71
  def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
74
- """
75
- Compute the stability score for a batch of masks.
72
+ """Compute the stability score for a batch of masks.
76
73
 
77
- The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
78
- high and low values.
74
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at high and
75
+ low values.
79
76
 
80
77
  Args:
81
78
  masks (torch.Tensor): Batch of predicted mask logits.
@@ -85,15 +82,15 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
85
82
  Returns:
86
83
  (torch.Tensor): Stability scores for each mask in the batch.
87
84
 
88
- Notes:
89
- - One mask is always contained inside the other.
90
- - Memory is saved by preventing unnecessary cast to torch.int64.
91
-
92
85
  Examples:
93
86
  >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
94
87
  >>> mask_threshold = 0.5
95
88
  >>> threshold_offset = 0.1
96
89
  >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
90
+
91
+ Notes:
92
+ - One mask is always contained inside the other.
93
+ - Memory is saved by preventing unnecessary cast to torch.int64.
97
94
  """
98
95
  intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
99
96
  unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -117,8 +114,7 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
117
114
  def generate_crop_boxes(
118
115
  im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
119
116
  ) -> tuple[list[list[int]], list[int]]:
120
- """
121
- Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
117
+ """Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
122
118
 
123
119
  Args:
124
120
  im_size (tuple[int, ...]): Height and width of the input image.
@@ -198,8 +194,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w:
198
194
 
199
195
 
200
196
  def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
201
- """
202
- Remove small disconnected regions or holes in a mask based on area threshold and mode.
197
+ """Remove small disconnected regions or holes in a mask based on area threshold and mode.
203
198
 
204
199
  Args:
205
200
  mask (np.ndarray): Binary mask to process.
@@ -236,8 +231,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tup
236
231
 
237
232
 
238
233
  def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
239
- """
240
- Calculate bounding boxes in XYXY format around binary masks.
234
+ """Calculate bounding boxes in XYXY format around binary masks.
241
235
 
242
236
  Args:
243
237
  masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
@@ -11,7 +11,7 @@ from functools import partial
11
11
  import torch
12
12
 
13
13
  from ultralytics.utils.downloads import attempt_download_asset
14
- from ultralytics.utils.torch_utils import TORCH_1_13
14
+ from ultralytics.utils.patches import torch_load
15
15
 
16
16
  from .modules.decoders import MaskDecoder
17
17
  from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
@@ -127,8 +127,7 @@ def _build_sam(
127
127
  checkpoint=None,
128
128
  mobile_sam=False,
129
129
  ):
130
- """
131
- Build a Segment Anything Model (SAM) with specified encoder parameters.
130
+ """Build a Segment Anything Model (SAM) with specified encoder parameters.
132
131
 
133
132
  Args:
134
133
  encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
@@ -208,7 +207,7 @@ def _build_sam(
208
207
  if checkpoint is not None:
209
208
  checkpoint = attempt_download_asset(checkpoint)
210
209
  with open(checkpoint, "rb") as f:
211
- state_dict = torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f)
210
+ state_dict = torch_load(f)
212
211
  sam.load_state_dict(state_dict)
213
212
  sam.eval()
214
213
  return sam
@@ -224,8 +223,7 @@ def _build_sam2(
224
223
  encoder_window_spec=[8, 4, 16, 8],
225
224
  checkpoint=None,
226
225
  ):
227
- """
228
- Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
226
+ """Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
229
227
 
230
228
  Args:
231
229
  encoder_embed_dim (int, optional): Embedding dimension for the encoder.
@@ -303,7 +301,7 @@ def _build_sam2(
303
301
  if checkpoint is not None:
304
302
  checkpoint = attempt_download_asset(checkpoint)
305
303
  with open(checkpoint, "rb") as f:
306
- state_dict = (torch.load(f, weights_only=False) if TORCH_1_13 else torch.load(f))["model"]
304
+ state_dict = torch_load(f)["model"]
307
305
  sam2.load_state_dict(state_dict)
308
306
  sam2.eval()
309
307
  return sam2
@@ -326,8 +324,7 @@ sam_model_map = {
326
324
 
327
325
 
328
326
  def build_sam(ckpt="sam_b.pt"):
329
- """
330
- Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
327
+ """Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
331
328
 
332
329
  Args:
333
330
  ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
@@ -25,12 +25,11 @@ from .predict import Predictor, SAM2Predictor
25
25
 
26
26
 
27
27
  class SAM(Model):
28
- """
29
- SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
28
+ """SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
30
29
 
31
- This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
32
- promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
33
- boxes, points, or labels, and features zero-shot performance capabilities.
30
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for promptable
31
+ segmentation with versatility in image analysis. It supports various prompts such as bounding boxes, points, or
32
+ labels, and features zero-shot performance capabilities.
34
33
 
35
34
  Attributes:
36
35
  model (torch.nn.Module): The loaded SAM model.
@@ -49,8 +48,7 @@ class SAM(Model):
49
48
  """
50
49
 
51
50
  def __init__(self, model: str = "sam_b.pt") -> None:
52
- """
53
- Initialize the SAM (Segment Anything Model) instance.
51
+ """Initialize the SAM (Segment Anything Model) instance.
54
52
 
55
53
  Args:
56
54
  model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
@@ -68,8 +66,7 @@ class SAM(Model):
68
66
  super().__init__(model=model, task="segment")
69
67
 
70
68
  def _load(self, weights: str, task=None):
71
- """
72
- Load the specified weights into the SAM model.
69
+ """Load the specified weights into the SAM model.
73
70
 
74
71
  Args:
75
72
  weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
@@ -84,12 +81,11 @@ class SAM(Model):
84
81
  self.model = build_sam(weights)
85
82
 
86
83
  def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
87
- """
88
- Perform segmentation prediction on the given image or video source.
84
+ """Perform segmentation prediction on the given image or video source.
89
85
 
90
86
  Args:
91
- source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
92
- a np.ndarray object.
87
+ source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or a
88
+ np.ndarray object.
93
89
  stream (bool): If True, enables real-time streaming.
94
90
  bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
95
91
  points (list[list[float]] | None): List of points for prompted segmentation.
@@ -111,15 +107,14 @@ class SAM(Model):
111
107
  return super().predict(source, stream, prompts=prompts, **kwargs)
112
108
 
113
109
  def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
114
- """
115
- Perform segmentation prediction on the given image or video source.
110
+ """Perform segmentation prediction on the given image or video source.
116
111
 
117
- This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
118
- for segmentation tasks.
112
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model for
113
+ segmentation tasks.
119
114
 
120
115
  Args:
121
- source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
122
- object, or a np.ndarray object.
116
+ source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image object, or a
117
+ np.ndarray object.
123
118
  stream (bool): If True, enables real-time streaming.
124
119
  bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
125
120
  points (list[list[float]] | None): List of points for prompted segmentation.
@@ -137,8 +132,7 @@ class SAM(Model):
137
132
  return self.predict(source, stream, bboxes, points, labels, **kwargs)
138
133
 
139
134
  def info(self, detailed: bool = False, verbose: bool = True):
140
- """
141
- Log information about the SAM model.
135
+ """Log information about the SAM model.
142
136
 
143
137
  Args:
144
138
  detailed (bool): If True, displays detailed information about the model layers and operations.
@@ -156,8 +150,7 @@ class SAM(Model):
156
150
 
157
151
  @property
158
152
  def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
159
- """
160
- Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
153
+ """Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
161
154
 
162
155
  Returns:
163
156
  (dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding