ultralytics 8.3.153__py3-none-any.whl → 8.3.154__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. tests/test_python.py +1 -0
  2. ultralytics/__init__.py +1 -1
  3. ultralytics/cfg/__init__.py +2 -0
  4. ultralytics/engine/predictor.py +1 -1
  5. ultralytics/engine/validator.py +0 -6
  6. ultralytics/models/fastsam/val.py +0 -2
  7. ultralytics/models/rtdetr/val.py +28 -16
  8. ultralytics/models/yolo/classify/val.py +26 -23
  9. ultralytics/models/yolo/detect/train.py +4 -7
  10. ultralytics/models/yolo/detect/val.py +88 -90
  11. ultralytics/models/yolo/obb/val.py +52 -44
  12. ultralytics/models/yolo/pose/train.py +1 -35
  13. ultralytics/models/yolo/pose/val.py +77 -176
  14. ultralytics/models/yolo/segment/train.py +1 -41
  15. ultralytics/models/yolo/segment/val.py +64 -176
  16. ultralytics/models/yolo/yoloe/val.py +2 -1
  17. ultralytics/nn/autobackend.py +2 -2
  18. ultralytics/solutions/ai_gym.py +2 -3
  19. ultralytics/solutions/solutions.py +2 -0
  20. ultralytics/solutions/templates/similarity-search.html +31 -0
  21. ultralytics/utils/callbacks/comet.py +1 -1
  22. ultralytics/utils/metrics.py +146 -317
  23. ultralytics/utils/ops.py +4 -4
  24. ultralytics/utils/plotting.py +31 -56
  25. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/METADATA +1 -1
  26. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/RECORD +30 -30
  27. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/WHEEL +0 -0
  28. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/entry_points.txt +0 -0
  29. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/licenses/LICENSE +0 -0
  30. {ultralytics-8.3.153.dist-info → ultralytics-8.3.154.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@
4
4
  import math
5
5
  import warnings
6
6
  from pathlib import Path
7
- from typing import Dict, List, Tuple, Union
7
+ from typing import Any, Dict, List, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
316
316
  Attributes:
317
317
  task (str): The type of task, either 'detect' or 'classify'.
318
318
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
319
- nc (int): The number of classes.
320
- conf (float): The confidence threshold for detections.
321
- iou_thres (float): The Intersection over Union threshold.
319
+ nc (int): The number of category.
320
+ names (List[str]): The names of the classes, used as labels on the plot.
322
321
  """
323
322
 
324
- def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
323
+ def __init__(self, names: List[str] = [], task: str = "detect"):
325
324
  """
326
325
  Initialize a ConfusionMatrix instance.
327
326
 
328
327
  Args:
329
- nc (int): Number of classes.
330
- conf (float, optional): Confidence threshold for detections.
331
- iou_thres (float, optional): IoU threshold for matching detections to ground truth.
332
- names (tuple, optional): Names of classes, used as labels on the plot.
328
+ names (List[str], optional): Names of classes, used as labels on the plot.
333
329
  task (str, optional): Type of task, either 'detect' or 'classify'.
334
330
  """
335
331
  self.task = task
336
- self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
337
- self.nc = nc # number of classes
338
- self.names = list(names) # name of classes
339
- self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
340
- self.iou_thres = iou_thres
332
+ self.nc = len(names) # number of classes
333
+ self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
334
+ self.names = names # name of classes
341
335
 
342
336
  def process_cls_preds(self, preds, targets):
343
337
  """
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
351
345
  for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
352
346
  self.matrix[p][t] += 1
353
347
 
354
- def process_batch(self, detections, gt_bboxes, gt_cls):
348
+ def process_batch(
349
+ self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
350
+ ) -> None:
355
351
  """
356
352
  Update confusion matrix for object detection task.
357
353
 
358
354
  Args:
359
- detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
360
- Each row should contain (x1, y1, x2, y2, conf, class)
361
- or with an additional element `angle` when it's obb.
362
- gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
363
- gt_cls (Array[M]): The class labels.
355
+ detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
356
+ Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
357
+ Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
358
+ batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
359
+ 'cls' (Array[M]) keys, where M is the number of ground truth objects.
360
+ conf (float, optional): Confidence threshold for detections.
361
+ iou_thres (float, optional): IoU threshold for matching detections to ground truth.
364
362
  """
363
+ conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
364
+ gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
365
+ no_pred = len(detections["cls"]) == 0
365
366
  if gt_cls.shape[0] == 0: # Check if labels is empty
366
- if detections is not None:
367
- detections = detections[detections[:, 4] > self.conf]
368
- detection_classes = detections[:, 5].int().tolist()
367
+ if not no_pred:
368
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
369
+ detection_classes = detections["cls"].int().tolist()
369
370
  for dc in detection_classes:
370
371
  self.matrix[dc, self.nc] += 1 # false positives
371
372
  return
372
- if detections is None:
373
+ if no_pred:
373
374
  gt_classes = gt_cls.int().tolist()
374
375
  for gc in gt_classes:
375
376
  self.matrix[self.nc, gc] += 1 # background FN
376
377
  return
377
378
 
378
- detections = detections[detections[:, 4] > self.conf]
379
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
379
380
  gt_classes = gt_cls.int().tolist()
380
- detection_classes = detections[:, 5].int().tolist()
381
- is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
382
- iou = (
383
- batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
384
- if is_obb
385
- else box_iou(gt_bboxes, detections[:, :4])
386
- )
387
-
388
- x = torch.where(iou > self.iou_thres)
381
+ detection_classes = detections["cls"].int().tolist()
382
+ bboxes = detections["bboxes"]
383
+ is_obb = bboxes.shape[1] == 5 # check if detections contains angle for OBB
384
+ iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
385
+
386
+ x = torch.where(iou > iou_thres)
389
387
  if x[0].shape[0]:
390
388
  matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
391
389
  if x[0].shape[0] > 1:
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
949
947
  Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
950
948
 
951
949
  Attributes:
952
- save_dir (Path): A path to the directory where the output plots will be saved.
953
- plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
954
950
  names (Dict[int, str]): A dictionary of class names.
955
951
  box (Metric): An instance of the Metric class for storing detection results.
956
952
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
957
953
  task (str): The task type, set to 'detect'.
954
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
955
+ nt_per_class: Number of targets per class.
956
+ nt_per_image: Number of targets per image.
958
957
  """
959
958
 
960
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
959
+ def __init__(self, names: Dict[int, str] = {}) -> None:
961
960
  """
962
961
  Initialize a DetMetrics instance with a save directory, plot flag, and class names.
963
962
 
964
963
  Args:
965
- save_dir (Path, optional): Directory to save plots.
966
- plot (bool, optional): Whether to plot precision-recall curves.
967
964
  names (Dict[int, str], optional): Dictionary of class names.
968
965
  """
969
- self.save_dir = save_dir
970
- self.plot = plot
971
966
  self.names = names
972
967
  self.box = Metric()
973
968
  self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
974
969
  self.task = "detect"
970
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
971
+ self.nt_per_class = None
972
+ self.nt_per_image = None
975
973
 
976
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
974
+ def update_stats(self, stat: Dict[str, Any]) -> None:
975
+ """
976
+ Update statistics by appending new values to existing stat collections.
977
+
978
+ Args:
979
+ stat (Dict[str, any]): Dictionary containing new statistical values to append.
980
+ Keys should match existing keys in self.stats.
981
+ """
982
+ for k in self.stats.keys():
983
+ self.stats[k].append(stat[k])
984
+
985
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
977
986
  """
978
987
  Process predicted results for object detection and update metrics.
979
988
 
980
989
  Args:
981
- tp (np.ndarray): True positive array.
982
- conf (np.ndarray): Confidence array.
983
- pred_cls (np.ndarray): Predicted class indices array.
984
- target_cls (np.ndarray): Target class indices array.
985
- on_plot (callable, optional): Function to call after plots are generated.
990
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
991
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
992
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
993
+
994
+ Returns:
995
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
986
996
  """
997
+ stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
998
+ if len(stats) == 0:
999
+ return stats
987
1000
  results = ap_per_class(
988
- tp,
989
- conf,
990
- pred_cls,
991
- target_cls,
992
- plot=self.plot,
993
- save_dir=self.save_dir,
1001
+ stats["tp"],
1002
+ stats["conf"],
1003
+ stats["pred_cls"],
1004
+ stats["target_cls"],
1005
+ plot=plot,
1006
+ save_dir=save_dir,
994
1007
  names=self.names,
995
1008
  on_plot=on_plot,
996
1009
  )[2:]
997
1010
  self.box.nc = len(self.names)
998
1011
  self.box.update(results)
1012
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
1013
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
1014
+ return stats
1015
+
1016
+ def clear_stats(self):
1017
+ """Clear the stored statistics."""
1018
+ for v in self.stats.values():
1019
+ v.clear()
999
1020
 
1000
1021
  @property
1001
1022
  def keys(self) -> List[str]:
@@ -1077,92 +1098,65 @@ class DetMetrics(SimpleClass, DataExportMixin):
1077
1098
  ]
1078
1099
 
1079
1100
 
1080
- class SegmentMetrics(SimpleClass, DataExportMixin):
1101
+ class SegmentMetrics(DetMetrics):
1081
1102
  """
1082
1103
  Calculate and aggregate detection and segmentation metrics over a given set of classes.
1083
1104
 
1084
1105
  Attributes:
1085
- save_dir (Path): Path to the directory where the output plots should be saved.
1086
- plot (bool): Whether to save the detection and segmentation plots.
1087
1106
  names (Dict[int, str]): Dictionary of class names.
1088
1107
  box (Metric): An instance of the Metric class for storing detection results.
1089
1108
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1090
1109
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1091
1110
  task (str): The task type, set to 'segment'.
1111
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1112
+ nt_per_class: Number of targets per class.
1113
+ nt_per_image: Number of targets per image.
1092
1114
  """
1093
1115
 
1094
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1116
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1095
1117
  """
1096
1118
  Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1097
1119
 
1098
1120
  Args:
1099
- save_dir (Path, optional): Directory to save plots.
1100
- plot (bool, optional): Whether to plot precision-recall curves.
1101
1121
  names (Dict[int, str], optional): Dictionary of class names.
1102
1122
  """
1103
- self.save_dir = save_dir
1104
- self.plot = plot
1105
- self.names = names
1106
- self.box = Metric()
1123
+ DetMetrics.__init__(self, names)
1107
1124
  self.seg = Metric()
1108
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1109
1125
  self.task = "segment"
1126
+ self.stats["tp_m"] = [] # add additional stats for masks
1110
1127
 
1111
- def process(
1112
- self,
1113
- tp: np.ndarray,
1114
- tp_m: np.ndarray,
1115
- conf: np.ndarray,
1116
- pred_cls: np.ndarray,
1117
- target_cls: np.ndarray,
1118
- on_plot=None,
1119
- ):
1128
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
1120
1129
  """
1121
1130
  Process the detection and segmentation metrics over the given set of predictions.
1122
1131
 
1123
1132
  Args:
1124
- tp (np.ndarray): True positive array for boxes.
1125
- tp_m (np.ndarray): True positive array for masks.
1126
- conf (np.ndarray): Confidence array.
1127
- pred_cls (np.ndarray): Predicted class indices array.
1128
- target_cls (np.ndarray): Target class indices array.
1129
- on_plot (callable, optional): Function to call after plots are generated.
1133
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1134
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1135
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
1136
+
1137
+ Returns:
1138
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1130
1139
  """
1140
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
1131
1141
  results_mask = ap_per_class(
1132
- tp_m,
1133
- conf,
1134
- pred_cls,
1135
- target_cls,
1136
- plot=self.plot,
1142
+ stats["tp_m"],
1143
+ stats["conf"],
1144
+ stats["pred_cls"],
1145
+ stats["target_cls"],
1146
+ plot=plot,
1137
1147
  on_plot=on_plot,
1138
- save_dir=self.save_dir,
1148
+ save_dir=save_dir,
1139
1149
  names=self.names,
1140
1150
  prefix="Mask",
1141
1151
  )[2:]
1142
1152
  self.seg.nc = len(self.names)
1143
1153
  self.seg.update(results_mask)
1144
- results_box = ap_per_class(
1145
- tp,
1146
- conf,
1147
- pred_cls,
1148
- target_cls,
1149
- plot=self.plot,
1150
- on_plot=on_plot,
1151
- save_dir=self.save_dir,
1152
- names=self.names,
1153
- prefix="Box",
1154
- )[2:]
1155
- self.box.nc = len(self.names)
1156
- self.box.update(results_box)
1154
+ return stats
1157
1155
 
1158
1156
  @property
1159
1157
  def keys(self) -> List[str]:
1160
1158
  """Return a list of keys for accessing metrics."""
1161
- return [
1162
- "metrics/precision(B)",
1163
- "metrics/recall(B)",
1164
- "metrics/mAP50(B)",
1165
- "metrics/mAP50-95(B)",
1159
+ return DetMetrics.keys.fget(self) + [
1166
1160
  "metrics/precision(M)",
1167
1161
  "metrics/recall(M)",
1168
1162
  "metrics/mAP50(M)",
@@ -1171,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1171
1165
 
1172
1166
  def mean_results(self) -> List[float]:
1173
1167
  """Return the mean metrics for bounding box and segmentation results."""
1174
- return self.box.mean_results() + self.seg.mean_results()
1168
+ return DetMetrics.mean_results(self) + self.seg.mean_results()
1175
1169
 
1176
1170
  def class_result(self, i: int) -> List[float]:
1177
1171
  """Return classification results for a specified class index."""
1178
- return self.box.class_result(i) + self.seg.class_result(i)
1172
+ return DetMetrics.class_result(self, i) + self.seg.class_result(i)
1179
1173
 
1180
1174
  @property
1181
1175
  def maps(self) -> np.ndarray:
1182
1176
  """Return mAP scores for object detection and semantic segmentation models."""
1183
- return self.box.maps + self.seg.maps
1177
+ return DetMetrics.maps.fget(self) + self.seg.maps
1184
1178
 
1185
1179
  @property
1186
1180
  def fitness(self) -> float:
1187
1181
  """Return the fitness score for both segmentation and bounding box models."""
1188
- return self.seg.fitness() + self.box.fitness()
1189
-
1190
- @property
1191
- def ap_class_index(self) -> List:
1192
- """Return the class indices (boxes and masks have the same ap_class_index)."""
1193
- return self.box.ap_class_index
1194
-
1195
- @property
1196
- def results_dict(self) -> Dict[str, float]:
1197
- """Return results of object detection model for evaluation."""
1198
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1182
+ return self.seg.fitness() + DetMetrics.fitness.fget(self)
1199
1183
 
1200
1184
  @property
1201
1185
  def curves(self) -> List[str]:
1202
1186
  """Return a list of curves for accessing specific metrics curves."""
1203
- return [
1204
- "Precision-Recall(B)",
1205
- "F1-Confidence(B)",
1206
- "Precision-Confidence(B)",
1207
- "Recall-Confidence(B)",
1187
+ return DetMetrics.curves.fget(self) + [
1208
1188
  "Precision-Recall(M)",
1209
1189
  "F1-Confidence(M)",
1210
1190
  "Precision-Confidence(M)",
@@ -1214,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1214
1194
  @property
1215
1195
  def curves_results(self) -> List[List]:
1216
1196
  """Return dictionary of computed performance metrics and statistics."""
1217
- return self.box.curves_results + self.seg.curves_results
1197
+ return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1218
1198
 
1219
1199
  def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1220
1200
  """
@@ -1234,43 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1234
1214
  >>> print(seg_summary)
1235
1215
  """
1236
1216
  scalars = {
1237
- "box-map": round(self.box.map, decimals),
1238
- "box-map50": round(self.box.map50, decimals),
1239
- "box-map75": round(self.box.map75, decimals),
1240
1217
  "mask-map": round(self.seg.map, decimals),
1241
1218
  "mask-map50": round(self.seg.map50, decimals),
1242
1219
  "mask-map75": round(self.seg.map75, decimals),
1243
1220
  }
1244
1221
  per_class = {
1245
- "box-p": self.box.p,
1246
- "box-r": self.box.r,
1247
- "box-f1": self.box.f1,
1248
1222
  "mask-p": self.seg.p,
1249
1223
  "mask-r": self.seg.r,
1250
1224
  "mask-f1": self.seg.f1,
1251
1225
  }
1252
- return [
1253
- {
1254
- "class_name": self.names[self.ap_class_index[i]],
1255
- **{k: round(v[i], decimals) for k, v in per_class.items()},
1256
- **scalars,
1257
- }
1258
- for i in range(len(per_class["box-p"]))
1259
- ]
1226
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1227
+ for i, s in enumerate(summary):
1228
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
1229
+ return summary
1260
1230
 
1261
1231
 
1262
- class PoseMetrics(SegmentMetrics):
1232
+ class PoseMetrics(DetMetrics):
1263
1233
  """
1264
1234
  Calculate and aggregate detection and pose metrics over a given set of classes.
1265
1235
 
1266
1236
  Attributes:
1267
- save_dir (Path): Path to the directory where the output plots should be saved.
1268
- plot (bool): Whether to save the detection and pose plots.
1269
1237
  names (Dict[int, str]): Dictionary of class names.
1270
1238
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1271
1239
  box (Metric): An instance of the Metric class for storing detection results.
1272
1240
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1273
1241
  task (str): The task type, set to 'pose'.
1242
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1243
+ nt_per_class: Number of targets per class.
1244
+ nt_per_image: Number of targets per image.
1274
1245
 
1275
1246
  Methods:
1276
1247
  process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
@@ -1282,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
1282
1253
  results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
1283
1254
  """
1284
1255
 
1285
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1256
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1286
1257
  """
1287
1258
  Initialize the PoseMetrics class with directory path, class names, and plotting options.
1288
1259
 
1289
1260
  Args:
1290
- save_dir (Path, optional): Directory to save plots.
1291
- plot (bool, optional): Whether to plot precision-recall curves.
1292
1261
  names (Dict[int, str], optional): Dictionary of class names.
1293
1262
  """
1294
- super().__init__(save_dir, plot, names)
1295
- self.save_dir = save_dir
1296
- self.plot = plot
1297
- self.names = names
1298
- self.box = Metric()
1263
+ super().__init__(names)
1299
1264
  self.pose = Metric()
1300
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1301
1265
  self.task = "pose"
1266
+ self.stats["tp_p"] = [] # add additional stats for pose
1302
1267
 
1303
- def process(
1304
- self,
1305
- tp: np.ndarray,
1306
- tp_p: np.ndarray,
1307
- conf: np.ndarray,
1308
- pred_cls: np.ndarray,
1309
- target_cls: np.ndarray,
1310
- on_plot=None,
1311
- ):
1268
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
1312
1269
  """
1313
1270
  Process the detection and pose metrics over the given set of predictions.
1314
1271
 
1315
1272
  Args:
1316
- tp (np.ndarray): True positive array for boxes.
1317
- tp_p (np.ndarray): True positive array for keypoints.
1318
- conf (np.ndarray): Confidence array.
1319
- pred_cls (np.ndarray): Predicted class indices array.
1320
- target_cls (np.ndarray): Target class indices array.
1273
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1274
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1321
1275
  on_plot (callable, optional): Function to call after plots are generated.
1276
+
1277
+ Returns:
1278
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1322
1279
  """
1280
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
1323
1281
  results_pose = ap_per_class(
1324
- tp_p,
1325
- conf,
1326
- pred_cls,
1327
- target_cls,
1328
- plot=self.plot,
1282
+ stats["tp_p"],
1283
+ stats["conf"],
1284
+ stats["pred_cls"],
1285
+ stats["target_cls"],
1286
+ plot=plot,
1329
1287
  on_plot=on_plot,
1330
- save_dir=self.save_dir,
1288
+ save_dir=save_dir,
1331
1289
  names=self.names,
1332
1290
  prefix="Pose",
1333
1291
  )[2:]
1334
1292
  self.pose.nc = len(self.names)
1335
1293
  self.pose.update(results_pose)
1336
- results_box = ap_per_class(
1337
- tp,
1338
- conf,
1339
- pred_cls,
1340
- target_cls,
1341
- plot=self.plot,
1342
- on_plot=on_plot,
1343
- save_dir=self.save_dir,
1344
- names=self.names,
1345
- prefix="Box",
1346
- )[2:]
1347
- self.box.nc = len(self.names)
1348
- self.box.update(results_box)
1294
+ return stats
1349
1295
 
1350
1296
  @property
1351
1297
  def keys(self) -> List[str]:
1352
1298
  """Return list of evaluation metric keys."""
1353
- return [
1354
- "metrics/precision(B)",
1355
- "metrics/recall(B)",
1356
- "metrics/mAP50(B)",
1357
- "metrics/mAP50-95(B)",
1299
+ return DetMetrics.keys.fget(self) + [
1358
1300
  "metrics/precision(P)",
1359
1301
  "metrics/recall(P)",
1360
1302
  "metrics/mAP50(P)",
@@ -1363,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
1363
1305
 
1364
1306
  def mean_results(self) -> List[float]:
1365
1307
  """Return the mean results of box and pose."""
1366
- return self.box.mean_results() + self.pose.mean_results()
1308
+ return DetMetrics.mean_results(self) + self.pose.mean_results()
1367
1309
 
1368
1310
  def class_result(self, i: int) -> List[float]:
1369
1311
  """Return the class-wise detection results for a specific class i."""
1370
- return self.box.class_result(i) + self.pose.class_result(i)
1312
+ return DetMetrics.class_result(self, i) + self.pose.class_result(i)
1371
1313
 
1372
1314
  @property
1373
1315
  def maps(self) -> np.ndarray:
1374
1316
  """Return the mean average precision (mAP) per class for both box and pose detections."""
1375
- return self.box.maps + self.pose.maps
1317
+ return DetMetrics.maps.fget(self) + self.pose.maps
1376
1318
 
1377
1319
  @property
1378
1320
  def fitness(self) -> float:
1379
1321
  """Return combined fitness score for pose and box detection."""
1380
- return self.pose.fitness() + self.box.fitness()
1322
+ return self.pose.fitness() + DetMetrics.fitness.fget(self)
1381
1323
 
1382
1324
  @property
1383
1325
  def curves(self) -> List[str]:
1384
1326
  """Return a list of curves for accessing specific metrics curves."""
1385
- return [
1327
+ return DetMetrics.curves.fget(self) + [
1386
1328
  "Precision-Recall(B)",
1387
1329
  "F1-Confidence(B)",
1388
1330
  "Precision-Confidence(B)",
@@ -1396,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
1396
1338
  @property
1397
1339
  def curves_results(self) -> List[List]:
1398
1340
  """Return dictionary of computed performance metrics and statistics."""
1399
- return self.box.curves_results + self.pose.curves_results
1341
+ return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1400
1342
 
1401
1343
  def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1402
1344
  """
@@ -1416,29 +1358,19 @@ class PoseMetrics(SegmentMetrics):
1416
1358
  >>> print(pose_summary)
1417
1359
  """
1418
1360
  scalars = {
1419
- "box-map": round(self.box.map, decimals),
1420
- "box-map50": round(self.box.map50, decimals),
1421
- "box-map75": round(self.box.map75, decimals),
1422
1361
  "pose-map": round(self.pose.map, decimals),
1423
1362
  "pose-map50": round(self.pose.map50, decimals),
1424
1363
  "pose-map75": round(self.pose.map75, decimals),
1425
1364
  }
1426
1365
  per_class = {
1427
- "box-p": self.box.p,
1428
- "box-r": self.box.r,
1429
- "box-f1": self.box.f1,
1430
1366
  "pose-p": self.pose.p,
1431
1367
  "pose-r": self.pose.r,
1432
1368
  "pose-f1": self.pose.f1,
1433
1369
  }
1434
- return [
1435
- {
1436
- "class_name": self.names[self.ap_class_index[i]],
1437
- **{k: round(v[i], decimals) for k, v in per_class.items()},
1438
- **scalars,
1439
- }
1440
- for i in range(len(per_class["box-p"]))
1441
- ]
1370
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1371
+ for i, s in enumerate(summary):
1372
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
1373
+ return summary
1442
1374
 
1443
1375
 
1444
1376
  class ClassifyMetrics(SimpleClass, DataExportMixin):
@@ -1516,133 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1516
1448
  return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
1517
1449
 
1518
1450
 
1519
- class OBBMetrics(SimpleClass, DataExportMixin):
1451
+ class OBBMetrics(DetMetrics):
1520
1452
  """
1521
1453
  Metrics for evaluating oriented bounding box (OBB) detection.
1522
1454
 
1523
1455
  Attributes:
1524
- save_dir (Path): Path to the directory where the output plots should be saved.
1525
- plot (bool): Whether to save the detection plots.
1526
1456
  names (Dict[int, str]): Dictionary of class names.
1527
1457
  box (Metric): An instance of the Metric class for storing detection results.
1528
1458
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1529
1459
  task (str): The task type, set to 'obb'.
1460
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1461
+ nt_per_class: Number of targets per class.
1462
+ nt_per_image: Number of targets per image.
1530
1463
 
1531
1464
  References:
1532
1465
  https://arxiv.org/pdf/2106.06072.pdf
1533
1466
  """
1534
1467
 
1535
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1468
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1536
1469
  """
1537
1470
  Initialize an OBBMetrics instance with directory, plotting, and class names.
1538
1471
 
1539
1472
  Args:
1540
- save_dir (Path, optional): Directory to save plots.
1541
- plot (bool, optional): Whether to plot precision-recall curves.
1542
1473
  names (Dict[int, str], optional): Dictionary of class names.
1543
1474
  """
1544
- self.save_dir = save_dir
1545
- self.plot = plot
1546
- self.names = names
1547
- self.box = Metric()
1548
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1475
+ DetMetrics.__init__(self, names)
1476
+ # TODO: probably remove task as well
1549
1477
  self.task = "obb"
1550
-
1551
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
1552
- """
1553
- Process predicted results for object detection and update metrics.
1554
-
1555
- Args:
1556
- tp (np.ndarray): True positive array.
1557
- conf (np.ndarray): Confidence array.
1558
- pred_cls (np.ndarray): Predicted class indices array.
1559
- target_cls (np.ndarray): Target class indices array.
1560
- on_plot (callable, optional): Function to call after plots are generated.
1561
- """
1562
- results = ap_per_class(
1563
- tp,
1564
- conf,
1565
- pred_cls,
1566
- target_cls,
1567
- plot=self.plot,
1568
- save_dir=self.save_dir,
1569
- names=self.names,
1570
- on_plot=on_plot,
1571
- )[2:]
1572
- self.box.nc = len(self.names)
1573
- self.box.update(results)
1574
-
1575
- @property
1576
- def keys(self) -> List[str]:
1577
- """Return a list of keys for accessing specific metrics."""
1578
- return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
1579
-
1580
- def mean_results(self) -> List[float]:
1581
- """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
1582
- return self.box.mean_results()
1583
-
1584
- def class_result(self, i: int) -> Tuple[float, float, float, float]:
1585
- """Return the result of evaluating the performance of an object detection model on a specific class."""
1586
- return self.box.class_result(i)
1587
-
1588
- @property
1589
- def maps(self) -> np.ndarray:
1590
- """Return mean Average Precision (mAP) scores per class."""
1591
- return self.box.maps
1592
-
1593
- @property
1594
- def fitness(self) -> float:
1595
- """Return the fitness of box object."""
1596
- return self.box.fitness()
1597
-
1598
- @property
1599
- def ap_class_index(self) -> List:
1600
- """Return the average precision index per class."""
1601
- return self.box.ap_class_index
1602
-
1603
- @property
1604
- def results_dict(self) -> Dict[str, float]:
1605
- """Return dictionary of computed performance metrics and statistics."""
1606
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1607
-
1608
- @property
1609
- def curves(self) -> List:
1610
- """Return a list of curves for accessing specific metrics curves."""
1611
- return []
1612
-
1613
- @property
1614
- def curves_results(self) -> List:
1615
- """Return a list of curves for accessing specific metrics curves."""
1616
- return []
1617
-
1618
- def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1619
- """
1620
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
1621
- scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
1622
-
1623
- Args:
1624
- normalize (bool): For OBB metrics, everything is normalized by default [0-1].
1625
- decimals (int): Number of decimal places to round the metrics values to.
1626
-
1627
- Returns:
1628
- (List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
1629
-
1630
- Examples:
1631
- >>> results = model.val(data="dota8.yaml")
1632
- >>> detection_summary = results.summary(decimals=4)
1633
- >>> print(detection_summary)
1634
- """
1635
- scalars = {
1636
- "box-map": round(self.box.map, decimals),
1637
- "box-map50": round(self.box.map50, decimals),
1638
- "box-map75": round(self.box.map75, decimals),
1639
- }
1640
- per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
1641
- return [
1642
- {
1643
- "class_name": self.names[self.ap_class_index[i]],
1644
- **{k: round(v[i], decimals) for k, v in per_class.items()},
1645
- **scalars,
1646
- }
1647
- for i in range(len(per_class["box-p"]))
1648
- ]