ultralytics 8.3.101__py3-none-any.whl → 8.3.103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_exports.py +14 -5
- tests/test_solutions.py +140 -76
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +1 -1
- ultralytics/engine/exporter.py +23 -8
- ultralytics/engine/tuner.py +8 -2
- ultralytics/hub/__init__.py +29 -2
- ultralytics/hub/google/__init__.py +18 -1
- ultralytics/models/fastsam/predict.py +12 -1
- ultralytics/models/nas/predict.py +21 -3
- ultralytics/models/rtdetr/val.py +26 -2
- ultralytics/models/sam/amg.py +22 -1
- ultralytics/models/sam/modules/encoders.py +85 -4
- ultralytics/models/sam/modules/memory_attention.py +61 -3
- ultralytics/models/sam/modules/utils.py +108 -5
- ultralytics/models/utils/loss.py +38 -2
- ultralytics/models/utils/ops.py +15 -1
- ultralytics/models/yolo/classify/predict.py +11 -1
- ultralytics/models/yolo/classify/train.py +17 -1
- ultralytics/models/yolo/classify/val.py +82 -6
- ultralytics/models/yolo/detect/predict.py +20 -1
- ultralytics/models/yolo/model.py +55 -4
- ultralytics/models/yolo/obb/predict.py +16 -1
- ultralytics/models/yolo/obb/train.py +35 -2
- ultralytics/models/yolo/obb/val.py +87 -6
- ultralytics/models/yolo/pose/predict.py +18 -1
- ultralytics/models/yolo/pose/train.py +48 -3
- ultralytics/models/yolo/pose/val.py +113 -8
- ultralytics/models/yolo/segment/predict.py +27 -2
- ultralytics/models/yolo/segment/train.py +61 -3
- ultralytics/models/yolo/segment/val.py +10 -1
- ultralytics/models/yolo/world/train_world.py +29 -1
- ultralytics/models/yolo/yoloe/train.py +47 -3
- ultralytics/nn/autobackend.py +9 -8
- ultralytics/nn/modules/activation.py +26 -3
- ultralytics/nn/modules/block.py +89 -0
- ultralytics/nn/modules/head.py +3 -92
- ultralytics/nn/modules/utils.py +70 -4
- ultralytics/nn/tasks.py +3 -0
- ultralytics/nn/text_model.py +93 -17
- ultralytics/solutions/instance_segmentation.py +15 -7
- ultralytics/solutions/solutions.py +2 -47
- ultralytics/utils/benchmarks.py +1 -1
- ultralytics/utils/callbacks/base.py +22 -5
- ultralytics/utils/callbacks/comet.py +93 -5
- ultralytics/utils/callbacks/dvc.py +64 -5
- ultralytics/utils/callbacks/neptune.py +25 -2
- ultralytics/utils/callbacks/tensorboard.py +30 -2
- ultralytics/utils/callbacks/wb.py +16 -1
- ultralytics/utils/dist.py +35 -2
- ultralytics/utils/errors.py +27 -6
- ultralytics/utils/metrics.py +1 -1
- ultralytics/utils/patches.py +33 -5
- ultralytics/utils/torch_utils.py +14 -6
- ultralytics/utils/triton.py +16 -3
- ultralytics/utils/tuner.py +17 -9
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/METADATA +3 -4
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/RECORD +62 -62
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
@@ -22,7 +22,24 @@ class YOLO(Model):
|
|
22
22
|
"""YOLO (You Only Look Once) object detection model."""
|
23
23
|
|
24
24
|
def __init__(self, model="yolo11n.pt", task=None, verbose=False):
|
25
|
-
"""
|
25
|
+
"""
|
26
|
+
Initialize a YOLO model.
|
27
|
+
|
28
|
+
This constructor initializes a YOLO model, automatically switching to specialized model types
|
29
|
+
(YOLOWorld or YOLOE) based on the model filename.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolov8n.yaml'.
|
33
|
+
task (str | None): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
34
|
+
Defaults to auto-detection based on model.
|
35
|
+
verbose (bool): Display model info on load.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> from ultralytics import YOLO
|
39
|
+
>>> model = YOLO("yolov8n.pt") # load a pretrained YOLOv8n detection model
|
40
|
+
>>> model = YOLO("yolov8n-seg.pt") # load a pretrained YOLOv8n segmentation model
|
41
|
+
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
42
|
+
"""
|
26
43
|
path = Path(model)
|
27
44
|
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
28
45
|
new_instance = YOLOWorld(path, verbose=verbose)
|
@@ -166,12 +183,46 @@ class YOLOE(Model):
|
|
166
183
|
return self.model.get_text_pe(texts)
|
167
184
|
|
168
185
|
def get_visual_pe(self, img, visual):
|
169
|
-
"""
|
186
|
+
"""
|
187
|
+
Get visual positional embeddings for the given image and visual features.
|
188
|
+
|
189
|
+
This method extracts positional embeddings from visual features based on the input image. It requires
|
190
|
+
that the model is an instance of YOLOEModel.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
img (torch.Tensor): Input image tensor.
|
194
|
+
visual (torch.Tensor): Visual features extracted from the image.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
(torch.Tensor): Visual positional embeddings.
|
198
|
+
|
199
|
+
Examples:
|
200
|
+
>>> model = YOLOE("yoloe-v8s.pt")
|
201
|
+
>>> img = torch.rand(1, 3, 640, 640)
|
202
|
+
>>> visual_features = model.model.backbone(img)
|
203
|
+
>>> pe = model.get_visual_pe(img, visual_features)
|
204
|
+
"""
|
170
205
|
assert isinstance(self.model, YOLOEModel)
|
171
206
|
return self.model.get_visual_pe(img, visual)
|
172
207
|
|
173
208
|
def set_vocab(self, vocab, names):
|
174
|
-
"""
|
209
|
+
"""
|
210
|
+
Set vocabulary and class names for the YOLOE model.
|
211
|
+
|
212
|
+
This method configures the vocabulary and class names used by the model for text processing and
|
213
|
+
classification tasks. The model must be an instance of YOLOEModel.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
vocab (list): Vocabulary list containing tokens or words used by the model for text processing.
|
217
|
+
names (list): List of class names that the model can detect or classify.
|
218
|
+
|
219
|
+
Raises:
|
220
|
+
AssertionError: If the model is not an instance of YOLOEModel.
|
221
|
+
|
222
|
+
Examples:
|
223
|
+
>>> model = YOLOE("yoloe-v8s.pt")
|
224
|
+
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
225
|
+
"""
|
175
226
|
assert isinstance(self.model, YOLOEModel)
|
176
227
|
self.model.set_vocab(vocab, names=names)
|
177
228
|
|
@@ -290,7 +341,7 @@ class YOLOE(Model):
|
|
290
341
|
|
291
342
|
self.predictor.setup_model(model=self.model)
|
292
343
|
|
293
|
-
if refer_image is None:
|
344
|
+
if refer_image is None and source:
|
294
345
|
dataset = load_inference_source(source)
|
295
346
|
if dataset.mode in {"video", "stream"}:
|
296
347
|
# NOTE: set the first frame as refer image for videos/streams inference
|
@@ -27,7 +27,22 @@ class OBBPredictor(DetectionPredictor):
|
|
27
27
|
"""
|
28
28
|
|
29
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Initialize OBBPredictor with optional model and data configuration overrides.
|
32
|
+
|
33
|
+
This constructor sets up an OBBPredictor instance for oriented bounding box detection tasks.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict, optional): Default configuration for the predictor.
|
37
|
+
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
38
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
>>> from ultralytics.utils import ASSETS
|
42
|
+
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
43
|
+
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
44
|
+
>>> predictor = OBBPredictor(overrides=args)
|
45
|
+
"""
|
31
46
|
super().__init__(cfg, overrides, _callbacks)
|
32
47
|
self.args.task = "obb"
|
33
48
|
|
@@ -26,14 +26,47 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
26
26
|
"""
|
27
27
|
|
28
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
29
|
-
"""
|
29
|
+
"""
|
30
|
+
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
31
|
+
|
32
|
+
This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
|
33
|
+
bounding boxes. It automatically sets the task to 'obb' in the configuration.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
37
|
+
model configuration.
|
38
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
39
|
+
will take precedence over those in cfg.
|
40
|
+
_callbacks (list, optional): List of callback functions to be invoked during training.
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
44
|
+
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
45
|
+
>>> trainer = OBBTrainer(overrides=args)
|
46
|
+
>>> trainer.train()
|
47
|
+
"""
|
30
48
|
if overrides is None:
|
31
49
|
overrides = {}
|
32
50
|
overrides["task"] = "obb"
|
33
51
|
super().__init__(cfg, overrides, _callbacks)
|
34
52
|
|
35
53
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
36
|
-
"""
|
54
|
+
"""
|
55
|
+
Return OBBModel initialized with specified config and weights.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
cfg (str | dict | None): Model configuration. Can be a path to a YAML config file, a dictionary
|
59
|
+
containing configuration parameters, or None to use default configuration.
|
60
|
+
weights (str | Path | None): Path to pretrained weights file. If None, random initialization is used.
|
61
|
+
verbose (bool): Whether to display model information during initialization.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
(OBBModel): Initialized OBBModel with the specified configuration and weights.
|
65
|
+
|
66
|
+
Examples:
|
67
|
+
>>> trainer = OBBTrainer()
|
68
|
+
>>> model = trainer.get_model(cfg="yolov8n-obb.yaml", weights="yolov8n-obb.pt")
|
69
|
+
"""
|
37
70
|
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
38
71
|
if weights:
|
39
72
|
model.load(weights)
|
@@ -40,7 +40,19 @@ class OBBValidator(DetectionValidator):
|
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
43
|
-
"""
|
43
|
+
"""
|
44
|
+
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
45
|
+
|
46
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
47
|
+
It extends the DetectionValidator class and configures it specifically for the OBB task.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
51
|
+
save_dir (str | Path, optional): Directory to save results.
|
52
|
+
pbar (bool, optional): Display progress bar during validation.
|
53
|
+
args (dict, optional): Arguments containing validation parameters.
|
54
|
+
_callbacks (list, optional): List of callback functions to be called during validation.
|
55
|
+
"""
|
44
56
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
45
57
|
self.args.task = "obb"
|
46
58
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
|
@@ -79,7 +91,22 @@ class OBBValidator(DetectionValidator):
|
|
79
91
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
80
92
|
|
81
93
|
def _prepare_batch(self, si, batch):
|
82
|
-
"""
|
94
|
+
"""
|
95
|
+
Prepare batch data for OBB validation with proper scaling and formatting.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
si (int): Batch index to process.
|
99
|
+
batch (dict): Dictionary containing batch data with keys:
|
100
|
+
- batch_idx: Tensor of batch indices
|
101
|
+
- cls: Tensor of class labels
|
102
|
+
- bboxes: Tensor of bounding boxes
|
103
|
+
- ori_shape: Original image shapes
|
104
|
+
- img: Batch of images
|
105
|
+
- ratio_pad: Ratio and padding information
|
106
|
+
|
107
|
+
This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
|
108
|
+
and scales the bounding boxes to the original image dimensions.
|
109
|
+
"""
|
83
110
|
idx = batch["batch_idx"] == si
|
84
111
|
cls = batch["cls"][idx].squeeze(-1)
|
85
112
|
bbox = batch["bboxes"][idx]
|
@@ -92,7 +119,22 @@ class OBBValidator(DetectionValidator):
|
|
92
119
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
93
120
|
|
94
121
|
def _prepare_pred(self, pred, pbatch):
|
95
|
-
"""
|
122
|
+
"""
|
123
|
+
Prepare predictions by scaling bounding boxes to original image dimensions.
|
124
|
+
|
125
|
+
This method takes prediction tensors containing bounding box coordinates and scales them from the model's
|
126
|
+
input dimensions to the original image dimensions using the provided batch information.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
|
130
|
+
pbatch (dict): Dictionary containing batch information with keys:
|
131
|
+
- imgsz (tuple): Model input image size.
|
132
|
+
- ori_shape (tuple): Original image shape.
|
133
|
+
- ratio_pad (tuple): Ratio and padding information for scaling.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
(torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
|
137
|
+
"""
|
96
138
|
predn = pred.clone()
|
97
139
|
ops.scale_boxes(
|
98
140
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
@@ -100,7 +142,20 @@ class OBBValidator(DetectionValidator):
|
|
100
142
|
return predn
|
101
143
|
|
102
144
|
def plot_predictions(self, batch, preds, ni):
|
103
|
-
"""
|
145
|
+
"""
|
146
|
+
Plot predicted bounding boxes on input images and save the result.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
batch (dict): Batch data containing images, file paths, and other metadata.
|
150
|
+
preds (list): List of prediction tensors for each image in the batch.
|
151
|
+
ni (int): Batch index used for naming the output file.
|
152
|
+
|
153
|
+
Examples:
|
154
|
+
>>> validator = OBBValidator()
|
155
|
+
>>> batch = {"img": images, "im_file": paths}
|
156
|
+
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
157
|
+
>>> validator.plot_predictions(batch, preds, 0)
|
158
|
+
"""
|
104
159
|
plot_images(
|
105
160
|
batch["img"],
|
106
161
|
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
@@ -111,7 +166,19 @@ class OBBValidator(DetectionValidator):
|
|
111
166
|
) # pred
|
112
167
|
|
113
168
|
def pred_to_json(self, predn, filename):
|
114
|
-
"""
|
169
|
+
"""
|
170
|
+
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
|
174
|
+
class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
|
175
|
+
filename (str | Path): Path to the image file for which predictions are being processed.
|
176
|
+
|
177
|
+
Notes:
|
178
|
+
This method processes rotated bounding box predictions and converts them to both rbox format
|
179
|
+
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
180
|
+
to the JSON dictionary.
|
181
|
+
"""
|
115
182
|
stem = Path(filename).stem
|
116
183
|
image_id = int(stem) if stem.isnumeric() else stem
|
117
184
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
@@ -128,7 +195,21 @@ class OBBValidator(DetectionValidator):
|
|
128
195
|
)
|
129
196
|
|
130
197
|
def save_one_txt(self, predn, save_conf, shape, file):
|
131
|
-
"""
|
198
|
+
"""
|
199
|
+
Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
203
|
+
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
204
|
+
save_conf (bool): Whether to save confidence scores in the text file.
|
205
|
+
shape (tuple): Original image shape in format (height, width).
|
206
|
+
file (Path | str): Output file path to save detections.
|
207
|
+
|
208
|
+
Examples:
|
209
|
+
>>> validator = OBBValidator()
|
210
|
+
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
|
211
|
+
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
|
212
|
+
"""
|
132
213
|
import numpy as np
|
133
214
|
|
134
215
|
from ultralytics.engine.results import Results
|
@@ -27,7 +27,24 @@ class PosePredictor(DetectionPredictor):
|
|
27
27
|
"""
|
28
28
|
|
29
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Initialize PosePredictor, a specialized predictor for pose estimation tasks.
|
32
|
+
|
33
|
+
This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
|
34
|
+
device-specific warnings for Apple MPS.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
|
38
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
39
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
40
|
+
|
41
|
+
Examples:
|
42
|
+
>>> from ultralytics.utils import ASSETS
|
43
|
+
>>> from ultralytics.models.yolo.pose import PosePredictor
|
44
|
+
>>> args = dict(model="yolov8n-pose.pt", source=ASSETS)
|
45
|
+
>>> predictor = PosePredictor(overrides=args)
|
46
|
+
>>> predictor.predict_cli()
|
47
|
+
"""
|
31
48
|
super().__init__(cfg, overrides, _callbacks)
|
32
49
|
self.args.task = "pose"
|
33
50
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
@@ -36,7 +36,27 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
36
36
|
"""
|
37
37
|
|
38
38
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
39
|
-
"""
|
39
|
+
"""
|
40
|
+
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
41
|
+
|
42
|
+
This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
|
43
|
+
handling specific configurations needed for keypoint detection models.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
47
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
48
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
49
|
+
|
50
|
+
Notes:
|
51
|
+
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
52
|
+
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
56
|
+
>>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
57
|
+
>>> trainer = PoseTrainer(overrides=args)
|
58
|
+
>>> trainer.train()
|
59
|
+
"""
|
40
60
|
if overrides is None:
|
41
61
|
overrides = {}
|
42
62
|
overrides["task"] = "pose"
|
@@ -49,7 +69,17 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
49
69
|
)
|
50
70
|
|
51
71
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
52
|
-
"""
|
72
|
+
"""
|
73
|
+
Get pose estimation model with specified configuration and weights.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
cfg (str | Path | dict | None): Model configuration file path or dictionary.
|
77
|
+
weights (str | Path | None): Path to the model weights file.
|
78
|
+
verbose (bool): Whether to display model information.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
(PoseModel): Initialized pose estimation model.
|
82
|
+
"""
|
53
83
|
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
|
54
84
|
if weights:
|
55
85
|
model.load(weights)
|
@@ -69,7 +99,22 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
69
99
|
)
|
70
100
|
|
71
101
|
def plot_training_samples(self, batch, ni):
|
72
|
-
"""
|
102
|
+
"""
|
103
|
+
Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
batch (dict): Dictionary containing batch data with the following keys:
|
107
|
+
- img (torch.Tensor): Batch of images
|
108
|
+
- keypoints (torch.Tensor): Keypoints coordinates for pose estimation
|
109
|
+
- cls (torch.Tensor): Class labels
|
110
|
+
- bboxes (torch.Tensor): Bounding box coordinates
|
111
|
+
- im_file (list): List of image file paths
|
112
|
+
- batch_idx (torch.Tensor): Batch indices for each instance
|
113
|
+
ni (int): Current training iteration number used for filename
|
114
|
+
|
115
|
+
The function saves the plotted batch as an image in the trainer's save directory with the filename
|
116
|
+
'train_batch{ni}.jpg', where ni is the iteration number.
|
117
|
+
"""
|
73
118
|
images = batch["img"]
|
74
119
|
kpts = batch["keypoints"]
|
75
120
|
cls = batch["cls"].squeeze(-1)
|
@@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
|
|
20
20
|
specialized metrics for pose evaluation.
|
21
21
|
|
22
22
|
Attributes:
|
23
|
-
sigma (np.ndarray): Sigma values for OKS calculation, either
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
24
24
|
kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
25
25
|
args (dict): Arguments for the validator including task set to "pose".
|
26
26
|
metrics (PoseMetrics): Metrics object for pose evaluation.
|
@@ -47,7 +47,30 @@ class PoseValidator(DetectionValidator):
|
|
47
47
|
"""
|
48
48
|
|
49
49
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
50
|
-
"""
|
50
|
+
"""
|
51
|
+
Initialize a PoseValidator object for pose estimation validation.
|
52
|
+
|
53
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
54
|
+
specialized metrics for pose evaluation.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
58
|
+
save_dir (Path | str, optional): Directory to save results.
|
59
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
60
|
+
args (dict, optional): Arguments for the validator including task set to "pose".
|
61
|
+
_callbacks (list, optional): List of callback functions to be executed during validation.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
65
|
+
>>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
|
66
|
+
>>> validator = PoseValidator(args=args)
|
67
|
+
>>> validator()
|
68
|
+
|
69
|
+
Notes:
|
70
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
71
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
72
|
+
due to a known bug with pose models.
|
73
|
+
"""
|
51
74
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
52
75
|
self.sigma = None
|
53
76
|
self.kpt_shape = None
|
@@ -91,7 +114,20 @@ class PoseValidator(DetectionValidator):
|
|
91
114
|
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
92
115
|
|
93
116
|
def _prepare_batch(self, si, batch):
|
94
|
-
"""
|
117
|
+
"""
|
118
|
+
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
si (int): Batch index.
|
122
|
+
batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
|
126
|
+
|
127
|
+
Notes:
|
128
|
+
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
129
|
+
Keypoints are scaled from normalized coordinates to original image dimensions.
|
130
|
+
"""
|
95
131
|
pbatch = super()._prepare_batch(si, batch)
|
96
132
|
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
97
133
|
h, w = pbatch["imgsz"]
|
@@ -103,7 +139,23 @@ class PoseValidator(DetectionValidator):
|
|
103
139
|
return pbatch
|
104
140
|
|
105
141
|
def _prepare_pred(self, pred, pbatch):
|
106
|
-
"""
|
142
|
+
"""
|
143
|
+
Prepare and scale keypoints in predictions for pose processing.
|
144
|
+
|
145
|
+
This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
|
146
|
+
the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
|
147
|
+
to match the original image dimensions.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
pred (torch.Tensor): Raw prediction tensor from the model.
|
151
|
+
pbatch (dict): Processed batch dictionary containing image information including:
|
152
|
+
- imgsz: Image size used for inference
|
153
|
+
- ori_shape: Original image shape
|
154
|
+
- ratio_pad: Ratio and padding information for coordinate scaling
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
|
158
|
+
"""
|
107
159
|
predn = super()._prepare_pred(pred, pbatch)
|
108
160
|
nk = pbatch["kpts"].shape[1]
|
109
161
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
@@ -204,7 +256,19 @@ class PoseValidator(DetectionValidator):
|
|
204
256
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
205
257
|
|
206
258
|
def plot_val_samples(self, batch, ni):
|
207
|
-
"""
|
259
|
+
"""
|
260
|
+
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
batch (dict): Dictionary containing batch data with keys:
|
264
|
+
- img (torch.Tensor): Batch of images
|
265
|
+
- batch_idx (torch.Tensor): Batch indices for each image
|
266
|
+
- cls (torch.Tensor): Class labels
|
267
|
+
- bboxes (torch.Tensor): Bounding box coordinates
|
268
|
+
- keypoints (torch.Tensor): Keypoint coordinates
|
269
|
+
- im_file (list): List of image file paths
|
270
|
+
ni (int): Batch index used for naming the output file
|
271
|
+
"""
|
208
272
|
plot_images(
|
209
273
|
batch["img"],
|
210
274
|
batch["batch_idx"],
|
@@ -218,7 +282,18 @@ class PoseValidator(DetectionValidator):
|
|
218
282
|
)
|
219
283
|
|
220
284
|
def plot_predictions(self, batch, preds, ni):
|
221
|
-
"""
|
285
|
+
"""
|
286
|
+
Plot and save model predictions with bounding boxes and keypoints.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
|
290
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
291
|
+
confidence scores, class predictions, and keypoints.
|
292
|
+
ni (int): Batch index used for naming the output file.
|
293
|
+
|
294
|
+
The function extracts keypoints from predictions, converts predictions to target format, and plots them
|
295
|
+
on the input images. The resulting visualization is saved to the specified save directory.
|
296
|
+
"""
|
222
297
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
223
298
|
plot_images(
|
224
299
|
batch["img"],
|
@@ -231,7 +306,21 @@ class PoseValidator(DetectionValidator):
|
|
231
306
|
) # pred
|
232
307
|
|
233
308
|
def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
|
234
|
-
"""
|
309
|
+
"""
|
310
|
+
Save YOLO pose detections to a text file in normalized coordinates.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
|
314
|
+
pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
|
315
|
+
and D is the dimension (typically 3 for x, y, visibility).
|
316
|
+
save_conf (bool): Whether to save confidence scores.
|
317
|
+
shape (tuple): Original image shape (height, width).
|
318
|
+
file (Path): Output file path to save detections.
|
319
|
+
|
320
|
+
Notes:
|
321
|
+
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
322
|
+
normalized (x, y, visibility) values for each point.
|
323
|
+
"""
|
235
324
|
from ultralytics.engine.results import Results
|
236
325
|
|
237
326
|
Results(
|
@@ -243,7 +332,23 @@ class PoseValidator(DetectionValidator):
|
|
243
332
|
).save_txt(file, save_conf=save_conf)
|
244
333
|
|
245
334
|
def pred_to_json(self, predn, filename):
|
246
|
-
"""
|
335
|
+
"""
|
336
|
+
Convert YOLO predictions to COCO JSON format.
|
337
|
+
|
338
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
339
|
+
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
340
|
+
|
341
|
+
Args:
|
342
|
+
predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
|
343
|
+
and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
|
344
|
+
keypoints dimension.
|
345
|
+
filename (str | Path): Path to the image file for which predictions are being processed.
|
346
|
+
|
347
|
+
Notes:
|
348
|
+
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
349
|
+
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
350
|
+
before saving to the JSON dictionary.
|
351
|
+
"""
|
247
352
|
stem = Path(filename).stem
|
248
353
|
image_id = int(stem) if stem.isnumeric() else stem
|
249
354
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -31,12 +31,37 @@ class SegmentationPredictor(DetectionPredictor):
|
|
31
31
|
"""
|
32
32
|
|
33
33
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
34
|
-
"""
|
34
|
+
"""
|
35
|
+
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
36
|
+
|
37
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
38
|
+
prediction results.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
|
42
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
43
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
44
|
+
"""
|
35
45
|
super().__init__(cfg, overrides, _callbacks)
|
36
46
|
self.args.task = "segment"
|
37
47
|
|
38
48
|
def postprocess(self, preds, img, orig_imgs):
|
39
|
-
"""
|
49
|
+
"""
|
50
|
+
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
54
|
+
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
55
|
+
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
59
|
+
Each Results object includes both bounding boxes and segmentation masks.
|
60
|
+
|
61
|
+
Examples:
|
62
|
+
>>> predictor = SegmentationPredictor(overrides=dict(model="yolov8n-seg.pt"))
|
63
|
+
>>> results = predictor.postprocess(preds, img, orig_img)
|
64
|
+
"""
|
40
65
|
# Extract protos - tuple if PyTorch model or array if exported
|
41
66
|
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
42
67
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|