dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
  2. dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
  3. dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +23 -0
  8. tests/conftest.py +59 -0
  9. tests/test_cli.py +131 -0
  10. tests/test_cuda.py +216 -0
  11. tests/test_engine.py +157 -0
  12. tests/test_exports.py +309 -0
  13. tests/test_integrations.py +151 -0
  14. tests/test_python.py +777 -0
  15. tests/test_solutions.py +371 -0
  16. ultralytics/__init__.py +48 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1028 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  29. ultralytics/cfg/datasets/VOC.yaml +102 -0
  30. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  31. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  32. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  33. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  34. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  35. ultralytics/cfg/datasets/coco.yaml +118 -0
  36. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  37. ultralytics/cfg/datasets/coco128.yaml +101 -0
  38. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  39. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  40. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  41. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  42. ultralytics/cfg/datasets/coco8.yaml +101 -0
  43. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  44. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  45. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  46. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  47. ultralytics/cfg/datasets/dota8.yaml +35 -0
  48. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  49. ultralytics/cfg/datasets/kitti.yaml +27 -0
  50. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  51. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  52. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  53. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  54. ultralytics/cfg/datasets/signature.yaml +21 -0
  55. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  56. ultralytics/cfg/datasets/xView.yaml +155 -0
  57. ultralytics/cfg/default.yaml +130 -0
  58. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  59. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  60. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  61. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  62. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  63. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  64. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  65. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  67. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  68. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  69. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  70. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  71. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  72. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  73. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  74. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  75. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  77. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  78. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  79. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  80. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  81. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  82. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  83. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  84. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  85. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  86. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  87. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  88. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  89. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  90. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  91. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  92. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  93. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  94. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  95. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  97. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  99. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  100. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  102. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  103. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  104. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  105. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  106. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  109. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  110. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  111. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  112. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  113. ultralytics/cfg/trackers/botsort.yaml +21 -0
  114. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  115. ultralytics/data/__init__.py +26 -0
  116. ultralytics/data/annotator.py +66 -0
  117. ultralytics/data/augment.py +2801 -0
  118. ultralytics/data/base.py +435 -0
  119. ultralytics/data/build.py +437 -0
  120. ultralytics/data/converter.py +855 -0
  121. ultralytics/data/dataset.py +834 -0
  122. ultralytics/data/loaders.py +704 -0
  123. ultralytics/data/scripts/download_weights.sh +18 -0
  124. ultralytics/data/scripts/get_coco.sh +61 -0
  125. ultralytics/data/scripts/get_coco128.sh +18 -0
  126. ultralytics/data/scripts/get_imagenet.sh +52 -0
  127. ultralytics/data/split.py +138 -0
  128. ultralytics/data/split_dota.py +344 -0
  129. ultralytics/data/utils.py +798 -0
  130. ultralytics/engine/__init__.py +1 -0
  131. ultralytics/engine/exporter.py +1580 -0
  132. ultralytics/engine/model.py +1125 -0
  133. ultralytics/engine/predictor.py +508 -0
  134. ultralytics/engine/results.py +1522 -0
  135. ultralytics/engine/trainer.py +977 -0
  136. ultralytics/engine/tuner.py +449 -0
  137. ultralytics/engine/validator.py +387 -0
  138. ultralytics/hub/__init__.py +166 -0
  139. ultralytics/hub/auth.py +151 -0
  140. ultralytics/hub/google/__init__.py +174 -0
  141. ultralytics/hub/session.py +422 -0
  142. ultralytics/hub/utils.py +162 -0
  143. ultralytics/models/__init__.py +9 -0
  144. ultralytics/models/fastsam/__init__.py +7 -0
  145. ultralytics/models/fastsam/model.py +79 -0
  146. ultralytics/models/fastsam/predict.py +169 -0
  147. ultralytics/models/fastsam/utils.py +23 -0
  148. ultralytics/models/fastsam/val.py +38 -0
  149. ultralytics/models/nas/__init__.py +7 -0
  150. ultralytics/models/nas/model.py +98 -0
  151. ultralytics/models/nas/predict.py +56 -0
  152. ultralytics/models/nas/val.py +38 -0
  153. ultralytics/models/rtdetr/__init__.py +7 -0
  154. ultralytics/models/rtdetr/model.py +63 -0
  155. ultralytics/models/rtdetr/predict.py +88 -0
  156. ultralytics/models/rtdetr/train.py +89 -0
  157. ultralytics/models/rtdetr/val.py +216 -0
  158. ultralytics/models/sam/__init__.py +25 -0
  159. ultralytics/models/sam/amg.py +275 -0
  160. ultralytics/models/sam/build.py +365 -0
  161. ultralytics/models/sam/build_sam3.py +377 -0
  162. ultralytics/models/sam/model.py +169 -0
  163. ultralytics/models/sam/modules/__init__.py +1 -0
  164. ultralytics/models/sam/modules/blocks.py +1067 -0
  165. ultralytics/models/sam/modules/decoders.py +495 -0
  166. ultralytics/models/sam/modules/encoders.py +794 -0
  167. ultralytics/models/sam/modules/memory_attention.py +298 -0
  168. ultralytics/models/sam/modules/sam.py +1160 -0
  169. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  170. ultralytics/models/sam/modules/transformer.py +344 -0
  171. ultralytics/models/sam/modules/utils.py +512 -0
  172. ultralytics/models/sam/predict.py +3940 -0
  173. ultralytics/models/sam/sam3/__init__.py +3 -0
  174. ultralytics/models/sam/sam3/decoder.py +546 -0
  175. ultralytics/models/sam/sam3/encoder.py +529 -0
  176. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  177. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  178. ultralytics/models/sam/sam3/model_misc.py +199 -0
  179. ultralytics/models/sam/sam3/necks.py +129 -0
  180. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  181. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  182. ultralytics/models/sam/sam3/vitdet.py +547 -0
  183. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  184. ultralytics/models/utils/__init__.py +1 -0
  185. ultralytics/models/utils/loss.py +466 -0
  186. ultralytics/models/utils/ops.py +315 -0
  187. ultralytics/models/yolo/__init__.py +7 -0
  188. ultralytics/models/yolo/classify/__init__.py +7 -0
  189. ultralytics/models/yolo/classify/predict.py +90 -0
  190. ultralytics/models/yolo/classify/train.py +202 -0
  191. ultralytics/models/yolo/classify/val.py +216 -0
  192. ultralytics/models/yolo/detect/__init__.py +7 -0
  193. ultralytics/models/yolo/detect/predict.py +122 -0
  194. ultralytics/models/yolo/detect/train.py +227 -0
  195. ultralytics/models/yolo/detect/val.py +507 -0
  196. ultralytics/models/yolo/model.py +430 -0
  197. ultralytics/models/yolo/obb/__init__.py +7 -0
  198. ultralytics/models/yolo/obb/predict.py +56 -0
  199. ultralytics/models/yolo/obb/train.py +79 -0
  200. ultralytics/models/yolo/obb/val.py +302 -0
  201. ultralytics/models/yolo/pose/__init__.py +7 -0
  202. ultralytics/models/yolo/pose/predict.py +65 -0
  203. ultralytics/models/yolo/pose/train.py +110 -0
  204. ultralytics/models/yolo/pose/val.py +248 -0
  205. ultralytics/models/yolo/segment/__init__.py +7 -0
  206. ultralytics/models/yolo/segment/predict.py +109 -0
  207. ultralytics/models/yolo/segment/train.py +69 -0
  208. ultralytics/models/yolo/segment/val.py +307 -0
  209. ultralytics/models/yolo/world/__init__.py +5 -0
  210. ultralytics/models/yolo/world/train.py +173 -0
  211. ultralytics/models/yolo/world/train_world.py +178 -0
  212. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  213. ultralytics/models/yolo/yoloe/predict.py +162 -0
  214. ultralytics/models/yolo/yoloe/train.py +287 -0
  215. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  216. ultralytics/models/yolo/yoloe/val.py +206 -0
  217. ultralytics/nn/__init__.py +27 -0
  218. ultralytics/nn/autobackend.py +964 -0
  219. ultralytics/nn/modules/__init__.py +182 -0
  220. ultralytics/nn/modules/activation.py +54 -0
  221. ultralytics/nn/modules/block.py +1947 -0
  222. ultralytics/nn/modules/conv.py +669 -0
  223. ultralytics/nn/modules/head.py +1183 -0
  224. ultralytics/nn/modules/transformer.py +793 -0
  225. ultralytics/nn/modules/utils.py +159 -0
  226. ultralytics/nn/tasks.py +1768 -0
  227. ultralytics/nn/text_model.py +356 -0
  228. ultralytics/py.typed +1 -0
  229. ultralytics/solutions/__init__.py +41 -0
  230. ultralytics/solutions/ai_gym.py +108 -0
  231. ultralytics/solutions/analytics.py +264 -0
  232. ultralytics/solutions/config.py +107 -0
  233. ultralytics/solutions/distance_calculation.py +123 -0
  234. ultralytics/solutions/heatmap.py +125 -0
  235. ultralytics/solutions/instance_segmentation.py +86 -0
  236. ultralytics/solutions/object_blurrer.py +89 -0
  237. ultralytics/solutions/object_counter.py +190 -0
  238. ultralytics/solutions/object_cropper.py +87 -0
  239. ultralytics/solutions/parking_management.py +280 -0
  240. ultralytics/solutions/queue_management.py +93 -0
  241. ultralytics/solutions/region_counter.py +133 -0
  242. ultralytics/solutions/security_alarm.py +151 -0
  243. ultralytics/solutions/similarity_search.py +219 -0
  244. ultralytics/solutions/solutions.py +828 -0
  245. ultralytics/solutions/speed_estimation.py +114 -0
  246. ultralytics/solutions/streamlit_inference.py +260 -0
  247. ultralytics/solutions/templates/similarity-search.html +156 -0
  248. ultralytics/solutions/trackzone.py +88 -0
  249. ultralytics/solutions/vision_eye.py +67 -0
  250. ultralytics/trackers/__init__.py +7 -0
  251. ultralytics/trackers/basetrack.py +115 -0
  252. ultralytics/trackers/bot_sort.py +257 -0
  253. ultralytics/trackers/byte_tracker.py +469 -0
  254. ultralytics/trackers/track.py +116 -0
  255. ultralytics/trackers/utils/__init__.py +1 -0
  256. ultralytics/trackers/utils/gmc.py +339 -0
  257. ultralytics/trackers/utils/kalman_filter.py +482 -0
  258. ultralytics/trackers/utils/matching.py +154 -0
  259. ultralytics/utils/__init__.py +1450 -0
  260. ultralytics/utils/autobatch.py +118 -0
  261. ultralytics/utils/autodevice.py +205 -0
  262. ultralytics/utils/benchmarks.py +728 -0
  263. ultralytics/utils/callbacks/__init__.py +5 -0
  264. ultralytics/utils/callbacks/base.py +233 -0
  265. ultralytics/utils/callbacks/clearml.py +146 -0
  266. ultralytics/utils/callbacks/comet.py +625 -0
  267. ultralytics/utils/callbacks/dvc.py +197 -0
  268. ultralytics/utils/callbacks/hub.py +110 -0
  269. ultralytics/utils/callbacks/mlflow.py +134 -0
  270. ultralytics/utils/callbacks/neptune.py +126 -0
  271. ultralytics/utils/callbacks/platform.py +453 -0
  272. ultralytics/utils/callbacks/raytune.py +42 -0
  273. ultralytics/utils/callbacks/tensorboard.py +123 -0
  274. ultralytics/utils/callbacks/wb.py +188 -0
  275. ultralytics/utils/checks.py +1020 -0
  276. ultralytics/utils/cpu.py +85 -0
  277. ultralytics/utils/dist.py +123 -0
  278. ultralytics/utils/downloads.py +529 -0
  279. ultralytics/utils/errors.py +35 -0
  280. ultralytics/utils/events.py +113 -0
  281. ultralytics/utils/export/__init__.py +7 -0
  282. ultralytics/utils/export/engine.py +237 -0
  283. ultralytics/utils/export/imx.py +325 -0
  284. ultralytics/utils/export/tensorflow.py +231 -0
  285. ultralytics/utils/files.py +219 -0
  286. ultralytics/utils/git.py +137 -0
  287. ultralytics/utils/instance.py +484 -0
  288. ultralytics/utils/logger.py +506 -0
  289. ultralytics/utils/loss.py +849 -0
  290. ultralytics/utils/metrics.py +1563 -0
  291. ultralytics/utils/nms.py +337 -0
  292. ultralytics/utils/ops.py +664 -0
  293. ultralytics/utils/patches.py +201 -0
  294. ultralytics/utils/plotting.py +1047 -0
  295. ultralytics/utils/tal.py +404 -0
  296. ultralytics/utils/torch_utils.py +984 -0
  297. ultralytics/utils/tqdm.py +443 -0
  298. ultralytics/utils/triton.py +112 -0
  299. ultralytics/utils/tuner.py +168 -0
