dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import json
4
6
  from collections import defaultdict
5
7
  from itertools import repeat
6
8
  from multiprocessing.pool import ThreadPool
7
9
  from pathlib import Path
10
+ from typing import Any
8
11
 
9
12
  import cv2
10
13
  import numpy as np
@@ -44,8 +47,7 @@ DATASET_CACHE_VERSION = "1.0.3"
44
47
 
45
48
 
46
49
  class YOLODataset(BaseDataset):
47
- """
48
- Dataset class for loading object detection and/or segmentation labels in YOLO format.
50
+ """Dataset class for loading object detection and/or segmentation labels in YOLO format.
49
51
 
50
52
  This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
51
53
  (OBB) tasks using the YOLO format.
@@ -58,20 +60,19 @@ class YOLODataset(BaseDataset):
58
60
 
59
61
  Methods:
60
62
  cache_labels: Cache dataset labels, check images and read shapes.
61
- get_labels: Returns dictionary of labels for YOLO training.
62
- build_transforms: Builds and appends transforms to the list.
63
- close_mosaic: Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
64
- update_labels_info: Updates label format for different tasks.
65
- collate_fn: Collates data samples into batches.
63
+ get_labels: Return dictionary of labels for YOLO training.
64
+ build_transforms: Build and append transforms to the list.
65
+ close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
66
+ update_labels_info: Update label format for different tasks.
67
+ collate_fn: Collate data samples into batches.
66
68
 
67
69
  Examples:
68
70
  >>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
69
71
  >>> dataset.get_labels()
70
72
  """
71
73
 
72
- def __init__(self, *args, data=None, task="detect", **kwargs):
73
- """
74
- Initialize the YOLODataset.
74
+ def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
75
+ """Initialize the YOLODataset.
75
76
 
76
77
  Args:
77
78
  data (dict, optional): Dataset configuration dictionary.
@@ -84,11 +85,10 @@ class YOLODataset(BaseDataset):
84
85
  self.use_obb = task == "obb"
85
86
  self.data = data
86
87
  assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
87
- super().__init__(*args, channels=self.data["channels"], **kwargs)
88
+ super().__init__(*args, channels=self.data.get("channels", 3), **kwargs)
88
89
 
89
- def cache_labels(self, path=Path("./labels.cache")):
90
- """
91
- Cache dataset labels, check images and read shapes.
90
+ def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
91
+ """Cache dataset labels, check images and read shapes.
92
92
 
93
93
  Args:
94
94
  path (Path): Path where to save the cache file.
@@ -154,14 +154,13 @@ class YOLODataset(BaseDataset):
154
154
  save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
155
155
  return x
156
156
 
157
- def get_labels(self):
158
- """
159
- Returns dictionary of labels for YOLO training.
157
+ def get_labels(self) -> list[dict]:
158
+ """Return dictionary of labels for YOLO training.
160
159
 
161
160
  This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
162
161
 
163
162
  Returns:
164
- (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
163
+ (list[dict]): List of label dictionaries, each containing information about an image and its annotations.
165
164
  """
166
165
  self.label_files = img2label_paths(self.im_files)
167
166
  cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
@@ -169,7 +168,7 @@ class YOLODataset(BaseDataset):
169
168
  cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
170
169
  assert cache["version"] == DATASET_CACHE_VERSION # matches current version
171
170
  assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
172
- except (FileNotFoundError, AssertionError, AttributeError):
171
+ except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
173
172
  cache, exists = self.cache_labels(cache_path), False # run cache ops
174
173
 
175
174
  # Display cache
@@ -204,9 +203,8 @@ class YOLODataset(BaseDataset):
204
203
  LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
205
204
  return labels
206
205
 
207
- def build_transforms(self, hyp=None):
208
- """
209
- Builds and appends transforms to the list.
206
+ def build_transforms(self, hyp: dict | None = None) -> Compose:
207
+ """Build and append transforms to the list.
210
208
 
211
209
  Args:
212
210
  hyp (dict, optional): Hyperparameters for transforms.
@@ -236,9 +234,8 @@ class YOLODataset(BaseDataset):
236
234
  )
237
235
  return transforms
238
236
 
239
- def close_mosaic(self, hyp):
240
- """
241
- Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
237
+ def close_mosaic(self, hyp: dict) -> None:
238
+ """Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
242
239
 
