ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +125 -39
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +34 -33
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +33 -47
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +69 -90
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +31 -38
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +21 -26
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +23 -17
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +29 -24
  53. ultralytics/models/nas/predict.py +14 -11
  54. ultralytics/models/nas/val.py +11 -13
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +21 -21
  57. ultralytics/models/rtdetr/train.py +25 -24
  58. ultralytics/models/rtdetr/val.py +47 -14
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +30 -12
  73. ultralytics/models/yolo/classify/train.py +83 -19
  74. ultralytics/models/yolo/classify/val.py +45 -23
  75. ultralytics/models/yolo/detect/predict.py +29 -19
  76. ultralytics/models/yolo/detect/train.py +90 -23
  77. ultralytics/models/yolo/detect/val.py +150 -29
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +18 -13
  80. ultralytics/models/yolo/obb/train.py +12 -8
  81. ultralytics/models/yolo/obb/val.py +35 -22
  82. ultralytics/models/yolo/pose/predict.py +28 -15
  83. ultralytics/models/yolo/pose/train.py +21 -8
  84. ultralytics/models/yolo/pose/val.py +51 -31
  85. ultralytics/models/yolo/segment/predict.py +27 -16
  86. ultralytics/models/yolo/segment/train.py +11 -8
  87. ultralytics/models/yolo/segment/val.py +110 -29
  88. ultralytics/models/yolo/world/train.py +43 -16
  89. ultralytics/models/yolo/world/train_world.py +61 -36
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +12 -12
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +226 -79
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +37 -35
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +139 -68
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +37 -56
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +117 -52
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +65 -61
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +72 -59
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +202 -64
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +13 -25
  149. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.88.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -16,18 +16,38 @@ class PoseValidator(DetectionValidator):
16
16
  """
17
17
  A class extending the DetectionValidator class for validation based on a pose model.
18
18
 
19
- Example:
20
- ```python
21
- from ultralytics.models.yolo.pose import PoseValidator
22
-
23
- args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
24
- validator = PoseValidator(args=args)
25
- validator()
26
- ```
19
+ This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
20
+ specialized metrics for pose evaluation.
21
+
22
+ Attributes:
23
+ sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints.
24
+ kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
25
+ args (Dict): Arguments for the validator including task set to "pose".
26
+ metrics (PoseMetrics): Metrics object for pose evaluation.
27
+
28
+ Methods:
29
+ preprocess: Preprocesses batch data for pose validation.
30
+ get_desc: Returns description of evaluation metrics.
31
+ init_metrics: Initializes pose metrics for the model.
32
+ _prepare_batch: Prepares a batch for processing.
33
+ _prepare_pred: Prepares and scales predictions for evaluation.
34
+ update_metrics: Updates metrics with new predictions.
35
+ _process_batch: Processes batch to compute IoU between detections and ground truth.
36
+ plot_val_samples: Plots validation samples with ground truth annotations.
37
+ plot_predictions: Plots model predictions.
38
+ save_one_txt: Saves detections to a text file.
39
+ pred_to_json: Converts predictions to COCO JSON format.
40
+ eval_json: Evaluates model using COCO JSON format.
41
+
42
+ Examples:
43
+ >>> from ultralytics.models.yolo.pose import PoseValidator
44
+ >>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
45
+ >>> validator = PoseValidator(args=args)
46
+ >>> validator()
27
47
  """
28
48
 
29
49
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
30
- """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
50
+ """Initialize a PoseValidator object with custom parameters and assigned attributes."""
31
51
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
32
52
  self.sigma = None
33
53
  self.kpt_shape = None
@@ -40,13 +60,13 @@ class PoseValidator(DetectionValidator):
40
60
  )
41
61
 
42
62
  def preprocess(self, batch):
43
- """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
63
+ """Preprocess batch by converting keypoints data to float and moving it to the device."""
44
64
  batch = super().preprocess(batch)
45
65
  batch["keypoints"] = batch["keypoints"].to(self.device).float()
46
66
  return batch
47
67
 
48
68
  def get_desc(self):
49
- """Returns description of evaluation metrics in string format."""
69
+ """Return description of evaluation metrics in string format."""
50
70
  return ("%22s" + "%11s" * 10) % (
51
71
  "Class",
52
72
  "Images",
@@ -62,7 +82,7 @@ class PoseValidator(DetectionValidator):
62
82
  )
