ultralytics 8.3.100__py3-none-any.whl → 8.3.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_solutions.py +140 -76
- ultralytics/__init__.py +1 -1
- ultralytics/engine/exporter.py +20 -5
- ultralytics/engine/model.py +1 -1
- ultralytics/engine/predictor.py +3 -1
- ultralytics/hub/__init__.py +29 -2
- ultralytics/hub/google/__init__.py +18 -1
- ultralytics/models/fastsam/predict.py +12 -1
- ultralytics/models/nas/predict.py +21 -3
- ultralytics/models/rtdetr/val.py +26 -2
- ultralytics/models/sam/amg.py +22 -1
- ultralytics/models/sam/modules/encoders.py +85 -4
- ultralytics/models/sam/modules/memory_attention.py +61 -3
- ultralytics/models/sam/modules/utils.py +108 -5
- ultralytics/models/utils/loss.py +38 -2
- ultralytics/models/utils/ops.py +15 -1
- ultralytics/models/yolo/classify/predict.py +11 -1
- ultralytics/models/yolo/classify/train.py +17 -1
- ultralytics/models/yolo/classify/val.py +82 -6
- ultralytics/models/yolo/detect/predict.py +20 -1
- ultralytics/models/yolo/model.py +69 -4
- ultralytics/models/yolo/obb/predict.py +16 -1
- ultralytics/models/yolo/obb/train.py +35 -2
- ultralytics/models/yolo/obb/val.py +87 -6
- ultralytics/models/yolo/pose/predict.py +18 -1
- ultralytics/models/yolo/pose/train.py +48 -3
- ultralytics/models/yolo/pose/val.py +113 -8
- ultralytics/models/yolo/segment/predict.py +27 -2
- ultralytics/models/yolo/segment/train.py +61 -3
- ultralytics/models/yolo/segment/val.py +10 -1
- ultralytics/models/yolo/world/train_world.py +29 -1
- ultralytics/models/yolo/yoloe/train.py +47 -3
- ultralytics/nn/modules/activation.py +26 -3
- ultralytics/nn/modules/block.py +89 -0
- ultralytics/nn/modules/head.py +3 -92
- ultralytics/nn/modules/utils.py +70 -4
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +93 -17
- ultralytics/utils/benchmarks.py +1 -1
- ultralytics/utils/callbacks/base.py +22 -5
- ultralytics/utils/callbacks/comet.py +93 -5
- ultralytics/utils/callbacks/dvc.py +64 -5
- ultralytics/utils/callbacks/neptune.py +25 -2
- ultralytics/utils/callbacks/tensorboard.py +30 -2
- ultralytics/utils/callbacks/wb.py +16 -1
- ultralytics/utils/dist.py +35 -2
- ultralytics/utils/errors.py +27 -6
- ultralytics/utils/patches.py +33 -5
- ultralytics/utils/triton.py +16 -3
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/METADATA +1 -2
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/RECORD +55 -55
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.100.dist-info → ultralytics-8.3.102.dist-info}/top_level.txt +0 -0
@@ -26,14 +26,47 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
26
26
|
"""
|
27
27
|
|
28
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
29
|
-
"""
|
29
|
+
"""
|
30
|
+
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
31
|
+
|
32
|
+
This trainer extends the DetectionTrainer class to specialize in training models that detect oriented
|
33
|
+
bounding boxes. It automatically sets the task to 'obb' in the configuration.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
37
|
+
model configuration.
|
38
|
+
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
39
|
+
will take precedence over those in cfg.
|
40
|
+
_callbacks (list, optional): List of callback functions to be invoked during training.
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
44
|
+
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
45
|
+
>>> trainer = OBBTrainer(overrides=args)
|
46
|
+
>>> trainer.train()
|
47
|
+
"""
|
30
48
|
if overrides is None:
|
31
49
|
overrides = {}
|
32
50
|
overrides["task"] = "obb"
|
33
51
|
super().__init__(cfg, overrides, _callbacks)
|
34
52
|
|
35
53
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
36
|
-
"""
|
54
|
+
"""
|
55
|
+
Return OBBModel initialized with specified config and weights.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
cfg (str | dict | None): Model configuration. Can be a path to a YAML config file, a dictionary
|
59
|
+
containing configuration parameters, or None to use default configuration.
|
60
|
+
weights (str | Path | None): Path to pretrained weights file. If None, random initialization is used.
|
61
|
+
verbose (bool): Whether to display model information during initialization.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
(OBBModel): Initialized OBBModel with the specified configuration and weights.
|
65
|
+
|
66
|
+
Examples:
|
67
|
+
>>> trainer = OBBTrainer()
|
68
|
+
>>> model = trainer.get_model(cfg="yolov8n-obb.yaml", weights="yolov8n-obb.pt")
|
69
|
+
"""
|
37
70
|
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
38
71
|
if weights:
|
39
72
|
model.load(weights)
|
@@ -40,7 +40,19 @@ class OBBValidator(DetectionValidator):
|
|
40
40
|
"""
|
41
41
|
|
42
42
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
43
|
-
"""
|
43
|
+
"""
|
44
|
+
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
45
|
+
|
46
|
+
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
47
|
+
It extends the DetectionValidator class and configures it specifically for the OBB task.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
51
|
+
save_dir (str | Path, optional): Directory to save results.
|
52
|
+
pbar (bool, optional): Display progress bar during validation.
|
53
|
+
args (dict, optional): Arguments containing validation parameters.
|
54
|
+
_callbacks (list, optional): List of callback functions to be called during validation.
|
55
|
+
"""
|
44
56
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
45
57
|
self.args.task = "obb"
|
46
58
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True)
|
@@ -79,7 +91,22 @@ class OBBValidator(DetectionValidator):
|
|
79
91
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
80
92
|
|
81
93
|
def _prepare_batch(self, si, batch):
|
82
|
-
"""
|
94
|
+
"""
|
95
|
+
Prepare batch data for OBB validation with proper scaling and formatting.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
si (int): Batch index to process.
|
99
|
+
batch (dict): Dictionary containing batch data with keys:
|
100
|
+
- batch_idx: Tensor of batch indices
|
101
|
+
- cls: Tensor of class labels
|
102
|
+
- bboxes: Tensor of bounding boxes
|
103
|
+
- ori_shape: Original image shapes
|
104
|
+
- img: Batch of images
|
105
|
+
- ratio_pad: Ratio and padding information
|
106
|
+
|
107
|
+
This method filters the batch data for a specific batch index, extracts class labels and bounding boxes,
|
108
|
+
and scales the bounding boxes to the original image dimensions.
|
109
|
+
"""
|
83
110
|
idx = batch["batch_idx"] == si
|
84
111
|
cls = batch["cls"][idx].squeeze(-1)
|
85
112
|
bbox = batch["bboxes"][idx]
|
@@ -92,7 +119,22 @@ class OBBValidator(DetectionValidator):
|
|
92
119
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
93
120
|
|
94
121
|
def _prepare_pred(self, pred, pbatch):
|
95
|
-
"""
|
122
|
+
"""
|
123
|
+
Prepare predictions by scaling bounding boxes to original image dimensions.
|
124
|
+
|
125
|
+
This method takes prediction tensors containing bounding box coordinates and scales them from the model's
|
126
|
+
input dimensions to the original image dimensions using the provided batch information.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
pred (torch.Tensor): Prediction tensor containing bounding box coordinates and other information.
|
130
|
+
pbatch (dict): Dictionary containing batch information with keys:
|
131
|
+
- imgsz (tuple): Model input image size.
|
132
|
+
- ori_shape (tuple): Original image shape.
|
133
|
+
- ratio_pad (tuple): Ratio and padding information for scaling.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
(torch.Tensor): Scaled prediction tensor with bounding boxes in original image dimensions.
|
137
|
+
"""
|
96
138
|
predn = pred.clone()
|
97
139
|
ops.scale_boxes(
|
98
140
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
@@ -100,7 +142,20 @@ class OBBValidator(DetectionValidator):
|
|
100
142
|
return predn
|
101
143
|
|
102
144
|
def plot_predictions(self, batch, preds, ni):
|
103
|
-
"""
|
145
|
+
"""
|
146
|
+
Plot predicted bounding boxes on input images and save the result.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
batch (dict): Batch data containing images, file paths, and other metadata.
|
150
|
+
preds (list): List of prediction tensors for each image in the batch.
|
151
|
+
ni (int): Batch index used for naming the output file.
|
152
|
+
|
153
|
+
Examples:
|
154
|
+
>>> validator = OBBValidator()
|
155
|
+
>>> batch = {"img": images, "im_file": paths}
|
156
|
+
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
157
|
+
>>> validator.plot_predictions(batch, preds, 0)
|
158
|
+
"""
|
104
159
|
plot_images(
|
105
160
|
batch["img"],
|
106
161
|
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
@@ -111,7 +166,19 @@ class OBBValidator(DetectionValidator):
|
|
111
166
|
) # pred
|
112
167
|
|
113
168
|
def pred_to_json(self, predn, filename):
|
114
|
-
"""
|
169
|
+
"""
|
170
|
+
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
predn (torch.Tensor): Prediction tensor containing bounding box coordinates, confidence scores,
|
174
|
+
class predictions, and rotation angles with shape (N, 6+) where the last column is the angle.
|
175
|
+
filename (str | Path): Path to the image file for which predictions are being processed.
|
176
|
+
|
177
|
+
Notes:
|
178
|
+
This method processes rotated bounding box predictions and converts them to both rbox format
|
179
|
+
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
180
|
+
to the JSON dictionary.
|
181
|
+
"""
|
115
182
|
stem = Path(filename).stem
|
116
183
|
image_id = int(stem) if stem.isnumeric() else stem
|
117
184
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
@@ -128,7 +195,21 @@ class OBBValidator(DetectionValidator):
|
|
128
195
|
)
|
129
196
|
|
130
197
|
def save_one_txt(self, predn, save_conf, shape, file):
|
131
|
-
"""
|
198
|
+
"""
|
199
|
+
Save YOLO OBB (Oriented Bounding Box) detections to a text file in normalized coordinates.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
203
|
+
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
204
|
+
save_conf (bool): Whether to save confidence scores in the text file.
|
205
|
+
shape (tuple): Original image shape in format (height, width).
|
206
|
+
file (Path | str): Output file path to save detections.
|
207
|
+
|
208
|
+
Examples:
|
209
|
+
>>> validator = OBBValidator()
|
210
|
+
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
|
211
|
+
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
|
212
|
+
"""
|
132
213
|
import numpy as np
|
133
214
|
|
134
215
|
from ultralytics.engine.results import Results
|
@@ -27,7 +27,24 @@ class PosePredictor(DetectionPredictor):
|
|
27
27
|
"""
|
28
28
|
|
29
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Initialize PosePredictor, a specialized predictor for pose estimation tasks.
|
32
|
+
|
33
|
+
This initializer sets up a PosePredictor instance, configuring it for pose detection tasks and handling
|
34
|
+
device-specific warnings for Apple MPS.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
cfg (Any): Configuration for the predictor. Default is DEFAULT_CFG.
|
38
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
39
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
40
|
+
|
41
|
+
Examples:
|
42
|
+
>>> from ultralytics.utils import ASSETS
|
43
|
+
>>> from ultralytics.models.yolo.pose import PosePredictor
|
44
|
+
>>> args = dict(model="yolov8n-pose.pt", source=ASSETS)
|
45
|
+
>>> predictor = PosePredictor(overrides=args)
|
46
|
+
>>> predictor.predict_cli()
|
47
|
+
"""
|
31
48
|
super().__init__(cfg, overrides, _callbacks)
|
32
49
|
self.args.task = "pose"
|
33
50
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
@@ -36,7 +36,27 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
36
36
|
"""
|
37
37
|
|
38
38
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
39
|
-
"""
|
39
|
+
"""
|
40
|
+
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
41
|
+
|
42
|
+
This initializes a trainer specialized for pose estimation tasks, setting the task to 'pose' and
|
43
|
+
handling specific configurations needed for keypoint detection models.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
47
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
48
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
49
|
+
|
50
|
+
Notes:
|
51
|
+
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
52
|
+
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
56
|
+
>>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
57
|
+
>>> trainer = PoseTrainer(overrides=args)
|
58
|
+
>>> trainer.train()
|
59
|
+
"""
|
40
60
|
if overrides is None:
|
41
61
|
overrides = {}
|
42
62
|
overrides["task"] = "pose"
|
@@ -49,7 +69,17 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
49
69
|
)
|
50
70
|
|
51
71
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
52
|
-
"""
|
72
|
+
"""
|
73
|
+
Get pose estimation model with specified configuration and weights.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
cfg (str | Path | dict | None): Model configuration file path or dictionary.
|
77
|
+
weights (str | Path | None): Path to the model weights file.
|
78
|
+
verbose (bool): Whether to display model information.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
(PoseModel): Initialized pose estimation model.
|
82
|
+
"""
|
53
83
|
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
|
54
84
|
if weights:
|
55
85
|
model.load(weights)
|
@@ -69,7 +99,22 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
69
99
|
)
|
70
100
|
|
71
101
|
def plot_training_samples(self, batch, ni):
|
72
|
-
"""
|
102
|
+
"""
|
103
|
+
Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
batch (dict): Dictionary containing batch data with the following keys:
|
107
|
+
- img (torch.Tensor): Batch of images
|
108
|
+
- keypoints (torch.Tensor): Keypoints coordinates for pose estimation
|
109
|
+
- cls (torch.Tensor): Class labels
|
110
|
+
- bboxes (torch.Tensor): Bounding box coordinates
|
111
|
+
- im_file (list): List of image file paths
|
112
|
+
- batch_idx (torch.Tensor): Batch indices for each instance
|
113
|
+
ni (int): Current training iteration number used for filename
|
114
|
+
|
115
|
+
The function saves the plotted batch as an image in the trainer's save directory with the filename
|
116
|
+
'train_batch{ni}.jpg', where ni is the iteration number.
|
117
|
+
"""
|
73
118
|
images = batch["img"]
|
74
119
|
kpts = batch["keypoints"]
|
75
120
|
cls = batch["cls"].squeeze(-1)
|
@@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
|
|
20
20
|
specialized metrics for pose evaluation.
|
21
21
|
|
22
22
|
Attributes:
|
23
|
-
sigma (np.ndarray): Sigma values for OKS calculation, either
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
24
24
|
kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
25
25
|
args (dict): Arguments for the validator including task set to "pose".
|
26
26
|
metrics (PoseMetrics): Metrics object for pose evaluation.
|
@@ -47,7 +47,30 @@ class PoseValidator(DetectionValidator):
|
|
47
47
|
"""
|
48
48
|
|
49
49
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
50
|
-
"""
|
50
|
+
"""
|
51
|
+
Initialize a PoseValidator object for pose estimation validation.
|
52
|
+
|
53
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
54
|
+
specialized metrics for pose evaluation.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
58
|
+
save_dir (Path | str, optional): Directory to save results.
|
59
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
60
|
+
args (dict, optional): Arguments for the validator including task set to "pose".
|
61
|
+
_callbacks (list, optional): List of callback functions to be executed during validation.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
65
|
+
>>> args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
|
66
|
+
>>> validator = PoseValidator(args=args)
|
67
|
+
>>> validator()
|
68
|
+
|
69
|
+
Notes:
|
70
|
+
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
71
|
+
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
72
|
+
due to a known bug with pose models.
|
73
|
+
"""
|
51
74
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
52
75
|
self.sigma = None
|
53
76
|
self.kpt_shape = None
|
@@ -91,7 +114,20 @@ class PoseValidator(DetectionValidator):
|
|
91
114
|
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
92
115
|
|
93
116
|
def _prepare_batch(self, si, batch):
|
94
|
-
"""
|
117
|
+
"""
|
118
|
+
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
si (int): Batch index.
|
122
|
+
batch (dict): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
pbatch (dict): Prepared batch with keypoints scaled to original image dimensions.
|
126
|
+
|
127
|
+
Notes:
|
128
|
+
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
129
|
+
Keypoints are scaled from normalized coordinates to original image dimensions.
|
130
|
+
"""
|
95
131
|
pbatch = super()._prepare_batch(si, batch)
|
96
132
|
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
97
133
|
h, w = pbatch["imgsz"]
|
@@ -103,7 +139,23 @@ class PoseValidator(DetectionValidator):
|
|
103
139
|
return pbatch
|
104
140
|
|
105
141
|
def _prepare_pred(self, pred, pbatch):
|
106
|
-
"""
|
142
|
+
"""
|
143
|
+
Prepare and scale keypoints in predictions for pose processing.
|
144
|
+
|
145
|
+
This method extends the parent class's _prepare_pred method to handle keypoint scaling. It first calls
|
146
|
+
the parent method to get the basic prediction boxes, then extracts and scales the keypoint coordinates
|
147
|
+
to match the original image dimensions.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
pred (torch.Tensor): Raw prediction tensor from the model.
|
151
|
+
pbatch (dict): Processed batch dictionary containing image information including:
|
152
|
+
- imgsz: Image size used for inference
|
153
|
+
- ori_shape: Original image shape
|
154
|
+
- ratio_pad: Ratio and padding information for coordinate scaling
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
predn (torch.Tensor): Processed prediction boxes scaled to original image dimensions.
|
158
|
+
"""
|
107
159
|
predn = super()._prepare_pred(pred, pbatch)
|
108
160
|
nk = pbatch["kpts"].shape[1]
|
109
161
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
@@ -204,7 +256,19 @@ class PoseValidator(DetectionValidator):
|
|
204
256
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
205
257
|
|
206
258
|
def plot_val_samples(self, batch, ni):
|
207
|
-
"""
|
259
|
+
"""
|
260
|
+
Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
batch (dict): Dictionary containing batch data with keys:
|
264
|
+
- img (torch.Tensor): Batch of images
|
265
|
+
- batch_idx (torch.Tensor): Batch indices for each image
|
266
|
+
- cls (torch.Tensor): Class labels
|
267
|
+
- bboxes (torch.Tensor): Bounding box coordinates
|
268
|
+
- keypoints (torch.Tensor): Keypoint coordinates
|
269
|
+
- im_file (list): List of image file paths
|
270
|
+
ni (int): Batch index used for naming the output file
|
271
|
+
"""
|
208
272
|
plot_images(
|
209
273
|
batch["img"],
|
210
274
|
batch["batch_idx"],
|
@@ -218,7 +282,18 @@ class PoseValidator(DetectionValidator):
|
|
218
282
|
)
|
219
283
|
|
220
284
|
def plot_predictions(self, batch, preds, ni):
|
221
|
-
"""
|
285
|
+
"""
|
286
|
+
Plot and save model predictions with bounding boxes and keypoints.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
batch (dict): Dictionary containing batch data including images, file paths, and other metadata.
|
290
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model, each containing bounding boxes,
|
291
|
+
confidence scores, class predictions, and keypoints.
|
292
|
+
ni (int): Batch index used for naming the output file.
|
293
|
+
|
294
|
+
The function extracts keypoints from predictions, converts predictions to target format, and plots them
|
295
|
+
on the input images. The resulting visualization is saved to the specified save directory.
|
296
|
+
"""
|
222
297
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
223
298
|
plot_images(
|
224
299
|
batch["img"],
|
@@ -231,7 +306,21 @@ class PoseValidator(DetectionValidator):
|
|
231
306
|
) # pred
|
232
307
|
|
233
308
|
def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
|
234
|
-
"""
|
309
|
+
"""
|
310
|
+
Save YOLO pose detections to a text file in normalized coordinates.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
predn (torch.Tensor): Prediction boxes and scores with shape (N, 6) for (x1, y1, x2, y2, conf, cls).
|
314
|
+
pred_kpts (torch.Tensor): Predicted keypoints with shape (N, K, D) where K is the number of keypoints
|
315
|
+
and D is the dimension (typically 3 for x, y, visibility).
|
316
|
+
save_conf (bool): Whether to save confidence scores.
|
317
|
+
shape (tuple): Original image shape (height, width).
|
318
|
+
file (Path): Output file path to save detections.
|
319
|
+
|
320
|
+
Notes:
|
321
|
+
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
322
|
+
normalized (x, y, visibility) values for each point.
|
323
|
+
"""
|
235
324
|
from ultralytics.engine.results import Results
|
236
325
|
|
237
326
|
Results(
|
@@ -243,7 +332,23 @@ class PoseValidator(DetectionValidator):
|
|
243
332
|
).save_txt(file, save_conf=save_conf)
|
244
333
|
|
245
334
|
def pred_to_json(self, predn, filename):
|
246
|
-
"""
|
335
|
+
"""
|
336
|
+
Convert YOLO predictions to COCO JSON format.
|
337
|
+
|
338
|
+
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
339
|
+
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
340
|
+
|
341
|
+
Args:
|
342
|
+
predn (torch.Tensor): Prediction tensor containing bounding boxes, confidence scores, class IDs,
|
343
|
+
and keypoints, with shape (N, 6+K) where N is the number of predictions and K is the flattened
|
344
|
+
keypoints dimension.
|
345
|
+
filename (str | Path): Path to the image file for which predictions are being processed.
|
346
|
+
|
347
|
+
Notes:
|
348
|
+
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
349
|
+
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
350
|
+
before saving to the JSON dictionary.
|
351
|
+
"""
|
247
352
|
stem = Path(filename).stem
|
248
353
|
image_id = int(stem) if stem.isnumeric() else stem
|
249
354
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -31,12 +31,37 @@ class SegmentationPredictor(DetectionPredictor):
|
|
31
31
|
"""
|
32
32
|
|
33
33
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
34
|
-
"""
|
34
|
+
"""
|
35
|
+
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
36
|
+
|
37
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
38
|
+
prediction results.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
|
42
|
+
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
43
|
+
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
44
|
+
"""
|
35
45
|
super().__init__(cfg, overrides, _callbacks)
|
36
46
|
self.args.task = "segment"
|
37
47
|
|
38
48
|
def postprocess(self, preds, img, orig_imgs):
|
39
|
-
"""
|
49
|
+
"""
|
50
|
+
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
54
|
+
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
55
|
+
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
59
|
+
Each Results object includes both bounding boxes and segmentation masks.
|
60
|
+
|
61
|
+
Examples:
|
62
|
+
>>> predictor = SegmentationPredictor(overrides=dict(model="yolov8n-seg.pt"))
|
63
|
+
>>> results = predictor.postprocess(preds, img, orig_img)
|
64
|
+
"""
|
40
65
|
# Extract protos - tuple if PyTorch model or array if exported
|
41
66
|
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
42
67
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
@@ -26,14 +26,45 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
26
26
|
"""
|
27
27
|
|
28
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
29
|
-
"""
|
29
|
+
"""
|
30
|
+
Initialize a SegmentationTrainer object.
|
31
|
+
|
32
|
+
This initializes a trainer for segmentation tasks, extending the detection trainer with segmentation-specific
|
33
|
+
functionality. It sets the task to 'segment' and prepares the trainer for training segmentation models.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
cfg (dict): Configuration dictionary with default training settings. Defaults to DEFAULT_CFG.
|
37
|
+
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
38
|
+
_callbacks (list, optional): List of callback functions to be executed during training.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
42
|
+
>>> args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
43
|
+
>>> trainer = SegmentationTrainer(overrides=args)
|
44
|
+
>>> trainer.train()
|
45
|
+
"""
|
30
46
|
if overrides is None:
|
31
47
|
overrides = {}
|
32
48
|
overrides["task"] = "segment"
|
33
49
|
super().__init__(cfg, overrides, _callbacks)
|
34
50
|
|
35
51
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
36
|
-
"""
|
52
|
+
"""
|
53
|
+
Initialize and return a SegmentationModel with specified configuration and weights.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
cfg (dict | str | None): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
57
|
+
weights (str | Path | None): Path to pretrained weights file.
|
58
|
+
verbose (bool): Whether to display model information during initialization.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
|
62
|
+
|
63
|
+
Examples:
|
64
|
+
>>> trainer = SegmentationTrainer()
|
65
|
+
>>> model = trainer.get_model(cfg="yolov8n-seg.yaml")
|
66
|
+
>>> model = trainer.get_model(weights="yolov8n-seg.pt", verbose=False)
|
67
|
+
"""
|
37
68
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
38
69
|
if weights:
|
39
70
|
model.load(weights)
|
@@ -48,7 +79,34 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
48
79
|
)
|
49
80
|
|
50
81
|
def plot_training_samples(self, batch, ni):
|
51
|
-
"""
|
82
|
+
"""
|
83
|
+
Plot training sample images with labels, bounding boxes, and masks.
|
84
|
+
|
85
|
+
This method creates a visualization of training batch images with their corresponding labels, bounding boxes,
|
86
|
+
and segmentation masks, saving the result to a file for inspection and debugging.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
batch (dict): Dictionary containing batch data with the following keys:
|
90
|
+
'img': Images tensor
|
91
|
+
'batch_idx': Batch indices for each box
|
92
|
+
'cls': Class labels tensor (squeezed to remove last dimension)
|
93
|
+
'bboxes': Bounding box coordinates tensor
|
94
|
+
'masks': Segmentation masks tensor
|
95
|
+
'im_file': List of image file paths
|
96
|
+
ni (int): Current training iteration number, used for naming the output file.
|
97
|
+
|
98
|
+
Examples:
|
99
|
+
>>> trainer = SegmentationTrainer()
|
100
|
+
>>> batch = {
|
101
|
+
... "img": torch.rand(16, 3, 640, 640),
|
102
|
+
... "batch_idx": torch.zeros(16),
|
103
|
+
... "cls": torch.randint(0, 80, (16, 1)),
|
104
|
+
... "bboxes": torch.rand(16, 4),
|
105
|
+
... "masks": torch.rand(16, 640, 640),
|
106
|
+
... "im_file": ["image1.jpg", "image2.jpg"],
|
107
|
+
... }
|
108
|
+
>>> trainer.plot_training_samples(batch, ni=5)
|
109
|
+
"""
|
52
110
|
plot_images(
|
53
111
|
batch["img"],
|
54
112
|
batch["batch_idx"],
|
@@ -215,7 +215,16 @@ class SegmentationValidator(DetectionValidator):
|
|
215
215
|
)
|
216
216
|
|
217
217
|
def finalize_metrics(self, *args, **kwargs):
|
218
|
-
"""
|
218
|
+
"""
|
219
|
+
Finalize evaluation metrics by setting the speed attribute in the metrics object.
|
220
|
+
|
221
|
+
This method is called at the end of validation to set the processing speed for the metrics calculations.
|
222
|
+
It transfers the validator's speed measurement to the metrics object for reporting.
|
223
|
+
|
224
|
+
Args:
|
225
|
+
*args (Any): Variable length argument list.
|
226
|
+
**kwargs (Any): Arbitrary keyword arguments.
|
227
|
+
"""
|
219
228
|
self.metrics.speed = self.speed
|
220
229
|
self.metrics.confusion_matrix = self.confusion_matrix
|
221
230
|
|
@@ -43,7 +43,35 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
43
43
|
"""
|
44
44
|
|
45
45
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
46
|
-
"""
|
46
|
+
"""
|
47
|
+
Initialize a WorldTrainerFromScratch object.
|
48
|
+
|
49
|
+
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
|
50
|
+
object detection and grounding datasets for vision-language capabilities.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
cfg (dict): Configuration dictionary with default parameters for model training.
|
54
|
+
overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
|
55
|
+
_callbacks (list, optional): List of callback functions to be executed during different stages of training.
|
56
|
+
|
57
|
+
Examples:
|
58
|
+
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
59
|
+
>>> from ultralytics import YOLOWorld
|
60
|
+
>>> data = dict(
|
61
|
+
... train=dict(
|
62
|
+
... yolo_data=["Objects365.yaml"],
|
63
|
+
... grounding_data=[
|
64
|
+
... dict(
|
65
|
+
... img_path="../datasets/flickr30k/images",
|
66
|
+
... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
|
67
|
+
... ),
|
68
|
+
... ],
|
69
|
+
... ),
|
70
|
+
... val=dict(yolo_data=["lvis.yaml"]),
|
71
|
+
... )
|
72
|
+
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
|
73
|
+
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
|
74
|
+
"""
|
47
75
|
if overrides is None:
|
48
76
|
overrides = {}
|
49
77
|
super().__init__(cfg, overrides, _callbacks)
|