243
240
  Args:
244
241
  hyp (dict): Hyperparameters for transforms.
@@ -249,9 +246,8 @@ class YOLODataset(BaseDataset):
249
246
  hyp.cutmix = 0.0
250
247
  self.transforms = self.build_transforms(hyp)
251
248
 
252
- def update_labels_info(self, label):
253
- """
254
- Custom your label format here.
249
+ def update_labels_info(self, label: dict) -> dict:
250
+ """Update label format for different tasks.
255
251
 
256
252
  Args:
257
253
  label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@@ -259,7 +255,7 @@ class YOLODataset(BaseDataset):
259
255
  Returns:
260
256
  (dict): Updated label dictionary with instances.
261
257
 
262
- Note:
258
+ Notes:
263
259
  cls is not with bboxes now, classification and semantic segmentation need an independent cls label
264
260
  Can also support classification and semantic segmentation by adding or removing dict keys there.
265
261
  """
@@ -283,12 +279,11 @@ class YOLODataset(BaseDataset):
283
279
  return label
284
280
 
285
281
  @staticmethod
286
- def collate_fn(batch):
287
- """
288
- Collates data samples into batches.
282
+ def collate_fn(batch: list[dict]) -> dict:
283
+ """Collate data samples into batches.
289
284
 
290
285
  Args:
291
- batch (List[dict]): List of dictionaries containing sample data.
286
+ batch (list[dict]): List of dictionaries containing sample data.
292
287
 
293
288
  Returns:
294
289
  (dict): Collated batch with stacked tensors.
@@ -314,15 +309,14 @@ class YOLODataset(BaseDataset):
314
309
 
315
310
 
316
311
  class YOLOMultiModalDataset(YOLODataset):
317
- """
318
- Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
312
+ """Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
319
313
 
320
- This class extends YOLODataset to add text information for multi-modal model training, enabling models to
321
- process both image and text data.
314
+ This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
315
+ both image and text data.
322
316
 
323
317
  Methods:
324
- update_labels_info: Adds text information for multi-modal model training.
325
- build_transforms: Enhances data transformations with text augmentation.
318
+ update_labels_info: Add text information for multi-modal model training.
319
+ build_transforms: Enhance data transformations with text augmentation.
326
320
 
327
321
  Examples:
328
322
  >>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
@@ -330,9 +324,8 @@ class YOLOMultiModalDataset(YOLODataset):
330
324
  >>> print(batch.keys()) # Should include 'texts'
331
325
  """
332
326
 
333
- def __init__(self, *args, data=None, task="detect", **kwargs):
334
- """
335
- Initialize a YOLOMultiModalDataset.
327
+ def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
328
+ """Initialize a YOLOMultiModalDataset.
336
329
 
337
330
  Args:
338
331
  data (dict, optional): Dataset configuration dictionary.
@@ -342,9 +335,8 @@ class YOLOMultiModalDataset(YOLODataset):
342
335
  """
343
336
  super().__init__(*args, data=data, task=task, **kwargs)
344
337
 
345
- def update_labels_info(self, label):
346
- """
347
- Add texts information for multi-modal model training.
338
+ def update_labels_info(self, label: dict) -> dict:
339
+ """Add text information for multi-modal model training.
348
340
 
349
341
  Args:
350
342
  label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
@@ -359,9 +351,8 @@ class YOLOMultiModalDataset(YOLODataset):
359
351
 
360
352
  return labels
361
353
 
362
- def build_transforms(self, hyp=None):
363
- """
364
- Enhances data transformations with optional text augmentation for multi-modal training.
354
+ def build_transforms(self, hyp: dict | None = None) -> Compose:
355
+ """Enhance data transformations with optional text augmentation for multi-modal training.
365
356
 
366
357
  Args:
367
358
  hyp (dict, optional): Hyperparameters for transforms.
@@ -385,11 +376,10 @@ class YOLOMultiModalDataset(YOLODataset):
385
376
 
386
377
  @property
387
378
  def category_names(self):
388
- """
389
- Return category names for the dataset.
379
+ """Return category names for the dataset.
390
380
 
391
381
  Returns:
392
- (Set[str]): List of class names.
382
+ (set[str]): List of class names.
393
383
  """
