ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/data/utils.py CHANGED
@@ -17,36 +17,47 @@ import numpy as np
17
17
  from PIL import Image, ImageOps
18
18
 
19
19
  from ultralytics.nn.autobackend import check_class_names
20
- from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, TQDM, clean_url, colorstr,
21
- emojis, yaml_load, yaml_save)
20
+ from ultralytics.utils import (
21
+ DATASETS_DIR,
22
+ LOGGER,
23
+ NUM_THREADS,
24
+ ROOT,
25
+ SETTINGS_YAML,
26
+ TQDM,
27
+ clean_url,
28
+ colorstr,
29
+ emojis,
30
+ yaml_load,
31
+ yaml_save,
32
+ )
22
33
  from ultralytics.utils.checks import check_file, check_font, is_ascii
23
34
  from ultralytics.utils.downloads import download, safe_download, unzip_file
24
35
  from ultralytics.utils.ops import segments2boxes
25
36
 
26
- HELP_URL = 'See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance.'
27
- IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
28
- VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
29
- PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
37
+ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
38
+ IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
39
+ VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
40
+ PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
30
41
 
31
42
 
32
43
  def img2label_paths(img_paths):
33
44
  """Define label paths as a function of image paths."""
34
- sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
35
- return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
45
+ sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
46
+ return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
36
47
 
37
48
 
38
49
  def get_hash(paths):
39
50
  """Returns a single hash value of a list of paths (files or dirs)."""
40
51
  size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
41
52
  h = hashlib.sha256(str(size).encode()) # hash sizes
42
- h.update(''.join(paths).encode()) # hash paths
53
+ h.update("".join(paths).encode()) # hash paths
43
54
  return h.hexdigest() # return hash
44
55
 
45
56
 
46
57
  def exif_size(img: Image.Image):
47
58
  """Returns exif-corrected PIL size."""
48
59
  s = img.size # (width, height)
49
- if img.format == 'JPEG': # only support JPEG images
60
+ if img.format == "JPEG": # only support JPEG images
50
61
  with contextlib.suppress(Exception):
51
62
  exif = img.getexif()
52
63
  if exif:
@@ -60,24 +71,24 @@ def verify_image(args):
60
71
  """Verify one image."""
61
72
  (im_file, cls), prefix = args
62
73
  # Number (found, corrupt), message
63
- nf, nc, msg = 0, 0, ''
74
+ nf, nc, msg = 0, 0, ""
64
75
  try:
65
76
  im = Image.open(im_file)
66
77
  im.verify() # PIL verify
67
78
  shape = exif_size(im) # image size
68
79
  shape = (shape[1], shape[0]) # hw
69
- assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
70
- assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
71
- if im.format.lower() in ('jpg', 'jpeg'):
72
- with open(im_file, 'rb') as f:
80
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
81
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
82
+ if im.format.lower() in ("jpg", "jpeg"):
83
+ with open(im_file, "rb") as f:
73
84
  f.seek(-2, 2)
74
- if f.read() != b'\xff\xd9': # corrupt JPEG
75
- ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
76
- msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
85
+ if f.read() != b"\xff\xd9": # corrupt JPEG
86
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
87
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
77
88
  nf = 1
78
89
  except Exception as e:
79
90
  nc = 1
80
- msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
91
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
81
92
  return (im_file, cls), nf, nc, msg
82
93
 
83
94
 
@@ -85,21 +96,21 @@ def verify_image_label(args):
85
96
  """Verify one image-label pair."""
86
97
  im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
87
98
  # Number (missing, found, empty, corrupt), message, segments, keypoints
88
- nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
99
+ nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
89
100
  try:
90
101
  # Verify images
91
102
  im = Image.open(im_file)
92
103
  im.verify() # PIL verify
93
104
  shape = exif_size(im) # image size
94
105
  shape = (shape[1], shape[0]) # hw
95
- assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
96
- assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
97
- if im.format.lower() in ('jpg', 'jpeg'):
98
- with open(im_file, 'rb') as f:
106
+ assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
107
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
108
+ if im.format.lower() in ("jpg", "jpeg"):
109
+ with open(im_file, "rb") as f:
99
110
  f.seek(-2, 2)
100
- if f.read() != b'\xff\xd9': # corrupt JPEG
101
- ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
102
- msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
111
+ if f.read() != b"\xff\xd9": # corrupt JPEG
112
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
113
+ msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
103
114
 
104
115
  # Verify labels
105
116
  if os.path.isfile(lb_file):
@@ -114,25 +125,26 @@ def verify_image_label(args):
114
125
  nl = len(lb)
