ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -46,16 +46,38 @@ class YOLODataset(BaseDataset):
46
46
  """
47
47
  Dataset class for loading object detection and/or segmentation labels in YOLO format.
48
48
 
49
- Args:
50
- data (dict, optional): A dataset YAML dictionary. Defaults to None.
51
- task (str): An explicit arg to point current task, Defaults to 'detect'.
49
+ This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
50
+ (OBB) tasks using the YOLO format.
52
51
 
53
- Returns:
54
- (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
52
+ Attributes:
53
+ use_segments (bool): Indicates if segmentation masks should be used.
54
+ use_keypoints (bool): Indicates if keypoints should be used for pose estimation.
55
+ use_obb (bool): Indicates if oriented bounding boxes should be used.
56
+ data (dict): Dataset configuration dictionary.
57
+
58
+ Methods:
59
+ cache_labels: Cache dataset labels, check images and read shapes.
60
+ get_labels: Returns dictionary of labels for YOLO training.
61
+ build_transforms: Builds and appends transforms to the list.
62
+ close_mosaic: Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
63
+ update_labels_info: Updates label format for different tasks.
64
+ collate_fn: Collates data samples into batches.
65
+
66
+ Examples:
67
+ >>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
68
+ >>> dataset.get_labels()
55
69
  """
56
70
 
57
71
  def __init__(self, *args, data=None, task="detect", **kwargs):
58
- """Initializes the YOLODataset with optional configurations for segments and keypoints."""
72
+ """
73
+ Initialize the YOLODataset.
74
+
75
+ Args:
76
+ data (dict, optional): Dataset configuration dictionary.
77
+ task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
78
+ *args (Any): Additional positional arguments for the parent class.
79
+ **kwargs (Any): Additional keyword arguments for the parent class.
80
+ """
59
81
  self.use_segments = task == "segment"
60
82
  self.use_keypoints = task == "pose"
61
83
  self.use_obb = task == "obb"
@@ -68,10 +90,10 @@ class YOLODataset(BaseDataset):
68
90
  Cache dataset labels, check images and read shapes.
69
91
 
70
92
  Args:
71
- path (Path): Path where to save the cache file. Default is Path("./labels.cache").
93
+ path (Path): Path where to save the cache file.
72
94
 
73
95
  Returns:
