ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/utils/dist.py CHANGED
@@ -18,13 +18,13 @@ def find_free_network_port() -> int:
18
18
  `MASTER_PORT` environment variable.
19
19
  """
20
20
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
21
- s.bind(('127.0.0.1', 0))
21
+ s.bind(("127.0.0.1", 0))
22
22
  return s.getsockname()[1] # port
23
23
 
24
24
 
25
25
  def generate_ddp_file(trainer):
26
26
  """Generates a DDP file and returns its file name."""
27
- module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
27
+ module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
28
28
 
29
29
  content = f"""
30
30
  # Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
@@ -39,13 +39,15 @@ if __name__ == "__main__":
39
39
  trainer = {name}(cfg=cfg, overrides=overrides)
40
40
  results = trainer.train()
41
41
  """
42
- (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
43
- with tempfile.NamedTemporaryFile(prefix='_temp_',
44
- suffix=f'{id(trainer)}.py',
45
- mode='w+',
46
- encoding='utf-8',
47
- dir=USER_CONFIG_DIR / 'DDP',
48
- delete=False) as file:
42
+ (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
43
+ with tempfile.NamedTemporaryFile(
44
+ prefix="_temp_",
45
+ suffix=f"{id(trainer)}.py",
46
+ mode="w+",
47
+ encoding="utf-8",
48
+ dir=USER_CONFIG_DIR / "DDP",
49
+ delete=False,
50
+ ) as file:
49
51
  file.write(content)
50
52
  return file.name
51
53
 
@@ -53,16 +55,17 @@ if __name__ == "__main__":
53
55
  def generate_ddp_command(world_size, trainer):
54
56
  """Generates and returns command for distributed training."""
55
57
  import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
58
+
56
59
  if not trainer.resume:
57
60
  shutil.rmtree(trainer.save_dir) # remove the save_dir
58
61
  file = generate_ddp_file(trainer)
59
- dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
62
+ dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
60
63
  port = find_free_network_port()
61
- cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
64
+ cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
62
65
  return cmd, file
63
66
 
64
67
 
65
68
  def ddp_cleanup(trainer, file):
66
69
  """Delete temp file if created."""
67
- if f'{id(trainer)}.py' in file: # if temp_file suffix in file
70
+ if f"{id(trainer)}.py" in file: # if temp_file suffix in file
68
71
  os.remove(file)
@@ -15,15 +15,17 @@ import torch
15
15
  from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
16
16
 
17
17
  # Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
18
- GITHUB_ASSETS_REPO = 'ultralytics/assets'
19
- GITHUB_ASSETS_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')] + \
20
- [f'yolov5{k}{resolution}u.pt' for k in 'nsmlx' for resolution in ('', '6')] + \
21
- [f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
22
- [f'yolo_nas_{k}.pt' for k in 'sml'] + \
23
- [f'sam_{k}.pt' for k in 'bl'] + \
24
- [f'FastSAM-{k}.pt' for k in 'sx'] + \
25
- [f'rtdetr-{k}.pt' for k in 'lx'] + \
26
- ['mobile_sam.pt']
18
+ GITHUB_ASSETS_REPO = "ultralytics/assets"
19
+ GITHUB_ASSETS_NAMES = (
20
+ [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
21
+ + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
22
+ + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
23
+ + [f"yolo_nas_{k}.pt" for k in "sml"]
24
+ + [f"sam_{k}.pt" for k in "bl"]
25
+ + [f"FastSAM-{k}.pt" for k in "sx"]
26
+ + [f"rtdetr-{k}.pt" for k in "lx"]
27
+ + ["mobile_sam.pt"]
28
+ )
27
29
  GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
28
30
 
29
31
 
@@ -56,7 +58,7 @@ def is_url(url, check=True):
56
58
  return False
57
59
 
58
60
 
59
- def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
61
+ def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
60
62
  """
61
63
  Deletes all ".DS_store" files under a specified directory.
62
64
 
