bplusplus 0.1.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bplusplus might be problematic. Click here for more details.

Files changed (95) hide show
  1. bplusplus/__init__.py +5 -3
  2. bplusplus/{collect_images.py → collect.py} +3 -3
  3. bplusplus/prepare.py +573 -0
  4. bplusplus/train_validate.py +8 -64
  5. bplusplus/yolov5detect/__init__.py +1 -0
  6. bplusplus/yolov5detect/detect.py +444 -0
  7. bplusplus/yolov5detect/export.py +1530 -0
  8. bplusplus/yolov5detect/insect.yaml +8 -0
  9. bplusplus/yolov5detect/models/__init__.py +0 -0
  10. bplusplus/yolov5detect/models/common.py +1109 -0
  11. bplusplus/yolov5detect/models/experimental.py +130 -0
  12. bplusplus/yolov5detect/models/hub/anchors.yaml +56 -0
  13. bplusplus/yolov5detect/models/hub/yolov3-spp.yaml +52 -0
  14. bplusplus/yolov5detect/models/hub/yolov3-tiny.yaml +42 -0
  15. bplusplus/yolov5detect/models/hub/yolov3.yaml +52 -0
  16. bplusplus/yolov5detect/models/hub/yolov5-bifpn.yaml +49 -0
  17. bplusplus/yolov5detect/models/hub/yolov5-fpn.yaml +43 -0
  18. bplusplus/yolov5detect/models/hub/yolov5-p2.yaml +55 -0
  19. bplusplus/yolov5detect/models/hub/yolov5-p34.yaml +42 -0
  20. bplusplus/yolov5detect/models/hub/yolov5-p6.yaml +57 -0
  21. bplusplus/yolov5detect/models/hub/yolov5-p7.yaml +68 -0
  22. bplusplus/yolov5detect/models/hub/yolov5-panet.yaml +49 -0
  23. bplusplus/yolov5detect/models/hub/yolov5l6.yaml +61 -0
  24. bplusplus/yolov5detect/models/hub/yolov5m6.yaml +61 -0
  25. bplusplus/yolov5detect/models/hub/yolov5n6.yaml +61 -0
  26. bplusplus/yolov5detect/models/hub/yolov5s-LeakyReLU.yaml +50 -0
  27. bplusplus/yolov5detect/models/hub/yolov5s-ghost.yaml +49 -0
  28. bplusplus/yolov5detect/models/hub/yolov5s-transformer.yaml +49 -0
  29. bplusplus/yolov5detect/models/hub/yolov5s6.yaml +61 -0
  30. bplusplus/yolov5detect/models/hub/yolov5x6.yaml +61 -0
  31. bplusplus/yolov5detect/models/segment/yolov5l-seg.yaml +49 -0
  32. bplusplus/yolov5detect/models/segment/yolov5m-seg.yaml +49 -0
  33. bplusplus/yolov5detect/models/segment/yolov5n-seg.yaml +49 -0
  34. bplusplus/yolov5detect/models/segment/yolov5s-seg.yaml +49 -0
  35. bplusplus/yolov5detect/models/segment/yolov5x-seg.yaml +49 -0
  36. bplusplus/yolov5detect/models/tf.py +797 -0
  37. bplusplus/yolov5detect/models/yolo.py +495 -0
  38. bplusplus/yolov5detect/models/yolov5l.yaml +49 -0
  39. bplusplus/yolov5detect/models/yolov5m.yaml +49 -0
  40. bplusplus/yolov5detect/models/yolov5n.yaml +49 -0
  41. bplusplus/yolov5detect/models/yolov5s.yaml +49 -0
  42. bplusplus/yolov5detect/models/yolov5x.yaml +49 -0
  43. bplusplus/yolov5detect/utils/__init__.py +97 -0
  44. bplusplus/yolov5detect/utils/activations.py +134 -0
  45. bplusplus/yolov5detect/utils/augmentations.py +448 -0
  46. bplusplus/yolov5detect/utils/autoanchor.py +175 -0
  47. bplusplus/yolov5detect/utils/autobatch.py +70 -0
  48. bplusplus/yolov5detect/utils/aws/__init__.py +0 -0
  49. bplusplus/yolov5detect/utils/aws/mime.sh +26 -0
  50. bplusplus/yolov5detect/utils/aws/resume.py +41 -0
  51. bplusplus/yolov5detect/utils/aws/userdata.sh +27 -0
  52. bplusplus/yolov5detect/utils/callbacks.py +72 -0
  53. bplusplus/yolov5detect/utils/dataloaders.py +1385 -0
  54. bplusplus/yolov5detect/utils/docker/Dockerfile +73 -0
  55. bplusplus/yolov5detect/utils/docker/Dockerfile-arm64 +40 -0
  56. bplusplus/yolov5detect/utils/docker/Dockerfile-cpu +42 -0
  57. bplusplus/yolov5detect/utils/downloads.py +136 -0
  58. bplusplus/yolov5detect/utils/flask_rest_api/README.md +70 -0
  59. bplusplus/yolov5detect/utils/flask_rest_api/example_request.py +17 -0
  60. bplusplus/yolov5detect/utils/flask_rest_api/restapi.py +49 -0
  61. bplusplus/yolov5detect/utils/general.py +1294 -0
  62. bplusplus/yolov5detect/utils/google_app_engine/Dockerfile +25 -0
  63. bplusplus/yolov5detect/utils/google_app_engine/additional_requirements.txt +6 -0
  64. bplusplus/yolov5detect/utils/google_app_engine/app.yaml +16 -0
  65. bplusplus/yolov5detect/utils/loggers/__init__.py +476 -0
  66. bplusplus/yolov5detect/utils/loggers/clearml/README.md +222 -0
  67. bplusplus/yolov5detect/utils/loggers/clearml/__init__.py +0 -0
  68. bplusplus/yolov5detect/utils/loggers/clearml/clearml_utils.py +230 -0
  69. bplusplus/yolov5detect/utils/loggers/clearml/hpo.py +90 -0
  70. bplusplus/yolov5detect/utils/loggers/comet/README.md +250 -0
  71. bplusplus/yolov5detect/utils/loggers/comet/__init__.py +551 -0
  72. bplusplus/yolov5detect/utils/loggers/comet/comet_utils.py +151 -0
  73. bplusplus/yolov5detect/utils/loggers/comet/hpo.py +126 -0
  74. bplusplus/yolov5detect/utils/loggers/comet/optimizer_config.json +135 -0
  75. bplusplus/yolov5detect/utils/loggers/wandb/__init__.py +0 -0
  76. bplusplus/yolov5detect/utils/loggers/wandb/wandb_utils.py +210 -0
  77. bplusplus/yolov5detect/utils/loss.py +259 -0
  78. bplusplus/yolov5detect/utils/metrics.py +381 -0
  79. bplusplus/yolov5detect/utils/plots.py +517 -0
  80. bplusplus/yolov5detect/utils/segment/__init__.py +0 -0
  81. bplusplus/yolov5detect/utils/segment/augmentations.py +100 -0
  82. bplusplus/yolov5detect/utils/segment/dataloaders.py +366 -0
  83. bplusplus/yolov5detect/utils/segment/general.py +160 -0
  84. bplusplus/yolov5detect/utils/segment/loss.py +198 -0
  85. bplusplus/yolov5detect/utils/segment/metrics.py +225 -0
  86. bplusplus/yolov5detect/utils/segment/plots.py +152 -0
  87. bplusplus/yolov5detect/utils/torch_utils.py +482 -0
  88. bplusplus/yolov5detect/utils/triton.py +90 -0
  89. bplusplus-1.1.0.dist-info/METADATA +179 -0
  90. bplusplus-1.1.0.dist-info/RECORD +92 -0
  91. bplusplus/build_model.py +0 -38
  92. bplusplus-0.1.1.dist-info/METADATA +0 -97
  93. bplusplus-0.1.1.dist-info/RECORD +0 -8
  94. {bplusplus-0.1.1.dist-info → bplusplus-1.1.0.dist-info}/LICENSE +0 -0
  95. {bplusplus-0.1.1.dist-info → bplusplus-1.1.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1294 @@
1
+ # Ultralytics YOLOv5 🚀, AGPL-3.0 license
2
+ """General utils."""
3
+
4
+ import contextlib
5
+ import glob
6
+ import inspect
7
+ import logging
8
+ import logging.config
9
+ import math
10
+ import os
11
+ import platform
12
+ import random
13
+ import re
14
+ import signal
15
+ import subprocess
16
+ import sys
17
+ import time
18
+ import urllib
19
+ from copy import deepcopy
20
+ from datetime import datetime
21
+ from itertools import repeat
22
+ from multiprocessing.pool import ThreadPool
23
+ from pathlib import Path
24
+ from subprocess import check_output
25
+ from tarfile import is_tarfile
26
+ from typing import Optional
27
+ from zipfile import ZipFile, is_zipfile
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import pandas as pd
32
+ import pkg_resources as pkg
33
+ import torch
34
+ import torchvision
35
+ import yaml
36
+
37
+ # Import 'ultralytics' package or install if missing
38
+ try:
39
+ import ultralytics
40
+
41
+ assert hasattr(ultralytics, "__version__") # verify package is not directory
42
+ except (ImportError, AssertionError):
43
+ os.system("pip install -U ultralytics")
44
+ import ultralytics
45
+
46
+ from ultralytics.utils.checks import check_requirements
47
+
48
+ from utils import TryExcept, emojis
49
+ from utils.downloads import curl_download, gsutil_getsize
50
+ from utils.metrics import box_iou, fitness
51
+
52
+ FILE = Path(__file__).resolve()
53
+ ROOT = FILE.parents[1] # YOLOv5 root directory
54
+ RANK = int(os.getenv("RANK", -1))
55
+
56
+ # Settings
57
+ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
58
+ DATASETS_DIR = Path(os.getenv("YOLOv5_DATASETS_DIR", ROOT.parent / "datasets")) # global datasets directory
59
+ AUTOINSTALL = str(os.getenv("YOLOv5_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
60
+ VERBOSE = str(os.getenv("YOLOv5_VERBOSE", True)).lower() == "true" # global verbose mode
61
+ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" # tqdm bar format
62
+ FONT = "Arial.ttf" # https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf
63
+
64
+ torch.set_printoptions(linewidth=320, precision=5, profile="long")
65
+ np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
66
+ pd.options.display.max_columns = 10
67
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
68
+ os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
69
+ os.environ["OMP_NUM_THREADS"] = "1" if platform.system() == "darwin" else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
70
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress verbose TF compiler warnings in Colab
71
+ os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
72
+ os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
73
+
74
+
75
+ def is_ascii(s=""):
76
+ """Checks if input string `s` contains only ASCII characters; returns `True` if so, otherwise `False`."""
77
+ s = str(s) # convert list, tuple, None, etc. to str
78
+ return len(s.encode().decode("ascii", "ignore")) == len(s)
79
+
80
+
81
+ def is_chinese(s="人工智能"):
82
+ """Determines if a string `s` contains any Chinese characters; returns `True` if so, otherwise `False`."""
83
+ return bool(re.search("[\u4e00-\u9fff]", str(s)))
84
+
85
+
86
+ def is_colab():
87
+ """Checks if the current environment is a Google Colab instance; returns `True` for Colab, otherwise `False`."""
88
+ return "google.colab" in sys.modules
89
+
90
+
91
+ def is_jupyter():
92
+ """
93
+ Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
94
+
95
+ Returns:
96
+ bool: True if running inside a Jupyter Notebook, False otherwise.
97
+ """
98
+ with contextlib.suppress(Exception):
99
+ from IPython import get_ipython
100
+
101
+ return get_ipython() is not None
102
+ return False
103
+
104
+
105
+ def is_kaggle():
106
+ """Checks if the current environment is a Kaggle Notebook by validating environment variables."""
107
+ return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
108
+
109
+
110
+ def is_docker() -> bool:
111
+ """Check if the process runs inside a docker container."""
112
+ if Path("/.dockerenv").exists():
113
+ return True
114
+ try: # check if docker is in control groups
115
+ with open("/proc/self/cgroup") as file:
116
+ return any("docker" in line for line in file)
117
+ except OSError:
118
+ return False
119
+
120
+
121
+ def is_writeable(dir, test=False):
122
+ """Checks if a directory is writable, optionally testing by creating a temporary file if `test=True`."""
123
+ if not test:
124
+ return os.access(dir, os.W_OK) # possible issues on Windows
125
+ file = Path(dir) / "tmp.txt"
126
+ try:
127
+ with open(file, "w"): # open file with write permissions
128
+ pass
129
+ file.unlink() # remove file
130
+ return True
131
+ except OSError:
132
+ return False
133
+
134
+
135
+ LOGGING_NAME = "yolov5"
136
+
137
+
138
+ def set_logging(name=LOGGING_NAME, verbose=True):
139
+ """Configures logging with specified verbosity; `name` sets the logger's name, `verbose` controls logging level."""
140
+ rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings
141
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
142
+ logging.config.dictConfig(
143
+ {
144
+ "version": 1,
145
+ "disable_existing_loggers": False,
146
+ "formatters": {name: {"format": "%(message)s"}},
147
+ "handlers": {
148
+ name: {
149
+ "class": "logging.StreamHandler",
150
+ "formatter": name,
151
+ "level": level,
152
+ }
153
+ },
154
+ "loggers": {
155
+ name: {
156
+ "level": level,
157
+ "handlers": [name],
158
+ "propagate": False,
159
+ }
160
+ },
161
+ }
162
+ )
163
+
164
+
165
+ set_logging(LOGGING_NAME) # run before defining LOGGER
166
+ LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
167
+ if platform.system() == "Windows":
168
+ for fn in LOGGER.info, LOGGER.warning:
169
+ setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
170
+
171
+
172
+ def user_config_dir(dir="Ultralytics", env_var="YOLOV5_CONFIG_DIR"):
173
+ """Returns user configuration directory path, preferring environment variable `YOLOV5_CONFIG_DIR` if set, else OS-
174
+ specific.
175
+ """
176
+ env = os.getenv(env_var)
177
+ if env:
178
+ path = Path(env) # use environment variable
179
+ else:
180
+ cfg = {"Windows": "AppData/Roaming", "Linux": ".config", "Darwin": "Library/Application Support"} # 3 OS dirs
181
+ path = Path.home() / cfg.get(platform.system(), "") # OS-specific config dir
182
+ path = (path if is_writeable(path) else Path("/tmp")) / dir # GCP and AWS lambda fix, only /tmp is writeable
183
+ path.mkdir(exist_ok=True) # make if required
184
+ return path
185
+
186
+
187
+ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
188
+
189
+
190
+ class Profile(contextlib.ContextDecorator):
191
+ """Context manager and decorator for profiling code execution time, with optional CUDA synchronization."""
192
+
193
+ def __init__(self, t=0.0, device: torch.device = None):
194
+ """Initializes a profiling context for YOLOv5 with optional timing threshold and device specification."""
195
+ self.t = t
196
+ self.device = device
197
+ self.cuda = bool(device and str(device).startswith("cuda"))
198
+
199
+ def __enter__(self):
200
+ """Initializes timing at the start of a profiling context block for performance measurement."""
201
+ self.start = self.time()
202
+ return self
203
+
204
+ def __exit__(self, type, value, traceback):
205
+ """Concludes timing, updating duration for profiling upon exiting a context block."""
206
+ self.dt = self.time() - self.start # delta-time
207
+ self.t += self.dt # accumulate dt
208
+
209
+ def time(self):
210
+ """Measures and returns the current time, synchronizing CUDA operations if `cuda` is True."""
211
+ if self.cuda:
212
+ torch.cuda.synchronize(self.device)
213
+ return time.time()
214
+
215
+
216
+ class Timeout(contextlib.ContextDecorator):
217
+ """Enforces a timeout on code execution, raising TimeoutError if the specified duration is exceeded."""
218
+
219
+ def __init__(self, seconds, *, timeout_msg="", suppress_timeout_errors=True):
220
+ """Initializes a timeout context/decorator with defined seconds, optional message, and error suppression."""
221
+ self.seconds = int(seconds)
222
+ self.timeout_message = timeout_msg
223
+ self.suppress = bool(suppress_timeout_errors)
224
+
225
+ def _timeout_handler(self, signum, frame):
226
+ """Raises a TimeoutError with a custom message when a timeout event occurs."""
227
+ raise TimeoutError(self.timeout_message)
228
+
229
+ def __enter__(self):
230
+ """Initializes timeout mechanism on non-Windows platforms, starting a countdown to raise TimeoutError."""
231
+ if platform.system() != "Windows": # not supported on Windows
232
+ signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
233
+ signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
234
+
235
+ def __exit__(self, exc_type, exc_val, exc_tb):
236
+ """Disables active alarm on non-Windows systems and optionally suppresses TimeoutError if set."""
237
+ if platform.system() != "Windows":
238
+ signal.alarm(0) # Cancel SIGALRM if it's scheduled
239
+ if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
240
+ return True
241
+
242
+
243
+ class WorkingDirectory(contextlib.ContextDecorator):
244
+ """Context manager/decorator to temporarily change the working directory within a 'with' statement or decorator."""
245
+
246
+ def __init__(self, new_dir):
247
+ """Initializes a context manager/decorator to temporarily change the working directory."""
248
+ self.dir = new_dir # new dir
249
+ self.cwd = Path.cwd().resolve() # current dir
250
+
251
+ def __enter__(self):
252
+ """Temporarily changes the working directory within a 'with' statement context."""
253
+ os.chdir(self.dir)
254
+
255
+ def __exit__(self, exc_type, exc_val, exc_tb):
256
+ """Restores the original working directory upon exiting a 'with' statement context."""
257
+ os.chdir(self.cwd)
258
+
259
+
260
+ def methods(instance):
261
+ """Returns list of method names for a class/instance excluding dunder methods."""
262
+ return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
263
+
264
+
265
+ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
266
+ """Logs the arguments of the calling function, with options to include the filename and function name."""
267
+ x = inspect.currentframe().f_back # previous frame
268
+ file, _, func, _, _ = inspect.getframeinfo(x)
269
+ if args is None: # get args automatically
270
+ args, _, _, frm = inspect.getargvalues(x)
271
+ args = {k: v for k, v in frm.items() if k in args}
272
+ try:
273
+ file = Path(file).resolve().relative_to(ROOT).with_suffix("")
274
+ except ValueError:
275
+ file = Path(file).stem
276
+ s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
277
+ LOGGER.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
278
+
279
+
280
+ def init_seeds(seed=0, deterministic=False):
281
+ """
282
+ Initializes RNG seeds and sets deterministic options if specified.
283
+
284
+ See https://pytorch.org/docs/stable/notes/randomness.html
285
+ """
286
+ random.seed(seed)
287
+ np.random.seed(seed)
288
+ torch.manual_seed(seed)
289
+ torch.cuda.manual_seed(seed)
290
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
291
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
292
+ if deterministic and check_version(torch.__version__, "1.12.0"): # https://github.com/ultralytics/yolov5/pull/8213
293
+ torch.use_deterministic_algorithms(True)
294
+ torch.backends.cudnn.deterministic = True
295
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
296
+ os.environ["PYTHONHASHSEED"] = str(seed)
297
+
298
+
299
+ def intersect_dicts(da, db, exclude=()):
300
+ """Returns intersection of `da` and `db` dicts with matching keys and shapes, excluding `exclude` keys; uses `da`
301
+ values.
302
+ """
303
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
304
+
305
+
306
+ def get_default_args(func):
307
+ """Returns a dict of `func` default arguments by inspecting its signature."""
308
+ signature = inspect.signature(func)
309
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
310
+
311
+
312
+ def get_latest_run(search_dir="."):
313
+ """Returns the path to the most recent 'last.pt' file in /runs to resume from, searches in `search_dir`."""
314
+ last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
315
+ return max(last_list, key=os.path.getctime) if last_list else ""
316
+
317
+
318
+ def file_age(path=__file__):
319
+ """Calculates and returns the age of a file in days based on its last modification time."""
320
+ dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
321
+ return dt.days # + dt.seconds / 86400 # fractional days
322
+
323
+
324
+ def file_date(path=__file__):
325
+ """Returns a human-readable file modification date in 'YYYY-M-D' format, given a file path."""
326
+ t = datetime.fromtimestamp(Path(path).stat().st_mtime)
327
+ return f"{t.year}-{t.month}-{t.day}"
328
+
329
+
330
+ def file_size(path):
331
+ """Returns file or directory size in megabytes (MB) for a given path, where directories are recursively summed."""
332
+ mb = 1 << 20 # bytes to MiB (1024 ** 2)
333
+ path = Path(path)
334
+ if path.is_file():
335
+ return path.stat().st_size / mb
336
+ elif path.is_dir():
337
+ return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
338
+ else:
339
+ return 0.0
340
+
341
+
342
+ def check_online():
343
+ """Checks internet connectivity by attempting to create a connection to "1.1.1.1" on port 443, retries once if the
344
+ first attempt fails.
345
+ """
346
+ import socket
347
+
348
+ def run_once():
349
+ """Checks internet connectivity by attempting to create a connection to "1.1.1.1" on port 443."""
350
+ try:
351
+ socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
352
+ return True
353
+ except OSError:
354
+ return False
355
+
356
+ return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
357
+
358
+
359
+ def git_describe(path=ROOT):
360
+ """
361
+ Returns a human-readable git description of the repository at `path`, or an empty string on failure.
362
+
363
+ Example output is 'fv5.0-5-g3e25f1e'. See https://git-scm.com/docs/git-describe.
364
+ """
365
+ try:
366
+ assert (Path(path) / ".git").is_dir()
367
+ return check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
368
+ except Exception:
369
+ return ""
370
+
371
+
372
+ @TryExcept()
373
+ @WorkingDirectory(ROOT)
374
+ def check_git_status(repo="ultralytics/yolov5", branch="master"):
375
+ """Checks if YOLOv5 code is up-to-date with the repository, advising 'git pull' if behind; errors return informative
376
+ messages.
377
+ """
378
+ url = f"https://github.com/{repo}"
379
+ msg = f", for updates see {url}"
380
+ s = colorstr("github: ") # string
381
+ assert Path(".git").exists(), s + "skipping check (not a git repository)" + msg
382
+ assert check_online(), s + "skipping check (offline)" + msg
383
+
384
+ splits = re.split(pattern=r"\s", string=check_output("git remote -v", shell=True).decode())
385
+ matches = [repo in s for s in splits]
386
+ if any(matches):
387
+ remote = splits[matches.index(True) - 1]
388
+ else:
389
+ remote = "ultralytics"
390
+ check_output(f"git remote add {remote} {url}", shell=True)
391
+ check_output(f"git fetch {remote}", shell=True, timeout=5) # git fetch
392
+ local_branch = check_output("git rev-parse --abbrev-ref HEAD", shell=True).decode().strip() # checked out
393
+ n = int(check_output(f"git rev-list {local_branch}..{remote}/{branch} --count", shell=True)) # commits behind
394
+ if n > 0:
395
+ pull = "git pull" if remote == "origin" else f"git pull {remote} {branch}"
396
+ s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use '{pull}' or 'git clone {url}' to update."
397
+ else:
398
+ s += f"up to date with {url} ✅"
399
+ LOGGER.info(s)
400
+
401
+
402
+ @WorkingDirectory(ROOT)
403
+ def check_git_info(path="."):
404
+ """Checks YOLOv5 git info, returning a dict with remote URL, branch name, and commit hash."""
405
+ check_requirements("gitpython")
406
+ import git
407
+
408
+ try:
409
+ repo = git.Repo(path)
410
+ remote = repo.remotes.origin.url.replace(".git", "") # i.e. 'https://github.com/ultralytics/yolov5'
411
+ commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
412
+ try:
413
+ branch = repo.active_branch.name # i.e. 'main'
414
+ except TypeError: # not on any branch
415
+ branch = None # i.e. 'detached HEAD' state
416
+ return {"remote": remote, "branch": branch, "commit": commit}
417
+ except git.exc.InvalidGitRepositoryError: # path is not a git dir
418
+ return {"remote": None, "branch": None, "commit": None}
419
+
420
+
421
+ def check_python(minimum="3.8.0"):
422
+ """Checks if current Python version meets the minimum required version, exits if not."""
423
+ check_version(platform.python_version(), minimum, name="Python ", hard=True)
424
+
425
+
426
+ def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
427
+ """Checks if the current version meets the minimum required version, exits or warns based on parameters."""
428
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
429
+ result = (current == minimum) if pinned else (current >= minimum) # bool
430
+ s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string
431
+ if hard:
432
+ assert result, emojis(s) # assert min requirements met
433
+ if verbose and not result:
434
+ LOGGER.warning(s)
435
+ return result
436
+
437
+
438
+ def check_img_size(imgsz, s=32, floor=0):
439
+ """Adjusts image size to be divisible by stride `s`, supports int or list/tuple input, returns adjusted size."""
440
+ if isinstance(imgsz, int): # integer i.e. img_size=640
441
+ new_size = max(make_divisible(imgsz, int(s)), floor)
442
+ else: # list i.e. img_size=[640, 480]
443
+ imgsz = list(imgsz) # convert to list if tuple
444
+ new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
445
+ if new_size != imgsz:
446
+ LOGGER.warning(f"WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}")
447
+ return new_size
448
+
449
+
450
+ def check_imshow(warn=False):
451
+ """Checks environment support for image display; warns on failure if `warn=True`."""
452
+ try:
453
+ assert not is_jupyter()
454
+ assert not is_docker()
455
+ cv2.imshow("test", np.zeros((1, 1, 3)))
456
+ cv2.waitKey(1)
457
+ cv2.destroyAllWindows()
458
+ cv2.waitKey(1)
459
+ return True
460
+ except Exception as e:
461
+ if warn:
462
+ LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
463
+ return False
464
+
465
+
466
+ def check_suffix(file="yolov5s.pt", suffix=(".pt",), msg=""):
467
+ """Validates if a file or files have an acceptable suffix, raising an error if not."""
468
+ if file and suffix:
469
+ if isinstance(suffix, str):
470
+ suffix = [suffix]
471
+ for f in file if isinstance(file, (list, tuple)) else [file]:
472
+ s = Path(f).suffix.lower() # file suffix
473
+ if len(s):
474
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
475
+
476
+
477
+ def check_yaml(file, suffix=(".yaml", ".yml")):
478
+ """Searches/downloads a YAML file, verifies its suffix (.yaml or .yml), and returns the file path."""
479
+ return check_file(file, suffix)
480
+
481
+
482
+ def check_file(file, suffix=""):
483
+ """Searches/downloads a file, checks its suffix (if provided), and returns the file path."""
484
+ check_suffix(file, suffix) # optional
485
+ file = str(file) # convert to str()
486
+ if os.path.isfile(file) or not file: # exists
487
+ return file
488
+ elif file.startswith(("http:/", "https:/")): # download
489
+ url = file # warning: Pathlib turns :// -> :/
490
+ file = Path(urllib.parse.unquote(file).split("?")[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
491
+ if os.path.isfile(file):
492
+ LOGGER.info(f"Found {url} locally at {file}") # file already exists
493
+ else:
494
+ LOGGER.info(f"Downloading {url} to {file}...")
495
+ torch.hub.download_url_to_file(url, file)
496
+ assert Path(file).exists() and Path(file).stat().st_size > 0, f"File download failed: {url}" # check
497
+ return file
498
+ elif file.startswith("clearml://"): # ClearML Dataset ID
499
+ assert (
500
+ "clearml" in sys.modules
501
+ ), "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
502
+ return file
503
+ else: # search
504
+ files = []
505
+ for d in "data", "models", "utils": # search directories
506
+ files.extend(glob.glob(str(ROOT / d / "**" / file), recursive=True)) # find file
507
+ assert len(files), f"File not found: {file}" # assert file was found
508
+ assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
509
+ return files[0] # return file
510
+
511
+
512
+ def check_font(font=FONT, progress=False):
513
+ """Ensures specified font exists or downloads it from Ultralytics assets, optionally displaying progress."""
514
+ font = Path(font)
515
+ file = CONFIG_DIR / font.name
516
+ if not font.exists() and not file.exists():
517
+ url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{font.name}"
518
+ LOGGER.info(f"Downloading {url} to {file}...")
519
+ torch.hub.download_url_to_file(url, str(file), progress=progress)
520
+
521
+
522
+ def check_dataset(data, autodownload=True):
523
+ """Validates and/or auto-downloads a dataset, returning its configuration as a dictionary."""
524
+ # Download (optional)
525
+ extract_dir = ""
526
+ if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
527
+ download(data, dir=f"{DATASETS_DIR}/{Path(data).stem}", unzip=True, delete=False, curl=False, threads=1)
528
+ data = next((DATASETS_DIR / Path(data).stem).rglob("*.yaml"))
529
+ extract_dir, autodownload = data.parent, False
530
+
531
+ # Read yaml (optional)
532
+ if isinstance(data, (str, Path)):
533
+ data = yaml_load(data) # dictionary
534
+
535
+ # Checks
536
+ for k in "train", "val", "names":
537
+ assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
538
+ if isinstance(data["names"], (list, tuple)): # old array format
539
+ data["names"] = dict(enumerate(data["names"])) # convert to dict
540
+ assert all(isinstance(k, int) for k in data["names"].keys()), "data.yaml names keys must be integers, i.e. 2: car"
541
+ data["nc"] = len(data["names"])
542
+
543
+ # Resolve paths
544
+ path = Path(extract_dir or data.get("path") or "") # optional 'path' default to '.'
545
+ if not path.is_absolute():
546
+ path = (ROOT / path).resolve()
547
+ data["path"] = path # download scripts
548
+ for k in "train", "val", "test":
549
+ if data.get(k): # prepend path
550
+ if isinstance(data[k], str):
551
+ x = (path / data[k]).resolve()
552
+ if not x.exists() and data[k].startswith("../"):
553
+ x = (path / data[k][3:]).resolve()
554
+ data[k] = str(x)
555
+ else:
556
+ data[k] = [str((path / x).resolve()) for x in data[k]]
557
+
558
+ # Parse yaml
559
+ train, val, test, s = (data.get(x) for x in ("train", "val", "test", "download"))
560
+ if val:
561
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
562
+ if not all(x.exists() for x in val):
563
+ LOGGER.info("\nDataset not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()])
564
+ if not s or not autodownload:
565
+ raise Exception("Dataset not found ❌")
566
+ t = time.time()
567
+ if s.startswith("http") and s.endswith(".zip"): # URL
568
+ f = Path(s).name # filename
569
+ LOGGER.info(f"Downloading {s} to {f}...")
570
+ torch.hub.download_url_to_file(s, f)
571
+ Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
572
+ unzip_file(f, path=DATASETS_DIR) # unzip
573
+ Path(f).unlink() # remove zip
574
+ r = None # success
575
+ elif s.startswith("bash "): # bash script
576
+ LOGGER.info(f"Running {s} ...")
577
+ r = subprocess.run(s, shell=True)
578
+ else: # python script
579
+ r = exec(s, {"yaml": data}) # return None
580
+ dt = f"({round(time.time() - t, 1)}s)"
581
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
582
+ LOGGER.info(f"Dataset download {s}")
583
+ check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf", progress=True) # download fonts
584
+ return data # dictionary
585
+
586
+
587
+ def check_amp(model):
588
+ """Checks PyTorch AMP functionality for a model, returns True if AMP operates correctly, otherwise False."""
589
+ from models.common import AutoShape, DetectMultiBackend
590
+
591
+ def amp_allclose(model, im):
592
+ """Compares FP32 and AMP model inference outputs, ensuring they are close within a 10% absolute tolerance."""
593
+ m = AutoShape(model, verbose=False) # model
594
+ a = m(im).xywhn[0] # FP32 inference
595
+ m.amp = True
596
+ b = m(im).xywhn[0] # AMP inference
597
+ return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
598
+
599
+ prefix = colorstr("AMP: ")
600
+ device = next(model.parameters()).device # get model device
601
+ if device.type in ("cpu", "mps"):
602
+ return False # AMP only used on CUDA devices
603
+ f = ROOT / "data" / "images" / "bus.jpg" # image to check
604
+ im = f if f.exists() else "https://ultralytics.com/images/bus.jpg" if check_online() else np.ones((640, 640, 3))
605
+ try:
606
+ assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend("yolov5n.pt", device), im)
607
+ LOGGER.info(f"{prefix}checks passed ✅")
608
+ return True
609
+ except Exception:
610
+ help_url = "https://github.com/ultralytics/yolov5/issues/7908"
611
+ LOGGER.warning(f"{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}")
612
+ return False
613
+
614
+
615
+ def yaml_load(file="data.yaml"):
616
+ """Safely loads and returns the contents of a YAML file specified by `file` argument."""
617
+ with open(file, errors="ignore") as f:
618
+ return yaml.safe_load(f)
619
+
620
+
621
+ def yaml_save(file="data.yaml", data=None):
622
+ """Safely saves `data` to a YAML file specified by `file`, converting `Path` objects to strings; `data` is a
623
+ dictionary.
624
+ """
625
+ if data is None:
626
+ data = {}
627
+ with open(file, "w") as f:
628
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
629
+
630
+
631
+ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX")):
632
+ """Unzips `file` to `path` (default: file's parent), excluding filenames containing any in `exclude` (`.DS_Store`,
633
+ `__MACOSX`).
634
+ """
635
+ if path is None:
636
+ path = Path(file).parent # default path
637
+ with ZipFile(file) as zipObj:
638
+ for f in zipObj.namelist(): # list all archived filenames in the zip
639
+ if all(x not in f for x in exclude):
640
+ zipObj.extract(f, path=path)
641
+
642
+
643
+ def url2file(url):
644
+ """
645
+ Converts a URL string to a valid filename by stripping protocol, domain, and any query parameters.
646
+
647
+ Example https://url.com/file.txt?auth -> file.txt
648
+ """
649
+ url = str(Path(url)).replace(":/", "://") # Pathlib turns :// -> :/
650
+ return Path(urllib.parse.unquote(url)).name.split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth
651
+
652
+
653
+ def download(url, dir=".", unzip=True, delete=True, curl=False, threads=1, retry=3):
654
+ """Downloads and optionally unzips files concurrently, supporting retries and curl fallback."""
655
+
656
+ def download_one(url, dir):
657
+ """Downloads a single file from `url` to `dir`, with retry support and optional curl fallback."""
658
+ success = True
659
+ if os.path.isfile(url):
660
+ f = Path(url) # filename
661
+ else: # does not exist
662
+ f = dir / Path(url).name
663
+ LOGGER.info(f"Downloading {url} to {f}...")
664
+ for i in range(retry + 1):
665
+ if curl:
666
+ success = curl_download(url, f, silent=(threads > 1))
667
+ else:
668
+ torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
669
+ success = f.is_file()
670
+ if success:
671
+ break
672
+ elif i < retry:
673
+ LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
674
+ else:
675
+ LOGGER.warning(f"❌ Failed to download {url}...")
676
+
677
+ if unzip and success and (f.suffix == ".gz" or is_zipfile(f) or is_tarfile(f)):
678
+ LOGGER.info(f"Unzipping {f}...")
679
+ if is_zipfile(f):
680
+ unzip_file(f, dir) # unzip
681
+ elif is_tarfile(f):
682
+ subprocess.run(["tar", "xf", f, "--directory", f.parent], check=True) # unzip
683
+ elif f.suffix == ".gz":
684
+ subprocess.run(["tar", "xfz", f, "--directory", f.parent], check=True) # unzip
685
+ if delete:
686
+ f.unlink() # remove zip
687
+
688
+ dir = Path(dir)
689
+ dir.mkdir(parents=True, exist_ok=True) # make directory
690
+ if threads > 1:
691
+ pool = ThreadPool(threads)
692
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
693
+ pool.close()
694
+ pool.join()
695
+ else:
696
+ for u in [url] if isinstance(url, (str, Path)) else url:
697
+ download_one(u, dir)
698
+
699
+
700
+ def make_divisible(x, divisor):
701
+ """Adjusts `x` to be divisible by `divisor`, returning the nearest greater or equal value."""
702
+ if isinstance(divisor, torch.Tensor):
703
+ divisor = int(divisor.max()) # to int
704
+ return math.ceil(x / divisor) * divisor
705
+
706
+
707
+ def clean_str(s):
708
+ """Cleans a string by replacing special characters with underscore, e.g., `clean_str('#example!')` returns
709
+ '_example_'.
710
+ """
711
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
712
+
713
+
714
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
715
+ """
716
+ Generates a lambda for a sinusoidal ramp from y1 to y2 over 'steps'.
717
+
718
+ See https://arxiv.org/pdf/1812.01187.pdf for details.
719
+ """
720
+ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
721
+
722
+
723
+ def colorstr(*input):
724
+ """
725
+ Colors a string using ANSI escape codes, e.g., colorstr('blue', 'hello world').
726
+
727
+ See https://en.wikipedia.org/wiki/ANSI_escape_code.
728
+ """
729
+ *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
730
+ colors = {
731
+ "black": "\033[30m", # basic colors
732
+ "red": "\033[31m",
733
+ "green": "\033[32m",
734
+ "yellow": "\033[33m",
735
+ "blue": "\033[34m",
736
+ "magenta": "\033[35m",
737
+ "cyan": "\033[36m",
738
+ "white": "\033[37m",
739
+ "bright_black": "\033[90m", # bright colors
740
+ "bright_red": "\033[91m",
741
+ "bright_green": "\033[92m",
742
+ "bright_yellow": "\033[93m",
743
+ "bright_blue": "\033[94m",
744
+ "bright_magenta": "\033[95m",
745
+ "bright_cyan": "\033[96m",
746
+ "bright_white": "\033[97m",
747
+ "end": "\033[0m", # misc
748
+ "bold": "\033[1m",
749
+ "underline": "\033[4m",
750
+ }
751
+ return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
752
+
753
+
754
+ def labels_to_class_weights(labels, nc=80):
755
+ """Calculates class weights from labels to handle class imbalance in training; input shape: (n, 5)."""
756
+ if labels[0] is None: # no labels loaded
757
+ return torch.Tensor()
758
+
759
+ labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
760
+ classes = labels[:, 0].astype(int) # labels = [class xywh]
761
+ weights = np.bincount(classes, minlength=nc) # occurrences per class
762
+
763
+ # Prepend gridpoint count (for uCE training)
764
+ # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
765
+ # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
766
+
767
+ weights[weights == 0] = 1 # replace empty bins with 1
768
+ weights = 1 / weights # number of targets per class
769
+ weights /= weights.sum() # normalize
770
+ return torch.from_numpy(weights).float()
771
+
772
+
773
+ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
774
+ """Calculates image weights from labels using class weights for weighted sampling."""
775
+ # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
776
+ class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
777
+ return (class_weights.reshape(1, nc) * class_counts).sum(1)
778
+
779
+
780
+ def coco80_to_coco91_class():
781
+ """
782
+ Converts COCO 80-class index to COCO 91-class index used in the paper.
783
+
784
+ Reference: https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
785
+ """
786
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
787
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
788
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
789
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
790
+ return [
791
+ 1,
792
+ 2,
793
+ 3,
794
+ 4,
795
+ 5,
796
+ 6,
797
+ 7,
798
+ 8,
799
+ 9,
800
+ 10,
801
+ 11,
802
+ 13,
803
+ 14,
804
+ 15,
805
+ 16,
806
+ 17,
807
+ 18,
808
+ 19,
809
+ 20,
810
+ 21,
811
+ 22,
812
+ 23,
813
+ 24,
814
+ 25,
815
+ 27,
816
+ 28,
817
+ 31,
818
+ 32,
819
+ 33,
820
+ 34,
821
+ 35,
822
+ 36,
823
+ 37,
824
+ 38,
825
+ 39,
826
+ 40,
827
+ 41,
828
+ 42,
829
+ 43,
830
+ 44,
831
+ 46,
832
+ 47,
833
+ 48,
834
+ 49,
835
+ 50,
836
+ 51,
837
+ 52,
838
+ 53,
839
+ 54,
840
+ 55,
841
+ 56,
842
+ 57,
843
+ 58,
844
+ 59,
845
+ 60,
846
+ 61,
847
+ 62,
848
+ 63,
849
+ 64,
850
+ 65,
851
+ 67,
852
+ 70,
853
+ 72,
854
+ 73,
855
+ 74,
856
+ 75,
857
+ 76,
858
+ 77,
859
+ 78,
860
+ 79,
861
+ 80,
862
+ 81,
863
+ 82,
864
+ 84,
865
+ 85,
866
+ 86,
867
+ 87,
868
+ 88,
869
+ 89,
870
+ 90,
871
+ ]
872
+
873
+
874
+ def xyxy2xywh(x):
875
+ """Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right."""
876
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
877
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
878
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
879
+ y[..., 2] = x[..., 2] - x[..., 0] # width
880
+ y[..., 3] = x[..., 3] - x[..., 1] # height
881
+ return y
882
+
883
+
884
+ def xywh2xyxy(x):
885
+ """Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
886
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
887
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
888
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
889
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
890
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
891
+ return y
892
+
893
+
894
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
895
+ """Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right."""
896
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
897
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
898
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
899
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
900
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
901
+ return y
902
+
903
+
904
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
905
+ """Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right."""
906
+ if clip:
907
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
908
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
909
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
910
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
911
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
912
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
913
+ return y
914
+
915
+
916
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
917
+ """Convert normalized segments into pixel segments, shape (n,2)."""
918
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
919
+ y[..., 0] = w * x[..., 0] + padw # top left x
920
+ y[..., 1] = h * x[..., 1] + padh # top left y
921
+ return y
922
+
923
+
924
+ def segment2box(segment, width=640, height=640):
925
+ """Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)."""
926
+ x, y = segment.T # segment xy
927
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
928
+ (
929
+ x,
930
+ y,
931
+ ) = x[inside], y[inside]
932
+ return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
933
+
934
+
935
+ def segments2boxes(segments):
936
+ """Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)."""
937
+ boxes = []
938
+ for s in segments:
939
+ x, y = s.T # segment xy
940
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
941
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
942
+
943
+
944
+ def resample_segments(segments, n=1000):
945
+ """Resamples an (n,2) segment to a fixed number of points for consistent representation."""
946
+ for i, s in enumerate(segments):
947
+ s = np.concatenate((s, s[0:1, :]), axis=0)
948
+ x = np.linspace(0, len(s) - 1, n)
949
+ xp = np.arange(len(s))
950
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
951
+ return segments
952
+
953
+
954
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
955
+ """Rescales (xyxy) bounding boxes from img1_shape to img0_shape, optionally using provided `ratio_pad`."""
956
+ if ratio_pad is None: # calculate from img0_shape
957
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
958
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
959
+ else:
960
+ gain = ratio_pad[0][0]
961
+ pad = ratio_pad[1]
962
+
963
+ boxes[..., [0, 2]] -= pad[0] # x padding
964
+ boxes[..., [1, 3]] -= pad[1] # y padding
965
+ boxes[..., :4] /= gain
966
+ clip_boxes(boxes, img0_shape)
967
+ return boxes
968
+
969
+
970
+ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
971
+ """Rescales segment coordinates from img1_shape to img0_shape, optionally normalizing them with custom padding."""
972
+ if ratio_pad is None: # calculate from img0_shape
973
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
974
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
975
+ else:
976
+ gain = ratio_pad[0][0]
977
+ pad = ratio_pad[1]
978
+
979
+ segments[:, 0] -= pad[0] # x padding
980
+ segments[:, 1] -= pad[1] # y padding
981
+ segments /= gain
982
+ clip_segments(segments, img0_shape)
983
+ if normalize:
984
+ segments[:, 0] /= img0_shape[1] # width
985
+ segments[:, 1] /= img0_shape[0] # height
986
+ return segments
987
+
988
+
989
+ def clip_boxes(boxes, shape):
990
+ """Clips bounding box coordinates (xyxy) to fit within the specified image shape (height, width)."""
991
+ if isinstance(boxes, torch.Tensor): # faster individually
992
+ boxes[..., 0].clamp_(0, shape[1]) # x1
993
+ boxes[..., 1].clamp_(0, shape[0]) # y1
994
+ boxes[..., 2].clamp_(0, shape[1]) # x2
995
+ boxes[..., 3].clamp_(0, shape[0]) # y2
996
+ else: # np.array (faster grouped)
997
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
998
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
999
+
1000
+
1001
+ def clip_segments(segments, shape):
1002
+ """Clips segment coordinates (xy1, xy2, ...) to an image's boundaries given its shape (height, width)."""
1003
+ if isinstance(segments, torch.Tensor): # faster individually
1004
+ segments[:, 0].clamp_(0, shape[1]) # x
1005
+ segments[:, 1].clamp_(0, shape[0]) # y
1006
+ else: # np.array (faster grouped)
1007
+ segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
1008
+ segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
1009
+
1010
+
1011
+ def non_max_suppression(
1012
+ prediction,
1013
+ conf_thres=0.25,
1014
+ iou_thres=0.45,
1015
+ classes=None,
1016
+ agnostic=False,
1017
+ multi_label=False,
1018
+ labels=(),
1019
+ max_det=300,
1020
+ nm=0, # number of masks
1021
+ ):
1022
+ """
1023
+ Non-Maximum Suppression (NMS) on inference results to reject overlapping detections.
1024
+
1025
+ Returns:
1026
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
1027
+ """
1028
+ # Checks
1029
+ assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
1030
+ assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
1031
+ if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
1032
+ prediction = prediction[0] # select only inference output
1033
+
1034
+ device = prediction.device
1035
+ mps = "mps" in device.type # Apple MPS
1036
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
1037
+ prediction = prediction.cpu()
1038
+ bs = prediction.shape[0] # batch size
1039
+ nc = prediction.shape[2] - nm - 5 # number of classes
1040
+ xc = prediction[..., 4] > conf_thres # candidates
1041
+
1042
+ # Settings
1043
+ # min_wh = 2 # (pixels) minimum box width and height
1044
+ max_wh = 7680 # (pixels) maximum box width and height
1045
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
1046
+ time_limit = 0.5 + 0.05 * bs # seconds to quit after
1047
+ redundant = True # require redundant detections
1048
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
1049
+ merge = False # use merge-NMS
1050
+
1051
+ t = time.time()
1052
+ mi = 5 + nc # mask start index
1053
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
1054
+ for xi, x in enumerate(prediction): # image index, image inference
1055
+ # Apply constraints
1056
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
1057
+ x = x[xc[xi]] # confidence
1058
+
1059
+ # Cat apriori labels if autolabelling
1060
+ if labels and len(labels[xi]):
1061
+ lb = labels[xi]
1062
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
1063
+ v[:, :4] = lb[:, 1:5] # box
1064
+ v[:, 4] = 1.0 # conf
1065
+ v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
1066
+ x = torch.cat((x, v), 0)
1067
+
1068
+ # If none remain process next image
1069
+ if not x.shape[0]:
1070
+ continue
1071
+
1072
+ # Compute conf
1073
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
1074
+
1075
+ # Box/Mask
1076
+ box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
1077
+ mask = x[:, mi:] # zero columns if no masks
1078
+
1079
+ # Detections matrix nx6 (xyxy, conf, cls)
1080
+ if multi_label:
1081
+ i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
1082
+ x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
1083
+ else: # best class only
1084
+ conf, j = x[:, 5:mi].max(1, keepdim=True)
1085
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
1086
+
1087
+ # Filter by class
1088
+ if classes is not None:
1089
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
1090
+
1091
+ # Apply finite constraint
1092
+ # if not torch.isfinite(x).all():
1093
+ # x = x[torch.isfinite(x).all(1)]
1094
+
1095
+ # Check shape
1096
+ n = x.shape[0] # number of boxes
1097
+ if not n: # no boxes
1098
+ continue
1099
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
1100
+
1101
+ # Batched NMS
1102
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
1103
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
1104
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
1105
+ i = i[:max_det] # limit detections
1106
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
1107
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
1108
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
1109
+ weights = iou * scores[None] # box weights
1110
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
1111
+ if redundant:
1112
+ i = i[iou.sum(1) > 1] # require redundancy
1113
+
1114
+ output[xi] = x[i]
1115
+ if mps:
1116
+ output[xi] = output[xi].to(device)
1117
+ if (time.time() - t) > time_limit:
1118
+ LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
1119
+ break # time limit exceeded
1120
+
1121
+ return output
1122
+
1123
+
1124
+ def strip_optimizer(f="best.pt", s=""):
1125
+ """
1126
+ Strips optimizer and optionally saves checkpoint to finalize training; arguments are file path 'f' and save path
1127
+ 's'.
1128
+
1129
+ Example: from utils.general import *; strip_optimizer()
1130
+ """
1131
+ x = torch.load(f, map_location=torch.device("cpu"))
1132
+ if x.get("ema"):
1133
+ x["model"] = x["ema"] # replace model with ema
1134
+ for k in "optimizer", "best_fitness", "ema", "updates": # keys
1135
+ x[k] = None
1136
+ x["epoch"] = -1
1137
+ x["model"].half() # to FP16
1138
+ for p in x["model"].parameters():
1139
+ p.requires_grad = False
1140
+ torch.save(x, s or f)
1141
+ mb = os.path.getsize(s or f) / 1e6 # filesize
1142
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
1143
+
1144
+
1145
+ def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr("evolve: ")):
1146
+ """Logs evolution results and saves to CSV and YAML in `save_dir`, optionally syncs with `bucket`."""
1147
+ evolve_csv = save_dir / "evolve.csv"
1148
+ evolve_yaml = save_dir / "hyp_evolve.yaml"
1149
+ keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
1150
+ keys = tuple(x.strip() for x in keys)
1151
+ vals = results + tuple(hyp.values())
1152
+ n = len(keys)
1153
+
1154
+ # Download (optional)
1155
+ if bucket:
1156
+ url = f"gs://{bucket}/evolve.csv"
1157
+ if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
1158
+ subprocess.run(["gsutil", "cp", f"{url}", f"{save_dir}"]) # download evolve.csv if larger than local
1159
+
1160
+ # Log to evolve.csv
1161
+ s = "" if evolve_csv.exists() else (("%20s," * n % keys).rstrip(",") + "\n") # add header
1162
+ with open(evolve_csv, "a") as f:
1163
+ f.write(s + ("%20.5g," * n % vals).rstrip(",") + "\n")
1164
+
1165
+ # Save yaml
1166
+ with open(evolve_yaml, "w") as f:
1167
+ data = pd.read_csv(evolve_csv, skipinitialspace=True)
1168
+ data = data.rename(columns=lambda x: x.strip()) # strip keys
1169
+ i = np.argmax(fitness(data.values[:, :4])) #
1170
+ generations = len(data)
1171
+ f.write(
1172
+ "# YOLOv5 Hyperparameter Evolution Results\n"
1173
+ + f"# Best generation: {i}\n"
1174
+ + f"# Last generation: {generations - 1}\n"
1175
+ + "# "
1176
+ + ", ".join(f"{x.strip():>20s}" for x in keys[:7])
1177
+ + "\n"
1178
+ + "# "
1179
+ + ", ".join(f"{x:>20.5g}" for x in data.values[i, :7])
1180
+ + "\n\n"
1181
+ )
1182
+ yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
1183
+
1184
+ # Print to screen
1185
+ LOGGER.info(
1186
+ prefix
1187
+ + f"{generations} generations finished, current result:\n"
1188
+ + prefix
1189
+ + ", ".join(f"{x.strip():>20s}" for x in keys)
1190
+ + "\n"
1191
+ + prefix
1192
+ + ", ".join(f"{x:20.5g}" for x in vals)
1193
+ + "\n\n"
1194
+ )
1195
+
1196
+ if bucket:
1197
+ subprocess.run(["gsutil", "cp", f"{evolve_csv}", f"{evolve_yaml}", f"gs://{bucket}"]) # upload
1198
+
1199
+
1200
+ def apply_classifier(x, model, img, im0):
1201
+ """Applies second-stage classifier to YOLO outputs, filtering detections by class match."""
1202
+ # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
1203
+ im0 = [im0] if isinstance(im0, np.ndarray) else im0
1204
+ for i, d in enumerate(x): # per image
1205
+ if d is not None and len(d):
1206
+ d = d.clone()
1207
+
1208
+ # Reshape and pad cutouts
1209
+ b = xyxy2xywh(d[:, :4]) # boxes
1210
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
1211
+ b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
1212
+ d[:, :4] = xywh2xyxy(b).long()
1213
+
1214
+ # Rescale boxes from img_size to im0 size
1215
+ scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
1216
+
1217
+ # Classes
1218
+ pred_cls1 = d[:, 5].long()
1219
+ ims = []
1220
+ for a in d:
1221
+ cutout = im0[i][int(a[1]) : int(a[3]), int(a[0]) : int(a[2])]
1222
+ im = cv2.resize(cutout, (224, 224)) # BGR
1223
+
1224
+ im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
1225
+ im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
1226
+ im /= 255 # 0 - 255 to 0.0 - 1.0
1227
+ ims.append(im)
1228
+
1229
+ pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
1230
+ x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
1231
+
1232
+ return x
1233
+
1234
+
1235
+ def increment_path(path, exist_ok=False, sep="", mkdir=False):
1236
+ """
1237
+ Generates an incremented file or directory path if it exists, with optional mkdir; args: path, exist_ok=False,
1238
+ sep="", mkdir=False.
1239
+
1240
+ Example: runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc
1241
+ """
1242
+ path = Path(path) # os-agnostic
1243
+ if path.exists() and not exist_ok:
1244
+ path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
1245
+
1246
+ # Method 1
1247
+ for n in range(2, 9999):
1248
+ p = f"{path}{sep}{n}{suffix}" # increment path
1249
+ if not os.path.exists(p): #
1250
+ break
1251
+ path = Path(p)
1252
+
1253
+ # Method 2 (deprecated)
1254
+ # dirs = glob.glob(f"{path}{sep}*") # similar paths
1255
+ # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
1256
+ # i = [int(m.groups()[0]) for m in matches if m] # indices
1257
+ # n = max(i) + 1 if i else 2 # increment number
1258
+ # path = Path(f"{path}{sep}{n}{suffix}") # increment path
1259
+
1260
+ if mkdir:
1261
+ path.mkdir(parents=True, exist_ok=True) # make directory
1262
+
1263
+ return path
1264
+
1265
+
1266
+ # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------
1267
+ imshow_ = cv2.imshow # copy to avoid recursion errors
1268
+
1269
+
1270
+ def imread(filename, flags=cv2.IMREAD_COLOR):
1271
+ """Reads an image from a file and returns it as a numpy array, using OpenCV's imdecode to support multilanguage
1272
+ paths.
1273
+ """
1274
+ return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
1275
+
1276
+
1277
+ def imwrite(filename, img):
1278
+ """Writes an image to a file, returns True on success and False on failure, supports multilanguage paths."""
1279
+ try:
1280
+ cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
1281
+ return True
1282
+ except Exception:
1283
+ return False
1284
+
1285
+
1286
+ def imshow(path, im):
1287
+ """Displays an image using Unicode path, requires encoded path and image matrix as input."""
1288
+ imshow_(path.encode("unicode_escape").decode(), im)
1289
+
1290
+
1291
+ if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
1292
+ cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
1293
+
1294
+ # Variables ------------------------------------------------------------------------------------------------------------