dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,17 @@ import torch
15
15
  from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
16
16
 
17
17
  OKS_SIGMA = (
18
- np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
18
+ np.array(
19
+ [0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89],
20
+ dtype=np.float32,
21
+ )
19
22
  / 10.0
20
23
  )
24
+ RLE_WEIGHT = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5])
21
25
 
22
26
 
23
27
  def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
24
- """
25
- Calculate the intersection over box2 area given box1 and box2.
28
+ """Calculate the intersection over box2 area given box1 and box2.
26
29
 
27
30
  Args:
28
31
  box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
@@ -53,8 +56,7 @@ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float =
53
56
 
54
57
 
55
58
  def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
56
- """
57
- Calculate intersection-over-union (IoU) of boxes.
59
+ """Calculate intersection-over-union (IoU) of boxes.
58
60
 
59
61
  Args:
60
62
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
@@ -85,19 +87,17 @@ def bbox_iou(
85
87
  CIoU: bool = False,
86
88
  eps: float = 1e-7,
87
89
  ) -> torch.Tensor:
88
- """
89
- Calculate the Intersection over Union (IoU) between bounding boxes.
90
+ """Calculate the Intersection over Union (IoU) between bounding boxes.
90
91
 
91
- This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
92
- For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
93
- Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
94
- or (x1, y1, x2, y2) if `xywh=False`.
92
+ This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
93
+ may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
94
+ dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
95
95
 
96
96
  Args:
97
97
  box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
98
98
  box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
99
- xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
100
- (x1, y1, x2, y2) format.
99
+ xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
100
+ x2, y2) format.
101
101
  GIoU (bool, optional): If True, calculate Generalized IoU.
102
102
  DIoU (bool, optional): If True, calculate Distance IoU.
103
103
  CIoU (bool, optional): If True, calculate Complete IoU.
@@ -148,14 +148,13 @@ def bbox_iou(
148
148
 
149
149
 
150
150
  def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
151
- """
152
- Calculate masks IoU.
151
+ """Calculate masks IoU.
153
152
 
154
153
  Args:
155
154
  mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
156
- product of image width and height.
157
- mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
158
- product of image width and height.
155
+ product of image width and height.
156
+ mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
157
+ of image width and height.
159
158
  eps (float, optional): A small value to avoid division by zero.
160
159
 
161
160
  Returns:
@@ -169,8 +168,7 @@ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> tor
169
168
  def kpt_iou(
170
169
  kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
171
170
  ) -> torch.Tensor:
172
- """
173
- Calculate Object Keypoint Similarity (OKS).
171
+ """Calculate Object Keypoint Similarity (OKS).
174
172
 
175
173
  Args:
176
174
  kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
@@ -191,14 +189,14 @@ def kpt_iou(
191
189
 
192
190
 
193
191
  def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """
195
- Generate covariance matrix from oriented bounding boxes.
192
+ """Generate covariance matrix from oriented bounding boxes.
196
193
 
197
194
  Args:
198
195
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
199
196
 
200
197
  Returns:
201
- (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
198
+ (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Covariance matrix components (a, b, c) where the covariance
199
+ matrix is [[a, c], [c, b]], each of shape (N, 1).
202
200
  """
203
201
  # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
204
202
  gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
@@ -211,8 +209,7 @@ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
211
209
 
212
210
 
213
211
  def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
214
- """
215
- Calculate probabilistic IoU between oriented bounding boxes.
212
+ """Calculate probabilistic IoU between oriented bounding boxes.
216
213
 
217
214
  Args:
218
215
  obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
@@ -257,8 +254,7 @@ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: flo
257
254
 
258
255
 
259
256
  def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
260
- """
261
- Calculate the probabilistic IoU between oriented bounding boxes.
257
+ """Calculate the probabilistic IoU between oriented bounding boxes.
262
258
 
263
259
  Args:
264
260
  obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
@@ -294,8 +290,7 @@ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarr
294
290
 
295
291
 
296
292
  def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
297
- """
298
- Compute smoothed positive and negative Binary Cross-Entropy targets.
293
+ """Compute smoothed positive and negative Binary Cross-Entropy targets.
299
294
 
300
295
  Args:
301
296
  eps (float, optional): The epsilon value for label smoothing.
@@ -311,20 +306,18 @@ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
311
306
 
312
307
 
313
308
  class ConfusionMatrix(DataExportMixin):
314
- """
315
- A class for calculating and updating a confusion matrix for object detection and classification tasks.
309
+ """A class for calculating and updating a confusion matrix for object detection and classification tasks.
316
310
 
317
311
  Attributes:
318
312
  task (str): The type of task, either 'detect' or 'classify'.
319
313
  matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
320
- nc (int): The number of category.
314
+ nc (int): The number of classes.
321
315
  names (list[str]): The names of the classes, used as labels on the plot.
322
316
  matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
323
317
  """
324
318
 
325
- def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
326
- """
327
- Initialize a ConfusionMatrix instance.
319
+ def __init__(self, names: dict[int, str] = {}, task: str = "detect", save_matches: bool = False):
320
+ """Initialize a ConfusionMatrix instance.
328
321
 
329
322
  Args:
330
323
  names (dict[int, str], optional): Names of classes, used as labels on the plot.
@@ -338,21 +331,20 @@ class ConfusionMatrix(DataExportMixin):
338
331
  self.matches = {} if save_matches else None
339
332
 
340
333
  def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
341
- """
342
- Append the matches to TP, FP, FN or GT list for the last batch.
334
+ """Append the matches to TP, FP, FN or GT list for the last batch.
343
335
 
344
- This method updates the matches dictionary by appending specific batch data
345
- to the appropriate match type (True Positive, False Positive, or False Negative).
336
+ This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
337
+ Positive, False Positive, or False Negative).
346
338
 
347
339
  Args:
348
340
  mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
349
- batch (dict[str, Any]): Batch data containing detection results with keys
350
- like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
341
+ batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
342
+ 'keypoints', 'masks'.
351
343
  idx (int): Index of the specific detection to append from the batch.
352
344
 
353
- Note:
354
- For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
355
- it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
345
+ Notes:
346
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
347
+ overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
356
348
  """
357
349
  if self.matches is None:
358
350
  return
@@ -364,8 +356,7 @@ class ConfusionMatrix(DataExportMixin):
364
356
  self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
365
357
 
366
358
  def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
367
- """
368
- Update confusion matrix for classification task.
359
+ """Update confusion matrix for classification task.
369
360
 
370
361
  Args:
371
362
  preds (list[N, min(nc,5)]): Predicted class labels.
@@ -382,15 +373,14 @@ class ConfusionMatrix(DataExportMixin):
382
373
  conf: float = 0.25,
383
374
  iou_thres: float = 0.45,
384
375
  ) -> None:
385
- """
386
- Update confusion matrix for object detection task.
376
+ """Update confusion matrix for object detection task.
387
377
 
388
378
  Args:
389
- detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
390
- Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
391
- Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
392
- batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
393
- 'cls' (Array[M]) keys, where M is the number of ground truth objects.
379
+ detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
380
+ information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
381
+ regular boxes or Array[N, 5] for OBB with angle.
382
+ batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
383
+ 5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
394
384
  conf (float, optional): Confidence threshold for detections.
395
385
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
396
386
  """
@@ -460,8 +450,7 @@ class ConfusionMatrix(DataExportMixin):
460
450
  return self.matrix
461
451
 
462
452
  def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
463
- """
464
- Return true positives and false positives.
453
+ """Return true positives and false positives.
465
454
 
466
455
  Returns:
467
456
  tp (np.ndarray): True positives.
@@ -473,8 +462,7 @@ class ConfusionMatrix(DataExportMixin):
473
462
  return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
474
463
 
475
464
  def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
476
- """
477
- Plot grid of GT, TP, FP, FN for each image.
465
+ """Plot grid of GT, TP, FP, FN for each image.
478
466
 
479
467
  Args:
480
468
  img (torch.Tensor): Image to plot onto.
@@ -513,8 +501,7 @@ class ConfusionMatrix(DataExportMixin):
513
501
  @TryExcept(msg="ConfusionMatrix plot failure")
514
502
  @plt_settings()
515
503
  def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
516
- """
517
- Plot the confusion matrix using matplotlib and save it to a file.
504
+ """Plot the confusion matrix using matplotlib and save it to a file.
518
505
 
519
506
  Args:
520
507
  normalize (bool, optional): Whether to normalize the confusion matrix.
@@ -535,7 +522,7 @@ class ConfusionMatrix(DataExportMixin):
535
522
  array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
536
523
  n = (self.nc + k - 1) // k # number of retained classes
537
524
  nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
538
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
525
+ ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
539
526
  xy_ticks = np.arange(len(ticklabels))
540
527
  tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
541
528
  label_fontsize = max(6, 12 - 0.1 * nc)
@@ -582,7 +569,7 @@ class ConfusionMatrix(DataExportMixin):
582
569
  fig.savefig(plot_fname, dpi=250)
583
570
  plt.close(fig)
584
571
  if on_plot:
585
- on_plot(plot_fname)
572
+ on_plot(plot_fname, {"type": "confusion_matrix", "matrix": self.matrix.tolist()})
586
573
 
587
574
  def print(self):
588
575
  """Print the confusion matrix to the console."""
@@ -590,16 +577,17 @@ class ConfusionMatrix(DataExportMixin):
590
577
  LOGGER.info(" ".join(map(str, self.matrix[i])))
591
578
 
592
579
  def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
593
- """
594
- Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
595
- normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.
580
+ """Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
581
+ normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
582
+ or SQL.
596
583
 
597
584
  Args:
598
585
  normalize (bool): Whether to normalize the confusion matrix values.
599
586
  decimals (int): Number of decimal places to round the output values to.
600
587
 
601
588
  Returns:
602
- (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.
589
+ (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
590
+ values for all actual classes.
603
591
 
604
592
  Examples:
605
593
  >>> results = model.val(data="coco8.yaml", plots=True)
@@ -608,7 +596,7 @@ class ConfusionMatrix(DataExportMixin):
608
596
  """
609
597
  import re
610
598
 
611
- names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
599
+ names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
612
600
  clean_names, seen = [], set()
613
601
  for name in names:
614
602
  clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
@@ -643,8 +631,7 @@ def plot_pr_curve(
643
631
  names: dict[int, str] = {},
644
632
  on_plot=None,
645
633
  ):
646
- """
647
- Plot precision-recall curve.
634
+ """Plot precision-recall curve.
648
635
 
649
636
  Args:
650
637
  px (np.ndarray): X values for the PR curve.
@@ -663,7 +650,7 @@ def plot_pr_curve(
663
650
  for i, y in enumerate(py.T):
664
651
  ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
665
652
  else:
666
- ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
653
+ ax.plot(px, py, linewidth=1, color="gray") # plot(recall, precision)
667
654
 
668
655
  ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
669
656
  ax.set_xlabel("Recall")
@@ -675,7 +662,9 @@ def plot_pr_curve(
675
662
  fig.savefig(save_dir, dpi=250)
676
663
  plt.close(fig)
677
664
  if on_plot:
678
- on_plot(save_dir)
665
+ # Pass PR curve data for interactive plotting (class names stored at model level)
666
+ # Transpose py to match other curves: y[class][point] format
667
+ on_plot(save_dir, {"type": "pr_curve", "x": px.tolist(), "y": py.T.tolist(), "ap": ap.tolist()})
679
668
 
680
669
 
681
670
  @plt_settings()
@@ -688,8 +677,7 @@ def plot_mc_curve(
688
677
  ylabel: str = "Metric",
689
678
  on_plot=None,
690
679
  ):
691
- """
692
- Plot metric-confidence curve.
680
+ """Plot metric-confidence curve.
693
681
 
694
682
  Args:
695
683
  px (np.ndarray): X values for the metric-confidence curve.
@@ -708,7 +696,7 @@ def plot_mc_curve(
708
696
  for i, y in enumerate(py):
709
697
  ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
710
698
  else:
711
- ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
699
+ ax.plot(px, py.T, linewidth=1, color="gray") # plot(confidence, metric)
712
700
 
713
701
  y = smooth(py.mean(0), 0.1)
714
702
  ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
@@ -721,12 +709,12 @@ def plot_mc_curve(
721
709
  fig.savefig(save_dir, dpi=250)
722
710
  plt.close(fig)
723
711
  if on_plot:
724
- on_plot(save_dir)
712
+ # Pass metric-confidence curve data for interactive plotting (class names stored at model level)
713
+ on_plot(save_dir, {"type": f"{ylabel.lower()}_curve", "x": px.tolist(), "y": py.tolist()})
725
714
 
726
715
 
727
716
  def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
728
- """
729
- Compute the average precision (AP) given the recall and precision curves.
717
+ """Compute the average precision (AP) given the recall and precision curves.
730
718
 
731
719
  Args:
732
720
  recall (list): The recall curve.
@@ -769,8 +757,7 @@ def ap_per_class(
769
757
  eps: float = 1e-16,
770
758
  prefix: str = "",
771
759
  ) -> tuple:
772
- """
773
- Compute the average precision per class for object detection evaluation.
760
+ """Compute the average precision per class for object detection evaluation.
774
761
 
775
762
  Args:
776
763
  tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
@@ -855,8 +842,7 @@ def ap_per_class(
855
842
 
856
843
 
857
844
  class Metric(SimpleClass):
858
- """
859
- Class for computing evaluation metrics for Ultralytics YOLO models.
845
+ """Class for computing evaluation metrics for Ultralytics YOLO models.
860
846
 
861
847
  Attributes:
862
848
  p (list): Precision for each class. Shape: (nc,).
@@ -894,8 +880,7 @@ class Metric(SimpleClass):
894
880
 
895
881
  @property
896
882
  def ap50(self) -> np.ndarray | list:
897
- """
898
- Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
883
+ """Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
899
884
 
900
885
  Returns:
901
886
  (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
@@ -904,8 +889,7 @@ class Metric(SimpleClass):
904
889
 
905
890
  @property
906
891
  def ap(self) -> np.ndarray | list:
907
- """
908
- Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
892
+ """Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
909
893
 
910
894
  Returns:
911
895
  (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
@@ -914,8 +898,7 @@ class Metric(SimpleClass):
914
898
 
915
899
  @property
916
900
  def mp(self) -> float:
917
- """
918
- Return the Mean Precision of all classes.
901
+ """Return the Mean Precision of all classes.
919
902
 
920
903
  Returns:
921
904
  (float): The mean precision of all classes.
@@ -924,8 +907,7 @@ class Metric(SimpleClass):
924
907
 
925
908
  @property
926
909
  def mr(self) -> float:
927
- """
928
- Return the Mean Recall of all classes.
910
+ """Return the Mean Recall of all classes.
929
911
 
930
912
  Returns:
931
913
  (float): The mean recall of all classes.
@@ -934,8 +916,7 @@ class Metric(SimpleClass):
934
916
 
935
917
  @property
936
918
  def map50(self) -> float:
937
- """
938
- Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
919
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
939
920
 
940
921
  Returns:
941
922
  (float): The mAP at an IoU threshold of 0.5.
@@ -944,8 +925,7 @@ class Metric(SimpleClass):
944
925
 
945
926
  @property
946
927
  def map75(self) -> float:
947
- """
948
- Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
928
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
949
929
 
950
930
  Returns:
951
931
  (float): The mAP at an IoU threshold of 0.75.
@@ -954,8 +934,7 @@ class Metric(SimpleClass):
954
934
 
955
935
  @property
956
936
  def map(self) -> float:
957
- """
958
- Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
937
+ """Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
959
938
 
960
939
  Returns:
961
940
  (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
@@ -984,8 +963,7 @@ class Metric(SimpleClass):
984
963
  return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
985
964
 
986
965
  def update(self, results: tuple):
987
- """
988
- Update the evaluation metrics with a new set of results.
966
+ """Update the evaluation metrics with a new set of results.
989
967
 
990
968
  Args:
991
969
  results (tuple): A tuple containing evaluation metrics:
@@ -1030,15 +1008,15 @@ class Metric(SimpleClass):
1030
1008
 
1031
1009
 
1032
1010
  class DetMetrics(SimpleClass, DataExportMixin):
1033
- """
1034
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1011
+ """Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1035
1012
 
1036
1013
  Attributes:
1037
1014
  names (dict[int, str]): A dictionary of class names.
1038
1015
  box (Metric): An instance of the Metric class for storing detection results.
1039
1016
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1017
  task (str): The task type, set to 'detect'.
1041
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1018
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1019
+ target classes, and target images.
1042
1020
  nt_per_class: Number of targets per class.
1043
1021
  nt_per_image: Number of targets per image.
1044
1022
 
@@ -1059,8 +1037,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1059
1037
  """
1060
1038
 
1061
1039
  def __init__(self, names: dict[int, str] = {}) -> None:
1062
- """
1063
- Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1040
+ """Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1064
1041
 
1065
1042
  Args:
1066
1043
  names (dict[int, str], optional): Dictionary of class names.
@@ -1074,19 +1051,17 @@ class DetMetrics(SimpleClass, DataExportMixin):
1074
1051
  self.nt_per_image = None
1075
1052
 
1076
1053
  def update_stats(self, stat: dict[str, Any]) -> None:
1077
- """
1078
- Update statistics by appending new values to existing stat collections.
1054
+ """Update statistics by appending new values to existing stat collections.
1079
1055
 
1080
1056
  Args:
1081
- stat (dict[str, any]): Dictionary containing new statistical values to append.
1082
- Keys should match existing keys in self.stats.
1057
+ stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
1058
+ keys in self.stats.
1083
1059
  """
1084
1060
  for k in self.stats.keys():
1085
1061
  self.stats[k].append(stat[k])
1086
1062
 
1087
1063
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1088
- """
1089
- Process predicted results for object detection and update metrics.
1064
+ """Process predicted results for object detection and update metrics.
1090
1065
 
1091
1066
  Args:
1092
1067
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1152,8 +1127,8 @@ class DetMetrics(SimpleClass, DataExportMixin):
1152
1127
  @property
1153
1128
  def results_dict(self) -> dict[str, float]:
1154
1129
  """Return dictionary of computed performance metrics and statistics."""
1155
- keys = self.keys + ["fitness"]
1156
- values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results() + [self.fitness]))
1130
+ keys = [*self.keys, "fitness"]
1131
+ values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
1157
1132
  return dict(zip(keys, values))
1158
1133
 
1159
1134
  @property
@@ -1167,16 +1142,16 @@ class DetMetrics(SimpleClass, DataExportMixin):
1167
1142
  return self.box.curves_results
1168
1143
 
1169
1144
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1170
- """
1171
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
1172
- scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1145
+ """Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
1146
+ shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1173
1147
 
1174
1148
  Args:
1175
- normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1176
- decimals (int): Number of decimal places to round the metrics values to.
1149
+ normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1150
+ decimals (int): Number of decimal places to round the metrics values to.
1177
1151
 
1178
1152
  Returns:
1179
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1153
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1154
+ values.
1180
1155
 
1181
1156
  Examples:
1182
1157
  >>> results = model.val(data="coco8.yaml")
@@ -1202,8 +1177,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1202
1177
 
1203
1178
 
1204
1179
  class SegmentMetrics(DetMetrics):
1205
- """
1206
- Calculate and aggregate detection and segmentation metrics over a given set of classes.
1180
+ """Calculate and aggregate detection and segmentation metrics over a given set of classes.
1207
1181
 
1208
1182
  Attributes:
1209
1183
  names (dict[int, str]): Dictionary of class names.
@@ -1211,7 +1185,8 @@ class SegmentMetrics(DetMetrics):
1211
1185
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1212
1186
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1213
1187
  task (str): The task type, set to 'segment'.
1214
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1188
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1189
+ target classes, and target images.
1215
1190
  nt_per_class: Number of targets per class.
1216
1191
  nt_per_image: Number of targets per image.
1217
1192
 
@@ -1228,8 +1203,7 @@ class SegmentMetrics(DetMetrics):
1228
1203
  """
1229
1204
 
1230
1205
  def __init__(self, names: dict[int, str] = {}) -> None:
1231
- """
1232
- Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1206
+ """Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1233
1207
 
1234
1208
  Args:
1235
1209
  names (dict[int, str], optional): Dictionary of class names.
@@ -1240,8 +1214,7 @@ class SegmentMetrics(DetMetrics):
1240
1214
  self.stats["tp_m"] = [] # add additional stats for masks
1241
1215
 
1242
1216
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1243
- """
1244
- Process the detection and segmentation metrics over the given set of predictions.
1217
+ """Process the detection and segmentation metrics over the given set of predictions.
1245
1218
 
1246
1219
  Args:
1247
1220
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1270,7 +1243,8 @@ class SegmentMetrics(DetMetrics):
1270
1243
  @property
1271
1244
  def keys(self) -> list[str]:
1272
1245
  """Return a list of keys for accessing metrics."""
1273
- return DetMetrics.keys.fget(self) + [
1246
+ return [
1247
+ *DetMetrics.keys.fget(self),
1274
1248
  "metrics/precision(M)",
1275
1249
  "metrics/recall(M)",
1276
1250
  "metrics/mAP50(M)",
@@ -1298,7 +1272,8 @@ class SegmentMetrics(DetMetrics):
1298
1272
  @property
1299
1273
  def curves(self) -> list[str]:
1300
1274
  """Return a list of curves for accessing specific metrics curves."""
1301
- return DetMetrics.curves.fget(self) + [
1275
+ return [
1276
+ *DetMetrics.curves.fget(self),
1302
1277
  "Precision-Recall(M)",
1303
1278
  "F1-Confidence(M)",
1304
1279
  "Precision-Confidence(M)",
@@ -1311,16 +1286,17 @@ class SegmentMetrics(DetMetrics):
1311
1286
  return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1312
1287
 
1313
1288
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1314
- """
1315
- Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both
1316
- box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1289
+ """Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
1290
+ both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
1291
+ each class.
1317
1292
 
1318
1293
  Args:
1319
- normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1294
+ normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1320
1295
  decimals (int): Number of decimal places to round the metrics values to.
1321
1296
 
1322
1297
  Returns:
1323
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1298
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1299
+ values.
1324
1300
 
1325
1301
  Examples:
1326
1302
  >>> results = model.val(data="coco8-seg.yaml")
@@ -1339,8 +1315,7 @@ class SegmentMetrics(DetMetrics):
1339
1315
 
1340
1316
 
1341
1317
  class PoseMetrics(DetMetrics):
1342
- """
1343
- Calculate and aggregate detection and pose metrics over a given set of classes.
1318
+ """Calculate and aggregate detection and pose metrics over a given set of classes.
1344
1319
 
1345
1320
  Attributes:
1346
1321
  names (dict[int, str]): Dictionary of class names.
@@ -1348,7 +1323,8 @@ class PoseMetrics(DetMetrics):
1348
1323
  box (Metric): An instance of the Metric class for storing detection results.
1349
1324
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1350
1325
  task (str): The task type, set to 'pose'.
1351
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1326
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1327
+ target classes, and target images.
1352
1328
  nt_per_class: Number of targets per class.
1353
1329
  nt_per_image: Number of targets per image.
1354
1330
 
@@ -1365,8 +1341,7 @@ class PoseMetrics(DetMetrics):
1365
1341
  """
1366
1342
 
1367
1343
  def __init__(self, names: dict[int, str] = {}) -> None:
1368
- """
1369
- Initialize the PoseMetrics class with directory path, class names, and plotting options.
1344
+ """Initialize the PoseMetrics class with directory path, class names, and plotting options.
1370
1345
 
1371
1346
  Args:
1372
1347
  names (dict[int, str], optional): Dictionary of class names.
@@ -1377,8 +1352,7 @@ class PoseMetrics(DetMetrics):
1377
1352
  self.stats["tp_p"] = [] # add additional stats for pose
1378
1353
 
1379
1354
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1380
- """
1381
- Process the detection and pose metrics over the given set of predictions.
1355
+ """Process the detection and pose metrics over the given set of predictions.
1382
1356
 
1383
1357
  Args:
1384
1358
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1407,7 +1381,8 @@ class PoseMetrics(DetMetrics):
1407
1381
  @property
1408
1382
  def keys(self) -> list[str]:
1409
1383
  """Return a list of evaluation metric keys."""
1410
- return DetMetrics.keys.fget(self) + [
1384
+ return [
1385
+ *DetMetrics.keys.fget(self),
1411
1386
  "metrics/precision(P)",
1412
1387
  "metrics/recall(P)",
1413
1388
  "metrics/mAP50(P)",
@@ -1435,7 +1410,8 @@ class PoseMetrics(DetMetrics):
1435
1410
  @property
1436
1411
  def curves(self) -> list[str]:
1437
1412
  """Return a list of curves for accessing specific metrics curves."""
1438
- return DetMetrics.curves.fget(self) + [
1413
+ return [
1414
+ *DetMetrics.curves.fget(self),
1439
1415
  "Precision-Recall(B)",
1440
1416
  "F1-Confidence(B)",
1441
1417
  "Precision-Confidence(B)",
@@ -1452,16 +1428,16 @@ class PoseMetrics(DetMetrics):
1452
1428
  return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1453
1429
 
1454
1430
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1455
- """
1456
- Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and
1457
- pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1431
+ """Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
1432
+ and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1458
1433
 
1459
1434
  Args:
1460
- normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1435
+ normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1461
1436
  decimals (int): Number of decimal places to round the metrics values to.
1462
1437
 
1463
1438
  Returns:
1464
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1439
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1440
+ values.
1465
1441
 
1466
1442
  Examples:
1467
1443
  >>> results = model.val(data="coco8-pose.yaml")
@@ -1480,8 +1456,7 @@ class PoseMetrics(DetMetrics):
1480
1456
 
1481
1457
 
1482
1458
  class ClassifyMetrics(SimpleClass, DataExportMixin):
1483
- """
1484
- Class for computing classification metrics including top-1 and top-5 accuracy.
1459
+ """Class for computing classification metrics including top-1 and top-5 accuracy.
1485
1460
 
1486
1461
  Attributes:
1487
1462
  top1 (float): The top-1 accuracy.
@@ -1507,8 +1482,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1507
1482
  self.task = "classify"
1508
1483
 
1509
1484
  def process(self, targets: torch.Tensor, pred: torch.Tensor):
1510
- """
1511
- Process target classes and predicted classes to compute metrics.
1485
+ """Process target classes and predicted classes to compute metrics.
1512
1486
 
1513
1487
  Args:
1514
1488
  targets (torch.Tensor): Target classes.
@@ -1527,7 +1501,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1527
1501
  @property
1528
1502
  def results_dict(self) -> dict[str, float]:
1529
1503
  """Return a dictionary with model's performance metrics and fitness score."""
1530
- return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
1504
+ return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
1531
1505
 
1532
1506
  @property
1533
1507
  def keys(self) -> list[str]:
@@ -1545,11 +1519,10 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1545
1519
  return []
1546
1520
 
1547
1521
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
1548
- """
1549
- Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1522
+ """Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1550
1523
 
1551
1524
  Args:
1552
- normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1525
+ normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1553
1526
  decimals (int): Number of decimal places to round the metrics values to.
1554
1527
 
1555
1528
  Returns:
@@ -1564,15 +1537,15 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1564
1537
 
1565
1538
 
1566
1539
  class OBBMetrics(DetMetrics):
1567
- """
1568
- Metrics for evaluating oriented bounding box (OBB) detection.
1540
+ """Metrics for evaluating oriented bounding box (OBB) detection.
1569
1541
 
1570
1542
  Attributes:
1571
1543
  names (dict[int, str]): Dictionary of class names.
1572
1544
  box (Metric): An instance of the Metric class for storing detection results.
1573
1545
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1574
1546
  task (str): The task type, set to 'obb'.
1575
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1547
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1548
+ target classes, and target images.
1576
1549
  nt_per_class: Number of targets per class.
1577
1550
  nt_per_image: Number of targets per image.
1578
1551
 
@@ -1581,8 +1554,7 @@ class OBBMetrics(DetMetrics):
1581
1554
  """
1582
1555
 
1583
1556
  def __init__(self, names: dict[int, str] = {}) -> None:
1584
- """
1585
- Initialize an OBBMetrics instance with directory, plotting, and class names.
1557
+ """Initialize an OBBMetrics instance with directory, plotting, and class names.
1586
1558
 
1587
1559
  Args:
1588
1560
  names (dict[int, str], optional): Dictionary of class names.