@@ -77,12 +79,12 @@ def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
77
79
  """
78
80
  for file in files_to_delete:
79
81
  matches = list(Path(path).rglob(file))
80
- LOGGER.info(f'Deleting {file} files: {matches}')
82
+ LOGGER.info(f"Deleting {file} files: {matches}")
81
83
  for f in matches:
82
84
  f.unlink()
83
85
 
84
86
 
85
- def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):
87
+ def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True):
86
88
  """
87
89
  Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is
88
90
  named after the directory and placed alongside it.
@@ -111,17 +113,17 @@ def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), p
111
113
  raise FileNotFoundError(f"Directory '{directory}' does not exist.")
112
114
 
113
115
  # Unzip with progress bar
114
- files_to_zip = [f for f in directory.rglob('*') if f.is_file() and all(x not in f.name for x in exclude)]
115
- zip_file = directory.with_suffix('.zip')
116
+ files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
117
+ zip_file = directory.with_suffix(".zip")
116
118
  compression = ZIP_DEFLATED if compress else ZIP_STORED
117
- with ZipFile(zip_file, 'w', compression) as f:
118
- for file in TQDM(files_to_zip, desc=f'Zipping {directory} to {zip_file}...', unit='file', disable=not progress):
119
+ with ZipFile(zip_file, "w", compression) as f:
120
+ for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
119
121
  f.write(file, file.relative_to(directory))
120
122
 
121
123
  return zip_file # return path to zip file
122
124
 
123
125
 
124
- def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False, progress=True):
126
+ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
125
127
  """
126
128
  Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
127
129
 
@@ -161,7 +163,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
161
163
  files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
162
164
  top_level_dirs = {Path(f).parts[0] for f in files}
163
165
 
164
- if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith('/')):
166
+ if len(top_level_dirs) > 1 or (len(files) > 1 and not files[0].endswith("/")):
165
167
  # Zip has multiple files at top level
166
168
  path = extract_path = Path(path) / Path(file).stem # i.e. ../datasets/coco8
167
169
  else:
@@ -172,20 +174,20 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=Fals
172
174
  # Check if destination directory already exists and contains files
173
175
  if path.exists() and any(path.iterdir()) and not exist_ok:
174
176
  # If it exists and is not empty, return the path without unzipping
175
- LOGGER.warning(f'WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.')
177
+ LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
176
178
  return path
177
179
 
178
- for f in TQDM(files, desc=f'Unzipping {file} to {Path(path).resolve()}...', unit='file', disable=not progress):
180
+ for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
179
181
  # Ensure the file is within the extract_path to avoid path traversal security vulnerability
180
- if '..' in Path(f).parts:
181
- LOGGER.warning(f'Potentially insecure file path: {f}, skipping extraction.')
182
+ if ".." in Path(f).parts:
183
+ LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
182
184
  continue
183
185
  zipObj.extract(f, extract_path)
184
186
 
185
187
  return path # return unzip dir
186
188
 
187
189
 
188
- def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
190
+ def check_disk_space(url="https://ultralytics.com/assets/coco128.zip", sf=1.5, hard=True):
189
191
  """
190
192
  Check if there is sufficient disk space to download and store a file.
191
193
 
@@ -199,20 +201,23 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
199
201
  """
200
202
  try:
201
203
  r = requests.head(url) # response
202
- assert r.status_code < 400, f'URL error for {url}: {r.status_code} {r.reason}' # check response
204
+ assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
203
205
  except Exception:
204
206
  return True # requests issue, default to True
205
207
 
206
208
  # Check file size
207
209
  gib = 1 << 30 # bytes per GiB
208
- data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
209
- total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
210
+ data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
211
+ total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes
212
+
210
213
  if data * sf < free:
211
214
  return True # sufficient space
212
215
 
213
216
  # Insufficient space
214
- text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
215
- f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
217
+ text = (
218
+ f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
219
+ f"Please free {data * sf - free:.1f} GB additional disk space and try again."
220
+ )
216
221
  if hard:
217
222
  raise MemoryError(text)
218
223
  LOGGER.warning(text)
@@ -238,36 +243,41 @@ def get_google_drive_file_info(link):
238
243
  url, filename = get_google_drive_file_info(link)
239
244
  ```
240
245
  """
241
- file_id = link.split('/d/')[1].split('/view')[0]
242
- drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
246
+ file_id = link.split("/d/")[1].split("/view")[0]
247
+ drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
243
248
  filename = None
