dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,18 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """Model validation metrics."""
3
3
 
4
+ from __future__ import annotations
5
+
4
6
  import math
5
7
  import warnings
8
+ from collections import defaultdict
6
9
  from pathlib import Path
10
+ from typing import Any
7
11
 
8
12
  import numpy as np
9
13
  import torch
10
14
 
11
- from ultralytics.utils import LOGGER, SimpleClass, TryExcept, checks, plt_settings
15
+ from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
12
16
 
13
17
  OKS_SIGMA = (
14
18
  np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
@@ -16,18 +20,17 @@ OKS_SIGMA = (
16
20
  )
17
21
 
18
22
 
19
- def bbox_ioa(box1, box2, iou=False, eps=1e-7):
20
- """
21
- Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
23
+ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
24
+ """Calculate the intersection over box2 area given box1 and box2.
22
25
 
23
26
  Args:
24
- box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
25
- box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
26
- iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
27
+ box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
28
+ box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.
29
+ iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.
27
30
  eps (float, optional): A small value to avoid division by zero.
28
31
 
29
32
  Returns:
30
- (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
33
+ (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.
31
34
  """
32
35
  # Get the coordinates of bounding boxes
33
36
  b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
@@ -48,18 +51,19 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
48
51
  return inter_area / (area + eps)
49
52
 
50
53
 
51
- def box_iou(box1, box2, eps=1e-7):
52
- """
53
- Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
54
- Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
54
+ def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
55
+ """Calculate intersection-over-union (IoU) of boxes.
55
56
 
56
57
  Args:
57
- box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
58
- box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
58
+ box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
59
+ box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.
59
60
  eps (float, optional): A small value to avoid division by zero.
60
61
 
61
62
  Returns:
62
63
  (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
64
+
65
+ References:
66
+ https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py
63
67
  """
64
68
  # NOTE: Need .float() to get accurate iou values
65
69
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
@@ -70,20 +74,26 @@ def box_iou(box1, box2, eps=1e-7):
70
74
  return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
71
75
 
72
76
 
73
- def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
74
- """
75
- Calculate the Intersection over Union (IoU) between bounding boxes.
77
+ def bbox_iou(
78
+ box1: torch.Tensor,
79
+ box2: torch.Tensor,
80
+ xywh: bool = True,
81
+ GIoU: bool = False,
82
+ DIoU: bool = False,
83
+ CIoU: bool = False,
84
+ eps: float = 1e-7,
85
+ ) -> torch.Tensor:
86
+ """Calculate the Intersection over Union (IoU) between bounding boxes.
76
87
 
77
- This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
78
- For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
79
- Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
80
- or (x1, y1, x2, y2) if `xywh=False`.
88
+ This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
89
+ may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
90
+ dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
81
91
 
82
92
  Args:
83
93
  box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
84
94
  box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
85
- xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
86
- (x1, y1, x2, y2) format.
95
+ xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
96
+ x2, y2) format.
87
97
  GIoU (bool, optional): If True, calculate Generalized IoU.
88
98
  DIoU (bool, optional): If True, calculate Distance IoU.
89
99
  CIoU (bool, optional): If True, calculate Complete IoU.
@@ -133,15 +143,14 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
133
143
  return iou # IoU
134
144
 
135
145
 
136
- def mask_iou(mask1, mask2, eps=1e-7):
137
- """
138
- Calculate masks IoU.
146
+ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
147
+ """Calculate masks IoU.
139
148
 
140
149
  Args:
141
150
  mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
142
- product of image width and height.
143
- mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
144
- product of image width and height.
151
+ product of image width and height.
152
+ mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
153
+ of image width and height.
145
154
  eps (float, optional): A small value to avoid division by zero.
146
155
 
147
156
  Returns:
@@ -152,9 +161,10 @@ def mask_iou(mask1, mask2, eps=1e-7):
152
161
  return intersection / (union + eps)
153
162
 
154
163
 
155
- def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
156
- """
157
- Calculate Object Keypoint Similarity (OKS).
164
+ def kpt_iou(
165
+ kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
166
+ ) -> torch.Tensor:
167
+ """Calculate Object Keypoint Similarity (OKS).
158
168
 
159
169
  Args:
160
170
  kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
@@ -174,9 +184,8 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
174
184
  return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
175
185
 
176
186
 
177
- def _get_covariance_matrix(boxes):
178
- """
179
- Generate covariance matrix from oriented bounding boxes.
187
+ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
188
+ """Generate covariance matrix from oriented bounding boxes.
180
189
 
181
190
  Args:
182
191
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
@@ -194,9 +203,8 @@ def _get_covariance_matrix(boxes):
194
203
  return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
195
204
 
196
205
 
197
- def probiou(obb1, obb2, CIoU=False, eps=1e-7):
198
- """
199
- Calculate probabilistic IoU between oriented bounding boxes.
206
+ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
207
+ """Calculate probabilistic IoU between oriented bounding boxes.
200
208
 
201
209
  Args:
202
210
  obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
@@ -208,8 +216,10 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
208
216
  (torch.Tensor): OBB similarities, shape (N,).
209
217
 
210
218
  Notes:
211
- - OBB format: [center_x, center_y, width, height, rotation_angle].
212
- - Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
219
+ OBB format: [center_x, center_y, width, height, rotation_angle].
220
+
221
+ References:
222
+ https://arxiv.org/pdf/2106.06072v1.pdf
213
223
  """
214
224
  x1, y1 = obb1[..., :2].split(1, dim=-1)
215
225
  x2, y2 = obb2[..., :2].split(1, dim=-1)
@@ -238,9 +248,8 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
238
248
  return iou
239
249
 
240
250
 
241
- def batch_probiou(obb1, obb2, eps=1e-7):
242
- """
243
- Calculate the probabilistic IoU between oriented bounding boxes.
251
+ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
252
+ """Calculate the probabilistic IoU between oriented bounding boxes.
244
253
 
245
254
  Args:
246
255
  obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
@@ -275,15 +284,15 @@ def batch_probiou(obb1, obb2, eps=1e-7):
275
284
  return 1 - hd
276
285
 
277
286
 
278
- def smooth_bce(eps=0.1):
279
- """
280
- Compute smoothed positive and negative Binary Cross-Entropy targets.
287
+ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
288
+ """Compute smoothed positive and negative Binary Cross-Entropy targets.
281
289
 
282
290
  Args:
283
291
  eps (float, optional): The epsilon value for label smoothing.
284
292
 
285
293
  Returns:
286
- (tuple): A tuple containing the positive and negative label smoothing BCE targets.
294
+ pos (float): Positive label smoothing BCE target.
295
+ neg (float): Negative label smoothing BCE target.
287
296
 
288
297
  References:
289
298
  https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
@@ -291,81 +300,115 @@ def smooth_bce(eps=0.1):
291
300
  return 1.0 - 0.5 * eps, 0.5 * eps
292
301
 
293
302
 
294
- class ConfusionMatrix:
295
- """
296
- A class for calculating and updating a confusion matrix for object detection and classification tasks.
303
+ class ConfusionMatrix(DataExportMixin):
304
+ """A class for calculating and updating a confusion matrix for object detection and classification tasks.
297
305
 
298
306
  Attributes:
299
307
  task (str): The type of task, either 'detect' or 'classify'.
300
308
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
301
- nc (int): The number of classes.
302
- conf (float): The confidence threshold for detections.
303
- iou_thres (float): The Intersection over Union threshold.
309
+ nc (int): The number of category.
310
+ names (list[str]): The names of the classes, used as labels on the plot.
311
+ matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
304
312
  """
305
313
 
306
- def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
307
- """
308
- Initialize a ConfusionMatrix instance.
314
+ def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
315
+ """Initialize a ConfusionMatrix instance.
309
316
 
310
317
  Args:
311
- nc (int): Number of classes.
312
- conf (float, optional): Confidence threshold for detections.
313
- iou_thres (float, optional): IoU threshold for matching detections to ground truth.
318
+ names (dict[int, str], optional): Names of classes, used as labels on the plot.
314
319
  task (str, optional): Type of task, either 'detect' or 'classify'.
320
+ save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
315
321
  """
316
322
  self.task = task
317
- self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
318
- self.nc = nc # number of classes
319
- self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
320
- self.iou_thres = iou_thres
323
+ self.nc = len(names) # number of classes
324
+ self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
325
+ self.names = names # name of classes
326
+ self.matches = {} if save_matches else None
327
+
328
+ def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
329
+ """Append the matches to TP, FP, FN or GT list for the last batch.
321
330
 
322
- def process_cls_preds(self, preds, targets):
331
+ This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
332
+ Positive, False Positive, or False Negative).
333
+
334
+ Args:
335
+ mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
336
+ batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
337
+ 'keypoints', 'masks'.
338
+ idx (int): Index of the specific detection to append from the batch.
339
+
340
+ Notes:
341
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
342
+ overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
323
343
  """
324
- Update confusion matrix for classification task.
344
+ if self.matches is None:
345
+ return
346
+ for k, v in batch.items():
347
+ if k in {"bboxes", "cls", "conf", "keypoints"}:
348
+ self.matches[mtype][k] += v[[idx]]
349
+ elif k == "masks":
350
+ # NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
351
+ self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
352
+
353
+ def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
354
+ """Update confusion matrix for classification task.
325
355
 
326
356
  Args:
327
- preds (Array[N, min(nc,5)]): Predicted class labels.
328
- targets (Array[N, 1]): Ground truth class labels.
357
+ preds (list[N, min(nc,5)]): Predicted class labels.
358
+ targets (list[N, 1]): Ground truth class labels.
329
359
  """
330
360
  preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
331
361
  for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
332
362
  self.matrix[p][t] += 1
333
363
 
334
- def process_batch(self, detections, gt_bboxes, gt_cls):
335
- """
336
- Update confusion matrix for object detection task.
364
+ def process_batch(
365
+ self,
366
+ detections: dict[str, torch.Tensor],
367
+ batch: dict[str, Any],
368
+ conf: float = 0.25,
369
+ iou_thres: float = 0.45,
370
+ ) -> None:
371
+ """Update confusion matrix for object detection task.
337
372
 
338
373
  Args:
339
- detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
340
- Each row should contain (x1, y1, x2, y2, conf, class)
341
- or with an additional element `angle` when it's obb.
342
- gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
343
- gt_cls (Array[M]): The class labels.
374
+ detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
375
+ information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
376
+ regular boxes or Array[N, 5] for OBB with angle.
377
+ batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
378
+ 5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
379
+ conf (float, optional): Confidence threshold for detections.
380
+ iou_thres (float, optional): IoU threshold for matching detections to ground truth.
344
381
  """
382
+ gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
383
+ if self.matches is not None: # only if visualization is enabled
384
+ self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
385
+ for i in range(gt_cls.shape[0]):
386
+ self._append_matches("GT", batch, i) # store GT
387
+ is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
388
+ conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
389
+ no_pred = detections["cls"].shape[0] == 0
345
390
  if gt_cls.shape[0] == 0: # Check if labels is empty
346
- if detections is not None:
347
- detections = detections[detections[:, 4] > self.conf]
348
- detection_classes = detections[:, 5].int()
349
- for dc in detection_classes:
350
- self.matrix[dc, self.nc] += 1 # false positives
391
+ if not no_pred:
392
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections}
393
+ detection_classes = detections["cls"].int().tolist()
394
+ for i, dc in enumerate(detection_classes):
395
+ self.matrix[dc, self.nc] += 1 # FP
396
+ self._append_matches("FP", detections, i)
351
397
  return
