dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
|
@@ -15,14 +15,16 @@ import torch
|
|
|
15
15
|
from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
|
|
16
16
|
|
|
17
17
|
OKS_SIGMA = (
|
|
18
|
-
np.array(
|
|
18
|
+
np.array(
|
|
19
|
+
[0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89],
|
|
20
|
+
dtype=np.float32,
|
|
21
|
+
)
|
|
19
22
|
/ 10.0
|
|
20
23
|
)
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
|
|
24
|
-
"""
|
|
25
|
-
Calculate the intersection over box2 area given box1 and box2.
|
|
27
|
+
"""Calculate the intersection over box2 area given box1 and box2.
|
|
26
28
|
|
|
27
29
|
Args:
|
|
28
30
|
box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
|
|
@@ -53,8 +55,7 @@ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float =
|
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
56
|
-
"""
|
|
57
|
-
Calculate intersection-over-union (IoU) of boxes.
|
|
58
|
+
"""Calculate intersection-over-union (IoU) of boxes.
|
|
58
59
|
|
|
59
60
|
Args:
|
|
60
61
|
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
|
|
@@ -85,19 +86,17 @@ def bbox_iou(
|
|
|
85
86
|
CIoU: bool = False,
|
|
86
87
|
eps: float = 1e-7,
|
|
87
88
|
) -> torch.Tensor:
|
|
88
|
-
"""
|
|
89
|
-
Calculate the Intersection over Union (IoU) between bounding boxes.
|
|
89
|
+
"""Calculate the Intersection over Union (IoU) between bounding boxes.
|
|
90
90
|
|
|
91
|
-
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
or (x1, y1, x2, y2) if `xywh=False`.
|
|
91
|
+
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
|
|
92
|
+
may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
|
|
93
|
+
dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
|
|
95
94
|
|
|
96
95
|
Args:
|
|
97
96
|
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
98
97
|
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
99
|
-
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
|
100
|
-
|
|
98
|
+
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
|
|
99
|
+
x2, y2) format.
|
|
101
100
|
GIoU (bool, optional): If True, calculate Generalized IoU.
|
|
102
101
|
DIoU (bool, optional): If True, calculate Distance IoU.
|
|
103
102
|
CIoU (bool, optional): If True, calculate Complete IoU.
|
|
@@ -148,14 +147,13 @@ def bbox_iou(
|
|
|
148
147
|
|
|
149
148
|
|
|
150
149
|
def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
151
|
-
"""
|
|
152
|
-
Calculate masks IoU.
|
|
150
|
+
"""Calculate masks IoU.
|
|
153
151
|
|
|
154
152
|
Args:
|
|
155
153
|
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
|
|
156
|
-
|
|
157
|
-
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
|
158
|
-
|
|
154
|
+
product of image width and height.
|
|
155
|
+
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
|
|
156
|
+
of image width and height.
|
|
159
157
|
eps (float, optional): A small value to avoid division by zero.
|
|
160
158
|
|
|
161
159
|
Returns:
|
|
@@ -169,8 +167,7 @@ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> tor
|
|
|
169
167
|
def kpt_iou(
|
|
170
168
|
kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
|
|
171
169
|
) -> torch.Tensor:
|
|
172
|
-
"""
|
|
173
|
-
Calculate Object Keypoint Similarity (OKS).
|
|
170
|
+
"""Calculate Object Keypoint Similarity (OKS).
|
|
174
171
|
|
|
175
172
|
Args:
|
|
176
173
|
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
|
|
@@ -191,14 +188,14 @@ def kpt_iou(
|
|
|
191
188
|
|
|
192
189
|
|
|
193
190
|
def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
194
|
-
"""
|
|
195
|
-
Generate covariance matrix from oriented bounding boxes.
|
|
191
|
+
"""Generate covariance matrix from oriented bounding boxes.
|
|
196
192
|
|
|
197
193
|
Args:
|
|
198
194
|
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
|
199
195
|
|
|
200
196
|
Returns:
|
|
201
|
-
(torch.Tensor): Covariance
|
|
197
|
+
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Covariance matrix components (a, b, c) where the covariance
|
|
198
|
+
matrix is [[a, c], [c, b]], each of shape (N, 1).
|
|
202
199
|
"""
|
|
203
200
|
# Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
|
|
204
201
|
gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
|
|
@@ -211,8 +208,7 @@ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
|
|
|
211
208
|
|
|
212
209
|
|
|
213
210
|
def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
|
214
|
-
"""
|
|
215
|
-
Calculate probabilistic IoU between oriented bounding boxes.
|
|
211
|
+
"""Calculate probabilistic IoU between oriented bounding boxes.
|
|
216
212
|
|
|
217
213
|
Args:
|
|
218
214
|
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
|
@@ -257,8 +253,7 @@ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: flo
|
|
|
257
253
|
|
|
258
254
|
|
|
259
255
|
def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
|
|
260
|
-
"""
|
|
261
|
-
Calculate the probabilistic IoU between oriented bounding boxes.
|
|
256
|
+
"""Calculate the probabilistic IoU between oriented bounding boxes.
|
|
262
257
|
|
|
263
258
|
Args:
|
|
264
259
|
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
|
@@ -294,8 +289,7 @@ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarr
|
|
|
294
289
|
|
|
295
290
|
|
|
296
291
|
def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
|
|
297
|
-
"""
|
|
298
|
-
Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
292
|
+
"""Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
299
293
|
|
|
300
294
|
Args:
|
|
301
295
|
eps (float, optional): The epsilon value for label smoothing.
|
|
@@ -311,20 +305,18 @@ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
|
|
|
311
305
|
|
|
312
306
|
|
|
313
307
|
class ConfusionMatrix(DataExportMixin):
|
|
314
|
-
"""
|
|
315
|
-
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
308
|
+
"""A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
316
309
|
|
|
317
310
|
Attributes:
|
|
318
311
|
task (str): The type of task, either 'detect' or 'classify'.
|
|
319
312
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
|
320
|
-
nc (int): The number of
|
|
313
|
+
nc (int): The number of classes.
|
|
321
314
|
names (list[str]): The names of the classes, used as labels on the plot.
|
|
322
315
|
matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
|
|
323
316
|
"""
|
|
324
317
|
|
|
325
318
|
def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
|
|
326
|
-
"""
|
|
327
|
-
Initialize a ConfusionMatrix instance.
|
|
319
|
+
"""Initialize a ConfusionMatrix instance.
|
|
328
320
|
|
|
329
321
|
Args:
|
|
330
322
|
names (dict[int, str], optional): Names of classes, used as labels on the plot.
|
|
@@ -338,21 +330,20 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
338
330
|
self.matches = {} if save_matches else None
|
|
339
331
|
|
|
340
332
|
def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
|
|
341
|
-
"""
|
|
342
|
-
Append the matches to TP, FP, FN or GT list for the last batch.
|
|
333
|
+
"""Append the matches to TP, FP, FN or GT list for the last batch.
|
|
343
334
|
|
|
344
|
-
This method updates the matches dictionary by appending specific batch data
|
|
345
|
-
|
|
335
|
+
This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
|
|
336
|
+
Positive, False Positive, or False Negative).
|
|
346
337
|
|
|
347
338
|
Args:
|
|
348
339
|
mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
|
|
349
|
-
batch (dict[str, Any]): Batch data containing detection results with keys
|
|
350
|
-
|
|
340
|
+
batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
|
|
341
|
+
'keypoints', 'masks'.
|
|
351
342
|
idx (int): Index of the specific detection to append from the batch.
|
|
352
343
|
|
|
353
|
-
|
|
354
|
-
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
|
|
355
|
-
|
|
344
|
+
Notes:
|
|
345
|
+
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
|
|
346
|
+
overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
|
|
356
347
|
"""
|
|
357
348
|
if self.matches is None:
|
|
358
349
|
return
|
|
@@ -364,8 +355,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
364
355
|
self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
|
|
365
356
|
|
|
366
357
|
def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
|
|
367
|
-
"""
|
|
368
|
-
Update confusion matrix for classification task.
|
|
358
|
+
"""Update confusion matrix for classification task.
|
|
369
359
|
|
|
370
360
|
Args:
|
|
371
361
|
preds (list[N, min(nc,5)]): Predicted class labels.
|
|
@@ -382,15 +372,14 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
382
372
|
conf: float = 0.25,
|
|
383
373
|
iou_thres: float = 0.45,
|
|
384
374
|
) -> None:
|
|
385
|
-
"""
|
|
386
|
-
Update confusion matrix for object detection task.
|
|
375
|
+
"""Update confusion matrix for object detection task.
|
|
387
376
|
|
|
388
377
|
Args:
|
|
389
|
-
detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
|
|
393
|
-
'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
378
|
+
detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
|
|
379
|
+
information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
|
|
380
|
+
regular boxes or Array[N, 5] for OBB with angle.
|
|
381
|
+
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
|
|
382
|
+
5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
394
383
|
conf (float, optional): Confidence threshold for detections.
|
|
395
384
|
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
396
385
|
"""
|
|
@@ -460,8 +449,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
460
449
|
return self.matrix
|
|
461
450
|
|
|
462
451
|
def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
|
|
463
|
-
"""
|
|
464
|
-
Return true positives and false positives.
|
|
452
|
+
"""Return true positives and false positives.
|
|
465
453
|
|
|
466
454
|
Returns:
|
|
467
455
|
tp (np.ndarray): True positives.
|
|
@@ -473,8 +461,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
473
461
|
return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
|
|
474
462
|
|
|
475
463
|
def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
|
|
476
|
-
"""
|
|
477
|
-
Plot grid of GT, TP, FP, FN for each image.
|
|
464
|
+
"""Plot grid of GT, TP, FP, FN for each image.
|
|
478
465
|
|
|
479
466
|
Args:
|
|
480
467
|
img (torch.Tensor): Image to plot onto.
|
|
@@ -513,8 +500,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
513
500
|
@TryExcept(msg="ConfusionMatrix plot failure")
|
|
514
501
|
@plt_settings()
|
|
515
502
|
def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
|
|
516
|
-
"""
|
|
517
|
-
Plot the confusion matrix using matplotlib and save it to a file.
|
|
503
|
+
"""Plot the confusion matrix using matplotlib and save it to a file.
|
|
518
504
|
|
|
519
505
|
Args:
|
|
520
506
|
normalize (bool, optional): Whether to normalize the confusion matrix.
|
|
@@ -535,7 +521,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
535
521
|
array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
|
|
536
522
|
n = (self.nc + k - 1) // k # number of retained classes
|
|
537
523
|
nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
|
|
538
|
-
ticklabels = (names
|
|
524
|
+
ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
|
|
539
525
|
xy_ticks = np.arange(len(ticklabels))
|
|
540
526
|
tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
|
|
541
527
|
label_fontsize = max(6, 12 - 0.1 * nc)
|
|
@@ -582,7 +568,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
582
568
|
fig.savefig(plot_fname, dpi=250)
|
|
583
569
|
plt.close(fig)
|
|
584
570
|
if on_plot:
|
|
585
|
-
on_plot(plot_fname)
|
|
571
|
+
on_plot(plot_fname, {"type": "confusion_matrix", "matrix": self.matrix.tolist()})
|
|
586
572
|
|
|
587
573
|
def print(self):
|
|
588
574
|
"""Print the confusion matrix to the console."""
|
|
@@ -590,16 +576,17 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
590
576
|
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
|
591
577
|
|
|
592
578
|
def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
|
|
593
|
-
"""
|
|
594
|
-
|
|
595
|
-
|
|
579
|
+
"""Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
|
|
580
|
+
normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
|
|
581
|
+
or SQL.
|
|
596
582
|
|
|
597
583
|
Args:
|
|
598
584
|
normalize (bool): Whether to normalize the confusion matrix values.
|
|
599
585
|
decimals (int): Number of decimal places to round the output values to.
|
|
600
586
|
|
|
601
587
|
Returns:
|
|
602
|
-
(list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
|
|
588
|
+
(list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
|
|
589
|
+
values for all actual classes.
|
|
603
590
|
|
|
604
591
|
Examples:
|
|
605
592
|
>>> results = model.val(data="coco8.yaml", plots=True)
|
|
@@ -608,7 +595,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
608
595
|
"""
|
|
609
596
|
import re
|
|
610
597
|
|
|
611
|
-
names = list(self.names.values()) if self.task == "classify" else list(self.names.values())
|
|
598
|
+
names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
|
|
612
599
|
clean_names, seen = [], set()
|
|
613
600
|
for name in names:
|
|
614
601
|
clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
|
@@ -643,8 +630,7 @@ def plot_pr_curve(
|
|
|
643
630
|
names: dict[int, str] = {},
|
|
644
631
|
on_plot=None,
|
|
645
632
|
):
|
|
646
|
-
"""
|
|
647
|
-
Plot precision-recall curve.
|
|
633
|
+
"""Plot precision-recall curve.
|
|
648
634
|
|
|
649
635
|
Args:
|
|
650
636
|
px (np.ndarray): X values for the PR curve.
|
|
@@ -663,7 +649,7 @@ def plot_pr_curve(
|
|
|
663
649
|
for i, y in enumerate(py.T):
|
|
664
650
|
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
|
665
651
|
else:
|
|
666
|
-
ax.plot(px, py, linewidth=1, color="
|
|
652
|
+
ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
|
|
667
653
|
|
|
668
654
|
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
|
|
669
655
|
ax.set_xlabel("Recall")
|
|
@@ -675,7 +661,9 @@ def plot_pr_curve(
|
|
|
675
661
|
fig.savefig(save_dir, dpi=250)
|
|
676
662
|
plt.close(fig)
|
|
677
663
|
if on_plot:
|
|
678
|
-
|
|
664
|
+
# Pass PR curve data for interactive plotting (class names stored at model level)
|
|
665
|
+
# Transpose py to match other curves: y[class][point] format
|
|
666
|
+
on_plot(save_dir, {"type": "pr_curve", "x": px.tolist(), "y": py.T.tolist(), "ap": ap.tolist()})
|
|
679
667
|
|
|
680
668
|
|
|
681
669
|
@plt_settings()
|
|
@@ -688,8 +676,7 @@ def plot_mc_curve(
|
|
|
688
676
|
ylabel: str = "Metric",
|
|
689
677
|
on_plot=None,
|
|
690
678
|
):
|
|
691
|
-
"""
|
|
692
|
-
Plot metric-confidence curve.
|
|
679
|
+
"""Plot metric-confidence curve.
|
|
693
680
|
|
|
694
681
|
Args:
|
|
695
682
|
px (np.ndarray): X values for the metric-confidence curve.
|
|
@@ -708,7 +695,7 @@ def plot_mc_curve(
|
|
|
708
695
|
for i, y in enumerate(py):
|
|
709
696
|
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
|
710
697
|
else:
|
|
711
|
-
ax.plot(px, py.T, linewidth=1, color="
|
|
698
|
+
ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
|
|
712
699
|
|
|
713
700
|
y = smooth(py.mean(0), 0.1)
|
|
714
701
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
|
@@ -721,12 +708,12 @@ def plot_mc_curve(
|
|
|
721
708
|
fig.savefig(save_dir, dpi=250)
|
|
722
709
|
plt.close(fig)
|
|
723
710
|
if on_plot:
|
|
724
|
-
|
|
711
|
+
# Pass metric-confidence curve data for interactive plotting (class names stored at model level)
|
|
712
|
+
on_plot(save_dir, {"type": f"{ylabel.lower()}_curve", "x": px.tolist(), "y": py.tolist()})
|
|
725
713
|
|
|
726
714
|
|
|
727
715
|
def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
|
|
728
|
-
"""
|
|
729
|
-
Compute the average precision (AP) given the recall and precision curves.
|
|
716
|
+
"""Compute the average precision (AP) given the recall and precision curves.
|
|
730
717
|
|
|
731
718
|
Args:
|
|
732
719
|
recall (list): The recall curve.
|
|
@@ -769,8 +756,7 @@ def ap_per_class(
|
|
|
769
756
|
eps: float = 1e-16,
|
|
770
757
|
prefix: str = "",
|
|
771
758
|
) -> tuple:
|
|
772
|
-
"""
|
|
773
|
-
Compute the average precision per class for object detection evaluation.
|
|
759
|
+
"""Compute the average precision per class for object detection evaluation.
|
|
774
760
|
|
|
775
761
|
Args:
|
|
776
762
|
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
|
@@ -855,8 +841,7 @@ def ap_per_class(
|
|
|
855
841
|
|
|
856
842
|
|
|
857
843
|
class Metric(SimpleClass):
|
|
858
|
-
"""
|
|
859
|
-
Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
844
|
+
"""Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
860
845
|
|
|
861
846
|
Attributes:
|
|
862
847
|
p (list): Precision for each class. Shape: (nc,).
|
|
@@ -894,8 +879,7 @@ class Metric(SimpleClass):
|
|
|
894
879
|
|
|
895
880
|
@property
|
|
896
881
|
def ap50(self) -> np.ndarray | list:
|
|
897
|
-
"""
|
|
898
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
882
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
899
883
|
|
|
900
884
|
Returns:
|
|
901
885
|
(np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
|
@@ -904,8 +888,7 @@ class Metric(SimpleClass):
|
|
|
904
888
|
|
|
905
889
|
@property
|
|
906
890
|
def ap(self) -> np.ndarray | list:
|
|
907
|
-
"""
|
|
908
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
891
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
909
892
|
|
|
910
893
|
Returns:
|
|
911
894
|
(np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
|
@@ -914,8 +897,7 @@ class Metric(SimpleClass):
|
|
|
914
897
|
|
|
915
898
|
@property
|
|
916
899
|
def mp(self) -> float:
|
|
917
|
-
"""
|
|
918
|
-
Return the Mean Precision of all classes.
|
|
900
|
+
"""Return the Mean Precision of all classes.
|
|
919
901
|
|
|
920
902
|
Returns:
|
|
921
903
|
(float): The mean precision of all classes.
|
|
@@ -924,8 +906,7 @@ class Metric(SimpleClass):
|
|
|
924
906
|
|
|
925
907
|
@property
|
|
926
908
|
def mr(self) -> float:
|
|
927
|
-
"""
|
|
928
|
-
Return the Mean Recall of all classes.
|
|
909
|
+
"""Return the Mean Recall of all classes.
|
|
929
910
|
|
|
930
911
|
Returns:
|
|
931
912
|
(float): The mean recall of all classes.
|
|
@@ -934,8 +915,7 @@ class Metric(SimpleClass):
|
|
|
934
915
|
|
|
935
916
|
@property
|
|
936
917
|
def map50(self) -> float:
|
|
937
|
-
"""
|
|
938
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
918
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
939
919
|
|
|
940
920
|
Returns:
|
|
941
921
|
(float): The mAP at an IoU threshold of 0.5.
|
|
@@ -944,8 +924,7 @@ class Metric(SimpleClass):
|
|
|
944
924
|
|
|
945
925
|
@property
|
|
946
926
|
def map75(self) -> float:
|
|
947
|
-
"""
|
|
948
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
927
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
949
928
|
|
|
950
929
|
Returns:
|
|
951
930
|
(float): The mAP at an IoU threshold of 0.75.
|
|
@@ -954,8 +933,7 @@ class Metric(SimpleClass):
|
|
|
954
933
|
|
|
955
934
|
@property
|
|
956
935
|
def map(self) -> float:
|
|
957
|
-
"""
|
|
958
|
-
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
936
|
+
"""Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
959
937
|
|
|
960
938
|
Returns:
|
|
961
939
|
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
@@ -984,8 +962,7 @@ class Metric(SimpleClass):
|
|
|
984
962
|
return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
|
|
985
963
|
|
|
986
964
|
def update(self, results: tuple):
|
|
987
|
-
"""
|
|
988
|
-
Update the evaluation metrics with a new set of results.
|
|
965
|
+
"""Update the evaluation metrics with a new set of results.
|
|
989
966
|
|
|
990
967
|
Args:
|
|
991
968
|
results (tuple): A tuple containing evaluation metrics:
|
|
@@ -1030,15 +1007,15 @@ class Metric(SimpleClass):
|
|
|
1030
1007
|
|
|
1031
1008
|
|
|
1032
1009
|
class DetMetrics(SimpleClass, DataExportMixin):
|
|
1033
|
-
"""
|
|
1034
|
-
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
1010
|
+
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
1035
1011
|
|
|
1036
1012
|
Attributes:
|
|
1037
1013
|
names (dict[int, str]): A dictionary of class names.
|
|
1038
1014
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1039
1015
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1040
1016
|
task (str): The task type, set to 'detect'.
|
|
1041
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1017
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1018
|
+
target classes, and target images.
|
|
1042
1019
|
nt_per_class: Number of targets per class.
|
|
1043
1020
|
nt_per_image: Number of targets per image.
|
|
1044
1021
|
|
|
@@ -1059,8 +1036,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1059
1036
|
"""
|
|
1060
1037
|
|
|
1061
1038
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1062
|
-
"""
|
|
1063
|
-
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
1039
|
+
"""Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
1064
1040
|
|
|
1065
1041
|
Args:
|
|
1066
1042
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1074,19 +1050,17 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1074
1050
|
self.nt_per_image = None
|
|
1075
1051
|
|
|
1076
1052
|
def update_stats(self, stat: dict[str, Any]) -> None:
|
|
1077
|
-
"""
|
|
1078
|
-
Update statistics by appending new values to existing stat collections.
|
|
1053
|
+
"""Update statistics by appending new values to existing stat collections.
|
|
1079
1054
|
|
|
1080
1055
|
Args:
|
|
1081
|
-
stat (dict[str, any]): Dictionary containing new statistical values to append.
|
|
1082
|
-
|
|
1056
|
+
stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
|
|
1057
|
+
keys in self.stats.
|
|
1083
1058
|
"""
|
|
1084
1059
|
for k in self.stats.keys():
|
|
1085
1060
|
self.stats[k].append(stat[k])
|
|
1086
1061
|
|
|
1087
1062
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1088
|
-
"""
|
|
1089
|
-
Process predicted results for object detection and update metrics.
|
|
1063
|
+
"""Process predicted results for object detection and update metrics.
|
|
1090
1064
|
|
|
1091
1065
|
Args:
|
|
1092
1066
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1152,8 +1126,8 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1152
1126
|
@property
|
|
1153
1127
|
def results_dict(self) -> dict[str, float]:
|
|
1154
1128
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
1155
|
-
keys = self.keys
|
|
1156
|
-
values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results()
|
|
1129
|
+
keys = [*self.keys, "fitness"]
|
|
1130
|
+
values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
|
|
1157
1131
|
return dict(zip(keys, values))
|
|
1158
1132
|
|
|
1159
1133
|
@property
|
|
@@ -1167,16 +1141,16 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1167
1141
|
return self.box.curves_results
|
|
1168
1142
|
|
|
1169
1143
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1170
|
-
"""
|
|
1171
|
-
|
|
1172
|
-
scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1144
|
+
"""Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
|
|
1145
|
+
shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1173
1146
|
|
|
1174
1147
|
Args:
|
|
1175
|
-
|
|
1176
|
-
|
|
1148
|
+
normalize (bool): For Detect metrics, everything is normalized by default [0-1].
|
|
1149
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1177
1150
|
|
|
1178
1151
|
Returns:
|
|
1179
|
-
|
|
1152
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1153
|
+
values.
|
|
1180
1154
|
|
|
1181
1155
|
Examples:
|
|
1182
1156
|
>>> results = model.val(data="coco8.yaml")
|
|
@@ -1202,8 +1176,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1202
1176
|
|
|
1203
1177
|
|
|
1204
1178
|
class SegmentMetrics(DetMetrics):
|
|
1205
|
-
"""
|
|
1206
|
-
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
1179
|
+
"""Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
1207
1180
|
|
|
1208
1181
|
Attributes:
|
|
1209
1182
|
names (dict[int, str]): Dictionary of class names.
|
|
@@ -1211,7 +1184,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1211
1184
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
1212
1185
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1213
1186
|
task (str): The task type, set to 'segment'.
|
|
1214
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1187
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1188
|
+
target classes, and target images.
|
|
1215
1189
|
nt_per_class: Number of targets per class.
|
|
1216
1190
|
nt_per_image: Number of targets per image.
|
|
1217
1191
|
|
|
@@ -1228,8 +1202,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1228
1202
|
"""
|
|
1229
1203
|
|
|
1230
1204
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1231
|
-
"""
|
|
1232
|
-
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
1205
|
+
"""Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
1233
1206
|
|
|
1234
1207
|
Args:
|
|
1235
1208
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1240,8 +1213,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1240
1213
|
self.stats["tp_m"] = [] # add additional stats for masks
|
|
1241
1214
|
|
|
1242
1215
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1243
|
-
"""
|
|
1244
|
-
Process the detection and segmentation metrics over the given set of predictions.
|
|
1216
|
+
"""Process the detection and segmentation metrics over the given set of predictions.
|
|
1245
1217
|
|
|
1246
1218
|
Args:
|
|
1247
1219
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1270,7 +1242,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1270
1242
|
@property
|
|
1271
1243
|
def keys(self) -> list[str]:
|
|
1272
1244
|
"""Return a list of keys for accessing metrics."""
|
|
1273
|
-
return
|
|
1245
|
+
return [
|
|
1246
|
+
*DetMetrics.keys.fget(self),
|
|
1274
1247
|
"metrics/precision(M)",
|
|
1275
1248
|
"metrics/recall(M)",
|
|
1276
1249
|
"metrics/mAP50(M)",
|
|
@@ -1298,7 +1271,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1298
1271
|
@property
|
|
1299
1272
|
def curves(self) -> list[str]:
|
|
1300
1273
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1301
|
-
return
|
|
1274
|
+
return [
|
|
1275
|
+
*DetMetrics.curves.fget(self),
|
|
1302
1276
|
"Precision-Recall(M)",
|
|
1303
1277
|
"F1-Confidence(M)",
|
|
1304
1278
|
"Precision-Confidence(M)",
|
|
@@ -1311,16 +1285,17 @@ class SegmentMetrics(DetMetrics):
|
|
|
1311
1285
|
return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
|
1312
1286
|
|
|
1313
1287
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1314
|
-
"""
|
|
1315
|
-
|
|
1316
|
-
|
|
1288
|
+
"""Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
|
|
1289
|
+
both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
|
|
1290
|
+
each class.
|
|
1317
1291
|
|
|
1318
1292
|
Args:
|
|
1319
|
-
normalize (bool): For Segment metrics, everything is normalized
|
|
1293
|
+
normalize (bool): For Segment metrics, everything is normalized by default [0-1].
|
|
1320
1294
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1321
1295
|
|
|
1322
1296
|
Returns:
|
|
1323
|
-
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1297
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1298
|
+
values.
|
|
1324
1299
|
|
|
1325
1300
|
Examples:
|
|
1326
1301
|
>>> results = model.val(data="coco8-seg.yaml")
|
|
@@ -1339,8 +1314,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1339
1314
|
|
|
1340
1315
|
|
|
1341
1316
|
class PoseMetrics(DetMetrics):
|
|
1342
|
-
"""
|
|
1343
|
-
Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
1317
|
+
"""Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
1344
1318
|
|
|
1345
1319
|
Attributes:
|
|
1346
1320
|
names (dict[int, str]): Dictionary of class names.
|
|
@@ -1348,7 +1322,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1348
1322
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1349
1323
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1350
1324
|
task (str): The task type, set to 'pose'.
|
|
1351
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1325
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1326
|
+
target classes, and target images.
|
|
1352
1327
|
nt_per_class: Number of targets per class.
|
|
1353
1328
|
nt_per_image: Number of targets per image.
|
|
1354
1329
|
|
|
@@ -1365,8 +1340,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1365
1340
|
"""
|
|
1366
1341
|
|
|
1367
1342
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1368
|
-
"""
|
|
1369
|
-
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1343
|
+
"""Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1370
1344
|
|
|
1371
1345
|
Args:
|
|
1372
1346
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1377,8 +1351,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1377
1351
|
self.stats["tp_p"] = [] # add additional stats for pose
|
|
1378
1352
|
|
|
1379
1353
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1380
|
-
"""
|
|
1381
|
-
Process the detection and pose metrics over the given set of predictions.
|
|
1354
|
+
"""Process the detection and pose metrics over the given set of predictions.
|
|
1382
1355
|
|
|
1383
1356
|
Args:
|
|
1384
1357
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1407,7 +1380,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1407
1380
|
@property
|
|
1408
1381
|
def keys(self) -> list[str]:
|
|
1409
1382
|
"""Return a list of evaluation metric keys."""
|
|
1410
|
-
return
|
|
1383
|
+
return [
|
|
1384
|
+
*DetMetrics.keys.fget(self),
|
|
1411
1385
|
"metrics/precision(P)",
|
|
1412
1386
|
"metrics/recall(P)",
|
|
1413
1387
|
"metrics/mAP50(P)",
|
|
@@ -1435,7 +1409,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1435
1409
|
@property
|
|
1436
1410
|
def curves(self) -> list[str]:
|
|
1437
1411
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1438
|
-
return
|
|
1412
|
+
return [
|
|
1413
|
+
*DetMetrics.curves.fget(self),
|
|
1439
1414
|
"Precision-Recall(B)",
|
|
1440
1415
|
"F1-Confidence(B)",
|
|
1441
1416
|
"Precision-Confidence(B)",
|
|
@@ -1452,16 +1427,16 @@ class PoseMetrics(DetMetrics):
|
|
|
1452
1427
|
return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
|
1453
1428
|
|
|
1454
1429
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1455
|
-
"""
|
|
1456
|
-
|
|
1457
|
-
pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1430
|
+
"""Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
|
|
1431
|
+
and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1458
1432
|
|
|
1459
1433
|
Args:
|
|
1460
|
-
normalize (bool): For Pose metrics, everything is normalized
|
|
1434
|
+
normalize (bool): For Pose metrics, everything is normalized by default [0-1].
|
|
1461
1435
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1462
1436
|
|
|
1463
1437
|
Returns:
|
|
1464
|
-
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1438
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1439
|
+
values.
|
|
1465
1440
|
|
|
1466
1441
|
Examples:
|
|
1467
1442
|
>>> results = model.val(data="coco8-pose.yaml")
|
|
@@ -1480,8 +1455,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1480
1455
|
|
|
1481
1456
|
|
|
1482
1457
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1483
|
-
"""
|
|
1484
|
-
Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
1458
|
+
"""Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
1485
1459
|
|
|
1486
1460
|
Attributes:
|
|
1487
1461
|
top1 (float): The top-1 accuracy.
|
|
@@ -1507,8 +1481,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1507
1481
|
self.task = "classify"
|
|
1508
1482
|
|
|
1509
1483
|
def process(self, targets: torch.Tensor, pred: torch.Tensor):
|
|
1510
|
-
"""
|
|
1511
|
-
Process target classes and predicted classes to compute metrics.
|
|
1484
|
+
"""Process target classes and predicted classes to compute metrics.
|
|
1512
1485
|
|
|
1513
1486
|
Args:
|
|
1514
1487
|
targets (torch.Tensor): Target classes.
|
|
@@ -1527,7 +1500,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1527
1500
|
@property
|
|
1528
1501
|
def results_dict(self) -> dict[str, float]:
|
|
1529
1502
|
"""Return a dictionary with model's performance metrics and fitness score."""
|
|
1530
|
-
return dict(zip(self.keys
|
|
1503
|
+
return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
|
|
1531
1504
|
|
|
1532
1505
|
@property
|
|
1533
1506
|
def keys(self) -> list[str]:
|
|
@@ -1545,11 +1518,10 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1545
1518
|
return []
|
|
1546
1519
|
|
|
1547
1520
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
|
|
1548
|
-
"""
|
|
1549
|
-
Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1521
|
+
"""Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1550
1522
|
|
|
1551
1523
|
Args:
|
|
1552
|
-
normalize (bool): For Classify metrics, everything is normalized
|
|
1524
|
+
normalize (bool): For Classify metrics, everything is normalized by default [0-1].
|
|
1553
1525
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1554
1526
|
|
|
1555
1527
|
Returns:
|
|
@@ -1564,15 +1536,15 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1564
1536
|
|
|
1565
1537
|
|
|
1566
1538
|
class OBBMetrics(DetMetrics):
|
|
1567
|
-
"""
|
|
1568
|
-
Metrics for evaluating oriented bounding box (OBB) detection.
|
|
1539
|
+
"""Metrics for evaluating oriented bounding box (OBB) detection.
|
|
1569
1540
|
|
|
1570
1541
|
Attributes:
|
|
1571
1542
|
names (dict[int, str]): Dictionary of class names.
|
|
1572
1543
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1573
1544
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1574
1545
|
task (str): The task type, set to 'obb'.
|
|
1575
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1546
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1547
|
+
target classes, and target images.
|
|
1576
1548
|
nt_per_class: Number of targets per class.
|
|
1577
1549
|
nt_per_image: Number of targets per image.
|
|
1578
1550
|
|
|
@@ -1581,8 +1553,7 @@ class OBBMetrics(DetMetrics):
|
|
|
1581
1553
|
"""
|
|
1582
1554
|
|
|
1583
1555
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1584
|
-
"""
|
|
1585
|
-
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1556
|
+
"""Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1586
1557
|
|
|
1587
1558
|
Args:
|
|
1588
1559
|
names (dict[int, str], optional): Dictionary of class names.
|