244
249
 
245
250
  # Start session
246
251
  with requests.Session() as session:
247
252
  response = session.get(drive_url, stream=True)
248
- if 'quota exceeded' in str(response.content.lower()):
253
+ if "quota exceeded" in str(response.content.lower()):
249
254
  raise ConnectionError(
250
- emojis(f'❌ Google Drive file download quota exceeded. '
251
- f'Please try again later or download this file manually at {link}.'))
255
+ emojis(
256
+ f"❌ Google Drive file download quota exceeded. "
257
+ f"Please try again later or download this file manually at {link}."
258
+ )
259
+ )
252
260
  for k, v in response.cookies.items():
253
- if k.startswith('download_warning'):
254
- drive_url += f'&confirm={v}' # v is token
255
- cd = response.headers.get('content-disposition')
261
+ if k.startswith("download_warning"):
262
+ drive_url += f"&confirm={v}" # v is token
263
+ cd = response.headers.get("content-disposition")
256
264
  if cd:
257
265
  filename = re.findall('filename="(.+)"', cd)[0]
258
266
  return drive_url, filename
259
267
 
260
268
 
261
- def safe_download(url,
262
- file=None,
263
- dir=None,
264
- unzip=True,
265
- delete=False,
266
- curl=False,
267
- retry=3,
268
- min_bytes=1E0,
269
- exist_ok=False,
270
- progress=True):
269
+ def safe_download(
270
+ url,
271
+ file=None,
272
+ dir=None,
273
+ unzip=True,
274
+ delete=False,
275
+ curl=False,
276
+ retry=3,
277
+ min_bytes=1e0,
278
+ exist_ok=False,
279
+ progress=True,
280
+ ):
271
281
  """
272
282
  Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
273
283
 
@@ -294,36 +304,38 @@ def safe_download(url,
294
304
  path = safe_download(link)
295
305
  ```
296
306
  """
297
- gdrive = url.startswith('https://drive.google.com/') # check if the URL is a Google Drive link
307
+ gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
298
308
  if gdrive:
299
309
  url, file = get_google_drive_file_info(url)
300
310
 
301
- f = Path(dir or '.') / (file or url2file(url)) # URL converted to filename
302
- if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
311
+ f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
312
+ if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
303
313
  f = Path(url) # filename
304
314
  elif not f.is_file(): # URL and file do not exist
305
315
  desc = f"Downloading {url if gdrive else clean_url(url)} to '{f}'"
306
- LOGGER.info(f'{desc}...')
316
+ LOGGER.info(f"{desc}...")
307
317
  f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
308
318
  check_disk_space(url)
309
319
  for i in range(retry + 1):
310
320
  try:
311
321
  if curl or i > 0: # curl download with retry, continue
312
- s = 'sS' * (not progress) # silent
313
- r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
314
- assert r == 0, f'Curl return value {r}'
322
+ s = "sS" * (not progress) # silent
323
+ r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
324
+ assert r == 0, f"Curl return value {r}"
315
325
  else: # urllib download
316
- method = 'torch'
317
- if method == 'torch':
326
+ method = "torch"
327
+ if method == "torch":
318
328
  torch.hub.download_url_to_file(url, f, progress=progress)
319
329
  else:
320
- with request.urlopen(url) as response, TQDM(total=int(response.getheader('Content-Length', 0)),
321
- desc=desc,
322
- disable=not progress,
323
- unit='B',
324
- unit_scale=True,
325
- unit_divisor=1024) as pbar:
326
- with open(f, 'wb') as f_opened:
330
+ with request.urlopen(url) as response, TQDM(
331
+ total=int(response.getheader("Content-Length", 0)),
332
+ desc=desc,
333
+ disable=not progress,
334
+ unit="B",
335
+ unit_scale=True,
336
+ unit_divisor=1024,
337
+ ) as pbar:
338
+ with open(f, "wb") as f_opened:
327
339
  for data in response:
328
340
  f_opened.write(data)
329
341
  pbar.update(len(data))
