ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +11 -11
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +52 -51
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +191 -161
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +4 -6
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +2 -2
  93. ultralytics/solutions/instance_segmentation.py +7 -4
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +15 -11
  96. ultralytics/solutions/object_cropper.py +3 -2
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +189 -79
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +45 -29
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
  143. ultralytics-8.3.145.dist-info/RECORD +272 -0
  144. ultralytics-8.3.143.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from collections import defaultdict
5
5
  from itertools import repeat
6
6
  from multiprocessing.pool import ThreadPool
7
7
  from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple
8
9
 
9
10
  import cv2
10
11
  import numpy as np
@@ -58,18 +59,18 @@ class YOLODataset(BaseDataset):
58
59
 
59
60
  Methods:
60
61
  cache_labels: Cache dataset labels, check images and read shapes.
61
- get_labels: Returns dictionary of labels for YOLO training.
62
- build_transforms: Builds and appends transforms to the list.
63
- close_mosaic: Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
64
- update_labels_info: Updates label format for different tasks.
65
- collate_fn: Collates data samples into batches.
62
+ get_labels: Return dictionary of labels for YOLO training.
63
+ build_transforms: Build and append transforms to the list.
64
+ close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
65
+ update_labels_info: Update label format for different tasks.
66
+ collate_fn: Collate data samples into batches.
66
67
 
67
68
  Examples:
68
69
  >>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
69
70
  >>> dataset.get_labels()
70
71
  """
71
72
 
72
- def __init__(self, *args, data=None, task="detect", **kwargs):
73
+ def __init__(self, *args, data: Optional[Dict] = None, task: str = "detect", **kwargs):
73
74
  """
74
75
  Initialize the YOLODataset.
75
76
 
@@ -86,7 +87,7 @@ class YOLODataset(BaseDataset):
86
87
  assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
87
88
  super().__init__(*args, channels=self.data["channels"], **kwargs)
88
89
 
89
- def cache_labels(self, path=Path("./labels.cache")):
90
+ def cache_labels(self, path: Path = Path("./labels.cache")) -> Dict:
90
91
  """
91
92
  Cache dataset labels, check images and read shapes.
92
93
 
@@ -154,9 +155,9 @@ class YOLODataset(BaseDataset):
154
155
  save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
155
156
  return x
156
157
 
157
- def get_labels(self):
158
+ def get_labels(self) -> List[Dict]:
158
159
  """
159
- Returns dictionary of labels for YOLO training.
160
+ Return dictionary of labels for YOLO training.
160
161
 
161
162
  This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
162
163
 
@@ -204,9 +205,9 @@ class YOLODataset(BaseDataset):
204
205
  LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
205
206
  return labels
206
207
 
207
- def build_transforms(self, hyp=None):
208
+ def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:
208
209
  """
209
- Builds and appends transforms to the list.
210
+ Build and append transforms to the list.
210
211
 
211
212
  Args:
212
213
  hyp (dict, optional): Hyperparameters for transforms.
@@ -236,7 +237,7 @@ class YOLODataset(BaseDataset):
236
237
  )
237
238
  return transforms
238
239
 
239
- def close_mosaic(self, hyp):
240
+ def close_mosaic(self, hyp: Dict) -> None:
240
241
  """
241
242
  Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
242
243
 
@@ -249,9 +250,9 @@ class YOLODataset(BaseDataset):
249
250
  hyp.cutmix = 0.0
250
251
  self.transforms = self.build_transforms(hyp)
251
252
 
252
- def update_labels_info(self, label):
253
+ def update_labels_info(self, label: Dict) -> Dict:
253
254
  """
254
- Custom your label format here.
255
+ Update label format for different tasks.
255
256
 
256
257
  Args:
257
258
  label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@@ -283,9 +284,9 @@ class YOLODataset(BaseDataset):
283
284
  return label
284
285
 
285
286
  @staticmethod
286
- def collate_fn(batch):
287
+ def collate_fn(batch: List[Dict]) -> Dict:
287
288
  """
288
- Collates data samples into batches.
289
+ Collate data samples into batches.
289
290
 
290
291
  Args:
