dgenerate-ultralytics-headless 8.3.152__py3-none-any.whl → 8.3.154__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/RECORD +30 -30
- tests/test_python.py +1 -0
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +2 -0
- ultralytics/engine/predictor.py +1 -1
- ultralytics/engine/validator.py +0 -6
- ultralytics/models/fastsam/val.py +0 -2
- ultralytics/models/rtdetr/val.py +28 -16
- ultralytics/models/yolo/classify/val.py +26 -23
- ultralytics/models/yolo/detect/train.py +4 -7
- ultralytics/models/yolo/detect/val.py +88 -90
- ultralytics/models/yolo/obb/val.py +52 -44
- ultralytics/models/yolo/pose/train.py +1 -35
- ultralytics/models/yolo/pose/val.py +77 -176
- ultralytics/models/yolo/segment/train.py +1 -41
- ultralytics/models/yolo/segment/val.py +64 -176
- ultralytics/models/yolo/yoloe/val.py +2 -1
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/solutions/ai_gym.py +2 -3
- ultralytics/solutions/solutions.py +2 -0
- ultralytics/solutions/templates/similarity-search.html +31 -0
- ultralytics/utils/callbacks/comet.py +1 -1
- ultralytics/utils/metrics.py +152 -307
- ultralytics/utils/ops.py +4 -4
- ultralytics/utils/plotting.py +31 -56
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4
4
|
import math
|
5
5
|
import warnings
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Tuple, Union
|
7
|
+
from typing import Any, Dict, List, Tuple, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import torch
|
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
|
|
316
316
|
Attributes:
|
317
317
|
task (str): The type of task, either 'detect' or 'classify'.
|
318
318
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
319
|
-
nc (int): The number of
|
320
|
-
|
321
|
-
iou_thres (float): The Intersection over Union threshold.
|
319
|
+
nc (int): The number of category.
|
320
|
+
names (List[str]): The names of the classes, used as labels on the plot.
|
322
321
|
"""
|
323
322
|
|
324
|
-
def __init__(self,
|
323
|
+
def __init__(self, names: List[str] = [], task: str = "detect"):
|
325
324
|
"""
|
326
325
|
Initialize a ConfusionMatrix instance.
|
327
326
|
|
328
327
|
Args:
|
329
|
-
|
330
|
-
conf (float, optional): Confidence threshold for detections.
|
331
|
-
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
332
|
-
names (tuple, optional): Names of classes, used as labels on the plot.
|
328
|
+
names (List[str], optional): Names of classes, used as labels on the plot.
|
333
329
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
334
330
|
"""
|
335
331
|
self.task = task
|
336
|
-
self.
|
337
|
-
self.
|
338
|
-
self.names =
|
339
|
-
self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
340
|
-
self.iou_thres = iou_thres
|
332
|
+
self.nc = len(names) # number of classes
|
333
|
+
self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
|
334
|
+
self.names = names # name of classes
|
341
335
|
|
342
336
|
def process_cls_preds(self, preds, targets):
|
343
337
|
"""
|
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
|
|
351
345
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
352
346
|
self.matrix[p][t] += 1
|
353
347
|
|
354
|
-
def process_batch(
|
348
|
+
def process_batch(
|
349
|
+
self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
|
350
|
+
) -> None:
|
355
351
|
"""
|
356
352
|
Update confusion matrix for object detection task.
|
357
353
|
|
358
354
|
Args:
|
359
|
-
detections (
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
355
|
+
detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
|
356
|
+
Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
|
357
|
+
Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
|
358
|
+
batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
|
359
|
+
'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
360
|
+
conf (float, optional): Confidence threshold for detections.
|
361
|
+
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
364
362
|
"""
|
363
|
+
conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
364
|
+
gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
|
365
|
+
no_pred = len(detections["cls"]) == 0
|
365
366
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
366
|
-
if
|
367
|
-
detections = detections[detections[
|
368
|
-
detection_classes = detections[
|
367
|
+
if not no_pred:
|
368
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
369
|
+
detection_classes = detections["cls"].int().tolist()
|
369
370
|
for dc in detection_classes:
|
370
371
|
self.matrix[dc, self.nc] += 1 # false positives
|
371
372
|
return
|
372
|
-
if
|
373
|
+
if no_pred:
|
373
374
|
gt_classes = gt_cls.int().tolist()
|
374
375
|
for gc in gt_classes:
|
375
376
|
self.matrix[self.nc, gc] += 1 # background FN
|
376
377
|
return
|
377
378
|
|
378
|
-
detections = detections[detections[
|
379
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
379
380
|
gt_classes = gt_cls.int().tolist()
|
380
|
-
detection_classes = detections[
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
)
|
387
|
-
|
388
|
-
x = torch.where(iou > self.iou_thres)
|
381
|
+
detection_classes = detections["cls"].int().tolist()
|
382
|
+
bboxes = detections["bboxes"]
|
383
|
+
is_obb = bboxes.shape[1] == 5 # check if detections contains angle for OBB
|
384
|
+
iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
|
385
|
+
|
386
|
+
x = torch.where(iou > iou_thres)
|
389
387
|
if x[0].shape[0]:
|
390
388
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
391
389
|
if x[0].shape[0] > 1:
|
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
949
947
|
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
950
948
|
|
951
949
|
Attributes:
|
952
|
-
save_dir (Path): A path to the directory where the output plots will be saved.
|
953
|
-
plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
|
954
950
|
names (Dict[int, str]): A dictionary of class names.
|
955
951
|
box (Metric): An instance of the Metric class for storing detection results.
|
956
952
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
957
953
|
task (str): The task type, set to 'detect'.
|
954
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
955
|
+
nt_per_class: Number of targets per class.
|
956
|
+
nt_per_image: Number of targets per image.
|
958
957
|
"""
|
959
958
|
|
960
|
-
def __init__(self,
|
959
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
961
960
|
"""
|
962
961
|
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
963
962
|
|
964
963
|
Args:
|
965
|
-
save_dir (Path, optional): Directory to save plots.
|
966
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
967
964
|
names (Dict[int, str], optional): Dictionary of class names.
|
968
965
|
"""
|
969
|
-
self.save_dir = save_dir
|
970
|
-
self.plot = plot
|
971
966
|
self.names = names
|
972
967
|
self.box = Metric()
|
973
968
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
974
969
|
self.task = "detect"
|
970
|
+
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
971
|
+
self.nt_per_class = None
|
972
|
+
self.nt_per_image = None
|
973
|
+
|
974
|
+
def update_stats(self, stat: Dict[str, Any]) -> None:
|
975
|
+
"""
|
976
|
+
Update statistics by appending new values to existing stat collections.
|
977
|
+
|
978
|
+
Args:
|
979
|
+
stat (Dict[str, any]): Dictionary containing new statistical values to append.
|
980
|
+
Keys should match existing keys in self.stats.
|
981
|
+
"""
|
982
|
+
for k in self.stats.keys():
|
983
|
+
self.stats[k].append(stat[k])
|
975
984
|
|
976
|
-
def process(self,
|
985
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
977
986
|
"""
|
978
987
|
Process predicted results for object detection and update metrics.
|
979
988
|
|
980
989
|
Args:
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
990
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
991
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
992
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
993
|
+
|
994
|
+
Returns:
|
995
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
986
996
|
"""
|
997
|
+
stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
|
998
|
+
if len(stats) == 0:
|
999
|
+
return stats
|
987
1000
|
results = ap_per_class(
|
988
|
-
tp,
|
989
|
-
conf,
|
990
|
-
pred_cls,
|
991
|
-
target_cls,
|
992
|
-
plot=
|
993
|
-
save_dir=
|
1001
|
+
stats["tp"],
|
1002
|
+
stats["conf"],
|
1003
|
+
stats["pred_cls"],
|
1004
|
+
stats["target_cls"],
|
1005
|
+
plot=plot,
|
1006
|
+
save_dir=save_dir,
|
994
1007
|
names=self.names,
|
995
1008
|
on_plot=on_plot,
|
996
1009
|
)[2:]
|
997
1010
|
self.box.nc = len(self.names)
|
998
1011
|
self.box.update(results)
|
1012
|
+
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
|
1013
|
+
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
|
1014
|
+
return stats
|
1015
|
+
|
1016
|
+
def clear_stats(self):
|
1017
|
+
"""Clear the stored statistics."""
|
1018
|
+
for v in self.stats.values():
|
1019
|
+
v.clear()
|
999
1020
|
|
1000
1021
|
@property
|
1001
1022
|
def keys(self) -> List[str]:
|
@@ -1068,97 +1089,74 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
1068
1089
|
"box-f1": self.box.f1,
|
1069
1090
|
}
|
1070
1091
|
return [
|
1071
|
-
{
|
1072
|
-
|
1092
|
+
{
|
1093
|
+
"class_name": self.names[self.ap_class_index[i]],
|
1094
|
+
**{k: round(v[i], decimals) for k, v in per_class.items()},
|
1095
|
+
**scalars,
|
1096
|
+
}
|
1097
|
+
for i in range(len(per_class["box-p"]))
|
1073
1098
|
]
|
1074
1099
|
|
1075
1100
|
|
1076
|
-
class SegmentMetrics(
|
1101
|
+
class SegmentMetrics(DetMetrics):
|
1077
1102
|
"""
|
1078
1103
|
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
1079
1104
|
|
1080
1105
|
Attributes:
|
1081
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1082
|
-
plot (bool): Whether to save the detection and segmentation plots.
|
1083
1106
|
names (Dict[int, str]): Dictionary of class names.
|
1084
1107
|
box (Metric): An instance of the Metric class for storing detection results.
|
1085
1108
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
1086
1109
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1087
1110
|
task (str): The task type, set to 'segment'.
|
1111
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1112
|
+
nt_per_class: Number of targets per class.
|
1113
|
+
nt_per_image: Number of targets per image.
|
1088
1114
|
"""
|
1089
1115
|
|
1090
|
-
def __init__(self,
|
1116
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1091
1117
|
"""
|
1092
1118
|
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
1093
1119
|
|
1094
1120
|
Args:
|
1095
|
-
save_dir (Path, optional): Directory to save plots.
|
1096
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1097
1121
|
names (Dict[int, str], optional): Dictionary of class names.
|
1098
1122
|
"""
|
1099
|
-
self
|
1100
|
-
self.plot = plot
|
1101
|
-
self.names = names
|
1102
|
-
self.box = Metric()
|
1123
|
+
DetMetrics.__init__(self, names)
|
1103
1124
|
self.seg = Metric()
|
1104
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1105
1125
|
self.task = "segment"
|
1126
|
+
self.stats["tp_m"] = [] # add additional stats for masks
|
1106
1127
|
|
1107
|
-
def process(
|
1108
|
-
self,
|
1109
|
-
tp: np.ndarray,
|
1110
|
-
tp_m: np.ndarray,
|
1111
|
-
conf: np.ndarray,
|
1112
|
-
pred_cls: np.ndarray,
|
1113
|
-
target_cls: np.ndarray,
|
1114
|
-
on_plot=None,
|
1115
|
-
):
|
1128
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
1116
1129
|
"""
|
1117
1130
|
Process the detection and segmentation metrics over the given set of predictions.
|
1118
1131
|
|
1119
1132
|
Args:
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1133
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
1134
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
1135
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
1136
|
+
|
1137
|
+
Returns:
|
1138
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
1126
1139
|
"""
|
1140
|
+
stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
1127
1141
|
results_mask = ap_per_class(
|
1128
|
-
tp_m,
|
1129
|
-
conf,
|
1130
|
-
pred_cls,
|
1131
|
-
target_cls,
|
1132
|
-
plot=
|
1142
|
+
stats["tp_m"],
|
1143
|
+
stats["conf"],
|
1144
|
+
stats["pred_cls"],
|
1145
|
+
stats["target_cls"],
|
1146
|
+
plot=plot,
|
1133
1147
|
on_plot=on_plot,
|
1134
|
-
save_dir=
|
1148
|
+
save_dir=save_dir,
|
1135
1149
|
names=self.names,
|
1136
1150
|
prefix="Mask",
|
1137
1151
|
)[2:]
|
1138
1152
|
self.seg.nc = len(self.names)
|
1139
1153
|
self.seg.update(results_mask)
|
1140
|
-
|
1141
|
-
tp,
|
1142
|
-
conf,
|
1143
|
-
pred_cls,
|
1144
|
-
target_cls,
|
1145
|
-
plot=self.plot,
|
1146
|
-
on_plot=on_plot,
|
1147
|
-
save_dir=self.save_dir,
|
1148
|
-
names=self.names,
|
1149
|
-
prefix="Box",
|
1150
|
-
)[2:]
|
1151
|
-
self.box.nc = len(self.names)
|
1152
|
-
self.box.update(results_box)
|
1154
|
+
return stats
|
1153
1155
|
|
1154
1156
|
@property
|
1155
1157
|
def keys(self) -> List[str]:
|
1156
1158
|
"""Return a list of keys for accessing metrics."""
|
1157
|
-
return [
|
1158
|
-
"metrics/precision(B)",
|
1159
|
-
"metrics/recall(B)",
|
1160
|
-
"metrics/mAP50(B)",
|
1161
|
-
"metrics/mAP50-95(B)",
|
1159
|
+
return DetMetrics.keys.fget(self) + [
|
1162
1160
|
"metrics/precision(M)",
|
1163
1161
|
"metrics/recall(M)",
|
1164
1162
|
"metrics/mAP50(M)",
|
@@ -1167,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1167
1165
|
|
1168
1166
|
def mean_results(self) -> List[float]:
|
1169
1167
|
"""Return the mean metrics for bounding box and segmentation results."""
|
1170
|
-
return
|
1168
|
+
return DetMetrics.mean_results(self) + self.seg.mean_results()
|
1171
1169
|
|
1172
1170
|
def class_result(self, i: int) -> List[float]:
|
1173
1171
|
"""Return classification results for a specified class index."""
|
1174
|
-
return
|
1172
|
+
return DetMetrics.class_result(self, i) + self.seg.class_result(i)
|
1175
1173
|
|
1176
1174
|
@property
|
1177
1175
|
def maps(self) -> np.ndarray:
|
1178
1176
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
1179
|
-
return
|
1177
|
+
return DetMetrics.maps.fget(self) + self.seg.maps
|
1180
1178
|
|
1181
1179
|
@property
|
1182
1180
|
def fitness(self) -> float:
|
1183
1181
|
"""Return the fitness score for both segmentation and bounding box models."""
|
1184
|
-
return self.seg.fitness() +
|
1185
|
-
|
1186
|
-
@property
|
1187
|
-
def ap_class_index(self) -> List:
|
1188
|
-
"""Return the class indices (boxes and masks have the same ap_class_index)."""
|
1189
|
-
return self.box.ap_class_index
|
1190
|
-
|
1191
|
-
@property
|
1192
|
-
def results_dict(self) -> Dict[str, float]:
|
1193
|
-
"""Return results of object detection model for evaluation."""
|
1194
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1182
|
+
return self.seg.fitness() + DetMetrics.fitness.fget(self)
|
1195
1183
|
|
1196
1184
|
@property
|
1197
1185
|
def curves(self) -> List[str]:
|
1198
1186
|
"""Return a list of curves for accessing specific metrics curves."""
|
1199
|
-
return [
|
1200
|
-
"Precision-Recall(B)",
|
1201
|
-
"F1-Confidence(B)",
|
1202
|
-
"Precision-Confidence(B)",
|
1203
|
-
"Recall-Confidence(B)",
|
1187
|
+
return DetMetrics.curves.fget(self) + [
|
1204
1188
|
"Precision-Recall(M)",
|
1205
1189
|
"F1-Confidence(M)",
|
1206
1190
|
"Precision-Confidence(M)",
|
@@ -1210,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1210
1194
|
@property
|
1211
1195
|
def curves_results(self) -> List[List]:
|
1212
1196
|
"""Return dictionary of computed performance metrics and statistics."""
|
1213
|
-
return
|
1197
|
+
return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
1214
1198
|
|
1215
1199
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1216
1200
|
"""
|
@@ -1230,39 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1230
1214
|
>>> print(seg_summary)
|
1231
1215
|
"""
|
1232
1216
|
scalars = {
|
1233
|
-
"box-map": round(self.box.map, decimals),
|
1234
|
-
"box-map50": round(self.box.map50, decimals),
|
1235
|
-
"box-map75": round(self.box.map75, decimals),
|
1236
1217
|
"mask-map": round(self.seg.map, decimals),
|
1237
1218
|
"mask-map50": round(self.seg.map50, decimals),
|
1238
1219
|
"mask-map75": round(self.seg.map75, decimals),
|
1239
1220
|
}
|
1240
1221
|
per_class = {
|
1241
|
-
"box-p": self.box.p,
|
1242
|
-
"box-r": self.box.r,
|
1243
|
-
"box-f1": self.box.f1,
|
1244
1222
|
"mask-p": self.seg.p,
|
1245
1223
|
"mask-r": self.seg.r,
|
1246
1224
|
"mask-f1": self.seg.f1,
|
1247
1225
|
}
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1226
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
1227
|
+
for i, s in enumerate(summary):
|
1228
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
1229
|
+
return summary
|
1252
1230
|
|
1253
1231
|
|
1254
|
-
class PoseMetrics(
|
1232
|
+
class PoseMetrics(DetMetrics):
|
1255
1233
|
"""
|
1256
1234
|
Calculate and aggregate detection and pose metrics over a given set of classes.
|
1257
1235
|
|
1258
1236
|
Attributes:
|
1259
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1260
|
-
plot (bool): Whether to save the detection and pose plots.
|
1261
1237
|
names (Dict[int, str]): Dictionary of class names.
|
1262
1238
|
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
1263
1239
|
box (Metric): An instance of the Metric class for storing detection results.
|
1264
1240
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1265
1241
|
task (str): The task type, set to 'pose'.
|
1242
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1243
|
+
nt_per_class: Number of targets per class.
|
1244
|
+
nt_per_image: Number of targets per image.
|
1266
1245
|
|
1267
1246
|
Methods:
|
1268
1247
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
|
@@ -1274,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
|
|
1274
1253
|
results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
|
1275
1254
|
"""
|
1276
1255
|
|
1277
|
-
def __init__(self,
|
1256
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1278
1257
|
"""
|
1279
1258
|
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
1280
1259
|
|
1281
1260
|
Args:
|
1282
|
-
save_dir (Path, optional): Directory to save plots.
|
1283
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1284
1261
|
names (Dict[int, str], optional): Dictionary of class names.
|
1285
1262
|
"""
|
1286
|
-
super().__init__(
|
1287
|
-
self.save_dir = save_dir
|
1288
|
-
self.plot = plot
|
1289
|
-
self.names = names
|
1290
|
-
self.box = Metric()
|
1263
|
+
super().__init__(names)
|
1291
1264
|
self.pose = Metric()
|
1292
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1293
1265
|
self.task = "pose"
|
1266
|
+
self.stats["tp_p"] = [] # add additional stats for pose
|
1294
1267
|
|
1295
|
-
def process(
|
1296
|
-
self,
|
1297
|
-
tp: np.ndarray,
|
1298
|
-
tp_p: np.ndarray,
|
1299
|
-
conf: np.ndarray,
|
1300
|
-
pred_cls: np.ndarray,
|
1301
|
-
target_cls: np.ndarray,
|
1302
|
-
on_plot=None,
|
1303
|
-
):
|
1268
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
1304
1269
|
"""
|
1305
1270
|
Process the detection and pose metrics over the given set of predictions.
|
1306
1271
|
|
1307
1272
|
Args:
|
1308
|
-
|
1309
|
-
|
1310
|
-
conf (np.ndarray): Confidence array.
|
1311
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
1312
|
-
target_cls (np.ndarray): Target class indices array.
|
1273
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
1274
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
1313
1275
|
on_plot (callable, optional): Function to call after plots are generated.
|
1276
|
+
|
1277
|
+
Returns:
|
1278
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
1314
1279
|
"""
|
1280
|
+
stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
1315
1281
|
results_pose = ap_per_class(
|
1316
|
-
tp_p,
|
1317
|
-
conf,
|
1318
|
-
pred_cls,
|
1319
|
-
target_cls,
|
1320
|
-
plot=
|
1282
|
+
stats["tp_p"],
|
1283
|
+
stats["conf"],
|
1284
|
+
stats["pred_cls"],
|
1285
|
+
stats["target_cls"],
|
1286
|
+
plot=plot,
|
1321
1287
|
on_plot=on_plot,
|
1322
|
-
save_dir=
|
1288
|
+
save_dir=save_dir,
|
1323
1289
|
names=self.names,
|
1324
1290
|
prefix="Pose",
|
1325
1291
|
)[2:]
|
1326
1292
|
self.pose.nc = len(self.names)
|
1327
1293
|
self.pose.update(results_pose)
|
1328
|
-
|
1329
|
-
tp,
|
1330
|
-
conf,
|
1331
|
-
pred_cls,
|
1332
|
-
target_cls,
|
1333
|
-
plot=self.plot,
|
1334
|
-
on_plot=on_plot,
|
1335
|
-
save_dir=self.save_dir,
|
1336
|
-
names=self.names,
|
1337
|
-
prefix="Box",
|
1338
|
-
)[2:]
|
1339
|
-
self.box.nc = len(self.names)
|
1340
|
-
self.box.update(results_box)
|
1294
|
+
return stats
|
1341
1295
|
|
1342
1296
|
@property
|
1343
1297
|
def keys(self) -> List[str]:
|
1344
1298
|
"""Return list of evaluation metric keys."""
|
1345
|
-
return [
|
1346
|
-
"metrics/precision(B)",
|
1347
|
-
"metrics/recall(B)",
|
1348
|
-
"metrics/mAP50(B)",
|
1349
|
-
"metrics/mAP50-95(B)",
|
1299
|
+
return DetMetrics.keys.fget(self) + [
|
1350
1300
|
"metrics/precision(P)",
|
1351
1301
|
"metrics/recall(P)",
|
1352
1302
|
"metrics/mAP50(P)",
|
@@ -1355,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
|
|
1355
1305
|
|
1356
1306
|
def mean_results(self) -> List[float]:
|
1357
1307
|
"""Return the mean results of box and pose."""
|
1358
|
-
return
|
1308
|
+
return DetMetrics.mean_results(self) + self.pose.mean_results()
|
1359
1309
|
|
1360
1310
|
def class_result(self, i: int) -> List[float]:
|
1361
1311
|
"""Return the class-wise detection results for a specific class i."""
|
1362
|
-
return
|
1312
|
+
return DetMetrics.class_result(self, i) + self.pose.class_result(i)
|
1363
1313
|
|
1364
1314
|
@property
|
1365
1315
|
def maps(self) -> np.ndarray:
|
1366
1316
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
1367
|
-
return
|
1317
|
+
return DetMetrics.maps.fget(self) + self.pose.maps
|
1368
1318
|
|
1369
1319
|
@property
|
1370
1320
|
def fitness(self) -> float:
|
1371
1321
|
"""Return combined fitness score for pose and box detection."""
|
1372
|
-
return self.pose.fitness() +
|
1322
|
+
return self.pose.fitness() + DetMetrics.fitness.fget(self)
|
1373
1323
|
|
1374
1324
|
@property
|
1375
1325
|
def curves(self) -> List[str]:
|
1376
1326
|
"""Return a list of curves for accessing specific metrics curves."""
|
1377
|
-
return [
|
1327
|
+
return DetMetrics.curves.fget(self) + [
|
1378
1328
|
"Precision-Recall(B)",
|
1379
1329
|
"F1-Confidence(B)",
|
1380
1330
|
"Precision-Confidence(B)",
|
@@ -1388,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
|
|
1388
1338
|
@property
|
1389
1339
|
def curves_results(self) -> List[List]:
|
1390
1340
|
"""Return dictionary of computed performance metrics and statistics."""
|
1391
|
-
return
|
1341
|
+
return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
1392
1342
|
|
1393
1343
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1394
1344
|
"""
|
@@ -1408,25 +1358,19 @@ class PoseMetrics(SegmentMetrics):
|
|
1408
1358
|
>>> print(pose_summary)
|
1409
1359
|
"""
|
1410
1360
|
scalars = {
|
1411
|
-
"box-map": round(self.box.map, decimals),
|
1412
|
-
"box-map50": round(self.box.map50, decimals),
|
1413
|
-
"box-map75": round(self.box.map75, decimals),
|
1414
1361
|
"pose-map": round(self.pose.map, decimals),
|
1415
1362
|
"pose-map50": round(self.pose.map50, decimals),
|
1416
1363
|
"pose-map75": round(self.pose.map75, decimals),
|
1417
1364
|
}
|
1418
1365
|
per_class = {
|
1419
|
-
"box-p": self.box.p,
|
1420
|
-
"box-r": self.box.r,
|
1421
|
-
"box-f1": self.box.f1,
|
1422
1366
|
"pose-p": self.pose.p,
|
1423
1367
|
"pose-r": self.pose.r,
|
1424
1368
|
"pose-f1": self.pose.f1,
|
1425
1369
|
}
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1370
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
1371
|
+
for i, s in enumerate(summary):
|
1372
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
1373
|
+
return summary
|
1430
1374
|
|
1431
1375
|
|
1432
1376
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
@@ -1504,129 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1504
1448
|
return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
|
1505
1449
|
|
1506
1450
|
|
1507
|
-
class OBBMetrics(
|
1451
|
+
class OBBMetrics(DetMetrics):
|
1508
1452
|
"""
|
1509
1453
|
Metrics for evaluating oriented bounding box (OBB) detection.
|
1510
1454
|
|
1511
1455
|
Attributes:
|
1512
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1513
|
-
plot (bool): Whether to save the detection plots.
|
1514
1456
|
names (Dict[int, str]): Dictionary of class names.
|
1515
1457
|
box (Metric): An instance of the Metric class for storing detection results.
|
1516
1458
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1517
1459
|
task (str): The task type, set to 'obb'.
|
1460
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1461
|
+
nt_per_class: Number of targets per class.
|
1462
|
+
nt_per_image: Number of targets per image.
|
1518
1463
|
|
1519
1464
|
References:
|
1520
1465
|
https://arxiv.org/pdf/2106.06072.pdf
|
1521
1466
|
"""
|
1522
1467
|
|
1523
|
-
def __init__(self,
|
1468
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1524
1469
|
"""
|
1525
1470
|
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
1526
1471
|
|
1527
1472
|
Args:
|
1528
|
-
save_dir (Path, optional): Directory to save plots.
|
1529
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1530
1473
|
names (Dict[int, str], optional): Dictionary of class names.
|
1531
1474
|
"""
|
1532
|
-
self
|
1533
|
-
|
1534
|
-
self.names = names
|
1535
|
-
self.box = Metric()
|
1536
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1475
|
+
DetMetrics.__init__(self, names)
|
1476
|
+
# TODO: probably remove task as well
|
1537
1477
|
self.task = "obb"
|
1538
|
-
|
1539
|
-
def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
1540
|
-
"""
|
1541
|
-
Process predicted results for object detection and update metrics.
|
1542
|
-
|
1543
|
-
Args:
|
1544
|
-
tp (np.ndarray): True positive array.
|
1545
|
-
conf (np.ndarray): Confidence array.
|
1546
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
1547
|
-
target_cls (np.ndarray): Target class indices array.
|
1548
|
-
on_plot (callable, optional): Function to call after plots are generated.
|
1549
|
-
"""
|
1550
|
-
results = ap_per_class(
|
1551
|
-
tp,
|
1552
|
-
conf,
|
1553
|
-
pred_cls,
|
1554
|
-
target_cls,
|
1555
|
-
plot=self.plot,
|
1556
|
-
save_dir=self.save_dir,
|
1557
|
-
names=self.names,
|
1558
|
-
on_plot=on_plot,
|
1559
|
-
)[2:]
|
1560
|
-
self.box.nc = len(self.names)
|
1561
|
-
self.box.update(results)
|
1562
|
-
|
1563
|
-
@property
|
1564
|
-
def keys(self) -> List[str]:
|
1565
|
-
"""Return a list of keys for accessing specific metrics."""
|
1566
|
-
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
1567
|
-
|
1568
|
-
def mean_results(self) -> List[float]:
|
1569
|
-
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
1570
|
-
return self.box.mean_results()
|
1571
|
-
|
1572
|
-
def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
1573
|
-
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
1574
|
-
return self.box.class_result(i)
|
1575
|
-
|
1576
|
-
@property
|
1577
|
-
def maps(self) -> np.ndarray:
|
1578
|
-
"""Return mean Average Precision (mAP) scores per class."""
|
1579
|
-
return self.box.maps
|
1580
|
-
|
1581
|
-
@property
|
1582
|
-
def fitness(self) -> float:
|
1583
|
-
"""Return the fitness of box object."""
|
1584
|
-
return self.box.fitness()
|
1585
|
-
|
1586
|
-
@property
|
1587
|
-
def ap_class_index(self) -> List:
|
1588
|
-
"""Return the average precision index per class."""
|
1589
|
-
return self.box.ap_class_index
|
1590
|
-
|
1591
|
-
@property
|
1592
|
-
def results_dict(self) -> Dict[str, float]:
|
1593
|
-
"""Return dictionary of computed performance metrics and statistics."""
|
1594
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1595
|
-
|
1596
|
-
@property
|
1597
|
-
def curves(self) -> List:
|
1598
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
1599
|
-
return []
|
1600
|
-
|
1601
|
-
@property
|
1602
|
-
def curves_results(self) -> List:
|
1603
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
1604
|
-
return []
|
1605
|
-
|
1606
|
-
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1607
|
-
"""
|
1608
|
-
Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
|
1609
|
-
scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
|
1610
|
-
|
1611
|
-
Args:
|
1612
|
-
normalize (bool): For OBB metrics, everything is normalized by default [0-1].
|
1613
|
-
decimals (int): Number of decimal places to round the metrics values to.
|
1614
|
-
|
1615
|
-
Returns:
|
1616
|
-
(List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
|
1617
|
-
|
1618
|
-
Examples:
|
1619
|
-
>>> results = model.val(data="dota8.yaml")
|
1620
|
-
>>> detection_summary = results.summary(decimals=4)
|
1621
|
-
>>> print(detection_summary)
|
1622
|
-
"""
|
1623
|
-
scalars = {
|
1624
|
-
"box-map": round(self.box.map, decimals),
|
1625
|
-
"box-map50": round(self.box.map50, decimals),
|
1626
|
-
"box-map75": round(self.box.map75, decimals),
|
1627
|
-
}
|
1628
|
-
per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
|
1629
|
-
return [
|
1630
|
-
{"class_name": self.names[i], **{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars}
|
1631
|
-
for i in range(len(next(iter(per_class.values()), [])))
|
1632
|
-
]
|