dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
+ from typing import Any, Dict, List, Optional, Tuple
4
5
 
5
6
  import numpy as np
6
7
  import torch
@@ -26,18 +27,20 @@ class PoseValidator(DetectionValidator):
26
27
  metrics (PoseMetrics): Metrics object for pose evaluation.
27
28
 
28
29
  Methods:
29
- preprocess: Preprocesses batch data for pose validation.
30
- get_desc: Returns description of evaluation metrics.
31
- init_metrics: Initializes pose metrics for the model.
32
- _prepare_batch: Prepares a batch for processing.
33
- _prepare_pred: Prepares and scales predictions for evaluation.
34
- update_metrics: Updates metrics with new predictions.
35
- _process_batch: Processes batch to compute IoU between detections and ground truth.
36
- plot_val_samples: Plots validation samples with ground truth annotations.
37
- plot_predictions: Plots model predictions.
38
- save_one_txt: Saves detections to a text file.
39
- pred_to_json: Converts predictions to COCO JSON format.
40
- eval_json: Evaluates model using COCO JSON format.
30
+ preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
31
+ get_desc: Return description of evaluation metrics in string format.
32
+ init_metrics: Initialize pose estimation metrics for YOLO model.
33
+ _prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
34
+ dimensions.
35
+ _prepare_pred: Prepare and scale keypoints in predictions for pose processing.
36
+ update_metrics: Update metrics with new predictions and ground truth data.
37
+ _process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
38
+ detections and ground truth.
39
+ plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
40
+ plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
41
+ save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
42
+ pred_to_json: Convert YOLO predictions to COCO JSON format.
43
+ eval_json: Evaluate object detection model using COCO JSON format.
41
44
 
42
45
  Examples:
43
46
  >>> from ultralytics.models.yolo.pose import PoseValidator
@@ -82,13 +85,13 @@ class PoseValidator(DetectionValidator):
82
85
  "See https://github.com/ultralytics/ultralytics/issues/4031."
83
86
  )
84
87
 
85
- def preprocess(self, batch):
88
+ def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
86
89
  """Preprocess batch by converting keypoints data to float and moving it to the device."""
87
90
  batch = super().preprocess(batch)
88
91
  batch["keypoints"] = batch["keypoints"].to(self.device).float()
89
92
  return batch
90
93
 
91
- def get_desc(self):
94
+ def get_desc(self) -> str:
92
95
  """Return description of evaluation metrics in string format."""
93
96
  return ("%22s" + "%11s" * 10) % (
94
97
  "Class",
@@ -113,7 +116,7 @@ class PoseValidator(DetectionValidator):
113
116
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
114
117
  self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
115
118
 
116
- def _prepare_batch(self, si, batch):
119
+ def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
117
120
  """
118
121
  Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
119
122
 
@@ -122,7 +125,7 @@ class PoseValidator(DetectionValidator):
122
125
  batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
123
126
 
124
127
  Returns:
125
- pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
128
+ (dict): Prepared batch with keypoints scaled to original image dimensions.
126
129
 
127
130
  Notes:
128
131
  This method extends the parent class's _prepare_batch method by adding keypoint processing.
@@ -138,7 +141,7 @@ class PoseValidator(DetectionValidator):
138
141
  pbatch["kpts"] = kpts
139
142
  return pbatch
140
143
 
141
- def _prepare_pred(self, pred, pbatch):
144
+ def _prepare_pred(self, pred: torch.Tensor, pbatch: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
142
145
  """
143
146
  Prepare and scale keypoints in predictions for pose processing.
144
147
 
@@ -155,6 +158,7 @@ class PoseValidator(DetectionValidator):
155
158
 
156
159
  Returns:
157
160
  predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
161
+ pred_kpts (torch.Tensor): Predicted keypoints scaled to original image dimensions.
158
162
  """
159
163
  predn = super()._prepare_pred(pred, pbatch)
160
164
  nk = pbatch["kpts"].shape[1]
@@ -162,7 +166,7 @@ class PoseValidator(DetectionValidator):
162
166
  ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
163
167
  return predn, pred_kpts
164
168
 
165
- def update_metrics(self, preds, batch):
169
+ def update_metrics(self, preds: List[torch.Tensor], batch: Dict[str, Any]):
166
170
  """
167
171
  Update metrics with new predictions and ground truth data.
168
172
 
@@ -224,7 +228,14 @@ class PoseValidator(DetectionValidator):
224
228
  self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
225
229
  )
226
230
 