291
292
  batch (List[dict]): List of dictionaries containing sample data.
@@ -321,8 +322,8 @@ class YOLOMultiModalDataset(YOLODataset):
321
322
  process both image and text data.
322
323
 
323
324
  Methods:
324
- update_labels_info: Adds text information for multi-modal model training.
325
- build_transforms: Enhances data transformations with text augmentation.
325
+ update_labels_info: Add text information for multi-modal model training.
326
+ build_transforms: Enhance data transformations with text augmentation.
326
327
 
327
328
  Examples:
328
329
  >>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
@@ -330,7 +331,7 @@ class YOLOMultiModalDataset(YOLODataset):
330
331
  >>> print(batch.keys()) # Should include 'texts'
331
332
  """
332
333
 
333
- def __init__(self, *args, data=None, task="detect", **kwargs):
334
+ def __init__(self, *args, data: Optional[Dict] = None, task: str = "detect", **kwargs):
334
335
  """
335
336
  Initialize a YOLOMultiModalDataset.
336
337
 
@@ -342,9 +343,9 @@ class YOLOMultiModalDataset(YOLODataset):
342
343
  """
343
344
  super().__init__(*args, data=data, task=task, **kwargs)
344
345
 
345
- def update_labels_info(self, label):
346
+ def update_labels_info(self, label: Dict) -> Dict:
346
347
  """
347
- Add texts information for multi-modal model training.
348
+ Add text information for multi-modal model training.
348
349
 
349
350
  Args:
350
351
  label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@@ -359,9 +360,9 @@ class YOLOMultiModalDataset(YOLODataset):
359
360
 
360
361
  return labels
361
362
 
362
- def build_transforms(self, hyp=None):
363
+ def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:
363
364
  """
364
- Enhances data transformations with optional text augmentation for multi-modal training.
365
+ Enhance data transformations with optional text augmentation for multi-modal training.
365
366
 
366
367
  Args:
367
368
  hyp (dict, optional): Hyperparameters for transforms.
@@ -408,14 +409,14 @@ class YOLOMultiModalDataset(YOLODataset):
408
409
  return category_freq
409
410
 
410
411
  @staticmethod
411
- def _get_neg_texts(category_freq, threshold=100):
412
+ def _get_neg_texts(category_freq: Dict, threshold: int = 100) -> List[str]:
412
413
  """Get negative text samples based on frequency threshold."""
413
414
  return [k for k, v in category_freq.items() if v >= threshold]
414
415
 
415
416
 
416
417
  class GroundingDataset(YOLODataset):
417
418
  """
418
- Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format.
419
+ Dataset class for object detection tasks using annotations from a JSON file in grounding format.
419
420
 
420
421
  This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than
421
422
  the standard YOLO format text files.
@@ -424,16 +425,16 @@ class GroundingDataset(YOLODataset):
424
425
  json_file (str): Path to the JSON file containing annotations.
425
426
 
426
427
  Methods:
427
- get_img_files: Returns empty list as image files are read in get_labels.
428
- get_labels: Loads annotations from a JSON file and prepares them for training.
429
- build_transforms: Configures augmentations for training with optional text loading.
428
+ get_img_files: Return empty list as image files are read in get_labels.
429
+ get_labels: Load annotations from a JSON file and prepare them for training.
430
+ build_transforms: Configure augmentations for training with optional text loading.
430
431
 
431
432
  Examples:
432
433
  >>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
433
434
  >>> len(dataset) # Number of valid images with annotations
434
435
  """
435
436
 
436
- def __init__(self, *args, task="detect", json_file="", **kwargs):
437
+ def __init__(self, *args, task: str = "detect", json_file: str = "", **kwargs):
437
438
  """
438
439
  Initialize a GroundingDataset for object detection.
439
440
 
@@ -447,7 +448,7 @@ class GroundingDataset(YOLODataset):
447
448
  self.json_file = json_file
448
449
  super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
449
450
 
450
- def get_img_files(self, img_path):
451
+ def get_img_files(self, img_path: str) -> List:
451
452
  """
452
453
  The image files would be read in `get_labels` function, return empty list here.
453
454
 
@@ -459,7 +460,7 @@ class GroundingDataset(YOLODataset):
459
460
  """
460
461
  return []
461
462
 
462
- def verify_labels(self, labels):
463
+ def verify_labels(self, labels: List[Dict]) -> None:
463
464
  """Verify the number of instances in the dataset matches expected counts."""
464
465
  instance_count = sum(label["bboxes"].shape[0] for label in labels)
465
466
  if "final_mixed_train_no_coco_segm" in self.json_file:
@@ -473,9 +474,9 @@ class GroundingDataset(YOLODataset):
473
474
  else:
474
475
  assert False
475
476
 
476
- def cache_labels(self, path=Path("./labels.cache")):
477
+ def cache_labels(self, path: Path = Path("./labels.cache")) -> Dict:
477
478
  """
478
- Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
479
+ Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
479
480
 
480
481
  Args:
481
482
  path (Path): Path where to save the cache file.
@@ -564,7 +565,7 @@ class GroundingDataset(YOLODataset):
564
565
  save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
565
566
  return x
566
567
 
567
- def get_labels(self):
568
+ def get_labels(self) -> List[Dict]:
568
569
  """
569
570
  Load labels from cache or generate them from JSON file.
570
571
 
@@ -586,9 +587,9 @@ class GroundingDataset(YOLODataset):
586
587
  LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
587
588
  return labels
588
589
 
589
- def build_transforms(self, hyp=None):
590
+ def build_transforms(self, hyp: Optional[Dict] = None) -> Compose:
590
591
  """
591
- Configures augmentations for training with optional text loading.
592
+ Configure augmentations for training with optional text loading.
592
593
 
593
594
  Args:
594
595
  hyp (dict, optional): Hyperparameters for transforms.
@@ -627,7 +628,7 @@ class GroundingDataset(YOLODataset):
627
628
  return category_freq
628
629
 
629
630
  @staticmethod
630
- def _get_neg_texts(category_freq, threshold=100):
631
+ def _get_neg_texts(category_freq: Dict, threshold: int = 100) -> List[str]:
631
632
  """Get negative text samples based on frequency threshold."""
632
633
  return [k for k, v in category_freq.items() if v >= threshold]
633
634
 
@@ -649,9 +650,9 @@ class YOLOConcatDataset(ConcatDataset):
649
650
  """
650
651
 
651
652
  @staticmethod
652
- def collate_fn(batch):
653
+ def collate_fn(batch: List[Dict]) -> Dict:
653
654
  """
654
- Collates data samples into batches.
655
+ Collate data samples into batches.
655
656
 
656
657
  Args:
657
658
  batch (List[dict]): List of dictionaries containing sample data.
@@ -661,9 +662,9 @@ class YOLOConcatDataset(ConcatDataset):
661
662
  """
662
663
  return YOLODataset.collate_fn(batch)
663
664
 
664
- def close_mosaic(self, hyp):
665
+ def close_mosaic(self, hyp: Dict) -> None:
665
666
  """
666
- Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
667
+ Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
667
668
 
668
669
  Args:
669
670
  hyp (dict): Hyperparameters for transforms.
@@ -685,7 +686,7 @@ class SemanticDataset(BaseDataset):
685
686
 
686
687
  class ClassificationDataset:
687
688
  """
688
- Extends torchvision ImageFolder to support YOLO classification tasks.
689
+ Dataset class for image classification tasks extending torchvision ImageFolder functionality.
689
690
 
690
691
  This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
691
692
  handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
@@ -701,14 +702,14 @@ class ClassificationDataset:
701
702
  prefix (str): Prefix for logging and cache filenames.
702
703
 
703
704
  Methods:
704
- __getitem__: Returns subset of data and targets corresponding to given indices.
705
- __len__: Returns the total number of samples in the dataset.
706
- verify_images: Verifies all images in dataset.
705
+ __getitem__: Return subset of data and targets corresponding to given indices.
706
+ __len__: Return the total number of samples in the dataset.
707
+ verify_images: Verify all images in dataset.
707
708
  """
708
709
 
709
- def __init__(self, root, args, augment=False, prefix=""):
710
+ def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
710
711
  """
711
- Initialize YOLO object with root, image size, augmentations, and cache settings.
712
+ Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
712
713
 
713
714
  Args:
714
715
  root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
@@ -758,9 +759,9 @@ class ClassificationDataset:
758
759
  else classify_transforms(size=args.imgsz)
759
760
  )
760
761
 
761
- def __getitem__(self, i):
762
+ def __getitem__(self, i: int) -> Dict:
762
763
  """
763
- Returns subset of data and targets corresponding to given indices.
764
+ Return subset of data and targets corresponding to given indices.
764
765
 
765
766
  Args:
766
767
  i (int): Index of the sample to retrieve.
@@ -787,7 +788,7 @@ class ClassificationDataset:
787
788
  """Return the total number of samples in the dataset."""
788
789
  return len(self.samples)
789
790
 
790
- def verify_images(self):
791
+ def verify_images(self) -> List[Tuple]:
791
792
  """
792
793
  Verify all images in dataset.
793
794
 
@@ -8,6 +8,7 @@ import urllib
8
8
  from dataclasses import dataclass
9
9
  from pathlib import Path
10
10
  from threading import Thread
11
+ from typing import Any, List, Optional, Tuple, Union
11
12
 
12
13
  import cv2
13
14
  import numpy as np
@@ -90,8 +91,16 @@ class LoadStreams:
90
91
  - The class implements a buffer system to manage frame storage and retrieval.
91
92
  """
92
93
 
93
- def __init__(self, sources="file.streams", vid_stride=1, buffer=False, channels=3):
94
- """Initialize stream loader for multiple video sources, supporting various stream types."""
94
+ def __init__(self, sources: str = "file.streams", vid_stride: int = 1, buffer: bool = False, channels: int = 3):
95
+ """
96
+ Initialize stream loader for multiple video sources, supporting various stream types.
97
+
98
+ Args:
99
+ sources (str): Path to streams file or single stream URL.
100
+ vid_stride (int): Video frame-rate stride.
101
+ buffer (bool): Whether to buffer input streams.
102
+ channels (int): Number of image channels (1 for grayscale, 3 for RGB).
103
+ """
95
104
  torch.backends.cudnn.benchmark = True # faster for fixed-size inference
96
105
  self.buffer = buffer # buffer input streams
97
106
  self.running = True # running flag for Thread
@@ -143,7 +152,7 @@ class LoadStreams:
143
152
  self.threads[i].start()
144
153
  LOGGER.info("") # newline
145
154
 
146
- def update(self, i, cap, stream):
155
+ def update(self, i: int, cap: cv2.VideoCapture, stream: str):
147
156
  """Read stream frames in daemon thread and update image buffer."""
148
157
  n, f = 0, self.frames[i] # frame number, frame array
149
158
  while self.running and cap.isOpened() and n < (f - 1):
@@ -167,7 +176,7 @@ class LoadStreams:
167
176
  time.sleep(0.01) # wait until the buffer is empty
168
177
 
169
178
  def close(self):
170
- """Terminates stream loader, stops threads, and releases video capture resources."""
179
+ """Terminate stream loader, stop threads, and release video capture resources."""
171
180
  self.running = False # stop flag for Thread
172
181
  for thread in self.threads:
173
182
  if thread.is_alive():
@@ -180,12 +189,12 @@ class LoadStreams:
180
189
  cv2.destroyAllWindows()
181
190
 
182
191
  def __iter__(self):
183
- """Iterates through YOLO image feed and re-opens unresponsive streams."""
192
+ """Iterate through YOLO image feed and re-open unresponsive streams."""
184
193
  self.count = -1
185
194
  return self
186
195
 
187
- def __next__(self):
188
- """Returns the next batch of frames from multiple video streams for processing."""
196
+ def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:
197
+ """Return the next batch of frames from multiple video streams for processing."""
189
198
  self.count += 1
190
199
 
191
200
  images = []
@@ -211,7 +220,7 @@ class LoadStreams:
211
220
 
212
221
  return self.sources, images, [""] * self.bs
213
222
 
214
- def __len__(self):
223
+ def __len__(self) -> int:
215
224
  """Return the number of video streams in the LoadStreams object."""
216
225
  return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
217
226
 
@@ -248,8 +257,14 @@ class LoadScreenshots:
248
257
  ... print(f"Captured frame: {im.shape}")
249
258
  """
