dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,394 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from ultralytics.models.yolo.detect import DetectionValidator
|
9
|
+
from ultralytics.utils import LOGGER, ops
|
10
|
+
from ultralytics.utils.checks import check_requirements
|
11
|
+
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
|
12
|
+
from ultralytics.utils.plotting import output_to_target, plot_images
|
13
|
+
|
14
|
+
|
15
|
+
class PoseValidator(DetectionValidator):
|
16
|
+
"""
|
17
|
+
A class extending the DetectionValidator class for validation based on a pose model.
|
18
|
+
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
20
|
+
specialized metrics for pose evaluation.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
24
|
+
kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
25
|
+
args (dict): Arguments for the validator including task set to "pose".
|
26
|
+
metrics (PoseMetrics): Metrics object for pose evaluation.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
preprocess: Preprocesses batch data for pose validation.
|
30
|
+
get_desc: Returns description of evaluation metrics.
|
31
|
+
init_metrics: Initializes pose metrics for the model.
|
32
|
+
_prepare_batch: Prepares a batch for processing.
|
33
|
+
_prepare_pred: Prepares and scales predictions for evaluation.
|
34
|
+
update_metrics: Updates metrics with new predictions.
|
35
|
+
_process_batch: Processes batch to compute IoU between detections and ground truth.
|
36
|
+
plot_val_samples: Plots validation samples with ground truth annotations.
|
37
|
+
plot_predictions: Plots model predictions.
|
38
|
+
save_one_txt: Saves detections to a text file.
|
39
|
+
pred_to_json: Converts predictions to COCO JSON format.
|
40
|
+
eval_json: Evaluates model using COCO JSON format.
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
44
|
+
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
45
|
+
>>> validator = PoseValidator(args=args)
|
46
|
+
>>> validator()
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
50
|
+
"""
|
51
|
+
Initialize a PoseValidator object for pose estimation validation.
|
52
|
+
|
53
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
54
|
+
specialized metrics for pose evaluation.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
58
|
+
save_dir (Path | str, optional): Directory to save results.
|
59
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
60
|
+
args (dict, optional): Arguments for the validator including task set to "pose".
|
61
|
+
_callbacks (list, optional): List of callback functions to be executed during validation.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
65
|
+
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
66
|
+
>>> validator = PoseValidator(args=args)
|
67
|
+
>>> validator()
|
68
|
+
|
69
|
+
Notes:
|
70
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
71
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
72
|
+
due to a known bug with pose models.
|
73
|
+
"""
|
74
|
+
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
75
|
+
self.sigma = None
|
76
|
+
self.kpt_shape = None
|
77
|
+
self.args.task = "pose"
|
78
|
+
self.metrics = PoseMetrics(save_dir=self.save_dir)
|
79
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
80
|
+
LOGGER.warning(
|
81
|
+
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
82
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
83
|
+
)
|
84
|
+
|
85
|
+
def preprocess(self, batch):
|
86
|
+
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
87
|
+
batch = super().preprocess(batch)
|
88
|
+
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
89
|
+
return batch
|
90
|
+
|
91
|
+
def get_desc(self):
|
92
|
+
"""Return description of evaluation metrics in string format."""
|
93
|
+
return ("%22s" + "%11s" * 10) % (
|
94
|
+
"Class",
|
95
|
+
"Images",
|
96
|
+
"Instances",
|
97
|
+
"Box(P",
|
98
|
+
"R",
|
99
|
+
"mAP50",
|
100
|
+
"mAP50-95)",
|
101
|
+
"Pose(P",
|
102
|
+
"R",
|
103
|
+
"mAP50",
|
104
|
+
"mAP50-95)",
|
105
|
+
)
|
106
|
+
|
107
|
+
def init_metrics(self, model):
|
108
|
+
"""Initialize pose estimation metrics for YOLO model."""
|
109
|
+
super().init_metrics(model)
|
110
|
+
self.kpt_shape = self.data["kpt_shape"]
|
111
|
+
is_pose = self.kpt_shape == [17, 3]
|
112
|
+
nkpt = self.kpt_shape[0]
|
113
|
+
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
114
|
+
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
115
|
+
|
116
|
+
def _prepare_batch(self, si, batch):
|
117
|
+
"""
|
118
|
+
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
si (int): Batch index.
|
122
|
+
batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
|
126
|
+
|
127
|
+
Notes:
|
128
|
+
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
129
|
+
Keypoints are scaled from normalized coordinates to original image dimensions.
|
130
|
+
"""
|
131
|
+
pbatch = super()._prepare_batch(si, batch)
|
132
|
+
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
133
|
+
h, w = pbatch["imgsz"]
|
134
|
+
kpts = kpts.clone()
|
135
|
+
kpts[..., 0] *= w
|
136
|
+
kpts[..., 1] *= h
|
137
|
+
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
138
|
+
pbatch["kpts"] = kpts
|
139
|
+
return pbatch
|
140
|
+
|
141
|
+
def _prepare_pred(self, pred, pbatch):
|
142
|
+
"""
|
143
|
+
Prepare and scale keypoints in predictions for pose processing.
|
144
|
+
|
145
|
+
This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
|
146
|
+
the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
|
147
|
+
to match the original image dimensions.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
pred (torch.Tensor): Raw prediction tensor from the model.
|
151
|
+
pbatch (dict): Processed batch dictionary containing image information including:
|
152
|
+
- imgsz: Image size used for inference
|
153
|
+
- ori_shape: Original image shape
|
154
|
+
- ratio_pad: Ratio and padding information for coordinate scaling
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
|
158
|
+
"""
|
159
|
+
predn = super()._prepare_pred(pred, pbatch)
|
160
|
+
nk = pbatch["kpts"].shape[1]
|
161
|
+
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
162
|
+
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
163
|
+
return predn, pred_kpts
|
164
|
+
|
165
|
+
def update_metrics(self, preds, batch):
|
166
|
+
"""
|
167
|
+
Update metrics with new predictions and ground truth data.
|
168
|
+
|
169
|
+
This method processes each prediction, compares it with ground truth, and updates various statistics
|
170
|
+
for performance evaluation.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model.
|
174
|
+
batch (dict): Batch data containing images and ground truth annotations.
|
175
|
+
"""
|
176
|
+
for si, pred in enumerate(preds):
|
177
|
+
self.seen += 1
|
178
|
+
npr = len(pred)
|
179
|
+
stat = dict(
|
180
|
+
conf=torch.zeros(0, device=self.device),
|
181
|
+
pred_cls=torch.zeros(0, device=self.device),
|
182
|
+
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
183
|
+
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
184
|
+
)
|
185
|
+
pbatch = self._prepare_batch(si, batch)
|
186
|
+
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
187
|
+
nl = len(cls)
|
188
|
+
stat["target_cls"] = cls
|
189
|
+
stat["target_img"] = cls.unique()
|
190
|
+
if npr == 0:
|
191
|
+
if nl:
|
192
|
+
for k in self.stats.keys():
|
193
|
+
self.stats[k].append(stat[k])
|
194
|
+
if self.args.plots:
|
195
|
+
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
196
|
+
continue
|
197
|
+
|
198
|
+
# Predictions
|
199
|
+
if self.args.single_cls:
|
200
|
+
pred[:, 5] = 0
|
201
|
+
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
202
|
+
stat["conf"] = predn[:, 4]
|
203
|
+
stat["pred_cls"] = predn[:, 5]
|
204
|
+
|
205
|
+
# Evaluate
|
206
|
+
if nl:
|
207
|
+
stat["tp"] = self._process_batch(predn, bbox, cls)
|
208
|
+
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
|
209
|
+
if self.args.plots:
|
210
|
+
self.confusion_matrix.process_batch(predn, bbox, cls)
|
211
|
+
|
212
|
+
for k in self.stats.keys():
|
213
|
+
self.stats[k].append(stat[k])
|
214
|
+
|
215
|
+
# Save
|
216
|
+
if self.args.save_json:
|
217
|
+
self.pred_to_json(predn, batch["im_file"][si])
|
218
|
+
if self.args.save_txt:
|
219
|
+
self.save_one_txt(
|
220
|
+
predn,
|
221
|
+
pred_kpts,
|
222
|
+
self.args.save_conf,
|
223
|
+
pbatch["ori_shape"],
|
224
|
+
self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
|
225
|
+
)
|
226
|
+
|
227
|
+
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
|
228
|
+
"""
|
229
|
+
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
|
233
|
+
detection is of the format (x1, y1, x2, y2, conf, class).
|
234
|
+
gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
|
235
|
+
box is of the format (x1, y1, x2, y2).
|
236
|
+
gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
|
237
|
+
pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
|
238
|
+
51 corresponds to 17 keypoints each having 3 values.
|
239
|
+
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
(torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
243
|
+
where N is the number of detections.
|
244
|
+
|
245
|
+
Notes:
|
246
|
+
`0.53` scale factor used in area computation is referenced from
|
247
|
+
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
248
|
+
"""
|
249
|
+
if pred_kpts is not None and gt_kpts is not None:
|
250
|
+
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
251
|
+
area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
|
252
|
+
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
|
253
|
+
else: # boxes
|
254
|
+
iou = box_iou(gt_bboxes, detections[:, :4])
|
255
|
+
|
256
|
+
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
257
|
+
|
258
|
+
def plot_val_samples(self, batch, ni):
|
259
|
+
"""
|
260
|
+
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
batch (dict): Dictionary containing batch data with keys:
|
264
|
+
- img (torch.Tensor): Batch of images
|
265
|
+
- batch_idx (torch.Tensor): Batch indices for each image
|
266
|
+
- cls (torch.Tensor): Class labels
|
267
|
+
- bboxes (torch.Tensor): Bounding box coordinates
|
268
|
+
- keypoints (torch.Tensor): Keypoint coordinates
|
269
|
+
- im_file (list): List of image file paths
|
270
|
+
ni (int): Batch index used for naming the output file
|
271
|
+
"""
|
272
|
+
plot_images(
|
273
|
+
batch["img"],
|
274
|
+
batch["batch_idx"],
|
275
|
+
batch["cls"].squeeze(-1),
|
276
|
+
batch["bboxes"],
|
277
|
+
kpts=batch["keypoints"],
|
278
|
+
paths=batch["im_file"],
|
279
|
+
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
280
|
+
names=self.names,
|
281
|
+
on_plot=self.on_plot,
|
282
|
+
)
|
283
|
+
|
284
|
+
def plot_predictions(self, batch, preds, ni):
|
285
|
+
"""
|
286
|
+
Plot and save model predictions with bounding boxes and keypoints.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
|
290
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
291
|
+
confidence scores, class predictions, and keypoints.
|
292
|
+
ni (int): Batch index used for naming the output file.
|
293
|
+
|
294
|
+
The function extracts keypoints from predictions, converts predictions to target format, and plots them
|
295
|
+
on the input images. The resulting visualization is saved to the specified save directory.
|
296
|
+
"""
|
297
|
+
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
298
|
+
plot_images(
|
299
|
+
batch["img"],
|
300
|
+
*output_to_target(preds, max_det=self.args.max_det),
|
301
|
+
kpts=pred_kpts,
|
302
|
+
paths=batch["im_file"],
|
303
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
304
|
+
names=self.names,
|
305
|
+
on_plot=self.on_plot,
|
306
|
+
) # pred
|
307
|
+
|
308
|
+
def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
|
309
|
+
"""
|
310
|
+
Save YOLO pose detections to a text file in normalized coordinates.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
|
314
|
+
pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
|
315
|
+
and D is the dimension (typically 3 for x, y, visibility).
|
316
|
+
save_conf (bool): Whether to save confidence scores.
|
317
|
+
shape (tuple): Original image shape (height, width).
|
318
|
+
file (Path): Output file path to save detections.
|
319
|
+
|
320
|
+
Notes:
|
321
|
+
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
322
|
+
normalized (x, y, visibility) values for each point.
|
323
|
+
"""
|
324
|
+
from ultralytics.engine.results import Results
|
325
|
+
|
326
|
+
Results(
|
327
|
+
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
328
|
+
path=None,
|
329
|
+
names=self.names,
|
330
|
+
boxes=predn[:, :6],
|
331
|
+
keypoints=pred_kpts,
|
332
|
+
).save_txt(file, save_conf=save_conf)
|
333
|
+
|
334
|
+
def pred_to_json(self, predn, filename):
|
335
|
+
"""
|
336
|
+
Convert YOLO predictions to COCO JSON format.
|
337
|
+
|
338
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
339
|
+
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
340
|
+
|
341
|
+
Args:
|
342
|
+
predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
|
343
|
+
and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
|
344
|
+
keypoints dimension.
|
345
|
+
filename (str | Path): Path to the image file for which predictions are being processed.
|
346
|
+
|
347
|
+
Notes:
|
348
|
+
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
349
|
+
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
350
|
+
before saving to the JSON dictionary.
|
351
|
+
"""
|
352
|
+
stem = Path(filename).stem
|
353
|
+
image_id = int(stem) if stem.isnumeric() else stem
|
354
|
+
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
355
|
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
356
|
+
for p, b in zip(predn.tolist(), box.tolist()):
|
357
|
+
self.jdict.append(
|
358
|
+
{
|
359
|
+
"image_id": image_id,
|
360
|
+
"category_id": self.class_map[int(p[5])],
|
361
|
+
"bbox": [round(x, 3) for x in b],
|
362
|
+
"keypoints": p[6:],
|
363
|
+
"score": round(p[4], 5),
|
364
|
+
}
|
365
|
+
)
|
366
|
+
|
367
|
+
def eval_json(self, stats):
|
368
|
+
"""Evaluate object detection model using COCO JSON format."""
|
369
|
+
if self.args.save_json and self.is_coco and len(self.jdict):
|
370
|
+
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
371
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
372
|
+
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
373
|
+
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
374
|
+
check_requirements("pycocotools>=2.0.6")
|
375
|
+
from pycocotools.coco import COCO # noqa
|
376
|
+
from pycocotools.cocoeval import COCOeval # noqa
|
377
|
+
|
378
|
+
for x in anno_json, pred_json:
|
379
|
+
assert x.is_file(), f"{x} file not found"
|
380
|
+
anno = COCO(str(anno_json)) # init annotations api
|
381
|
+
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
382
|
+
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
|
383
|
+
if self.is_coco:
|
384
|
+
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
385
|
+
eval.evaluate()
|
386
|
+
eval.accumulate()
|
387
|
+
eval.summarize()
|
388
|
+
idx = i * 4 + 2
|
389
|
+
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
390
|
+
:2
|
391
|
+
] # update mAP50-95 and mAP50
|
392
|
+
except Exception as e:
|
393
|
+
LOGGER.warning(f"pycocotools unable to run: {e}")
|
394
|
+
return stats
|
@@ -0,0 +1,7 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from .predict import SegmentationPredictor
|
4
|
+
from .train import SegmentationTrainer
|
5
|
+
from .val import SegmentationValidator
|
6
|
+
|
7
|
+
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from ultralytics.engine.results import Results
|
4
|
+
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
5
|
+
from ultralytics.utils import DEFAULT_CFG, ops
|
6
|
+
|
7
|
+
|
8
|
+
class SegmentationPredictor(DetectionPredictor):
|
9
|
+
"""
|
10
|
+
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
11
|
+
|
12
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
13
|
+
prediction results.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
args (dict): Configuration arguments for the predictor.
|
17
|
+
model (torch.nn.Module): The loaded YOLO segmentation model.
|
18
|
+
batch (list): Current batch of images being processed.
|
19
|
+
|
20
|
+
Methods:
|
21
|
+
postprocess: Applies non-max suppression and processes detections.
|
22
|
+
construct_results: Constructs a list of result objects from predictions.
|
23
|
+
construct_result: Constructs a single result object from a prediction.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
>>> from ultralytics.utils import ASSETS
|
27
|
+
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
28
|
+
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
|
29
|
+
>>> predictor = SegmentationPredictor(overrides=args)
|
30
|
+
>>> predictor.predict_cli()
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
34
|
+
"""
|
35
|
+
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
36
|
+
|
37
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
38
|
+
prediction results.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
|
42
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
43
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
44
|
+
"""
|
45
|
+
super().__init__(cfg, overrides, _callbacks)
|
46
|
+
self.args.task = "segment"
|
47
|
+
|
48
|
+
def postprocess(self, preds, img, orig_imgs):
|
49
|
+
"""
|
50
|
+
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
54
|
+
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
55
|
+
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
59
|
+
Each Results object includes both bounding boxes and segmentation masks.
|
60
|
+
|
61
|
+
Examples:
|
62
|
+
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
63
|
+
>>> results = predictor.postprocess(preds, img, orig_img)
|
64
|
+
"""
|
65
|
+
# Extract protos - tuple if PyTorch model or array if exported
|
66
|
+
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
67
|
+
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
68
|
+
|
69
|
+
def construct_results(self, preds, img, orig_imgs, protos):
|
70
|
+
"""
|
71
|
+
Construct a list of result objects from the predictions.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
75
|
+
img (torch.Tensor): The image after preprocessing.
|
76
|
+
orig_imgs (List[np.ndarray]): List of original images before preprocessing.
|
77
|
+
protos (List[torch.Tensor]): List of prototype masks.
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
(List[Results]): List of result objects containing the original images, image paths, class names,
|
81
|
+
bounding boxes, and masks.
|
82
|
+
"""
|
83
|
+
return [
|
84
|
+
self.construct_result(pred, img, orig_img, img_path, proto)
|
85
|
+
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
|
86
|
+
]
|
87
|
+
|
88
|
+
def construct_result(self, pred, img, orig_img, img_path, proto):
|
89
|
+
"""
|
90
|
+
Construct a single result object from the prediction.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
pred (np.ndarray): The predicted bounding boxes, scores, and masks.
|
94
|
+
img (torch.Tensor): The image after preprocessing.
|
95
|
+
orig_img (np.ndarray): The original image before preprocessing.
|
96
|
+
img_path (str): The path to the original image.
|
97
|
+
proto (torch.Tensor): The prototype masks.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
101
|
+
"""
|
102
|
+
if not len(pred): # save empty boxes
|
103
|
+
masks = None
|
104
|
+
elif self.args.retina_masks:
|
105
|
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
106
|
+
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
107
|
+
else:
|
108
|
+
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
109
|
+
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
110
|
+
if masks is not None:
|
111
|
+
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
|
112
|
+
pred, masks = pred[keep], masks[keep]
|
113
|
+
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from copy import copy
|
4
|
+
|
5
|
+
from ultralytics.models import yolo
|
6
|
+
from ultralytics.nn.tasks import SegmentationModel
|
7
|
+
from ultralytics.utils import DEFAULT_CFG, RANK
|
8
|
+
from ultralytics.utils.plotting import plot_images, plot_results
|
9
|
+
|
10
|
+
|
11
|
+
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
12
|
+
"""
|
13
|
+
A class extending the DetectionTrainer class for training based on a segmentation model.
|
14
|
+
|
15
|
+
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
16
|
+
functionality including model initialization, validation, and visualization.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
loss_names (Tuple[str]): Names of the loss components used during training.
|
20
|
+
|
21
|
+
Examples:
|
22
|
+
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
23
|
+
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
24
|
+
>>> trainer = SegmentationTrainer(overrides=args)
|
25
|
+
>>> trainer.train()
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
29
|
+
"""
|
30
|
+
Initialize a SegmentationTrainer object.
|
31
|
+
|
32
|
+
This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
|
33
|
+
functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict): Configuration dictionary with default training settings. Defaults to DEFAULT_CFG.
|
37
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
38
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
42
|
+
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
43
|
+
>>> trainer = SegmentationTrainer(overrides=args)
|
44
|
+
>>> trainer.train()
|
45
|
+
"""
|
46
|
+
if overrides is None:
|
47
|
+
overrides = {}
|
48
|
+
overrides["task"] = "segment"
|
49
|
+
super().__init__(cfg, overrides, _callbacks)
|
50
|
+
|
51
|
+
def get_model(self, cfg=None, weights=None, verbose=True):
|
52
|
+
"""
|
53
|
+
Initialize and return a SegmentationModel with specified configuration and weights.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
cfg (dict | str | None): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
57
|
+
weights (str | Path | None): Path to pretrained weights file.
|
58
|
+
verbose (bool): Whether to display model information during initialization.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> trainer = SegmentationTrainer()
|
65
|
+
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
|
66
|
+
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
|
67
|
+
"""
|
68
|
+
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
69
|
+
if weights:
|
70
|
+
model.load(weights)
|
71
|
+
|
72
|
+
return model
|
73
|
+
|
74
|
+
def get_validator(self):
|
75
|
+
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
76
|
+
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
77
|
+
return yolo.segment.SegmentationValidator(
|
78
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
79
|
+
)
|
80
|
+
|
81
|
+
def plot_training_samples(self, batch, ni):
|
82
|
+
"""
|
83
|
+
Plot training sample images with labels, bounding boxes, and masks.
|
84
|
+
|
85
|
+
This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
|
86
|
+
and segmentation masks, saving the result to a file for inspection and debugging.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
batch (dict): Dictionary containing batch data with the following keys:
|
90
|
+
'img': Images tensor
|
91
|
+
'batch_idx': Batch indices for each box
|
92
|
+
'cls': Class labels tensor (squeezed to remove last dimension)
|
93
|
+
'bboxes': Bounding box coordinates tensor
|
94
|
+
'masks': Segmentation masks tensor
|
95
|
+
'im_file': List of image file paths
|
96
|
+
ni (int): Current training iteration number, used for naming the output file.
|
97
|
+
|
98
|
+
Examples:
|
99
|
+
>>> trainer = SegmentationTrainer()
|
100
|
+
>>> batch = {
|
101
|
+
... "img": torch.rand(16, 3, 640, 640),
|
102
|
+
... "batch_idx": torch.zeros(16),
|
103
|
+
... "cls": torch.randint(0, 80, (16, 1)),
|
104
|
+
... "bboxes": torch.rand(16, 4),
|
105
|
+
... "masks": torch.rand(16, 640, 640),
|
106
|
+
... "im_file": ["image1.jpg", "image2.jpg"],
|
107
|
+
... }
|
108
|
+
>>> trainer.plot_training_samples(batch, ni=5)
|
109
|
+
"""
|
110
|
+
plot_images(
|
111
|
+
batch["img"],
|
112
|
+
batch["batch_idx"],
|
113
|
+
batch["cls"].squeeze(-1),
|
114
|
+
batch["bboxes"],
|
115
|
+
masks=batch["masks"],
|
116
|
+
paths=batch["im_file"],
|
117
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
118
|
+
on_plot=self.on_plot,
|
119
|
+
)
|
120
|
+
|
121
|
+
def plot_metrics(self):
|
122
|
+
"""Plots training/val metrics."""
|
123
|
+
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|