ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +13 -19
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +64 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
@@ -93,7 +93,7 @@ class YOLOWorld(Model):
|
|
93
93
|
|
94
94
|
def set_classes(self, classes):
|
95
95
|
"""
|
96
|
-
Set
|
96
|
+
Set the model's class names for detection.
|
97
97
|
|
98
98
|
Args:
|
99
99
|
classes (List(str)): A list of categories i.e. ["person"].
|
@@ -106,6 +106,5 @@ class YOLOWorld(Model):
|
|
106
106
|
self.model.names = classes
|
107
107
|
|
108
108
|
# Reset method class names
|
109
|
-
# self.predictor = None # reset predictor otherwise old names remain
|
110
109
|
if self.predictor:
|
111
110
|
self.predictor.model.names = classes
|
@@ -11,6 +11,13 @@ class OBBPredictor(DetectionPredictor):
|
|
11
11
|
"""
|
12
12
|
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
13
13
|
|
14
|
+
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
15
|
+
bounding boxes.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
args (namespace): Configuration arguments for the predictor.
|
19
|
+
model (torch.nn.Module): The loaded YOLO OBB model.
|
20
|
+
|
14
21
|
Examples:
|
15
22
|
>>> from ultralytics.utils import ASSETS
|
16
23
|
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
@@ -20,17 +27,18 @@ class OBBPredictor(DetectionPredictor):
|
|
20
27
|
"""
|
21
28
|
|
22
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
23
|
-
"""
|
30
|
+
"""Initialize OBBPredictor with optional model and data configuration overrides."""
|
24
31
|
super().__init__(cfg, overrides, _callbacks)
|
25
32
|
self.args.task = "obb"
|
26
33
|
|
27
34
|
def construct_result(self, pred, img, orig_img, img_path):
|
28
35
|
"""
|
29
|
-
|
36
|
+
Construct the result object from the prediction.
|
30
37
|
|
31
38
|
Args:
|
32
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles
|
33
|
-
|
39
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 6) where
|
40
|
+
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
41
|
+
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
34
42
|
orig_img (np.ndarray): The original image before preprocessing.
|
35
43
|
img_path (str): The path to the original image.
|
36
44
|
|
@@ -11,6 +11,13 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
11
11
|
"""
|
12
12
|
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
13
13
|
|
14
|
+
Attributes:
|
15
|
+
loss_names (Tuple[str]): Names of the loss components used during training.
|
16
|
+
|
17
|
+
Methods:
|
18
|
+
get_model: Return OBBModel initialized with specified config and weights.
|
19
|
+
get_validator: Return an instance of OBBValidator for validation of YOLO model.
|
20
|
+
|
14
21
|
Examples:
|
15
22
|
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
16
23
|
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
@@ -14,6 +14,24 @@ class OBBValidator(DetectionValidator):
|
|
14
14
|
"""
|
15
15
|
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
16
16
|
|
17
|
+
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
18
|
+
satellite imagery where objects can appear at various orientations.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
args (Dict): Configuration arguments for the validator.
|
22
|
+
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
|
23
|
+
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
|
24
|
+
|
25
|
+
Methods:
|
26
|
+
init_metrics: Initialize evaluation metrics for YOLO.
|
27
|
+
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
|
28
|
+
_prepare_batch: Prepare batch data for OBB validation.
|
29
|
+
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
|
30
|
+
plot_predictions: Plot predicted bounding boxes on input images.
|
31
|
+
pred_to_json: Serialize YOLO predictions to COCO json format.
|
32
|
+
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
|
33
|
+
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
|
34
|
+
|
17
35
|
Examples:
|
18
36
|
>>> from ultralytics.models.yolo.obb import OBBValidator
|
19
37
|
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
|
@@ -31,7 +49,7 @@ class OBBValidator(DetectionValidator):
|
|
31
49
|
"""Initialize evaluation metrics for YOLO."""
|
32
50
|
super().init_metrics(model)
|
33
51
|
val = self.data.get(self.args.split, "") # validation path
|
34
|
-
self.is_dota = isinstance(val, str) and "DOTA" in val # is
|
52
|
+
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
35
53
|
|
36
54
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
37
55
|
"""
|
@@ -61,7 +79,7 @@ class OBBValidator(DetectionValidator):
|
|
61
79
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
62
80
|
|
63
81
|
def _prepare_batch(self, si, batch):
|
64
|
-
"""
|
82
|
+
"""Prepare batch data for OBB validation with proper scaling and formatting."""
|
65
83
|
idx = batch["batch_idx"] == si
|
66
84
|
cls = batch["cls"][idx].squeeze(-1)
|
67
85
|
bbox = batch["bboxes"][idx]
|
@@ -74,7 +92,7 @@ class OBBValidator(DetectionValidator):
|
|
74
92
|
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
75
93
|
|
76
94
|
def _prepare_pred(self, pred, pbatch):
|
77
|
-
"""
|
95
|
+
"""Prepare predictions by scaling bounding boxes to original image dimensions."""
|
78
96
|
predn = pred.clone()
|
79
97
|
ops.scale_boxes(
|
80
98
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
@@ -82,7 +100,7 @@ class OBBValidator(DetectionValidator):
|
|
82
100
|
return predn
|
83
101
|
|
84
102
|
def plot_predictions(self, batch, preds, ni):
|
85
|
-
"""
|
103
|
+
"""Plot predicted bounding boxes on input images and save the result."""
|
86
104
|
plot_images(
|
87
105
|
batch["img"],
|
88
106
|
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
@@ -93,7 +111,7 @@ class OBBValidator(DetectionValidator):
|
|
93
111
|
) # pred
|
94
112
|
|
95
113
|
def pred_to_json(self, predn, filename):
|
96
|
-
"""
|
114
|
+
"""Convert YOLO predictions to COCO JSON format with rotated bounding box information."""
|
97
115
|
stem = Path(filename).stem
|
98
116
|
image_id = int(stem) if stem.isnumeric() else stem
|
99
117
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
@@ -110,7 +128,7 @@ class OBBValidator(DetectionValidator):
|
|
110
128
|
)
|
111
129
|
|
112
130
|
def save_one_txt(self, predn, save_conf, shape, file):
|
113
|
-
"""Save YOLO detections to a txt file in normalized coordinates
|
131
|
+
"""Save YOLO detections to a txt file in normalized coordinates using the Results class."""
|
114
132
|
import numpy as np
|
115
133
|
|
116
134
|
from ultralytics.engine.results import Results
|
@@ -126,7 +144,7 @@ class OBBValidator(DetectionValidator):
|
|
126
144
|
).save_txt(file, save_conf=save_conf)
|
127
145
|
|
128
146
|
def eval_json(self, stats):
|
129
|
-
"""
|
147
|
+
"""Evaluate YOLO output in JSON format and save predictions in DOTA format."""
|
130
148
|
if self.args.save_json and self.is_dota and len(self.jdict):
|
131
149
|
import json
|
132
150
|
import re
|
@@ -8,6 +8,16 @@ class PosePredictor(DetectionPredictor):
|
|
8
8
|
"""
|
9
9
|
A class extending the DetectionPredictor class for prediction based on a pose model.
|
10
10
|
|
11
|
+
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
12
|
+
capabilities inherited from DetectionPredictor.
|
13
|
+
|
14
|
+
Attributes:
|
15
|
+
args (namespace): Configuration arguments for the predictor.
|
16
|
+
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
17
|
+
|
18
|
+
Methods:
|
19
|
+
construct_result: Constructs the result object from the prediction, including keypoints.
|
20
|
+
|
11
21
|
Examples:
|
12
22
|
>>> from ultralytics.utils import ASSETS
|
13
23
|
>>> from ultralytics.models.yolo.pose import PosePredictor
|
@@ -17,7 +27,7 @@ class PosePredictor(DetectionPredictor):
|
|
17
27
|
"""
|
18
28
|
|
19
29
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
20
|
-
"""
|
30
|
+
"""Initialize PosePredictor, set task to 'pose' and log a warning for using 'mps' as device."""
|
21
31
|
super().__init__(cfg, overrides, _callbacks)
|
22
32
|
self.args.task = "pose"
|
23
33
|
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
@@ -28,19 +38,25 @@ class PosePredictor(DetectionPredictor):
|
|
28
38
|
|
29
39
|
def construct_result(self, pred, img, orig_img, img_path):
|
30
40
|
"""
|
31
|
-
|
41
|
+
Construct the result object from the prediction, including keypoints.
|
42
|
+
|
43
|
+
This method extends the parent class implementation by extracting keypoint data from predictions
|
44
|
+
and adding them to the result object.
|
32
45
|
|
33
46
|
Args:
|
34
|
-
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints
|
35
|
-
|
36
|
-
|
37
|
-
|
47
|
+
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
48
|
+
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
|
49
|
+
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
|
50
|
+
orig_img (np.ndarray): The original unprocessed image as a numpy array.
|
51
|
+
img_path (str): The path to the original image file.
|
38
52
|
|
39
53
|
Returns:
|
40
54
|
(Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
|
41
55
|
"""
|
42
56
|
result = super().construct_result(pred, img, orig_img, img_path)
|
57
|
+
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
43
58
|
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
59
|
+
# Scale keypoints coordinates to match the original image dimensions
|
44
60
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
45
61
|
result.update(keypoints=pred_kpts)
|
46
62
|
return result
|
@@ -10,7 +10,23 @@ from ultralytics.utils.plotting import plot_images, plot_results
|
|
10
10
|
|
11
11
|
class PoseTrainer(yolo.detect.DetectionTrainer):
|
12
12
|
"""
|
13
|
-
A class extending the DetectionTrainer class for training
|
13
|
+
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
14
|
+
|
15
|
+
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
16
|
+
of pose keypoints alongside bounding boxes.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
args (Dict): Configuration arguments for training.
|
20
|
+
model (PoseModel): The pose estimation model being trained.
|
21
|
+
data (Dict): Dataset configuration including keypoint shape information.
|
22
|
+
loss_names (Tuple[str]): Names of the loss components used in training.
|
23
|
+
|
24
|
+
Methods:
|
25
|
+
get_model: Retrieves a pose estimation model with specified configuration.
|
26
|
+
set_model_attributes: Sets keypoints shape attribute on the model.
|
27
|
+
get_validator: Creates a validator instance for model evaluation.
|
28
|
+
plot_training_samples: Visualizes training samples with keypoints.
|
29
|
+
plot_metrics: Generates and saves training/validation metric plots.
|
14
30
|
|
15
31
|
Examples:
|
16
32
|
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
@@ -16,6 +16,29 @@ class PoseValidator(DetectionValidator):
|
|
16
16
|
"""
|
17
17
|
A class extending the DetectionValidator class for validation based on a pose model.
|
18
18
|
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
20
|
+
specialized metrics for pose evaluation.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints.
|
24
|
+
kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
25
|
+
args (Dict): Arguments for the validator including task set to "pose".
|
26
|
+
metrics (PoseMetrics): Metrics object for pose evaluation.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
preprocess: Preprocesses batch data for pose validation.
|
30
|
+
get_desc: Returns description of evaluation metrics.
|
31
|
+
init_metrics: Initializes pose metrics for the model.
|
32
|
+
_prepare_batch: Prepares a batch for processing.
|
33
|
+
_prepare_pred: Prepares and scales predictions for evaluation.
|
34
|
+
update_metrics: Updates metrics with new predictions.
|
35
|
+
_process_batch: Processes batch to compute IoU between detections and ground truth.
|
36
|
+
plot_val_samples: Plots validation samples with ground truth annotations.
|
37
|
+
plot_predictions: Plots model predictions.
|
38
|
+
save_one_txt: Saves detections to a text file.
|
39
|
+
pred_to_json: Converts predictions to COCO JSON format.
|
40
|
+
eval_json: Evaluates model using COCO JSON format.
|
41
|
+
|
19
42
|
Examples:
|
20
43
|
>>> from ultralytics.models.yolo.pose import PoseValidator
|
21
44
|
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
@@ -24,7 +47,7 @@ class PoseValidator(DetectionValidator):
|
|
24
47
|
"""
|
25
48
|
|
26
49
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
27
|
-
"""Initialize a
|
50
|
+
"""Initialize a PoseValidator object with custom parameters and assigned attributes."""
|
28
51
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
29
52
|
self.sigma = None
|
30
53
|
self.kpt_shape = None
|
@@ -37,13 +60,13 @@ class PoseValidator(DetectionValidator):
|
|
37
60
|
)
|
38
61
|
|
39
62
|
def preprocess(self, batch):
|
40
|
-
"""
|
63
|
+
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
41
64
|
batch = super().preprocess(batch)
|
42
65
|
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
43
66
|
return batch
|
44
67
|
|
45
68
|
def get_desc(self):
|
46
|
-
"""
|
69
|
+
"""Return description of evaluation metrics in string format."""
|
47
70
|
return ("%22s" + "%11s" * 10) % (
|
48
71
|
"Class",
|
49
72
|
"Images",
|
@@ -59,7 +82,7 @@ class PoseValidator(DetectionValidator):
|
|
59
82
|
)
|
60
83
|
|
61
84
|
def init_metrics(self, model):
|
62
|
-
"""
|
85
|
+
"""Initialize pose estimation metrics for YOLO model."""
|
63
86
|
super().init_metrics(model)
|
64
87
|
self.kpt_shape = self.data["kpt_shape"]
|
65
88
|
is_pose = self.kpt_shape == [17, 3]
|
@@ -68,7 +91,7 @@ class PoseValidator(DetectionValidator):
|
|
68
91
|
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
69
92
|
|
70
93
|
def _prepare_batch(self, si, batch):
|
71
|
-
"""
|
94
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions."""
|
72
95
|
pbatch = super()._prepare_batch(si, batch)
|
73
96
|
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
74
97
|
h, w = pbatch["imgsz"]
|
@@ -80,7 +103,7 @@ class PoseValidator(DetectionValidator):
|
|
80
103
|
return pbatch
|
81
104
|
|
82
105
|
def _prepare_pred(self, pred, pbatch):
|
83
|
-
"""
|
106
|
+
"""Prepare and scale keypoints in predictions for pose processing."""
|
84
107
|
predn = super()._prepare_pred(pred, pbatch)
|
85
108
|
nk = pbatch["kpts"].shape[1]
|
86
109
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
@@ -88,7 +111,16 @@ class PoseValidator(DetectionValidator):
|
|
88
111
|
return predn, pred_kpts
|
89
112
|
|
90
113
|
def update_metrics(self, preds, batch):
|
91
|
-
"""
|
114
|
+
"""
|
115
|
+
Update metrics with new predictions and ground truth data.
|
116
|
+
|
117
|
+
This method processes each prediction, compares it with ground truth, and updates various statistics
|
118
|
+
for performance evaluation.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model.
|
122
|
+
batch (Dict): Batch data containing images and ground truth annotations.
|
123
|
+
"""
|
92
124
|
for si, pred in enumerate(preds):
|
93
125
|
self.seen += 1
|
94
126
|
npr = len(pred)
|
@@ -158,16 +190,9 @@ class PoseValidator(DetectionValidator):
|
|
158
190
|
(torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
159
191
|
where N is the number of detections.
|
160
192
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
>>> gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
|
165
|
-
>>> pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
|
166
|
-
>>> gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
|
167
|
-
>>> correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
|
168
|
-
|
169
|
-
Note:
|
170
|
-
`0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
193
|
+
Notes:
|
194
|
+
`0.53` scale factor used in area computation is referenced from
|
195
|
+
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
171
196
|
"""
|
172
197
|
if pred_kpts is not None and gt_kpts is not None:
|
173
198
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
@@ -179,7 +204,7 @@ class PoseValidator(DetectionValidator):
|
|
179
204
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
180
205
|
|
181
206
|
def plot_val_samples(self, batch, ni):
|
182
|
-
"""
|
207
|
+
"""Plot and save validation set samples with ground truth bounding boxes and keypoints."""
|
183
208
|
plot_images(
|
184
209
|
batch["img"],
|
185
210
|
batch["batch_idx"],
|
@@ -193,7 +218,7 @@ class PoseValidator(DetectionValidator):
|
|
193
218
|
)
|
194
219
|
|
195
220
|
def plot_predictions(self, batch, preds, ni):
|
196
|
-
"""
|
221
|
+
"""Plot and save model predictions with bounding boxes and keypoints."""
|
197
222
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
198
223
|
plot_images(
|
199
224
|
batch["img"],
|
@@ -218,7 +243,7 @@ class PoseValidator(DetectionValidator):
|
|
218
243
|
).save_txt(file, save_conf=save_conf)
|
219
244
|
|
220
245
|
def pred_to_json(self, predn, filename):
|
221
|
-
"""
|
246
|
+
"""Convert YOLO predictions to COCO JSON format."""
|
222
247
|
stem = Path(filename).stem
|
223
248
|
image_id = int(stem) if stem.isnumeric() else stem
|
224
249
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -235,7 +260,7 @@ class PoseValidator(DetectionValidator):
|
|
235
260
|
)
|
236
261
|
|
237
262
|
def eval_json(self, stats):
|
238
|
-
"""
|
263
|
+
"""Evaluate object detection model using COCO JSON format."""
|
239
264
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
240
265
|
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
241
266
|
pred_json = self.save_dir / "predictions.json" # predictions
|
@@ -9,6 +9,19 @@ class SegmentationPredictor(DetectionPredictor):
|
|
9
9
|
"""
|
10
10
|
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
11
11
|
|
12
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
13
|
+
prediction results.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
args (Dict): Configuration arguments for the predictor.
|
17
|
+
model (torch.nn.Module): The loaded YOLO segmentation model.
|
18
|
+
batch (List): Current batch of images being processed.
|
19
|
+
|
20
|
+
Methods:
|
21
|
+
postprocess: Applies non-max suppression and processes detections.
|
22
|
+
construct_results: Constructs a list of result objects from predictions.
|
23
|
+
construct_result: Constructs a single result object from a prediction.
|
24
|
+
|
12
25
|
Examples:
|
13
26
|
>>> from ultralytics.utils import ASSETS
|
14
27
|
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
@@ -18,19 +31,19 @@ class SegmentationPredictor(DetectionPredictor):
|
|
18
31
|
"""
|
19
32
|
|
20
33
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
21
|
-
"""
|
34
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks."""
|
22
35
|
super().__init__(cfg, overrides, _callbacks)
|
23
36
|
self.args.task = "segment"
|
24
37
|
|
25
38
|
def postprocess(self, preds, img, orig_imgs):
|
26
|
-
"""
|
27
|
-
# tuple if PyTorch model or array if exported
|
39
|
+
"""Apply non-max suppression and process detections for each image in the input batch."""
|
40
|
+
# Extract protos - tuple if PyTorch model or array if exported
|
28
41
|
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
29
42
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
30
43
|
|
31
44
|
def construct_results(self, preds, img, orig_imgs, protos):
|
32
45
|
"""
|
33
|
-
|
46
|
+
Construct a list of result objects from the predictions.
|
34
47
|
|
35
48
|
Args:
|
36
49
|
preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
@@ -39,7 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
39
52
|
protos (List[torch.Tensor]): List of prototype masks.
|
40
53
|
|
41
54
|
Returns:
|
42
|
-
(
|
55
|
+
(List[Results]): List of result objects containing the original images, image paths, class names,
|
56
|
+
bounding boxes, and masks.
|
43
57
|
"""
|
44
58
|
return [
|
45
59
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
@@ -48,7 +62,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
48
62
|
|
49
63
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
50
64
|
"""
|
51
|
-
|
65
|
+
Construct a single result object from the prediction.
|
52
66
|
|
53
67
|
Args:
|
54
68
|
pred (np.ndarray): The predicted bounding boxes, scores, and masks.
|
@@ -58,7 +72,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
58
72
|
proto (torch.Tensor): The prototype masks.
|
59
73
|
|
60
74
|
Returns:
|
61
|
-
(Results):
|
75
|
+
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
62
76
|
"""
|
63
77
|
if not len(pred): # save empty boxes
|
64
78
|
masks = None
|
@@ -69,6 +83,6 @@ class SegmentationPredictor(DetectionPredictor):
|
|
69
83
|
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
70
84
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
71
85
|
if masks is not None:
|
72
|
-
keep = masks.sum((-2, -1)) > 0 # only keep
|
86
|
+
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
|
73
87
|
pred, masks = pred[keep], masks[keep]
|
74
88
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
@@ -12,6 +12,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
12
12
|
"""
|
13
13
|
A class extending the DetectionTrainer class for training based on a segmentation model.
|
14
14
|
|
15
|
+
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
16
|
+
functionality including model initialization, validation, and visualization.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
loss_names (Tuple[str]): Names of the loss components used during training.
|
20
|
+
|
15
21
|
Examples:
|
16
22
|
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
17
23
|
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|