dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/data/utils.py CHANGED
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import json
4
6
  import os
5
7
  import random
@@ -9,6 +11,7 @@ import zipfile
9
11
  from multiprocessing.pool import ThreadPool
10
12
  from pathlib import Path
11
13
  from tarfile import is_tarfile
14
+ from typing import Any
12
15
 
13
16
  import cv2
14
17
  import numpy as np
@@ -16,9 +19,9 @@ from PIL import Image, ImageOps
16
19
 
17
20
  from ultralytics.nn.autobackend import check_class_names
18
21
  from ultralytics.utils import (
22
+ ASSETS_URL,
19
23
  DATASETS_DIR,
20
24
  LOGGER,
21
- MACOS,
22
25
  NUM_THREADS,
23
26
  ROOT,
24
27
  SETTINGS_FILE,
@@ -36,25 +39,25 @@ from ultralytics.utils.ops import segments2boxes
36
39
  HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
37
40
  IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
38
41
  VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
39
- PIN_MEMORY = str(os.getenv("PIN_MEMORY", not MACOS)).lower() == "true" # global pin_memory for dataloaders
40
42
  FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
41
43
 
42
44
 
43
- def img2label_paths(img_paths):
44
- """Define label paths as a function of image paths."""
45
+ def img2label_paths(img_paths: list[str]) -> list[str]:
46
+ """Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'."""
45
47
  sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
46
48
  return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
47
49
 
48
50
 
49
- def check_file_speeds(files, threshold_ms=10, threshold_mb=50, max_files=5, prefix=""):
50
- """
51
- Check dataset file access speed and provide performance feedback.
51
+ def check_file_speeds(
52
+ files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
53
+ ):
54
+ """Check dataset file access speed and provide performance feedback.
52
55
 
53
- This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
54
- It samples up to 5 files from the provided list and warns if access times exceed the threshold.
56
+ This function tests the access speed of dataset files by measuring ping (stat call) time and read speed. It samples
57
+ up to 5 files from the provided list and warns if access times exceed the threshold.
55
58
 
56
59
  Args:
57
- files (list): List of file paths to check for access speed.
60
+ files (list[str]): List of file paths to check for access speed.
58
61
  threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
59
62
  threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.
60
63
  max_files (int, optional): The maximum number of files to check.
@@ -65,7 +68,7 @@ def check_file_speeds(files, threshold_ms=10, threshold_mb=50, max_files=5, pref
65
68
  >>> image_files = list(Path("dataset/images").glob("*.jpg"))
66
69
  >>> check_file_speeds(image_files, threshold_ms=15)
67
70
  """
68
- if not files or len(files) == 0:
71
+ if not files:
69
72
  LOGGER.warning(f"{prefix}Image speed checks: No files to check")
70
73
  return
71
74
 
@@ -122,8 +125,8 @@ def check_file_speeds(files, threshold_ms=10, threshold_mb=50, max_files=5, pref
122
125
  )
123
126
 
124
127
 
125
- def get_hash(paths):
126
- """Returns a single hash value of a list of paths (files or dirs)."""
128
+ def get_hash(paths: list[str]) -> str:
129
+ """Return a single hash value of a list of paths (files or dirs)."""
127
130
  size = 0
128
131
  for p in paths:
129
132
  try:
@@ -135,8 +138,8 @@ def get_hash(paths):
135
138
  return h.hexdigest() # return hash
136
139
 
137
140
 
138
- def exif_size(img: Image.Image):
139
- """Returns exif-corrected PIL size."""
141
+ def exif_size(img: Image.Image) -> tuple[int, int]:
142
+ """Return exif-corrected PIL size."""
140
143
  s = img.size # (width, height)
141
144
  if img.format == "JPEG": # only support JPEG images
142
145
  try:
@@ -149,7 +152,7 @@ def exif_size(img: Image.Image):
149
152
  return s
150
153
 
151
154
 
152
- def verify_image(args):
155
+ def verify_image(args: tuple) -> tuple:
153
156
  """Verify one image."""
154
157
  (im_file, cls), prefix = args
155
158
  # Number (found, corrupt), message
@@ -174,7 +177,7 @@ def verify_image(args):
174
177
  return (im_file, cls), nf, nc, msg
175
178
 
176
179
 
177
- def verify_image_label(args):
180
+ def verify_image_label(args: tuple) -> list:
178
181
  """Verify one image-label pair."""
179
182
  im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args
180
183
  # Number (missing, found, empty, corrupt), message, segments, keypoints
@@ -211,13 +214,12 @@ def verify_image_label(args):
211
214
  else:
212
215
  assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
213
216
  points = lb[:, 1:]
214
- assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
215
- assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
217
+ # Coordinate points check with 1% tolerance
218
+ assert points.max() <= 1.01, f"non-normalized or out of bounds coordinates {points[points > 1.01]}"
219
+ assert lb.min() >= -0.01, f"negative class labels or coordinate {lb[lb < -0.01]}"
216
220
 
217
221
  # All labels
218
- if single_cls:
219
- lb[:, 0] = 0
220
- max_cls = lb[:, 0].max() # max label count
222
+ max_cls = 0 if single_cls else lb[:, 0].max() # max label count
221
223
  assert max_cls < num_cls, (
222
224
  f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
223
225
  f"Possible class labels are 0-{num_cls - 1}"
@@ -233,7 +235,7 @@ def verify_image_label(args):
233
235
  lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
234
236
  else:
235
237
  nm = 1 # label missing
236
- lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32)
238
+ lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
237
239
  if keypoint:
238
240
  keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
239
241
  if ndim == 2:
@@ -247,19 +249,18 @@ def verify_image_label(args):
247
249
  return [None, None, None, None, None, nm, nf, ne, nc, msg]
248
250
 
249
251
 
250
- def visualize_image_annotations(image_path, txt_path, label_map):
251
- """
252
- Visualizes YOLO annotations (bounding boxes and class labels) on an image.
252
+ def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
253
+ """Visualize YOLO annotations (bounding boxes and class labels) on an image.
253
254
 
254
- This function reads an image and its corresponding annotation file in YOLO format, then
255
- draws bounding boxes around detected objects and labels them with their respective class names.
256
- The bounding box colors are assigned based on the class ID, and the text color is dynamically
257
- adjusted for readability, depending on the background color's luminance.
255
+ This function reads an image and its corresponding annotation file in YOLO format, then draws bounding boxes around
256
+ detected objects and labels them with their respective class names. The bounding box colors are assigned based on
257
+ the class ID, and the text color is dynamically adjusted for readability, depending on the background color's
258
+ luminance.
258
259
 
259
260
  Args:
260
261
  image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
261
262
  txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.
262
- label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings).
263
+ label_map (dict[int, str]): A dictionary that maps class IDs (integers) to class labels (strings).
263
264
 
264
265
  Examples:
265
266
  >>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details
@@ -280,7 +281,7 @@ def visualize_image_annotations(image_path, txt_path, label_map):
280
281
  w = width * img_width
281
282
  h = height * img_height
282
283
  annotations.append((x, y, w, h, int(class_id)))
283
- fig, ax = plt.subplots(1) # Plot the image and annotations
284
+ _, ax = plt.subplots(1) # Plot the image and annotations
284
285
  for x, y, w, h, label in annotations:
285
286
  color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color
286
287
  rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
@@ -291,14 +292,15 @@ def visualize_image_annotations(image_path, txt_path, label_map):
291
292
  plt.show()
292
293
 
293
294
 
294
- def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
295
- """
296
- Convert a list of polygons to a binary mask of the specified image size.
295
+ def polygon2mask(
296
+ imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
297
+ ) -> np.ndarray:
298
+ """Convert a list of polygons to a binary mask of the specified image size.
297
299
 
298
300
  Args:
299
- imgsz (tuple): The size of the image as (height, width).
300
- polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
301
- N is the number of polygons, and M is the number of points such that M % 2 = 0.
301
+ imgsz (tuple[int, int]): The size of the image as (height, width).
302
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
303
+ number of polygons, and M is the number of points such that M % 2 = 0.
302
304
  color (int, optional): The color value to fill in the polygons on the mask.
303
305
  downsample_ratio (int, optional): Factor by which to downsample the mask.
304
306
 
@@ -314,14 +316,15 @@ def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
314
316
  return cv2.resize(mask, (nw, nh))
315
317
 
316
318
 
317
- def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
318
- """
319
- Convert a list of polygons to a set of binary masks of the specified image size.
319
+ def polygons2masks(
320
+ imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
321
+ ) -> np.ndarray:
322
+ """Convert a list of polygons to a set of binary masks of the specified image size.
320
323
 
321
324
  Args:
322
- imgsz (tuple): The size of the image as (height, width).
323
- polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where
324
- N is the number of polygons, and M is the number of points such that M % 2 = 0.
325
+ imgsz (tuple[int, int]): The size of the image as (height, width).
326
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
327
+ number of polygons, and M is the number of points such that M % 2 = 0.
325
328
  color (int): The color value to fill in the polygons on the masks.
326
329
  downsample_ratio (int, optional): Factor by which to downsample each mask.
327
330
 
@@ -331,7 +334,9 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
331
334
  return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
332
335
 
333
336
 
334
- def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
337
+ def polygons2masks_overlap(
338
+ imgsz: tuple[int, int], segments: list[np.ndarray], downsample_ratio: int = 1
339
+ ) -> tuple[np.ndarray, np.ndarray]:
335
340
  """Return a (640, 640) overlap mask."""
336
341
  masks = np.zeros(
337
342
  (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
@@ -339,8 +344,13 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
339
344
  )
340
345
  areas = []
341
346
  ms = []
342
- for si in range(len(segments)):
343
- mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
347
+ for segment in segments:
348
+ mask = polygon2mask(
349
+ imgsz,
350
+ [segment.reshape(-1)],
351
+ downsample_ratio=downsample_ratio,
352
+ color=1,
353
+ )
344
354
  ms.append(mask.astype(masks.dtype))
345
355
  areas.append(mask.sum())
346
356
  areas = np.asarray(areas)
@@ -354,8 +364,7 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
354
364
 
355
365
 
356
366
  def find_dataset_yaml(path: Path) -> Path:
357
- """
358
- Find and return the YAML file associated with a Detect, Segment or Pose dataset.
367
+ """Find and return the YAML file associated with a Detect, Segment or Pose dataset.
359
368
 
360
369
  This function searches for a YAML file at the root level of the provided directory first, and if not found, it
361
370
  performs a recursive search. It prefers YAML files that have the same stem as the provided path.
@@ -374,9 +383,8 @@ def find_dataset_yaml(path: Path) -> Path:
374
383
  return files[0]
375
384
 
376
385
 
377
- def check_det_dataset(dataset, autodownload=True):
378
- """
379
- Download, verify, and/or unzip a dataset if not found locally.
386
+ def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
387
+ """Download, verify, and/or unzip a dataset if not found locally.
380
388
 
381
389
  This function checks the availability of a specified dataset, and if not found, it has the option to download and
382
390
  unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
@@ -387,7 +395,7 @@ def check_det_dataset(dataset, autodownload=True):
387
395
  autodownload (bool, optional): Whether to automatically download the dataset if not found.
388
396
 
389
397
  Returns:
390
- (dict): Parsed dataset information and paths.
398
+ (dict[str, Any]): Parsed dataset information and paths.
391
399
  """
392
400
  file = check_file(dataset)
393
401
 
@@ -446,7 +454,7 @@ def check_det_dataset(dataset, autodownload=True):
446
454
  if not all(x.exists() for x in val):
447
455
  name = clean_url(dataset) # dataset name with URL auth stripped
448
456
  LOGGER.info("")
449
- m = f"Dataset '{name}' images not found, missing path '{[x for x in val if not x.exists()][0]}'"
457
+ m = f"Dataset '{name}' images not found, missing path '{next(x for x in val if not x.exists())}'"
450
458
  if s and autodownload:
451
459
  LOGGER.warning(m)
452
460
  else:
@@ -458,7 +466,7 @@ def check_det_dataset(dataset, autodownload=True):
458
466
  safe_download(url=s, dir=DATASETS_DIR, delete=True)
459
467
  elif s.startswith("bash "): # bash script
460
468
  LOGGER.info(f"Running {s} ...")
461
- r = os.system(s)
469
+ subprocess.run(s.split(), check=True)
462
470
  else: # python script
463
471
  exec(s, {"yaml": data})
464
472
  dt = f"({round(time.time() - t, 1)}s)"
@@ -469,25 +477,24 @@ def check_det_dataset(dataset, autodownload=True):
469
477
  return data # dictionary
470
478
 
471
479
 
472
- def check_cls_dataset(dataset, split=""):
473
- """
474
- Checks a classification dataset such as Imagenet.
480
+ def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
481
+ """Check a classification dataset such as Imagenet.
475
482
 
476
- This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
477
- If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
483
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the
484
+ dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
478
485
 
479
486
  Args:
480
487
  dataset (str | Path): The name of the dataset.
481
488
  split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
482
489
 
483
490
  Returns:
484
- (dict): A dictionary containing the following keys:
491
+ (dict[str, Any]): A dictionary containing the following keys:
485
492
 
486
493
  - 'train' (Path): The directory path containing the training set of the dataset.
487
494
  - 'val' (Path): The directory path containing the validation set of the dataset.
488
495
  - 'test' (Path): The directory path containing the test set of the dataset.
489
496
  - 'nc' (int): The number of classes in the dataset.
490
- - 'names' (dict): A dictionary of class names in the dataset.
497
+ - 'names' (dict[int, str]): A dictionary of class names in the dataset.
491
498
  """
492
499
  # Download (optional if dataset=https://file.zip is passed directly)
493
500
  if str(dataset).startswith(("http:/", "https:/")):
@@ -499,20 +506,23 @@ def check_cls_dataset(dataset, split=""):
499
506
  dataset = Path(dataset)
500
507
  data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
501
508
  if not data_dir.is_dir():
509
+ if data_dir.suffix != "":
510
+ raise ValueError(
511
+ f'Classification datasets must be a directory (data="path/to/dir") not a file (data="{dataset}"), '
512
+ "See https://docs.ultralytics.com/datasets/classify/"
513
+ )
502
514
  LOGGER.info("")
503
515
  LOGGER.warning(f"Dataset not found, missing path {data_dir}, attempting download...")
504
516
  t = time.time()
505
517
  if str(dataset) == "imagenet":
506
- subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
518
+ subprocess.run(["bash", str(ROOT / "data/scripts/get_imagenet.sh")], check=True)
507
519
  else:
508
- url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
509
- download(url, dir=data_dir.parent)
520
+ download(f"{ASSETS_URL}/{dataset}.zip", dir=data_dir.parent)
510
521
  LOGGER.info(f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n")
511
522
  train_set = data_dir / "train"
512
523
  if not train_set.is_dir():
513
524
  LOGGER.warning(f"Dataset 'split=train' not found at {train_set}")
514
- image_files = list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png"))
515
- if image_files:
525
+ if image_files := list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png")):
516
526
  from ultralytics.data.split import split_classify_dataset
517
527
 
518
528
  LOGGER.info(f"Found {len(image_files)} images in subdirectories. Attempting to split...")
@@ -525,6 +535,8 @@ def check_cls_dataset(dataset, split=""):
525
535
  if (data_dir / "val").exists()
526
536
  else data_dir / "validation"
527
537
  if (data_dir / "validation").exists()
538
+ else data_dir / "valid"
539
+ if (data_dir / "valid").exists()
528
540
  else None
529
541
  ) # data/test or data/val
530
542
  test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
@@ -562,17 +574,23 @@ def check_cls_dataset(dataset, split=""):
562
574
 
563
575
 
564
576
  class HUBDatasetStats:
565
- """
566
- A class for generating HUB dataset JSON and `-hub` dataset directory.
577
+ """A class for generating HUB dataset JSON and `-hub` dataset directory.
567
578
 
568
579
  Args:
569
- path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'.
570
- task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
571
- autodownload (bool): Attempt to download dataset if not found locally. Default is False.
580
+ path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
581
+ task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
582
+ autodownload (bool): Attempt to download dataset if not found locally.
572
583
 
573
- Note:
574
- Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
575
- i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
584
+ Attributes:
585
+ task (str): Dataset task type.
586
+ hub_dir (Path): Directory path for HUB dataset files.
587
+ im_dir (Path): Directory path for compressed images.
588
+ stats (dict): Statistics dictionary containing dataset information.
589
+ data (dict): Dataset configuration data.
590
+
591
+ Methods:
592
+ get_json: Return dataset JSON for Ultralytics HUB.
593
+ process_images: Compress images for Ultralytics HUB.
576
594
 
577
595
  Examples:
578
596
  >>> from ultralytics.data.utils import HUBDatasetStats
@@ -583,9 +601,13 @@ class HUBDatasetStats:
583
601
  >>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
584
602
  >>> stats.get_json(save=True)
585
603
  >>> stats.process_images()
604
+
605
+ Notes:
606
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
607
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
586
608
  """
587
609
 
588
- def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
610
+ def __init__(self, path: str = "coco8.yaml", task: str = "detect", autodownload: bool = False):
589
611
  """Initialize class."""
590
612
  path = Path(path).resolve()
591
613
  LOGGER.info(f"Starting HUB dataset checks for {path}....")
@@ -613,7 +635,7 @@ class HUBDatasetStats:
613
635
  self.data = data
614
636
 
615
637
  @staticmethod
616
- def _unzip(path):
638
+ def _unzip(path: Path) -> tuple[bool, str, Path]:
617
639
  """Unzip data.zip."""
618
640
  if not str(path).endswith(".zip"): # path is data.yaml
619
641
  return False, None, path
@@ -623,11 +645,11 @@ class HUBDatasetStats:
623
645
  )
624
646
  return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
625
647
 
626
- def _hub_ops(self, f):
627
- """Saves a compressed image for HUB previews."""
648
+ def _hub_ops(self, f: str):
649
+ """Save a compressed image for HUB previews."""
628
650
  compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
629
651
 
630
- def get_json(self, save=False, verbose=False):
652
+ def get_json(self, save: bool = False, verbose: bool = False) -> dict:
631
653
  """Return dataset JSON for Ultralytics HUB."""
632
654
 
633
655
  def _round(labels):
@@ -701,7 +723,7 @@ class HUBDatasetStats:
701
723
  LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
702
724
  return self.stats
703
725
 
704
- def process_images(self):
726
+ def process_images(self) -> Path:
705
727
  """Compress images for Ultralytics HUB."""
706
728
  from ultralytics.data import YOLODataset # ClassificationDataset
707
729
 
@@ -717,11 +739,10 @@ class HUBDatasetStats:
717
739
  return self.im_dir
718
740
 
719
741
 
720
- def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
721
- """
722
- Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python
723
- Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
724
- resized.
742
+ def compress_one_image(f: str, f_new: str | None = None, max_dim: int = 1920, quality: int = 50):
743
+ """Compress a single image file to reduced size while preserving its aspect ratio and quality using either the
744
+ Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it
745
+ will not be resized.
725
746
 
726
747
  Args:
727
748
  f (str): The path to the input image file.
@@ -754,7 +775,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
754
775
  cv2.imwrite(str(f_new or f), im)
755
776
 
756
777
 
757
- def load_dataset_cache_file(path):
778
+ def load_dataset_cache_file(path: Path) -> dict:
758
779
  """Load an Ultralytics *.cache dictionary from path."""
759
780
  import gc
760
781
 
@@ -764,7 +785,7 @@ def load_dataset_cache_file(path):
764
785
  return cache
765
786
 
766
787
 
767
- def save_dataset_cache_file(prefix, path, x, version):
788
+ def save_dataset_cache_file(prefix: str, path: Path, x: dict, version: str):
768
789
  """Save an Ultralytics dataset *.cache dictionary x to path."""
769
790
  x["version"] = version # add cache version
770
791
  if is_dir_writeable(path.parent):
@@ -774,4 +795,4 @@ def save_dataset_cache_file(prefix, path, x, version):
774
795
  np.save(file, x)
775
796
  LOGGER.info(f"{prefix}New cache created: {path}")
776
797
  else:
777
- LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writeable, cache not saved.")
798
+ LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writable, cache not saved.")