dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
11
10
|
|
|
12
11
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
13
12
|
prediction results.
|
|
@@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
32
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
34
|
-
"""
|
|
35
|
-
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
36
34
|
|
|
37
35
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
38
36
|
prediction results.
|
|
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
46
44
|
self.args.task = "segment"
|
|
47
45
|
|
|
48
46
|
def postprocess(self, preds, img, orig_imgs):
|
|
49
|
-
"""
|
|
50
|
-
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
51
48
|
|
|
52
49
|
Args:
|
|
53
50
|
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
55
52
|
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
56
53
|
|
|
57
54
|
Returns:
|
|
58
|
-
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
|
59
|
-
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
60
57
|
|
|
61
58
|
Examples:
|
|
62
59
|
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
|
@@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
67
64
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
68
65
|
|
|
69
66
|
def construct_results(self, preds, img, orig_imgs, protos):
|
|
70
|
-
"""
|
|
71
|
-
Construct a list of result objects from the predictions.
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
70
|
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
77
73
|
protos (list[torch.Tensor]): List of prototype masks.
|
|
78
74
|
|
|
79
75
|
Returns:
|
|
80
|
-
(list[Results]): List of result objects containing the original images, image paths, class names,
|
|
81
|
-
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
82
78
|
"""
|
|
83
79
|
return [
|
|
84
80
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
@@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
86
82
|
]
|
|
87
83
|
|
|
88
84
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
89
|
-
"""
|
|
90
|
-
Construct a single result object from the prediction.
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
91
86
|
|
|
92
87
|
Args:
|
|
93
88
|
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
@@ -103,11 +98,12 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
103
98
|
masks = None
|
|
104
99
|
elif self.args.retina_masks:
|
|
105
100
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
106
|
-
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) #
|
|
101
|
+
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
|
|
107
102
|
else:
|
|
108
|
-
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) #
|
|
103
|
+
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
|
|
109
104
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
110
105
|
if masks is not None:
|
|
111
|
-
keep = masks.
|
|
112
|
-
|
|
106
|
+
keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
|
|
107
|
+
if not all(keep): # most predictions have masks
|
|
108
|
+
pred, masks = pred[keep], masks[keep] # indexing is slow
|
|
113
109
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
|
@@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
16
15
|
|
|
17
16
|
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
18
17
|
functionality including model initialization, validation, and visualization.
|
|
@@ -28,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
31
|
-
"""
|
|
32
|
-
Initialize a SegmentationTrainer object.
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
33
31
|
|
|
34
32
|
Args:
|
|
35
33
|
cfg (dict): Configuration dictionary with default training settings.
|
|
@@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
42
40
|
super().__init__(cfg, overrides, _callbacks)
|
|
43
41
|
|
|
44
42
|
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
45
|
-
"""
|
|
46
|
-
Initialize and return a SegmentationModel with specified configuration and weights.
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
46
|
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from multiprocessing.pool import ThreadPool
|
|
6
5
|
from pathlib import Path
|
|
7
6
|
from typing import Any
|
|
8
7
|
|
|
@@ -11,17 +10,16 @@ import torch
|
|
|
11
10
|
import torch.nn.functional as F
|
|
12
11
|
|
|
13
12
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
14
|
-
from ultralytics.utils import LOGGER,
|
|
13
|
+
from ultralytics.utils import LOGGER, ops
|
|
15
14
|
from ultralytics.utils.checks import check_requirements
|
|
16
15
|
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
|
17
16
|
|
|
18
17
|
|
|
19
18
|
class SegmentationValidator(DetectionValidator):
|
|
20
|
-
"""
|
|
21
|
-
A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
19
|
+
"""A class extending the DetectionValidator class for validation based on a segmentation model.
|
|
22
20
|
|
|
23
|
-
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
|
24
|
-
|
|
21
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
|
|
22
|
+
compute metrics such as mAP for both detection and segmentation tasks.
|
|
25
23
|
|
|
26
24
|
Attributes:
|
|
27
25
|
plot_masks (list): List to store masks for plotting.
|
|
@@ -38,11 +36,10 @@ class SegmentationValidator(DetectionValidator):
|
|
|
38
36
|
"""
|
|
39
37
|
|
|
40
38
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
41
|
-
"""
|
|
42
|
-
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
39
|
+
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
|
43
40
|
|
|
44
41
|
Args:
|
|
45
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
42
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
|
|
46
43
|
save_dir (Path, optional): Directory to save results.
|
|
47
44
|
args (namespace, optional): Arguments for the validator.
|
|
48
45
|
_callbacks (list, optional): List of callback functions.
|
|
@@ -53,8 +50,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
53
50
|
self.metrics = SegmentMetrics()
|
|
54
51
|
|
|
55
52
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
56
|
-
"""
|
|
57
|
-
Preprocess batch of images for YOLO segmentation validation.
|
|
53
|
+
"""Preprocess batch of images for YOLO segmentation validation.
|
|
58
54
|
|
|
59
55
|
Args:
|
|
60
56
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -67,8 +63,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
67
63
|
return batch
|
|
68
64
|
|
|
69
65
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
70
|
-
"""
|
|
71
|
-
Initialize metrics and select mask processing function based on save_json flag.
|
|
66
|
+
"""Initialize metrics and select mask processing function based on save_json flag.
|
|
72
67
|
|
|
73
68
|
Args:
|
|
74
69
|
model (torch.nn.Module): Model to validate.
|
|
@@ -96,8 +91,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
96
91
|
)
|
|
97
92
|
|
|
98
93
|
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
|
99
|
-
"""
|
|
100
|
-
Post-process YOLO predictions and return output detections with proto.
|
|
94
|
+
"""Post-process YOLO predictions and return output detections with proto.
|
|
101
95
|
|
|
102
96
|
Args:
|
|
103
97
|
preds (list[torch.Tensor]): Raw predictions from the model.
|
|
@@ -122,8 +116,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
122
116
|
return preds
|
|
123
117
|
|
|
124
118
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
125
|
-
"""
|
|
126
|
-
Prepare a batch for training or inference by processing images and targets.
|
|
119
|
+
"""Prepare a batch for training or inference by processing images and targets.
|
|
127
120
|
|
|
128
121
|
Args:
|
|
129
122
|
si (int): Batch index.
|
|
@@ -149,8 +142,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
149
142
|
return prepared_batch
|
|
150
143
|
|
|
151
144
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
152
|
-
"""
|
|
153
|
-
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
145
|
+
"""Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
|
154
146
|
|
|
155
147
|
Args:
|
|
156
148
|
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
|
@@ -159,28 +151,27 @@ class SegmentationValidator(DetectionValidator):
|
|
|
159
151
|
Returns:
|
|
160
152
|
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
|
161
153
|
|
|
162
|
-
Notes:
|
|
163
|
-
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
164
|
-
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
165
|
-
|
|
166
154
|
Examples:
|
|
167
155
|
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
168
156
|
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
|
169
157
|
>>> correct_preds = validator._process_batch(preds, batch)
|
|
158
|
+
|
|
159
|
+
Notes:
|
|
160
|
+
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
|
161
|
+
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
|
170
162
|
"""
|
|
171
163
|
tp = super()._process_batch(preds, batch)
|
|
172
164
|
gt_cls = batch["cls"]
|
|
173
165
|
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
174
166
|
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
175
167
|
else:
|
|
176
|
-
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
|
|
168
|
+
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1).float()) # float, uint8
|
|
177
169
|
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
|
178
170
|
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
|
179
171
|
return tp
|
|
180
172
|
|
|
181
173
|
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
182
|
-
"""
|
|
183
|
-
Plot batch predictions with masks and bounding boxes.
|
|
174
|
+
"""Plot batch predictions with masks and bounding boxes.
|
|
184
175
|
|
|
185
176
|
Args:
|
|
186
177
|
batch (dict[str, Any]): Batch containing images and annotations.
|
|
@@ -195,8 +186,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
195
186
|
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
|
196
187
|
|
|
197
188
|
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
198
|
-
"""
|
|
199
|
-
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
189
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
|
200
190
|
|
|
201
191
|
Args:
|
|
202
192
|
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
|
@@ -215,24 +205,84 @@ class SegmentationValidator(DetectionValidator):
|
|
|
215
205
|
).save_txt(file, save_conf=save_conf)
|
|
216
206
|
|
|
217
207
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
218
|
-
"""
|
|
219
|
-
Save one JSON result for COCO evaluation.
|
|
208
|
+
"""Save one JSON result for COCO evaluation.
|
|
220
209
|
|
|
221
210
|
Args:
|
|
222
211
|
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
|
223
212
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
224
213
|
"""
|
|
225
|
-
from faster_coco_eval.core.mask import encode # noqa
|
|
226
|
-
|
|
227
|
-
def single_encode(x):
|
|
228
|
-
"""Encode predicted masks as RLE and append results to jdict."""
|
|
229
|
-
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
|
230
|
-
rle["counts"] = rle["counts"].decode("utf-8")
|
|
231
|
-
return rle
|
|
232
214
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
215
|
+
def to_string(counts: list[int]) -> str:
|
|
216
|
+
"""Converts the RLE object into a compact string representation. Each count is delta-encoded and
|
|
217
|
+
variable-length encoded as a string.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
counts (list[int]): List of RLE counts.
|
|
221
|
+
"""
|
|
222
|
+
result = []
|
|
223
|
+
|
|
224
|
+
for i in range(len(counts)):
|
|
225
|
+
x = int(counts[i])
|
|
226
|
+
|
|
227
|
+
# Apply delta encoding for all counts after the second entry
|
|
228
|
+
if i > 2:
|
|
229
|
+
x -= int(counts[i - 2])
|
|
230
|
+
|
|
231
|
+
# Variable-length encode the value
|
|
232
|
+
while True:
|
|
233
|
+
c = x & 0x1F # Take 5 bits
|
|
234
|
+
x >>= 5
|
|
235
|
+
|
|
236
|
+
# If the sign bit (0x10) is set, continue if x != -1;
|
|
237
|
+
# otherwise, continue if x != 0
|
|
238
|
+
more = (x != -1) if (c & 0x10) else (x != 0)
|
|
239
|
+
if more:
|
|
240
|
+
c |= 0x20 # Set continuation bit
|
|
241
|
+
c += 48 # Shift to ASCII
|
|
242
|
+
result.append(chr(c))
|
|
243
|
+
if not more:
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
return "".join(result)
|
|
247
|
+
|
|
248
|
+
def multi_encode(pixels: torch.Tensor) -> list[int]:
|
|
249
|
+
"""Convert multiple binary masks using Run-Length Encoding (RLE).
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
pixels (torch.Tensor): A 2D tensor where each row represents a flattened binary mask with shape [N,
|
|
253
|
+
H*W].
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
(list[int]): A list of RLE counts for each mask.
|
|
257
|
+
"""
|
|
258
|
+
transitions = pixels[:, 1:] != pixels[:, :-1]
|
|
259
|
+
row_idx, col_idx = torch.where(transitions)
|
|
260
|
+
col_idx = col_idx + 1
|
|
261
|
+
|
|
262
|
+
# Compute run lengths
|
|
263
|
+
counts = []
|
|
264
|
+
for i in range(pixels.shape[0]):
|
|
265
|
+
positions = col_idx[row_idx == i]
|
|
266
|
+
if len(positions):
|
|
267
|
+
count = torch.diff(positions).tolist()
|
|
268
|
+
count.insert(0, positions[0].item())
|
|
269
|
+
count.append(len(pixels[i]) - positions[-1].item())
|
|
270
|
+
else:
|
|
271
|
+
count = [len(pixels[i])]
|
|
272
|
+
|
|
273
|
+
# Ensure starting with background (0) count
|
|
274
|
+
if pixels[i][0].item() == 1:
|
|
275
|
+
count = [0, *count]
|
|
276
|
+
counts.append(count)
|
|
277
|
+
|
|
278
|
+
return counts
|
|
279
|
+
|
|
280
|
+
pred_masks = predn["masks"].transpose(2, 1).contiguous().view(len(predn["masks"]), -1) # N, H*W
|
|
281
|
+
h, w = predn["masks"].shape[1:3]
|
|
282
|
+
counts = multi_encode(pred_masks)
|
|
283
|
+
rles = []
|
|
284
|
+
for c in counts:
|
|
285
|
+
rles.append({"size": [h, w], "counts": to_string(c)})
|
|
236
286
|
super().pred_to_json(predn, pbatch)
|
|
237
287
|
for i, r in enumerate(rles):
|
|
238
288
|
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
|
@@ -241,11 +291,9 @@ class SegmentationValidator(DetectionValidator):
|
|
|
241
291
|
"""Scales predictions to the original image size."""
|
|
242
292
|
return {
|
|
243
293
|
**super().scale_preds(predn, pbatch),
|
|
244
|
-
"masks": ops.
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
ratio_pad=pbatch["ratio_pad"],
|
|
248
|
-
),
|
|
294
|
+
"masks": ops.scale_masks(predn["masks"][None], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])[
|
|
295
|
+
0
|
|
296
|
+
].byte(),
|
|
249
297
|
}
|
|
250
298
|
|
|
251
299
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class WorldTrainer(DetectionTrainer):
|
|
27
|
-
"""
|
|
28
|
-
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
27
|
+
"""A trainer class for fine-tuning YOLO World models on close-set datasets.
|
|
29
28
|
|
|
30
29
|
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
|
31
30
|
features for improved object detection and understanding. It handles text embedding generation and caching to
|
|
@@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
54
53
|
"""
|
|
55
54
|
|
|
56
55
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
57
|
-
"""
|
|
58
|
-
Initialize a WorldTrainer object with given arguments.
|
|
56
|
+
"""Initialize a WorldTrainer object with given arguments.
|
|
59
57
|
|
|
60
58
|
Args:
|
|
61
59
|
cfg (dict[str, Any]): Configuration for the trainer.
|
|
@@ -69,8 +67,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
69
67
|
self.text_embeddings = None
|
|
70
68
|
|
|
71
69
|
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
|
72
|
-
"""
|
|
73
|
-
Return WorldModel initialized with specified config and weights.
|
|
70
|
+
"""Return WorldModel initialized with specified config and weights.
|
|
74
71
|
|
|
75
72
|
Args:
|
|
76
73
|
cfg (dict[str, Any] | str, optional): Model configuration.
|
|
@@ -95,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
95
92
|
return model
|
|
96
93
|
|
|
97
94
|
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
|
98
|
-
"""
|
|
99
|
-
Build YOLO Dataset for training or validation.
|
|
95
|
+
"""Build YOLO Dataset for training or validation.
|
|
100
96
|
|
|
101
97
|
Args:
|
|
102
98
|
img_path (str): Path to the folder containing images.
|
|
@@ -115,11 +111,10 @@ class WorldTrainer(DetectionTrainer):
|
|
|
115
111
|
return dataset
|
|
116
112
|
|
|
117
113
|
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
|
118
|
-
"""
|
|
119
|
-
Set text embeddings for datasets to accelerate training by caching category names.
|
|
114
|
+
"""Set text embeddings for datasets to accelerate training by caching category names.
|
|
120
115
|
|
|
121
|
-
This method collects unique category names from all datasets, then generates and caches text embeddings
|
|
122
|
-
|
|
116
|
+
This method collects unique category names from all datasets, then generates and caches text embeddings for
|
|
117
|
+
these categories to improve training efficiency.
|
|
123
118
|
|
|
124
119
|
Args:
|
|
125
120
|
datasets (list[Any]): List of datasets from which to extract category names.
|
|
@@ -141,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
|
|
|
141
136
|
self.text_embeddings = text_embeddings
|
|
142
137
|
|
|
143
138
|
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
|
144
|
-
"""
|
|
145
|
-
Generate text embeddings for a list of text samples.
|
|
139
|
+
"""Generate text embeddings for a list of text samples.
|
|
146
140
|
|
|
147
141
|
Args:
|
|
148
142
|
texts (list[str]): List of text samples to encode.
|
|
@@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class WorldTrainerFromScratch(WorldTrainer):
|
|
13
|
-
"""
|
|
14
|
-
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
13
|
+
"""A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
|
15
14
|
|
|
16
15
|
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
|
17
16
|
supporting training YOLO-World models with combined vision-language capabilities.
|
|
@@ -53,45 +52,25 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
53
52
|
"""
|
|
54
53
|
|
|
55
54
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
56
|
-
"""
|
|
57
|
-
Initialize a WorldTrainerFromScratch object.
|
|
55
|
+
"""Initialize a WorldTrainerFromScratch object.
|
|
58
56
|
|
|
59
|
-
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
|
|
60
|
-
|
|
57
|
+
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
|
|
58
|
+
detection and grounding datasets for vision-language capabilities.
|
|
61
59
|
|
|
62
60
|
Args:
|
|
63
61
|
cfg (dict): Configuration dictionary with default parameters for model training.
|
|
64
62
|
overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
|
|
65
63
|
_callbacks (list, optional): List of callback functions to be executed during different stages of training.
|
|
66
|
-
|
|
67
|
-
Examples:
|
|
68
|
-
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
|
69
|
-
>>> from ultralytics import YOLOWorld
|
|
70
|
-
>>> data = dict(
|
|
71
|
-
... train=dict(
|
|
72
|
-
... yolo_data=["Objects365.yaml"],
|
|
73
|
-
... grounding_data=[
|
|
74
|
-
... dict(
|
|
75
|
-
... img_path="flickr30k/images",
|
|
76
|
-
... json_file="flickr30k/final_flickr_separateGT_train.json",
|
|
77
|
-
... ),
|
|
78
|
-
... ],
|
|
79
|
-
... ),
|
|
80
|
-
... val=dict(yolo_data=["lvis.yaml"]),
|
|
81
|
-
... )
|
|
82
|
-
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
|
|
83
|
-
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
|
|
84
64
|
"""
|
|
85
65
|
if overrides is None:
|
|
86
66
|
overrides = {}
|
|
87
67
|
super().__init__(cfg, overrides, _callbacks)
|
|
88
68
|
|
|
89
69
|
def build_dataset(self, img_path, mode="train", batch=None):
|
|
90
|
-
"""
|
|
91
|
-
Build YOLO Dataset for training or validation.
|
|
70
|
+
"""Build YOLO Dataset for training or validation.
|
|
92
71
|
|
|
93
|
-
This method constructs appropriate datasets based on the mode and input paths, handling both
|
|
94
|
-
|
|
72
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
|
|
73
|
+
datasets and grounding datasets with different formats.
|
|
95
74
|
|
|
96
75
|
Args:
|
|
97
76
|
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
|
@@ -122,11 +101,10 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
122
101
|
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
|
123
102
|
|
|
124
103
|
def get_dataset(self):
|
|
125
|
-
"""
|
|
126
|
-
Get train and validation paths from data dictionary.
|
|
104
|
+
"""Get train and validation paths from data dictionary.
|
|
127
105
|
|
|
128
|
-
Processes the data configuration to extract paths for training and validation datasets,
|
|
129
|
-
|
|
106
|
+
Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
|
|
107
|
+
detection datasets and grounding datasets.
|
|
130
108
|
|
|
131
109
|
Returns:
|
|
132
110
|
train_path (str): Train dataset path.
|
|
@@ -187,8 +165,7 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
|
187
165
|
pass
|
|
188
166
|
|
|
189
167
|
def final_eval(self):
|
|
190
|
-
"""
|
|
191
|
-
Perform final evaluation and validation for the YOLO-World model.
|
|
168
|
+
"""Perform final evaluation and validation for the YOLO-World model.
|
|
192
169
|
|
|
193
170
|
Configures the validator with appropriate dataset and split information before running evaluation.
|
|
194
171
|
|
|
@@ -6,17 +6,17 @@ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromSc
|
|
|
6
6
|
from .val import YOLOEDetectValidator, YOLOESegValidator
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
|
-
"YOLOETrainer",
|
|
10
|
-
"YOLOEPETrainer",
|
|
11
|
-
"YOLOESegTrainer",
|
|
12
9
|
"YOLOEDetectValidator",
|
|
13
|
-
"
|
|
10
|
+
"YOLOEPEFreeTrainer",
|
|
14
11
|
"YOLOEPESegTrainer",
|
|
12
|
+
"YOLOEPETrainer",
|
|
13
|
+
"YOLOESegTrainer",
|
|
15
14
|
"YOLOESegTrainerFromScratch",
|
|
16
15
|
"YOLOESegVPTrainer",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
16
|
+
"YOLOESegValidator",
|
|
17
|
+
"YOLOETrainer",
|
|
18
|
+
"YOLOETrainerFromScratch",
|
|
19
19
|
"YOLOEVPDetectPredictor",
|
|
20
20
|
"YOLOEVPSegPredictor",
|
|
21
|
-
"
|
|
21
|
+
"YOLOEVPTrainer",
|
|
22
22
|
]
|
|
@@ -9,11 +9,10 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
12
|
-
"""
|
|
13
|
-
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
12
|
+
"""A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
|
14
13
|
|
|
15
|
-
This mixin provides common functionality for YOLO models that use visual prompting, including
|
|
16
|
-
|
|
14
|
+
This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
|
|
15
|
+
handling, and preprocessing transformations.
|
|
17
16
|
|
|
18
17
|
Attributes:
|
|
19
18
|
model (torch.nn.Module): The YOLO model for inference.
|
|
@@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
29
28
|
"""
|
|
30
29
|
|
|
31
30
|
def setup_model(self, model, verbose: bool = True):
|
|
32
|
-
"""
|
|
33
|
-
Set up the model for prediction.
|
|
31
|
+
"""Set up the model for prediction.
|
|
34
32
|
|
|
35
33
|
Args:
|
|
36
34
|
model (torch.nn.Module): Model to load or use.
|
|
@@ -40,21 +38,19 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
40
38
|
self.done_warmup = True
|
|
41
39
|
|
|
42
40
|
def set_prompts(self, prompts):
|
|
43
|
-
"""
|
|
44
|
-
Set the visual prompts for the model.
|
|
41
|
+
"""Set the visual prompts for the model.
|
|
45
42
|
|
|
46
43
|
Args:
|
|
47
|
-
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
|
48
|
-
|
|
44
|
+
prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
|
|
45
|
+
with class indices.
|
|
49
46
|
"""
|
|
50
47
|
self.prompts = prompts
|
|
51
48
|
|
|
52
49
|
def pre_transform(self, im):
|
|
53
|
-
"""
|
|
54
|
-
Preprocess images and prompts before inference.
|
|
50
|
+
"""Preprocess images and prompts before inference.
|
|
55
51
|
|
|
56
|
-
This method applies letterboxing to the input image and transforms the visual prompts
|
|
57
|
-
|
|
52
|
+
This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
|
|
53
|
+
accordingly.
|
|
58
54
|
|
|
59
55
|
Args:
|
|
60
56
|
im (list): List containing a single input image.
|
|
@@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
94
90
|
return img
|
|
95
91
|
|
|
96
92
|
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
|
97
|
-
"""
|
|
98
|
-
Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
93
|
+
"""Process a single image by resizing bounding boxes or masks and generating visuals.
|
|
99
94
|
|
|
100
95
|
Args:
|
|
101
96
|
dst_shape (tuple): The target shape (height, width) of the image.
|
|
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
131
126
|
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
|
132
127
|
|
|
133
128
|
def inference(self, im, *args, **kwargs):
|
|
134
|
-
"""
|
|
135
|
-
Run inference with visual prompts.
|
|
129
|
+
"""Run inference with visual prompts.
|
|
136
130
|
|
|
137
131
|
Args:
|
|
138
132
|
im (torch.Tensor): Input image tensor.
|
|
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
|
|
|
145
139
|
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
|
146
140
|
|
|
147
141
|
def get_vpe(self, source):
|
|
148
|
-
"""
|
|
149
|
-
Process the source to get the visual prompt embeddings (VPE).
|
|
142
|
+
"""Process the source to get the visual prompt embeddings (VPE).
|
|
150
143
|
|
|
151
144
|
Args:
|
|
152
|
-
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
|
153
|
-
|
|
154
|
-
|
|
145
|
+
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
|
|
146
|
+
make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
|
|
147
|
+
torch tensors.
|
|
155
148
|
|
|
156
149
|
Returns:
|
|
157
150
|
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|