ultralytics 8.1.39__py3-none-any.whl → 8.1.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (47) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +3 -3
  3. ultralytics/data/augment.py +2 -2
  4. ultralytics/data/base.py +2 -2
  5. ultralytics/data/converter.py +2 -2
  6. ultralytics/data/dataset.py +4 -4
  7. ultralytics/data/loaders.py +11 -8
  8. ultralytics/data/split_dota.py +1 -1
  9. ultralytics/data/utils.py +8 -7
  10. ultralytics/engine/exporter.py +3 -3
  11. ultralytics/engine/model.py +6 -3
  12. ultralytics/engine/results.py +2 -2
  13. ultralytics/engine/trainer.py +22 -25
  14. ultralytics/engine/validator.py +2 -2
  15. ultralytics/hub/utils.py +1 -1
  16. ultralytics/models/fastsam/model.py +1 -1
  17. ultralytics/models/fastsam/prompt.py +4 -5
  18. ultralytics/models/nas/model.py +1 -1
  19. ultralytics/models/sam/model.py +1 -1
  20. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  21. ultralytics/models/yolo/classify/train.py +1 -1
  22. ultralytics/models/yolo/detect/train.py +1 -1
  23. ultralytics/models/yolo/world/train.py +16 -15
  24. ultralytics/nn/autobackend.py +5 -5
  25. ultralytics/nn/modules/conv.py +1 -1
  26. ultralytics/nn/modules/head.py +4 -4
  27. ultralytics/nn/tasks.py +1 -1
  28. ultralytics/solutions/ai_gym.py +1 -1
  29. ultralytics/solutions/heatmap.py +1 -1
  30. ultralytics/trackers/byte_tracker.py +1 -1
  31. ultralytics/trackers/track.py +1 -1
  32. ultralytics/trackers/utils/gmc.py +1 -1
  33. ultralytics/utils/__init__.py +4 -4
  34. ultralytics/utils/benchmarks.py +2 -2
  35. ultralytics/utils/callbacks/comet.py +1 -1
  36. ultralytics/utils/callbacks/mlflow.py +1 -1
  37. ultralytics/utils/checks.py +6 -4
  38. ultralytics/utils/downloads.py +2 -2
  39. ultralytics/utils/metrics.py +1 -1
  40. ultralytics/utils/plotting.py +1 -1
  41. ultralytics/utils/torch_utils.py +4 -4
  42. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/METADATA +1 -1
  43. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/RECORD +47 -47
  44. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/LICENSE +0 -0
  45. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/WHEEL +0 -0
  46. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/entry_points.txt +0 -0
  47. {ultralytics-8.1.39.dist-info → ultralytics-8.1.41.dist-info}/top_level.txt +0 -0
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.1.39"
3
+ __version__ = "8.1.41"
4
4
 
5
5
  from ultralytics.data.explorer.explorer import Explorer
6
6
  from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
@@ -272,7 +272,7 @@ def get_save_dir(args, name=None):
272
272
 
273
273
  project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
274
274
  name = name or args.name or f"{args.mode}"
275
- save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
275
+ save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)
276
276
 
277
277
  return Path(save_dir)
278
278
 
@@ -566,10 +566,10 @@ def entrypoint(debug=""):
566
566
  task = model.task
567
567
 
568
568
  # Mode
569
- if mode in ("predict", "track") and "source" not in overrides:
569
+ if mode in {"predict", "track"} and "source" not in overrides:
570
570
  overrides["source"] = DEFAULT_CFG.source or ASSETS
571
571
  LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
572
- elif mode in ("train", "val"):
572
+ elif mode in {"train", "val"}:
573
573
  if "data" not in overrides and "resume" not in overrides:
574
574
  overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
575
575
  LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
@@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
191
191
  def __init__(self, dataset, imgsz=640, p=1.0, n=4):
192
192
  """Initializes the object with a dataset, image size, probability, and border."""
193
193
  assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
194
- assert n in (4, 9), "grid must be equal to 4 or 9."
194
+ assert n in {4, 9}, "grid must be equal to 4 or 9."
195
195
  super().__init__(dataset=dataset, p=p)
196
196
  self.dataset = dataset
