dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/data/build.py CHANGED
@@ -1,14 +1,22 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import math
3
6
  import os
4
7
  import random
8
+ from collections.abc import Iterator
5
9
  from pathlib import Path
10
+ from typing import Any
11
+ from urllib.parse import urlsplit
6
12
 
7
13
  import numpy as np
8
14
  import torch
15
+ import torch.distributed as dist
9
16
  from PIL import Image
10
17
  from torch.utils.data import dataloader, distributed
11
18
 
19
+ from ultralytics.cfg import IterableSimpleNamespace
12
20
  from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
13
21
  from ultralytics.data.loaders import (
14
22
  LOADERS,
@@ -20,40 +28,49 @@ from ultralytics.data.loaders import (
20
28
  SourceTypes,
21
29
  autocast_list,
22
30
  )
23
- from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
31
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
24
32
  from ultralytics.utils import RANK, colorstr
25
33
  from ultralytics.utils.checks import check_file
34
+ from ultralytics.utils.torch_utils import TORCH_2_0
26
35
 
27
36
 
28
37
  class InfiniteDataLoader(dataloader.DataLoader):
29
- """
30
- Dataloader that reuses workers.
38
+ """Dataloader that reuses workers for infinite iteration.
31
39
 
32
40
  This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency
33
- for training loops that need to iterate through the dataset multiple times.
41
+ for training loops that need to iterate through the dataset multiple times without recreating workers.
34
42
 
35
43
  Attributes:
36
44
  batch_sampler (_RepeatSampler): A sampler that repeats indefinitely.
37
45
  iterator (Iterator): The iterator from the parent DataLoader.
38
46
 
39
47
  Methods:
40
- __len__: Returns the length of the batch sampler's sampler.
41
- __iter__: Creates a sampler that repeats indefinitely.
42
- __del__: Ensures workers are properly terminated.
43
- reset: Resets the iterator, useful when modifying dataset settings during training.
48
+ __len__: Return the length of the batch sampler's sampler.
49
+ __iter__: Create a sampler that repeats indefinitely.
50
+ __del__: Ensure workers are properly terminated.
51
+ reset: Reset the iterator, useful when modifying dataset settings during training.
52
+
53
+ Examples:
54
+ Create an infinite dataloader for training
55
+ >>> dataset = YOLODataset(...)
56
+ >>> dataloader = InfiniteDataLoader(dataset, batch_size=16, shuffle=True)
57
+ >>> for batch in dataloader: # Infinite iteration
58
+ >>> train_step(batch)
44
59
  """
45
60
 
46
- def __init__(self, *args, **kwargs):
61
+ def __init__(self, *args: Any, **kwargs: Any):
47
62
  """Initialize the InfiniteDataLoader with the same arguments as DataLoader."""
63
+ if not TORCH_2_0:
64
+ kwargs.pop("prefetch_factor", None) # not supported by earlier versions
48
65
  super().__init__(*args, **kwargs)
49
66
  object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
50
67
  self.iterator = super().__iter__()
51
68
 
52
- def __len__(self):
69
+ def __len__(self) -> int:
53
70
  """Return the length of the batch sampler's sampler."""
54
71
  return len(self.batch_sampler.sampler)
55
72
 
56
- def __iter__(self):
73
+ def __iter__(self) -> Iterator:
57
74
  """Create an iterator that yields indefinitely from the underlying iterator."""
58
75
  for _ in range(len(self)):
59
76
  yield next(self.iterator)
@@ -76,34 +93,137 @@ class InfiniteDataLoader(dataloader.DataLoader):
76
93
 
77
94
 
78
95
  class _RepeatSampler:
79
- """
80
- Sampler that repeats forever.
96
+ """Sampler that repeats forever for infinite iteration.
81
97
 
82
- This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration
83
- over a dataset.
98
+ This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration over a
99
+ dataset without recreating the sampler.
84
100
 
85
101
  Attributes:
86
102
  sampler (Dataset.sampler): The sampler to repeat.
87
103
  """
88
104
 
89
- def __init__(self, sampler):
105
+ def __init__(self, sampler: Any):
90
106
  """Initialize the _RepeatSampler with a sampler to repeat indefinitely."""
91
107
  self.sampler = sampler
92
108
 
93
- def __iter__(self):
109
+ def __iter__(self) -> Iterator:
94
110
  """Iterate over the sampler indefinitely, yielding its contents."""
95
111
  while True:
96
112
  yield from iter(self.sampler)
97
113
 
98
114
 
99
- def seed_worker(worker_id): # noqa
115
+ class ContiguousDistributedSampler(torch.utils.data.Sampler):
116
+ """Distributed sampler that assigns contiguous batch-aligned chunks of the dataset to each GPU.
117
+
118
+ Unlike PyTorch's DistributedSampler which distributes samples in a round-robin fashion (GPU 0 gets indices
119
+ [0,2,4,...], GPU 1 gets [1,3,5,...]), this sampler gives each GPU contiguous batches of the dataset (GPU 0 gets
120
+ batches [0,1,2,...], GPU 1 gets batches [k,k+1,...], etc.). This preserves any ordering or grouping in the original
121
+ dataset, which is critical when samples are organized by similarity (e.g., images sorted by size to enable efficient
122
+ batching without padding when using rect=True).
123
+
124
+ The sampler handles uneven batch counts by distributing remainder batches to the first few ranks, ensuring all
125
+ samples are covered exactly once across all GPUs.
126
+
127
+ Args:
128
+ dataset (torch.utils.data.Dataset): Dataset to sample from. Must implement __len__.
129
+ num_replicas (int, optional): Number of distributed processes. Defaults to world size.
130
+ batch_size (int, optional): Batch size used by dataloader. Defaults to dataset batch size.
131
+ rank (int, optional): Rank of current process. Defaults to current rank.
132
+ shuffle (bool, optional): Whether to shuffle indices within each rank's chunk. Defaults to False. When True,
133
+ shuffling is deterministic and controlled by set_epoch() for reproducibility.
134
+
135
+ Examples:
136
+ >>> # For validation with size-grouped images
137
+ >>> sampler = ContiguousDistributedSampler(val_dataset, batch_size=32, shuffle=False)
138
+ >>> loader = DataLoader(val_dataset, batch_size=32, sampler=sampler)
139
+ >>> # For training with shuffling
140
+ >>> sampler = ContiguousDistributedSampler(train_dataset, batch_size=32, shuffle=True)
141
+ >>> for epoch in range(num_epochs):
142
+ ... sampler.set_epoch(epoch)
143
+ ... for batch in loader:
144
+ ... ...
145
+ """
146
+
147
+ def __init__(self, dataset, num_replicas=None, batch_size=None, rank=None, shuffle=False):
148
+ """Initialize the sampler with dataset and distributed training parameters."""
149
+ if num_replicas is None:
150
+ num_replicas = dist.get_world_size() if dist.is_initialized() else 1
151
+ if rank is None:
152
+ rank = dist.get_rank() if dist.is_initialized() else 0
153
+ if batch_size is None:
154
+ batch_size = getattr(dataset, "batch_size", 1)
155
+
156
+ self.dataset = dataset
157
+ self.num_replicas = num_replicas
158
+ self.batch_size = batch_size
159
+ self.rank = rank
160
+ self.epoch = 0
161
+ self.shuffle = shuffle
162
+ self.total_size = len(dataset)
163
+ self.num_batches = math.ceil(self.total_size / self.batch_size)
164
+
165
+ def _get_rank_indices(self):
166
+ """Calculate the start and end sample indices for this rank."""
167
+ # Calculate which batches this rank handles
168
+ batches_per_rank_base = self.num_batches // self.num_replicas
169
+ remainder = self.num_batches % self.num_replicas
170
+
171
+ # This rank gets an extra batch if rank < remainder
172
+ batches_for_this_rank = batches_per_rank_base + (1 if self.rank < remainder else 0)
173
+
174
+ # Calculate starting batch: base position + number of extra batches given to earlier ranks
175
+ start_batch = self.rank * batches_per_rank_base + min(self.rank, remainder)
176
+ end_batch = start_batch + batches_for_this_rank
177
+
178
+ # Convert batch indices to sample indices
179
+ start_idx = start_batch * self.batch_size
180
+ end_idx = min(end_batch * self.batch_size, self.total_size)
181
+
182
+ return start_idx, end_idx
183
+
184
+ def __iter__(self):
185
+ """Generate indices for this rank's contiguous chunk of the dataset."""
186
+ start_idx, end_idx = self._get_rank_indices()
187
+ indices = list(range(start_idx, end_idx))
188
+
189
+ if self.shuffle:
190
+ g = torch.Generator()
191
+ g.manual_seed(self.epoch)
192
+ indices = [indices[i] for i in torch.randperm(len(indices), generator=g).tolist()]
193
+
194
+ return iter(indices)
195
+
196
+ def __len__(self):
197
+ """Return the number of samples in this rank's chunk."""
198
+ start_idx, end_idx = self._get_rank_indices()
199
+ return end_idx - start_idx
200
+
201
+ def set_epoch(self, epoch):
202
+ """Set the epoch for this sampler to ensure different shuffling patterns across epochs.
203
+
204
+ Args:
205
+ epoch (int): Epoch number to use as the random seed for shuffling.
206
+ """
207
+ self.epoch = epoch
208
+
209
+
210
+ def seed_worker(worker_id: int):
100
211
  """Set dataloader worker seed for reproducibility across worker processes."""
