dgenerate-ultralytics-headless 8.3.153__py3-none-any.whl → 8.3.154__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/RECORD +30 -30
- tests/test_python.py +1 -0
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +2 -0
- ultralytics/engine/predictor.py +1 -1
- ultralytics/engine/validator.py +0 -6
- ultralytics/models/fastsam/val.py +0 -2
- ultralytics/models/rtdetr/val.py +28 -16
- ultralytics/models/yolo/classify/val.py +26 -23
- ultralytics/models/yolo/detect/train.py +4 -7
- ultralytics/models/yolo/detect/val.py +88 -90
- ultralytics/models/yolo/obb/val.py +52 -44
- ultralytics/models/yolo/pose/train.py +1 -35
- ultralytics/models/yolo/pose/val.py +77 -176
- ultralytics/models/yolo/segment/train.py +1 -41
- ultralytics/models/yolo/segment/val.py +64 -176
- ultralytics/models/yolo/yoloe/val.py +2 -1
- ultralytics/nn/autobackend.py +2 -2
- ultralytics/solutions/ai_gym.py +2 -3
- ultralytics/solutions/solutions.py +2 -0
- ultralytics/solutions/templates/similarity-search.html +31 -0
- ultralytics/utils/callbacks/comet.py +1 -1
- ultralytics/utils/metrics.py +146 -317
- ultralytics/utils/ops.py +4 -4
- ultralytics/utils/plotting.py +31 -56
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.153.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/top_level.txt +0 -0
ultralytics/utils/metrics.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4
4
|
import math
|
5
5
|
import warnings
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Dict, List, Tuple, Union
|
7
|
+
from typing import Any, Dict, List, Tuple, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import torch
|
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
|
|
316
316
|
Attributes:
|
317
317
|
task (str): The type of task, either 'detect' or 'classify'.
|
318
318
|
matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
|
319
|
-
nc (int): The number of
|
320
|
-
|
321
|
-
iou_thres (float): The Intersection over Union threshold.
|
319
|
+
nc (int): The number of category.
|
320
|
+
names (List[str]): The names of the classes, used as labels on the plot.
|
322
321
|
"""
|
323
322
|
|
324
|
-
def __init__(self,
|
323
|
+
def __init__(self, names: List[str] = [], task: str = "detect"):
|
325
324
|
"""
|
326
325
|
Initialize a ConfusionMatrix instance.
|
327
326
|
|
328
327
|
Args:
|
329
|
-
|
330
|
-
conf (float, optional): Confidence threshold for detections.
|
331
|
-
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
332
|
-
names (tuple, optional): Names of classes, used as labels on the plot.
|
328
|
+
names (List[str], optional): Names of classes, used as labels on the plot.
|
333
329
|
task (str, optional): Type of task, either 'detect' or 'classify'.
|
334
330
|
"""
|
335
331
|
self.task = task
|
336
|
-
self.
|
337
|
-
self.
|
338
|
-
self.names =
|
339
|
-
self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
340
|
-
self.iou_thres = iou_thres
|
332
|
+
self.nc = len(names) # number of classes
|
333
|
+
self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
|
334
|
+
self.names = names # name of classes
|
341
335
|
|
342
336
|
def process_cls_preds(self, preds, targets):
|
343
337
|
"""
|
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
|
|
351
345
|
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
352
346
|
self.matrix[p][t] += 1
|
353
347
|
|
354
|
-
def process_batch(
|
348
|
+
def process_batch(
|
349
|
+
self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
|
350
|
+
) -> None:
|
355
351
|
"""
|
356
352
|
Update confusion matrix for object detection task.
|
357
353
|
|
358
354
|
Args:
|
359
|
-
detections (
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
355
|
+
detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
|
356
|
+
Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
|
357
|
+
Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
|
358
|
+
batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
|
359
|
+
'cls' (Array[M]) keys, where M is the number of ground truth objects.
|
360
|
+
conf (float, optional): Confidence threshold for detections.
|
361
|
+
iou_thres (float, optional): IoU threshold for matching detections to ground truth.
|
364
362
|
"""
|
363
|
+
conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
364
|
+
gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
|
365
|
+
no_pred = len(detections["cls"]) == 0
|
365
366
|
if gt_cls.shape[0] == 0: # Check if labels is empty
|
366
|
-
if
|
367
|
-
detections = detections[detections[
|
368
|
-
detection_classes = detections[
|
367
|
+
if not no_pred:
|
368
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
369
|
+
detection_classes = detections["cls"].int().tolist()
|
369
370
|
for dc in detection_classes:
|
370
371
|
self.matrix[dc, self.nc] += 1 # false positives
|
371
372
|
return
|
372
|
-
if
|
373
|
+
if no_pred:
|
373
374
|
gt_classes = gt_cls.int().tolist()
|
374
375
|
for gc in gt_classes:
|
375
376
|
self.matrix[self.nc, gc] += 1 # background FN
|
376
377
|
return
|
377
378
|
|
378
|
-
detections = detections[detections[
|
379
|
+
detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
|
379
380
|
gt_classes = gt_cls.int().tolist()
|
380
|
-
detection_classes = detections[
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
)
|
387
|
-
|
388
|
-
x = torch.where(iou > self.iou_thres)
|
381
|
+
detection_classes = detections["cls"].int().tolist()
|
382
|
+
bboxes = detections["bboxes"]
|
383
|
+
is_obb = bboxes.shape[1] == 5 # check if detections contains angle for OBB
|
384
|
+
iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
|
385
|
+
|
386
|
+
x = torch.where(iou > iou_thres)
|
389
387
|
if x[0].shape[0]:
|
390
388
|
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
391
389
|
if x[0].shape[0] > 1:
|
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
949
947
|
Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
|
950
948
|
|
951
949
|
Attributes:
|
952
|
-
save_dir (Path): A path to the directory where the output plots will be saved.
|
953
|
-
plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
|
954
950
|
names (Dict[int, str]): A dictionary of class names.
|
955
951
|
box (Metric): An instance of the Metric class for storing detection results.
|
956
952
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
957
953
|
task (str): The task type, set to 'detect'.
|
954
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
955
|
+
nt_per_class: Number of targets per class.
|
956
|
+
nt_per_image: Number of targets per image.
|
958
957
|
"""
|
959
958
|
|
960
|
-
def __init__(self,
|
959
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
961
960
|
"""
|
962
961
|
Initialize a DetMetrics instance with a save directory, plot flag, and class names.
|
963
962
|
|
964
963
|
Args:
|
965
|
-
save_dir (Path, optional): Directory to save plots.
|
966
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
967
964
|
names (Dict[int, str], optional): Dictionary of class names.
|
968
965
|
"""
|
969
|
-
self.save_dir = save_dir
|
970
|
-
self.plot = plot
|
971
966
|
self.names = names
|
972
967
|
self.box = Metric()
|
973
968
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
974
969
|
self.task = "detect"
|
970
|
+
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
971
|
+
self.nt_per_class = None
|
972
|
+
self.nt_per_image = None
|
975
973
|
|
976
|
-
def
|
974
|
+
def update_stats(self, stat: Dict[str, Any]) -> None:
|
975
|
+
"""
|
976
|
+
Update statistics by appending new values to existing stat collections.
|
977
|
+
|
978
|
+
Args:
|
979
|
+
stat (Dict[str, any]): Dictionary containing new statistical values to append.
|
980
|
+
Keys should match existing keys in self.stats.
|
981
|
+
"""
|
982
|
+
for k in self.stats.keys():
|
983
|
+
self.stats[k].append(stat[k])
|
984
|
+
|
985
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
977
986
|
"""
|
978
987
|
Process predicted results for object detection and update metrics.
|
979
988
|
|
980
989
|
Args:
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
990
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
991
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
992
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
993
|
+
|
994
|
+
Returns:
|
995
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
986
996
|
"""
|
997
|
+
stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
|
998
|
+
if len(stats) == 0:
|
999
|
+
return stats
|
987
1000
|
results = ap_per_class(
|
988
|
-
tp,
|
989
|
-
conf,
|
990
|
-
pred_cls,
|
991
|
-
target_cls,
|
992
|
-
plot=
|
993
|
-
save_dir=
|
1001
|
+
stats["tp"],
|
1002
|
+
stats["conf"],
|
1003
|
+
stats["pred_cls"],
|
1004
|
+
stats["target_cls"],
|
1005
|
+
plot=plot,
|
1006
|
+
save_dir=save_dir,
|
994
1007
|
names=self.names,
|
995
1008
|
on_plot=on_plot,
|
996
1009
|
)[2:]
|
997
1010
|
self.box.nc = len(self.names)
|
998
1011
|
self.box.update(results)
|
1012
|
+
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
|
1013
|
+
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
|
1014
|
+
return stats
|
1015
|
+
|
1016
|
+
def clear_stats(self):
|
1017
|
+
"""Clear the stored statistics."""
|
1018
|
+
for v in self.stats.values():
|
1019
|
+
v.clear()
|
999
1020
|
|
1000
1021
|
@property
|
1001
1022
|
def keys(self) -> List[str]:
|
@@ -1077,92 +1098,65 @@ class DetMetrics(SimpleClass, DataExportMixin):
|
|
1077
1098
|
]
|
1078
1099
|
|
1079
1100
|
|
1080
|
-
class SegmentMetrics(
|
1101
|
+
class SegmentMetrics(DetMetrics):
|
1081
1102
|
"""
|
1082
1103
|
Calculate and aggregate detection and segmentation metrics over a given set of classes.
|
1083
1104
|
|
1084
1105
|
Attributes:
|
1085
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1086
|
-
plot (bool): Whether to save the detection and segmentation plots.
|
1087
1106
|
names (Dict[int, str]): Dictionary of class names.
|
1088
1107
|
box (Metric): An instance of the Metric class for storing detection results.
|
1089
1108
|
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
1090
1109
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1091
1110
|
task (str): The task type, set to 'segment'.
|
1111
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1112
|
+
nt_per_class: Number of targets per class.
|
1113
|
+
nt_per_image: Number of targets per image.
|
1092
1114
|
"""
|
1093
1115
|
|
1094
|
-
def __init__(self,
|
1116
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1095
1117
|
"""
|
1096
1118
|
Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
|
1097
1119
|
|
1098
1120
|
Args:
|
1099
|
-
save_dir (Path, optional): Directory to save plots.
|
1100
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1101
1121
|
names (Dict[int, str], optional): Dictionary of class names.
|
1102
1122
|
"""
|
1103
|
-
self
|
1104
|
-
self.plot = plot
|
1105
|
-
self.names = names
|
1106
|
-
self.box = Metric()
|
1123
|
+
DetMetrics.__init__(self, names)
|
1107
1124
|
self.seg = Metric()
|
1108
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1109
1125
|
self.task = "segment"
|
1126
|
+
self.stats["tp_m"] = [] # add additional stats for masks
|
1110
1127
|
|
1111
|
-
def process(
|
1112
|
-
self,
|
1113
|
-
tp: np.ndarray,
|
1114
|
-
tp_m: np.ndarray,
|
1115
|
-
conf: np.ndarray,
|
1116
|
-
pred_cls: np.ndarray,
|
1117
|
-
target_cls: np.ndarray,
|
1118
|
-
on_plot=None,
|
1119
|
-
):
|
1128
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
1120
1129
|
"""
|
1121
1130
|
Process the detection and segmentation metrics over the given set of predictions.
|
1122
1131
|
|
1123
1132
|
Args:
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1133
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
1134
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
1135
|
+
on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
|
1136
|
+
|
1137
|
+
Returns:
|
1138
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
1130
1139
|
"""
|
1140
|
+
stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
1131
1141
|
results_mask = ap_per_class(
|
1132
|
-
tp_m,
|
1133
|
-
conf,
|
1134
|
-
pred_cls,
|
1135
|
-
target_cls,
|
1136
|
-
plot=
|
1142
|
+
stats["tp_m"],
|
1143
|
+
stats["conf"],
|
1144
|
+
stats["pred_cls"],
|
1145
|
+
stats["target_cls"],
|
1146
|
+
plot=plot,
|
1137
1147
|
on_plot=on_plot,
|
1138
|
-
save_dir=
|
1148
|
+
save_dir=save_dir,
|
1139
1149
|
names=self.names,
|
1140
1150
|
prefix="Mask",
|
1141
1151
|
)[2:]
|
1142
1152
|
self.seg.nc = len(self.names)
|
1143
1153
|
self.seg.update(results_mask)
|
1144
|
-
|
1145
|
-
tp,
|
1146
|
-
conf,
|
1147
|
-
pred_cls,
|
1148
|
-
target_cls,
|
1149
|
-
plot=self.plot,
|
1150
|
-
on_plot=on_plot,
|
1151
|
-
save_dir=self.save_dir,
|
1152
|
-
names=self.names,
|
1153
|
-
prefix="Box",
|
1154
|
-
)[2:]
|
1155
|
-
self.box.nc = len(self.names)
|
1156
|
-
self.box.update(results_box)
|
1154
|
+
return stats
|
1157
1155
|
|
1158
1156
|
@property
|
1159
1157
|
def keys(self) -> List[str]:
|
1160
1158
|
"""Return a list of keys for accessing metrics."""
|
1161
|
-
return [
|
1162
|
-
"metrics/precision(B)",
|
1163
|
-
"metrics/recall(B)",
|
1164
|
-
"metrics/mAP50(B)",
|
1165
|
-
"metrics/mAP50-95(B)",
|
1159
|
+
return DetMetrics.keys.fget(self) + [
|
1166
1160
|
"metrics/precision(M)",
|
1167
1161
|
"metrics/recall(M)",
|
1168
1162
|
"metrics/mAP50(M)",
|
@@ -1171,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1171
1165
|
|
1172
1166
|
def mean_results(self) -> List[float]:
|
1173
1167
|
"""Return the mean metrics for bounding box and segmentation results."""
|
1174
|
-
return
|
1168
|
+
return DetMetrics.mean_results(self) + self.seg.mean_results()
|
1175
1169
|
|
1176
1170
|
def class_result(self, i: int) -> List[float]:
|
1177
1171
|
"""Return classification results for a specified class index."""
|
1178
|
-
return
|
1172
|
+
return DetMetrics.class_result(self, i) + self.seg.class_result(i)
|
1179
1173
|
|
1180
1174
|
@property
|
1181
1175
|
def maps(self) -> np.ndarray:
|
1182
1176
|
"""Return mAP scores for object detection and semantic segmentation models."""
|
1183
|
-
return
|
1177
|
+
return DetMetrics.maps.fget(self) + self.seg.maps
|
1184
1178
|
|
1185
1179
|
@property
|
1186
1180
|
def fitness(self) -> float:
|
1187
1181
|
"""Return the fitness score for both segmentation and bounding box models."""
|
1188
|
-
return self.seg.fitness() +
|
1189
|
-
|
1190
|
-
@property
|
1191
|
-
def ap_class_index(self) -> List:
|
1192
|
-
"""Return the class indices (boxes and masks have the same ap_class_index)."""
|
1193
|
-
return self.box.ap_class_index
|
1194
|
-
|
1195
|
-
@property
|
1196
|
-
def results_dict(self) -> Dict[str, float]:
|
1197
|
-
"""Return results of object detection model for evaluation."""
|
1198
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1182
|
+
return self.seg.fitness() + DetMetrics.fitness.fget(self)
|
1199
1183
|
|
1200
1184
|
@property
|
1201
1185
|
def curves(self) -> List[str]:
|
1202
1186
|
"""Return a list of curves for accessing specific metrics curves."""
|
1203
|
-
return [
|
1204
|
-
"Precision-Recall(B)",
|
1205
|
-
"F1-Confidence(B)",
|
1206
|
-
"Precision-Confidence(B)",
|
1207
|
-
"Recall-Confidence(B)",
|
1187
|
+
return DetMetrics.curves.fget(self) + [
|
1208
1188
|
"Precision-Recall(M)",
|
1209
1189
|
"F1-Confidence(M)",
|
1210
1190
|
"Precision-Confidence(M)",
|
@@ -1214,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1214
1194
|
@property
|
1215
1195
|
def curves_results(self) -> List[List]:
|
1216
1196
|
"""Return dictionary of computed performance metrics and statistics."""
|
1217
|
-
return
|
1197
|
+
return DetMetrics.curves_results.fget(self) + self.seg.curves_results
|
1218
1198
|
|
1219
1199
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1220
1200
|
"""
|
@@ -1234,43 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
|
|
1234
1214
|
>>> print(seg_summary)
|
1235
1215
|
"""
|
1236
1216
|
scalars = {
|
1237
|
-
"box-map": round(self.box.map, decimals),
|
1238
|
-
"box-map50": round(self.box.map50, decimals),
|
1239
|
-
"box-map75": round(self.box.map75, decimals),
|
1240
1217
|
"mask-map": round(self.seg.map, decimals),
|
1241
1218
|
"mask-map50": round(self.seg.map50, decimals),
|
1242
1219
|
"mask-map75": round(self.seg.map75, decimals),
|
1243
1220
|
}
|
1244
1221
|
per_class = {
|
1245
|
-
"box-p": self.box.p,
|
1246
|
-
"box-r": self.box.r,
|
1247
|
-
"box-f1": self.box.f1,
|
1248
1222
|
"mask-p": self.seg.p,
|
1249
1223
|
"mask-r": self.seg.r,
|
1250
1224
|
"mask-f1": self.seg.f1,
|
1251
1225
|
}
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
**scalars,
|
1257
|
-
}
|
1258
|
-
for i in range(len(per_class["box-p"]))
|
1259
|
-
]
|
1226
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
1227
|
+
for i, s in enumerate(summary):
|
1228
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
1229
|
+
return summary
|
1260
1230
|
|
1261
1231
|
|
1262
|
-
class PoseMetrics(
|
1232
|
+
class PoseMetrics(DetMetrics):
|
1263
1233
|
"""
|
1264
1234
|
Calculate and aggregate detection and pose metrics over a given set of classes.
|
1265
1235
|
|
1266
1236
|
Attributes:
|
1267
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1268
|
-
plot (bool): Whether to save the detection and pose plots.
|
1269
1237
|
names (Dict[int, str]): Dictionary of class names.
|
1270
1238
|
pose (Metric): An instance of the Metric class to calculate pose metrics.
|
1271
1239
|
box (Metric): An instance of the Metric class for storing detection results.
|
1272
1240
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1273
1241
|
task (str): The task type, set to 'pose'.
|
1242
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1243
|
+
nt_per_class: Number of targets per class.
|
1244
|
+
nt_per_image: Number of targets per image.
|
1274
1245
|
|
1275
1246
|
Methods:
|
1276
1247
|
process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
|
@@ -1282,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
|
|
1282
1253
|
results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
|
1283
1254
|
"""
|
1284
1255
|
|
1285
|
-
def __init__(self,
|
1256
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1286
1257
|
"""
|
1287
1258
|
Initialize the PoseMetrics class with directory path, class names, and plotting options.
|
1288
1259
|
|
1289
1260
|
Args:
|
1290
|
-
save_dir (Path, optional): Directory to save plots.
|
1291
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1292
1261
|
names (Dict[int, str], optional): Dictionary of class names.
|
1293
1262
|
"""
|
1294
|
-
super().__init__(
|
1295
|
-
self.save_dir = save_dir
|
1296
|
-
self.plot = plot
|
1297
|
-
self.names = names
|
1298
|
-
self.box = Metric()
|
1263
|
+
super().__init__(names)
|
1299
1264
|
self.pose = Metric()
|
1300
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1301
1265
|
self.task = "pose"
|
1266
|
+
self.stats["tp_p"] = [] # add additional stats for pose
|
1302
1267
|
|
1303
|
-
def process(
|
1304
|
-
self,
|
1305
|
-
tp: np.ndarray,
|
1306
|
-
tp_p: np.ndarray,
|
1307
|
-
conf: np.ndarray,
|
1308
|
-
pred_cls: np.ndarray,
|
1309
|
-
target_cls: np.ndarray,
|
1310
|
-
on_plot=None,
|
1311
|
-
):
|
1268
|
+
def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
|
1312
1269
|
"""
|
1313
1270
|
Process the detection and pose metrics over the given set of predictions.
|
1314
1271
|
|
1315
1272
|
Args:
|
1316
|
-
|
1317
|
-
|
1318
|
-
conf (np.ndarray): Confidence array.
|
1319
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
1320
|
-
target_cls (np.ndarray): Target class indices array.
|
1273
|
+
save_dir (Path): Directory to save plots. Defaults to Path(".").
|
1274
|
+
plot (bool): Whether to plot precision-recall curves. Defaults to False.
|
1321
1275
|
on_plot (callable, optional): Function to call after plots are generated.
|
1276
|
+
|
1277
|
+
Returns:
|
1278
|
+
(Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
|
1322
1279
|
"""
|
1280
|
+
stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
|
1323
1281
|
results_pose = ap_per_class(
|
1324
|
-
tp_p,
|
1325
|
-
conf,
|
1326
|
-
pred_cls,
|
1327
|
-
target_cls,
|
1328
|
-
plot=
|
1282
|
+
stats["tp_p"],
|
1283
|
+
stats["conf"],
|
1284
|
+
stats["pred_cls"],
|
1285
|
+
stats["target_cls"],
|
1286
|
+
plot=plot,
|
1329
1287
|
on_plot=on_plot,
|
1330
|
-
save_dir=
|
1288
|
+
save_dir=save_dir,
|
1331
1289
|
names=self.names,
|
1332
1290
|
prefix="Pose",
|
1333
1291
|
)[2:]
|
1334
1292
|
self.pose.nc = len(self.names)
|
1335
1293
|
self.pose.update(results_pose)
|
1336
|
-
|
1337
|
-
tp,
|
1338
|
-
conf,
|
1339
|
-
pred_cls,
|
1340
|
-
target_cls,
|
1341
|
-
plot=self.plot,
|
1342
|
-
on_plot=on_plot,
|
1343
|
-
save_dir=self.save_dir,
|
1344
|
-
names=self.names,
|
1345
|
-
prefix="Box",
|
1346
|
-
)[2:]
|
1347
|
-
self.box.nc = len(self.names)
|
1348
|
-
self.box.update(results_box)
|
1294
|
+
return stats
|
1349
1295
|
|
1350
1296
|
@property
|
1351
1297
|
def keys(self) -> List[str]:
|
1352
1298
|
"""Return list of evaluation metric keys."""
|
1353
|
-
return [
|
1354
|
-
"metrics/precision(B)",
|
1355
|
-
"metrics/recall(B)",
|
1356
|
-
"metrics/mAP50(B)",
|
1357
|
-
"metrics/mAP50-95(B)",
|
1299
|
+
return DetMetrics.keys.fget(self) + [
|
1358
1300
|
"metrics/precision(P)",
|
1359
1301
|
"metrics/recall(P)",
|
1360
1302
|
"metrics/mAP50(P)",
|
@@ -1363,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
|
|
1363
1305
|
|
1364
1306
|
def mean_results(self) -> List[float]:
|
1365
1307
|
"""Return the mean results of box and pose."""
|
1366
|
-
return
|
1308
|
+
return DetMetrics.mean_results(self) + self.pose.mean_results()
|
1367
1309
|
|
1368
1310
|
def class_result(self, i: int) -> List[float]:
|
1369
1311
|
"""Return the class-wise detection results for a specific class i."""
|
1370
|
-
return
|
1312
|
+
return DetMetrics.class_result(self, i) + self.pose.class_result(i)
|
1371
1313
|
|
1372
1314
|
@property
|
1373
1315
|
def maps(self) -> np.ndarray:
|
1374
1316
|
"""Return the mean average precision (mAP) per class for both box and pose detections."""
|
1375
|
-
return
|
1317
|
+
return DetMetrics.maps.fget(self) + self.pose.maps
|
1376
1318
|
|
1377
1319
|
@property
|
1378
1320
|
def fitness(self) -> float:
|
1379
1321
|
"""Return combined fitness score for pose and box detection."""
|
1380
|
-
return self.pose.fitness() +
|
1322
|
+
return self.pose.fitness() + DetMetrics.fitness.fget(self)
|
1381
1323
|
|
1382
1324
|
@property
|
1383
1325
|
def curves(self) -> List[str]:
|
1384
1326
|
"""Return a list of curves for accessing specific metrics curves."""
|
1385
|
-
return [
|
1327
|
+
return DetMetrics.curves.fget(self) + [
|
1386
1328
|
"Precision-Recall(B)",
|
1387
1329
|
"F1-Confidence(B)",
|
1388
1330
|
"Precision-Confidence(B)",
|
@@ -1396,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
|
|
1396
1338
|
@property
|
1397
1339
|
def curves_results(self) -> List[List]:
|
1398
1340
|
"""Return dictionary of computed performance metrics and statistics."""
|
1399
|
-
return
|
1341
|
+
return DetMetrics.curves_results.fget(self) + self.pose.curves_results
|
1400
1342
|
|
1401
1343
|
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1402
1344
|
"""
|
@@ -1416,29 +1358,19 @@ class PoseMetrics(SegmentMetrics):
|
|
1416
1358
|
>>> print(pose_summary)
|
1417
1359
|
"""
|
1418
1360
|
scalars = {
|
1419
|
-
"box-map": round(self.box.map, decimals),
|
1420
|
-
"box-map50": round(self.box.map50, decimals),
|
1421
|
-
"box-map75": round(self.box.map75, decimals),
|
1422
1361
|
"pose-map": round(self.pose.map, decimals),
|
1423
1362
|
"pose-map50": round(self.pose.map50, decimals),
|
1424
1363
|
"pose-map75": round(self.pose.map75, decimals),
|
1425
1364
|
}
|
1426
1365
|
per_class = {
|
1427
|
-
"box-p": self.box.p,
|
1428
|
-
"box-r": self.box.r,
|
1429
|
-
"box-f1": self.box.f1,
|
1430
1366
|
"pose-p": self.pose.p,
|
1431
1367
|
"pose-r": self.pose.r,
|
1432
1368
|
"pose-f1": self.pose.f1,
|
1433
1369
|
}
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
**scalars,
|
1439
|
-
}
|
1440
|
-
for i in range(len(per_class["box-p"]))
|
1441
|
-
]
|
1370
|
+
summary = DetMetrics.summary(self, normalize, decimals) # get box summary
|
1371
|
+
for i, s in enumerate(summary):
|
1372
|
+
s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
|
1373
|
+
return summary
|
1442
1374
|
|
1443
1375
|
|
1444
1376
|
class ClassifyMetrics(SimpleClass, DataExportMixin):
|
@@ -1516,133 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
|
|
1516
1448
|
return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
|
1517
1449
|
|
1518
1450
|
|
1519
|
-
class OBBMetrics(
|
1451
|
+
class OBBMetrics(DetMetrics):
|
1520
1452
|
"""
|
1521
1453
|
Metrics for evaluating oriented bounding box (OBB) detection.
|
1522
1454
|
|
1523
1455
|
Attributes:
|
1524
|
-
save_dir (Path): Path to the directory where the output plots should be saved.
|
1525
|
-
plot (bool): Whether to save the detection plots.
|
1526
1456
|
names (Dict[int, str]): Dictionary of class names.
|
1527
1457
|
box (Metric): An instance of the Metric class for storing detection results.
|
1528
1458
|
speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
|
1529
1459
|
task (str): The task type, set to 'obb'.
|
1460
|
+
stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
|
1461
|
+
nt_per_class: Number of targets per class.
|
1462
|
+
nt_per_image: Number of targets per image.
|
1530
1463
|
|
1531
1464
|
References:
|
1532
1465
|
https://arxiv.org/pdf/2106.06072.pdf
|
1533
1466
|
"""
|
1534
1467
|
|
1535
|
-
def __init__(self,
|
1468
|
+
def __init__(self, names: Dict[int, str] = {}) -> None:
|
1536
1469
|
"""
|
1537
1470
|
Initialize an OBBMetrics instance with directory, plotting, and class names.
|
1538
1471
|
|
1539
1472
|
Args:
|
1540
|
-
save_dir (Path, optional): Directory to save plots.
|
1541
|
-
plot (bool, optional): Whether to plot precision-recall curves.
|
1542
1473
|
names (Dict[int, str], optional): Dictionary of class names.
|
1543
1474
|
"""
|
1544
|
-
self
|
1545
|
-
|
1546
|
-
self.names = names
|
1547
|
-
self.box = Metric()
|
1548
|
-
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
1475
|
+
DetMetrics.__init__(self, names)
|
1476
|
+
# TODO: probably remove task as well
|
1549
1477
|
self.task = "obb"
|
1550
|
-
|
1551
|
-
def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
|
1552
|
-
"""
|
1553
|
-
Process predicted results for object detection and update metrics.
|
1554
|
-
|
1555
|
-
Args:
|
1556
|
-
tp (np.ndarray): True positive array.
|
1557
|
-
conf (np.ndarray): Confidence array.
|
1558
|
-
pred_cls (np.ndarray): Predicted class indices array.
|
1559
|
-
target_cls (np.ndarray): Target class indices array.
|
1560
|
-
on_plot (callable, optional): Function to call after plots are generated.
|
1561
|
-
"""
|
1562
|
-
results = ap_per_class(
|
1563
|
-
tp,
|
1564
|
-
conf,
|
1565
|
-
pred_cls,
|
1566
|
-
target_cls,
|
1567
|
-
plot=self.plot,
|
1568
|
-
save_dir=self.save_dir,
|
1569
|
-
names=self.names,
|
1570
|
-
on_plot=on_plot,
|
1571
|
-
)[2:]
|
1572
|
-
self.box.nc = len(self.names)
|
1573
|
-
self.box.update(results)
|
1574
|
-
|
1575
|
-
@property
|
1576
|
-
def keys(self) -> List[str]:
|
1577
|
-
"""Return a list of keys for accessing specific metrics."""
|
1578
|
-
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
|
1579
|
-
|
1580
|
-
def mean_results(self) -> List[float]:
|
1581
|
-
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
1582
|
-
return self.box.mean_results()
|
1583
|
-
|
1584
|
-
def class_result(self, i: int) -> Tuple[float, float, float, float]:
|
1585
|
-
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
1586
|
-
return self.box.class_result(i)
|
1587
|
-
|
1588
|
-
@property
|
1589
|
-
def maps(self) -> np.ndarray:
|
1590
|
-
"""Return mean Average Precision (mAP) scores per class."""
|
1591
|
-
return self.box.maps
|
1592
|
-
|
1593
|
-
@property
|
1594
|
-
def fitness(self) -> float:
|
1595
|
-
"""Return the fitness of box object."""
|
1596
|
-
return self.box.fitness()
|
1597
|
-
|
1598
|
-
@property
|
1599
|
-
def ap_class_index(self) -> List:
|
1600
|
-
"""Return the average precision index per class."""
|
1601
|
-
return self.box.ap_class_index
|
1602
|
-
|
1603
|
-
@property
|
1604
|
-
def results_dict(self) -> Dict[str, float]:
|
1605
|
-
"""Return dictionary of computed performance metrics and statistics."""
|
1606
|
-
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
|
1607
|
-
|
1608
|
-
@property
|
1609
|
-
def curves(self) -> List:
|
1610
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
1611
|
-
return []
|
1612
|
-
|
1613
|
-
@property
|
1614
|
-
def curves_results(self) -> List:
|
1615
|
-
"""Return a list of curves for accessing specific metrics curves."""
|
1616
|
-
return []
|
1617
|
-
|
1618
|
-
def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
|
1619
|
-
"""
|
1620
|
-
Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
|
1621
|
-
scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
|
1622
|
-
|
1623
|
-
Args:
|
1624
|
-
normalize (bool): For OBB metrics, everything is normalized by default [0-1].
|
1625
|
-
decimals (int): Number of decimal places to round the metrics values to.
|
1626
|
-
|
1627
|
-
Returns:
|
1628
|
-
(List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
|
1629
|
-
|
1630
|
-
Examples:
|
1631
|
-
>>> results = model.val(data="dota8.yaml")
|
1632
|
-
>>> detection_summary = results.summary(decimals=4)
|
1633
|
-
>>> print(detection_summary)
|
1634
|
-
"""
|
1635
|
-
scalars = {
|
1636
|
-
"box-map": round(self.box.map, decimals),
|
1637
|
-
"box-map50": round(self.box.map50, decimals),
|
1638
|
-
"box-map75": round(self.box.map75, decimals),
|
1639
|
-
}
|
1640
|
-
per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
|
1641
|
-
return [
|
1642
|
-
{
|
1643
|
-
"class_name": self.names[self.ap_class_index[i]],
|
1644
|
-
**{k: round(v[i], decimals) for k, v in per_class.items()},
|
1645
|
-
**scalars,
|
1646
|
-
}
|
1647
|
-
for i in range(len(per_class["box-p"]))
|
1648
|
-
]
|