ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +13 -13
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/nas/model.py +1 -0
  13. ultralytics/models/nas/predict.py +4 -24
  14. ultralytics/models/nas/val.py +1 -4
  15. ultralytics/models/yolo/__init__.py +3 -3
  16. ultralytics/models/yolo/detect/val.py +6 -1
  17. ultralytics/models/yolo/model.py +182 -3
  18. ultralytics/models/yolo/segment/val.py +43 -16
  19. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  20. ultralytics/models/yolo/yoloe/predict.py +170 -0
  21. ultralytics/models/yolo/yoloe/train.py +355 -0
  22. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  23. ultralytics/models/yolo/yoloe/val.py +187 -0
  24. ultralytics/nn/autobackend.py +3 -2
  25. ultralytics/nn/modules/__init__.py +18 -1
  26. ultralytics/nn/modules/block.py +17 -1
  27. ultralytics/nn/modules/head.py +359 -22
  28. ultralytics/nn/tasks.py +276 -10
  29. ultralytics/nn/text_model.py +193 -0
  30. ultralytics/utils/callbacks/comet.py +3 -6
  31. ultralytics/utils/downloads.py +6 -2
  32. ultralytics/utils/instance.py +7 -2
  33. ultralytics/utils/loss.py +67 -6
  34. ultralytics/utils/plotting.py +1 -1
  35. ultralytics/utils/tal.py +1 -1
  36. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
  37. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
  38. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
  39. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
  40. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
  41. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@ from PIL import Image
13
13
  from torch.utils.data import ConcatDataset
14
14
 
15
15
  from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
16
- from ultralytics.utils.ops import resample_segments
16
+ from ultralytics.utils.ops import resample_segments, segments2boxes
17
17
  from ultralytics.utils.torch_utils import TORCHVISION_0_18
18
18
 
