megadetector 5.0.5__py3-none-any.whl → 5.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of megadetector might be problematic. Click here for more details.

Files changed (132) hide show
  1. api/batch_processing/data_preparation/manage_local_batch.py +302 -263
  2. api/batch_processing/data_preparation/manage_video_batch.py +81 -2
  3. api/batch_processing/postprocessing/add_max_conf.py +1 -0
  4. api/batch_processing/postprocessing/categorize_detections_by_size.py +50 -19
  5. api/batch_processing/postprocessing/compare_batch_results.py +110 -60
  6. api/batch_processing/postprocessing/load_api_results.py +56 -70
  7. api/batch_processing/postprocessing/md_to_coco.py +1 -1
  8. api/batch_processing/postprocessing/md_to_labelme.py +2 -1
  9. api/batch_processing/postprocessing/postprocess_batch_results.py +240 -81
  10. api/batch_processing/postprocessing/render_detection_confusion_matrix.py +625 -0
  11. api/batch_processing/postprocessing/repeat_detection_elimination/find_repeat_detections.py +71 -23
  12. api/batch_processing/postprocessing/repeat_detection_elimination/remove_repeat_detections.py +1 -1
  13. api/batch_processing/postprocessing/repeat_detection_elimination/repeat_detections_core.py +227 -75
  14. api/batch_processing/postprocessing/subset_json_detector_output.py +132 -5
  15. api/batch_processing/postprocessing/top_folders_to_bottom.py +1 -1
  16. api/synchronous/api_core/animal_detection_api/detection/run_detector_batch.py +2 -2
  17. classification/prepare_classification_script.py +191 -191
  18. data_management/coco_to_yolo.py +68 -45
  19. data_management/databases/integrity_check_json_db.py +7 -5
  20. data_management/generate_crops_from_cct.py +3 -3
  21. data_management/get_image_sizes.py +8 -6
  22. data_management/importers/add_timestamps_to_icct.py +79 -0
  23. data_management/importers/animl_results_to_md_results.py +160 -0
  24. data_management/importers/auckland_doc_test_to_json.py +4 -4
  25. data_management/importers/auckland_doc_to_json.py +1 -1
  26. data_management/importers/awc_to_json.py +5 -5
  27. data_management/importers/bellevue_to_json.py +5 -5
  28. data_management/importers/carrizo_shrubfree_2018.py +5 -5
  29. data_management/importers/carrizo_trail_cam_2017.py +5 -5
  30. data_management/importers/cct_field_adjustments.py +2 -3
  31. data_management/importers/channel_islands_to_cct.py +4 -4
  32. data_management/importers/ena24_to_json.py +5 -5
  33. data_management/importers/helena_to_cct.py +10 -10
  34. data_management/importers/idaho-camera-traps.py +12 -12
  35. data_management/importers/idfg_iwildcam_lila_prep.py +8 -8
  36. data_management/importers/jb_csv_to_json.py +4 -4
  37. data_management/importers/missouri_to_json.py +1 -1
  38. data_management/importers/noaa_seals_2019.py +1 -1
  39. data_management/importers/pc_to_json.py +5 -5
  40. data_management/importers/prepare-noaa-fish-data-for-lila.py +4 -4
  41. data_management/importers/prepare_zsl_imerit.py +5 -5
  42. data_management/importers/rspb_to_json.py +4 -4
  43. data_management/importers/save_the_elephants_survey_A.py +5 -5
  44. data_management/importers/save_the_elephants_survey_B.py +6 -6
  45. data_management/importers/snapshot_safari_importer.py +9 -9
  46. data_management/importers/snapshot_serengeti_lila.py +9 -9
  47. data_management/importers/timelapse_csv_set_to_json.py +5 -7
  48. data_management/importers/ubc_to_json.py +4 -4
  49. data_management/importers/umn_to_json.py +4 -4
  50. data_management/importers/wellington_to_json.py +1 -1
  51. data_management/importers/wi_to_json.py +2 -2
  52. data_management/importers/zamba_results_to_md_results.py +181 -0
  53. data_management/labelme_to_coco.py +35 -7
  54. data_management/labelme_to_yolo.py +229 -0
  55. data_management/lila/add_locations_to_island_camera_traps.py +1 -1
  56. data_management/lila/add_locations_to_nacti.py +147 -0
  57. data_management/lila/create_lila_blank_set.py +474 -0
  58. data_management/lila/create_lila_test_set.py +2 -1
  59. data_management/lila/create_links_to_md_results_files.py +106 -0
  60. data_management/lila/download_lila_subset.py +46 -21
  61. data_management/lila/generate_lila_per_image_labels.py +23 -14
  62. data_management/lila/get_lila_annotation_counts.py +17 -11
  63. data_management/lila/lila_common.py +14 -11
  64. data_management/lila/test_lila_metadata_urls.py +116 -0
  65. data_management/ocr_tools.py +829 -0
  66. data_management/resize_coco_dataset.py +13 -11
  67. data_management/yolo_output_to_md_output.py +84 -12
  68. data_management/yolo_to_coco.py +38 -20
  69. detection/process_video.py +36 -14
  70. detection/pytorch_detector.py +23 -8
  71. detection/run_detector.py +76 -19
  72. detection/run_detector_batch.py +178 -63
  73. detection/run_inference_with_yolov5_val.py +326 -57
  74. detection/run_tiled_inference.py +153 -43
  75. detection/video_utils.py +34 -8
  76. md_utils/ct_utils.py +172 -1
  77. md_utils/md_tests.py +372 -51
  78. md_utils/path_utils.py +167 -39
  79. md_utils/process_utils.py +26 -7
  80. md_utils/split_locations_into_train_val.py +215 -0
  81. md_utils/string_utils.py +10 -0
  82. md_utils/url_utils.py +0 -2
  83. md_utils/write_html_image_list.py +9 -26
  84. md_visualization/plot_utils.py +12 -8
  85. md_visualization/visualization_utils.py +106 -7
  86. md_visualization/visualize_db.py +16 -8
  87. md_visualization/visualize_detector_output.py +208 -97
  88. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/METADATA +3 -6
  89. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/RECORD +98 -121
  90. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/WHEEL +1 -1
  91. taxonomy_mapping/map_lila_taxonomy_to_wi_taxonomy.py +1 -1
  92. taxonomy_mapping/map_new_lila_datasets.py +43 -39
  93. taxonomy_mapping/prepare_lila_taxonomy_release.py +5 -2
  94. taxonomy_mapping/preview_lila_taxonomy.py +27 -27
  95. taxonomy_mapping/species_lookup.py +33 -13
  96. taxonomy_mapping/taxonomy_csv_checker.py +7 -5
  97. api/synchronous/api_core/yolov5/detect.py +0 -252
  98. api/synchronous/api_core/yolov5/export.py +0 -607
  99. api/synchronous/api_core/yolov5/hubconf.py +0 -146
  100. api/synchronous/api_core/yolov5/models/__init__.py +0 -0
  101. api/synchronous/api_core/yolov5/models/common.py +0 -738
  102. api/synchronous/api_core/yolov5/models/experimental.py +0 -104
  103. api/synchronous/api_core/yolov5/models/tf.py +0 -574
  104. api/synchronous/api_core/yolov5/models/yolo.py +0 -338
  105. api/synchronous/api_core/yolov5/train.py +0 -670
  106. api/synchronous/api_core/yolov5/utils/__init__.py +0 -36
  107. api/synchronous/api_core/yolov5/utils/activations.py +0 -103
  108. api/synchronous/api_core/yolov5/utils/augmentations.py +0 -284
  109. api/synchronous/api_core/yolov5/utils/autoanchor.py +0 -170
  110. api/synchronous/api_core/yolov5/utils/autobatch.py +0 -66
  111. api/synchronous/api_core/yolov5/utils/aws/__init__.py +0 -0
  112. api/synchronous/api_core/yolov5/utils/aws/resume.py +0 -40
  113. api/synchronous/api_core/yolov5/utils/benchmarks.py +0 -148
  114. api/synchronous/api_core/yolov5/utils/callbacks.py +0 -71
  115. api/synchronous/api_core/yolov5/utils/dataloaders.py +0 -1087
  116. api/synchronous/api_core/yolov5/utils/downloads.py +0 -178
  117. api/synchronous/api_core/yolov5/utils/flask_rest_api/example_request.py +0 -19
  118. api/synchronous/api_core/yolov5/utils/flask_rest_api/restapi.py +0 -46
  119. api/synchronous/api_core/yolov5/utils/general.py +0 -1018
  120. api/synchronous/api_core/yolov5/utils/loggers/__init__.py +0 -187
  121. api/synchronous/api_core/yolov5/utils/loggers/wandb/__init__.py +0 -0
  122. api/synchronous/api_core/yolov5/utils/loggers/wandb/log_dataset.py +0 -27
  123. api/synchronous/api_core/yolov5/utils/loggers/wandb/sweep.py +0 -41
  124. api/synchronous/api_core/yolov5/utils/loggers/wandb/wandb_utils.py +0 -577
  125. api/synchronous/api_core/yolov5/utils/loss.py +0 -234
  126. api/synchronous/api_core/yolov5/utils/metrics.py +0 -355
  127. api/synchronous/api_core/yolov5/utils/plots.py +0 -489
  128. api/synchronous/api_core/yolov5/utils/torch_utils.py +0 -314
  129. api/synchronous/api_core/yolov5/val.py +0 -394
  130. md_utils/matlab_porting_tools.py +0 -97
  131. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/LICENSE +0 -0
  132. {megadetector-5.0.5.dist-info → megadetector-5.0.7.dist-info}/top_level.txt +0 -0
