ultralytics 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/METADATA +2 -2
- ultralytics-8.3.145.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
@@ -4,6 +4,7 @@
|
|
4
4
|
import math
|
5
5
|
import warnings
|
6
6
|
from pathlib import Path
|
7
|
+
from typing import Dict, List, Tuple, Union
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import torch
|
@@ -16,18 +17,18 @@ OKS_SIGMA = (
|
|
16
17
|
)
|
17
18
|
|
18
19
|
|
19
|
-
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
20
|
+
def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
|
20
21
|
"""
|
21
|
-
Calculate the intersection over box2 area given box1 and box2.
|
22
|
+
Calculate the intersection over box2 area given box1 and box2.
|
22
23
|
|
23
24
|
Args:
|
24
|
-
box1 (np.ndarray): A numpy array of shape (
|
25
|
-
box2 (np.ndarray): A numpy array of shape (
|
26
|
-
iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
|
25
|
+
box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
|
26
|
+
box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.
|
27
|
+
iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.
|
27
28
|
eps (float, optional): A small value to avoid division by zero.
|
28
29
|
|
29
30
|
Returns:
|
30
|
-
(np.ndarray): A numpy array of shape (
|
31
|
+
(np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.
|
31
32
|
"""
|
32
33
|
# Get the coordinates of bounding boxes
|
33
34
|
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
@@ -48,18 +49,20 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
|
48
49
|
return inter_area / (area + eps)
|
49
50
|
|
50
51
|
|
51
|
-
def box_iou(box1, box2, eps=1e-7):
|
52
|
+
def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
52
53
|
"""
|
53
|
-
Calculate intersection-over-union (IoU) of boxes.
|
54
|
-
Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
|
54
|
+
Calculate intersection-over-union (IoU) of boxes.
|
55
55
|
|
56
56
|
Args:
|
57
|
-
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
58
|
-
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
57
|
+
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
|
58
|
+
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.
|
59
59
|
eps (float, optional): A small value to avoid division by zero.
|
60
60
|
|
61
61
|
Returns:
|
62
62
|
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
63
|
+
|
64
|
+
References:
|
65
|
+
https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py
|
63
66
|
"""
|
64
67
|
# NOTE: Need .float() to get accurate iou values
|
65
68
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
@@ -70,7 +73,15 @@ def box_iou(box1, box2, eps=1e-7):
|
|
70
73
|
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
71
74
|
|
72
75
|
|
73
|
-
def bbox_iou(
|
76
|
+
def bbox_iou(
|
77
|
+
box1: torch.Tensor,
|
78
|
+
box2: torch.Tensor,
|
79
|
+
xywh: bool = True,
|
80
|
+
GIoU: bool = False,
|
81
|
+
DIoU: bool = False,
|
82
|
+
CIoU: bool = False,
|
83
|
+
eps: float = 1e-7,
|
84
|
+
) -> torch.Tensor:
|
74
85
|
"""
|
75
86
|
Calculate the Intersection over Union (IoU) between bounding boxes.
|
76
87
|
|
@@ -133,7 +144,7 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
|
133
144
|
return iou # IoU
|
134
145
|
|
135
146
|
|
136
|
-
def mask_iou(mask1, mask2, eps=1e-7):
|
147
|
+
def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
|
137
148
|
"""
|
138
149
|
Calculate masks IoU.
|
139
150
|
|
@@ -152,7 +163,9 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
|
152
163
|
return intersection / (union + eps)
|
153
164
|
|
154
165
|
|
155
|
-
def kpt_iou(
|
166
|
+
def kpt_iou(
|
167
|
+
kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: List[float], eps: float = 1e-7
|
168
|
+
) -> torch.Tensor:
|
156
169
|
"""
|
157
170
|
Calculate Object Keypoint Similarity (OKS).
|
158
171
|
|
@@ -174,7 +187,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
|
174
187
|
return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
175
188
|
|
176
189
|
|
177
|
-
def _get_covariance_matrix(boxes):
|
190
|
+
def _get_covariance_matrix(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178
191
|
"""
|
179
192
|
Generate covariance matrix from oriented bounding boxes.
|
180
193
|
|
@@ -194,7 +207,7 @@ def _get_covariance_matrix(boxes):
|
|
194
207
|
return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
|
195
208
|
|
196
209
|
|
197
|
-
def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
210
|
+
def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
198
211
|
"""
|
199
212
|
Calculate probabilistic IoU between oriented bounding boxes.
|
200
213
|
|
@@ -208,8 +221,10 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
208
221
|
(torch.Tensor): OBB similarities, shape (N,).
|
209
222
|
|
210
223
|
Notes:
|
211
|
-
|
212
|
-
|
224
|
+
OBB format: [center_x, center_y, width, height, rotation_angle].
|
225
|
+
|
226
|
+
References:
|
227
|
+
https://arxiv.org/pdf/2106.06072v1.pdf
|
213
228
|
"""
|
214
229
|
x1, y1 = obb1[..., :2].split(1, dim=-1)
|
215
230
|
x2, y2 = obb2[..., :2].split(1, dim=-1)
|
@@ -238,7 +253,9 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
|
|
238
253
|
return iou
|
239
254
|
|
240
255
|
|
241
|
-
def batch_probiou(
|
256
|
+
def batch_probiou(
|
257
|
+
obb1: Union[torch.Tensor, np.ndarray], obb2: Union[torch.Tensor, np.ndarray], eps: float = 1e-7
|
258
|
+
) -> torch.Tensor:
|
242
259
|
"""
|
243
260
|
Calculate the probabilistic IoU between oriented bounding boxes.
|
244
261
|
|
@@ -275,7 +292,7 @@ def batch_probiou(obb1, obb2, eps=1e-7):
|
|
275
292
|
return 1 - hd
|
276
293
|
|
277
294
|
|
278
|
-
def smooth_bce(eps=0.1):
|
295
|
+
def smooth_bce(eps: float = 0.1) -> Tuple[float, float]:
|
279
296
|
"""
|
280
297
|
Compute smoothed positive and negative Binary Cross-Entropy targets.
|
281
298
|
|
@@ -283,7 +300,8 @@ def smooth_bce(eps=0.1):
|
|
283
300
|
eps (float, optional): The epsilon value for label smoothing.
|
284
301
|
|
285
302
|
Returns:
|
286
|
-
(
|
303
|
+
pos (float): Positive label smoothing BCE target.
|
304
|
+
neg (float): Negative label smoothing BCE target.
|
287
305
|
|
288
306
|
References:
|
289
307
|
https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
@@ -303,7 +321,7 @@ class ConfusionMatrix:
|
|
303
321
|
iou_thres (float): The Intersection over Union threshold.
|
304
322
|
"""
|
305
323
|
|
306
|
-
def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
|
324
|
+
def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, task: str = "detect"):
|
307
325
|
"""
|
308
326
|
Initialize a ConfusionMatrix instance.
|
309
327
|
|
@@ -393,12 +411,13 @@ class ConfusionMatrix:
|
|
393
411
|
"""Return the confusion matrix."""
|
394
412
|
return self.matrix
|
395
413
|
|
396
|
-
def tp_fp(self):
|
414
|
+
def tp_fp(self) -> Tuple[np.ndarray, np.ndarray]:
|
397
415
|
"""
|
398
416
|
Return true positives and false positives.
|
399
417
|
|
400
418
|
Returns:
|
401
|
-
(
|
419
|
+
tp (np.ndarray): True positives.
|
420
|
+
fp (np.ndarray): False positives.
|
402
421
|
"""
|
403
422
|
tp = self.matrix.diagonal() # true positives
|
404
423
|
fp = self.matrix.sum(1) - tp # false positives
|
@@ -407,15 +426,15 @@ class ConfusionMatrix:
|
|
407
426
|
|
408
427
|
@TryExcept(msg="ConfusionMatrix plot failure")
|
409
428
|
@plt_settings()
|
410
|
-
def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
|
429
|
+
def plot(self, normalize: bool = True, save_dir: str = "", names: tuple = (), on_plot=None):
|
411
430
|
"""
|
412
431
|
Plot the confusion matrix using matplotlib and save it to a file.
|
413
432
|
|
414
433
|
Args:
|
415
|
-
normalize (bool): Whether to normalize the confusion matrix.
|
416
|
-
save_dir (str): Directory where the plot will be saved.
|
417
|
-
names (tuple): Names of classes, used as labels on the plot.
|
418
|
-
on_plot (
|
434
|
+
normalize (bool, optional): Whether to normalize the confusion matrix.
|
435
|
+
save_dir (str, optional): Directory where the plot will be saved.
|
436
|
+
names (tuple, optional): Names of classes, used as labels on the plot.
|
437
|
+
on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.
|
419
438
|
"""
|
420
439
|
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
421
440
|
|
@@ -487,7 +506,7 @@ class ConfusionMatrix:
|
|
487
506
|
LOGGER.info(" ".join(map(str, self.matrix[i])))
|
488
507
|
|
489
508
|
|
490
|
-
def smooth(y, f=0.05):
|
509
|
+
def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:
|
491
510
|
"""Box filter of fraction f."""
|
492
511
|
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
493
512
|
p = np.ones(nf // 2) # ones padding
|
@@ -496,7 +515,14 @@ def smooth(y, f=0.05):
|
|
496
515
|
|
497
516
|
|
498
517
|
@plt_settings()
|
499
|
-
def plot_pr_curve(
|
518
|
+
def plot_pr_curve(
|
519
|
+
px: np.ndarray,
|
520
|
+
py: np.ndarray,
|
521
|
+
ap: np.ndarray,
|
522
|
+
save_dir: Path = Path("pr_curve.png"),
|
523
|
+
names: dict = {},
|
524
|
+
on_plot=None,
|
525
|
+
):
|
500
526
|
"""
|
501
527
|
Plot precision-recall curve.
|
502
528
|
|
@@ -533,7 +559,15 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
|
|
533
559
|
|
534
560
|
|
535
561
|
@plt_settings()
|
536
|
-
def plot_mc_curve(
|
562
|
+
def plot_mc_curve(
|
563
|
+
px: np.ndarray,
|
564
|
+
py: np.ndarray,
|
565
|
+
save_dir: Path = Path("mc_curve.png"),
|
566
|
+
names: dict = {},
|
567
|
+
xlabel: str = "Confidence",
|
568
|
+
ylabel: str = "Metric",
|
569
|
+
on_plot=None,
|
570
|
+
):
|
537
571
|
"""
|
538
572
|
Plot metric-confidence curve.
|
539
573
|
|
@@ -570,7 +604,7 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
|
|
570
604
|
on_plot(save_dir)
|
571
605
|
|
572
606
|
|
573
|
-
def compute_ap(recall, precision):
|
607
|
+
def compute_ap(recall: List[float], precision: List[float]) -> Tuple[float, np.ndarray, np.ndarray]:
|
574
608
|
"""
|
575
609
|
Compute the average precision (AP) given the recall and precision curves.
|
576
610
|
|
@@ -579,9 +613,9 @@ def compute_ap(recall, precision):
|
|
579
613
|
precision (list): The precision curve.
|
580
614
|
|
581
615
|
Returns:
|
582
|
-
(float): Average precision.
|
583
|
-
(np.ndarray): Precision envelope curve.
|
584
|
-
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
616
|
+
ap (float): Average precision.
|
617
|
+
mpre (np.ndarray): Precision envelope curve.
|
618
|
+
mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
585
619
|
"""
|
586
620
|
# Append sentinel values to beginning and end
|
587
621
|
mrec = np.concatenate(([0.0], recall, [1.0]))
|
@@ -604,8 +638,17 @@ def compute_ap(recall, precision):
|
|
604
638
|
|
605
639
|
|
606
640
|
def ap_per_class(
|
607
|
-
tp
|
608
|
-
|
641
|
+
tp: np.ndarray,
|
642
|
+
conf: np.ndarray,
|
643
|
+
pred_cls: np.ndarray,
|
644
|
+
target_cls: np.ndarray,
|
645
|
+
plot: bool = False,
|
646
|
+
on_plot=None,
|
647
|
+
save_dir: Path = Path(),
|
648
|
+
names: dict = {},
|
649
|
+
eps: float = 1e-16,
|
650
|
+
prefix: str = "",
|
651
|
+
) -> Tuple:
|
609
652
|
"""
|
610
653
|
Compute the average precision per class for object detection evaluation.
|
611
654
|
|
@@ -615,7 +658,7 @@ def ap_per_class(
|
|
615
658
|
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
616
659
|
target_cls (np.ndarray): Array of true classes of the detections.
|
617
660
|
plot (bool, optional): Whether to plot PR curves or not.
|
618
|
-
on_plot (
|
661
|
+
on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
|
619
662
|
save_dir (Path, optional): Directory to save the PR curves.
|
620
663
|
names (dict, optional): Dict of class names to plot PR curves.
|
621
664
|
eps (float, optional): A small value to avoid division by zero.
|
@@ -729,27 +772,27 @@ class Metric(SimpleClass):
|
|
729
772
|
self.nc = 0
|
730
773
|
|
731
774
|
@property
|
732
|
-
def ap50(self):
|
775
|
+
def ap50(self) -> Union[np.ndarray, List]:
|
733
776
|
"""
|
734
777
|
Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
735
778
|
|
736
779
|
Returns:
|
737
|
-
(np.ndarray
|
780
|
+
(np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
738
781
|
"""
|
739
782
|
return self.all_ap[:, 0] if len(self.all_ap) else []
|
740
783
|
|
741
784
|
@property
|
742
|
-
def ap(self):
|
785
|
+
def ap(self) -> Union[np.ndarray, List]:
|
743
786
|
"""
|
744
787
|
Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
745
788
|
|
746
789
|
Returns:
|
747
|
-
(np.ndarray
|
790
|
+
(np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
748
791
|
"""
|
749
792
|
return self.all_ap.mean(1) if len(self.all_ap) else []
|
750
793
|
|
751
794
|
@property
|
752
|
-
def mp(self):
|
795
|
+
def mp(self) -> float:
|
753
796
|
"""
|
754
797
|
Return the Mean Precision of all classes.
|
755
798
|
|
@@ -759,7 +802,7 @@ class Metric(SimpleClass):
|
|
759
802
|
return self.p.mean() if len(self.p) else 0.0
|
760
803
|
|
761
804
|
@property
|
762
|
-
def mr(self):
|
805
|
+
def mr(self) -> float:
|
763
806
|
"""
|
764
807
|
Return the Mean Recall of all classes.
|
765
808
|
|
@@ -769,7 +812,7 @@ class Metric(SimpleClass):
|
|
769
812
|
return self.r.mean() if len(self.r) else 0.0
|
770
813
|
|
771
814
|
@property
|
772
|
-
def map50(self):
|
815
|
+
def map50(self) -> float:
|
773
816
|
"""
|
774
817
|
Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
775
818
|
|
@@ -779,7 +822,7 @@ class Metric(SimpleClass):
|
|
779
822
|
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
780
823
|
|
781
824
|
@property
|
782
|
-
def map75(self):
|
825
|
+
def map75(self) -> float:
|
783
826
|
"""
|
784
827
|
Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
785
828
|
|
@@ -789,7 +832,7 @@ class Metric(SimpleClass):
|
|
789
832
|
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
790
833
|
|
791
834
|
@property
|
792
|
-
def map(self):
|
835
|
+
def map(self) -> float:
|
793
836
|
"""
|
794
837
|
Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
795
838
|
|
@@ -798,28 +841,28 @@ class Metric(SimpleClass):
|
|
798
841
|
"""
|
799
842
|
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
800
843
|
|
801
|
-
def mean_results(self):
|
844
|
+
def mean_results(self) -> List[float]:
|
802
845
|
"""Return mean of results, mp, mr, map50, map."""
|
803
846
|
return [self.mp, self.mr, self.map50, self.map]
|
804
847
|
|
805
|
-
def class_result(self, i):
|
848
|
+
def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
806
849
|
"""Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
|
807
850
|
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
808
851
|
|
809
852
|
@property
|
810
|
-
def maps(self):
|
853
|
+
def maps(self) -> np.ndarray:
|
811
854
|
"""Return mAP of each class."""
|
812
855
|
maps = np.zeros(self.nc) + self.map
|
813
856
|
for i, c in enumerate(self.ap_class_index):
|
814
857
|
maps[c] = self.ap[i]
|
815
858
|
return maps
|
816
859
|
|
817
|
-
def fitness(self):
|
860
|
+
def fitness(self) -> float:
|
818
861
|
"""Return model fitness as a weighted combination of metrics."""
|
819
862
|
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
820
863
|
return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
|
821
864
|
|
822
|
-
def update(self, results):
|
865
|
+
def update(self, results: tuple):
|
823
866
|
"""
|
824
867
|
Update the evaluation metrics with a new set of results.
|
825
868
|
|
@@ -850,12 +893,12 @@ class Metric(SimpleClass):
|
|
850
893
|
) = results
|
851
894
|
|
852
895
|
@property
|
853
|
-
def curves(self):
|
896
|
+
def curves(self) -> List:
|
854
897
|
"""Return a list of curves for accessing specific metrics curves."""
|
855
898
|
return []
|
856
899
|
|
857
900
|
@property
|
858
|
-
def curves_results(self):
|
901
|
+
def curves_results(self) -> List[List]:
|
859
902
|
"""Return a list of curves for accessing specific metrics curves."""
|
860
903
|
return [
|
861
904
|
[self.px, self.prec_values, "Recall", "Precision"],
|
@@ -878,7 +921,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
878
921
|
task (str): The task type, set to 'detect'.
|
879
922
|
"""
|
880
923
|
|
881
|
-
def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
|
924
|
+
def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: dict = {}) -> None:
|
882
925
|
"""
|
883
926
|
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
884
927
|
|
@@ -894,7 +937,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
894
937
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
895
938
|
self.task = "detect"
|
896
939
|
|
897
|
-
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
|
940
|
+
def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
898
941
|
"""
|
899
942
|
Process predicted results for object detection and update metrics.
|
900
943
|
|
@@ -919,50 +962,50 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
919
962
|
self.box.update(results)
|
920
963
|
|
921
964
|
@property
|
922
|
-
def keys(self):
|
965
|
+
def keys(self) -> List[str]:
|
923
966
|
"""Return a list of keys for accessing specific metrics."""
|
924
967
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
925
968
|
|
926
|
-
def mean_results(self):
|
969
|
+
def mean_results(self) -> List[float]:
|
927
970
|
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
928
971
|
return self.box.mean_results()
|
929
972
|
|
930
|
-
def class_result(self, i):
|
973
|
+
def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
931
974
|
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
932
975
|
return self.box.class_result(i)
|
933
976
|
|
934
977
|
@property
|
935
|
-
def maps(self):
|
978
|
+
def maps(self) -> np.ndarray:
|
936
979
|
"""Return mean Average Precision (mAP) scores per class."""
|
937
980
|
return self.box.maps
|
938
981
|
|
939
982
|
@property
|
940
|
-
def fitness(self):
|
983
|
+
def fitness(self) -> float:
|
941
984
|
"""Return the fitness of box object."""
|
942
985
|
return self.box.fitness()
|
943
986
|
|
944
987
|
@property
|
945
|
-
def ap_class_index(self):
|
988
|
+
def ap_class_index(self) -> List:
|
946
989
|
"""Return the average precision index per class."""
|
947
990
|
return self.box.ap_class_index
|
948
991
|
|
949
992
|
@property
|
950
|
-
def results_dict(self):
|
993
|
+
def results_dict(self) -> Dict[str, float]:
|
951
994
|
"""Return dictionary of computed performance metrics and statistics."""
|
952
995
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
953
996
|
|
954
997
|
@property
|
955
|
-
def curves(self):
|
998
|
+
def curves(self) -> List[str]:
|
956
999
|
"""Return a list of curves for accessing specific metrics curves."""
|
957
1000
|
return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
|
958
1001
|
|
959
1002
|
@property
|
960
|
-
def curves_results(self):
|
1003
|
+
def curves_results(self) -> List[List]:
|
961
1004
|
"""Return dictionary of computed performance metrics and statistics."""
|
962
1005
|
return self.box.curves_results
|
963
1006
|
|
964
|
-
def summary(self, **kwargs):
|
965
|
-
"""
|
1007
|
+
def summary(self, **kwargs) -> List[Dict[str, Union[str, float]]]:
|
1008
|
+
"""Return per-class detection metrics with shared scalar values included."""
|
966
1009
|
scalars = {
|
967
1010
|
"box-map": self.box.map,
|
968
1011
|
"box-map50": self.box.map50,
|
@@ -985,7 +1028,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
985
1028
|
|
986
1029
|
class SegmentMetrics(SimpleClass, DataExportMixin):
|
987
1030
|
"""
|
988
|
-
|
1031
|
+
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
989
1032
|
|
990
1033
|
Attributes:
|
991
1034
|
save_dir (Path): Path to the directory where the output plots should be saved.
|
@@ -997,14 +1040,14 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
997
1040
|
task (str): The task type, set to 'segment'.
|
998
1041
|
"""
|
999
1042
|
|
1000
|
-
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1043
|
+
def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
|
1001
1044
|
"""
|
1002
1045
|
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
1003
1046
|
|
1004
1047
|
Args:
|
1005
1048
|
save_dir (Path, optional): Directory to save plots.
|
1006
1049
|
plot (bool, optional): Whether to plot precision-recall curves.
|
1007
|
-
names (
|
1050
|
+
names (tuple, optional): Tuple mapping class indices to names.
|
1008
1051
|
"""
|
1009
1052
|
self.save_dir = save_dir
|
1010
1053
|
self.plot = plot
|
@@ -1014,7 +1057,15 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1014
1057
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1015
1058
|
self.task = "segment"
|
1016
1059
|
|
1017
|
-
def process(
|
1060
|
+
def process(
|
1061
|
+
self,
|
1062
|
+
tp: np.ndarray,
|
1063
|
+
tp_m: np.ndarray,
|
1064
|
+
conf: np.ndarray,
|
1065
|
+
pred_cls: np.ndarray,
|
1066
|
+
target_cls: np.ndarray,
|
1067
|
+
on_plot=None,
|
1068
|
+
):
|
1018
1069
|
"""
|
1019
1070
|
Process the detection and segmentation metrics over the given set of predictions.
|
1020
1071
|
|
@@ -1054,7 +1105,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1054
1105
|
self.box.update(results_box)
|
1055
1106
|
|
1056
1107
|
@property
|
1057
|
-
def keys(self):
|
1108
|
+
def keys(self) -> List[str]:
|
1058
1109
|
"""Return a list of keys for accessing metrics."""
|
1059
1110
|
return [
|
1060
1111
|
"metrics/precision(B)",
|
@@ -1067,40 +1118,36 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1067
1118
|
"metrics/mAP50-95(M)",
|
1068
1119
|
]
|
1069
1120
|
|
1070
|
-
def mean_results(self):
|
1121
|
+
def mean_results(self) -> List[float]:
|
1071
1122
|
"""Return the mean metrics for bounding box and segmentation results."""
|
1072
1123
|
return self.box.mean_results() + self.seg.mean_results()
|
1073
1124
|
|
1074
|
-
def class_result(self, i):
|
1125
|
+
def class_result(self, i: int) -> List[float]:
|
1075
1126
|
"""Return classification results for a specified class index."""
|
1076
1127
|
return self.box.class_result(i) + self.seg.class_result(i)
|
1077
1128
|
|
1078
1129
|
@property
|
1079
|
-
def maps(self):
|
1130
|
+
def maps(self) -> np.ndarray:
|
1080
1131
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
1081
1132
|
return self.box.maps + self.seg.maps
|
1082
1133
|
|
1083
1134
|
@property
|
1084
|
-
def fitness(self):
|
1135
|
+
def fitness(self) -> float:
|
1085
1136
|
"""Return the fitness score for both segmentation and bounding box models."""
|
1086
1137
|
return self.seg.fitness() + self.box.fitness()
|
1087
1138
|
|
1088
1139
|
@property
|
1089
|
-
def ap_class_index(self):
|
1090
|
-
"""
|
1091
|
-
Return the class indices.
|
1092
|
-
|
1093
|
-
Boxes and masks have the same ap_class_index.
|
1094
|
-
"""
|
1140
|
+
def ap_class_index(self) -> List:
|
1141
|
+
"""Return the class indices (boxes and masks have the same ap_class_index)."""
|
1095
1142
|
return self.box.ap_class_index
|
1096
1143
|
|
1097
1144
|
@property
|
1098
|
-
def results_dict(self):
|
1145
|
+
def results_dict(self) -> Dict[str, float]:
|
1099
1146
|
"""Return results of object detection model for evaluation."""
|
1100
1147
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1101
1148
|
|
1102
1149
|
@property
|
1103
|
-
def curves(self):
|
1150
|
+
def curves(self) -> List[str]:
|
1104
1151
|
"""Return a list of curves for accessing specific metrics curves."""
|
1105
1152
|
return [
|
1106
1153
|
"Precision-Recall(B)",
|
@@ -1114,12 +1161,12 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1114
1161
|
]
|
1115
1162
|
|
1116
1163
|
@property
|
1117
|
-
def curves_results(self):
|
1164
|
+
def curves_results(self) -> List[List]:
|
1118
1165
|
"""Return dictionary of computed performance metrics and statistics."""
|
1119
1166
|
return self.box.curves_results + self.seg.curves_results
|
1120
1167
|
|
1121
|
-
def summary(self, **kwargs):
|
1122
|
-
"""
|
1168
|
+
def summary(self, **kwargs) -> List[Dict[str, Union[str, float]]]:
|
1169
|
+
"""Return per-class segmentation metrics with shared scalar values included (box + mask)."""
|
1123
1170
|
scalars = {
|
1124
1171
|
"box-map": self.box.map,
|
1125
1172
|
"box-map50": self.box.map50,
|
@@ -1144,7 +1191,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1144
1191
|
|
1145
1192
|
class PoseMetrics(SegmentMetrics):
|
1146
1193
|
"""
|
1147
|
-
|
1194
|
+
Calculate and aggregate detection and pose metrics over a given set of classes.
|
1148
1195
|
|
1149
1196
|
Attributes:
|
1150
1197
|
save_dir (Path): Path to the directory where the output plots should be saved.
|
@@ -1156,23 +1203,23 @@ class PoseMetrics(SegmentMetrics):
|
|
1156
1203
|
task (str): The task type, set to 'pose'.
|
1157
1204
|
|
1158
1205
|
Methods:
|
1159
|
-
process(tp_m, tp_b, conf, pred_cls, target_cls):
|
1160
|
-
mean_results():
|
1161
|
-
class_result(i):
|
1162
|
-
maps:
|
1163
|
-
fitness:
|
1164
|
-
ap_class_index:
|
1165
|
-
results_dict:
|
1206
|
+
process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
|
1207
|
+
mean_results(): Return the mean of the detection and segmentation metrics over all the classes.
|
1208
|
+
class_result(i): Return the detection and segmentation metrics of class `i`.
|
1209
|
+
maps: Return the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
1210
|
+
fitness: Return the fitness scores, which are a single weighted combination of metrics.
|
1211
|
+
ap_class_index: Return the list of indices of classes used to compute Average Precision (AP).
|
1212
|
+
results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
|
1166
1213
|
"""
|
1167
1214
|
|
1168
|
-
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1215
|
+
def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
|
1169
1216
|
"""
|
1170
1217
|
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
1171
1218
|
|
1172
1219
|
Args:
|
1173
1220
|
save_dir (Path, optional): Directory to save plots.
|
1174
1221
|
plot (bool, optional): Whether to plot precision-recall curves.
|
1175
|
-
names (
|
1222
|
+
names (tuple, optional): Tuple mapping class indices to names.
|
1176
1223
|
"""
|
1177
1224
|
super().__init__(save_dir, plot, names)
|
1178
1225
|
self.save_dir = save_dir
|
@@ -1183,7 +1230,15 @@ class PoseMetrics(SegmentMetrics):
|
|
1183
1230
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1184
1231
|
self.task = "pose"
|
1185
1232
|
|
1186
|
-
def process(
|
1233
|
+
def process(
|
1234
|
+
self,
|
1235
|
+
tp: np.ndarray,
|
1236
|
+
tp_p: np.ndarray,
|
1237
|
+
conf: np.ndarray,
|
1238
|
+
pred_cls: np.ndarray,
|
1239
|
+
target_cls: np.ndarray,
|
1240
|
+
on_plot=None,
|
1241
|
+
):
|
1187
1242
|
"""
|
1188
1243
|
Process the detection and pose metrics over the given set of predictions.
|
1189
1244
|
|
@@ -1223,7 +1278,7 @@ class PoseMetrics(SegmentMetrics):
|
|
1223
1278
|
self.box.update(results_box)
|
1224
1279
|
|
1225
1280
|
@property
|
1226
|
-
def keys(self):
|
1281
|
+
def keys(self) -> List[str]:
|
1227
1282
|
"""Return list of evaluation metric keys."""
|
1228
1283
|
return [
|
1229
1284
|
"metrics/precision(B)",
|
@@ -1236,26 +1291,26 @@ class PoseMetrics(SegmentMetrics):
|
|
1236
1291
|
"metrics/mAP50-95(P)",
|
1237
1292
|
]
|
1238
1293
|
|
1239
|
-
def mean_results(self):
|
1294
|
+
def mean_results(self) -> List[float]:
|
1240
1295
|
"""Return the mean results of box and pose."""
|
1241
1296
|
return self.box.mean_results() + self.pose.mean_results()
|
1242
1297
|
|
1243
|
-
def class_result(self, i):
|
1298
|
+
def class_result(self, i: int) -> List[float]:
|
1244
1299
|
"""Return the class-wise detection results for a specific class i."""
|
1245
1300
|
return self.box.class_result(i) + self.pose.class_result(i)
|
1246
1301
|
|
1247
1302
|
@property
|
1248
|
-
def maps(self):
|
1303
|
+
def maps(self) -> np.ndarray:
|
1249
1304
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
1250
1305
|
return self.box.maps + self.pose.maps
|
1251
1306
|
|
1252
1307
|
@property
|
1253
|
-
def fitness(self):
|
1308
|
+
def fitness(self) -> float:
|
1254
1309
|
"""Return combined fitness score for pose and box detection."""
|
1255
1310
|
return self.pose.fitness() + self.box.fitness()
|
1256
1311
|
|
1257
1312
|
@property
|
1258
|
-
def curves(self):
|
1313
|
+
def curves(self) -> List[str]:
|
1259
1314
|
"""Return a list of curves for accessing specific metrics curves."""
|
1260
1315
|
return [
|
1261
1316
|
"Precision-Recall(B)",
|
@@ -1269,12 +1324,12 @@ class PoseMetrics(SegmentMetrics):
|
|
1269
1324
|
]
|
1270
1325
|
|
1271
1326
|
@property
|
1272
|
-
def curves_results(self):
|
1327
|
+
def curves_results(self) -> List[List]:
|
1273
1328
|
"""Return dictionary of computed performance metrics and statistics."""
|
1274
1329
|
return self.box.curves_results + self.pose.curves_results
|
1275
1330
|
|
1276
|
-
def summary(self, **kwargs):
|
1277
|
-
"""
|
1331
|
+
def summary(self, **kwargs) -> List[Dict[str, Union[str, float]]]:
|
1332
|
+
"""Return per-class pose metrics with shared scalar values included (box + pose)."""
|
1278
1333
|
scalars = {
|
1279
1334
|
"box-map": self.box.map,
|
1280
1335
|
"box-map50": self.box.map50,
|
@@ -1315,7 +1370,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1315
1370
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1316
1371
|
self.task = "classify"
|
1317
1372
|
|
1318
|
-
def process(self, targets, pred):
|
1373
|
+
def process(self, targets: torch.Tensor, pred: torch.Tensor):
|
1319
1374
|
"""
|
1320
1375
|
Process target classes and predicted classes to compute metrics.
|
1321
1376
|
|
@@ -1329,32 +1384,32 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1329
1384
|
self.top1, self.top5 = acc.mean(0).tolist()
|
1330
1385
|
|
1331
1386
|
@property
|
1332
|
-
def fitness(self):
|
1387
|
+
def fitness(self) -> float:
|
1333
1388
|
"""Return mean of top-1 and top-5 accuracies as fitness score."""
|
1334
1389
|
return (self.top1 + self.top5) / 2
|
1335
1390
|
|
1336
1391
|
@property
|
1337
|
-
def results_dict(self):
|
1392
|
+
def results_dict(self) -> Dict[str, float]:
|
1338
1393
|
"""Return a dictionary with model's performance metrics and fitness score."""
|
1339
1394
|
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
|
1340
1395
|
|
1341
1396
|
@property
|
1342
|
-
def keys(self):
|
1397
|
+
def keys(self) -> List[str]:
|
1343
1398
|
"""Return a list of keys for the results_dict property."""
|
1344
1399
|
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
1345
1400
|
|
1346
1401
|
@property
|
1347
|
-
def curves(self):
|
1402
|
+
def curves(self) -> List:
|
1348
1403
|
"""Return a list of curves for accessing specific metrics curves."""
|
1349
1404
|
return []
|
1350
1405
|
|
1351
1406
|
@property
|
1352
|
-
def curves_results(self):
|
1407
|
+
def curves_results(self) -> List:
|
1353
1408
|
"""Return a list of curves for accessing specific metrics curves."""
|
1354
1409
|
return []
|
1355
1410
|
|
1356
|
-
def summary(self, **kwargs):
|
1357
|
-
"""
|
1411
|
+
def summary(self, **kwargs) -> List[Dict[str, float]]:
|
1412
|
+
"""Return a single-row summary for classification metrics (top1/top5)."""
|
1358
1413
|
return [{"classify-top1": self.top1, "classify-top5": self.top5}]
|
1359
1414
|
|
1360
1415
|
|
@@ -1368,19 +1423,20 @@ class OBBMetrics(SimpleClass, DataExportMixin):
|
|
1368
1423
|
names (dict): Dictionary of class names.
|
1369
1424
|
box (Metric): An instance of the Metric class for storing detection results.
|
1370
1425
|
speed (dict): A dictionary for storing execution times of different parts of the detection process.
|
1426
|
+
task (str): The task type, set to 'obb'.
|
1371
1427
|
|
1372
1428
|
References:
|
1373
1429
|
https://arxiv.org/pdf/2106.06072.pdf
|
1374
1430
|
"""
|
1375
1431
|
|
1376
|
-
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
|
1432
|
+
def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
|
1377
1433
|
"""
|
1378
1434
|
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
1379
1435
|
|
1380
1436
|
Args:
|
1381
1437
|
save_dir (Path, optional): Directory to save plots.
|
1382
1438
|
plot (bool, optional): Whether to plot precision-recall curves.
|
1383
|
-
names (
|
1439
|
+
names (tuple, optional): Tuple mapping class indices to names.
|
1384
1440
|
"""
|
1385
1441
|
self.save_dir = save_dir
|
1386
1442
|
self.plot = plot
|
@@ -1389,7 +1445,7 @@ class OBBMetrics(SimpleClass, DataExportMixin):
|
|
1389
1445
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1390
1446
|
self.task = "obb"
|
1391
1447
|
|
1392
|
-
def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
|
1448
|
+
def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
1393
1449
|
"""
|
1394
1450
|
Process predicted results for object detection and update metrics.
|
1395
1451
|
|
@@ -1414,50 +1470,50 @@ class OBBMetrics(SimpleClass, DataExportMixin):
|
|
1414
1470
|
self.box.update(results)
|
1415
1471
|
|
1416
1472
|
@property
|
1417
|
-
def keys(self):
|
1473
|
+
def keys(self) -> List[str]:
|
1418
1474
|
"""Return a list of keys for accessing specific metrics."""
|
1419
1475
|
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
1420
1476
|
|
1421
|
-
def mean_results(self):
|
1477
|
+
def mean_results(self) -> List[float]:
|
1422
1478
|
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
1423
1479
|
return self.box.mean_results()
|
1424
1480
|
|
1425
|
-
def class_result(self, i):
|
1481
|
+
def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
1426
1482
|
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
1427
1483
|
return self.box.class_result(i)
|
1428
1484
|
|
1429
1485
|
@property
|
1430
|
-
def maps(self):
|
1486
|
+
def maps(self) -> np.ndarray:
|
1431
1487
|
"""Return mean Average Precision (mAP) scores per class."""
|
1432
1488
|
return self.box.maps
|
1433
1489
|
|
1434
1490
|
@property
|
1435
|
-
def fitness(self):
|
1491
|
+
def fitness(self) -> float:
|
1436
1492
|
"""Return the fitness of box object."""
|
1437
1493
|
return self.box.fitness()
|
1438
1494
|
|
1439
1495
|
@property
|
1440
|
-
def ap_class_index(self):
|
1496
|
+
def ap_class_index(self) -> List:
|
1441
1497
|
"""Return the average precision index per class."""
|
1442
1498
|
return self.box.ap_class_index
|
1443
1499
|
|
1444
1500
|
@property
|
1445
|
-
def results_dict(self):
|
1501
|
+
def results_dict(self) -> Dict[str, float]:
|
1446
1502
|
"""Return dictionary of computed performance metrics and statistics."""
|
1447
1503
|
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1448
1504
|
|
1449
1505
|
@property
|
1450
|
-
def curves(self):
|
1506
|
+
def curves(self) -> List:
|
1451
1507
|
"""Return a list of curves for accessing specific metrics curves."""
|
1452
1508
|
return []
|
1453
1509
|
|
1454
1510
|
@property
|
1455
|
-
def curves_results(self):
|
1511
|
+
def curves_results(self) -> List:
|
1456
1512
|
"""Return a list of curves for accessing specific metrics curves."""
|
1457
1513
|
return []
|
1458
1514
|
|
1459
|
-
def summary(self, **kwargs):
|
1460
|
-
"""
|
1515
|
+
def summary(self, **kwargs) -> List[Dict[str, Union[str, float]]]:
|
1516
|
+
"""Return per-class detection metrics with shared scalar values included."""
|
1461
1517
|
scalars = {
|
1462
1518
|
"box-map": self.box.map,
|
1463
1519
|
"box-map50": self.box.map50,
|