227
- def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
231
+ def _process_batch(
232
+ self,
233
+ detections: torch.Tensor,
234
+ gt_bboxes: torch.Tensor,
235
+ gt_cls: torch.Tensor,
236
+ pred_kpts: Optional[torch.Tensor] = None,
237
+ gt_kpts: Optional[torch.Tensor] = None,
238
+ ) -> torch.Tensor:
228
239
  """
229
240
  Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
230
241
 
@@ -234,9 +245,9 @@ class PoseValidator(DetectionValidator):
234
245
  gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
235
246
  box is of the format (x1, y1, x2, y2).
236
247
  gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
237
- pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
248
+ pred_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing predicted keypoints, where
238
249
  51 corresponds to 17 keypoints each having 3 values.
239
- gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
250
+ gt_kpts (torch.Tensor, optional): Tensor with shape (N, 51) representing ground truth keypoints.
240
251
 
241
252
  Returns:
242
253
  (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
@@ -255,7 +266,7 @@ class PoseValidator(DetectionValidator):
255
266
 
256
267
  return self.match_predictions(detections[:, 5], gt_cls, iou)
257
268
 
258
- def plot_val_samples(self, batch, ni):
269
+ def plot_val_samples(self, batch: Dict[str, Any], ni: int):
259
270
  """
260
271
  Plot and save validation set samples with ground truth bounding boxes and keypoints.
261
272
 
@@ -281,7 +292,7 @@ class PoseValidator(DetectionValidator):
281
292
  on_plot=self.on_plot,
282
293
  )
283
294
 
284
- def plot_predictions(self, batch, preds, ni):
295
+ def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int):
285
296
  """
286
297
  Plot and save model predictions with bounding boxes and keypoints.
287
298
 
@@ -305,7 +316,14 @@ class PoseValidator(DetectionValidator):
305
316
  on_plot=self.on_plot,
306
317
  ) # pred
307
318
 
308
- def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
319
+ def save_one_txt(
320
+ self,
321
+ predn: torch.Tensor,
322
+ pred_kpts: torch.Tensor,
323
+ save_conf: bool,
324
+ shape: Tuple[int, int],
325
+ file: Path,
326
+ ):
309
327
  """
310
328
  Save YOLO pose detections to a text file in normalized coordinates.
311
329
 
@@ -331,7 +349,7 @@ class PoseValidator(DetectionValidator):
331
349
  keypoints=pred_kpts,
332
350
  ).save_txt(file, save_conf=save_conf)
333
351
 
334
- def pred_to_json(self, predn, filename):
352
+ def pred_to_json(self, predn: torch.Tensor, filename: str):
335
353
  """
336
354
  Convert YOLO predictions to COCO JSON format.
337
355
 
@@ -364,7 +382,7 @@ class PoseValidator(DetectionValidator):
364
382
  }
365
383
  )
366
384
 
367
- def eval_json(self, stats):
385
+ def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
368
386
  """Evaluate object detection model using COCO JSON format."""
369
387
  if self.args.save_json and self.is_coco and len(self.jdict):
370
388
  anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
@@ -18,9 +18,9 @@ class SegmentationPredictor(DetectionPredictor):
18
18
  batch (list): Current batch of images being processed.
19
19
 
20
20
  Methods:
21
- postprocess: Applies non-max suppression and processes detections.
22
- construct_results: Constructs a list of result objects from predictions.
23
- construct_result: Constructs a single result object from a prediction.
21
+ postprocess: Apply non-max suppression and process segmentation detections.
22
+ construct_results: Construct a list of result objects from predictions.
23
+ construct_result: Construct a single result object from a prediction.
24
24
 
25
25
  Examples:
26
26
  >>> from ultralytics.utils import ASSETS
@@ -38,7 +38,7 @@ class SegmentationPredictor(DetectionPredictor):
38
38
  prediction results.
39
39
 
40
40
  Args:
41
- cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
41
+ cfg (dict): Configuration for the predictor.
42
42
  overrides (dict, optional): Configuration overrides that take precedence over cfg.
43
43
  _callbacks (list, optional): List of callback functions to be invoked during prediction.
44
44
  """
@@ -56,7 +56,7 @@ class SegmentationPredictor(DetectionPredictor):
56
56
 
57
57
  Returns:
58
58
  (list): List of Results objects containing the segmentation predictions for each image in the batch.
59
- Each Results object includes both bounding boxes and segmentation masks.
59
+ Each Results object includes both bounding boxes and segmentation masks.
60
60
 
61
61
  Examples:
62
62
  >>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
+ from pathlib import Path
5
+ from typing import Dict, Optional, Union
4
6
 
5
7
  from ultralytics.models import yolo
6
8
  from ultralytics.nn.tasks import SegmentationModel
