ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +118 -30
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +5 -5
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +13 -19
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +67 -88
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +21 -18
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +12 -13
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +20 -11
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +22 -11
  53. ultralytics/models/nas/predict.py +9 -4
  54. ultralytics/models/nas/val.py +5 -5
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +18 -15
  57. ultralytics/models/rtdetr/train.py +20 -16
  58. ultralytics/models/rtdetr/val.py +42 -6
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +24 -3
  73. ultralytics/models/yolo/classify/train.py +77 -10
  74. ultralytics/models/yolo/classify/val.py +40 -15
  75. ultralytics/models/yolo/detect/predict.py +23 -10
  76. ultralytics/models/yolo/detect/train.py +85 -15
  77. ultralytics/models/yolo/detect/val.py +145 -21
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +12 -4
  80. ultralytics/models/yolo/obb/train.py +7 -0
  81. ultralytics/models/yolo/obb/val.py +25 -7
  82. ultralytics/models/yolo/pose/predict.py +22 -6
  83. ultralytics/models/yolo/pose/train.py +17 -1
  84. ultralytics/models/yolo/pose/val.py +46 -21
  85. ultralytics/models/yolo/segment/predict.py +22 -8
  86. ultralytics/models/yolo/segment/train.py +6 -0
  87. ultralytics/models/yolo/segment/val.py +100 -14
  88. ultralytics/models/yolo/world/train.py +38 -8
  89. ultralytics/models/yolo/world/train_world.py +39 -10
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +3 -0
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +221 -69
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +32 -27
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +116 -35
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +13 -9
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +112 -45
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +61 -53
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +64 -45
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +181 -33
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +8 -16
  149. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.89.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,16 @@ class SegmentationValidator(DetectionValidator):
18
18
  """
19
19
  A class extending the DetectionValidator class for validation based on a segmentation model.
20
20
 
21
+ This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
22
+ to compute metrics such as mAP for both detection and segmentation tasks.
23
+
24
+ Attributes:
25
+ plot_masks (List): List to store masks for plotting.
26
+ process (callable): Function to process masks based on save_json and save_txt flags.
27
+ args (namespace): Arguments for the validator.
28
+ metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
29
+ stats (Dict): Dictionary to store statistics during validation.
30
+
21
31
  Examples:
22
32
  >>> from ultralytics.models.yolo.segment import SegmentationValidator
23
33
  >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
@@ -26,7 +36,16 @@ class SegmentationValidator(DetectionValidator):
26
36
  """
27
37
 