352
- if detections is None:
353
- gt_classes = gt_cls.int()
354
- for gc in gt_classes:
355
- self.matrix[self.nc, gc] += 1 # background FN
398
+ if no_pred:
399
+ gt_classes = gt_cls.int().tolist()
400
+ for i, gc in enumerate(gt_classes):
401
+ self.matrix[self.nc, gc] += 1 # FN
402
+ self._append_matches("FN", batch, i)
356
403
  return
357
404
 
358
- detections = detections[detections[:, 4] > self.conf]
359
- gt_classes = gt_cls.int()
360
- detection_classes = detections[:, 5].int()
361
- is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
362
- iou = (
363
- batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
364
- if is_obb
365
- else box_iou(gt_bboxes, detections[:, :4])
366
- )
405
+ detections = {k: detections[k][detections["conf"] > conf] for k in detections}
406
+ gt_classes = gt_cls.int().tolist()
407
+ detection_classes = detections["cls"].int().tolist()
408
+ bboxes = detections["bboxes"]
409
+ iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
367
410
 
368
- x = torch.where(iou > self.iou_thres)
411
+ x = torch.where(iou > iou_thres)
369
412
  if x[0].shape[0]:
370
413
  matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
371
414
  if x[0].shape[0] > 1:
@@ -381,59 +424,100 @@ class ConfusionMatrix:
381
424
  for i, gc in enumerate(gt_classes):
382
425
  j = m0 == i
383
426
  if n and sum(j) == 1:
384
- self.matrix[detection_classes[m1[j]], gc] += 1 # correct
427
+ dc = detection_classes[m1[j].item()]
428
+ self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
429
+ if dc == gc:
430
+ self._append_matches("TP", detections, m1[j].item())
431
+ else:
432
+ self._append_matches("FP", detections, m1[j].item())
433
+ self._append_matches("FN", batch, i)
385
434
  else:
386
- self.matrix[self.nc, gc] += 1 # true background
435
+ self.matrix[self.nc, gc] += 1 # FN
436
+ self._append_matches("FN", batch, i)
387
437
 
388
438
  for i, dc in enumerate(detection_classes):
389
439
  if not any(m1 == i):
390
- self.matrix[dc, self.nc] += 1 # predicted background
440
+ self.matrix[dc, self.nc] += 1 # FP
441
+ self._append_matches("FP", detections, i)
391
442
 
392
443
  def matrix(self):
393
444
  """Return the confusion matrix."""
394
445
  return self.matrix
395
446
 
396
- def tp_fp(self):
397
- """
398
- Return true positives and false positives.
447
+ def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
448
+ """Return true positives and false positives.
399
449
 
400
450
  Returns:
401
- (tuple): True positives and false positives.
451
+ tp (np.ndarray): True positives.
452
+ fp (np.ndarray): False positives.
402
453
  """
403
454
  tp = self.matrix.diagonal() # true positives
404
455
  fp = self.matrix.sum(1) - tp # false positives
405
456
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
406
- return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
457
+ return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
458
+
459
+ def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
460
+ """Plot grid of GT, TP, FP, FN for each image.
461
+
462
+ Args:
463
+ img (torch.Tensor): Image to plot onto.
464
+ im_file (str): Image filename to save visualizations.
465
+ save_dir (Path): Location to save the visualizations to.
466
+ """
467
+ if not self.matches:
468
+ return
469
+ from .ops import xyxy2xywh
470
+ from .plotting import plot_images
471
+
472
+ # Create batch of 4 (GT, TP, FP, FN)
473
+ labels = defaultdict(list)
474
+ for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
475
+ mbatch = self.matches[mtype]
476
+ if "conf" not in mbatch:
477
+ mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
478
+ mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
479
+ for k in mbatch.keys():
480
+ labels[k] += mbatch[k]
481
+
482
+ labels = {k: torch.stack(v, 0) if len(v) else torch.empty(0) for k, v in labels.items()}
483
+ if self.task != "obb" and labels["bboxes"].shape[0]:
484
+ labels["bboxes"] = xyxy2xywh(labels["bboxes"])
485
+ (save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
486
+ plot_images(
487
+ labels,
488
+ img.repeat(4, 1, 1, 1),
489
+ paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
490
+ fname=save_dir / "visualizations" / Path(im_file).name,
491
+ names=self.names,
492
+ max_subplots=4,
493
+ conf_thres=0.001,
494
+ )
407
495
 
408
496
  @TryExcept(msg="ConfusionMatrix plot failure")
409
497
  @plt_settings()
410
- def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
411
- """
412
- Plot the confusion matrix using matplotlib and save it to a file.
498
+ def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
499
+ """Plot the confusion matrix using matplotlib and save it to a file.
413
500
 
414
501
  Args:
415
- normalize (bool): Whether to normalize the confusion matrix.
416
- save_dir (str): Directory where the plot will be saved.
417
- names (tuple): Names of classes, used as labels on the plot.
418
- on_plot (func): An optional callback to pass plots path and data when they are rendered.
502
+ normalize (bool, optional): Whether to normalize the confusion matrix.
503
+ save_dir (str, optional): Directory where the plot will be saved.
504
+ on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.
419
505
  """
