ultralytics 8.1.38__py3-none-any.whl → 8.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (58) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +3 -3
  3. ultralytics/cfg/datasets/lvis.yaml +1239 -0
  4. ultralytics/data/__init__.py +18 -2
  5. ultralytics/data/augment.py +124 -3
  6. ultralytics/data/base.py +2 -2
  7. ultralytics/data/build.py +25 -3
  8. ultralytics/data/converter.py +24 -6
  9. ultralytics/data/dataset.py +142 -27
  10. ultralytics/data/loaders.py +11 -8
  11. ultralytics/data/split_dota.py +1 -1
  12. ultralytics/data/utils.py +33 -8
  13. ultralytics/engine/exporter.py +3 -3
  14. ultralytics/engine/model.py +6 -3
  15. ultralytics/engine/results.py +2 -2
  16. ultralytics/engine/trainer.py +59 -55
  17. ultralytics/engine/validator.py +2 -2
  18. ultralytics/hub/utils.py +1 -1
  19. ultralytics/models/fastsam/model.py +1 -1
  20. ultralytics/models/fastsam/prompt.py +4 -5
  21. ultralytics/models/nas/model.py +1 -1
  22. ultralytics/models/sam/model.py +1 -1
  23. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  24. ultralytics/models/yolo/__init__.py +2 -2
  25. ultralytics/models/yolo/classify/train.py +1 -1
  26. ultralytics/models/yolo/detect/train.py +1 -1
  27. ultralytics/models/yolo/detect/val.py +36 -17
  28. ultralytics/models/yolo/model.py +1 -0
  29. ultralytics/models/yolo/world/__init__.py +5 -0
  30. ultralytics/models/yolo/world/train.py +92 -0
  31. ultralytics/models/yolo/world/train_world.py +108 -0
  32. ultralytics/nn/autobackend.py +5 -5
  33. ultralytics/nn/modules/block.py +4 -2
  34. ultralytics/nn/modules/conv.py +1 -1
  35. ultralytics/nn/modules/head.py +13 -4
  36. ultralytics/nn/tasks.py +30 -14
  37. ultralytics/solutions/ai_gym.py +1 -1
  38. ultralytics/solutions/heatmap.py +85 -47
  39. ultralytics/solutions/object_counter.py +79 -64
  40. ultralytics/trackers/byte_tracker.py +1 -1
  41. ultralytics/trackers/track.py +1 -1
  42. ultralytics/trackers/utils/gmc.py +1 -1
  43. ultralytics/utils/__init__.py +4 -4
  44. ultralytics/utils/benchmarks.py +2 -2
  45. ultralytics/utils/callbacks/comet.py +1 -1
  46. ultralytics/utils/callbacks/mlflow.py +1 -1
  47. ultralytics/utils/checks.py +3 -3
  48. ultralytics/utils/downloads.py +2 -2
  49. ultralytics/utils/loss.py +1 -1
  50. ultralytics/utils/metrics.py +1 -1
  51. ultralytics/utils/plotting.py +36 -22
  52. ultralytics/utils/torch_utils.py +17 -3
  53. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/METADATA +1 -1
  54. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/RECORD +58 -54
  55. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/LICENSE +0 -0
  56. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/WHEEL +0 -0
  57. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/entry_points.txt +0 -0
  58. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/top_level.txt +0 -0
ultralytics/data/utils.py CHANGED
@@ -29,6 +29,7 @@ from ultralytics.utils import (
29
29
  emojis,
30
30
  yaml_load,
31
31
  yaml_save,
32
+ is_dir_writeable,
32
33
  )
33
34
  from ultralytics.utils.checks import check_file, check_font, is_ascii
34
35
  from ultralytics.utils.downloads import download, safe_download, unzip_file
@@ -38,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
38
39
  IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
39
40
  VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
40
41
  PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
42
+ FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
41
43
 
42
44
 
43
45
  def img2label_paths(img_paths):
