dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,124 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import copy, deepcopy
|
4
|
+
|
5
|
+
from ultralytics.models.yolo.segment import SegmentationTrainer
|
6
|
+
from ultralytics.nn.tasks import YOLOESegModel
|
7
|
+
from ultralytics.utils import RANK
|
8
|
+
|
9
|
+
from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
10
|
+
from .val import YOLOESegValidator
|
11
|
+
|
12
|
+
|
13
|
+
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
14
|
+
"""
|
15
|
+
Trainer class for YOLOE segmentation models.
|
16
|
+
|
17
|
+
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
|
18
|
+
specifically for YOLOE segmentation models.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
cfg (dict): Configuration dictionary with training parameters.
|
22
|
+
overrides (dict): Dictionary with parameter overrides.
|
23
|
+
_callbacks (list): List of callback functions for training events.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
27
|
+
"""
|
28
|
+
Return YOLOESegModel initialized with specified config and weights.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
cfg (dict | str): Model configuration dictionary or YAML file path.
|
32
|
+
weights (str, optional): Path to pretrained weights file.
|
33
|
+
verbose (bool): Whether to display model information.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
(YOLOESegModel): Initialized YOLOE segmentation model.
|
37
|
+
"""
|
38
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
39
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
40
|
+
model = YOLOESegModel(
|
41
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
42
|
+
ch=self.data["channels"],
|
43
|
+
nc=min(self.data["nc"], 80),
|
44
|
+
verbose=verbose and RANK == -1,
|
45
|
+
)
|
46
|
+
if weights:
|
47
|
+
model.load(weights)
|
48
|
+
|
49
|
+
return model
|
50
|
+
|
51
|
+
def get_validator(self):
|
52
|
+
"""
|
53
|
+
Create and return a validator for YOLOE segmentation model evaluation.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
57
|
+
"""
|
58
|
+
self.loss_names = "box", "seg", "cls", "dfl"
|
59
|
+
return YOLOESegValidator(
|
60
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
class YOLOEPESegTrainer(SegmentationTrainer):
|
65
|
+
"""
|
66
|
+
Fine-tune YOLOESeg model in linear probing way.
|
67
|
+
|
68
|
+
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
69
|
+
most of the model and only training specific layers.
|
70
|
+
"""
|
71
|
+
|
72
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
73
|
+
"""
|
74
|
+
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
cfg (dict | str): Model configuration dictionary or YAML file path.
|
78
|
+
weights (str, optional): Path to pretrained weights file.
|
79
|
+
verbose (bool): Whether to display model information.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
(YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
|
83
|
+
"""
|
84
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
85
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
86
|
+
model = YOLOESegModel(
|
87
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
88
|
+
ch=self.data["channels"],
|
89
|
+
nc=self.data["nc"],
|
90
|
+
verbose=verbose and RANK == -1,
|
91
|
+
)
|
92
|
+
|
93
|
+
del model.model[-1].savpe
|
94
|
+
|
95
|
+
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
96
|
+
if weights:
|
97
|
+
model.load(weights)
|
98
|
+
|
99
|
+
model.eval()
|
100
|
+
names = list(self.data["names"].values())
|
101
|
+
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
102
|
+
# it'd get correct results as long as loading proper pretrained weights.
|
103
|
+
tpe = model.get_text_pe(names)
|
104
|
+
model.set_classes(names, tpe)
|
105
|
+
model.model[-1].fuse(model.pe)
|
106
|
+
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
107
|
+
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
108
|
+
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
109
|
+
del model.pe
|
110
|
+
model.train()
|
111
|
+
|
112
|
+
return model
|
113
|
+
|
114
|
+
|
115
|
+
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
|
116
|
+
"""Trainer for YOLOE segmentation from scratch."""
|
117
|
+
|
118
|
+
pass
|
119
|
+
|
120
|
+
|
121
|
+
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
|
122
|
+
"""Trainer for YOLOE segmentation with VP."""
|
123
|
+
|
124
|
+
pass
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.nn import functional as F
|
7
|
+
|
8
|
+
from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
|
9
|
+
from ultralytics.data.augment import LoadVisualPrompt
|
10
|
+
from ultralytics.data.utils import check_det_dataset
|
11
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
12
|
+
from ultralytics.models.yolo.segment import SegmentationValidator
|
13
|
+
from ultralytics.nn.modules.head import YOLOEDetect
|
14
|
+
from ultralytics.nn.tasks import YOLOEModel
|
15
|
+
from ultralytics.utils import LOGGER, TQDM
|
16
|
+
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
17
|
+
|
18
|
+
|
19
|
+
class YOLOEDetectValidator(DetectionValidator):
|
20
|
+
"""
|
21
|
+
A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
|
22
|
+
|
23
|
+
This mixin provides functionality to validate YOLOE models using either text or visual prompt embeddings.
|
24
|
+
It includes methods for extracting visual prompt embeddings from samples, preprocessing batches, and
|
25
|
+
running validation with different prompt types.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
device (torch.device): The device on which validation is performed.
|
29
|
+
args (namespace): Configuration arguments for validation.
|
30
|
+
dataloader (DataLoader): DataLoader for validation data.
|
31
|
+
"""
|
32
|
+
|
33
|
+
@smart_inference_mode()
|
34
|
+
def get_visual_pe(self, dataloader, model):
|
35
|
+
"""
|
36
|
+
Extract visual prompt embeddings from training samples.
|
37
|
+
|
38
|
+
This function processes a dataloader to compute visual prompt embeddings for each class
|
39
|
+
using a YOLOE model. It normalizes the embeddings and handles cases where no samples
|
40
|
+
exist for a class.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
44
|
+
model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
48
|
+
"""
|
49
|
+
assert isinstance(model, YOLOEModel)
|
50
|
+
names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
|
51
|
+
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
52
|
+
cls_visual_num = torch.zeros(len(names))
|
53
|
+
|
54
|
+
desc = "Get visual prompt embeddings from samples"
|
55
|
+
|
56
|
+
for batch in dataloader:
|
57
|
+
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
|
58
|
+
count = torch.bincount(cls, minlength=len(names))
|
59
|
+
cls_visual_num += count
|
60
|
+
|
61
|
+
cls_visual_num = cls_visual_num.to(self.device)
|
62
|
+
|
63
|
+
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
|
64
|
+
for batch in pbar:
|
65
|
+
batch = self.preprocess(batch)
|
66
|
+
preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
|
67
|
+
|
68
|
+
batch_idx = batch["batch_idx"]
|
69
|
+
for i in range(preds.shape[0]):
|
70
|
+
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
71
|
+
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
72
|
+
pad_cls[: len(cls)] = cls
|
73
|
+
for c in cls:
|
74
|
+
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
75
|
+
|
76
|
+
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
|
77
|
+
visual_pe[cls_visual_num == 0] = 0
|
78
|
+
return visual_pe.unsqueeze(0)
|
79
|
+
|
80
|
+
def preprocess(self, batch):
|
81
|
+
"""Preprocess batch data, ensuring visuals are on the same device as images."""
|
82
|
+
batch = super().preprocess(batch)
|
83
|
+
if "visuals" in batch:
|
84
|
+
batch["visuals"] = batch["visuals"].to(batch["img"].device)
|
85
|
+
return batch
|
86
|
+
|
87
|
+
def get_vpe_dataloader(self, data):
|
88
|
+
"""
|
89
|
+
Create a dataloader for LVIS training visual prompt samples.
|
90
|
+
|
91
|
+
This function prepares a dataloader for visual prompt embeddings (VPE) using the LVIS dataset.
|
92
|
+
It applies necessary transformations and configurations to the dataset and returns a dataloader
|
93
|
+
for validation purposes.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
data (dict): Dataset configuration dictionary containing paths and settings.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
(torch.utils.data.DataLoader): The dataLoader for visual prompt samples.
|
100
|
+
"""
|
101
|
+
dataset = build_yolo_dataset(
|
102
|
+
self.args,
|
103
|
+
data.get(self.args.split, data.get("val")),
|
104
|
+
self.args.batch,
|
105
|
+
data,
|
106
|
+
mode="val",
|
107
|
+
rect=False,
|
108
|
+
)
|
109
|
+
if isinstance(dataset, YOLOConcatDataset):
|
110
|
+
for d in dataset.datasets:
|
111
|
+
d.transforms.append(LoadVisualPrompt())
|
112
|
+
else:
|
113
|
+
dataset.transforms.append(LoadVisualPrompt())
|
114
|
+
return build_dataloader(
|
115
|
+
dataset,
|
116
|
+
self.args.batch,
|
117
|
+
self.args.workers,
|
118
|
+
shuffle=False,
|
119
|
+
rank=-1,
|
120
|
+
)
|
121
|
+
|
122
|
+
@smart_inference_mode()
|
123
|
+
def __call__(self, trainer=None, model=None, refer_data=None, load_vp=False):
|
124
|
+
"""
|
125
|
+
Run validation on the model using either text or visual prompt embeddings.
|
126
|
+
|
127
|
+
This method validates the model using either text prompts or visual prompts, depending
|
128
|
+
on the `load_vp` flag. It supports validation during training (using a trainer object)
|
129
|
+
or standalone validation with a provided model.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
trainer (object, optional): Trainer object containing the model and device.
|
133
|
+
model (YOLOEModel, optional): Model to validate. Required if `trainer` is not provided.
|
134
|
+
refer_data (str, optional): Path to reference data for visual prompts.
|
135
|
+
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
(dict): Validation statistics containing metrics computed during validation.
|
139
|
+
"""
|
140
|
+
if trainer is not None:
|
141
|
+
self.device = trainer.device
|
142
|
+
model = trainer.ema.ema
|
143
|
+
names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
144
|
+
|
145
|
+
if load_vp:
|
146
|
+
LOGGER.info("Validate using the visual prompt.")
|
147
|
+
self.args.half = False
|
148
|
+
# Directly use the same dataloader for visual embeddings extracted during training
|
149
|
+
vpe = self.get_visual_pe(self.dataloader, model)
|
150
|
+
model.set_classes(names, vpe)
|
151
|
+
else:
|
152
|
+
LOGGER.info("Validate using the text prompt.")
|
153
|
+
tpe = model.get_text_pe(names)
|
154
|
+
model.set_classes(names, tpe)
|
155
|
+
stats = super().__call__(trainer, model)
|
156
|
+
else:
|
157
|
+
if refer_data is not None:
|
158
|
+
assert load_vp, "Refer data is only used for visual prompt validation."
|
159
|
+
self.device = select_device(self.args.device)
|
160
|
+
|
161
|
+
if isinstance(model, str):
|
162
|
+
from ultralytics.nn.tasks import attempt_load_weights
|
163
|
+
|
164
|
+
model = attempt_load_weights(model, device=self.device, inplace=True)
|
165
|
+
model.eval().to(self.device)
|
166
|
+
data = check_det_dataset(refer_data or self.args.data)
|
167
|
+
names = [name.split("/")[0] for name in list(data["names"].values())]
|
168
|
+
|
169
|
+
if load_vp:
|
170
|
+
LOGGER.info("Validate using the visual prompt.")
|
171
|
+
self.args.half = False
|
172
|
+
# TODO: need to check if the names from refer data is consistent with the evaluated dataset
|
173
|
+
# could use same dataset or refer to extract visual prompt embeddings
|
174
|
+
dataloader = self.get_vpe_dataloader(data)
|
175
|
+
vpe = self.get_visual_pe(dataloader, model)
|
176
|
+
model.set_classes(names, vpe)
|
177
|
+
stats = super().__call__(model=deepcopy(model))
|
178
|
+
elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
|
179
|
+
return super().__call__(trainer, model)
|
180
|
+
else:
|
181
|
+
LOGGER.info("Validate using the text prompt.")
|
182
|
+
tpe = model.get_text_pe(names)
|
183
|
+
model.set_classes(names, tpe)
|
184
|
+
stats = super().__call__(model=deepcopy(model))
|
185
|
+
return stats
|
186
|
+
|
187
|
+
|
188
|
+
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
|
189
|
+
"""YOLOE segmentation validator that supports both text and visual prompt embeddings."""
|
190
|
+
|
191
|
+
pass
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from .tasks import (
|
4
|
+
BaseModel,
|
5
|
+
ClassificationModel,
|
6
|
+
DetectionModel,
|
7
|
+
SegmentationModel,
|
8
|
+
attempt_load_one_weight,
|
9
|
+
attempt_load_weights,
|
10
|
+
guess_model_scale,
|
11
|
+
guess_model_task,
|
12
|
+
parse_model,
|
13
|
+
torch_safe_load,
|
14
|
+
yaml_model_load,
|
15
|
+
)
|
16
|
+
|
17
|
+
__all__ = (
|
18
|
+
"attempt_load_one_weight",
|
19
|
+
"attempt_load_weights",
|
20
|
+
"parse_model",
|
21
|
+
"yaml_model_load",
|
22
|
+
"guess_model_task",
|
23
|
+
"guess_model_scale",
|
24
|
+
"torch_safe_load",
|
25
|
+
"DetectionModel",
|
26
|
+
"SegmentationModel",
|
27
|
+
"ClassificationModel",
|
28
|
+
"BaseModel",
|
29
|
+
)
|