101
212
  worker_seed = torch.initial_seed() % 2**32
102
213
  np.random.seed(worker_seed)
103
214
  random.seed(worker_seed)
104
215
 
105
216
 
106
- def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
217
+ def build_yolo_dataset(
218
+ cfg: IterableSimpleNamespace,
219
+ img_path: str,
220
+ batch: int,
221
+ data: dict[str, Any],
222
+ mode: str = "train",
223
+ rect: bool = False,
224
+ stride: int = 32,
225
+ multi_modal: bool = False,
226
+ ):
107
227
  """Build and return a YOLO dataset based on configuration parameters."""
108
228
  dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
109
229
  return dataset(
@@ -115,7 +235,7 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
115
235
  rect=cfg.rect or rect, # rectangular batches
116
236
  cache=cfg.cache or None,
117
237
  single_cls=cfg.single_cls or False,
118
- stride=int(stride),
238
+ stride=stride,
119
239
  pad=0.0 if mode == "train" else 0.5,
120
240
  prefix=colorstr(f"{mode}: "),
121
241
  task=cfg.task,
@@ -125,11 +245,21 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
125
245
  )
126
246
 
127
247
 
128
- def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
248
+ def build_grounding(
249
+ cfg: IterableSimpleNamespace,
250
+ img_path: str,
251
+ json_file: str,
252
+ batch: int,
253
+ mode: str = "train",
254
+ rect: bool = False,
255
+ stride: int = 32,
256
+ max_samples: int = 80,
257
+ ):
129
258
  """Build and return a GroundingDataset based on configuration parameters."""
