dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,43 +1,44 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import numpy as np
|
|
6
9
|
import torch
|
|
7
10
|
|
|
8
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
9
12
|
from ultralytics.utils import LOGGER, ops
|
|
10
|
-
from ultralytics.utils.
|
|
11
|
-
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
|
|
12
|
-
from ultralytics.utils.plotting import output_to_target, plot_images
|
|
13
|
+
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class PoseValidator(DetectionValidator):
|
|
16
|
-
"""
|
|
17
|
-
A class extending the DetectionValidator class for validation based on a pose model.
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
18
18
|
|
|
19
|
-
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
20
|
-
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
21
21
|
|
|
22
22
|
Attributes:
|
|
23
23
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
24
|
-
kpt_shape (
|
|
24
|
+
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
|
25
25
|
args (dict): Arguments for the validator including task set to "pose".
|
|
26
26
|
metrics (PoseMetrics): Metrics object for pose evaluation.
|
|
27
27
|
|
|
28
28
|
Methods:
|
|
29
|
-
preprocess:
|
|
30
|
-
get_desc:
|
|
31
|
-
init_metrics:
|
|
32
|
-
_prepare_batch:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
_process_batch:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
29
|
+
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
|
|
30
|
+
get_desc: Return description of evaluation metrics in string format.
|
|
31
|
+
init_metrics: Initialize pose estimation metrics for YOLO model.
|
|
32
|
+
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
33
|
+
dimensions.
|
|
34
|
+
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
37
|
+
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
38
|
+
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
39
|
+
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
40
|
+
pred_to_json: Convert YOLO predictions to COCO JSON format.
|
|
41
|
+
eval_json: Evaluate object detection model using COCO JSON format.
|
|
41
42
|
|
|
42
43
|
Examples:
|
|
43
44
|
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
@@ -46,9 +47,8 @@ class PoseValidator(DetectionValidator):
|
|
|
46
47
|
>>> validator()
|
|
47
48
|
"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, dataloader=None, save_dir=None,
|
|
50
|
-
"""
|
|
51
|
-
Initialize a PoseValidator object for pose estimation validation.
|
|
50
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
51
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
52
52
|
|
|
53
53
|
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
54
54
|
specialized metrics for pose evaluation.
|
|
@@ -56,7 +56,6 @@ class PoseValidator(DetectionValidator):
|
|
|
56
56
|
Args:
|
|
57
57
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
|
58
58
|
save_dir (Path | str, optional): Directory to save results.
|
|
59
|
-
pbar (Any, optional): Progress bar for displaying progress.
|
|
60
59
|
args (dict, optional): Arguments for the validator including task set to "pose".
|
|
61
60
|
_callbacks (list, optional): List of callback functions to be executed during validation.
|
|
62
61
|
|
|
@@ -71,24 +70,24 @@ class PoseValidator(DetectionValidator):
|
|
|
71
70
|
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
72
71
|
due to a known bug with pose models.
|
|
73
72
|
"""
|
|
74
|
-
super().__init__(dataloader, save_dir,
|
|
73
|
+
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
75
74
|
self.sigma = None
|
|
76
75
|
self.kpt_shape = None
|
|
77
76
|
self.args.task = "pose"
|
|
78
|
-
self.metrics = PoseMetrics(
|
|
77
|
+
self.metrics = PoseMetrics()
|
|
79
78
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
80
79
|
LOGGER.warning(
|
|
81
80
|
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
82
81
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
83
82
|
)
|
|
84
83
|
|
|
85
|
-
def preprocess(self, batch):
|
|
84
|
+
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
86
85
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
|
87
86
|
batch = super().preprocess(batch)
|
|
88
|
-
batch["keypoints"] = batch["keypoints"].
|
|
87
|
+
batch["keypoints"] = batch["keypoints"].float()
|
|
89
88
|
return batch
|
|
90
89
|
|
|
91
|
-
def get_desc(self):
|
|
90
|
+
def get_desc(self) -> str:
|
|
92
91
|
"""Return description of evaluation metrics in string format."""
|
|
93
92
|
return ("%22s" + "%11s" * 10) % (
|
|
94
93
|
"Class",
|
|
@@ -104,25 +103,55 @@ class PoseValidator(DetectionValidator):
|
|
|
104
103
|
"mAP50-95)",
|
|
105
104
|
)
|
|
106
105
|
|
|
107
|
-
def init_metrics(self, model):
|
|
108
|
-
"""Initialize
|
|
106
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
107
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model (torch.nn.Module): Model to validate.
|
|
111
|
+
"""
|
|
109
112
|
super().init_metrics(model)
|
|
110
113
|
self.kpt_shape = self.data["kpt_shape"]
|
|
111
114
|
is_pose = self.kpt_shape == [17, 3]
|
|
112
115
|
nkpt = self.kpt_shape[0]
|
|
113
116
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
114
|
-
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
|
115
117
|
|
|
116
|
-
def
|
|
118
|
+
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
119
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
120
|
+
|
|
121
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
122
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
123
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
127
|
+
scores, class predictions, and keypoint data.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
131
|
+
- 'bboxes': Bounding box coordinates
|
|
132
|
+
- 'conf': Confidence scores
|
|
133
|
+
- 'cls': Class predictions
|
|
134
|
+
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
135
|
+
|
|
136
|
+
Notes:
|
|
137
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
138
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
139
|
+
task-specific data beyond basic detection.
|
|
117
140
|
"""
|
|
118
|
-
|
|
141
|
+
preds = super().postprocess(preds)
|
|
142
|
+
for pred in preds:
|
|
143
|
+
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
|
|
144
|
+
return preds
|
|
145
|
+
|
|
146
|
+
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
147
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
119
148
|
|
|
120
149
|
Args:
|
|
121
150
|
si (int): Batch index.
|
|
122
|
-
batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
|
151
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
|
123
152
|
|
|
124
153
|
Returns:
|
|
125
|
-
|
|
154
|
+
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
|
126
155
|
|
|
127
156
|
Notes:
|
|
128
157
|
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
|
@@ -134,187 +163,46 @@ class PoseValidator(DetectionValidator):
|
|
|
134
163
|
kpts = kpts.clone()
|
|
135
164
|
kpts[..., 0] *= w
|
|
136
165
|
kpts[..., 1] *= h
|
|
137
|
-
|
|
138
|
-
pbatch["kpts"] = kpts
|
|
166
|
+
pbatch["keypoints"] = kpts
|
|
139
167
|
return pbatch
|
|
140
168
|
|
|
141
|
-
def
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
|
|
146
|
-
the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
|
|
147
|
-
to match the original image dimensions.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
pred (torch.Tensor): Raw prediction tensor from the model.
|
|
151
|
-
pbatch (dict): Processed batch dictionary containing image information including:
|
|
152
|
-
- imgsz: Image size used for inference
|
|
153
|
-
- ori_shape: Original image shape
|
|
154
|
-
- ratio_pad: Ratio and padding information for coordinate scaling
|
|
155
|
-
|
|
156
|
-
Returns:
|
|
157
|
-
predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
|
|
158
|
-
"""
|
|
159
|
-
predn = super()._prepare_pred(pred, pbatch)
|
|
160
|
-
nk = pbatch["kpts"].shape[1]
|
|
161
|
-
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
|
162
|
-
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
|
163
|
-
return predn, pred_kpts
|
|
164
|
-
|
|
165
|
-
def update_metrics(self, preds, batch):
|
|
166
|
-
"""
|
|
167
|
-
Update metrics with new predictions and ground truth data.
|
|
168
|
-
|
|
169
|
-
This method processes each prediction, compares it with ground truth, and updates various statistics
|
|
170
|
-
for performance evaluation.
|
|
169
|
+
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
170
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
171
|
+
truth.
|
|
171
172
|
|
|
172
173
|
Args:
|
|
173
|
-
preds (
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
self.seen += 1
|
|
178
|
-
npr = len(pred)
|
|
179
|
-
stat = dict(
|
|
180
|
-
conf=torch.zeros(0, device=self.device),
|
|
181
|
-
pred_cls=torch.zeros(0, device=self.device),
|
|
182
|
-
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
183
|
-
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
184
|
-
)
|
|
185
|
-
pbatch = self._prepare_batch(si, batch)
|
|
186
|
-
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
|
187
|
-
nl = len(cls)
|
|
188
|
-
stat["target_cls"] = cls
|
|
189
|
-
stat["target_img"] = cls.unique()
|
|
190
|
-
if npr == 0:
|
|
191
|
-
if nl:
|
|
192
|
-
for k in self.stats.keys():
|
|
193
|
-
self.stats[k].append(stat[k])
|
|
194
|
-
if self.args.plots:
|
|
195
|
-
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
|
196
|
-
continue
|
|
197
|
-
|
|
198
|
-
# Predictions
|
|
199
|
-
if self.args.single_cls:
|
|
200
|
-
pred[:, 5] = 0
|
|
201
|
-
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
|
202
|
-
stat["conf"] = predn[:, 4]
|
|
203
|
-
stat["pred_cls"] = predn[:, 5]
|
|
204
|
-
|
|
205
|
-
# Evaluate
|
|
206
|
-
if nl:
|
|
207
|
-
stat["tp"] = self._process_batch(predn, bbox, cls)
|
|
208
|
-
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
|
|
209
|
-
if self.args.plots:
|
|
210
|
-
self.confusion_matrix.process_batch(predn, bbox, cls)
|
|
211
|
-
|
|
212
|
-
for k in self.stats.keys():
|
|
213
|
-
self.stats[k].append(stat[k])
|
|
214
|
-
|
|
215
|
-
# Save
|
|
216
|
-
if self.args.save_json:
|
|
217
|
-
self.pred_to_json(predn, batch["im_file"][si])
|
|
218
|
-
if self.args.save_txt:
|
|
219
|
-
self.save_one_txt(
|
|
220
|
-
predn,
|
|
221
|
-
pred_kpts,
|
|
222
|
-
self.args.save_conf,
|
|
223
|
-
pbatch["ori_shape"],
|
|
224
|
-
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
|
225
|
-
)
|
|
226
|
-
|
|
227
|
-
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
|
228
|
-
"""
|
|
229
|
-
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
|
230
|
-
|
|
231
|
-
Args:
|
|
232
|
-
detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
|
|
233
|
-
detection is of the format (x1, y1, x2, y2, conf, class).
|
|
234
|
-
gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
|
|
235
|
-
box is of the format (x1, y1, x2, y2).
|
|
236
|
-
gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
|
|
237
|
-
pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
|
|
238
|
-
51 corresponds to 17 keypoints each having 3 values.
|
|
239
|
-
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
|
|
174
|
+
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
175
|
+
and 'keypoints' for keypoint predictions.
|
|
176
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
177
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
240
178
|
|
|
241
179
|
Returns:
|
|
242
|
-
(
|
|
243
|
-
|
|
180
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
181
|
+
positives across 10 IoU levels.
|
|
244
182
|
|
|
245
183
|
Notes:
|
|
246
184
|
`0.53` scale factor used in area computation is referenced from
|
|
247
185
|
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
|
248
186
|
"""
|
|
249
|
-
|
|
187
|
+
tp = super()._process_batch(preds, batch)
|
|
188
|
+
gt_cls = batch["cls"]
|
|
189
|
+
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
190
|
+
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
191
|
+
else:
|
|
250
192
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
|
251
|
-
area = ops.xyxy2xywh(
|
|
252
|
-
iou = kpt_iou(
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
|
257
|
-
|
|
258
|
-
def plot_val_samples(self, batch, ni):
|
|
259
|
-
"""
|
|
260
|
-
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
261
|
-
|
|
262
|
-
Args:
|
|
263
|
-
batch (dict): Dictionary containing batch data with keys:
|
|
264
|
-
- img (torch.Tensor): Batch of images
|
|
265
|
-
- batch_idx (torch.Tensor): Batch indices for each image
|
|
266
|
-
- cls (torch.Tensor): Class labels
|
|
267
|
-
- bboxes (torch.Tensor): Bounding box coordinates
|
|
268
|
-
- keypoints (torch.Tensor): Keypoint coordinates
|
|
269
|
-
- im_file (list): List of image file paths
|
|
270
|
-
ni (int): Batch index used for naming the output file
|
|
271
|
-
"""
|
|
272
|
-
plot_images(
|
|
273
|
-
batch["img"],
|
|
274
|
-
batch["batch_idx"],
|
|
275
|
-
batch["cls"].squeeze(-1),
|
|
276
|
-
batch["bboxes"],
|
|
277
|
-
kpts=batch["keypoints"],
|
|
278
|
-
paths=batch["im_file"],
|
|
279
|
-
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
280
|
-
names=self.names,
|
|
281
|
-
on_plot=self.on_plot,
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
def plot_predictions(self, batch, preds, ni):
|
|
285
|
-
"""
|
|
286
|
-
Plot and save model predictions with bounding boxes and keypoints.
|
|
287
|
-
|
|
288
|
-
Args:
|
|
289
|
-
batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
|
|
290
|
-
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
|
291
|
-
confidence scores, class predictions, and keypoints.
|
|
292
|
-
ni (int): Batch index used for naming the output file.
|
|
293
|
-
|
|
294
|
-
The function extracts keypoints from predictions, converts predictions to target format, and plots them
|
|
295
|
-
on the input images. The resulting visualization is saved to the specified save directory.
|
|
296
|
-
"""
|
|
297
|
-
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
|
298
|
-
plot_images(
|
|
299
|
-
batch["img"],
|
|
300
|
-
*output_to_target(preds, max_det=self.args.max_det),
|
|
301
|
-
kpts=pred_kpts,
|
|
302
|
-
paths=batch["im_file"],
|
|
303
|
-
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
304
|
-
names=self.names,
|
|
305
|
-
on_plot=self.on_plot,
|
|
306
|
-
) # pred
|
|
193
|
+
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
|
194
|
+
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
|
195
|
+
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
|
196
|
+
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
|
197
|
+
return tp
|
|
307
198
|
|
|
308
|
-
def save_one_txt(self, predn,
|
|
309
|
-
"""
|
|
310
|
-
Save YOLO pose detections to a text file in normalized coordinates.
|
|
199
|
+
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
200
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
311
201
|
|
|
312
202
|
Args:
|
|
313
|
-
predn (torch.Tensor): Prediction
|
|
314
|
-
pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
|
|
315
|
-
and D is the dimension (typically 3 for x, y, visibility).
|
|
203
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
316
204
|
save_conf (bool): Whether to save confidence scores.
|
|
317
|
-
shape (tuple):
|
|
205
|
+
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
318
206
|
file (Path): Output file path to save detections.
|
|
319
207
|
|
|
320
208
|
Notes:
|
|
@@ -327,68 +215,45 @@ class PoseValidator(DetectionValidator):
|
|
|
327
215
|
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
|
328
216
|
path=None,
|
|
329
217
|
names=self.names,
|
|
330
|
-
boxes=predn[
|
|
331
|
-
keypoints=
|
|
218
|
+
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
|
219
|
+
keypoints=predn["keypoints"],
|
|
332
220
|
).save_txt(file, save_conf=save_conf)
|
|
333
221
|
|
|
334
|
-
def pred_to_json(self, predn,
|
|
335
|
-
"""
|
|
336
|
-
Convert YOLO predictions to COCO JSON format.
|
|
222
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
223
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
337
224
|
|
|
338
|
-
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
|
339
|
-
|
|
225
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
226
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
340
227
|
|
|
341
228
|
Args:
|
|
342
|
-
predn (torch.Tensor): Prediction
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
filename (str | Path): Path to the image file for which predictions are being processed.
|
|
229
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
230
|
+
tensors.
|
|
231
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
346
232
|
|
|
347
233
|
Notes:
|
|
348
234
|
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
|
349
235
|
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
|
350
236
|
before saving to the JSON dictionary.
|
|
351
237
|
"""
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
238
|
+
super().pred_to_json(predn, pbatch)
|
|
239
|
+
kpts = predn["kpts"]
|
|
240
|
+
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
|
241
|
+
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
|
242
|
+
|
|
243
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
244
|
+
"""Scales predictions to the original image size."""
|
|
245
|
+
return {
|
|
246
|
+
**super().scale_preds(predn, pbatch),
|
|
247
|
+
"kpts": ops.scale_coords(
|
|
248
|
+
pbatch["imgsz"],
|
|
249
|
+
predn["keypoints"].clone(),
|
|
250
|
+
pbatch["ori_shape"],
|
|
251
|
+
ratio_pad=pbatch["ratio_pad"],
|
|
252
|
+
),
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
368
256
|
"""Evaluate object detection model using COCO JSON format."""
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
|
373
|
-
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
|
374
|
-
check_requirements("pycocotools>=2.0.6")
|
|
375
|
-
from pycocotools.coco import COCO # noqa
|
|
376
|
-
from pycocotools.cocoeval import COCOeval # noqa
|
|
377
|
-
|
|
378
|
-
for x in anno_json, pred_json:
|
|
379
|
-
assert x.is_file(), f"{x} file not found"
|
|
380
|
-
anno = COCO(str(anno_json)) # init annotations api
|
|
381
|
-
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
382
|
-
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
|
|
383
|
-
if self.is_coco:
|
|
384
|
-
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
|
385
|
-
eval.evaluate()
|
|
386
|
-
eval.accumulate()
|
|
387
|
-
eval.summarize()
|
|
388
|
-
idx = i * 4 + 2
|
|
389
|
-
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
|
390
|
-
:2
|
|
391
|
-
] # update mAP50-95 and mAP50
|
|
392
|
-
except Exception as e:
|
|
393
|
-
LOGGER.warning(f"pycocotools unable to run: {e}")
|
|
394
|
-
return stats
|
|
257
|
+
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
|
258
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
259
|
+
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
11
10
|
|
|
12
11
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
13
12
|
prediction results.
|
|
@@ -18,9 +17,9 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
18
17
|
batch (list): Current batch of images being processed.
|
|
19
18
|
|
|
20
19
|
Methods:
|
|
21
|
-
postprocess:
|
|
22
|
-
construct_results:
|
|
23
|
-
construct_result:
|
|
20
|
+
postprocess: Apply non-max suppression and process segmentation detections.
|
|
21
|
+
construct_results: Construct a list of result objects from predictions.
|
|
22
|
+
construct_result: Construct a single result object from a prediction.
|
|
24
23
|
|
|
25
24
|
Examples:
|
|
26
25
|
>>> from ultralytics.utils import ASSETS
|
|
@@ -31,14 +30,13 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
32
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
34
|
-
"""
|
|
35
|
-
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
36
34
|
|
|
37
35
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
38
36
|
prediction results.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
|
-
cfg (dict): Configuration for the predictor.
|
|
39
|
+
cfg (dict): Configuration for the predictor.
|
|
42
40
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
43
41
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
44
42
|
"""
|
|
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
46
44
|
self.args.task = "segment"
|
|
47
45
|
|
|
48
46
|
def postprocess(self, preds, img, orig_imgs):
|
|
49
|
-
"""
|
|
50
|
-
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
51
48
|
|
|
52
49
|
Args:
|
|
53
50
|
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
@@ -55,8 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
55
52
|
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
56
53
|
|
|
57
54
|
Returns:
|
|
58
|
-
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
|
59
|
-
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
60
57
|
|
|
61
58
|
Examples:
|
|
62
59
|
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
|
@@ -67,18 +64,17 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
67
64
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
68
65
|
|
|
69
66
|
def construct_results(self, preds, img, orig_imgs, protos):
|
|
70
|
-
"""
|
|
71
|
-
Construct a list of result objects from the predictions.
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
|
-
preds (
|
|
70
|
+
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
75
71
|
img (torch.Tensor): The image after preprocessing.
|
|
76
|
-
orig_imgs (
|
|
77
|
-
protos (
|
|
72
|
+
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
|
73
|
+
protos (list[torch.Tensor]): List of prototype masks.
|
|
78
74
|
|
|
79
75
|
Returns:
|
|
80
|
-
(
|
|
81
|
-
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
82
78
|
"""
|
|
83
79
|
return [
|
|
84
80
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
@@ -86,11 +82,10 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
86
82
|
]
|
|
87
83
|
|
|
88
84
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
89
|
-
"""
|
|
90
|
-
Construct a single result object from the prediction.
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
91
86
|
|
|
92
87
|
Args:
|
|
93
|
-
pred (
|
|
88
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
94
89
|
img (torch.Tensor): The image after preprocessing.
|
|
95
90
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
96
91
|
img_path (str): The path to the original image.
|
|
@@ -99,7 +94,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
99
94
|
Returns:
|
|
100
95
|
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
|
101
96
|
"""
|
|
102
|
-
if
|
|
97
|
+
if pred.shape[0] == 0: # save empty boxes
|
|
103
98
|
masks = None
|
|
104
99
|
elif self.args.retina_masks:
|
|
105
100
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
@@ -108,6 +103,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
108
103
|
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
|
109
104
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
110
105
|
if masks is not None:
|
|
111
|
-
keep = masks.
|
|
112
|
-
|
|
106
|
+
keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
|
|
107
|
+
if not all(keep): # most predictions have masks
|
|
108
|
+
pred, masks = pred[keep], masks[keep] # indexing is slow
|
|
113
109
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|