197
197
  self.imgsz = imgsz
@@ -685,7 +685,7 @@ class RandomFlip:
685
685
  Default is 'horizontal'.
686
686
  flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
687
687
  """
688
- assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
688
+ assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
689
689
  assert 0 <= p <= 1.0
690
690
 
691
691
  self.p = p
ultralytics/data/base.py CHANGED
@@ -15,7 +15,7 @@ import psutil
15
15
  from torch.utils.data import Dataset
16
16
 
17
17
  from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
18
- from .utils import HELP_URL, IMG_FORMATS
18
+ from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
19
19
 
20
20
 
21
21
  class BaseDataset(Dataset):
@@ -118,7 +118,7 @@ class BaseDataset(Dataset):
118
118
  raise FileNotFoundError(f"{self.prefix}{p} does not exist")
119
119
  im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
120
120
  # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
121
- assert im_files, f"{self.prefix}No images found in {img_path}"
121
+ assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
122
122
  except Exception as e:
123
123
  raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
124
124
  if self.fraction < 1:
@@ -481,7 +481,7 @@ def merge_multi_segment(segments):
481
481
  segments[i] = np.roll(segments[i], -idx[0], axis=0)
482
482
  segments[i] = np.concatenate([segments[i], segments[i][:1]])
483
483
  # Deal with the first segment and the last one
484
- if i in [0, len(idx_list) - 1]:
484
+ if i in {0, len(idx_list) - 1}:
485
485
  s.append(segments[i])
486
486
  else:
487
487
  idx = [0, idx[1] - idx[0]]
@@ -489,7 +489,7 @@ def merge_multi_segment(segments):
489
489
 
490
490
  else:
491
491
  for i in range(len(idx_list) - 1, -1, -1):
492
- if i not in [0, len(idx_list) - 1]:
492
+ if i not in {0, len(idx_list) - 1}:
493
493
  idx = idx_list[i]
494
494
  nidx = abs(idx[1] - idx[0])
495
495
  s.append(segments[i][nidx:])
@@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
77
77
  desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
78
78
  total = len(self.im_files)
79
79
  nkpt, ndim = self.data.get("kpt_shape", (0, 0))
80
- if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
80
+ if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
81
81
  raise ValueError(
82
82
  "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
83
83
  "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
@@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
142
142
 
143
143
  # Display cache
144
144
  nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
145
- if exists and LOCAL_RANK in (-1, 0):
145
+ if exists and LOCAL_RANK in {-1, 0}:
146
146
  d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
147
147
  TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
148
148
  if cache["msgs"]:
@@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
235
235
  value = values[i]
236
236
  if k == "img":
237
237
  value = torch.stack(value, 0)
238
- if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
238
+ if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
239
239
  value = torch.cat(value, 0)
240
240
  new_batch[k] = value
241
241
  new_batch["batch_idx"] = list(new_batch["batch_idx"])
@@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
334
334
  assert cache["version"] == DATASET_CACHE_VERSION # matches current version
335
335
  assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
336
336
  nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
337
- if LOCAL_RANK in (-1, 0):
337
+ if LOCAL_RANK in {-1, 0}:
338
338
  d = f"{desc} {nf} images, {nc} corrupt"
339
339
  TQDM(None, desc=d, total=n, initial=n)
340
340
  if cache["msgs"]:
@@ -15,7 +15,7 @@ import requests
15
15
  import torch
16
16
  from PIL import Image
17
17
 
18
- from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
18
+ from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
19
19
  from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
20
20
  from ultralytics.utils.checks import check_requirements
21
21
 
@@ -83,7 +83,7 @@ class LoadStreams:
83
83
  for i, s in enumerate(sources): # index, source
84
84
  # Start thread to read frames from video stream
85
85
  st = f"{i + 1}/{n}: {s}... "
86
- if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
86
+ if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
87
87
  # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
88
88
  s = get_best_youtube_url(s)
89
89
  s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
@@ -291,8 +291,14 @@ class LoadImagesAndVideos:
291
291
  else:
292
292
  raise FileNotFoundError(f"{p} does not exist")
293
293
 
294
- images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
295
- videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
294
+ # Define files as images or videos
295
+ images, videos = [], []
296
+ for f in files:
297
+ suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
298
+ if suffix in IMG_FORMATS:
299
+ images.append(f)
300
+ elif suffix in VID_FORMATS:
301
+ videos.append(f)
296
302
  ni, nv = len(images), len(videos)
297
303
 
298
304
  self.files = images + videos
@@ -307,10 +313,7 @@ class LoadImagesAndVideos:
307
313
  else:
308
314
  self.cap = None
309
315
  if self.nf == 0:
310
- raise FileNotFoundError(
311
- f"No images or videos found in {p}. "
312
- f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
313
- )
316
+ raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
314
317
 
315
318
  def __iter__(self):
316
319
  """Returns an iterator object for VideoStream or ImageFolder."""
@@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
71
71
  - train
72
72
  - val
73
73
  """