@@ -25,7 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
25
27
  >>> trainer.train()
26
28
  """
27
29
 
28
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
30
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict] = None, _callbacks=None):
29
31
  """
30
32
  Initialize a SegmentationTrainer object.
31
33
 
@@ -33,7 +35,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
33
35
  functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
34
36
 
35
37
  Args:
36
- cfg (dict): Configuration dictionary with default training settings. Defaults to DEFAULT_CFG.
38
+ cfg (dict): Configuration dictionary with default training settings.
37
39
  overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
38
40
  _callbacks (list, optional): List of callback functions to be executed during training.
39
41
 
@@ -48,13 +50,15 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
48
50
  overrides["task"] = "segment"
49
51
  super().__init__(cfg, overrides, _callbacks)
50
52
 
51
- def get_model(self, cfg=None, weights=None, verbose=True):
53
+ def get_model(
54
+ self, cfg: Optional[Union[Dict, str]] = None, weights: Optional[Union[str, Path]] = None, verbose: bool = True
55
+ ):
52
56
  """
53
57
  Initialize and return a SegmentationModel with specified configuration and weights.
54
58
 
55
59
  Args:
56
- cfg (dict | str | None): Model configuration. Can be a dictionary, a path to a YAML file, or None.
57
- weights (str | Path | None): Path to pretrained weights file.
60
+ cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
61
+ weights (str | Path, optional): Path to pretrained weights file.
58
62
  verbose (bool): Whether to display model information during initialization.
59
63
 
60
64
  Returns:
@@ -78,7 +82,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
78
82
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
79
83
  )
80
84
 
81
- def plot_training_samples(self, batch, ni):
85
+ def plot_training_samples(self, batch: Dict, ni: int):
82
86
  """
83
87
  Plot training sample images with labels, bounding boxes, and masks.
84
88
 
@@ -119,5 +123,5 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
119
123
  )
120
124
 
121
125
  def plot_metrics(self):
122
- """Plots training/val metrics."""
126
+ """Plot training/validation metrics."""
123
127
  plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
@@ -69,7 +69,7 @@ class SegmentationValidator(DetectionValidator):
69
69
  self.plot_masks = []
70
70
  if self.args.save_json:
71
71
  check_requirements("pycocotools>=2.0.6")
72
- # more accurate vs faster
72
+ # More accurate vs faster
73
73
  self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
74
74
  self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
75
75
 
@@ -247,7 +247,7 @@ class SegmentationValidator(DetectionValidator):
247
247
  Returns:
