ultralytics 8.0.159__py3-none-any.whl → 8.0.161__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (30) hide show
  1. ultralytics/__init__.py +2 -3
  2. ultralytics/data/dataset.py +74 -20
  3. ultralytics/data/utils.py +39 -5
  4. ultralytics/engine/trainer.py +4 -1
  5. ultralytics/hub/__init__.py +2 -25
  6. ultralytics/hub/auth.py +2 -22
  7. ultralytics/models/fastsam/predict.py +8 -11
  8. ultralytics/models/nas/predict.py +5 -5
  9. ultralytics/models/rtdetr/predict.py +5 -5
  10. ultralytics/models/sam/modules/sam.py +21 -35
  11. ultralytics/models/sam/predict.py +4 -4
  12. ultralytics/models/yolo/classify/predict.py +4 -5
  13. ultralytics/models/yolo/classify/train.py +1 -1
  14. ultralytics/models/yolo/classify/val.py +1 -1
  15. ultralytics/models/yolo/detect/predict.py +5 -7
  16. ultralytics/models/yolo/pose/predict.py +6 -11
  17. ultralytics/models/yolo/segment/predict.py +8 -13
  18. ultralytics/nn/modules/conv.py +6 -1
  19. ultralytics/trackers/utils/kalman_filter.py +71 -95
  20. ultralytics/utils/callbacks/tensorboard.py +3 -3
  21. ultralytics/utils/checks.py +6 -5
  22. ultralytics/utils/downloads.py +12 -13
  23. ultralytics/utils/metrics.py +0 -11
  24. ultralytics/utils/ops.py +84 -117
  25. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/METADATA +1 -1
  26. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/RECORD +30 -30
  27. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/WHEEL +1 -1
  28. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/LICENSE +0 -0
  29. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/entry_points.txt +0 -0
  30. {ultralytics-8.0.159.dist-info → ultralytics-8.0.161.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,8 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.159'
3
+ __version__ = '8.0.161'
4
4
 
5
- from ultralytics.hub import start
6
5
  from ultralytics.models import RTDETR, SAM, YOLO
7
6
  from ultralytics.models.fastsam import FastSAM
8
7
  from ultralytics.models.nas import NAS
@@ -10,4 +9,4 @@ from ultralytics.utils import SETTINGS as settings
10
9
  from ultralytics.utils.checks import check_yolo as checks
11
10
  from ultralytics.utils.downloads import download
12
11
 
13
- __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
12
+ __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings' # allow simpler import
@@ -1,5 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
2
+ import contextlib
3
3
  from itertools import repeat
4
4
  from multiprocessing.pool import ThreadPool
5
5
  from pathlib import Path
@@ -10,11 +10,14 @@ import torch
10
10
  import torchvision
11
11
  from tqdm import tqdm
12
12
 
13
- from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
13
+ from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable
14
14
 
15
15
  from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
16
16
  from .base import BaseDataset
17
- from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
17
+ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
18
+
19
+ # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
20
+ DATASET_CACHE_VERSION = '1.0.3'
18
21
 
19
22
 
20
23
  class YOLODataset(BaseDataset):
@@ -29,7 +32,6 @@ class YOLODataset(BaseDataset):
29
32
  Returns:
30
33
  (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
31
34
  """
32
- cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
33
35
 
34
36
  def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
35
37
  self.use_segments = use_segments
@@ -87,15 +89,7 @@ class YOLODataset(BaseDataset):
87
89
  x['hash'] = get_hash(self.label_files + self.im_files)
88
90
  x['results'] = nf, nm, ne, nc, len(self.im_files)
89
91
  x['msgs'] = msgs # warnings
90
- x['version'] = self.cache_version # cache version
91
- if is_dir_writeable(path.parent):
92
- if path.exists():
93
- path.unlink() # remove *.cache file if exists
94
- np.save(str(path), x) # save cache for next time
95
- path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
96
- LOGGER.info(f'{self.prefix}New cache created: {path}')
97
- else:
98
- LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
92
+ save_dataset_cache_file(self.prefix, path, x)
99
93
  return x
100
94
 
101
95
  def get_labels(self):
@@ -103,11 +97,8 @@ class YOLODataset(BaseDataset):
103
97
  self.label_files = img2label_paths(self.im_files)
104
98
  cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
105
99
  try:
106
- import gc
107
- gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
108
- cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
109
- gc.enable()
110
- assert cache['version'] == self.cache_version # matches current version
100
+ cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
101
+ assert cache['version'] == DATASET_CACHE_VERSION # matches current version
111
102
  assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
112
103
  except (FileNotFoundError, AssertionError, AttributeError):
113
104
  cache, exists = self.cache_labels(cache_path), False # run cache ops
@@ -116,7 +107,7 @@ class YOLODataset(BaseDataset):
116
107
  nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
117
108
  if exists and LOCAL_RANK in (-1, 0):
118
109
  d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
119
- tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
110
+ tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results
120
111
  if cache['msgs']:
121
112
  LOGGER.info('\n'.join(cache['msgs'])) # display warnings
122
113
  if nf == 0: # number of labels found
@@ -216,7 +207,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
216
207
  album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
217
208
  """
218
209
 
219
- def __init__(self, root, args, augment=False, cache=False):
210
+ def __init__(self, root, args, augment=False, cache=False, prefix=''):
220
211
  """
221
212
  Initialize YOLO object with root, image size, augmentations, and cache settings.
222
213
 
@@ -229,8 +220,10 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
229
220
  super().__init__(root=root)
230
221
  if augment and args.fraction < 1.0: # reduce training fraction
231
222
  self.samples = self.samples[:round(len(self.samples) * args.fraction)]
223
+ self.prefix = colorstr(f'{prefix}: ') if prefix else ''
232
224
  self.cache_ram = cache is True or cache == 'ram'
233
225
  self.cache_disk = cache == 'disk'
226
+ self.samples = self.verify_images() # filter out bad images
234
227
  self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
235
228
  self.torch_transforms = classify_transforms(args.imgsz)
236
229
  self.album_transforms = classify_albumentations(
@@ -266,6 +259,67 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
266
259
  def __len__(self) -> int:
267
260
  return len(self.samples)
268
261
 
262
+ def verify_images(self):
263
+ """Verify all images in dataset."""
264
+ desc = f'{self.prefix}Scanning {self.root}...'
265
+ path = Path(self.root).with_suffix('.cache') # *.cache file path
266
+
267
+ with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
268
+ cache = load_dataset_cache_file(path) # attempt to load a *.cache file
269
+ assert cache['version'] == DATASET_CACHE_VERSION # matches current version
270
+ assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
271
+ nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
272
+ if LOCAL_RANK in (-1, 0):
273
+ d = f'{desc} {nf} images, {nc} corrupt'
274
+ tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)
275
+ if cache['msgs']:
276
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
277
+ return samples
278
+
279
+ # Run scan if *.cache retrieval failed
280
+ nf, nc, msgs, samples, x = 0, 0, [], [], {}
281
+ with ThreadPool(NUM_THREADS) as pool:
282
+ results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
283
+ pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT)
284
+ for sample, nf_f, nc_f, msg in pbar:
285
+ if nf_f:
286
+ samples.append(sample)
287
+ if msg:
288
+ msgs.append(msg)
289
+ nf += nf_f
290
+ nc += nc_f
291
+ pbar.desc = f'{desc} {nf} images, {nc} corrupt'
292
+ pbar.close()
293
+ if msgs:
294
+ LOGGER.info('\n'.join(msgs))
295
+ x['hash'] = get_hash([x[0] for x in self.samples])
296
+ x['results'] = nf, nc, len(samples), samples
297
+ x['msgs'] = msgs # warnings
298
+ save_dataset_cache_file(self.prefix, path, x)
299
+ return samples
300
+
301
+
302
+ def load_dataset_cache_file(path):
303
+ """Load an Ultralytics *.cache dictionary from path."""
304
+ import gc
305
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
306
+ cache = np.load(str(path), allow_pickle=True).item() # load dict
307
+ gc.enable()
308
+ return cache
309
+
310
+
311
+ def save_dataset_cache_file(prefix, path, x):
312
+ """Save an Ultralytics dataset *.cache dictionary x to path."""
313
+ x['version'] = DATASET_CACHE_VERSION # add cache version
314
+ if is_dir_writeable(path.parent):
315
+ if path.exists():
316
+ path.unlink() # remove *.cache file if exists
317
+ np.save(str(path), x) # save cache for next time
318
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
319
+ LOGGER.info(f'{prefix}New cache created: {path}')
320
+ else:
321
+ LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
322
+
269
323
 
270
324
  # TODO: support semantic segmentation
271
325
  class SemanticDataset(BaseDataset):
ultralytics/data/utils.py CHANGED
@@ -57,6 +57,31 @@ def exif_size(img: Image.Image):
57
57
  return s
58
58
 
59
59
 
60
+ def verify_image(args):
61
+ """Verify one image."""
62
+ (im_file, cls), prefix = args
63
+ # Number (found, corrupt), message
64
+ nf, nc, msg = 0, 0, ''
65
+ try:
66
+ im = Image.open(im_file)
67
+ im.verify() # PIL verify
68
+ shape = exif_size(im) # image size
69
+ shape = (shape[1], shape[0]) # hw
70
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
71
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
72
+ if im.format.lower() in ('jpg', 'jpeg'):
73
+ with open(im_file, 'rb') as f:
74
+ f.seek(-2, 2)
75
+ if f.read() != b'\xff\xd9': # corrupt JPEG
76
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
77
+ msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
78
+ nf = 1
79
+ except Exception as e:
80
+ nc = 1
81
+ msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
82
+ return (im_file, cls), nf, nc, msg
83
+
84
+
60
85
  def verify_image_label(args):
61
86
  """Verify one image-label pair."""
62
87
  im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
@@ -296,7 +321,7 @@ def check_cls_dataset(dataset: str, split=''):
296
321
  dataset = Path(dataset)
297
322
  data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
298
323
  if not data_dir.is_dir():
299
- LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
324
+ LOGGER.warning(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
300
325
  t = time.time()
301
326
  if str(dataset) == 'imagenet':
302
327
  subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
@@ -310,9 +335,9 @@ def check_cls_dataset(dataset: str, split=''):
310
335
  data_dir / 'validation').exists() else None # data/test or data/val
311
336
  test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
312
337
  if split == 'val' and not val_set:
313
- LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
338
+ LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
314
339
  elif split == 'test' and not test_set:
315
- LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
340
+ LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
316
341
 
317
342
  nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
318
343
  names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
@@ -320,13 +345,22 @@ def check_cls_dataset(dataset: str, split=''):
320
345
 
321
346
  # Print to console
322
347
  for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
348
+ prefix = f'{colorstr(k)} {v}...'
323
349
  if v is None:
324
- LOGGER.info(f'{colorstr(k)}: {v}')
350
+ LOGGER.info(prefix)
325
351
  else:
326
352
  files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
327
353
  nf = len(files) # number of files
328
354
  nd = len({file.parent for file in files}) # number of directories
329
- LOGGER.info(f'{colorstr(k)}: {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
355
+ if nf == 0:
356
+ if k == 'train':
357
+ raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
358
+ else:
359
+ LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found')
360
+ elif nd != nc:
361
+ LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}')
362
+ else:
363
+ LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
330
364
 
331
365
  return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
332
366
 
@@ -10,6 +10,7 @@ import math
10
10
  import os
11
11
  import subprocess
12
12
  import time
13
+ import warnings
13
14
  from copy import deepcopy
14
15
  from datetime import datetime, timedelta
15
16
  from pathlib import Path
@@ -378,7 +379,9 @@ class BaseTrainer:
378
379
 
379
380
  self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
380
381
 
381
- self.scheduler.step()
382
+ with warnings.catch_warnings():
383
+ warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
384
+ self.scheduler.step()
382
385
  self.run_callbacks('on_train_epoch_end')
383
386
 
384
387
  if RANK in (-1, 0):
@@ -5,7 +5,7 @@ import requests
5
5
  from ultralytics.data.utils import HUBDatasetStats
6
6
  from ultralytics.hub.auth import Auth
7
7
  from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
8
- from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
8
+ from ultralytics.utils import LOGGER, SETTINGS
9
9
 
10
10
 
11
11
  def login(api_key=''):
@@ -37,29 +37,10 @@ def logout():
37
37
  ```
38
38
  """
39
39
  SETTINGS['api_key'] = ''
40
- yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
40
+ SETTINGS.save()
41
41
  LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
42
42
 
43
43
 
44
- def start(key=''):
45
- """
46
- Start training models with Ultralytics HUB (DEPRECATED).
47
-
48
- Args:
49
- key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
50
- or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
51
- """
52
- api_key, model_id = key.split('_')
53
- LOGGER.warning(f"""
54
- WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
55
-
56
- from ultralytics import YOLO, hub
57
-
58
- hub.login('{api_key}')
59
- model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
60
- model.train()""")
61
-
62
-
63
44
  def reset_model(model_id=''):
64
45
  """Reset a trained model to an untrained state."""
65
46
  r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
@@ -117,7 +98,3 @@ def check_dataset(path='', task='detect'):
117
98
  """
118
99
  HUBDatasetStats(path=path, task=task).get_json()
119
100
  LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
120
-
121
-
122
- if __name__ == '__main__':
123
- start()
ultralytics/hub/auth.py CHANGED
@@ -73,8 +73,7 @@ class Auth:
73
73
  bool: True if authentication is successful, False otherwise.
74
74
  """
75
75
  try:
76
- header = self.get_auth_header()
77
- if header:
76
+ if header := self.get_auth_header():
78
77
  r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
79
78
  if not r.json().get('success', False):
80
79
  raise ConnectionError('Unable to authenticate.')
@@ -117,23 +116,4 @@ class Auth:
117
116
  return {'authorization': f'Bearer {self.id_token}'}
118
117
  elif self.api_key:
119
118
  return {'x-api-key': self.api_key}
120
- else:
121
- return None
122
-
123
- def get_state(self) -> bool:
124
- """
125
- Get the authentication state.
126
-
127
- Returns:
128
- bool: True if either id_token or API key is set, False otherwise.
129
- """
130
- return self.id_token or self.api_key
131
-
132
- def set_api_key(self, key: str):
133
- """
134
- Set the API key for authentication.
135
-
136
- Args:
137
- key (str): The API key string.
138
- """
139
- self.api_key = key
119
+ # else returns None
@@ -15,7 +15,6 @@ class FastSAMPredictor(DetectionPredictor):
15
15
  self.args.task = 'segment'
16
16
 
17
17
  def postprocess(self, preds, img, orig_imgs):
18
- """TODO: filter by classes."""
19
18
  p = ops.non_max_suppression(preds[0],
20
19
  self.args.conf,
21
20
  self.args.iou,
@@ -32,22 +31,20 @@ class FastSAMPredictor(DetectionPredictor):
32
31
  full_box[0][6:] = p[0][critical_iou_index][:, 6:]
33
32
  p[0][critical_iou_index] = full_box
34
33
  results = []
34
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
35
35
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
36
36
  for i, pred in enumerate(p):
37
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
38
- path = self.batch[0]
39
- img_path = path[i] if isinstance(path, list) else path
37
+ orig_img = orig_imgs[i] if is_list else orig_imgs
38
+ img_path = self.batch[0][i]
40
39
  if not len(pred): # save empty boxes
41
- results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
42
- continue
43
- if self.args.retina_masks:
44
- if not isinstance(orig_imgs, torch.Tensor):
40
+ masks = None
41
+ elif self.args.retina_masks:
42
+ if is_list:
45
43
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
46
44
  masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
47
45
  else:
48
46
  masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
49
- if not isinstance(orig_imgs, torch.Tensor):
47
+ if is_list:
50
48
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
51
- results.append(
52
- Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
49
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
53
50
  return results
@@ -24,11 +24,11 @@ class NASPredictor(BasePredictor):
24
24
  classes=self.args.classes)
25
25
 
26
26
  results = []
27
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
27
28
  for i, pred in enumerate(preds):
28
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
29
- if not isinstance(orig_imgs, torch.Tensor):
29
+ orig_img = orig_imgs[i] if is_list else orig_imgs
30
+ if is_list:
30
31
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
31
- path = self.batch[0]
32
- img_path = path[i] if isinstance(path, list) else path
33
- results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
32
+ img_path = self.batch[0][i]
33
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
34
34
  return results
@@ -28,6 +28,7 @@ class RTDETRPredictor(BasePredictor):
28
28
  nd = preds[0].shape[-1]
29
29
  bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
30
30
  results = []
31
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
31
32
  for i, bbox in enumerate(bboxes): # (300, 4)
32
33
  bbox = ops.xywh2xyxy(bbox)
33
34
  score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
@@ -35,14 +36,13 @@ class RTDETRPredictor(BasePredictor):
35
36
  if self.args.classes is not None:
36
37
  idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
37
38
  pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
38
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
39
+ orig_img = orig_imgs[i] if is_list else orig_imgs
39
40
  oh, ow = orig_img.shape[:2]
40
- if not isinstance(orig_imgs, torch.Tensor):
41
+ if is_list:
41
42
  pred[..., [0, 2]] *= ow
42
43
  pred[..., [1, 3]] *= oh
43
- path = self.batch[0]
44
- img_path = path[i] if isinstance(path, list) else path
45
- results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
44
+ img_path = self.batch[0][i]
45
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
46
46
  return results
47
47
 
48
48
  def pre_transform(self, im):
@@ -30,11 +30,10 @@ class Sam(nn.Module):
30
30
  SAM predicts object masks from an image and input prompts.
31
31
 
32
32
  Args:
33
- image_encoder (ImageEncoderViT): The backbone used to encode the
34
- image into image embeddings that allow for efficient mask prediction.
33
+ image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
34
+ efficient mask prediction.
35
35
  prompt_encoder (PromptEncoder): Encodes various types of input prompts.
36
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings
37
- and encoded prompts.
36
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
38
37
  pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
39
38
  pixel_std (list(float)): Std values for normalizing pixels in the input image.
40
39
  """
@@ -65,34 +64,25 @@ class Sam(nn.Module):
65
64
 
66
65
  Args:
67
66
  batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
68
- key can be excluded if it is not present.
69
- 'image': The image as a torch tensor in 3xHxW format,
70
- already transformed for input to the model.
71
- 'original_size': (tuple(int, int)) The original size of
72
- the image before transformation, as (H, W).
73
- 'point_coords': (torch.Tensor) Batched point prompts for
74
- this image, with shape BxNx2. Already transformed to the
75
- input frame of the model.
76
- 'point_labels': (torch.Tensor) Batched labels for point prompts,
77
- with shape BxN.
78
- 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
79
- Already transformed to the input frame of the model.
80
- 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
81
- in the form Bx1xHxW.
67
+ key can be excluded if it is not present.
68
+ 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
69
+ 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
70
+ 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
71
+ transformed to the input frame of the model.
72
+ 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
73
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
74
+ the model.
75
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
82
76
  multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
83
77
  mask.
84
78
 
85
79
  Returns:
86
80
  (list(dict)): A list over input images, where each element is as dictionary with the following keys.
87
- 'masks': (torch.Tensor) Batched binary mask predictions,
88
- with shape BxCxHxW, where B is the number of input prompts,
89
- C is determined by multimask_output, and (H, W) is the
90
- original size of the image.
91
- 'iou_predictions': (torch.Tensor) The model's predictions
92
- of mask quality, in shape BxC.
93
- 'low_res_logits': (torch.Tensor) Low resolution logits with
94
- shape BxCxHxW, where H=W=256. Can be passed as mask input
95
- to subsequent iterations of prediction.
81
+ 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
82
+ input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
83
+ 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
84
+ 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
85
+ as mask input to subsequent iterations of prediction.
96
86
  """
97
87
  input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
98
88
  image_embeddings = self.image_encoder(input_images)
@@ -137,16 +127,12 @@ class Sam(nn.Module):
137
127
  Remove padding and upscale masks to the original image size.
138
128
 
139
129
  Args:
140
- masks (torch.Tensor): Batched masks from the mask_decoder,
141
- in BxCxHxW format.
142
- input_size (tuple(int, int)): The size of the image input to the
143
- model, in (H, W) format. Used to remove padding.
144
- original_size (tuple(int, int)): The original size of the image
145
- before resizing for input to the model, in (H, W) format.
130
+ masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
131
+ input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
132
+ original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
146
133
 
147
134
  Returns:
148
- (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
149
- is given by original_size.
135
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
150
136
  """
151
137
  masks = F.interpolate(
152
138
  masks,
@@ -318,8 +318,9 @@ class Predictor(BasePredictor):
318
318
  pred_bboxes = preds[2] if self.segment_all else None
319
319
  names = dict(enumerate(str(i) for i in range(len(pred_masks))))
320
320
  results = []
321
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
321
322
  for i, masks in enumerate([pred_masks]):
322
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
323
+ orig_img = orig_imgs[i] if is_list else orig_imgs
323
324
  if pred_bboxes is not None:
324
325
  pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
325
326
  cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
@@ -327,9 +328,8 @@ class Predictor(BasePredictor):
327
328
 
328
329
  masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
329
330
  masks = masks > self.model.mask_threshold # to bool
330
- path = self.batch[0]
331
- img_path = path[i] if isinstance(path, list) else path
332
- results.append(Results(orig_img=orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
331
+ img_path = self.batch[0][i]
332
+ results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
333
333
  # Reset segment-all mode.
334
334
  self.segment_all = False
335
335
  return results
@@ -39,10 +39,9 @@ class ClassificationPredictor(BasePredictor):
39
39
  def postprocess(self, preds, img, orig_imgs):
40
40
  """Post-processes predictions to return Results objects."""
41
41
  results = []
42
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
42
43
  for i, pred in enumerate(preds):
43
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
44
- path = self.batch[0]
45
- img_path = path[i] if isinstance(path, list) else path
46
- results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
47
-
44
+ orig_img = orig_imgs[i] if is_list else orig_imgs
45
+ img_path = self.batch[0][i]
46
+ results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
48
47
  return results
@@ -79,7 +79,7 @@ class ClassificationTrainer(BaseTrainer):
79
79
  return ckpt
80
80
 
81
81
  def build_dataset(self, img_path, mode='train', batch=None):
82
- return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
82
+ return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
83
83
 
84
84
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
85
85
  """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
@@ -77,7 +77,7 @@ class ClassificationValidator(BaseValidator):
77
77
  return self.metrics.results_dict
78
78
 
79
79
  def build_dataset(self, img_path):
80
- return ClassificationDataset(root=img_path, args=self.args, augment=False)
80
+ return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
81
81
 
82
82
  def get_dataloader(self, dataset_path, batch_size):
83
83
  """Builds and returns a data loader for classification tasks with given parameters."""
@@ -1,7 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- import torch
4
-
5
3
  from ultralytics.engine.predictor import BasePredictor
6
4
  from ultralytics.engine.results import Results
7
5
  from ultralytics.utils import ops
@@ -32,11 +30,11 @@ class DetectionPredictor(BasePredictor):
32
30
  classes=self.args.classes)
33
31
 
34
32
  results = []
33
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
35
34
  for i, pred in enumerate(preds):
36
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
37
- if not isinstance(orig_imgs, torch.Tensor):
35
+ orig_img = orig_imgs[i] if is_list else orig_imgs
36
+ if is_list:
38
37
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
39
- path = self.batch[0]
40
- img_path = path[i] if isinstance(path, list) else path
41
- results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
38
+ img_path = self.batch[0][i]
39
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
42
40
  return results
@@ -38,18 +38,13 @@ class PosePredictor(DetectionPredictor):
38
38
  nc=len(self.model.names))
39
39
 
40
40
  results = []
41
+ is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
41
42
  for i, pred in enumerate(preds):
42
- orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
43
- shape = orig_img.shape
44
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
43
+ orig_img = orig_imgs[i] if is_list else orig_imgs
44
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
45
45
  pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
46
- pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
47
- path = self.batch[0]
48
- img_path = path[i] if isinstance(path, list) else path
46
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
47
+ img_path = self.batch[0][i]
49
48
  results.append(
50
- Results(orig_img=orig_img,
51
- path=img_path,
52
- names=self.model.names,
53
- boxes=pred[:, :6],
54
- keypoints=pred_kpts))
49
+ Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts))
55
50
  return results