dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/data/base.py CHANGED
@@ -21,11 +21,10 @@ from ultralytics.utils.patches import imread
21
21
 
22
22
 
23
23
  class BaseDataset(Dataset):
24
- """
25
- Base dataset class for loading and processing image data.
24
+ """Base dataset class for loading and processing image data.
26
25
 
27
- This class provides core functionality for loading images, caching, and preparing data for training and inference
28
- in object detection tasks.
26
+ This class provides core functionality for loading images, caching, and preparing data for training and inference in
27
+ object detection tasks.
29
28
 
30
29
  Attributes:
31
30
  img_path (str): Path to the folder containing images.
@@ -34,7 +33,8 @@ class BaseDataset(Dataset):
34
33
  single_cls (bool): Whether to treat all objects as a single class.
35
34
  prefix (str): Prefix to print in log messages.
36
35
  fraction (float): Fraction of dataset to utilize.
37
- channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).
36
+ channels (int): Number of channels in the images (1 for grayscale, 3 for color). Color images loaded with OpenCV
37
+ are in BGR channel order.
38
38
  cv2_flag (int): OpenCV flag for reading images.
39
39
  im_files (list[str]): List of image file paths.
40
40
  labels (list[dict]): List of label data dictionaries.
@@ -86,8 +86,7 @@ class BaseDataset(Dataset):
86
86
  fraction: float = 1.0,
87
87
  channels: int = 3,
88
88
  ):
89
- """
90
- Initialize BaseDataset with given configuration and options.
89
+ """Initialize BaseDataset with given configuration and options.
91
90
 
92
91
  Args:
93
92
  img_path (str | list[str]): Path to the folder containing images or list of image paths.
@@ -103,7 +102,8 @@ class BaseDataset(Dataset):
103
102
  single_cls (bool): If True, single class training is used.
104
103
  classes (list[int], optional): List of included classes.
105
104
  fraction (float): Fraction of dataset to utilize.