74
- (dict): labels.
96
+ (dict): Dictionary containing cached labels and related information.
75
97
  """
76
98
  x = {"labels": []}
77
99
  nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
@@ -131,7 +153,14 @@ class YOLODataset(BaseDataset):
131
153
  return x
132
154
 
133
155
  def get_labels(self):
134
- """Returns dictionary of labels for YOLO training."""
156
+ """
157
+ Returns dictionary of labels for YOLO training.
158
+
159
+ This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
160
+
161
+ Returns:
162
+ (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
163
+ """
135
164
  self.label_files = img2label_paths(self.im_files)
136
165
  cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
137
166
  try:
@@ -172,7 +201,15 @@ class YOLODataset(BaseDataset):
172
201
  return labels
173
202
 
174
203
  def build_transforms(self, hyp=None):
175
- """Builds and appends transforms to the list."""
204
+ """
205
+ Builds and appends transforms to the list.
206
+
207
+ Args:
208
+ hyp (dict, optional): Hyperparameters for transforms.
209
+
210
+ Returns:
211
+ (Compose): Composed transforms.
212
+ """
176
213
  if self.augment:
177
214
  hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
178
215
  hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
@@ -195,7 +232,12 @@ class YOLODataset(BaseDataset):
195
232
  return transforms
196
233
 
197
234
  def close_mosaic(self, hyp):
198
- """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
235
+ """
236
+ Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
237
+
238
+ Args:
239
+ hyp (dict): Hyperparameters for transforms.
240
+ """
199
241
  hyp.mosaic = 0.0 # set mosaic ratio=0.0
200
242
  hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
201
243
  hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
@@ -205,6 +247,12 @@ class YOLODataset(BaseDataset):
205
247
  """
206
248
  Custom your label format here.
207
249
 
250
+ Args:
251
+ label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
252
+
253
+ Returns:
254
+ (dict): Updated label dictionary with instances.
255
+
208
256
  Note:
209
257
  cls is not with bboxes now, classification and semantic segmentation need an independent cls label
210
258
  Can also support classification and semantic segmentation by adding or removing dict keys there.
@@ -230,7 +278,15 @@ class YOLODataset(BaseDataset):
230
278
 
231
279
  @staticmethod
232
280
  def collate_fn(batch):
233
- """Collates data samples into batches."""
281
+ """
282
+ Collates data samples into batches.
283
+
284
+ Args:
285
+ batch (List[dict]): List of dictionaries containing sample data.
286
+
287
+ Returns:
288
+ (dict): Collated batch with stacked tensors.
289
+ """
234
290
  new_batch = {}
235
291
  keys = batch[0].keys()
236
292
  values = list(zip(*[list(b.values()) for b in batch]))
@@ -250,29 +306,58 @@ class YOLODataset(BaseDataset):
250
306
 
251
307
  class YOLOMultiModalDataset(YOLODataset):
252
308
  """
253
- Dataset class for loading object detection and/or segmentation labels in YOLO format.
309
+ Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
254
310
 
255
- Args:
256
- data (dict, optional): A dataset YAML dictionary. Defaults to None.
257
- task (str): An explicit arg to point current task, Defaults to 'detect'.
311
+ This class extends YOLODataset to add text information for multi-modal model training, enabling models to
312
+ process both image and text data.
258
313
 
259
- Returns:
260
- (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
314
+ Methods:
315
+ update_labels_info: Adds text information for multi-modal model training.
316
+ build_transforms: Enhances data transformations with text augmentation.
317
+
318
+ Examples:
319
+ >>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
320
+ >>> batch = next(iter(dataset))
321
+ >>> print(batch.keys()) # Should include 'texts'
261
322
  """
262
323
 
263
324
  def __init__(self, *args, data=None, task="detect", **kwargs):
264
- """Initializes a dataset object for object detection tasks with optional specifications."""
325
+ """
326
+ Initialize a YOLOMultiModalDataset.
327
+
328
+ Args:
329
+ data (dict, optional): Dataset configuration dictionary.
330
+ task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
331
+ *args (Any): Additional positional arguments for the parent class.
332
+ **kwargs (Any): Additional keyword arguments for the parent class.
333
+ """
265
334
  super().__init__(*args, data=data, task=task, **kwargs)
266
335
 
267
336
  def update_labels_info(self, label):
268
- """Add texts information for multi-modal model training."""
337
+ """
338
+ Add texts information for multi-modal model training.
339
+
340
+ Args:
341
+ label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
342
+
343
+ Returns:
344
+ (dict): Updated label dictionary with instances and texts.
345
+ """
269
346
  labels = super().update_labels_info(label)
270
347
  # NOTE: some categories are concatenated with its synonyms by `/`.
271
348
  labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
272
349
  return labels
273
350
 
274
351
  def build_transforms(self, hyp=None):
275
- """Enhances data transformations with optional text augmentation for multi-modal training."""
352
+ """
353
+ Enhances data transformations with optional text augmentation for multi-modal training.
354
+
355
+ Args:
356
+ hyp (dict, optional): Hyperparameters for transforms.
357
+
358
+ Returns:
359
+ (Compose): Composed transforms including text augmentation if applicable.
360
+ """
276
361
  transforms = super().build_transforms(hyp)
277
362
  if self.augment:
278
363
  # NOTE: hard-coded the args for now.
@@ -281,20 +366,58 @@ class YOLOMultiModalDataset(YOLODataset):
281
366
 
282
367
 
283
368
  class GroundingDataset(YOLODataset):
284
- """Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""
369
+ """
370
+ Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format.
371
+
372
+ This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than
373
+ the standard YOLO format text files.
374
+
375
+ Attributes:
376
+ json_file (str): Path to the JSON file containing annotations.
377
+
378
+ Methods:
379
+ get_img_files: Returns empty list as image files are read in get_labels.
380
+ get_labels: Loads annotations from a JSON file and prepares them for training.
381
+ build_transforms: Configures augmentations for training with optional text loading.
382
+
383
+ Examples:
384
+ >>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
385
+ >>> len(dataset) # Number of valid images with annotations
386
+ """
285
387
 
286
388
  def __init__(self, *args, task="detect", json_file, **kwargs):
287
- """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
389
+ """
390
+ Initialize a GroundingDataset for object detection.
391
+
392
+ Args:
393
+ json_file (str): Path to the JSON file containing annotations.
394
+ task (str): Must be 'detect' for GroundingDataset.
395
+ *args (Any): Additional positional arguments for the parent class.
396
+ **kwargs (Any): Additional keyword arguments for the parent class.
397
+ """
288
398
  assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
289
399
  self.json_file = json_file
290
400
  super().__init__(*args, task=task, data={}, **kwargs)
291
401
 
292
402
  def get_img_files(self, img_path):
293
- """The image files would be read in `get_labels` function, return empty list here."""
403
+ """
404
+ The image files would be read in `get_labels` function, return empty list here.
405
+
406
+ Args:
407
+ img_path (str): Path to the directory containing images.
408
+
409
+ Returns:
410
+ (List): Empty list as image files are read in get_labels.
411
+ """
294
412
  return []
295
413
 
296
414
  def get_labels(self):
297
- """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
415
+ """
416
+ Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
417
+
418
+ Returns:
419
+ (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
420
+ """
298
421
  labels = []
299
422
  LOGGER.info("Loading annotation file...")
300
423
  with open(self.json_file) as f:
@@ -347,7 +470,15 @@ class GroundingDataset(YOLODataset):
347
470
  return labels
348
471
 
349
472
  def build_transforms(self, hyp=None):
350
- """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
473
+ """
474
+ Configures augmentations for training with optional text loading.
475
+
476
+ Args:
477
+ hyp (dict, optional): Hyperparameters for transforms.
478
+
479
+ Returns:
480
+ (Compose): Composed transforms including text augmentation if applicable.
481
+ """
351
482
  transforms = super().build_transforms(hyp)
352
483
  if self.augment:
353
484
  # NOTE: hard-coded the args for now.
@@ -359,27 +490,35 @@ class YOLOConcatDataset(ConcatDataset):
359
490
  """
360
491
  Dataset as a concatenation of multiple datasets.
361
492
 
362
- This class is useful to assemble different existing datasets.
493
+ This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same
494
+ collation function.
495
+
496
+ Methods:
497
+ collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
498
+
499
+ Examples:
500
+ >>> dataset1 = YOLODataset(...)
501
+ >>> dataset2 = YOLODataset(...)
502
+ >>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])
363
503
  """
364
504
 
365
505
  @staticmethod
366
506
  def collate_fn(batch):
367
- """Collates data samples into batches."""
507
+ """
508
+ Collates data samples into batches.
509
+
510
+ Args:
511
+ batch (List[dict]): List of dictionaries containing sample data.
512
+
513
+ Returns:
514
+ (dict): Collated batch with stacked tensors.
515
+ """
368
516
  return YOLODataset.collate_fn(batch)
369
517
 
370
518
 
371
519
  # TODO: support semantic segmentation
372
520
  class SemanticDataset(BaseDataset):
373
- """
374
- Semantic Segmentation Dataset.
375
-
376
- This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities
377
- from the BaseDataset class.
378
-
379
- Note:
380
- This class is currently a placeholder and needs to be populated with methods and attributes for supporting
381
- semantic segmentation tasks.
382
- """
521
+ """Semantic Segmentation Dataset."""
383
522
 
384
523
  def __init__(self):
385
524
  """Initialize a SemanticDataset object."""
@@ -388,20 +527,25 @@ class SemanticDataset(BaseDataset):
388
527
 
389
528
  class ClassificationDataset:
390
529
  """