394
384
  names = self.data["names"].values()
395
385
  return {n.strip() for name in names for n in name.split("/")} # category names
@@ -408,48 +398,48 @@ class YOLOMultiModalDataset(YOLODataset):
408
398
  return category_freq
409
399
 
410
400
  @staticmethod
411
- def _get_neg_texts(category_freq, threshold=100):
401
+ def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
412
402
  """Get negative text samples based on frequency threshold."""
403
+ threshold = min(max(category_freq.values()), 100)
413
404
  return [k for k, v in category_freq.items() if v >= threshold]
414
405
 
415
406
 
416
407
  class GroundingDataset(YOLODataset):
417
- """
418
- Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format.
408
+ """Dataset class for object detection tasks using annotations from a JSON file in grounding format.
419
409
 
420
- This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than
421
- the standard YOLO format text files.
410
+ This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
411
+ YOLO format text files.
422
412
 
423
413
  Attributes:
424
414
  json_file (str): Path to the JSON file containing annotations.
425
415
 
426
416
  Methods:
427
- get_img_files: Returns empty list as image files are read in get_labels.
428
- get_labels: Loads annotations from a JSON file and prepares them for training.
429
- build_transforms: Configures augmentations for training with optional text loading.
417
+ get_img_files: Return empty list as image files are read in get_labels.
418
+ get_labels: Load annotations from a JSON file and prepare them for training.
419
+ build_transforms: Configure augmentations for training with optional text loading.
430
420
 
431
421
  Examples:
432
422
  >>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
433
423
  >>> len(dataset) # Number of valid images with annotations
434
424
  """
435
425
 
436
- def __init__(self, *args, task="detect", json_file="", **kwargs):
437
- """
438
- Initialize a GroundingDataset for object detection.
426
+ def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
427
+ """Initialize a GroundingDataset for object detection.
439
428
 
440
429
  Args:
441
430
  json_file (str): Path to the JSON file containing annotations.
442
431
  task (str): Must be 'detect' or 'segment' for GroundingDataset.
432
+ max_samples (int): Maximum number of samples to load for text augmentation.
443
433
  *args (Any): Additional positional arguments for the parent class.
444
434
  **kwargs (Any): Additional keyword arguments for the parent class.
445
435
  """
446
436
  assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
447
437
  self.json_file = json_file
438
+ self.max_samples = max_samples
448
439
  super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
449
440
 
450
- def get_img_files(self, img_path):
451
- """
452
- The image files would be read in `get_labels` function, return empty list here.
441
+ def get_img_files(self, img_path: str) -> list:
442
+ """The image files would be read in `get_labels` function, return empty list here.
453
443
 
454
444
  Args:
455
445
  img_path (str): Path to the directory containing images.
@@ -459,29 +449,47 @@ class GroundingDataset(YOLODataset):
459
449
  """
460
450
  return []
461
451
 
462
- def verify_labels(self, labels):
463
- """Verify the number of instances in the dataset matches expected counts."""
464
- instance_count = sum(label["bboxes"].shape[0] for label in labels)
465
- if "final_mixed_train_no_coco_segm" in self.json_file:
466
- assert instance_count == 3662344
467
- elif "final_mixed_train_no_coco" in self.json_file:
468
- assert instance_count == 3681235
469
- elif "final_flickr_separateGT_train_segm" in self.json_file:
470
- assert instance_count == 638214
471
- elif "final_flickr_separateGT_train" in self.json_file:
472
- assert instance_count == 640704
473
- else:
474
- assert False
452
+ def verify_labels(self, labels: list[dict[str, Any]]) -> None:
453
+ """Verify the number of instances in the dataset matches expected counts.
454
+
455
+ This method checks if the total number of bounding box instances in the provided labels matches the expected
456
+ count for known datasets. It performs validation against a predefined set of datasets with known instance
457
+ counts.
458
+
459
+ Args:
460
+ labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
461
+ annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
462
+ box coordinates.
463
+
464
+ Raises:
465
+ AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
475
466
 
