dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import itertools
|
|
6
5
|
from copy import copy, deepcopy
|
|
7
6
|
from pathlib import Path
|
|
8
7
|
|
|
@@ -20,11 +19,10 @@ from .val import YOLOEDetectValidator
|
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class YOLOETrainer(DetectionTrainer):
|
|
23
|
-
"""
|
|
24
|
-
A trainer class for YOLOE object detection models.
|
|
22
|
+
"""A trainer class for YOLOE object detection models.
|
|
25
23
|
|
|
26
|
-
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
|
|
27
|
-
|
|
24
|
+
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
|
|
25
|
+
model initialization, validation, and dataset building with multi-modal support.
|
|
28
26
|
|
|
29
27
|
Attributes:
|
|
30
28
|
loss_names (tuple): Names of loss components used during training.
|
|
@@ -36,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
36
34
|
"""
|
|
37
35
|
|
|
38
36
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
39
|
-
"""
|
|
40
|
-
Initialize the YOLOE Trainer with specified configurations.
|
|
37
|
+
"""Initialize the YOLOE Trainer with specified configurations.
|
|
41
38
|
|
|
42
39
|
Args:
|
|
43
40
|
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
|
@@ -46,16 +43,16 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
46
43
|
"""
|
|
47
44
|
if overrides is None:
|
|
48
45
|
overrides = {}
|
|
46
|
+
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
|
49
47
|
overrides["overlap_mask"] = False
|
|
50
48
|
super().__init__(cfg, overrides, _callbacks)
|
|
51
49
|
|
|
52
50
|
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
53
|
-
"""
|
|
54
|
-
Return a YOLOEModel initialized with the specified configuration and weights.
|
|
51
|
+
"""Return a YOLOEModel initialized with the specified configuration and weights.
|
|
55
52
|
|
|
56
53
|
Args:
|
|
57
|
-
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
|
|
58
|
-
|
|
54
|
+
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
|
|
55
|
+
path to a YAML file, or None to use default configuration.
|
|
59
56
|
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
|
60
57
|
verbose (bool): Whether to display model information during initialization.
|
|
61
58
|
|
|
@@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
88
85
|
)
|
|
89
86
|
|
|
90
87
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
91
|
-
"""
|
|
92
|
-
Build YOLO Dataset.
|
|
88
|
+
"""Build YOLO Dataset.
|
|
93
89
|
|
|
94
90
|
Args:
|
|
95
91
|
img_path (str): Path to the folder containing images.
|
|
@@ -106,19 +102,17 @@ class YOLOETrainer(DetectionTrainer):
|
|
|
106
102
|
|
|
107
103
|
|
|
108
104
|
class YOLOEPETrainer(DetectionTrainer):
|
|
109
|
-
"""
|
|
110
|
-
Fine-tune YOLOE model using linear probing approach.
|
|
105
|
+
"""Fine-tune YOLOE model using linear probing approach.
|
|
111
106
|
|
|
112
|
-
This trainer freezes most model layers and only trains specific projection layers for efficient
|
|
113
|
-
|
|
107
|
+
This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
|
|
108
|
+
datasets while preserving pretrained features.
|
|
114
109
|
|
|
115
110
|
Methods:
|
|
116
111
|
get_model: Initialize YOLOEModel with frozen layers except projection layers.
|
|
117
112
|
"""
|
|
118
113
|
|
|
119
114
|
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
120
|
-
"""
|
|
121
|
-
Return YOLOEModel initialized with specified config and weights.
|
|
115
|
+
"""Return YOLOEModel initialized with specified config and weights.
|
|
122
116
|
|
|
123
117
|
Args:
|
|
124
118
|
cfg (dict | str, optional): Model configuration.
|
|
@@ -160,24 +154,21 @@ class YOLOEPETrainer(DetectionTrainer):
|
|
|
160
154
|
|
|
161
155
|
|
|
162
156
|
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
163
|
-
"""
|
|
164
|
-
Train YOLOE models from scratch with text embedding support.
|
|
157
|
+
"""Train YOLOE models from scratch with text embedding support.
|
|
165
158
|
|
|
166
|
-
This trainer combines YOLOE training capabilities with world training features, enabling
|
|
167
|
-
|
|
159
|
+
This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
|
|
160
|
+
text embeddings and grounding datasets.
|
|
168
161
|
|
|
169
162
|
Methods:
|
|
170
163
|
build_dataset: Build datasets for training with grounding support.
|
|
171
|
-
preprocess_batch: Process batches with text features.
|
|
172
164
|
generate_text_embeddings: Generate and cache text embeddings for training.
|
|
173
165
|
"""
|
|
174
166
|
|
|
175
167
|
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
176
|
-
"""
|
|
177
|
-
Build YOLO Dataset for training or validation.
|
|
168
|
+
"""Build YOLO Dataset for training or validation.
|
|
178
169
|
|
|
179
|
-
This method constructs appropriate datasets based on the mode and input paths, handling both
|
|
180
|
-
|
|
170
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
171
|
+
datasets and grounding datasets with different formats.
|
|
181
172
|
|
|
182
173
|
Args:
|
|
183
174
|
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
@@ -189,19 +180,8 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
|
189
180
|
"""
|
|
190
181
|
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
|
191
182
|
|
|
192
|
-
def preprocess_batch(self, batch):
|
|
193
|
-
"""Process batch for training, moving text features to the appropriate device."""
|
|
194
|
-
batch = DetectionTrainer.preprocess_batch(self, batch)
|
|
195
|
-
|
|
196
|
-
texts = list(itertools.chain(*batch["texts"]))
|
|
197
|
-
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
|
|
198
|
-
txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
|
199
|
-
batch["txt_feats"] = txt_feats
|
|
200
|
-
return batch
|
|
201
|
-
|
|
202
183
|
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
|
203
|
-
"""
|
|
204
|
-
Generate text embeddings for a list of text samples.
|
|
184
|
+
"""Generate text embeddings for a list of text samples.
|
|
205
185
|
|
|
206
186
|
Args:
|
|
207
187
|
texts (list[str]): List of text samples to encode.
|
|
@@ -227,11 +207,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
|
227
207
|
|
|
228
208
|
|
|
229
209
|
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
230
|
-
"""
|
|
231
|
-
Train prompt-free YOLOE model.
|
|
210
|
+
"""Train prompt-free YOLOE model.
|
|
232
211
|
|
|
233
|
-
This trainer combines linear probing capabilities with from-scratch training for prompt-free
|
|
234
|
-
|
|
212
|
+
This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
|
|
213
|
+
require text prompts during inference.
|
|
235
214
|
|
|
236
215
|
Methods:
|
|
237
216
|
get_validator: Return standard DetectionValidator for validation.
|
|
@@ -251,12 +230,11 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
|
251
230
|
return DetectionTrainer.preprocess_batch(self, batch)
|
|
252
231
|
|
|
253
232
|
def set_text_embeddings(self, datasets, batch: int):
|
|
254
|
-
"""
|
|
255
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
|
233
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
256
234
|
|
|
257
|
-
This method collects unique category names from all datasets, generates text embeddings for them,
|
|
258
|
-
|
|
259
|
-
|
|
235
|
+
This method collects unique category names from all datasets, generates text embeddings for them, and caches
|
|
236
|
+
these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
|
|
237
|
+
the first dataset's image path.
|
|
260
238
|
|
|
261
239
|
Args:
|
|
262
240
|
datasets (list[Dataset]): List of datasets containing category names to process.
|
|
@@ -271,20 +249,17 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
|
271
249
|
|
|
272
250
|
|
|
273
251
|
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|
274
|
-
"""
|
|
275
|
-
Train YOLOE model with visual prompts.
|
|
252
|
+
"""Train YOLOE model with visual prompts.
|
|
276
253
|
|
|
277
|
-
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
|
|
278
|
-
|
|
254
|
+
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
|
|
255
|
+
alongside images to guide the detection process.
|
|
279
256
|
|
|
280
257
|
Methods:
|
|
281
258
|
build_dataset: Build dataset with visual prompt loading transforms.
|
|
282
|
-
preprocess_batch: Preprocess batches with visual prompts.
|
|
283
259
|
"""
|
|
284
260
|
|
|
285
261
|
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
286
|
-
"""
|
|
287
|
-
Build YOLO Dataset for training or validation with visual prompts.
|
|
262
|
+
"""Build YOLO Dataset for training or validation with visual prompts.
|
|
288
263
|
|
|
289
264
|
Args:
|
|
290
265
|
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
@@ -11,8 +11,7 @@ from .val import YOLOESegValidator
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
14
|
-
"""
|
|
15
|
-
Trainer class for YOLOE segmentation models.
|
|
14
|
+
"""Trainer class for YOLOE segmentation models.
|
|
16
15
|
|
|
17
16
|
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
|
18
17
|
segmentation models, enabling both object detection and instance segmentation capabilities.
|
|
@@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
24
23
|
"""
|
|
25
24
|
|
|
26
25
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
27
|
-
"""
|
|
28
|
-
Return YOLOESegModel initialized with specified config and weights.
|
|
26
|
+
"""Return YOLOESegModel initialized with specified config and weights.
|
|
29
27
|
|
|
30
28
|
Args:
|
|
31
29
|
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
49
47
|
return model
|
|
50
48
|
|
|
51
49
|
def get_validator(self):
|
|
52
|
-
"""
|
|
53
|
-
Create and return a validator for YOLOE segmentation model evaluation.
|
|
50
|
+
"""Create and return a validator for YOLOE segmentation model evaluation.
|
|
54
51
|
|
|
55
52
|
Returns:
|
|
56
53
|
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
|
@@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
|
62
59
|
|
|
63
60
|
|
|
64
61
|
class YOLOEPESegTrainer(SegmentationTrainer):
|
|
65
|
-
"""
|
|
66
|
-
Fine-tune YOLOESeg model in linear probing way.
|
|
62
|
+
"""Fine-tune YOLOESeg model in linear probing way.
|
|
67
63
|
|
|
68
64
|
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
|
69
65
|
most of the model and only training specific layers for efficient adaptation to new tasks.
|
|
@@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
|
|
|
73
69
|
"""
|
|
74
70
|
|
|
75
71
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
76
|
-
"""
|
|
77
|
-
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
|
72
|
+
"""Return YOLOESegModel initialized with specified config and weights for linear probing.
|
|
78
73
|
|
|
79
74
|
Args:
|
|
80
75
|
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
@@ -21,12 +21,11 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class YOLOEDetectValidator(DetectionValidator):
|
|
24
|
-
"""
|
|
25
|
-
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
|
24
|
+
"""A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
|
26
25
|
|
|
27
|
-
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
|
|
27
|
+
validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
|
|
28
|
+
evaluation strategies for prompt-based object detection.
|
|
30
29
|
|
|
31
30
|
Attributes:
|
|
32
31
|
device (torch.device): The device on which validation is performed.
|
|
@@ -50,12 +49,11 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
50
49
|
|
|
51
50
|
@smart_inference_mode()
|
|
52
51
|
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
|
53
|
-
"""
|
|
54
|
-
Extract visual prompt embeddings from training samples.
|
|
52
|
+
"""Extract visual prompt embeddings from training samples.
|
|
55
53
|
|
|
56
|
-
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
|
|
55
|
+
normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
|
|
56
|
+
zero.
|
|
59
57
|
|
|
60
58
|
Args:
|
|
61
59
|
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
|
@@ -89,7 +87,7 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
89
87
|
for i in range(preds.shape[0]):
|
|
90
88
|
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
|
91
89
|
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
|
92
|
-
pad_cls[:
|
|
90
|
+
pad_cls[: cls.shape[0]] = cls
|
|
93
91
|
for c in cls:
|
|
94
92
|
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
|
95
93
|
|
|
@@ -99,12 +97,10 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
99
97
|
return visual_pe.unsqueeze(0)
|
|
100
98
|
|
|
101
99
|
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
|
102
|
-
"""
|
|
103
|
-
Create a dataloader for LVIS training visual prompt samples.
|
|
100
|
+
"""Create a dataloader for LVIS training visual prompt samples.
|
|
104
101
|
|
|
105
|
-
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
|
|
106
|
-
|
|
107
|
-
for validation purposes.
|
|
102
|
+
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
|
|
103
|
+
necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
|
|
108
104
|
|
|
109
105
|
Args:
|
|
110
106
|
data (dict): Dataset configuration dictionary containing paths and settings.
|
|
@@ -141,12 +137,11 @@ class YOLOEDetectValidator(DetectionValidator):
|
|
|
141
137
|
refer_data: str | None = None,
|
|
142
138
|
load_vp: bool = False,
|
|
143
139
|
) -> dict[str, Any]:
|
|
144
|
-
"""
|
|
145
|
-
Run validation on the model using either text or visual prompt embeddings.
|
|
140
|
+
"""Run validation on the model using either text or visual prompt embeddings.
|
|
146
141
|
|
|
147
|
-
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
|
|
148
|
-
|
|
149
|
-
|
|
142
|
+
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
|
|
143
|
+
supports validation during training (using a trainer object) or standalone validation with a provided model. For
|
|
144
|
+
visual prompts, reference data can be specified to extract embeddings from a different dataset.
|
|
150
145
|
|
|
151
146
|
Args:
|
|
152
147
|
trainer (object, optional): Trainer object containing the model and device.
|
ultralytics/nn/__init__.py
CHANGED
|
@@ -14,14 +14,14 @@ from .tasks import (
|
|
|
14
14
|
)
|
|
15
15
|
|
|
16
16
|
__all__ = (
|
|
17
|
+
"BaseModel",
|
|
18
|
+
"ClassificationModel",
|
|
19
|
+
"DetectionModel",
|
|
20
|
+
"SegmentationModel",
|
|
21
|
+
"guess_model_scale",
|
|
22
|
+
"guess_model_task",
|
|
17
23
|
"load_checkpoint",
|
|
18
24
|
"parse_model",
|
|
19
|
-
"yaml_model_load",
|
|
20
|
-
"guess_model_task",
|
|
21
|
-
"guess_model_scale",
|
|
22
25
|
"torch_safe_load",
|
|
23
|
-
"
|
|
24
|
-
"SegmentationModel",
|
|
25
|
-
"ClassificationModel",
|
|
26
|
-
"BaseModel",
|
|
26
|
+
"yaml_model_load",
|
|
27
27
|
)
|