ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +13 -13
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/nas/model.py +1 -0
  13. ultralytics/models/nas/predict.py +4 -24
  14. ultralytics/models/nas/val.py +1 -4
  15. ultralytics/models/yolo/__init__.py +3 -3
  16. ultralytics/models/yolo/detect/val.py +6 -1
  17. ultralytics/models/yolo/model.py +182 -3
  18. ultralytics/models/yolo/segment/val.py +43 -16
  19. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  20. ultralytics/models/yolo/yoloe/predict.py +170 -0
  21. ultralytics/models/yolo/yoloe/train.py +355 -0
  22. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  23. ultralytics/models/yolo/yoloe/val.py +187 -0
  24. ultralytics/nn/autobackend.py +3 -2
  25. ultralytics/nn/modules/__init__.py +18 -1
  26. ultralytics/nn/modules/block.py +17 -1
  27. ultralytics/nn/modules/head.py +359 -22
  28. ultralytics/nn/tasks.py +276 -10
  29. ultralytics/nn/text_model.py +193 -0
  30. ultralytics/utils/callbacks/comet.py +3 -6
  31. ultralytics/utils/downloads.py +6 -2
  32. ultralytics/utils/instance.py +7 -2
  33. ultralytics/utils/loss.py +67 -6
  34. ultralytics/utils/plotting.py +1 -1
  35. ultralytics/utils/tal.py +1 -1
  36. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
  37. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
  38. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
  39. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
  40. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
  41. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,16 @@ from pathlib import Path
4
4
 
5
5
  from ultralytics.engine.model import Model
6
6
  from ultralytics.models import yolo
7
- from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
7
+ from ultralytics.nn.tasks import (
8
+ ClassificationModel,
9
+ DetectionModel,
10
+ OBBModel,
11
+ PoseModel,
12
+ SegmentationModel,
13
+ WorldModel,
14
+ YOLOEModel,
15
+ YOLOESegModel,
16
+ )
8
17
  from ultralytics.utils import ROOT, yaml_load
9
18
 
10
19
 
@@ -12,12 +21,16 @@ class YOLO(Model):
12
21
  """YOLO (You Only Look Once) object detection model."""
13
22
 
14
23
  def __init__(self, model="yolo11n.pt", task=None, verbose=False):
15
- """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
24
+ """Initialize YOLO model, switching to YOLOWorld/YOLOE if model filename contains '-world'/'yoloe'."""
16
25
  path = Path(model)
17
26
  if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
18
27
  new_instance = YOLOWorld(path, verbose=verbose)
19
28
  self.__class__ = type(new_instance)
20
29
  self.__dict__ = new_instance.__dict__
30
+ elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
31
+ new_instance = YOLOE(path, task=task, verbose=verbose)
32
+ self.__class__ = type(new_instance)
33
+ self.__dict__ = new_instance.__dict__
21
34
  else:
22
35
  # Continue with default YOLO initialization
23
36
  super().__init__(model=model, task=task, verbose=verbose)
@@ -96,7 +109,7 @@ class YOLOWorld(Model):
96
109
  Set the model's class names for detection.
97
110
 
98
111
  Args:
99
- classes (List(str)): A list of categories i.e. ["person"].
112
+ classes (list[str]): A list of categories i.e. ["person"].
100
113
  """
101
114
  self.model.set_classes(classes)
102
115
  # Remove background if it's given
@@ -108,3 +121,169 @@ class YOLOWorld(Model):
108
121
  # Reset method class names
109
122
  if self.predictor:
110
123
  self.predictor.model.names = classes
124
+
125
+
126
+ class YOLOE(Model):
127
+ """YOLOE object detection and segmentation model."""
128
+
129
+ def __init__(self, model="yoloe-v8s-seg.pt", task=None, verbose=False) -> None:
130
+ """
131
+ Initialize YOLOE model with a pre-trained model file.
132
+
133
+ Args:
134
+ model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
135
+ task (str, optional): Task type for the model. Auto-detected if None.
136
+ verbose (bool): If True, prints additional information during initialization.
137
+ """
138
+ super().__init__(model=model, task=task, verbose=verbose)
139
+
140
+ # Assign default COCO class names when there are no custom names
141
+ if not hasattr(self.model, "names"):
142
+ self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
143
+
144
+ @property
145
+ def task_map(self):
146
+ """Map head to model, validator, and predictor classes."""
147
+ return {
148
+ "detect": {
149
+ "model": YOLOEModel,
150
+ "validator": yolo.yoloe.YOLOEDetectValidator,
151
+ "predictor": yolo.detect.DetectionPredictor,
152
+ "trainer": yolo.yoloe.YOLOETrainer,
153
+ },
154
+ "segment": {
155
+ "model": YOLOESegModel,
156
+ "validator": yolo.yoloe.YOLOESegValidator,
157
+ "predictor": yolo.segment.SegmentationPredictor,
158
+ "trainer": yolo.yoloe.YOLOESegTrainer,
159
+ },
160
+ }
161
+
162
+ def get_text_pe(self, texts):
163
+ """Get text positional embeddings for the given texts."""
164
+ assert isinstance(self.model, YOLOEModel)
165
+ return self.model.get_text_pe(texts)
166
+
167
+ def get_visual_pe(self, img, visual):
168
+ """Get visual positional embeddings for the given image and visual features."""
169
+ assert isinstance(self.model, YOLOEModel)
170
+ return self.model.get_visual_pe(img, visual)
171
+
172
+ def set_vocab(self, vocab, names):
173
+ """Set vocabulary and class names for the model."""
174
+ assert isinstance(self.model, YOLOEModel)
175
+ self.model.set_vocab(vocab, names=names)
176
+
177
+ def get_vocab(self, names):
178
+ """Get vocabulary for the given class names."""
179
+ assert isinstance(self.model, YOLOEModel)
180
+ return self.model.get_vocab(names)
181
+
182
+ def set_classes(self, classes, embeddings):
183
+ """
184
+ Set the model's class names and embeddings for detection.
185
+
186
+ Args:
187
+ classes (list[str]): A list of categories i.e. ["person"].
188
+ embeddings (torch.Tensor): Embeddings corresponding to the classes.
189
+ """
190
+ assert isinstance(self.model, YOLOEModel)
191
+ self.model.set_classes(classes, embeddings)
192
+ # Verify no background class is present
193
+ assert " " not in classes
194
+ self.model.names = classes
195
+
196
+ # Reset method class names
197
+ if self.predictor:
198
+ self.predictor.model.names = classes
199
+
200
+ def val(
201
+ self,
202
+ validator=None,
203
+ load_vp=False,
204
+ refer_data=None,
205
+ **kwargs,
206
+ ):
207
+ """
208
+ Validate the model using text or visual prompts.
209
+
210
+ Args:
211
+ validator (callable, optional): A callable validator function. If None, a default validator is loaded.
212
+ load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
213
+ refer_data (str, optional): Path to the reference data for visual prompts.
214
+ **kwargs (Any): Additional keyword arguments to override default settings.
215
+
216
+ Returns:
217
+ (dict): Validation statistics containing metrics computed during validation.
218
+ """
219
+ custom = {"rect": not load_vp} # method defaults
220
+ args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
221
+
222
+ validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
223
+ validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
224
+ self.metrics = validator.metrics
225
+ return validator.metrics
226
+
227
+ def predict(
228
+ self,
229
+ source=None,
230
+ stream: bool = False,
231
+ visual_prompts: dict = {},
232
+ refer_image=None,
233
+ predictor=None,
234
+ **kwargs,
235
+ ):
236
+ """
237
+ Run prediction on images, videos, directories, streams, etc.
238
+
239
+ Args:
240
+ source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
241
+ directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
242
+ stream (bool): Whether to stream the prediction results. If True, results are yielded as a
243
+ generator as they are computed.
244
+ visual_prompts (dict): Dictionary containing visual prompts for the model. Must include 'bboxes' and
245
+ 'cls' keys when non-empty.
246
+ refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
247
+ predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
248
+ loaded based on the task.
249
+ **kwargs (Any): Additional keyword arguments passed to the predictor.
250
+
251
+ Returns:
252
+ (List | generator): List of Results objects or generator of Results objects if stream=True.
253
+
254
+ Examples:
255
+ >>> model = YOLOE("yoloe-v8s-seg.pt")
256
+ >>> results = model.predict("path/to/image.jpg")
257
+ >>> # With visual prompts
258
+ >>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
259
+ >>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
260
+ """
261
+ if len(visual_prompts):
262
+ assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
263
+ f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
264
+ )
265
+ assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
266
+ f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
267
+ f"{len(visual_prompts['cls'])} respectively"
268
+ )
269
+ self.predictor = (predictor or self._smart_load("predictor"))(
270
+ overrides={"task": "segment", "mode": "predict", "save": False, "verbose": False}, _callbacks=self.callbacks
271
+ )
272
+
273
+ if len(visual_prompts):
274
+ num_cls = (
275
+ max(len(set(c)) for c in visual_prompts["cls"])
276
+ if isinstance(source, list) # means multiple images
277
+ else len(set(visual_prompts["cls"]))
278
+ )
279
+ self.model.model[-1].nc = num_cls
280
+ self.model.names = [f"object{i}" for i in range(num_cls)]
281
+ self.predictor.set_prompts(visual_prompts)
282
+
283
+ self.predictor.setup_model(model=self.model)
284
+ if refer_image is not None and len(visual_prompts):
285
+ vpe = self.predictor.get_vpe(refer_image)
286
+ self.model.set_classes(self.model.names, vpe)
287
+ self.predictor = None # reset predictor
288
+
289
+ return super().predict(source, stream, **kwargs)
@@ -364,29 +364,56 @@ class SegmentationValidator(DetectionValidator):
364
364
 
365
365
  def eval_json(self, stats):
366
366
  """Return COCO-style object detection evaluation metrics."""
367
- if self.args.save_json and self.is_coco and len(self.jdict):
368
- anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
367
+ if self.args.save_json and (self.is_lvis or self.is_coco) and len(self.jdict):
369
368
  pred_json = self.save_dir / "predictions.json" # predictions
370
- LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
371
- try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
372
- check_requirements("pycocotools>=2.0.6")
373
- from pycocotools.coco import COCO # noqa
374
- from pycocotools.cocoeval import COCOeval # noqa
375
369
 
370
+ anno_json = (
371
+ self.data["path"]
372
+ / "annotations"
373
+ / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
374
+ ) # annotations
375
+
376
+ pkg = "pycocotools" if self.is_coco else "lvis"
377
+ LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
378
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
376
379
  for x in anno_json, pred_json:
377
380
  assert x.is_file(), f"{x} file not found"
378
- anno = COCO(str(anno_json)) # init annotations api
379
- pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
380
- for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
381
- if self.is_coco:
382
- eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
381
+ check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
382
+ if self.is_coco:
383
+ from pycocotools.coco import COCO # noqa
384
+ from pycocotools.cocoeval import COCOeval # noqa
385
+
386
+ anno = COCO(str(anno_json)) # init annotations api
387
+ pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
388
+ vals = [COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]
389
+ else:
390
+ from lvis import LVIS, LVISEval
391
+
392
+ anno = LVIS(str(anno_json))
393
+ pred = anno._load_json(str(pred_json))
394
+ vals = [LVISEval(anno, pred, "bbox"), LVISEval(anno, pred, "segm")]
395
+
396
+ for i, eval in enumerate(vals):
397
+ eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
383
398
  eval.evaluate()
384
399
  eval.accumulate()
385
400
  eval.summarize()
401
+ if self.is_lvis:
402
+ eval.print_results()
386
403
  idx = i * 4 + 2
387
- stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
388
- :2
389
- ] # update mAP50-95 and mAP50
404
+ # update mAP50-95 and mAP50
405
+ stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = (
406
+ eval.stats[:2] if self.is_coco else [eval.results["AP"], eval.results["AP50"]]
407
+ )
408
+ if self.is_lvis:
409
+ tag = "B" if i == 0 else "M"
410
+ stats[f"metrics/APr({tag})"] = eval.results["APr"]
411
+ stats[f"metrics/APc({tag})"] = eval.results["APc"]
412
+ stats[f"metrics/APf({tag})"] = eval.results["APf"]
413
+
414
+ if self.is_lvis:
415
+ stats["fitness"] = stats["metrics/mAP50-95(B)"]
416
+
390
417
  except Exception as e:
391
- LOGGER.warning(f"pycocotools unable to run: {e}")
418
+ LOGGER.warning(f"{pkg} unable to run: {e}")
392
419
  return stats
@@ -0,0 +1,21 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
4
+ from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOEVPTrainer
5
+ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
6
+ from .val import YOLOEDetectValidator, YOLOESegValidator
7
+
8
+ __all__ = [
9
+ "YOLOETrainer",
10
+ "YOLOEPETrainer",
11
+ "YOLOESegTrainer",
12
+ "YOLOEDetectValidator",
13
+ "YOLOESegValidator",
14
+ "YOLOEPESegTrainer",
15
+ "YOLOESegTrainerFromScratch",
16
+ "YOLOESegVPTrainer",
17
+ "YOLOEVPTrainer",
18
+ "YOLOEPEFreeTrainer",
19
+ "YOLOEVPDetectPredictor",
20
+ "YOLOEVPSegPredictor",
21
+ ]
@@ -0,0 +1,170 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ultralytics.data.augment import LoadVisualPrompt
8
+ from ultralytics.models.yolo.detect import DetectionPredictor
9
+ from ultralytics.models.yolo.segment import SegmentationPredictor
10
+
11
+
12
+ class YOLOEVPDetectPredictor(DetectionPredictor):
13
+ """
14
+ A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
15
+
16
+ This mixin provides common functionality for YOLO models that use visual prompting, including
17
+ model setup, prompt handling, and preprocessing transformations.
18
+
19
+ Attributes:
20
+ model (torch.nn.Module): The YOLO model for inference.
21
+ device (torch.device): Device to run the model on (CPU or CUDA).
22
+ prompts (dict): Visual prompts containing class indices and bounding boxes or masks.
23
+
24
+ Methods:
25
+ setup_model: Initialize the YOLO model and set it to evaluation mode.
26
+ set_return_vpe: Set whether to return visual prompt embeddings.
27
+ set_prompts: Set the visual prompts for the model.
28
+ pre_transform: Preprocess images and prompts before inference.
29
+ inference: Run inference with visual prompts.
30
+ """
31
+
32
+ def setup_model(self, model, verbose=True):
33
+ """
34
+ Sets up the model for prediction.
35
+
36
+ Args:
37
+ model (torch.nn.Module): Model to load or use.
38
+ verbose (bool): If True, provides detailed logging.
39
+ """
40
+ super().setup_model(model, verbose=verbose)
41
+ self.done_warmup = True
42
+
43
+ def set_prompts(self, prompts):
44
+ """
45
+ Set the visual prompts for the model.
46
+
47
+ Args:
48
+ prompts (dict): Dictionary containing class indices and bounding boxes or masks.
49
+ Must include a 'cls' key with class indices.
50
+ """
51
+ self.prompts = prompts
52
+
53
+ def pre_transform(self, im):
54
+ """
55
+ Preprocess images and prompts before inference.
56
+
57
+ This method applies letterboxing to the input image and transforms the visual prompts
58
+ (bounding boxes or masks) accordingly.
59
+
60
+ Args:
61
+ im (list): List containing a single input image.
62
+
63
+ Returns:
64
+ (list): Preprocessed image ready for model inference.
65
+
66
+ Raises:
67
+ ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
68
+ """
69
+ img = super().pre_transform(im)
70
+ bboxes = self.prompts.pop("bboxes", None)
71
+ masks = self.prompts.pop("masks", None)
72
+ category = self.prompts["cls"]
73
+ if len(img) == 1:
74
+ visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
75
+ self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
76
+ else:
77
+ # NOTE: only supports bboxes as prompts for now
78
+ assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
79
+ # NOTE: needs List[np.ndarray]
80
+ assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
81
+ f"Expected List[np.ndarray], but got {bboxes}!"
82
+ )
83
+ assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
84
+ f"Expected List[np.ndarray], but got {category}!"
85
+ )
86
+ assert len(im) == len(category) == len(bboxes), (
87
+ f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
88
+ )
89
+ visuals = [
90
+ self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
91
+ for i in range(len(img))
92
+ ]
93
+ self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
94
+
95
+ return img
96
+
97
+ def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
98
+ """
99
+ Processes a single image by resizing bounding boxes or masks and generating visuals.
100
+
101
+ Args:
102
+ dst_shape (tuple): The target shape (height, width) of the image.
103
+ src_shape (tuple): The original shape (height, width) of the image.
104
+ category (str): The category of the image for visual prompts.
105
+ bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2]. Defaults to None.
106
+ masks (np.ndarray, optional): A list of masks corresponding to the image. Defaults to None.
107
+
108
+ Returns:
109
+ visuals: The processed visuals for the image.
110
+
111
+ Raises:
112
+ ValueError: If neither `bboxes` nor `masks` are provided.
113
+ """
114
+ if bboxes is not None and len(bboxes):
115
+ bboxes = np.array(bboxes, dtype=np.float32)
116
+ if bboxes.ndim == 1:
117
+ bboxes = bboxes[None, :]
118
+ # Calculate scaling factor and adjust bounding boxes
119
+ gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
120
+ bboxes *= gain
121
+ bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
122
+ bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
123
+ elif masks is not None:
124
+ # Resize and process masks
125
+ resized_masks = super().pre_transform(masks)
126
+ masks = np.stack(resized_masks) # (N, H, W)
127
+ masks[masks == 114] = 0 # Reset padding values to 0
128
+ else:
129
+ raise ValueError("Please provide valid bboxes or masks")
130
+
131
+ # Generate visuals using the visual prompt loader
132
+ return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
133
+
134
+ def inference(self, im, *args, **kwargs):
135
+ """
136
+ Run inference with visual prompts.
137
+
138
+ Args:
139
+ im (torch.Tensor): Input image tensor.
140
+ *args (Any): Variable length argument list.
141
+ **kwargs (Any): Arbitrary keyword arguments.
142
+
143
+ Returns:
144
+ (torch.Tensor): Model prediction results.
145
+ """
146
+ return super().inference(im, vpe=self.prompts, *args, **kwargs)
147
+
148
+ def get_vpe(self, source):
149
+ """
150
+ Processes the source to get the visual prompt embeddings (VPE).
151
+
152
+ Args:
153
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
154
+ of the image to make predictions on. Accepts various types including file paths, URLs, PIL
155
+ images, numpy arrays, and torch tensors.
156
+
157
+ Returns:
158
+ (torch.Tensor): The visual prompt embeddings (VPE) from the model.
159
+ """
160
+ self.setup_source(source)
161
+ assert len(self.dataset) == 1, "get_vpe only supports one image!"
162
+ for _, im0s, _ in self.dataset:
163
+ im = self.preprocess(im0s)
164
+ return self.model(im, vpe=self.prompts, return_vpe=True)
165
+
166
+
167
+ class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
168
+ """Predictor for YOLOE VP segmentation."""
169
+
170
+ pass