420
506
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
421
507
 
422
508
  array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
423
509
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
424
510
 
425
- names = list(names)
426
511
  fig, ax = plt.subplots(1, 1, figsize=(12, 9))
512
+ names, n = list(self.names.values()), self.nc
427
513
  if self.nc >= 100: # downsample for large class count
428
514
  k = max(2, self.nc // 60) # step size for downsampling, always > 1
429
515
  keep_idx = slice(None, None, k) # create slice instead of array
430
516
  names = names[keep_idx] # slice class names
431
517
  array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
432
518
  n = (self.nc + k - 1) // k # number of retained classes
433
- nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
434
- else:
435
- nc = nn = self.nc if self.task == "classify" else self.nc + 1
436
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
519
+ nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
520
+ ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
437
521
  xy_ticks = np.arange(len(ticklabels))
438
522
  tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
439
523
  label_fontsize = max(6, 12 - 0.1 * nc)
@@ -444,6 +528,7 @@ class ConfusionMatrix:
444
528
  im = ax.imshow(array, cmap="Blues", vmin=0.0, interpolation="none")
445
529
  ax.xaxis.set_label_position("bottom")
446
530
  if nc < 30: # Add score for each cell of confusion matrix
531
+ color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold
447
532
  for i, row in enumerate(array[:nc]):
448
533
  for j, val in enumerate(row[:nc]):
449
534
  val = array[i, j]
@@ -456,7 +541,7 @@ class ConfusionMatrix:
456
541
  ha="center",
457
542
  va="center",
458
543
  fontsize=10,
459
- color="white" if val > (0.7 if normalize else 2) else "black",
544
+ color="white" if val > color_threshold else "black",
460
545
  )
461
546
  cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)
462
547
  title = "Confusion Matrix" + " Normalized" * normalize
@@ -470,7 +555,7 @@ class ConfusionMatrix:
470
555
  if ticklabels != "auto":
471
556
  ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha="center")
472
557
  ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)
473
- for s in ["left", "right", "bottom", "top", "outline"]:
558
+ for s in {"left", "right", "bottom", "top", "outline"}:
474
559
  if s != "outline":
475
560
  ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline
476
561
  cbar.ax.spines[s].set_visible(False)
@@ -486,8 +571,45 @@ class ConfusionMatrix:
486
571
  for i in range(self.matrix.shape[0]):
487
572
  LOGGER.info(" ".join(map(str, self.matrix[i])))
488
573
 
574
+ def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
575
+ """Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
576
+ normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
577
+ or SQL.
578
+
579
+ Args:
580
+ normalize (bool): Whether to normalize the confusion matrix values.
581
+ decimals (int): Number of decimal places to round the output values to.
582
+
583
+ Returns:
584
+ (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
585
+ values for all actual classes.
586
+
587
+ Examples:
588
+ >>> results = model.val(data="coco8.yaml", plots=True)
589
+ >>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5)
590
+ >>> print(cm_dict)
591
+ """
592
+ import re
593
+
594
+ names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
595
+ clean_names, seen = [], set()
596
+ for name in names:
597
+ clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
598
+ original_clean = clean_name
599
+ counter = 1
600
+ while clean_name.lower() in seen:
601
+ clean_name = f"{original_clean}_{counter}"
602
+ counter += 1
603
+ seen.add(clean_name.lower())
604
+ clean_names.append(clean_name)
605
+ array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals)
606
+ return [
607
+ dict({"Predicted": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))})
608
+ for i in range(len(clean_names))
609
+ ]
489
610
 
490
- def smooth(y, f=0.05):
611
+
612
+ def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:
491
613
  """Box filter of fraction f."""
492
614
  nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