115
126
  if nl:
116
127
  if keypoint:
117
- assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
128
+ assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
118
129
  points = lb[:, 5:].reshape(-1, ndim)[:, :2]
119
130
  else:
120
- assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
131
+ assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
121
132
  points = lb[:, 1:]
122
- assert points.max() <= 1, f'non-normalized or out of bounds coordinates {points[points > 1]}'
123
- assert lb.min() >= 0, f'negative label values {lb[lb < 0]}'
133
+ assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}"
134
+ assert lb.min() >= 0, f"negative label values {lb[lb < 0]}"
124
135
 
125
136
  # All labels
126
137
  max_cls = lb[:, 0].max() # max label count
127
- assert max_cls <= num_cls, \
128
- f'Label class {int(max_cls)} exceeds dataset class count {num_cls}. ' \
129
- f'Possible class labels are 0-{num_cls - 1}'
138
+ assert max_cls <= num_cls, (
139
+ f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
140
+ f"Possible class labels are 0-{num_cls - 1}"
141
+ )
130
142
  _, i = np.unique(lb, axis=0, return_index=True)
131
143
  if len(i) < nl: # duplicate row check
132
144
  lb = lb[i] # remove duplicates
133
145
  if segments:
134
146
  segments = [segments[x] for x in i]
135
- msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
147
+ msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
136
148
  else:
137
149
  ne = 1 # label empty
138
150
  lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
@@ -148,7 +160,7 @@ def verify_image_label(args):
148
160
  return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
149
161
  except Exception as e:
150
162
  nc = 1
151
- msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
163
+ msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
152
164
  return [None, None, None, None, None, nm, nf, ne, nc, msg]
153
165
 
154
166
 
@@ -194,8 +206,10 @@ def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
194
206
 
195
207
  def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
196
208
  """Return a (640, 640) overlap mask."""