@@ -0,0 +1,798 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import random
8
+ import subprocess
9
+ import time
10
+ import zipfile
11
+ from multiprocessing.pool import ThreadPool
12
+ from pathlib import Path
13
+ from tarfile import is_tarfile
14
+ from typing import Any
15
+
16
+ import cv2
17
+ import numpy as np
18
+ from PIL import Image, ImageOps
19
+
20
+ from ultralytics.nn.autobackend import check_class_names
21
+ from ultralytics.utils import (
22
+ ASSETS_URL,
23
+ DATASETS_DIR,
24
+ LOGGER,
25
+ NUM_THREADS,
26
+ ROOT,
27
+ SETTINGS_FILE,
28
+ TQDM,
29
+ YAML,
30
+ clean_url,
31
+ colorstr,
32
+ emojis,
33
+ is_dir_writeable,
34
+ )
35
+ from ultralytics.utils.checks import check_file, check_font, is_ascii
36
+ from ultralytics.utils.downloads import download, safe_download, unzip_file
37
+ from ultralytics.utils.ops import segments2boxes
38
+
39
+ HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
40
+ IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
41
+ VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
42
+ FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
43
+
44
+
45
+ def img2label_paths(img_paths: list[str]) -> list[str]:
46
+ """Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'."""
47
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
48
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
49
+
50
+
51
+ def check_file_speeds(
52
+ files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
53
+ ):
54
+ """Check dataset file access speed and provide performance feedback.
55
+
56
+ This function tests the access speed of dataset files by measuring ping (stat call) time and read speed. It samples
57
+ up to 5 files from the provided list and warns if access times exceed the threshold.
58
+
59
+ Args:
60
+ files (list[str]): List of file paths to check for access speed.
61
+ threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
62
+ threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.
63
+ max_files (int, optional): The maximum number of files to check.
64
+ prefix (str, optional): Prefix string to add to log messages.
65
+
66
+ Examples:
67
+ >>> from pathlib import Path
68
+ >>> image_files = list(Path("dataset/images").glob("*.jpg"))
69
+ >>> check_file_speeds(image_files, threshold_ms=15)
70
+ """
71
+ if not files:
72
+ LOGGER.warning(f"{prefix}Image speed checks: No files to check")
73
+ return
74
+
75
+ # Sample files (max 5)
76
+ files = random.sample(files, min(max_files, len(files)))
77
+
78
+ # Test ping (stat time)
79
+ ping_times = []
80
+ file_sizes = []
81
+ read_speeds = []
82
+
83
+ for f in files:
84
+ try:
85
+ # Measure ping (stat call)
86
+ start = time.perf_counter()
87
+ file_size = os.stat(f).st_size
88
+ ping_times.append((time.perf_counter() - start) * 1000) # ms
89
+ file_sizes.append(file_size)
90
+
91
+ # Measure read speed
92
+ start = time.perf_counter()
93
+ with open(f, "rb") as file_obj:
94
+ _ = file_obj.read()
95
+ read_time = time.perf_counter() - start
96
+ if read_time > 0: # Avoid division by zero
97
+ read_speeds.append(file_size / (1 << 20) / read_time) # MB/s
98
+ except Exception:
99
+ pass
100
+
101
+ if not ping_times:
102
+ LOGGER.warning(f"{prefix}Image speed checks: failed to access files")
103
+ return
104
+
105
+ # Calculate stats with uncertainties
106
+ avg_ping = np.mean(ping_times)
107
+ std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0
108
+ size_msg = f", size: {np.mean(file_sizes) / (1 << 10):.1f} KB"
109
+ ping_msg = f"ping: {avg_ping:.1f}±{std_ping:.1f} ms"
110
+
111
+ if read_speeds:
112
+ avg_speed = np.mean(read_speeds)
113
+ std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0
114
+ speed_msg = f", read: {avg_speed:.1f}±{std_speed:.1f} MB/s"
115
+ else:
116
+ speed_msg = ""
117
+
118
+ if avg_ping < threshold_ms or avg_speed < threshold_mb:
119
+ LOGGER.info(f"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})")
120
+ else:
121
+ LOGGER.warning(
122
+ f"{prefix}Slow image access detected ({ping_msg}{speed_msg}{size_msg}). "
123
+ f"Use local storage instead of remote/mounted storage for better performance. "
124
+ f"See https://docs.ultralytics.com/guides/model-training-tips/"
125
+ )
126
+
127
+
128
+ def get_hash(paths: list[str]) -> str:
129
+ """Return a single hash value of a list of paths (files or dirs)."""
130
+ size = 0
131
+ for p in paths:
132
+ try:
133
+ size += os.stat(p).st_size
134
+ except OSError:
135
+ continue
136
+ h = __import__("hashlib").sha256(str(size).encode()) # hash sizes
137
+ h.update("".join(paths).encode()) # hash paths
138
+ return h.hexdigest() # return hash
139
+
140
+
141
+ def exif_size(img: Image.Image) -> tuple[int, int]:
142
+ """Return exif-corrected PIL size."""
143
+ s = img.size # (width, height)
144
+ if img.format == "JPEG": # only support JPEG images
145
+ try:
146
+ if exif := img.getexif():
147
+ rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
148
+ if rotation in {6, 8}: # rotation 270 or 90
149
+ s = s[1], s[0]
150
+ except Exception:
151
+ pass
152
+ return s
153
+
154
+
155
+ def verify_image(args: tuple) -> tuple:
156
+ """Verify one image."""
157
+ (im_file, cls), prefix = args
158
+ # Number (found, corrupt), message
159
+ nf, nc, msg = 0, 0, ""
160
+ try:
161
+ im = Image.open(im_file)
162
+ im.verify() # PIL verify
163
+ shape = exif_size(im) # image size
164
+ shape = (shape[1], shape[0]) # hw
165
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
166
+ assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
167
+ if im.format.lower() in {"jpg", "jpeg"}:
168
+ with open(im_file, "rb") as f:
169
+ f.seek(-2, 2)
170
+ if f.read() != b"\xff\xd9": # corrupt JPEG
171
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
172
+ msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
173
+ nf = 1
174
+ except Exception as e:
175
+ nc = 1
176
+ msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
177
+ return (im_file, cls), nf, nc, msg
178
+
179
+
180
+ def verify_image_label(args: tuple) -> list:
181
+ """Verify one image-label pair."""
182
+ im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args
183
+ # Number (missing, found, empty, corrupt), message, segments, keypoints
184
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
185
+ try:
186
+ # Verify images
187
+ im = Image.open(im_file)
188
+ im.verify() # PIL verify
189
+ shape = exif_size(im) # image size
190
+ shape = (shape[1], shape[0]) # hw
191
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
192
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
193
+ if im.format.lower() in {"jpg", "jpeg"}:
194
+ with open(im_file, "rb") as f:
195
+ f.seek(-2, 2)
196
+ if f.read() != b"\xff\xd9": # corrupt JPEG
197
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
198
+ msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
199
+
200
+ # Verify labels
201
+ if os.path.isfile(lb_file):
202
+ nf = 1 # label found
203
+ with open(lb_file, encoding="utf-8") as f:
204
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
205
+ if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
206
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
207
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
208
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
209
+ lb = np.array(lb, dtype=np.float32)
210
+ if nl := len(lb):
211
+ if keypoint:
212
+ assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
213
+ points = lb[:, 5:].reshape(-1, ndim)[:, :2]
214
+ else:
215
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
216
+ points = lb[:, 1:]
217
+ # Coordinate points check with 1% tolerance
218
+ assert points.max() <= 1.01, f"non-normalized or out of bounds coordinates {points[points > 1.01]}"
219
+ assert lb.min() >= -0.01, f"negative class labels or coordinate {lb[lb < -0.01]}"
220
+
221
+ # All labels
222
+ max_cls = 0 if single_cls else lb[:, 0].max() # max label count
223
+ assert max_cls < num_cls, (
224
+ f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
225
+ f"Possible class labels are 0-{num_cls - 1}"
226
+ )
227
+ _, i = np.unique(lb, axis=0, return_index=True)
228
+ if len(i) < nl: # duplicate row check
229
+ lb = lb[i] # remove duplicates
230
+ if segments:
231
+ segments = [segments[x] for x in i]
232
+ msg = f"{prefix}{im_file}: {nl - len(i)} duplicate labels removed"
233
+ else:
234
+ ne = 1 # label empty
235
+ lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
236
+ else:
237
+ nm = 1 # label missing
238
+ lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
239
+ if keypoint:
240
+ keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
241
+ if ndim == 2:
242
+ kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
243
+ keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
244
+ lb = lb[:, :5]
245
+ return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
246
+ except Exception as e:
247
+ nc = 1
248
+ msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
249
+ return [None, None, None, None, None, nm, nf, ne, nc, msg]
250
+
251
+
252
+ def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
253
+ """Visualize YOLO annotations (bounding boxes and class labels) on an image.
254
+
255
+ This function reads an image and its corresponding annotation file in YOLO format, then draws bounding boxes around
256
+ detected objects and labels them with their respective class names. The bounding box colors are assigned based on
257
+ the class ID, and the text color is dynamically adjusted for readability, depending on the background color's
258
+ luminance.
259
+
260
+ Args:
261
+ image_path (str): Path to the image file to annotate. The file must be readable by PIL.
262
+ txt_path (str): Path to the annotation file in YOLO format, which should contain one line per object.
263
+ label_map (dict[int, str]): A dictionary that maps class IDs (integers) to class labels (strings).
264
+
265
+ Examples:
266
+ >>> label_map = {0: "cat", 1: "dog", 2: "bird"} # Should include all annotated classes
267
+ >>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map)
268
+ """
269
+ import matplotlib.pyplot as plt
270
+
271
+ from ultralytics.utils.plotting import colors
272
+
273
+ img = np.array(Image.open(image_path))
274
+ img_height, img_width = img.shape[:2]
275
+ annotations = []
276
+ with open(txt_path, encoding="utf-8") as file:
277
+ for line in file:
278
+ class_id, x_center, y_center, width, height = map(float, line.split())
279
+ x = (x_center - width / 2) * img_width
280
+ y = (y_center - height / 2) * img_height
281
+ w = width * img_width
282
+ h = height * img_height
283
+ annotations.append((x, y, w, h, int(class_id)))
284
+ _, ax = plt.subplots(1) # Plot the image and annotations
285
+ for x, y, w, h, label in annotations:
286
+ color = tuple(c / 255 for c in colors(label, False)) # Get and normalize an RGB color for Matplotlib
287
+ rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
288
+ ax.add_patch(rect)
289
+ luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance
290
+ ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color)
291
+ ax.imshow(img)
292
+ plt.show()
293
+
294
+
295
+ def polygon2mask(
296
+ imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
297
+ ) -> np.ndarray:
298
+ """Convert a list of polygons to a binary mask of the specified image size.
299
+
300
+ Args:
301
+ imgsz (tuple[int, int]): The size of the image as (height, width).
302
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
303
+ number of polygons, and M is the number of points such that M % 2 = 0.
304
+ color (int, optional): The color value to fill in the polygons on the mask.
305
+ downsample_ratio (int, optional): Factor by which to downsample the mask.
306
+
307
+ Returns:
308
+ (np.ndarray): A binary mask of the specified image size with the polygons filled in.
309
+ """
310
+ mask = np.zeros(imgsz, dtype=np.uint8)
311
+ polygons = np.asarray(polygons, dtype=np.int32)
312
+ polygons = polygons.reshape((polygons.shape[0], -1, 2))
313
+ cv2.fillPoly(mask, polygons, color=color)
314
+ nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
315
+ # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
316
+ return cv2.resize(mask, (nw, nh))
317
+
318
+
319
+ def polygons2masks(
320
+ imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
321
+ ) -> np.ndarray:
322
+ """Convert a list of polygons to a set of binary masks of the specified image size.
323
+
324
+ Args:
325
+ imgsz (tuple[int, int]): The size of the image as (height, width).
326
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
327
+ number of polygons, and M is the number of points such that M % 2 = 0.
328
+ color (int): The color value to fill in the polygons on the masks.
329
+ downsample_ratio (int, optional): Factor by which to downsample each mask.
330
+
331
+ Returns:
332
+ (np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
333
+ """
334
+ return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
335
+
336
+
337
+ def polygons2masks_overlap(
338
+ imgsz: tuple[int, int], segments: list[np.ndarray], downsample_ratio: int = 1
339
+ ) -> tuple[np.ndarray, np.ndarray]:
340
+ """Return a (640, 640) overlap mask."""
341
+ masks = np.zeros(
342
+ (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
343
+ dtype=np.int32 if len(segments) > 255 else np.uint8,
344
+ )
345
+ areas = []
346
+ ms = []
347
+ for segment in segments:
348
+ mask = polygon2mask(
349
+ imgsz,
350
+ [segment.reshape(-1)],
351
+ downsample_ratio=downsample_ratio,
352
+ color=1,
353
+ )
354
+ ms.append(mask.astype(masks.dtype))
355
+ areas.append(mask.sum())
356
+ areas = np.asarray(areas)
357
+ index = np.argsort(-areas)
358
+ ms = np.array(ms)[index]
359
+ for i in range(len(segments)):
360
+ mask = ms[i] * (i + 1)
361
+ masks = masks + mask
362
+ masks = np.clip(masks, a_min=0, a_max=i + 1)
363
+ return masks, index
364
+
365
+
366
+ def find_dataset_yaml(path: Path) -> Path:
367
+ """Find and return the YAML file associated with a Detect, Segment or Pose dataset.
368
+
369
+ This function searches for a YAML file at the root level of the provided directory first, and if not found, it
370
+ performs a recursive search. It prefers YAML files that have the same stem as the provided path.
371
+
372
+ Args:
373
+ path (Path): The directory path to search for the YAML file.
374
+
375
+ Returns:
376
+ (Path): The path of the found YAML file.
377
+ """
378
+ files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
379
+ assert files, f"No YAML file found in '{path.resolve()}'"
380
+ if len(files) > 1:
381
+ files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
382
+ assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
383
+ return files[0]
384
+
385
+
386
+ def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
387
+ """Download, verify, and/or unzip a dataset if not found locally.
388
+
389
+ This function checks the availability of a specified dataset, and if not found, it has the option to download and
390
+ unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
391
+ resolves paths related to the dataset.
392
+
393
+ Args:
394
+ dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
395
+ autodownload (bool, optional): Whether to automatically download the dataset if not found.
396
+
397
+ Returns:
398
+ (dict[str, Any]): Parsed dataset information and paths.
399
+ """
400
+ file = check_file(dataset)
401
+
402
+ # Download (optional)
403
+ extract_dir = ""
404
+ if zipfile.is_zipfile(file) or is_tarfile(file):
405
+ new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
406
+ file = find_dataset_yaml(DATASETS_DIR / new_dir)
407
+ extract_dir, autodownload = file.parent, False
408
+
409
+ # Read YAML
410
+ data = YAML.load(file, append_filename=True) # dictionary
411
+
412
+ # Checks
413
+ for k in "train", "val":
414
+ if k not in data:
415
+ if k != "val" or "validation" not in data:
416
+ raise SyntaxError(
417
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
418
+ )
419
+ LOGGER.warning("renaming data YAML 'validation' key to 'val' to match YOLO format.")
420
+ data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
421
+ if "names" not in data and "nc" not in data:
422
+ raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
423
+ if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
424
+ raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
425
+ if "names" not in data:
426
+ data["names"] = [f"class_{i}" for i in range(data["nc"])]
427
+ else:
428
+ data["nc"] = len(data["names"])
429
+
430
+ data["names"] = check_class_names(data["names"])
431
+ data["channels"] = data.get("channels", 3) # get image channels, default to 3
432
+
433
+ # Resolve paths
434
+ path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
435
+ if not path.exists() and not path.is_absolute():
436
+ path = (DATASETS_DIR / path).resolve() # path relative to DATASETS_DIR
437
+
438
+ # Set paths
439
+ data["path"] = path # download scripts
440
+ for k in "train", "val", "test", "minival":
441
+ if data.get(k): # prepend path
442
+ if isinstance(data[k], str):
443
+ x = (path / data[k]).resolve()
444
+ if not x.exists() and data[k].startswith("../"):
445
+ x = (path / data[k][3:]).resolve()
446
+ data[k] = str(x)
447
+ else:
448
+ data[k] = [str((path / x).resolve()) for x in data[k]]
449
+
450
+ # Parse YAML
451
+ val, s = (data.get(x) for x in ("val", "download"))
452
+ if val:
453
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
454
+ if not all(x.exists() for x in val):
455
+ name = clean_url(dataset) # dataset name with URL auth stripped
456
+ LOGGER.info("")
457
+ m = f"Dataset '{name}' images not found, missing path '{next(x for x in val if not x.exists())}'"
458
+ if s and autodownload:
459
+ LOGGER.warning(m)
460
+ else:
461
+ m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'"
462
+ raise FileNotFoundError(m)
463
+ t = time.time()
464
+ r = None # success
465
+ if s.startswith("http") and s.endswith(".zip"): # URL
466
+ safe_download(url=s, dir=DATASETS_DIR, delete=True)
467
+ elif s.startswith("bash "): # bash script
468
+ LOGGER.info(f"Running {s} ...")
469
+ subprocess.run(s.split(), check=True)
470
+ else: # python script
471
+ exec(s, {"yaml": data})
472
+ dt = f"({round(time.time() - t, 1)}s)"
473
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
474
+ LOGGER.info(f"Dataset download {s}\n")
475
+ check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
476
+
477
+ return data # dictionary
478
+
479
+
480
+ def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
481
+ """Check a classification dataset such as Imagenet.
482
+
483
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the
484
+ dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
485
+
486
+ Args:
487
+ dataset (str | Path): The name of the dataset.
488
+ split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
489
+
490
+ Returns:
491
+ (dict[str, Any]): A dictionary containing the following keys:
492
+
493
+ - 'train' (Path): The directory path containing the training set of the dataset.
494
+ - 'val' (Path): The directory path containing the validation set of the dataset.
495
+ - 'test' (Path): The directory path containing the test set of the dataset.
496
+ - 'nc' (int): The number of classes in the dataset.
497
+ - 'names' (dict[int, str]): A dictionary of class names in the dataset.
498
+ """
499
+ # Download (optional if dataset=https://file.zip is passed directly)
500
+ if str(dataset).startswith(("http:/", "https:/")):
501
+ dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
502
+ elif str(dataset).endswith((".zip", ".tar", ".gz")):
503
+ file = check_file(dataset)
504
+ dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
505
+
506
+ dataset = Path(dataset)
507
+ data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
508
+ if not data_dir.is_dir():
509
+ if data_dir.suffix != "":
510
+ raise ValueError(
511
+ f'Classification datasets must be a directory (data="path/to/dir") not a file (data="{dataset}"), '
512
+ "See https://docs.ultralytics.com/datasets/classify/"
513
+ )
514
+ LOGGER.info("")
515
+ LOGGER.warning(f"Dataset not found, missing path {data_dir}, attempting download...")
516
+ t = time.time()
517
+ if str(dataset) == "imagenet":
518
+ subprocess.run(["bash", str(ROOT / "data/scripts/get_imagenet.sh")], check=True)
519
+ else:
520
+ download(f"{ASSETS_URL}/{dataset}.zip", dir=data_dir.parent)
521
+ LOGGER.info(f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n")
522
+ train_set = data_dir / "train"
523
+ if not train_set.is_dir():
524
+ LOGGER.warning(f"Dataset 'split=train' not found at {train_set}")
525
+ if image_files := list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png")):
526
+ from ultralytics.data.split import split_classify_dataset
527
+
528
+ LOGGER.info(f"Found {len(image_files)} images in subdirectories. Attempting to split...")
529
+ data_dir = split_classify_dataset(data_dir, train_ratio=0.8)
530
+ train_set = data_dir / "train"
531
+ else:
532
+ LOGGER.error(f"No images found in {data_dir} or its subdirectories.")
533
+ val_set = (
534
+ data_dir / "val"
535
+ if (data_dir / "val").exists()
536
+ else data_dir / "validation"
537
+ if (data_dir / "validation").exists()
538
+ else data_dir / "valid"
539
+ if (data_dir / "valid").exists()
540
+ else None
541
+ ) # data/test or data/val
542
+ test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
543
+ if split == "val" and not val_set:
544
+ LOGGER.warning("Dataset 'split=val' not found, using 'split=test' instead.")
545
+ val_set = test_set
546
+ elif split == "test" and not test_set:
547
+ LOGGER.warning("Dataset 'split=test' not found, using 'split=val' instead.")
548
+ test_set = val_set
549
+
550
+ nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
551
+ names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
552
+ names = dict(enumerate(sorted(names)))
553
+
554
+ # Print to console
555
+ for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
556
+ prefix = f"{colorstr(f'{k}:')} {v}..."
557
+ if v is None:
558
+ LOGGER.info(prefix)
559
+ else:
560
+ files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
561
+ nf = len(files) # number of files
562
+ nd = len({file.parent for file in files}) # number of directories
563
+ if nf == 0:
564
+ if k == "train":
565
+ raise FileNotFoundError(f"{dataset} '{k}:' no training images found")
566
+ else:
567
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes (no images found)")
568
+ elif nd != nc:
569
+ LOGGER.error(f"{prefix} found {nf} images in {nd} classes (requires {nc} classes, not {nd})")
570
+ else:
571
+ LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
572
+
573
+ return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names, "channels": 3}
574
+
575
+
576
+ class HUBDatasetStats:
577
+ """A class for generating HUB dataset JSON and `-hub` dataset directory.
578
+
579
+ Args:
580
+ path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
581
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
582
+ autodownload (bool): Attempt to download dataset if not found locally.
583
+
584
+ Attributes:
585
+ task (str): Dataset task type.
586
+ hub_dir (Path): Directory path for HUB dataset files.
587
+ im_dir (Path): Directory path for compressed images.
588
+ stats (dict): Statistics dictionary containing dataset information.
589
+ data (dict): Dataset configuration data.
590
+
591
+ Methods:
592
+ get_json: Return dataset JSON for Ultralytics HUB.
593
+ process_images: Compress images for Ultralytics HUB.
594
+
595
+ Examples:
596
+ >>> from ultralytics.data.utils import HUBDatasetStats
597
+ >>> stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
598
+ >>> stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset
599
+ >>> stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset
600
+ >>> stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset
601
+ >>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
602
+ >>> stats.get_json(save=True)
603
+ >>> stats.process_images()
604
+
605
+ Notes:
606
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
607
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
608
+ """
609
+
610
+ def __init__(self, path: str = "coco8.yaml", task: str = "detect", autodownload: bool = False):
611
+ """Initialize class."""
612
+ path = Path(path).resolve()
613
+ LOGGER.info(f"Starting HUB dataset checks for {path}....")
614
+
615
+ self.task = task # detect, segment, pose, classify, obb
616
+ if self.task == "classify":
617
+ unzip_dir = unzip_file(path)
618
+ data = check_cls_dataset(unzip_dir)
619
+ data["path"] = unzip_dir
620
+ else: # detect, segment, pose, obb
621
+ _, data_dir, yaml_path = self._unzip(Path(path))
622
+ try:
623
+ # Load YAML with checks
624
+ data = YAML.load(yaml_path)
625
+ data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
626
+ YAML.save(yaml_path, data)
627
+ data = check_det_dataset(yaml_path, autodownload) # dict
628
+ data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
629
+ except Exception as e:
630
+ raise Exception("error/HUB/dataset_stats/init") from e
631
+
632
+ self.hub_dir = Path(f"{data['path']}-hub")
633
+ self.im_dir = self.hub_dir / "images"
634
+ self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
635
+ self.data = data
636
+
637
+ @staticmethod
638
+ def _unzip(path: Path) -> tuple[bool, str, Path]:
639
+ """Unzip data.zip."""
640
+ if not str(path).endswith(".zip"): # path is data.yaml
641
+ return False, None, path
642
+ unzip_dir = unzip_file(path, path=path.parent)
643
+ assert unzip_dir.is_dir(), (
644
+ f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
645
+ )
646
+ return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
647
+
648
+ def _hub_ops(self, f: str):
649
+ """Save a compressed image for HUB previews."""
650
+ compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
651
+
652
+ def get_json(self, save: bool = False, verbose: bool = False) -> dict:
653
+ """Return dataset JSON for Ultralytics HUB."""
654
+
655
+ def _round(labels):
656
+ """Update labels to integer class and 4 decimal place floats."""
657
+ if self.task == "detect":
658
+ coordinates = labels["bboxes"]
659
+ elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
660
+ coordinates = [x.flatten() for x in labels["segments"]]
661
+ elif self.task == "pose":
662
+ n, nk, nd = labels["keypoints"].shape
663
+ coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
664
+ else:
665
+ raise ValueError(f"Undefined dataset task={self.task}.")
666
+ zipped = zip(labels["cls"], coordinates)
667
+ return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
668
+
669
+ for split in "train", "val", "test":
670
+ self.stats[split] = None # predefine
671
+ path = self.data.get(split)
672
+
673
+ # Check split
674
+ if path is None: # no split
675
+ continue
676
+ files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
677
+ if not files: # no images
678
+ continue
679
+
680
+ # Get dataset statistics
681
+ if self.task == "classify":
682
+ from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
683
+
684
+ dataset = ImageFolder(self.data[split])
685
+
686
+ x = np.zeros(len(dataset.classes)).astype(int)
687
+ for im in dataset.imgs:
688
+ x[im[1]] += 1
689
+
690
+ self.stats[split] = {
691
+ "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
692
+ "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
693
+ "labels": [{Path(k).name: v} for k, v in dataset.imgs],
694
+ }
695
+ else:
696
+ from ultralytics.data import YOLODataset
697
+
698
+ dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
699
+ x = np.array(
700
+ [
701
+ np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
702
+ for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
703
+ ]
704
+ ) # shape(128x80)
705
+ self.stats[split] = {
706
+ "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
707
+ "image_stats": {
708
+ "total": len(dataset),
709
+ "unlabelled": int(np.all(x == 0, 1).sum()),
710
+ "per_class": (x > 0).sum(0).tolist(),
711
+ },
712
+ "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
713
+ }
714
+
715
+ # Save, print and return
716
+ if save:
717
+ self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
718
+ stats_path = self.hub_dir / "stats.json"
719
+ LOGGER.info(f"Saving {stats_path.resolve()}...")
720
+ with open(stats_path, "w", encoding="utf-8") as f:
721
+ json.dump(self.stats, f) # save stats.json
722
+ if verbose:
723
+ LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
724
+ return self.stats
725
+
726
+ def process_images(self) -> Path:
727
+ """Compress images for Ultralytics HUB."""
728
+ from ultralytics.data import YOLODataset # ClassificationDataset
729
+
730
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
731
+ for split in "train", "val", "test":
732
+ if self.data.get(split) is None:
733
+ continue
734
+ dataset = YOLODataset(img_path=self.data[split], data=self.data)
735
+ with ThreadPool(NUM_THREADS) as pool:
736
+ for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
737
+ pass
738
+ LOGGER.info(f"Done. All images saved to {self.im_dir}")
739
+ return self.im_dir
740
+
741
+
742
+ def compress_one_image(f: str, f_new: str | None = None, max_dim: int = 1920, quality: int = 50):
743
+ """Compress a single image file to reduced size while preserving its aspect ratio and quality using either the
744
+ Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it
745
+ will not be resized.
746
+
747
+ Args:
748
+ f (str): The path to the input image file.
749
+ f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
750
+ max_dim (int, optional): The maximum dimension (width or height) of the output image.
751
+ quality (int, optional): The image compression quality as a percentage.
752
+
753
+ Examples:
754
+ >>> from pathlib import Path
755
+ >>> from ultralytics.data.utils import compress_one_image
756
+ >>> for f in Path("path/to/dataset").rglob("*.jpg"):
757
+ >>> compress_one_image(f)
758
+ """
759
+ try: # use PIL
760
+ Image.MAX_IMAGE_PIXELS = None # Fix DecompressionBombError, allow optimization of image > ~178.9 million pixels
761
+ im = Image.open(f)
762
+ if im.mode in {"RGBA", "LA"}: # Convert to RGB if needed (for JPEG)
763
+ im = im.convert("RGB")
764
+ r = max_dim / max(im.height, im.width) # ratio
765
+ if r < 1.0: # image too large
766
+ im = im.resize((int(im.width * r), int(im.height * r)))
767
+ im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
768
+ except Exception as e: # use OpenCV
769
+ LOGGER.warning(f"HUB ops PIL failure {f}: {e}")
770
+ im = cv2.imread(f)
771
+ im_height, im_width = im.shape[:2]
772
+ r = max_dim / max(im_height, im_width) # ratio
773
+ if r < 1.0: # image too large
774
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
775
+ cv2.imwrite(str(f_new or f), im)
776
+
777
+
778
+ def load_dataset_cache_file(path: Path) -> dict:
779
+ """Load an Ultralytics *.cache dictionary from path."""
780
+ import gc
781
+
782
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
783
+ cache = np.load(str(path), allow_pickle=True).item() # load dict
784
+ gc.enable()
785
+ return cache
786
+
787
+
788
+ def save_dataset_cache_file(prefix: str, path: Path, x: dict, version: str):
789
+ """Save an Ultralytics dataset *.cache dictionary x to path."""
790
+ x["version"] = version # add cache version
791
+ if is_dir_writeable(path.parent):
792
+ if path.exists():
793
+ path.unlink() # remove *.cache file if exists
794
+ with open(str(path), "wb") as file: # context manager here fixes windows async np.save bug
795
+ np.save(file, x)
796
+ LOGGER.info(f"{prefix}New cache created: {path}")
797
+ else:
798
+ LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writable, cache not saved.")