dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
- tests/conftest.py +5 -8
- tests/test_cli.py +1 -8
- tests/test_python.py +1 -2
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +34 -49
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +244 -323
- ultralytics/data/base.py +12 -22
- ultralytics/data/build.py +47 -40
- ultralytics/data/converter.py +32 -42
- ultralytics/data/dataset.py +43 -71
- ultralytics/data/loaders.py +22 -34
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +27 -36
- ultralytics/engine/exporter.py +49 -116
- ultralytics/engine/model.py +144 -180
- ultralytics/engine/predictor.py +18 -29
- ultralytics/engine/results.py +165 -231
- ultralytics/engine/trainer.py +11 -19
- ultralytics/engine/tuner.py +13 -23
- ultralytics/engine/validator.py +6 -10
- ultralytics/hub/__init__.py +7 -12
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +3 -6
- ultralytics/models/fastsam/model.py +6 -8
- ultralytics/models/fastsam/predict.py +5 -10
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +2 -4
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -18
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +13 -20
- ultralytics/models/sam/amg.py +12 -18
- ultralytics/models/sam/build.py +6 -9
- ultralytics/models/sam/model.py +16 -23
- ultralytics/models/sam/modules/blocks.py +62 -84
- ultralytics/models/sam/modules/decoders.py +17 -24
- ultralytics/models/sam/modules/encoders.py +40 -56
- ultralytics/models/sam/modules/memory_attention.py +10 -16
- ultralytics/models/sam/modules/sam.py +41 -47
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +17 -27
- ultralytics/models/sam/modules/utils.py +31 -42
- ultralytics/models/sam/predict.py +172 -209
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/classify/predict.py +8 -11
- ultralytics/models/yolo/classify/train.py +8 -16
- ultralytics/models/yolo/classify/val.py +13 -20
- ultralytics/models/yolo/detect/predict.py +4 -8
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +38 -48
- ultralytics/models/yolo/model.py +35 -47
- ultralytics/models/yolo/obb/predict.py +5 -8
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +20 -28
- ultralytics/models/yolo/pose/predict.py +5 -8
- ultralytics/models/yolo/pose/train.py +4 -8
- ultralytics/models/yolo/pose/val.py +31 -39
- ultralytics/models/yolo/segment/predict.py +9 -14
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +16 -26
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -16
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/autobackend.py +10 -18
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +99 -185
- ultralytics/nn/modules/conv.py +45 -90
- ultralytics/nn/modules/head.py +44 -98
- ultralytics/nn/modules/transformer.py +44 -76
- ultralytics/nn/modules/utils.py +14 -19
- ultralytics/nn/tasks.py +86 -146
- ultralytics/nn/text_model.py +25 -40
- ultralytics/solutions/ai_gym.py +10 -16
- ultralytics/solutions/analytics.py +7 -10
- ultralytics/solutions/config.py +4 -5
- ultralytics/solutions/distance_calculation.py +9 -12
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +8 -12
- ultralytics/solutions/object_cropper.py +5 -8
- ultralytics/solutions/parking_management.py +12 -14
- ultralytics/solutions/queue_management.py +4 -6
- ultralytics/solutions/region_counter.py +7 -10
- ultralytics/solutions/security_alarm.py +14 -19
- ultralytics/solutions/similarity_search.py +7 -12
- ultralytics/solutions/solutions.py +31 -53
- ultralytics/solutions/speed_estimation.py +6 -9
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/basetrack.py +2 -4
- ultralytics/trackers/bot_sort.py +6 -11
- ultralytics/trackers/byte_tracker.py +10 -15
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +6 -12
- ultralytics/trackers/utils/kalman_filter.py +35 -43
- ultralytics/trackers/utils/matching.py +6 -10
- ultralytics/utils/__init__.py +61 -100
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +11 -13
- ultralytics/utils/benchmarks.py +25 -35
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +2 -4
- ultralytics/utils/callbacks/comet.py +30 -44
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +4 -6
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +4 -6
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +29 -56
- ultralytics/utils/cpu.py +1 -2
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +17 -27
- ultralytics/utils/errors.py +6 -8
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -239
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +11 -17
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +10 -15
- ultralytics/utils/git.py +5 -7
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +11 -15
- ultralytics/utils/loss.py +8 -14
- ultralytics/utils/metrics.py +98 -138
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +47 -74
- ultralytics/utils/patches.py +11 -18
- ultralytics/utils/plotting.py +29 -42
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +45 -73
- ultralytics/utils/tqdm.py +6 -8
- ultralytics/utils/triton.py +9 -12
- ultralytics/utils/tuner.py +1 -2
- dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
|
@@ -12,8 +12,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
17
16
|
|
|
18
17
|
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
|
19
18
|
of pose keypoints alongside bounding boxes.
|
|
@@ -39,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
39
38
|
"""
|
|
40
39
|
|
|
41
40
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
42
|
-
"""
|
|
43
|
-
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
41
|
+
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
44
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -68,8 +66,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
68
66
|
weights: str | Path | None = None,
|
|
69
67
|
verbose: bool = True,
|
|
70
68
|
) -> PoseModel:
|
|
71
|
-
"""
|
|
72
|
-
Get pose estimation model with specified configuration and weights.
|
|
69
|
+
"""Get pose estimation model with specified configuration and weights.
|
|
73
70
|
|
|
74
71
|
Args:
|
|
75
72
|
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
|
@@ -105,8 +102,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
105
102
|
)
|
|
106
103
|
|
|
107
104
|
def get_dataset(self) -> dict[str, Any]:
|
|
108
|
-
"""
|
|
109
|
-
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
105
|
+
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
110
106
|
|
|
111
107
|
Returns:
|
|
112
108
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -14,11 +14,10 @@ from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PoseValidator(DetectionValidator):
|
|
17
|
-
"""
|
|
18
|
-
A class extending the DetectionValidator class for validation based on a pose model.
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
19
18
|
|
|
20
|
-
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
21
|
-
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
|
|
|
33
32
|
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
34
33
|
dimensions.
|
|
35
34
|
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
36
|
-
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
|
37
|
-
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
38
37
|
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
39
38
|
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
40
39
|
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
@@ -49,8 +48,7 @@ class PoseValidator(DetectionValidator):
|
|
|
49
48
|
"""
|
|
50
49
|
|
|
51
50
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
52
|
-
"""
|
|
53
|
-
Initialize a PoseValidator object for pose estimation validation.
|
|
51
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
54
52
|
|
|
55
53
|
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
56
54
|
specialized metrics for pose evaluation.
|
|
@@ -106,8 +104,7 @@ class PoseValidator(DetectionValidator):
|
|
|
106
104
|
)
|
|
107
105
|
|
|
108
106
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
109
|
-
"""
|
|
110
|
-
Initialize evaluation metrics for YOLO pose validation.
|
|
107
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
111
108
|
|
|
112
109
|
Args:
|
|
113
110
|
model (torch.nn.Module): Model to validate.
|
|
@@ -119,17 +116,15 @@ class PoseValidator(DetectionValidator):
|
|
|
119
116
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
120
117
|
|
|
121
118
|
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
122
|
-
"""
|
|
123
|
-
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
119
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
124
120
|
|
|
125
|
-
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
(typically [N, 17, 3] for COCO pose format).
|
|
121
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
122
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
123
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
129
124
|
|
|
130
125
|
Args:
|
|
131
|
-
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
|
132
|
-
|
|
126
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
127
|
+
scores, class predictions, and keypoint data.
|
|
133
128
|
|
|
134
129
|
Returns:
|
|
135
130
|
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
@@ -138,10 +133,10 @@ class PoseValidator(DetectionValidator):
|
|
|
138
133
|
- 'cls': Class predictions
|
|
139
134
|
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
140
135
|
|
|
141
|
-
|
|
142
|
-
If no keypoints are present in a prediction (empty keypoints), that prediction
|
|
143
|
-
|
|
144
|
-
|
|
136
|
+
Notes:
|
|
137
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
138
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
139
|
+
task-specific data beyond basic detection.
|
|
145
140
|
"""
|
|
146
141
|
preds = super().postprocess(preds)
|
|
147
142
|
for pred in preds:
|
|
@@ -149,8 +144,7 @@ class PoseValidator(DetectionValidator):
|
|
|
149
144
|
return preds
|
|
150
145
|
|
|
151
146
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
152
|
-
"""
|
|
153
|
-
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
147
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
154
148
|
|
|
155
149
|
Args:
|
|
156
150
|
si (int): Batch index.
|
|
@@ -173,18 +167,18 @@ class PoseValidator(DetectionValidator):
|
|
|
173
167
|
return pbatch
|
|
174
168
|
|
|
175
169
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
176
|
-
"""
|
|
177
|
-
|
|
170
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
171
|
+
truth.
|
|
178
172
|
|
|
179
173
|
Args:
|
|
180
174
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
181
175
|
and 'keypoints' for keypoint predictions.
|
|
182
|
-
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
|
183
|
-
|
|
176
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
177
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
184
178
|
|
|
185
179
|
Returns:
|
|
186
|
-
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
|
187
|
-
|
|
180
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
181
|
+
positives across 10 IoU levels.
|
|
188
182
|
|
|
189
183
|
Notes:
|
|
190
184
|
`0.53` scale factor used in area computation is referenced from
|
|
@@ -203,11 +197,10 @@ class PoseValidator(DetectionValidator):
|
|
|
203
197
|
return tp
|
|
204
198
|
|
|
205
199
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Save YOLO pose detections to a text file in normalized coordinates.
|
|
200
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
208
201
|
|
|
209
202
|
Args:
|
|
210
|
-
predn (dict[str, torch.Tensor]):
|
|
203
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
211
204
|
save_conf (bool): Whether to save confidence scores.
|
|
212
205
|
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
213
206
|
file (Path): Output file path to save detections.
|
|
@@ -227,15 +220,14 @@ class PoseValidator(DetectionValidator):
|
|
|
227
220
|
).save_txt(file, save_conf=save_conf)
|
|
228
221
|
|
|
229
222
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
230
|
-
"""
|
|
231
|
-
Convert YOLO predictions to COCO JSON format.
|
|
223
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
232
224
|
|
|
233
|
-
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
|
234
|
-
|
|
225
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
226
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
235
227
|
|
|
236
228
|
Args:
|
|
237
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
|
238
|
-
|
|
229
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
230
|
+
tensors.
|
|
239
231
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
240
232
|
|
|
241
233
|
Notes:
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
11
10
|
|
|
12
11
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
13
12
|
prediction results.
|
|
@@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
32
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
34
|
-
"""
|
|
35
|
-
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
36
34
|
|
|
37
35
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
38
36
|
prediction results.
|
|
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
46
44
|
self.args.task = "segment"
|
|
47
45
|
|
|
48
46
|
def postprocess(self, preds, img, orig_imgs):
|
|
49
|
-
"""
|
|
50
|
-
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
51
48
|
|
|
52
49
|
Args:
|
|
53
50
|
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
55
52
|
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
56
53
|
|
|
57
54
|
Returns:
|
|
58
|
-
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
|
59
|
-
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
60
57
|
|
|
61
58
|
Examples:
|
|
62
59
|
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
|
@@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
67
64
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
68
65
|
|
|
69
66
|
def construct_results(self, preds, img, orig_imgs, protos):
|
|
70
|
-
"""
|
|
71
|
-
Construct a list of result objects from the predictions.
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
70
|
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
77
73
|
protos (list[torch.Tensor]): List of prototype masks.
|
|
78
74
|
|
|
79
75
|
Returns:
|
|
80
|
-
(list[Results]): List of result objects containing the original images, image paths, class names,
|
|
81
|
-
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
82
78
|
"""
|
|
83
79
|
return [
|
|
84
80
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
@@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
86
82
|
]
|
|
87
83
|
|
|
88
84
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
89
|
-
"""
|
|
90
|
-
Construct a single result object from the prediction.
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
91
86
|
|
|
92
87
|
Args:
|
|
93
88
|
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
@@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
16
15
|
|
|
17
16
|
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
18
17
|
functionality including model initialization, validation, and visualization.
|
|
@@ -28,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
31
|
-
"""
|
|
32
|
-
Initialize a SegmentationTrainer object.
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
33
31
|
|
|
34
32
|
Args:
|
|
35
33
|
cfg (dict): Configuration dictionary with default training settings.
|
|
@@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
42
40
|
super().__init__(cfg, overrides, _callbacks)
|
|
43
41
|
|
|
44
42
|
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
45
|
-
"""
|
|
46
|
-
Initialize and return a SegmentationModel with specified configuration and weights.
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
46
|
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
@@ -17,11 +17,10 @@ from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class SegmentationValidator(DetectionValidator):
|
|
20
|
-
"""
|
|
21
|
-
A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
20
|
+
"""A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
22
21
|
|
|
23
|
-
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
|
24
|
-
|
|
22
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
|
|
23
|
+
compute metrics such as mAP for both detection and segmentation tasks.
|
|
25
24
|
|
|
26
25
|
Attributes:
|
|
27
26
|
plot_masks (list): List to store masks for plotting.
|
|
@@ -38,8 +37,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
38
37
|
"""
|
|
39
38
|
|
|
40
39
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
41
|
-
"""
|
|
42
|
-
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
40
|
+
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
43
41
|
|
|
44
42
|
Args:
|
|
45
43
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
|
@@ -53,8 +51,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
53
51
|
self.metrics = SegmentMetrics()
|
|
54
52
|
|
|
55
53
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Preprocess batch of images for YOLO segmentation validation.
|
|
54
|
+
"""Preprocess batch of images for YOLO segmentation validation.
|
|
58
55
|
|
|
59
56
|
Args:
|
|
60
57
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -67,8 +64,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
67
64
|
return batch
|
|
68
65
|
|
|
69
66
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
70
|
-
"""
|
|
71
|
-
Initialize metrics and select mask processing function based on save_json flag.
|
|
67
|
+
"""Initialize metrics and select mask processing function based on save_json flag.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
70
|
model (torch.nn.Module): Model to validate.
|
|
@@ -96,8 +92,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
96
92
|
)
|
|
97
93
|
|
|
98
94
|
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
|
99
|
-
"""
|
|
100
|
-
Post-process YOLO predictions and return output detections with proto.
|
|
95
|
+
"""Post-process YOLO predictions and return output detections with proto.
|
|
101
96
|
|
|
102
97
|
Args:
|
|
103
98
|
preds (list[torch.Tensor]): Raw predictions from the model.
|
|
@@ -122,8 +117,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
122
117
|
return preds
|
|
123
118
|
|
|
124
119
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
125
|
-
"""
|
|
126
|
-
Prepare a batch for training or inference by processing images and targets.
|
|
120
|
+
"""Prepare a batch for training or inference by processing images and targets.
|
|
127
121
|
|
|
128
122
|
Args:
|
|
129
123
|
si (int): Batch index.
|
|
@@ -149,8 +143,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
149
143
|
return prepared_batch
|
|
150
144
|
|
|
151
145
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
152
|
-
"""
|
|
153
|
-
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
146
|
+
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
154
147
|
|
|
155
148
|
Args:
|
|
156
149
|
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
|
@@ -159,14 +152,14 @@ class SegmentationValidator(DetectionValidator):
|
|
|
159
152
|
Returns:
|
|
160
153
|
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
|
161
154
|
|
|
162
|
-
Notes:
|
|
163
|
-
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
164
|
-
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
165
|
-
|
|
166
155
|
Examples:
|
|
167
156
|
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
168
157
|
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
169
158
|
>>> correct_preds = validator._process_batch(preds, batch)
|
|
159
|
+
|
|
160
|
+
Notes:
|
|
161
|
+
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
162
|
+
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
170
163
|
"""
|
|
171
164
|
tp = super()._process_batch(preds, batch)
|
|
172
165
|
gt_cls = batch["cls"]
|
|
@@ -179,8 +172,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
179
172
|
return tp
|
|
180
173
|
|
|
181
174
|
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
182
|
-
"""
|
|
183
|
-
Plot batch predictions with masks and bounding boxes.
|
|
175
|
+
"""Plot batch predictions with masks and bounding boxes.
|
|
184
176
|
|
|
185
177
|
Args:
|
|
186
178
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -195,8 +187,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
195
187
|
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
|
196
188
|
|
|
197
189
|
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
198
|
-
"""
|
|
199
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
190
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
200
191
|
|
|
201
192
|
Args:
|
|
202
193
|
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
|
@@ -215,8 +206,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
215
206
|
).save_txt(file, save_conf=save_conf)
|
|
216
207
|
|
|
217
208
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
218
|
-
"""
|
|
219
|
-
Save one JSON result for COCO evaluation.
|
|
209
|
+
"""Save one JSON result for COCO evaluation.
|
|
220
210
|
|
|
221
211
|
Args:
|
|
222
212
|
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
|
@@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class WorldTrainer(DetectionTrainer):
|
|
27
|
-
"""
|
|
28
|
-
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
27
|
+
"""A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
29
28
|
|
|
30
29
|
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
|
31
30
|
features for improved object detection and understanding. It handles text embedding generation and caching to
|
|
@@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
54
53
|
"""
|
|
55
54
|
|
|
56
55
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
57
|
-
"""
|
|
58
|
-
Initialize a WorldTrainer object with given arguments.
|
|
56
|
+
"""Initialize a WorldTrainer object with given arguments.
|
|
59
57
|
|
|
60
58
|
Args:
|
|
61
59
|
cfg (dict[str, Any]): Configuration for the trainer.
|
|
@@ -69,8 +67,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
69
67
|
self.text_embeddings = None
|
|
70
68
|
|
|
71
69
|
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
|
72
|
-
"""
|
|
73
|
-
Return WorldModel initialized with specified config and weights.
|
|
70
|
+
"""Return WorldModel initialized with specified config and weights.
|
|
74
71
|
|
|
75
72
|
Args:
|
|
76
73
|
cfg (dict[str, Any] | str, optional): Model configuration.
|
|
@@ -95,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
95
92
|
return model
|
|
96
93
|
|
|
97
94
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
98
|
-
"""
|
|
99
|
-
Build YOLO Dataset for training or validation.
|
|
95
|
+
"""Build YOLO Dataset for training or validation.
|
|
100
96
|
|
|
101
97
|
Args:
|
|
102
98
|
img_path (str): Path to the folder containing images.
|
|
@@ -115,11 +111,10 @@ class WorldTrainer(DetectionTrainer):
|
|
|
115
111
|
return dataset
|
|
116
112
|
|
|
117
113
|
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
|
118
|
-
"""
|
|
119
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
|
114
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
120
115
|
|
|
121
|
-
This method collects unique category names from all datasets, then generates and caches text embeddings
|
|
122
|
-
|
|
116
|
+
This method collects unique category names from all datasets, then generates and caches text embeddings for
|
|
117
|
+
these categories to improve training efficiency.
|
|
123
118
|
|
|
124
119
|
Args:
|
|
125
120
|
datasets (list[Any]): List of datasets from which to extract category names.
|
|
@@ -141,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
141
136
|
self.text_embeddings = text_embeddings
|
|
142
137
|
|
|
143
138
|
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
|
144
|
-
"""
|
|
145
|
-
Generate text embeddings for a list of text samples.
|
|
139
|
+
"""Generate text embeddings for a list of text samples.
|
|
146
140
|
|
|
147
141
|
Args:
|
|
148
142
|
texts (list[str]): List of text samples to encode.
|
|
@@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class WorldTrainerFromScratch(WorldTrainer):
|
|
13
|
-
"""
|
|
14
|
-
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
13
|
+
"""A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
15
14
|
|
|
16
15
|
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
|
17
16
|
supporting training YOLO-World models with combined vision-language capabilities.
|
|
@@ -53,11 +52,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
53
52
|
"""
|
|
54
53
|
|
|
55
54
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
56
|
-
"""
|
|
57
|
-
Initialize a WorldTrainerFromScratch object.
|
|
55
|
+
"""Initialize a WorldTrainerFromScratch object.
|
|
58
56
|
|
|
59
|
-
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
|
|
60
|
-
|
|
57
|
+
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
|
|
58
|
+
detection and grounding datasets for vision-language capabilities.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
61
|
cfg (dict): Configuration dictionary with default parameters for model training.
|
|
@@ -87,11 +85,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
87
85
|
super().__init__(cfg, overrides, _callbacks)
|
|
88
86
|
|
|
89
87
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
90
|
-
"""
|
|
91
|
-
Build YOLO Dataset for training or validation.
|
|
88
|
+
"""Build YOLO Dataset for training or validation.
|
|
92
89
|
|
|
93
|
-
This method constructs appropriate datasets based on the mode and input paths, handling both
|
|
94
|
-
|
|
90
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
91
|
+
datasets and grounding datasets with different formats.
|
|
95
92
|
|
|
96
93
|
Args:
|
|
97
94
|
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
@@ -122,11 +119,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
122
119
|
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
|
123
120
|
|
|
124
121
|
def get_dataset(self):
|
|
125
|
-
"""
|
|
126
|
-
Get train and validation paths from data dictionary.
|
|
122
|
+
"""Get train and validation paths from data dictionary.
|
|
127
123
|
|
|
128
|
-
Processes the data configuration to extract paths for training and validation datasets,
|
|
129
|
-
|
|
124
|
+
Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
|
|
125
|
+
detection datasets and grounding datasets.
|
|
130
126
|
|
|
131
127
|
Returns:
|
|
132
128
|
train_path (str): Train dataset path.
|
|
@@ -187,8 +183,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
187
183
|
pass
|
|
188
184
|
|
|
189
185
|
def final_eval(self):
|
|
190
|
-
"""
|
|
191
|
-
Perform final evaluation and validation for the YOLO-World model.
|
|
186
|
+
"""Perform final evaluation and validation for the YOLO-World model.
|
|
192
187
|
|
|
193
188
|
Configures the validator with appropriate dataset and split information before running evaluation.
|
|
194
189
|
|
|
@@ -9,11 +9,10 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
12
|
-
"""
|
|
13
|
-
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
12
|
+
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
14
13
|
|
|
15
|
-
This mixin provides common functionality for YOLO models that use visual prompting, including
|
|
16
|
-
|
|
14
|
+
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
|
|
15
|
+
handling, and preprocessing transformations.
|
|
17
16
|
|
|
18
17
|
Attributes:
|
|
19
18
|
model (torch.nn.Module): The YOLO model for inference.
|
|
@@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
29
28
|
"""
|
|
30
29
|
|
|
31
30
|
def setup_model(self, model, verbose: bool = True):
|
|
32
|
-
"""
|
|
33
|
-
Set up the model for prediction.
|
|
31
|
+
"""Set up the model for prediction.
|
|
34
32
|
|
|
35
33
|
Args:
|
|
36
34
|
model (torch.nn.Module): Model to load or use.
|
|
@@ -40,21 +38,19 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
40
38
|
self.done_warmup = True
|
|
41
39
|
|
|
42
40
|
def set_prompts(self, prompts):
|
|
43
|
-
"""
|
|
44
|
-
Set the visual prompts for the model.
|
|
41
|
+
"""Set the visual prompts for the model.
|
|
45
42
|
|
|
46
43
|
Args:
|
|
47
|
-
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
|
48
|
-
|
|
44
|
+
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
|
|
45
|
+
with class indices.
|
|
49
46
|
"""
|
|
50
47
|
self.prompts = prompts
|
|
51
48
|
|
|
52
49
|
def pre_transform(self, im):
|
|
53
|
-
"""
|
|
54
|
-
Preprocess images and prompts before inference.
|
|
50
|
+
"""Preprocess images and prompts before inference.
|
|
55
51
|
|
|
56
|
-
This method applies letterboxing to the input image and transforms the visual prompts
|
|
57
|
-
|
|
52
|
+
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
|
|
53
|
+
accordingly.
|
|
58
54
|
|
|
59
55
|
Args:
|
|
60
56
|
im (list): List containing a single input image.
|
|
@@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
94
90
|
return img
|
|
95
91
|
|
|
96
92
|
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
|
97
|
-
"""
|
|
98
|
-
Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
93
|
+
"""Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
99
94
|
|
|
100
95
|
Args:
|
|
101
96
|
dst_shape (tuple): The target shape (height, width) of the image.
|
|
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
131
126
|
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
|
132
127
|
|
|
133
128
|
def inference(self, im, *args, **kwargs):
|
|
134
|
-
"""
|
|
135
|
-
Run inference with visual prompts.
|
|
129
|
+
"""Run inference with visual prompts.
|
|
136
130
|
|
|
137
131
|
Args:
|
|
138
132
|
im (torch.Tensor): Input image tensor.
|
|
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
145
139
|
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
|
146
140
|
|
|
147
141
|
def get_vpe(self, source):
|
|
148
|
-
"""
|
|
149
|
-
Process the source to get the visual prompt embeddings (VPE).
|
|
142
|
+
"""Process the source to get the visual prompt embeddings (VPE).
|
|
150
143
|
|
|
151
144
|
Args:
|
|
152
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
|
153
|
-
|
|
154
|
-
|
|
145
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
|
|
146
|
+
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
147
|
+
torch tensors.
|
|
155
148
|
|
|
156
149
|
Returns:
|
|
157
150
|
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|