ultralytics-opencv-headless 8.3.242__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1026 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1574 -0
- ultralytics/engine/model.py +1124 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +974 -0
- ultralytics/engine/tuner.py +448 -0
- ultralytics/engine/validator.py +384 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +958 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +998 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +315 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +444 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1560 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1045 -0
- ultralytics/utils/tal.py +403 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +440 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +160 -0
- ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
- ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
- ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
- ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
- ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
- ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.nn import functional as F
|
|
11
|
+
|
|
12
|
+
from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
|
|
13
|
+
from ultralytics.data.augment import LoadVisualPrompt
|
|
14
|
+
from ultralytics.data.utils import check_det_dataset
|
|
15
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
|
16
|
+
from ultralytics.models.yolo.segment import SegmentationValidator
|
|
17
|
+
from ultralytics.nn.modules.head import YOLOEDetect
|
|
18
|
+
from ultralytics.nn.tasks import YOLOEModel
|
|
19
|
+
from ultralytics.utils import LOGGER, TQDM
|
|
20
|
+
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class YOLOEDetectValidator(DetectionValidator):
|
|
24
|
+
"""A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
|
25
|
+
|
|
26
|
+
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
|
|
27
|
+
validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
|
|
28
|
+
evaluation strategies for prompt-based object detection.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
device (torch.device): The device on which validation is performed.
|
|
32
|
+
args (namespace): Configuration arguments for validation.
|
|
33
|
+
dataloader (DataLoader): DataLoader for validation data.
|
|
34
|
+
|
|
35
|
+
Methods:
|
|
36
|
+
get_visual_pe: Extract visual prompt embeddings from training samples.
|
|
37
|
+
preprocess: Preprocess batch data ensuring visuals are on the same device as images.
|
|
38
|
+
get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
|
|
39
|
+
__call__: Run validation using either text or visual prompt embeddings.
|
|
40
|
+
|
|
41
|
+
Examples:
|
|
42
|
+
Validate with text prompts
|
|
43
|
+
>>> validator = YOLOEDetectValidator()
|
|
44
|
+
>>> stats = validator(model=model, load_vp=False)
|
|
45
|
+
|
|
46
|
+
Validate with visual prompts
|
|
47
|
+
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
@smart_inference_mode()
|
|
51
|
+
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
|
52
|
+
"""Extract visual prompt embeddings from training samples.
|
|
53
|
+
|
|
54
|
+
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
|
|
55
|
+
normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
|
|
56
|
+
zero.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
|
60
|
+
model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
|
64
|
+
"""
|
|
65
|
+
assert isinstance(model, YOLOEModel)
|
|
66
|
+
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
|
|
67
|
+
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
|
68
|
+
cls_visual_num = torch.zeros(len(names))
|
|
69
|
+
|
|
70
|
+
desc = "Get visual prompt embeddings from samples"
|
|
71
|
+
|
|
72
|
+
# Count samples per class
|
|
73
|
+
for batch in dataloader:
|
|
74
|
+
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
|
|
75
|
+
count = torch.bincount(cls, minlength=len(names))
|
|
76
|
+
cls_visual_num += count
|
|
77
|
+
|
|
78
|
+
cls_visual_num = cls_visual_num.to(self.device)
|
|
79
|
+
|
|
80
|
+
# Extract visual prompt embeddings
|
|
81
|
+
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
|
|
82
|
+
for batch in pbar:
|
|
83
|
+
batch = self.preprocess(batch)
|
|
84
|
+
preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
|
|
85
|
+
|
|
86
|
+
batch_idx = batch["batch_idx"]
|
|
87
|
+
for i in range(preds.shape[0]):
|
|
88
|
+
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
|
89
|
+
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
|
90
|
+
pad_cls[: cls.shape[0]] = cls
|
|
91
|
+
for c in cls:
|
|
92
|
+
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
|
93
|
+
|
|
94
|
+
# Normalize embeddings for classes with samples, set others to zero
|
|
95
|
+
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
|
|
96
|
+
visual_pe[cls_visual_num == 0] = 0
|
|
97
|
+
return visual_pe.unsqueeze(0)
|
|
98
|
+
|
|
99
|
+
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
|
100
|
+
"""Create a dataloader for LVIS training visual prompt samples.
|
|
101
|
+
|
|
102
|
+
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
|
|
103
|
+
necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
data (dict): Dataset configuration dictionary containing paths and settings.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
(torch.utils.data.DataLoader): The dataloader for visual prompt samples.
|
|
110
|
+
"""
|
|
111
|
+
dataset = build_yolo_dataset(
|
|
112
|
+
self.args,
|
|
113
|
+
data.get(self.args.split, data.get("val")),
|
|
114
|
+
self.args.batch,
|
|
115
|
+
data,
|
|
116
|
+
mode="val",
|
|
117
|
+
rect=False,
|
|
118
|
+
)
|
|
119
|
+
if isinstance(dataset, YOLOConcatDataset):
|
|
120
|
+
for d in dataset.datasets:
|
|
121
|
+
d.transforms.append(LoadVisualPrompt())
|
|
122
|
+
else:
|
|
123
|
+
dataset.transforms.append(LoadVisualPrompt())
|
|
124
|
+
return build_dataloader(
|
|
125
|
+
dataset,
|
|
126
|
+
self.args.batch,
|
|
127
|
+
self.args.workers,
|
|
128
|
+
shuffle=False,
|
|
129
|
+
rank=-1,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@smart_inference_mode()
|
|
133
|
+
def __call__(
|
|
134
|
+
self,
|
|
135
|
+
trainer: Any | None = None,
|
|
136
|
+
model: YOLOEModel | str | None = None,
|
|
137
|
+
refer_data: str | None = None,
|
|
138
|
+
load_vp: bool = False,
|
|
139
|
+
) -> dict[str, Any]:
|
|
140
|
+
"""Run validation on the model using either text or visual prompt embeddings.
|
|
141
|
+
|
|
142
|
+
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
|
|
143
|
+
supports validation during training (using a trainer object) or standalone validation with a provided model. For
|
|
144
|
+
visual prompts, reference data can be specified to extract embeddings from a different dataset.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
trainer (object, optional): Trainer object containing the model and device.
|
|
148
|
+
model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
|
|
149
|
+
refer_data (str, optional): Path to reference data for visual prompts.
|
|
150
|
+
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
(dict): Validation statistics containing metrics computed during validation.
|
|
154
|
+
"""
|
|
155
|
+
if trainer is not None:
|
|
156
|
+
self.device = trainer.device
|
|
157
|
+
model = trainer.ema.ema
|
|
158
|
+
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
|
159
|
+
|
|
160
|
+
if load_vp:
|
|
161
|
+
LOGGER.info("Validate using the visual prompt.")
|
|
162
|
+
self.args.half = False
|
|
163
|
+
# Directly use the same dataloader for visual embeddings extracted during training
|
|
164
|
+
vpe = self.get_visual_pe(self.dataloader, model)
|
|
165
|
+
model.set_classes(names, vpe)
|
|
166
|
+
else:
|
|
167
|
+
LOGGER.info("Validate using the text prompt.")
|
|
168
|
+
tpe = model.get_text_pe(names)
|
|
169
|
+
model.set_classes(names, tpe)
|
|
170
|
+
stats = super().__call__(trainer, model)
|
|
171
|
+
else:
|
|
172
|
+
if refer_data is not None:
|
|
173
|
+
assert load_vp, "Refer data is only used for visual prompt validation."
|
|
174
|
+
self.device = select_device(self.args.device, verbose=False)
|
|
175
|
+
|
|
176
|
+
if isinstance(model, (str, Path)):
|
|
177
|
+
from ultralytics.nn.tasks import load_checkpoint
|
|
178
|
+
|
|
179
|
+
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
|
|
180
|
+
model.eval().to(self.device)
|
|
181
|
+
data = check_det_dataset(refer_data or self.args.data)
|
|
182
|
+
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
|
183
|
+
|
|
184
|
+
if load_vp:
|
|
185
|
+
LOGGER.info("Validate using the visual prompt.")
|
|
186
|
+
self.args.half = False
|
|
187
|
+
# TODO: need to check if the names from refer data is consistent with the evaluated dataset
|
|
188
|
+
# could use same dataset or refer to extract visual prompt embeddings
|
|
189
|
+
dataloader = self.get_vpe_dataloader(data)
|
|
190
|
+
vpe = self.get_visual_pe(dataloader, model)
|
|
191
|
+
model.set_classes(names, vpe)
|
|
192
|
+
stats = super().__call__(model=deepcopy(model))
|
|
193
|
+
elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
|
|
194
|
+
return super().__call__(trainer, model)
|
|
195
|
+
else:
|
|
196
|
+
LOGGER.info("Validate using the text prompt.")
|
|
197
|
+
tpe = model.get_text_pe(names)
|
|
198
|
+
model.set_classes(names, tpe)
|
|
199
|
+
stats = super().__call__(model=deepcopy(model))
|
|
200
|
+
return stats
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
|
|
204
|
+
"""YOLOE segmentation validator that supports both text and visual prompt embeddings."""
|
|
205
|
+
|
|
206
|
+
pass
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from .tasks import (
|
|
4
|
+
BaseModel,
|
|
5
|
+
ClassificationModel,
|
|
6
|
+
DetectionModel,
|
|
7
|
+
SegmentationModel,
|
|
8
|
+
guess_model_scale,
|
|
9
|
+
guess_model_task,
|
|
10
|
+
load_checkpoint,
|
|
11
|
+
parse_model,
|
|
12
|
+
torch_safe_load,
|
|
13
|
+
yaml_model_load,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = (
|
|
17
|
+
"BaseModel",
|
|
18
|
+
"ClassificationModel",
|
|
19
|
+
"DetectionModel",
|
|
20
|
+
"SegmentationModel",
|
|
21
|
+
"guess_model_scale",
|
|
22
|
+
"guess_model_task",
|
|
23
|
+
"load_checkpoint",
|
|
24
|
+
"parse_model",
|
|
25
|
+
"torch_safe_load",
|
|
26
|
+
"yaml_model_load",
|
|
27
|
+
)
|