130
259
  return GroundingDataset(
131
260
  img_path=img_path,
132
261
  json_file=json_file,
262
+ max_samples=max_samples,
133
263
  imgsz=cfg.imgsz,
134
264
  batch_size=batch,
135
265
  augment=mode == "train", # augmentation
@@ -137,7 +267,7 @@ def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, s
137
267
  rect=cfg.rect or rect, # rectangular batches
138
268
  cache=cfg.cache or None,
139
269
  single_cls=cfg.single_cls or False,
140
- stride=int(stride),
270
+ stride=stride,
141
271
  pad=0.0 if mode == "train" else 0.5,
142
272
  prefix=colorstr(f"{mode}: "),
143
273
  task=cfg.task,
@@ -146,24 +276,44 @@ def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, s
146
276
  )
147
277
 
148
278
 
149
- def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
150
- """
151
- Create and return an InfiniteDataLoader or DataLoader for training or validation.
279
+ def build_dataloader(
280
+ dataset,
281
+ batch: int,
282
+ workers: int,
283
+ shuffle: bool = True,
284
+ rank: int = -1,
285
+ drop_last: bool = False,
286
+ pin_memory: bool = True,
287
+ ):
288
+ """Create and return an InfiniteDataLoader or DataLoader for training or validation.
152
289
 
153
290
  Args:
154
291
  dataset (Dataset): Dataset to load data from.
155
292
  batch (int): Batch size for the dataloader.
156
293
  workers (int): Number of worker threads for loading data.
157
- shuffle (bool): Whether to shuffle the dataset.
158
- rank (int): Process rank in distributed training. -1 for single-GPU training.
294
+ shuffle (bool, optional): Whether to shuffle the dataset.
295
+ rank (int, optional): Process rank in distributed training. -1 for single-GPU training.
296
+ drop_last (bool, optional): Whether to drop the last incomplete batch.
297
+ pin_memory (bool, optional): Whether to use pinned memory for dataloader.
159
298
 