@@ -62,7 +64,7 @@ def exif_size(img: Image.Image):
62
64
  exif = img.getexif()
63
65
  if exif:
64
66
  rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
65
- if rotation in [6, 8]: # rotation 270 or 90
67
+ if rotation in {6, 8}: # rotation 270 or 90
66
68
  s = s[1], s[0]
67
69
  return s
68
70
 
@@ -78,8 +80,8 @@ def verify_image(args):
78
80
  shape = exif_size(im) # image size
79
81
  shape = (shape[1], shape[0]) # hw
80
82
  assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
81
- assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
82
- if im.format.lower() in ("jpg", "jpeg"):
83
+ assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
84
+ if im.format.lower() in {"jpg", "jpeg"}:
83
85
  with open(im_file, "rb") as f:
84
86
  f.seek(-2, 2)
85
87
  if f.read() != b"\xff\xd9": # corrupt JPEG
@@ -104,8 +106,8 @@ def verify_image_label(args):
104
106
  shape = exif_size(im) # image size
105
107
  shape = (shape[1], shape[0]) # hw
106
108
  assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
107
- assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
108
- if im.format.lower() in ("jpg", "jpeg"):
109
+ assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
110
+ if im.format.lower() in {"jpg", "jpeg"}:
109
111
  with open(im_file, "rb") as f:
110
112
  f.seek(-2, 2)
111
113
  if f.read() != b"\xff\xd9": # corrupt JPEG
@@ -303,7 +305,7 @@ def check_det_dataset(dataset, autodownload=True):
303
305
 
304
306
  # Set paths
305
307
  data["path"] = path # download scripts
306
- for k in "train", "val", "test":
308
+ for k in "train", "val", "test", "minival":
307
309
  if data.get(k): # prepend path
308
310
  if isinstance(data[k], str):
309
311
  x = (path / data[k]).resolve()
@@ -335,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
335
337
  else: # python script
336
338
  exec(s, {"yaml": data})
337
339
  dt = f"({round(time.time() - t, 1)}s)"
338
- s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
340
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
339
341
  LOGGER.info(f"Dataset download {s}\n")
340
342
  check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
341
343
 
@@ -365,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
365
367
  # Download (optional if dataset=https://file.zip is passed directly)
366
368
  if str(dataset).startswith(("http:/", "https:/")):
367
369
  dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
368
- elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
370
+ elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
369
371
  file = check_file(dataset)
370
372
  dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
371
373
 
@@ -649,3 +651,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
649
651
  if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
650
652
  with open(path.parent / txt[i], "a") as f:
651
653
  f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
654
+
655
+
656
+ def load_dataset_cache_file(path):
657
+ """Load an Ultralytics *.cache dictionary from path."""
658
+ import gc
659
+
660
+ gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
661
+ cache = np.load(str(path), allow_pickle=True).item() # load dict
662
+ gc.enable()
663
+ return cache
664
+
665
+
666
+ def save_dataset_cache_file(prefix, path, x, version):
667
+ """Save an Ultralytics dataset *.cache dictionary x to path."""
668
+ x["version"] = version # add cache version
669
+ if is_dir_writeable(path.parent):
670
+ if path.exists():
671
+ path.unlink() # remove *.cache file if exists
672
+ np.save(str(path), x) # save cache for next time
673
+ path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
674
+ LOGGER.info(f"{prefix}New cache created: {path}")
675
+ else:
676
+ LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
@@ -159,7 +159,7 @@ class Exporter:
159
159
  _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