476
- def cache_labels(self, path=Path("./labels.cache")):
467
+ Notes:
468
+ For unrecognized datasets (those not in the predefined expected_counts),
469
+ a warning is logged and verification is skipped.
477
470
  """
478
- Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
471
+ expected_counts = {
472
+ "final_mixed_train_no_coco_segm": 3662412,
473
+ "final_mixed_train_no_coco": 3681235,
474
+ "final_flickr_separateGT_train_segm": 638214,
475
+ "final_flickr_separateGT_train": 640704,
476
+ }
477
+
478
+ instance_count = sum(label["bboxes"].shape[0] for label in labels)
479
+ for data_name, count in expected_counts.items():
480
+ if data_name in self.json_file:
481
+ assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}."
482
+ return
483
+ LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
484
+
485
+ def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
486
+ """Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
479
487
 
480
488
  Args:
481
489
  path (Path): Path where to save the cache file.
482
490
 
483
491
  Returns:
484
- (dict): Dictionary containing cached labels and related information.
492
+ (dict[str, Any]): Dictionary containing cached labels and related information.
485
493
  """
486
494
  x = {"labels": []}
487
495
  LOGGER.info("Loading annotation file...")
@@ -521,7 +529,7 @@ class GroundingDataset(YOLODataset):
521
529
  cat2id[cat_name] = len(cat2id)
522
530
  texts.append([cat_name])
523
531
  cls = cat2id[cat_name] # class
524
- box = [cls] + box.tolist()
532
+ box = [cls, *box.tolist()]
525
533
  if box not in bboxes:
526
534
  bboxes.append(box)
527
535
  if ann.get("segmentation") is not None:
@@ -538,7 +546,7 @@ class GroundingDataset(YOLODataset):
538
546
  .reshape(-1)
539
547
  .tolist()
540
548
  )
541
- s = [cls] + s
549
+ s = [cls, *s]
542
550
  segments.append(s)
543
551
  lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
544
552
 
@@ -564,31 +572,29 @@ class GroundingDataset(YOLODataset):
564
572
  save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
565
573
  return x
566
574
 
567
- def get_labels(self):
568
- """
569
- Load labels from cache or generate them from JSON file.
575
+ def get_labels(self) -> list[dict]:
576
+ """Load labels from cache or generate them from JSON file.
570
577
 
571
578
  Returns:
572
- (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
579
+ (list[dict]): List of label dictionaries, each containing information about an image and its annotations.
573
580
  """
574
581
  cache_path = Path(self.json_file).with_suffix(".cache")
575
582
  try:
576
583
  cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
577
584
  assert cache["version"] == DATASET_CACHE_VERSION # matches current version
578
585
  assert cache["hash"] == get_hash(self.json_file) # identical hash
579
- except (FileNotFoundError, AssertionError, AttributeError):
586
+ except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
580
587
  cache, _ = self.cache_labels(cache_path), False # run cache ops
581
588
  [cache.pop(k) for k in ("hash", "version")] # remove items
582
589
  labels = cache["labels"]
583
- # self.verify_labels(labels)
590
+ self.verify_labels(labels)
584
591
  self.im_files = [str(label["im_file"]) for label in labels]
585
592
  if LOCAL_RANK in {-1, 0}:
586
593
  LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
587
594
  return labels
588
595
 
589
- def build_transforms(self, hyp=None):
590
- """
591
- Configures augmentations for training with optional text loading.
596
+ def build_transforms(self, hyp: dict | None = None) -> Compose:
597
+ """Configure augmentations for training with optional text loading.
592
598
 
593
599
  Args:
594
600
  hyp (dict, optional): Hyperparameters for transforms.
@@ -603,7 +609,7 @@ class GroundingDataset(YOLODataset):
603
609
  # the strategy of selecting negative is restricted in one dataset,
604
610
  # while official pre-saved neg embeddings from all datasets at once.
605
611
  transform = RandomLoadText(
606
- max_samples=80,
612
+ max_samples=min(self.max_samples, 80),
607
613
  padding=True,
608
614
  padding_value=self._get_neg_texts(self.category_freq),
609
615
  )
@@ -627,17 +633,17 @@ class GroundingDataset(YOLODataset):
627
633
  return category_freq
628
634
 
629
635
  @staticmethod
630
- def _get_neg_texts(category_freq, threshold=100):
636
+ def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
631
637
  """Get negative text samples based on frequency threshold."""
638
+ threshold = min(max(category_freq.values()), 100)
632
639
  return [k for k, v in category_freq.items() if v >= threshold]
