dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -8,12 +8,10 @@ from pathlib import Path
|
|
|
8
8
|
from ultralytics.models import yolo
|
|
9
9
|
from ultralytics.nn.tasks import SegmentationModel
|
|
10
10
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
11
|
-
from ultralytics.utils.plotting import plot_results
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
17
15
|
|
|
18
16
|
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
19
17
|
functionality including model initialization, validation, and visualization.
|
|
@@ -29,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
29
27
|
"""
|
|
30
28
|
|
|
31
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
32
|
-
"""
|
|
33
|
-
Initialize a SegmentationTrainer object.
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
34
31
|
|
|
35
32
|
Args:
|
|
36
33
|
cfg (dict): Configuration dictionary with default training settings.
|
|
@@ -41,11 +38,9 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
41
38
|
overrides = {}
|
|
42
39
|
overrides["task"] = "segment"
|
|
43
40
|
super().__init__(cfg, overrides, _callbacks)
|
|
44
|
-
self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "masks"]
|
|
45
41
|
|
|
46
42
|
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
47
|
-
"""
|
|
48
|
-
Initialize and return a SegmentationModel with specified configuration and weights.
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
49
44
|
|
|
50
45
|
Args:
|
|
51
46
|
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
@@ -72,7 +67,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
72
67
|
return yolo.segment.SegmentationValidator(
|
|
73
68
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
74
69
|
)
|
|
75
|
-
|
|
76
|
-
def plot_metrics(self):
|
|
77
|
-
"""Plot training/validation metrics."""
|
|
78
|
-
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from multiprocessing.pool import ThreadPool
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
from typing import Any
|
|
8
7
|
|
|
@@ -11,17 +10,16 @@ import torch
|
|
|
11
10
|
import torch.nn.functional as F
|
|
12
11
|
|
|
13
12
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
14
|
-
from ultralytics.utils import LOGGER,
|
|
13
|
+
from ultralytics.utils import LOGGER, ops
|
|
15
14
|
from ultralytics.utils.checks import check_requirements
|
|
16
15
|
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
class SegmentationValidator(DetectionValidator):
|
|
20
|
-
"""
|
|
21
|
-
A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
19
|
+
"""A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
22
20
|
|
|
23
|
-
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
|
24
|
-
|
|
21
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
|
|
22
|
+
compute metrics such as mAP for both detection and segmentation tasks.
|
|
25
23
|
|
|
26
24
|
Attributes:
|
|
27
25
|
plot_masks (list): List to store masks for plotting.
|
|
@@ -38,11 +36,10 @@ class SegmentationValidator(DetectionValidator):
|
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
41
|
-
"""
|
|
42
|
-
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
39
|
+
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
43
40
|
|
|
44
41
|
Args:
|
|
45
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
42
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
46
43
|
save_dir (Path, optional): Directory to save results.
|
|
47
44
|
args (namespace, optional): Arguments for the validator.
|
|
48
45
|
_callbacks (list, optional): List of callback functions.
|
|
@@ -53,8 +50,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
53
50
|
self.metrics = SegmentMetrics()
|
|
54
51
|
|
|
55
52
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Preprocess batch of images for YOLO segmentation validation.
|
|
53
|
+
"""Preprocess batch of images for YOLO segmentation validation.
|
|
58
54
|
|
|
59
55
|
Args:
|
|
60
56
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -67,8 +63,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
67
63
|
return batch
|
|
68
64
|
|
|
69
65
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
70
|
-
"""
|
|
71
|
-
Initialize metrics and select mask processing function based on save_json flag.
|
|
66
|
+
"""Initialize metrics and select mask processing function based on save_json flag.
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
69
|
model (torch.nn.Module): Model to validate.
|
|
@@ -96,8 +91,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
96
91
|
)
|
|
97
92
|
|
|
98
93
|
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
|
99
|
-
"""
|
|
100
|
-
Post-process YOLO predictions and return output detections with proto.
|
|
94
|
+
"""Post-process YOLO predictions and return output detections with proto.
|
|
101
95
|
|
|
102
96
|
Args:
|
|
103
97
|
preds (list[torch.Tensor]): Raw predictions from the model.
|
|
@@ -112,7 +106,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
112
106
|
coefficient = pred.pop("extra")
|
|
113
107
|
pred["masks"] = (
|
|
114
108
|
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
|
|
115
|
-
if
|
|
109
|
+
if coefficient.shape[0]
|
|
116
110
|
else torch.zeros(
|
|
117
111
|
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
|
|
118
112
|
dtype=torch.uint8,
|
|
@@ -122,8 +116,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
122
116
|
return preds
|
|
123
117
|
|
|
124
118
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
125
|
-
"""
|
|
126
|
-
Prepare a batch for training or inference by processing images and targets.
|
|
119
|
+
"""Prepare a batch for training or inference by processing images and targets.
|
|
127
120
|
|
|
128
121
|
Args:
|
|
129
122
|
si (int): Batch index.
|
|
@@ -133,22 +126,23 @@ class SegmentationValidator(DetectionValidator):
|
|
|
133
126
|
(dict[str, Any]): Prepared batch with processed annotations.
|
|
134
127
|
"""
|
|
135
128
|
prepared_batch = super()._prepare_batch(si, batch)
|
|
136
|
-
nl =
|
|
129
|
+
nl = prepared_batch["cls"].shape[0]
|
|
137
130
|
if self.args.overlap_mask:
|
|
138
131
|
masks = batch["masks"][si]
|
|
139
132
|
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
|
|
140
133
|
masks = (masks == index).float()
|
|
141
134
|
else:
|
|
142
135
|
masks = batch["masks"][batch["batch_idx"] == si]
|
|
143
|
-
if nl
|
|
144
|
-
|
|
145
|
-
masks
|
|
136
|
+
if nl:
|
|
137
|
+
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
|
|
138
|
+
if masks.shape[1:] != mask_size:
|
|
139
|
+
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
|
|
140
|
+
masks = masks.gt_(0.5)
|
|
146
141
|
prepared_batch["masks"] = masks
|
|
147
142
|
return prepared_batch
|
|
148
143
|
|
|
149
144
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
150
|
-
"""
|
|
151
|
-
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
145
|
+
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
152
146
|
|
|
153
147
|
Args:
|
|
154
148
|
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
|
@@ -157,28 +151,27 @@ class SegmentationValidator(DetectionValidator):
|
|
|
157
151
|
Returns:
|
|
158
152
|
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
|
159
153
|
|
|
160
|
-
Notes:
|
|
161
|
-
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
162
|
-
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
163
|
-
|
|
164
154
|
Examples:
|
|
165
155
|
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
166
156
|
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
167
157
|
>>> correct_preds = validator._process_batch(preds, batch)
|
|
158
|
+
|
|
159
|
+
Notes:
|
|
160
|
+
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
161
|
+
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
168
162
|
"""
|
|
169
163
|
tp = super()._process_batch(preds, batch)
|
|
170
164
|
gt_cls = batch["cls"]
|
|
171
|
-
if
|
|
172
|
-
tp_m = np.zeros((
|
|
165
|
+
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
166
|
+
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
173
167
|
else:
|
|
174
|
-
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
|
|
168
|
+
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1).float()) # float, uint8
|
|
175
169
|
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
|
176
170
|
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
|
177
171
|
return tp
|
|
178
172
|
|
|
179
173
|
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
180
|
-
"""
|
|
181
|
-
Plot batch predictions with masks and bounding boxes.
|
|
174
|
+
"""Plot batch predictions with masks and bounding boxes.
|
|
182
175
|
|
|
183
176
|
Args:
|
|
184
177
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -187,14 +180,13 @@ class SegmentationValidator(DetectionValidator):
|
|
|
187
180
|
"""
|
|
188
181
|
for p in preds:
|
|
189
182
|
masks = p["masks"]
|
|
190
|
-
if masks.shape[0] >
|
|
191
|
-
LOGGER.warning("Limiting validation plots to
|
|
192
|
-
p["masks"] = torch.as_tensor(masks[:
|
|
193
|
-
super().plot_predictions(batch, preds, ni, max_det=
|
|
183
|
+
if masks.shape[0] > self.args.max_det:
|
|
184
|
+
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
|
|
185
|
+
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
|
|
186
|
+
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
|
194
187
|
|
|
195
188
|
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
196
|
-
"""
|
|
197
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
189
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
198
190
|
|
|
199
191
|
Args:
|
|
200
192
|
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
|
@@ -213,24 +205,84 @@ class SegmentationValidator(DetectionValidator):
|
|
|
213
205
|
).save_txt(file, save_conf=save_conf)
|
|
214
206
|
|
|
215
207
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
216
|
-
"""
|
|
217
|
-
Save one JSON result for COCO evaluation.
|
|
208
|
+
"""Save one JSON result for COCO evaluation.
|
|
218
209
|
|
|
219
210
|
Args:
|
|
220
211
|
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
|
221
212
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
222
213
|
"""
|
|
223
|
-
from faster_coco_eval.core.mask import encode # noqa
|
|
224
|
-
|
|
225
|
-
def single_encode(x):
|
|
226
|
-
"""Encode predicted masks as RLE and append results to jdict."""
|
|
227
|
-
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
|
228
|
-
rle["counts"] = rle["counts"].decode("utf-8")
|
|
229
|
-
return rle
|
|
230
214
|
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
215
|
+
def to_string(counts: list[int]) -> str:
|
|
216
|
+
"""Converts the RLE object into a compact string representation. Each count is delta-encoded and
|
|
217
|
+
variable-length encoded as a string.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
counts (list[int]): List of RLE counts.
|
|
221
|
+
"""
|
|
222
|
+
result = []
|
|
223
|
+
|
|
224
|
+
for i in range(len(counts)):
|
|
225
|
+
x = int(counts[i])
|
|
226
|
+
|
|
227
|
+
# Apply delta encoding for all counts after the second entry
|
|
228
|
+
if i > 2:
|
|
229
|
+
x -= int(counts[i - 2])
|
|
230
|
+
|
|
231
|
+
# Variable-length encode the value
|
|
232
|
+
while True:
|
|
233
|
+
c = x & 0x1F # Take 5 bits
|
|
234
|
+
x >>= 5
|
|
235
|
+
|
|
236
|
+
# If the sign bit (0x10) is set, continue if x != -1;
|
|
237
|
+
# otherwise, continue if x != 0
|
|
238
|
+
more = (x != -1) if (c & 0x10) else (x != 0)
|
|
239
|
+
if more:
|
|
240
|
+
c |= 0x20 # Set continuation bit
|
|
241
|
+
c += 48 # Shift to ASCII
|
|
242
|
+
result.append(chr(c))
|
|
243
|
+
if not more:
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
return "".join(result)
|
|
247
|
+
|
|
248
|
+
def multi_encode(pixels: torch.Tensor) -> list[int]:
|
|
249
|
+
"""Convert multiple binary masks using Run-Length Encoding (RLE).
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
pixels (torch.Tensor): A 2D tensor where each row represents a flattened binary mask with shape [N,
|
|
253
|
+
H*W].
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
(list[int]): A list of RLE counts for each mask.
|
|
257
|
+
"""
|
|
258
|
+
transitions = pixels[:, 1:] != pixels[:, :-1]
|
|
259
|
+
row_idx, col_idx = torch.where(transitions)
|
|
260
|
+
col_idx = col_idx + 1
|
|
261
|
+
|
|
262
|
+
# Compute run lengths
|
|
263
|
+
counts = []
|
|
264
|
+
for i in range(pixels.shape[0]):
|
|
265
|
+
positions = col_idx[row_idx == i]
|
|
266
|
+
if len(positions):
|
|
267
|
+
count = torch.diff(positions).tolist()
|
|
268
|
+
count.insert(0, positions[0].item())
|
|
269
|
+
count.append(len(pixels[i]) - positions[-1].item())
|
|
270
|
+
else:
|
|
271
|
+
count = [len(pixels[i])]
|
|
272
|
+
|
|
273
|
+
# Ensure starting with background (0) count
|
|
274
|
+
if pixels[i][0].item() == 1:
|
|
275
|
+
count = [0, *count]
|
|
276
|
+
counts.append(count)
|
|
277
|
+
|
|
278
|
+
return counts
|
|
279
|
+
|
|
280
|
+
pred_masks = predn["masks"].transpose(2, 1).contiguous().view(len(predn["masks"]), -1) # N, H*W
|
|
281
|
+
h, w = predn["masks"].shape[1:3]
|
|
282
|
+
counts = multi_encode(pred_masks)
|
|
283
|
+
rles = []
|
|
284
|
+
for c in counts:
|
|
285
|
+
rles.append({"size": [h, w], "counts": to_string(c)})
|
|
234
286
|
super().pred_to_json(predn, pbatch)
|
|
235
287
|
for i, r in enumerate(rles):
|
|
236
288
|
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
|
@@ -239,11 +291,9 @@ class SegmentationValidator(DetectionValidator):
|
|
|
239
291
|
"""Scales predictions to the original image size."""
|
|
240
292
|
return {
|
|
241
293
|
**super().scale_preds(predn, pbatch),
|
|
242
|
-
"masks": ops.
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
ratio_pad=pbatch["ratio_pad"],
|
|
246
|
-
),
|
|
294
|
+
"masks": ops.scale_masks(predn["masks"][None], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])[
|
|
295
|
+
0
|
|
296
|
+
].byte(),
|
|
247
297
|
}
|
|
248
298
|
|
|
249
299
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class WorldTrainer(DetectionTrainer):
|
|
27
|
-
"""
|
|
28
|
-
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
27
|
+
"""A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
29
28
|
|
|
30
29
|
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
|
31
30
|
features for improved object detection and understanding. It handles text embedding generation and caching to
|
|
@@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
54
53
|
"""
|
|
55
54
|
|
|
56
55
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
57
|
-
"""
|
|
58
|
-
Initialize a WorldTrainer object with given arguments.
|
|
56
|
+
"""Initialize a WorldTrainer object with given arguments.
|
|
59
57
|
|
|
60
58
|
Args:
|
|
61
59
|
cfg (dict[str, Any]): Configuration for the trainer.
|
|
@@ -64,12 +62,12 @@ class WorldTrainer(DetectionTrainer):
|
|
|
64
62
|
"""
|
|
65
63
|
if overrides is None:
|
|
66
64
|
overrides = {}
|
|
65
|
+
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
|
67
66
|
super().__init__(cfg, overrides, _callbacks)
|
|
68
67
|
self.text_embeddings = None
|
|
69
68
|
|
|
70
69
|
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
|
71
|
-
"""
|
|
72
|
-
Return WorldModel initialized with specified config and weights.
|
|
70
|
+
"""Return WorldModel initialized with specified config and weights.
|
|
73
71
|
|
|
74
72
|
Args:
|
|
75
73
|
cfg (dict[str, Any] | str, optional): Model configuration.
|
|
@@ -94,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
94
92
|
return model
|
|
95
93
|
|
|
96
94
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
97
|
-
"""
|
|
98
|
-
Build YOLO Dataset for training or validation.
|
|
95
|
+
"""Build YOLO Dataset for training or validation.
|
|
99
96
|
|
|
100
97
|
Args:
|
|
101
98
|
img_path (str): Path to the folder containing images.
|
|
@@ -114,11 +111,10 @@ class WorldTrainer(DetectionTrainer):
|
|
|
114
111
|
return dataset
|
|
115
112
|
|
|
116
113
|
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
|
117
|
-
"""
|
|
118
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
|
114
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
119
115
|
|
|
120
|
-
This method collects unique category names from all datasets, then generates and caches text embeddings
|
|
121
|
-
|
|
116
|
+
This method collects unique category names from all datasets, then generates and caches text embeddings for
|
|
117
|
+
these categories to improve training efficiency.
|
|
122
118
|
|
|
123
119
|
Args:
|
|
124
120
|
datasets (list[Any]): List of datasets from which to extract category names.
|
|
@@ -140,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
140
136
|
self.text_embeddings = text_embeddings
|
|
141
137
|
|
|
142
138
|
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
|
143
|
-
"""
|
|
144
|
-
Generate text embeddings for a list of text samples.
|
|
139
|
+
"""Generate text embeddings for a list of text samples.
|
|
145
140
|
|
|
146
141
|
Args:
|
|
147
142
|
texts (list[str]): List of text samples to encode.
|
|
@@ -171,7 +166,8 @@ class WorldTrainer(DetectionTrainer):
|
|
|
171
166
|
|
|
172
167
|
# Add text features
|
|
173
168
|
texts = list(itertools.chain(*batch["texts"]))
|
|
174
|
-
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
|
|
175
|
-
|
|
169
|
+
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
|
|
170
|
+
self.device, non_blocking=self.device.type == "cuda"
|
|
171
|
+
)
|
|
176
172
|
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
|
177
173
|
return batch
|
|
@@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class WorldTrainerFromScratch(WorldTrainer):
|
|
13
|
-
"""
|
|
14
|
-
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
13
|
+
"""A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
15
14
|
|
|
16
15
|
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
|
17
16
|
supporting training YOLO-World models with combined vision-language capabilities.
|
|
@@ -53,45 +52,25 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
53
52
|
"""
|
|
54
53
|
|
|
55
54
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
56
|
-
"""
|
|
57
|
-
Initialize a WorldTrainerFromScratch object.
|
|
55
|
+
"""Initialize a WorldTrainerFromScratch object.
|
|
58
56
|
|
|
59
|
-
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
|
|
60
|
-
|
|
57
|
+
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
|
|
58
|
+
detection and grounding datasets for vision-language capabilities.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
61
|
cfg (dict): Configuration dictionary with default parameters for model training.
|
|
64
62
|
overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
|
|
65
63
|
_callbacks (list, optional): List of callback functions to be executed during different stages of training.
|
|
66
|
-
|
|
67
|
-
Examples:
|
|
68
|
-
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
|
69
|
-
>>> from ultralytics import YOLOWorld
|
|
70
|
-
>>> data = dict(
|
|
71
|
-
... train=dict(
|
|
72
|
-
... yolo_data=["Objects365.yaml"],
|
|
73
|
-
... grounding_data=[
|
|
74
|
-
... dict(
|
|
75
|
-
... img_path="flickr30k/images",
|
|
76
|
-
... json_file="flickr30k/final_flickr_separateGT_train.json",
|
|
77
|
-
... ),
|
|
78
|
-
... ],
|
|
79
|
-
... ),
|
|
80
|
-
... val=dict(yolo_data=["lvis.yaml"]),
|
|
81
|
-
... )
|
|
82
|
-
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
|
|
83
|
-
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
|
|
84
64
|
"""
|
|
85
65
|
if overrides is None:
|
|
86
66
|
overrides = {}
|
|
87
67
|
super().__init__(cfg, overrides, _callbacks)
|
|
88
68
|
|
|
89
69
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
90
|
-
"""
|
|
91
|
-
Build YOLO Dataset for training or validation.
|
|
70
|
+
"""Build YOLO Dataset for training or validation.
|
|
92
71
|
|
|
93
|
-
This method constructs appropriate datasets based on the mode and input paths, handling both
|
|
94
|
-
|
|
72
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
73
|
+
datasets and grounding datasets with different formats.
|
|
95
74
|
|
|
96
75
|
Args:
|
|
97
76
|
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
@@ -122,11 +101,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
122
101
|
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
|
123
102
|
|
|
124
103
|
def get_dataset(self):
|
|
125
|
-
"""
|
|
126
|
-
Get train and validation paths from data dictionary.
|
|
104
|
+
"""Get train and validation paths from data dictionary.
|
|
127
105
|
|
|
128
|
-
Processes the data configuration to extract paths for training and validation datasets,
|
|
129
|
-
|
|
106
|
+
Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
|
|
107
|
+
detection datasets and grounding datasets.
|
|
130
108
|
|
|
131
109
|
Returns:
|
|
132
110
|
train_path (str): Train dataset path.
|
|
@@ -187,8 +165,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
187
165
|
pass
|
|
188
166
|
|
|
189
167
|
def final_eval(self):
|
|
190
|
-
"""
|
|
191
|
-
Perform final evaluation and validation for the YOLO-World model.
|
|
168
|
+
"""Perform final evaluation and validation for the YOLO-World model.
|
|
192
169
|
|
|
193
170
|
Configures the validator with appropriate dataset and split information before running evaluation.
|
|
194
171
|
|
|
@@ -6,17 +6,17 @@ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromSc
|
|
|
6
6
|
from .val import YOLOEDetectValidator, YOLOESegValidator
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
|
-
"YOLOETrainer",
|
|
10
|
-
"YOLOEPETrainer",
|
|
11
|
-
"YOLOESegTrainer",
|
|
12
9
|
"YOLOEDetectValidator",
|
|
13
|
-
"
|
|
10
|
+
"YOLOEPEFreeTrainer",
|
|
14
11
|
"YOLOEPESegTrainer",
|
|
12
|
+
"YOLOEPETrainer",
|
|
13
|
+
"YOLOESegTrainer",
|
|
15
14
|
"YOLOESegTrainerFromScratch",
|
|
16
15
|
"YOLOESegVPTrainer",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
16
|
+
"YOLOESegValidator",
|
|
17
|
+
"YOLOETrainer",
|
|
18
|
+
"YOLOETrainerFromScratch",
|
|
19
19
|
"YOLOEVPDetectPredictor",
|
|
20
20
|
"YOLOEVPSegPredictor",
|
|
21
|
-
"
|
|
21
|
+
"YOLOEVPTrainer",
|
|
22
22
|
]
|
|
@@ -9,11 +9,10 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
12
|
-
"""
|
|
13
|
-
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
12
|
+
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
14
13
|
|
|
15
|
-
This mixin provides common functionality for YOLO models that use visual prompting, including
|
|
16
|
-
|
|
14
|
+
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
|
|
15
|
+
handling, and preprocessing transformations.
|
|
17
16
|
|
|
18
17
|
Attributes:
|
|
19
18
|
model (torch.nn.Module): The YOLO model for inference.
|
|
@@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
29
28
|
"""
|
|
30
29
|
|
|
31
30
|
def setup_model(self, model, verbose: bool = True):
|
|
32
|
-
"""
|
|
33
|
-
Set up the model for prediction.
|
|
31
|
+
"""Set up the model for prediction.
|
|
34
32
|
|
|
35
33
|
Args:
|
|
36
34
|
model (torch.nn.Module): Model to load or use.
|
|
@@ -40,21 +38,19 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
40
38
|
self.done_warmup = True
|
|
41
39
|
|
|
42
40
|
def set_prompts(self, prompts):
|
|
43
|
-
"""
|
|
44
|
-
Set the visual prompts for the model.
|
|
41
|
+
"""Set the visual prompts for the model.
|
|
45
42
|
|
|
46
43
|
Args:
|
|
47
|
-
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
|
48
|
-
|
|
44
|
+
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
|
|
45
|
+
with class indices.
|
|
49
46
|
"""
|
|
50
47
|
self.prompts = prompts
|
|
51
48
|
|
|
52
49
|
def pre_transform(self, im):
|
|
53
|
-
"""
|
|
54
|
-
Preprocess images and prompts before inference.
|
|
50
|
+
"""Preprocess images and prompts before inference.
|
|
55
51
|
|
|
56
|
-
This method applies letterboxing to the input image and transforms the visual prompts
|
|
57
|
-
|
|
52
|
+
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
|
|
53
|
+
accordingly.
|
|
58
54
|
|
|
59
55
|
Args:
|
|
60
56
|
im (list): List containing a single input image.
|
|
@@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
94
90
|
return img
|
|
95
91
|
|
|
96
92
|
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
|
97
|
-
"""
|
|
98
|
-
Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
93
|
+
"""Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
99
94
|
|
|
100
95
|
Args:
|
|
101
96
|
dst_shape (tuple): The target shape (height, width) of the image.
|
|
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
131
126
|
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
|
132
127
|
|
|
133
128
|
def inference(self, im, *args, **kwargs):
|
|
134
|
-
"""
|
|
135
|
-
Run inference with visual prompts.
|
|
129
|
+
"""Run inference with visual prompts.
|
|
136
130
|
|
|
137
131
|
Args:
|
|
138
132
|
im (torch.Tensor): Input image tensor.
|
|
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
145
139
|
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
|
146
140
|
|
|
147
141
|
def get_vpe(self, source):
|
|
148
|
-
"""
|
|
149
|
-
Process the source to get the visual prompt embeddings (VPE).
|
|
142
|
+
"""Process the source to get the visual prompt embeddings (VPE).
|
|
150
143
|
|
|
151
144
|
Args:
|
|
152
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
|
153
|
-
|
|
154
|
-
|
|
145
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
|
|
146
|
+
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
147
|
+
torch tensors.
|
|
155
148
|
|
|
156
149
|
Returns:
|
|
157
150
|
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|