ultralytics 8.3.98__py3-none-any.whl → 8.3.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +4 -3
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/yolo/__init__.py +3 -3
  13. ultralytics/models/yolo/detect/val.py +6 -1
  14. ultralytics/models/yolo/model.py +182 -3
  15. ultralytics/models/yolo/segment/val.py +43 -16
  16. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  17. ultralytics/models/yolo/yoloe/predict.py +170 -0
  18. ultralytics/models/yolo/yoloe/train.py +355 -0
  19. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  20. ultralytics/models/yolo/yoloe/val.py +187 -0
  21. ultralytics/nn/autobackend.py +3 -2
  22. ultralytics/nn/modules/__init__.py +18 -1
  23. ultralytics/nn/modules/block.py +17 -1
  24. ultralytics/nn/modules/head.py +359 -22
  25. ultralytics/nn/tasks.py +276 -10
  26. ultralytics/nn/text_model.py +193 -0
  27. ultralytics/utils/callbacks/comet.py +3 -6
  28. ultralytics/utils/downloads.py +6 -2
  29. ultralytics/utils/loss.py +67 -6
  30. ultralytics/utils/plotting.py +1 -1
  31. ultralytics/utils/tal.py +1 -1
  32. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/METADATA +10 -10
  33. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/RECORD +37 -27
  34. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
  35. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
  36. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
  37. {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from PIL import Image
13
13
  from torch.utils.data import ConcatDataset
14
14
 
15
15
  from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
16
- from ultralytics.utils.ops import resample_segments
16
+ from ultralytics.utils.ops import resample_segments, segments2boxes
17
17
  from ultralytics.utils.torch_utils import TORCHVISION_0_18
18
18
 
19
19
  from .augment import (
@@ -27,6 +27,7 @@ from .augment import (
27
27
  v8_transforms,
28
28
  )
29
29
  from .base import BaseDataset
30
+ from .converter import merge_multi_segment
30
31
  from .utils import (
31
32
  HELP_URL,
32
33
  LOGGER,
@@ -289,12 +290,15 @@ class YOLODataset(BaseDataset):
289
290
  (dict): Collated batch with stacked tensors.
290
291
  """
291
292
  new_batch = {}
293
+ batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
292
294
  keys = batch[0].keys()
293
295
  values = list(zip(*[list(b.values()) for b in batch]))
294
296
  for i, k in enumerate(keys):
295
297
  value = values[i]
296
- if k == "img":
298
+ if k == "img" or k == "text_feats":
297
299
  value = torch.stack(value, 0)
300
+ elif k == "visuals":
301
+ value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
298
302
  if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
299
303
  value = torch.cat(value, 0)
300
304
  new_batch[k] = value
@@ -346,7 +350,9 @@ class YOLOMultiModalDataset(YOLODataset):
346
350
  """
347
351
  labels = super().update_labels_info(label)
348
352
  # NOTE: some categories are concatenated with its synonyms by `/`.
353
+ # NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
349
354
  labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
355
+
350
356
  return labels
351
357
 
352
358
  def build_transforms(self, hyp=None):
@@ -362,9 +368,46 @@ class YOLOMultiModalDataset(YOLODataset):
362
368
  transforms = super().build_transforms(hyp)
363
369
  if self.augment:
364
370
  # NOTE: hard-coded the args for now.
365
- transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
371
+ # NOTE: this implementation is different from official yoloe,
372
+ # the strategy of selecting negative is restricted in one dataset,
373
+ # while official pre-saved neg embeddings from all datasets at once.
374
+ transform = RandomLoadText(
375
+ max_samples=min(self.data["nc"], 80),
376
+ padding=True,
377
+ padding_value=self._get_neg_texts(self.category_freq),
378
+ )
379
+ transforms.insert(-1, transform)
366
380
  return transforms
367
381
 
382
+ @property
383
+ def category_names(self):
384
+ """
385
+ Return category names for the dataset.
386
+
387
+ Returns:
388
+ (Tuple[str]): List of class names.
389
+ """
390
+ names = self.data["names"].values()
391
+ return {n.strip() for name in names for n in name.split("/")} # category names
392
+
393
+ @property
394
+ def category_freq(self):
395
+ """Return frequency of each category in the dataset."""
396
+ texts = [v.split("/") for v in self.data["names"].values()]
397
+ category_freq = defaultdict(int)
398
+ for label in self.labels:
399
+ for c in label["cls"]: # to check
400
+ text = texts[int(c)]
401
+ for t in text:
402
+ t = t.strip()
403
+ category_freq[t] += 1
404
+ return category_freq
405
+
406
+ @staticmethod
407
+ def _get_neg_texts(category_freq, threshold=100):
408
+ """Get negative text samples based on frequency threshold."""
409
+ return [k for k, v in category_freq.items() if v >= threshold]
410
+
368
411
 
369
412
  class GroundingDataset(YOLODataset):
370
413
  """
@@ -386,17 +429,17 @@ class GroundingDataset(YOLODataset):
386
429
  >>> len(dataset) # Number of valid images with annotations
387
430
  """
388
431
 
389
- def __init__(self, *args, task="detect", json_file, **kwargs):
432
+ def __init__(self, *args, task="detect", json_file="", **kwargs):
390
433
  """
391
434
  Initialize a GroundingDataset for object detection.
392
435
 
393
436
  Args:
394
437
  json_file (str): Path to the JSON file containing annotations.
395
- task (str): Must be 'detect' for GroundingDataset.
438
+ task (str): Must be 'detect' or 'segment' for GroundingDataset.
396
439
  *args (Any): Additional positional arguments for the parent class.
397
440
  **kwargs (Any): Additional keyword arguments for the parent class.
398
441
  """
399
- assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
442
+ assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
400
443
  self.json_file = json_file
401
444
  super().__init__(*args, task=task, data={}, **kwargs)
402
445
 
@@ -412,14 +455,31 @@ class GroundingDataset(YOLODataset):
412
455
  """
413
456
  return []
414
457
 
415
- def get_labels(self):
458
+ def verify_labels(self, labels):
459
+ """Verify the number of instances in the dataset matches expected counts."""
460
+ instance_count = sum(label["bboxes"].shape[0] for label in labels)
461
+ if "final_mixed_train_no_coco_segm" in self.json_file:
462
+ assert instance_count == 3662344
463
+ elif "final_mixed_train_no_coco" in self.json_file:
464
+ assert instance_count == 3681235
465
+ elif "final_flickr_separateGT_train_segm" in self.json_file:
466
+ assert instance_count == 638214
467
+ elif "final_flickr_separateGT_train" in self.json_file:
468
+ assert instance_count == 640704
469
+ else:
470
+ assert False
471
+
472
+ def cache_labels(self, path=Path("./labels.cache")):
416
473
  """
417
474
  Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
418
475
 
476
+ Args:
477
+ path (Path): Path where to save the cache file.
478
+
419
479
  Returns:
420
- (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
480
+ (dict): Dictionary containing cached labels and related information.
421
481
  """
422
- labels = []
482
+ x = {"labels": []}
423
483
  LOGGER.info("Loading annotation file...")
424
484
  with open(self.json_file) as f:
425
485
  annotations = json.load(f)
@@ -435,6 +495,7 @@ class GroundingDataset(YOLODataset):
435
495
  continue
436
496
  self.im_files.append(str(im_file))
437
497
  bboxes = []
498
+ segments = []
438
499
  cat2id = {}
439
500
  texts = []
440
501
  for ann in anns:
@@ -448,7 +509,10 @@ class GroundingDataset(YOLODataset):
448
509
  continue
449
510
 
450
511
  caption = img["caption"]
451
- cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
512
+ cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
513
+ if not cat_name:
514
+ continue
515
+
452
516
  if cat_name not in cat2id:
453
517
  cat2id[cat_name] = len(cat2id)
454
518
  texts.append([cat_name])
@@ -456,18 +520,66 @@ class GroundingDataset(YOLODataset):
456
520
  box = [cls] + box.tolist()
457
521
  if box not in bboxes:
458
522
  bboxes.append(box)
523
+ if ann.get("segmentation") is not None:
524
+ if len(ann["segmentation"]) == 0:
525
+ segments.append(box)
526
+ continue
527
+ elif len(ann["segmentation"]) > 1:
528
+ s = merge_multi_segment(ann["segmentation"])
529
+ s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
530
+ else:
531
+ s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
532
+ s = (
533
+ (np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
534
+ .reshape(-1)
535
+ .tolist()
536
+ )
537
+ s = [cls] + s
538
+ segments.append(s)
459
539
  lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
460
- labels.append(
540
+
541
+ if segments:
542
+ classes = np.array([x[0] for x in segments], dtype=np.float32)
543
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
544
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
545
+ lb = np.array(lb, dtype=np.float32)
546
+
547
+ x["labels"].append(
461
548
  {
462
549
  "im_file": im_file,
463
550
  "shape": (h, w),
464
551
  "cls": lb[:, 0:1], # n, 1
465
552
  "bboxes": lb[:, 1:], # n, 4
553
+ "segments": segments,
466
554
  "normalized": True,
467
555
  "bbox_format": "xywh",
468
556
  "texts": texts,
469
557
  }
470
558
  )
559
+ x["hash"] = get_hash(self.json_file)
560
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
561
+ return x
562
+
563
+ def get_labels(self):
564
+ """
565
+ Load labels from cache or generate them from JSON file.
566
+
567
+ Returns:
568
+ (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
569
+ """
570
+ cache_path = Path(self.json_file).with_suffix(".cache")
571
+ try:
572
+ cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
573
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
574
+ assert cache["hash"] == get_hash(self.json_file) # identical hash
575
+ except (FileNotFoundError, AssertionError, AttributeError):
576
+ cache, _ = self.cache_labels(cache_path), False # run cache ops
577
+ [cache.pop(k) for k in ("hash", "version")] # remove items
578
+ labels = cache["labels"]
579
+ # self.verify_labels(labels)
580
+ self.im_files = [str(label["im_file"]) for label in labels]
581
+ if LOCAL_RANK in {-1, 0}:
582
+ LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
471
583
  return labels
472
584
 
473
585
  def build_transforms(self, hyp=None):
@@ -483,9 +595,38 @@ class GroundingDataset(YOLODataset):
483
595
  transforms = super().build_transforms(hyp)
484
596
  if self.augment:
485
597
  # NOTE: hard-coded the args for now.
486
- transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
598
+ # NOTE: this implementation is different from official yoloe,
599
+ # the strategy of selecting negative is restricted in one dataset,
600
+ # while official pre-saved neg embeddings from all datasets at once.
601
+ transform = RandomLoadText(
602
+ max_samples=80,
603
+ padding=True,
604
+ padding_value=self._get_neg_texts(self.category_freq),
605
+ )
606
+ transforms.insert(-1, transform)
487
607
  return transforms
488
608
 
609
+ @property
610
+ def category_names(self):
611
+ """Return unique category names from the dataset."""
612
+ return {t.strip() for label in self.labels for text in label["texts"] for t in text}
613
+
614
+ @property
615
+ def category_freq(self):
616
+ """Return frequency of each category in the dataset."""
617
+ category_freq = defaultdict(int)
618
+ for label in self.labels:
619
+ for text in label["texts"]:
620
+ for t in text:
621
+ t = t.strip()
622
+ category_freq[t] += 1
623
+ return category_freq
624
+
625
+ @staticmethod
626
+ def _get_neg_texts(category_freq, threshold=100):
627
+ """Get negative text samples based on frequency threshold."""
628
+ return [k for k, v in category_freq.items() if v >= threshold]
629
+
489
630
 
490
631
  class YOLOConcatDataset(ConcatDataset):
491
632
  """
@@ -516,6 +657,18 @@ class YOLOConcatDataset(ConcatDataset):
516
657
  """
517
658
  return YOLODataset.collate_fn(batch)
518
659
 
660
+ def close_mosaic(self, hyp):
661
+ """
662
+ Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
663
+
664
+ Args:
665
+ hyp (dict): Hyperparameters for transforms.
666
+ """
667
+ for dataset in self.datasets:
668
+ if not hasattr(dataset, "close_mosaic"):
669
+ continue
670
+ dataset.close_mosaic(hyp)
671
+
519
672
 
520
673
  # TODO: support semantic segmentation
521
674
  class SemanticDataset(BaseDataset):
@@ -327,6 +327,7 @@ class Exporter:
327
327
  "See https://docs.ultralytics.com/models/yolo-world for details."
328
328
  )
329
329
  model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
330
+
330
331
  if self.args.int8 and not self.args.data:
331
332
  self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
332
333
  LOGGER.warning(
@@ -635,7 +636,7 @@ class Exporter:
635
636
  # Generate calibration data for integer quantization
636
637
  ignored_scope = None
637
638
  if isinstance(self.model.model[-1], Detect):
638
- # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
639
+ # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
639
640
  head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
640
641
  ignored_scope = nncf.IgnoredScope( # ignore operations
641
642
  patterns=[
@@ -797,12 +798,12 @@ class Exporter:
797
798
  LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
798
799
  # TODO CoreML Segment and Pose model pipelining
799
800
  model = self.model
800
-
801
801
  ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
802
802
  ct_model = ct.convert(
803
803
  ts,
804
- inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
804
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
805
805
  classifier_config=classifier_config,
806
+ minimum_deployment_target=ct.target.iOS16,
806
807
  convert_to="neuralnetwork" if mlmodel else "mlprogram",
807
808
  )
808
809
  bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
@@ -249,6 +249,7 @@ class BaseTrainer:
249
249
  )
250
250
  always_freeze_names = [".dfl"] # always freeze these layers
251
251
  freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
252
+ self.freeze_layer_names = freeze_layer_names
252
253
  for k, v in self.model.named_parameters():
253
254
  # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
254
255
  if any(x in k for x in freeze_layer_names):
@@ -350,7 +351,7 @@ class BaseTrainer:
350
351
  warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
351
352
  self.scheduler.step()
352
353
 
353
- self.model.train()
354
+ self._model_train()
354
355
  if RANK != -1:
355
356
  self.train_loader.sampler.set_epoch(epoch)
356
357
  pbar = enumerate(self.train_loader)
@@ -381,7 +382,8 @@ class BaseTrainer:
381
382
  # Forward
382
383
  with autocast(self.amp):
383
384
  batch = self.preprocess_batch(batch)
384
- self.loss, self.loss_items = self.model(batch)
385
+ loss, self.loss_items = self.model(batch)
386
+ self.loss = loss.sum()
385
387
  if RANK != -1:
386
388
  self.loss *= world_size
387
389
  self.tloss = (
@@ -496,9 +498,7 @@ class BaseTrainer:
496
498
  memory = torch.mps.driver_allocated_memory()
497
499
  if fraction:
498
500
  return __import__("psutil").virtual_memory().percent / 100
499
- elif self.device.type == "cpu":
500
- pass
501
- else:
501
+ elif self.device.type != "cpu":
502
502
  memory = torch.cuda.memory_reserved()
503
503
  if fraction:
504
504
  total = torch.cuda.get_device_properties(self.device).total_memory
@@ -520,6 +520,14 @@ class BaseTrainer:
520
520
 
521
521
  return pd.read_csv(self.csv).to_dict(orient="list")
522
522
 
523
+ def _model_train(self):
524
+ """Set model in training mode."""
525
+ self.model.train()
526
+ # Freeze BN stat
527
+ for n, m in self.model.named_modules():
528
+ if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
529
+ m.eval()
530
+
523
531
  def save_model(self):
524
532
  """Save model training checkpoints with additional metadata."""
525
533
  import io
@@ -720,7 +728,7 @@ class BaseTrainer:
720
728
 
721
729
  # Check that resume data YAML exists, otherwise strip to force re-download of dataset
722
730
  ckpt_args = attempt_load_weights(last).args
723
- if not Path(ckpt_args["data"]).exists():
731
+ if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
724
732
  ckpt_args["data"] = self.args.data
725
733
 
726
734
  resume = True
@@ -812,7 +820,8 @@ class BaseTrainer:
812
820
  fullname = f"{module_name}.{param_name}" if module_name else param_name
813
821
  if "bias" in fullname: # bias (no decay)
814
822
  g[2].append(param)
815
- elif isinstance(module, bn): # weight (no decay)
823
+ elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
824
+ # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
816
825
  g[1].append(param)
817
826
  else: # weight (with decay)
818
827
  g[0].append(param)
@@ -4,6 +4,6 @@ from .fastsam import FastSAM
4
4
  from .nas import NAS
5
5
  from .rtdetr import RTDETR
6
6
  from .sam import SAM
7
- from .yolo import YOLO, YOLOWorld
7
+ from .yolo import YOLO, YOLOE, YOLOWorld
8
8
 
9
- __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
9
+ __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
4
4
 
5
- from .model import YOLO, YOLOWorld
5
+ from .model import YOLO, YOLOE, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
@@ -455,8 +455,13 @@ class DetectionValidator(BaseValidator):
455
455
  val.print_results() # explicitly call print_results
456
456
  # update mAP50-95 and mAP50
457
457
  stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
458
- val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
458
+ val.stats[:2] if self.is_coco else [val.results["AP"], val.results["AP50"]]
459
459
  )
460
+ if self.is_lvis:
461
+ stats["metrics/APr(B)"] = val.results["APr"]
462
+ stats["metrics/APc(B)"] = val.results["APc"]
463
+ stats["metrics/APf(B)"] = val.results["APf"]
464
+ stats["fitness"] = val.results["AP"]
460
465
  except Exception as e:
461
466
  LOGGER.warning(f"{pkg} unable to run: {e}")
462
467
  return stats
@@ -4,7 +4,16 @@ from pathlib import Path
4
4
 
5
5
  from ultralytics.engine.model import Model
6
6
  from ultralytics.models import yolo
7
- from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
7
+ from ultralytics.nn.tasks import (
8
+ ClassificationModel,
9
+ DetectionModel,
10
+ OBBModel,
11
+ PoseModel,
12
+ SegmentationModel,
13
+ WorldModel,
14
+ YOLOEModel,
15
+ YOLOESegModel,
16
+ )
8
17
  from ultralytics.utils import ROOT, yaml_load
9
18
 
10
19
 
@@ -12,12 +21,16 @@ class YOLO(Model):
12
21
  """YOLO (You Only Look Once) object detection model."""
13
22
 
14
23
  def __init__(self, model="yolo11n.pt", task=None, verbose=False):
15
- """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
24
+ """Initialize YOLO model, switching to YOLOWorld/YOLOE if model filename contains '-world'/'yoloe'."""
16
25
  path = Path(model)
17
26
  if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
18
27
  new_instance = YOLOWorld(path, verbose=verbose)
19
28
  self.__class__ = type(new_instance)
20
29
  self.__dict__ = new_instance.__dict__
30
+ elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
31
+ new_instance = YOLOE(path, task=task, verbose=verbose)
32
+ self.__class__ = type(new_instance)
33
+ self.__dict__ = new_instance.__dict__
21
34
  else:
22
35
  # Continue with default YOLO initialization
23
36
  super().__init__(model=model, task=task, verbose=verbose)
@@ -96,7 +109,7 @@ class YOLOWorld(Model):
96
109
  Set the model's class names for detection.
97
110
 
98
111
  Args:
99
- classes (List(str)): A list of categories i.e. ["person"].
112
+ classes (list[str]): A list of categories i.e. ["person"].
100
113
  """
101
114
  self.model.set_classes(classes)
102
115
  # Remove background if it's given
@@ -108,3 +121,169 @@ class YOLOWorld(Model):
108
121
  # Reset method class names
109
122
  if self.predictor:
110
123
  self.predictor.model.names = classes
124
+
125
+
126
+ class YOLOE(Model):
127
+ """YOLOE object detection and segmentation model."""
128
+
129
+ def __init__(self, model="yoloe-v8s-seg.pt", task=None, verbose=False) -> None:
130
+ """
131
+ Initialize YOLOE model with a pre-trained model file.
132
+
133
+ Args:
134
+ model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
135
+ task (str, optional): Task type for the model. Auto-detected if None.
136
+ verbose (bool): If True, prints additional information during initialization.
137
+ """
138
+ super().__init__(model=model, task=task, verbose=verbose)
139
+
140
+ # Assign default COCO class names when there are no custom names
141
+ if not hasattr(self.model, "names"):
142
+ self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
143
+
144
+ @property
145
+ def task_map(self):
146
+ """Map head to model, validator, and predictor classes."""
147
+ return {
148
+ "detect": {
149
+ "model": YOLOEModel,
150
+ "validator": yolo.yoloe.YOLOEDetectValidator,
151
+ "predictor": yolo.detect.DetectionPredictor,
152
+ "trainer": yolo.yoloe.YOLOETrainer,
153
+ },
154
+ "segment": {
155
+ "model": YOLOESegModel,
156
+ "validator": yolo.yoloe.YOLOESegValidator,
157
+ "predictor": yolo.segment.SegmentationPredictor,
158
+ "trainer": yolo.yoloe.YOLOESegTrainer,
159
+ },
160
+ }
161
+
162
+ def get_text_pe(self, texts):
163
+ """Get text positional embeddings for the given texts."""
164
+ assert isinstance(self.model, YOLOEModel)
165
+ return self.model.get_text_pe(texts)
166
+
167
+ def get_visual_pe(self, img, visual):
168
+ """Get visual positional embeddings for the given image and visual features."""
169
+ assert isinstance(self.model, YOLOEModel)
170
+ return self.model.get_visual_pe(img, visual)
171
+
172
+ def set_vocab(self, vocab, names):
173
+ """Set vocabulary and class names for the model."""
174
+ assert isinstance(self.model, YOLOEModel)
175
+ self.model.set_vocab(vocab, names=names)
176
+
177
+ def get_vocab(self, names):
178
+ """Get vocabulary for the given class names."""
179
+ assert isinstance(self.model, YOLOEModel)
180
+ return self.model.get_vocab(names)
181
+
182
+ def set_classes(self, classes, embeddings):
183
+ """
184
+ Set the model's class names and embeddings for detection.
185
+
186
+ Args:
187
+ classes (list[str]): A list of categories i.e. ["person"].
188
+ embeddings (torch.Tensor): Embeddings corresponding to the classes.
189
+ """
190
+ assert isinstance(self.model, YOLOEModel)
191
+ self.model.set_classes(classes, embeddings)
192
+ # Verify no background class is present
193
+ assert " " not in classes
194
+ self.model.names = classes
195
+
196
+ # Reset method class names
197
+ if self.predictor:
198
+ self.predictor.model.names = classes
199
+
200
+ def val(
201
+ self,
202
+ validator=None,
203
+ load_vp=False,
204
+ refer_data=None,
205
+ **kwargs,
206
+ ):
207
+ """
208
+ Validate the model using text or visual prompts.
209
+
210
+ Args:
211
+ validator (callable, optional): A callable validator function. If None, a default validator is loaded.
212
+ load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
213
+ refer_data (str, optional): Path to the reference data for visual prompts.
214
+ **kwargs (Any): Additional keyword arguments to override default settings.
215
+
216
+ Returns:
217
+ (dict): Validation statistics containing metrics computed during validation.
218
+ """
219
+ custom = {"rect": not load_vp} # method defaults
220
+ args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
221
+
222
+ validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
223
+ validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
224
+ self.metrics = validator.metrics
225
+ return validator.metrics
226
+
227
+ def predict(
228
+ self,
229
+ source=None,
230
+ stream: bool = False,
231
+ visual_prompts: dict = {},
232
+ refer_image=None,
233
+ predictor=None,
234
+ **kwargs,
235
+ ):
236
+ """
237
+ Run prediction on images, videos, directories, streams, etc.
238
+
239
+ Args:
240
+ source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
241
+ directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
242
+ stream (bool): Whether to stream the prediction results. If True, results are yielded as a
243
+ generator as they are computed.
244
+ visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
245
+ 'cls' keys when non-empty.
246
+ refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
247
+ predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
248
+ loaded based on the task.
249
+ **kwargs (Any): Additional keyword arguments passed to the predictor.
250
+
251
+ Returns:
252
+ (List | generator): List of Results objects or generator of Results objects if stream=True.
253
+
254
+ Examples:
255
+ >>> model = YOLOE("yoloe-v8s-seg.pt")
256
+ >>> results = model.predict("path/to/image.jpg")
257
+ >>> # With visual prompts
258
+ >>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
259
+ >>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
260
+ """
261
+ if len(visual_prompts):
262
+ assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
263
+ f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
264
+ )
265
+ assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
266
+ f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
267
+ f"{len(visual_prompts['cls'])} respectively"
268
+ )
269
+ self.predictor = (predictor or self._smart_load("predictor"))(
270
+ overrides={"task": "segment", "mode": "predict", "save": False, "verbose": False}, _callbacks=self.callbacks
271
+ )
272
+
273
+ if len(visual_prompts):
274
+ num_cls = (
275
+ max(len(set(c)) for c in visual_prompts["cls"])
276
+ if isinstance(source, list) # means multiple images
277
+ else len(set(visual_prompts["cls"]))
278
+ )
279
+ self.model.model[-1].nc = num_cls
280
+ self.model.names = [f"object{i}" for i in range(num_cls)]
281
+ self.predictor.set_prompts(visual_prompts)
282
+
283
+ self.predictor.setup_model(model=self.model)
284
+ if refer_image is not None and len(visual_prompts):
285
+ vpe = self.predictor.get_vpe(refer_image)
286
+ self.model.set_classes(self.model.names, vpe)
287
+ self.predictor = None # reset predictor
288
+
289
+ return super().predict(source, stream, **kwargs)