dgenerate-ultralytics-headless 8.3.152__py3-none-any.whl → 8.3.154__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/RECORD +30 -30
  3. tests/test_python.py +1 -0
  4. ultralytics/__init__.py +1 -1
  5. ultralytics/cfg/__init__.py +2 -0
  6. ultralytics/engine/predictor.py +1 -1
  7. ultralytics/engine/validator.py +0 -6
  8. ultralytics/models/fastsam/val.py +0 -2
  9. ultralytics/models/rtdetr/val.py +28 -16
  10. ultralytics/models/yolo/classify/val.py +26 -23
  11. ultralytics/models/yolo/detect/train.py +4 -7
  12. ultralytics/models/yolo/detect/val.py +88 -90
  13. ultralytics/models/yolo/obb/val.py +52 -44
  14. ultralytics/models/yolo/pose/train.py +1 -35
  15. ultralytics/models/yolo/pose/val.py +77 -176
  16. ultralytics/models/yolo/segment/train.py +1 -41
  17. ultralytics/models/yolo/segment/val.py +64 -176
  18. ultralytics/models/yolo/yoloe/val.py +2 -1
  19. ultralytics/nn/autobackend.py +2 -2
  20. ultralytics/solutions/ai_gym.py +2 -3
  21. ultralytics/solutions/solutions.py +2 -0
  22. ultralytics/solutions/templates/similarity-search.html +31 -0
  23. ultralytics/utils/callbacks/comet.py +1 -1
  24. ultralytics/utils/metrics.py +152 -307
  25. ultralytics/utils/ops.py +4 -4
  26. ultralytics/utils/plotting.py +31 -56
  27. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/WHEEL +0 -0
  28. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/entry_points.txt +0 -0
  29. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/licenses/LICENSE +0 -0
  30. {dgenerate_ultralytics_headless-8.3.152.dist-info → dgenerate_ultralytics_headless-8.3.154.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@
4
4
  import math
5
5
  import warnings
6
6
  from pathlib import Path
7
- from typing import Dict, List, Tuple, Union
7
+ from typing import Any, Dict, List, Tuple, Union
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -316,28 +316,22 @@ class ConfusionMatrix(DataExportMixin):
316
316
  Attributes:
317
317
  task (str): The type of task, either 'detect' or 'classify'.
318
318
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
319
- nc (int): The number of classes.
320
- conf (float): The confidence threshold for detections.
321
- iou_thres (float): The Intersection over Union threshold.
319
+ nc (int): The number of category.
320
+ names (List[str]): The names of the classes, used as labels on the plot.
322
321
  """
323
322
 
324
- def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
323
+ def __init__(self, names: List[str] = [], task: str = "detect"):
325
324
  """
326
325
  Initialize a ConfusionMatrix instance.
327
326
 
328
327
  Args:
329
- nc (int): Number of classes.
330
- conf (float, optional): Confidence threshold for detections.
331
- iou_thres (float, optional): IoU threshold for matching detections to ground truth.
332
- names (tuple, optional): Names of classes, used as labels on the plot.
328
+ names (List[str], optional): Names of classes, used as labels on the plot.
333
329
  task (str, optional): Type of task, either 'detect' or 'classify'.
334
330
  """
335
331
  self.task = task
336
- self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
337
- self.nc = nc # number of classes
338
- self.names = list(names) # name of classes
339
- self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
340
- self.iou_thres = iou_thres
332
+ self.nc = len(names) # number of classes
333
+ self.matrix = np.zeros((self.nc + 1, self.nc + 1)) if self.task == "detect" else np.zeros((self.nc, self.nc))
334
+ self.names = names # name of classes
341
335
 
342
336
  def process_cls_preds(self, preds, targets):
343
337
  """
@@ -351,41 +345,45 @@ class ConfusionMatrix(DataExportMixin):
351
345
  for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
352
346
  self.matrix[p][t] += 1
353
347
 
354
- def process_batch(self, detections, gt_bboxes, gt_cls):
348
+ def process_batch(
349
+ self, detections: Dict[str, torch.Tensor], batch: Dict[str, Any], conf: float = 0.25, iou_thres: float = 0.45
350
+ ) -> None:
355
351
  """
356
352
  Update confusion matrix for object detection task.
357
353
 
358
354
  Args:
359
- detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
360
- Each row should contain (x1, y1, x2, y2, conf, class)
361
- or with an additional element `angle` when it's obb.
362
- gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
363
- gt_cls (Array[M]): The class labels.
355
+ detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
356
+ Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
357
+ Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
358
+ batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
359
+ 'cls' (Array[M]) keys, where M is the number of ground truth objects.
360
+ conf (float, optional): Confidence threshold for detections.
361
+ iou_thres (float, optional): IoU threshold for matching detections to ground truth.
364
362
  """
363
+ conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
364
+ gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
365
+ no_pred = len(detections["cls"]) == 0
365
366
  if gt_cls.shape[0] == 0: # Check if labels is empty
366
- if detections is not None:
367
- detections = detections[detections[:, 4] > self.conf]
368
- detection_classes = detections[:, 5].int().tolist()
367
+ if not no_pred:
368
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
369
+ detection_classes = detections["cls"].int().tolist()
369
370
  for dc in detection_classes:
370
371
  self.matrix[dc, self.nc] += 1 # false positives
371
372
  return
372
- if detections is None:
373
+ if no_pred:
373
374
  gt_classes = gt_cls.int().tolist()
374
375
  for gc in gt_classes:
375
376
  self.matrix[self.nc, gc] += 1 # background FN
376
377
  return
377
378
 
378
- detections = detections[detections[:, 4] > self.conf]
379
+ detections = {k: detections[k][detections["conf"] > conf] for k in {"cls", "bboxes"}}
379
380
  gt_classes = gt_cls.int().tolist()
380
- detection_classes = detections[:, 5].int().tolist()
381
- is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
382
- iou = (
383
- batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
384
- if is_obb
385
- else box_iou(gt_bboxes, detections[:, :4])
386
- )
387
-
388
- x = torch.where(iou > self.iou_thres)
381
+ detection_classes = detections["cls"].int().tolist()
382
+ bboxes = detections["bboxes"]
383
+ is_obb = bboxes.shape[1] == 5 # check if detections contains angle for OBB
384
+ iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
385
+
386
+ x = torch.where(iou > iou_thres)
389
387
  if x[0].shape[0]:
390
388
  matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
391
389
  if x[0].shape[0] > 1:
@@ -949,53 +947,76 @@ class DetMetrics(SimpleClass, DataExportMixin):
949
947
  Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
950
948
 
951
949
  Attributes:
952
- save_dir (Path): A path to the directory where the output plots will be saved.
953
- plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
954
950
  names (Dict[int, str]): A dictionary of class names.
955
951
  box (Metric): An instance of the Metric class for storing detection results.
956
952
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
957
953
  task (str): The task type, set to 'detect'.
954
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
955
+ nt_per_class: Number of targets per class.
956
+ nt_per_image: Number of targets per image.
958
957
  """
959
958
 
960
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
959
+ def __init__(self, names: Dict[int, str] = {}) -> None:
961
960
  """
962
961
  Initialize a DetMetrics instance with a save directory, plot flag, and class names.
963
962
 
964
963
  Args:
965
- save_dir (Path, optional): Directory to save plots.
966
- plot (bool, optional): Whether to plot precision-recall curves.
967
964
  names (Dict[int, str], optional): Dictionary of class names.
968
965
  """
969
- self.save_dir = save_dir
970
- self.plot = plot
971
966
  self.names = names
972
967
  self.box = Metric()
973
968
  self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
974
969
  self.task = "detect"
970
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
971
+ self.nt_per_class = None
972
+ self.nt_per_image = None
973
+
974
+ def update_stats(self, stat: Dict[str, Any]) -> None:
975
+ """
976
+ Update statistics by appending new values to existing stat collections.
977
+
978
+ Args:
979
+ stat (Dict[str, any]): Dictionary containing new statistical values to append.
980
+ Keys should match existing keys in self.stats.
981
+ """
982
+ for k in self.stats.keys():
983
+ self.stats[k].append(stat[k])
975
984
 
976
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
985
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
977
986
  """
978
987
  Process predicted results for object detection and update metrics.
979
988
 
980
989
  Args:
981
- tp (np.ndarray): True positive array.
982
- conf (np.ndarray): Confidence array.
983
- pred_cls (np.ndarray): Predicted class indices array.
984
- target_cls (np.ndarray): Target class indices array.
985
- on_plot (callable, optional): Function to call after plots are generated.
990
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
991
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
992
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
993
+
994
+ Returns:
995
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
986
996
  """
997
+ stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
998
+ if len(stats) == 0:
999
+ return stats
987
1000
  results = ap_per_class(
988
- tp,
989
- conf,
990
- pred_cls,
991
- target_cls,
992
- plot=self.plot,
993
- save_dir=self.save_dir,
1001
+ stats["tp"],
1002
+ stats["conf"],
1003
+ stats["pred_cls"],
1004
+ stats["target_cls"],
1005
+ plot=plot,
1006
+ save_dir=save_dir,
994
1007
  names=self.names,
995
1008
  on_plot=on_plot,
996
1009
  )[2:]
997
1010
  self.box.nc = len(self.names)
998
1011
  self.box.update(results)
1012
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
1013
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
1014
+ return stats
1015
+
1016
+ def clear_stats(self):
1017
+ """Clear the stored statistics."""
1018
+ for v in self.stats.values():
1019
+ v.clear()
999
1020
 
1000
1021
  @property
1001
1022
  def keys(self) -> List[str]:
@@ -1068,97 +1089,74 @@ class DetMetrics(SimpleClass, DataExportMixin):
1068
1089
  "box-f1": self.box.f1,
1069
1090
  }
1070
1091
  return [
1071
- {"class_name": self.names[i], **{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars}
1072
- for i in range(len(next(iter(per_class.values()), [])))
1092
+ {
1093
+ "class_name": self.names[self.ap_class_index[i]],
1094
+ **{k: round(v[i], decimals) for k, v in per_class.items()},
1095
+ **scalars,
1096
+ }
1097
+ for i in range(len(per_class["box-p"]))
1073
1098
  ]
1074
1099
 
1075
1100
 
1076
- class SegmentMetrics(SimpleClass, DataExportMixin):
1101
+ class SegmentMetrics(DetMetrics):
1077
1102
  """
1078
1103
  Calculate and aggregate detection and segmentation metrics over a given set of classes.
1079
1104
 
1080
1105
  Attributes:
1081
- save_dir (Path): Path to the directory where the output plots should be saved.
1082
- plot (bool): Whether to save the detection and segmentation plots.
1083
1106
  names (Dict[int, str]): Dictionary of class names.
1084
1107
  box (Metric): An instance of the Metric class for storing detection results.
1085
1108
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1086
1109
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1087
1110
  task (str): The task type, set to 'segment'.
1111
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1112
+ nt_per_class: Number of targets per class.
1113
+ nt_per_image: Number of targets per image.
1088
1114
  """
1089
1115
 
1090
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1116
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1091
1117
  """
1092
1118
  Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1093
1119
 
1094
1120
  Args:
1095
- save_dir (Path, optional): Directory to save plots.
1096
- plot (bool, optional): Whether to plot precision-recall curves.
1097
1121
  names (Dict[int, str], optional): Dictionary of class names.
1098
1122
  """
1099
- self.save_dir = save_dir
1100
- self.plot = plot
1101
- self.names = names
1102
- self.box = Metric()
1123
+ DetMetrics.__init__(self, names)
1103
1124
  self.seg = Metric()
1104
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1105
1125
  self.task = "segment"
1126
+ self.stats["tp_m"] = [] # add additional stats for masks
1106
1127
 
1107
- def process(
1108
- self,
1109
- tp: np.ndarray,
1110
- tp_m: np.ndarray,
1111
- conf: np.ndarray,
1112
- pred_cls: np.ndarray,
1113
- target_cls: np.ndarray,
1114
- on_plot=None,
1115
- ):
1128
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
1116
1129
  """
1117
1130
  Process the detection and segmentation metrics over the given set of predictions.
1118
1131
 
1119
1132
  Args:
1120
- tp (np.ndarray): True positive array for boxes.
1121
- tp_m (np.ndarray): True positive array for masks.
1122
- conf (np.ndarray): Confidence array.
1123
- pred_cls (np.ndarray): Predicted class indices array.
1124
- target_cls (np.ndarray): Target class indices array.
1125
- on_plot (callable, optional): Function to call after plots are generated.
1133
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1134
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1135
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
1136
+
1137
+ Returns:
1138
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1126
1139
  """
1140
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
1127
1141
  results_mask = ap_per_class(
1128
- tp_m,
1129
- conf,
1130
- pred_cls,
1131
- target_cls,
1132
- plot=self.plot,
1142
+ stats["tp_m"],
1143
+ stats["conf"],
1144
+ stats["pred_cls"],
1145
+ stats["target_cls"],
1146
+ plot=plot,
1133
1147
  on_plot=on_plot,
1134
- save_dir=self.save_dir,
1148
+ save_dir=save_dir,
1135
1149
  names=self.names,
1136
1150
  prefix="Mask",
1137
1151
  )[2:]
1138
1152
  self.seg.nc = len(self.names)
1139
1153
  self.seg.update(results_mask)
1140
- results_box = ap_per_class(
1141
- tp,
1142
- conf,
1143
- pred_cls,
1144
- target_cls,
1145
- plot=self.plot,
1146
- on_plot=on_plot,
1147
- save_dir=self.save_dir,
1148
- names=self.names,
1149
- prefix="Box",
1150
- )[2:]
1151
- self.box.nc = len(self.names)
1152
- self.box.update(results_box)
1154
+ return stats
1153
1155
 
1154
1156
  @property
1155
1157
  def keys(self) -> List[str]:
1156
1158
  """Return a list of keys for accessing metrics."""
1157
- return [
1158
- "metrics/precision(B)",
1159
- "metrics/recall(B)",
1160
- "metrics/mAP50(B)",
1161
- "metrics/mAP50-95(B)",
1159
+ return DetMetrics.keys.fget(self) + [
1162
1160
  "metrics/precision(M)",
1163
1161
  "metrics/recall(M)",
1164
1162
  "metrics/mAP50(M)",
@@ -1167,40 +1165,26 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1167
1165
 
1168
1166
  def mean_results(self) -> List[float]:
1169
1167
  """Return the mean metrics for bounding box and segmentation results."""
1170
- return self.box.mean_results() + self.seg.mean_results()
1168
+ return DetMetrics.mean_results(self) + self.seg.mean_results()
1171
1169
 
1172
1170
  def class_result(self, i: int) -> List[float]:
1173
1171
  """Return classification results for a specified class index."""
1174
- return self.box.class_result(i) + self.seg.class_result(i)
1172
+ return DetMetrics.class_result(self, i) + self.seg.class_result(i)
1175
1173
 
1176
1174
  @property
1177
1175
  def maps(self) -> np.ndarray:
1178
1176
  """Return mAP scores for object detection and semantic segmentation models."""
1179
- return self.box.maps + self.seg.maps
1177
+ return DetMetrics.maps.fget(self) + self.seg.maps
1180
1178
 
1181
1179
  @property
1182
1180
  def fitness(self) -> float:
1183
1181
  """Return the fitness score for both segmentation and bounding box models."""
1184
- return self.seg.fitness() + self.box.fitness()
1185
-
1186
- @property
1187
- def ap_class_index(self) -> List:
1188
- """Return the class indices (boxes and masks have the same ap_class_index)."""
1189
- return self.box.ap_class_index
1190
-
1191
- @property
1192
- def results_dict(self) -> Dict[str, float]:
1193
- """Return results of object detection model for evaluation."""
1194
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1182
+ return self.seg.fitness() + DetMetrics.fitness.fget(self)
1195
1183
 
1196
1184
  @property
1197
1185
  def curves(self) -> List[str]:
1198
1186
  """Return a list of curves for accessing specific metrics curves."""
1199
- return [
1200
- "Precision-Recall(B)",
1201
- "F1-Confidence(B)",
1202
- "Precision-Confidence(B)",
1203
- "Recall-Confidence(B)",
1187
+ return DetMetrics.curves.fget(self) + [
1204
1188
  "Precision-Recall(M)",
1205
1189
  "F1-Confidence(M)",
1206
1190
  "Precision-Confidence(M)",
@@ -1210,7 +1194,7 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1210
1194
  @property
1211
1195
  def curves_results(self) -> List[List]:
1212
1196
  """Return dictionary of computed performance metrics and statistics."""
1213
- return self.box.curves_results + self.seg.curves_results
1197
+ return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1214
1198
 
1215
1199
  def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1216
1200
  """
@@ -1230,39 +1214,34 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1230
1214
  >>> print(seg_summary)
1231
1215
  """
1232
1216
  scalars = {
1233
- "box-map": round(self.box.map, decimals),
1234
- "box-map50": round(self.box.map50, decimals),
1235
- "box-map75": round(self.box.map75, decimals),
1236
1217
  "mask-map": round(self.seg.map, decimals),
1237
1218
  "mask-map50": round(self.seg.map50, decimals),
1238
1219
  "mask-map75": round(self.seg.map75, decimals),
1239
1220
  }
1240
1221
  per_class = {
1241
- "box-p": self.box.p,
1242
- "box-r": self.box.r,
1243
- "box-f1": self.box.f1,
1244
1222
  "mask-p": self.seg.p,
1245
1223
  "mask-r": self.seg.r,
1246
1224
  "mask-f1": self.seg.f1,
1247
1225
  }
1248
- return [
1249
- {"class_name": self.names[i], **{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars}
1250
- for i in range(len(next(iter(per_class.values()), [])))
1251
- ]
1226
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1227
+ for i, s in enumerate(summary):
1228
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
1229
+ return summary
1252
1230
 
1253
1231
 
1254
- class PoseMetrics(SegmentMetrics):
1232
+ class PoseMetrics(DetMetrics):
1255
1233
  """
1256
1234
  Calculate and aggregate detection and pose metrics over a given set of classes.
1257
1235
 
1258
1236
  Attributes:
1259
- save_dir (Path): Path to the directory where the output plots should be saved.
1260
- plot (bool): Whether to save the detection and pose plots.
1261
1237
  names (Dict[int, str]): Dictionary of class names.
1262
1238
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1263
1239
  box (Metric): An instance of the Metric class for storing detection results.
1264
1240
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1265
1241
  task (str): The task type, set to 'pose'.
1242
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1243
+ nt_per_class: Number of targets per class.
1244
+ nt_per_image: Number of targets per image.
1266
1245
 
1267
1246
  Methods:
1268
1247
  process(tp_m, tp_b, conf, pred_cls, target_cls): Process metrics over the given set of predictions.
@@ -1274,79 +1253,50 @@ class PoseMetrics(SegmentMetrics):
1274
1253
  results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
1275
1254
  """
1276
1255
 
1277
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1256
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1278
1257
  """
1279
1258
  Initialize the PoseMetrics class with directory path, class names, and plotting options.
1280
1259
 
1281
1260
  Args:
1282
- save_dir (Path, optional): Directory to save plots.
1283
- plot (bool, optional): Whether to plot precision-recall curves.
1284
1261
  names (Dict[int, str], optional): Dictionary of class names.
1285
1262
  """
1286
- super().__init__(save_dir, plot, names)
1287
- self.save_dir = save_dir
1288
- self.plot = plot
1289
- self.names = names
1290
- self.box = Metric()
1263
+ super().__init__(names)
1291
1264
  self.pose = Metric()
1292
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1293
1265
  self.task = "pose"
1266
+ self.stats["tp_p"] = [] # add additional stats for pose
1294
1267
 
1295
- def process(
1296
- self,
1297
- tp: np.ndarray,
1298
- tp_p: np.ndarray,
1299
- conf: np.ndarray,
1300
- pred_cls: np.ndarray,
1301
- target_cls: np.ndarray,
1302
- on_plot=None,
1303
- ):
1268
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]:
1304
1269
  """
1305
1270
  Process the detection and pose metrics over the given set of predictions.
1306
1271
 
1307
1272
  Args:
1308
- tp (np.ndarray): True positive array for boxes.
1309
- tp_p (np.ndarray): True positive array for keypoints.
1310
- conf (np.ndarray): Confidence array.
1311
- pred_cls (np.ndarray): Predicted class indices array.
1312
- target_cls (np.ndarray): Target class indices array.
1273
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1274
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1313
1275
  on_plot (callable, optional): Function to call after plots are generated.
1276
+
1277
+ Returns:
1278
+ (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1314
1279
  """
1280
+ stats = DetMetrics.process(self, on_plot=on_plot) # process box stats
1315
1281
  results_pose = ap_per_class(
1316
- tp_p,
1317
- conf,
1318
- pred_cls,
1319
- target_cls,
1320
- plot=self.plot,
1282
+ stats["tp_p"],
1283
+ stats["conf"],
1284
+ stats["pred_cls"],
1285
+ stats["target_cls"],
1286
+ plot=plot,
1321
1287
  on_plot=on_plot,
1322
- save_dir=self.save_dir,
1288
+ save_dir=save_dir,
1323
1289
  names=self.names,
1324
1290
  prefix="Pose",
1325
1291
  )[2:]
1326
1292
  self.pose.nc = len(self.names)
1327
1293
  self.pose.update(results_pose)
1328
- results_box = ap_per_class(
1329
- tp,
1330
- conf,
1331
- pred_cls,
1332
- target_cls,
1333
- plot=self.plot,
1334
- on_plot=on_plot,
1335
- save_dir=self.save_dir,
1336
- names=self.names,
1337
- prefix="Box",
1338
- )[2:]
1339
- self.box.nc = len(self.names)
1340
- self.box.update(results_box)
1294
+ return stats
1341
1295
 
1342
1296
  @property
1343
1297
  def keys(self) -> List[str]:
1344
1298
  """Return list of evaluation metric keys."""
1345
- return [
1346
- "metrics/precision(B)",
1347
- "metrics/recall(B)",
1348
- "metrics/mAP50(B)",
1349
- "metrics/mAP50-95(B)",
1299
+ return DetMetrics.keys.fget(self) + [
1350
1300
  "metrics/precision(P)",
1351
1301
  "metrics/recall(P)",
1352
1302
  "metrics/mAP50(P)",
@@ -1355,26 +1305,26 @@ class PoseMetrics(SegmentMetrics):
1355
1305
 
1356
1306
  def mean_results(self) -> List[float]:
1357
1307
  """Return the mean results of box and pose."""
1358
- return self.box.mean_results() + self.pose.mean_results()
1308
+ return DetMetrics.mean_results(self) + self.pose.mean_results()
1359
1309
 
1360
1310
  def class_result(self, i: int) -> List[float]:
1361
1311
  """Return the class-wise detection results for a specific class i."""
1362
- return self.box.class_result(i) + self.pose.class_result(i)
1312
+ return DetMetrics.class_result(self, i) + self.pose.class_result(i)
1363
1313
 
1364
1314
  @property
1365
1315
  def maps(self) -> np.ndarray:
1366
1316
  """Return the mean average precision (mAP) per class for both box and pose detections."""
1367
- return self.box.maps + self.pose.maps
1317
+ return DetMetrics.maps.fget(self) + self.pose.maps
1368
1318
 
1369
1319
  @property
1370
1320
  def fitness(self) -> float:
1371
1321
  """Return combined fitness score for pose and box detection."""
1372
- return self.pose.fitness() + self.box.fitness()
1322
+ return self.pose.fitness() + DetMetrics.fitness.fget(self)
1373
1323
 
1374
1324
  @property
1375
1325
  def curves(self) -> List[str]:
1376
1326
  """Return a list of curves for accessing specific metrics curves."""
1377
- return [
1327
+ return DetMetrics.curves.fget(self) + [
1378
1328
  "Precision-Recall(B)",
1379
1329
  "F1-Confidence(B)",
1380
1330
  "Precision-Confidence(B)",
@@ -1388,7 +1338,7 @@ class PoseMetrics(SegmentMetrics):
1388
1338
  @property
1389
1339
  def curves_results(self) -> List[List]:
1390
1340
  """Return dictionary of computed performance metrics and statistics."""
1391
- return self.box.curves_results + self.pose.curves_results
1341
+ return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1392
1342
 
1393
1343
  def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1394
1344
  """
@@ -1408,25 +1358,19 @@ class PoseMetrics(SegmentMetrics):
1408
1358
  >>> print(pose_summary)
1409
1359
  """
1410
1360
  scalars = {
1411
- "box-map": round(self.box.map, decimals),
1412
- "box-map50": round(self.box.map50, decimals),
1413
- "box-map75": round(self.box.map75, decimals),
1414
1361
  "pose-map": round(self.pose.map, decimals),
1415
1362
  "pose-map50": round(self.pose.map50, decimals),
1416
1363
  "pose-map75": round(self.pose.map75, decimals),
1417
1364
  }
1418
1365
  per_class = {
1419
- "box-p": self.box.p,
1420
- "box-r": self.box.r,
1421
- "box-f1": self.box.f1,
1422
1366
  "pose-p": self.pose.p,
1423
1367
  "pose-r": self.pose.r,
1424
1368
  "pose-f1": self.pose.f1,
1425
1369
  }
1426
- return [
1427
- {"class_name": self.names[i], **{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars}
1428
- for i in range(len(next(iter(per_class.values()), [])))
1429
- ]
1370
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1371
+ for i, s in enumerate(summary):
1372
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars})
1373
+ return summary
1430
1374
 
1431
1375
 
1432
1376
  class ClassifyMetrics(SimpleClass, DataExportMixin):
@@ -1504,129 +1448,30 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1504
1448
  return [{"classify-top1": round(self.top1, decimals), "classify-top5": round(self.top5, decimals)}]
1505
1449
 
1506
1450
 
1507
- class OBBMetrics(SimpleClass, DataExportMixin):
1451
+ class OBBMetrics(DetMetrics):
1508
1452
  """
1509
1453
  Metrics for evaluating oriented bounding box (OBB) detection.
1510
1454
 
1511
1455
  Attributes:
1512
- save_dir (Path): Path to the directory where the output plots should be saved.
1513
- plot (bool): Whether to save the detection plots.
1514
1456
  names (Dict[int, str]): Dictionary of class names.
1515
1457
  box (Metric): An instance of the Metric class for storing detection results.
1516
1458
  speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1517
1459
  task (str): The task type, set to 'obb'.
1460
+ stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1461
+ nt_per_class: Number of targets per class.
1462
+ nt_per_image: Number of targets per image.
1518
1463
 
1519
1464
  References:
1520
1465
  https://arxiv.org/pdf/2106.06072.pdf
1521
1466
  """
1522
1467
 
1523
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1468
+ def __init__(self, names: Dict[int, str] = {}) -> None:
1524
1469
  """
1525
1470
  Initialize an OBBMetrics instance with directory, plotting, and class names.
1526
1471
 
1527
1472
  Args:
1528
- save_dir (Path, optional): Directory to save plots.
1529
- plot (bool, optional): Whether to plot precision-recall curves.
1530
1473
  names (Dict[int, str], optional): Dictionary of class names.
1531
1474
  """
1532
- self.save_dir = save_dir
1533
- self.plot = plot
1534
- self.names = names
1535
- self.box = Metric()
1536
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1475
+ DetMetrics.__init__(self, names)
1476
+ # TODO: probably remove task as well
1537
1477
  self.task = "obb"
1538
-
1539
- def process(self, tp: np.ndarray, conf: np.ndarray, pred_cls: np.ndarray, target_cls: np.ndarray, on_plot=None):
1540
- """
1541
- Process predicted results for object detection and update metrics.
1542
-
1543
- Args:
1544
- tp (np.ndarray): True positive array.
1545
- conf (np.ndarray): Confidence array.
1546
- pred_cls (np.ndarray): Predicted class indices array.
1547
- target_cls (np.ndarray): Target class indices array.
1548
- on_plot (callable, optional): Function to call after plots are generated.
1549
- """
1550
- results = ap_per_class(
1551
- tp,
1552
- conf,
1553
- pred_cls,
1554
- target_cls,
1555
- plot=self.plot,
1556
- save_dir=self.save_dir,
1557
- names=self.names,
1558
- on_plot=on_plot,
1559
- )[2:]
1560
- self.box.nc = len(self.names)
1561
- self.box.update(results)
1562
-
1563
- @property
1564
- def keys(self) -> List[str]:
1565
- """Return a list of keys for accessing specific metrics."""
1566
- return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
1567
-
1568
- def mean_results(self) -> List[float]:
1569
- """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
1570
- return self.box.mean_results()
1571
-
1572
- def class_result(self, i: int) -> Tuple[float, float, float, float]:
1573
- """Return the result of evaluating the performance of an object detection model on a specific class."""
1574
- return self.box.class_result(i)
1575
-
1576
- @property
1577
- def maps(self) -> np.ndarray:
1578
- """Return mean Average Precision (mAP) scores per class."""
1579
- return self.box.maps
1580
-
1581
- @property
1582
- def fitness(self) -> float:
1583
- """Return the fitness of box object."""
1584
- return self.box.fitness()
1585
-
1586
- @property
1587
- def ap_class_index(self) -> List:
1588
- """Return the average precision index per class."""
1589
- return self.box.ap_class_index
1590
-
1591
- @property
1592
- def results_dict(self) -> Dict[str, float]:
1593
- """Return dictionary of computed performance metrics and statistics."""
1594
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1595
-
1596
- @property
1597
- def curves(self) -> List:
1598
- """Return a list of curves for accessing specific metrics curves."""
1599
- return []
1600
-
1601
- @property
1602
- def curves_results(self) -> List:
1603
- """Return a list of curves for accessing specific metrics curves."""
1604
- return []
1605
-
1606
- def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Union[str, float]]]:
1607
- """
1608
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
1609
- scalar metrics (mAP, mAP50, mAP75) along with precision, recall, and F1-score for each class.
1610
-
1611
- Args:
1612
- normalize (bool): For OBB metrics, everything is normalized by default [0-1].
1613
- decimals (int): Number of decimal places to round the metrics values to.
1614
-
1615
- Returns:
1616
- (List[Dict[str, Union[str, float]]]): A list of dictionaries, each representing one class with detection metrics.
1617
-
1618
- Examples:
1619
- >>> results = model.val(data="dota8.yaml")
1620
- >>> detection_summary = results.summary(decimals=4)
1621
- >>> print(detection_summary)
1622
- """
1623
- scalars = {
1624
- "box-map": round(self.box.map, decimals),
1625
- "box-map50": round(self.box.map50, decimals),
1626
- "box-map75": round(self.box.map75, decimals),
1627
- }
1628
- per_class = {"box-p": self.box.p, "box-r": self.box.r, "box-f1": self.box.f1}
1629
- return [
1630
- {"class_name": self.names[i], **{k: round(v[i], decimals) for k, v in per_class.items()}, **scalars}
1631
- for i in range(len(next(iter(per_class.values()), [])))
1632
- ]