106
- channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).
105
+ channels (int): Number of channels in the images (1 for grayscale, 3 for color). Color images loaded with
106
+ OpenCV are in BGR channel order.
107
107
  """
108
108
  super().__init__()
109
109
  self.img_path = img_path
@@ -148,8 +148,7 @@ class BaseDataset(Dataset):
148
148
  self.transforms = self.build_transforms(hyp=hyp)
149
149
 
150
150
  def get_img_files(self, img_path: str | list[str]) -> list[str]:
151
- """
152
- Read image files from the specified path.
151
+ """Read image files from the specified path.
153
152
 
154
153
  Args:
155
154
  img_path (str | list[str]): Path or list of paths to image directories or files.
@@ -186,8 +185,7 @@ class BaseDataset(Dataset):
186
185
  return im_files
187
186
 
188
187
  def update_labels(self, include_class: list[int] | None) -> None:
189
- """
190
- Update labels to include only specified classes.
188
+ """Update labels to include only specified classes.
191
189
 
192
190
  Args:
193
191
  include_class (list[int], optional): List of classes to include. If None, all classes are included.
@@ -210,8 +208,7 @@ class BaseDataset(Dataset):
210
208
  self.labels[i]["cls"][:, 0] = 0
211
209
 
212
210
  def load_image(self, i: int, rect_mode: bool = True) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
213
- """
214
- Load an image from dataset index 'i'.
211
+ """Load an image from dataset index 'i'.
215
212
 
216
213
  Args:
217
214
  i (int): Index of the image to load.
@@ -286,8 +283,7 @@ class BaseDataset(Dataset):
286
283
  np.save(f.as_posix(), imread(self.im_files[i]), allow_pickle=False)
287
284
 
288
285
  def check_cache_disk(self, safety_margin: float = 0.5) -> bool:
289
- """
290
- Check if there's enough disk space for caching images.
286
+ """Check if there's enough disk space for caching images.
291
287
 
292
288
  Args:
293
289
  safety_margin (float): Safety margin factor for disk space calculation.
@@ -307,10 +303,10 @@ class BaseDataset(Dataset):
307
303
  b += im.nbytes
308
304
  if not os.access(Path(im_file).parent, os.W_OK):
309
305
  self.cache = None
310
- LOGGER.warning(f"{self.prefix}Skipping caching images to disk, directory not writeable")
306
+ LOGGER.warning(f"{self.prefix}Skipping caching images to disk, directory not writable")
311
307
  return False
312
308
  disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk
313
- total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent)
309
+ total, _used, free = shutil.disk_usage(Path(self.im_files[0]).parent)
314
310
  if disk_required > free:
315
311
  self.cache = None
316
312
  LOGGER.warning(
@@ -322,8 +318,7 @@ class BaseDataset(Dataset):
322
318
  return True
323
319
 
324
320
  def check_cache_ram(self, safety_margin: float = 0.5) -> bool:
325
- """
326
- Check if there's enough RAM for caching images.
321
+ """Check if there's enough RAM for caching images.
327
322
 
328
323
  Args:
329
324
  safety_margin (float): Safety margin factor for RAM calculation.
@@ -381,8 +376,7 @@ class BaseDataset(Dataset):
381
376
  return self.transforms(self.get_image_and_label(index))
382
377
 
383
378
  def get_image_and_label(self, index: int) -> dict[str, Any]:
384
- """
385
- Get and return label information from the dataset.
379
+ """Get and return label information from the dataset.
386
380
 
387
381
  Args:
388
382
  index (int): Index of the image to retrieve.
@@ -410,8 +404,7 @@ class BaseDataset(Dataset):
410
404
  return label
411
405
 
412
406
  def build_transforms(self, hyp: dict[str, Any] | None = None):
413
- """
414
- Users can customize augmentations here.
407
+ """Users can customize augmentations here.
415
408
 
416
409
  Examples:
417
410
  >>> if self.augment:
@@ -424,8 +417,7 @@ class BaseDataset(Dataset):
424
417
  raise NotImplementedError
425
418
 
426
419
  def get_labels(self) -> list[dict[str, Any]]:
427
- """
428
- Users can customize their own format here.
420
+ """Users can customize their own format here.
429
421
 
430
422
  Examples:
431
423
  Ensure output is a dictionary with the following keys:
ultralytics/data/build.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import math
5
6
  import os
6
7
  import random
7
8
  from collections.abc import Iterator
@@ -11,8 +12,9 @@ from urllib.parse import urlsplit
11
12
 
12
13
  import numpy as np
13
14
  import torch
15
+ import torch.distributed as dist
14
16
  from PIL import Image
15
- from torch.utils.data import dataloader, distributed
17
+ from torch.utils.data import Dataset, dataloader, distributed
16
18
 
17
19
  from ultralytics.cfg import IterableSimpleNamespace
18
20
  from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
@@ -33,8 +35,7 @@ from ultralytics.utils.torch_utils import TORCH_2_0
33
35
 
34
36
 
35
37
  class InfiniteDataLoader(dataloader.DataLoader):
36
- """
37
- Dataloader that reuses workers for infinite iteration.
38
+ """DataLoader that reuses workers for infinite iteration.
38
39
 
39
40
  This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency
40
41
  for training loops that need to iterate through the dataset multiple times without recreating workers.
@@ -50,7 +51,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
50
51
  reset: Reset the iterator, useful when modifying dataset settings during training.
51
52
 
52
53
  Examples:
53
- Create an infinite dataloader for training
54
+ Create an infinite DataLoader for training
54
55
  >>> dataset = YOLODataset(...)
55
56
  >>> dataloader = InfiniteDataLoader(dataset, batch_size=16, shuffle=True)
56
57
  >>> for batch in dataloader: # Infinite iteration
@@ -75,7 +76,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
75
76
  yield next(self.iterator)
76
77
 
77
78
  def __del__(self):
78
- """Ensure that workers are properly terminated when the dataloader is deleted."""
79
+ """Ensure that workers are properly terminated when the DataLoader is deleted."""
79
80
  try:
80
81
  if not hasattr(self.iterator, "_workers"):
81
82
  return
@@ -92,11 +93,10 @@ class InfiniteDataLoader(dataloader.DataLoader):
92
93
 
93
94
 
94
95
  class _RepeatSampler:
95
- """
96
- Sampler that repeats forever for infinite iteration.
96
+ """Sampler that repeats forever for infinite iteration.
97
97
 
98
- This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration
99
- over a dataset without recreating the sampler.
98
+ This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration over a
99
+ dataset without recreating the sampler.
100
100
 
101
101
  Attributes:
102
102
  sampler (Dataset.sampler): The sampler to repeat.
@@ -112,7 +112,109 @@ class _RepeatSampler:
112
112
  yield from iter(self.sampler)
113
113
 
114
114
 
115
- def seed_worker(worker_id: int): # noqa
115
+ class ContiguousDistributedSampler(torch.utils.data.Sampler):
116
+ """Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
117
+
118
+ Unlike PyTorch's DistributedSampler which distributes samples in a round-robin fashion (GPU 0 gets indices
119
+ [0,2,4,...], GPU 1 gets [1,3,5,...]), this sampler gives each GPU contiguous batches of the dataset (GPU 0 gets
120
+ batches [0,1,2,...], GPU 1 gets batches [k,k+1,...], etc.). This preserves any ordering or grouping in the original
121
+ dataset, which is critical when samples are organized by similarity (e.g., images sorted by size to enable efficient
122
+ batching without padding when using rect=True).
123
+
124
+ The sampler handles uneven batch counts by distributing remainder batches to the first few ranks, ensuring all
125
+ samples are covered exactly once across all GPUs.
126
+
127
+ Args:
128
+ dataset (Dataset): Dataset to sample from. Must implement __len__.
129
+ num_replicas (int, optional): Number of distributed processes. Defaults to world size.
130
+ batch_size (int, optional): Batch size used by dataloader. Defaults to dataset batch size.
131
+ rank (int, optional): Rank of current process. Defaults to current rank.
132
+ shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False. When True,
133
+ shuffling is deterministic and controlled by set_epoch() for reproducibility.
134
+
135
+ Examples:
136
+ >>> # For validation with size-grouped images
137
+ >>> sampler = ContiguousDistributedSampler(val_dataset, batch_size=32, shuffle=False)
138
+ >>> loader = DataLoader(val_dataset, batch_size=32, sampler=sampler)
139
+ >>> # For training with shuffling
140
+ >>> sampler = ContiguousDistributedSampler(train_dataset, batch_size=32, shuffle=True)
141
+ >>> for epoch in range(num_epochs):
142
+ ... sampler.set_epoch(epoch)
143
+ ... for batch in loader:
144
+ ... ...
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ dataset: Dataset,
150
+ num_replicas: int | None = None,
151
+ batch_size: int | None = None,
152
+ rank: int | None = None,
153
+ shuffle: bool = False,
154
+ ) -> None:
155
+ """Initialize the sampler with dataset and distributed training parameters."""
156
+ if num_replicas is None:
157
+ num_replicas = dist.get_world_size() if dist.is_initialized() else 1
158
+ if rank is None:
159
+ rank = dist.get_rank() if dist.is_initialized() else 0
160
+ if batch_size is None:
161
+ batch_size = getattr(dataset, "batch_size", 1)
162
+
163
+ self.num_replicas = num_replicas
164
+ self.rank = rank
165
+ self.epoch = 0
166
+ self.shuffle = shuffle
167
+ self.total_size = len(dataset)
168
+ # ensure all ranks have a sample if batch size >= total size; degenerates to round-robin sampler
169
+ self.batch_size = 1 if batch_size >= self.total_size else batch_size
170
+ self.num_batches = math.ceil(self.total_size / self.batch_size)
171
+
172
+ def _get_rank_indices(self) -> tuple[int, int]:
173
+ """Calculate the start and end sample indices for this rank."""
174
+ # Calculate which batches this rank handles
175
+ batches_per_rank_base = self.num_batches // self.num_replicas
176
+ remainder = self.num_batches % self.num_replicas
177
+
178
+ # This rank gets an extra batch if rank < remainder
179
+ batches_for_this_rank = batches_per_rank_base + (1 if self.rank < remainder else 0)
180
+
181
+ # Calculate starting batch: base position + number of extra batches given to earlier ranks
182
+ start_batch = self.rank * batches_per_rank_base + min(self.rank, remainder)
183
+ end_batch = start_batch + batches_for_this_rank
184
+
185
+ # Convert batch indices to sample indices
186
+ start_idx = start_batch * self.batch_size
187
+ end_idx = min(end_batch * self.batch_size, self.total_size)
188
+
189
+ return start_idx, end_idx
190
+
191
+ def __iter__(self) -> Iterator:
192
+ """Generate indices for this rank's contiguous chunk of the dataset."""
193
+ start_idx, end_idx = self._get_rank_indices()
194
+ indices = list(range(start_idx, end_idx))
195
+
196
+ if self.shuffle:
197
+ g = torch.Generator()
198
+ g.manual_seed(self.epoch)
199
+ indices = [indices[i] for i in torch.randperm(len(indices), generator=g).tolist()]
200
+
201
+ return iter(indices)
202
+
203
+ def __len__(self) -> int:
204
+ """Return the number of samples in this rank's chunk."""
205
+ start_idx, end_idx = self._get_rank_indices()
206
+ return end_idx - start_idx
207
+
208
+ def set_epoch(self, epoch: int) -> None:
209
+ """Set the epoch for this sampler to ensure different shuffling patterns across epochs.
210
+
211
+ Args:
212
+ epoch (int): Epoch number to use as the random seed for shuffling.
213
+ """
214
+ self.epoch = epoch
215
+
216
+
217
+ def seed_worker(worker_id: int) -> None:
116
218
  """Set dataloader worker seed for reproducibility across worker processes."""