391
- Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
392
- augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
393
- learning models, with optional image transformations and caching mechanisms to speed up training.
530
+ Extends torchvision ImageFolder to support YOLO classification tasks.
394
531
 
395
- This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
396
- in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
397
- to ensure data integrity and consistency.
532
+ This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
533
+ handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
534
+ to speed up training.
398
535
 
399
536
  Attributes:
400
537
  cache_ram (bool): Indicates if caching in RAM is enabled.
401
538
  cache_disk (bool): Indicates if caching on disk is enabled.
402
- samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
539
+ samples (List): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
403
540
  file (if caching on disk), and optionally the loaded image array (if caching in RAM).
404
541
  torch_transforms (callable): PyTorch transforms to be applied to the images.
542
+ root (str): Root directory of the dataset.
543
+ prefix (str): Prefix for logging and cache filenames.
544
+
545
+ Methods:
546
+ __getitem__: Returns subset of data and targets corresponding to given indices.
547
+ __len__: Returns the total number of samples in the dataset.
548
+ verify_images: Verifies all images in dataset.
405
549
  """
406
550
 
407
551
  def __init__(self, root, args, augment=False, prefix=""):
@@ -411,12 +555,9 @@ class ClassificationDataset:
411
555
  Args:
412
556
  root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
413
557
  args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
414
- parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
415
- of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
416
- `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
417
- augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
418
- prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
419
- debugging. Default is an empty string.
558
+ parameters, and cache settings.
559
+ augment (bool, optional): Whether to apply augmentations to the dataset.
560
+ prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification.
420
561
  """