28
38
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
29
- """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
39
+ """
40
+ Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
41
+
42
+ Args:
43
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
44
+ save_dir (Path, optional): Directory to save results.
45
+ pbar (Any, optional): Progress bar for displaying progress.
46
+ args (namespace, optional): Arguments for the validator.
47
+ _callbacks (List, optional): List of callback functions.
48
+ """
30
49
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
31
50
  self.plot_masks = None
32
51
  self.process = None
@@ -34,13 +53,18 @@ class SegmentationValidator(DetectionValidator):
34
53
  self.metrics = SegmentMetrics(save_dir=self.save_dir)
35
54
 
36
55
  def preprocess(self, batch):
37
- """Preprocesses batch by converting masks to float and sending to device."""
56
+ """Preprocess batch by converting masks to float and sending to device."""
38
57
  batch = super().preprocess(batch)
39
58
  batch["masks"] = batch["masks"].to(self.device).float()
40
59
  return batch
41
60
 
42
61
  def init_metrics(self, model):
43
- """Initialize metrics and select mask processing function based on save_json flag."""
62
+ """
63
+ Initialize metrics and select mask processing function based on save_json flag.
64
+
65
+ Args:
66
+ model (torch.nn.Module): Model to validate.
67
+ """
44
68
  super().init_metrics(model)
45
69
  self.plot_masks = []
46
70
  if self.args.save_json:
@@ -66,26 +90,61 @@ class SegmentationValidator(DetectionValidator):
66
90
  )
67
91
 
68
92
  def postprocess(self, preds):
69
- """Post-processes YOLO predictions and returns output detections with proto."""
93
+ """
94
+ Post-process YOLO predictions and return output detections with proto.
95
+
96
+ Args:
97
+ preds (List): Raw predictions from the model.
98
+
99
+ Returns:
100
+ p (torch.Tensor): Processed detection predictions.
101
+ proto (torch.Tensor): Prototype masks for segmentation.
102
+ """
70
103
  p = super().postprocess(preds[0])
71
104
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
72
105
  return p, proto
73
106
 
74
107
  def _prepare_batch(self, si, batch):
75
- """Prepares a batch for training or inference by processing images and targets."""
108
+ """
109
+ Prepare a batch for training or inference by processing images and targets.
110
+
111
+ Args:
112
+ si (int): Batch index.
113
+ batch (Dict): Batch data containing images and targets.
114
+
115
+ Returns:
116
+ (Dict): Prepared batch with processed images and targets.
117
+ """
76
118
  prepared_batch = super()._prepare_batch(si, batch)
77
119
  midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
78
120
  prepared_batch["masks"] = batch["masks"][midx]
79
121
  return prepared_batch
80
122
 
81
123
  def _prepare_pred(self, pred, pbatch, proto):
82
- """Prepares a batch for training or inference by processing images and targets."""
124
+ """
125
+ Prepare predictions for evaluation by processing bounding boxes and masks.
126
+
127
+ Args:
128
+ pred (torch.Tensor): Raw predictions from the model.
129
+ pbatch (Dict): Prepared batch data.
130
+ proto (torch.Tensor): Prototype masks for segmentation.
131
+
132
+ Returns:
133
+ predn (torch.Tensor): Processed bounding box predictions.
134
+ pred_masks (torch.Tensor): Processed mask predictions.
135
+ """
83
136
  predn = super()._prepare_pred(pred, pbatch)
84
137
  pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
85
138
  return predn, pred_masks
86
139
 
87
140
  def update_metrics(self, preds, batch):
88
- """Metrics."""
141
+ """
142
+ Update metrics with the current batch predictions and targets.
143
+
144
+ Args:
145
+ preds (List): Predictions from the model.
146
+ batch (Dict): Batch data containing images and targets.
147
+ """
89
148
  for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
90
149
  self.seen += 1
91
150
  npr = len(pred)
@@ -154,7 +213,7 @@ class SegmentationValidator(DetectionValidator):
154
213
  )
155
214
 
156
215
  def finalize_metrics(self, *args, **kwargs):
157
- """Sets speed and confusion matrix for evaluation metrics."""
216
+ """Set speed and confusion matrix for evaluation metrics."""
158
217
  self.metrics.speed = self.speed
159
218
  self.metrics.confusion_matrix = self.confusion_matrix
160
219
 
@@ -168,9 +227,9 @@ class SegmentationValidator(DetectionValidator):
168
227
  gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
169
228
  Each row is of the format [x1, y1, x2, y2].
170
229
  gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
171
- pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
230
+ pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
172
231
  match the ground truth masks.
173
- gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
232
+ gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
174
233
  overlap (bool): Flag indicating if overlapping masks should be considered.
175
234
  masks (bool): Flag indicating if the batch contains mask data.
176
235
 
@@ -203,7 +262,13 @@ class SegmentationValidator(DetectionValidator):
203
262
  return self.match_predictions(detections[:, 5], gt_cls, iou)
204
263
 
205
264
  def plot_val_samples(self, batch, ni):
206
- """Plots validation samples with bounding box labels."""
265
+ """
266
+ Plot validation samples with bounding box labels and masks.
267
+
268
+ Args:
269
+ batch (Dict): Batch data containing images and targets.
270
+ ni (int): Batch index.
271
+ """
207
272
  plot_images(
208
273
  batch["img"],
209
274
  batch["batch_idx"],
@@ -217,7 +282,14 @@ class SegmentationValidator(DetectionValidator):
217
282
  )
218
283
 
219
284
  def plot_predictions(self, batch, preds, ni):
220
- """Plots batch predictions with masks and bounding boxes."""
285
+ """
286
+ Plot batch predictions with masks and bounding boxes.
287
+
288
+ Args:
289
+ batch (Dict): Batch data containing images.
290
+ preds (List): Predictions from the model.
291
+ ni (int): Batch index.
292
+ """
221
293
  plot_images(
222
294
  batch["img"],
223
295
  *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
@@ -230,7 +302,16 @@ class SegmentationValidator(DetectionValidator):
230
302
  self.plot_masks.clear()
231
303
 
232
304
  def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
233
- """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
305
+ """
306
+ Save YOLO detections to a txt file in normalized coordinates in a specific format.
307
+
308
+ Args:
309
+ predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
310
+ pred_masks (torch.Tensor): Predicted masks.
311
+ save_conf (bool): Whether to save confidence scores.
312
+ shape (Tuple): Original image shape.
313
+ file (Path): File path to save the detections.
314
+ """
234
315
  from ultralytics.engine.results import Results