160
299
  Returns:
161
300
  (InfiniteDataLoader): A dataloader that can be used for training or validation.
301
+
302
+ Examples:
303
+ Create a dataloader for training
304
+ >>> dataset = YOLODataset(...)
305
+ >>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True)
162
306
  """
163
307
  batch = min(batch, len(dataset))
164
308
  nd = torch.cuda.device_count() # number of CUDA devices
165
309
  nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers
166
- sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
310
+ sampler = (
311
+ None
312
+ if rank == -1
313
+ else distributed.DistributedSampler(dataset, shuffle=shuffle)
314
+ if shuffle
315
+ else ContiguousDistributedSampler(dataset)
316
+ )
167
317
  generator = torch.Generator()
168
318
  generator.manual_seed(6148914691236517205 + RANK)
169
319
  return InfiniteDataLoader(
@@ -172,38 +322,46 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
172
322
  shuffle=shuffle and sampler is None,
173
323
  num_workers=nw,
174
324
  sampler=sampler,
175
- pin_memory=PIN_MEMORY,
325
+ prefetch_factor=4 if nw > 0 else None, # increase over default 2
326
+ pin_memory=nd > 0 and pin_memory,
176
327
  collate_fn=getattr(dataset, "collate_fn", None),
177
328
  worker_init_fn=seed_worker,
178
329
  generator=generator,
330
+ drop_last=drop_last and len(dataset) % batch != 0,
179
331
  )
180
332
 
181
333
 
182
334
  def check_source(source):
183
- """
184
- Check the type of input source and return corresponding flag values.
335
+ """Check the type of input source and return corresponding flag values.
185
336
 
186
337
  Args:
187
- source (str | int | Path | List | Tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
338
+ source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
188
339
 
189
340
  Returns:
190
- source (str | int | Path | List | Tuple | np.ndarray | PIL.Image | torch.Tensor): The processed source.
341
+ source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The processed source.
191
342
  webcam (bool): Whether the source is a webcam.
192
343
  screenshot (bool): Whether the source is a screenshot.
193
344
  from_img (bool): Whether the source is an image or list of images.
194
345
  in_memory (bool): Whether the source is an in-memory object.
195
346
  tensor (bool): Whether the source is a torch.Tensor.
196
347
 
197
- Raises:
198
- TypeError: If the source type is unsupported.
348
+ Examples:
349
+ Check a file path source
350
+ >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source("image.jpg")
351
+
352
+ Check a webcam source
353
+ >>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(0)
199
354
  """
200
355
  webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
201
356
  if isinstance(source, (str, int, Path)): # int for local usb camera
202
357
  source = str(source)
203
- is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
204
- is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
358
+ source_lower = source.lower()
359
+ is_url = source_lower.startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
360
+ is_file = (urlsplit(source_lower).path if is_url else source_lower).rpartition(".")[-1] in (
361
+ IMG_FORMATS | VID_FORMATS
362
+ )
205
363
  webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
206
- screenshot = source.lower() == "screen"
364
+ screenshot = source_lower == "screen"
207
365
  if is_url and is_file:
208
366
  source = check_file(source) # download
209
367
  elif isinstance(source, LOADERS):
@@ -221,19 +379,25 @@ def check_source(source):
221
379
  return source, webcam, screenshot, from_img, in_memory, tensor
222
380
 
223
381
 
224
- def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False, channels=3):
225
- """
226
- Load an inference source for object detection and apply necessary transformations.
382
+ def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):
383
+ """Load an inference source for object detection and apply necessary transformations.
227
384
 
228
385
  Args:
229
386
  source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.
230
387
  batch (int, optional): Batch size for dataloaders.
231
388
  vid_stride (int, optional): The frame interval for video sources.
232
389
  buffer (bool, optional): Whether stream frames will be buffered.
233
- channels (int): The number of input channels for the model.
390
+ channels (int, optional): The number of input channels for the model.
234
391
 
235
392
  Returns:
236
393
  (Dataset): A dataset object for the specified input source with attached source_type attribute.
394
+
395
+ Examples:
396
+ Load an image source for inference
397
+ >>> dataset = load_inference_source("image.jpg", batch=1)
398
+
399
+ Load a video stream source
400
+ >>> dataset = load_inference_source("rtsp://example.com/stream", vid_stride=2)
237
401
  """
238
402
  source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
239
403
  source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)