117
219
  worker_seed = torch.initial_seed() % 2**32
118
220
  np.random.seed(worker_seed)
@@ -128,7 +230,7 @@ def build_yolo_dataset(
128
230
  rect: bool = False,
129
231
  stride: int = 32,
130
232
  multi_modal: bool = False,
131
- ):
233
+ ) -> Dataset:
132
234
  """Build and return a YOLO dataset based on configuration parameters."""
133
235
  dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
134
236
  return dataset(
@@ -159,7 +261,7 @@ def build_grounding(
159
261
  rect: bool = False,
160
262
  stride: int = 32,
161
263
  max_samples: int = 80,
162
- ):
264
+ ) -> Dataset:
163
265
  """Build and return a GroundingDataset based on configuration parameters."""
164
266
  return GroundingDataset(
165
267
  img_path=img_path,
@@ -181,9 +283,16 @@ def build_grounding(
181
283
  )
182
284
 
183
285
 
184
- def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False):
185
- """
186
- Create and return an InfiniteDataLoader or DataLoader for training or validation.
286
+ def build_dataloader(
287
+ dataset,
288
+ batch: int,
289
+ workers: int,
290
+ shuffle: bool = True,
291
+ rank: int = -1,
292
+ drop_last: bool = False,
293
+ pin_memory: bool = True,
294
+ ) -> InfiniteDataLoader:
295
+ """Create and return an InfiniteDataLoader or DataLoader for training or validation.
187
296
 
188
297
  Args:
189
298
  dataset (Dataset): Dataset to load data from.
@@ -192,6 +301,7 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
192
301
  shuffle (bool, optional): Whether to shuffle the dataset.
193
302
  rank (int, optional): Process rank in distributed training. -1 for single-GPU training.
194
303
  drop_last (bool, optional): Whether to drop the last incomplete batch.
304
+ pin_memory (bool, optional): Whether to use pinned memory for dataloader.
195
305
 
196
306
  Returns:
197
307
  (InfiniteDataLoader): A dataloader that can be used for training or validation.
@@ -204,7 +314,13 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
204
314
  batch = min(batch, len(dataset))
205
315
  nd = torch.cuda.device_count() # number of CUDA devices
206
316
  nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers
207
- sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
317
+ sampler = (
318
+ None
319
+ if rank == -1
320
+ else distributed.DistributedSampler(dataset, shuffle=shuffle)
321
+ if shuffle
322
+ else ContiguousDistributedSampler(dataset)
323
+ )
208
324
  generator = torch.Generator()
209
325
  generator.manual_seed(6148914691236517205 + RANK)
210
326
  return InfiniteDataLoader(
@@ -214,7 +330,7 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
214
330
  num_workers=nw,
215
331
  sampler=sampler,
216
332
  prefetch_factor=4 if nw > 0 else None, # increase over default 2
217
- pin_memory=nd > 0,
333
+ pin_memory=nd > 0 and pin_memory,
218
334
  collate_fn=getattr(dataset, "collate_fn", None),
219
335
  worker_init_fn=seed_worker,
220
336
  generator=generator,
@@ -222,9 +338,10 @@ def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, ra
222
338
  )
223
339
 
224
340
 
225
- def check_source(source):
226
- """
227
- Check the type of input source and return corresponding flag values.
341
+ def check_source(
342
+ source: str | int | Path | list | tuple | np.ndarray | Image.Image | torch.Tensor,
343
+ ) -> tuple[Any, bool, bool, bool, bool, bool]:
344
+ """Check the type of input source and return corresponding flag values.
228
345
 
229
346
  Args:
230
347
  source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
@@ -271,12 +388,17 @@ def check_source(source):
271
388
  return source, webcam, screenshot, from_img, in_memory, tensor
272
389
 
273
390
 
274
- def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):
275
- """
276
- Load an inference source for object detection and apply necessary transformations.
391
+ def load_inference_source(
392
+ source: str | int | Path | list | tuple | np.ndarray | Image.Image | torch.Tensor,
393
+ batch: int = 1,
394
+ vid_stride: int = 1,
395
+ buffer: bool = False,
396
+ channels: int = 3,
397
+ ):
398
+ """Load an inference source for object detection and apply necessary transformations.
277
399
 
278
400
  Args:
279
- source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.
401
+ source (str | Path | list | tuple | torch.Tensor | PIL.Image | np.ndarray): The input source for inference.
280
402
  batch (int, optional): Batch size for dataloaders.
281
403
  vid_stride (int, optional): The frame interval for video sources.
282
404
  buffer (bool, optional): Whether stream frames will be buffered.
@@ -295,7 +417,7 @@ def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buff
295
417
  source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
296
418
  source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
297
419
 
298
- # Dataloader
420
+ # DataLoader
299
421
  if tensor:
300
422
  dataset = LoadTensor(source)
301
423
  elif in_memory: