ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +56 -0
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/data/augment.py +101 -5
- ultralytics/data/dataset.py +165 -12
- ultralytics/engine/exporter.py +13 -13
- ultralytics/engine/trainer.py +16 -7
- ultralytics/models/__init__.py +2 -2
- ultralytics/models/nas/model.py +1 -0
- ultralytics/models/nas/predict.py +4 -24
- ultralytics/models/nas/val.py +1 -4
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/detect/val.py +6 -1
- ultralytics/models/yolo/model.py +182 -3
- ultralytics/models/yolo/segment/val.py +43 -16
- ultralytics/models/yolo/yoloe/__init__.py +21 -0
- ultralytics/models/yolo/yoloe/predict.py +170 -0
- ultralytics/models/yolo/yoloe/train.py +355 -0
- ultralytics/models/yolo/yoloe/train_seg.py +141 -0
- ultralytics/models/yolo/yoloe/val.py +187 -0
- ultralytics/nn/autobackend.py +3 -2
- ultralytics/nn/modules/__init__.py +18 -1
- ultralytics/nn/modules/block.py +17 -1
- ultralytics/nn/modules/head.py +359 -22
- ultralytics/nn/tasks.py +276 -10
- ultralytics/nn/text_model.py +193 -0
- ultralytics/utils/callbacks/comet.py +3 -6
- ultralytics/utils/downloads.py +6 -2
- ultralytics/utils/instance.py +7 -2
- ultralytics/utils/loss.py +67 -6
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/tal.py +1 -1
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
@@ -4,7 +4,16 @@ from pathlib import Path
|
|
4
4
|
|
5
5
|
from ultralytics.engine.model import Model
|
6
6
|
from ultralytics.models import yolo
|
7
|
-
from ultralytics.nn.tasks import
|
7
|
+
from ultralytics.nn.tasks import (
|
8
|
+
ClassificationModel,
|
9
|
+
DetectionModel,
|
10
|
+
OBBModel,
|
11
|
+
PoseModel,
|
12
|
+
SegmentationModel,
|
13
|
+
WorldModel,
|
14
|
+
YOLOEModel,
|
15
|
+
YOLOESegModel,
|
16
|
+
)
|
8
17
|
from ultralytics.utils import ROOT, yaml_load
|
9
18
|
|
10
19
|
|
@@ -12,12 +21,16 @@ class YOLO(Model):
|
|
12
21
|
"""YOLO (You Only Look Once) object detection model."""
|
13
22
|
|
14
23
|
def __init__(self, model="yolo11n.pt", task=None, verbose=False):
|
15
|
-
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
|
24
|
+
"""Initialize YOLO model, switching to YOLOWorld/YOLOE if model filename contains '-world'/'yoloe'."""
|
16
25
|
path = Path(model)
|
17
26
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
18
27
|
new_instance = YOLOWorld(path, verbose=verbose)
|
19
28
|
self.__class__ = type(new_instance)
|
20
29
|
self.__dict__ = new_instance.__dict__
|
30
|
+
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
|
31
|
+
new_instance = YOLOE(path, task=task, verbose=verbose)
|
32
|
+
self.__class__ = type(new_instance)
|
33
|
+
self.__dict__ = new_instance.__dict__
|
21
34
|
else:
|
22
35
|
# Continue with default YOLO initialization
|
23
36
|
super().__init__(model=model, task=task, verbose=verbose)
|
@@ -96,7 +109,7 @@ class YOLOWorld(Model):
|
|
96
109
|
Set the model's class names for detection.
|
97
110
|
|
98
111
|
Args:
|
99
|
-
classes (
|
112
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
100
113
|
"""
|
101
114
|
self.model.set_classes(classes)
|
102
115
|
# Remove background if it's given
|
@@ -108,3 +121,169 @@ class YOLOWorld(Model):
|
|
108
121
|
# Reset method class names
|
109
122
|
if self.predictor:
|
110
123
|
self.predictor.model.names = classes
|
124
|
+
|
125
|
+
|
126
|
+
class YOLOE(Model):
|
127
|
+
"""YOLOE object detection and segmentation model."""
|
128
|
+
|
129
|
+
def __init__(self, model="yoloe-v8s-seg.pt", task=None, verbose=False) -> None:
|
130
|
+
"""
|
131
|
+
Initialize YOLOE model with a pre-trained model file.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
135
|
+
task (str, optional): Task type for the model. Auto-detected if None.
|
136
|
+
verbose (bool): If True, prints additional information during initialization.
|
137
|
+
"""
|
138
|
+
super().__init__(model=model, task=task, verbose=verbose)
|
139
|
+
|
140
|
+
# Assign default COCO class names when there are no custom names
|
141
|
+
if not hasattr(self.model, "names"):
|
142
|
+
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
143
|
+
|
144
|
+
@property
|
145
|
+
def task_map(self):
|
146
|
+
"""Map head to model, validator, and predictor classes."""
|
147
|
+
return {
|
148
|
+
"detect": {
|
149
|
+
"model": YOLOEModel,
|
150
|
+
"validator": yolo.yoloe.YOLOEDetectValidator,
|
151
|
+
"predictor": yolo.detect.DetectionPredictor,
|
152
|
+
"trainer": yolo.yoloe.YOLOETrainer,
|
153
|
+
},
|
154
|
+
"segment": {
|
155
|
+
"model": YOLOESegModel,
|
156
|
+
"validator": yolo.yoloe.YOLOESegValidator,
|
157
|
+
"predictor": yolo.segment.SegmentationPredictor,
|
158
|
+
"trainer": yolo.yoloe.YOLOESegTrainer,
|
159
|
+
},
|
160
|
+
}
|
161
|
+
|
162
|
+
def get_text_pe(self, texts):
|
163
|
+
"""Get text positional embeddings for the given texts."""
|
164
|
+
assert isinstance(self.model, YOLOEModel)
|
165
|
+
return self.model.get_text_pe(texts)
|
166
|
+
|
167
|
+
def get_visual_pe(self, img, visual):
|
168
|
+
"""Get visual positional embeddings for the given image and visual features."""
|
169
|
+
assert isinstance(self.model, YOLOEModel)
|
170
|
+
return self.model.get_visual_pe(img, visual)
|
171
|
+
|
172
|
+
def set_vocab(self, vocab, names):
|
173
|
+
"""Set vocabulary and class names for the model."""
|
174
|
+
assert isinstance(self.model, YOLOEModel)
|
175
|
+
self.model.set_vocab(vocab, names=names)
|
176
|
+
|
177
|
+
def get_vocab(self, names):
|
178
|
+
"""Get vocabulary for the given class names."""
|
179
|
+
assert isinstance(self.model, YOLOEModel)
|
180
|
+
return self.model.get_vocab(names)
|
181
|
+
|
182
|
+
def set_classes(self, classes, embeddings):
|
183
|
+
"""
|
184
|
+
Set the model's class names and embeddings for detection.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
classes (list[str]): A list of categories i.e. ["person"].
|
188
|
+
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
189
|
+
"""
|
190
|
+
assert isinstance(self.model, YOLOEModel)
|
191
|
+
self.model.set_classes(classes, embeddings)
|
192
|
+
# Verify no background class is present
|
193
|
+
assert " " not in classes
|
194
|
+
self.model.names = classes
|
195
|
+
|
196
|
+
# Reset method class names
|
197
|
+
if self.predictor:
|
198
|
+
self.predictor.model.names = classes
|
199
|
+
|
200
|
+
def val(
|
201
|
+
self,
|
202
|
+
validator=None,
|
203
|
+
load_vp=False,
|
204
|
+
refer_data=None,
|
205
|
+
**kwargs,
|
206
|
+
):
|
207
|
+
"""
|
208
|
+
Validate the model using text or visual prompts.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
212
|
+
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
213
|
+
refer_data (str, optional): Path to the reference data for visual prompts.
|
214
|
+
**kwargs (Any): Additional keyword arguments to override default settings.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
(dict): Validation statistics containing metrics computed during validation.
|
218
|
+
"""
|
219
|
+
custom = {"rect": not load_vp} # method defaults
|
220
|
+
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
221
|
+
|
222
|
+
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
223
|
+
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
|
224
|
+
self.metrics = validator.metrics
|
225
|
+
return validator.metrics
|
226
|
+
|
227
|
+
def predict(
|
228
|
+
self,
|
229
|
+
source=None,
|
230
|
+
stream: bool = False,
|
231
|
+
visual_prompts: dict = {},
|
232
|
+
refer_image=None,
|
233
|
+
predictor=None,
|
234
|
+
**kwargs,
|
235
|
+
):
|
236
|
+
"""
|
237
|
+
Run prediction on images, videos, directories, streams, etc.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
241
|
+
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
242
|
+
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
243
|
+
generator as they are computed.
|
244
|
+
visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
|
245
|
+
'cls' keys when non-empty.
|
246
|
+
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
247
|
+
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
248
|
+
loaded based on the task.
|
249
|
+
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
(List | generator): List of Results objects or generator of Results objects if stream=True.
|
253
|
+
|
254
|
+
Examples:
|
255
|
+
>>> model = YOLOE("yoloe-v8s-seg.pt")
|
256
|
+
>>> results = model.predict("path/to/image.jpg")
|
257
|
+
>>> # With visual prompts
|
258
|
+
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
259
|
+
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
|
260
|
+
"""
|
261
|
+
if len(visual_prompts):
|
262
|
+
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
|
263
|
+
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
|
264
|
+
)
|
265
|
+
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
|
266
|
+
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
|
267
|
+
f"{len(visual_prompts['cls'])} respectively"
|
268
|
+
)
|
269
|
+
self.predictor = (predictor or self._smart_load("predictor"))(
|
270
|
+
overrides={"task": "segment", "mode": "predict", "save": False, "verbose": False}, _callbacks=self.callbacks
|
271
|
+
)
|
272
|
+
|
273
|
+
if len(visual_prompts):
|
274
|
+
num_cls = (
|
275
|
+
max(len(set(c)) for c in visual_prompts["cls"])
|
276
|
+
if isinstance(source, list) # means multiple images
|
277
|
+
else len(set(visual_prompts["cls"]))
|
278
|
+
)
|
279
|
+
self.model.model[-1].nc = num_cls
|
280
|
+
self.model.names = [f"object{i}" for i in range(num_cls)]
|
281
|
+
self.predictor.set_prompts(visual_prompts)
|
282
|
+
|
283
|
+
self.predictor.setup_model(model=self.model)
|
284
|
+
if refer_image is not None and len(visual_prompts):
|
285
|
+
vpe = self.predictor.get_vpe(refer_image)
|
286
|
+
self.model.set_classes(self.model.names, vpe)
|
287
|
+
self.predictor = None # reset predictor
|
288
|
+
|
289
|
+
return super().predict(source, stream, **kwargs)
|
@@ -364,29 +364,56 @@ class SegmentationValidator(DetectionValidator):
|
|
364
364
|
|
365
365
|
def eval_json(self, stats):
|
366
366
|
"""Return COCO-style object detection evaluation metrics."""
|
367
|
-
if self.args.save_json and self.is_coco and len(self.jdict):
|
368
|
-
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
367
|
+
if self.args.save_json and (self.is_lvis or self.is_coco) and len(self.jdict):
|
369
368
|
pred_json = self.save_dir / "predictions.json" # predictions
|
370
|
-
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
371
|
-
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
372
|
-
check_requirements("pycocotools>=2.0.6")
|
373
|
-
from pycocotools.coco import COCO # noqa
|
374
|
-
from pycocotools.cocoeval import COCOeval # noqa
|
375
369
|
|
370
|
+
anno_json = (
|
371
|
+
self.data["path"]
|
372
|
+
/ "annotations"
|
373
|
+
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
374
|
+
) # annotations
|
375
|
+
|
376
|
+
pkg = "pycocotools" if self.is_coco else "lvis"
|
377
|
+
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
|
378
|
+
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
376
379
|
for x in anno_json, pred_json:
|
377
380
|
assert x.is_file(), f"{x} file not found"
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
381
|
+
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
|
382
|
+
if self.is_coco:
|
383
|
+
from pycocotools.coco import COCO # noqa
|
384
|
+
from pycocotools.cocoeval import COCOeval # noqa
|
385
|
+
|
386
|
+
anno = COCO(str(anno_json)) # init annotations api
|
387
|
+
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
388
|
+
vals = [COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]
|
389
|
+
else:
|
390
|
+
from lvis import LVIS, LVISEval
|
391
|
+
|
392
|
+
anno = LVIS(str(anno_json))
|
393
|
+
pred = anno._load_json(str(pred_json))
|
394
|
+
vals = [LVISEval(anno, pred, "bbox"), LVISEval(anno, pred, "segm")]
|
395
|
+
|
396
|
+
for i, eval in enumerate(vals):
|
397
|
+
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
383
398
|
eval.evaluate()
|
384
399
|
eval.accumulate()
|
385
400
|
eval.summarize()
|
401
|
+
if self.is_lvis:
|
402
|
+
eval.print_results()
|
386
403
|
idx = i * 4 + 2
|
387
|
-
|
388
|
-
|
389
|
-
|
404
|
+
# update mAP50-95 and mAP50
|
405
|
+
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = (
|
406
|
+
eval.stats[:2] if self.is_coco else [eval.results["AP"], eval.results["AP50"]]
|
407
|
+
)
|
408
|
+
if self.is_lvis:
|
409
|
+
tag = "B" if i == 0 else "M"
|
410
|
+
stats[f"metrics/APr({tag})"] = eval.results["APr"]
|
411
|
+
stats[f"metrics/APc({tag})"] = eval.results["APc"]
|
412
|
+
stats[f"metrics/APf({tag})"] = eval.results["APf"]
|
413
|
+
|
414
|
+
if self.is_lvis:
|
415
|
+
stats["fitness"] = stats["metrics/mAP50-95(B)"]
|
416
|
+
|
390
417
|
except Exception as e:
|
391
|
-
LOGGER.warning(f"
|
418
|
+
LOGGER.warning(f"{pkg} unable to run: {e}")
|
392
419
|
return stats
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
|
4
|
+
from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOEVPTrainer
|
5
|
+
from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
|
6
|
+
from .val import YOLOEDetectValidator, YOLOESegValidator
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"YOLOETrainer",
|
10
|
+
"YOLOEPETrainer",
|
11
|
+
"YOLOESegTrainer",
|
12
|
+
"YOLOEDetectValidator",
|
13
|
+
"YOLOESegValidator",
|
14
|
+
"YOLOEPESegTrainer",
|
15
|
+
"YOLOESegTrainerFromScratch",
|
16
|
+
"YOLOESegVPTrainer",
|
17
|
+
"YOLOEVPTrainer",
|
18
|
+
"YOLOEPEFreeTrainer",
|
19
|
+
"YOLOEVPDetectPredictor",
|
20
|
+
"YOLOEVPSegPredictor",
|
21
|
+
]
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from ultralytics.data.augment import LoadVisualPrompt
|
8
|
+
from ultralytics.models.yolo.detect import DetectionPredictor
|
9
|
+
from ultralytics.models.yolo.segment import SegmentationPredictor
|
10
|
+
|
11
|
+
|
12
|
+
class YOLOEVPDetectPredictor(DetectionPredictor):
|
13
|
+
"""
|
14
|
+
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
15
|
+
|
16
|
+
This mixin provides common functionality for YOLO models that use visual prompting, including
|
17
|
+
model setup, prompt handling, and preprocessing transformations.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
model (torch.nn.Module): The YOLO model for inference.
|
21
|
+
device (torch.device): Device to run the model on (CPU or CUDA).
|
22
|
+
prompts (dict): Visual prompts containing class indices and bounding boxes or masks.
|
23
|
+
|
24
|
+
Methods:
|
25
|
+
setup_model: Initialize the YOLO model and set it to evaluation mode.
|
26
|
+
set_return_vpe: Set whether to return visual prompt embeddings.
|
27
|
+
set_prompts: Set the visual prompts for the model.
|
28
|
+
pre_transform: Preprocess images and prompts before inference.
|
29
|
+
inference: Run inference with visual prompts.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def setup_model(self, model, verbose=True):
|
33
|
+
"""
|
34
|
+
Sets up the model for prediction.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
model (torch.nn.Module): Model to load or use.
|
38
|
+
verbose (bool): If True, provides detailed logging.
|
39
|
+
"""
|
40
|
+
super().setup_model(model, verbose=verbose)
|
41
|
+
self.done_warmup = True
|
42
|
+
|
43
|
+
def set_prompts(self, prompts):
|
44
|
+
"""
|
45
|
+
Set the visual prompts for the model.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
49
|
+
Must include a 'cls' key with class indices.
|
50
|
+
"""
|
51
|
+
self.prompts = prompts
|
52
|
+
|
53
|
+
def pre_transform(self, im):
|
54
|
+
"""
|
55
|
+
Preprocess images and prompts before inference.
|
56
|
+
|
57
|
+
This method applies letterboxing to the input image and transforms the visual prompts
|
58
|
+
(bounding boxes or masks) accordingly.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
im (list): List containing a single input image.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
(list): Preprocessed image ready for model inference.
|
65
|
+
|
66
|
+
Raises:
|
67
|
+
ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
|
68
|
+
"""
|
69
|
+
img = super().pre_transform(im)
|
70
|
+
bboxes = self.prompts.pop("bboxes", None)
|
71
|
+
masks = self.prompts.pop("masks", None)
|
72
|
+
category = self.prompts["cls"]
|
73
|
+
if len(img) == 1:
|
74
|
+
visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
|
75
|
+
self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
|
76
|
+
else:
|
77
|
+
# NOTE: only supports bboxes as prompts for now
|
78
|
+
assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
|
79
|
+
# NOTE: needs List[np.ndarray]
|
80
|
+
assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
|
81
|
+
f"Expected List[np.ndarray], but got {bboxes}!"
|
82
|
+
)
|
83
|
+
assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
|
84
|
+
f"Expected List[np.ndarray], but got {category}!"
|
85
|
+
)
|
86
|
+
assert len(im) == len(category) == len(bboxes), (
|
87
|
+
f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
|
88
|
+
)
|
89
|
+
visuals = [
|
90
|
+
self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
|
91
|
+
for i in range(len(img))
|
92
|
+
]
|
93
|
+
self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
|
94
|
+
|
95
|
+
return img
|
96
|
+
|
97
|
+
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
98
|
+
"""
|
99
|
+
Processes a single image by resizing bounding boxes or masks and generating visuals.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
dst_shape (tuple): The target shape (height, width) of the image.
|
103
|
+
src_shape (tuple): The original shape (height, width) of the image.
|
104
|
+
category (str): The category of the image for visual prompts.
|
105
|
+
bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2]. Defaults to None.
|
106
|
+
masks (np.ndarray, optional): A list of masks corresponding to the image. Defaults to None.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
visuals: The processed visuals for the image.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If neither `bboxes` nor `masks` are provided.
|
113
|
+
"""
|
114
|
+
if bboxes is not None and len(bboxes):
|
115
|
+
bboxes = np.array(bboxes, dtype=np.float32)
|
116
|
+
if bboxes.ndim == 1:
|
117
|
+
bboxes = bboxes[None, :]
|
118
|
+
# Calculate scaling factor and adjust bounding boxes
|
119
|
+
gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
|
120
|
+
bboxes *= gain
|
121
|
+
bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
|
122
|
+
bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
|
123
|
+
elif masks is not None:
|
124
|
+
# Resize and process masks
|
125
|
+
resized_masks = super().pre_transform(masks)
|
126
|
+
masks = np.stack(resized_masks) # (N, H, W)
|
127
|
+
masks[masks == 114] = 0 # Reset padding values to 0
|
128
|
+
else:
|
129
|
+
raise ValueError("Please provide valid bboxes or masks")
|
130
|
+
|
131
|
+
# Generate visuals using the visual prompt loader
|
132
|
+
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
133
|
+
|
134
|
+
def inference(self, im, *args, **kwargs):
|
135
|
+
"""
|
136
|
+
Run inference with visual prompts.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
im (torch.Tensor): Input image tensor.
|
140
|
+
*args (Any): Variable length argument list.
|
141
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
(torch.Tensor): Model prediction results.
|
145
|
+
"""
|
146
|
+
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
147
|
+
|
148
|
+
def get_vpe(self, source):
|
149
|
+
"""
|
150
|
+
Processes the source to get the visual prompt embeddings (VPE).
|
151
|
+
|
152
|
+
Args:
|
153
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
|
154
|
+
of the image to make predictions on. Accepts various types including file paths, URLs, PIL
|
155
|
+
images, numpy arrays, and torch tensors.
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|
159
|
+
"""
|
160
|
+
self.setup_source(source)
|
161
|
+
assert len(self.dataset) == 1, "get_vpe only supports one image!"
|
162
|
+
for _, im0s, _ in self.dataset:
|
163
|
+
im = self.preprocess(im0s)
|
164
|
+
return self.model(im, vpe=self.prompts, return_vpe=True)
|
165
|
+
|
166
|
+
|
167
|
+
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
|
168
|
+
"""Predictor for YOLOE VP segmentation."""
|
169
|
+
|
170
|
+
pass
|