197
- masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
198
- dtype=np.int32 if len(segments) > 255 else np.uint8)
209
+ masks = np.zeros(
210
+ (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
211
+ dtype=np.int32 if len(segments) > 255 else np.uint8,
212
+ )
199
213
  areas = []
200
214
  ms = []
201
215
  for si in range(len(segments)):
@@ -226,7 +240,7 @@ def find_dataset_yaml(path: Path) -> Path:
226
240
  Returns:
227
241
  (Path): The path of the found YAML file.
228
242
  """
229
- files = list(path.glob('*.yaml')) or list(path.rglob('*.yaml')) # try root level first and then recursive
243
+ files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
230
244
  assert files, f"No YAML file found in '{path.resolve()}'"
231
245
  if len(files) > 1:
232
246
  files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
@@ -253,7 +267,7 @@ def check_det_dataset(dataset, autodownload=True):
253
267
  file = check_file(dataset)
254
268
 
255
269
  # Download (optional)
256
- extract_dir = ''
270
+ extract_dir = ""
257
271
  if zipfile.is_zipfile(file) or is_tarfile(file):
258
272
  new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
259
273
  file = find_dataset_yaml(DATASETS_DIR / new_dir)
@@ -263,43 +277,44 @@ def check_det_dataset(dataset, autodownload=True):
263
277
  data = yaml_load(file, append_filename=True) # dictionary
264
278
 
265
279
  # Checks
266
- for k in 'train', 'val':
280
+ for k in "train", "val":
267
281
  if k not in data:
268
- if k != 'val' or 'validation' not in data:
282
+ if k != "val" or "validation" not in data:
269
283
  raise SyntaxError(
270
- emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
284
+ emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
285
+ )
271
286
  LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.")
272
- data['val'] = data.pop('validation') # replace 'validation' key with 'val' key
273
- if 'names' not in data and 'nc' not in data:
287
+ data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
288
+ if "names" not in data and "nc" not in data:
274
289
  raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
275
- if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
290
+ if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
276
291
  raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
277
- if 'names' not in data:
278
- data['names'] = [f'class_{i}' for i in range(data['nc'])]
292
+ if "names" not in data:
293
+ data["names"] = [f"class_{i}" for i in range(data["nc"])]
279
294
  else:
280
- data['nc'] = len(data['names'])
295
+ data["nc"] = len(data["names"])
281
296
 
282
- data['names'] = check_class_names(data['names'])
297
+ data["names"] = check_class_names(data["names"])
283
298
 
284
299
  # Resolve paths
285
- path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
300
+ path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
286
301
  if not path.is_absolute():
287
302
  path = (DATASETS_DIR / path).resolve()
288
303
 
289
304
  # Set paths
290
- data['path'] = path # download scripts
291
- for k in 'train', 'val', 'test':
305
+ data["path"] = path # download scripts
306
+ for k in "train", "val", "test":
292
307
  if data.get(k): # prepend path
293
308
  if isinstance(data[k], str):
294
309
  x = (path / data[k]).resolve()
295
- if not x.exists() and data[k].startswith('../'):
310
+ if not x.exists() and data[k].startswith("../"):
296
311
  x = (path / data[k][3:]).resolve()
297
312
  data[k] = str(x)
298
313
  else:
299
314
  data[k] = [str((path / x).resolve()) for x in data[k]]
300
315
 
301
316
  # Parse YAML
302
- val, s = (data.get(x) for x in ('val', 'download'))
317
+ val, s = (data.get(x) for x in ("val", "download"))
303
318
  if val:
304
319
  val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
305
320
  if not all(x.exists() for x in val):
@@ -312,22 +327,22 @@ def check_det_dataset(dataset, autodownload=True):
312
327
  raise FileNotFoundError(m)
313
328
  t = time.time()
314
329
  r = None # success
315
- if s.startswith('http') and s.endswith('.zip'): # URL
330
+ if s.startswith("http") and s.endswith(".zip"): # URL
316
331
  safe_download(url=s, dir=DATASETS_DIR, delete=True)
317
- elif s.startswith('bash '): # bash script
318
- LOGGER.info(f'Running {s} ...')
332
+ elif s.startswith("bash "): # bash script
333
+ LOGGER.info(f"Running {s} ...")
319
334
  r = os.system(s)
320
335
  else: # python script
321
- exec(s, {'yaml': data})
322
- dt = f'({round(time.time() - t, 1)}s)'
323
- s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
324
- LOGGER.info(f'Dataset download {s}\n')
325
- check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
336
+ exec(s, {"yaml": data})
337
+ dt = f"({round(time.time() - t, 1)}s)"
338
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
339
+ LOGGER.info(f"Dataset download {s}\n")
340
+ check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
326
341
 
327
342
  return data # dictionary
328
343
 
329
344
 
330
- def check_cls_dataset(dataset, split=''):
345
+ def check_cls_dataset(dataset, split=""):
331
346
  """
332
347
  Checks a classification dataset such as Imagenet.
333
348
 
@@ -348,54 +363,59 @@ def check_cls_dataset(dataset, split=''):
348
363
  """
349
364
 
350
365
  # Download (optional if dataset=https://file.zip is passed directly)
351
- if str(dataset).startswith(('http:/', 'https:/')):
366
+ if str(dataset).startswith(("http:/", "https:/")):
352
367
  dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
353
368
 
354
369
  dataset = Path(dataset)
355
370
  data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
356
371
  if not data_dir.is_dir():
357
- LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
372
+ LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...")
358
373
  t = time.time()
359
- if str(dataset) == 'imagenet':
374
+ if str(dataset) == "imagenet":
360
375
  subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
361
376
  else:
362
- url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
377
+ url = f"https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip"
363
378
  download(url, dir=data_dir.parent)
364
379
  s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
365
380
  LOGGER.info(s)
366
- train_set = data_dir / 'train'
367
- val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if \
368
- (data_dir / 'validation').exists() else None # data/test or data/val
369
- test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
370
- if split == 'val' and not val_set:
381
+ train_set = data_dir / "train"
382
+ val_set = (
383
+ data_dir / "val"
384
+ if (data_dir / "val").exists()
385
+ else data_dir / "validation"
386
+ if (data_dir / "validation").exists()
387
+ else None
388
+ ) # data/test or data/val
389
+ test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
390
+ if split == "val" and not val_set:
371
391
  LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
372
- elif split == 'test' and not test_set:
392
+ elif split == "test" and not test_set:
373
393
  LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
374
394
 
375
- nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
376
- names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
395
+ nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
396
+ names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
377
397
  names = dict(enumerate(sorted(names)))
378
398
 
379
399
  # Print to console
380
- for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
400
+ for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
381
401
  prefix = f'{colorstr(f"{k}:")} {v}...'
382
402
  if v is None:
383
403
  LOGGER.info(prefix)
384
404
  else:
385
- files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
405
+ files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
386
406
  nf = len(files) # number of files
387
407
  nd = len({file.parent for file in files}) # number of directories
388
408
  if nf == 0:
389
- if k == 'train':
409
+ if k == "train":
390
410
  raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
391
411
  else:
392
- LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
412
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found")
393
413
  elif nd != nc:
394
- LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
414
+ LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}")
395
415
  else:
396
- LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
416
+ LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
397
417
 
398
- return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
418
+ return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
399
419
 
400
420
 
401
421
  class HUBDatasetStats:
@@ -423,42 +443,43 @@ class HUBDatasetStats:
423
443
  ```
424
444
  """
425
445
 
426
- def __init__(self, path='coco8.yaml', task='detect', autodownload=False):
446
+ def __init__(self, path="coco8.yaml", task="detect", autodownload=False):
427
447
  """Initialize class."""
428
448
  path = Path(path).resolve()
429
- LOGGER.info(f'Starting HUB dataset checks for {path}....')
449
+ LOGGER.info(f"Starting HUB dataset checks for {path}....")
430
450
 
431
451
  self.task = task # detect, segment, pose, classify
432
- if self.task == 'classify':
452
+ if self.task == "classify":
433
453
  unzip_dir = unzip_file(path)
434
454
  data = check_cls_dataset(unzip_dir)
435
- data['path'] = unzip_dir
455
+ data["path"] = unzip_dir
436
456
  else: # detect, segment, pose
437
457
  _, data_dir, yaml_path = self._unzip(Path(path))
438
458
  try:
439
459
  # Load YAML with checks
440
460
  data = yaml_load(yaml_path)
441
- data['path'] = '' # strip path since YAML should be in dataset root for all HUB datasets
461
+ data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
442
462
  yaml_save(yaml_path, data)
443
463
  data = check_det_dataset(yaml_path, autodownload) # dict
444
- data['path'] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
464
+ data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
445
465
  except Exception as e:
446
- raise Exception('error/HUB/dataset_stats/init') from e
466
+ raise Exception("error/HUB/dataset_stats/init") from e
447
467
 
448
468
  self.hub_dir = Path(f'{data["path"]}-hub')
449
- self.im_dir = self.hub_dir / 'images'
469
+ self.im_dir = self.hub_dir / "images"
450
470
  self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
451
- self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
471
+ self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
452
472
  self.data = data
453
473
 
454
474
  @staticmethod
455
475
  def _unzip(path):
456
476
  """Unzip data.zip."""
457
- if not str(path).endswith('.zip'): # path is data.yaml
477
+ if not str(path).endswith(".zip"): # path is data.yaml
458
478
  return False, None, path
459
479
  unzip_dir = unzip_file(path, path=path.parent)
460
- assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
461
- f'path/to/abc.zip MUST unzip to path/to/abc/'
480
+ assert unzip_dir.is_dir(), (
481
+ f"Error unzipping {path}, {unzip_dir} not found. " f"path/to/abc.zip MUST unzip to path/to/abc/"
482
+ )
462
483
  return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
463
484
 
464
485
  def _hub_ops(self, f):
@@ -470,31 +491,31 @@ class HUBDatasetStats:
470
491
 
471
492
  def _round(labels):
472
493
  """Update labels to integer class and 4 decimal place floats."""
473
- if self.task == 'detect':
474
- coordinates = labels['bboxes']
475
- elif self.task == 'segment':
476
- coordinates = [x.flatten() for x in labels['segments']]
477
- elif self.task == 'pose':
478
- n = labels['keypoints'].shape[0]
479
- coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
494
+ if self.task == "detect":
495
+ coordinates = labels["bboxes"]
496
+ elif self.task == "segment":
497
+ coordinates = [x.flatten() for x in labels["segments"]]
498
+ elif self.task == "pose":
499
+ n = labels["keypoints"].shape[0]
500
+ coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, -1)), 1)
480
501
  else:
481
- raise ValueError('Undefined dataset task.')
482
- zipped = zip(labels['cls'], coordinates)
502
+ raise ValueError("Undefined dataset task.")
503
+ zipped = zip(labels["cls"], coordinates)
483
504
  return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
484
505
 
485
- for split in 'train', 'val', 'test':
506
+ for split in "train", "val", "test":
486
507
  self.stats[split] = None # predefine
487
508
  path = self.data.get(split)
488
509
 
489
510
  # Check split
490
511
  if path is None: # no split
491
512
  continue
492
- files = [f for f in Path(path).rglob('*.*') if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
513
+ files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
493
514
  if not files: # no images
494
515
  continue
495
516
 
496
517
  # Get dataset statistics
497
- if self.task == 'classify':
518
+ if self.task == "classify":
498
519
  from torchvision.datasets import ImageFolder
499
520
 
500
521
  dataset = ImageFolder(self.data[split])
@@ -504,38 +525,35 @@ class HUBDatasetStats:
504
525
  x[im[1]] += 1
505
526
 
506
527
  self.stats[split] = {
507
- 'instance_stats': {
508
- 'total': len(dataset),
509
- 'per_class': x.tolist()},
510
- 'image_stats': {
511
- 'total': len(dataset),
512
- 'unlabelled': 0,
513
- 'per_class': x.tolist()},
514
- 'labels': [{
515
- Path(k).name: v} for k, v in dataset.imgs]}
528
+ "instance_stats": {"total": len(dataset), "per_class": x.tolist()},
529
+ "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
530
+ "labels": [{Path(k).name: v} for k, v in dataset.imgs],
531
+ }
516
532
  else:
