dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -8,8 +8,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class OBBPredictor(DetectionPredictor):
|
|
11
|
-
"""
|
|
12
|
-
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
11
|
+
"""A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
|
13
12
|
|
|
14
13
|
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
|
15
14
|
bounding boxes.
|
|
@@ -27,30 +26,22 @@ class OBBPredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize OBBPredictor with optional model and data configuration overrides.
|
|
29
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides.
|
|
32
30
|
|
|
33
31
|
Args:
|
|
34
32
|
cfg (dict, optional): Default configuration for the predictor.
|
|
35
33
|
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
|
36
34
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
37
|
-
|
|
38
|
-
Examples:
|
|
39
|
-
>>> from ultralytics.utils import ASSETS
|
|
40
|
-
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
|
41
|
-
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
|
42
|
-
>>> predictor = OBBPredictor(overrides=args)
|
|
43
35
|
"""
|
|
44
36
|
super().__init__(cfg, overrides, _callbacks)
|
|
45
37
|
self.args.task = "obb"
|
|
46
38
|
|
|
47
39
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
48
|
-
"""
|
|
49
|
-
Construct the result object from the prediction.
|
|
40
|
+
"""Construct the result object from the prediction.
|
|
50
41
|
|
|
51
42
|
Args:
|
|
52
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
|
53
|
-
|
|
43
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where the
|
|
44
|
+
last dimension contains [x, y, w, h, confidence, class_id, angle].
|
|
54
45
|
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
|
55
46
|
orig_img (np.ndarray): The original image before preprocessing.
|
|
56
47
|
img_path (str): The path to the original image.
|
|
@@ -12,15 +12,14 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
|
17
16
|
|
|
18
|
-
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
|
19
|
-
|
|
17
|
+
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for detecting
|
|
18
|
+
objects at arbitrary angles rather than just axis-aligned rectangles.
|
|
20
19
|
|
|
21
20
|
Attributes:
|
|
22
|
-
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
|
23
|
-
|
|
21
|
+
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss, and
|
|
22
|
+
dfl_loss.
|
|
24
23
|
|
|
25
24
|
Methods:
|
|
26
25
|
get_model: Return OBBModel initialized with specified config and weights.
|
|
@@ -34,14 +33,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
34
33
|
"""
|
|
35
34
|
|
|
36
35
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
|
37
|
-
"""
|
|
38
|
-
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
36
|
+
"""Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
|
39
37
|
|
|
40
38
|
Args:
|
|
41
|
-
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
|
42
|
-
|
|
43
|
-
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
|
44
|
-
|
|
39
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and model
|
|
40
|
+
configuration.
|
|
41
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here will
|
|
42
|
+
take precedence over those in cfg.
|
|
45
43
|
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
|
46
44
|
"""
|
|
47
45
|
if overrides is None:
|
|
@@ -52,8 +50,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
52
50
|
def get_model(
|
|
53
51
|
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
|
54
52
|
) -> OBBModel:
|
|
55
|
-
"""
|
|
56
|
-
Return OBBModel initialized with specified config and weights.
|
|
53
|
+
"""Return OBBModel initialized with specified config and weights.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
|
@@ -12,11 +12,11 @@ from ultralytics.models.yolo.detect import DetectionValidator
|
|
|
12
12
|
from ultralytics.utils import LOGGER, ops
|
|
13
13
|
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
|
14
14
|
from ultralytics.utils.nms import TorchNMS
|
|
15
|
+
from ultralytics.utils.plotting import plot_images
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
class OBBValidator(DetectionValidator):
|
|
18
|
-
"""
|
|
19
|
-
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
19
|
+
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
20
20
|
|
|
21
21
|
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
|
22
22
|
satellite imagery where objects can appear at various orientations.
|
|
@@ -44,14 +44,13 @@ class OBBValidator(DetectionValidator):
|
|
|
44
44
|
"""
|
|
45
45
|
|
|
46
46
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
47
|
-
"""
|
|
48
|
-
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
47
|
+
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
49
48
|
|
|
50
|
-
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
|
51
|
-
|
|
49
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
|
50
|
+
extends the DetectionValidator class and configures it specifically for the OBB task.
|
|
52
51
|
|
|
53
52
|
Args:
|
|
54
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
53
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
55
54
|
save_dir (str | Path, optional): Directory to save results.
|
|
56
55
|
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
|
57
56
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
@@ -61,8 +60,7 @@ class OBBValidator(DetectionValidator):
|
|
|
61
60
|
self.metrics = OBBMetrics()
|
|
62
61
|
|
|
63
62
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Initialize evaluation metrics for YOLO obb validation.
|
|
63
|
+
"""Initialize evaluation metrics for YOLO obb validation.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
model (torch.nn.Module): Model to validate.
|
|
@@ -73,19 +71,18 @@ class OBBValidator(DetectionValidator):
|
|
|
73
71
|
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
|
74
72
|
|
|
75
73
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
76
|
-
"""
|
|
77
|
-
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
74
|
+
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
78
75
|
|
|
79
76
|
Args:
|
|
80
77
|
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
|
81
78
|
class labels and bounding boxes.
|
|
82
|
-
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
|
83
|
-
|
|
79
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
|
80
|
+
labels and bounding boxes.
|
|
84
81
|
|
|
85
82
|
Returns:
|
|
86
|
-
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
|
84
|
+
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
|
85
|
+
predictions compared to the ground truth.
|
|
89
86
|
|
|
90
87
|
Examples:
|
|
91
88
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
|
@@ -99,7 +96,8 @@ class OBBValidator(DetectionValidator):
|
|
|
99
96
|
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
100
97
|
|
|
101
98
|
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
102
|
-
"""
|
|
99
|
+
"""Postprocess OBB predictions.
|
|
100
|
+
|
|
103
101
|
Args:
|
|
104
102
|
preds (torch.Tensor): Raw predictions from the model.
|
|
105
103
|
|
|
@@ -112,8 +110,7 @@ class OBBValidator(DetectionValidator):
|
|
|
112
110
|
return preds
|
|
113
111
|
|
|
114
112
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
115
|
-
"""
|
|
116
|
-
Prepare batch data for OBB validation with proper scaling and formatting.
|
|
113
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
|
117
114
|
|
|
118
115
|
Args:
|
|
119
116
|
si (int): Batch index to process.
|
|
@@ -145,33 +142,41 @@ class OBBValidator(DetectionValidator):
|
|
|
145
142
|
"im_file": batch["im_file"][si],
|
|
146
143
|
}
|
|
147
144
|
|
|
148
|
-
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
145
|
+
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
|
146
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
151
147
|
|
|
152
148
|
Args:
|
|
153
149
|
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
|
154
|
-
preds (list[torch.Tensor]): List of prediction
|
|
150
|
+
preds (list[dict[str, torch.Tensor]]): List of prediction dictionaries for each image in the batch.
|
|
155
151
|
ni (int): Batch index used for naming the output file.
|
|
156
152
|
|
|
157
153
|
Examples:
|
|
158
154
|
>>> validator = OBBValidator()
|
|
159
155
|
>>> batch = {"img": images, "im_file": paths}
|
|
160
|
-
>>> preds = [torch.rand(10,
|
|
156
|
+
>>> preds = [{"bboxes": torch.rand(10, 5), "cls": torch.zeros(10), "conf": torch.rand(10)}]
|
|
161
157
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
162
158
|
"""
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
if not preds:
|
|
160
|
+
return
|
|
161
|
+
for i, pred in enumerate(preds):
|
|
162
|
+
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i
|
|
163
|
+
keys = preds[0].keys()
|
|
164
|
+
batched_preds = {k: torch.cat([x[k] for x in preds], dim=0) for k in keys}
|
|
165
|
+
plot_images(
|
|
166
|
+
images=batch["img"],
|
|
167
|
+
labels=batched_preds,
|
|
168
|
+
paths=batch["im_file"],
|
|
169
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
170
|
+
names=self.names,
|
|
171
|
+
on_plot=self.on_plot,
|
|
172
|
+
)
|
|
167
173
|
|
|
168
174
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
169
|
-
"""
|
|
170
|
-
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
175
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
171
176
|
|
|
172
177
|
Args:
|
|
173
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
|
174
|
-
|
|
178
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
179
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
175
180
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
176
181
|
|
|
177
182
|
Notes:
|
|
@@ -197,8 +202,7 @@ class OBBValidator(DetectionValidator):
|
|
|
197
202
|
)
|
|
198
203
|
|
|
199
204
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
200
|
-
"""
|
|
201
|
-
Save YOLO OBB detections to a text file in normalized coordinates.
|
|
205
|
+
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
|
202
206
|
|
|
203
207
|
Args:
|
|
204
208
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
|
@@ -233,8 +237,7 @@ class OBBValidator(DetectionValidator):
|
|
|
233
237
|
}
|
|
234
238
|
|
|
235
239
|
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
236
|
-
"""
|
|
237
|
-
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
240
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
stats (dict[str, Any]): Performance statistics dictionary.
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
|
4
|
-
from ultralytics.utils import DEFAULT_CFG,
|
|
4
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class PosePredictor(DetectionPredictor):
|
|
8
|
-
"""
|
|
9
|
-
A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
8
|
+
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
10
9
|
|
|
11
10
|
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
|
12
11
|
capabilities inherited from DetectionPredictor.
|
|
@@ -27,35 +26,21 @@ class PosePredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize PosePredictor for pose estimation tasks.
|
|
29
|
+
"""Initialize PosePredictor for pose estimation tasks.
|
|
32
30
|
|
|
33
|
-
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
|
34
|
-
|
|
31
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
|
32
|
+
for Apple MPS.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
35
|
cfg (Any): Configuration for the predictor.
|
|
38
36
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
39
37
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
40
|
-
|
|
41
|
-
Examples:
|
|
42
|
-
>>> from ultralytics.utils import ASSETS
|
|
43
|
-
>>> from ultralytics.models.yolo.pose import PosePredictor
|
|
44
|
-
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
|
45
|
-
>>> predictor = PosePredictor(overrides=args)
|
|
46
|
-
>>> predictor.predict_cli()
|
|
47
38
|
"""
|
|
48
39
|
super().__init__(cfg, overrides, _callbacks)
|
|
49
40
|
self.args.task = "pose"
|
|
50
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
51
|
-
LOGGER.warning(
|
|
52
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
53
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
54
|
-
)
|
|
55
41
|
|
|
56
42
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
57
|
-
"""
|
|
58
|
-
Construct the result object from the prediction, including keypoints.
|
|
43
|
+
"""Construct the result object from the prediction, including keypoints.
|
|
59
44
|
|
|
60
45
|
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
|
61
46
|
result object.
|
|
@@ -8,12 +8,11 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
from ultralytics.models import yolo
|
|
10
10
|
from ultralytics.nn.tasks import PoseModel
|
|
11
|
-
from ultralytics.utils import DEFAULT_CFG
|
|
11
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
17
16
|
|
|
18
17
|
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
|
19
18
|
of pose keypoints alongside bounding boxes.
|
|
@@ -39,8 +38,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
39
38
|
"""
|
|
40
39
|
|
|
41
40
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
42
|
-
"""
|
|
43
|
-
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
41
|
+
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
44
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -56,20 +54,13 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
56
54
|
overrides["task"] = "pose"
|
|
57
55
|
super().__init__(cfg, overrides, _callbacks)
|
|
58
56
|
|
|
59
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
60
|
-
LOGGER.warning(
|
|
61
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
62
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
57
|
def get_model(
|
|
66
58
|
self,
|
|
67
59
|
cfg: str | Path | dict[str, Any] | None = None,
|
|
68
60
|
weights: str | Path | None = None,
|
|
69
61
|
verbose: bool = True,
|
|
70
62
|
) -> PoseModel:
|
|
71
|
-
"""
|
|
72
|
-
Get pose estimation model with specified configuration and weights.
|
|
63
|
+
"""Get pose estimation model with specified configuration and weights.
|
|
73
64
|
|
|
74
65
|
Args:
|
|
75
66
|
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
|
@@ -91,6 +82,11 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
91
82
|
"""Set keypoints shape attribute of PoseModel."""
|
|
92
83
|
super().set_model_attributes()
|
|
93
84
|
self.model.kpt_shape = self.data["kpt_shape"]
|
|
85
|
+
kpt_names = self.data.get("kpt_names")
|
|
86
|
+
if not kpt_names:
|
|
87
|
+
names = list(map(str, range(self.model.kpt_shape[0])))
|
|
88
|
+
kpt_names = {i: names for i in range(self.model.nc)}
|
|
89
|
+
self.model.kpt_names = kpt_names
|
|
94
90
|
|
|
95
91
|
def get_validator(self):
|
|
96
92
|
"""Return an instance of the PoseValidator class for validation."""
|
|
@@ -100,8 +96,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
100
96
|
)
|
|
101
97
|
|
|
102
98
|
def get_dataset(self) -> dict[str, Any]:
|
|
103
|
-
"""
|
|
104
|
-
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
99
|
+
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
105
100
|
|
|
106
101
|
Returns:
|
|
107
102
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -9,16 +9,15 @@ import numpy as np
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
12
|
-
from ultralytics.utils import
|
|
12
|
+
from ultralytics.utils import ops
|
|
13
13
|
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PoseValidator(DetectionValidator):
|
|
17
|
-
"""
|
|
18
|
-
A class extending the DetectionValidator class for validation based on a pose model.
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
19
18
|
|
|
20
|
-
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
21
|
-
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
|
|
|
33
32
|
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
34
33
|
dimensions.
|
|
35
34
|
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
36
|
-
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
|
37
|
-
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
38
37
|
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
39
38
|
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
40
39
|
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
@@ -46,42 +45,30 @@ class PoseValidator(DetectionValidator):
|
|
|
46
45
|
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
47
46
|
>>> validator = PoseValidator(args=args)
|
|
48
47
|
>>> validator()
|
|
48
|
+
|
|
49
|
+
Notes:
|
|
50
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
51
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
52
|
+
due to a known bug with pose models.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
52
|
-
"""
|
|
53
|
-
Initialize a PoseValidator object for pose estimation validation.
|
|
56
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
54
57
|
|
|
55
58
|
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
56
59
|
specialized metrics for pose evaluation.
|
|
57
60
|
|
|
58
61
|
Args:
|
|
59
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
62
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
60
63
|
save_dir (Path | str, optional): Directory to save results.
|
|
61
64
|
args (dict, optional): Arguments for the validator including task set to "pose".
|
|
62
65
|
_callbacks (list, optional): List of callback functions to be executed during validation.
|
|
63
|
-
|
|
64
|
-
Examples:
|
|
65
|
-
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
66
|
-
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
67
|
-
>>> validator = PoseValidator(args=args)
|
|
68
|
-
>>> validator()
|
|
69
|
-
|
|
70
|
-
Notes:
|
|
71
|
-
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
72
|
-
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
73
|
-
due to a known bug with pose models.
|
|
74
66
|
"""
|
|
75
67
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
76
68
|
self.sigma = None
|
|
77
69
|
self.kpt_shape = None
|
|
78
70
|
self.args.task = "pose"
|
|
79
71
|
self.metrics = PoseMetrics()
|
|
80
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
81
|
-
LOGGER.warning(
|
|
82
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
83
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
84
|
-
)
|
|
85
72
|
|
|
86
73
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
87
74
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
|
@@ -106,8 +93,7 @@ class PoseValidator(DetectionValidator):
|
|
|
106
93
|
)
|
|
107
94
|
|
|
108
95
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
109
|
-
"""
|
|
110
|
-
Initialize evaluation metrics for YOLO pose validation.
|
|
96
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
111
97
|
|
|
112
98
|
Args:
|
|
113
99
|
model (torch.nn.Module): Model to validate.
|
|
@@ -119,17 +105,15 @@ class PoseValidator(DetectionValidator):
|
|
|
119
105
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
120
106
|
|
|
121
107
|
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
122
|
-
"""
|
|
123
|
-
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
108
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
124
109
|
|
|
125
|
-
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
(typically [N, 17, 3] for COCO pose format).
|
|
110
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
111
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
112
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
129
113
|
|
|
130
114
|
Args:
|
|
131
|
-
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
|
132
|
-
|
|
115
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
116
|
+
scores, class predictions, and keypoint data.
|
|
133
117
|
|
|
134
118
|
Returns:
|
|
135
119
|
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
@@ -138,10 +122,10 @@ class PoseValidator(DetectionValidator):
|
|
|
138
122
|
- 'cls': Class predictions
|
|
139
123
|
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
140
124
|
|
|
141
|
-
|
|
142
|
-
If no keypoints are present in a prediction (empty keypoints), that prediction
|
|
143
|
-
|
|
144
|
-
|
|
125
|
+
Notes:
|
|
126
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
127
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
128
|
+
task-specific data beyond basic detection.
|
|
145
129
|
"""
|
|
146
130
|
preds = super().postprocess(preds)
|
|
147
131
|
for pred in preds:
|
|
@@ -149,8 +133,7 @@ class PoseValidator(DetectionValidator):
|
|
|
149
133
|
return preds
|
|
150
134
|
|
|
151
135
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
152
|
-
"""
|
|
153
|
-
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
136
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
154
137
|
|
|
155
138
|
Args:
|
|
156
139
|
si (int): Batch index.
|
|
@@ -173,18 +156,18 @@ class PoseValidator(DetectionValidator):
|
|
|
173
156
|
return pbatch
|
|
174
157
|
|
|
175
158
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
176
|
-
"""
|
|
177
|
-
|
|
159
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
160
|
+
truth.
|
|
178
161
|
|
|
179
162
|
Args:
|
|
180
163
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
181
164
|
and 'keypoints' for keypoint predictions.
|
|
182
|
-
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
|
183
|
-
|
|
165
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
166
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
184
167
|
|
|
185
168
|
Returns:
|
|
186
|
-
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
|
187
|
-
|
|
169
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
170
|
+
positives across 10 IoU levels.
|
|
188
171
|
|
|
189
172
|
Notes:
|
|
190
173
|
`0.53` scale factor used in area computation is referenced from
|
|
@@ -203,11 +186,10 @@ class PoseValidator(DetectionValidator):
|
|
|
203
186
|
return tp
|
|
204
187
|
|
|
205
188
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Save YOLO pose detections to a text file in normalized coordinates.
|
|
189
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
208
190
|
|
|
209
191
|
Args:
|
|
210
|
-
predn (dict[str, torch.Tensor]):
|
|
192
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
211
193
|
save_conf (bool): Whether to save confidence scores.
|
|
212
194
|
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
213
195
|
file (Path): Output file path to save detections.
|
|
@@ -227,15 +209,14 @@ class PoseValidator(DetectionValidator):
|
|
|
227
209
|
).save_txt(file, save_conf=save_conf)
|
|
228
210
|
|
|
229
211
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
230
|
-
"""
|
|
231
|
-
Convert YOLO predictions to COCO JSON format.
|
|
212
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
232
213
|
|
|
233
|
-
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
|
234
|
-
|
|
214
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
215
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
235
216
|
|
|
236
217
|
Args:
|
|
237
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
|
238
|
-
|
|
218
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
219
|
+
tensors.
|
|
239
220
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
240
221
|
|
|
241
222
|
Notes:
|