63
83
 
64
84
  def init_metrics(self, model):
65
- """Initiate pose estimation metrics for YOLO model."""
85
+ """Initialize pose estimation metrics for YOLO model."""
66
86
  super().init_metrics(model)
67
87
  self.kpt_shape = self.data["kpt_shape"]
68
88
  is_pose = self.kpt_shape == [17, 3]
@@ -71,7 +91,7 @@ class PoseValidator(DetectionValidator):
71
91
  self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
72
92
 
73
93
  def _prepare_batch(self, si, batch):
74
- """Prepares a batch for processing by converting keypoints to float and moving to device."""
94
+ """Prepare a batch for processing by converting keypoints to float and scaling to original dimensions."""
75
95
  pbatch = super()._prepare_batch(si, batch)
76
96
  kpts = batch["keypoints"][batch["batch_idx"] == si]
77
97
  h, w = pbatch["imgsz"]
@@ -83,7 +103,7 @@ class PoseValidator(DetectionValidator):
83
103
  return pbatch
84
104
 
85
105
  def _prepare_pred(self, pred, pbatch):
86
- """Prepares and scales keypoints in a batch for pose processing."""
106
+ """Prepare and scale keypoints in predictions for pose processing."""
87
107
  predn = super()._prepare_pred(pred, pbatch)
88
108
  nk = pbatch["kpts"].shape[1]
89
109
  pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
@@ -91,7 +111,16 @@ class PoseValidator(DetectionValidator):
91
111
  return predn, pred_kpts
92
112
 
93
113
  def update_metrics(self, preds, batch):
94
- """Metrics."""
114
+ """
115
+ Update metrics with new predictions and ground truth data.
116
+
117
+ This method processes each prediction, compares it with ground truth, and updates various statistics
118
+ for performance evaluation.
119
+
120
+ Args:
121
+ preds (List[torch.Tensor]): List of prediction tensors from the model.
122
+ batch (Dict): Batch data containing images and ground truth annotations.
123
+ """
95
124
  for si, pred in enumerate(preds):
96
125
  self.seen += 1
97
126
  npr = len(pred)