517
533
  from ultralytics.data import YOLODataset
518
534
 
519
535
  dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
520
- x = np.array([
521
- np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
522
- for label in TQDM(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
536
+ x = np.array(
537
+ [
538
+ np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
539
+ for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
540
+ ]
541
+ ) # shape(128x80)
523
542
  self.stats[split] = {
524
- 'instance_stats': {
525
- 'total': int(x.sum()),
526
- 'per_class': x.sum(0).tolist()},
527
- 'image_stats': {
528
- 'total': len(dataset),
529
- 'unlabelled': int(np.all(x == 0, 1).sum()),
530
- 'per_class': (x > 0).sum(0).tolist()},
531
- 'labels': [{
532
- Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
543
+ "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
544
+ "image_stats": {
545
+ "total": len(dataset),
546
+ "unlabelled": int(np.all(x == 0, 1).sum()),
547
+ "per_class": (x > 0).sum(0).tolist(),
548
+ },
549
+ "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
550
+ }
533
551
 
534
552
  # Save, print and return
535
553
  if save:
536
- stats_path = self.hub_dir / 'stats.json'
537
- LOGGER.info(f'Saving {stats_path.resolve()}...')
538
- with open(stats_path, 'w') as f:
554
+ stats_path = self.hub_dir / "stats.json"
555
+ LOGGER.info(f"Saving {stats_path.resolve()}...")
556
+ with open(stats_path, "w") as f:
539
557
  json.dump(self.stats, f) # save stats.json
540
558
  if verbose:
541
559
  LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
@@ -545,14 +563,14 @@ class HUBDatasetStats:
545
563
  """Compress images for Ultralytics HUB."""
546
564
  from ultralytics.data import YOLODataset # ClassificationDataset
547
565
 
548
- for split in 'train', 'val', 'test':
566
+ for split in "train", "val", "test":
549
567
  if self.data.get(split) is None:
550
568
  continue
551
569
  dataset = YOLODataset(img_path=self.data[split], data=self.data)
552
570
  with ThreadPool(NUM_THREADS) as pool:
553
- for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
571
+ for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
554
572
  pass
555
- LOGGER.info(f'Done. All images saved to {self.im_dir}')
573
+ LOGGER.info(f"Done. All images saved to {self.im_dir}")
556
574
  return self.im_dir
557
575
 
558
576
 
@@ -583,9 +601,9 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
583
601
  r = max_dim / max(im.height, im.width) # ratio
584
602
  if r < 1.0: # image too large
585
603
  im = im.resize((int(im.width * r), int(im.height * r)))
586
- im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
604
+ im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
587
605
  except Exception as e: # use OpenCV
588
- LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
606
+ LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}")
589
607
  im = cv2.imread(f)
590
608
  im_height, im_width = im.shape[:2]
591
609
  r = max_dim / max(im_height, im_width) # ratio
@@ -594,7 +612,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
594
612
  cv2.imwrite(str(f_new or f), im)
595
613
 
596
614
 
597
- def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
615
+ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False):
598
616
  """
599
617
  Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
600
618
 
@@ -612,18 +630,18 @@ def autosplit(path=DATASETS_DIR / 'coco8/images', weights=(0.9, 0.1, 0.0), annot
612
630
  """
613
631
 
614
632
  path = Path(path) # images dir
615
- files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
633
+ files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
616
634
  n = len(files) # number of files
617
635
  random.seed(0) # for reproducibility
618
636
  indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
619
637
 
620
- txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
638
+ txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
621
639
  for x in txt:
622
640
  if (path.parent / x).exists():
623
641
  (path.parent / x).unlink() # remove existing
624
642
 
625
- LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
643
+ LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
626
644
  for i, img in TQDM(zip(indices, files), total=n):
627
645
  if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
628
- with open(path.parent / txt[i], 'a') as f:
629
- f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
646
+ with open(path.parent / txt[i], "a") as f:
647
+ f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file