dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
|
|
4
|
+
from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
|
5
|
+
from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
|
|
6
|
+
from .val import YOLOEDetectValidator, YOLOESegValidator
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"YOLOEDetectValidator",
|
|
10
|
+
"YOLOEPEFreeTrainer",
|
|
11
|
+
"YOLOEPESegTrainer",
|
|
12
|
+
"YOLOEPETrainer",
|
|
13
|
+
"YOLOESegTrainer",
|
|
14
|
+
"YOLOESegTrainerFromScratch",
|
|
15
|
+
"YOLOESegVPTrainer",
|
|
16
|
+
"YOLOESegValidator",
|
|
17
|
+
"YOLOETrainer",
|
|
18
|
+
"YOLOETrainerFromScratch",
|
|
19
|
+
"YOLOEVPDetectPredictor",
|
|
20
|
+
"YOLOEVPSegPredictor",
|
|
21
|
+
"YOLOEVPTrainer",
|
|
22
|
+
]
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ultralytics.data.augment import LoadVisualPrompt
|
|
7
|
+
from ultralytics.models.yolo.detect import DetectionPredictor
|
|
8
|
+
from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
12
|
+
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
13
|
+
|
|
14
|
+
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
|
|
15
|
+
handling, and preprocessing transformations.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
model (torch.nn.Module): The YOLO model for inference.
|
|
19
|
+
device (torch.device): Device to run the model on (CPU or CUDA).
|
|
20
|
+
prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
|
|
21
|
+
|
|
22
|
+
Methods:
|
|
23
|
+
setup_model: Initialize the YOLO model and set it to evaluation mode.
|
|
24
|
+
set_prompts: Set the visual prompts for the model.
|
|
25
|
+
pre_transform: Preprocess images and prompts before inference.
|
|
26
|
+
inference: Run inference with visual prompts.
|
|
27
|
+
get_vpe: Process source to get visual prompt embeddings.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def setup_model(self, model, verbose: bool = True):
|
|
31
|
+
"""Set up the model for prediction.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
model (torch.nn.Module): Model to load or use.
|
|
35
|
+
verbose (bool, optional): If True, provides detailed logging.
|
|
36
|
+
"""
|
|
37
|
+
super().setup_model(model, verbose=verbose)
|
|
38
|
+
self.done_warmup = True
|
|
39
|
+
|
|
40
|
+
def set_prompts(self, prompts):
|
|
41
|
+
"""Set the visual prompts for the model.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
|
|
45
|
+
with class indices.
|
|
46
|
+
"""
|
|
47
|
+
self.prompts = prompts
|
|
48
|
+
|
|
49
|
+
def pre_transform(self, im):
|
|
50
|
+
"""Preprocess images and prompts before inference.
|
|
51
|
+
|
|
52
|
+
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
|
|
53
|
+
accordingly.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
im (list): List containing a single input image.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
(list): Preprocessed image ready for model inference.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
|
|
63
|
+
"""
|
|
64
|
+
img = super().pre_transform(im)
|
|
65
|
+
bboxes = self.prompts.pop("bboxes", None)
|
|
66
|
+
masks = self.prompts.pop("masks", None)
|
|
67
|
+
category = self.prompts["cls"]
|
|
68
|
+
if len(img) == 1:
|
|
69
|
+
visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
|
|
70
|
+
prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
|
|
71
|
+
else:
|
|
72
|
+
# NOTE: only supports bboxes as prompts for now
|
|
73
|
+
assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
|
|
74
|
+
# NOTE: needs list[np.ndarray]
|
|
75
|
+
assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
|
|
76
|
+
f"Expected list[np.ndarray], but got {bboxes}!"
|
|
77
|
+
)
|
|
78
|
+
assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
|
|
79
|
+
f"Expected list[np.ndarray], but got {category}!"
|
|
80
|
+
)
|
|
81
|
+
assert len(im) == len(category) == len(bboxes), (
|
|
82
|
+
f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
|
|
83
|
+
)
|
|
84
|
+
visuals = [
|
|
85
|
+
self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
|
|
86
|
+
for i in range(len(img))
|
|
87
|
+
]
|
|
88
|
+
prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
|
|
89
|
+
self.prompts = prompts.half() if self.model.fp16 else prompts.float()
|
|
90
|
+
return img
|
|
91
|
+
|
|
92
|
+
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
|
93
|
+
"""Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
dst_shape (tuple): The target shape (height, width) of the image.
|
|
97
|
+
src_shape (tuple): The original shape (height, width) of the image.
|
|
98
|
+
category (str): The category of the image for visual prompts.
|
|
99
|
+
bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
|
|
100
|
+
masks (np.ndarray, optional): A list of masks corresponding to the image.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
(torch.Tensor): The processed visuals for the image.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If neither `bboxes` nor `masks` are provided.
|
|
107
|
+
"""
|
|
108
|
+
if bboxes is not None and len(bboxes):
|
|
109
|
+
bboxes = np.array(bboxes, dtype=np.float32)
|
|
110
|
+
if bboxes.ndim == 1:
|
|
111
|
+
bboxes = bboxes[None, :]
|
|
112
|
+
# Calculate scaling factor and adjust bounding boxes
|
|
113
|
+
gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
|
|
114
|
+
bboxes *= gain
|
|
115
|
+
bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
|
|
116
|
+
bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
|
|
117
|
+
elif masks is not None:
|
|
118
|
+
# Resize and process masks
|
|
119
|
+
resized_masks = super().pre_transform(masks)
|
|
120
|
+
masks = np.stack(resized_masks) # (N, H, W)
|
|
121
|
+
masks[masks == 114] = 0 # Reset padding values to 0
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError("Please provide valid bboxes or masks")
|
|
124
|
+
|
|
125
|
+
# Generate visuals using the visual prompt loader
|
|
126
|
+
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
|
127
|
+
|
|
128
|
+
def inference(self, im, *args, **kwargs):
|
|
129
|
+
"""Run inference with visual prompts.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
im (torch.Tensor): Input image tensor.
|
|
133
|
+
*args (Any): Variable length argument list.
|
|
134
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
(torch.Tensor): Model prediction results.
|
|
138
|
+
"""
|
|
139
|
+
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
|
140
|
+
|
|
141
|
+
def get_vpe(self, source):
|
|
142
|
+
"""Process the source to get the visual prompt embeddings (VPE).
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
|
|
146
|
+
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
147
|
+
torch tensors.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|
|
151
|
+
"""
|
|
152
|
+
self.setup_source(source)
|
|
153
|
+
assert len(self.dataset) == 1, "get_vpe only supports one image!"
|
|
154
|
+
for _, im0s, _ in self.dataset:
|
|
155
|
+
im = self.preprocess(im0s)
|
|
156
|
+
return self.model(im, vpe=self.prompts, return_vpe=True)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
|
|
160
|
+
"""Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
|
|
161
|
+
|
|
162
|
+
pass
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import copy, deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ultralytics.data import YOLOConcatDataset, build_yolo_dataset
|
|
11
|
+
from ultralytics.data.augment import LoadVisualPrompt
|
|
12
|
+
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
|
|
13
|
+
from ultralytics.nn.tasks import YOLOEModel
|
|
14
|
+
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
|
15
|
+
from ultralytics.utils.torch_utils import unwrap_model
|
|
16
|
+
|
|
17
|
+
from ..world.train_world import WorldTrainerFromScratch
|
|
18
|
+
from .val import YOLOEDetectValidator
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class YOLOETrainer(DetectionTrainer):
|
|
22
|
+
"""A trainer class for YOLOE object detection models.
|
|
23
|
+
|
|
24
|
+
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
|
|
25
|
+
model initialization, validation, and dataset building with multi-modal support.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
loss_names (tuple): Names of loss components used during training.
|
|
29
|
+
|
|
30
|
+
Methods:
|
|
31
|
+
get_model: Initialize and return a YOLOEModel with specified configuration.
|
|
32
|
+
get_validator: Return a YOLOEDetectValidator for model validation.
|
|
33
|
+
build_dataset: Build YOLO dataset with multi-modal support for training.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
37
|
+
"""Initialize the YOLOE Trainer with specified configurations.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
|
42
|
+
_callbacks (list, optional): List of callback functions to be applied during training.
|
|
43
|
+
"""
|
|
44
|
+
if overrides is None:
|
|
45
|
+
overrides = {}
|
|
46
|
+
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
|
47
|
+
overrides["overlap_mask"] = False
|
|
48
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
49
|
+
|
|
50
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
51
|
+
"""Return a YOLOEModel initialized with the specified configuration and weights.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
|
|
55
|
+
path to a YAML file, or None to use default configuration.
|
|
56
|
+
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
|
57
|
+
verbose (bool): Whether to display model information during initialization.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
(YOLOEModel): The initialized YOLOE model.
|
|
61
|
+
|
|
62
|
+
Notes:
|
|
63
|
+
- The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
|
|
64
|
+
- The nc parameter here represents the maximum number of different text samples in one image,
|
|
65
|
+
rather than the actual number of classes.
|
|
66
|
+
"""
|
|
67
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
68
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
69
|
+
model = YOLOEModel(
|
|
70
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
71
|
+
ch=self.data["channels"],
|
|
72
|
+
nc=min(self.data["nc"], 80),
|
|
73
|
+
verbose=verbose and RANK == -1,
|
|
74
|
+
)
|
|
75
|
+
if weights:
|
|
76
|
+
model.load(weights)
|
|
77
|
+
|
|
78
|
+
return model
|
|
79
|
+
|
|
80
|
+
def get_validator(self):
|
|
81
|
+
"""Return a YOLOEDetectValidator for YOLOE model validation."""
|
|
82
|
+
self.loss_names = "box", "cls", "dfl"
|
|
83
|
+
return YOLOEDetectValidator(
|
|
84
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
88
|
+
"""Build YOLO Dataset.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
img_path (str): Path to the folder containing images.
|
|
92
|
+
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
|
93
|
+
batch (int, optional): Size of batches, this is for rectangular training.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
(Dataset): YOLO dataset configured for training or validation.
|
|
97
|
+
"""
|
|
98
|
+
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
|
99
|
+
return build_yolo_dataset(
|
|
100
|
+
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class YOLOEPETrainer(DetectionTrainer):
|
|
105
|
+
"""Fine-tune YOLOE model using linear probing approach.
|
|
106
|
+
|
|
107
|
+
This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
|
|
108
|
+
datasets while preserving pretrained features.
|
|
109
|
+
|
|
110
|
+
Methods:
|
|
111
|
+
get_model: Initialize YOLOEModel with frozen layers except projection layers.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
|
115
|
+
"""Return YOLOEModel initialized with specified config and weights.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
cfg (dict | str, optional): Model configuration.
|
|
119
|
+
weights (str, optional): Path to pretrained weights.
|
|
120
|
+
verbose (bool): Whether to display model information.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
(YOLOEModel): Initialized model with frozen layers except for specific projection layers.
|
|
124
|
+
"""
|
|
125
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
126
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
127
|
+
model = YOLOEModel(
|
|
128
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
129
|
+
ch=self.data["channels"],
|
|
130
|
+
nc=self.data["nc"],
|
|
131
|
+
verbose=verbose and RANK == -1,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
del model.model[-1].savpe
|
|
135
|
+
|
|
136
|
+
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
|
137
|
+
if weights:
|
|
138
|
+
model.load(weights)
|
|
139
|
+
|
|
140
|
+
model.eval()
|
|
141
|
+
names = list(self.data["names"].values())
|
|
142
|
+
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
|
143
|
+
# it'd get correct results as long as loading proper pretrained weights.
|
|
144
|
+
tpe = model.get_text_pe(names)
|
|
145
|
+
model.set_classes(names, tpe)
|
|
146
|
+
model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
|
|
147
|
+
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
|
148
|
+
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
|
149
|
+
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
|
150
|
+
del model.pe
|
|
151
|
+
model.train()
|
|
152
|
+
|
|
153
|
+
return model
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
|
157
|
+
"""Train YOLOE models from scratch with text embedding support.
|
|
158
|
+
|
|
159
|
+
This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
|
|
160
|
+
text embeddings and grounding datasets.
|
|
161
|
+
|
|
162
|
+
Methods:
|
|
163
|
+
build_dataset: Build datasets for training with grounding support.
|
|
164
|
+
generate_text_embeddings: Generate and cache text embeddings for training.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
168
|
+
"""Build YOLO Dataset for training or validation.
|
|
169
|
+
|
|
170
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
171
|
+
datasets and grounding datasets with different formats.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
175
|
+
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
|
176
|
+
batch (int, optional): Size of batches, used for rectangular training/validation.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
|
180
|
+
"""
|
|
181
|
+
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
|
182
|
+
|
|
183
|
+
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
|
184
|
+
"""Generate text embeddings for a list of text samples.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
texts (list[str]): List of text samples to encode.
|
|
188
|
+
batch (int): Batch size for processing.
|
|
189
|
+
cache_dir (Path): Directory to save/load cached embeddings.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
(dict): Dictionary mapping text samples to their embeddings.
|
|
193
|
+
"""
|
|
194
|
+
model = "mobileclip:blt"
|
|
195
|
+
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
|
196
|
+
if cache_path.exists():
|
|
197
|
+
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
|
198
|
+
txt_map = torch.load(cache_path, map_location=self.device)
|
|
199
|
+
if sorted(txt_map.keys()) == sorted(texts):
|
|
200
|
+
return txt_map
|
|
201
|
+
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
|
202
|
+
assert self.model is not None
|
|
203
|
+
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
|
204
|
+
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
|
205
|
+
torch.save(txt_map, cache_path)
|
|
206
|
+
return txt_map
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
|
210
|
+
"""Train prompt-free YOLOE model.
|
|
211
|
+
|
|
212
|
+
This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
|
|
213
|
+
require text prompts during inference.
|
|
214
|
+
|
|
215
|
+
Methods:
|
|
216
|
+
get_validator: Return standard DetectionValidator for validation.
|
|
217
|
+
preprocess_batch: Preprocess batches without text features.
|
|
218
|
+
set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def get_validator(self):
|
|
222
|
+
"""Return a DetectionValidator for YOLO model validation."""
|
|
223
|
+
self.loss_names = "box", "cls", "dfl"
|
|
224
|
+
return DetectionValidator(
|
|
225
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def preprocess_batch(self, batch):
|
|
229
|
+
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
|
|
230
|
+
return DetectionTrainer.preprocess_batch(self, batch)
|
|
231
|
+
|
|
232
|
+
def set_text_embeddings(self, datasets, batch: int):
|
|
233
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
234
|
+
|
|
235
|
+
This method collects unique category names from all datasets, generates text embeddings for them, and caches
|
|
236
|
+
these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
|
|
237
|
+
the first dataset's image path.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
datasets (list[Dataset]): List of datasets containing category names to process.
|
|
241
|
+
batch (int): Batch size for processing text embeddings.
|
|
242
|
+
|
|
243
|
+
Notes:
|
|
244
|
+
The method creates a dictionary mapping text samples to their embeddings and stores it
|
|
245
|
+
at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
|
|
246
|
+
instead of regenerating the embeddings.
|
|
247
|
+
"""
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
|
252
|
+
"""Train YOLOE model with visual prompts.
|
|
253
|
+
|
|
254
|
+
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
|
|
255
|
+
alongside images to guide the detection process.
|
|
256
|
+
|
|
257
|
+
Methods:
|
|
258
|
+
build_dataset: Build dataset with visual prompt loading transforms.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
|
262
|
+
"""Build YOLO Dataset for training or validation with visual prompts.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
266
|
+
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
|
267
|
+
batch (int, optional): Size of batches, used for rectangular training/validation.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
(Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
|
|
271
|
+
"""
|
|
272
|
+
dataset = super().build_dataset(img_path, mode, batch)
|
|
273
|
+
if isinstance(dataset, YOLOConcatDataset):
|
|
274
|
+
for d in dataset.datasets:
|
|
275
|
+
d.transforms.append(LoadVisualPrompt())
|
|
276
|
+
else:
|
|
277
|
+
dataset.transforms.append(LoadVisualPrompt())
|
|
278
|
+
return dataset
|
|
279
|
+
|
|
280
|
+
def _close_dataloader_mosaic(self):
|
|
281
|
+
"""Close mosaic augmentation and add visual prompt loading to the training dataset."""
|
|
282
|
+
super()._close_dataloader_mosaic()
|
|
283
|
+
if isinstance(self.train_loader.dataset, YOLOConcatDataset):
|
|
284
|
+
for d in self.train_loader.dataset.datasets:
|
|
285
|
+
d.transforms.append(LoadVisualPrompt())
|
|
286
|
+
else:
|
|
287
|
+
self.train_loader.dataset.transforms.append(LoadVisualPrompt())
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from copy import copy, deepcopy
|
|
4
|
+
|
|
5
|
+
from ultralytics.models.yolo.segment import SegmentationTrainer
|
|
6
|
+
from ultralytics.nn.tasks import YOLOESegModel
|
|
7
|
+
from ultralytics.utils import RANK
|
|
8
|
+
|
|
9
|
+
from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
|
10
|
+
from .val import YOLOESegValidator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
|
14
|
+
"""Trainer class for YOLOE segmentation models.
|
|
15
|
+
|
|
16
|
+
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
|
17
|
+
segmentation models, enabling both object detection and instance segmentation capabilities.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
cfg (dict): Configuration dictionary with training parameters.
|
|
21
|
+
overrides (dict): Dictionary with parameter overrides.
|
|
22
|
+
_callbacks (list): List of callback functions for training events.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
26
|
+
"""Return YOLOESegModel initialized with specified config and weights.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
30
|
+
weights (str, optional): Path to pretrained weights file.
|
|
31
|
+
verbose (bool): Whether to display model information.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
(YOLOESegModel): Initialized YOLOE segmentation model.
|
|
35
|
+
"""
|
|
36
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
37
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
38
|
+
model = YOLOESegModel(
|
|
39
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
40
|
+
ch=self.data["channels"],
|
|
41
|
+
nc=min(self.data["nc"], 80),
|
|
42
|
+
verbose=verbose and RANK == -1,
|
|
43
|
+
)
|
|
44
|
+
if weights:
|
|
45
|
+
model.load(weights)
|
|
46
|
+
|
|
47
|
+
return model
|
|
48
|
+
|
|
49
|
+
def get_validator(self):
|
|
50
|
+
"""Create and return a validator for YOLOE segmentation model evaluation.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
|
54
|
+
"""
|
|
55
|
+
self.loss_names = "box", "seg", "cls", "dfl"
|
|
56
|
+
return YOLOESegValidator(
|
|
57
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class YOLOEPESegTrainer(SegmentationTrainer):
|
|
62
|
+
"""Fine-tune YOLOESeg model in linear probing way.
|
|
63
|
+
|
|
64
|
+
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
|
65
|
+
most of the model and only training specific layers for efficient adaptation to new tasks.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
data (dict): Dataset configuration containing channels, class names, and number of classes.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
72
|
+
"""Return YOLOESegModel initialized with specified config and weights for linear probing.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
|
76
|
+
weights (str, optional): Path to pretrained weights file.
|
|
77
|
+
verbose (bool): Whether to display model information.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
(YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
|
|
81
|
+
"""
|
|
82
|
+
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
|
83
|
+
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
|
84
|
+
model = YOLOESegModel(
|
|
85
|
+
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
|
86
|
+
ch=self.data["channels"],
|
|
87
|
+
nc=self.data["nc"],
|
|
88
|
+
verbose=verbose and RANK == -1,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
del model.model[-1].savpe
|
|
92
|
+
|
|
93
|
+
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
|
94
|
+
if weights:
|
|
95
|
+
model.load(weights)
|
|
96
|
+
|
|
97
|
+
model.eval()
|
|
98
|
+
names = list(self.data["names"].values())
|
|
99
|
+
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
|
100
|
+
# it'd get correct results as long as loading proper pretrained weights.
|
|
101
|
+
tpe = model.get_text_pe(names)
|
|
102
|
+
model.set_classes(names, tpe)
|
|
103
|
+
model.model[-1].fuse(model.pe)
|
|
104
|
+
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
|
105
|
+
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
|
106
|
+
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
|
107
|
+
del model.pe
|
|
108
|
+
model.train()
|
|
109
|
+
|
|
110
|
+
return model
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
|
|
114
|
+
"""Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
|
|
115
|
+
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
|
|
120
|
+
"""Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
|
|
121
|
+
|
|
122
|
+
pass
|