160
160
  """
161
161
  self.args = get_cfg(cfg, overrides)
162
- if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
162
+ if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
163
163
  os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
164
164
 
165
165
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
@@ -171,9 +171,9 @@ class Exporter:
171
171
  self.run_callbacks("on_export_start")
172
172
  t = time.time()
173
173
  fmt = self.args.format.lower() # to lowercase
174
- if fmt in ("tensorrt", "trt"): # 'engine' aliases
174
+ if fmt in {"tensorrt", "trt"}: # 'engine' aliases
175
175
  fmt = "engine"
176
- if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
176
+ if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
177
177
  fmt = "coreml"
178
178
  fmts = tuple(export_formats()["Argument"][1:]) # available export formats
179
179
  flags = [x == fmt for x in fmts]
@@ -145,7 +145,7 @@ class Model(nn.Module):
145
145
  return
146
146
 
147
147
  # Load or create new YOLO model
148
- if Path(model).suffix in (".yaml", ".yml"):
148
+ if Path(model).suffix in {".yaml", ".yml"}:
149
149
  self._new(model, task=task, verbose=verbose)
150
150
  else:
151
151
  self._load(model, task=task)
@@ -666,7 +666,7 @@ class Model(nn.Module):
666
666
  self.trainer.hub_session = self.session # attach optional HUB session
667
667
  self.trainer.train()
668
668
  # Update model and cfg after training
669
- if RANK in (-1, 0):
669
+ if RANK in {-1, 0}:
670
670
  ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
671
671
  self.model, _ = attempt_load_one_weight(ckpt)
672
672
  self.overrides = self.model.args
@@ -733,7 +733,10 @@ class Model(nn.Module):
733
733
  """
734
734
  from ultralytics.nn.autobackend import check_class_names
735
735
 
736
- return check_class_names(self.model.names) if hasattr(self.model, "names") else None
736
+ if hasattr(self.model, "names"):
737
+ return check_class_names(self.model.names)
738
+ elif self.predictor:
739
+ return self.predictor.model.names
737
740
 
738
741
  @property
739
742
  def device(self) -> torch.device:
@@ -470,7 +470,7 @@ class Boxes(BaseTensor):
470
470
  if boxes.ndim == 1:
471
471
  boxes = boxes[None, :]
472
472
  n = boxes.shape[-1]
473
- assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
473
+ assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
474
474
  super().__init__(boxes, orig_shape)
475
475
  self.is_track = n == 7
476
476
  self.orig_shape = orig_shape
@@ -687,7 +687,7 @@ class OBB(BaseTensor):
687
687
  if boxes.ndim == 1:
688
688
  boxes = boxes[None, :]
689
689
  n = boxes.shape[-1]
690
- assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
690
+ assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
691
691
  super().__init__(boxes, orig_shape)
692
692
  self.is_track = n == 8
693
693
  self.orig_shape = orig_shape
@@ -42,7 +42,7 @@ from ultralytics.utils.files import get_latest_run
42
42
  from ultralytics.utils.torch_utils import (
43
43
  EarlyStopping,
44
44
  ModelEMA,
45
- de_parallel,
45
+ convert_optimizer_state_dict_to_fp16,
46
46
  init_seeds,
47
47
  one_cycle,
48
48
  select_device,
@@ -107,7 +107,7 @@ class BaseTrainer:
107
107
  self.save_dir = get_save_dir(self.args)
108
108
  self.args.name = self.save_dir.name # update name for loggers
109
109
  self.wdir = self.save_dir / "weights" # weights dir
110
- if RANK in (-1, 0):
110
+ if RANK in {-1, 0}:
111
111
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
112
112
  self.args.save_dir = str(self.save_dir)
113
113
  yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
@@ -121,27 +121,12 @@ class BaseTrainer:
121
121
  print_args(vars(self.args))
122
122
 
123
123
  # Device
124
- if self.device.type in ("cpu", "mps"):
124
+ if self.device.type in {"cpu", "mps"}:
125
125
  self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
126
126
 
127
127
  # Model and Dataset
128
128
  self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
129
- try:
130
- if self.args.task == "classify":
131
- self.data = check_cls_dataset(self.args.data)
132
- elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
133
- "detect",
134
- "segment",
135
- "pose",
136
- "obb",
137
- ):
138
- self.data = check_det_dataset(self.args.data)
139
- if "yaml_file" in self.data:
140
- self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
141
- except Exception as e:
142
- raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
143
-
144
- self.trainset, self.testset = self.get_dataset(self.data)
129
+ self.trainset, self.testset = self.get_dataset()
145
130
  self.ema = None