250
259
 
251
- def __init__(self, source, channels=3):
252
- """Initialize screenshot capture with specified screen and region parameters."""
260
+ def __init__(self, source: str, channels: int = 3):
261
+ """
262
+ Initialize screenshot capture with specified screen and region parameters.
263
+
264
+ Args:
265
+ source (str): Screen capture source string in format "screen_num left top width height".
266
+ channels (int): Number of image channels (1 for grayscale, 3 for RGB).
267
+ """
253
268
  check_requirements("mss")
254
269
  import mss # noqa
255
270
 
@@ -277,11 +292,11 @@ class LoadScreenshots:
277
292
  self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
278
293
 
279
294
  def __iter__(self):
280
- """Yields the next screenshot image from the specified screen or region for processing."""
295
+ """Yield the next screenshot image from the specified screen or region for processing."""
281
296
  return self
282
297
 
283
- def __next__(self):
284
- """Captures and returns the next screenshot as a numpy array using the mss library."""
298
+ def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:
299
+ """Capture and return the next screenshot as a numpy array using the mss library."""
285
300
  im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
286
301
  im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im0
287
302
  s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
@@ -330,8 +345,16 @@ class LoadImagesAndVideos:
330
345
  - Can read from a text file containing paths to images and videos.
331
346
  """
332
347
 
333
- def __init__(self, path, batch=1, vid_stride=1, channels=3):
334
- """Initialize dataloader for images and videos, supporting various input formats."""
348
+ def __init__(self, path: Union[str, Path, List], batch: int = 1, vid_stride: int = 1, channels: int = 3):
349
+ """
350
+ Initialize dataloader for images and videos, supporting various input formats.
351
+
352
+ Args:
353
+ path (str | Path | List): Path to images/videos, directory, or list of paths.
354
+ batch (int): Batch size for processing.
355
+ vid_stride (int): Video frame-rate stride.
356
+ channels (int): Number of image channels (1 for grayscale, 3 for RGB).
357
+ """
335
358
  parent = None
336
359
  if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
337
360
  parent = Path(path).parent
@@ -376,12 +399,12 @@ class LoadImagesAndVideos:
376
399
  raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
377
400
 
378
401
  def __iter__(self):
379
- """Iterates through image/video files, yielding source paths, images, and metadata."""
402
+ """Iterate through image/video files, yielding source paths, images, and metadata."""
380
403
  self.count = 0
381
404
  return self
382
405
 
383
- def __next__(self):
384
- """Returns the next batch of images or video frames with their paths and metadata."""
406
+ def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:
407
+ """Return the next batch of images or video frames with their paths and metadata."""
385
408
  paths, imgs, info = [], [], []
386
409
  while len(imgs) < self.bs:
387
410
  if self.count >= self.nf: # end of file list
@@ -450,8 +473,8 @@ class LoadImagesAndVideos:
450
473
 
451
474
  return paths, imgs, info
452
475
 
453
- def _new_video(self, path):
454
- """Creates a new video capture object for the given path and initializes video-related attributes."""
476
+ def _new_video(self, path: str):
477
+ """Create a new video capture object for the given path and initialize video-related attributes."""
455
478
  self.frame = 0
456
479
  self.cap = cv2.VideoCapture(path)
457
480
  self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
@@ -459,8 +482,8 @@ class LoadImagesAndVideos:
459
482
  raise FileNotFoundError(f"Failed to open video {path}")
460
483
  self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
461
484
 
462
- def __len__(self):
463
- """Returns the number of files (images and videos) in the dataset."""
485
+ def __len__(self) -> int:
486
+ """Return the number of files (images and videos) in the dataset."""
464
487
  return math.ceil(self.nf / self.bs) # number of batches
465
488
 
466
489
 
@@ -491,8 +514,14 @@ class LoadPilAndNumpy:
491
514
  Loaded 2 images
492
515
  """
