dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -12,11 +12,11 @@ from ultralytics.models.yolo.detect import DetectionValidator
|
|
|
12
12
|
from ultralytics.utils import LOGGER, ops
|
|
13
13
|
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
|
14
14
|
from ultralytics.utils.nms import TorchNMS
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class OBBValidator(DetectionValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
19
|
+
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
20
20
|
|
|
21
21
|
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
|
22
22
|
satellite imagery where objects can appear at various orientations.
|
|
@@ -44,14 +44,13 @@ class OBBValidator(DetectionValidator):
|
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
46
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
47
|
+
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
49
48
|
|
|
50
|
-
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
|
51
|
-
|
|
49
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
|
50
|
+
extends the DetectionValidator class and configures it specifically for the OBB task.
|
|
52
51
|
|
|
53
52
|
Args:
|
|
54
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
53
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
55
54
|
save_dir (str | Path, optional): Directory to save results.
|
|
56
55
|
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
|
57
56
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
@@ -61,8 +60,7 @@ class OBBValidator(DetectionValidator):
|
|
|
61
60
|
self.metrics = OBBMetrics()
|
|
62
61
|
|
|
63
62
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize evaluation metrics for YOLO obb validation.
|
|
63
|
+
"""Initialize evaluation metrics for YOLO obb validation.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
model (torch.nn.Module): Model to validate.
|
|
@@ -73,19 +71,18 @@ class OBBValidator(DetectionValidator):
|
|
|
73
71
|
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
|
74
72
|
|
|
75
73
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
76
|
-
"""
|
|
77
|
-
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
74
|
+
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
78
75
|
|
|
79
76
|
Args:
|
|
80
77
|
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
|
81
78
|
class labels and bounding boxes.
|
|
82
|
-
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
|
83
|
-
|
|
79
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
|
80
|
+
labels and bounding boxes.
|
|
84
81
|
|
|
85
82
|
Returns:
|
|
86
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
|
84
|
+
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
|
85
|
+
predictions compared to the ground truth.
|
|
89
86
|
|
|
90
87
|
Examples:
|
|
91
88
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
|
@@ -93,13 +90,14 @@ class OBBValidator(DetectionValidator):
|
|
|
93
90
|
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
|
94
91
|
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
|
95
92
|
"""
|
|
96
|
-
if
|
|
97
|
-
return {"tp": np.zeros((
|
|
93
|
+
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
94
|
+
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
|
98
95
|
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
|
|
99
96
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
100
97
|
|
|
101
98
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
102
|
-
"""
|
|
99
|
+
"""Postprocess OBB predictions.
|
|
100
|
+
|
|
103
101
|
Args:
|
|
104
102
|
preds (torch.Tensor): Raw predictions from the model.
|
|
105
103
|
|
|
@@ -112,8 +110,7 @@ class OBBValidator(DetectionValidator):
|
|
|
112
110
|
return preds
|
|
113
111
|
|
|
114
112
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
-
"""
|
|
116
|
-
Prepare batch data for OBB validation with proper scaling and formatting.
|
|
113
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
|
117
114
|
|
|
118
115
|
Args:
|
|
119
116
|
si (int): Batch index to process.
|
|
@@ -134,7 +131,7 @@ class OBBValidator(DetectionValidator):
|
|
|
134
131
|
ori_shape = batch["ori_shape"][si]
|
|
135
132
|
imgsz = batch["img"].shape[2:]
|
|
136
133
|
ratio_pad = batch["ratio_pad"][si]
|
|
137
|
-
if
|
|
134
|
+
if cls.shape[0]:
|
|
138
135
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
|
139
136
|
return {
|
|
140
137
|
"cls": cls,
|
|
@@ -145,33 +142,41 @@ class OBBValidator(DetectionValidator):
|
|
|
145
142
|
"im_file": batch["im_file"][si],
|
|
146
143
|
}
|
|
147
144
|
|
|
148
|
-
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
145
|
+
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
146
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
151
147
|
|
|
152
148
|
Args:
|
|
153
149
|
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
|
154
|
-
preds (list[torch.Tensor]): List of prediction
|
|
150
|
+
preds (list[dict[str, torch.Tensor]]): List of prediction dictionaries for each image in the batch.
|
|
155
151
|
ni (int): Batch index used for naming the output file.
|
|
156
152
|
|
|
157
153
|
Examples:
|
|
158
154
|
>>> validator = OBBValidator()
|
|
159
155
|
>>> batch = {"img": images, "im_file": paths}
|
|
160
|
-
>>> preds = [torch.rand(10,
|
|
156
|
+
>>> preds = [{"bboxes": torch.rand(10, 5), "cls": torch.zeros(10), "conf": torch.rand(10)}]
|
|
161
157
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
162
158
|
"""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
if not preds:
|
|
160
|
+
return
|
|
161
|
+
for i, pred in enumerate(preds):
|
|
162
|
+
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i
|
|
163
|
+
keys = preds[0].keys()
|
|
164
|
+
batched_preds = {k: torch.cat([x[k] for x in preds], dim=0) for k in keys}
|
|
165
|
+
plot_images(
|
|
166
|
+
images=batch["img"],
|
|
167
|
+
labels=batched_preds,
|
|
168
|
+
paths=batch["im_file"],
|
|
169
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
170
|
+
names=self.names,
|
|
171
|
+
on_plot=self.on_plot,
|
|
172
|
+
)
|
|
167
173
|
|
|
168
174
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
175
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
171
176
|
|
|
172
177
|
Args:
|
|
173
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
174
|
-
|
|
178
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
179
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
175
180
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
176
181
|
|
|
177
182
|
Notes:
|
|
@@ -197,8 +202,7 @@ class OBBValidator(DetectionValidator):
|
|
|
197
202
|
)
|
|
198
203
|
|
|
199
204
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
200
|
-
"""
|
|
201
|
-
Save YOLO OBB detections to a text file in normalized coordinates.
|
|
205
|
+
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
|
202
206
|
|
|
203
207
|
Args:
|
|
204
208
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
|
@@ -233,8 +237,7 @@ class OBBValidator(DetectionValidator):
|
|
|
233
237
|
}
|
|
234
238
|
|
|
235
239
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
236
|
-
"""
|
|
237
|
-
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
240
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
stats (dict[str, Any]): Performance statistics dictionary.
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
|
4
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
4
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class PosePredictor(DetectionPredictor):
|
|
8
|
-
"""
|
|
9
|
-
A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
8
|
+
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
10
9
|
|
|
11
10
|
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
|
12
11
|
capabilities inherited from DetectionPredictor.
|
|
@@ -27,35 +26,21 @@ class PosePredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize PosePredictor for pose estimation tasks.
|
|
29
|
+
"""Initialize PosePredictor for pose estimation tasks.
|
|
32
30
|
|
|
33
|
-
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
|
34
|
-
|
|
31
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
|
32
|
+
for Apple MPS.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
35
|
cfg (Any): Configuration for the predictor.
|
|
38
36
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
39
37
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
40
|
-
|
|
41
|
-
Examples:
|
|
42
|
-
>>> from ultralytics.utils import ASSETS
|
|
43
|
-
>>> from ultralytics.models.yolo.pose import PosePredictor
|
|
44
|
-
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
|
45
|
-
>>> predictor = PosePredictor(overrides=args)
|
|
46
|
-
>>> predictor.predict_cli()
|
|
47
38
|
"""
|
|
48
39
|
super().__init__(cfg, overrides, _callbacks)
|
|
49
40
|
self.args.task = "pose"
|
|
50
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
51
|
-
LOGGER.warning(
|
|
52
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
53
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
54
|
-
)
|
|
55
41
|
|
|
56
42
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
57
|
-
"""
|
|
58
|
-
Construct the result object from the prediction, including keypoints.
|
|
43
|
+
"""Construct the result object from the prediction, including keypoints.
|
|
59
44
|
|
|
60
45
|
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
|
61
46
|
result object.
|
|
@@ -73,7 +58,7 @@ class PosePredictor(DetectionPredictor):
|
|
|
73
58
|
"""
|
|
74
59
|
result = super().construct_result(pred, img, orig_img, img_path)
|
|
75
60
|
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
|
76
|
-
pred_kpts = pred[:, 6:].view(
|
|
61
|
+
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
|
|
77
62
|
# Scale keypoints coordinates to match the original image dimensions
|
|
78
63
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
|
79
64
|
result.update(keypoints=pred_kpts)
|
|
@@ -8,13 +8,11 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
from ultralytics.models import yolo
|
|
10
10
|
from ultralytics.nn.tasks import PoseModel
|
|
11
|
-
from ultralytics.utils import DEFAULT_CFG
|
|
12
|
-
from ultralytics.utils.plotting import plot_results
|
|
11
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
16
|
-
"""
|
|
17
|
-
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
18
16
|
|
|
19
17
|
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
|
20
18
|
of pose keypoints alongside bounding boxes.
|
|
@@ -30,7 +28,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
30
28
|
set_model_attributes: Set keypoints shape attribute on the model.
|
|
31
29
|
get_validator: Create a validator instance for model evaluation.
|
|
32
30
|
plot_training_samples: Visualize training samples with keypoints.
|
|
33
|
-
plot_metrics: Generate and save training/validation metric plots.
|
|
34
31
|
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
|
35
32
|
|
|
36
33
|
Examples:
|
|
@@ -41,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
41
38
|
"""
|
|
42
39
|
|
|
43
40
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
44
|
-
"""
|
|
45
|
-
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
41
|
+
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
46
42
|
|
|
47
43
|
Args:
|
|
48
44
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -57,13 +53,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
57
53
|
overrides = {}
|
|
58
54
|
overrides["task"] = "pose"
|
|
59
55
|
super().__init__(cfg, overrides, _callbacks)
|
|
60
|
-
self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "keypoints"]
|
|
61
|
-
|
|
62
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
63
|
-
LOGGER.warning(
|
|
64
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
65
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
66
|
-
)
|
|
67
56
|
|
|
68
57
|
def get_model(
|
|
69
58
|
self,
|
|
@@ -71,8 +60,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
71
60
|
weights: str | Path | None = None,
|
|
72
61
|
verbose: bool = True,
|
|
73
62
|
) -> PoseModel:
|
|
74
|
-
"""
|
|
75
|
-
Get pose estimation model with specified configuration and weights.
|
|
63
|
+
"""Get pose estimation model with specified configuration and weights.
|
|
76
64
|
|
|
77
65
|
Args:
|
|
78
66
|
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
|
@@ -94,6 +82,11 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
94
82
|
"""Set keypoints shape attribute of PoseModel."""
|
|
95
83
|
super().set_model_attributes()
|
|
96
84
|
self.model.kpt_shape = self.data["kpt_shape"]
|
|
85
|
+
kpt_names = self.data.get("kpt_names")
|
|
86
|
+
if not kpt_names:
|
|
87
|
+
names = list(map(str, range(self.model.kpt_shape[0])))
|
|
88
|
+
kpt_names = {i: names for i in range(self.model.nc)}
|
|
89
|
+
self.model.kpt_names = kpt_names
|
|
97
90
|
|
|
98
91
|
def get_validator(self):
|
|
99
92
|
"""Return an instance of the PoseValidator class for validation."""
|
|
@@ -102,13 +95,8 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
102
95
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
103
96
|
)
|
|
104
97
|
|
|
105
|
-
def plot_metrics(self):
|
|
106
|
-
"""Plot training/validation metrics."""
|
|
107
|
-
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
|
|
108
|
-
|
|
109
98
|
def get_dataset(self) -> dict[str, Any]:
|
|
110
|
-
"""
|
|
111
|
-
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
99
|
+
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
112
100
|
|
|
113
101
|
Returns:
|
|
114
102
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -9,16 +9,15 @@ import numpy as np
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
12
|
-
from ultralytics.utils import
|
|
12
|
+
from ultralytics.utils import ops
|
|
13
13
|
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PoseValidator(DetectionValidator):
|
|
17
|
-
"""
|
|
18
|
-
A class extending the DetectionValidator class for validation based on a pose model.
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
19
18
|
|
|
20
|
-
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
21
|
-
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
|
|
|
33
32
|
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
34
33
|
dimensions.
|
|
35
34
|
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
36
|
-
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
|
37
|
-
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
38
37
|
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
39
38
|
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
40
39
|
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
@@ -46,42 +45,30 @@ class PoseValidator(DetectionValidator):
|
|
|
46
45
|
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
47
46
|
>>> validator = PoseValidator(args=args)
|
|
48
47
|
>>> validator()
|
|
48
|
+
|
|
49
|
+
Notes:
|
|
50
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
51
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
52
|
+
due to a known bug with pose models.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
52
|
-
"""
|
|
53
|
-
Initialize a PoseValidator object for pose estimation validation.
|
|
56
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
54
57
|
|
|
55
58
|
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
56
59
|
specialized metrics for pose evaluation.
|
|
57
60
|
|
|
58
61
|
Args:
|
|
59
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
62
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
60
63
|
save_dir (Path | str, optional): Directory to save results.
|
|
61
64
|
args (dict, optional): Arguments for the validator including task set to "pose".
|
|
62
65
|
_callbacks (list, optional): List of callback functions to be executed during validation.
|
|
63
|
-
|
|
64
|
-
Examples:
|
|
65
|
-
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
66
|
-
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
67
|
-
>>> validator = PoseValidator(args=args)
|
|
68
|
-
>>> validator()
|
|
69
|
-
|
|
70
|
-
Notes:
|
|
71
|
-
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
72
|
-
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
73
|
-
due to a known bug with pose models.
|
|
74
66
|
"""
|
|
75
67
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
76
68
|
self.sigma = None
|
|
77
69
|
self.kpt_shape = None
|
|
78
70
|
self.args.task = "pose"
|
|
79
71
|
self.metrics = PoseMetrics()
|
|
80
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
81
|
-
LOGGER.warning(
|
|
82
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
83
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
84
|
-
)
|
|
85
72
|
|
|
86
73
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
87
74
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
|
@@ -106,8 +93,7 @@ class PoseValidator(DetectionValidator):
|
|
|
106
93
|
)
|
|
107
94
|
|
|
108
95
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
109
|
-
"""
|
|
110
|
-
Initialize evaluation metrics for YOLO pose validation.
|
|
96
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
111
97
|
|
|
112
98
|
Args:
|
|
113
99
|
model (torch.nn.Module): Model to validate.
|
|
@@ -119,17 +105,15 @@ class PoseValidator(DetectionValidator):
|
|
|
119
105
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
120
106
|
|
|
121
107
|
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
122
|
-
"""
|
|
123
|
-
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
108
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
124
109
|
|
|
125
|
-
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
(typically [N, 17, 3] for COCO pose format).
|
|
110
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
111
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
112
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
129
113
|
|
|
130
114
|
Args:
|
|
131
|
-
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
|
132
|
-
|
|
115
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
116
|
+
scores, class predictions, and keypoint data.
|
|
133
117
|
|
|
134
118
|
Returns:
|
|
135
119
|
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
@@ -138,10 +122,10 @@ class PoseValidator(DetectionValidator):
|
|
|
138
122
|
- 'cls': Class predictions
|
|
139
123
|
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
140
124
|
|
|
141
|
-
|
|
142
|
-
If no keypoints are present in a prediction (empty keypoints), that prediction
|
|
143
|
-
|
|
144
|
-
|
|
125
|
+
Notes:
|
|
126
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
127
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
128
|
+
task-specific data beyond basic detection.
|
|
145
129
|
"""
|
|
146
130
|
preds = super().postprocess(preds)
|
|
147
131
|
for pred in preds:
|
|
@@ -149,8 +133,7 @@ class PoseValidator(DetectionValidator):
|
|
|
149
133
|
return preds
|
|
150
134
|
|
|
151
135
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
152
|
-
"""
|
|
153
|
-
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
136
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
154
137
|
|
|
155
138
|
Args:
|
|
156
139
|
si (int): Batch index.
|
|
@@ -173,18 +156,18 @@ class PoseValidator(DetectionValidator):
|
|
|
173
156
|
return pbatch
|
|
174
157
|
|
|
175
158
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
176
|
-
"""
|
|
177
|
-
|
|
159
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
160
|
+
truth.
|
|
178
161
|
|
|
179
162
|
Args:
|
|
180
163
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
181
164
|
and 'keypoints' for keypoint predictions.
|
|
182
|
-
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
|
183
|
-
|
|
165
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
166
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
184
167
|
|
|
185
168
|
Returns:
|
|
186
|
-
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
|
187
|
-
|
|
169
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
170
|
+
positives across 10 IoU levels.
|
|
188
171
|
|
|
189
172
|
Notes:
|
|
190
173
|
`0.53` scale factor used in area computation is referenced from
|
|
@@ -192,8 +175,8 @@ class PoseValidator(DetectionValidator):
|
|
|
192
175
|
"""
|
|
193
176
|
tp = super()._process_batch(preds, batch)
|
|
194
177
|
gt_cls = batch["cls"]
|
|
195
|
-
if
|
|
196
|
-
tp_p = np.zeros((
|
|
178
|
+
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
179
|
+
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
197
180
|
else:
|
|
198
181
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
|
199
182
|
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
|
@@ -203,11 +186,10 @@ class PoseValidator(DetectionValidator):
|
|
|
203
186
|
return tp
|
|
204
187
|
|
|
205
188
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Save YOLO pose detections to a text file in normalized coordinates.
|
|
189
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
208
190
|
|
|
209
191
|
Args:
|
|
210
|
-
predn (dict[str, torch.Tensor]):
|
|
192
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
211
193
|
save_conf (bool): Whether to save confidence scores.
|
|
212
194
|
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
213
195
|
file (Path): Output file path to save detections.
|
|
@@ -227,15 +209,14 @@ class PoseValidator(DetectionValidator):
|
|
|
227
209
|
).save_txt(file, save_conf=save_conf)
|
|
228
210
|
|
|
229
211
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
230
|
-
"""
|
|
231
|
-
Convert YOLO predictions to COCO JSON format.
|
|
212
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
232
213
|
|
|
233
|
-
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
|
234
|
-
|
|
214
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
215
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
235
216
|
|
|
236
217
|
Args:
|
|
237
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
|
238
|
-
|
|
218
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
219
|
+
tensors.
|
|
239
220
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
240
221
|
|
|
241
222
|
Notes:
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
11
10
|
|
|
12
11
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
13
12
|
prediction results.
|
|
@@ -31,8 +30,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
32
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
34
|
-
"""
|
|
35
|
-
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
36
34
|
|
|
37
35
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
38
36
|
prediction results.
|
|
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
46
44
|
self.args.task = "segment"
|
|
47
45
|
|
|
48
46
|
def postprocess(self, preds, img, orig_imgs):
|
|
49
|
-
"""
|
|
50
|
-
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
51
48
|
|
|
52
49
|
Args:
|
|
53
50
|
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
55
52
|
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
56
53
|
|
|
57
54
|
Returns:
|
|
58
|
-
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
|
59
|
-
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
60
57
|
|
|
61
58
|
Examples:
|
|
62
59
|
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
|
@@ -67,8 +64,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
67
64
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
68
65
|
|
|
69
66
|
def construct_results(self, preds, img, orig_imgs, protos):
|
|
70
|
-
"""
|
|
71
|
-
Construct a list of result objects from the predictions.
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
70
|
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
77
73
|
protos (list[torch.Tensor]): List of prototype masks.
|
|
78
74
|
|
|
79
75
|
Returns:
|
|
80
|
-
(list[Results]): List of result objects containing the original images, image paths, class names,
|
|
81
|
-
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
82
78
|
"""
|
|
83
79
|
return [
|
|
84
80
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
@@ -86,11 +82,10 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
86
82
|
]
|
|
87
83
|
|
|
88
84
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
89
|
-
"""
|
|
90
|
-
Construct a single result object from the prediction.
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
91
86
|
|
|
92
87
|
Args:
|
|
93
|
-
pred (
|
|
88
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
94
89
|
img (torch.Tensor): The image after preprocessing.
|
|
95
90
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
96
91
|
img_path (str): The path to the original image.
|
|
@@ -99,15 +94,16 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
99
94
|
Returns:
|
|
100
95
|
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
|
101
96
|
"""
|
|
102
|
-
if
|
|
97
|
+
if pred.shape[0] == 0: # save empty boxes
|
|
103
98
|
masks = None
|
|
104
99
|
elif self.args.retina_masks:
|
|
105
100
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
106
|
-
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) #
|
|
101
|
+
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
|
|
107
102
|
else:
|
|
108
|
-
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) #
|
|
103
|
+
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
|
|
109
104
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
110
105
|
if masks is not None:
|
|
111
|
-
keep = masks.
|
|
112
|
-
|
|
106
|
+
keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
|
|
107
|
+
if not all(keep): # most predictions have masks
|
|
108
|
+
pred, masks = pred[keep], masks[keep] # indexing is slow
|
|
113
109
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|