dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
|
12
|
+
from ultralytics.utils import ops
|
|
13
|
+
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PoseValidator(DetectionValidator):
|
|
17
|
+
"""A class extending the DetectionValidator class for validation based on a pose model.
|
|
18
|
+
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing specialized
|
|
20
|
+
metrics for pose evaluation.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
|
24
|
+
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
|
25
|
+
args (dict): Arguments for the validator including task set to "pose".
|
|
26
|
+
metrics (PoseMetrics): Metrics object for pose evaluation.
|
|
27
|
+
|
|
28
|
+
Methods:
|
|
29
|
+
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
|
|
30
|
+
get_desc: Return description of evaluation metrics in string format.
|
|
31
|
+
init_metrics: Initialize pose estimation metrics for YOLO model.
|
|
32
|
+
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
|
33
|
+
dimensions.
|
|
34
|
+
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
|
35
|
+
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between detections
|
|
36
|
+
and ground truth.
|
|
37
|
+
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
|
38
|
+
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
|
39
|
+
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
|
40
|
+
pred_to_json: Convert YOLO predictions to COCO JSON format.
|
|
41
|
+
eval_json: Evaluate object detection model using COCO JSON format.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
|
45
|
+
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
|
46
|
+
>>> validator = PoseValidator(args=args)
|
|
47
|
+
>>> validator()
|
|
48
|
+
|
|
49
|
+
Notes:
|
|
50
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
|
51
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
|
52
|
+
due to a known bug with pose models.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
56
|
+
"""Initialize a PoseValidator object for pose estimation validation.
|
|
57
|
+
|
|
58
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
|
59
|
+
specialized metrics for pose evaluation.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
63
|
+
save_dir (Path | str, optional): Directory to save results.
|
|
64
|
+
args (dict, optional): Arguments for the validator including task set to "pose".
|
|
65
|
+
_callbacks (list, optional): List of callback functions to be executed during validation.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
68
|
+
self.sigma = None
|
|
69
|
+
self.kpt_shape = None
|
|
70
|
+
self.args.task = "pose"
|
|
71
|
+
self.metrics = PoseMetrics()
|
|
72
|
+
|
|
73
|
+
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
74
|
+
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
|
75
|
+
batch = super().preprocess(batch)
|
|
76
|
+
batch["keypoints"] = batch["keypoints"].float()
|
|
77
|
+
return batch
|
|
78
|
+
|
|
79
|
+
def get_desc(self) -> str:
|
|
80
|
+
"""Return description of evaluation metrics in string format."""
|
|
81
|
+
return ("%22s" + "%11s" * 10) % (
|
|
82
|
+
"Class",
|
|
83
|
+
"Images",
|
|
84
|
+
"Instances",
|
|
85
|
+
"Box(P",
|
|
86
|
+
"R",
|
|
87
|
+
"mAP50",
|
|
88
|
+
"mAP50-95)",
|
|
89
|
+
"Pose(P",
|
|
90
|
+
"R",
|
|
91
|
+
"mAP50",
|
|
92
|
+
"mAP50-95)",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
96
|
+
"""Initialize evaluation metrics for YOLO pose validation.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
model (torch.nn.Module): Model to validate.
|
|
100
|
+
"""
|
|
101
|
+
super().init_metrics(model)
|
|
102
|
+
self.kpt_shape = self.data["kpt_shape"]
|
|
103
|
+
is_pose = self.kpt_shape == [17, 3]
|
|
104
|
+
nkpt = self.kpt_shape[0]
|
|
105
|
+
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
106
|
+
|
|
107
|
+
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
|
108
|
+
"""Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
|
109
|
+
|
|
110
|
+
This method extends the parent class postprocessing by extracting keypoints from the 'extra' field of
|
|
111
|
+
predictions and reshaping them according to the keypoint shape configuration. The keypoints are reshaped from a
|
|
112
|
+
flattened format to the proper dimensional structure (typically [N, 17, 3] for COCO pose format).
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing bounding boxes, confidence
|
|
116
|
+
scores, class predictions, and keypoint data.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
|
120
|
+
- 'bboxes': Bounding box coordinates
|
|
121
|
+
- 'conf': Confidence scores
|
|
122
|
+
- 'cls': Class predictions
|
|
123
|
+
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
|
124
|
+
|
|
125
|
+
Notes:
|
|
126
|
+
If no keypoints are present in a prediction (empty keypoints), that prediction is skipped and continues
|
|
127
|
+
to the next one. The keypoints are extracted from the 'extra' field which contains additional
|
|
128
|
+
task-specific data beyond basic detection.
|
|
129
|
+
"""
|
|
130
|
+
preds = super().postprocess(preds)
|
|
131
|
+
for pred in preds:
|
|
132
|
+
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
|
|
133
|
+
return preds
|
|
134
|
+
|
|
135
|
+
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
136
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
si (int): Batch index.
|
|
140
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
|
144
|
+
|
|
145
|
+
Notes:
|
|
146
|
+
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
|
147
|
+
Keypoints are scaled from normalized coordinates to original image dimensions.
|
|
148
|
+
"""
|
|
149
|
+
pbatch = super()._prepare_batch(si, batch)
|
|
150
|
+
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
|
151
|
+
h, w = pbatch["imgsz"]
|
|
152
|
+
kpts = kpts.clone()
|
|
153
|
+
kpts[..., 0] *= w
|
|
154
|
+
kpts[..., 1] *= h
|
|
155
|
+
pbatch["keypoints"] = kpts
|
|
156
|
+
return pbatch
|
|
157
|
+
|
|
158
|
+
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
|
159
|
+
"""Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground
|
|
160
|
+
truth.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
|
164
|
+
and 'keypoints' for keypoint predictions.
|
|
165
|
+
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels, 'bboxes'
|
|
166
|
+
for bounding boxes, and 'keypoints' for keypoint annotations.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose true
|
|
170
|
+
positives across 10 IoU levels.
|
|
171
|
+
|
|
172
|
+
Notes:
|
|
173
|
+
`0.53` scale factor used in area computation is referenced from
|
|
174
|
+
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
|
175
|
+
"""
|
|
176
|
+
tp = super()._process_batch(preds, batch)
|
|
177
|
+
gt_cls = batch["cls"]
|
|
178
|
+
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
179
|
+
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
|
180
|
+
else:
|
|
181
|
+
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
|
182
|
+
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
|
183
|
+
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
|
184
|
+
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
|
185
|
+
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
|
186
|
+
return tp
|
|
187
|
+
|
|
188
|
+
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
189
|
+
"""Save YOLO pose detections to a text file in normalized coordinates.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
predn (dict[str, torch.Tensor]): Prediction dict with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
|
193
|
+
save_conf (bool): Whether to save confidence scores.
|
|
194
|
+
shape (tuple[int, int]): Shape of the original image (height, width).
|
|
195
|
+
file (Path): Output file path to save detections.
|
|
196
|
+
|
|
197
|
+
Notes:
|
|
198
|
+
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
|
199
|
+
normalized (x, y, visibility) values for each point.
|
|
200
|
+
"""
|
|
201
|
+
from ultralytics.engine.results import Results
|
|
202
|
+
|
|
203
|
+
Results(
|
|
204
|
+
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
|
205
|
+
path=None,
|
|
206
|
+
names=self.names,
|
|
207
|
+
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
|
208
|
+
keypoints=predn["keypoints"],
|
|
209
|
+
).save_txt(file, save_conf=save_conf)
|
|
210
|
+
|
|
211
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
212
|
+
"""Convert YOLO predictions to COCO JSON format.
|
|
213
|
+
|
|
214
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format to COCO
|
|
215
|
+
format, and appends the results to the internal JSON dictionary (self.jdict).
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls', and 'keypoints'
|
|
219
|
+
tensors.
|
|
220
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
221
|
+
|
|
222
|
+
Notes:
|
|
223
|
+
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
|
224
|
+
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
|
225
|
+
before saving to the JSON dictionary.
|
|
226
|
+
"""
|
|
227
|
+
super().pred_to_json(predn, pbatch)
|
|
228
|
+
kpts = predn["kpts"]
|
|
229
|
+
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
|
230
|
+
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
|
231
|
+
|
|
232
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
233
|
+
"""Scales predictions to the original image size."""
|
|
234
|
+
return {
|
|
235
|
+
**super().scale_preds(predn, pbatch),
|
|
236
|
+
"kpts": ops.scale_coords(
|
|
237
|
+
pbatch["imgsz"],
|
|
238
|
+
predn["keypoints"].clone(),
|
|
239
|
+
pbatch["ori_shape"],
|
|
240
|
+
ratio_pad=pbatch["ratio_pad"],
|
|
241
|
+
),
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
245
|
+
"""Evaluate object detection model using COCO JSON format."""
|
|
246
|
+
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
|
247
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
248
|
+
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from .predict import SegmentationPredictor
|
|
4
|
+
from .train import SegmentationTrainer
|
|
5
|
+
from .val import SegmentationValidator
|
|
6
|
+
|
|
7
|
+
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from ultralytics.engine.results import Results
|
|
4
|
+
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
|
5
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SegmentationPredictor(DetectionPredictor):
|
|
9
|
+
"""A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
|
10
|
+
|
|
11
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
12
|
+
prediction results.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
args (dict): Configuration arguments for the predictor.
|
|
16
|
+
model (torch.nn.Module): The loaded YOLO segmentation model.
|
|
17
|
+
batch (list): Current batch of images being processed.
|
|
18
|
+
|
|
19
|
+
Methods:
|
|
20
|
+
postprocess: Apply non-max suppression and process segmentation detections.
|
|
21
|
+
construct_results: Construct a list of result objects from predictions.
|
|
22
|
+
construct_result: Construct a single result object from a prediction.
|
|
23
|
+
|
|
24
|
+
Examples:
|
|
25
|
+
>>> from ultralytics.utils import ASSETS
|
|
26
|
+
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
|
27
|
+
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
|
|
28
|
+
>>> predictor = SegmentationPredictor(overrides=args)
|
|
29
|
+
>>> predictor.predict_cli()
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
33
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
|
34
|
+
|
|
35
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
|
36
|
+
prediction results.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
cfg (dict): Configuration for the predictor.
|
|
40
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
41
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
44
|
+
self.args.task = "segment"
|
|
45
|
+
|
|
46
|
+
def postprocess(self, preds, img, orig_imgs):
|
|
47
|
+
"""Apply non-max suppression and process segmentation detections for each image in the input batch.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
|
51
|
+
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
|
52
|
+
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch. Each
|
|
56
|
+
Results object includes both bounding boxes and segmentation masks.
|
|
57
|
+
|
|
58
|
+
Examples:
|
|
59
|
+
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
|
60
|
+
>>> results = predictor.postprocess(preds, img, orig_img)
|
|
61
|
+
"""
|
|
62
|
+
# Extract protos - tuple if PyTorch model or array if exported
|
|
63
|
+
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
|
64
|
+
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
|
65
|
+
|
|
66
|
+
def construct_results(self, preds, img, orig_imgs, protos):
|
|
67
|
+
"""Construct a list of result objects from the predictions.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
|
71
|
+
img (torch.Tensor): The image after preprocessing.
|
|
72
|
+
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
|
73
|
+
protos (list[torch.Tensor]): List of prototype masks.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
(list[Results]): List of result objects containing the original images, image paths, class names, bounding
|
|
77
|
+
boxes, and masks.
|
|
78
|
+
"""
|
|
79
|
+
return [
|
|
80
|
+
self.construct_result(pred, img, orig_img, img_path, proto)
|
|
81
|
+
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
def construct_result(self, pred, img, orig_img, img_path, proto):
|
|
85
|
+
"""Construct a single result object from the prediction.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
|
89
|
+
img (torch.Tensor): The image after preprocessing.
|
|
90
|
+
orig_img (np.ndarray): The original image before preprocessing.
|
|
91
|
+
img_path (str): The path to the original image.
|
|
92
|
+
proto (torch.Tensor): The prototype masks.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
|
96
|
+
"""
|
|
97
|
+
if pred.shape[0] == 0: # save empty boxes
|
|
98
|
+
masks = None
|
|
99
|
+
elif self.args.retina_masks:
|
|
100
|
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
101
|
+
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # NHW
|
|
102
|
+
else:
|
|
103
|
+
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # NHW
|
|
104
|
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
|
105
|
+
if masks is not None:
|
|
106
|
+
keep = masks.amax((-2, -1)) > 0 # only keep predictions with masks
|
|
107
|
+
if not all(keep): # most predictions have masks
|
|
108
|
+
pred, masks = pred[keep], masks[keep] # indexing is slow
|
|
109
|
+
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from ultralytics.models import yolo
|
|
9
|
+
from ultralytics.nn.tasks import SegmentationModel
|
|
10
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
14
|
+
"""A class extending the DetectionTrainer class for training based on a segmentation model.
|
|
15
|
+
|
|
16
|
+
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
|
17
|
+
functionality including model initialization, validation, and visualization.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
loss_names (tuple[str]): Names of the loss components used during training.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
|
24
|
+
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
|
25
|
+
>>> trainer = SegmentationTrainer(overrides=args)
|
|
26
|
+
>>> trainer.train()
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
|
30
|
+
"""Initialize a SegmentationTrainer object.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
cfg (dict): Configuration dictionary with default training settings.
|
|
34
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
|
35
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
|
36
|
+
"""
|
|
37
|
+
if overrides is None:
|
|
38
|
+
overrides = {}
|
|
39
|
+
overrides["task"] = "segment"
|
|
40
|
+
super().__init__(cfg, overrides, _callbacks)
|
|
41
|
+
|
|
42
|
+
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
|
43
|
+
"""Initialize and return a SegmentationModel with specified configuration and weights.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
|
47
|
+
weights (str | Path, optional): Path to pretrained weights file.
|
|
48
|
+
verbose (bool): Whether to display model information during initialization.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> trainer = SegmentationTrainer()
|
|
55
|
+
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
|
|
56
|
+
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
|
|
57
|
+
"""
|
|
58
|
+
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
|
59
|
+
if weights:
|
|
60
|
+
model.load(weights)
|
|
61
|
+
|
|
62
|
+
return model
|
|
63
|
+
|
|
64
|
+
def get_validator(self):
|
|
65
|
+
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
|
66
|
+
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
|
67
|
+
return yolo.segment.SegmentationValidator(
|
|
68
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
69
|
+
)
|