493
516
 
494
- def __init__(self, im0, channels=3):
495
- """Initializes a loader for PIL and Numpy images, converting inputs to a standardized format."""
517
+ def __init__(self, im0: Union[Image.Image, np.ndarray, List], channels: int = 3):
518
+ """
519
+ Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
520
+
521
+ Args:
522
+ im0 (PIL.Image.Image | np.ndarray | List): Single image or list of images in PIL or numpy format.
523
+ channels (int): Number of image channels (1 for grayscale, 3 for RGB).
524
+ """
496
525
  if not isinstance(im0, list):
497
526
  im0 = [im0]
498
527
  # use `image{i}.jpg` when Image.filename returns an empty path.
@@ -503,7 +532,7 @@ class LoadPilAndNumpy:
503
532
  self.bs = len(self.im0)
504
533
 
505
534
  @staticmethod
506
- def _single_check(im, flag="RGB"):
535
+ def _single_check(im: Union[Image.Image, np.ndarray], flag: str = "RGB") -> np.ndarray:
507
536
  """Validate and format an image to numpy array, ensuring RGB order and contiguous memory."""
508
537
  assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
509
538
  if isinstance(im, Image.Image):
@@ -515,19 +544,19 @@ class LoadPilAndNumpy:
515
544
  im = im[..., None]
516
545
  return im
517
546
 
518
- def __len__(self):
519
- """Returns the length of the 'im0' attribute, representing the number of loaded images."""
547
+ def __len__(self) -> int:
548
+ """Return the length of the 'im0' attribute, representing the number of loaded images."""
520
549
  return len(self.im0)
521
550
 
522
- def __next__(self):
523
- """Returns the next batch of images, paths, and metadata for processing."""
551
+ def __next__(self) -> Tuple[List[str], List[np.ndarray], List[str]]:
552
+ """Return the next batch of images, paths, and metadata for processing."""
524
553
  if self.count == 1: # loop only once as it's batch inference
525
554
  raise StopIteration
526
555
  self.count += 1
527
556
  return self.paths, self.im0, [""] * self.bs
528
557
 
529
558
  def __iter__(self):
530
- """Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
559
+ """Iterate through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
531
560
  self.count = 0
532
561
  return self
533
562
 
@@ -556,16 +585,21 @@ class LoadTensor:
556
585
  >>> print(f"Processed {len(images)} images")
557
586
  """
558
587
 
559
- def __init__(self, im0) -> None:
560
- """Initialize LoadTensor object for processing torch.Tensor image data."""
588
+ def __init__(self, im0: torch.Tensor) -> None:
589
+ """
590
+ Initialize LoadTensor object for processing torch.Tensor image data.
591
+
592
+ Args:
593
+ im0 (torch.Tensor): Input tensor with shape (B, C, H, W).
594
+ """
561
595
  self.im0 = self._single_check(im0)
562
596
  self.bs = self.im0.shape[0]
563
597
  self.mode = "image"
564
598
  self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
565
599
 
566
600
  @staticmethod
567
- def _single_check(im, stride=32):
568
- """Validates and formats a single image tensor, ensuring correct shape and normalization."""
601
+ def _single_check(im: torch.Tensor, stride: int = 32) -> torch.Tensor:
602
+ """Validate and format a single image tensor, ensuring correct shape and normalization."""
569
603
  s = (
570
604
  f"torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
571
605
  f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
@@ -586,24 +620,24 @@ class LoadTensor:
586
620
  return im
587
621
 
588
622
  def __iter__(self):
589
- """Yields an iterator object for iterating through tensor image data."""
623
+ """Yield an iterator object for iterating through tensor image data."""
590
624
  self.count = 0
591
625
  return self
592
626
 
593
- def __next__(self):
594
- """Yields the next batch of tensor images and metadata for processing."""
627
+ def __next__(self) -> Tuple[List[str], torch.Tensor, List[str]]:
628
+ """Yield the next batch of tensor images and metadata for processing."""
595
629
  if self.count == 1:
596
630
  raise StopIteration
597
631
  self.count += 1
598
632
  return self.paths, self.im0, [""] * self.bs
599
633
 
600
- def __len__(self):
601
- """Returns the batch size of the tensor input."""
634
+ def __len__(self) -> int:
635
+ """Return the batch size of the tensor input."""
602
636
  return self.bs
603
637
 
604
638
 
605
- def autocast_list(source):
606
- """Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
639
+ def autocast_list(source: List[Any]) -> List[Union[Image.Image, np.ndarray]]:
640
+ """Merge a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
607
641
  files = []
608
642
  for im in source:
609
643
  if isinstance(im, (str, Path)): # filename or uri
@@ -619,14 +653,13 @@ def autocast_list(source):
619
653
  return files
620
654
 
621
655
 
622
- def get_best_youtube_url(url, method="pytube"):
656
+ def get_best_youtube_url(url: str, method: str = "pytube") -> Optional[str]:
623
657
  """