421
562
  import torchvision # scope for faster 'import ultralytics'
422
563
 
@@ -460,7 +601,15 @@ class ClassificationDataset:
460
601
  )
461
602
 
462
603
  def __getitem__(self, i):
463
- """Returns subset of data and targets corresponding to given indices."""
604
+ """
605
+ Returns subset of data and targets corresponding to given indices.
606
+
607
+ Args:
608
+ i (int): Index of the sample to retrieve.
609
+
610
+ Returns:
611
+ (dict): Dictionary containing the image and its class index.
612
+ """
464
613
  f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
465
614
  if self.cache_ram:
466
615
  if im is None: # Warning: two separate if statements required here, do not combine this with previous line
@@ -481,7 +630,12 @@ class ClassificationDataset:
481
630
  return len(self.samples)
482
631
 
483
632
  def verify_images(self):
484
- """Verify all images in dataset."""
633
+ """
634
+ Verify all images in dataset.
635
+
636
+ Returns:
637
+ (List): List of valid samples after verification.
638
+ """
485
639
  desc = f"{self.prefix}Scanning {self.root}..."
486
640
  path = Path(self.root).with_suffix(".cache") # *.cache file path
487
641
 
@@ -33,6 +33,7 @@ class SourceTypes:
33
33
  stream (bool): Flag indicating if the input source is a video stream.
34
34
  screenshot (bool): Flag indicating if the input source is a screenshot.
35
35
  from_img (bool): Flag indicating if the input source is an image file.
36
+ tensor (bool): Flag indicating if the input source is a tensor.
36
37
 
37
38
  Examples:
38
39
  >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)
@@ -19,14 +19,14 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
19
19
  Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
20
20
 
21
21
  Args:
22
- polygon1 (np.ndarray): Polygon coordinates, shape (n, 8).
23
- bbox2 (np.ndarray): Bounding boxes, shape (n, 4).
24
- eps (float, optional): Small value to prevent division by zero. Defaults to 1e-6.
22
+ polygon1 (np.ndarray): Polygon coordinates with shape (n, 8).
23
+ bbox2 (np.ndarray): Bounding boxes with shape (n, 4).
24
+ eps (float, optional): Small value to prevent division by zero.
25
25
 
26
26
  Returns:
27
- (np.ndarray): IoF scores, shape (n, 1) or (n, m) if bbox2 is (m, 4).
27
+ (np.ndarray): IoF scores with shape (n, 1) or (n, m) if bbox2 is (m, 4).
28
28
 