493
615
  p = np.ones(nf // 2) # ones padding
@@ -496,16 +618,22 @@ def smooth(y, f=0.05):
496
618
 
497
619
 
498
620
  @plt_settings()
499
- def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
500
- """
501
- Plot precision-recall curve.
621
+ def plot_pr_curve(
622
+ px: np.ndarray,
623
+ py: np.ndarray,
624
+ ap: np.ndarray,
625
+ save_dir: Path = Path("pr_curve.png"),
626
+ names: dict[int, str] = {},
627
+ on_plot=None,
628
+ ):
629
+ """Plot precision-recall curve.
502
630
 
503
631
  Args:
504
632
  px (np.ndarray): X values for the PR curve.
505
633
  py (np.ndarray): Y values for the PR curve.
506
634
  ap (np.ndarray): Average precision values.
507
635
  save_dir (Path, optional): Path to save the plot.
508
- names (dict, optional): Dictionary mapping class indices to class names.
636
+ names (dict[int, str], optional): Dictionary mapping class indices to class names.
509
637
  on_plot (callable, optional): Function to call after plot is saved.
510
638
  """
511
639
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
@@ -517,7 +645,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
517
645
  for i, y in enumerate(py.T):
518
646
  ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
519
647
  else:
520
- ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
648
+ ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
521
649
 
522
650
  ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
523
651
  ax.set_xlabel("Recall")
@@ -533,15 +661,22 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
533
661
 
534
662
 
535
663
  @plt_settings()
536
- def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
537
- """
538
- Plot metric-confidence curve.
664
+ def plot_mc_curve(
665
+ px: np.ndarray,
666
+ py: np.ndarray,
667
+ save_dir: Path = Path("mc_curve.png"),
668
+ names: dict[int, str] = {},
669
+ xlabel: str = "Confidence",
670
+ ylabel: str = "Metric",
671
+ on_plot=None,
672
+ ):
673
+ """Plot metric-confidence curve.
539
674
 
540
675
  Args:
541
676
  px (np.ndarray): X values for the metric-confidence curve.
542
677
  py (np.ndarray): Y values for the metric-confidence curve.
543
678
  save_dir (Path, optional): Path to save the plot.
544
- names (dict, optional): Dictionary mapping class indices to class names.
679
+ names (dict[int, str], optional): Dictionary mapping class indices to class names.
545
680
  xlabel (str, optional): X-axis label.
546
681
  ylabel (str, optional): Y-axis label.
547
682
  on_plot (callable, optional): Function to call after plot is saved.
@@ -554,7 +689,7 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
554
689
  for i, y in enumerate(py):
555
690
  ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
556
691
  else:
557
- ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
692
+ ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
558
693
 
559
694
  y = smooth(py.mean(0), 0.1)
560
695
  ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
@@ -570,18 +705,17 @@ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confi
570
705
  on_plot(save_dir)
571
706
 
572
707
 
573
- def compute_ap(recall, precision):
574
- """
575
- Compute the average precision (AP) given the recall and precision curves.
708
+ def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
709
+ """Compute the average precision (AP) given the recall and precision curves.
576
710
 
577
711
  Args:
578
712
  recall (list): The recall curve.
579
713
  precision (list): The precision curve.
580
714
 
581
715
  Returns:
582
- (float): Average precision.
583
- (np.ndarray): Precision envelope curve.
584
- (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
716
+ ap (float): Average precision.
717
+ mpre (np.ndarray): Precision envelope curve.
718
+ mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
585
719
  """
586
720
  # Append sentinel values to beginning and end
587
721
  mrec = np.concatenate(([0.0], recall, [1.0]))
@@ -604,10 +738,18 @@ def compute_ap(recall, precision):
604
738
 
605
739
 
606
740
  def ap_per_class(
607
- tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
608
- ):
609
- """
610
- Compute the average precision per class for object detection evaluation.
741
+ tp: np.ndarray,
742
+ conf: np.ndarray,
743
+ pred_cls: np.ndarray,
744
+ target_cls: np.ndarray,
745
+ plot: bool = False,
746
+ on_plot=None,
747
+ save_dir: Path = Path(),
748
+ names: dict[int, str] = {},
749
+ eps: float = 1e-16,
750
+ prefix: str = "",
751
+ ) -> tuple:
752
+ """Compute the average precision per class for object detection evaluation.
611
753
 
612
754
  Args:
613
755
  tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
@@ -615,9 +757,9 @@ def ap_per_class(
615
757
  pred_cls (np.ndarray): Array of predicted classes of the detections.
616
758
  target_cls (np.ndarray): Array of true classes of the detections.
617
759
  plot (bool, optional): Whether to plot PR curves or not.
618
- on_plot (func, optional): A callback to pass plots path and data when they are rendered.
760
+ on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
619
761
  save_dir (Path, optional): Directory to save the PR curves.
620
- names (dict, optional): Dict of class names to plot PR curves.
762
+ names (dict[int, str], optional): Dictionary of class names to plot PR curves.
621
763
  eps (float, optional): A small value to avoid division by zero.
622
764
  prefix (str, optional): A prefix string for saving the plot files.
623
765
 
@@ -677,8 +819,7 @@ def ap_per_class(
677
819
 
678
820
  # Compute F1 (harmonic mean of precision and recall)
679
821
  f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
680
- names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
681
- names = dict(enumerate(names)) # to dict
822
+ names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
682
823
  if plot:
683
824
  plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
684
825
  plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
@@ -693,8 +834,7 @@ def ap_per_class(
693
834
 
694
835
 
695
836
  class Metric(SimpleClass):
696
- """
697
- Class for computing evaluation metrics for Ultralytics YOLO models.
837
+ """Class for computing evaluation metrics for Ultralytics YOLO models.
698
838
 
699
839
  Attributes:
700
840
  p (list): Precision for each class. Shape: (nc,).
@@ -705,18 +845,20 @@ class Metric(SimpleClass):
705
845
  nc (int): Number of classes.
706
846
 
707
847
  Methods:
708
- ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
709
- ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
710
- mp(): Mean precision of all classes. Returns: Float.
711
- mr(): Mean recall of all classes. Returns: Float.
712
- map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
713
- map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
714
- map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
715
- mean_results(): Mean of results, returns mp, mr, map50, map.
716
- class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
717
- maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
718
- fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
719
- update(results): Update metric attributes with new evaluation results.
848
+ ap50: AP at IoU threshold of 0.5 for all classes.
849
+ ap: AP at IoU thresholds from 0.5 to 0.95 for all classes.
850
+ mp: Mean precision of all classes.
851
+ mr: Mean recall of all classes.
852
+ map50: Mean AP at IoU threshold of 0.5 for all classes.
853
+ map75: Mean AP at IoU threshold of 0.75 for all classes.
854
+ map: Mean AP at IoU thresholds from 0.5 to 0.95 for all classes.
855
+ mean_results: Mean of results, returns mp, mr, map50, map.
856
+ class_result: Class-aware result, returns p[i], r[i], ap50[i], ap[i].
857
+ maps: mAP of each class.
858
+ fitness: Model fitness as a weighted combination of metrics.
859
+ update: Update metric attributes with new evaluation results.
860
+ curves: Provides a list of curves for accessing specific metrics like precision, recall, F1, etc.
861
+ curves_results: Provide a list of results for accessing specific metrics like precision, recall, F1, etc.
720
862
  """
721
863
 
722
864
  def __init__(self) -> None:
@@ -729,29 +871,26 @@ class Metric(SimpleClass):
729
871
  self.nc = 0
730
872
 
731
873
  @property
732
- def ap50(self):
733
- """
734
- Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
874
+ def ap50(self) -> np.ndarray | list:
875
+ """Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
735
876
 
736
877
  Returns:
737
- (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
878
+ (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
738
879
  """
739
880
  return self.all_ap[:, 0] if len(self.all_ap) else []
740
881
 
741
882
  @property
742
- def ap(self):
743
- """
744
- Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
883
+ def ap(self) -> np.ndarray | list:
884
+ """Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
745
885
 
746
886
  Returns:
747
- (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
887
+ (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
748
888
  """
749
889
  return self.all_ap.mean(1) if len(self.all_ap) else []
750
890
 
751
891
  @property
752
- def mp(self):
753
- """
754
- Return the Mean Precision of all classes.
892
+ def mp(self) -> float:
893
+ """Return the Mean Precision of all classes.
755
894
 
756
895
  Returns:
757
896
  (float): The mean precision of all classes.
@@ -759,9 +898,8 @@ class Metric(SimpleClass):
759
898
  return self.p.mean() if len(self.p) else 0.0
760
899
 
761
900
  @property
762
- def mr(self):
763
- """
764
- Return the Mean Recall of all classes.
901
+ def mr(self) -> float:
902
+ """Return the Mean Recall of all classes.
765
903
 
766
904
  Returns:
767
905
  (float): The mean recall of all classes.
@@ -769,9 +907,8 @@ class Metric(SimpleClass):
769
907
  return self.r.mean() if len(self.r) else 0.0
770
908
 
771
909
  @property
772
- def map50(self):
773
- """
774
- Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
910
+ def map50(self) -> float:
911
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
775
912
 
776
913
  Returns:
777
914
  (float): The mAP at an IoU threshold of 0.5.
@@ -779,9 +916,8 @@ class Metric(SimpleClass):
779
916
  return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
780
917
 
781
918
  @property
782
- def map75(self):
783
- """
784
- Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
919
+ def map75(self) -> float:
920
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
785
921
 
786
922
  Returns:
787
923
  (float): The mAP at an IoU threshold of 0.75.
@@ -789,39 +925,37 @@ class Metric(SimpleClass):
789
925
  return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
790
926
 
791
927
  @property
792
- def map(self):
793
- """
794
- Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
928
+ def map(self) -> float:
929
+ """Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
795
930
 
796
931
  Returns:
797
932
  (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
798
933
  """
799
934
  return self.all_ap.mean() if len(self.all_ap) else 0.0
800
935
 
801
- def mean_results(self):
936
+ def mean_results(self) -> list[float]:
802
937
  """Return mean of results, mp, mr, map50, map."""
803
938
  return [self.mp, self.mr, self.map50, self.map]
804
939
 
805
- def class_result(self, i):
940
+ def class_result(self, i: int) -> tuple[float, float, float, float]:
806
941
  """Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
807
942
  return self.p[i], self.r[i], self.ap50[i], self.ap[i]
808
943
 
809
944
  @property
810
- def maps(self):
945
+ def maps(self) -> np.ndarray:
811
946
  """Return mAP of each class."""
812
947
  maps = np.zeros(self.nc) + self.map
813
948
  for i, c in enumerate(self.ap_class_index):
814
949
  maps[c] = self.ap[i]
815
950
  return maps
816
951
 
817
- def fitness(self):
952
+ def fitness(self) -> float:
818
953
  """Return model fitness as a weighted combination of metrics."""
819
- w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
954
+ w = [0.0, 0.0, 0.0, 1.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
820
955
  return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
821
956
 
822
- def update(self, results):
823
- """
824
- Update the evaluation metrics with a new set of results.
957
+ def update(self, results: tuple):
958
+ """Update the evaluation metrics with a new set of results.
825
959
 
826
960
  Args:
827
961
  results (tuple): A tuple containing evaluation metrics:
@@ -850,12 +984,12 @@ class Metric(SimpleClass):
850
984
  ) = results
851
985
 
852
986
  @property
853
- def curves(self):
987
+ def curves(self) -> list:
854
988
  """Return a list of curves for accessing specific metrics curves."""
855
989
  return []
856
990
 
857
991
  @property
858
- def curves_results(self):
992
+ def curves_results(self) -> list[list]:
859
993
  """Return a list of curves for accessing specific metrics curves."""
860
994
  return [
861
995
  [self.px, self.prec_values, "Recall", "Precision"],
@@ -865,227 +999,273 @@ class Metric(SimpleClass):
865
999
  ]
866
1000
 
867
1001
 
868
- class DetMetrics(SimpleClass):
869
- """
870
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1002
+ class DetMetrics(SimpleClass, DataExportMixin):
1003
+ """Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
871
1004
 
872
1005
  Attributes:
873
- save_dir (Path): A path to the directory where the output plots will be saved.
874
- plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
875
- names (dict): A dictionary of class names.
1006
+ names (dict[int, str]): A dictionary of class names.
876
1007
  box (Metric): An instance of the Metric class for storing detection results.
877
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
1008
+ speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
878
1009
  task (str): The task type, set to 'detect'.
1010
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1011
+ target classes, and target images.
1012
+ nt_per_class: Number of targets per class.
1013
+ nt_per_image: Number of targets per image.
1014
+
1015
+ Methods:
1016
+ update_stats: Update statistics by appending new values to existing stat collections.
1017
+ process: Process predicted results for object detection and update metrics.
1018
+ clear_stats: Clear the stored statistics.
1019
+ keys: Return a list of keys for accessing specific metrics.
1020
+ mean_results: Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.
1021
+ class_result: Return the result of evaluating the performance of an object detection model on a specific class.
1022
+ maps: Return mean Average Precision (mAP) scores per class.
1023
+ fitness: Return the fitness of box object.
1024
+ ap_class_index: Return the average precision index per class.
1025
+ results_dict: Return dictionary of computed performance metrics and statistics.
1026
+ curves: Return a list of curves for accessing specific metrics curves.
1027
+ curves_results: Return a list of computed performance metrics and statistics.
1028
+ summary: Generate a summarized representation of per-class detection metrics as a list of dictionaries.
879
1029
  """
880
1030
 
881
- def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
882
- """
883
- Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1031
+ def __init__(self, names: dict[int, str] = {}) -> None:
1032
+ """Initialize a DetMetrics instance with a save directory, plot flag, and class names.
884
1033
 
885
1034
  Args:
886
- save_dir (Path, optional): Directory to save plots.
887
- plot (bool, optional): Whether to plot precision-recall curves.
888
- names (dict, optional): Dictionary mapping class indices to names.
1035
+ names (dict[int, str], optional): Dictionary of class names.
889
1036
  """
890
- self.save_dir = save_dir
891
- self.plot = plot
892
1037
  self.names = names
893
1038
  self.box = Metric()
894
1039
  self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
895
1040
  self.task = "detect"
1041
+ self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
1042
+ self.nt_per_class = None
1043
+ self.nt_per_image = None
896
1044
 
897
- def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
1045
+ def update_stats(self, stat: dict[str, Any]) -> None:
1046
+ """Update statistics by appending new values to existing stat collections.
1047
+
1048
+ Args:
1049
+ stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
1050
+ keys in self.stats.
898
1051
  """
899
- Process predicted results for object detection and update metrics.
1052
+ for k in self.stats.keys():
1053
+ self.stats[k].append(stat[k])
1054
+
1055
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1056
+ """Process predicted results for object detection and update metrics.
900
1057
 
901
1058
  Args:
902
- tp (np.ndarray): True positive array.
903
- conf (np.ndarray): Confidence array.
904
- pred_cls (np.ndarray): Predicted class indices array.
905
- target_cls (np.ndarray): Target class indices array.
906
- on_plot (callable, optional): Function to call after plots are generated.
1059
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1060
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1061
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
1062
+
1063
+ Returns:
1064
+ (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
907
1065
  """
1066
+ stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
1067
+ if not stats:
1068
+ return stats
908
1069
  results = ap_per_class(
909
- tp,
910
- conf,
911
- pred_cls,
912
- target_cls,
913
- plot=self.plot,
914
- save_dir=self.save_dir,
1070
+ stats["tp"],
1071
+ stats["conf"],
1072
+ stats["pred_cls"],
1073
+ stats["target_cls"],
1074
+ plot=plot,
1075
+ save_dir=save_dir,
915
1076
  names=self.names,
916
1077
  on_plot=on_plot,
1078
+ prefix="Box",
917
1079
  )[2:]
918
1080
  self.box.nc = len(self.names)
919
1081
  self.box.update(results)
1082
+ self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
1083
+ self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
1084
+ return stats
1085
+
1086
+ def clear_stats(self):
1087
+ """Clear the stored statistics."""
1088
+ for v in self.stats.values():
1089
+ v.clear()
920
1090
 
921
1091
  @property
922
- def keys(self):
1092
+ def keys(self) -> list[str]:
923
1093
  """Return a list of keys for accessing specific metrics."""
924
1094
  return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
925
1095
 
926
- def mean_results(self):
1096
+ def mean_results(self) -> list[float]:
927
1097
  """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
928
1098
  return self.box.mean_results()
929
1099
 
930
- def class_result(self, i):
1100
+ def class_result(self, i: int) -> tuple[float, float, float, float]:
931
1101
  """Return the result of evaluating the performance of an object detection model on a specific class."""
932
1102
  return self.box.class_result(i)
933
1103
 
934
1104
  @property
935
- def maps(self):
1105
+ def maps(self) -> np.ndarray:
936
1106
  """Return mean Average Precision (mAP) scores per class."""
937
1107
  return self.box.maps
938
1108
 
939
1109
  @property
940
- def fitness(self):
1110
+ def fitness(self) -> float:
941
1111
  """Return the fitness of box object."""
942
1112
  return self.box.fitness()
943
1113
 
944
1114
  @property
945
- def ap_class_index(self):
1115
+ def ap_class_index(self) -> list:
946
1116
  """Return the average precision index per class."""
947
1117
  return self.box.ap_class_index
948
1118
 
949
1119
  @property
950
- def results_dict(self):
1120
+ def results_dict(self) -> dict[str, float]:
951
1121
  """Return dictionary of computed performance metrics and statistics."""
952
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1122
+ keys = [*self.keys, "fitness"]
1123
+ values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
1124
+ return dict(zip(keys, values))
953
1125
 
954
1126
  @property
955
- def curves(self):
1127
+ def curves(self) -> list[str]:
956
1128
  """Return a list of curves for accessing specific metrics curves."""
957
1129
  return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
958
1130
 
959
1131
  @property
960
- def curves_results(self):
961
- """Return dictionary of computed performance metrics and statistics."""
1132
+ def curves_results(self) -> list[list]:
1133
+ """Return a list of computed performance metrics and statistics."""
962
1134
  return self.box.curves_results
963
1135
 
1136
+ def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1137
+ """Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
1138
+ shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
964
1139
 
965
- class SegmentMetrics(SimpleClass):
966
- """
967
- Calculates and aggregates detection and segmentation metrics over a given set of classes.
1140
+ Args:
1141
+ normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1142
+ decimals (int): Number of decimal places to round the metrics values to.
1143
+
1144
+ Returns:
1145
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1146
+ values.
1147
+
1148
+ Examples:
1149
+ >>> results = model.val(data="coco8.yaml")
1150
+ >>> detection_summary = results.summary()
1151
+ >>> print(detection_summary)
1152
+ """
1153
+ per_class = {
1154
+ "Box-P": self.box.p,
1155
+ "Box-R": self.box.r,
1156
+ "Box-F1": self.box.f1,
1157
+ }
1158
+ return [
1159
+ {
1160
+ "Class": self.names[self.ap_class_index[i]],
1161
+ "Images": self.nt_per_image[self.ap_class_index[i]],
1162
+ "Instances": self.nt_per_class[self.ap_class_index[i]],
1163
+ **{k: round(v[i], decimals) for k, v in per_class.items()},
1164
+ "mAP50": round(self.class_result(i)[2], decimals),
1165
+ "mAP50-95": round(self.class_result(i)[3], decimals),
1166
+ }
1167
+ for i in range(len(per_class["Box-P"]))
1168
+ ]
1169
+
1170
+
1171
+ class SegmentMetrics(DetMetrics):
1172
+ """Calculate and aggregate detection and segmentation metrics over a given set of classes.
968
1173
 
969
1174
  Attributes:
970
- save_dir (Path): Path to the directory where the output plots should be saved.
971
- plot (bool): Whether to save the detection and segmentation plots.
972
- names (dict): Dictionary of class names.
973
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1175
+ names (dict[int, str]): Dictionary of class names.
1176
+ box (Metric): An instance of the Metric class for storing detection results.
974
1177
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
975
- speed (dict): Dictionary to store the time taken in different phases of inference.
1178
+ speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
976
1179
  task (str): The task type, set to 'segment'.
1180
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1181
+ target classes, and target images.
1182
+ nt_per_class: Number of targets per class.
1183
+ nt_per_image: Number of targets per image.
1184
+
1185
+ Methods:
1186
+ process: Process the detection and segmentation metrics over the given set of predictions.
1187
+ keys: Return a list of keys for accessing metrics.
1188
+ mean_results: Return the mean metrics for bounding box and segmentation results.
1189
+ class_result: Return classification results for a specified class index.
1190
+ maps: Return mAP scores for object detection and semantic segmentation models.
1191
+ fitness: Return the fitness score for both segmentation and bounding box models.
1192
+ curves: Return a list of curves for accessing specific metrics curves.
1193
+ curves_results: Provide a list of computed performance metrics and statistics.
1194
+ summary: Generate a summarized representation of per-class segmentation metrics as a list of dictionaries.
977
1195
  """
978
1196
 
979
- def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
980
- """
981
- Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1197
+ def __init__(self, names: dict[int, str] = {}) -> None:
1198
+ """Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
982
1199
 
983
1200
  Args:
984
- save_dir (Path, optional): Directory to save plots.
985
- plot (bool, optional): Whether to plot precision-recall curves.
986
- names (dict, optional): Dictionary mapping class indices to names.
1201
+ names (dict[int, str], optional): Dictionary of class names.
987
1202
  """
988
- self.save_dir = save_dir
989
- self.plot = plot
990
- self.names = names
991
- self.box = Metric()
1203
+ DetMetrics.__init__(self, names)
992
1204
  self.seg = Metric()
993
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
994
1205
  self.task = "segment"
1206
+ self.stats["tp_m"] = [] # add additional stats for masks
995
1207
 
996
- def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
997
- """
998
- Process the detection and segmentation metrics over the given set of predictions.
1208
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1209
+ """Process the detection and segmentation metrics over the given set of predictions.
999
1210
 
1000
1211
  Args:
1001
- tp (np.ndarray): True positive array for boxes.
1002
- tp_m (np.ndarray): True positive array for masks.
1003
- conf (np.ndarray): Confidence array.
1004
- pred_cls (np.ndarray): Predicted class indices array.
1005
- target_cls (np.ndarray): Target class indices array.
1006
- on_plot (callable, optional): Function to call after plots are generated.
1212
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1213
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1214
+ on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
1215
+
1216
+ Returns:
1217
+ (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1007
1218
  """
1219
+ stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
1008
1220
  results_mask = ap_per_class(
1009
- tp_m,
1010
- conf,
1011
- pred_cls,
1012
- target_cls,
1013
- plot=self.plot,
1221
+ stats["tp_m"],
1222
+ stats["conf"],
1223
+ stats["pred_cls"],
1224
+ stats["target_cls"],
1225
+ plot=plot,
1014
1226
  on_plot=on_plot,
1015
- save_dir=self.save_dir,
1227
+ save_dir=save_dir,
1016
1228
  names=self.names,
1017
1229
  prefix="Mask",
1018
1230
  )[2:]
1019
1231
  self.seg.nc = len(self.names)
1020
1232
  self.seg.update(results_mask)
1021
- results_box = ap_per_class(
1022
- tp,
1023
- conf,
1024
- pred_cls,
1025
- target_cls,
1026
- plot=self.plot,
1027
- on_plot=on_plot,
1028
- save_dir=self.save_dir,
1029
- names=self.names,
1030
- prefix="Box",
1031
- )[2:]
1032
- self.box.nc = len(self.names)
1033
- self.box.update(results_box)
1233
+ return stats
1034
1234
 
1035
1235
  @property
1036
- def keys(self):
1236
+ def keys(self) -> list[str]:
1037
1237
  """Return a list of keys for accessing metrics."""
1038
1238
  return [
1039
- "metrics/precision(B)",
1040
- "metrics/recall(B)",
1041
- "metrics/mAP50(B)",
1042
- "metrics/mAP50-95(B)",
1239
+ *DetMetrics.keys.fget(self),
1043
1240
  "metrics/precision(M)",
1044
1241
  "metrics/recall(M)",
1045
1242
  "metrics/mAP50(M)",
1046
1243
  "metrics/mAP50-95(M)",
1047
1244
  ]
1048
1245
 
1049
- def mean_results(self):
1246
+ def mean_results(self) -> list[float]:
1050
1247
  """Return the mean metrics for bounding box and segmentation results."""
1051
- return self.box.mean_results() + self.seg.mean_results()
1248
+ return DetMetrics.mean_results(self) + self.seg.mean_results()
1052
1249
 
1053
- def class_result(self, i):
1250
+ def class_result(self, i: int) -> list[float]:
1054
1251
  """Return classification results for a specified class index."""
1055
- return self.box.class_result(i) + self.seg.class_result(i)
1252
+ return DetMetrics.class_result(self, i) + self.seg.class_result(i)
1056
1253
 
1057
1254
  @property
1058
- def maps(self):
1255
+ def maps(self) -> np.ndarray:
1059
1256
  """Return mAP scores for object detection and semantic segmentation models."""
1060
- return self.box.maps + self.seg.maps
1257
+ return DetMetrics.maps.fget(self) + self.seg.maps
1061
1258
 
1062
1259
  @property
1063
- def fitness(self):
1260
+ def fitness(self) -> float:
1064
1261
  """Return the fitness score for both segmentation and bounding box models."""
1065
- return self.seg.fitness() + self.box.fitness()
1262
+ return self.seg.fitness() + DetMetrics.fitness.fget(self)
1066
1263
 
1067
1264
  @property
1068
- def ap_class_index(self):
1069
- """
1070
- Return the class indices.
1071
-
1072
- Boxes and masks have the same ap_class_index.
1073
- """
1074
- return self.box.ap_class_index
1075
-
1076
- @property
1077
- def results_dict(self):
1078
- """Return results of object detection model for evaluation."""
1079
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1080
-
1081
- @property
1082
- def curves(self):
1265
+ def curves(self) -> list[str]:
1083
1266
  """Return a list of curves for accessing specific metrics curves."""
1084
1267
  return [
1085
- "Precision-Recall(B)",
1086
- "F1-Confidence(B)",
1087
- "Precision-Confidence(B)",
1088
- "Recall-Confidence(B)",
1268
+ *DetMetrics.curves.fget(self),
1089
1269
  "Precision-Recall(M)",
1090
1270
  "F1-Confidence(M)",
1091
1271
  "Precision-Confidence(M)",
@@ -1093,127 +1273,137 @@ class SegmentMetrics(SimpleClass):
1093
1273
  ]
1094
1274
 
1095
1275
  @property
1096
- def curves_results(self):
1097
- """Return dictionary of computed performance metrics and statistics."""
1098
- return self.box.curves_results + self.seg.curves_results
1276
+ def curves_results(self) -> list[list]:
1277
+ """Return a list of computed performance metrics and statistics."""
1278
+ return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1099
1279
 
1280
+ def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1281
+ """Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
1282
+ both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
1283
+ each class.
1100
1284
 
1101
- class PoseMetrics(SegmentMetrics):
1102
- """
1103
- Calculates and aggregates detection and pose metrics over a given set of classes.
1285
+ Args:
1286
+ normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1287
+ decimals (int): Number of decimal places to round the metrics values to.
1288
+
1289
+ Returns:
1290
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1291
+ values.
1292
+
1293
+ Examples:
1294
+ >>> results = model.val(data="coco8-seg.yaml")
1295
+ >>> seg_summary = results.summary(decimals=4)
1296
+ >>> print(seg_summary)
1297
+ """
1298
+ per_class = {
1299
+ "Mask-P": self.seg.p,
1300
+ "Mask-R": self.seg.r,
1301
+ "Mask-F1": self.seg.f1,
1302
+ }
1303
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1304
+ for i, s in enumerate(summary):
1305
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
1306
+ return summary
1307
+
1308
+
1309
+ class PoseMetrics(DetMetrics):
1310
+ """Calculate and aggregate detection and pose metrics over a given set of classes.
1104
1311
 
1105
1312
  Attributes:
1106
- save_dir (Path): Path to the directory where the output plots should be saved.
1107
- plot (bool): Whether to save the detection and pose plots.
1108
- names (dict): Dictionary of class names.
1109
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1313
+ names (dict[int, str]): Dictionary of class names.
1110
1314
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1111
- speed (dict): Dictionary to store the time taken in different phases of inference.
1315
+ box (Metric): An instance of the Metric class for storing detection results.
1316
+ speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1112
1317
  task (str): The task type, set to 'pose'.
1318
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1319
+ target classes, and target images.
1320
+ nt_per_class: Number of targets per class.
1321
+ nt_per_image: Number of targets per image.
1113
1322
 
1114
1323
  Methods:
1115
- process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
1116
- mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
1117
- class_result(i): Returns the detection and segmentation metrics of class `i`.
1118
- maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
1119
- fitness: Returns the fitness scores, which are a single weighted combination of metrics.
1120
- ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
1121
- results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
1324
+ process: Process the detection and pose metrics over the given set of predictions. R
1325
+ keys: Return a list of keys for accessing metrics.
1326
+ mean_results: Return the mean results of box and pose.
1327
+ class_result: Return the class-wise detection results for a specific class i.
1328
+ maps: Return the mean average precision (mAP) per class for both box and pose detections.
1329
+ fitness: Return combined fitness score for pose and box detection.
1330
+ curves: Return a list of curves for accessing specific metrics curves.
1331
+ curves_results: Provide a list of computed performance metrics and statistics.
1332
+ summary: Generate a summarized representation of per-class pose metrics as a list of dictionaries.
1122
1333
  """
1123
1334
 
1124
- def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
1125
- """
1126
- Initialize the PoseMetrics class with directory path, class names, and plotting options.
1335
+ def __init__(self, names: dict[int, str] = {}) -> None:
1336
+ """Initialize the PoseMetrics class with directory path, class names, and plotting options.
1127
1337
 
1128
1338
  Args:
1129
- save_dir (Path, optional): Directory to save plots.
1130
- plot (bool, optional): Whether to plot precision-recall curves.
1131
- names (dict, optional): Dictionary mapping class indices to names.
1339
+ names (dict[int, str], optional): Dictionary of class names.
1132
1340
  """
1133
- super().__init__(save_dir, plot, names)
1134
- self.save_dir = save_dir
1135
- self.plot = plot
1136
- self.names = names
1137
- self.box = Metric()
1341
+ super().__init__(names)
1138
1342
  self.pose = Metric()
1139
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1140
1343
  self.task = "pose"
1344
+ self.stats["tp_p"] = [] # add additional stats for pose
1141
1345
 
1142
- def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
1143
- """
1144
- Process the detection and pose metrics over the given set of predictions.
1346
+ def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1347
+ """Process the detection and pose metrics over the given set of predictions.
1145
1348
 
1146
1349
  Args:
1147
- tp (np.ndarray): True positive array for boxes.
1148
- tp_p (np.ndarray): True positive array for keypoints.
1149
- conf (np.ndarray): Confidence array.
1150
- pred_cls (np.ndarray): Predicted class indices array.
1151
- target_cls (np.ndarray): Target class indices array.
1350
+ save_dir (Path): Directory to save plots. Defaults to Path(".").
1351
+ plot (bool): Whether to plot precision-recall curves. Defaults to False.
1152
1352
  on_plot (callable, optional): Function to call after plots are generated.
1353
+
1354
+ Returns:
1355
+ (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
1153
1356
  """
1357
+ stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
1154
1358
  results_pose = ap_per_class(
1155
- tp_p,
1156
- conf,
1157
- pred_cls,
1158
- target_cls,
1159
- plot=self.plot,
1359
+ stats["tp_p"],
1360
+ stats["conf"],
1361
+ stats["pred_cls"],
1362
+ stats["target_cls"],
1363
+ plot=plot,
1160
1364
  on_plot=on_plot,
1161
- save_dir=self.save_dir,
1365
+ save_dir=save_dir,
1162
1366
  names=self.names,
1163
1367
  prefix="Pose",
1164
1368
  )[2:]
1165
1369
  self.pose.nc = len(self.names)
1166
1370
  self.pose.update(results_pose)
1167
- results_box = ap_per_class(
1168
- tp,
1169
- conf,
1170
- pred_cls,
1171
- target_cls,
1172
- plot=self.plot,
1173
- on_plot=on_plot,
1174
- save_dir=self.save_dir,
1175
- names=self.names,
1176
- prefix="Box",
1177
- )[2:]
1178
- self.box.nc = len(self.names)
1179
- self.box.update(results_box)
1371
+ return stats
1180
1372
 
1181
1373
  @property
1182
- def keys(self):
1183
- """Return list of evaluation metric keys."""
1374
+ def keys(self) -> list[str]:
1375
+ """Return a list of evaluation metric keys."""
1184
1376
  return [
1185
- "metrics/precision(B)",
1186
- "metrics/recall(B)",
1187
- "metrics/mAP50(B)",
1188
- "metrics/mAP50-95(B)",
1377
+ *DetMetrics.keys.fget(self),
1189
1378
  "metrics/precision(P)",
1190
1379
  "metrics/recall(P)",
1191
1380
  "metrics/mAP50(P)",
1192
1381
  "metrics/mAP50-95(P)",
1193
1382
  ]
1194
1383
 
1195
- def mean_results(self):
1384
+ def mean_results(self) -> list[float]:
1196
1385
  """Return the mean results of box and pose."""
1197
- return self.box.mean_results() + self.pose.mean_results()
1386
+ return DetMetrics.mean_results(self) + self.pose.mean_results()
1198
1387
 
1199
- def class_result(self, i):
1388
+ def class_result(self, i: int) -> list[float]:
1200
1389
  """Return the class-wise detection results for a specific class i."""
1201
- return self.box.class_result(i) + self.pose.class_result(i)
1390
+ return DetMetrics.class_result(self, i) + self.pose.class_result(i)
1202
1391
 
1203
1392
  @property
1204
- def maps(self):
1393
+ def maps(self) -> np.ndarray:
1205
1394
  """Return the mean average precision (mAP) per class for both box and pose detections."""
1206
- return self.box.maps + self.pose.maps
1395
+ return DetMetrics.maps.fget(self) + self.pose.maps
1207
1396
 
1208
1397
  @property
1209
- def fitness(self):
1398
+ def fitness(self) -> float:
1210
1399
  """Return combined fitness score for pose and box detection."""
1211
- return self.pose.fitness() + self.box.fitness()
1400
+ return self.pose.fitness() + DetMetrics.fitness.fget(self)
1212
1401
 
1213
1402
  @property
1214
- def curves(self):
1403
+ def curves(self) -> list[str]:
1215
1404
  """Return a list of curves for accessing specific metrics curves."""
1216
1405
  return [
1406
+ *DetMetrics.curves.fget(self),
1217
1407
  "Precision-Recall(B)",
1218
1408
  "F1-Confidence(B)",
1219
1409
  "Precision-Confidence(B)",
@@ -1225,20 +1415,55 @@ class PoseMetrics(SegmentMetrics):
1225
1415
  ]
1226
1416
 
1227
1417
  @property
1228
- def curves_results(self):
1229
- """Return dictionary of computed performance metrics and statistics."""
1230
- return self.box.curves_results + self.pose.curves_results
1418
+ def curves_results(self) -> list[list]:
1419
+ """Return a list of computed performance metrics and statistics."""
1420
+ return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1231
1421
 
1422
+ def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1423
+ """Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
1424
+ and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1232
1425
 
1233
- class ClassifyMetrics(SimpleClass):
1234
- """
1235
- Class for computing classification metrics including top-1 and top-5 accuracy.
1426
+ Args:
1427
+ normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1428
+ decimals (int): Number of decimal places to round the metrics values to.
1429
+
1430
+ Returns:
1431
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1432
+ values.
1433
+
1434
+ Examples:
1435
+ >>> results = model.val(data="coco8-pose.yaml")
1436
+ >>> pose_summary = results.summary(decimals=4)
1437
+ >>> print(pose_summary)
1438
+ """
1439
+ per_class = {
1440
+ "Pose-P": self.pose.p,
1441
+ "Pose-R": self.pose.r,
1442
+ "Pose-F1": self.pose.f1,
1443
+ }
1444
+ summary = DetMetrics.summary(self, normalize, decimals) # get box summary
1445
+ for i, s in enumerate(summary):
1446
+ s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
1447
+ return summary
1448
+
1449
+
1450
+ class ClassifyMetrics(SimpleClass, DataExportMixin):
1451
+ """Class for computing classification metrics including top-1 and top-5 accuracy.
1236
1452
 
1237
1453
  Attributes:
1238
1454
  top1 (float): The top-1 accuracy.
1239
1455
  top5 (float): The top-5 accuracy.
1240
1456
  speed (dict): A dictionary containing the time taken for each step in the pipeline.
1241
1457
  task (str): The task type, set to 'classify'.
1458
+
1459
+ Methods:
1460
+ process: Process target classes and predicted classes to compute metrics.
1461
+ fitness: Return mean of top-1 and top-5 accuracies as fitness score.
1462
+ results_dict: Return a dictionary with model's performance metrics and fitness score.
1463
+ keys: Return a list of keys for the results_dict property.
1464
+ curves: Return a list of curves for accessing specific metrics curves.
1465
+ curves_results: Provide a list of computed performance metrics and statistics.
1466
+ summary: Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1242
1467
  """
1243
1468
 
1244
1469
  def __init__(self) -> None:
@@ -1248,9 +1473,8 @@ class ClassifyMetrics(SimpleClass):
1248
1473
  self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1249
1474
  self.task = "classify"
1250
1475
 
1251
- def process(self, targets, pred):
1252
- """
1253
- Process target classes and predicted classes to compute metrics.
1476
+ def process(self, targets: torch.Tensor, pred: torch.Tensor):
1477
+ """Process target classes and predicted classes to compute metrics.
1254
1478
 
1255
1479
  Args:
1256
1480
  targets (torch.Tensor): Target classes.
@@ -1262,124 +1486,71 @@ class ClassifyMetrics(SimpleClass):
1262
1486
  self.top1, self.top5 = acc.mean(0).tolist()
1263
1487
 
1264
1488
  @property
1265
- def fitness(self):
1489
+ def fitness(self) -> float:
1266
1490
  """Return mean of top-1 and top-5 accuracies as fitness score."""
1267
1491
  return (self.top1 + self.top5) / 2
1268
1492
 
1269
1493
  @property
1270
- def results_dict(self):
1494
+ def results_dict(self) -> dict[str, float]:
1271
1495
  """Return a dictionary with model's performance metrics and fitness score."""
1272
- return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
1496
+ return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
1273
1497
 
1274
1498
  @property
1275
- def keys(self):
1499
+ def keys(self) -> list[str]:
1276
1500
  """Return a list of keys for the results_dict property."""
1277
1501
  return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
1278
1502
 
1279
1503
  @property
1280
- def curves(self):
1504
+ def curves(self) -> list:
1281
1505
  """Return a list of curves for accessing specific metrics curves."""
1282
1506
  return []
1283
1507
 
1284
1508
  @property
1285
- def curves_results(self):
1509
+ def curves_results(self) -> list:
1286
1510
  """Return a list of curves for accessing specific metrics curves."""
1287
1511
  return []
1288
1512
 
1513
+ def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
1514
+ """Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1289
1515
 
1290
- class OBBMetrics(SimpleClass):
1291
- """
1292
- Metrics for evaluating oriented bounding box (OBB) detection.
1516
+ Args:
1517
+ normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1518
+ decimals (int): Number of decimal places to round the metrics values to.
1519
+
1520
+ Returns:
1521
+ (list[dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy.
1522
+
1523
+ Examples:
1524
+ >>> results = model.val(data="imagenet10")
1525
+ >>> classify_summary = results.summary(decimals=4)
1526
+ >>> print(classify_summary)
1527
+ """
1528
+ return [{"top1_acc": round(self.top1, decimals), "top5_acc": round(self.top5, decimals)}]
1529
+
1530
+
1531
+ class OBBMetrics(DetMetrics):
1532
+ """Metrics for evaluating oriented bounding box (OBB) detection.
1293
1533
 
1294
1534
  Attributes:
1295
- save_dir (Path): Path to the directory where the output plots should be saved.
1296
- plot (bool): Whether to save the detection plots.
1297
- names (dict): Dictionary of class names.
1535
+ names (dict[int, str]): Dictionary of class names.
1298
1536
  box (Metric): An instance of the Metric class for storing detection results.
1299
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
1537
+ speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1538
+ task (str): The task type, set to 'obb'.
1539
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1540
+ target classes, and target images.
1541
+ nt_per_class: Number of targets per class.
1542
+ nt_per_image: Number of targets per image.
1300
1543
 
1301
1544
  References:
1302
1545
  https://arxiv.org/pdf/2106.06072.pdf
1303
1546
  """
1304
1547
 
1305
- def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
1306
- """
1307
- Initialize an OBBMetrics instance with directory, plotting, and class names.
1548
+ def __init__(self, names: dict[int, str] = {}) -> None:
1549
+ """Initialize an OBBMetrics instance with directory, plotting, and class names.
1308
1550
 
1309
1551
  Args:
1310
- save_dir (Path, optional): Directory to save plots.
1311
- plot (bool, optional): Whether to plot precision-recall curves.
1312
- names (dict, optional): Dictionary mapping class indices to names.
1552
+ names (dict[int, str], optional): Dictionary of class names.
1313
1553
  """
1314
- self.save_dir = save_dir
1315
- self.plot = plot
1316
- self.names = names
1317
- self.box = Metric()
1318
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1319
-
1320
- def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
1321
- """
1322
- Process predicted results for object detection and update metrics.
1323
-
1324
- Args:
1325
- tp (np.ndarray): True positive array.
1326
- conf (np.ndarray): Confidence array.
1327
- pred_cls (np.ndarray): Predicted class indices array.
1328
- target_cls (np.ndarray): Target class indices array.
1329
- on_plot (callable, optional): Function to call after plots are generated.
1330
- """
1331
- results = ap_per_class(
1332
- tp,
1333
- conf,
1334
- pred_cls,
1335
- target_cls,
1336
- plot=self.plot,
1337
- save_dir=self.save_dir,
1338
- names=self.names,
1339
- on_plot=on_plot,
1340
- )[2:]
1341
- self.box.nc = len(self.names)
1342
- self.box.update(results)
1343
-
1344
- @property
1345
- def keys(self):
1346
- """Return a list of keys for accessing specific metrics."""
1347
- return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
1348
-
1349
- def mean_results(self):
1350
- """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
1351
- return self.box.mean_results()
1352
-
1353
- def class_result(self, i):
1354
- """Return the result of evaluating the performance of an object detection model on a specific class."""
1355
- return self.box.class_result(i)
1356
-
1357
- @property
1358
- def maps(self):
1359
- """Return mean Average Precision (mAP) scores per class."""
1360
- return self.box.maps
1361
-
1362
- @property
1363
- def fitness(self):
1364
- """Return the fitness of box object."""
1365
- return self.box.fitness()
1366
-
1367
- @property
1368
- def ap_class_index(self):
1369
- """Return the average precision index per class."""
1370
- return self.box.ap_class_index
1371
-
1372
- @property
1373
- def results_dict(self):
1374
- """Return dictionary of computed performance metrics and statistics."""
1375
- return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1376
-
1377
- @property
1378
- def curves(self):
1379
- """Return a list of curves for accessing specific metrics curves."""
1380
- return []
1381
-
1382
- @property
1383
- def curves_results(self):
1384
- """Return a list of curves for accessing specific metrics curves."""
1385
- return []
1554
+ DetMetrics.__init__(self, names)
1555
+ # TODO: probably remove task as well
1556
+ self.task = "obb"