dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
"""Model validation metrics."""
|
|
3
3
|
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
4
6
|
import math
|
|
5
7
|
import warnings
|
|
8
|
+
from collections import defaultdict
|
|
6
9
|
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
7
11
|
|
|
8
12
|
import numpy as np
|
|
9
13
|
import torch
|
|
10
14
|
|
|
11
|
-
from ultralytics.utils import LOGGER, SimpleClass, TryExcept, checks, plt_settings
|
|
15
|
+
from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
|
|
12
16
|
|
|
13
17
|
OKS_SIGMA = (
|
|
14
18
|
np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
|
|
@@ -16,18 +20,17 @@ OKS_SIGMA = (
|
|
|
16
20
|
)
|
|
17
21
|
|
|
18
22
|
|
|
19
|
-
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
20
|
-
"""
|
|
21
|
-
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
|
23
|
+
def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
|
|
24
|
+
"""Calculate the intersection over box2 area given box1 and box2.
|
|
22
25
|
|
|
23
26
|
Args:
|
|
24
|
-
box1 (np.ndarray): A numpy array of shape (
|
|
25
|
-
box2 (np.ndarray): A numpy array of shape (
|
|
26
|
-
iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
|
|
27
|
+
box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
|
|
28
|
+
box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.
|
|
29
|
+
iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.
|
|
27
30
|
eps (float, optional): A small value to avoid division by zero.
|
|
28
31
|
|
|
29
32
|
Returns:
|
|
30
|
-
(np.ndarray): A numpy array of shape (
|
|
33
|
+
(np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.
|
|
31
34
|
"""
|
|
32
35
|
# Get the coordinates of bounding boxes
|
|
33
36
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
|
@@ -48,18 +51,19 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
|
48
51
|
return inter_area / (area + eps)
|
|
49
52
|
|
|
50
53
|
|
|
51
|
-
def box_iou(box1, box2, eps=1e-7):
|
|
52
|
-
"""
|
|
53
|
-
Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
|
54
|
-
Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
|
|
54
|
+
def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
55
|
+
"""Calculate intersection-over-union (IoU) of boxes.
|
|
55
56
|
|
|
56
57
|
Args:
|
|
57
|
-
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
|
58
|
-
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
|
58
|
+
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
|
|
59
|
+
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.
|
|
59
60
|
eps (float, optional): A small value to avoid division by zero.
|
|
60
61
|
|
|
61
62
|
Returns:
|
|
62
63
|
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
|
64
|
+
|
|
65
|
+
References:
|
|
66
|
+
https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py
|
|
63
67
|
"""
|
|
64
68
|
# NOTE: Need .float() to get accurate iou values
|
|
65
69
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
|
@@ -70,20 +74,26 @@ def box_iou(box1, box2, eps=1e-7):
|
|
|
70
74
|
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
|
71
75
|
|
|
72
76
|
|
|
73
|
-
def bbox_iou(
|
|
74
|
-
|
|
75
|
-
|
|
77
|
+
def bbox_iou(
|
|
78
|
+
box1: torch.Tensor,
|
|
79
|
+
box2: torch.Tensor,
|
|
80
|
+
xywh: bool = True,
|
|
81
|
+
GIoU: bool = False,
|
|
82
|
+
DIoU: bool = False,
|
|
83
|
+
CIoU: bool = False,
|
|
84
|
+
eps: float = 1e-7,
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
"""Calculate the Intersection over Union (IoU) between bounding boxes.
|
|
76
87
|
|
|
77
|
-
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
or (x1, y1, x2, y2) if `xywh=False`.
|
|
88
|
+
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
|
|
89
|
+
may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
|
|
90
|
+
dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
|
|
81
91
|
|
|
82
92
|
Args:
|
|
83
93
|
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
84
94
|
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
85
|
-
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
|
86
|
-
|
|
95
|
+
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
|
|
96
|
+
x2, y2) format.
|
|
87
97
|
GIoU (bool, optional): If True, calculate Generalized IoU.
|
|
88
98
|
DIoU (bool, optional): If True, calculate Distance IoU.
|
|
89
99
|
CIoU (bool, optional): If True, calculate Complete IoU.
|
|
@@ -133,15 +143,14 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
|
133
143
|
return iou # IoU
|
|
134
144
|
|
|
135
145
|
|
|
136
|
-
def mask_iou(mask1, mask2, eps=1e-7):
|
|
137
|
-
"""
|
|
138
|
-
Calculate masks IoU.
|
|
146
|
+
def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
147
|
+
"""Calculate masks IoU.
|
|
139
148
|
|
|
140
149
|
Args:
|
|
141
150
|
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
|
|
142
|
-
|
|
143
|
-
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
|
144
|
-
|
|
151
|
+
product of image width and height.
|
|
152
|
+
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
|
|
153
|
+
of image width and height.
|
|
145
154
|
eps (float, optional): A small value to avoid division by zero.
|
|
146
155
|
|
|
147
156
|
Returns:
|
|
@@ -152,9 +161,10 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
|
|
152
161
|
return intersection / (union + eps)
|
|
153
162
|
|
|
154
163
|
|
|
155
|
-
def kpt_iou(
|
|
156
|
-
|
|
157
|
-
|
|
164
|
+
def kpt_iou(
|
|
165
|
+
kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
|
|
166
|
+
) -> torch.Tensor:
|
|
167
|
+
"""Calculate Object Keypoint Similarity (OKS).
|
|
158
168
|
|
|
159
169
|
Args:
|
|
160
170
|
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
|
|
@@ -174,9 +184,8 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|
|
174
184
|
return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
|
175
185
|
|
|
176
186
|
|
|
177
|
-
def _get_covariance_matrix(boxes):
|
|
178
|
-
"""
|
|
179
|
-
Generate covariance matrix from oriented bounding boxes.
|
|
187
|
+
def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
188
|
+
"""Generate covariance matrix from oriented bounding boxes.
|
|
180
189
|
|
|
181
190
|
Args:
|
|
182
191
|
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
|
@@ -194,9 +203,8 @@ def _get_covariance_matrix(boxes):
|
|
|
194
203
|
return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
|
|
195
204
|
|
|
196
205
|
|
|
197
|
-
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
198
|
-
"""
|
|
199
|
-
Calculate probabilistic IoU between oriented bounding boxes.
|
|
206
|
+
def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
|
207
|
+
"""Calculate probabilistic IoU between oriented bounding boxes.
|
|
200
208
|
|
|
201
209
|
Args:
|
|
202
210
|
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
|
@@ -208,8 +216,10 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
|
208
216
|
(torch.Tensor): OBB similarities, shape (N,).
|
|
209
217
|
|
|
210
218
|
Notes:
|
|
211
|
-
|
|
212
|
-
|
|
219
|
+
OBB format: [center_x, center_y, width, height, rotation_angle].
|
|
220
|
+
|
|
221
|
+
References:
|
|
222
|
+
https://arxiv.org/pdf/2106.06072v1.pdf
|
|
213
223
|
"""
|
|
214
224
|
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
|
215
225
|
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
|
@@ -238,9 +248,8 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
|
238
248
|
return iou
|
|
239
249
|
|
|
240
250
|
|
|
241
|
-
def batch_probiou(obb1, obb2, eps=1e-7):
|
|
242
|
-
"""
|
|
243
|
-
Calculate the probabilistic IoU between oriented bounding boxes.
|
|
251
|
+
def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
|
|
252
|
+
"""Calculate the probabilistic IoU between oriented bounding boxes.
|
|
244
253
|
|
|
245
254
|
Args:
|
|
246
255
|
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
|
@@ -275,15 +284,15 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|
|
275
284
|
return 1 - hd
|
|
276
285
|
|
|
277
286
|
|
|
278
|
-
def smooth_bce(eps=0.1):
|
|
279
|
-
"""
|
|
280
|
-
Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
287
|
+
def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
|
|
288
|
+
"""Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
281
289
|
|
|
282
290
|
Args:
|
|
283
291
|
eps (float, optional): The epsilon value for label smoothing.
|
|
284
292
|
|
|
285
293
|
Returns:
|
|
286
|
-
(
|
|
294
|
+
pos (float): Positive label smoothing BCE target.
|
|
295
|
+
neg (float): Negative label smoothing BCE target.
|
|
287
296
|
|
|
288
297
|
References:
|
|
289
298
|
https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
|
@@ -291,81 +300,115 @@ def smooth_bce(eps=0.1):
|
|
|
291
300
|
return 1.0 - 0.5 * eps, 0.5 * eps
|
|
292
301
|
|
|
293
302
|
|
|
294
|
-
class ConfusionMatrix:
|
|
295
|
-
"""
|
|
296
|
-
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
303
|
+
class ConfusionMatrix(DataExportMixin):
|
|
304
|
+
"""A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
297
305
|
|
|
298
306
|
Attributes:
|
|
299
307
|
task (str): The type of task, either 'detect' or 'classify'.
|
|
300
308
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
|
301
|
-
nc (int): The number of
|
|
302
|
-
|
|
303
|
-
|
|
309
|
+
nc (int): The number of category.
|
|
310
|
+
names (list[str]): The names of the classes, used as labels on the plot.
|
|
311
|
+
matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
|
|
304
312
|
"""
|
|
305
313
|
|
|
306
|
-
def __init__(self,
|
|
307
|
-
"""
|
|
308
|
-
Initialize a ConfusionMatrix instance.
|
|
314
|
+
def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
|
|
315
|
+
"""Initialize a ConfusionMatrix instance.
|
|
309
316
|
|
|
310
317
|
Args:
|
|
311
|
-
|
|
312
|
-
conf (float, optional): Confidence threshold for detections.
|
|
313
|
-
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
318
|
+
names (dict[int, str], optional): Names of classes, used as labels on the plot.
|
|
314
319
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
|
320
|
+
save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
|
|
315
321
|
"""
|
|
316
322
|
self.task = task
|
|
317
|
-
self.
|
|
318
|
-
self.
|
|
319
|
-
self.
|
|
320
|
-
self.
|
|
323
|
+
self.nc = len(names) # number of classes
|
|
324
|
+
self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
|
|
325
|
+
self.names = names # name of classes
|
|
326
|
+
self.matches = {} if save_matches else None
|
|
327
|
+
|
|
328
|
+
def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
|
|
329
|
+
"""Append the matches to TP, FP, FN or GT list for the last batch.
|
|
321
330
|
|
|
322
|
-
|
|
331
|
+
This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
|
|
332
|
+
Positive, False Positive, or False Negative).
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
|
|
336
|
+
batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
|
|
337
|
+
'keypoints', 'masks'.
|
|
338
|
+
idx (int): Index of the specific detection to append from the batch.
|
|
339
|
+
|
|
340
|
+
Notes:
|
|
341
|
+
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
|
|
342
|
+
overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
|
|
323
343
|
"""
|
|
324
|
-
|
|
344
|
+
if self.matches is None:
|
|
345
|
+
return
|
|
346
|
+
for k, v in batch.items():
|
|
347
|
+
if k in {"bboxes", "cls", "conf", "keypoints"}:
|
|
348
|
+
self.matches[mtype][k] += v[[idx]]
|
|
349
|
+
elif k == "masks":
|
|
350
|
+
# NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
|
|
351
|
+
self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
|
|
352
|
+
|
|
353
|
+
def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
|
|
354
|
+
"""Update confusion matrix for classification task.
|
|
325
355
|
|
|
326
356
|
Args:
|
|
327
|
-
preds (
|
|
328
|
-
targets (
|
|
357
|
+
preds (list[N, min(nc,5)]): Predicted class labels.
|
|
358
|
+
targets (list[N, 1]): Ground truth class labels.
|
|
329
359
|
"""
|
|
330
360
|
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
|
|
331
361
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
|
332
362
|
self.matrix[p][t] += 1
|
|
333
363
|
|
|
334
|
-
def process_batch(
|
|
335
|
-
|
|
336
|
-
|
|
364
|
+
def process_batch(
|
|
365
|
+
self,
|
|
366
|
+
detections: dict[str, torch.Tensor],
|
|
367
|
+
batch: dict[str, Any],
|
|
368
|
+
conf: float = 0.25,
|
|
369
|
+
iou_thres: float = 0.45,
|
|
370
|
+
) -> None:
|
|
371
|
+
"""Update confusion matrix for object detection task.
|
|
337
372
|
|
|
338
373
|
Args:
|
|
339
|
-
detections (
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
374
|
+
detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
|
|
375
|
+
information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
|
|
376
|
+
regular boxes or Array[N, 5] for OBB with angle.
|
|
377
|
+
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
|
|
378
|
+
5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
379
|
+
conf (float, optional): Confidence threshold for detections.
|
|
380
|
+
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
344
381
|
"""
|
|
382
|
+
gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
|
|
383
|
+
if self.matches is not None: # only if visualization is enabled
|
|
384
|
+
self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
|
|
385
|
+
for i in range(gt_cls.shape[0]):
|
|
386
|
+
self._append_matches("GT", batch, i) # store GT
|
|
387
|
+
is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
|
|
388
|
+
conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
|
|
389
|
+
no_pred = detections["cls"].shape[0] == 0
|
|
345
390
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
|
346
|
-
if
|
|
347
|
-
detections = detections[detections[
|
|
348
|
-
detection_classes = detections[
|
|
349
|
-
for dc in detection_classes:
|
|
350
|
-
self.matrix[dc, self.nc] += 1 #
|
|
391
|
+
if not no_pred:
|
|
392
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in detections}
|
|
393
|
+
detection_classes = detections["cls"].int().tolist()
|
|
394
|
+
for i, dc in enumerate(detection_classes):
|
|
395
|
+
self.matrix[dc, self.nc] += 1 # FP
|
|
396
|
+
self._append_matches("FP", detections, i)
|
|
351
397
|
return
|
|
352
|
-
if
|
|
353
|
-
gt_classes = gt_cls.int()
|
|
354
|
-
for gc in gt_classes:
|
|
355
|
-
self.matrix[self.nc, gc] += 1 #
|
|
398
|
+
if no_pred:
|
|
399
|
+
gt_classes = gt_cls.int().tolist()
|
|
400
|
+
for i, gc in enumerate(gt_classes):
|
|
401
|
+
self.matrix[self.nc, gc] += 1 # FN
|
|
402
|
+
self._append_matches("FN", batch, i)
|
|
356
403
|
return
|
|
357
404
|
|
|
358
|
-
detections = detections[detections[
|
|
359
|
-
gt_classes = gt_cls.int()
|
|
360
|
-
detection_classes = detections[
|
|
361
|
-
|
|
362
|
-
iou = (
|
|
363
|
-
batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
|
364
|
-
if is_obb
|
|
365
|
-
else box_iou(gt_bboxes, detections[:, :4])
|
|
366
|
-
)
|
|
405
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in detections}
|
|
406
|
+
gt_classes = gt_cls.int().tolist()
|
|
407
|
+
detection_classes = detections["cls"].int().tolist()
|
|
408
|
+
bboxes = detections["bboxes"]
|
|
409
|
+
iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
|
|
367
410
|
|
|
368
|
-
x = torch.where(iou >
|
|
411
|
+
x = torch.where(iou > iou_thres)
|
|
369
412
|
if x[0].shape[0]:
|
|
370
413
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
|
371
414
|
if x[0].shape[0] > 1:
|
|
@@ -381,59 +424,100 @@ class ConfusionMatrix:
|
|
|
381
424
|
for i, gc in enumerate(gt_classes):
|
|
382
425
|
j = m0 == i
|
|
383
426
|
if n and sum(j) == 1:
|
|
384
|
-
|
|
427
|
+
dc = detection_classes[m1[j].item()]
|
|
428
|
+
self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
|
|
429
|
+
if dc == gc:
|
|
430
|
+
self._append_matches("TP", detections, m1[j].item())
|
|
431
|
+
else:
|
|
432
|
+
self._append_matches("FP", detections, m1[j].item())
|
|
433
|
+
self._append_matches("FN", batch, i)
|
|
385
434
|
else:
|
|
386
|
-
self.matrix[self.nc, gc] += 1 #
|
|
435
|
+
self.matrix[self.nc, gc] += 1 # FN
|
|
436
|
+
self._append_matches("FN", batch, i)
|
|
387
437
|
|
|
388
438
|
for i, dc in enumerate(detection_classes):
|
|
389
439
|
if not any(m1 == i):
|
|
390
|
-
self.matrix[dc, self.nc] += 1 #
|
|
440
|
+
self.matrix[dc, self.nc] += 1 # FP
|
|
441
|
+
self._append_matches("FP", detections, i)
|
|
391
442
|
|
|
392
443
|
def matrix(self):
|
|
393
444
|
"""Return the confusion matrix."""
|
|
394
445
|
return self.matrix
|
|
395
446
|
|
|
396
|
-
def tp_fp(self):
|
|
397
|
-
"""
|
|
398
|
-
Return true positives and false positives.
|
|
447
|
+
def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
|
|
448
|
+
"""Return true positives and false positives.
|
|
399
449
|
|
|
400
450
|
Returns:
|
|
401
|
-
(
|
|
451
|
+
tp (np.ndarray): True positives.
|
|
452
|
+
fp (np.ndarray): False positives.
|
|
402
453
|
"""
|
|
403
454
|
tp = self.matrix.diagonal() # true positives
|
|
404
455
|
fp = self.matrix.sum(1) - tp # false positives
|
|
405
456
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
|
406
|
-
return (tp
|
|
457
|
+
return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
|
|
458
|
+
|
|
459
|
+
def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
|
|
460
|
+
"""Plot grid of GT, TP, FP, FN for each image.
|
|
461
|
+
|
|
462
|
+
Args:
|
|
463
|
+
img (torch.Tensor): Image to plot onto.
|
|
464
|
+
im_file (str): Image filename to save visualizations.
|
|
465
|
+
save_dir (Path): Location to save the visualizations to.
|
|
466
|
+
"""
|
|
467
|
+
if not self.matches:
|
|
468
|
+
return
|
|
469
|
+
from .ops import xyxy2xywh
|
|
470
|
+
from .plotting import plot_images
|
|
471
|
+
|
|
472
|
+
# Create batch of 4 (GT, TP, FP, FN)
|
|
473
|
+
labels = defaultdict(list)
|
|
474
|
+
for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
|
|
475
|
+
mbatch = self.matches[mtype]
|
|
476
|
+
if "conf" not in mbatch:
|
|
477
|
+
mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
|
|
478
|
+
mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
|
|
479
|
+
for k in mbatch.keys():
|
|
480
|
+
labels[k] += mbatch[k]
|
|
481
|
+
|
|
482
|
+
labels = {k: torch.stack(v, 0) if len(v) else torch.empty(0) for k, v in labels.items()}
|
|
483
|
+
if self.task != "obb" and labels["bboxes"].shape[0]:
|
|
484
|
+
labels["bboxes"] = xyxy2xywh(labels["bboxes"])
|
|
485
|
+
(save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
|
|
486
|
+
plot_images(
|
|
487
|
+
labels,
|
|
488
|
+
img.repeat(4, 1, 1, 1),
|
|
489
|
+
paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
|
|
490
|
+
fname=save_dir / "visualizations" / Path(im_file).name,
|
|
491
|
+
names=self.names,
|
|
492
|
+
max_subplots=4,
|
|
493
|
+
conf_thres=0.001,
|
|
494
|
+
)
|
|
407
495
|
|
|
408
496
|
@TryExcept(msg="ConfusionMatrix plot failure")
|
|
409
497
|
@plt_settings()
|
|
410
|
-
def plot(self, normalize=True, save_dir="",
|
|
411
|
-
"""
|
|
412
|
-
Plot the confusion matrix using matplotlib and save it to a file.
|
|
498
|
+
def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
|
|
499
|
+
"""Plot the confusion matrix using matplotlib and save it to a file.
|
|
413
500
|
|
|
414
501
|
Args:
|
|
415
|
-
normalize (bool): Whether to normalize the confusion matrix.
|
|
416
|
-
save_dir (str): Directory where the plot will be saved.
|
|
417
|
-
|
|
418
|
-
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
|
502
|
+
normalize (bool, optional): Whether to normalize the confusion matrix.
|
|
503
|
+
save_dir (str, optional): Directory where the plot will be saved.
|
|
504
|
+
on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.
|
|
419
505
|
"""
|
|
420
506
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
|
421
507
|
|
|
422
508
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
|
|
423
509
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
|
424
510
|
|
|
425
|
-
names = list(names)
|
|
426
511
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
|
|
512
|
+
names, n = list(self.names.values()), self.nc
|
|
427
513
|
if self.nc >= 100: # downsample for large class count
|
|
428
514
|
k = max(2, self.nc // 60) # step size for downsampling, always > 1
|
|
429
515
|
keep_idx = slice(None, None, k) # create slice instead of array
|
|
430
516
|
names = names[keep_idx] # slice class names
|
|
431
517
|
array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
|
|
432
518
|
n = (self.nc + k - 1) // k # number of retained classes
|
|
433
|
-
|
|
434
|
-
else
|
|
435
|
-
nc = nn = self.nc if self.task == "classify" else self.nc + 1
|
|
436
|
-
ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
|
|
519
|
+
nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
|
|
520
|
+
ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
|
|
437
521
|
xy_ticks = np.arange(len(ticklabels))
|
|
438
522
|
tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
|
|
439
523
|
label_fontsize = max(6, 12 - 0.1 * nc)
|
|
@@ -444,6 +528,7 @@ class ConfusionMatrix:
|
|
|
444
528
|
im = ax.imshow(array, cmap="Blues", vmin=0.0, interpolation="none")
|
|
445
529
|
ax.xaxis.set_label_position("bottom")
|
|
446
530
|
if nc < 30: # Add score for each cell of confusion matrix
|
|
531
|
+
color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold
|
|
447
532
|
for i, row in enumerate(array[:nc]):
|
|
448
533
|
for j, val in enumerate(row[:nc]):
|
|
449
534
|
val = array[i, j]
|
|
@@ -456,7 +541,7 @@ class ConfusionMatrix:
|
|
|
456
541
|
ha="center",
|
|
457
542
|
va="center",
|
|
458
543
|
fontsize=10,
|
|
459
|
-
color="white" if val >
|
|
544
|
+
color="white" if val > color_threshold else "black",
|
|
460
545
|
)
|
|
461
546
|
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)
|
|
462
547
|
title = "Confusion Matrix" + " Normalized" * normalize
|
|
@@ -470,7 +555,7 @@ class ConfusionMatrix:
|
|
|
470
555
|
if ticklabels != "auto":
|
|
471
556
|
ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha="center")
|
|
472
557
|
ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)
|
|
473
|
-
for s in
|
|
558
|
+
for s in {"left", "right", "bottom", "top", "outline"}:
|
|
474
559
|
if s != "outline":
|
|
475
560
|
ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline
|
|
476
561
|
cbar.ax.spines[s].set_visible(False)
|
|
@@ -486,8 +571,45 @@ class ConfusionMatrix:
|
|
|
486
571
|
for i in range(self.matrix.shape[0]):
|
|
487
572
|
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
|
488
573
|
|
|
574
|
+
def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
|
|
575
|
+
"""Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
|
|
576
|
+
normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
|
|
577
|
+
or SQL.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
normalize (bool): Whether to normalize the confusion matrix values.
|
|
581
|
+
decimals (int): Number of decimal places to round the output values to.
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
(list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
|
|
585
|
+
values for all actual classes.
|
|
586
|
+
|
|
587
|
+
Examples:
|
|
588
|
+
>>> results = model.val(data="coco8.yaml", plots=True)
|
|
589
|
+
>>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5)
|
|
590
|
+
>>> print(cm_dict)
|
|
591
|
+
"""
|
|
592
|
+
import re
|
|
593
|
+
|
|
594
|
+
names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
|
|
595
|
+
clean_names, seen = [], set()
|
|
596
|
+
for name in names:
|
|
597
|
+
clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
|
598
|
+
original_clean = clean_name
|
|
599
|
+
counter = 1
|
|
600
|
+
while clean_name.lower() in seen:
|
|
601
|
+
clean_name = f"{original_clean}_{counter}"
|
|
602
|
+
counter += 1
|
|
603
|
+
seen.add(clean_name.lower())
|
|
604
|
+
clean_names.append(clean_name)
|
|
605
|
+
array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals)
|
|
606
|
+
return [
|
|
607
|
+
dict({"Predicted": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))})
|
|
608
|
+
for i in range(len(clean_names))
|
|
609
|
+
]
|
|
489
610
|
|
|
490
|
-
|
|
611
|
+
|
|
612
|
+
def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:
|
|
491
613
|
"""Box filter of fraction f."""
|
|
492
614
|
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
|
493
615
|
p = np.ones(nf // 2) # ones padding
|
|
@@ -496,16 +618,22 @@ def smooth(y, f=0.05):
|
|
|
496
618
|
|
|
497
619
|
|
|
498
620
|
@plt_settings()
|
|
499
|
-
def plot_pr_curve(
|
|
500
|
-
|
|
501
|
-
|
|
621
|
+
def plot_pr_curve(
|
|
622
|
+
px: np.ndarray,
|
|
623
|
+
py: np.ndarray,
|
|
624
|
+
ap: np.ndarray,
|
|
625
|
+
save_dir: Path = Path("pr_curve.png"),
|
|
626
|
+
names: dict[int, str] = {},
|
|
627
|
+
on_plot=None,
|
|
628
|
+
):
|
|
629
|
+
"""Plot precision-recall curve.
|
|
502
630
|
|
|
503
631
|
Args:
|
|
504
632
|
px (np.ndarray): X values for the PR curve.
|
|
505
633
|
py (np.ndarray): Y values for the PR curve.
|
|
506
634
|
ap (np.ndarray): Average precision values.
|
|
507
635
|
save_dir (Path, optional): Path to save the plot.
|
|
508
|
-
names (dict, optional): Dictionary mapping class indices to class names.
|
|
636
|
+
names (dict[int, str], optional): Dictionary mapping class indices to class names.
|
|
509
637
|
on_plot (callable, optional): Function to call after plot is saved.
|
|
510
638
|
"""
|
|
511
639
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
|
@@ -517,7 +645,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
|
|
|
517
645
|
for i, y in enumerate(py.T):
|
|
518
646
|
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
|
519
647
|
else:
|
|
520
|
-
ax.plot(px, py, linewidth=1, color="
|
|
648
|
+
ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
|
|
521
649
|
|
|
522
650
|
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
|
|
523
651
|
ax.set_xlabel("Recall")
|
|
@@ -533,15 +661,22 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
|
|
|
533
661
|
|
|
534
662
|
|
|
535
663
|
@plt_settings()
|
|
536
|
-
def plot_mc_curve(
|
|
537
|
-
|
|
538
|
-
|
|
664
|
+
def plot_mc_curve(
|
|
665
|
+
px: np.ndarray,
|
|
666
|
+
py: np.ndarray,
|
|
667
|
+
save_dir: Path = Path("mc_curve.png"),
|
|
668
|
+
names: dict[int, str] = {},
|
|
669
|
+
xlabel: str = "Confidence",
|
|
670
|
+
ylabel: str = "Metric",
|
|
671
|
+
on_plot=None,
|
|
672
|
+
):
|
|
673
|
+
"""Plot metric-confidence curve.
|
|
539
674
|
|
|
540
675
|
Args:
|
|
541
676
|
px (np.ndarray): X values for the metric-confidence curve.
|
|
542
677
|
py (np.ndarray): Y values for the metric-confidence curve.
|
|
543
678
|
save_dir (Path, optional): Path to save the plot.
|
|
544
|
-
names (dict, optional): Dictionary mapping class indices to class names.
|
|
679
|
+
names (dict[int, str], optional): Dictionary mapping class indices to class names.
|
|
545
680
|
xlabel (str, optional): X-axis label.
|
|
546
681
|
ylabel (str, optional): Y-axis label.
|
|
547
682
|
on_plot (callable, optional): Function to call after plot is saved.
|
|
@@ -554,7 +689,7 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
|
|
|
554
689
|
for i, y in enumerate(py):
|
|
555
690
|
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
|
556
691
|
else:
|
|
557
|
-
ax.plot(px, py.T, linewidth=1, color="
|
|
692
|
+
ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
|
|
558
693
|
|
|
559
694
|
y = smooth(py.mean(0), 0.1)
|
|
560
695
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
|
@@ -570,18 +705,17 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
|
|
|
570
705
|
on_plot(save_dir)
|
|
571
706
|
|
|
572
707
|
|
|
573
|
-
def compute_ap(recall, precision):
|
|
574
|
-
"""
|
|
575
|
-
Compute the average precision (AP) given the recall and precision curves.
|
|
708
|
+
def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
|
|
709
|
+
"""Compute the average precision (AP) given the recall and precision curves.
|
|
576
710
|
|
|
577
711
|
Args:
|
|
578
712
|
recall (list): The recall curve.
|
|
579
713
|
precision (list): The precision curve.
|
|
580
714
|
|
|
581
715
|
Returns:
|
|
582
|
-
(float): Average precision.
|
|
583
|
-
(np.ndarray): Precision envelope curve.
|
|
584
|
-
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
|
716
|
+
ap (float): Average precision.
|
|
717
|
+
mpre (np.ndarray): Precision envelope curve.
|
|
718
|
+
mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
|
585
719
|
"""
|
|
586
720
|
# Append sentinel values to beginning and end
|
|
587
721
|
mrec = np.concatenate(([0.0], recall, [1.0]))
|
|
@@ -604,10 +738,18 @@ def compute_ap(recall, precision):
|
|
|
604
738
|
|
|
605
739
|
|
|
606
740
|
def ap_per_class(
|
|
607
|
-
tp
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
741
|
+
tp: np.ndarray,
|
|
742
|
+
conf: np.ndarray,
|
|
743
|
+
pred_cls: np.ndarray,
|
|
744
|
+
target_cls: np.ndarray,
|
|
745
|
+
plot: bool = False,
|
|
746
|
+
on_plot=None,
|
|
747
|
+
save_dir: Path = Path(),
|
|
748
|
+
names: dict[int, str] = {},
|
|
749
|
+
eps: float = 1e-16,
|
|
750
|
+
prefix: str = "",
|
|
751
|
+
) -> tuple:
|
|
752
|
+
"""Compute the average precision per class for object detection evaluation.
|
|
611
753
|
|
|
612
754
|
Args:
|
|
613
755
|
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
|
@@ -615,9 +757,9 @@ def ap_per_class(
|
|
|
615
757
|
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
|
616
758
|
target_cls (np.ndarray): Array of true classes of the detections.
|
|
617
759
|
plot (bool, optional): Whether to plot PR curves or not.
|
|
618
|
-
on_plot (
|
|
760
|
+
on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
|
|
619
761
|
save_dir (Path, optional): Directory to save the PR curves.
|
|
620
|
-
names (dict, optional):
|
|
762
|
+
names (dict[int, str], optional): Dictionary of class names to plot PR curves.
|
|
621
763
|
eps (float, optional): A small value to avoid division by zero.
|
|
622
764
|
prefix (str, optional): A prefix string for saving the plot files.
|
|
623
765
|
|
|
@@ -677,8 +819,7 @@ def ap_per_class(
|
|
|
677
819
|
|
|
678
820
|
# Compute F1 (harmonic mean of precision and recall)
|
|
679
821
|
f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
|
|
680
|
-
names = [
|
|
681
|
-
names = dict(enumerate(names)) # to dict
|
|
822
|
+
names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
|
|
682
823
|
if plot:
|
|
683
824
|
plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
|
|
684
825
|
plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
|
|
@@ -693,8 +834,7 @@ def ap_per_class(
|
|
|
693
834
|
|
|
694
835
|
|
|
695
836
|
class Metric(SimpleClass):
|
|
696
|
-
"""
|
|
697
|
-
Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
837
|
+
"""Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
698
838
|
|
|
699
839
|
Attributes:
|
|
700
840
|
p (list): Precision for each class. Shape: (nc,).
|
|
@@ -705,18 +845,20 @@ class Metric(SimpleClass):
|
|
|
705
845
|
nc (int): Number of classes.
|
|
706
846
|
|
|
707
847
|
Methods:
|
|
708
|
-
ap50
|
|
709
|
-
ap
|
|
710
|
-
mp
|
|
711
|
-
mr
|
|
712
|
-
map50
|
|
713
|
-
map75
|
|
714
|
-
map
|
|
715
|
-
mean_results
|
|
716
|
-
class_result
|
|
717
|
-
maps
|
|
718
|
-
fitness
|
|
719
|
-
update
|
|
848
|
+
ap50: AP at IoU threshold of 0.5 for all classes.
|
|
849
|
+
ap: AP at IoU thresholds from 0.5 to 0.95 for all classes.
|
|
850
|
+
mp: Mean precision of all classes.
|
|
851
|
+
mr: Mean recall of all classes.
|
|
852
|
+
map50: Mean AP at IoU threshold of 0.5 for all classes.
|
|
853
|
+
map75: Mean AP at IoU threshold of 0.75 for all classes.
|
|
854
|
+
map: Mean AP at IoU thresholds from 0.5 to 0.95 for all classes.
|
|
855
|
+
mean_results: Mean of results, returns mp, mr, map50, map.
|
|
856
|
+
class_result: Class-aware result, returns p[i], r[i], ap50[i], ap[i].
|
|
857
|
+
maps: mAP of each class.
|
|
858
|
+
fitness: Model fitness as a weighted combination of metrics.
|
|
859
|
+
update: Update metric attributes with new evaluation results.
|
|
860
|
+
curves: Provides a list of curves for accessing specific metrics like precision, recall, F1, etc.
|
|
861
|
+
curves_results: Provide a list of results for accessing specific metrics like precision, recall, F1, etc.
|
|
720
862
|
"""
|
|
721
863
|
|
|
722
864
|
def __init__(self) -> None:
|
|
@@ -729,29 +871,26 @@ class Metric(SimpleClass):
|
|
|
729
871
|
self.nc = 0
|
|
730
872
|
|
|
731
873
|
@property
|
|
732
|
-
def ap50(self):
|
|
733
|
-
"""
|
|
734
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
874
|
+
def ap50(self) -> np.ndarray | list:
|
|
875
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
735
876
|
|
|
736
877
|
Returns:
|
|
737
|
-
(np.ndarray
|
|
878
|
+
(np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
|
738
879
|
"""
|
|
739
880
|
return self.all_ap[:, 0] if len(self.all_ap) else []
|
|
740
881
|
|
|
741
882
|
@property
|
|
742
|
-
def ap(self):
|
|
743
|
-
"""
|
|
744
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
883
|
+
def ap(self) -> np.ndarray | list:
|
|
884
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
745
885
|
|
|
746
886
|
Returns:
|
|
747
|
-
(np.ndarray
|
|
887
|
+
(np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
|
748
888
|
"""
|
|
749
889
|
return self.all_ap.mean(1) if len(self.all_ap) else []
|
|
750
890
|
|
|
751
891
|
@property
|
|
752
|
-
def mp(self):
|
|
753
|
-
"""
|
|
754
|
-
Return the Mean Precision of all classes.
|
|
892
|
+
def mp(self) -> float:
|
|
893
|
+
"""Return the Mean Precision of all classes.
|
|
755
894
|
|
|
756
895
|
Returns:
|
|
757
896
|
(float): The mean precision of all classes.
|
|
@@ -759,9 +898,8 @@ class Metric(SimpleClass):
|
|
|
759
898
|
return self.p.mean() if len(self.p) else 0.0
|
|
760
899
|
|
|
761
900
|
@property
|
|
762
|
-
def mr(self):
|
|
763
|
-
"""
|
|
764
|
-
Return the Mean Recall of all classes.
|
|
901
|
+
def mr(self) -> float:
|
|
902
|
+
"""Return the Mean Recall of all classes.
|
|
765
903
|
|
|
766
904
|
Returns:
|
|
767
905
|
(float): The mean recall of all classes.
|
|
@@ -769,9 +907,8 @@ class Metric(SimpleClass):
|
|
|
769
907
|
return self.r.mean() if len(self.r) else 0.0
|
|
770
908
|
|
|
771
909
|
@property
|
|
772
|
-
def map50(self):
|
|
773
|
-
"""
|
|
774
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
910
|
+
def map50(self) -> float:
|
|
911
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
775
912
|
|
|
776
913
|
Returns:
|
|
777
914
|
(float): The mAP at an IoU threshold of 0.5.
|
|
@@ -779,9 +916,8 @@ class Metric(SimpleClass):
|
|
|
779
916
|
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
|
780
917
|
|
|
781
918
|
@property
|
|
782
|
-
def map75(self):
|
|
783
|
-
"""
|
|
784
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
919
|
+
def map75(self) -> float:
|
|
920
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
785
921
|
|
|
786
922
|
Returns:
|
|
787
923
|
(float): The mAP at an IoU threshold of 0.75.
|
|
@@ -789,39 +925,37 @@ class Metric(SimpleClass):
|
|
|
789
925
|
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
|
790
926
|
|
|
791
927
|
@property
|
|
792
|
-
def map(self):
|
|
793
|
-
"""
|
|
794
|
-
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
928
|
+
def map(self) -> float:
|
|
929
|
+
"""Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
795
930
|
|
|
796
931
|
Returns:
|
|
797
932
|
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
798
933
|
"""
|
|
799
934
|
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
|
800
935
|
|
|
801
|
-
def mean_results(self):
|
|
936
|
+
def mean_results(self) -> list[float]:
|
|
802
937
|
"""Return mean of results, mp, mr, map50, map."""
|
|
803
938
|
return [self.mp, self.mr, self.map50, self.map]
|
|
804
939
|
|
|
805
|
-
def class_result(self, i):
|
|
940
|
+
def class_result(self, i: int) -> tuple[float, float, float, float]:
|
|
806
941
|
"""Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
|
|
807
942
|
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
|
808
943
|
|
|
809
944
|
@property
|
|
810
|
-
def maps(self):
|
|
945
|
+
def maps(self) -> np.ndarray:
|
|
811
946
|
"""Return mAP of each class."""
|
|
812
947
|
maps = np.zeros(self.nc) + self.map
|
|
813
948
|
for i, c in enumerate(self.ap_class_index):
|
|
814
949
|
maps[c] = self.ap[i]
|
|
815
950
|
return maps
|
|
816
951
|
|
|
817
|
-
def fitness(self):
|
|
952
|
+
def fitness(self) -> float:
|
|
818
953
|
"""Return model fitness as a weighted combination of metrics."""
|
|
819
|
-
w = [0.0, 0.0, 0.
|
|
954
|
+
w = [0.0, 0.0, 0.0, 1.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
|
820
955
|
return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
|
|
821
956
|
|
|
822
|
-
def update(self, results):
|
|
823
|
-
"""
|
|
824
|
-
Update the evaluation metrics with a new set of results.
|
|
957
|
+
def update(self, results: tuple):
|
|
958
|
+
"""Update the evaluation metrics with a new set of results.
|
|
825
959
|
|
|
826
960
|
Args:
|
|
827
961
|
results (tuple): A tuple containing evaluation metrics:
|
|
@@ -850,12 +984,12 @@ class Metric(SimpleClass):
|
|
|
850
984
|
) = results
|
|
851
985
|
|
|
852
986
|
@property
|
|
853
|
-
def curves(self):
|
|
987
|
+
def curves(self) -> list:
|
|
854
988
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
855
989
|
return []
|
|
856
990
|
|
|
857
991
|
@property
|
|
858
|
-
def curves_results(self):
|
|
992
|
+
def curves_results(self) -> list[list]:
|
|
859
993
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
860
994
|
return [
|
|
861
995
|
[self.px, self.prec_values, "Recall", "Precision"],
|
|
@@ -865,227 +999,273 @@ class Metric(SimpleClass):
|
|
|
865
999
|
]
|
|
866
1000
|
|
|
867
1001
|
|
|
868
|
-
class DetMetrics(SimpleClass):
|
|
869
|
-
"""
|
|
870
|
-
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
1002
|
+
class DetMetrics(SimpleClass, DataExportMixin):
|
|
1003
|
+
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
871
1004
|
|
|
872
1005
|
Attributes:
|
|
873
|
-
|
|
874
|
-
plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
|
|
875
|
-
names (dict): A dictionary of class names.
|
|
1006
|
+
names (dict[int, str]): A dictionary of class names.
|
|
876
1007
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
877
|
-
speed (dict): A dictionary for storing execution times of different parts of the detection process.
|
|
1008
|
+
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
878
1009
|
task (str): The task type, set to 'detect'.
|
|
1010
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1011
|
+
target classes, and target images.
|
|
1012
|
+
nt_per_class: Number of targets per class.
|
|
1013
|
+
nt_per_image: Number of targets per image.
|
|
1014
|
+
|
|
1015
|
+
Methods:
|
|
1016
|
+
update_stats: Update statistics by appending new values to existing stat collections.
|
|
1017
|
+
process: Process predicted results for object detection and update metrics.
|
|
1018
|
+
clear_stats: Clear the stored statistics.
|
|
1019
|
+
keys: Return a list of keys for accessing specific metrics.
|
|
1020
|
+
mean_results: Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.
|
|
1021
|
+
class_result: Return the result of evaluating the performance of an object detection model on a specific class.
|
|
1022
|
+
maps: Return mean Average Precision (mAP) scores per class.
|
|
1023
|
+
fitness: Return the fitness of box object.
|
|
1024
|
+
ap_class_index: Return the average precision index per class.
|
|
1025
|
+
results_dict: Return dictionary of computed performance metrics and statistics.
|
|
1026
|
+
curves: Return a list of curves for accessing specific metrics curves.
|
|
1027
|
+
curves_results: Return a list of computed performance metrics and statistics.
|
|
1028
|
+
summary: Generate a summarized representation of per-class detection metrics as a list of dictionaries.
|
|
879
1029
|
"""
|
|
880
1030
|
|
|
881
|
-
def __init__(self,
|
|
882
|
-
"""
|
|
883
|
-
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
1031
|
+
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1032
|
+
"""Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
884
1033
|
|
|
885
1034
|
Args:
|
|
886
|
-
|
|
887
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
|
888
|
-
names (dict, optional): Dictionary mapping class indices to names.
|
|
1035
|
+
names (dict[int, str], optional): Dictionary of class names.
|
|
889
1036
|
"""
|
|
890
|
-
self.save_dir = save_dir
|
|
891
|
-
self.plot = plot
|
|
892
1037
|
self.names = names
|
|
893
1038
|
self.box = Metric()
|
|
894
1039
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
895
1040
|
self.task = "detect"
|
|
1041
|
+
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
|
1042
|
+
self.nt_per_class = None
|
|
1043
|
+
self.nt_per_image = None
|
|
896
1044
|
|
|
897
|
-
def
|
|
1045
|
+
def update_stats(self, stat: dict[str, Any]) -> None:
|
|
1046
|
+
"""Update statistics by appending new values to existing stat collections.
|
|
1047
|
+
|
|
1048
|
+
Args:
|
|
1049
|
+
stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
|
|
1050
|
+
keys in self.stats.
|
|
898
1051
|
"""
|
|
899
|
-
|
|
1052
|
+
for k in self.stats.keys():
|
|
1053
|
+
self.stats[k].append(stat[k])
|
|
1054
|
+
|
|
1055
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1056
|
+
"""Process predicted results for object detection and update metrics.
|
|
900
1057
|
|
|
901
1058
|
Args:
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
1059
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
1060
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
1061
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
|
1062
|
+
|
|
1063
|
+
Returns:
|
|
1064
|
+
(dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
907
1065
|
"""
|
|
1066
|
+
stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
|
|
1067
|
+
if not stats:
|
|
1068
|
+
return stats
|
|
908
1069
|
results = ap_per_class(
|
|
909
|
-
tp,
|
|
910
|
-
conf,
|
|
911
|
-
pred_cls,
|
|
912
|
-
target_cls,
|
|
913
|
-
plot=
|
|
914
|
-
save_dir=
|
|
1070
|
+
stats["tp"],
|
|
1071
|
+
stats["conf"],
|
|
1072
|
+
stats["pred_cls"],
|
|
1073
|
+
stats["target_cls"],
|
|
1074
|
+
plot=plot,
|
|
1075
|
+
save_dir=save_dir,
|
|
915
1076
|
names=self.names,
|
|
916
1077
|
on_plot=on_plot,
|
|
1078
|
+
prefix="Box",
|
|
917
1079
|
)[2:]
|
|
918
1080
|
self.box.nc = len(self.names)
|
|
919
1081
|
self.box.update(results)
|
|
1082
|
+
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
|
|
1083
|
+
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
|
|
1084
|
+
return stats
|
|
1085
|
+
|
|
1086
|
+
def clear_stats(self):
|
|
1087
|
+
"""Clear the stored statistics."""
|
|
1088
|
+
for v in self.stats.values():
|
|
1089
|
+
v.clear()
|
|
920
1090
|
|
|
921
1091
|
@property
|
|
922
|
-
def keys(self):
|
|
1092
|
+
def keys(self) -> list[str]:
|
|
923
1093
|
"""Return a list of keys for accessing specific metrics."""
|
|
924
1094
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
925
1095
|
|
|
926
|
-
def mean_results(self):
|
|
1096
|
+
def mean_results(self) -> list[float]:
|
|
927
1097
|
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
|
928
1098
|
return self.box.mean_results()
|
|
929
1099
|
|
|
930
|
-
def class_result(self, i):
|
|
1100
|
+
def class_result(self, i: int) -> tuple[float, float, float, float]:
|
|
931
1101
|
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
|
932
1102
|
return self.box.class_result(i)
|
|
933
1103
|
|
|
934
1104
|
@property
|
|
935
|
-
def maps(self):
|
|
1105
|
+
def maps(self) -> np.ndarray:
|
|
936
1106
|
"""Return mean Average Precision (mAP) scores per class."""
|
|
937
1107
|
return self.box.maps
|
|
938
1108
|
|
|
939
1109
|
@property
|
|
940
|
-
def fitness(self):
|
|
1110
|
+
def fitness(self) -> float:
|
|
941
1111
|
"""Return the fitness of box object."""
|
|
942
1112
|
return self.box.fitness()
|
|
943
1113
|
|
|
944
1114
|
@property
|
|
945
|
-
def ap_class_index(self):
|
|
1115
|
+
def ap_class_index(self) -> list:
|
|
946
1116
|
"""Return the average precision index per class."""
|
|
947
1117
|
return self.box.ap_class_index
|
|
948
1118
|
|
|
949
1119
|
@property
|
|
950
|
-
def results_dict(self):
|
|
1120
|
+
def results_dict(self) -> dict[str, float]:
|
|
951
1121
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
952
|
-
|
|
1122
|
+
keys = [*self.keys, "fitness"]
|
|
1123
|
+
values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
|
|
1124
|
+
return dict(zip(keys, values))
|
|
953
1125
|
|
|
954
1126
|
@property
|
|
955
|
-
def curves(self):
|
|
1127
|
+
def curves(self) -> list[str]:
|
|
956
1128
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
957
1129
|
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
|
|
958
1130
|
|
|
959
1131
|
@property
|
|
960
|
-
def curves_results(self):
|
|
961
|
-
"""Return
|
|
1132
|
+
def curves_results(self) -> list[list]:
|
|
1133
|
+
"""Return a list of computed performance metrics and statistics."""
|
|
962
1134
|
return self.box.curves_results
|
|
963
1135
|
|
|
1136
|
+
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1137
|
+
"""Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
|
|
1138
|
+
shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
964
1139
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
1140
|
+
Args:
|
|
1141
|
+
normalize (bool): For Detect metrics, everything is normalized by default [0-1].
|
|
1142
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1143
|
+
|
|
1144
|
+
Returns:
|
|
1145
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1146
|
+
values.
|
|
1147
|
+
|
|
1148
|
+
Examples:
|
|
1149
|
+
>>> results = model.val(data="coco8.yaml")
|
|
1150
|
+
>>> detection_summary = results.summary()
|
|
1151
|
+
>>> print(detection_summary)
|
|
1152
|
+
"""
|
|
1153
|
+
per_class = {
|
|
1154
|
+
"Box-P": self.box.p,
|
|
1155
|
+
"Box-R": self.box.r,
|
|
1156
|
+
"Box-F1": self.box.f1,
|
|
1157
|
+
}
|
|
1158
|
+
return [
|
|
1159
|
+
{
|
|
1160
|
+
"Class": self.names[self.ap_class_index[i]],
|
|
1161
|
+
"Images": self.nt_per_image[self.ap_class_index[i]],
|
|
1162
|
+
"Instances": self.nt_per_class[self.ap_class_index[i]],
|
|
1163
|
+
**{k: round(v[i], decimals) for k, v in per_class.items()},
|
|
1164
|
+
"mAP50": round(self.class_result(i)[2], decimals),
|
|
1165
|
+
"mAP50-95": round(self.class_result(i)[3], decimals),
|
|
1166
|
+
}
|
|
1167
|
+
for i in range(len(per_class["Box-P"]))
|
|
1168
|
+
]
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
class SegmentMetrics(DetMetrics):
|
|
1172
|
+
"""Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
968
1173
|
|
|
969
1174
|
Attributes:
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
names (dict): Dictionary of class names.
|
|
973
|
-
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
|
1175
|
+
names (dict[int, str]): Dictionary of class names.
|
|
1176
|
+
box (Metric): An instance of the Metric class for storing detection results.
|
|
974
1177
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
975
|
-
speed (dict):
|
|
1178
|
+
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
976
1179
|
task (str): The task type, set to 'segment'.
|
|
1180
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1181
|
+
target classes, and target images.
|
|
1182
|
+
nt_per_class: Number of targets per class.
|
|
1183
|
+
nt_per_image: Number of targets per image.
|
|
1184
|
+
|
|
1185
|
+
Methods:
|
|
1186
|
+
process: Process the detection and segmentation metrics over the given set of predictions.
|
|
1187
|
+
keys: Return a list of keys for accessing metrics.
|
|
1188
|
+
mean_results: Return the mean metrics for bounding box and segmentation results.
|
|
1189
|
+
class_result: Return classification results for a specified class index.
|
|
1190
|
+
maps: Return mAP scores for object detection and semantic segmentation models.
|
|
1191
|
+
fitness: Return the fitness score for both segmentation and bounding box models.
|
|
1192
|
+
curves: Return a list of curves for accessing specific metrics curves.
|
|
1193
|
+
curves_results: Provide a list of computed performance metrics and statistics.
|
|
1194
|
+
summary: Generate a summarized representation of per-class segmentation metrics as a list of dictionaries.
|
|
977
1195
|
"""
|
|
978
1196
|
|
|
979
|
-
def __init__(self,
|
|
980
|
-
"""
|
|
981
|
-
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
1197
|
+
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1198
|
+
"""Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
982
1199
|
|
|
983
1200
|
Args:
|
|
984
|
-
|
|
985
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
|
986
|
-
names (dict, optional): Dictionary mapping class indices to names.
|
|
1201
|
+
names (dict[int, str], optional): Dictionary of class names.
|
|
987
1202
|
"""
|
|
988
|
-
self
|
|
989
|
-
self.plot = plot
|
|
990
|
-
self.names = names
|
|
991
|
-
self.box = Metric()
|
|
1203
|
+
DetMetrics.__init__(self, names)
|
|
992
1204
|
self.seg = Metric()
|
|
993
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
994
1205
|
self.task = "segment"
|
|
1206
|
+
self.stats["tp_m"] = [] # add additional stats for masks
|
|
995
1207
|
|
|
996
|
-
def process(self,
|
|
997
|
-
"""
|
|
998
|
-
Process the detection and segmentation metrics over the given set of predictions.
|
|
1208
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1209
|
+
"""Process the detection and segmentation metrics over the given set of predictions.
|
|
999
1210
|
|
|
1000
1211
|
Args:
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1212
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
1213
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
1214
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
|
1215
|
+
|
|
1216
|
+
Returns:
|
|
1217
|
+
(dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
1007
1218
|
"""
|
|
1219
|
+
stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
|
|
1008
1220
|
results_mask = ap_per_class(
|
|
1009
|
-
tp_m,
|
|
1010
|
-
conf,
|
|
1011
|
-
pred_cls,
|
|
1012
|
-
target_cls,
|
|
1013
|
-
plot=
|
|
1221
|
+
stats["tp_m"],
|
|
1222
|
+
stats["conf"],
|
|
1223
|
+
stats["pred_cls"],
|
|
1224
|
+
stats["target_cls"],
|
|
1225
|
+
plot=plot,
|
|
1014
1226
|
on_plot=on_plot,
|
|
1015
|
-
save_dir=
|
|
1227
|
+
save_dir=save_dir,
|
|
1016
1228
|
names=self.names,
|
|
1017
1229
|
prefix="Mask",
|
|
1018
1230
|
)[2:]
|
|
1019
1231
|
self.seg.nc = len(self.names)
|
|
1020
1232
|
self.seg.update(results_mask)
|
|
1021
|
-
|
|
1022
|
-
tp,
|
|
1023
|
-
conf,
|
|
1024
|
-
pred_cls,
|
|
1025
|
-
target_cls,
|
|
1026
|
-
plot=self.plot,
|
|
1027
|
-
on_plot=on_plot,
|
|
1028
|
-
save_dir=self.save_dir,
|
|
1029
|
-
names=self.names,
|
|
1030
|
-
prefix="Box",
|
|
1031
|
-
)[2:]
|
|
1032
|
-
self.box.nc = len(self.names)
|
|
1033
|
-
self.box.update(results_box)
|
|
1233
|
+
return stats
|
|
1034
1234
|
|
|
1035
1235
|
@property
|
|
1036
|
-
def keys(self):
|
|
1236
|
+
def keys(self) -> list[str]:
|
|
1037
1237
|
"""Return a list of keys for accessing metrics."""
|
|
1038
1238
|
return [
|
|
1039
|
-
|
|
1040
|
-
"metrics/recall(B)",
|
|
1041
|
-
"metrics/mAP50(B)",
|
|
1042
|
-
"metrics/mAP50-95(B)",
|
|
1239
|
+
*DetMetrics.keys.fget(self),
|
|
1043
1240
|
"metrics/precision(M)",
|
|
1044
1241
|
"metrics/recall(M)",
|
|
1045
1242
|
"metrics/mAP50(M)",
|
|
1046
1243
|
"metrics/mAP50-95(M)",
|
|
1047
1244
|
]
|
|
1048
1245
|
|
|
1049
|
-
def mean_results(self):
|
|
1246
|
+
def mean_results(self) -> list[float]:
|
|
1050
1247
|
"""Return the mean metrics for bounding box and segmentation results."""
|
|
1051
|
-
return
|
|
1248
|
+
return DetMetrics.mean_results(self) + self.seg.mean_results()
|
|
1052
1249
|
|
|
1053
|
-
def class_result(self, i):
|
|
1250
|
+
def class_result(self, i: int) -> list[float]:
|
|
1054
1251
|
"""Return classification results for a specified class index."""
|
|
1055
|
-
return
|
|
1252
|
+
return DetMetrics.class_result(self, i) + self.seg.class_result(i)
|
|
1056
1253
|
|
|
1057
1254
|
@property
|
|
1058
|
-
def maps(self):
|
|
1255
|
+
def maps(self) -> np.ndarray:
|
|
1059
1256
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
|
1060
|
-
return
|
|
1257
|
+
return DetMetrics.maps.fget(self) + self.seg.maps
|
|
1061
1258
|
|
|
1062
1259
|
@property
|
|
1063
|
-
def fitness(self):
|
|
1260
|
+
def fitness(self) -> float:
|
|
1064
1261
|
"""Return the fitness score for both segmentation and bounding box models."""
|
|
1065
|
-
return self.seg.fitness() +
|
|
1262
|
+
return self.seg.fitness() + DetMetrics.fitness.fget(self)
|
|
1066
1263
|
|
|
1067
1264
|
@property
|
|
1068
|
-
def
|
|
1069
|
-
"""
|
|
1070
|
-
Return the class indices.
|
|
1071
|
-
|
|
1072
|
-
Boxes and masks have the same ap_class_index.
|
|
1073
|
-
"""
|
|
1074
|
-
return self.box.ap_class_index
|
|
1075
|
-
|
|
1076
|
-
@property
|
|
1077
|
-
def results_dict(self):
|
|
1078
|
-
"""Return results of object detection model for evaluation."""
|
|
1079
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
1080
|
-
|
|
1081
|
-
@property
|
|
1082
|
-
def curves(self):
|
|
1265
|
+
def curves(self) -> list[str]:
|
|
1083
1266
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1084
1267
|
return [
|
|
1085
|
-
|
|
1086
|
-
"F1-Confidence(B)",
|
|
1087
|
-
"Precision-Confidence(B)",
|
|
1088
|
-
"Recall-Confidence(B)",
|
|
1268
|
+
*DetMetrics.curves.fget(self),
|
|
1089
1269
|
"Precision-Recall(M)",
|
|
1090
1270
|
"F1-Confidence(M)",
|
|
1091
1271
|
"Precision-Confidence(M)",
|
|
@@ -1093,127 +1273,137 @@ class SegmentMetrics(SimpleClass):
|
|
|
1093
1273
|
]
|
|
1094
1274
|
|
|
1095
1275
|
@property
|
|
1096
|
-
def curves_results(self):
|
|
1097
|
-
"""Return
|
|
1098
|
-
return
|
|
1276
|
+
def curves_results(self) -> list[list]:
|
|
1277
|
+
"""Return a list of computed performance metrics and statistics."""
|
|
1278
|
+
return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
|
1099
1279
|
|
|
1280
|
+
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1281
|
+
"""Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
|
|
1282
|
+
both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
|
|
1283
|
+
each class.
|
|
1100
1284
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1285
|
+
Args:
|
|
1286
|
+
normalize (bool): For Segment metrics, everything is normalized by default [0-1].
|
|
1287
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1288
|
+
|
|
1289
|
+
Returns:
|
|
1290
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1291
|
+
values.
|
|
1292
|
+
|
|
1293
|
+
Examples:
|
|
1294
|
+
>>> results = model.val(data="coco8-seg.yaml")
|
|
1295
|
+
>>> seg_summary = results.summary(decimals=4)
|
|
1296
|
+
>>> print(seg_summary)
|
|
1297
|
+
"""
|
|
1298
|
+
per_class = {
|
|
1299
|
+
"Mask-P": self.seg.p,
|
|
1300
|
+
"Mask-R": self.seg.r,
|
|
1301
|
+
"Mask-F1": self.seg.f1,
|
|
1302
|
+
}
|
|
1303
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
|
1304
|
+
for i, s in enumerate(summary):
|
|
1305
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
|
|
1306
|
+
return summary
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
class PoseMetrics(DetMetrics):
|
|
1310
|
+
"""Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
1104
1311
|
|
|
1105
1312
|
Attributes:
|
|
1106
|
-
|
|
1107
|
-
plot (bool): Whether to save the detection and pose plots.
|
|
1108
|
-
names (dict): Dictionary of class names.
|
|
1109
|
-
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
|
1313
|
+
names (dict[int, str]): Dictionary of class names.
|
|
1110
1314
|
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
|
1111
|
-
|
|
1315
|
+
box (Metric): An instance of the Metric class for storing detection results.
|
|
1316
|
+
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1112
1317
|
task (str): The task type, set to 'pose'.
|
|
1318
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1319
|
+
target classes, and target images.
|
|
1320
|
+
nt_per_class: Number of targets per class.
|
|
1321
|
+
nt_per_image: Number of targets per image.
|
|
1113
1322
|
|
|
1114
1323
|
Methods:
|
|
1115
|
-
process
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1324
|
+
process: Process the detection and pose metrics over the given set of predictions. R
|
|
1325
|
+
keys: Return a list of keys for accessing metrics.
|
|
1326
|
+
mean_results: Return the mean results of box and pose.
|
|
1327
|
+
class_result: Return the class-wise detection results for a specific class i.
|
|
1328
|
+
maps: Return the mean average precision (mAP) per class for both box and pose detections.
|
|
1329
|
+
fitness: Return combined fitness score for pose and box detection.
|
|
1330
|
+
curves: Return a list of curves for accessing specific metrics curves.
|
|
1331
|
+
curves_results: Provide a list of computed performance metrics and statistics.
|
|
1332
|
+
summary: Generate a summarized representation of per-class pose metrics as a list of dictionaries.
|
|
1122
1333
|
"""
|
|
1123
1334
|
|
|
1124
|
-
def __init__(self,
|
|
1125
|
-
"""
|
|
1126
|
-
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1335
|
+
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1336
|
+
"""Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1127
1337
|
|
|
1128
1338
|
Args:
|
|
1129
|
-
|
|
1130
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
|
1131
|
-
names (dict, optional): Dictionary mapping class indices to names.
|
|
1339
|
+
names (dict[int, str], optional): Dictionary of class names.
|
|
1132
1340
|
"""
|
|
1133
|
-
super().__init__(
|
|
1134
|
-
self.save_dir = save_dir
|
|
1135
|
-
self.plot = plot
|
|
1136
|
-
self.names = names
|
|
1137
|
-
self.box = Metric()
|
|
1341
|
+
super().__init__(names)
|
|
1138
1342
|
self.pose = Metric()
|
|
1139
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1140
1343
|
self.task = "pose"
|
|
1344
|
+
self.stats["tp_p"] = [] # add additional stats for pose
|
|
1141
1345
|
|
|
1142
|
-
def process(self,
|
|
1143
|
-
"""
|
|
1144
|
-
Process the detection and pose metrics over the given set of predictions.
|
|
1346
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1347
|
+
"""Process the detection and pose metrics over the given set of predictions.
|
|
1145
1348
|
|
|
1146
1349
|
Args:
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
conf (np.ndarray): Confidence array.
|
|
1150
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
|
1151
|
-
target_cls (np.ndarray): Target class indices array.
|
|
1350
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
1351
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
|
1152
1352
|
on_plot (callable, optional): Function to call after plots are generated.
|
|
1353
|
+
|
|
1354
|
+
Returns:
|
|
1355
|
+
(dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
|
1153
1356
|
"""
|
|
1357
|
+
stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
|
|
1154
1358
|
results_pose = ap_per_class(
|
|
1155
|
-
tp_p,
|
|
1156
|
-
conf,
|
|
1157
|
-
pred_cls,
|
|
1158
|
-
target_cls,
|
|
1159
|
-
plot=
|
|
1359
|
+
stats["tp_p"],
|
|
1360
|
+
stats["conf"],
|
|
1361
|
+
stats["pred_cls"],
|
|
1362
|
+
stats["target_cls"],
|
|
1363
|
+
plot=plot,
|
|
1160
1364
|
on_plot=on_plot,
|
|
1161
|
-
save_dir=
|
|
1365
|
+
save_dir=save_dir,
|
|
1162
1366
|
names=self.names,
|
|
1163
1367
|
prefix="Pose",
|
|
1164
1368
|
)[2:]
|
|
1165
1369
|
self.pose.nc = len(self.names)
|
|
1166
1370
|
self.pose.update(results_pose)
|
|
1167
|
-
|
|
1168
|
-
tp,
|
|
1169
|
-
conf,
|
|
1170
|
-
pred_cls,
|
|
1171
|
-
target_cls,
|
|
1172
|
-
plot=self.plot,
|
|
1173
|
-
on_plot=on_plot,
|
|
1174
|
-
save_dir=self.save_dir,
|
|
1175
|
-
names=self.names,
|
|
1176
|
-
prefix="Box",
|
|
1177
|
-
)[2:]
|
|
1178
|
-
self.box.nc = len(self.names)
|
|
1179
|
-
self.box.update(results_box)
|
|
1371
|
+
return stats
|
|
1180
1372
|
|
|
1181
1373
|
@property
|
|
1182
|
-
def keys(self):
|
|
1183
|
-
"""Return list of evaluation metric keys."""
|
|
1374
|
+
def keys(self) -> list[str]:
|
|
1375
|
+
"""Return a list of evaluation metric keys."""
|
|
1184
1376
|
return [
|
|
1185
|
-
|
|
1186
|
-
"metrics/recall(B)",
|
|
1187
|
-
"metrics/mAP50(B)",
|
|
1188
|
-
"metrics/mAP50-95(B)",
|
|
1377
|
+
*DetMetrics.keys.fget(self),
|
|
1189
1378
|
"metrics/precision(P)",
|
|
1190
1379
|
"metrics/recall(P)",
|
|
1191
1380
|
"metrics/mAP50(P)",
|
|
1192
1381
|
"metrics/mAP50-95(P)",
|
|
1193
1382
|
]
|
|
1194
1383
|
|
|
1195
|
-
def mean_results(self):
|
|
1384
|
+
def mean_results(self) -> list[float]:
|
|
1196
1385
|
"""Return the mean results of box and pose."""
|
|
1197
|
-
return
|
|
1386
|
+
return DetMetrics.mean_results(self) + self.pose.mean_results()
|
|
1198
1387
|
|
|
1199
|
-
def class_result(self, i):
|
|
1388
|
+
def class_result(self, i: int) -> list[float]:
|
|
1200
1389
|
"""Return the class-wise detection results for a specific class i."""
|
|
1201
|
-
return
|
|
1390
|
+
return DetMetrics.class_result(self, i) + self.pose.class_result(i)
|
|
1202
1391
|
|
|
1203
1392
|
@property
|
|
1204
|
-
def maps(self):
|
|
1393
|
+
def maps(self) -> np.ndarray:
|
|
1205
1394
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
|
1206
|
-
return
|
|
1395
|
+
return DetMetrics.maps.fget(self) + self.pose.maps
|
|
1207
1396
|
|
|
1208
1397
|
@property
|
|
1209
|
-
def fitness(self):
|
|
1398
|
+
def fitness(self) -> float:
|
|
1210
1399
|
"""Return combined fitness score for pose and box detection."""
|
|
1211
|
-
return self.pose.fitness() +
|
|
1400
|
+
return self.pose.fitness() + DetMetrics.fitness.fget(self)
|
|
1212
1401
|
|
|
1213
1402
|
@property
|
|
1214
|
-
def curves(self):
|
|
1403
|
+
def curves(self) -> list[str]:
|
|
1215
1404
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1216
1405
|
return [
|
|
1406
|
+
*DetMetrics.curves.fget(self),
|
|
1217
1407
|
"Precision-Recall(B)",
|
|
1218
1408
|
"F1-Confidence(B)",
|
|
1219
1409
|
"Precision-Confidence(B)",
|
|
@@ -1225,20 +1415,55 @@ class PoseMetrics(SegmentMetrics):
|
|
|
1225
1415
|
]
|
|
1226
1416
|
|
|
1227
1417
|
@property
|
|
1228
|
-
def curves_results(self):
|
|
1229
|
-
"""Return
|
|
1230
|
-
return
|
|
1418
|
+
def curves_results(self) -> list[list]:
|
|
1419
|
+
"""Return a list of computed performance metrics and statistics."""
|
|
1420
|
+
return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
|
1231
1421
|
|
|
1422
|
+
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1423
|
+
"""Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
|
|
1424
|
+
and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1232
1425
|
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1426
|
+
Args:
|
|
1427
|
+
normalize (bool): For Pose metrics, everything is normalized by default [0-1].
|
|
1428
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1429
|
+
|
|
1430
|
+
Returns:
|
|
1431
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1432
|
+
values.
|
|
1433
|
+
|
|
1434
|
+
Examples:
|
|
1435
|
+
>>> results = model.val(data="coco8-pose.yaml")
|
|
1436
|
+
>>> pose_summary = results.summary(decimals=4)
|
|
1437
|
+
>>> print(pose_summary)
|
|
1438
|
+
"""
|
|
1439
|
+
per_class = {
|
|
1440
|
+
"Pose-P": self.pose.p,
|
|
1441
|
+
"Pose-R": self.pose.r,
|
|
1442
|
+
"Pose-F1": self.pose.f1,
|
|
1443
|
+
}
|
|
1444
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
|
1445
|
+
for i, s in enumerate(summary):
|
|
1446
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
|
|
1447
|
+
return summary
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1451
|
+
"""Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
1236
1452
|
|
|
1237
1453
|
Attributes:
|
|
1238
1454
|
top1 (float): The top-1 accuracy.
|
|
1239
1455
|
top5 (float): The top-5 accuracy.
|
|
1240
1456
|
speed (dict): A dictionary containing the time taken for each step in the pipeline.
|
|
1241
1457
|
task (str): The task type, set to 'classify'.
|
|
1458
|
+
|
|
1459
|
+
Methods:
|
|
1460
|
+
process: Process target classes and predicted classes to compute metrics.
|
|
1461
|
+
fitness: Return mean of top-1 and top-5 accuracies as fitness score.
|
|
1462
|
+
results_dict: Return a dictionary with model's performance metrics and fitness score.
|
|
1463
|
+
keys: Return a list of keys for the results_dict property.
|
|
1464
|
+
curves: Return a list of curves for accessing specific metrics curves.
|
|
1465
|
+
curves_results: Provide a list of computed performance metrics and statistics.
|
|
1466
|
+
summary: Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1242
1467
|
"""
|
|
1243
1468
|
|
|
1244
1469
|
def __init__(self) -> None:
|
|
@@ -1248,9 +1473,8 @@ class ClassifyMetrics(SimpleClass):
|
|
|
1248
1473
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1249
1474
|
self.task = "classify"
|
|
1250
1475
|
|
|
1251
|
-
def process(self, targets, pred):
|
|
1252
|
-
"""
|
|
1253
|
-
Process target classes and predicted classes to compute metrics.
|
|
1476
|
+
def process(self, targets: torch.Tensor, pred: torch.Tensor):
|
|
1477
|
+
"""Process target classes and predicted classes to compute metrics.
|
|
1254
1478
|
|
|
1255
1479
|
Args:
|
|
1256
1480
|
targets (torch.Tensor): Target classes.
|
|
@@ -1262,124 +1486,71 @@ class ClassifyMetrics(SimpleClass):
|
|
|
1262
1486
|
self.top1, self.top5 = acc.mean(0).tolist()
|
|
1263
1487
|
|
|
1264
1488
|
@property
|
|
1265
|
-
def fitness(self):
|
|
1489
|
+
def fitness(self) -> float:
|
|
1266
1490
|
"""Return mean of top-1 and top-5 accuracies as fitness score."""
|
|
1267
1491
|
return (self.top1 + self.top5) / 2
|
|
1268
1492
|
|
|
1269
1493
|
@property
|
|
1270
|
-
def results_dict(self):
|
|
1494
|
+
def results_dict(self) -> dict[str, float]:
|
|
1271
1495
|
"""Return a dictionary with model's performance metrics and fitness score."""
|
|
1272
|
-
return dict(zip(self.keys
|
|
1496
|
+
return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
|
|
1273
1497
|
|
|
1274
1498
|
@property
|
|
1275
|
-
def keys(self):
|
|
1499
|
+
def keys(self) -> list[str]:
|
|
1276
1500
|
"""Return a list of keys for the results_dict property."""
|
|
1277
1501
|
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
|
1278
1502
|
|
|
1279
1503
|
@property
|
|
1280
|
-
def curves(self):
|
|
1504
|
+
def curves(self) -> list:
|
|
1281
1505
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1282
1506
|
return []
|
|
1283
1507
|
|
|
1284
1508
|
@property
|
|
1285
|
-
def curves_results(self):
|
|
1509
|
+
def curves_results(self) -> list:
|
|
1286
1510
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1287
1511
|
return []
|
|
1288
1512
|
|
|
1513
|
+
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
|
|
1514
|
+
"""Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1289
1515
|
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1516
|
+
Args:
|
|
1517
|
+
normalize (bool): For Classify metrics, everything is normalized by default [0-1].
|
|
1518
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1519
|
+
|
|
1520
|
+
Returns:
|
|
1521
|
+
(list[dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy.
|
|
1522
|
+
|
|
1523
|
+
Examples:
|
|
1524
|
+
>>> results = model.val(data="imagenet10")
|
|
1525
|
+
>>> classify_summary = results.summary(decimals=4)
|
|
1526
|
+
>>> print(classify_summary)
|
|
1527
|
+
"""
|
|
1528
|
+
return [{"top1_acc": round(self.top1, decimals), "top5_acc": round(self.top5, decimals)}]
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
class OBBMetrics(DetMetrics):
|
|
1532
|
+
"""Metrics for evaluating oriented bounding box (OBB) detection.
|
|
1293
1533
|
|
|
1294
1534
|
Attributes:
|
|
1295
|
-
|
|
1296
|
-
plot (bool): Whether to save the detection plots.
|
|
1297
|
-
names (dict): Dictionary of class names.
|
|
1535
|
+
names (dict[int, str]): Dictionary of class names.
|
|
1298
1536
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1299
|
-
speed (dict): A dictionary for storing execution times of different parts of the detection process.
|
|
1537
|
+
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1538
|
+
task (str): The task type, set to 'obb'.
|
|
1539
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1540
|
+
target classes, and target images.
|
|
1541
|
+
nt_per_class: Number of targets per class.
|
|
1542
|
+
nt_per_image: Number of targets per image.
|
|
1300
1543
|
|
|
1301
1544
|
References:
|
|
1302
1545
|
https://arxiv.org/pdf/2106.06072.pdf
|
|
1303
1546
|
"""
|
|
1304
1547
|
|
|
1305
|
-
def __init__(self,
|
|
1306
|
-
"""
|
|
1307
|
-
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1548
|
+
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1549
|
+
"""Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1308
1550
|
|
|
1309
1551
|
Args:
|
|
1310
|
-
|
|
1311
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
|
1312
|
-
names (dict, optional): Dictionary mapping class indices to names.
|
|
1552
|
+
names (dict[int, str], optional): Dictionary of class names.
|
|
1313
1553
|
"""
|
|
1314
|
-
self
|
|
1315
|
-
|
|
1316
|
-
self.
|
|
1317
|
-
self.box = Metric()
|
|
1318
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
|
1319
|
-
|
|
1320
|
-
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
|
|
1321
|
-
"""
|
|
1322
|
-
Process predicted results for object detection and update metrics.
|
|
1323
|
-
|
|
1324
|
-
Args:
|
|
1325
|
-
tp (np.ndarray): True positive array.
|
|
1326
|
-
conf (np.ndarray): Confidence array.
|
|
1327
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
|
1328
|
-
target_cls (np.ndarray): Target class indices array.
|
|
1329
|
-
on_plot (callable, optional): Function to call after plots are generated.
|
|
1330
|
-
"""
|
|
1331
|
-
results = ap_per_class(
|
|
1332
|
-
tp,
|
|
1333
|
-
conf,
|
|
1334
|
-
pred_cls,
|
|
1335
|
-
target_cls,
|
|
1336
|
-
plot=self.plot,
|
|
1337
|
-
save_dir=self.save_dir,
|
|
1338
|
-
names=self.names,
|
|
1339
|
-
on_plot=on_plot,
|
|
1340
|
-
)[2:]
|
|
1341
|
-
self.box.nc = len(self.names)
|
|
1342
|
-
self.box.update(results)
|
|
1343
|
-
|
|
1344
|
-
@property
|
|
1345
|
-
def keys(self):
|
|
1346
|
-
"""Return a list of keys for accessing specific metrics."""
|
|
1347
|
-
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
|
1348
|
-
|
|
1349
|
-
def mean_results(self):
|
|
1350
|
-
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
|
1351
|
-
return self.box.mean_results()
|
|
1352
|
-
|
|
1353
|
-
def class_result(self, i):
|
|
1354
|
-
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
|
1355
|
-
return self.box.class_result(i)
|
|
1356
|
-
|
|
1357
|
-
@property
|
|
1358
|
-
def maps(self):
|
|
1359
|
-
"""Return mean Average Precision (mAP) scores per class."""
|
|
1360
|
-
return self.box.maps
|
|
1361
|
-
|
|
1362
|
-
@property
|
|
1363
|
-
def fitness(self):
|
|
1364
|
-
"""Return the fitness of box object."""
|
|
1365
|
-
return self.box.fitness()
|
|
1366
|
-
|
|
1367
|
-
@property
|
|
1368
|
-
def ap_class_index(self):
|
|
1369
|
-
"""Return the average precision index per class."""
|
|
1370
|
-
return self.box.ap_class_index
|
|
1371
|
-
|
|
1372
|
-
@property
|
|
1373
|
-
def results_dict(self):
|
|
1374
|
-
"""Return dictionary of computed performance metrics and statistics."""
|
|
1375
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
|
1376
|
-
|
|
1377
|
-
@property
|
|
1378
|
-
def curves(self):
|
|
1379
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
|
1380
|
-
return []
|
|
1381
|
-
|
|
1382
|
-
@property
|
|
1383
|
-
def curves_results(self):
|
|
1384
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
|
1385
|
-
return []
|
|
1554
|
+
DetMetrics.__init__(self, names)
|
|
1555
|
+
# TODO: probably remove task as well
|
|
1556
|
+
self.task = "obb"
|