624
- Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
658
+ Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
625
659
 
626
660
  Args:
627
661
  url (str): The URL of the YouTube video.
628
662
  method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp".
629
- Defaults to "pytube".
630
663
 
631
664
  Returns:
632
665
  (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
ultralytics/data/split.py CHANGED
@@ -3,14 +3,15 @@
3
3
  import random
4
4
  import shutil
5
5
  from pathlib import Path
6
+ from typing import Tuple, Union
6
7
 
7
8
  from ultralytics.data.utils import IMG_FORMATS, img2label_paths
8
9
  from ultralytics.utils import DATASETS_DIR, LOGGER, TQDM
9
10
 
10
11
 
11
- def split_classify_dataset(source_dir, train_ratio=0.8):
12
+ def split_classify_dataset(source_dir: Union[str, Path], train_ratio: float = 0.8) -> Path:
12
13
  """
13
- Split dataset into train and val directories in a new directory.
14
+ Split classification dataset into train and val directories in a new directory.
14
15
 
15
16
  Creates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class
16
17
  structure with an 80/20 split by default.
@@ -46,13 +47,17 @@ def split_classify_dataset(source_dir, train_ratio=0.8):
46
47
  └── ...
47
48
 
48
49
  Args:
49
- source_dir (str | Path): Path to Caltech dataset root directory.
50
+ source_dir (str | Path): Path to classification dataset root directory.
50
51
  train_ratio (float): Ratio for train split, between 0 and 1.
51
52
 
53
+ Returns:
54
+ (Path): Path to the created split directory.
55
+
52
56
  Examples:
53
- >>> # Split dataset with default 80/20 ratio
57
+ Split dataset with default 80/20 ratio
54
58
  >>> split_classify_dataset("path/to/caltech")
55
- >>> # Split with custom ratio
59
+
60
+ Split with custom ratio
56
61
  >>> split_classify_dataset("path/to/caltech", 0.75)
57
62
  """
58
63
  source_path = Path(source_dir)
@@ -90,18 +95,26 @@ def split_classify_dataset(source_dir, train_ratio=0.8):
90
95
  return split_path
91
96
 
92
97
 
93
- def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
98
+ def autosplit(
99
+ path: Path = DATASETS_DIR / "coco8/images",
100
+ weights: Tuple[float, float, float] = (0.9, 0.1, 0.0),
101
+ annotated_only: bool = False,
102
+ ) -> None:
94
103
  """
95
104
  Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
96
105
 
97
106
  Args:
98
- path (Path, optional): Path to images directory.
99
- weights (list | tuple, optional): Train, validation, and test split fractions.
100
- annotated_only (bool, optional): If True, only images with an associated txt file are used.
107
+ path (Path): Path to images directory.
108
+ weights (tuple): Train, validation, and test split fractions.
109
+ annotated_only (bool): If True, only images with an associated txt file are used.
101
110
 
102
111
  Examples:
112
+ Split images with default weights
103
113
  >>> from ultralytics.data.split import autosplit
104
114
  >>> autosplit()
115
+
116
+ Split with custom weights and annotated images only
117
+ >>> autosplit(path="path/to/images", weights=(0.8, 0.15, 0.05), annotated_only=True)
105
118
  """
106
119
  path = Path(path) # images dir
107
120
  files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only