@@ -334,26 +346,26 @@ def safe_download(url,
334
346
  f.unlink() # remove partial downloads
335
347
  except Exception as e:
336
348
  if i == 0 and not is_online():
337
- raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
349
+ raise ConnectionError(emojis(f"❌ Download failure for {url}. Environment is not online.")) from e
338
350
  elif i >= retry:
339
- raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
340
- LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
351
+ raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
352
+ LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
341
353
 
342
- if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
354
+ if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
343
355
  from zipfile import is_zipfile
344
356
 
345
357
  unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
346
358
  if is_zipfile(f):
347
359
  unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
348
- elif f.suffix in ('.tar', '.gz'):
349
- LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
350
- subprocess.run(['tar', 'xf' if f.suffix == '.tar' else 'xfz', f, '--directory', unzip_dir], check=True)
360
+ elif f.suffix in (".tar", ".gz"):
361
+ LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
362
+ subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
351
363
  if delete:
352
364
  f.unlink() # remove zip
353
365
  return unzip_dir
354
366
 
355
367
 
356
- def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
368
+ def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
357
369
  """
358
370
  Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the
359
371
  function fetches the latest release assets.
@@ -372,20 +384,20 @@ def get_github_assets(repo='ultralytics/assets', version='latest', retry=False):
372
384
  ```
373
385
  """
374
386
 
375
- if version != 'latest':
376
- version = f'tags/{version}' # i.e. tags/v6.2
377
- url = f'https://api.github.com/repos/{repo}/releases/{version}'
387
+ if version != "latest":
388
+ version = f"tags/{version}" # i.e. tags/v6.2
389
+ url = f"https://api.github.com/repos/{repo}/releases/{version}"
378
390
  r = requests.get(url) # github api
379
- if r.status_code != 200 and r.reason != 'rate limit exceeded' and retry: # failed and not 403 rate limit exceeded
391
+ if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
380
392
  r = requests.get(url) # try again
381
393
  if r.status_code != 200:
382
- LOGGER.warning(f'⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}')
383
- return '', []
394
+ LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
395
+ return "", []
384
396
  data = r.json()
385
- return data['tag_name'], [x['name'] for x in data['assets']] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
397
+ return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]
386
398
 
387
399
 
388
- def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **kwargs):
400
+ def attempt_download_asset(file, repo="ultralytics/assets", release="v0.0.0", **kwargs):
389
401
  """
390
402
  Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file
391
403
  locally first, then tries to download it from the specified GitHub repository release.
@@ -409,32 +421,32 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0', **
409
421
  # YOLOv3/5u updates
410
422
  file = str(file)
411
423
  file = checks.check_yolov5u_filename(file)
412
- file = Path(file.strip().replace("'", ''))
424
+ file = Path(file.strip().replace("'", ""))
413
425
  if file.exists():
414
426
  return str(file)
415
- elif (SETTINGS['weights_dir'] / file).exists():
416
- return str(SETTINGS['weights_dir'] / file)
427
+ elif (SETTINGS["weights_dir"] / file).exists():
428
+ return str(SETTINGS["weights_dir"] / file)
417
429
  else:
418
430
  # URL specified
419
431
  name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
420
- download_url = f'https://github.com/{repo}/releases/download'
421
- if str(file).startswith(('http:/', 'https:/')): # download
422
- url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
432
+ download_url = f"https://github.com/{repo}/releases/download"
433
+ if str(file).startswith(("http:/", "https:/")): # download
434
+ url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
423
435
  file = url2file(name) # parse authentication https://url.com/file.txt?auth...
424
436
  if Path(file).is_file():
425
- LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
437
+ LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
426
438
  else:
427
- safe_download(url=url, file=file, min_bytes=1E5, **kwargs)
439
+ safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
428
440
 
429
441
  elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
430
- safe_download(url=f'{download_url}/{release}/{name}', file=file, min_bytes=1E5, **kwargs)
442
+ safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
431
443
 
432
444
  else:
433
445
  tag, assets = get_github_assets(repo, release)
434
446
  if not assets:
435
447
  tag, assets = get_github_assets(repo) # latest release
436
448
  if name in assets:
437
- safe_download(url=f'{download_url}/{tag}/{name}', file=file, min_bytes=1E5, **kwargs)
449
+ safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
438
450
 
439
451
  return str(file)
440
452
 
@@ -464,14 +476,18 @@ def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=
464
476
  if threads > 1:
465
477
  with ThreadPool(threads) as pool:
466
478
  pool.map(
467
- lambda x: safe_download(url=x[0],
468
- dir=x[1],
469
- unzip=unzip,
470
- delete=delete,
471
- curl=curl,
472
- retry=retry,
473
- exist_ok=exist_ok,
474
- progress=threads <= 1), zip(url, repeat(dir)))
479
+ lambda x: safe_download(
480
+ url=x[0],
481
+ dir=x[1],
482
+ unzip=unzip,
483
+ delete=delete,
484
+ curl=curl,
485
+ retry=retry,
486
+ exist_ok=exist_ok,
487
+ progress=threads <= 1,
488
+ ),
489
+ zip(url, repeat(dir)),
490
+ )
475
491
  pool.close()
476
492
  pool.join()
477
493
  else:
@@ -17,6 +17,6 @@ class HUBModelError(Exception):
17
17
  The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
18
18
  """