235
316
 
236
317
  Results(
@@ -243,7 +324,12 @@ class SegmentationValidator(DetectionValidator):
243
324
 
244
325
  def pred_to_json(self, predn, filename, pred_masks):
245
326
  """
246
- Save one JSON result.
327
+ Save one JSON result for COCO evaluation.
328
+
329
+ Args:
330
+ predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
331
+ filename (str): Image filename.
332
+ pred_masks (numpy.ndarray): Predicted masks.
247
333
 
248
334
  Examples:
249
335
  >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
@@ -10,9 +10,9 @@ from ultralytics.utils.torch_utils import de_parallel
10
10
 
11
11
 
12
12
  def on_pretrain_routine_end(trainer):
13
- """Callback."""
13
+ """Callback to set up model classes and text encoder at the end of the pretrain routine."""
14
14
  if RANK in {-1, 0}:
15
- # NOTE: for evaluation
15
+ # Set class names for evaluation
16
16
  names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
17
17
  de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
18
18
  device = next(trainer.model.parameters()).device
@@ -25,6 +25,16 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
25
25
  """
26
26
  A class to fine-tune a world model on a close-set dataset.
27
27
 
28
+ This trainer extends the DetectionTrainer to support training YOLO World models, which combine
29
+ visual and textual features for improved object detection and understanding.
30
+
31
+ Attributes:
32
+ clip (module): The CLIP module for text-image understanding.
33
+ text_model (module): The text encoder model from CLIP.
34
+ model (WorldModel): The YOLO World model being trained.
35
+ data (Dict): Dataset configuration containing class information.
36
+ args (Dict): Training arguments and configuration.
37
+
28
38
  Examples:
29
39
  >>> from ultralytics.models.yolo.world import WorldModel
30
40
  >>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
@@ -33,7 +43,14 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
33
43
  """
34
44
 
35
45
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
36
- """Initialize a WorldTrainer object with given arguments."""
46
+ """
47
+ Initialize a WorldTrainer object with given arguments.
48
+
49
+ Args:
50
+ cfg (Dict): Configuration for the trainer.
51
+ overrides (Dict, optional): Configuration overrides.
52
+ _callbacks (List, optional): List of callback functions.
53
+ """
37
54
  if overrides is None:
38
55
  overrides = {}
39
56
  super().__init__(cfg, overrides, _callbacks)
@@ -47,7 +64,17 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
47
64
  self.clip = clip
48
65
 
49
66
  def get_model(self, cfg=None, weights=None, verbose=True):
50
- """Return WorldModel initialized with specified config and weights."""
67
+ """
68
+ Return WorldModel initialized with specified config and weights.
69
+
70
+ Args:
71
+ cfg (Dict | str, optional): Model configuration.
72
+ weights (str, optional): Path to pretrained weights.
73
+ verbose (bool): Whether to display model info.
74
+
75
+ Returns:
76
+ (WorldModel): Initialized WorldModel.
77
+ """
51
78
  # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
52
79
  # NOTE: Following the official config, nc hard-coded to 80 for now.
53
80
  model = WorldModel(
@@ -64,12 +91,15 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
64
91
 
65
92
  def build_dataset(self, img_path, mode="train", batch=None):
66
93
  """
67
- Build YOLO Dataset.
94
+ Build YOLO Dataset for training or validation.
68
95
 
69
96
  Args:
70
97
  img_path (str): Path to the folder containing images.
71
98
  mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
72
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
99
+ batch (int, optional): Size of batches, this is for `rect`.
100
+
101
+ Returns:
102
+ (Dataset): YOLO dataset configured for training or validation.
73
103
  """
74
104
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
75
105
  return build_yolo_dataset(
@@ -77,10 +107,10 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
77
107
  )
78
108
 
79
109
  def preprocess_batch(self, batch):
80
- """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
110
+ """Preprocess a batch of images and text for YOLOWorld training."""
81
111
  batch = super().preprocess_batch(batch)
82
112
 
83
- # NOTE: add text features
113
+ # Add text features
84
114
  texts = list(itertools.chain(*batch["texts"]))
85
115
  text_token = self.clip.tokenize(texts).to(batch["img"].device)
86
116
  txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
@@ -9,7 +9,15 @@ from ultralytics.utils.torch_utils import de_parallel
9
9
 
10
10
  class WorldTrainerFromScratch(WorldTrainer):
11
11
  """
12
- A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
12
+ A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
+
14
+ This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
15
+ supporting training YOLO-World models with combined vision-language capabilities.
16
+
17
+ Attributes:
18
+ cfg (Dict): Configuration dictionary with default parameters for model training.
19
+ overrides (Dict): Dictionary of parameter overrides to customize the configuration.
20
+ _callbacks (List): List of callback functions to be executed during different stages of training.
13
21
 
14
22
  Examples:
15
23
  >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
@@ -35,19 +43,25 @@ class WorldTrainerFromScratch(WorldTrainer):
35
43
  """
36
44
 
37
45
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
38
- """Initialize a WorldTrainer object with given arguments."""
46
+ """Initialize a WorldTrainerFromScratch object with given configuration and callbacks."""
39
47
  if overrides is None:
40
48
  overrides = {}
41
49
  super().__init__(cfg, overrides, _callbacks)
42
50
 
43
51
  def build_dataset(self, img_path, mode="train", batch=None):
44
52
  """
45
- Build YOLO Dataset.
53
+ Build YOLO Dataset for training or validation.
54
+
55
+ This method constructs appropriate datasets based on the mode and input paths, handling both
56
+ standard YOLO datasets and grounding datasets with different formats.
46
57
 
47
58
  Args:
48
- img_path (List[str] | str): Path to the folder containing images.
49
- mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
50
- batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
59
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
60
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
61
+ batch (int, optional): Size of batches, used for rectangular training/validation.
62
+
63
+ Returns:
64
+ (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
51
65
  """
52
66
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
53
67
  if mode != "train":
@@ -62,9 +76,17 @@ class WorldTrainerFromScratch(WorldTrainer):
62
76
 
63
77
  def get_dataset(self):
64
78
  """
65
- Get train, val path from data dict if it exists.
79
+ Get train and validation paths from data dictionary.
80
+
81
+ Processes the data configuration to extract paths for training and validation datasets,
82
+ handling both YOLO detection datasets and grounding datasets.
66
83
 
67
- Returns None if data format is not recognized.
84
+ Returns:
85
+ (str): Train dataset path.
86
+ (str): Validation dataset path.
87
+
88
+ Raises:
89
+ AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
68
90
  """
69
91
  final_data = {}
70
92
  data_yaml = self.args.data
@@ -94,11 +116,18 @@ class WorldTrainerFromScratch(WorldTrainer):
94
116
  return final_data["train"], final_data["val"][0]
95
117
 
96
118
  def plot_training_labels(self):
97
- """DO NOT plot labels."""
119
+ """Do not plot labels for YOLO-World training."""
98
120
  pass
99
121
 
100
122
  def final_eval(self):
101
- """Performs final evaluation and validation for object detection YOLO-World model."""
123
+ """
124
+ Perform final evaluation and validation for the YOLO-World model.
125
+
126
+ Configures the validator with appropriate dataset and split information before running evaluation.
127
+
128
+ Returns:
129
+ (Dict): Dictionary containing evaluation metrics and results.
130
+ """
102
131
  val = self.args.data["val"]["yolo_data"][0]
103
132
  self.validator.args.data = val
104
133
  self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
@@ -19,11 +19,7 @@ from ultralytics.utils.downloads import attempt_download_asset, is_url
19
19
 
20
20
 
21
21
  def check_class_names(names):
22
- """
23
- Check class names.
24
-
25
- Map imagenet class codes to human-readable names if required. Convert lists to dicts.
26
- """
22
+ """Check class names and convert to dict format if needed."""
27
23
  if isinstance(names, list): # names is a list
28
24
  names = dict(enumerate(names)) # convert to dict
29
25
  if isinstance(names, dict):
@@ -78,8 +74,23 @@ class AutoBackend(nn.Module):
78
74
  | IMX | *_imx_model/ |
79
75
  | RKNN | *_rknn_model/ |
80
76
 
81
- This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy
82
- models across various platforms.
77
+ Attributes:
78
+ model (torch.nn.Module): The loaded YOLO model.
79
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
80
+ task (str): The type of task the model performs (detect, segment, classify, pose).
81
+ names (Dict): A dictionary of class names that the model can detect.
82
+ stride (int): The model stride, typically 32 for YOLO models.
83
+ fp16 (bool): Whether the model uses half-precision (FP16) inference.
84
+
85
+ Methods:
86
+ forward: Run inference on an input image.
87
+ from_numpy: Convert numpy array to tensor.
88
+ warmup: Warm up the model with a dummy input.
89
+ _model_type: Determine the model type from file path.
90
+
91
+ Examples:
92
+ >>> model = AutoBackend(weights="yolov8n.pt", device="cuda")
93
+ >>> results = model(img)
83
94
  """
84
95
 
85
96
  @torch.no_grad()
@@ -101,7 +112,7 @@ class AutoBackend(nn.Module):
101
112
  weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'.
102
113
  device (torch.device): Device to run the model on. Defaults to CPU.
103
114
  dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
104
- data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional.
115
+ data (str | Path | optional): Path to the additional data.yaml file containing class names.
105
116
  fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
106
117
  batch (int): Batch-size to assume for inference.
107
118
  fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
@@ -539,12 +550,12 @@ class AutoBackend(nn.Module):
539
550
 
540
551
  Args:
541
552
  im (torch.Tensor): The image tensor to perform inference on.
542
- augment (bool): whether to perform data augmentation during inference, defaults to False
543
- visualize (bool): whether to visualize the output predictions, defaults to False
544
- embed (list, optional): A list of feature vectors/embeddings to return.
553
+ augment (bool): Whether to perform data augmentation during inference. Defaults to False.
554
+ visualize (bool): Whether to visualize the output predictions. Defaults to False.
555
+ embed (List, optional): A list of feature vectors/embeddings to return.
545
556
 
546
557
  Returns:
547
- (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
558
+ (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.
548
559
  """
549
560
  b, ch, h, w = im.shape # batch, channel, height, width
550
561
  if self.fp16 and im.dtype != torch.float16:
@@ -776,10 +787,13 @@ class AutoBackend(nn.Module):
776
787
  def _model_type(p="path/to/model.pt"):
777
788
  """
778
789
  Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
779
- saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
790
+ saved_model, pb, tflite, edgetpu, tfjs, ncnn, mnn, imx or paddle.
780
791
 
781
792
  Args:
782
- p (str): path to the model file. Defaults to path/to/model.pt
793
+ p (str): Path to the model file. Defaults to path/to/model.pt
794
+
795
+ Returns:
796
+ (List[bool]): List of booleans indicating the model type.
783
797
 
784
798
  Examples:
785
799
  >>> model = AutoBackend(weights="path/to/model.onnx")
@@ -2,6 +2,9 @@
2
2
  """
3
3
  Ultralytics modules.
4
4
 
5
+ This module provides access to various neural network components used in Ultralytics models, including convolution blocks,
6
+ attention mechanisms, transformer components, and detection/segmentation heads.
7
+
5
8
  Examples:
6
9
  Visualize a module with Netron.
7
10
  >>> from ultralytics.nn.modules import *
@@ -6,10 +6,19 @@ import torch.nn as nn
6
6
 
7
7
 
8
8
  class AGLU(nn.Module):
9
- """Unified activation function module from https://github.com/kostas1515/AGLU."""
9
+ """
10
+ Unified activation function module from https://github.com/kostas1515/AGLU.
11
+
12
+ This class implements a parameterized activation function with learnable parameters lambda and kappa.
13
+
14
+ Attributes:
15
+ act (nn.Softplus): Softplus activation function with negative beta.
16
+ lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.
17
+ kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.
18
+ """
10
19
 
11
20
  def __init__(self, device=None, dtype=None) -> None:
12
- """Initialize the Unified activation function."""
21
+ """Initialize the Unified activation function with learnable parameters."""
13
22
  super().__init__()
14
23
  self.act = nn.Softplus(beta=-1.0)
15
24
  self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter
@@ -17,5 +26,5 @@ class AGLU(nn.Module):
17
26
 
18
27
  def forward(self, x: torch.Tensor) -> torch.Tensor:
19
28
  """Compute the forward pass of the Unified activation function."""
20
- lam = torch.clamp(self.lambd, min=0.0001)
29
+ lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero
21
30
  return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))