ultralytics 8.0.159__py3-none-any.whl → 8.0.161__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -3
- ultralytics/data/dataset.py +74 -20
- ultralytics/data/utils.py +39 -5
- ultralytics/engine/trainer.py +4 -1
- ultralytics/hub/__init__.py +2 -25
- ultralytics/hub/auth.py +2 -22
- ultralytics/models/fastsam/predict.py +8 -11
- ultralytics/models/nas/predict.py +5 -5
- ultralytics/models/rtdetr/predict.py +5 -5
- ultralytics/models/sam/modules/sam.py +21 -35
- ultralytics/models/sam/predict.py +4 -4
- ultralytics/models/yolo/classify/predict.py +4 -5
- ultralytics/models/yolo/classify/train.py +1 -1
- ultralytics/models/yolo/classify/val.py +1 -1
- ultralytics/models/yolo/detect/predict.py +5 -7
- ultralytics/models/yolo/pose/predict.py +6 -11
- ultralytics/models/yolo/segment/predict.py +8 -13
- ultralytics/nn/modules/conv.py +6 -1
- ultralytics/trackers/utils/kalman_filter.py +71 -95
- ultralytics/utils/callbacks/tensorboard.py +3 -3
- ultralytics/utils/checks.py +6 -5
- ultralytics/utils/downloads.py +12 -13
- ultralytics/utils/metrics.py +0 -11
- ultralytics/utils/ops.py +84 -117
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/METADATA +1 -1
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/RECORD +30 -30
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/WHEEL +1 -1
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
-
__version__ = '8.0.
|
|
3
|
+
__version__ = '8.0.161'
|
|
4
4
|
|
|
5
|
-
from ultralytics.hub import start
|
|
6
5
|
from ultralytics.models import RTDETR, SAM, YOLO
|
|
7
6
|
from ultralytics.models.fastsam import FastSAM
|
|
8
7
|
from ultralytics.models.nas import NAS
|
|
@@ -10,4 +9,4 @@ from ultralytics.utils import SETTINGS as settings
|
|
|
10
9
|
from ultralytics.utils.checks import check_yolo as checks
|
|
11
10
|
from ultralytics.utils.downloads import download
|
|
12
11
|
|
|
13
|
-
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', '
|
|
12
|
+
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings' # allow simpler import
|
ultralytics/data/dataset.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
|
|
2
|
+
import contextlib
|
|
3
3
|
from itertools import repeat
|
|
4
4
|
from multiprocessing.pool import ThreadPool
|
|
5
5
|
from pathlib import Path
|
|
@@ -10,11 +10,14 @@ import torch
|
|
|
10
10
|
import torchvision
|
|
11
11
|
from tqdm import tqdm
|
|
12
12
|
|
|
13
|
-
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
|
13
|
+
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable
|
|
14
14
|
|
|
15
15
|
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
|
16
16
|
from .base import BaseDataset
|
|
17
|
-
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
|
|
17
|
+
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
|
18
|
+
|
|
19
|
+
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
|
20
|
+
DATASET_CACHE_VERSION = '1.0.3'
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class YOLODataset(BaseDataset):
|
|
@@ -29,7 +32,6 @@ class YOLODataset(BaseDataset):
|
|
|
29
32
|
Returns:
|
|
30
33
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
|
31
34
|
"""
|
|
32
|
-
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
|
33
35
|
|
|
34
36
|
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
|
|
35
37
|
self.use_segments = use_segments
|
|
@@ -87,15 +89,7 @@ class YOLODataset(BaseDataset):
|
|
|
87
89
|
x['hash'] = get_hash(self.label_files + self.im_files)
|
|
88
90
|
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
|
89
91
|
x['msgs'] = msgs # warnings
|
|
90
|
-
|
|
91
|
-
if is_dir_writeable(path.parent):
|
|
92
|
-
if path.exists():
|
|
93
|
-
path.unlink() # remove *.cache file if exists
|
|
94
|
-
np.save(str(path), x) # save cache for next time
|
|
95
|
-
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
|
96
|
-
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
|
97
|
-
else:
|
|
98
|
-
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
|
92
|
+
save_dataset_cache_file(self.prefix, path, x)
|
|
99
93
|
return x
|
|
100
94
|
|
|
101
95
|
def get_labels(self):
|
|
@@ -103,11 +97,8 @@ class YOLODataset(BaseDataset):
|
|
|
103
97
|
self.label_files = img2label_paths(self.im_files)
|
|
104
98
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
|
105
99
|
try:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
|
109
|
-
gc.enable()
|
|
110
|
-
assert cache['version'] == self.cache_version # matches current version
|
|
100
|
+
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
|
101
|
+
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
|
111
102
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
|
112
103
|
except (FileNotFoundError, AssertionError, AttributeError):
|
|
113
104
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
|
@@ -116,7 +107,7 @@ class YOLODataset(BaseDataset):
|
|
|
116
107
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
|
117
108
|
if exists and LOCAL_RANK in (-1, 0):
|
|
118
109
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
|
119
|
-
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display
|
|
110
|
+
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results
|
|
120
111
|
if cache['msgs']:
|
|
121
112
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
|
122
113
|
if nf == 0: # number of labels found
|
|
@@ -216,7 +207,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
216
207
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
|
217
208
|
"""
|
|
218
209
|
|
|
219
|
-
def __init__(self, root, args, augment=False, cache=False):
|
|
210
|
+
def __init__(self, root, args, augment=False, cache=False, prefix=''):
|
|
220
211
|
"""
|
|
221
212
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
222
213
|
|
|
@@ -229,8 +220,10 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
229
220
|
super().__init__(root=root)
|
|
230
221
|
if augment and args.fraction < 1.0: # reduce training fraction
|
|
231
222
|
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
|
223
|
+
self.prefix = colorstr(f'{prefix}: ') if prefix else ''
|
|
232
224
|
self.cache_ram = cache is True or cache == 'ram'
|
|
233
225
|
self.cache_disk = cache == 'disk'
|
|
226
|
+
self.samples = self.verify_images() # filter out bad images
|
|
234
227
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
|
235
228
|
self.torch_transforms = classify_transforms(args.imgsz)
|
|
236
229
|
self.album_transforms = classify_albumentations(
|
|
@@ -266,6 +259,67 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|
|
266
259
|
def __len__(self) -> int:
|
|
267
260
|
return len(self.samples)
|
|
268
261
|
|
|
262
|
+
def verify_images(self):
|
|
263
|
+
"""Verify all images in dataset."""
|
|
264
|
+
desc = f'{self.prefix}Scanning {self.root}...'
|
|
265
|
+
path = Path(self.root).with_suffix('.cache') # *.cache file path
|
|
266
|
+
|
|
267
|
+
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
|
268
|
+
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
|
269
|
+
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
|
270
|
+
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
|
|
271
|
+
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
|
|
272
|
+
if LOCAL_RANK in (-1, 0):
|
|
273
|
+
d = f'{desc} {nf} images, {nc} corrupt'
|
|
274
|
+
tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)
|
|
275
|
+
if cache['msgs']:
|
|
276
|
+
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
|
277
|
+
return samples
|
|
278
|
+
|
|
279
|
+
# Run scan if *.cache retrieval failed
|
|
280
|
+
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
|
281
|
+
with ThreadPool(NUM_THREADS) as pool:
|
|
282
|
+
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
|
283
|
+
pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT)
|
|
284
|
+
for sample, nf_f, nc_f, msg in pbar:
|
|
285
|
+
if nf_f:
|
|
286
|
+
samples.append(sample)
|
|
287
|
+
if msg:
|
|
288
|
+
msgs.append(msg)
|
|
289
|
+
nf += nf_f
|
|
290
|
+
nc += nc_f
|
|
291
|
+
pbar.desc = f'{desc} {nf} images, {nc} corrupt'
|
|
292
|
+
pbar.close()
|
|
293
|
+
if msgs:
|
|
294
|
+
LOGGER.info('\n'.join(msgs))
|
|
295
|
+
x['hash'] = get_hash([x[0] for x in self.samples])
|
|
296
|
+
x['results'] = nf, nc, len(samples), samples
|
|
297
|
+
x['msgs'] = msgs # warnings
|
|
298
|
+
save_dataset_cache_file(self.prefix, path, x)
|
|
299
|
+
return samples
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def load_dataset_cache_file(path):
|
|
303
|
+
"""Load an Ultralytics *.cache dictionary from path."""
|
|
304
|
+
import gc
|
|
305
|
+
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
|
306
|
+
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
|
307
|
+
gc.enable()
|
|
308
|
+
return cache
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def save_dataset_cache_file(prefix, path, x):
|
|
312
|
+
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
|
313
|
+
x['version'] = DATASET_CACHE_VERSION # add cache version
|
|
314
|
+
if is_dir_writeable(path.parent):
|
|
315
|
+
if path.exists():
|
|
316
|
+
path.unlink() # remove *.cache file if exists
|
|
317
|
+
np.save(str(path), x) # save cache for next time
|
|
318
|
+
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
|
319
|
+
LOGGER.info(f'{prefix}New cache created: {path}')
|
|
320
|
+
else:
|
|
321
|
+
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
|
322
|
+
|
|
269
323
|
|
|
270
324
|
# TODO: support semantic segmentation
|
|
271
325
|
class SemanticDataset(BaseDataset):
|
ultralytics/data/utils.py
CHANGED
|
@@ -57,6 +57,31 @@ def exif_size(img: Image.Image):
|
|
|
57
57
|
return s
|
|
58
58
|
|
|
59
59
|
|
|
60
|
+
def verify_image(args):
|
|
61
|
+
"""Verify one image."""
|
|
62
|
+
(im_file, cls), prefix = args
|
|
63
|
+
# Number (found, corrupt), message
|
|
64
|
+
nf, nc, msg = 0, 0, ''
|
|
65
|
+
try:
|
|
66
|
+
im = Image.open(im_file)
|
|
67
|
+
im.verify() # PIL verify
|
|
68
|
+
shape = exif_size(im) # image size
|
|
69
|
+
shape = (shape[1], shape[0]) # hw
|
|
70
|
+
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
|
71
|
+
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
|
72
|
+
if im.format.lower() in ('jpg', 'jpeg'):
|
|
73
|
+
with open(im_file, 'rb') as f:
|
|
74
|
+
f.seek(-2, 2)
|
|
75
|
+
if f.read() != b'\xff\xd9': # corrupt JPEG
|
|
76
|
+
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
|
77
|
+
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
|
78
|
+
nf = 1
|
|
79
|
+
except Exception as e:
|
|
80
|
+
nc = 1
|
|
81
|
+
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
|
82
|
+
return (im_file, cls), nf, nc, msg
|
|
83
|
+
|
|
84
|
+
|
|
60
85
|
def verify_image_label(args):
|
|
61
86
|
"""Verify one image-label pair."""
|
|
62
87
|
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
|
@@ -296,7 +321,7 @@ def check_cls_dataset(dataset: str, split=''):
|
|
|
296
321
|
dataset = Path(dataset)
|
|
297
322
|
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
|
298
323
|
if not data_dir.is_dir():
|
|
299
|
-
LOGGER.
|
|
324
|
+
LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
|
300
325
|
t = time.time()
|
|
301
326
|
if str(dataset) == 'imagenet':
|
|
302
327
|
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
|
@@ -310,9 +335,9 @@ def check_cls_dataset(dataset: str, split=''):
|
|
|
310
335
|
data_dir / 'validation').exists() else None # data/test or data/val
|
|
311
336
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
|
312
337
|
if split == 'val' and not val_set:
|
|
313
|
-
LOGGER.
|
|
338
|
+
LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
|
314
339
|
elif split == 'test' and not test_set:
|
|
315
|
-
LOGGER.
|
|
340
|
+
LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
|
316
341
|
|
|
317
342
|
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
|
318
343
|
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
|
@@ -320,13 +345,22 @@ def check_cls_dataset(dataset: str, split=''):
|
|
|
320
345
|
|
|
321
346
|
# Print to console
|
|
322
347
|
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
|
|
348
|
+
prefix = f'{colorstr(k)} {v}...'
|
|
323
349
|
if v is None:
|
|
324
|
-
LOGGER.info(
|
|
350
|
+
LOGGER.info(prefix)
|
|
325
351
|
else:
|
|
326
352
|
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
|
|
327
353
|
nf = len(files) # number of files
|
|
328
354
|
nd = len({file.parent for file in files}) # number of directories
|
|
329
|
-
|
|
355
|
+
if nf == 0:
|
|
356
|
+
if k == 'train':
|
|
357
|
+
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
|
|
358
|
+
else:
|
|
359
|
+
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
|
|
360
|
+
elif nd != nc:
|
|
361
|
+
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
|
|
362
|
+
else:
|
|
363
|
+
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
|
|
330
364
|
|
|
331
365
|
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
|
332
366
|
|
ultralytics/engine/trainer.py
CHANGED
|
@@ -10,6 +10,7 @@ import math
|
|
|
10
10
|
import os
|
|
11
11
|
import subprocess
|
|
12
12
|
import time
|
|
13
|
+
import warnings
|
|
13
14
|
from copy import deepcopy
|
|
14
15
|
from datetime import datetime, timedelta
|
|
15
16
|
from pathlib import Path
|
|
@@ -378,7 +379,9 @@ class BaseTrainer:
|
|
|
378
379
|
|
|
379
380
|
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
|
380
381
|
|
|
381
|
-
|
|
382
|
+
with warnings.catch_warnings():
|
|
383
|
+
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
|
384
|
+
self.scheduler.step()
|
|
382
385
|
self.run_callbacks('on_train_epoch_end')
|
|
383
386
|
|
|
384
387
|
if RANK in (-1, 0):
|
ultralytics/hub/__init__.py
CHANGED
|
@@ -5,7 +5,7 @@ import requests
|
|
|
5
5
|
from ultralytics.data.utils import HUBDatasetStats
|
|
6
6
|
from ultralytics.hub.auth import Auth
|
|
7
7
|
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
|
8
|
-
from ultralytics.utils import LOGGER, SETTINGS
|
|
8
|
+
from ultralytics.utils import LOGGER, SETTINGS
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def login(api_key=''):
|
|
@@ -37,29 +37,10 @@ def logout():
|
|
|
37
37
|
```
|
|
38
38
|
"""
|
|
39
39
|
SETTINGS['api_key'] = ''
|
|
40
|
-
|
|
40
|
+
SETTINGS.save()
|
|
41
41
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
def start(key=''):
|
|
45
|
-
"""
|
|
46
|
-
Start training models with Ultralytics HUB (DEPRECATED).
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
|
50
|
-
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
|
51
|
-
"""
|
|
52
|
-
api_key, model_id = key.split('_')
|
|
53
|
-
LOGGER.warning(f"""
|
|
54
|
-
WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
|
|
55
|
-
|
|
56
|
-
from ultralytics import YOLO, hub
|
|
57
|
-
|
|
58
|
-
hub.login('{api_key}')
|
|
59
|
-
model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
|
|
60
|
-
model.train()""")
|
|
61
|
-
|
|
62
|
-
|
|
63
44
|
def reset_model(model_id=''):
|
|
64
45
|
"""Reset a trained model to an untrained state."""
|
|
65
46
|
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
|
@@ -117,7 +98,3 @@ def check_dataset(path='', task='detect'):
|
|
|
117
98
|
"""
|
|
118
99
|
HUBDatasetStats(path=path, task=task).get_json()
|
|
119
100
|
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
if __name__ == '__main__':
|
|
123
|
-
start()
|
ultralytics/hub/auth.py
CHANGED
|
@@ -73,8 +73,7 @@ class Auth:
|
|
|
73
73
|
bool: True if authentication is successful, False otherwise.
|
|
74
74
|
"""
|
|
75
75
|
try:
|
|
76
|
-
header
|
|
77
|
-
if header:
|
|
76
|
+
if header := self.get_auth_header():
|
|
78
77
|
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
|
79
78
|
if not r.json().get('success', False):
|
|
80
79
|
raise ConnectionError('Unable to authenticate.')
|
|
@@ -117,23 +116,4 @@ class Auth:
|
|
|
117
116
|
return {'authorization': f'Bearer {self.id_token}'}
|
|
118
117
|
elif self.api_key:
|
|
119
118
|
return {'x-api-key': self.api_key}
|
|
120
|
-
else
|
|
121
|
-
return None
|
|
122
|
-
|
|
123
|
-
def get_state(self) -> bool:
|
|
124
|
-
"""
|
|
125
|
-
Get the authentication state.
|
|
126
|
-
|
|
127
|
-
Returns:
|
|
128
|
-
bool: True if either id_token or API key is set, False otherwise.
|
|
129
|
-
"""
|
|
130
|
-
return self.id_token or self.api_key
|
|
131
|
-
|
|
132
|
-
def set_api_key(self, key: str):
|
|
133
|
-
"""
|
|
134
|
-
Set the API key for authentication.
|
|
135
|
-
|
|
136
|
-
Args:
|
|
137
|
-
key (str): The API key string.
|
|
138
|
-
"""
|
|
139
|
-
self.api_key = key
|
|
119
|
+
# else returns None
|
|
@@ -15,7 +15,6 @@ class FastSAMPredictor(DetectionPredictor):
|
|
|
15
15
|
self.args.task = 'segment'
|
|
16
16
|
|
|
17
17
|
def postprocess(self, preds, img, orig_imgs):
|
|
18
|
-
"""TODO: filter by classes."""
|
|
19
18
|
p = ops.non_max_suppression(preds[0],
|
|
20
19
|
self.args.conf,
|
|
21
20
|
self.args.iou,
|
|
@@ -32,22 +31,20 @@ class FastSAMPredictor(DetectionPredictor):
|
|
|
32
31
|
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
|
33
32
|
p[0][critical_iou_index] = full_box
|
|
34
33
|
results = []
|
|
34
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
35
35
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
|
36
36
|
for i, pred in enumerate(p):
|
|
37
|
-
orig_img = orig_imgs[i] if
|
|
38
|
-
|
|
39
|
-
img_path = path[i] if isinstance(path, list) else path
|
|
37
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
38
|
+
img_path = self.batch[0][i]
|
|
40
39
|
if not len(pred): # save empty boxes
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
if not isinstance(orig_imgs, torch.Tensor):
|
|
40
|
+
masks = None
|
|
41
|
+
elif self.args.retina_masks:
|
|
42
|
+
if is_list:
|
|
45
43
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
46
44
|
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
|
47
45
|
else:
|
|
48
46
|
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
|
49
|
-
if
|
|
47
|
+
if is_list:
|
|
50
48
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
51
|
-
results.append(
|
|
52
|
-
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
|
49
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
|
53
50
|
return results
|
|
@@ -24,11 +24,11 @@ class NASPredictor(BasePredictor):
|
|
|
24
24
|
classes=self.args.classes)
|
|
25
25
|
|
|
26
26
|
results = []
|
|
27
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
27
28
|
for i, pred in enumerate(preds):
|
|
28
|
-
orig_img = orig_imgs[i] if
|
|
29
|
-
if
|
|
29
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
30
|
+
if is_list:
|
|
30
31
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
31
|
-
|
|
32
|
-
img_path =
|
|
33
|
-
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
32
|
+
img_path = self.batch[0][i]
|
|
33
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
34
34
|
return results
|
|
@@ -28,6 +28,7 @@ class RTDETRPredictor(BasePredictor):
|
|
|
28
28
|
nd = preds[0].shape[-1]
|
|
29
29
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
|
30
30
|
results = []
|
|
31
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
31
32
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
|
32
33
|
bbox = ops.xywh2xyxy(bbox)
|
|
33
34
|
score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
|
|
@@ -35,14 +36,13 @@ class RTDETRPredictor(BasePredictor):
|
|
|
35
36
|
if self.args.classes is not None:
|
|
36
37
|
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
|
37
38
|
pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
|
|
38
|
-
orig_img = orig_imgs[i] if
|
|
39
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
39
40
|
oh, ow = orig_img.shape[:2]
|
|
40
|
-
if
|
|
41
|
+
if is_list:
|
|
41
42
|
pred[..., [0, 2]] *= ow
|
|
42
43
|
pred[..., [1, 3]] *= oh
|
|
43
|
-
|
|
44
|
-
img_path =
|
|
45
|
-
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
44
|
+
img_path = self.batch[0][i]
|
|
45
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
46
46
|
return results
|
|
47
47
|
|
|
48
48
|
def pre_transform(self, im):
|
|
@@ -30,11 +30,10 @@ class Sam(nn.Module):
|
|
|
30
30
|
SAM predicts object masks from an image and input prompts.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
image_encoder (ImageEncoderViT): The backbone used to encode the
|
|
34
|
-
|
|
33
|
+
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
|
34
|
+
efficient mask prediction.
|
|
35
35
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
|
36
|
-
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
|
37
|
-
and encoded prompts.
|
|
36
|
+
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
|
38
37
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
|
39
38
|
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
|
40
39
|
"""
|
|
@@ -65,34 +64,25 @@ class Sam(nn.Module):
|
|
|
65
64
|
|
|
66
65
|
Args:
|
|
67
66
|
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
|
68
|
-
|
|
69
|
-
'image': The image as a torch tensor in 3xHxW format,
|
|
70
|
-
|
|
71
|
-
'
|
|
72
|
-
the
|
|
73
|
-
'
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
'
|
|
77
|
-
with shape BxN.
|
|
78
|
-
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
|
79
|
-
Already transformed to the input frame of the model.
|
|
80
|
-
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
|
81
|
-
in the form Bx1xHxW.
|
|
67
|
+
key can be excluded if it is not present.
|
|
68
|
+
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
|
|
69
|
+
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
|
|
70
|
+
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
|
|
71
|
+
transformed to the input frame of the model.
|
|
72
|
+
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
|
|
73
|
+
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
|
|
74
|
+
the model.
|
|
75
|
+
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
|
|
82
76
|
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
|
83
77
|
mask.
|
|
84
78
|
|
|
85
79
|
Returns:
|
|
86
80
|
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
|
87
|
-
'masks': (torch.Tensor) Batched binary mask predictions,
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
of mask quality, in shape BxC.
|
|
93
|
-
'low_res_logits': (torch.Tensor) Low resolution logits with
|
|
94
|
-
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
|
95
|
-
to subsequent iterations of prediction.
|
|
81
|
+
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
|
|
82
|
+
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
|
|
83
|
+
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
|
|
84
|
+
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
|
|
85
|
+
as mask input to subsequent iterations of prediction.
|
|
96
86
|
"""
|
|
97
87
|
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
|
|
98
88
|
image_embeddings = self.image_encoder(input_images)
|
|
@@ -137,16 +127,12 @@ class Sam(nn.Module):
|
|
|
137
127
|
Remove padding and upscale masks to the original image size.
|
|
138
128
|
|
|
139
129
|
Args:
|
|
140
|
-
masks (torch.Tensor): Batched masks from the mask_decoder,
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
model, in (H, W) format. Used to remove padding.
|
|
144
|
-
original_size (tuple(int, int)): The original size of the image
|
|
145
|
-
before resizing for input to the model, in (H, W) format.
|
|
130
|
+
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
|
|
131
|
+
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
|
|
132
|
+
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
|
|
146
133
|
|
|
147
134
|
Returns:
|
|
148
|
-
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
|
149
|
-
is given by original_size.
|
|
135
|
+
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
|
|
150
136
|
"""
|
|
151
137
|
masks = F.interpolate(
|
|
152
138
|
masks,
|
|
@@ -318,8 +318,9 @@ class Predictor(BasePredictor):
|
|
|
318
318
|
pred_bboxes = preds[2] if self.segment_all else None
|
|
319
319
|
names = dict(enumerate(str(i) for i in range(len(pred_masks))))
|
|
320
320
|
results = []
|
|
321
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
321
322
|
for i, masks in enumerate([pred_masks]):
|
|
322
|
-
orig_img = orig_imgs[i] if
|
|
323
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
323
324
|
if pred_bboxes is not None:
|
|
324
325
|
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
|
325
326
|
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
|
@@ -327,9 +328,8 @@ class Predictor(BasePredictor):
|
|
|
327
328
|
|
|
328
329
|
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
|
329
330
|
masks = masks > self.model.mask_threshold # to bool
|
|
330
|
-
|
|
331
|
-
img_path =
|
|
332
|
-
results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
|
331
|
+
img_path = self.batch[0][i]
|
|
332
|
+
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
|
333
333
|
# Reset segment-all mode.
|
|
334
334
|
self.segment_all = False
|
|
335
335
|
return results
|
|
@@ -39,10 +39,9 @@ class ClassificationPredictor(BasePredictor):
|
|
|
39
39
|
def postprocess(self, preds, img, orig_imgs):
|
|
40
40
|
"""Post-processes predictions to return Results objects."""
|
|
41
41
|
results = []
|
|
42
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
42
43
|
for i, pred in enumerate(preds):
|
|
43
|
-
orig_img = orig_imgs[i] if
|
|
44
|
-
|
|
45
|
-
img_path =
|
|
46
|
-
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
|
|
47
|
-
|
|
44
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
45
|
+
img_path = self.batch[0][i]
|
|
46
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
|
|
48
47
|
return results
|
|
@@ -79,7 +79,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
79
79
|
return ckpt
|
|
80
80
|
|
|
81
81
|
def build_dataset(self, img_path, mode='train', batch=None):
|
|
82
|
-
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
|
82
|
+
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
|
83
83
|
|
|
84
84
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
|
85
85
|
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
|
@@ -77,7 +77,7 @@ class ClassificationValidator(BaseValidator):
|
|
|
77
77
|
return self.metrics.results_dict
|
|
78
78
|
|
|
79
79
|
def build_dataset(self, img_path):
|
|
80
|
-
return ClassificationDataset(root=img_path, args=self.args, augment=False)
|
|
80
|
+
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
81
81
|
|
|
82
82
|
def get_dataloader(self, dataset_path, batch_size):
|
|
83
83
|
"""Builds and returns a data loader for classification tasks with given parameters."""
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
3
|
from ultralytics.engine.predictor import BasePredictor
|
|
6
4
|
from ultralytics.engine.results import Results
|
|
7
5
|
from ultralytics.utils import ops
|
|
@@ -32,11 +30,11 @@ class DetectionPredictor(BasePredictor):
|
|
|
32
30
|
classes=self.args.classes)
|
|
33
31
|
|
|
34
32
|
results = []
|
|
33
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
35
34
|
for i, pred in enumerate(preds):
|
|
36
|
-
orig_img = orig_imgs[i] if
|
|
37
|
-
if
|
|
35
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
36
|
+
if is_list:
|
|
38
37
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
39
|
-
|
|
40
|
-
img_path =
|
|
41
|
-
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
38
|
+
img_path = self.batch[0][i]
|
|
39
|
+
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
|
42
40
|
return results
|
|
@@ -38,18 +38,13 @@ class PosePredictor(DetectionPredictor):
|
|
|
38
38
|
nc=len(self.model.names))
|
|
39
39
|
|
|
40
40
|
results = []
|
|
41
|
+
is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
|
|
41
42
|
for i, pred in enumerate(preds):
|
|
42
|
-
orig_img = orig_imgs[i] if
|
|
43
|
-
|
|
44
|
-
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
|
43
|
+
orig_img = orig_imgs[i] if is_list else orig_imgs
|
|
44
|
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
|
|
45
45
|
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
|
46
|
-
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
|
|
47
|
-
|
|
48
|
-
img_path = path[i] if isinstance(path, list) else path
|
|
46
|
+
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
|
47
|
+
img_path = self.batch[0][i]
|
|
49
48
|
results.append(
|
|
50
|
-
Results(orig_img=
|
|
51
|
-
path=img_path,
|
|
52
|
-
names=self.model.names,
|
|
53
|
-
boxes=pred[:, :6],
|
|
54
|
-
keypoints=pred_kpts))
|
|
49
|
+
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
|
|
55
50
|
return results
|