dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,16 @@ import torch
15
15
  from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
16
16
 
17
17
  OKS_SIGMA = (
18
- np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
18
+ np.array(
19
+ [0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89],
20
+ dtype=np.float32,
21
+ )
19
22
  / 10.0
20
23
  )
21
24
 
22
25
 
23
26
  def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
24
- """
25
- Calculate the intersection over box2 area given box1 and box2.
27
+ """Calculate the intersection over box2 area given box1 and box2.
26
28
 
27
29
  Args:
28
30
  box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
@@ -53,8 +55,7 @@ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float =
53
55
 
54
56
 
55
57
  def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
56
- """
57
- Calculate intersection-over-union (IoU) of boxes.
58
+ """Calculate intersection-over-union (IoU) of boxes.
58
59
 
59
60
  Args:
60
61
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
@@ -85,19 +86,17 @@ def bbox_iou(
85
86
  CIoU: bool = False,
86
87
  eps: float = 1e-7,
87
88
  ) -> torch.Tensor:
88
- """
89
- Calculate the Intersection over Union (IoU) between bounding boxes.
89
+ """Calculate the Intersection over Union (IoU) between bounding boxes.
90
90
 
91
- This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
92
- For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
93
- Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
94
- or (x1, y1, x2, y2) if `xywh=False`.
91
+ This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
92
+ may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
93
+ dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
95
94
 
96
95
  Args:
97
96
  box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
98
97
  box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
99
- xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
100
- (x1, y1, x2, y2) format.
98
+ xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
99
+ x2, y2) format.
101
100
  GIoU (bool, optional): If True, calculate Generalized IoU.
102
101
  DIoU (bool, optional): If True, calculate Distance IoU.
103
102
  CIoU (bool, optional): If True, calculate Complete IoU.
@@ -148,14 +147,13 @@ def bbox_iou(
148
147
 
149
148
 
150
149
  def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
151
- """
152
- Calculate masks IoU.
150
+ """Calculate masks IoU.
153
151
 
154
152
  Args:
155
153
  mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
156
- product of image width and height.
157
- mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
158
- product of image width and height.
154
+ product of image width and height.
155
+ mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
156
+ of image width and height.
159
157
  eps (float, optional): A small value to avoid division by zero.
160
158
 
161
159
  Returns:
@@ -169,8 +167,7 @@ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> tor
169
167
  def kpt_iou(
170
168
  kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
171
169
  ) -> torch.Tensor:
172
- """
173
- Calculate Object Keypoint Similarity (OKS).
170
+ """Calculate Object Keypoint Similarity (OKS).
174
171
 
175
172
  Args:
176
173
  kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
@@ -191,14 +188,14 @@ def kpt_iou(
191
188
 
192
189
 
193
190
  def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """
195
- Generate covariance matrix from oriented bounding boxes.
191
+ """Generate covariance matrix from oriented bounding boxes.
196
192
 
197
193
  Args:
198
194
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
199
195
 
200
196
  Returns:
201
- (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
197
+ (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Covariance matrix components (a, b, c) where the covariance
198
+ matrix is [[a, c], [c, b]], each of shape (N, 1).
202
199
  """
203
200
  # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
204
201
  gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
@@ -211,8 +208,7 @@ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
211
208
 
212
209
 
213
210
  def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
214
- """
215
- Calculate probabilistic IoU between oriented bounding boxes.
211
+ """Calculate probabilistic IoU between oriented bounding boxes.
216
212
 
217
213
  Args:
218
214
  obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
@@ -257,8 +253,7 @@ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: flo
257
253
 
258
254
 
259
255
  def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
260
- """
261
- Calculate the probabilistic IoU between oriented bounding boxes.
256
+ """Calculate the probabilistic IoU between oriented bounding boxes.
262
257
 
263
258
  Args:
264
259
  obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
@@ -294,8 +289,7 @@ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarr
294
289
 
295
290
 
296
291
  def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
297
- """
298
- Compute smoothed positive and negative Binary Cross-Entropy targets.
292
+ """Compute smoothed positive and negative Binary Cross-Entropy targets.
299
293
 
300
294
  Args:
301
295
  eps (float, optional): The epsilon value for label smoothing.
@@ -311,20 +305,18 @@ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
311
305
 
312
306
 
313
307
  class ConfusionMatrix(DataExportMixin):
314
- """
315
- A class for calculating and updating a confusion matrix for object detection and classification tasks.
308
+ """A class for calculating and updating a confusion matrix for object detection and classification tasks.
316
309
 
317
310
  Attributes:
318
311
  task (str): The type of task, either 'detect' or 'classify'.
319
312
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
320
- nc (int): The number of category.
313
+ nc (int): The number of classes.
321
314
  names (list[str]): The names of the classes, used as labels on the plot.
322
315
  matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
323
316
  """
324
317
 
325
318
  def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
326
- """
327
- Initialize a ConfusionMatrix instance.
319
+ """Initialize a ConfusionMatrix instance.
328
320
 
329
321
  Args:
330
322
  names (dict[int, str], optional): Names of classes, used as labels on the plot.
@@ -338,21 +330,20 @@ class ConfusionMatrix(DataExportMixin):
338
330
  self.matches = {} if save_matches else None
339
331
 
340
332
  def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
341
- """
342
- Append the matches to TP, FP, FN or GT list for the last batch.
333
+ """Append the matches to TP, FP, FN or GT list for the last batch.
343
334
 
344
- This method updates the matches dictionary by appending specific batch data
345
- to the appropriate match type (True Positive, False Positive, or False Negative).
335
+ This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
336
+ Positive, False Positive, or False Negative).
346
337
 
347
338
  Args:
348
339
  mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
349
- batch (dict[str, Any]): Batch data containing detection results with keys
350
- like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
340
+ batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
341
+ 'keypoints', 'masks'.
351
342
  idx (int): Index of the specific detection to append from the batch.
352
343
 
353
- Note:
354
- For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
355
- it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
344
+ Notes:
345
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
346
+ overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
356
347
  """
357
348
  if self.matches is None:
358
349
  return
@@ -364,8 +355,7 @@ class ConfusionMatrix(DataExportMixin):
364
355
  self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
365
356
 
366
357
  def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
367
- """
368
- Update confusion matrix for classification task.
358
+ """Update confusion matrix for classification task.
369
359
 
370
360
  Args:
371
361
  preds (list[N, min(nc,5)]): Predicted class labels.
@@ -382,15 +372,14 @@ class ConfusionMatrix(DataExportMixin):
382
372
  conf: float = 0.25,
383
373
  iou_thres: float = 0.45,
384
374
  ) -> None:
385
- """
386
- Update confusion matrix for object detection task.
375
+ """Update confusion matrix for object detection task.
387
376
 
388
377
  Args:
389
- detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
390
- Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
391
- Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
392
- batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
393
- 'cls' (Array[M]) keys, where M is the number of ground truth objects.
378
+ detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
379
+ information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
380
+ regular boxes or Array[N, 5] for OBB with angle.
381
+ batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
382
+ 5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
394
383
  conf (float, optional): Confidence threshold for detections.
395
384
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
396
385
  """
@@ -460,8 +449,7 @@ class ConfusionMatrix(DataExportMixin):
460
449
  return self.matrix
461
450
 
462
451
  def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
463
- """
464
- Return true positives and false positives.
452
+ """Return true positives and false positives.
465
453
 
466
454
  Returns:
467
455
  tp (np.ndarray): True positives.
@@ -473,8 +461,7 @@ class ConfusionMatrix(DataExportMixin):
473
461
  return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
474
462
 
475
463
  def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
476
- """
477
- Plot grid of GT, TP, FP, FN for each image.
464
+ """Plot grid of GT, TP, FP, FN for each image.
478
465
 
479
466
  Args:
480
467
  img (torch.Tensor): Image to plot onto.
@@ -513,8 +500,7 @@ class ConfusionMatrix(DataExportMixin):
513
500
  @TryExcept(msg="ConfusionMatrix plot failure")
514
501
  @plt_settings()
515
502
  def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
516
- """
517
- Plot the confusion matrix using matplotlib and save it to a file.
503
+ """Plot the confusion matrix using matplotlib and save it to a file.
518
504
 
519
505
  Args:
520
506
  normalize (bool, optional): Whether to normalize the confusion matrix.
@@ -535,7 +521,7 @@ class ConfusionMatrix(DataExportMixin):
535
521
  array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
536
522
  n = (self.nc + k - 1) // k # number of retained classes
537
523
  nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
538
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
524
+ ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
539
525
  xy_ticks = np.arange(len(ticklabels))
540
526
  tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
541
527
  label_fontsize = max(6, 12 - 0.1 * nc)
@@ -582,7 +568,7 @@ class ConfusionMatrix(DataExportMixin):
582
568
  fig.savefig(plot_fname, dpi=250)
583
569
  plt.close(fig)
584
570
  if on_plot:
585
- on_plot(plot_fname)
571
+ on_plot(plot_fname, {"type": "confusion_matrix", "matrix": self.matrix.tolist()})
586
572
 
587
573
  def print(self):
588
574
  """Print the confusion matrix to the console."""
@@ -590,16 +576,17 @@ class ConfusionMatrix(DataExportMixin):
590
576
  LOGGER.info(" ".join(map(str, self.matrix[i])))
591
577
 
592
578
  def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
593
- """
594
- Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
595
- normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.
579
+ """Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
580
+ normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
581
+ or SQL.
596
582
 
597
583
  Args:
598
584
  normalize (bool): Whether to normalize the confusion matrix values.
599
585
  decimals (int): Number of decimal places to round the output values to.
600
586
 
601
587
  Returns:
602
- (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.
588
+ (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
589
+ values for all actual classes.
603
590
 
604
591
  Examples:
605
592
  >>> results = model.val(data="coco8.yaml", plots=True)
@@ -608,7 +595,7 @@ class ConfusionMatrix(DataExportMixin):
608
595
  """
609
596
  import re
610
597
 
611
- names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
598
+ names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
612
599
  clean_names, seen = [], set()
613
600
  for name in names:
614
601
  clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
@@ -643,8 +630,7 @@ def plot_pr_curve(
643
630
  names: dict[int, str] = {},
644
631
  on_plot=None,
645
632
  ):
646
- """
647
- Plot precision-recall curve.
633
+ """Plot precision-recall curve.
648
634
 
649
635
  Args:
650
636
  px (np.ndarray): X values for the PR curve.
@@ -663,7 +649,7 @@ def plot_pr_curve(
663
649
  for i, y in enumerate(py.T):
664
650
  ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
665
651
  else:
666
- ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
652
+ ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
667
653
 
668
654
  ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
669
655
  ax.set_xlabel("Recall")
@@ -675,7 +661,9 @@ def plot_pr_curve(
675
661
  fig.savefig(save_dir, dpi=250)
676
662
  plt.close(fig)
677
663
  if on_plot:
678
- on_plot(save_dir)
664
+ # Pass PR curve data for interactive plotting (class names stored at model level)
665
+ # Transpose py to match other curves: y[class][point] format
666
+ on_plot(save_dir, {"type": "pr_curve", "x": px.tolist(), "y": py.T.tolist(), "ap": ap.tolist()})
679
667
 
680
668
 
681
669
  @plt_settings()
@@ -688,8 +676,7 @@ def plot_mc_curve(
688
676
  ylabel: str = "Metric",
689
677
  on_plot=None,
690
678
  ):
691
- """
692
- Plot metric-confidence curve.
679
+ """Plot metric-confidence curve.
693
680
 
694
681
  Args:
695
682
  px (np.ndarray): X values for the metric-confidence curve.
@@ -708,7 +695,7 @@ def plot_mc_curve(
708
695
  for i, y in enumerate(py):
709
696
  ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
710
697
  else:
711
- ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
698
+ ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
712
699
 
713
700
  y = smooth(py.mean(0), 0.1)
714
701
  ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
@@ -721,12 +708,12 @@ def plot_mc_curve(
721
708
  fig.savefig(save_dir, dpi=250)
722
709
  plt.close(fig)
723
710
  if on_plot:
724
- on_plot(save_dir)
711
+ # Pass metric-confidence curve data for interactive plotting (class names stored at model level)
712
+ on_plot(save_dir, {"type": f"{ylabel.lower()}_curve", "x": px.tolist(), "y": py.tolist()})
725
713
 
726
714
 
727
715
  def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
728
- """
729
- Compute the average precision (AP) given the recall and precision curves.
716
+ """Compute the average precision (AP) given the recall and precision curves.
730
717
 
731
718
  Args:
732
719
  recall (list): The recall curve.
@@ -769,8 +756,7 @@ def ap_per_class(
769
756
  eps: float = 1e-16,
770
757
  prefix: str = "",
771
758
  ) -> tuple:
772
- """
773
- Compute the average precision per class for object detection evaluation.
759
+ """Compute the average precision per class for object detection evaluation.
774
760
 
775
761
  Args:
776
762
  tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
@@ -855,8 +841,7 @@ def ap_per_class(
855
841
 
856
842
 
857
843
  class Metric(SimpleClass):
858
- """
859
- Class for computing evaluation metrics for Ultralytics YOLO models.
844
+ """Class for computing evaluation metrics for Ultralytics YOLO models.
860
845
 
861
846
  Attributes:
862
847
  p (list): Precision for each class. Shape: (nc,).
@@ -894,8 +879,7 @@ class Metric(SimpleClass):
894
879
 
895
880
  @property
896
881
  def ap50(self) -> np.ndarray | list:
897
- """
898
- Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
882
+ """Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
899
883
 
900
884
  Returns:
901
885
  (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
@@ -904,8 +888,7 @@ class Metric(SimpleClass):
904
888
 
905
889
  @property
906
890
  def ap(self) -> np.ndarray | list:
907
- """
908
- Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
891
+ """Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
909
892
 
910
893
  Returns:
911
894
  (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
@@ -914,8 +897,7 @@ class Metric(SimpleClass):
914
897
 
915
898
  @property
916
899
  def mp(self) -> float:
917
- """
918
- Return the Mean Precision of all classes.
900
+ """Return the Mean Precision of all classes.
919
901
 
920
902
  Returns:
921
903
  (float): The mean precision of all classes.
@@ -924,8 +906,7 @@ class Metric(SimpleClass):
924
906
 
925
907
  @property
926
908
  def mr(self) -> float:
927
- """
928
- Return the Mean Recall of all classes.
909
+ """Return the Mean Recall of all classes.
929
910
 
930
911
  Returns:
931
912
  (float): The mean recall of all classes.
@@ -934,8 +915,7 @@ class Metric(SimpleClass):
934
915
 
935
916
  @property
936
917
  def map50(self) -> float:
937
- """
938
- Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
918
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
939
919
 
940
920
  Returns:
941
921
  (float): The mAP at an IoU threshold of 0.5.
@@ -944,8 +924,7 @@ class Metric(SimpleClass):
944
924
 
945
925
  @property
946
926
  def map75(self) -> float:
947
- """
948
- Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
927
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
949
928
 
950
929
  Returns:
951
930
  (float): The mAP at an IoU threshold of 0.75.
@@ -954,8 +933,7 @@ class Metric(SimpleClass):
954
933
 
955
934
  @property
956
935
  def map(self) -> float:
957
- """
958
- Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
936
+ """Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
959
937
 
960
938
  Returns:
961
939
  (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
@@ -984,8 +962,7 @@ class Metric(SimpleClass):
984
962
  return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
985
963
 
986
964
  def update(self, results: tuple):
987
- """
988
- Update the evaluation metrics with a new set of results.
965
+ """Update the evaluation metrics with a new set of results.
989
966
 
990
967
  Args:
991
968
  results (tuple): A tuple containing evaluation metrics:
@@ -1030,15 +1007,15 @@ class Metric(SimpleClass):
1030
1007
 
1031
1008
 
1032
1009
  class DetMetrics(SimpleClass, DataExportMixin):
1033
- """
1034
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1010
+ """Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1035
1011
 
1036
1012
  Attributes:
1037
1013
  names (dict[int, str]): A dictionary of class names.
1038
1014
  box (Metric): An instance of the Metric class for storing detection results.
1039
1015
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1016
  task (str): The task type, set to 'detect'.
1041
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1017
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1018
+ target classes, and target images.
1042
1019
  nt_per_class: Number of targets per class.
1043
1020
  nt_per_image: Number of targets per image.
1044
1021
 
@@ -1059,8 +1036,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1059
1036
  """
1060
1037
 
1061
1038
  def __init__(self, names: dict[int, str] = {}) -> None:
1062
- """
1063
- Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1039
+ """Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1064
1040
 
1065
1041
  Args:
1066
1042
  names (dict[int, str], optional): Dictionary of class names.
@@ -1074,19 +1050,17 @@ class DetMetrics(SimpleClass, DataExportMixin):
1074
1050
  self.nt_per_image = None
1075
1051
 
1076
1052
  def update_stats(self, stat: dict[str, Any]) -> None:
1077
- """
1078
- Update statistics by appending new values to existing stat collections.
1053
+ """Update statistics by appending new values to existing stat collections.
1079
1054
 
1080
1055
  Args:
1081
- stat (dict[str, any]): Dictionary containing new statistical values to append.
1082
- Keys should match existing keys in self.stats.
1056
+ stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
1057
+ keys in self.stats.
1083
1058
  """
1084
1059
  for k in self.stats.keys():
1085
1060
  self.stats[k].append(stat[k])
1086
1061
 
1087
1062
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1088
- """
1089
- Process predicted results for object detection and update metrics.
1063
+ """Process predicted results for object detection and update metrics.
1090
1064
 
1091
1065
  Args:
1092
1066
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1152,8 +1126,8 @@ class DetMetrics(SimpleClass, DataExportMixin):
1152
1126
  @property
1153
1127
  def results_dict(self) -> dict[str, float]:
1154
1128
  """Return dictionary of computed performance metrics and statistics."""
1155
- keys = self.keys + ["fitness"]
1156
- values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results() + [self.fitness]))
1129
+ keys = [*self.keys, "fitness"]
1130
+ values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
1157
1131
  return dict(zip(keys, values))
1158
1132
 
1159
1133
  @property
@@ -1167,16 +1141,16 @@ class DetMetrics(SimpleClass, DataExportMixin):
1167
1141
  return self.box.curves_results
1168
1142
 
1169
1143
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1170
- """
1171
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
1172
- scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1144
+ """Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
1145
+ shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1173
1146
 
1174
1147
  Args:
1175
- normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1176
- decimals (int): Number of decimal places to round the metrics values to.
1148
+ normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1149
+ decimals (int): Number of decimal places to round the metrics values to.
1177
1150
 
1178
1151
  Returns:
1179
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1152
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1153
+ values.
1180
1154
 
1181
1155
  Examples:
1182
1156
  >>> results = model.val(data="coco8.yaml")
@@ -1202,8 +1176,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1202
1176
 
1203
1177
 
1204
1178
  class SegmentMetrics(DetMetrics):
1205
- """
1206
- Calculate and aggregate detection and segmentation metrics over a given set of classes.
1179
+ """Calculate and aggregate detection and segmentation metrics over a given set of classes.
1207
1180
 
1208
1181
  Attributes:
1209
1182
  names (dict[int, str]): Dictionary of class names.
@@ -1211,7 +1184,8 @@ class SegmentMetrics(DetMetrics):
1211
1184
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1212
1185
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1213
1186
  task (str): The task type, set to 'segment'.
1214
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1187
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1188
+ target classes, and target images.
1215
1189
  nt_per_class: Number of targets per class.
1216
1190
  nt_per_image: Number of targets per image.
1217
1191
 
@@ -1228,8 +1202,7 @@ class SegmentMetrics(DetMetrics):
1228
1202
  """
1229
1203
 
1230
1204
  def __init__(self, names: dict[int, str] = {}) -> None:
1231
- """
1232
- Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1205
+ """Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1233
1206
 
1234
1207
  Args:
1235
1208
  names (dict[int, str], optional): Dictionary of class names.
@@ -1240,8 +1213,7 @@ class SegmentMetrics(DetMetrics):
1240
1213
  self.stats["tp_m"] = [] # add additional stats for masks
1241
1214
 
1242
1215
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1243
- """
1244
- Process the detection and segmentation metrics over the given set of predictions.
1216
+ """Process the detection and segmentation metrics over the given set of predictions.
1245
1217
 
1246
1218
  Args:
1247
1219
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1270,7 +1242,8 @@ class SegmentMetrics(DetMetrics):
1270
1242
  @property
1271
1243
  def keys(self) -> list[str]:
1272
1244
  """Return a list of keys for accessing metrics."""
1273
- return DetMetrics.keys.fget(self) + [
1245
+ return [
1246
+ *DetMetrics.keys.fget(self),
1274
1247
  "metrics/precision(M)",
1275
1248
  "metrics/recall(M)",
1276
1249
  "metrics/mAP50(M)",
@@ -1298,7 +1271,8 @@ class SegmentMetrics(DetMetrics):
1298
1271
  @property
1299
1272
  def curves(self) -> list[str]:
1300
1273
  """Return a list of curves for accessing specific metrics curves."""
1301
- return DetMetrics.curves.fget(self) + [
1274
+ return [
1275
+ *DetMetrics.curves.fget(self),
1302
1276
  "Precision-Recall(M)",
1303
1277
  "F1-Confidence(M)",
1304
1278
  "Precision-Confidence(M)",
@@ -1311,16 +1285,17 @@ class SegmentMetrics(DetMetrics):
1311
1285
  return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1312
1286
 
1313
1287
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1314
- """
1315
- Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both
1316
- box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1288
+ """Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
1289
+ both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
1290
+ each class.
1317
1291
 
1318
1292
  Args:
1319
- normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1293
+ normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1320
1294
  decimals (int): Number of decimal places to round the metrics values to.
1321
1295
 
1322
1296
  Returns:
1323
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1297
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1298
+ values.
1324
1299
 
1325
1300
  Examples:
1326
1301
  >>> results = model.val(data="coco8-seg.yaml")
@@ -1339,8 +1314,7 @@ class SegmentMetrics(DetMetrics):
1339
1314
 
1340
1315
 
1341
1316
  class PoseMetrics(DetMetrics):
1342
- """
1343
- Calculate and aggregate detection and pose metrics over a given set of classes.
1317
+ """Calculate and aggregate detection and pose metrics over a given set of classes.
1344
1318
 
1345
1319
  Attributes:
1346
1320
  names (dict[int, str]): Dictionary of class names.
@@ -1348,7 +1322,8 @@ class PoseMetrics(DetMetrics):
1348
1322
  box (Metric): An instance of the Metric class for storing detection results.
1349
1323
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1350
1324
  task (str): The task type, set to 'pose'.
1351
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1325
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1326
+ target classes, and target images.
1352
1327
  nt_per_class: Number of targets per class.
1353
1328
  nt_per_image: Number of targets per image.
1354
1329
 
@@ -1365,8 +1340,7 @@ class PoseMetrics(DetMetrics):
1365
1340
  """
1366
1341
 
1367
1342
  def __init__(self, names: dict[int, str] = {}) -> None:
1368
- """
1369
- Initialize the PoseMetrics class with directory path, class names, and plotting options.
1343
+ """Initialize the PoseMetrics class with directory path, class names, and plotting options.
1370
1344
 
1371
1345
  Args:
1372
1346
  names (dict[int, str], optional): Dictionary of class names.
@@ -1377,8 +1351,7 @@ class PoseMetrics(DetMetrics):
1377
1351
  self.stats["tp_p"] = [] # add additional stats for pose
1378
1352
 
1379
1353
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1380
- """
1381
- Process the detection and pose metrics over the given set of predictions.
1354
+ """Process the detection and pose metrics over the given set of predictions.
1382
1355
 
1383
1356
  Args:
1384
1357
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1407,7 +1380,8 @@ class PoseMetrics(DetMetrics):
1407
1380
  @property
1408
1381
  def keys(self) -> list[str]:
1409
1382
  """Return a list of evaluation metric keys."""
1410
- return DetMetrics.keys.fget(self) + [
1383
+ return [
1384
+ *DetMetrics.keys.fget(self),
1411
1385
  "metrics/precision(P)",
1412
1386
  "metrics/recall(P)",
1413
1387
  "metrics/mAP50(P)",
@@ -1435,7 +1409,8 @@ class PoseMetrics(DetMetrics):
1435
1409
  @property
1436
1410
  def curves(self) -> list[str]:
1437
1411
  """Return a list of curves for accessing specific metrics curves."""
1438
- return DetMetrics.curves.fget(self) + [
1412
+ return [
1413
+ *DetMetrics.curves.fget(self),
1439
1414
  "Precision-Recall(B)",
1440
1415
  "F1-Confidence(B)",
1441
1416
  "Precision-Confidence(B)",
@@ -1452,16 +1427,16 @@ class PoseMetrics(DetMetrics):
1452
1427
  return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1453
1428
 
1454
1429
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1455
- """
1456
- Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and
1457
- pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1430
+ """Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
1431
+ and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1458
1432
 
1459
1433
  Args:
1460
- normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1434
+ normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1461
1435
  decimals (int): Number of decimal places to round the metrics values to.
1462
1436
 
1463
1437
  Returns:
1464
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1438
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1439
+ values.
1465
1440
 
1466
1441
  Examples:
1467
1442
  >>> results = model.val(data="coco8-pose.yaml")
@@ -1480,8 +1455,7 @@ class PoseMetrics(DetMetrics):
1480
1455
 
1481
1456
 
1482
1457
  class ClassifyMetrics(SimpleClass, DataExportMixin):
1483
- """
1484
- Class for computing classification metrics including top-1 and top-5 accuracy.
1458
+ """Class for computing classification metrics including top-1 and top-5 accuracy.
1485
1459
 
1486
1460
  Attributes:
1487
1461
  top1 (float): The top-1 accuracy.
@@ -1507,8 +1481,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1507
1481
  self.task = "classify"
1508
1482
 
1509
1483
  def process(self, targets: torch.Tensor, pred: torch.Tensor):
1510
- """
1511
- Process target classes and predicted classes to compute metrics.
1484
+ """Process target classes and predicted classes to compute metrics.
1512
1485
 
1513
1486
  Args:
1514
1487
  targets (torch.Tensor): Target classes.
@@ -1527,7 +1500,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1527
1500
  @property
1528
1501
  def results_dict(self) -> dict[str, float]:
1529
1502
  """Return a dictionary with model's performance metrics and fitness score."""
1530
- return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
1503
+ return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
1531
1504
 
1532
1505
  @property
1533
1506
  def keys(self) -> list[str]:
@@ -1545,11 +1518,10 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1545
1518
  return []
1546
1519
 
1547
1520
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
1548
- """
1549
- Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1521
+ """Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1550
1522
 
1551
1523
  Args:
1552
- normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1524
+ normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1553
1525
  decimals (int): Number of decimal places to round the metrics values to.
1554
1526
 
1555
1527
  Returns:
@@ -1564,15 +1536,15 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1564
1536
 
1565
1537
 
1566
1538
  class OBBMetrics(DetMetrics):
1567
- """
1568
- Metrics for evaluating oriented bounding box (OBB) detection.
1539
+ """Metrics for evaluating oriented bounding box (OBB) detection.
1569
1540
 
1570
1541
  Attributes:
1571
1542
  names (dict[int, str]): Dictionary of class names.
1572
1543
  box (Metric): An instance of the Metric class for storing detection results.
1573
1544
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1574
1545
  task (str): The task type, set to 'obb'.
1575
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1546
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1547
+ target classes, and target images.
1576
1548
  nt_per_class: Number of targets per class.
1577
1549
  nt_per_image: Number of targets per image.
1578
1550
 
@@ -1581,8 +1553,7 @@ class OBBMetrics(DetMetrics):
1581
1553
  """
1582
1554
 
1583
1555
  def __init__(self, names: dict[int, str] = {}) -> None:
1584
- """
1585
- Initialize an OBBMetrics instance with directory, plotting, and class names.
1556
+ """Initialize an OBBMetrics instance with directory, plotting, and class names.
1586
1557
 
1587
1558
  Args:
1588
1559
  names (dict[int, str], optional): Dictionary of class names.