146
131
 
147
132
  # Optimization utils init
@@ -159,7 +144,7 @@ class BaseTrainer:
159
144
 
160
145
  # Callbacks
161
146
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
162
- if RANK in (-1, 0):
147
+ if RANK in {-1, 0}:
163
148
  callbacks.add_integration_callbacks(self)
164
149
 
165
150
  def add_callback(self, event: str, callback):
@@ -225,7 +210,7 @@ class BaseTrainer:
225
210
  torch.cuda.set_device(RANK)
226
211
  self.device = torch.device("cuda", RANK)
227
212
  # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
228
- os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
213
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
229
214
  dist.init_process_group(
230
215
  "nccl" if dist.is_nccl_available() else "gloo",
231
216
  timeout=timedelta(seconds=10800), # 3 hours
@@ -266,7 +251,7 @@ class BaseTrainer:
266
251
 
267
252
  # Check AMP
268
253
  self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
269
- if self.amp and RANK in (-1, 0): # Single-GPU and DDP
254
+ if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
270
255
  callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
271
256
  self.amp = torch.tensor(check_amp(self.model), device=self.device)
272
257
  callbacks.default_callbacks = callbacks_backup # restore callbacks
@@ -289,7 +274,7 @@ class BaseTrainer:
289
274
  # Dataloaders
290
275
  batch_size = self.batch_size // max(world_size, 1)
291
276
  self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
292
- if RANK in (-1, 0):
277
+ if RANK in {-1, 0}:
293
278
  # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
294
279
  self.test_loader = self.get_dataloader(
295
280
  self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
@@ -355,7 +340,7 @@ class BaseTrainer:
355
340
  self._close_dataloader_mosaic()
356
341
  self.train_loader.reset()
357
342
 
358
- if RANK in (-1, 0):
343
+ if RANK in {-1, 0}:
359
344
  LOGGER.info(self.progress_string())
360
345
  pbar = TQDM(enumerate(self.train_loader), total=nb)
361
346
  self.tloss = None
@@ -407,7 +392,7 @@ class BaseTrainer:
407
392
  mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
408
393
  loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
409
394
  losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
410
- if RANK in (-1, 0):
395
+ if RANK in {-1, 0}:
411
396
  pbar.set_description(
412
397
  ("%11s" * 2 + "%11.4g" * (2 + loss_len))
413
398
  % (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
@@ -420,7 +405,7 @@ class BaseTrainer:
420
405
 
421
406
  self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
422
407
  self.run_callbacks("on_train_epoch_end")
423
- if RANK in (-1, 0):
408
+ if RANK in {-1, 0}:
424
409
  final_epoch = epoch + 1 >= self.epochs
425
410
  self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
426
411
 
@@ -462,7 +447,7 @@ class BaseTrainer:
462
447
  break # must break all DDP ranks
463
448
  epoch += 1
464
449
 
465
- if RANK in (-1, 0):
450
+ if RANK in {-1, 0}:
466
451
  # Do final val with best.pt
467
452
  LOGGER.info(
468
453
  f"\n{epoch - self.start_epoch + 1} epochs completed in "
@@ -477,40 +462,59 @@ class BaseTrainer:
477
462
 
478
463
  def save_model(self):
479
464
  """Save model training checkpoints with additional metadata."""
465
+ import io
480
466
  import pandas as pd # scope for faster startup
481
467
 
482
- metrics = {**self.metrics, **{"fitness": self.fitness}}
483
- results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
484
- ckpt = {
485
- "epoch": self.epoch,
486
- "best_fitness": self.best_fitness,
487
- "model": deepcopy(de_parallel(self.model)).half(),
488
- "ema": deepcopy(self.ema.ema).half(),
489
- "updates": self.ema.updates,
490
- "optimizer": self.optimizer.state_dict(),
491
- "train_args": vars(self.args), # save as dict
492
- "train_metrics": metrics,
493
- "train_results": results,
494
- "date": datetime.now().isoformat(),
495
- "version": __version__,
496
- "license": "AGPL-3.0 (https://ultralytics.com/license)",
497
- "docs": "https://docs.ultralytics.com",
498
- }
499
-
500
- # Save last and best
501
- torch.save(ckpt, self.last)
468
+ # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
469
+ buffer = io.BytesIO()
470
+ torch.save(
471
+ {
472
+ "epoch": self.epoch,
473
+ "best_fitness": self.best_fitness,
474
+ "model": None, # resume and final checkpoints derive from EMA
475
+ "ema": deepcopy(self.ema.ema).half(),
476
+ "updates": self.ema.updates,
477
+ "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
478
+ "train_args": vars(self.args), # save as dict
479
+ "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
480
+ "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
481
+ "date": datetime.now().isoformat(),
482
+ "version": __version__,
483
+ "license": "AGPL-3.0 (https://ultralytics.com/license)",
484
+ "docs": "https://docs.ultralytics.com",
485
+ },
486
+ buffer,
487
+ )
488
+ serialized_ckpt = buffer.getvalue() # get the serialized content to save
489
+
490
+ # Save checkpoints
491
+ self.last.write_bytes(serialized_ckpt) # save last.pt
502
492
  if self.best_fitness == self.fitness:
503
- torch.save(ckpt, self.best)
493
+ self.best.write_bytes(serialized_ckpt) # save best.pt
504
494
  if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
505
- torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
495
+ (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
506
496
 
507
- @staticmethod
508
- def get_dataset(data):
497
+ def get_dataset(self):
509
498
  """
510
499
  Get train, val path from data dict if it exists.
511
500
 
512
501
  Returns None if data format is not recognized.
513
502
  """
503
+ try:
504
+ if self.args.task == "classify":
505
+ data = check_cls_dataset(self.args.data)
506
+ elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
507
+ "detect",
508
+ "segment",
509
+ "pose",
510
+ "obb",
511
+ }:
512
+ data = check_det_dataset(self.args.data)
513
+ if "yaml_file" in data:
514
+ self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
515
+ except Exception as e:
516
+ raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
517
+ self.data = data
514
518
  return data["train"], data.get("val") or data.get("test")
515
519
 
516
520
  def setup_model(self):
@@ -522,7 +526,7 @@ class BaseTrainer:
522
526
  ckpt = None
523
527
  if str(model).endswith(".pt"):
524
528
  weights, ckpt = attempt_load_one_weight(model)
525
- cfg = ckpt["model"].yaml
529
+ cfg = weights.yaml
526
530
  else:
527
531
  cfg = model
528
532
  self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
@@ -661,8 +665,8 @@ class BaseTrainer:
661
665
  if ckpt is None:
662
666
  return
663
667
  best_fitness = 0.0
664
- start_epoch = ckpt["epoch"] + 1
665
- if ckpt["optimizer"] is not None:
668
+ start_epoch = ckpt.get("epoch", -1) + 1
669
+ if ckpt.get("optimizer", None) is not None:
666
670
  self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
667
671
  best_fitness = ckpt["best_fitness"]
668
672
  if self.ema and ckpt.get("ema"):
@@ -736,7 +740,7 @@ class BaseTrainer:
736
740
  else: # weight (with decay)
737
741
  g[0].append(param)
738
742
 
739
- if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
743
+ if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
740
744
  optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
741
745
  elif name == "RMSProp":
742
746
  optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
@@ -139,14 +139,14 @@ class BaseValidator:
139
139
  self.args.batch = 1 # export.py models default to batch-size 1
140
140
  LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
141
141
 
142
- if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
142
+ if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
143
143
  self.data = check_det_dataset(self.args.data)
144
144
  elif self.args.task == "classify":
145
145
  self.data = check_cls_dataset(self.args.data, split=self.args.split)
146
146
  else:
147
147
  raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
148
148
 
149
- if self.device.type in ("cpu", "mps"):
149
+ if self.device.type in {"cpu", "mps"}:
150
150
  self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
151
151
  if not pt:
152
152
  self.args.rect = False
ultralytics/hub/utils.py CHANGED
@@ -198,7 +198,7 @@ class Events:
198
198
  }
199
199
  self.enabled = (
200
200
  SETTINGS["sync"]
201
- and RANK in (-1, 0)
201
+ and RANK in {-1, 0}
202
202
  and not TESTS_RUNNING
203
203
  and ONLINE
204
204
  and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
@@ -24,7 +24,7 @@ class FastSAM(Model):
24
24
  """Call the __init__ method of the parent class (YOLO) with the updated default model."""
25
25
  if str(model) == "FastSAM.pt":
26
26
  model = "FastSAM-x.pt"
27
- assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
27
+ assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
28
28
  super().__init__(model=model, task="segment")
29
29
 
30
30
  @property
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from PIL import Image
11
11
 
12
- from ultralytics.utils import TQDM
12
+ from ultralytics.utils import TQDM, checks
13
13
 
14
14
 
15
15
  class FastSAMPrompt:
@@ -33,9 +33,7 @@ class FastSAMPrompt:
33
33
  try:
34
34
  import clip
35
35
  except ImportError:
36
- from ultralytics.utils.checks import check_requirements
37
-
38
- check_requirements("git+https://github.com/openai/CLIP.git")
36
+ checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
39
37
  import clip
40
38
  self.clip = clip
41
39
 
@@ -115,7 +113,8 @@ class FastSAMPrompt:
115
113
  points (list, optional): Points to be plotted. Defaults to None.
116
114
  point_label (list, optional): Labels for the points. Defaults to None.
117
115
  mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
118
- better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
116
+ better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
117
+ Defaults to True.
119
118
  retina (bool, optional): Whether to use retina mask. Defaults to False.
120
119
  with_contours (bool, optional): Whether to plot contours. Defaults to True.
121
120
  """
@@ -45,7 +45,7 @@ class NAS(Model):
45
45
 
46
46
  def __init__(self, model="yolo_nas_s.pt") -> None:
47
47
  """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
48
- assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
48
+ assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
49
49
  super().__init__(model, task="detect")
50
50
 
51
51
  @smart_inference_mode()
@@ -41,7 +41,7 @@ class SAM(Model):
41
41
  Raises:
42
42
  NotImplementedError: If the model file extension is not .pt or .pth.
43
43
  """
44
- if model and Path(model).suffix not in (".pt", ".pth"):
44
+ if model and Path(model).suffix not in {".pt", ".pth"}:
45
45
  raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
46
46
  super().__init__(model=model, task="segment")
47
47
 
@@ -112,7 +112,7 @@ class PatchMerging(nn.Module):
112
112
  self.out_dim = out_dim
113
113
  self.act = activation()
114
114
  self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
115
- stride_c = 1 if out_dim in [320, 448, 576] else 2
115
+ stride_c = 1 if out_dim in {320, 448, 576} else 2
116
116
  self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
117
117
  self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
118
118
 
@@ -1,7 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.models.yolo import classify, detect, obb, pose, segment
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
4
4
 
5
5
  from .model import YOLO, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
@@ -68,7 +68,7 @@ class ClassificationTrainer(BaseTrainer):
68
68
  self.model, ckpt = attempt_load_one_weight(model, device="cpu")
69
69
  for p in self.model.parameters():
70
70
  p.requires_grad = True # for training
71
- elif model.split(".")[-1] in ("yaml", "yml"):
71
+ elif model.split(".")[-1] in {"yaml", "yml"}:
72
72
  self.model = self.get_model(cfg=model)
73
73
  elif model in torchvision.models.__dict__:
74
74
  self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
@@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
44
44
 
45
45
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
46
46
  """Construct and return dataloader."""
47
- assert mode in ["train", "val"]
47
+ assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
48
48
  with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
49
49
  dataset = self.build_dataset(dataset_path, mode, batch_size)
50
50
  shuffle = mode == "train"
@@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
33
33
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
34
34
  self.nt_per_class = None
35
35
  self.is_coco = False
36
+ self.is_lvis = False
36
37
  self.class_map = None
37
38
  self.args.task = "detect"
38
39
  self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
66
67
  """Initialize evaluation metrics for YOLO."""
67
68
  val = self.data.get(self.args.split, "") # validation path
68
69
  self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
69
- self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
70
- self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
70
+ self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
71
+ self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
72
+ self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
71
73
  self.names = model.names
72
74
  self.nc = len(model.names)
73
75
  self.metrics.names = self.names
@@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
266
268
  self.jdict.append(
267
269
  {
268
270
  "image_id": image_id,
269
- "category_id": self.class_map[int(p[5])],
271
+ "category_id": self.class_map[int(p[5])]
272
+ + (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
270
273
  "bbox": [round(x, 3) for x in b],
271
274
  "score": round(p[4], 5),
272
275
  }
@@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
274
277
 
275
278
  def eval_json(self, stats):
276
279
  """Evaluates YOLO output in JSON format and returns performance statistics."""
277
- if self.args.save_json and self.is_coco and len(self.jdict):
278
- anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
280
+ if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
279
281
  pred_json = self.save_dir / "predictions.json" # predictions
280
- LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
282
+ anno_json = (
283
+ self.data["path"]
284
+ / "annotations"
285
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
286
+ ) # annotations
287
+ pkg = "pycocotools" if self.is_coco else "lvis"
288
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
281
289
  try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
282
- check_requirements("pycocotools>=2.0.6")
283
- from pycocotools.coco import COCO # noqa
284
- from pycocotools.cocoeval import COCOeval # noqa
285
-
286
- for x in anno_json, pred_json:
290
+ for x in pred_json, anno_json:
287
291
  assert x.is_file(), f"{x} file not found"
288
- anno = COCO(str(anno_json)) # init annotations api
289
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
290
- eval = COCOeval(anno, pred, "bbox")
292
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
291
293
  if self.is_coco:
292
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
294
+ from pycocotools.coco import COCO # noqa
295
+ from pycocotools.cocoeval import COCOeval # noqa
296
+
297
+ anno = COCO(str(anno_json)) # init annotations api
298
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
299
+ eval = COCOeval(anno, pred, "bbox")
300
+ else:
301
+ from lvis import LVIS, LVISEval
302
+
303
+ anno = LVIS(str(anno_json)) # init annotations api
304
+ pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
305
+ eval = LVISEval(anno, pred, "bbox")
306
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
293
307
  eval.evaluate()
294
308
  eval.accumulate()
295
309
  eval.summarize()
296
- stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
310
+ if self.is_lvis:
311
+ eval.print_results() # explicitly call print_results
312
+ # update mAP50-95 and mAP50
313
+ stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
314
+ eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
315
+ )
297
316
  except Exception as e:
298
- LOGGER.warning(f"pycocotools unable to run: {e}")
317
+ LOGGER.warning(f"{pkg} unable to run: {e}")
299
318
  return stats
@@ -83,6 +83,7 @@ class YOLOWorld(Model):
83
83
  "model": WorldModel,
84
84
  "validator": yolo.detect.DetectionValidator,
85
85
  "predictor": yolo.detect.DetectionPredictor,
86
+ "trainer": yolo.world.WorldTrainer,
86
87
  }
87
88
  }
88
89
 
@@ -0,0 +1,5 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .train import WorldTrainer
4
+
5
+ __all__ = ["WorldTrainer"]