19
19
  from .augment import (
@@ -27,6 +27,7 @@ from .augment import (
27
27
  v8_transforms,
28
28
  )
29
29
  from .base import BaseDataset
30
+ from .converter import merge_multi_segment
30
31
  from .utils import (
31
32
  HELP_URL,
32
33
  LOGGER,
@@ -289,12 +290,15 @@ class YOLODataset(BaseDataset):
289
290
  (dict): Collated batch with stacked tensors.
290
291
  """
291
292
  new_batch = {}
293
+ batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
292
294
  keys = batch[0].keys()
293
295
  values = list(zip(*[list(b.values()) for b in batch]))
294
296
  for i, k in enumerate(keys):
295
297
  value = values[i]
296
- if k == "img":
298
+ if k == "img" or k == "text_feats":
297
299
  value = torch.stack(value, 0)
300
+ elif k == "visuals":
301
+ value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
298
302
  if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
299
303
  value = torch.cat(value, 0)
300
304
  new_batch[k] = value
@@ -346,7 +350,9 @@ class YOLOMultiModalDataset(YOLODataset):
346
350
  """
347
351
  labels = super().update_labels_info(label)
348
352
  # NOTE: some categories are concatenated with its synonyms by `/`.
353
+ # NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
349
354
  labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
355
+
350
356
  return labels
351
357
 
352
358
  def build_transforms(self, hyp=None):
@@ -362,9 +368,46 @@ class YOLOMultiModalDataset(YOLODataset):
362
368
  transforms = super().build_transforms(hyp)
363
369
  if self.augment:
364
370
  # NOTE: hard-coded the args for now.
365
- transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
371
+ # NOTE: this implementation is different from official yoloe,
372
+ # the strategy of selecting negative is restricted in one dataset,
373
+ # while official pre-saved neg embeddings from all datasets at once.
374
+ transform = RandomLoadText(
375
+ max_samples=min(self.data["nc"], 80),
376
+ padding=True,
377
+ padding_value=self._get_neg_texts(self.category_freq),
378
+ )
379
+ transforms.insert(-1, transform)
366
380
  return transforms
367
381
 
382
+ @property
383
+ def category_names(self):
384
+ """
385
+ Return category names for the dataset.
386
+
387
+ Returns:
388
+ (Tuple[str]): List of class names.
389
+ """
390
+ names = self.data["names"].values()
391
+ return {n.strip() for name in names for n in name.split("/")} # category names
392
+
393
+ @property
394
+ def category_freq(self):
395
+ """Return frequency of each category in the dataset."""
396
+ texts = [v.split("/") for v in self.data["names"].values()]
397
+ category_freq = defaultdict(int)
398
+ for label in self.labels:
399
+ for c in label["cls"]: # to check
400
+ text = texts[int(c)]
401
+ for t in text:
402
+ t = t.strip()
403
+ category_freq[t] += 1
404
+ return category_freq
405
+
406
+ @staticmethod
407
+ def _get_neg_texts(category_freq, threshold=100):
408
+ """Get negative text samples based on frequency threshold."""
409
+ return [k for k, v in category_freq.items() if v >= threshold]
410
+
368
411
 
369
412
  class GroundingDataset(YOLODataset):
370
413
  """
@@ -386,17 +429,17 @@ class GroundingDataset(YOLODataset):
386
429
  >>> len(dataset) # Number of valid images with annotations
387
430
  """
388
431
 
389
- def __init__(self, *args, task="detect", json_file, **kwargs):
432
+ def __init__(self, *args, task="detect", json_file="", **kwargs):
390
433
  """
391
434
  Initialize a GroundingDataset for object detection.
392
435
 
393
436
  Args:
394
437
  json_file (str): Path to the JSON file containing annotations.
395
- task (str): Must be 'detect' for GroundingDataset.
438
+ task (str): Must be 'detect' or 'segment' for GroundingDataset.
396
439
  *args (Any): Additional positional arguments for the parent class.
397
440
  **kwargs (Any): Additional keyword arguments for the parent class.
398
441
  """
399
- assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
442
+ assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
400
443
  self.json_file = json_file
401
444
  super().__init__(*args, task=task, data={}, **kwargs)
402
445
 
@@ -412,14 +455,31 @@ class GroundingDataset(YOLODataset):
412
455
  """
413
456
  return []
414
457
 
415
- def get_labels(self):
458
+ def verify_labels(self, labels):
459
+ """Verify the number of instances in the dataset matches expected counts."""
460
+ instance_count = sum(label["bboxes"].shape[0] for label in labels)
461
+ if "final_mixed_train_no_coco_segm" in self.json_file:
462
+ assert instance_count == 3662344
463
+ elif "final_mixed_train_no_coco" in self.json_file:
464
+ assert instance_count == 3681235
465
+ elif "final_flickr_separateGT_train_segm" in self.json_file:
466
+ assert instance_count == 638214
467
+ elif "final_flickr_separateGT_train" in self.json_file:
468
+ assert instance_count == 640704
469
+ else:
470
+ assert False
471
+
472
+ def cache_labels(self, path=Path("./labels.cache")):
416
473
  """
417
474
  Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
418
475
 
476
+ Args:
477
+ path (Path): Path where to save the cache file.
478
+
419
479
  Returns:
420
- (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
480
+ (dict): Dictionary containing cached labels and related information.
421
481
  """
422
- labels = []
482
+ x = {"labels": []}
423
483
  LOGGER.info("Loading annotation file...")
424
484
  with open(self.json_file) as f:
425
485
  annotations = json.load(f)
@@ -435,6 +495,7 @@ class GroundingDataset(YOLODataset):
435
495
  continue
436
496
  self.im_files.append(str(im_file))
437
497
  bboxes = []
498
+ segments = []
438
499
  cat2id = {}
439
500
  texts = []
440
501
  for ann in anns:
@@ -448,7 +509,10 @@ class GroundingDataset(YOLODataset):
448
509
  continue
449
510
 
450
511
  caption = img["caption"]
451
- cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
512
+ cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
513
+ if not cat_name:
514
+ continue
515
+
452
516
  if cat_name not in cat2id:
453
517
  cat2id[cat_name] = len(cat2id)
454
518
  texts.append([cat_name])
@@ -456,18 +520,66 @@ class GroundingDataset(YOLODataset):
456
520
  box = [cls] + box.tolist()
457
521
  if box not in bboxes:
458
522
  bboxes.append(box)
523
+ if ann.get("segmentation") is not None:
524
+ if len(ann["segmentation"]) == 0:
525
+ segments.append(box)
526
+ continue
527
+ elif len(ann["segmentation"]) > 1:
528
+ s = merge_multi_segment(ann["segmentation"])
529
+ s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
530
+ else:
531
+ s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
532
+ s = (
533
+ (np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
534
+ .reshape(-1)
535
+ .tolist()
536
+ )
537
+ s = [cls] + s
538
+ segments.append(s)
459
539
  lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
460
- labels.append(
540
+
541
+ if segments:
542
+ classes = np.array([x[0] for x in segments], dtype=np.float32)
543
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
544
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
545
+ lb = np.array(lb, dtype=np.float32)
546
+
547
+ x["labels"].append(
461
548
  {
462
549
  "im_file": im_file,
463
550
  "shape": (h, w),
464
551
  "cls": lb[:, 0:1], # n, 1
465
552
  "bboxes": lb[:, 1:], # n, 4
553
+ "segments": segments,
466
554
  "normalized": True,
467
555
  "bbox_format": "xywh",
468
556
  "texts": texts,
469
557
  }
470
558
  )
559
+ x["hash"] = get_hash(self.json_file)
560
+ save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
561
+ return x
562
+
563
+ def get_labels(self):
564
+ """
565
+ Load labels from cache or generate them from JSON file.
566
+
567
+ Returns:
568
+ (List[dict]): List of label dictionaries, each containing information about an image and its annotations.
569
+ """
570
+ cache_path = Path(self.json_file).with_suffix(".cache")
571
+ try:
572
+ cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
573
+ assert cache["version"] == DATASET_CACHE_VERSION # matches current version
574
+ assert cache["hash"] == get_hash(self.json_file) # identical hash
575
+ except (FileNotFoundError, AssertionError, AttributeError):
576
+ cache, _ = self.cache_labels(cache_path), False # run cache ops
577
+ [cache.pop(k) for k in ("hash", "version")] # remove items
578
+ labels = cache["labels"]
579
+ # self.verify_labels(labels)
580
+ self.im_files = [str(label["im_file"]) for label in labels]
581
+ if LOCAL_RANK in {-1, 0}:
582
+ LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
471
583
  return labels
472
584
 
473
585
  def build_transforms(self, hyp=None):
@@ -483,9 +595,38 @@ class GroundingDataset(YOLODataset):
483
595
  transforms = super().build_transforms(hyp)
484
596
  if self.augment:
485
597
  # NOTE: hard-coded the args for now.
486
- transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
598
+ # NOTE: this implementation is different from official yoloe,
599
+ # the strategy of selecting negative is restricted in one dataset,
600
+ # while official pre-saved neg embeddings from all datasets at once.
601
+ transform = RandomLoadText(
602
+ max_samples=80,
603
+ padding=True,
604
+ padding_value=self._get_neg_texts(self.category_freq),
605
+ )
606
+ transforms.insert(-1, transform)
487
607
  return transforms
488
608
 
609
+ @property
610
+ def category_names(self):
611
+ """Return unique category names from the dataset."""
612
+ return {t.strip() for label in self.labels for text in label["texts"] for t in text}
613
+
614
+ @property
615
+ def category_freq(self):
616
+ """Return frequency of each category in the dataset."""
617
+ category_freq = defaultdict(int)
618
+ for label in self.labels:
619
+ for text in label["texts"]:
620
+ for t in text:
621
+ t = t.strip()
622
+ category_freq[t] += 1
623
+ return category_freq
624
+
625
+ @staticmethod
626
+ def _get_neg_texts(category_freq, threshold=100):
627
+ """Get negative text samples based on frequency threshold."""
628
+ return [k for k, v in category_freq.items() if v >= threshold]
629
+
489
630
 
490
631
  class YOLOConcatDataset(ConcatDataset):
491
632
  """
@@ -516,6 +657,18 @@ class YOLOConcatDataset(ConcatDataset):
516
657
  """
517
658
  return YOLODataset.collate_fn(batch)
518
659
 
660
+ def close_mosaic(self, hyp):
661
+ """
662
+ Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
663
+
664
+ Args:
665
+ hyp (dict): Hyperparameters for transforms.
666
+ """
667
+ for dataset in self.datasets:
668
+ if not hasattr(dataset, "close_mosaic"):
669
+ continue
670
+ dataset.close_mosaic(hyp)
671
+
519
672
 
520
673
  # TODO: support semantic segmentation
521
674
  class SemanticDataset(BaseDataset):
@@ -58,6 +58,7 @@ TensorFlow.js:
58
58
  import gc
59
59
  import json
60
60
  import os
61
+ import re
61
62
  import shutil
62
63
  import subprocess
63
64
  import time
@@ -326,6 +327,7 @@ class Exporter:
326
327
  "See https://docs.ultralytics.com/models/yolo-world for details."
327
328
  )
328
329
  model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
330
+
329
331
  if self.args.int8 and not self.args.data:
330
332
  self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
331
333
  LOGGER.warning(
@@ -634,7 +636,7 @@ class Exporter:
634
636
  # Generate calibration data for integer quantization
635
637
  ignored_scope = None
636
638
  if isinstance(self.model.model[-1], Detect):
637
- # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
639
+ # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
638
640
  head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
639
641
  ignored_scope = nncf.IgnoredScope( # ignore operations
640
642
  patterns=[
@@ -796,12 +798,12 @@ class Exporter:
796
798
  LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
797
799
  # TODO CoreML Segment and Pose model pipelining
798
800
  model = self.model
799
-
800
801
  ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
801
802
  ct_model = ct.convert(
802
803
  ts,
803
- inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
804
+ inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
804
805
  classifier_config=classifier_config,
806
+ minimum_deployment_target=ct.target.iOS16,
805
807
  convert_to="neuralnetwork" if mlmodel else "mlprogram",
806
808
  )
807
809
  bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
@@ -1231,17 +1233,15 @@ class Exporter:
1231
1233
 
1232
1234
  LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
1233
1235
 
1236
+ # Install Java>=17
1234
1237
  try:
1235
- out = subprocess.run(
1236
- ["java", "--version"], check=True, capture_output=True
1237
- ) # Java 17 is required for imx500-converter
1238
- if "openjdk 17" not in str(out.stdout):
1239
- raise FileNotFoundError
1240
- except FileNotFoundError:
1241
- c = ["apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"]
1242
- if is_sudo_available():
1243
- c.insert(0, "sudo")
1244
- subprocess.run(c, check=True)
1238
+ java_output = subprocess.run(["java", "--version"], check=True, capture_output=True).stdout.decode()
1239
+ version_match = re.search(r"(?:openjdk|java) (\d+)", java_output)
1240
+ java_version = int(version_match.group(1)) if version_match else 0
1241
+ assert java_version >= 17, "Java version too old"
1242
+ except (FileNotFoundError, subprocess.CalledProcessError, AssertionError):
1243
+ cmd = (["sudo"] if is_sudo_available() else []) + ["apt", "install", "-y", "default-jre"]
1244
+ subprocess.run(cmd, check=True)
1245
1245
 
1246
1246
  def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
1247
1247
  for batch in dataloader:
@@ -249,6 +249,7 @@ class BaseTrainer:
249
249
  )
250
250
  always_freeze_names = [".dfl"] # always freeze these layers
251
251
  freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
252
+ self.freeze_layer_names = freeze_layer_names
252
253
  for k, v in self.model.named_parameters():
253
254
  # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
254
255
  if any(x in k for x in freeze_layer_names):
@@ -350,7 +351,7 @@ class BaseTrainer:
350
351
  warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
351
352
  self.scheduler.step()
352
353
 
353
- self.model.train()
354
+ self._model_train()
354
355
  if RANK != -1:
355
356
  self.train_loader.sampler.set_epoch(epoch)
356
357
  pbar = enumerate(self.train_loader)
@@ -381,7 +382,8 @@ class BaseTrainer:
381
382
  # Forward
382
383
  with autocast(self.amp):
383
384
  batch = self.preprocess_batch(batch)
384
- self.loss, self.loss_items = self.model(batch)
385
+ loss, self.loss_items = self.model(batch)
386
+ self.loss = loss.sum()
385
387
  if RANK != -1:
386
388
  self.loss *= world_size
387
389
  self.tloss = (
@@ -496,9 +498,7 @@ class BaseTrainer:
496
498
  memory = torch.mps.driver_allocated_memory()
497
499
  if fraction:
498
500
  return __import__("psutil").virtual_memory().percent / 100
499
- elif self.device.type == "cpu":
500
- pass
501
- else:
501
+ elif self.device.type != "cpu":
502
502
  memory = torch.cuda.memory_reserved()
503
503
  if fraction:
504
504
  total = torch.cuda.get_device_properties(self.device).total_memory
@@ -520,6 +520,14 @@ class BaseTrainer:
520
520
 
521
521
  return pd.read_csv(self.csv).to_dict(orient="list")
522
522
 
523
+ def _model_train(self):
524
+ """Set model in training mode."""
525
+ self.model.train()
526
+ # Freeze BN stat
527
+ for n, m in self.model.named_modules():
528
+ if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
529
+ m.eval()
530
+
523
531
  def save_model(self):
524
532
  """Save model training checkpoints with additional metadata."""
525
533
  import io
@@ -720,7 +728,7 @@ class BaseTrainer:
720
728
 
721
729
  # Check that resume data YAML exists, otherwise strip to force re-download of dataset
722
730
  ckpt_args = attempt_load_weights(last).args
723
- if not Path(ckpt_args["data"]).exists():
731
+ if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
724
732
  ckpt_args["data"] = self.args.data
725
733
 
726
734
  resume = True
@@ -812,7 +820,8 @@ class BaseTrainer:
812
820
  fullname = f"{module_name}.{param_name}" if module_name else param_name
813
821
  if "bias" in fullname: # bias (no decay)
814
822
  g[2].append(param)
815
- elif isinstance(module, bn): # weight (no decay)
823
+ elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
824
+ # ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
816
825
  g[1].append(param)
817
826
  else: # weight (with decay)
818
827
  g[0].append(param)
@@ -4,6 +4,6 @@ from .fastsam import FastSAM
4
4
  from .nas import NAS
5
5
  from .rtdetr import RTDETR
6
6
  from .sam import SAM
7
- from .yolo import YOLO, YOLOWorld
7
+ from .yolo import YOLO, YOLOE, YOLOWorld
8
8
 
9
- __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
9
+ __all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
@@ -81,6 +81,7 @@ class NAS(Model):
81
81
  self.model.pt_path = weights # for export()
82
82
  self.model.task = "detect" # for export()
83
83
  self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
84
+ self.model.eval()
84
85
 
85
86
  def info(self, detailed: bool = False, verbose: bool = True):
86
87
  """
@@ -2,16 +2,15 @@
2
2
 
3
3
  import torch
4
4
 
5
- from ultralytics.engine.predictor import BasePredictor
6
- from ultralytics.engine.results import Results
5
+ from ultralytics.models.yolo.detect.predict import DetectionPredictor
7
6
  from ultralytics.utils import ops
8
7
 
9
8
 
10
- class NASPredictor(BasePredictor):
9
+ class NASPredictor(DetectionPredictor):
11
10
  """
12
11
  Ultralytics YOLO NAS Predictor for object detection.
13
12
 
14
- This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
13
+ This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the
15
14
  raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
16
15
  scaling the bounding boxes to fit the original image dimensions.
17
16
 
@@ -38,23 +37,4 @@ class NASPredictor(BasePredictor):
38
37
  # Convert boxes from xyxy to xywh format and concatenate with class scores
39
38
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
39
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
-
42
- # Apply non-maximum suppression to filter overlapping detections
43
- preds = ops.non_max_suppression(
44
- preds,
45
- self.args.conf,
46
- self.args.iou,
47
- agnostic=self.args.agnostic_nms,
48
- max_det=self.args.max_det,
49
- classes=self.args.classes,
50
- )
51
-
52
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
53
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
54
-
55
- results = []
56
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
57
- # Scale bounding boxes to match original image dimensions
58
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
59
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
60
- return results
40
+ return super().postprocess(preds, img, orig_imgs)
@@ -36,7 +36,4 @@ class NASValidator(DetectionValidator):
36
36
  """Apply Non-maximum suppression to prediction outputs."""
37
37
  boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
38
38
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
39
- return super().postprocess(
40
- preds,
41
- max_time_img=0.5,
42
- )
39
+ return super().postprocess(preds)
@@ -1,7 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
3
+ from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
4
4
 
5
- from .model import YOLO, YOLOWorld
5
+ from .model import YOLO, YOLOE, YOLOWorld
6
6
 
7
- __all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
7
+ __all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
@@ -455,8 +455,13 @@ class DetectionValidator(BaseValidator):
455
455
  val.print_results() # explicitly call print_results
456
456
  # update mAP50-95 and mAP50
457
457
  stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
458
- val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]
458
+ val.stats[:2] if self.is_coco else [val.results["AP"], val.results["AP50"]]
459
459
  )
460
+ if self.is_lvis:
461
+ stats["metrics/APr(B)"] = val.results["APr"]
462
+ stats["metrics/APc(B)"] = val.results["APc"]
463
+ stats["metrics/APf(B)"] = val.results["APf"]
464
+ stats["fitness"] = val.results["AP"]
460
465
  except Exception as e:
461
466
  LOGGER.warning(f"{pkg} unable to run: {e}")
462
467
  return stats