74
- assert split in ["train", "val"]
74
+ assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
75
75
  im_dir = Path(data_root) / "images" / split
76
76
  assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
77
77
  im_files = glob(str(Path(data_root) / "images" / split / "*"))
ultralytics/data/utils.py CHANGED
@@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
39
39
  IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
40
40
  VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
41
41
  PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
42
+ FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
42
43
 
43
44
 
44
45
  def img2label_paths(img_paths):
@@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
63
64
  exif = img.getexif()
64
65
  if exif:
65
66
  rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
66
- if rotation in [6, 8]: # rotation 270 or 90
67
+ if rotation in {6, 8}: # rotation 270 or 90
67
68
  s = s[1], s[0]
68
69
  return s
69
70
 
@@ -79,8 +80,8 @@ def verify_image(args):
79
80
  shape = exif_size(im) # image size
80
81
  shape = (shape[1], shape[0]) # hw
81
82
  assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
82
- assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
83
- if im.format.lower() in ("jpg", "jpeg"):
83
+ assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
84
+ if im.format.lower() in {"jpg", "jpeg"}:
84
85
  with open(im_file, "rb") as f:
85
86
  f.seek(-2, 2)
86
87
  if f.read() != b"\xff\xd9": # corrupt JPEG
@@ -105,8 +106,8 @@ def verify_image_label(args):
105
106
  shape = exif_size(im) # image size
106
107
  shape = (shape[1], shape[0]) # hw
107
108
  assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
108
- assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
109
- if im.format.lower() in ("jpg", "jpeg"):
109
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
110
+ if im.format.lower() in {"jpg", "jpeg"}:
110
111
  with open(im_file, "rb") as f:
111
112
  f.seek(-2, 2)
112
113
  if f.read() != b"\xff\xd9": # corrupt JPEG
@@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
336
337
  else: # python script
337
338
  exec(s, {"yaml": data})
338
339
  dt = f"({round(time.time() - t, 1)}s)"
339
- s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
340
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
340
341
  LOGGER.info(f"Dataset download {s}\n")
341
342
  check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
342
343
 
@@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
366
367
  # Download (optional if dataset=https://file.zip is passed directly)
367
368
  if str(dataset).startswith(("http:/", "https:/")):
368
369
  dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
369
- elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
370
+ elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
370
371
  file = check_file(dataset)
371
372
  dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
372
373
 
@@ -159,7 +159,7 @@ class Exporter:
159
159
  _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