29
- Note:
29
+ Notes:
30
30
  Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
31
31
  Bounding box format: [x_min, y_min, x_max, y_max].
32
32
  """
@@ -66,9 +66,12 @@ def load_yolo_dota(data_root, split="train"):
66
66
  Load DOTA dataset.
67
67
 
68
68
  Args:
69
- data_root (str): Data root.
69
+ data_root (str): Data root directory.
70
70
  split (str): The split data set, could be `train` or `val`.
71
71
 
72
+ Returns:
73
+ (List[Dict]): List of annotation dictionaries containing image information.
74
+
72
75
  Notes:
73
76
  The directory structure assumed for the DOTA dataset:
74
77
  - data_root
@@ -100,10 +103,13 @@ def get_windows(im_size, crop_sizes=(1024,), gaps=(200,), im_rate_thr=0.6, eps=0
100
103
 
101
104
  Args:
102
105
  im_size (tuple): Original image size, (h, w).
103
- crop_sizes (List(int)): Crop size of windows.
104
- gaps (List(int)): Gap between crops.
105
- im_rate_thr (float): Threshold of windows areas divided by image ares.
106
+ crop_sizes (List[int]): Crop size of windows.
107
+ gaps (List[int]): Gap between crops.
108
+ im_rate_thr (float): Threshold of windows areas divided by image areas.
106
109
  eps (float): Epsilon value for math operations.
110
+
111
+ Returns:
112
+ (np.ndarray): Array of window coordinates with shape (n, 4) where each row is [x_start, y_start, x_stop, y_stop].
107
113
  """
108
114
  h, w = im_size
109
115
  windows = []
@@ -157,9 +163,9 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir, allow_background_i
157
163
  Crop images and save new labels.
158
164
 
159
165
  Args:
160
- anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
161
- windows (list): A list of windows coordinates.
162
- window_objs (list): A list of labels inside each window.
166
+ anno (Dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
167
+ windows (np.ndarray): Array of windows coordinates with shape (n, 4).
168
+ window_objs (List): A list of labels inside each window.
163
169
  im_dir (str): The output directory path of images.
164
170
  lb_dir (str): The output directory path of labels.
165
171
  allow_background_images (bool): Whether to include background images without labels.
@@ -201,6 +207,13 @@ def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024
201
207
  """
202
208
  Split both images and labels.
203
209
 
210
+ Args:
211
+ data_root (str): Root directory of the dataset.
212
+ save_dir (str): Directory to save the split dataset.
213
+ split (str): The split data set, could be `train` or `val`.
214
+ crop_sizes (tuple): Tuple of crop sizes.
215
+ gaps (tuple): Tuple of gaps between crops.
216
+
204
217
  Notes:
205
218
  The directory structure assumed for the DOTA dataset:
206
219
  - data_root
@@ -231,6 +244,13 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
231
244
  """
232
245
  Split train and val set of DOTA.
233
246
 
247
+ Args:
248
+ data_root (str): Root directory of the dataset.
249
+ save_dir (str): Directory to save the split dataset.
250
+ crop_size (int): Base crop size.
251
+ gap (int): Base gap between crops.
252
+ rates (tuple): Scaling rates for crop_size and gap.
253
+
234
254
  Notes:
235
255
  The directory structure assumed for the DOTA dataset:
236
256
  - data_root
@@ -261,6 +281,13 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)):
261
281
  """
262
282
  Split test set of DOTA, labels are not included within this set.
263
283
 
284
+ Args:
285
+ data_root (str): Root directory of the dataset.
286
+ save_dir (str): Directory to save the split dataset.
287
+ crop_size (int): Base crop size.
288
+ gap (int): Base gap between crops.
289
+ rates (tuple): Scaling rates for crop_size and gap.
290
+
264
291
  Notes:
265
292
  The directory structure assumed for the DOTA dataset:
266
293
  - data_root
ultralytics/data/utils.py CHANGED
@@ -175,13 +175,8 @@ def visualize_image_annotations(image_path, txt_path, label_map):
175
175
  adjusted for readability, depending on the background color's luminance.
176
176
 
177
177
  Args:
178
- image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL (e.g., .jpg, .png).
179
- txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object with:
180
- - class_id (int): The class index.
181
- - x_center (float): The X center of the bounding box (relative to image width).
182
- - y_center (float): The Y center of the bounding box (relative to image height).
183
- - width (float): The width of the bounding box (relative to image width).
184
- - height (float): The height of the bounding box (relative to image height).
178
+ image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
179
+ txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.
185
180
  label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
186
181
 
187
182
  Examples:
@@ -222,8 +217,8 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
222
217
  imgsz (tuple): The size of the image as (height, width).
223
218
  polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
224
219
  N is the number of polygons, and M is the number of points such that M % 2 = 0.
225
- color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1.
226
- downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1.
220
+ color (int, optional): The color value to fill in the polygons on the mask.
221
+ downsample_ratio (int, optional): Factor by which to downsample the mask.
227
222
 
228
223
  Returns:
229
224
  (np.ndarray): A binary mask of the specified image size with the polygons filled in.
@@ -246,7 +241,7 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
246
241
  polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
247
242
  N is the number of polygons, and M is the number of points such that M % 2 = 0.
248
243
  color (int): The color value to fill in the polygons on the masks.
249
- downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1.
244
+ downsample_ratio (int, optional): Factor by which to downsample each mask.
250
245
 
251
246
  Returns:
252
247
  (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
@@ -281,8 +276,7 @@ def find_dataset_yaml(path: Path) -> Path:
281
276
  Find and return the YAML file associated with a Detect, Segment or Pose dataset.
282
277
 
283
278
  This function searches for a YAML file at the root level of the provided directory first, and if not found, it
284
- performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError
285
- is raised if no YAML file is found or if multiple YAML files are found.
279
+ performs a recursive search. It prefers YAML files that have the same stem as the provided path.
286
280
 
287
281
  Args:
288
282
  path (Path): The directory path to search for the YAML file.
@@ -308,7 +302,7 @@ def check_det_dataset(dataset, autodownload=True):
308
302
 
309
303
  Args:
310
304
  dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
311
- autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True.
305
+ autodownload (bool, optional): Whether to automatically download the dataset if not found.
312
306
 
313
307
  Returns:
314
308
  (dict): Parsed dataset information and paths.
@@ -400,7 +394,7 @@ def check_cls_dataset(dataset, split=""):
400
394
 
401
395
  Args:
402
396
  dataset (str | Path): The name of the dataset.
403
- split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
397
+ split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
404
398
 
405
399
  Returns:
406
400
  (dict): A dictionary containing the following keys:
@@ -440,8 +434,10 @@ def check_cls_dataset(dataset, split=""):
440
434
  test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
441
435
  if split == "val" and not val_set:
442
436
  LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
437
+ val_set = test_set
443
438
  elif split == "test" and not test_set:
444
439
  LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
440
+ test_set = val_set
445
441
 
446
442
  nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
447
443
  names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
@@ -634,8 +630,8 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
634
630
  Args:
635
631
  f (str): The path to the input image file.
636
632
  f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
637
- max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
638
- quality (int, optional): The image compression quality as a percentage. Default is 50%.
633
+ max_dim (int, optional): The maximum dimension (width or height) of the output image.
634
+ quality (int, optional): The image compression quality as a percentage.
639
635
 
640
636
  Examples:
641
637
  >>> from pathlib import Path
@@ -664,9 +660,9 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
664
660
  Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
665
661
 
666
662
  Args:
667
- path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'.
668
- weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
669
- annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
663
+ path (Path, optional): Path to images directory.
664
+ weights (list | tuple, optional): Train, validation, and test split fractions.
665
+ annotated_only (bool, optional): If True, only images with an associated txt file are used.
670
666
 
671
667
  Examples:
672
668
  >>> from ultralytics.data.utils import autosplit