19
19
 
20
- def __init__(self, message='Model not found. Please check model URL and try again.'):
20
+ def __init__(self, message="Model not found. Please check model URL and try again."):
21
21
  """Create an exception for when a model is not found."""
22
22
  super().__init__(emojis(message))
@@ -50,13 +50,13 @@ def spaces_in_path(path):
50
50
  """
51
51
 
52
52
  # If path has spaces, replace them with underscores
53
- if ' ' in str(path):
53
+ if " " in str(path):
54
54
  string = isinstance(path, str) # input type
55
55
  path = Path(path)
56
56
 
57
57
  # Create a temporary directory and construct the new path
58
58
  with tempfile.TemporaryDirectory() as tmp_dir:
59
- tmp_path = Path(tmp_dir) / path.name.replace(' ', '_')
59
+ tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
60
60
 
61
61
  # Copy file/directory
62
62
  if path.is_dir():
@@ -82,7 +82,7 @@ def spaces_in_path(path):
82
82
  yield path
83
83
 
84
84
 
85
- def increment_path(path, exist_ok=False, sep='', mkdir=False):
85
+ def increment_path(path, exist_ok=False, sep="", mkdir=False):
86
86
  """
87
87
  Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
88
88
 
@@ -102,11 +102,11 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
102
102
  """
103
103
  path = Path(path) # os-agnostic
104
104
  if path.exists() and not exist_ok:
105
- path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
105
+ path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
106
106
 
107
107
  # Method 1
108
108
  for n in range(2, 9999):
109
- p = f'{path}{sep}{n}{suffix}' # increment path
109
+ p = f"{path}{sep}{n}{suffix}" # increment path
110
110
  if not os.path.exists(p):
111
111
  break
112
112
  path = Path(p)
@@ -119,14 +119,14 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
119
119
 
120
120
  def file_age(path=__file__):
121
121
  """Return days since last file update."""
122
- dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
122
+ dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
123
123
  return dt.days # + dt.seconds / 86400 # fractional days
124
124
 
125
125
 
126
126
  def file_date(path=__file__):
127
127
  """Return human-readable file modification date, i.e. '2021-3-26'."""
128
128
  t = datetime.fromtimestamp(Path(path).stat().st_mtime)
129
- return f'{t.year}-{t.month}-{t.day}'
129
+ return f"{t.year}-{t.month}-{t.day}"
130
130
 
131
131
 
132
132
  def file_size(path):
@@ -137,11 +137,11 @@ def file_size(path):
137
137
  if path.is_file():
138
138
  return path.stat().st_size / mb
139
139
  elif path.is_dir():
140
- return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
140
+ return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
141
141
  return 0.0
142
142
 
143
143
 
144
- def get_latest_run(search_dir='.'):
144
+ def get_latest_run(search_dir="."):
145
145
  """Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
146
- last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
147
- return max(last_list, key=os.path.getctime) if last_list else ''
146
+ last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
147
+ return max(last_list, key=os.path.getctime) if last_list else ""