248
248
  (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
249
249
 
250
- Note:
250
+ Notes:
251
251
  - If `masks` is True, the function computes IoU between predicted and ground truth masks.
252
252
  - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
253
253
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  import itertools
4
4
  from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
5
6
 
6
7
  import torch
7
8
 
@@ -12,8 +13,8 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
12
13
  from ultralytics.utils.torch_utils import de_parallel
13
14
 
14
15
 
15
- def on_pretrain_routine_end(trainer):
16
- """Callback to set up model classes and text encoder at the end of the pretrain routine."""
16
+ def on_pretrain_routine_end(trainer) -> None:
17
+ """Set up model classes and text encoder at the end of the pretrain routine."""
17
18
  if RANK in {-1, 0}:
18
19
  # Set class names for evaluation
19
20
  names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
@@ -22,45 +23,54 @@ def on_pretrain_routine_end(trainer):
22
23
 
23
24
  class WorldTrainer(DetectionTrainer):
24
25
  """
25
- A class to fine-tune a world model on a close-set dataset.
26
+ A trainer class for fine-tuning YOLO World models on close-set datasets.
26
27
 
27
- This trainer extends the DetectionTrainer to support training YOLO World models, which combine
28
- visual and textual features for improved object detection and understanding.
28
+ This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
29
+ features for improved object detection and understanding. It handles text embedding generation and caching to
30
+ accelerate training with multi-modal data.
29
31
 
30
32
  Attributes:
31
- clip (module): The CLIP module for text-image understanding.
32
- text_model (module): The text encoder model from CLIP.
33
+ text_embeddings (Dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
34
+ training.
33
35
  model (WorldModel): The YOLO World model being trained.
34
- data (dict): Dataset configuration containing class information.
35
- args (dict): Training arguments and configuration.
36
+ data (Dict[str, Any]): Dataset configuration containing class information.
37
+ args (Any): Training arguments and configuration.
38
+
39
+ Methods:
40
+ get_model: Return WorldModel initialized with specified config and weights.
41
+ build_dataset: Build YOLO Dataset for training or validation.
42
+ set_text_embeddings: Set text embeddings for datasets to accelerate training.
43
+ generate_text_embeddings: Generate text embeddings for a list of text samples.
44
+ preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
36
45
 
37
46
  Examples:
38
- >>> from ultralytics.models.yolo.world import WorldModel
47
+ Initialize and train a YOLO World model
48
+ >>> from ultralytics.models.yolo.world import WorldTrainer
39
49
  >>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
40
50
  >>> trainer = WorldTrainer(overrides=args)
41
51
  >>> trainer.train()
42
52
  """
43
53
 
44
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
54
+ def __init__(self, cfg=DEFAULT_CFG, overrides: Optional[Dict[str, Any]] = None, _callbacks=None):
45
55
  """
46
56
  Initialize a WorldTrainer object with given arguments.
47
57
 
48
58
  Args:
49
- cfg (dict): Configuration for the trainer.
50
- overrides (dict, optional): Configuration overrides.
51
- _callbacks (list, optional): List of callback functions.
59
+ cfg (Dict[str, Any]): Configuration for the trainer.
60
+ overrides (Dict[str, Any], optional): Configuration overrides.
61
+ _callbacks (List[Any], optional): List of callback functions.
52
62
  """
53
63
  if overrides is None:
54
64
  overrides = {}
55
65
  super().__init__(cfg, overrides, _callbacks)
56
66
  self.text_embeddings = None
57
67
 
58
- def get_model(self, cfg=None, weights=None, verbose=True):
68
+ def get_model(self, cfg=None, weights: Optional[str] = None, verbose: bool = True) -> WorldModel:
59
69
  """
60
70
  Return WorldModel initialized with specified config and weights.
61
71
 
62
72
  Args:
63
- cfg (Dict | str, optional): Model configuration.
73
+ cfg (Dict[str, Any] | str, optional): Model configuration.
64
74
  weights (str, optional): Path to pretrained weights.
65
75
  verbose (bool): Whether to display model info.
66
76
 
@@ -81,7 +91,7 @@ class WorldTrainer(DetectionTrainer):
81
91
 
82
92
  return model
83
93
 
84
- def build_dataset(self, img_path, mode="train", batch=None):
94
+ def build_dataset(self, img_path: str, mode: str = "train", batch: Optional[int] = None):
85
95
  """
86
96
  Build YOLO Dataset for training or validation.
87
97
 
@@ -91,7 +101,7 @@ class WorldTrainer(DetectionTrainer):
91
101
  batch (int, optional): Size of batches, this is for `rect`.
92
102
 
93
103
  Returns:
94
- (Dataset): YOLO dataset configured for training or validation.
104
+ (Any): YOLO dataset configured for training or validation.
95
105
  """
96
106
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
97
107
  dataset = build_yolo_dataset(
@@ -101,7 +111,7 @@ class WorldTrainer(DetectionTrainer):
101
111
  self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
102
112
  return dataset
103
113
 
104
- def set_text_embeddings(self, datasets, batch):
114
+ def set_text_embeddings(self, datasets: List[Any], batch: Optional[int]) -> None:
105
115
  """
106
116
  Set text embeddings for datasets to accelerate training by caching category names.
107
117
 
@@ -109,7 +119,7 @@ class WorldTrainer(DetectionTrainer):
109
119
  for these categories to improve training efficiency.
110
120
 
111
121
  Args:
112
- datasets (List[Dataset]): List of datasets from which to extract category names.
122
+ datasets (List[Any]): List of datasets from which to extract category names.
113
123
  batch (int | None): Batch size used for processing.
114
124
 
115
125
  Notes:
@@ -127,7 +137,7 @@ class WorldTrainer(DetectionTrainer):
127
137
  )
128
138
  self.text_embeddings = text_embeddings
129
139
 
130
- def generate_text_embeddings(self, texts, batch, cache_dir):
140
+ def generate_text_embeddings(self, texts: List[str], batch: int, cache_dir: Path) -> Dict[str, torch.Tensor]:
131
141
  """
132
142
  Generate text embeddings for a list of text samples.
133
143
 
@@ -137,7 +147,7 @@ class WorldTrainer(DetectionTrainer):
137
147
  cache_dir (Path): Directory to save/load cached embeddings.
138
148
 
139
149
  Returns:
140
- (dict): Dictionary mapping text samples to their embeddings.
150
+ (Dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
141
151
  """
142
152
  model = "clip:ViT-B/32"
143
153
  cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
@@ -153,7 +163,7 @@ class WorldTrainer(DetectionTrainer):
153
163
  torch.save(txt_map, cache_path)
154
164
  return txt_map
155
165
 
