ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -21,12 +21,33 @@ import requests
21
21
  import torch
22
22
  from matplotlib import font_manager
23
23
 
24
- from ultralytics.utils import (ASSETS, AUTOINSTALL, LINUX, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, SimpleNamespace,
25
- ThreadingLocked, TryExcept, clean_url, colorstr, downloads, emojis, is_colab, is_docker,
26
- is_github_action_running, is_jupyter, is_kaggle, is_online, is_pip_package, url2file)
27
-
28
-
29
- def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
24
+ from ultralytics.utils import (
25
+ ASSETS,
26
+ AUTOINSTALL,
27
+ LINUX,
28
+ LOGGER,
29
+ ONLINE,
30
+ ROOT,
31
+ USER_CONFIG_DIR,
32
+ SimpleNamespace,
33
+ ThreadingLocked,
34
+ TryExcept,
35
+ clean_url,
36
+ colorstr,
37
+ downloads,
38
+ emojis,
39
+ is_colab,
40
+ is_docker,
41
+ is_github_action_running,
42
+ is_jupyter,
43
+ is_kaggle,
44
+ is_online,
45
+ is_pip_package,
46
+ url2file,
47
+ )
48
+
49
+
50
+ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
30
51
  """
31
52
  Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
32
53
 
@@ -46,23 +67,23 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
46
67
  """
47
68
 
48
69
  if package:
49
- requires = [x for x in metadata.distribution(package).requires if 'extra == ' not in x]
70
+ requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
50
71
  else:
51
72
  requires = Path(file_path).read_text().splitlines()
52
73
 
53
74
  requirements = []
54
75
  for line in requires:
55
76
  line = line.strip()
56
- if line and not line.startswith('#'):
57
- line = line.split('#')[0].strip() # ignore inline comments
58
- match = re.match(r'([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?', line)
77
+ if line and not line.startswith("#"):
78
+ line = line.split("#")[0].strip() # ignore inline comments
79
+ match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
59
80
  if match:
60
- requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ''))
81
+ requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
61
82
 
62
83
  return requirements
63
84
 
64
85
 
65
- def parse_version(version='0.0.0') -> tuple:
86
+ def parse_version(version="0.0.0") -> tuple:
66
87
  """
67
88
  Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
68
89
  function replaces deprecated 'pkg_resources.parse_version(v)'.
@@ -74,9 +95,9 @@ def parse_version(version='0.0.0') -> tuple:
74
95
  (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
75
96
  """
76
97
  try:
77
- return tuple(map(int, re.findall(r'\d+', version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
98
+ return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
78
99
  except Exception as e:
79
- LOGGER.warning(f'WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}')
100
+ LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
80
101
  return 0, 0, 0
81
102
 
82
103
 
@@ -121,15 +142,19 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
121
142
  elif isinstance(imgsz, (list, tuple)):
122
143
  imgsz = list(imgsz)
123
144
  else:
124
- raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
125
- f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
145
+ raise TypeError(
146
+ f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
147
+ f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
148
+ )
126
149
 
127
150
  # Apply max_dim
128
151
  if len(imgsz) > max_dim:
129
- msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
130
- "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
152
+ msg = (
153
+ "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
154
+ "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
155
+ )
131
156
  if max_dim != 1:
132
- raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
157
+ raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
133
158
  LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
134
159
  imgsz = [max(imgsz)]
135
160
  # Make image size a multiple of the stride
@@ -137,7 +162,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
137
162
 
138
163
  # Print warning message if image size was updated
139
164
  if sz != imgsz:
140
- LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
165
+ LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
141
166
 
142
167
  # Add missing dimensions if necessary
143
168
  sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@@ -145,12 +170,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
145
170
  return sz
146
171
 
147
172
 
148
- def check_version(current: str = '0.0.0',
149
- required: str = '0.0.0',
150
- name: str = 'version',
151
- hard: bool = False,
152
- verbose: bool = False,
153
- msg: str = '') -> bool:
173
+ def check_version(
174
+ current: str = "0.0.0",
175
+ required: str = "0.0.0",
176
+ name: str = "version",
177
+ hard: bool = False,
178
+ verbose: bool = False,
179
+ msg: str = "",
180
+ ) -> bool:
154
181
  """
155
182
  Check current version against the required version or range.
156
183
 
@@ -181,7 +208,7 @@ def check_version(current: str = '0.0.0',
181
208
  ```
182
209
  """
183
210
  if not current: # if current is '' or None
184
- LOGGER.warning(f'WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.')
211
+ LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
185
212
  return True
186
213
  elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
187
214
  try:
@@ -189,34 +216,34 @@ def check_version(current: str = '0.0.0',
189
216
  current = metadata.version(current) # get version string from package name
190
217
  except metadata.PackageNotFoundError:
191
218
  if hard:
192
- raise ModuleNotFoundError(emojis(f'WARNING ⚠️ {current} package is required but not installed'))
219
+ raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed"))
193
220
  else:
194
221
  return False
195
222
 
196
223
  if not required: # if required is '' or None
197
224
  return True
198
225
 
199
- op = ''
200
- version = ''
226
+ op = ""
227
+ version = ""
201
228
  result = True
202
229
  c = parse_version(current) # '1.2.3' -> (1, 2, 3)
203
- for r in required.strip(',').split(','):
204
- op, version = re.match(r'([^0-9]*)([\d.]+)', r).groups() # split '>=22.04' -> ('>=', '22.04')
230
+ for r in required.strip(",").split(","):
231
+ op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
205
232
  v = parse_version(version) # '1.2.3' -> (1, 2, 3)
206
- if op == '==' and c != v:
233
+ if op == "==" and c != v:
207
234
  result = False
208
- elif op == '!=' and c == v:
235
+ elif op == "!=" and c == v:
209
236
  result = False
210
- elif op in ('>=', '') and not (c >= v): # if no constraint passed assume '>=required'
237
+ elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
211
238
  result = False
212
- elif op == '<=' and not (c <= v):
239
+ elif op == "<=" and not (c <= v):
213
240
  result = False
214
- elif op == '>' and not (c > v):
241
+ elif op == ">" and not (c > v):
215
242
  result = False
216
- elif op == '<' and not (c < v):
243
+ elif op == "<" and not (c < v):
217
244
  result = False
218
245
  if not result:
219
- warning = f'WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}'
246
+ warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
220
247
  if hard:
221
248
  raise ModuleNotFoundError(emojis(warning)) # assert version requirements met
222
249
  if verbose:
@@ -224,7 +251,7 @@ def check_version(current: str = '0.0.0',
224
251
  return result
225
252
 
226
253
 
227
- def check_latest_pypi_version(package_name='ultralytics'):
254
+ def check_latest_pypi_version(package_name="ultralytics"):
228
255
  """
229
256
  Returns the latest version of a PyPI package without downloading or installing it.
230
257
 
@@ -236,9 +263,9 @@ def check_latest_pypi_version(package_name='ultralytics'):
236
263
  """
237
264
  with contextlib.suppress(Exception):
238
265
  requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
239
- response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
266
+ response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
240
267
  if response.status_code == 200:
241
- return response.json()['info']['version']
268
+ return response.json()["info"]["version"]
242
269
 
243
270
 
244
271
  def check_pip_update_available():
@@ -251,16 +278,19 @@ def check_pip_update_available():
251
278
  if ONLINE and is_pip_package():
252
279
  with contextlib.suppress(Exception):
253
280
  from ultralytics import __version__
281
+
254
282
  latest = check_latest_pypi_version()
255
- if check_version(__version__, f'<{latest}'): # check if current version is < latest version
256
- LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
257
- f"Update with 'pip install -U ultralytics'")
283
+ if check_version(__version__, f"<{latest}"): # check if current version is < latest version
284
+ LOGGER.info(
285
+ f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
286
+ f"Update with 'pip install -U ultralytics'"
287
+ )
258
288
  return True
259
289
  return False
260
290
 
261
291
 
262
292
  @ThreadingLocked()
263
- def check_font(font='Arial.ttf'):
293
+ def check_font(font="Arial.ttf"):
264
294
  """
265
295
  Find font locally or download to user's configuration directory if it does not already exist.
266
296
 
@@ -283,13 +313,13 @@ def check_font(font='Arial.ttf'):
283
313
  return matches[0]
284
314
 
285
315
  # Download to USER_CONFIG_DIR if missing
286
- url = f'https://ultralytics.com/assets/{name}'
316
+ url = f"https://ultralytics.com/assets/{name}"
287
317
  if downloads.is_url(url):
288
318
  downloads.safe_download(url=url, file=file)
289
319
  return file
290
320
 
291
321
 
292
- def check_python(minimum: str = '3.8.0') -> bool:
322
+ def check_python(minimum: str = "3.8.0") -> bool:
293
323
  """
294
324
  Check current python version against the required minimum version.
295
325
 
@@ -299,11 +329,11 @@ def check_python(minimum: str = '3.8.0') -> bool:
299
329
  Returns:
300
330
  None
301
331
  """
302
- return check_version(platform.python_version(), minimum, name='Python ', hard=True)
332
+ return check_version(platform.python_version(), minimum, name="Python ", hard=True)
303
333
 
304
334
 
305
335
  @TryExcept()
306
- def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
336
+ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
307
337
  """
308
338
  Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
309
339
 
@@ -329,41 +359,42 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
329
359
  ```
330
360
  """
331
361
 
332
- prefix = colorstr('red', 'bold', 'requirements:')
362
+ prefix = colorstr("red", "bold", "requirements:")
333
363
  check_python() # check python version
334
364
  check_torchvision() # check torch-torchvision compatibility
335
365
  if isinstance(requirements, Path): # requirements.txt file
336
366
  file = requirements.resolve()
337
- assert file.exists(), f'{prefix} {file} not found, check failed.'
338
- requirements = [f'{x.name}{x.specifier}' for x in parse_requirements(file) if x.name not in exclude]
367
+ assert file.exists(), f"{prefix} {file} not found, check failed."
368
+ requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
339
369
  elif isinstance(requirements, str):
340
370
  requirements = [requirements]
341
371
 
342
372
  pkgs = []
343
373
  for r in requirements:
344
- r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
345
- match = re.match(r'([a-zA-Z0-9-_]+)([<>!=~]+.*)?', r_stripped)
346
- name, required = match[1], match[2].strip() if match[2] else ''
374
+ r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
375
+ match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
376
+ name, required = match[1], match[2].strip() if match[2] else ""
347
377
  try:
348
378
  assert check_version(metadata.version(name), required) # exception if requirements not met
349
379
  except (AssertionError, metadata.PackageNotFoundError):
350
380
  pkgs.append(r)
351
381
 
352
- s = ' '.join(f'"{x}"' for x in pkgs) # console string
382
+ s = " ".join(f'"{x}"' for x in pkgs) # console string
353
383
  if s:
354
384
  if install and AUTOINSTALL: # check environment variable
355
385
  n = len(pkgs) # number of packages updates
356
386
  LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
357
387
  try:
358
388
  t = time.time()
359
- assert is_online(), 'AutoUpdate skipped (offline)'
360
- LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
389
+ assert is_online(), "AutoUpdate skipped (offline)"
390
+ LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
361
391
  dt = time.time() - t
362
392
  LOGGER.info(
363
393
  f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
364
- f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
394
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
395
+ )
365
396
  except Exception as e:
366
- LOGGER.warning(f'{prefix} ❌ {e}')
397
+ LOGGER.warning(f"{prefix} ❌ {e}")
367
398
  return False
368
399
  else:
369
400
  return False
@@ -386,76 +417,82 @@ def check_torchvision():
386
417
  import torchvision
387
418
 
388
419
  # Compatibility table
389
- compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
420
+ compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]}
390
421
 
391
422
  # Extract only the major and minor versions
392
- v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
393
- v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
423
+ v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
424
+ v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2])
394
425
 
395
426
  if v_torch in compatibility_table:
396
427
  compatible_versions = compatibility_table[v_torch]
397
428
  if all(v_torchvision != v for v in compatible_versions):
398
- print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
399
- f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
400
- "'pip install -U torch torchvision' to update both.\n"
401
- 'For a full compatibility table see https://github.com/pytorch/vision#installation')
429
+ print(
430
+ f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
431
+ f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
432
+ "'pip install -U torch torchvision' to update both.\n"
433
+ "For a full compatibility table see https://github.com/pytorch/vision#installation"
434
+ )
402
435
 
403
436
 
404
- def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
437
+ def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
405
438
  """Check file(s) for acceptable suffix."""
406
439
  if file and suffix:
407
440
  if isinstance(suffix, str):
408
- suffix = (suffix, )
441
+ suffix = (suffix,)
409
442
  for f in file if isinstance(file, (list, tuple)) else [file]:
410
443
  s = Path(f).suffix.lower().strip() # file suffix
411
444
  if len(s):
412
- assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
445
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"
413
446
 
414
447
 
415
448
  def check_yolov5u_filename(file: str, verbose: bool = True):
416
449
  """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
417
- if 'yolov3' in file or 'yolov5' in file:
418
- if 'u.yaml' in file:
419
- file = file.replace('u.yaml', '.yaml') # i.e. yolov5nu.yaml -> yolov5n.yaml
420
- elif '.pt' in file and 'u' not in file:
450
+ if "yolov3" in file or "yolov5" in file:
451
+ if "u.yaml" in file:
452
+ file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
453
+ elif ".pt" in file and "u" not in file:
421
454
  original_file = file
422
- file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
423
- file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
424
- file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
455
+ file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
456
+ file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
457
+ file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
425
458
  if file != original_file and verbose:
426
459
  LOGGER.info(
427
460
  f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
428
- f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
429
- f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
461
+ f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
462
+ f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
463
+ )
430
464
  return file
431
465
 
432
466
 
433
- def check_model_file_from_stem(model='yolov8n'):
467
+ def check_model_file_from_stem(model="yolov8n"):
434
468
  """Return a model filename from a valid model stem."""
435
469
  if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
436
- return Path(model).with_suffix('.pt') # add suffix, i.e. yolov8n -> yolov8n.pt
470
+ return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt
437
471
  else:
438
472
  return model
439
473
 
440
474
 
441
- def check_file(file, suffix='', download=True, hard=True):
475
+ def check_file(file, suffix="", download=True, hard=True):
442
476
  """Search/download file (if necessary) and return path."""
443
477
  check_suffix(file, suffix) # optional
444
478
  file = str(file).strip() # convert to string and strip spaces
445
479
  file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
446
- if (not file or ('://' not in file and Path(file).exists()) or # '://' check required in Windows Python<3.10
447
- file.lower().startswith('grpc://')): # file exists or gRPC Triton images
480
+ if (
481
+ not file
482
+ or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
483
+ or file.lower().startswith("grpc://")
484
+ ): # file exists or gRPC Triton images
448
485
  return file
449
- elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://', 'tcp://')): # download
486
+ elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
450
487
  url = file # warning: Pathlib turns :// -> :/
451
488
  file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
452
489
  if Path(file).exists():
453
- LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
490
+ LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
454
491
  else:
455
492
  downloads.safe_download(url=url, file=file, unzip=False)
456
493
  return file
457
494
  else: # search
458
- files = glob.glob(str(ROOT / 'cfg' / '**' / file), recursive=True) # find file
495
+ files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
459
496
  if not files and hard:
460
497
  raise FileNotFoundError(f"'{file}' does not exist")
461
498
  elif len(files) > 1 and hard:
@@ -463,7 +500,7 @@ def check_file(file, suffix='', download=True, hard=True):
463
500
  return files[0] if len(files) else [] # return file
464
501
 
465
502
 
466
- def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
503
+ def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
467
504
  """Search/download YAML file (if necessary) and return path, checking suffix."""
468
505
  return check_file(file, suffix, hard=hard)
469
506
 
@@ -482,51 +519,52 @@ def check_is_path_safe(basedir, path):
482
519
  base_dir_resolved = Path(basedir).resolve()
483
520
  path_resolved = Path(path).resolve()
484
521
 
485
- return path_resolved.is_file() and path_resolved.parts[:len(base_dir_resolved.parts)] == base_dir_resolved.parts
522
+ return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
486
523
 
487
524
 
488
525
  def check_imshow(warn=False):
489
526
  """Check if environment supports image displays."""
490
527
  try:
491
528
  if LINUX:
492
- assert 'DISPLAY' in os.environ and not is_docker() and not is_colab() and not is_kaggle()
493
- cv2.imshow('test', np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
529
+ assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle()
530
+ cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
494
531
  cv2.waitKey(1)
495
532
  cv2.destroyAllWindows()
496
533
  cv2.waitKey(1)
497
534
  return True
498
535
  except Exception as e:
499
536
  if warn:
500
- LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
537
+ LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
501
538
  return False
502
539
 
503
540
 
504
- def check_yolo(verbose=True, device=''):
541
+ def check_yolo(verbose=True, device=""):
505
542
  """Return a human-readable YOLO software and hardware summary."""
506
543
  import psutil
507
544
 
508
545
  from ultralytics.utils.torch_utils import select_device
509
546
 
510
547
  if is_jupyter():
511
- if check_requirements('wandb', install=False):
512
- os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
548
+ if check_requirements("wandb", install=False):
549
+ os.system("pip uninstall -y wandb") # uninstall wandb: unwanted account creation prompt with infinite hang
513
550
  if is_colab():
514
- shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
551
+ shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
515
552
 
516
553
  if verbose:
517
554
  # System info
518
555
  gib = 1 << 30 # bytes per GiB
519
556
  ram = psutil.virtual_memory().total
520
- total, used, free = shutil.disk_usage('/')
521
- s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
557
+ total, used, free = shutil.disk_usage("/")
558
+ s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
522
559
  with contextlib.suppress(Exception): # clear display if ipython is installed
523
560
  from IPython import display
561
+
524
562
  display.clear_output()
525
563
  else:
526
- s = ''
564
+ s = ""
527
565
 
528
566
  select_device(device=device, newline=False)
529
- LOGGER.info(f'Setup complete ✅ {s}')
567
+ LOGGER.info(f"Setup complete ✅ {s}")
530
568
 
531
569
 
532
570
  def collect_system_info():
@@ -537,32 +575,36 @@ def collect_system_info():
537
575
  from ultralytics.utils import ENVIRONMENT, is_git_dir
538
576
  from ultralytics.utils.torch_utils import get_cpu_info
539
577
 
540
- ram_info = psutil.virtual_memory().total / (1024 ** 3) # Convert bytes to GB
578
+ ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
541
579
  check_yolo()
542
- LOGGER.info(f"\n{'OS':<20}{platform.platform()}\n"
543
- f"{'Environment':<20}{ENVIRONMENT}\n"
544
- f"{'Python':<20}{sys.version.split()[0]}\n"
545
- f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
546
- f"{'RAM':<20}{ram_info:.2f} GB\n"
547
- f"{'CPU':<20}{get_cpu_info()}\n"
548
- f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n")
549
-
550
- for r in parse_requirements(package='ultralytics'):
580
+ LOGGER.info(
581
+ f"\n{'OS':<20}{platform.platform()}\n"
582
+ f"{'Environment':<20}{ENVIRONMENT}\n"
583
+ f"{'Python':<20}{sys.version.split()[0]}\n"
584
+ f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n"
585
+ f"{'RAM':<20}{ram_info:.2f} GB\n"
586
+ f"{'CPU':<20}{get_cpu_info()}\n"
587
+ f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
588
+ )
589
+
590
+ for r in parse_requirements(package="ultralytics"):
551
591
  try:
552
592
  current = metadata.version(r.name)
553
- is_met = '' if check_version(current, str(r.specifier), hard=True) else ''
593
+ is_met = "" if check_version(current, str(r.specifier), hard=True) else ""
554
594
  except metadata.PackageNotFoundError:
555
- current = '(not installed)'
556
- is_met = ''
557
- LOGGER.info(f'{r.name:<20}{is_met}{current}{r.specifier}')
595
+ current = "(not installed)"
596
+ is_met = ""
597
+ LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
558
598
 
559
599
  if is_github_action_running():
560
- LOGGER.info(f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
561
- f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
562
- f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
563
- f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
564
- f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
565
- f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n")
600
+ LOGGER.info(
601
+ f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
602
+ f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
603
+ f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
604
+ f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
605
+ f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
606
+ f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
607
+ )
566
608
 
567
609
 
568
610
  def check_amp(model):
@@ -587,7 +629,7 @@ def check_amp(model):
587
629
  (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
588
630
  """
589
631
  device = next(model.parameters()).device # get model device
590
- if device.type in ('cpu', 'mps'):
632
+ if device.type in ("cpu", "mps"):
591
633
  return False # AMP only used on CUDA devices
592
634
 
593
635
  def amp_allclose(m, im):
@@ -598,22 +640,27 @@ def check_amp(model):
598
640
  del m
599
641
  return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
600
642
 
601
- im = ASSETS / 'bus.jpg' # image to check
602
- prefix = colorstr('AMP: ')
603
- LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
643
+ im = ASSETS / "bus.jpg" # image to check
644
+ prefix = colorstr("AMP: ")
645
+ LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")
604
646
  warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
605
647
  try:
606
648
  from ultralytics import YOLO
607
- assert amp_allclose(YOLO('yolov8n.pt'), im)
608
- LOGGER.info(f'{prefix}checks passed ✅')
649
+
650
+ assert amp_allclose(YOLO("yolov8n.pt"), im)
651
+ LOGGER.info(f"{prefix}checks passed ✅")
609
652
  except ConnectionError:
610
- LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
653
+ LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")
611
654
  except (AttributeError, ModuleNotFoundError):
612
- LOGGER.warning(f'{prefix}checks skipped ⚠️. '
613
- f'Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}')
655
+ LOGGER.warning(
656
+ f"{prefix}checks skipped ⚠️. "
657
+ f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
658
+ )
614
659
  except AssertionError:
615
- LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
616
- f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
660
+ LOGGER.warning(
661
+ f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
662
+ f"NaN losses or zero-mAP results, so AMP will be disabled during training."
663
+ )
617
664
  return False
618
665
  return True
619
666
 
@@ -621,8 +668,8 @@ def check_amp(model):
621
668
  def git_describe(path=ROOT): # path must be a directory
622
669
  """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
623
670
  with contextlib.suppress(Exception):
624
- return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
625
- return ''
671
+ return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
672
+ return ""
626
673
 
627
674
 
628
675
  def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
@@ -630,7 +677,7 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
630
677
 
631
678
  def strip_auth(v):
632
679
  """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
633
- return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
680
+ return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
634
681
 
635
682
  x = inspect.currentframe().f_back # previous frame
636
683
  file, _, func, _, _ = inspect.getframeinfo(x)
@@ -638,11 +685,11 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
638
685
  args, _, _, frm = inspect.getargvalues(x)
639
686
  args = {k: v for k, v in frm.items() if k in args}
640
687
  try:
641
- file = Path(file).resolve().relative_to(ROOT).with_suffix('')
688
+ file = Path(file).resolve().relative_to(ROOT).with_suffix("")
642
689
  except ValueError:
643
690
  file = Path(file).stem
644
- s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
645
- LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
691
+ s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
692
+ LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))
646
693
 
647
694
 
648
695
  def cuda_device_count() -> int:
@@ -654,11 +701,12 @@ def cuda_device_count() -> int:
654
701
  """
655
702
  try:
656
703
  # Run the nvidia-smi command and capture its output
657
- output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
658
- encoding='utf-8')
704
+ output = subprocess.check_output(
705
+ ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
706
+ )
659
707
 
660
708
  # Take the first line and strip any leading/trailing white space
661
- first_line = output.strip().split('\n')[0]
709
+ first_line = output.strip().split("\n")[0]
662
710
 
663
711
  return int(first_line)
664
712
  except (subprocess.CalledProcessError, FileNotFoundError, ValueError):