@@ -1,1018 +0,0 @@
1
- # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
- """
3
- General utils
4
- """
5
-
6
- import contextlib
7
- import glob
8
- import inspect
9
- import logging
10
- import math
11
- import os
12
- import platform
13
- import random
14
- import re
15
- import shutil
16
- import signal
17
- import threading
18
- import time
19
- import urllib
20
- from datetime import datetime
21
- from itertools import repeat
22
- from multiprocessing.pool import ThreadPool
23
- from pathlib import Path
24
- from subprocess import check_output
25
- from typing import Optional
26
- from zipfile import ZipFile
27
-
28
- import cv2
29
- import numpy as np
30
- import pandas as pd
31
- import pkg_resources as pkg
32
- import torch
33
- import torchvision
34
- import yaml
35
-
36
- from utils.downloads import gsutil_getsize
37
- from utils.metrics import box_iou, fitness
38
-
39
- FILE = Path(__file__).resolve()
40
- ROOT = FILE.parents[1] # YOLOv5 root directory
41
- RANK = int(os.getenv('RANK', -1))
42
-
43
- # Settings
44
- DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
45
- NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
46
- AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
47
- VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
48
- FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
49
-
50
- torch.set_printoptions(linewidth=320, precision=5, profile='long')
51
- np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
52
- pd.options.display.max_columns = 10
53
- cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
54
- os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
55
- os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
56
-
57
-
58
- def is_kaggle():
59
- # Is environment a Kaggle Notebook?
60
- try:
61
- assert os.environ.get('PWD') == '/kaggle/working'
62
- assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
63
- return True
64
- except AssertionError:
65
- return False
66
-
67
-
68
- def is_writeable(dir, test=False):
69
- # Return True if directory has write permissions, test opening a file with write permissions if test=True
70
- if not test:
71
- return os.access(dir, os.R_OK) # possible issues on Windows
72
- file = Path(dir) / 'tmp.txt'
73
- try:
74
- with open(file, 'w'): # open file with write permissions
75
- pass
76
- file.unlink() # remove file
77
- return True
78
- except OSError:
79
- return False
80
-
81
-
82
- def set_logging(name=None, verbose=VERBOSE):
83
- # Sets level and returns logger
84
- if is_kaggle():
85
- for h in logging.root.handlers:
86
- logging.root.removeHandler(h) # remove all handlers associated with the root logger object
87
- rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
88
- level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING
89
- log = logging.getLogger(name)
90
- log.setLevel(level)
91
- handler = logging.StreamHandler()
92
- handler.setFormatter(logging.Formatter("%(message)s"))
93
- handler.setLevel(level)
94
- log.addHandler(handler)
95
-
96
-
97
- set_logging() # run before defining LOGGER
98
- LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
99
-
100
-
101
- def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
102
- # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
103
- env = os.getenv(env_var)
104
- if env:
105
- path = Path(env) # use environment variable
106
- else:
107
- cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
108
- path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
109
- path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
110
- path.mkdir(exist_ok=True) # make if required
111
- return path
112
-
113
-
114
- CONFIG_DIR = user_config_dir() # Ultralytics settings dir
115
-
116
-
117
- class Profile(contextlib.ContextDecorator):
118
- # Usage: @Profile() decorator or 'with Profile():' context manager
119
- def __enter__(self):
120
- self.start = time.time()
121
-
122
- def __exit__(self, type, value, traceback):
123
- print(f'Profile results: {time.time() - self.start:.5f}s')
124
-
125
-
126
- class Timeout(contextlib.ContextDecorator):
127
- # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
128
- def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
129
- self.seconds = int(seconds)
130
- self.timeout_message = timeout_msg
131
- self.suppress = bool(suppress_timeout_errors)
132
-
133
- def _timeout_handler(self, signum, frame):
134
- raise TimeoutError(self.timeout_message)
135
-
136
- def __enter__(self):
137
- if platform.system() != 'Windows': # not supported on Windows
138
- signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
139
- signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
140
-
141
- def __exit__(self, exc_type, exc_val, exc_tb):
142
- if platform.system() != 'Windows':
143
- signal.alarm(0) # Cancel SIGALRM if it's scheduled
144
- if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
145
- return True
146
-
147
-
148
- class WorkingDirectory(contextlib.ContextDecorator):
149
- # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
150
- def __init__(self, new_dir):
151
- self.dir = new_dir # new dir
152
- self.cwd = Path.cwd().resolve() # current dir
153
-
154
- def __enter__(self):
155
- os.chdir(self.dir)
156
-
157
- def __exit__(self, exc_type, exc_val, exc_tb):
158
- os.chdir(self.cwd)
159
-
160
-
161
- def try_except(func):
162
- # try-except function. Usage: @try_except decorator
163
- def handler(*args, **kwargs):
164
- try:
165
- func(*args, **kwargs)
166
- except Exception as e:
167
- print(e)
168
-
169
- return handler
170
-
171
-
172
- def threaded(func):
173
- # Multi-threads a target function and returns thread. Usage: @threaded decorator
174
- def wrapper(*args, **kwargs):
175
- thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
176
- thread.start()
177
- return thread
178
-
179
- return wrapper
180
-
181
-
182
- def methods(instance):
183
- # Get class/instance methods
184
- return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
185
-
186
-
187
- def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
188
- # Print function arguments (optional args dict)
189
- x = inspect.currentframe().f_back # previous frame
190
- file, _, fcn, _, _ = inspect.getframeinfo(x)
191
- if args is None: # get args automatically
192
- args, _, _, frm = inspect.getargvalues(x)
193
- args = {k: v for k, v in frm.items() if k in args}
194
- s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
195
- LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
196
-
197
-
198
- def init_seeds(seed=0):
199
- # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
200
- # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
201
- import torch.backends.cudnn as cudnn
202
- random.seed(seed)
203
- np.random.seed(seed)
204
- torch.manual_seed(seed)
205
- cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
206
-
207
-
208
- def intersect_dicts(da, db, exclude=()):
209
- # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
210
- return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
211
-
212
-
213
- def get_latest_run(search_dir='.'):
214
- # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
215
- last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
216
- return max(last_list, key=os.path.getctime) if last_list else ''
217
-
218
-
219
- def is_docker():
220
- # Is environment a Docker container?
221
- return Path('/workspace').exists() # or Path('/.dockerenv').exists()
222
-
223
-
224
- def is_colab():
225
- # Is environment a Google Colab instance?
226
- try:
227
- import google.colab
228
- return True
229
- except ImportError:
230
- return False
231
-
232
-
233
- def is_pip():
234
- # Is file in a pip package?
235
- return 'site-packages' in Path(__file__).resolve().parts
236
-
237
-
238
- def is_ascii(s=''):
239
- # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
240
- s = str(s) # convert list, tuple, None, etc. to str
241
- return len(s.encode().decode('ascii', 'ignore')) == len(s)
242
-
243
-
244
- def is_chinese(s='人工智能'):
245
- # Is string composed of any Chinese characters?
246
- return bool(re.search('[\u4e00-\u9fff]', str(s)))
247
-
248
-
249
- def emojis(str=''):
250
- # Return platform-dependent emoji-safe version of string
251
- return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
252
-
253
-
254
- def file_age(path=__file__):
255
- # Return days since last file update
256
- dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
257
- return dt.days # + dt.seconds / 86400 # fractional days
258
-
259
-
260
- def file_date(path=__file__):
261
- # Return human-readable file modification date, i.e. '2021-3-26'
262
- t = datetime.fromtimestamp(Path(path).stat().st_mtime)
263
- return f'{t.year}-{t.month}-{t.day}'
264
-
265
-
266
- def file_size(path):
267
- # Return file/dir size (MB)
268
- mb = 1 << 20 # bytes to MiB (1024 ** 2)
269
- path = Path(path)
270
- if path.is_file():
271
- return path.stat().st_size / mb
272
- elif path.is_dir():
273
- return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
274
- else:
275
- return 0.0
276
-
277
-
278
- def check_online():
279
- # Check internet connectivity
280
- import socket
281
- try:
282
- socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
283
- return True
284
- except OSError:
285
- return False
286
-
287
-
288
- def git_describe(path=ROOT): # path must be a directory
289
- # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
290
- try:
291
- assert (Path(path) / '.git').is_dir()
292
- return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
293
- except Exception:
294
- return ''
295
-
296
-
297
- @try_except
298
- @WorkingDirectory(ROOT)
299
- def check_git_status():
300
- # Recommend 'git pull' if code is out of date
301
- msg = ', for updates see https://github.com/ultralytics/yolov5'
302
- s = colorstr('github: ') # string
303
- assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
304
- assert not is_docker(), s + 'skipping check (Docker image)' + msg
305
- assert check_online(), s + 'skipping check (offline)' + msg
306
-
307
- cmd = 'git fetch && git config --get remote.origin.url'
308
- url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
309
- branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
310
- n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
311
- if n > 0:
312
- s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
313
- else:
314
- s += f'up to date with {url} ✅'
315
- LOGGER.info(emojis(s)) # emoji-safe
316
-
317
-
318
- def check_python(minimum='3.7.0'):
319
- # Check current python version vs. required python version
320
- check_version(platform.python_version(), minimum, name='Python ', hard=True)
321
-
322
-
323
- def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
324
- # Check version vs. required version
325
- current, minimum = (pkg.parse_version(x) for x in (current, minimum))
326
- result = (current == minimum) if pinned else (current >= minimum) # bool
327
- s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
328
- if hard:
329
- assert result, s # assert min requirements met
330
- if verbose and not result:
331
- LOGGER.warning(s)
332
- return result
333
-
334
-
335
- @try_except
336
- def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
337
- # Check installed dependencies meet requirements (pass *.txt file or list of packages)
338
- prefix = colorstr('red', 'bold', 'requirements:')
339
- check_python() # check python version
340
- if isinstance(requirements, (str, Path)): # requirements.txt file
341
- file = Path(requirements)
342
- assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
343
- with file.open() as f:
344
- requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
345
- else: # list or tuple of packages
346
- requirements = [x for x in requirements if x not in exclude]
347
-
348
- n = 0 # number of packages updates
349
- for i, r in enumerate(requirements):
350
- try:
351
- pkg.require(r)
352
- except Exception: # DistributionNotFound or VersionConflict if requirements not met
353
- s = f"{prefix} {r} not found and is required by YOLOv5"
354
- if install and AUTOINSTALL: # check environment variable
355
- LOGGER.info(f"{s}, attempting auto-update...")
356
- try:
357
- assert check_online(), f"'pip install {r}' skipped (offline)"
358
- LOGGER.info(check_output(f'pip install "{r}" {cmds[i] if cmds else ""}', shell=True).decode())
359
- n += 1
360
- except Exception as e:
361
- LOGGER.warning(f'{prefix} {e}')
362
- else:
363
- LOGGER.info(f'{s}. Please install and rerun your command.')
364
-
365
- if n: # if packages updated
366
- source = file.resolve() if 'file' in locals() else requirements
367
- s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
368
- f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
369
- LOGGER.info(emojis(s))
370
-
371
-
372
- def check_img_size(imgsz, s=32, floor=0):
373
- # Verify image size is a multiple of stride s in each dimension
374
- if isinstance(imgsz, int): # integer i.e. img_size=640
375
- new_size = max(make_divisible(imgsz, int(s)), floor)
376
- else: # list i.e. img_size=[640, 480]
377
- imgsz = list(imgsz) # convert to list if tuple
378
- new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
379
- if new_size != imgsz:
380
- LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
381
- return new_size
382
-
383
-
384
- def check_imshow():
385
- # Check if environment supports image displays
386
- try:
387
- assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
388
- assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
389
- cv2.imshow('test', np.zeros((1, 1, 3)))
390
- cv2.waitKey(1)
391
- cv2.destroyAllWindows()
392
- cv2.waitKey(1)
393
- return True
394
- except Exception as e:
395
- LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
396
- return False
397
-
398
-
399
- def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
400
- # Check file(s) for acceptable suffix
401
- if file and suffix:
402
- if isinstance(suffix, str):
403
- suffix = [suffix]
404
- for f in file if isinstance(file, (list, tuple)) else [file]:
405
- s = Path(f).suffix.lower() # file suffix
406
- if len(s):
407
- assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
408
-
409
-
410
- def check_yaml(file, suffix=('.yaml', '.yml')):
411
- # Search/download YAML file (if necessary) and return path, checking suffix
412
- return check_file(file, suffix)
413
-
414
-
415
- def check_file(file, suffix=''):
416
- # Search/download file (if necessary) and return path
417
- check_suffix(file, suffix) # optional
418
- file = str(file) # convert to str()
419
- if Path(file).is_file() or not file: # exists
420
- return file
421
- elif file.startswith(('http:/', 'https:/')): # download
422
- url = file # warning: Pathlib turns :// -> :/
423
- file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
424
- if Path(file).is_file():
425
- LOGGER.info(f'Found {url} locally at {file}') # file already exists
426
- else:
427
- LOGGER.info(f'Downloading {url} to {file}...')
428
- torch.hub.download_url_to_file(url, file)
429
- assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
430
- return file
431
- else: # search
432
- files = []
433
- for d in 'data', 'models', 'utils': # search directories
434
- files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
435
- assert len(files), f'File not found: {file}' # assert file was found
436
- assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
437
- return files[0] # return file
438
-
439
-
440
- def check_font(font=FONT, progress=False):
441
- # Download font to CONFIG_DIR if necessary
442
- font = Path(font)
443
- file = CONFIG_DIR / font.name
444
- if not font.exists() and not file.exists():
445
- url = "https://ultralytics.com/assets/" + font.name
446
- LOGGER.info(f'Downloading {url} to {file}...')
447
- torch.hub.download_url_to_file(url, str(file), progress=progress)
448
-
449
-
450
- def check_dataset(data, autodownload=True):
451
- # Download, check and/or unzip dataset if not found locally
452
-
453
- # Download (optional)
454
- extract_dir = ''
455
- if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
456
- download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
457
- data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
458
- extract_dir, autodownload = data.parent, False
459
-
460
- # Read yaml (optional)
461
- if isinstance(data, (str, Path)):
462
- with open(data, errors='ignore') as f:
463
- data = yaml.safe_load(f) # dictionary
464
-
465
- # Checks
466
- for k in 'train', 'val', 'nc':
467
- assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
468
- if 'names' not in data:
469
- LOGGER.warning(emojis("data.yaml 'names:' field missing ⚠, assigning default names 'class0', 'class1', etc."))
470
- data['names'] = [f'class{i}' for i in range(data['nc'])] # default names
471
-
472
- # Resolve paths
473
- path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
474
- if not path.is_absolute():
475
- path = (ROOT / path).resolve()
476
- for k in 'train', 'val', 'test':
477
- if data.get(k): # prepend path
478
- data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
479
-
480
- # Parse yaml
481
- train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
482
- if val:
483
- val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
484
- if not all(x.exists() for x in val):
485
- LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
486
- if not s or not autodownload:
487
- raise Exception(emojis('Dataset not found ❌'))
488
- t = time.time()
489
- root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
490
- if s.startswith('http') and s.endswith('.zip'): # URL
491
- f = Path(s).name # filename
492
- LOGGER.info(f'Downloading {s} to {f}...')
493
- torch.hub.download_url_to_file(s, f)
494
- Path(root).mkdir(parents=True, exist_ok=True) # create root
495
- ZipFile(f).extractall(path=root) # unzip
496
- Path(f).unlink() # remove zip
497
- r = None # success
498
- elif s.startswith('bash '): # bash script
499
- LOGGER.info(f'Running {s} ...')
500
- r = os.system(s)
501
- else: # python script
502
- r = exec(s, {'yaml': data}) # return None
503
- dt = f'({round(time.time() - t, 1)}s)'
504
- s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
505
- LOGGER.info(emojis(f"Dataset download {s}"))
506
- check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
507
- return data # dictionary
508
-
509
-
510
- def check_amp(model):
511
- # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
512
- from models.common import AutoShape, DetectMultiBackend
513
-
514
- def amp_allclose(model, im):
515
- # All close FP32 vs AMP results
516
- m = AutoShape(model, verbose=False) # model
517
- a = m(im).xywhn[0] # FP32 inference
518
- m.amp = True
519
- b = m(im).xywhn[0] # AMP inference
520
- return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
521
-
522
- prefix = colorstr('AMP: ')
523
- device = next(model.parameters()).device # get model device
524
- if device.type == 'cpu':
525
- return False # AMP disabled on CPU
526
- f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
527
- im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
528
- try:
529
- assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
530
- LOGGER.info(emojis(f'{prefix}checks passed ✅'))
531
- return True
532
- except Exception:
533
- help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
534
- LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
535
- return False
536
-
537
-
538
- def url2file(url):
539
- # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
540
- url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
541
- return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
542
-
543
-
544
- def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
545
- # Multi-threaded file download and unzip function, used in data.yaml for autodownload
546
- def download_one(url, dir):
547
- # Download 1 file
548
- success = True
549
- f = dir / Path(url).name # filename
550
- if Path(url).is_file(): # exists in current path
551
- Path(url).rename(f) # move to dir
552
- elif not f.exists():
553
- LOGGER.info(f'Downloading {url} to {f}...')
554
- for i in range(retry + 1):
555
- if curl:
556
- s = 'sS' if threads > 1 else '' # silent
557
- r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
558
- success = r == 0
559
- else:
560
- torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
561
- success = f.is_file()
562
- if success:
563
- break
564
- elif i < retry:
565
- LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
566
- else:
567
- LOGGER.warning(f'Failed to download {url}...')
568
-
569
- if unzip and success and f.suffix in ('.zip', '.gz'):
570
- LOGGER.info(f'Unzipping {f}...')
571
- if f.suffix == '.zip':
572
- ZipFile(f).extractall(path=dir) # unzip
573
- elif f.suffix == '.gz':
574
- os.system(f'tar xfz {f} --directory {f.parent}') # unzip
575
- if delete:
576
- f.unlink() # remove zip
577
-
578
- dir = Path(dir)
579
- dir.mkdir(parents=True, exist_ok=True) # make directory
580
- if threads > 1:
581
- pool = ThreadPool(threads)
582
- pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
583
- pool.close()
584
- pool.join()
585
- else:
586
- for u in [url] if isinstance(url, (str, Path)) else url:
587
- download_one(u, dir)
588
-
589
-
590
- def make_divisible(x, divisor):
591
- # Returns nearest x divisible by divisor
592
- if isinstance(divisor, torch.Tensor):
593
- divisor = int(divisor.max()) # to int
594
- return math.ceil(x / divisor) * divisor
595
-
596
-
597
- def clean_str(s):
598
- # Cleans a string by replacing special characters with underscore _
599
- return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
600
-
601
-
602
- def one_cycle(y1=0.0, y2=1.0, steps=100):
603
- # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
604
- return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
605
-
606
-
607
- def colorstr(*input):
608
- # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
609
- *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
610
- colors = {
611
- 'black': '\033[30m', # basic colors
612
- 'red': '\033[31m',
613
- 'green': '\033[32m',
614
- 'yellow': '\033[33m',
615
- 'blue': '\033[34m',
616
- 'magenta': '\033[35m',
617
- 'cyan': '\033[36m',
618
- 'white': '\033[37m',
619
- 'bright_black': '\033[90m', # bright colors
620
- 'bright_red': '\033[91m',
621
- 'bright_green': '\033[92m',
622
- 'bright_yellow': '\033[93m',
623
- 'bright_blue': '\033[94m',
624
- 'bright_magenta': '\033[95m',
625
- 'bright_cyan': '\033[96m',
626
- 'bright_white': '\033[97m',
627
- 'end': '\033[0m', # misc
628
- 'bold': '\033[1m',
629
- 'underline': '\033[4m'}
630
- return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
631
-
632
-
633
- def labels_to_class_weights(labels, nc=80):
634
- # Get class weights (inverse frequency) from training labels
635
- if labels[0] is None: # no labels loaded
636
- return torch.Tensor()
637
-
638
- labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
639
- classes = labels[:, 0].astype(np.int) # labels = [class xywh]
640
- weights = np.bincount(classes, minlength=nc) # occurrences per class
641
-
642
- # Prepend gridpoint count (for uCE training)
643
- # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
644
- # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
645
-
646
- weights[weights == 0] = 1 # replace empty bins with 1
647
- weights = 1 / weights # number of targets per class
648
- weights /= weights.sum() # normalize
649
- return torch.from_numpy(weights)
650
-
651
-
652
- def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
653
- # Produces image weights based on class_weights and image contents
654
- # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
655
- class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
656
- return (class_weights.reshape(1, nc) * class_counts).sum(1)
657
-
658
-
659
- def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
660
- # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
661
- # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
662
- # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
663
- # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
664
- # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
665
- return [
666
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
667
- 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
668
- 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
669
-
670
-
671
- def xyxy2xywh(x):
672
- # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
673
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
674
- y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
675
- y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
676
- y[:, 2] = x[:, 2] - x[:, 0] # width
677
- y[:, 3] = x[:, 3] - x[:, 1] # height
678
- return y
679
-
680
-
681
- def xywh2xyxy(x):
682
- # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
683
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
684
- y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
685
- y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
686
- y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
687
- y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
688
- return y
689
-
690
-
691
- def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
692
- # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
693
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
694
- y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
695
- y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
696
- y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
697
- y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
698
- return y
699
-
700
-
701
- def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
702
- # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
703
- if clip:
704
- clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
705
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
706
- y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
707
- y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
708
- y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
709
- y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
710
- return y
711
-
712
-
713
- def xyn2xy(x, w=640, h=640, padw=0, padh=0):
714
- # Convert normalized segments into pixel segments, shape (n,2)
715
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
716
- y[:, 0] = w * x[:, 0] + padw # top left x
717
- y[:, 1] = h * x[:, 1] + padh # top left y
718
- return y
719
-
720
-
721
- def segment2box(segment, width=640, height=640):
722
- # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
723
- x, y = segment.T # segment xy
724
- inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
725
- x, y, = x[inside], y[inside]
726
- return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
727
-
728
-
729
- def segments2boxes(segments):
730
- # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
731
- boxes = []
732
- for s in segments:
733
- x, y = s.T # segment xy
734
- boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
735
- return xyxy2xywh(np.array(boxes)) # cls, xywh
736
-
737
-
738
- def resample_segments(segments, n=1000):
739
- # Up-sample an (n,2) segment
740
- for i, s in enumerate(segments):
741
- s = np.concatenate((s, s[0:1, :]), axis=0)
742
- x = np.linspace(0, len(s) - 1, n)
743
- xp = np.arange(len(s))
744
- segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
745
- return segments
746
-
747
-
748
- def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
749
- # Rescale coords (xyxy) from img1_shape to img0_shape
750
- if ratio_pad is None: # calculate from img0_shape
751
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
752
- pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
753
- else:
754
- gain = ratio_pad[0][0]
755
- pad = ratio_pad[1]
756
-
757
- coords[:, [0, 2]] -= pad[0] # x padding
758
- coords[:, [1, 3]] -= pad[1] # y padding
759
- coords[:, :4] /= gain
760
- clip_coords(coords, img0_shape)
761
- return coords
762
-
763
-
764
- def clip_coords(boxes, shape):
765
- # Clip bounding xyxy bounding boxes to image shape (height, width)
766
- if isinstance(boxes, torch.Tensor): # faster individually
767
- boxes[:, 0].clamp_(0, shape[1]) # x1
768
- boxes[:, 1].clamp_(0, shape[0]) # y1
769
- boxes[:, 2].clamp_(0, shape[1]) # x2
770
- boxes[:, 3].clamp_(0, shape[0]) # y2
771
- else: # np.array (faster grouped)
772
- boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
773
- boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
774
-
775
-
776
- def non_max_suppression(prediction,
777
- conf_thres=0.25,
778
- iou_thres=0.45,
779
- classes=None,
780
- agnostic=False,
781
- multi_label=False,
782
- labels=(),
783
- max_det=300):
784
- """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
785
-
786
- Returns:
787
- list of detections, on (n,6) tensor per image [xyxy, conf, cls]
788
- """
789
-
790
- bs = prediction.shape[0] # batch size
791
- nc = prediction.shape[2] - 5 # number of classes
792
- xc = prediction[..., 4] > conf_thres # candidates
793
-
794
- # Checks
795
- assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
796
- assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
797
-
798
- # Settings
799
- # min_wh = 2 # (pixels) minimum box width and height
800
- max_wh = 7680 # (pixels) maximum box width and height
801
- max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
802
- time_limit = 0.3 + 0.03 * bs # seconds to quit after
803
- redundant = True # require redundant detections
804
- multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
805
- merge = False # use merge-NMS
806
-
807
- t = time.time()
808
- output = [torch.zeros((0, 6), device=prediction.device)] * bs
809
- for xi, x in enumerate(prediction): # image index, image inference
810
- # Apply constraints
811
- # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
812
- x = x[xc[xi]] # confidence
813
-
814
- # Cat apriori labels if autolabelling
815
- if labels and len(labels[xi]):
816
- lb = labels[xi]
817
- v = torch.zeros((len(lb), nc + 5), device=x.device)
818
- v[:, :4] = lb[:, 1:5] # box
819
- v[:, 4] = 1.0 # conf
820
- v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
821
- x = torch.cat((x, v), 0)
822
-
823
- # If none remain process next image
824
- if not x.shape[0]:
825
- continue
826
-
827
- # Compute conf
828
- x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
829
-
830
- # Box (center x, center y, width, height) to (x1, y1, x2, y2)
831
- box = xywh2xyxy(x[:, :4])
832
-
833
- # Detections matrix nx6 (xyxy, conf, cls)
834
- if multi_label:
835
- i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
836
- x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
837
- else: # best class only
838
- conf, j = x[:, 5:].max(1, keepdim=True)
839
- x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
840
-
841
- # Filter by class
842
- if classes is not None:
843
- x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
844
-
845
- # Apply finite constraint
846
- # if not torch.isfinite(x).all():
847
- # x = x[torch.isfinite(x).all(1)]
848
-
849
- # Check shape
850
- n = x.shape[0] # number of boxes
851
- if not n: # no boxes
852
- continue
853
- elif n > max_nms: # excess boxes
854
- x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
855
-
856
- # Batched NMS
857
- c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
858
- boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
859
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
860
- if i.shape[0] > max_det: # limit detections
861
- i = i[:max_det]
862
- if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
863
- # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
864
- iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
865
- weights = iou * scores[None] # box weights
866
- x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
867
- if redundant:
868
- i = i[iou.sum(1) > 1] # require redundancy
869
-
870
- output[xi] = x[i]
871
- if (time.time() - t) > time_limit:
872
- LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
873
- break # time limit exceeded
874
-
875
- return output
876
-
877
-
878
- def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
879
- # Strip optimizer from 'f' to finalize training, optionally save as 's'
880
- x = torch.load(f, map_location=torch.device('cpu'))
881
- if x.get('ema'):
882
- x['model'] = x['ema'] # replace model with ema
883
- for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
884
- x[k] = None
885
- x['epoch'] = -1
886
- x['model'].half() # to FP16
887
- for p in x['model'].parameters():
888
- p.requires_grad = False
889
- torch.save(x, s or f)
890
- mb = os.path.getsize(s or f) / 1E6 # filesize
891
- LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
892
-
893
-
894
- def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
895
- evolve_csv = save_dir / 'evolve.csv'
896
- evolve_yaml = save_dir / 'hyp_evolve.yaml'
897
- keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
898
- 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
899
- keys = tuple(x.strip() for x in keys)
900
- vals = results + tuple(hyp.values())
901
- n = len(keys)
902
-
903
- # Download (optional)
904
- if bucket:
905
- url = f'gs://{bucket}/evolve.csv'
906
- if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
907
- os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
908
-
909
- # Log to evolve.csv
910
- s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
911
- with open(evolve_csv, 'a') as f:
912
- f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
913
-
914
- # Save yaml
915
- with open(evolve_yaml, 'w') as f:
916
- data = pd.read_csv(evolve_csv)
917
- data = data.rename(columns=lambda x: x.strip()) # strip keys
918
- i = np.argmax(fitness(data.values[:, :4])) #
919
- generations = len(data)
920
- f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
921
- f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
922
- '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
923
- yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
924
-
925
- # Print to screen
926
- LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
927
- ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
928
- for x in vals) + '\n\n')
929
-
930
- if bucket:
931
- os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
932
-
933
-
934
- def apply_classifier(x, model, img, im0):
935
- # Apply a second stage classifier to YOLO outputs
936
- # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
937
- im0 = [im0] if isinstance(im0, np.ndarray) else im0
938
- for i, d in enumerate(x): # per image
939
- if d is not None and len(d):
940
- d = d.clone()
941
-
942
- # Reshape and pad cutouts
943
- b = xyxy2xywh(d[:, :4]) # boxes
944
- b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
945
- b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
946
- d[:, :4] = xywh2xyxy(b).long()
947
-
948
- # Rescale boxes from img_size to im0 size
949
- scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
950
-
951
- # Classes
952
- pred_cls1 = d[:, 5].long()
953
- ims = []
954
- for a in d:
955
- cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
956
- im = cv2.resize(cutout, (224, 224)) # BGR
957
-
958
- im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
959
- im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
960
- im /= 255 # 0 - 255 to 0.0 - 1.0
961
- ims.append(im)
962
-
963
- pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
964
- x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
965
-
966
- return x
967
-
968
-
969
- def increment_path(path, exist_ok=False, sep='', mkdir=False):
970
- # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
971
- path = Path(path) # os-agnostic
972
- if path.exists() and not exist_ok:
973
- path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
974
-
975
- # Method 1
976
- for n in range(2, 9999):
977
- p = f'{path}{sep}{n}{suffix}' # increment path
978
- if not os.path.exists(p): #
979
- break
980
- path = Path(p)
981
-
982
- # Method 2 (deprecated)
983
- # dirs = glob.glob(f"{path}{sep}*") # similar paths
984
- # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
985
- # i = [int(m.groups()[0]) for m in matches if m] # indices
986
- # n = max(i) + 1 if i else 2 # increment number
987
- # path = Path(f"{path}{sep}{n}{suffix}") # increment path
988
-
989
- if mkdir:
990
- path.mkdir(parents=True, exist_ok=True) # make directory
991
-
992
- return path
993
-
994
-
995
- # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
996
- imshow_ = cv2.imshow # copy to avoid recursion errors
997
-
998
-
999
- def imread(path, flags=cv2.IMREAD_COLOR):
1000
- return cv2.imdecode(np.fromfile(path, np.uint8), flags)
1001
-
1002
-
1003
- def imwrite(path, im):
1004
- try:
1005
- cv2.imencode(Path(path).suffix, im)[1].tofile(path)
1006
- return True
1007
- except Exception:
1008
- return False
1009
-
1010
-
1011
- def imshow(path, im):
1012
- imshow_(path.encode('unicode_escape').decode(), im)
1013
-
1014
-
1015
- cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
1016
-
1017
- # Variables ------------------------------------------------------------------------------------------------------------
1018
- NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm