dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
4
5
|
from copy import copy, deepcopy
|
|
6
|
+
from pathlib import Path
|
|
5
7
|
|
|
6
8
|
import torch
|
|
7
9
|
|
|
@@ -10,21 +12,29 @@ from ultralytics.data.augment import LoadVisualPrompt
|
|
|
10
12
|
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
|
|
11
13
|
from ultralytics.nn.tasks import YOLOEModel
|
|
12
14
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
|
13
|
-
from ultralytics.utils.torch_utils import
|
|
15
|
+
from ultralytics.utils.torch_utils import unwrap_model
|
|
14
16
|
|
|
15
17
|
from ..world.train_world import WorldTrainerFromScratch
|
|
16
18
|
from .val import YOLOEDetectValidator
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class YOLOETrainer(DetectionTrainer):
|
|
20
|
-
"""A
|
|
22
|
+
"""A trainer class for YOLOE object detection models.
|
|
21
23
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
|
|
25
|
+
model initialization, validation, and dataset building with multi-modal support.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
loss_names (tuple): Names of loss components used during training.
|
|
25
29
|
|
|
26
|
-
|
|
27
|
-
|
|
30
|
+
Methods:
|
|
31
|
+
get_model: Initialize and return a YOLOEModel with specified configuration.
|
|
32
|
+
get_validator: Return a YOLOEDetectValidator for model validation.
|
|
33
|
+
build_dataset: Build YOLO dataset with multi-modal support for training.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
37
|
+
"""Initialize the YOLOE Trainer with specified configurations.
|
|
28
38
|
|
|
29
39
|
Args:
|
|
30
40
|
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
|
@@ -33,17 +43,17 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
33
43
|
"""
|
|
34
44
|
if overrides is None:
|
|
35
45
|
overrides = {}
|
|
46
|
+
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
|
36
47
|
overrides["overlap_mask"] = False
|
|
37
48
|
super().__init__(cfg, overrides, _callbacks)
|
|
38
49
|
|
|
39
|
-
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
40
|
-
"""
|
|
41
|
-
Return a YOLOEModel initialized with the specified configuration and weights.
|
|
50
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
51
|
+
"""Return a YOLOEModel initialized with the specified configuration and weights.
|
|
42
52
|
|
|
43
53
|
Args:
|
|
44
|
-
cfg (dict | str
|
|
45
|
-
|
|
46
|
-
weights (str | Path
|
|
54
|
+
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
|
|
55
|
+
path to a YAML file, or None to use default configuration.
|
|
56
|
+
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
|
47
57
|
verbose (bool): Whether to display model information during initialization.
|
|
48
58
|
|
|
49
59
|
Returns:
|
|
@@ -68,36 +78,41 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
68
78
|
return model
|
|
69
79
|
|
|
70
80
|
def get_validator(self):
|
|
71
|
-
"""
|
|
81
|
+
"""Return a YOLOEDetectValidator for YOLOE model validation."""
|
|
72
82
|
self.loss_names = "box", "cls", "dfl"
|
|
73
83
|
return YOLOEDetectValidator(
|
|
74
84
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
75
85
|
)
|
|
76
86
|
|
|
77
|
-
def build_dataset(self, img_path, mode="train", batch=None):
|
|
78
|
-
"""
|
|
79
|
-
Build YOLO Dataset.
|
|
87
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
88
|
+
"""Build YOLO Dataset.
|
|
80
89
|
|
|
81
90
|
Args:
|
|
82
91
|
img_path (str): Path to the folder containing images.
|
|
83
|
-
mode (str):
|
|
84
|
-
batch (int, optional): Size of batches, this is for
|
|
92
|
+
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
|
93
|
+
batch (int, optional): Size of batches, this is for rectangular training.
|
|
85
94
|
|
|
86
95
|
Returns:
|
|
87
96
|
(Dataset): YOLO dataset configured for training or validation.
|
|
88
97
|
"""
|
|
89
|
-
gs = max(int(
|
|
98
|
+
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
|
90
99
|
return build_yolo_dataset(
|
|
91
100
|
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
|
92
101
|
)
|
|
93
102
|
|
|
94
103
|
|
|
95
104
|
class YOLOEPETrainer(DetectionTrainer):
|
|
96
|
-
"""Fine-tune YOLOE model
|
|
105
|
+
"""Fine-tune YOLOE model using linear probing approach.
|
|
97
106
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
107
|
+
This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
|
|
108
|
+
datasets while preserving pretrained features.
|
|
109
|
+
|
|
110
|
+
Methods:
|
|
111
|
+
get_model: Initialize YOLOEModel with frozen layers except projection layers.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
115
|
+
"""Return YOLOEModel initialized with specified config and weights.
|
|
101
116
|
|
|
102
117
|
Args:
|
|
103
118
|
cfg (dict | str, optional): Model configuration.
|
|
@@ -139,17 +154,24 @@ class YOLOEPETrainer(DetectionTrainer):
|
|
|
139
154
|
|
|
140
155
|
|
|
141
156
|
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
142
|
-
"""Train YOLOE models from scratch.
|
|
157
|
+
"""Train YOLOE models from scratch with text embedding support.
|
|
143
158
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
159
|
+
This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
|
|
160
|
+
text embeddings and grounding datasets.
|
|
161
|
+
|
|
162
|
+
Methods:
|
|
163
|
+
build_dataset: Build datasets for training with grounding support.
|
|
164
|
+
generate_text_embeddings: Generate and cache text embeddings for training.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
168
|
+
"""Build YOLO Dataset for training or validation.
|
|
147
169
|
|
|
148
|
-
This method constructs appropriate datasets based on the mode and input paths, handling both
|
|
149
|
-
|
|
170
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
171
|
+
datasets and grounding datasets with different formats.
|
|
150
172
|
|
|
151
173
|
Args:
|
|
152
|
-
img_path (
|
|
174
|
+
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
153
175
|
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
|
154
176
|
batch (int, optional): Size of batches, used for rectangular training/validation.
|
|
155
177
|
|
|
@@ -158,22 +180,11 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
|
158
180
|
"""
|
|
159
181
|
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
|
160
182
|
|
|
161
|
-
def
|
|
162
|
-
"""
|
|
163
|
-
batch = DetectionTrainer.preprocess_batch(self, batch)
|
|
164
|
-
|
|
165
|
-
texts = list(itertools.chain(*batch["texts"]))
|
|
166
|
-
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
|
|
167
|
-
txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
|
168
|
-
batch["txt_feats"] = txt_feats
|
|
169
|
-
return batch
|
|
170
|
-
|
|
171
|
-
def generate_text_embeddings(self, texts, batch, cache_dir):
|
|
172
|
-
"""
|
|
173
|
-
Generate text embeddings for a list of text samples.
|
|
183
|
+
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
|
184
|
+
"""Generate text embeddings for a list of text samples.
|
|
174
185
|
|
|
175
186
|
Args:
|
|
176
|
-
texts (
|
|
187
|
+
texts (list[str]): List of text samples to encode.
|
|
177
188
|
batch (int): Batch size for processing.
|
|
178
189
|
cache_dir (Path): Directory to save/load cached embeddings.
|
|
179
190
|
|
|
@@ -184,42 +195,49 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
|
184
195
|
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
|
185
196
|
if cache_path.exists():
|
|
186
197
|
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
|
187
|
-
txt_map = torch.load(cache_path)
|
|
198
|
+
txt_map = torch.load(cache_path, map_location=self.device)
|
|
188
199
|
if sorted(txt_map.keys()) == sorted(texts):
|
|
189
200
|
return txt_map
|
|
190
201
|
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
|
191
202
|
assert self.model is not None
|
|
192
|
-
txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
|
203
|
+
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
|
193
204
|
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
|
194
205
|
torch.save(txt_map, cache_path)
|
|
195
206
|
return txt_map
|
|
196
207
|
|
|
197
208
|
|
|
198
209
|
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
199
|
-
"""Train prompt-free YOLOE model.
|
|
210
|
+
"""Train prompt-free YOLOE model.
|
|
211
|
+
|
|
212
|
+
This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
|
|
213
|
+
require text prompts during inference.
|
|
214
|
+
|
|
215
|
+
Methods:
|
|
216
|
+
get_validator: Return standard DetectionValidator for validation.
|
|
217
|
+
preprocess_batch: Preprocess batches without text features.
|
|
218
|
+
set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
|
|
219
|
+
"""
|
|
200
220
|
|
|
201
221
|
def get_validator(self):
|
|
202
|
-
"""
|
|
222
|
+
"""Return a DetectionValidator for YOLO model validation."""
|
|
203
223
|
self.loss_names = "box", "cls", "dfl"
|
|
204
224
|
return DetectionValidator(
|
|
205
225
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
206
226
|
)
|
|
207
227
|
|
|
208
228
|
def preprocess_batch(self, batch):
|
|
209
|
-
"""
|
|
210
|
-
|
|
211
|
-
return batch
|
|
229
|
+
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
|
|
230
|
+
return DetectionTrainer.preprocess_batch(self, batch)
|
|
212
231
|
|
|
213
|
-
def set_text_embeddings(self, datasets, batch):
|
|
214
|
-
"""
|
|
215
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
|
232
|
+
def set_text_embeddings(self, datasets, batch: int):
|
|
233
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
216
234
|
|
|
217
|
-
This method collects unique category names from all datasets, generates text embeddings for them,
|
|
218
|
-
|
|
219
|
-
|
|
235
|
+
This method collects unique category names from all datasets, generates text embeddings for them, and caches
|
|
236
|
+
these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
|
|
237
|
+
the first dataset's image path.
|
|
220
238
|
|
|
221
239
|
Args:
|
|
222
|
-
datasets (
|
|
240
|
+
datasets (list[Dataset]): List of datasets containing category names to process.
|
|
223
241
|
batch (int): Batch size for processing text embeddings.
|
|
224
242
|
|
|
225
243
|
Notes:
|
|
@@ -231,14 +249,20 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
|
231
249
|
|
|
232
250
|
|
|
233
251
|
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|
234
|
-
"""Train YOLOE model with visual prompts.
|
|
252
|
+
"""Train YOLOE model with visual prompts.
|
|
235
253
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
254
|
+
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
|
|
255
|
+
alongside images to guide the detection process.
|
|
256
|
+
|
|
257
|
+
Methods:
|
|
258
|
+
build_dataset: Build dataset with visual prompt loading transforms.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
262
|
+
"""Build YOLO Dataset for training or validation with visual prompts.
|
|
239
263
|
|
|
240
264
|
Args:
|
|
241
|
-
img_path (
|
|
265
|
+
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
242
266
|
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
|
243
267
|
batch (int, optional): Size of batches, used for rectangular training/validation.
|
|
244
268
|
|
|
@@ -261,9 +285,3 @@ class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|
|
261
285
|
d.transforms.append(LoadVisualPrompt())
|
|
262
286
|
else:
|
|
263
287
|
self.train_loader.dataset.transforms.append(LoadVisualPrompt())
|
|
264
|
-
|
|
265
|
-
def preprocess_batch(self, batch):
|
|
266
|
-
"""Preprocesses a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
|
|
267
|
-
batch = super().preprocess_batch(batch)
|
|
268
|
-
batch["visuals"] = batch["visuals"].to(self.device)
|
|
269
|
-
return batch
|
|
@@ -11,11 +11,10 @@ from .val import YOLOESegValidator
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
14
|
-
"""
|
|
15
|
-
Trainer class for YOLOE segmentation models.
|
|
14
|
+
"""Trainer class for YOLOE segmentation models.
|
|
16
15
|
|
|
17
|
-
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality
|
|
18
|
-
|
|
16
|
+
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
|
17
|
+
segmentation models, enabling both object detection and instance segmentation capabilities.
|
|
19
18
|
|
|
20
19
|
Attributes:
|
|
21
20
|
cfg (dict): Configuration dictionary with training parameters.
|
|
@@ -24,11 +23,10 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
25
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
27
|
-
"""
|
|
28
|
-
Return YOLOESegModel initialized with specified config and weights.
|
|
26
|
+
"""Return YOLOESegModel initialized with specified config and weights.
|
|
29
27
|
|
|
30
28
|
Args:
|
|
31
|
-
cfg (dict | str): Model configuration dictionary or YAML file path.
|
|
29
|
+
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
32
30
|
weights (str, optional): Path to pretrained weights file.
|
|
33
31
|
verbose (bool): Whether to display model information.
|
|
34
32
|
|
|
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
49
47
|
return model
|
|
50
48
|
|
|
51
49
|
def get_validator(self):
|
|
52
|
-
"""
|
|
53
|
-
Create and return a validator for YOLOE segmentation model evaluation.
|
|
50
|
+
"""Create and return a validator for YOLOE segmentation model evaluation.
|
|
54
51
|
|
|
55
52
|
Returns:
|
|
56
53
|
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
|
@@ -62,19 +59,20 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
62
59
|
|
|
63
60
|
|
|
64
61
|
class YOLOEPESegTrainer(SegmentationTrainer):
|
|
65
|
-
"""
|
|
66
|
-
Fine-tune YOLOESeg model in linear probing way.
|
|
62
|
+
"""Fine-tune YOLOESeg model in linear probing way.
|
|
67
63
|
|
|
68
64
|
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
|
69
|
-
most of the model and only training specific layers.
|
|
65
|
+
most of the model and only training specific layers for efficient adaptation to new tasks.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
data (dict): Dataset configuration containing channels, class names, and number of classes.
|
|
70
69
|
"""
|
|
71
70
|
|
|
72
71
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
73
|
-
"""
|
|
74
|
-
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
|
72
|
+
"""Return YOLOESegModel initialized with specified config and weights for linear probing.
|
|
75
73
|
|
|
76
74
|
Args:
|
|
77
|
-
cfg (dict | str): Model configuration dictionary or YAML file path.
|
|
75
|
+
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
78
76
|
weights (str, optional): Path to pretrained weights file.
|
|
79
77
|
verbose (bool): Whether to display model information.
|
|
80
78
|
|
|
@@ -113,12 +111,12 @@ class YOLOEPESegTrainer(SegmentationTrainer):
|
|
|
113
111
|
|
|
114
112
|
|
|
115
113
|
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
|
|
116
|
-
"""Trainer for YOLOE segmentation from scratch."""
|
|
114
|
+
"""Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
|
|
117
115
|
|
|
118
116
|
pass
|
|
119
117
|
|
|
120
118
|
|
|
121
119
|
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
|
|
122
|
-
"""Trainer for YOLOE segmentation with VP."""
|
|
120
|
+
"""Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
|
|
123
121
|
|
|
124
122
|
pass
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
4
8
|
|
|
5
9
|
import torch
|
|
6
10
|
from torch.nn import functional as F
|
|
@@ -17,27 +21,39 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class YOLOEDetectValidator(DetectionValidator):
|
|
20
|
-
"""
|
|
21
|
-
A mixin class for YOLOE model validation that handles both text and visual prompt embeddings.
|
|
24
|
+
"""A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
|
22
25
|
|
|
23
|
-
This
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
|
|
27
|
+
validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
|
|
28
|
+
evaluation strategies for prompt-based object detection.
|
|
26
29
|
|
|
27
30
|
Attributes:
|
|
28
31
|
device (torch.device): The device on which validation is performed.
|
|
29
32
|
args (namespace): Configuration arguments for validation.
|
|
30
33
|
dataloader (DataLoader): DataLoader for validation data.
|
|
34
|
+
|
|
35
|
+
Methods:
|
|
36
|
+
get_visual_pe: Extract visual prompt embeddings from training samples.
|
|
37
|
+
preprocess: Preprocess batch data ensuring visuals are on the same device as images.
|
|
38
|
+
get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
|
|
39
|
+
__call__: Run validation using either text or visual prompt embeddings.
|
|
40
|
+
|
|
41
|
+
Examples:
|
|
42
|
+
Validate with text prompts
|
|
43
|
+
>>> validator = YOLOEDetectValidator()
|
|
44
|
+
>>> stats = validator(model=model, load_vp=False)
|
|
45
|
+
|
|
46
|
+
Validate with visual prompts
|
|
47
|
+
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
|
|
31
48
|
"""
|
|
32
49
|
|
|
33
50
|
@smart_inference_mode()
|
|
34
|
-
def get_visual_pe(self, dataloader, model):
|
|
35
|
-
"""
|
|
36
|
-
Extract visual prompt embeddings from training samples.
|
|
51
|
+
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
|
52
|
+
"""Extract visual prompt embeddings from training samples.
|
|
37
53
|
|
|
38
|
-
This
|
|
39
|
-
|
|
40
|
-
|
|
54
|
+
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
|
|
55
|
+
normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
|
|
56
|
+
zero.
|
|
41
57
|
|
|
42
58
|
Args:
|
|
43
59
|
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
|
@@ -47,12 +63,13 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
47
63
|
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
|
48
64
|
"""
|
|
49
65
|
assert isinstance(model, YOLOEModel)
|
|
50
|
-
names = [name.split("/")[0] for name in list(dataloader.dataset.data["names"].values())]
|
|
66
|
+
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
|
|
51
67
|
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
|
52
68
|
cls_visual_num = torch.zeros(len(names))
|
|
53
69
|
|
|
54
70
|
desc = "Get visual prompt embeddings from samples"
|
|
55
71
|
|
|
72
|
+
# Count samples per class
|
|
56
73
|
for batch in dataloader:
|
|
57
74
|
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
|
|
58
75
|
count = torch.bincount(cls, minlength=len(names))
|
|
@@ -60,6 +77,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
60
77
|
|
|
61
78
|
cls_visual_num = cls_visual_num.to(self.device)
|
|
62
79
|
|
|
80
|
+
# Extract visual prompt embeddings
|
|
63
81
|
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
|
|
64
82
|
for batch in pbar:
|
|
65
83
|
batch = self.preprocess(batch)
|
|
@@ -69,34 +87,26 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
69
87
|
for i in range(preds.shape[0]):
|
|
70
88
|
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
|
71
89
|
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
|
72
|
-
pad_cls[:
|
|
90
|
+
pad_cls[: cls.shape[0]] = cls
|
|
73
91
|
for c in cls:
|
|
74
92
|
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
|
75
93
|
|
|
94
|
+
# Normalize embeddings for classes with samples, set others to zero
|
|
76
95
|
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
|
|
77
96
|
visual_pe[cls_visual_num == 0] = 0
|
|
78
97
|
return visual_pe.unsqueeze(0)
|
|
79
98
|
|
|
80
|
-
def
|
|
81
|
-
"""
|
|
82
|
-
batch = super().preprocess(batch)
|
|
83
|
-
if "visuals" in batch:
|
|
84
|
-
batch["visuals"] = batch["visuals"].to(batch["img"].device)
|
|
85
|
-
return batch
|
|
86
|
-
|
|
87
|
-
def get_vpe_dataloader(self, data):
|
|
88
|
-
"""
|
|
89
|
-
Create a dataloader for LVIS training visual prompt samples.
|
|
99
|
+
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
|
100
|
+
"""Create a dataloader for LVIS training visual prompt samples.
|
|
90
101
|
|
|
91
|
-
This
|
|
92
|
-
|
|
93
|
-
for validation purposes.
|
|
102
|
+
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
|
|
103
|
+
necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
|
|
94
104
|
|
|
95
105
|
Args:
|
|
96
106
|
data (dict): Dataset configuration dictionary containing paths and settings.
|
|
97
107
|
|
|
98
108
|
Returns:
|
|
99
|
-
(torch.utils.data.DataLoader): The
|
|
109
|
+
(torch.utils.data.DataLoader): The dataloader for visual prompt samples.
|
|
100
110
|
"""
|
|
101
111
|
dataset = build_yolo_dataset(
|
|
102
112
|
self.args,
|
|
@@ -120,17 +130,22 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
120
130
|
)
|
|
121
131
|
|
|
122
132
|
@smart_inference_mode()
|
|
123
|
-
def __call__(
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
133
|
+
def __call__(
|
|
134
|
+
self,
|
|
135
|
+
trainer: Any | None = None,
|
|
136
|
+
model: YOLOEModel | str | None = None,
|
|
137
|
+
refer_data: str | None = None,
|
|
138
|
+
load_vp: bool = False,
|
|
139
|
+
) -> dict[str, Any]:
|
|
140
|
+
"""Run validation on the model using either text or visual prompt embeddings.
|
|
141
|
+
|
|
142
|
+
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
|
|
143
|
+
supports validation during training (using a trainer object) or standalone validation with a provided model. For
|
|
144
|
+
visual prompts, reference data can be specified to extract embeddings from a different dataset.
|
|
130
145
|
|
|
131
146
|
Args:
|
|
132
147
|
trainer (object, optional): Trainer object containing the model and device.
|
|
133
|
-
model (YOLOEModel, optional): Model to validate. Required if
|
|
148
|
+
model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
|
|
134
149
|
refer_data (str, optional): Path to reference data for visual prompts.
|
|
135
150
|
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
|
136
151
|
|
|
@@ -140,7 +155,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
140
155
|
if trainer is not None:
|
|
141
156
|
self.device = trainer.device
|
|
142
157
|
model = trainer.ema.ema
|
|
143
|
-
names = [name.split("/")[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
|
158
|
+
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
|
144
159
|
|
|
145
160
|
if load_vp:
|
|
146
161
|
LOGGER.info("Validate using the visual prompt.")
|
|
@@ -156,15 +171,15 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
156
171
|
else:
|
|
157
172
|
if refer_data is not None:
|
|
158
173
|
assert load_vp, "Refer data is only used for visual prompt validation."
|
|
159
|
-
self.device = select_device(self.args.device)
|
|
174
|
+
self.device = select_device(self.args.device, verbose=False)
|
|
160
175
|
|
|
161
|
-
if isinstance(model, str):
|
|
162
|
-
from ultralytics.nn.tasks import
|
|
176
|
+
if isinstance(model, (str, Path)):
|
|
177
|
+
from ultralytics.nn.tasks import load_checkpoint
|
|
163
178
|
|
|
164
|
-
model =
|
|
179
|
+
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
|
|
165
180
|
model.eval().to(self.device)
|
|
166
181
|
data = check_det_dataset(refer_data or self.args.data)
|
|
167
|
-
names = [name.split("/")[0] for name in list(data["names"].values())]
|
|
182
|
+
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
|
168
183
|
|
|
169
184
|
if load_vp:
|
|
170
185
|
LOGGER.info("Validate using the visual prompt.")
|
ultralytics/nn/__init__.py
CHANGED
|
@@ -5,25 +5,23 @@ from .tasks import (
|
|
|
5
5
|
ClassificationModel,
|
|
6
6
|
DetectionModel,
|
|
7
7
|
SegmentationModel,
|
|
8
|
-
attempt_load_one_weight,
|
|
9
|
-
attempt_load_weights,
|
|
10
8
|
guess_model_scale,
|
|
11
9
|
guess_model_task,
|
|
10
|
+
load_checkpoint,
|
|
12
11
|
parse_model,
|
|
13
12
|
torch_safe_load,
|
|
14
13
|
yaml_model_load,
|
|
15
14
|
)
|
|
16
15
|
|
|
17
16
|
__all__ = (
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"parse_model",
|
|
21
|
-
"yaml_model_load",
|
|
22
|
-
"guess_model_task",
|
|
23
|
-
"guess_model_scale",
|
|
24
|
-
"torch_safe_load",
|
|
17
|
+
"BaseModel",
|
|
18
|
+
"ClassificationModel",
|
|
25
19
|
"DetectionModel",
|
|
26
20
|
"SegmentationModel",
|
|
27
|
-
"
|
|
28
|
-
"
|
|
21
|
+
"guess_model_scale",
|
|
22
|
+
"guess_model_task",
|
|
23
|
+
"load_checkpoint",
|
|
24
|
+
"parse_model",
|
|
25
|
+
"torch_safe_load",
|
|
26
|
+
"yaml_model_load",
|
|
29
27
|
)
|