ultralytics 8.3.98__py3-none-any.whl → 8.3.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +56 -0
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/data/augment.py +101 -5
- ultralytics/data/dataset.py +165 -12
- ultralytics/engine/exporter.py +4 -3
- ultralytics/engine/trainer.py +16 -7
- ultralytics/models/__init__.py +2 -2
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/detect/val.py +6 -1
- ultralytics/models/yolo/model.py +182 -3
- ultralytics/models/yolo/segment/val.py +43 -16
- ultralytics/models/yolo/yoloe/__init__.py +21 -0
- ultralytics/models/yolo/yoloe/predict.py +170 -0
- ultralytics/models/yolo/yoloe/train.py +355 -0
- ultralytics/models/yolo/yoloe/train_seg.py +141 -0
- ultralytics/models/yolo/yoloe/val.py +187 -0
- ultralytics/nn/autobackend.py +3 -2
- ultralytics/nn/modules/__init__.py +18 -1
- ultralytics/nn/modules/block.py +17 -1
- ultralytics/nn/modules/head.py +359 -22
- ultralytics/nn/tasks.py +276 -10
- ultralytics/nn/text_model.py +193 -0
- ultralytics/utils/callbacks/comet.py +3 -6
- ultralytics/utils/downloads.py +6 -2
- ultralytics/utils/loss.py +67 -6
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/tal.py +1 -1
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/METADATA +10 -10
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/RECORD +37 -27
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
ultralytics/data/dataset.py
CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
|
|
13
13
|
from torch.utils.data import ConcatDataset
|
14
14
|
|
15
15
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
16
|
-
from ultralytics.utils.ops import resample_segments
|
16
|
+
from ultralytics.utils.ops import resample_segments, segments2boxes
|
17
17
|
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
18
18
|
|
19
19
|
from .augment import (
|
@@ -27,6 +27,7 @@ from .augment import (
|
|
27
27
|
v8_transforms,
|
28
28
|
)
|
29
29
|
from .base import BaseDataset
|
30
|
+
from .converter import merge_multi_segment
|
30
31
|
from .utils import (
|
31
32
|
HELP_URL,
|
32
33
|
LOGGER,
|
@@ -289,12 +290,15 @@ class YOLODataset(BaseDataset):
|
|
289
290
|
(dict): Collated batch with stacked tensors.
|
290
291
|
"""
|
291
292
|
new_batch = {}
|
293
|
+
batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
|
292
294
|
keys = batch[0].keys()
|
293
295
|
values = list(zip(*[list(b.values()) for b in batch]))
|
294
296
|
for i, k in enumerate(keys):
|
295
297
|
value = values[i]
|
296
|
-
if k == "img":
|
298
|
+
if k == "img" or k == "text_feats":
|
297
299
|
value = torch.stack(value, 0)
|
300
|
+
elif k == "visuals":
|
301
|
+
value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
|
298
302
|
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
299
303
|
value = torch.cat(value, 0)
|
300
304
|
new_batch[k] = value
|
@@ -346,7 +350,9 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
346
350
|
"""
|
347
351
|
labels = super().update_labels_info(label)
|
348
352
|
# NOTE: some categories are concatenated with its synonyms by `/`.
|
353
|
+
# NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
|
349
354
|
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
355
|
+
|
350
356
|
return labels
|
351
357
|
|
352
358
|
def build_transforms(self, hyp=None):
|
@@ -362,9 +368,46 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
362
368
|
transforms = super().build_transforms(hyp)
|
363
369
|
if self.augment:
|
364
370
|
# NOTE: hard-coded the args for now.
|
365
|
-
|
371
|
+
# NOTE: this implementation is different from official yoloe,
|
372
|
+
# the strategy of selecting negative is restricted in one dataset,
|
373
|
+
# while official pre-saved neg embeddings from all datasets at once.
|
374
|
+
transform = RandomLoadText(
|
375
|
+
max_samples=min(self.data["nc"], 80),
|
376
|
+
padding=True,
|
377
|
+
padding_value=self._get_neg_texts(self.category_freq),
|
378
|
+
)
|
379
|
+
transforms.insert(-1, transform)
|
366
380
|
return transforms
|
367
381
|
|
382
|
+
@property
|
383
|
+
def category_names(self):
|
384
|
+
"""
|
385
|
+
Return category names for the dataset.
|
386
|
+
|
387
|
+
Returns:
|
388
|
+
(Tuple[str]): List of class names.
|
389
|
+
"""
|
390
|
+
names = self.data["names"].values()
|
391
|
+
return {n.strip() for name in names for n in name.split("/")} # category names
|
392
|
+
|
393
|
+
@property
|
394
|
+
def category_freq(self):
|
395
|
+
"""Return frequency of each category in the dataset."""
|
396
|
+
texts = [v.split("/") for v in self.data["names"].values()]
|
397
|
+
category_freq = defaultdict(int)
|
398
|
+
for label in self.labels:
|
399
|
+
for c in label["cls"]: # to check
|
400
|
+
text = texts[int(c)]
|
401
|
+
for t in text:
|
402
|
+
t = t.strip()
|
403
|
+
category_freq[t] += 1
|
404
|
+
return category_freq
|
405
|
+
|
406
|
+
@staticmethod
|
407
|
+
def _get_neg_texts(category_freq, threshold=100):
|
408
|
+
"""Get negative text samples based on frequency threshold."""
|
409
|
+
return [k for k, v in category_freq.items() if v >= threshold]
|
410
|
+
|
368
411
|
|
369
412
|
class GroundingDataset(YOLODataset):
|
370
413
|
"""
|
@@ -386,17 +429,17 @@ class GroundingDataset(YOLODataset):
|
|
386
429
|
>>> len(dataset) # Number of valid images with annotations
|
387
430
|
"""
|
388
431
|
|
389
|
-
def __init__(self, *args, task="detect", json_file, **kwargs):
|
432
|
+
def __init__(self, *args, task="detect", json_file="", **kwargs):
|
390
433
|
"""
|
391
434
|
Initialize a GroundingDataset for object detection.
|
392
435
|
|
393
436
|
Args:
|
394
437
|
json_file (str): Path to the JSON file containing annotations.
|
395
|
-
task (str): Must be 'detect' for GroundingDataset.
|
438
|
+
task (str): Must be 'detect' or 'segment' for GroundingDataset.
|
396
439
|
*args (Any): Additional positional arguments for the parent class.
|
397
440
|
**kwargs (Any): Additional keyword arguments for the parent class.
|
398
441
|
"""
|
399
|
-
assert task
|
442
|
+
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
400
443
|
self.json_file = json_file
|
401
444
|
super().__init__(*args, task=task, data={}, **kwargs)
|
402
445
|
|
@@ -412,14 +455,31 @@ class GroundingDataset(YOLODataset):
|
|
412
455
|
"""
|
413
456
|
return []
|
414
457
|
|
415
|
-
def
|
458
|
+
def verify_labels(self, labels):
|
459
|
+
"""Verify the number of instances in the dataset matches expected counts."""
|
460
|
+
instance_count = sum(label["bboxes"].shape[0] for label in labels)
|
461
|
+
if "final_mixed_train_no_coco_segm" in self.json_file:
|
462
|
+
assert instance_count == 3662344
|
463
|
+
elif "final_mixed_train_no_coco" in self.json_file:
|
464
|
+
assert instance_count == 3681235
|
465
|
+
elif "final_flickr_separateGT_train_segm" in self.json_file:
|
466
|
+
assert instance_count == 638214
|
467
|
+
elif "final_flickr_separateGT_train" in self.json_file:
|
468
|
+
assert instance_count == 640704
|
469
|
+
else:
|
470
|
+
assert False
|
471
|
+
|
472
|
+
def cache_labels(self, path=Path("./labels.cache")):
|
416
473
|
"""
|
417
474
|
Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.
|
418
475
|
|
476
|
+
Args:
|
477
|
+
path (Path): Path where to save the cache file.
|
478
|
+
|
419
479
|
Returns:
|
420
|
-
(
|
480
|
+
(dict): Dictionary containing cached labels and related information.
|
421
481
|
"""
|
422
|
-
|
482
|
+
x = {"labels": []}
|
423
483
|
LOGGER.info("Loading annotation file...")
|
424
484
|
with open(self.json_file) as f:
|
425
485
|
annotations = json.load(f)
|
@@ -435,6 +495,7 @@ class GroundingDataset(YOLODataset):
|
|
435
495
|
continue
|
436
496
|
self.im_files.append(str(im_file))
|
437
497
|
bboxes = []
|
498
|
+
segments = []
|
438
499
|
cat2id = {}
|
439
500
|
texts = []
|
440
501
|
for ann in anns:
|
@@ -448,7 +509,10 @@ class GroundingDataset(YOLODataset):
|
|
448
509
|
continue
|
449
510
|
|
450
511
|
caption = img["caption"]
|
451
|
-
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]])
|
512
|
+
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
|
513
|
+
if not cat_name:
|
514
|
+
continue
|
515
|
+
|
452
516
|
if cat_name not in cat2id:
|
453
517
|
cat2id[cat_name] = len(cat2id)
|
454
518
|
texts.append([cat_name])
|
@@ -456,18 +520,66 @@ class GroundingDataset(YOLODataset):
|
|
456
520
|
box = [cls] + box.tolist()
|
457
521
|
if box not in bboxes:
|
458
522
|
bboxes.append(box)
|
523
|
+
if ann.get("segmentation") is not None:
|
524
|
+
if len(ann["segmentation"]) == 0:
|
525
|
+
segments.append(box)
|
526
|
+
continue
|
527
|
+
elif len(ann["segmentation"]) > 1:
|
528
|
+
s = merge_multi_segment(ann["segmentation"])
|
529
|
+
s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
|
530
|
+
else:
|
531
|
+
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
532
|
+
s = (
|
533
|
+
(np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
|
534
|
+
.reshape(-1)
|
535
|
+
.tolist()
|
536
|
+
)
|
537
|
+
s = [cls] + s
|
538
|
+
segments.append(s)
|
459
539
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
460
|
-
|
540
|
+
|
541
|
+
if segments:
|
542
|
+
classes = np.array([x[0] for x in segments], dtype=np.float32)
|
543
|
+
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
|
544
|
+
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
545
|
+
lb = np.array(lb, dtype=np.float32)
|
546
|
+
|
547
|
+
x["labels"].append(
|
461
548
|
{
|
462
549
|
"im_file": im_file,
|
463
550
|
"shape": (h, w),
|
464
551
|
"cls": lb[:, 0:1], # n, 1
|
465
552
|
"bboxes": lb[:, 1:], # n, 4
|
553
|
+
"segments": segments,
|
466
554
|
"normalized": True,
|
467
555
|
"bbox_format": "xywh",
|
468
556
|
"texts": texts,
|
469
557
|
}
|
470
558
|
)
|
559
|
+
x["hash"] = get_hash(self.json_file)
|
560
|
+
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
561
|
+
return x
|
562
|
+
|
563
|
+
def get_labels(self):
|
564
|
+
"""
|
565
|
+
Load labels from cache or generate them from JSON file.
|
566
|
+
|
567
|
+
Returns:
|
568
|
+
(List[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
569
|
+
"""
|
570
|
+
cache_path = Path(self.json_file).with_suffix(".cache")
|
571
|
+
try:
|
572
|
+
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
573
|
+
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
574
|
+
assert cache["hash"] == get_hash(self.json_file) # identical hash
|
575
|
+
except (FileNotFoundError, AssertionError, AttributeError):
|
576
|
+
cache, _ = self.cache_labels(cache_path), False # run cache ops
|
577
|
+
[cache.pop(k) for k in ("hash", "version")] # remove items
|
578
|
+
labels = cache["labels"]
|
579
|
+
# self.verify_labels(labels)
|
580
|
+
self.im_files = [str(label["im_file"]) for label in labels]
|
581
|
+
if LOCAL_RANK in {-1, 0}:
|
582
|
+
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
|
471
583
|
return labels
|
472
584
|
|
473
585
|
def build_transforms(self, hyp=None):
|
@@ -483,9 +595,38 @@ class GroundingDataset(YOLODataset):
|
|
483
595
|
transforms = super().build_transforms(hyp)
|
484
596
|
if self.augment:
|
485
597
|
# NOTE: hard-coded the args for now.
|
486
|
-
|
598
|
+
# NOTE: this implementation is different from official yoloe,
|
599
|
+
# the strategy of selecting negative is restricted in one dataset,
|
600
|
+
# while official pre-saved neg embeddings from all datasets at once.
|
601
|
+
transform = RandomLoadText(
|
602
|
+
max_samples=80,
|
603
|
+
padding=True,
|
604
|
+
padding_value=self._get_neg_texts(self.category_freq),
|
605
|
+
)
|
606
|
+
transforms.insert(-1, transform)
|
487
607
|
return transforms
|
488
608
|
|
609
|
+
@property
|
610
|
+
def category_names(self):
|
611
|
+
"""Return unique category names from the dataset."""
|
612
|
+
return {t.strip() for label in self.labels for text in label["texts"] for t in text}
|
613
|
+
|
614
|
+
@property
|
615
|
+
def category_freq(self):
|
616
|
+
"""Return frequency of each category in the dataset."""
|
617
|
+
category_freq = defaultdict(int)
|
618
|
+
for label in self.labels:
|
619
|
+
for text in label["texts"]:
|
620
|
+
for t in text:
|
621
|
+
t = t.strip()
|
622
|
+
category_freq[t] += 1
|
623
|
+
return category_freq
|
624
|
+
|
625
|
+
@staticmethod
|
626
|
+
def _get_neg_texts(category_freq, threshold=100):
|
627
|
+
"""Get negative text samples based on frequency threshold."""
|
628
|
+
return [k for k, v in category_freq.items() if v >= threshold]
|
629
|
+
|
489
630
|
|
490
631
|
class YOLOConcatDataset(ConcatDataset):
|
491
632
|
"""
|
@@ -516,6 +657,18 @@ class YOLOConcatDataset(ConcatDataset):
|
|
516
657
|
"""
|
517
658
|
return YOLODataset.collate_fn(batch)
|
518
659
|
|
660
|
+
def close_mosaic(self, hyp):
|
661
|
+
"""
|
662
|
+
Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
|
663
|
+
|
664
|
+
Args:
|
665
|
+
hyp (dict): Hyperparameters for transforms.
|
666
|
+
"""
|
667
|
+
for dataset in self.datasets:
|
668
|
+
if not hasattr(dataset, "close_mosaic"):
|
669
|
+
continue
|
670
|
+
dataset.close_mosaic(hyp)
|
671
|
+
|
519
672
|
|
520
673
|
# TODO: support semantic segmentation
|
521
674
|
class SemanticDataset(BaseDataset):
|
ultralytics/engine/exporter.py
CHANGED
@@ -327,6 +327,7 @@ class Exporter:
|
|
327
327
|
"See https://docs.ultralytics.com/models/yolo-world for details."
|
328
328
|
)
|
329
329
|
model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445
|
330
|
+
|
330
331
|
if self.args.int8 and not self.args.data:
|
331
332
|
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
|
332
333
|
LOGGER.warning(
|
@@ -635,7 +636,7 @@ class Exporter:
|
|
635
636
|
# Generate calibration data for integer quantization
|
636
637
|
ignored_scope = None
|
637
638
|
if isinstance(self.model.model[-1], Detect):
|
638
|
-
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
|
639
|
+
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect, YOLOEDetect
|
639
640
|
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
|
640
641
|
ignored_scope = nncf.IgnoredScope( # ignore operations
|
641
642
|
patterns=[
|
@@ -797,12 +798,12 @@ class Exporter:
|
|
797
798
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
|
798
799
|
# TODO CoreML Segment and Pose model pipelining
|
799
800
|
model = self.model
|
800
|
-
|
801
801
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
802
802
|
ct_model = ct.convert(
|
803
803
|
ts,
|
804
|
-
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)],
|
804
|
+
inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], # expects ct.TensorType
|
805
805
|
classifier_config=classifier_config,
|
806
|
+
minimum_deployment_target=ct.target.iOS16,
|
806
807
|
convert_to="neuralnetwork" if mlmodel else "mlprogram",
|
807
808
|
)
|
808
809
|
bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None)
|
ultralytics/engine/trainer.py
CHANGED
@@ -249,6 +249,7 @@ class BaseTrainer:
|
|
249
249
|
)
|
250
250
|
always_freeze_names = [".dfl"] # always freeze these layers
|
251
251
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
252
|
+
self.freeze_layer_names = freeze_layer_names
|
252
253
|
for k, v in self.model.named_parameters():
|
253
254
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
254
255
|
if any(x in k for x in freeze_layer_names):
|
@@ -350,7 +351,7 @@ class BaseTrainer:
|
|
350
351
|
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
351
352
|
self.scheduler.step()
|
352
353
|
|
353
|
-
self.
|
354
|
+
self._model_train()
|
354
355
|
if RANK != -1:
|
355
356
|
self.train_loader.sampler.set_epoch(epoch)
|
356
357
|
pbar = enumerate(self.train_loader)
|
@@ -381,7 +382,8 @@ class BaseTrainer:
|
|
381
382
|
# Forward
|
382
383
|
with autocast(self.amp):
|
383
384
|
batch = self.preprocess_batch(batch)
|
384
|
-
|
385
|
+
loss, self.loss_items = self.model(batch)
|
386
|
+
self.loss = loss.sum()
|
385
387
|
if RANK != -1:
|
386
388
|
self.loss *= world_size
|
387
389
|
self.tloss = (
|
@@ -496,9 +498,7 @@ class BaseTrainer:
|
|
496
498
|
memory = torch.mps.driver_allocated_memory()
|
497
499
|
if fraction:
|
498
500
|
return __import__("psutil").virtual_memory().percent / 100
|
499
|
-
elif self.device.type
|
500
|
-
pass
|
501
|
-
else:
|
501
|
+
elif self.device.type != "cpu":
|
502
502
|
memory = torch.cuda.memory_reserved()
|
503
503
|
if fraction:
|
504
504
|
total = torch.cuda.get_device_properties(self.device).total_memory
|
@@ -520,6 +520,14 @@ class BaseTrainer:
|
|
520
520
|
|
521
521
|
return pd.read_csv(self.csv).to_dict(orient="list")
|
522
522
|
|
523
|
+
def _model_train(self):
|
524
|
+
"""Set model in training mode."""
|
525
|
+
self.model.train()
|
526
|
+
# Freeze BN stat
|
527
|
+
for n, m in self.model.named_modules():
|
528
|
+
if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
|
529
|
+
m.eval()
|
530
|
+
|
523
531
|
def save_model(self):
|
524
532
|
"""Save model training checkpoints with additional metadata."""
|
525
533
|
import io
|
@@ -720,7 +728,7 @@ class BaseTrainer:
|
|
720
728
|
|
721
729
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
722
730
|
ckpt_args = attempt_load_weights(last).args
|
723
|
-
if not Path(ckpt_args["data"]).exists():
|
731
|
+
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
|
724
732
|
ckpt_args["data"] = self.args.data
|
725
733
|
|
726
734
|
resume = True
|
@@ -812,7 +820,8 @@ class BaseTrainer:
|
|
812
820
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
813
821
|
if "bias" in fullname: # bias (no decay)
|
814
822
|
g[2].append(param)
|
815
|
-
elif isinstance(module, bn): # weight (no decay)
|
823
|
+
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
|
824
|
+
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
|
816
825
|
g[1].append(param)
|
817
826
|
else: # weight (with decay)
|
818
827
|
g[0].append(param)
|
ultralytics/models/__init__.py
CHANGED
@@ -4,6 +4,6 @@ from .fastsam import FastSAM
|
|
4
4
|
from .nas import NAS
|
5
5
|
from .rtdetr import RTDETR
|
6
6
|
from .sam import SAM
|
7
|
-
from .yolo import YOLO, YOLOWorld
|
7
|
+
from .yolo import YOLO, YOLOE, YOLOWorld
|
8
8
|
|
9
|
-
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
|
9
|
+
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
|
3
|
+
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
4
4
|
|
5
|
-
from .model import YOLO, YOLOWorld
|
5
|
+
from .model import YOLO, YOLOE, YOLOWorld
|
6
6
|
|
7
|
-
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"
|
7
|
+
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
|
@@ -455,8 +455,13 @@ class DetectionValidator(BaseValidator):
|
|
455
455
|
val.print_results() # explicitly call print_results
|
456
456
|
# update mAP50-95 and mAP50
|
457
457
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
|
458
|
-
val.stats[:2] if self.is_coco else [val.results["
|
458
|
+
val.stats[:2] if self.is_coco else [val.results["AP"], val.results["AP50"]]
|
459
459
|
)
|
460
|
+
if self.is_lvis:
|
461
|
+
stats["metrics/APr(B)"] = val.results["APr"]
|
462
|
+
stats["metrics/APc(B)"] = val.results["APc"]
|
463
|
+
stats["metrics/APf(B)"] = val.results["APf"]
|
464
|
+
stats["fitness"] = val.results["AP"]
|
460
465
|
except Exception as e:
|
461
466
|
LOGGER.warning(f"{pkg} unable to run: {e}")
|
462
467
|
return stats
|
ultralytics/models/yolo/model.py
CHANGED
@@ -4,7 +4,16 @@ from pathlib import Path
|
|
4
4
|
|
5
5
|
from ultralytics.engine.model import Model
|
6
6
|
from ultralytics.models import yolo
|
7
|
-
from ultralytics.nn.tasks import
|
7
|
+
from ultralytics.nn.tasks import (
|
8
|
+
ClassificationModel,
|
9
|
+
DetectionModel,
|
10
|
+
OBBModel,
|
11
|
+
PoseModel,
|
12
|
+
SegmentationModel,
|
13
|
+
WorldModel,
|
14
|
+
YOLOEModel,
|
15
|
+
YOLOESegModel,
|
16
|
+
)
|
8
17
|
from ultralytics.utils import ROOT, yaml_load
|
9
18
|
|
10
19
|
|
@@ -12,12 +21,16 @@ class YOLO(Model):
|
|
12
21
|
"""YOLO (You Only Look Once) object detection model."""
|
13
22
|
|
14
23
|
def __init__(self, model="yolo11n.pt", task=None, verbose=False):
|
15
|
-
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
|
24
|
+
"""Initialize YOLO model, switching to YOLOWorld/YOLOE if model filename contains '-world'/'yoloe'."""
|
16
25
|
path = Path(model)
|
17
26
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
18
27
|
new_instance = YOLOWorld(path, verbose=verbose)
|
19
28
|
self.__class__ = type(new_instance)
|
20
29
|
self.__dict__ = new_instance.__dict__
|
30
|
+
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
|
31
|
+
new_instance = YOLOE(path, task=task, verbose=verbose)
|
32
|
+
self.__class__ = type(new_instance)
|
33
|
+
self.__dict__ = new_instance.__dict__
|
21
34
|
else:
|
22
35
|
# Continue with default YOLO initialization
|
23
36
|
super().__init__(model=model, task=task, verbose=verbose)
|
@@ -96,7 +109,7 @@ class YOLOWorld(Model):
|
|
96
109
|
Set the model's class names for detection.
|
97
110
|
|
98
111
|
Args:
|
99
|
-
classes (
|
112
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
100
113
|
"""
|
101
114
|
self.model.set_classes(classes)
|
102
115
|
# Remove background if it's given
|
@@ -108,3 +121,169 @@ class YOLOWorld(Model):
|
|
108
121
|
# Reset method class names
|
109
122
|
if self.predictor:
|
110
123
|
self.predictor.model.names = classes
|
124
|
+
|
125
|
+
|
126
|
+
class YOLOE(Model):
|
127
|
+
"""YOLOE object detection and segmentation model."""
|
128
|
+
|
129
|
+
def __init__(self, model="yoloe-v8s-seg.pt", task=None, verbose=False) -> None:
|
130
|
+
"""
|
131
|
+
Initialize YOLOE model with a pre-trained model file.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
135
|
+
task (str, optional): Task type for the model. Auto-detected if None.
|
136
|
+
verbose (bool): If True, prints additional information during initialization.
|
137
|
+
"""
|
138
|
+
super().__init__(model=model, task=task, verbose=verbose)
|
139
|
+
|
140
|
+
# Assign default COCO class names when there are no custom names
|
141
|
+
if not hasattr(self.model, "names"):
|
142
|
+
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
143
|
+
|
144
|
+
@property
|
145
|
+
def task_map(self):
|
146
|
+
"""Map head to model, validator, and predictor classes."""
|
147
|
+
return {
|
148
|
+
"detect": {
|
149
|
+
"model": YOLOEModel,
|
150
|
+
"validator": yolo.yoloe.YOLOEDetectValidator,
|
151
|
+
"predictor": yolo.detect.DetectionPredictor,
|
152
|
+
"trainer": yolo.yoloe.YOLOETrainer,
|
153
|
+
},
|
154
|
+
"segment": {
|
155
|
+
"model": YOLOESegModel,
|
156
|
+
"validator": yolo.yoloe.YOLOESegValidator,
|
157
|
+
"predictor": yolo.segment.SegmentationPredictor,
|
158
|
+
"trainer": yolo.yoloe.YOLOESegTrainer,
|
159
|
+
},
|
160
|
+
}
|
161
|
+
|
162
|
+
def get_text_pe(self, texts):
|
163
|
+
"""Get text positional embeddings for the given texts."""
|
164
|
+
assert isinstance(self.model, YOLOEModel)
|
165
|
+
return self.model.get_text_pe(texts)
|
166
|
+
|
167
|
+
def get_visual_pe(self, img, visual):
|
168
|
+
"""Get visual positional embeddings for the given image and visual features."""
|
169
|
+
assert isinstance(self.model, YOLOEModel)
|
170
|
+
return self.model.get_visual_pe(img, visual)
|
171
|
+
|
172
|
+
def set_vocab(self, vocab, names):
|
173
|
+
"""Set vocabulary and class names for the model."""
|
174
|
+
assert isinstance(self.model, YOLOEModel)
|
175
|
+
self.model.set_vocab(vocab, names=names)
|
176
|
+
|
177
|
+
def get_vocab(self, names):
|
178
|
+
"""Get vocabulary for the given class names."""
|
179
|
+
assert isinstance(self.model, YOLOEModel)
|
180
|
+
return self.model.get_vocab(names)
|
181
|
+
|
182
|
+
def set_classes(self, classes, embeddings):
|
183
|
+
"""
|
184
|
+
Set the model's class names and embeddings for detection.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
188
|
+
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
189
|
+
"""
|
190
|
+
assert isinstance(self.model, YOLOEModel)
|
191
|
+
self.model.set_classes(classes, embeddings)
|
192
|
+
# Verify no background class is present
|
193
|
+
assert " " not in classes
|
194
|
+
self.model.names = classes
|
195
|
+
|
196
|
+
# Reset method class names
|
197
|
+
if self.predictor:
|
198
|
+
self.predictor.model.names = classes
|
199
|
+
|
200
|
+
def val(
|
201
|
+
self,
|
202
|
+
validator=None,
|
203
|
+
load_vp=False,
|
204
|
+
refer_data=None,
|
205
|
+
**kwargs,
|
206
|
+
):
|
207
|
+
"""
|
208
|
+
Validate the model using text or visual prompts.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
212
|
+
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
213
|
+
refer_data (str, optional): Path to the reference data for visual prompts.
|
214
|
+
**kwargs (Any): Additional keyword arguments to override default settings.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
(dict): Validation statistics containing metrics computed during validation.
|
218
|
+
"""
|
219
|
+
custom = {"rect": not load_vp} # method defaults
|
220
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
221
|
+
|
222
|
+
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
223
|
+
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
|
224
|
+
self.metrics = validator.metrics
|
225
|
+
return validator.metrics
|
226
|
+
|
227
|
+
def predict(
|
228
|
+
self,
|
229
|
+
source=None,
|
230
|
+
stream: bool = False,
|
231
|
+
visual_prompts: dict = {},
|
232
|
+
refer_image=None,
|
233
|
+
predictor=None,
|
234
|
+
**kwargs,
|
235
|
+
):
|
236
|
+
"""
|
237
|
+
Run prediction on images, videos, directories, streams, etc.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
241
|
+
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
242
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
243
|
+
generator as they are computed.
|
244
|
+
visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
|
245
|
+
'cls' keys when non-empty.
|
246
|
+
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
247
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
248
|
+
loaded based on the task.
|
249
|
+
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
(List | generator): List of Results objects or generator of Results objects if stream=True.
|
253
|
+
|
254
|
+
Examples:
|
255
|
+
>>> model = YOLOE("yoloe-v8s-seg.pt")
|
256
|
+
>>> results = model.predict("path/to/image.jpg")
|
257
|
+
>>> # With visual prompts
|
258
|
+
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
259
|
+
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
|
260
|
+
"""
|
261
|
+
if len(visual_prompts):
|
262
|
+
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
|
263
|
+
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
|
264
|
+
)
|
265
|
+
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
|
266
|
+
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
|
267
|
+
f"{len(visual_prompts['cls'])} respectively"
|
268
|
+
)
|
269
|
+
self.predictor = (predictor or self._smart_load("predictor"))(
|
270
|
+
overrides={"task": "segment", "mode": "predict", "save": False, "verbose": False}, _callbacks=self.callbacks
|
271
|
+
)
|
272
|
+
|
273
|
+
if len(visual_prompts):
|
274
|
+
num_cls = (
|
275
|
+
max(len(set(c)) for c in visual_prompts["cls"])
|
276
|
+
if isinstance(source, list) # means multiple images
|
277
|
+
else len(set(visual_prompts["cls"]))
|
278
|
+
)
|
279
|
+
self.model.model[-1].nc = num_cls
|
280
|
+
self.model.names = [f"object{i}" for i in range(num_cls)]
|
281
|
+
self.predictor.set_prompts(visual_prompts)
|
282
|
+
|
283
|
+
self.predictor.setup_model(model=self.model)
|
284
|
+
if refer_image is not None and len(visual_prompts):
|
285
|
+
vpe = self.predictor.get_vpe(refer_image)
|
286
|
+
self.model.set_classes(self.model.names, vpe)
|
287
|
+
self.predictor = None # reset predictor
|
288
|
+
|
289
|
+
return super().predict(source, stream, **kwargs)
|