160
160
  """
161
161
  self.args = get_cfg(cfg, overrides)
162
- if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
162
+ if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
163
163
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
164
164
 
165
165
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
@@ -171,9 +171,9 @@ class Exporter:
171
171
  self.run_callbacks("on_export_start")
172
172
  t = time.time()
173
173
  fmt = self.args.format.lower() # to lowercase
174
- if fmt in ("tensorrt", "trt"): # 'engine' aliases
174
+ if fmt in {"tensorrt", "trt"}: # 'engine' aliases
175
175
  fmt = "engine"
176
- if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
176
+ if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
177
177
  fmt = "coreml"
178
178
  fmts = tuple(export_formats()["Argument"][1:]) # available export formats
179
179
  flags = [x == fmt for x in fmts]
@@ -145,7 +145,7 @@ class Model(nn.Module):
145
145
  return
146
146
 
147
147
  # Load or create new YOLO model
148
- if Path(model).suffix in (".yaml", ".yml"):
148
+ if Path(model).suffix in {".yaml", ".yml"}:
149
149
  self._new(model, task=task, verbose=verbose)
150
150
  else:
151
151
  self._load(model, task=task)
@@ -666,7 +666,7 @@ class Model(nn.Module):
666
666
  self.trainer.hub_session = self.session # attach optional HUB session
667
667
  self.trainer.train()
668
668
  # Update model and cfg after training
669
- if RANK in (-1, 0):
669
+ if RANK in {-1, 0}:
670
670
  ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
671
671
  self.model, _ = attempt_load_one_weight(ckpt)
672
672
  self.overrides = self.model.args
@@ -735,7 +735,10 @@ class Model(nn.Module):
735
735
 
736
736
  if hasattr(self.model, "names"):
737
737
  return check_class_names(self.model.names)
738
- elif self.predictor:
738
+ else:
739
+ if not self.predictor: # export formats will not have predictor defined until predict() is called
740
+ self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
741
+ self.predictor.setup_model(model=self.model, verbose=False)
739
742
  return self.predictor.model.names
740
743
 
741
744
  @property
@@ -470,7 +470,7 @@ class Boxes(BaseTensor):
470
470
  if boxes.ndim == 1:
471
471
  boxes = boxes[None, :]
472
472
  n = boxes.shape[-1]
473
- assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
473
+ assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
474
474
  super().__init__(boxes, orig_shape)
475
475
  self.is_track = n == 7
476
476
  self.orig_shape = orig_shape
@@ -687,7 +687,7 @@ class OBB(BaseTensor):
687
687
  if boxes.ndim == 1:
688
688
  boxes = boxes[None, :]
689
689
  n = boxes.shape[-1]
690
- assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
690
+ assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
691
691
  super().__init__(boxes, orig_shape)
692
692
  self.is_track = n == 8
693
693
  self.orig_shape = orig_shape
@@ -107,7 +107,7 @@ class BaseTrainer:
107
107
  self.save_dir = get_save_dir(self.args)
108
108
  self.args.name = self.save_dir.name # update name for loggers
109
109
  self.wdir = self.save_dir / "weights" # weights dir
110
- if RANK in (-1, 0):
110
+ if RANK in {-1, 0}:
111
111
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
112
112
  self.args.save_dir = str(self.save_dir)
113
113
  yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
@@ -121,7 +121,7 @@ class BaseTrainer:
121
121
  print_args(vars(self.args))
122
122
 
123
123
  # Device
124
- if self.device.type in ("cpu", "mps"):
124
+ if self.device.type in {"cpu", "mps"}:
125
125
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
126
126
 
127
127
  # Model and Dataset
@@ -144,7 +144,7 @@ class BaseTrainer:
144
144
 
145
145
  # Callbacks
146
146
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
147
- if RANK in (-1, 0):
147
+ if RANK in {-1, 0}:
148
148
  callbacks.add_integration_callbacks(self)
149
149
 
150
150
  def add_callback(self, event: str, callback):
@@ -210,9 +210,9 @@ class BaseTrainer:
210
210
  torch.cuda.set_device(RANK)
211
211
  self.device = torch.device("cuda", RANK)
212
212
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
213
- os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
213
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
214
214
  dist.init_process_group(
215
- "nccl" if dist.is_nccl_available() else "gloo",
215
+ backend="nccl" if dist.is_nccl_available() else "gloo",
216
216
  timeout=timedelta(seconds=10800), # 3 hours
217
217
  rank=RANK,
218
218
  world_size=world_size,
@@ -251,7 +251,7 @@ class BaseTrainer:
251
251
 
252
252
  # Check AMP
253
253
  self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
254
- if self.amp and RANK in (-1, 0): # Single-GPU and DDP
254
+ if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
255
255
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
256
256
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
257
257
  callbacks.default_callbacks = callbacks_backup # restore callbacks
@@ -274,7 +274,7 @@ class BaseTrainer:
274
274
  # Dataloaders
275
275
  batch_size = self.batch_size // max(world_size, 1)
276
276
  self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
277
- if RANK in (-1, 0):
277
+ if RANK in {-1, 0}:
278
278
  # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
279
279
  self.test_loader = self.get_dataloader(
280
280
  self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
@@ -340,7 +340,7 @@ class BaseTrainer:
340
340
  self._close_dataloader_mosaic()
341
341
  self.train_loader.reset()
342
342
 
343
- if RANK in (-1, 0):
343
+ if RANK in {-1, 0}:
344
344
  LOGGER.info(self.progress_string())
345
345
  pbar = TQDM(enumerate(self.train_loader), total=nb)
346
346
  self.tloss = None
@@ -392,7 +392,7 @@ class BaseTrainer:
392
392
  mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
393
393
  loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
394
394
  losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
395
- if RANK in (-1, 0):
395
+ if RANK in {-1, 0}:
396
396
  pbar.set_description(
397
397
  ("%11s" * 2 + "%11.4g" * (2 + loss_len))
398
398
  % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
@@ -405,7 +405,7 @@ class BaseTrainer:
405
405
 
406
406
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
407
407
  self.run_callbacks("on_train_epoch_end")
408
- if RANK in (-1, 0):
408
+ if RANK in {-1, 0}:
409
409
  final_epoch = epoch + 1 >= self.epochs
410
410
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
411
411
 
@@ -447,7 +447,7 @@ class BaseTrainer:
447
447
  break # must break all DDP ranks
448
448
  epoch += 1
449
449
 
450
- if RANK in (-1, 0):
450
+ if RANK in {-1, 0}:
451
451
  # Do final val with best.pt
452
452
  LOGGER.info(
453
453
  f"\n{epoch - self.start_epoch + 1} epochs completed in "
@@ -503,12 +503,12 @@ class BaseTrainer:
503
503
  try:
504
504
  if self.args.task == "classify":
505
505
  data = check_cls_dataset(self.args.data)
506
- elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
506
+ elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
507
507
  "detect",
508
508
  "segment",
509
509
  "pose",
510
510
  "obb",
511
- ):
511
+ }:
512
512
  data = check_det_dataset(self.args.data)
513
513
  if "yaml_file" in data:
514
514
  self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
@@ -648,8 +648,8 @@ class BaseTrainer:
648
648
 
649
649
  resume = True
650
650
  self.args = get_cfg(ckpt_args)
651
- self.args.model = str(last) # reinstate model
652
- for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
651
+ self.args.model = self.args.resume = str(last) # reinstate model
652
+ for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
653
653
  if k in overrides:
654
654
  setattr(self.args, k, overrides[k])
655
655
 
@@ -662,7 +662,7 @@ class BaseTrainer:
662
662
 
663
663
  def resume_training(self, ckpt):
664
664
  """Resume YOLO training from given epoch and best fitness."""
665
- if ckpt is None:
665
+ if ckpt is None or not self.resume:
666
666
  return
667
667
  best_fitness = 0.0
668
668
  start_epoch = ckpt.get("epoch", -1) + 1
@@ -672,14 +672,11 @@ class BaseTrainer:
672
672
  if self.ema and ckpt.get("ema"):
673
673
  self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
674
674
  self.ema.updates = ckpt["updates"]
675
- if self.resume:
676
- assert start_epoch > 0, (
677
- f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
678
- f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
679
- )
680
- LOGGER.info(
681
- f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
682
- )
675
+ assert start_epoch > 0, (
676
+ f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
677
+ f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
678
+ )
679
+ LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
683
680
  if self.epochs < start_epoch:
684
681
  LOGGER.info(
685
682
  f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
@@ -740,7 +737,7 @@ class BaseTrainer:
740
737
  else: # weight (with decay)
741
738
  g[0].append(param)
742
739
 
743
- if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
740
+ if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
744
741
  optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
745
742
  elif name == "RMSProp":
746
743
  optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
@@ -139,14 +139,14 @@ class BaseValidator:
139
139
  self.args.batch = 1 # export.py models default to batch-size 1
140
140
  LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
141
141
 
142
- if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
142
+ if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
143
143
  self.data = check_det_dataset(self.args.data)
144
144
  elif self.args.task == "classify":
145
145
  self.data = check_cls_dataset(self.args.data, split=self.args.split)
146
146
  else:
147
147
  raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
148
148
 
149
- if self.device.type in ("cpu", "mps"):
149
+ if self.device.type in {"cpu", "mps"}:
150
150
  self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
151
151
  if not pt:
152
152
  self.args.rect = False
ultralytics/hub/utils.py CHANGED
@@ -198,7 +198,7 @@ class Events:
198
198
  }
199
199
  self.enabled = (
200
200
  SETTINGS["sync"]
201
- and RANK in (-1, 0)
201
+ and RANK in {-1, 0}
202
202
  and not TESTS_RUNNING
203
203
  and ONLINE
204
204
  and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
@@ -24,7 +24,7 @@ class FastSAM(Model):
24
24
  """Call the __init__ method of the parent class (YOLO) with the updated default model."""
25
25
  if str(model) == "FastSAM.pt":
26
26
  model = "FastSAM-x.pt"
27
- assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
27
+ assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
28
28
  super().__init__(model=model, task="segment")
29
29
 
30
30
  @property
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from PIL import Image
11
11
 
12
- from ultralytics.utils import TQDM
12
+ from ultralytics.utils import TQDM, checks
13
13
 
14
14
 
15
15
  class FastSAMPrompt:
@@ -33,9 +33,7 @@ class FastSAMPrompt:
33
33
  try:
34
34
  import clip
35
35
  except ImportError:
36
- from ultralytics.utils.checks import check_requirements
37
-
38
- check_requirements("git+https://github.com/ultralytics/CLIP.git")
36
+ checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
39
37
  import clip
40
38
  self.clip = clip
41
39
 
@@ -115,7 +113,8 @@ class FastSAMPrompt:
115
113
  points (list, optional): Points to be plotted. Defaults to None.
116
114
  point_label (list, optional): Labels for the points. Defaults to None.
117
115
  mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
118
- better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
116
+ better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
117
+ Defaults to True.
119
118
  retina (bool, optional): Whether to use retina mask. Defaults to False.
120
119
  with_contours (bool, optional): Whether to plot contours. Defaults to True.
121
120
  """