156
- def preprocess_batch(self, batch):
166
+ def preprocess_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
157
167
  """Preprocess a batch of images and text for YOLOWorld training."""
158
168
  batch = DetectionTrainer.preprocess_batch(self, batch)
159
169
 
@@ -18,6 +18,14 @@ class WorldTrainerFromScratch(WorldTrainer):
18
18
  cfg (dict): Configuration dictionary with default parameters for model training.
19
19
  overrides (dict): Dictionary of parameter overrides to customize the configuration.
20
20
  _callbacks (list): List of callback functions to be executed during different stages of training.
21
+ data (dict): Final processed data configuration containing train/val paths and metadata.
22
+ training_data (dict): Dictionary mapping training dataset paths to their configurations.
23
+
24
+ Methods:
25
+ build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.
26
+ get_dataset: Get train and validation paths from data dictionary.
27
+ plot_training_labels: Skip label plotting for YOLO-World training.
28
+ final_eval: Perform final evaluation and validation for the YOLO-World model.
21
29
 
22
30
  Examples:
23
31
  >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
@@ -111,8 +119,8 @@ class WorldTrainerFromScratch(WorldTrainer):
111
119
  handling both YOLO detection datasets and grounding datasets.
112
120
 
113
121
  Returns:
114
- (str): Train dataset path.
115
- (str): Validation dataset path.
122
+ train_path (str): Train dataset path.
123
+ val_path (str): Validation dataset path.
116
124
 
117
125
  Raises:
118
126
  AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
@@ -159,7 +167,7 @@ class WorldTrainerFromScratch(WorldTrainer):
159
167
  return final_data
160
168
 
161
169
  def plot_training_labels(self):
162
- """Do not plot labels for YOLO-World training."""
170
+ """Skip label plotting for YOLO-World training."""
163
171
  pass
164
172
 
165
173
  def final_eval(self):
@@ -18,23 +18,23 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
18
18
  Attributes:
19
19
  model (torch.nn.Module): The YOLO model for inference.
20
20
  device (torch.device): Device to run the model on (CPU or CUDA).
21
- prompts (dict): Visual prompts containing class indices and bounding boxes or masks.
21
+ prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
22
22
 
23
23
  Methods:
24
24
  setup_model: Initialize the YOLO model and set it to evaluation mode.
25
- set_return_vpe: Set whether to return visual prompt embeddings.
26
25
  set_prompts: Set the visual prompts for the model.
27
26
  pre_transform: Preprocess images and prompts before inference.
28
27
  inference: Run inference with visual prompts.
28
+ get_vpe: Process source to get visual prompt embeddings.
29
29
  """
30
30
 
31
- def setup_model(self, model, verbose=True):
31
+ def setup_model(self, model, verbose: bool = True):
32
32
  """
33
- Sets up the model for prediction.
33
+ Set up the model for prediction.
34
34
 
35
35
  Args:
36
36
  model (torch.nn.Module): Model to load or use.
37
- verbose (bool): If True, provides detailed logging.
37
+ verbose (bool, optional): If True, provides detailed logging.
38
38
  """
39
39
  super().setup_model(model, verbose=verbose)
40
40
  self.done_warmup = True
@@ -95,17 +95,17 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
95
95
 
96
96
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
97
97
  """
98
- Processes a single image by resizing bounding boxes or masks and generating visuals.
98
+ Process a single image by resizing bounding boxes or masks and generating visuals.
99
99
 
100
100
  Args:
101
101
  dst_shape (tuple): The target shape (height, width) of the image.
102
102
  src_shape (tuple): The original shape (height, width) of the image.
103
103
  category (str): The category of the image for visual prompts.
104
- bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2]. Defaults to None.
105
- masks (np.ndarray, optional): A list of masks corresponding to the image. Defaults to None.
104
+ bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
105
+ masks (np.ndarray, optional): A list of masks corresponding to the image.
106
106
 
107
107
  Returns:
108
- visuals: The processed visuals for the image.
108
+ (torch.Tensor): The processed visuals for the image.
109
109
 
110
110
  Raises:
111
111
  ValueError: If neither `bboxes` nor `masks` are provided.
@@ -146,7 +146,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
146
146
 
147
147
  def get_vpe(self, source):
148
148
  """
149
- Processes the source to get the visual prompt embeddings (VPE).
149
+ Process the source to get the visual prompt embeddings (VPE).
150
150
 
151
151
  Args:
152
152
  source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
@@ -164,6 +164,6 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
164
164
 
165
165
 
166
166
  class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
167
- """Predictor for YOLOE VP segmentation."""
167
+ """Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
168
168
 
169
169
  pass