dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
|
@@ -8,12 +8,11 @@ from typing import Any
|
|
|
8
8
|
|
|
9
9
|
from ultralytics.models import yolo
|
|
10
10
|
from ultralytics.nn.tasks import PoseModel
|
|
11
|
-
from ultralytics.utils import DEFAULT_CFG
|
|
11
|
+
from ultralytics.utils import DEFAULT_CFG
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
15
|
-
"""
|
|
16
|
-
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
17
16
|
|
|
18
17
|
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
|
19
18
|
of pose keypoints alongside bounding boxes.
|
|
@@ -33,14 +32,13 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
33
32
|
|
|
34
33
|
Examples:
|
|
35
34
|
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
|
36
|
-
>>> args = dict(model="
|
|
35
|
+
>>> args = dict(model="yolo26n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
|
37
36
|
>>> trainer = PoseTrainer(overrides=args)
|
|
38
37
|
>>> trainer.train()
|
|
39
38
|
"""
|
|
40
39
|
|
|
41
40
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
42
|
-
"""
|
|
43
|
-
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
41
|
+
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
44
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -56,20 +54,13 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
56
54
|
overrides["task"] = "pose"
|
|
57
55
|
super().__init__(cfg, overrides, _callbacks)
|
|
58
56
|
|
|
59
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
60
|
-
LOGGER.warning(
|
|
61
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
62
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
57
|
def get_model(
|
|
66
58
|
self,
|
|
67
59
|
cfg: str | Path | dict[str, Any] | None = None,
|
|
68
60
|
weights: str | Path | None = None,
|
|
69
61
|
verbose: bool = True,
|
|
70
62
|
) -> PoseModel:
|
|
71
|
-
"""
|
|
72
|
-
Get pose estimation model with specified configuration and weights.
|
|
63
|
+
"""Get pose estimation model with specified configuration and weights.
|
|
73
64
|
|
|
74
65
|
Args:
|
|
75
66
|
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
|
@@ -91,17 +82,23 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
91
82
|
"""Set keypoints shape attribute of PoseModel."""
|
|
92
83
|
super().set_model_attributes()
|
|
93
84
|
self.model.kpt_shape = self.data["kpt_shape"]
|
|
85
|
+
kpt_names = self.data.get("kpt_names")
|
|
86
|
+
if not kpt_names:
|
|
87
|
+
names = list(map(str, range(self.model.kpt_shape[0])))
|
|
88
|
+
kpt_names = {i: names for i in range(self.model.nc)}
|
|
89
|
+
self.model.kpt_names = kpt_names
|
|
94
90
|
|
|
95
91
|
def get_validator(self):
|
|
96
92
|
"""Return an instance of the PoseValidator class for validation."""
|
|
97
93
|
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
|
94
|
+
if getattr(self.model.model[-1], "flow_model", None) is not None:
|
|
95
|
+
self.loss_names += ("rle_loss",)
|
|
98
96
|
return yolo.pose.PoseValidator(
|
|
99
97
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
100
98
|
)
|
|
101
99
|
|
|
102
100
|
def get_dataset(self) -> dict[str, Any]:
|
|
103
|
-
"""
|
|
104
|
-
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
101
|
+
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
105
102
|
|
|
106
103
|
Returns:
|
|
107
104
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|
|
@@ -9,16 +9,15 @@ import numpy as np
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
12
|
-
from ultralytics.utils import
|
|
12
|
+
from ultralytics.utils import ops
|
|
13
13
|
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PoseValidator(DetectionValidator):
|
|
17
|
-
"""
|
|
18
|
-
A class extending the DetectionValidator class for validation based on a pose model.
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
19
18
|
|
|
20
|
-
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
21
|
-
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
22
21
|
|
|
23
22
|
Attributes:
|
|
24
23
|
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
@@ -33,8 +32,8 @@ class PoseValidator(DetectionValidator):
|
|
|
33
32
|
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
34
33
|
dimensions.
|
|
35
34
|
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
36
|
-
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
|
37
|
-
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
38
37
|
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
39
38
|
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
40
39
|
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
@@ -43,45 +42,33 @@ class PoseValidator(DetectionValidator):
|
|
|
43
42
|
|
|
44
43
|
Examples:
|
|
45
44
|
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
46
|
-
>>> args = dict(model="
|
|
45
|
+
>>> args = dict(model="yolo26n-pose.pt", data="coco8-pose.yaml")
|
|
47
46
|
>>> validator = PoseValidator(args=args)
|
|
48
47
|
>>> validator()
|
|
48
|
+
|
|
49
|
+
Notes:
|
|
50
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
51
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
52
|
+
due to a known bug with pose models.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
52
|
-
"""
|
|
53
|
-
Initialize a PoseValidator object for pose estimation validation.
|
|
56
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
54
57
|
|
|
55
58
|
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
56
59
|
specialized metrics for pose evaluation.
|
|
57
60
|
|
|
58
61
|
Args:
|
|
59
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
62
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
60
63
|
save_dir (Path | str, optional): Directory to save results.
|
|
61
64
|
args (dict, optional): Arguments for the validator including task set to "pose".
|
|
62
65
|
_callbacks (list, optional): List of callback functions to be executed during validation.
|
|
63
|
-
|
|
64
|
-
Examples:
|
|
65
|
-
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
66
|
-
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
67
|
-
>>> validator = PoseValidator(args=args)
|
|
68
|
-
>>> validator()
|
|
69
|
-
|
|
70
|
-
Notes:
|
|
71
|
-
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
72
|
-
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
73
|
-
due to a known bug with pose models.
|
|
74
66
|
"""
|
|
75
67
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
76
68
|
self.sigma = None
|
|
77
69
|
self.kpt_shape = None
|
|
78
70
|
self.args.task = "pose"
|
|
79
71
|
self.metrics = PoseMetrics()
|
|
80
|
-
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
81
|
-
LOGGER.warning(
|
|
82
|
-
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
83
|
-
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
84
|
-
)
|
|
85
72
|
|
|
86
73
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
87
74
|
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
|
@@ -106,8 +93,7 @@ class PoseValidator(DetectionValidator):
|
|
|
106
93
|
)
|
|
107
94
|
|
|
108
95
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
109
|
-
"""
|
|
110
|
-
Initialize evaluation metrics for YOLO pose validation.
|
|
96
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
111
97
|
|
|
112
98
|
Args:
|
|
113
99
|
model (torch.nn.Module): Model to validate.
|
|
@@ -119,17 +105,15 @@ class PoseValidator(DetectionValidator):
|
|
|
119
105
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
120
106
|
|
|
121
107
|
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
122
|
-
"""
|
|
123
|
-
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
108
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
124
109
|
|
|
125
|
-
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
(typically [N, 17, 3] for COCO pose format).
|
|
110
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
111
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
112
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
129
113
|
|
|
130
114
|
Args:
|
|
131
|
-
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
|
132
|
-
|
|
115
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
116
|
+
scores, class predictions, and keypoint data.
|
|
133
117
|
|
|
134
118
|
Returns:
|
|
135
119
|
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
@@ -138,10 +122,10 @@ class PoseValidator(DetectionValidator):
|
|
|
138
122
|
- 'cls': Class predictions
|
|
139
123
|
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
140
124
|
|
|
141
|
-
|
|
142
|
-
If no keypoints are present in a prediction (empty keypoints), that prediction
|
|
143
|
-
|
|
144
|
-
|
|
125
|
+
Notes:
|
|
126
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
127
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
128
|
+
task-specific data beyond basic detection.
|
|
145
129
|
"""
|
|
146
130
|
preds = super().postprocess(preds)
|
|
147
131
|
for pred in preds:
|
|
@@ -149,8 +133,7 @@ class PoseValidator(DetectionValidator):
|
|
|
149
133
|
return preds
|
|
150
134
|
|
|
151
135
|
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
152
|
-
"""
|
|
153
|
-
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
136
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
154
137
|
|
|
155
138
|
Args:
|
|
156
139
|
si (int): Batch index.
|
|
@@ -173,18 +156,18 @@ class PoseValidator(DetectionValidator):
|
|
|
173
156
|
return pbatch
|
|
174
157
|
|
|
175
158
|
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
176
|
-
"""
|
|
177
|
-
|
|
159
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
160
|
+
truth.
|
|
178
161
|
|
|
179
162
|
Args:
|
|
180
163
|
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
181
164
|
and 'keypoints' for keypoint predictions.
|
|
182
|
-
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
|
183
|
-
|
|
165
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
166
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
184
167
|
|
|
185
168
|
Returns:
|
|
186
|
-
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
|
187
|
-
|
|
169
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
170
|
+
positives across 10 IoU levels.
|
|
188
171
|
|
|
189
172
|
Notes:
|
|
190
173
|
`0.53` scale factor used in area computation is referenced from
|
|
@@ -203,11 +186,10 @@ class PoseValidator(DetectionValidator):
|
|
|
203
186
|
return tp
|
|
204
187
|
|
|
205
188
|
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Save YOLO pose detections to a text file in normalized coordinates.
|
|
189
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
208
190
|
|
|
209
191
|
Args:
|
|
210
|
-
predn (dict[str, torch.Tensor]):
|
|
192
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
211
193
|
save_conf (bool): Whether to save confidence scores.
|
|
212
194
|
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
213
195
|
file (Path): Output file path to save detections.
|
|
@@ -227,15 +209,14 @@ class PoseValidator(DetectionValidator):
|
|
|
227
209
|
).save_txt(file, save_conf=save_conf)
|
|
228
210
|
|
|
229
211
|
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
230
|
-
"""
|
|
231
|
-
Convert YOLO predictions to COCO JSON format.
|
|
212
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
232
213
|
|
|
233
|
-
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
|
234
|
-
|
|
214
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
215
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
235
216
|
|
|
236
217
|
Args:
|
|
237
|
-
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
|
238
|
-
|
|
218
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
219
|
+
tensors.
|
|
239
220
|
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
240
221
|
|
|
241
222
|
Notes:
|
|
@@ -6,8 +6,7 @@ from ultralytics.utils import DEFAULT_CFG, ops
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
-
"""
|
|
10
|
-
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
11
10
|
|
|
12
11
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
13
12
|
prediction results.
|
|
@@ -25,14 +24,13 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
25
24
|
Examples:
|
|
26
25
|
>>> from ultralytics.utils import ASSETS
|
|
27
26
|
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
28
|
-
>>> args = dict(model="
|
|
27
|
+
>>> args = dict(model="yolo26n-seg.pt", source=ASSETS)
|
|
29
28
|
>>> predictor = SegmentationPredictor(overrides=args)
|
|
30
29
|
>>> predictor.predict_cli()
|
|
31
30
|
"""
|
|
32
31
|
|
|
33
32
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
34
|
-
"""
|
|
35
|
-
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
36
34
|
|
|
37
35
|
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
38
36
|
prediction results.
|
|
@@ -46,8 +44,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
46
44
|
self.args.task = "segment"
|
|
47
45
|
|
|
48
46
|
def postprocess(self, preds, img, orig_imgs):
|
|
49
|
-
"""
|
|
50
|
-
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
51
48
|
|
|
52
49
|
Args:
|
|
53
50
|
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
@@ -55,20 +52,19 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
55
52
|
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
56
53
|
|
|
57
54
|
Returns:
|
|
58
|
-
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
|
59
|
-
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
60
57
|
|
|
61
58
|
Examples:
|
|
62
|
-
>>> predictor = SegmentationPredictor(overrides=dict(model="
|
|
59
|
+
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo26n-seg.pt"))
|
|
63
60
|
>>> results = predictor.postprocess(preds, img, orig_img)
|
|
64
61
|
"""
|
|
65
62
|
# Extract protos - tuple if PyTorch model or array if exported
|
|
66
|
-
protos = preds[
|
|
63
|
+
protos = preds[0][1] if isinstance(preds[0], tuple) else preds[1]
|
|
67
64
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
68
65
|
|
|
69
66
|
def construct_results(self, preds, img, orig_imgs, protos):
|
|
70
|
-
"""
|
|
71
|
-
Construct a list of result objects from the predictions.
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
72
68
|
|
|
73
69
|
Args:
|
|
74
70
|
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
@@ -77,8 +73,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
77
73
|
protos (list[torch.Tensor]): List of prototype masks.
|
|
78
74
|
|
|
79
75
|
Returns:
|
|
80
|
-
(list[Results]): List of result objects containing the original images, image paths, class names,
|
|
81
|
-
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
82
78
|
"""
|
|
83
79
|
return [
|
|
84
80
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
@@ -86,8 +82,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
86
82
|
]
|
|
87
83
|
|
|
88
84
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
89
|
-
"""
|
|
90
|
-
Construct a single result object from the prediction.
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
91
86
|
|
|
92
87
|
Args:
|
|
93
88
|
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
@@ -103,11 +98,12 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
103
98
|
masks = None
|
|
104
99
|
elif self.args.retina_masks:
|
|
105
100
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
106
|
-
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) #
|
|
101
|
+
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
|
|
107
102
|
else:
|
|
108
|
-
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) #
|
|
103
|
+
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
|
|
109
104
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
110
105
|
if masks is not None:
|
|
111
|
-
keep = masks.
|
|
112
|
-
|
|
106
|
+
keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
|
|
107
|
+
if not all(keep): # most predictions have masks
|
|
108
|
+
pred, masks = pred[keep], masks[keep] # indexing is slow
|
|
113
109
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
|
@@ -11,8 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
16
15
|
|
|
17
16
|
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
18
17
|
functionality including model initialization, validation, and visualization.
|
|
@@ -22,14 +21,13 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
22
21
|
|
|
23
22
|
Examples:
|
|
24
23
|
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
|
25
|
-
>>> args = dict(model="
|
|
24
|
+
>>> args = dict(model="yolo26n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
|
26
25
|
>>> trainer = SegmentationTrainer(overrides=args)
|
|
27
26
|
>>> trainer.train()
|
|
28
27
|
"""
|
|
29
28
|
|
|
30
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
31
|
-
"""
|
|
32
|
-
Initialize a SegmentationTrainer object.
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
33
31
|
|
|
34
32
|
Args:
|
|
35
33
|
cfg (dict): Configuration dictionary with default training settings.
|
|
@@ -42,8 +40,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
42
40
|
super().__init__(cfg, overrides, _callbacks)
|
|
43
41
|
|
|
44
42
|
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
45
|
-
"""
|
|
46
|
-
Initialize and return a SegmentationModel with specified configuration and weights.
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
47
44
|
|
|
48
45
|
Args:
|
|
49
46
|
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
@@ -55,8 +52,8 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
55
52
|
|
|
56
53
|
Examples:
|
|
57
54
|
>>> trainer = SegmentationTrainer()
|
|
58
|
-
>>> model = trainer.get_model(cfg="
|
|
59
|
-
>>> model = trainer.get_model(weights="
|
|
55
|
+
>>> model = trainer.get_model(cfg="yolo26n-seg.yaml")
|
|
56
|
+
>>> model = trainer.get_model(weights="yolo26n-seg.pt", verbose=False)
|
|
60
57
|
"""
|
|
61
58
|
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
|
62
59
|
if weights:
|
|
@@ -66,7 +63,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
66
63
|
|
|
67
64
|
def get_validator(self):
|
|
68
65
|
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
|
69
|
-
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
|
66
|
+
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss", "sem_loss"
|
|
70
67
|
return yolo.segment.SegmentationValidator(
|
|
71
68
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
72
69
|
)
|