@@ -161,18 +190,9 @@ class PoseValidator(DetectionValidator):
161
190
  (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
162
191
  where N is the number of detections.
163
192
 
164
- Example:
165
- ```python
166
- detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class)
167
- gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
168
- gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
169
- pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
170
- gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
171
- correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
172
- ```
173
-
174
- Note:
175
- `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
193
+ Notes:
194
+ `0.53` scale factor used in area computation is referenced from
195
+ https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
176
196
  """
177
197
  if pred_kpts is not None and gt_kpts is not None:
178
198
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
@@ -184,7 +204,7 @@ class PoseValidator(DetectionValidator):
184
204
  return self.match_predictions(detections[:, 5], gt_cls, iou)
185
205
 
186
206
  def plot_val_samples(self, batch, ni):
187
- """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
207
+ """Plot and save validation set samples with ground truth bounding boxes and keypoints."""
188
208
  plot_images(
189
209
  batch["img"],
190
210
  batch["batch_idx"],
@@ -198,7 +218,7 @@ class PoseValidator(DetectionValidator):
198
218
  )
199
219
 
200
220
  def plot_predictions(self, batch, preds, ni):
201
- """Plots predictions for YOLO model."""
221
+ """Plot and save model predictions with bounding boxes and keypoints."""
202
222
  pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
203
223
  plot_images(
204
224
  batch["img"],
@@ -223,7 +243,7 @@ class PoseValidator(DetectionValidator):
223
243
  ).save_txt(file, save_conf=save_conf)
224
244
 
225
245
  def pred_to_json(self, predn, filename):
226
- """Converts YOLO predictions to COCO JSON format."""
246
+ """Convert YOLO predictions to COCO JSON format."""
227
247
  stem = Path(filename).stem
228
248
  image_id = int(stem) if stem.isnumeric() else stem
229
249
  box = ops.xyxy2xywh(predn[:, :4]) # xywh
@@ -240,7 +260,7 @@ class PoseValidator(DetectionValidator):
240
260
  )
241
261
 
242
262
  def eval_json(self, stats):
243
- """Evaluates object detection model using COCO JSON format."""
263
+ """Evaluate object detection model using COCO JSON format."""
244
264
  if self.args.save_json and self.is_coco and len(self.jdict):
245
265
  anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
246
266
  pred_json = self.save_dir / "predictions.json" # predictions
@@ -9,31 +9,41 @@ class SegmentationPredictor(DetectionPredictor):
9
9
  """
10
10
  A class extending the DetectionPredictor class for prediction based on a segmentation model.
11
11
 
12
- Example:
13
- ```python
14
- from ultralytics.utils import ASSETS
15
- from ultralytics.models.yolo.segment import SegmentationPredictor
12
+ This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
13
+ prediction results.
16
14
 
17
- args = dict(model="yolo11n-seg.pt", source=ASSETS)
18
- predictor = SegmentationPredictor(overrides=args)
19
- predictor.predict_cli()
20
- ```
15
+ Attributes:
16
+ args (Dict): Configuration arguments for the predictor.
17
+ model (torch.nn.Module): The loaded YOLO segmentation model.
18
+ batch (List): Current batch of images being processed.
19
+
20
+ Methods:
21
+ postprocess: Applies non-max suppression and processes detections.
22
+ construct_results: Constructs a list of result objects from predictions.
23
+ construct_result: Constructs a single result object from a prediction.
24
+
25
+ Examples:
26
+ >>> from ultralytics.utils import ASSETS
27
+ >>> from ultralytics.models.yolo.segment import SegmentationPredictor
28
+ >>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
29
+ >>> predictor = SegmentationPredictor(overrides=args)
30
+ >>> predictor.predict_cli()
21
31
  """
22
32
 
23
33
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
- """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
34
+ """Initialize the SegmentationPredictor with configuration, overrides, and callbacks."""
25
35
  super().__init__(cfg, overrides, _callbacks)
26
36
  self.args.task = "segment"
27
37
 
28
38
  def postprocess(self, preds, img, orig_imgs):
29
- """Applies non-max suppression and processes detections for each image in an input batch."""
30
- # tuple if PyTorch model or array if exported
39
+ """Apply non-max suppression and process detections for each image in the input batch."""
40
+ # Extract protos - tuple if PyTorch model or array if exported
31
41
  protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
32
42
  return super().postprocess(preds[0], img, orig_imgs, protos=protos)
33
43
 
34
44
  def construct_results(self, preds, img, orig_imgs, protos):
35
45
  """
36
- Constructs a list of result objects from the predictions.
46
+ Construct a list of result objects from the predictions.
37
47
 
38
48
  Args:
39
49
  preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
@@ -42,7 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
42
52
  protos (List[torch.Tensor]): List of prototype masks.
43
53
 
44
54
  Returns:
45
- (list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
55
+ (List[Results]): List of result objects containing the original images, image paths, class names,
56
+ bounding boxes, and masks.
46
57
  """
47
58
  return [
48
59
  self.construct_result(pred, img, orig_img, img_path, proto)
@@ -51,7 +62,7 @@ class SegmentationPredictor(DetectionPredictor):
51
62
 
52
63
  def construct_result(self, pred, img, orig_img, img_path, proto):
53
64
  """
54
- Constructs the result object from the prediction.
65
+ Construct a single result object from the prediction.
55
66
 
56
67
  Args:
57
68
  pred (np.ndarray): The predicted bounding boxes, scores, and masks.
@@ -61,7 +72,7 @@ class SegmentationPredictor(DetectionPredictor):
61
72
  proto (torch.Tensor): The prototype masks.
62
73
 
63
74
  Returns:
64
- (Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
75
+ (Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
65
76
  """
66
77
  if not len(pred): # save empty boxes
67
78
  masks = None
@@ -72,6 +83,6 @@ class SegmentationPredictor(DetectionPredictor):
72
83
  masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
73
84
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
74
85
  if masks is not None:
75
- keep = masks.sum((-2, -1)) > 0 # only keep preds with masks
86
+ keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
76
87
  pred, masks = pred[keep], masks[keep]
77
88
  return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
@@ -12,14 +12,17 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
12
12
  """
13
13
  A class extending the DetectionTrainer class for training based on a segmentation model.
14
14
 
15
- Example:
16
- ```python
17
- from ultralytics.models.yolo.segment import SegmentationTrainer
18
-
19
- args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
20
- trainer = SegmentationTrainer(overrides=args)
21
- trainer.train()
22
- ```
15
+ This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
16
+ functionality including model initialization, validation, and visualization.
17
+
18
+ Attributes:
19
+ loss_names (Tuple[str]): Names of the loss components used during training.
20
+
21
+ Examples:
22
+ >>> from ultralytics.models.yolo.segment import SegmentationTrainer
23
+ >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
24
+ >>> trainer = SegmentationTrainer(overrides=args)
25
+ >>> trainer.train()
23
26
  """
24
27
 
25
28
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@@ -18,18 +18,34 @@ class SegmentationValidator(DetectionValidator):
18
18
  """
19
19
  A class extending the DetectionValidator class for validation based on a segmentation model.
20
20
 
21
- Example:
22
- ```python
23
- from ultralytics.models.yolo.segment import SegmentationValidator
24
-
25
- args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
26
- validator = SegmentationValidator(args=args)
27
- validator()
28
- ```
21
+ This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
22
+ to compute metrics such as mAP for both detection and segmentation tasks.
23
+
24
+ Attributes:
25
+ plot_masks (List): List to store masks for plotting.
26
+ process (callable): Function to process masks based on save_json and save_txt flags.
27
+ args (namespace): Arguments for the validator.
28
+ metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
29
+ stats (Dict): Dictionary to store statistics during validation.
30
+
31
+ Examples:
32
+ >>> from ultralytics.models.yolo.segment import SegmentationValidator
33
+ >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
34
+ >>> validator = SegmentationValidator(args=args)
35
+ >>> validator()
29
36
  """
30
37
 
31
38
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
32
- """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
39
+ """
40
+ Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
41
+
42
+ Args:
43
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
44
+ save_dir (Path, optional): Directory to save results.
45
+ pbar (Any, optional): Progress bar for displaying progress.
46
+ args (namespace, optional): Arguments for the validator.
47
+ _callbacks (List, optional): List of callback functions.
48
+ """
33
49
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
50
  self.plot_masks = None
35
51
  self.process = None
@@ -37,13 +53,18 @@ class SegmentationValidator(DetectionValidator):
37
53
  self.metrics = SegmentMetrics(save_dir=self.save_dir)
38
54
 
39
55
  def preprocess(self, batch):
40
- """Preprocesses batch by converting masks to float and sending to device."""
56
+ """Preprocess batch by converting masks to float and sending to device."""
41
57
  batch = super().preprocess(batch)
42
58
  batch["masks"] = batch["masks"].to(self.device).float()
43
59
  return batch
44
60
 
45
61
  def init_metrics(self, model):
46
- """Initialize metrics and select mask processing function based on save_json flag."""
62
+ """
63
+ Initialize metrics and select mask processing function based on save_json flag.
64
+
65
+ Args:
66
+ model (torch.nn.Module): Model to validate.
67
+ """
47
68
  super().init_metrics(model)
48
69
  self.plot_masks = []
49
70
  if self.args.save_json:
@@ -69,26 +90,61 @@ class SegmentationValidator(DetectionValidator):
69
90
  )
70
91
 
71
92
  def postprocess(self, preds):
72
- """Post-processes YOLO predictions and returns output detections with proto."""
93
+ """
94
+ Post-process YOLO predictions and return output detections with proto.
95
+
96
+ Args:
97
+ preds (List): Raw predictions from the model.
98
+
99
+ Returns:
100
+ p (torch.Tensor): Processed detection predictions.
101
+ proto (torch.Tensor): Prototype masks for segmentation.
102
+ """
73
103
  p = super().postprocess(preds[0])
74
104
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
75
105
  return p, proto
76
106
 
77
107
  def _prepare_batch(self, si, batch):
78
- """Prepares a batch for training or inference by processing images and targets."""
108
+ """
109
+ Prepare a batch for training or inference by processing images and targets.
110
+
111
+ Args:
112
+ si (int): Batch index.
113
+ batch (Dict): Batch data containing images and targets.
114
+
115
+ Returns:
116
+ (Dict): Prepared batch with processed images and targets.
117
+ """
79
118
  prepared_batch = super()._prepare_batch(si, batch)
80
119
  midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
81
120
  prepared_batch["masks"] = batch["masks"][midx]
82
121
  return prepared_batch
83
122
 
84
123
  def _prepare_pred(self, pred, pbatch, proto):
85
- """Prepares a batch for training or inference by processing images and targets."""
124
+ """
125
+ Prepare predictions for evaluation by processing bounding boxes and masks.
126
+
127
+ Args:
128
+ pred (torch.Tensor): Raw predictions from the model.
129
+ pbatch (Dict): Prepared batch data.
130
+ proto (torch.Tensor): Prototype masks for segmentation.
131
+
132
+ Returns:
133
+ predn (torch.Tensor): Processed bounding box predictions.
134
+ pred_masks (torch.Tensor): Processed mask predictions.
135
+ """
86
136
  predn = super()._prepare_pred(pred, pbatch)
87
137
  pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
88
138
  return predn, pred_masks
89
139
 
90
140
  def update_metrics(self, preds, batch):
91
- """Metrics."""
141
+ """
142
+ Update metrics with the current batch predictions and targets.
143
+
144
+ Args:
145
+ preds (List): Predictions from the model.
146
+ batch (Dict): Batch data containing images and targets.
147
+ """
92
148
  for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
93
149
  self.seen += 1
94
150
  npr = len(pred)
@@ -157,7 +213,7 @@ class SegmentationValidator(DetectionValidator):
157
213
  )
158
214
 
159
215
  def finalize_metrics(self, *args, **kwargs):
160
- """Sets speed and confusion matrix for evaluation metrics."""
216
+ """Set speed and confusion matrix for evaluation metrics."""
161
217
  self.metrics.speed = self.speed
162
218
  self.metrics.confusion_matrix = self.confusion_matrix
163
219
 
@@ -171,9 +227,9 @@ class SegmentationValidator(DetectionValidator):
171
227
  gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
172
228
  Each row is of the format [x1, y1, x2, y2].
173
229
  gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
174
- pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
230
+ pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
175
231
  match the ground truth masks.
176
- gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
232
+ gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
177
233
  overlap (bool): Flag indicating if overlapping masks should be considered.
178
234
  masks (bool): Flag indicating if the batch contains mask data.
179
235
 
@@ -184,13 +240,11 @@ class SegmentationValidator(DetectionValidator):
184
240
  - If `masks` is True, the function computes IoU between predicted and ground truth masks.
185
241
  - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
186
242
 
187
- Example:
188
- ```python
189
- detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
190
- gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
191
- gt_cls = torch.tensor([1, 0])
192
- correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
193
- ```
243
+ Examples:
244
+ >>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
245
+ >>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
246
+ >>> gt_cls = torch.tensor([1, 0])
247
+ >>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
194
248
  """
195
249
  if masks:
196
250
  if overlap:
@@ -208,7 +262,13 @@ class SegmentationValidator(DetectionValidator):
208
262
  return self.match_predictions(detections[:, 5], gt_cls, iou)
209
263
 
210
264
  def plot_val_samples(self, batch, ni):
211
- """Plots validation samples with bounding box labels."""
265
+ """
266
+ Plot validation samples with bounding box labels and masks.
267
+
268
+ Args:
269
+ batch (Dict): Batch data containing images and targets.
270
+ ni (int): Batch index.
271
+ """
212
272
  plot_images(
213
273
  batch["img"],
214
274
  batch["batch_idx"],
@@ -222,7 +282,14 @@ class SegmentationValidator(DetectionValidator):
222
282
  )
223
283
 
224
284
  def plot_predictions(self, batch, preds, ni):
225
- """Plots batch predictions with masks and bounding boxes."""
285
+ """
286
+ Plot batch predictions with masks and bounding boxes.
287
+
288
+ Args:
289
+ batch (Dict): Batch data containing images.
290
+ preds (List): Predictions from the model.
291
+ ni (int): Batch index.
292
+ """
226
293
  plot_images(
227
294
  batch["img"],
228
295
  *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
@@ -235,7 +302,16 @@ class SegmentationValidator(DetectionValidator):
235
302
  self.plot_masks.clear()
236
303
 
237
304
  def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
238
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
305
+ """
306
+ Save YOLO detections to a txt file in normalized coordinates in a specific format.
307
+
308
+ Args:
309
+ predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
310
+ pred_masks (torch.Tensor): Predicted masks.
311
+ save_conf (bool): Whether to save confidence scores.
312
+ shape (Tuple): Original image shape.
313
+ file (Path): File path to save the detections.
314
+ """
239
315
  from ultralytics.engine.results import Results
240
316
 
241
317
  Results(
@@ -248,7 +324,12 @@ class SegmentationValidator(DetectionValidator):
248
324
 
249
325
  def pred_to_json(self, predn, filename, pred_masks):
250
326
  """
251
- Save one JSON result.
327
+ Save one JSON result for COCO evaluation.
328
+
329
+ Args:
330
+ predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
331
+ filename (str): Image filename.
332
+ pred_masks (numpy.ndarray): Predicted masks.
252
333
 
253
334
  Examples:
254
335
  >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -10,9 +10,9 @@ from ultralytics.utils.torch_utils import de_parallel
10
10
 
11
11
 
12
12
  def on_pretrain_routine_end(trainer):
13
- """Callback."""
13
+ """Callback to set up model classes and text encoder at the end of the pretrain routine."""
14
14
  if RANK in {-1, 0}:
15
- # NOTE: for evaluation
15
+ # Set class names for evaluation
16
16
  names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
17
17
  de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
18
18
  device = next(trainer.model.parameters()).device
@@ -25,18 +25,32 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
25
25
  """
26
26
  A class to fine-tune a world model on a close-set dataset.
27
27
 
28
- Example:
29
- ```python
30
- from ultralytics.models.yolo.world import WorldModel
31
-
32
- args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
33
- trainer = WorldTrainer(overrides=args)
34
- trainer.train()
35
- ```
28
+ This trainer extends the DetectionTrainer to support training YOLO World models, which combine
29
+ visual and textual features for improved object detection and understanding.
30
+
31
+ Attributes:
32
+ clip (module): The CLIP module for text-image understanding.
33
+ text_model (module): The text encoder model from CLIP.
34
+ model (WorldModel): The YOLO World model being trained.
35
+ data (Dict): Dataset configuration containing class information.
36
+ args (Dict): Training arguments and configuration.
37
+
38
+ Examples:
39
+ >>> from ultralytics.models.yolo.world import WorldModel
40
+ >>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
41
+ >>> trainer = WorldTrainer(overrides=args)
42
+ >>> trainer.train()
36
43
  """
37
44
 
38
45
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """Initialize a WorldTrainer object with given arguments."""
46
+ """
47
+ Initialize a WorldTrainer object with given arguments.
48
+
49
+ Args:
50
+ cfg (Dict): Configuration for the trainer.
51
+ overrides (Dict, optional): Configuration overrides.
52
+ _callbacks (List, optional): List of callback functions.
53
+ """
40
54
  if overrides is None:
41
55
  overrides = {}
42
56
  super().__init__(cfg, overrides, _callbacks)
@@ -50,7 +64,17 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
50
64
  self.clip = clip
51
65
 
52
66
  def get_model(self, cfg=None, weights=None, verbose=True):
53
- """Return WorldModel initialized with specified config and weights."""
67
+ """
68
+ Return WorldModel initialized with specified config and weights.
69
+
70
+ Args:
71
+ cfg (Dict | str, optional): Model configuration.
72
+ weights (str, optional): Path to pretrained weights.
73
+ verbose (bool): Whether to display model info.
74
+
75
+ Returns:
76
+ (WorldModel): Initialized WorldModel.
77
+ """
54
78
  # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
55
79
  # NOTE: Following the official config, nc hard-coded to 80 for now.
56
80
  model = WorldModel(
@@ -67,12 +91,15 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
67
91
 
68
92
  def build_dataset(self, img_path, mode="train", batch=None):
69
93
  """
70
- Build YOLO Dataset.
94
+ Build YOLO Dataset for training or validation.
71
95
 
72
96
  Args:
73
97
  img_path (str): Path to the folder containing images.
74
98
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
75
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
99
+ batch (int, optional): Size of batches, this is for `rect`.
100
+
101
+ Returns:
102
+ (Dataset): YOLO dataset configured for training or validation.
76
103
  """
77
104
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
78
105
  return build_yolo_dataset(
@@ -80,10 +107,10 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
80
107
  )
81
108
 
82
109
  def preprocess_batch(self, batch):
83
- """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
110
+ """Preprocess a batch of images and text for YOLOWorld training."""
84
111
  batch = super().preprocess_batch(batch)
85
112
 
86
- # NOTE: add text features
113
+ # Add text features
87
114
  texts = list(itertools.chain(*batch["texts"]))
88
115
  text_token = self.clip.tokenize(texts).to(batch["img"].device)
89
116
  txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32