dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,21 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
8
|
+
import numpy as np
|
|
5
9
|
import torch
|
|
6
10
|
|
|
7
11
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
8
12
|
from ultralytics.utils import LOGGER, ops
|
|
9
13
|
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
|
10
|
-
from ultralytics.utils.
|
|
14
|
+
from ultralytics.utils.nms import TorchNMS
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
class OBBValidator(DetectionValidator):
|
|
14
|
-
"""
|
|
15
|
-
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
18
|
+
"""A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
|
16
19
|
|
|
17
20
|
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
|
18
21
|
satellite imagery where objects can appear at various orientations.
|
|
@@ -39,64 +42,78 @@ class OBBValidator(DetectionValidator):
|
|
|
39
42
|
>>> validator(model=args["model"])
|
|
40
43
|
"""
|
|
41
44
|
|
|
42
|
-
def __init__(self, dataloader=None, save_dir=None,
|
|
43
|
-
"""
|
|
44
|
-
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
45
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
46
|
+
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
|
45
47
|
|
|
46
|
-
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
|
47
|
-
|
|
48
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models. It
|
|
49
|
+
extends the DetectionValidator class and configures it specifically for the OBB task.
|
|
48
50
|
|
|
49
51
|
Args:
|
|
50
52
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
|
51
53
|
save_dir (str | Path, optional): Directory to save results.
|
|
52
|
-
|
|
53
|
-
args (dict, optional): Arguments containing validation parameters.
|
|
54
|
+
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
|
54
55
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
55
56
|
"""
|
|
56
|
-
super().__init__(dataloader, save_dir,
|
|
57
|
+
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
57
58
|
self.args.task = "obb"
|
|
58
|
-
self.metrics = OBBMetrics(
|
|
59
|
+
self.metrics = OBBMetrics()
|
|
59
60
|
|
|
60
|
-
def init_metrics(self, model):
|
|
61
|
-
"""Initialize evaluation metrics for YOLO.
|
|
61
|
+
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
62
|
+
"""Initialize evaluation metrics for YOLO obb validation.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
model (torch.nn.Module): Model to validate.
|
|
66
|
+
"""
|
|
62
67
|
super().init_metrics(model)
|
|
63
68
|
val = self.data.get(self.args.split, "") # validation path
|
|
64
69
|
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
|
70
|
+
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
|
65
71
|
|
|
66
|
-
def _process_batch(self,
|
|
67
|
-
"""
|
|
68
|
-
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
72
|
+
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
|
73
|
+
"""Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
|
69
74
|
|
|
70
75
|
Args:
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
|
|
76
|
+
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
|
77
|
+
class labels and bounding boxes.
|
|
78
|
+
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth class
|
|
79
|
+
labels and bounding boxes.
|
|
76
80
|
|
|
77
81
|
Returns:
|
|
78
|
-
(
|
|
79
|
-
|
|
82
|
+
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy array
|
|
83
|
+
with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy of
|
|
84
|
+
predictions compared to the ground truth.
|
|
80
85
|
|
|
81
86
|
Examples:
|
|
82
87
|
>>> detections = torch.rand(100, 7) # 100 sample detections
|
|
83
88
|
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
|
84
89
|
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
|
85
|
-
>>> correct_matrix =
|
|
86
|
-
|
|
87
|
-
Note:
|
|
88
|
-
This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
|
|
90
|
+
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
|
89
91
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
+
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
|
93
|
+
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
|
94
|
+
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
|
|
95
|
+
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
|
96
|
+
|
|
97
|
+
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
|
98
|
+
"""Postprocess OBB predictions.
|
|
92
99
|
|
|
93
|
-
|
|
100
|
+
Args:
|
|
101
|
+
preds (torch.Tensor): Raw predictions from the model.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
|
|
94
105
|
"""
|
|
95
|
-
|
|
106
|
+
preds = super().postprocess(preds)
|
|
107
|
+
for pred in preds:
|
|
108
|
+
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
|
|
109
|
+
return preds
|
|
110
|
+
|
|
111
|
+
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
|
112
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting.
|
|
96
113
|
|
|
97
114
|
Args:
|
|
98
115
|
si (int): Batch index to process.
|
|
99
|
-
batch (dict): Dictionary containing batch data with keys:
|
|
116
|
+
batch (dict[str, Any]): Dictionary containing batch data with keys:
|
|
100
117
|
- batch_idx: Tensor of batch indices
|
|
101
118
|
- cls: Tensor of class labels
|
|
102
119
|
- bboxes: Tensor of bounding boxes
|
|
@@ -104,8 +121,8 @@ class OBBValidator(DetectionValidator):
|
|
|
104
121
|
- img: Batch of images
|
|
105
122
|
- ratio_pad: Ratio and padding information
|
|
106
123
|
|
|
107
|
-
|
|
108
|
-
|
|
124
|
+
Returns:
|
|
125
|
+
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
|
109
126
|
"""
|
|
110
127
|
idx = batch["batch_idx"] == si
|
|
111
128
|
cls = batch["cls"][idx].squeeze(-1)
|
|
@@ -113,41 +130,23 @@ class OBBValidator(DetectionValidator):
|
|
|
113
130
|
ori_shape = batch["ori_shape"][si]
|
|
114
131
|
imgsz = batch["img"].shape[2:]
|
|
115
132
|
ratio_pad = batch["ratio_pad"][si]
|
|
116
|
-
if
|
|
133
|
+
if cls.shape[0]:
|
|
117
134
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
135
|
+
return {
|
|
136
|
+
"cls": cls,
|
|
137
|
+
"bboxes": bbox,
|
|
138
|
+
"ori_shape": ori_shape,
|
|
139
|
+
"imgsz": imgsz,
|
|
140
|
+
"ratio_pad": ratio_pad,
|
|
141
|
+
"im_file": batch["im_file"][si],
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
|
145
|
+
"""Plot predicted bounding boxes on input images and save the result.
|
|
127
146
|
|
|
128
147
|
Args:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
- imgsz (tuple): Model input image size.
|
|
132
|
-
- ori_shape (tuple): Original image shape.
|
|
133
|
-
- ratio_pad (tuple): Ratio and padding information for scaling.
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
(torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
|
|
137
|
-
"""
|
|
138
|
-
predn = pred.clone()
|
|
139
|
-
ops.scale_boxes(
|
|
140
|
-
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
|
141
|
-
) # native-space pred
|
|
142
|
-
return predn
|
|
143
|
-
|
|
144
|
-
def plot_predictions(self, batch, preds, ni):
|
|
145
|
-
"""
|
|
146
|
-
Plot predicted bounding boxes on input images and save the result.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
batch (dict): Batch data containing images, file paths, and other metadata.
|
|
150
|
-
preds (list): List of prediction tensors for each image in the batch.
|
|
148
|
+
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
|
149
|
+
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
|
|
151
150
|
ni (int): Batch index used for naming the output file.
|
|
152
151
|
|
|
153
152
|
Examples:
|
|
@@ -156,54 +155,50 @@ class OBBValidator(DetectionValidator):
|
|
|
156
155
|
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
|
157
156
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
158
157
|
"""
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
164
|
-
names=self.names,
|
|
165
|
-
on_plot=self.on_plot,
|
|
166
|
-
) # pred
|
|
158
|
+
for p in preds:
|
|
159
|
+
# TODO: fix this duplicated `xywh2xyxy`
|
|
160
|
+
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
|
|
161
|
+
super().plot_predictions(batch, preds, ni) # plot bboxes
|
|
167
162
|
|
|
168
|
-
def pred_to_json(self, predn,
|
|
169
|
-
"""
|
|
170
|
-
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
163
|
+
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
|
164
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
|
171
165
|
|
|
172
166
|
Args:
|
|
173
|
-
predn (torch.Tensor): Prediction
|
|
174
|
-
|
|
175
|
-
|
|
167
|
+
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys with
|
|
168
|
+
bounding box coordinates, confidence scores, and class predictions.
|
|
169
|
+
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
|
176
170
|
|
|
177
171
|
Notes:
|
|
178
172
|
This method processes rotated bounding box predictions and converts them to both rbox format
|
|
179
173
|
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
|
180
174
|
to the JSON dictionary.
|
|
181
175
|
"""
|
|
182
|
-
|
|
176
|
+
path = Path(pbatch["im_file"])
|
|
177
|
+
stem = path.stem
|
|
183
178
|
image_id = int(stem) if stem.isnumeric() else stem
|
|
184
|
-
rbox =
|
|
179
|
+
rbox = predn["bboxes"]
|
|
185
180
|
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
|
186
|
-
for
|
|
181
|
+
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
|
187
182
|
self.jdict.append(
|
|
188
183
|
{
|
|
189
184
|
"image_id": image_id,
|
|
190
|
-
"
|
|
191
|
-
"
|
|
185
|
+
"file_name": path.name,
|
|
186
|
+
"category_id": self.class_map[int(c)],
|
|
187
|
+
"score": round(s, 5),
|
|
192
188
|
"rbox": [round(x, 3) for x in r],
|
|
193
189
|
"poly": [round(x, 3) for x in b],
|
|
194
190
|
}
|
|
195
191
|
)
|
|
196
192
|
|
|
197
|
-
def save_one_txt(self, predn, save_conf, shape, file):
|
|
198
|
-
"""
|
|
199
|
-
Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
|
|
193
|
+
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
|
194
|
+
"""Save YOLO OBB detections to a text file in normalized coordinates.
|
|
200
195
|
|
|
201
196
|
Args:
|
|
202
197
|
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
|
203
198
|
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
|
204
199
|
save_conf (bool): Whether to save confidence scores in the text file.
|
|
205
|
-
shape (tuple): Original image shape in format (height, width).
|
|
206
|
-
file (Path
|
|
200
|
+
shape (tuple[int, int]): Original image shape in format (height, width).
|
|
201
|
+
file (Path): Output file path to save detections.
|
|
207
202
|
|
|
208
203
|
Examples:
|
|
209
204
|
>>> validator = OBBValidator()
|
|
@@ -214,18 +209,31 @@ class OBBValidator(DetectionValidator):
|
|
|
214
209
|
|
|
215
210
|
from ultralytics.engine.results import Results
|
|
216
211
|
|
|
217
|
-
rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
|
218
|
-
# xywh, r, conf, cls
|
|
219
|
-
obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
|
|
220
212
|
Results(
|
|
221
213
|
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
|
222
214
|
path=None,
|
|
223
215
|
names=self.names,
|
|
224
|
-
obb=
|
|
216
|
+
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
|
225
217
|
).save_txt(file, save_conf=save_conf)
|
|
226
218
|
|
|
227
|
-
def
|
|
228
|
-
"""
|
|
219
|
+
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
|
220
|
+
"""Scales predictions to the original image size."""
|
|
221
|
+
return {
|
|
222
|
+
**predn,
|
|
223
|
+
"bboxes": ops.scale_boxes(
|
|
224
|
+
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
|
225
|
+
),
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
|
229
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
stats (dict[str, Any]): Performance statistics dictionary.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
(dict[str, Any]): Updated performance statistics.
|
|
236
|
+
"""
|
|
229
237
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
|
230
238
|
import json
|
|
231
239
|
import re
|
|
@@ -252,7 +260,7 @@ class OBBValidator(DetectionValidator):
|
|
|
252
260
|
merged_results = defaultdict(list)
|
|
253
261
|
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
|
|
254
262
|
for d in data:
|
|
255
|
-
image_id = d["image_id"].split("__")[0]
|
|
263
|
+
image_id = d["image_id"].split("__", 1)[0]
|
|
256
264
|
pattern = re.compile(r"\d+___\d+")
|
|
257
265
|
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
|
258
266
|
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
|
|
@@ -268,7 +276,7 @@ class OBBValidator(DetectionValidator):
|
|
|
268
276
|
b = bbox[:, :5].clone()
|
|
269
277
|
b[:, :2] += c
|
|
270
278
|
# 0.3 could get results close to the ones from official merging script, even slightly better.
|
|
271
|
-
i =
|
|
279
|
+
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
|
|
272
280
|
bbox = bbox[i]
|
|
273
281
|
|
|
274
282
|
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
|
@@ -5,8 +5,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class PosePredictor(DetectionPredictor):
|
|
8
|
-
"""
|
|
9
|
-
A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
8
|
+
"""A class extending the DetectionPredictor class for prediction based on a pose model.
|
|
10
9
|
|
|
11
10
|
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
|
12
11
|
capabilities inherited from DetectionPredictor.
|
|
@@ -16,7 +15,7 @@ class PosePredictor(DetectionPredictor):
|
|
|
16
15
|
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
|
17
16
|
|
|
18
17
|
Methods:
|
|
19
|
-
construct_result:
|
|
18
|
+
construct_result: Construct the result object from the prediction, including keypoints.
|
|
20
19
|
|
|
21
20
|
Examples:
|
|
22
21
|
>>> from ultralytics.utils import ASSETS
|
|
@@ -27,14 +26,13 @@ class PosePredictor(DetectionPredictor):
|
|
|
27
26
|
"""
|
|
28
27
|
|
|
29
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
30
|
-
"""
|
|
31
|
-
Initialize PosePredictor, a specialized predictor for pose estimation tasks.
|
|
29
|
+
"""Initialize PosePredictor for pose estimation tasks.
|
|
32
30
|
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific warnings
|
|
32
|
+
for Apple MPS.
|
|
35
33
|
|
|
36
34
|
Args:
|
|
37
|
-
cfg (Any): Configuration for the predictor.
|
|
35
|
+
cfg (Any): Configuration for the predictor.
|
|
38
36
|
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
|
39
37
|
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
|
40
38
|
|
|
@@ -54,11 +52,10 @@ class PosePredictor(DetectionPredictor):
|
|
|
54
52
|
)
|
|
55
53
|
|
|
56
54
|
def construct_result(self, pred, img, orig_img, img_path):
|
|
57
|
-
"""
|
|
58
|
-
Construct the result object from the prediction, including keypoints.
|
|
55
|
+
"""Construct the result object from the prediction, including keypoints.
|
|
59
56
|
|
|
60
|
-
|
|
61
|
-
|
|
57
|
+
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
|
58
|
+
result object.
|
|
62
59
|
|
|
63
60
|
Args:
|
|
64
61
|
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
|
@@ -68,11 +65,12 @@ class PosePredictor(DetectionPredictor):
|
|
|
68
65
|
img_path (str): The path to the original image file.
|
|
69
66
|
|
|
70
67
|
Returns:
|
|
71
|
-
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
|
68
|
+
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
|
69
|
+
keypoints.
|
|
72
70
|
"""
|
|
73
71
|
result = super().construct_result(pred, img, orig_img, img_path)
|
|
74
72
|
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
|
75
|
-
pred_kpts = pred[:, 6:].view(
|
|
73
|
+
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
|
|
76
74
|
# Scale keypoints coordinates to match the original image dimensions
|
|
77
75
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
|
78
76
|
result.update(keypoints=pred_kpts)
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from copy import copy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
4
8
|
|
|
5
9
|
from ultralytics.models import yolo
|
|
6
10
|
from ultralytics.nn.tasks import PoseModel
|
|
7
11
|
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
|
8
|
-
from ultralytics.utils.plotting import plot_images, plot_results
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
12
|
-
"""
|
|
13
|
-
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
15
|
+
"""A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
|
14
16
|
|
|
15
17
|
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
|
16
18
|
of pose keypoints alongside bounding boxes.
|
|
@@ -19,14 +21,14 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
19
21
|
args (dict): Configuration arguments for training.
|
|
20
22
|
model (PoseModel): The pose estimation model being trained.
|
|
21
23
|
data (dict): Dataset configuration including keypoint shape information.
|
|
22
|
-
loss_names (
|
|
24
|
+
loss_names (tuple): Names of the loss components used in training.
|
|
23
25
|
|
|
24
26
|
Methods:
|
|
25
|
-
get_model:
|
|
26
|
-
set_model_attributes:
|
|
27
|
-
get_validator:
|
|
28
|
-
plot_training_samples:
|
|
29
|
-
|
|
27
|
+
get_model: Retrieve a pose estimation model with specified configuration.
|
|
28
|
+
set_model_attributes: Set keypoints shape attribute on the model.
|
|
29
|
+
get_validator: Create a validator instance for model evaluation.
|
|
30
|
+
plot_training_samples: Visualize training samples with keypoints.
|
|
31
|
+
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
|
30
32
|
|
|
31
33
|
Examples:
|
|
32
34
|
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
|
@@ -35,12 +37,8 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
35
37
|
>>> trainer.train()
|
|
36
38
|
"""
|
|
37
39
|
|
|
38
|
-
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
39
|
-
"""
|
|
40
|
-
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
41
|
-
|
|
42
|
-
This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
|
|
43
|
-
handling specific configurations needed for keypoint detection models.
|
|
40
|
+
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
|
41
|
+
"""Initialize a PoseTrainer object for training YOLO pose estimation models.
|
|
44
42
|
|
|
45
43
|
Args:
|
|
46
44
|
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
|
@@ -50,12 +48,6 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
50
48
|
Notes:
|
|
51
49
|
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
|
52
50
|
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
|
53
|
-
|
|
54
|
-
Examples:
|
|
55
|
-
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
|
56
|
-
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
|
57
|
-
>>> trainer = PoseTrainer(overrides=args)
|
|
58
|
-
>>> trainer.train()
|
|
59
51
|
"""
|
|
60
52
|
if overrides is None:
|
|
61
53
|
overrides = {}
|
|
@@ -68,13 +60,17 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
68
60
|
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
69
61
|
)
|
|
70
62
|
|
|
71
|
-
def get_model(
|
|
72
|
-
|
|
73
|
-
|
|
63
|
+
def get_model(
|
|
64
|
+
self,
|
|
65
|
+
cfg: str | Path | dict[str, Any] | None = None,
|
|
66
|
+
weights: str | Path | None = None,
|
|
67
|
+
verbose: bool = True,
|
|
68
|
+
) -> PoseModel:
|
|
69
|
+
"""Get pose estimation model with specified configuration and weights.
|
|
74
70
|
|
|
75
71
|
Args:
|
|
76
|
-
cfg (str | Path | dict
|
|
77
|
-
weights (str | Path
|
|
72
|
+
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
|
73
|
+
weights (str | Path, optional): Path to the model weights file.
|
|
78
74
|
verbose (bool): Whether to display model information.
|
|
79
75
|
|
|
80
76
|
Returns:
|
|
@@ -89,58 +85,24 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
89
85
|
return model
|
|
90
86
|
|
|
91
87
|
def set_model_attributes(self):
|
|
92
|
-
"""
|
|
88
|
+
"""Set keypoints shape attribute of PoseModel."""
|
|
93
89
|
super().set_model_attributes()
|
|
94
90
|
self.model.kpt_shape = self.data["kpt_shape"]
|
|
91
|
+
kpt_names = self.data.get("kpt_names")
|
|
92
|
+
if not kpt_names:
|
|
93
|
+
names = list(map(str, range(self.model.kpt_shape[0])))
|
|
94
|
+
kpt_names = {i: names for i in range(self.model.nc)}
|
|
95
|
+
self.model.kpt_names = kpt_names
|
|
95
96
|
|
|
96
97
|
def get_validator(self):
|
|
97
|
-
"""
|
|
98
|
+
"""Return an instance of the PoseValidator class for validation."""
|
|
98
99
|
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
|
99
100
|
return yolo.pose.PoseValidator(
|
|
100
101
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
101
102
|
)
|
|
102
103
|
|
|
103
|
-
def
|
|
104
|
-
"""
|
|
105
|
-
Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
batch (dict): Dictionary containing batch data with the following keys:
|
|
109
|
-
- img (torch.Tensor): Batch of images
|
|
110
|
-
- keypoints (torch.Tensor): Keypoints coordinates for pose estimation
|
|
111
|
-
- cls (torch.Tensor): Class labels
|
|
112
|
-
- bboxes (torch.Tensor): Bounding box coordinates
|
|
113
|
-
- im_file (list): List of image file paths
|
|
114
|
-
- batch_idx (torch.Tensor): Batch indices for each instance
|
|
115
|
-
ni (int): Current training iteration number used for filename
|
|
116
|
-
|
|
117
|
-
The function saves the plotted batch as an image in the trainer's save directory with the filename
|
|
118
|
-
'train_batch{ni}.jpg', where ni is the iteration number.
|
|
119
|
-
"""
|
|
120
|
-
images = batch["img"]
|
|
121
|
-
kpts = batch["keypoints"]
|
|
122
|
-
cls = batch["cls"].squeeze(-1)
|
|
123
|
-
bboxes = batch["bboxes"]
|
|
124
|
-
paths = batch["im_file"]
|
|
125
|
-
batch_idx = batch["batch_idx"]
|
|
126
|
-
plot_images(
|
|
127
|
-
images,
|
|
128
|
-
batch_idx,
|
|
129
|
-
cls,
|
|
130
|
-
bboxes,
|
|
131
|
-
kpts=kpts,
|
|
132
|
-
paths=paths,
|
|
133
|
-
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
134
|
-
on_plot=self.on_plot,
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
def plot_metrics(self):
|
|
138
|
-
"""Plots training/val metrics."""
|
|
139
|
-
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
|
|
140
|
-
|
|
141
|
-
def get_dataset(self):
|
|
142
|
-
"""
|
|
143
|
-
Retrieves the dataset and ensures it contains the required `kpt_shape` key.
|
|
104
|
+
def get_dataset(self) -> dict[str, Any]:
|
|
105
|
+
"""Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
|
144
106
|
|
|
145
107
|
Returns:
|
|
146
108
|
(dict): A dictionary containing the training/validation/test dataset and category names.
|