dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
|
@@ -15,14 +15,17 @@ import torch
|
|
|
15
15
|
from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
|
|
16
16
|
|
|
17
17
|
OKS_SIGMA = (
|
|
18
|
-
np.array(
|
|
18
|
+
np.array(
|
|
19
|
+
[0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89],
|
|
20
|
+
dtype=np.float32,
|
|
21
|
+
)
|
|
19
22
|
/ 10.0
|
|
20
23
|
)
|
|
24
|
+
RLE_WEIGHT = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5])
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
|
|
24
|
-
"""
|
|
25
|
-
Calculate the intersection over box2 area given box1 and box2.
|
|
28
|
+
"""Calculate the intersection over box2 area given box1 and box2.
|
|
26
29
|
|
|
27
30
|
Args:
|
|
28
31
|
box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
|
|
@@ -53,8 +56,7 @@ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float =
|
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
56
|
-
"""
|
|
57
|
-
Calculate intersection-over-union (IoU) of boxes.
|
|
59
|
+
"""Calculate intersection-over-union (IoU) of boxes.
|
|
58
60
|
|
|
59
61
|
Args:
|
|
60
62
|
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
|
|
@@ -85,19 +87,17 @@ def bbox_iou(
|
|
|
85
87
|
CIoU: bool = False,
|
|
86
88
|
eps: float = 1e-7,
|
|
87
89
|
) -> torch.Tensor:
|
|
88
|
-
"""
|
|
89
|
-
Calculate the Intersection over Union (IoU) between bounding boxes.
|
|
90
|
+
"""Calculate the Intersection over Union (IoU) between bounding boxes.
|
|
90
91
|
|
|
91
|
-
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
or (x1, y1, x2, y2) if `xywh=False`.
|
|
92
|
+
This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
|
|
93
|
+
may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
|
|
94
|
+
dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
|
|
95
95
|
|
|
96
96
|
Args:
|
|
97
97
|
box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
98
98
|
box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
|
|
99
|
-
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
|
100
|
-
|
|
99
|
+
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
|
|
100
|
+
x2, y2) format.
|
|
101
101
|
GIoU (bool, optional): If True, calculate Generalized IoU.
|
|
102
102
|
DIoU (bool, optional): If True, calculate Distance IoU.
|
|
103
103
|
CIoU (bool, optional): If True, calculate Complete IoU.
|
|
@@ -148,14 +148,13 @@ def bbox_iou(
|
|
|
148
148
|
|
|
149
149
|
|
|
150
150
|
def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
|
151
|
-
"""
|
|
152
|
-
Calculate masks IoU.
|
|
151
|
+
"""Calculate masks IoU.
|
|
153
152
|
|
|
154
153
|
Args:
|
|
155
154
|
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
|
|
156
|
-
|
|
157
|
-
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
|
158
|
-
|
|
155
|
+
product of image width and height.
|
|
156
|
+
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
|
|
157
|
+
of image width and height.
|
|
159
158
|
eps (float, optional): A small value to avoid division by zero.
|
|
160
159
|
|
|
161
160
|
Returns:
|
|
@@ -169,8 +168,7 @@ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> tor
|
|
|
169
168
|
def kpt_iou(
|
|
170
169
|
kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
|
|
171
170
|
) -> torch.Tensor:
|
|
172
|
-
"""
|
|
173
|
-
Calculate Object Keypoint Similarity (OKS).
|
|
171
|
+
"""Calculate Object Keypoint Similarity (OKS).
|
|
174
172
|
|
|
175
173
|
Args:
|
|
176
174
|
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
|
|
@@ -191,14 +189,14 @@ def kpt_iou(
|
|
|
191
189
|
|
|
192
190
|
|
|
193
191
|
def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
194
|
-
"""
|
|
195
|
-
Generate covariance matrix from oriented bounding boxes.
|
|
192
|
+
"""Generate covariance matrix from oriented bounding boxes.
|
|
196
193
|
|
|
197
194
|
Args:
|
|
198
195
|
boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
|
|
199
196
|
|
|
200
197
|
Returns:
|
|
201
|
-
(torch.Tensor): Covariance
|
|
198
|
+
(tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Covariance matrix components (a, b, c) where the covariance
|
|
199
|
+
matrix is [[a, c], [c, b]], each of shape (N, 1).
|
|
202
200
|
"""
|
|
203
201
|
# Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
|
|
204
202
|
gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
|
|
@@ -211,8 +209,7 @@ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
|
|
|
211
209
|
|
|
212
210
|
|
|
213
211
|
def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
|
214
|
-
"""
|
|
215
|
-
Calculate probabilistic IoU between oriented bounding boxes.
|
|
212
|
+
"""Calculate probabilistic IoU between oriented bounding boxes.
|
|
216
213
|
|
|
217
214
|
Args:
|
|
218
215
|
obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
|
|
@@ -257,8 +254,7 @@ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: flo
|
|
|
257
254
|
|
|
258
255
|
|
|
259
256
|
def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
|
|
260
|
-
"""
|
|
261
|
-
Calculate the probabilistic IoU between oriented bounding boxes.
|
|
257
|
+
"""Calculate the probabilistic IoU between oriented bounding boxes.
|
|
262
258
|
|
|
263
259
|
Args:
|
|
264
260
|
obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
|
|
@@ -294,8 +290,7 @@ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarr
|
|
|
294
290
|
|
|
295
291
|
|
|
296
292
|
def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
|
|
297
|
-
"""
|
|
298
|
-
Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
293
|
+
"""Compute smoothed positive and negative Binary Cross-Entropy targets.
|
|
299
294
|
|
|
300
295
|
Args:
|
|
301
296
|
eps (float, optional): The epsilon value for label smoothing.
|
|
@@ -311,20 +306,18 @@ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
|
|
|
311
306
|
|
|
312
307
|
|
|
313
308
|
class ConfusionMatrix(DataExportMixin):
|
|
314
|
-
"""
|
|
315
|
-
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
309
|
+
"""A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
|
316
310
|
|
|
317
311
|
Attributes:
|
|
318
312
|
task (str): The type of task, either 'detect' or 'classify'.
|
|
319
313
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
|
320
|
-
nc (int): The number of
|
|
314
|
+
nc (int): The number of classes.
|
|
321
315
|
names (list[str]): The names of the classes, used as labels on the plot.
|
|
322
316
|
matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
|
|
323
317
|
"""
|
|
324
318
|
|
|
325
|
-
def __init__(self, names: dict[int, str] =
|
|
326
|
-
"""
|
|
327
|
-
Initialize a ConfusionMatrix instance.
|
|
319
|
+
def __init__(self, names: dict[int, str] = {}, task: str = "detect", save_matches: bool = False):
|
|
320
|
+
"""Initialize a ConfusionMatrix instance.
|
|
328
321
|
|
|
329
322
|
Args:
|
|
330
323
|
names (dict[int, str], optional): Names of classes, used as labels on the plot.
|
|
@@ -338,21 +331,20 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
338
331
|
self.matches = {} if save_matches else None
|
|
339
332
|
|
|
340
333
|
def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
|
|
341
|
-
"""
|
|
342
|
-
Append the matches to TP, FP, FN or GT list for the last batch.
|
|
334
|
+
"""Append the matches to TP, FP, FN or GT list for the last batch.
|
|
343
335
|
|
|
344
|
-
This method updates the matches dictionary by appending specific batch data
|
|
345
|
-
|
|
336
|
+
This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
|
|
337
|
+
Positive, False Positive, or False Negative).
|
|
346
338
|
|
|
347
339
|
Args:
|
|
348
340
|
mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
|
|
349
|
-
batch (dict[str, Any]): Batch data containing detection results with keys
|
|
350
|
-
|
|
341
|
+
batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
|
|
342
|
+
'keypoints', 'masks'.
|
|
351
343
|
idx (int): Index of the specific detection to append from the batch.
|
|
352
344
|
|
|
353
|
-
|
|
354
|
-
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
|
|
355
|
-
|
|
345
|
+
Notes:
|
|
346
|
+
For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
|
|
347
|
+
overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
|
|
356
348
|
"""
|
|
357
349
|
if self.matches is None:
|
|
358
350
|
return
|
|
@@ -364,8 +356,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
364
356
|
self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
|
|
365
357
|
|
|
366
358
|
def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
|
|
367
|
-
"""
|
|
368
|
-
Update confusion matrix for classification task.
|
|
359
|
+
"""Update confusion matrix for classification task.
|
|
369
360
|
|
|
370
361
|
Args:
|
|
371
362
|
preds (list[N, min(nc,5)]): Predicted class labels.
|
|
@@ -382,15 +373,14 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
382
373
|
conf: float = 0.25,
|
|
383
374
|
iou_thres: float = 0.45,
|
|
384
375
|
) -> None:
|
|
385
|
-
"""
|
|
386
|
-
Update confusion matrix for object detection task.
|
|
376
|
+
"""Update confusion matrix for object detection task.
|
|
387
377
|
|
|
388
378
|
Args:
|
|
389
|
-
detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
|
|
393
|
-
'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
379
|
+
detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
|
|
380
|
+
information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
|
|
381
|
+
regular boxes or Array[N, 5] for OBB with angle.
|
|
382
|
+
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
|
|
383
|
+
5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
|
394
384
|
conf (float, optional): Confidence threshold for detections.
|
|
395
385
|
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
|
396
386
|
"""
|
|
@@ -460,8 +450,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
460
450
|
return self.matrix
|
|
461
451
|
|
|
462
452
|
def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
|
|
463
|
-
"""
|
|
464
|
-
Return true positives and false positives.
|
|
453
|
+
"""Return true positives and false positives.
|
|
465
454
|
|
|
466
455
|
Returns:
|
|
467
456
|
tp (np.ndarray): True positives.
|
|
@@ -473,8 +462,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
473
462
|
return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
|
|
474
463
|
|
|
475
464
|
def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
|
|
476
|
-
"""
|
|
477
|
-
Plot grid of GT, TP, FP, FN for each image.
|
|
465
|
+
"""Plot grid of GT, TP, FP, FN for each image.
|
|
478
466
|
|
|
479
467
|
Args:
|
|
480
468
|
img (torch.Tensor): Image to plot onto.
|
|
@@ -513,8 +501,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
513
501
|
@TryExcept(msg="ConfusionMatrix plot failure")
|
|
514
502
|
@plt_settings()
|
|
515
503
|
def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
|
|
516
|
-
"""
|
|
517
|
-
Plot the confusion matrix using matplotlib and save it to a file.
|
|
504
|
+
"""Plot the confusion matrix using matplotlib and save it to a file.
|
|
518
505
|
|
|
519
506
|
Args:
|
|
520
507
|
normalize (bool, optional): Whether to normalize the confusion matrix.
|
|
@@ -535,7 +522,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
535
522
|
array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
|
|
536
523
|
n = (self.nc + k - 1) // k # number of retained classes
|
|
537
524
|
nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
|
|
538
|
-
ticklabels = (names
|
|
525
|
+
ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
|
|
539
526
|
xy_ticks = np.arange(len(ticklabels))
|
|
540
527
|
tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
|
|
541
528
|
label_fontsize = max(6, 12 - 0.1 * nc)
|
|
@@ -582,7 +569,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
582
569
|
fig.savefig(plot_fname, dpi=250)
|
|
583
570
|
plt.close(fig)
|
|
584
571
|
if on_plot:
|
|
585
|
-
on_plot(plot_fname)
|
|
572
|
+
on_plot(plot_fname, {"type": "confusion_matrix", "matrix": self.matrix.tolist()})
|
|
586
573
|
|
|
587
574
|
def print(self):
|
|
588
575
|
"""Print the confusion matrix to the console."""
|
|
@@ -590,16 +577,17 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
590
577
|
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
|
591
578
|
|
|
592
579
|
def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
|
|
593
|
-
"""
|
|
594
|
-
|
|
595
|
-
|
|
580
|
+
"""Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
|
|
581
|
+
normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
|
|
582
|
+
or SQL.
|
|
596
583
|
|
|
597
584
|
Args:
|
|
598
585
|
normalize (bool): Whether to normalize the confusion matrix values.
|
|
599
586
|
decimals (int): Number of decimal places to round the output values to.
|
|
600
587
|
|
|
601
588
|
Returns:
|
|
602
|
-
(list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
|
|
589
|
+
(list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
|
|
590
|
+
values for all actual classes.
|
|
603
591
|
|
|
604
592
|
Examples:
|
|
605
593
|
>>> results = model.val(data="coco8.yaml", plots=True)
|
|
@@ -608,7 +596,7 @@ class ConfusionMatrix(DataExportMixin):
|
|
|
608
596
|
"""
|
|
609
597
|
import re
|
|
610
598
|
|
|
611
|
-
names = list(self.names.values()) if self.task == "classify" else list(self.names.values())
|
|
599
|
+
names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
|
|
612
600
|
clean_names, seen = [], set()
|
|
613
601
|
for name in names:
|
|
614
602
|
clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
|
@@ -643,8 +631,7 @@ def plot_pr_curve(
|
|
|
643
631
|
names: dict[int, str] = {},
|
|
644
632
|
on_plot=None,
|
|
645
633
|
):
|
|
646
|
-
"""
|
|
647
|
-
Plot precision-recall curve.
|
|
634
|
+
"""Plot precision-recall curve.
|
|
648
635
|
|
|
649
636
|
Args:
|
|
650
637
|
px (np.ndarray): X values for the PR curve.
|
|
@@ -663,7 +650,7 @@ def plot_pr_curve(
|
|
|
663
650
|
for i, y in enumerate(py.T):
|
|
664
651
|
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
|
665
652
|
else:
|
|
666
|
-
ax.plot(px, py, linewidth=1, color="
|
|
653
|
+
ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
|
|
667
654
|
|
|
668
655
|
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
|
|
669
656
|
ax.set_xlabel("Recall")
|
|
@@ -675,7 +662,9 @@ def plot_pr_curve(
|
|
|
675
662
|
fig.savefig(save_dir, dpi=250)
|
|
676
663
|
plt.close(fig)
|
|
677
664
|
if on_plot:
|
|
678
|
-
|
|
665
|
+
# Pass PR curve data for interactive plotting (class names stored at model level)
|
|
666
|
+
# Transpose py to match other curves: y[class][point] format
|
|
667
|
+
on_plot(save_dir, {"type": "pr_curve", "x": px.tolist(), "y": py.T.tolist(), "ap": ap.tolist()})
|
|
679
668
|
|
|
680
669
|
|
|
681
670
|
@plt_settings()
|
|
@@ -688,8 +677,7 @@ def plot_mc_curve(
|
|
|
688
677
|
ylabel: str = "Metric",
|
|
689
678
|
on_plot=None,
|
|
690
679
|
):
|
|
691
|
-
"""
|
|
692
|
-
Plot metric-confidence curve.
|
|
680
|
+
"""Plot metric-confidence curve.
|
|
693
681
|
|
|
694
682
|
Args:
|
|
695
683
|
px (np.ndarray): X values for the metric-confidence curve.
|
|
@@ -708,7 +696,7 @@ def plot_mc_curve(
|
|
|
708
696
|
for i, y in enumerate(py):
|
|
709
697
|
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
|
710
698
|
else:
|
|
711
|
-
ax.plot(px, py.T, linewidth=1, color="
|
|
699
|
+
ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
|
|
712
700
|
|
|
713
701
|
y = smooth(py.mean(0), 0.1)
|
|
714
702
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
|
|
@@ -721,12 +709,12 @@ def plot_mc_curve(
|
|
|
721
709
|
fig.savefig(save_dir, dpi=250)
|
|
722
710
|
plt.close(fig)
|
|
723
711
|
if on_plot:
|
|
724
|
-
|
|
712
|
+
# Pass metric-confidence curve data for interactive plotting (class names stored at model level)
|
|
713
|
+
on_plot(save_dir, {"type": f"{ylabel.lower()}_curve", "x": px.tolist(), "y": py.tolist()})
|
|
725
714
|
|
|
726
715
|
|
|
727
716
|
def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
|
|
728
|
-
"""
|
|
729
|
-
Compute the average precision (AP) given the recall and precision curves.
|
|
717
|
+
"""Compute the average precision (AP) given the recall and precision curves.
|
|
730
718
|
|
|
731
719
|
Args:
|
|
732
720
|
recall (list): The recall curve.
|
|
@@ -769,8 +757,7 @@ def ap_per_class(
|
|
|
769
757
|
eps: float = 1e-16,
|
|
770
758
|
prefix: str = "",
|
|
771
759
|
) -> tuple:
|
|
772
|
-
"""
|
|
773
|
-
Compute the average precision per class for object detection evaluation.
|
|
760
|
+
"""Compute the average precision per class for object detection evaluation.
|
|
774
761
|
|
|
775
762
|
Args:
|
|
776
763
|
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
|
@@ -855,8 +842,7 @@ def ap_per_class(
|
|
|
855
842
|
|
|
856
843
|
|
|
857
844
|
class Metric(SimpleClass):
|
|
858
|
-
"""
|
|
859
|
-
Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
845
|
+
"""Class for computing evaluation metrics for Ultralytics YOLO models.
|
|
860
846
|
|
|
861
847
|
Attributes:
|
|
862
848
|
p (list): Precision for each class. Shape: (nc,).
|
|
@@ -894,8 +880,7 @@ class Metric(SimpleClass):
|
|
|
894
880
|
|
|
895
881
|
@property
|
|
896
882
|
def ap50(self) -> np.ndarray | list:
|
|
897
|
-
"""
|
|
898
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
883
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
|
899
884
|
|
|
900
885
|
Returns:
|
|
901
886
|
(np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
|
@@ -904,8 +889,7 @@ class Metric(SimpleClass):
|
|
|
904
889
|
|
|
905
890
|
@property
|
|
906
891
|
def ap(self) -> np.ndarray | list:
|
|
907
|
-
"""
|
|
908
|
-
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
892
|
+
"""Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
|
909
893
|
|
|
910
894
|
Returns:
|
|
911
895
|
(np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
|
@@ -914,8 +898,7 @@ class Metric(SimpleClass):
|
|
|
914
898
|
|
|
915
899
|
@property
|
|
916
900
|
def mp(self) -> float:
|
|
917
|
-
"""
|
|
918
|
-
Return the Mean Precision of all classes.
|
|
901
|
+
"""Return the Mean Precision of all classes.
|
|
919
902
|
|
|
920
903
|
Returns:
|
|
921
904
|
(float): The mean precision of all classes.
|
|
@@ -924,8 +907,7 @@ class Metric(SimpleClass):
|
|
|
924
907
|
|
|
925
908
|
@property
|
|
926
909
|
def mr(self) -> float:
|
|
927
|
-
"""
|
|
928
|
-
Return the Mean Recall of all classes.
|
|
910
|
+
"""Return the Mean Recall of all classes.
|
|
929
911
|
|
|
930
912
|
Returns:
|
|
931
913
|
(float): The mean recall of all classes.
|
|
@@ -934,8 +916,7 @@ class Metric(SimpleClass):
|
|
|
934
916
|
|
|
935
917
|
@property
|
|
936
918
|
def map50(self) -> float:
|
|
937
|
-
"""
|
|
938
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
919
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
|
939
920
|
|
|
940
921
|
Returns:
|
|
941
922
|
(float): The mAP at an IoU threshold of 0.5.
|
|
@@ -944,8 +925,7 @@ class Metric(SimpleClass):
|
|
|
944
925
|
|
|
945
926
|
@property
|
|
946
927
|
def map75(self) -> float:
|
|
947
|
-
"""
|
|
948
|
-
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
928
|
+
"""Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
|
949
929
|
|
|
950
930
|
Returns:
|
|
951
931
|
(float): The mAP at an IoU threshold of 0.75.
|
|
@@ -954,8 +934,7 @@ class Metric(SimpleClass):
|
|
|
954
934
|
|
|
955
935
|
@property
|
|
956
936
|
def map(self) -> float:
|
|
957
|
-
"""
|
|
958
|
-
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
937
|
+
"""Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
959
938
|
|
|
960
939
|
Returns:
|
|
961
940
|
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
|
@@ -984,8 +963,7 @@ class Metric(SimpleClass):
|
|
|
984
963
|
return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
|
|
985
964
|
|
|
986
965
|
def update(self, results: tuple):
|
|
987
|
-
"""
|
|
988
|
-
Update the evaluation metrics with a new set of results.
|
|
966
|
+
"""Update the evaluation metrics with a new set of results.
|
|
989
967
|
|
|
990
968
|
Args:
|
|
991
969
|
results (tuple): A tuple containing evaluation metrics:
|
|
@@ -1030,15 +1008,15 @@ class Metric(SimpleClass):
|
|
|
1030
1008
|
|
|
1031
1009
|
|
|
1032
1010
|
class DetMetrics(SimpleClass, DataExportMixin):
|
|
1033
|
-
"""
|
|
1034
|
-
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
1011
|
+
"""Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
|
1035
1012
|
|
|
1036
1013
|
Attributes:
|
|
1037
1014
|
names (dict[int, str]): A dictionary of class names.
|
|
1038
1015
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1039
1016
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1040
1017
|
task (str): The task type, set to 'detect'.
|
|
1041
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1018
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1019
|
+
target classes, and target images.
|
|
1042
1020
|
nt_per_class: Number of targets per class.
|
|
1043
1021
|
nt_per_image: Number of targets per image.
|
|
1044
1022
|
|
|
@@ -1059,8 +1037,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1059
1037
|
"""
|
|
1060
1038
|
|
|
1061
1039
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1062
|
-
"""
|
|
1063
|
-
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
1040
|
+
"""Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
|
1064
1041
|
|
|
1065
1042
|
Args:
|
|
1066
1043
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1074,19 +1051,17 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1074
1051
|
self.nt_per_image = None
|
|
1075
1052
|
|
|
1076
1053
|
def update_stats(self, stat: dict[str, Any]) -> None:
|
|
1077
|
-
"""
|
|
1078
|
-
Update statistics by appending new values to existing stat collections.
|
|
1054
|
+
"""Update statistics by appending new values to existing stat collections.
|
|
1079
1055
|
|
|
1080
1056
|
Args:
|
|
1081
|
-
stat (dict[str, any]): Dictionary containing new statistical values to append.
|
|
1082
|
-
|
|
1057
|
+
stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
|
|
1058
|
+
keys in self.stats.
|
|
1083
1059
|
"""
|
|
1084
1060
|
for k in self.stats.keys():
|
|
1085
1061
|
self.stats[k].append(stat[k])
|
|
1086
1062
|
|
|
1087
1063
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1088
|
-
"""
|
|
1089
|
-
Process predicted results for object detection and update metrics.
|
|
1064
|
+
"""Process predicted results for object detection and update metrics.
|
|
1090
1065
|
|
|
1091
1066
|
Args:
|
|
1092
1067
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1152,8 +1127,8 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1152
1127
|
@property
|
|
1153
1128
|
def results_dict(self) -> dict[str, float]:
|
|
1154
1129
|
"""Return dictionary of computed performance metrics and statistics."""
|
|
1155
|
-
keys = self.keys
|
|
1156
|
-
values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results()
|
|
1130
|
+
keys = [*self.keys, "fitness"]
|
|
1131
|
+
values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
|
|
1157
1132
|
return dict(zip(keys, values))
|
|
1158
1133
|
|
|
1159
1134
|
@property
|
|
@@ -1167,16 +1142,16 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1167
1142
|
return self.box.curves_results
|
|
1168
1143
|
|
|
1169
1144
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1170
|
-
"""
|
|
1171
|
-
|
|
1172
|
-
scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1145
|
+
"""Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
|
|
1146
|
+
shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1173
1147
|
|
|
1174
1148
|
Args:
|
|
1175
|
-
|
|
1176
|
-
|
|
1149
|
+
normalize (bool): For Detect metrics, everything is normalized by default [0-1].
|
|
1150
|
+
decimals (int): Number of decimal places to round the metrics values to.
|
|
1177
1151
|
|
|
1178
1152
|
Returns:
|
|
1179
|
-
|
|
1153
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1154
|
+
values.
|
|
1180
1155
|
|
|
1181
1156
|
Examples:
|
|
1182
1157
|
>>> results = model.val(data="coco8.yaml")
|
|
@@ -1202,8 +1177,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
|
1202
1177
|
|
|
1203
1178
|
|
|
1204
1179
|
class SegmentMetrics(DetMetrics):
|
|
1205
|
-
"""
|
|
1206
|
-
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
1180
|
+
"""Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
|
1207
1181
|
|
|
1208
1182
|
Attributes:
|
|
1209
1183
|
names (dict[int, str]): Dictionary of class names.
|
|
@@ -1211,7 +1185,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1211
1185
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
|
1212
1186
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1213
1187
|
task (str): The task type, set to 'segment'.
|
|
1214
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1188
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1189
|
+
target classes, and target images.
|
|
1215
1190
|
nt_per_class: Number of targets per class.
|
|
1216
1191
|
nt_per_image: Number of targets per image.
|
|
1217
1192
|
|
|
@@ -1228,8 +1203,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1228
1203
|
"""
|
|
1229
1204
|
|
|
1230
1205
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1231
|
-
"""
|
|
1232
|
-
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
1206
|
+
"""Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
|
1233
1207
|
|
|
1234
1208
|
Args:
|
|
1235
1209
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1240,8 +1214,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1240
1214
|
self.stats["tp_m"] = [] # add additional stats for masks
|
|
1241
1215
|
|
|
1242
1216
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1243
|
-
"""
|
|
1244
|
-
Process the detection and segmentation metrics over the given set of predictions.
|
|
1217
|
+
"""Process the detection and segmentation metrics over the given set of predictions.
|
|
1245
1218
|
|
|
1246
1219
|
Args:
|
|
1247
1220
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1270,7 +1243,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1270
1243
|
@property
|
|
1271
1244
|
def keys(self) -> list[str]:
|
|
1272
1245
|
"""Return a list of keys for accessing metrics."""
|
|
1273
|
-
return
|
|
1246
|
+
return [
|
|
1247
|
+
*DetMetrics.keys.fget(self),
|
|
1274
1248
|
"metrics/precision(M)",
|
|
1275
1249
|
"metrics/recall(M)",
|
|
1276
1250
|
"metrics/mAP50(M)",
|
|
@@ -1298,7 +1272,8 @@ class SegmentMetrics(DetMetrics):
|
|
|
1298
1272
|
@property
|
|
1299
1273
|
def curves(self) -> list[str]:
|
|
1300
1274
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1301
|
-
return
|
|
1275
|
+
return [
|
|
1276
|
+
*DetMetrics.curves.fget(self),
|
|
1302
1277
|
"Precision-Recall(M)",
|
|
1303
1278
|
"F1-Confidence(M)",
|
|
1304
1279
|
"Precision-Confidence(M)",
|
|
@@ -1311,16 +1286,17 @@ class SegmentMetrics(DetMetrics):
|
|
|
1311
1286
|
return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
|
1312
1287
|
|
|
1313
1288
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1314
|
-
"""
|
|
1315
|
-
|
|
1316
|
-
|
|
1289
|
+
"""Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
|
|
1290
|
+
both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
|
|
1291
|
+
each class.
|
|
1317
1292
|
|
|
1318
1293
|
Args:
|
|
1319
|
-
normalize (bool): For Segment metrics, everything is normalized
|
|
1294
|
+
normalize (bool): For Segment metrics, everything is normalized by default [0-1].
|
|
1320
1295
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1321
1296
|
|
|
1322
1297
|
Returns:
|
|
1323
|
-
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1298
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1299
|
+
values.
|
|
1324
1300
|
|
|
1325
1301
|
Examples:
|
|
1326
1302
|
>>> results = model.val(data="coco8-seg.yaml")
|
|
@@ -1339,8 +1315,7 @@ class SegmentMetrics(DetMetrics):
|
|
|
1339
1315
|
|
|
1340
1316
|
|
|
1341
1317
|
class PoseMetrics(DetMetrics):
|
|
1342
|
-
"""
|
|
1343
|
-
Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
1318
|
+
"""Calculate and aggregate detection and pose metrics over a given set of classes.
|
|
1344
1319
|
|
|
1345
1320
|
Attributes:
|
|
1346
1321
|
names (dict[int, str]): Dictionary of class names.
|
|
@@ -1348,7 +1323,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1348
1323
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1349
1324
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1350
1325
|
task (str): The task type, set to 'pose'.
|
|
1351
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1326
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1327
|
+
target classes, and target images.
|
|
1352
1328
|
nt_per_class: Number of targets per class.
|
|
1353
1329
|
nt_per_image: Number of targets per image.
|
|
1354
1330
|
|
|
@@ -1365,8 +1341,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1365
1341
|
"""
|
|
1366
1342
|
|
|
1367
1343
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1368
|
-
"""
|
|
1369
|
-
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1344
|
+
"""Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
|
1370
1345
|
|
|
1371
1346
|
Args:
|
|
1372
1347
|
names (dict[int, str], optional): Dictionary of class names.
|
|
@@ -1377,8 +1352,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1377
1352
|
self.stats["tp_p"] = [] # add additional stats for pose
|
|
1378
1353
|
|
|
1379
1354
|
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
|
|
1380
|
-
"""
|
|
1381
|
-
Process the detection and pose metrics over the given set of predictions.
|
|
1355
|
+
"""Process the detection and pose metrics over the given set of predictions.
|
|
1382
1356
|
|
|
1383
1357
|
Args:
|
|
1384
1358
|
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
|
@@ -1407,7 +1381,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1407
1381
|
@property
|
|
1408
1382
|
def keys(self) -> list[str]:
|
|
1409
1383
|
"""Return a list of evaluation metric keys."""
|
|
1410
|
-
return
|
|
1384
|
+
return [
|
|
1385
|
+
*DetMetrics.keys.fget(self),
|
|
1411
1386
|
"metrics/precision(P)",
|
|
1412
1387
|
"metrics/recall(P)",
|
|
1413
1388
|
"metrics/mAP50(P)",
|
|
@@ -1435,7 +1410,8 @@ class PoseMetrics(DetMetrics):
|
|
|
1435
1410
|
@property
|
|
1436
1411
|
def curves(self) -> list[str]:
|
|
1437
1412
|
"""Return a list of curves for accessing specific metrics curves."""
|
|
1438
|
-
return
|
|
1413
|
+
return [
|
|
1414
|
+
*DetMetrics.curves.fget(self),
|
|
1439
1415
|
"Precision-Recall(B)",
|
|
1440
1416
|
"F1-Confidence(B)",
|
|
1441
1417
|
"Precision-Confidence(B)",
|
|
@@ -1452,16 +1428,16 @@ class PoseMetrics(DetMetrics):
|
|
|
1452
1428
|
return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
|
1453
1429
|
|
|
1454
1430
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
|
|
1455
|
-
"""
|
|
1456
|
-
|
|
1457
|
-
pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1431
|
+
"""Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
|
|
1432
|
+
and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
|
|
1458
1433
|
|
|
1459
1434
|
Args:
|
|
1460
|
-
normalize (bool): For Pose metrics, everything is normalized
|
|
1435
|
+
normalize (bool): For Pose metrics, everything is normalized by default [0-1].
|
|
1461
1436
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1462
1437
|
|
|
1463
1438
|
Returns:
|
|
1464
|
-
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1439
|
+
(list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
|
|
1440
|
+
values.
|
|
1465
1441
|
|
|
1466
1442
|
Examples:
|
|
1467
1443
|
>>> results = model.val(data="coco8-pose.yaml")
|
|
@@ -1480,8 +1456,7 @@ class PoseMetrics(DetMetrics):
|
|
|
1480
1456
|
|
|
1481
1457
|
|
|
1482
1458
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1483
|
-
"""
|
|
1484
|
-
Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
1459
|
+
"""Class for computing classification metrics including top-1 and top-5 accuracy.
|
|
1485
1460
|
|
|
1486
1461
|
Attributes:
|
|
1487
1462
|
top1 (float): The top-1 accuracy.
|
|
@@ -1507,8 +1482,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1507
1482
|
self.task = "classify"
|
|
1508
1483
|
|
|
1509
1484
|
def process(self, targets: torch.Tensor, pred: torch.Tensor):
|
|
1510
|
-
"""
|
|
1511
|
-
Process target classes and predicted classes to compute metrics.
|
|
1485
|
+
"""Process target classes and predicted classes to compute metrics.
|
|
1512
1486
|
|
|
1513
1487
|
Args:
|
|
1514
1488
|
targets (torch.Tensor): Target classes.
|
|
@@ -1527,7 +1501,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1527
1501
|
@property
|
|
1528
1502
|
def results_dict(self) -> dict[str, float]:
|
|
1529
1503
|
"""Return a dictionary with model's performance metrics and fitness score."""
|
|
1530
|
-
return dict(zip(self.keys
|
|
1504
|
+
return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
|
|
1531
1505
|
|
|
1532
1506
|
@property
|
|
1533
1507
|
def keys(self) -> list[str]:
|
|
@@ -1545,11 +1519,10 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1545
1519
|
return []
|
|
1546
1520
|
|
|
1547
1521
|
def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
|
|
1548
|
-
"""
|
|
1549
|
-
Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1522
|
+
"""Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
|
|
1550
1523
|
|
|
1551
1524
|
Args:
|
|
1552
|
-
normalize (bool): For Classify metrics, everything is normalized
|
|
1525
|
+
normalize (bool): For Classify metrics, everything is normalized by default [0-1].
|
|
1553
1526
|
decimals (int): Number of decimal places to round the metrics values to.
|
|
1554
1527
|
|
|
1555
1528
|
Returns:
|
|
@@ -1564,15 +1537,15 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
|
1564
1537
|
|
|
1565
1538
|
|
|
1566
1539
|
class OBBMetrics(DetMetrics):
|
|
1567
|
-
"""
|
|
1568
|
-
Metrics for evaluating oriented bounding box (OBB) detection.
|
|
1540
|
+
"""Metrics for evaluating oriented bounding box (OBB) detection.
|
|
1569
1541
|
|
|
1570
1542
|
Attributes:
|
|
1571
1543
|
names (dict[int, str]): Dictionary of class names.
|
|
1572
1544
|
box (Metric): An instance of the Metric class for storing detection results.
|
|
1573
1545
|
speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
|
1574
1546
|
task (str): The task type, set to 'obb'.
|
|
1575
|
-
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1547
|
+
stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
|
|
1548
|
+
target classes, and target images.
|
|
1576
1549
|
nt_per_class: Number of targets per class.
|
|
1577
1550
|
nt_per_image: Number of targets per image.
|
|
1578
1551
|
|
|
@@ -1581,8 +1554,7 @@ class OBBMetrics(DetMetrics):
|
|
|
1581
1554
|
"""
|
|
1582
1555
|
|
|
1583
1556
|
def __init__(self, names: dict[int, str] = {}) -> None:
|
|
1584
|
-
"""
|
|
1585
|
-
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1557
|
+
"""Initialize an OBBMetrics instance with directory, plotting, and class names.
|
|
1586
1558
|
|
|
1587
1559
|
Args:
|
|
1588
1560
|
names (dict[int, str], optional): Dictionary of class names.
|