@@ -45,7 +45,7 @@ class NAS(Model):
45
45
 
46
46
  def __init__(self, model="yolo_nas_s.pt") -> None:
47
47
  """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
48
- assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
48
+ assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
49
49
  super().__init__(model, task="detect")
50
50
 
51
51
  @smart_inference_mode()
@@ -41,7 +41,7 @@ class SAM(Model):
41
41
  Raises:
42
42
  NotImplementedError: If the model file extension is not .pt or .pth.
43
43
  """
44
- if model and Path(model).suffix not in (".pt", ".pth"):
44
+ if model and Path(model).suffix not in {".pt", ".pth"}:
45
45
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
46
46
  super().__init__(model=model, task="segment")
47
47
 
@@ -112,7 +112,7 @@ class PatchMerging(nn.Module):
112
112
  self.out_dim = out_dim
113
113
  self.act = activation()
114
114
  self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
115
- stride_c = 1 if out_dim in [320, 448, 576] else 2
115
+ stride_c = 1 if out_dim in {320, 448, 576} else 2
116
116
  self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
117
117
  self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
118
118
 
@@ -68,7 +68,7 @@ class ClassificationTrainer(BaseTrainer):
68
68
  self.model, ckpt = attempt_load_one_weight(model, device="cpu")
69
69
  for p in self.model.parameters():
70
70
  p.requires_grad = True # for training
71
- elif model.split(".")[-1] in ("yaml", "yml"):
71
+ elif model.split(".")[-1] in {"yaml", "yml"}:
72
72
  self.model = self.get_model(cfg=model)
73
73
  elif model in torchvision.models.__dict__:
74
74
  self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
@@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
44
44
 
45
45
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
46
46
  """Construct and return dataloader."""
47
- assert mode in ["train", "val"]
47
+ assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
48
48
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
49
49
  dataset = self.build_dataset(dataset_path, mode, batch_size)
50
50
  shuffle = mode == "train"