633
640
 
634
641
 
635
642
  class YOLOConcatDataset(ConcatDataset):
636
- """
637
- Dataset as a concatenation of multiple datasets.
643
+ """Dataset as a concatenation of multiple datasets.
638
644
 
639
- This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same
640
- collation function.
645
+ This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
646
+ function.
641
647
 
642
648
  Methods:
643
649
  collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
@@ -649,21 +655,19 @@ class YOLOConcatDataset(ConcatDataset):
649
655
  """
650
656
 
651
657
  @staticmethod
652
- def collate_fn(batch):
653
- """
654
- Collates data samples into batches.
658
+ def collate_fn(batch: list[dict]) -> dict:
659
+ """Collate data samples into batches.
655
660
 
656
661
  Args:
657
- batch (List[dict]): List of dictionaries containing sample data.
662
+ batch (list[dict]): List of dictionaries containing sample data.
658
663
 
659
664
  Returns:
660
665
  (dict): Collated batch with stacked tensors.
661
666
  """
662
667
  return YOLODataset.collate_fn(batch)
663
668
 
664
- def close_mosaic(self, hyp):
665
- """
666
- Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
669
+ def close_mosaic(self, hyp: dict) -> None:
670
+ """Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
667
671
 
668
672
  Args:
669
673
  hyp (dict): Hyperparameters for transforms.
@@ -684,8 +688,7 @@ class SemanticDataset(BaseDataset):
684
688
 
685
689
 
686
690
  class ClassificationDataset:
687
- """
688
- Extends torchvision ImageFolder to support YOLO classification tasks.
691
+ """Dataset class for image classification tasks extending torchvision ImageFolder functionality.
689
692
 
690
693
  This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
691
694
  handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
@@ -695,20 +698,19 @@ class ClassificationDataset:
695
698
  cache_ram (bool): Indicates if caching in RAM is enabled.
696
699
  cache_disk (bool): Indicates if caching on disk is enabled.
697
700
  samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
698
- file (if caching on disk), and optionally the loaded image array (if caching in RAM).
701
+ file (if caching on disk), and optionally the loaded image array (if caching in RAM).
699
702
  torch_transforms (callable): PyTorch transforms to be applied to the images.
700
703
  root (str): Root directory of the dataset.
701
704
  prefix (str): Prefix for logging and cache filenames.
702
705
 
703
706
  Methods:
704
- __getitem__: Returns subset of data and targets corresponding to given indices.
705
- __len__: Returns the total number of samples in the dataset.
706
- verify_images: Verifies all images in dataset.
707
+ __getitem__: Return subset of data and targets corresponding to given indices.
708
+ __len__: Return the total number of samples in the dataset.
709
+ verify_images: Verify all images in dataset.
707
710
  """
708
711
 
709
- def __init__(self, root, args, augment=False, prefix=""):
710
- """
711
- Initialize YOLO object with root, image size, augmentations, and cache settings.
712
+ def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
713
+ """Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
712
714
 
713
715
  Args:
714
716
  root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
@@ -740,7 +742,7 @@ class ClassificationDataset:
740
742
  self.cache_ram = False
741
743
  self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
742
744
  self.samples = self.verify_images() # filter out bad images
743
- self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
745
+ self.samples = [[*list(x), Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
744
746
  scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
745
747
  self.torch_transforms = (
746
748
  classify_augmentations(
@@ -758,9 +760,8 @@ class ClassificationDataset:
758
760
  else classify_transforms(size=args.imgsz)
759
761
  )
760
762
 
761
- def __getitem__(self, i):
762
- """
763
- Returns subset of data and targets corresponding to given indices.
763
+ def __getitem__(self, i: int) -> dict:
764
+ """Return subset of data and targets corresponding to given indices.
764
765
 
765
766
  Args:
766
767
  i (int): Index of the sample to retrieve.
@@ -787,9 +788,8 @@ class ClassificationDataset:
787
788
  """Return the total number of samples in the dataset."""
788
789
  return len(self.samples)
789
790
 
790
- def verify_images(self):
791
- """
792
- Verify all images in dataset.
791
+ def verify_images(self) -> list[tuple]:
792
+ """Verify all images in dataset.
793
793
 
794
794
  Returns:
795
795
  (list): List of valid samples after verification.