dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
12
12
 
13
13
 
14
14
  class PoseTrainer(yolo.detect.DetectionTrainer):
15
- """
16
- A class extending the DetectionTrainer class for training YOLO pose estimation models.
15
+ """A class extending the DetectionTrainer class for training YOLO pose estimation models.
17
16
 
18
17
  This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
19
18
  of pose keypoints alongside bounding boxes.
@@ -39,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
39
38
  """
40
39
 
41
40
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
42
- """
43
- Initialize a PoseTrainer object for training YOLO pose estimation models.
41
+ """Initialize a PoseTrainer object for training YOLO pose estimation models.
44
42
 
45
43
  Args:
46
44
  cfg (dict, optional): Default configuration dictionary containing training parameters.
@@ -68,8 +66,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
68
66
  weights: str | Path | None = None,
69
67
  verbose: bool = True,
70
68
  ) -> PoseModel:
71
- """
72
- Get pose estimation model with specified configuration and weights.
69
+ """Get pose estimation model with specified configuration and weights.
73
70
 
74
71
  Args:
75
72
  cfg (str | Path | dict, optional): Model configuration file path or dictionary.
@@ -105,8 +102,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
105
102
  )
106
103
 
107
104
  def get_dataset(self) -> dict[str, Any]:
108
- """
109
- Retrieve the dataset and ensure it contains the required `kpt_shape` key.
105
+ """Retrieve the dataset and ensure it contains the required `kpt_shape` key.
110
106
 
111
107
  Returns:
112
108
  (dict): A dictionary containing the training/validation/test dataset and category names.
@@ -14,11 +14,10 @@ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
14
14
 
15
15
 
16
16
  class PoseValidator(DetectionValidator):
17
- """
18
- A class extending the DetectionValidator class for validation based on a pose model.
17
+ """A class extending the DetectionValidator class for validation based on a pose model.
19
18
 
20
- This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
21
- specialized metrics for pose evaluation.
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
20
+ metrics for pose evaluation.
22
21
 
23
22
  Attributes:
24
23
  sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
33
32
  _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
34
33
  dimensions.
35
34
  _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
36
- _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
37
- detections and ground truth.
35
+ _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
36
+ and ground truth.
38
37
  plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
39
38
  plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
40
39
  save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
@@ -49,8 +48,7 @@ class PoseValidator(DetectionValidator):
49
48
  """
50
49
 
51
50
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
52
- """
53
- Initialize a PoseValidator object for pose estimation validation.
51
+ """Initialize a PoseValidator object for pose estimation validation.
54
52
 
55
53
  This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
56
54
  specialized metrics for pose evaluation.
@@ -106,8 +104,7 @@ class PoseValidator(DetectionValidator):
106
104
  )
107
105
 
108
106
  def init_metrics(self, model: torch.nn.Module) -> None:
109
- """
110
- Initialize evaluation metrics for YOLO pose validation.
107
+ """Initialize evaluation metrics for YOLO pose validation.
111
108
 
112
109
  Args:
113
110
  model (torch.nn.Module): Model to validate.
@@ -119,17 +116,15 @@ class PoseValidator(DetectionValidator):
119
116
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
120
117
 
121
118
  def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
122
- """
123
- Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
119
+ """Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
124
120
 
125
- This method extends the parent class postprocessing by extracting keypoints from the 'extra'
126
- field of predictions and reshaping them according to the keypoint shape configuration.
127
- The keypoints are reshaped from a flattened format to the proper dimensional structure
128
- (typically [N, 17, 3] for COCO pose format).
121
+ This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
122
+ predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
123
+ flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
129
124
 
130
125
  Args:
131
- preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
132
- bounding boxes, confidence scores, class predictions, and keypoint data.
126
+ preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
127
+ scores, class predictions, and keypoint data.
133
128
 
134
129
  Returns:
135
130
  (dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
@@ -138,10 +133,10 @@ class PoseValidator(DetectionValidator):
138
133
  - 'cls': Class predictions
139
134
  - 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
140
135
 
141
- Note:
142
- If no keypoints are present in a prediction (empty keypoints), that prediction
143
- is skipped and continues to the next one. The keypoints are extracted from the
144
- 'extra' field which contains additional task-specific data beyond basic detection.
136
+ Notes:
137
+ If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
138
+ to the next one. The keypoints are extracted from the 'extra' field which contains additional
139
+ task-specific data beyond basic detection.
145
140
  """
146
141
  preds = super().postprocess(preds)
147
142
  for pred in preds:
@@ -149,8 +144,7 @@ class PoseValidator(DetectionValidator):
149
144
  return preds
150
145
 
151
146
  def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
152
- """
153
- Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
147
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
154
148
 
155
149
  Args:
156
150
  si (int): Batch index.
@@ -173,18 +167,18 @@ class PoseValidator(DetectionValidator):
173
167
  return pbatch
174
168
 
175
169
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
176
- """
177
- Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
170
+ """Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
171
+ truth.
178
172
 
179
173
  Args:
180
174
  preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
181
175
  and 'keypoints' for keypoint predictions.
182
- batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
183
- 'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
176
+ batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
177
+ for bounding boxes, and 'keypoints' for keypoint annotations.
184
178
 
185
179
  Returns:
186
- (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
187
- true positives across 10 IoU levels.
180
+ (dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
181
+ positives across 10 IoU levels.
188
182
 
189
183
  Notes:
190
184
  `0.53` scale factor used in area computation is referenced from
@@ -203,11 +197,10 @@ class PoseValidator(DetectionValidator):
203
197
  return tp
204
198
 
205
199
  def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
206
- """
207
- Save YOLO pose detections to a text file in normalized coordinates.
200
+ """Save YOLO pose detections to a text file in normalized coordinates.
208
201
 
209
202
  Args:
210
- predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
203
+ predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
211
204
  save_conf (bool): Whether to save confidence scores.
212
205
  shape (tuple[int, int]): Shape of the original image (height, width).
213
206
  file (Path): Output file path to save detections.
@@ -227,15 +220,14 @@ class PoseValidator(DetectionValidator):
227
220
  ).save_txt(file, save_conf=save_conf)
228
221
 
229
222
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
230
- """
231
- Convert YOLO predictions to COCO JSON format.
223
+ """Convert YOLO predictions to COCO JSON format.
232
224
 
233
- This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
234
- to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
225
+ This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
226
+ format, and appends the results to the internal JSON dictionary (self.jdict).
235
227
 
236
228
  Args:
237
- predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
238
- and 'keypoints' tensors.
229
+ predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
230
+ tensors.
239
231
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
240
232
 
241
233
  Notes:
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
6
6
 
7
7
 
8
8
  class SegmentationPredictor(DetectionPredictor):
9
- """
10
- A class extending the DetectionPredictor class for prediction based on a segmentation model.
9
+ """A class extending the DetectionPredictor class for prediction based on a segmentation model.
11
10
 
12
11
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
13
12
  prediction results.
@@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
31
30
  """
32
31
 
33
32
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
- """
35
- Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
33
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
36
34
 
37
35
  This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
38
36
  prediction results.
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
46
44
  self.args.task = "segment"
47
45
 
48
46
  def postprocess(self, preds, img, orig_imgs):
49
- """
50
- Apply non-max suppression and process segmentation detections for each image in the input batch.
47
+ """Apply non-max suppression and process segmentation detections for each image in the input batch.
51
48
 
52
49
  Args:
53
50
  preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
55
52
  orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
56
53
 
57
54
  Returns:
58
- (list): List of Results objects containing the segmentation predictions for each image in the batch.
59
- Each Results object includes both bounding boxes and segmentation masks.
55
+ (list): List of Results objects containing the segmentation predictions for each image in the batch. Each
56
+ Results object includes both bounding boxes and segmentation masks.
60
57
 
61
58
  Examples:
62
59
  >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
@@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
67
64
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)
68
65
 
69
66
  def construct_results(self, preds, img, orig_imgs, protos):
70
- """
71
- Construct a list of result objects from the predictions.
67
+ """Construct a list of result objects from the predictions.
72
68
 
73
69
  Args:
74
70
  preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
77
73
  protos (list[torch.Tensor]): List of prototype masks.
78
74
 
79
75
  Returns:
80
- (list[Results]): List of result objects containing the original images, image paths, class names,
81
- bounding boxes, and masks.
76
+ (list[Results]): List of result objects containing the original images, image paths, class names, bounding
77
+ boxes, and masks.
82
78
  """
83
79
  return [
84
80
  self.construct_result(pred, img, orig_img, img_path, proto)
@@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
86
82
  ]
87
83
 
88
84
  def construct_result(self, pred, img, orig_img, img_path, proto):
89
- """
90
- Construct a single result object from the prediction.
85
+ """Construct a single result object from the prediction.
91
86
 
92
87
  Args:
93
88
  pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
@@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
11
11
 
12
12
 
13
13
  class SegmentationTrainer(yolo.detect.DetectionTrainer):
14
- """
15
- A class extending the DetectionTrainer class for training based on a segmentation model.
14
+ """A class extending the DetectionTrainer class for training based on a segmentation model.
16
15
 
17
16
  This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
18
17
  functionality including model initialization, validation, and visualization.
@@ -28,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
28
27
  """
29
28
 
30
29
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
31
- """
32
- Initialize a SegmentationTrainer object.
30
+ """Initialize a SegmentationTrainer object.
33
31
 
34
32
  Args:
35
33
  cfg (dict): Configuration dictionary with default training settings.
@@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
42
40
  super().__init__(cfg, overrides, _callbacks)
43
41
 
44
42
  def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
45
- """
46
- Initialize and return a SegmentationModel with specified configuration and weights.
43
+ """Initialize and return a SegmentationModel with specified configuration and weights.
47
44
 
48
45
  Args:
49
46
  cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
@@ -17,11 +17,10 @@ from ultralytics.utils.metrics import SegmentMetrics, mask_iou
17
17
 
18
18
 
19
19
  class SegmentationValidator(DetectionValidator):
20
- """
21
- A class extending the DetectionValidator class for validation based on a segmentation model.
20
+ """A class extending the DetectionValidator class for validation based on a segmentation model.
22
21
 
23
- This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
24
- to compute metrics such as mAP for both detection and segmentation tasks.
22
+ This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
23
+ compute metrics such as mAP for both detection and segmentation tasks.
25
24
 
26
25
  Attributes:
27
26
  plot_masks (list): List to store masks for plotting.
@@ -38,8 +37,7 @@ class SegmentationValidator(DetectionValidator):
38
37
  """
39
38
 
40
39
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
41
- """
42
- Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
40
+ """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
43
41
 
44
42
  Args:
45
43
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
@@ -53,8 +51,7 @@ class SegmentationValidator(DetectionValidator):
53
51
  self.metrics = SegmentMetrics()
54
52
 
55
53
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
56
- """
57
- Preprocess batch of images for YOLO segmentation validation.
54
+ """Preprocess batch of images for YOLO segmentation validation.
58
55
 
59
56
  Args:
60
57
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -67,8 +64,7 @@ class SegmentationValidator(DetectionValidator):
67
64
  return batch
68
65
 
69
66
  def init_metrics(self, model: torch.nn.Module) -> None:
70
- """
71
- Initialize metrics and select mask processing function based on save_json flag.
67
+ """Initialize metrics and select mask processing function based on save_json flag.
72
68
 
73
69
  Args:
74
70
  model (torch.nn.Module): Model to validate.
@@ -96,8 +92,7 @@ class SegmentationValidator(DetectionValidator):
96
92
  )
97
93
 
98
94
  def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
99
- """
100
- Post-process YOLO predictions and return output detections with proto.
95
+ """Post-process YOLO predictions and return output detections with proto.
101
96
 
102
97
  Args:
103
98
  preds (list[torch.Tensor]): Raw predictions from the model.
@@ -122,8 +117,7 @@ class SegmentationValidator(DetectionValidator):
122
117
  return preds
123
118
 
124
119
  def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
125
- """
126
- Prepare a batch for training or inference by processing images and targets.
120
+ """Prepare a batch for training or inference by processing images and targets.
127
121
 
128
122
  Args:
129
123
  si (int): Batch index.
@@ -149,8 +143,7 @@ class SegmentationValidator(DetectionValidator):
149
143
  return prepared_batch
150
144
 
151
145
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
152
- """
153
- Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
146
+ """Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
154
147
 
155
148
  Args:
156
149
  preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
@@ -159,14 +152,14 @@ class SegmentationValidator(DetectionValidator):
159
152
  Returns:
160
153
  (dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
161
154
 
162
- Notes:
163
- - If `masks` is True, the function computes IoU between predicted and ground truth masks.
164
- - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
165
-
166
155
  Examples:
167
156
  >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
168
157
  >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
169
158
  >>> correct_preds = validator._process_batch(preds, batch)
159
+
160
+ Notes:
161
+ - If `masks` is True, the function computes IoU between predicted and ground truth masks.
162
+ - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
170
163
  """
171
164
  tp = super()._process_batch(preds, batch)
172
165
  gt_cls = batch["cls"]
@@ -179,8 +172,7 @@ class SegmentationValidator(DetectionValidator):
179
172
  return tp
180
173
 
181
174
  def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
182
- """
183
- Plot batch predictions with masks and bounding boxes.
175
+ """Plot batch predictions with masks and bounding boxes.
184
176
 
185
177
  Args:
186
178
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -195,8 +187,7 @@ class SegmentationValidator(DetectionValidator):
195
187
  super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
196
188
 
197
189
  def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
198
- """
199
- Save YOLO detections to a txt file in normalized coordinates in a specific format.
190
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format.
200
191
 
201
192
  Args:
202
193
  predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
@@ -215,8 +206,7 @@ class SegmentationValidator(DetectionValidator):
215
206
  ).save_txt(file, save_conf=save_conf)
216
207
 
217
208
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
218
- """
219
- Save one JSON result for COCO evaluation.
209
+ """Save one JSON result for COCO evaluation.
220
210
 
221
211
  Args:
222
212
  predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
@@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
24
24
 
25
25
 
26
26
  class WorldTrainer(DetectionTrainer):
27
- """
28
- A trainer class for fine-tuning YOLO World models on close-set datasets.
27
+ """A trainer class for fine-tuning YOLO World models on close-set datasets.
29
28
 
30
29
  This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
31
30
  features for improved object detection and understanding. It handles text embedding generation and caching to
@@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
54
53
  """
55
54
 
56
55
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
57
- """
58
- Initialize a WorldTrainer object with given arguments.
56
+ """Initialize a WorldTrainer object with given arguments.
59
57
 
60
58
  Args:
61
59
  cfg (dict[str, Any]): Configuration for the trainer.
@@ -69,8 +67,7 @@ class WorldTrainer(DetectionTrainer):
69
67
  self.text_embeddings = None
70
68
 
71
69
  def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
72
- """
73
- Return WorldModel initialized with specified config and weights.
70
+ """Return WorldModel initialized with specified config and weights.
74
71
 
75
72
  Args:
76
73
  cfg (dict[str, Any] | str, optional): Model configuration.
@@ -95,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
95
92
  return model
96
93
 
97
94
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
98
- """
99
- Build YOLO Dataset for training or validation.
95
+ """Build YOLO Dataset for training or validation.
100
96
 
101
97
  Args:
102
98
  img_path (str): Path to the folder containing images.
@@ -115,11 +111,10 @@ class WorldTrainer(DetectionTrainer):
115
111
  return dataset
116
112
 
117
113
  def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
118
- """
119
- Set text embeddings for datasets to accelerate training by caching category names.
114
+ """Set text embeddings for datasets to accelerate training by caching category names.
120
115
 
121
- This method collects unique category names from all datasets, then generates and caches text embeddings
122
- for these categories to improve training efficiency.
116
+ This method collects unique category names from all datasets, then generates and caches text embeddings for
117
+ these categories to improve training efficiency.
123
118
 
124
119
  Args:
125
120
  datasets (list[Any]): List of datasets from which to extract category names.
@@ -141,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
141
136
  self.text_embeddings = text_embeddings
142
137
 
143
138
  def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
144
- """
145
- Generate text embeddings for a list of text samples.
139
+ """Generate text embeddings for a list of text samples.
146
140
 
147
141
  Args:
148
142
  texts (list[str]): List of text samples to encode.
@@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
10
10
 
11
11
 
12
12
  class WorldTrainerFromScratch(WorldTrainer):
13
- """
14
- A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
+ """A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
15
14
 
16
15
  This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
17
16
  supporting training YOLO-World models with combined vision-language capabilities.
@@ -53,11 +52,10 @@ class WorldTrainerFromScratch(WorldTrainer):
53
52
  """
54
53
 
55
54
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
56
- """
57
- Initialize a WorldTrainerFromScratch object.
55
+ """Initialize a WorldTrainerFromScratch object.
58
56
 
59
- This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
60
- object detection and grounding datasets for vision-language capabilities.
57
+ This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
58
+ detection and grounding datasets for vision-language capabilities.
61
59
 
62
60
  Args:
63
61
  cfg (dict): Configuration dictionary with default parameters for model training.
@@ -87,11 +85,10 @@ class WorldTrainerFromScratch(WorldTrainer):
87
85
  super().__init__(cfg, overrides, _callbacks)
88
86
 
89
87
  def build_dataset(self, img_path, mode="train", batch=None):
90
- """
91
- Build YOLO Dataset for training or validation.
88
+ """Build YOLO Dataset for training or validation.
92
89
 
93
- This method constructs appropriate datasets based on the mode and input paths, handling both
94
- standard YOLO datasets and grounding datasets with different formats.
90
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
91
+ datasets and grounding datasets with different formats.
95
92
 
96
93
  Args:
97
94
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -122,11 +119,10 @@ class WorldTrainerFromScratch(WorldTrainer):
122
119
  return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
123
120
 
124
121
  def get_dataset(self):
125
- """
126
- Get train and validation paths from data dictionary.
122
+ """Get train and validation paths from data dictionary.
127
123
 
128
- Processes the data configuration to extract paths for training and validation datasets,
129
- handling both YOLO detection datasets and grounding datasets.
124
+ Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
125
+ detection datasets and grounding datasets.
130
126
 
131
127
  Returns:
132
128
  train_path (str): Train dataset path.
@@ -187,8 +183,7 @@ class WorldTrainerFromScratch(WorldTrainer):
187
183
  pass
188
184
 
189
185
  def final_eval(self):
190
- """
191
- Perform final evaluation and validation for the YOLO-World model.
186
+ """Perform final evaluation and validation for the YOLO-World model.
192
187
 
193
188
  Configures the validator with appropriate dataset and split information before running evaluation.
194
189
 
@@ -9,11 +9,10 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
9
9
 
10
10
 
11
11
  class YOLOEVPDetectPredictor(DetectionPredictor):
12
- """
13
- A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
12
+ """A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
14
13
 
15
- This mixin provides common functionality for YOLO models that use visual prompting, including
16
- model setup, prompt handling, and preprocessing transformations.
14
+ This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
15
+ handling, and preprocessing transformations.
17
16
 
18
17
  Attributes:
19
18
  model (torch.nn.Module): The YOLO model for inference.
@@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
29
28
  """
30
29
 
31
30
  def setup_model(self, model, verbose: bool = True):
32
- """
33
- Set up the model for prediction.
31
+ """Set up the model for prediction.
34
32
 
35
33
  Args:
36
34
  model (torch.nn.Module): Model to load or use.
@@ -40,21 +38,19 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
40
38
  self.done_warmup = True
41
39
 
42
40
  def set_prompts(self, prompts):
43
- """
44
- Set the visual prompts for the model.
41
+ """Set the visual prompts for the model.
45
42
 
46
43
  Args:
47
- prompts (dict): Dictionary containing class indices and bounding boxes or masks.
48
- Must include a 'cls' key with class indices.
44
+ prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
45
+ with class indices.
49
46
  """
50
47
  self.prompts = prompts
51
48
 
52
49
  def pre_transform(self, im):
53
- """
54
- Preprocess images and prompts before inference.
50
+ """Preprocess images and prompts before inference.
55
51
 
56
- This method applies letterboxing to the input image and transforms the visual prompts
57
- (bounding boxes or masks) accordingly.
52
+ This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
53
+ accordingly.
58
54
 
59
55
  Args:
60
56
  im (list): List containing a single input image.
@@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
94
90
  return img
95
91
 
96
92
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
97
- """
98
- Process a single image by resizing bounding boxes or masks and generating visuals.
93
+ """Process a single image by resizing bounding boxes or masks and generating visuals.
99
94
 
100
95
  Args:
101
96
  dst_shape (tuple): The target shape (height, width) of the image.
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
131
126
  return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
132
127
 
133
128
  def inference(self, im, *args, **kwargs):
134
- """
135
- Run inference with visual prompts.
129
+ """Run inference with visual prompts.
136
130
 
137
131
  Args:
138
132
  im (torch.Tensor): Input image tensor.
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
145
139
  return super().inference(im, vpe=self.prompts, *args, **kwargs)
146
140
 
147
141
  def get_vpe(self, source):
148
- """
149
- Process the source to get the visual prompt embeddings (VPE).
142
+ """Process the source to get the visual prompt embeddings (VPE).
150
143
 
151
144
  Args:
152
- source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
153
- of the image to make predictions on. Accepts various types including file paths, URLs, PIL
154
- images, numpy arrays, and torch tensors.
145
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
146
+ make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
147
+ torch tensors.
155
148
 
156
149
  Returns:
157
150
  (torch.Tensor): The visual prompt embeddings (VPE) from the model.