dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,7 @@ OKS_SIGMA = (
21
21
 
22
22
 
23
23
  def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
24
- """
25
- Calculate the intersection over box2 area given box1 and box2.
24
+ """Calculate the intersection over box2 area given box1 and box2.
26
25
 
27
26
  Args:
28
27
  box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
@@ -53,8 +52,7 @@ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float =
53
52
 
54
53
 
55
54
  def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
56
- """
57
- Calculate intersection-over-union (IoU) of boxes.
55
+ """Calculate intersection-over-union (IoU) of boxes.
58
56
 
59
57
  Args:
60
58
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
@@ -85,19 +83,17 @@ def bbox_iou(
85
83
  CIoU: bool = False,
86
84
  eps: float = 1e-7,
87
85
  ) -> torch.Tensor:
88
- """
89
- Calculate the Intersection over Union (IoU) between bounding boxes.
86
+ """Calculate the Intersection over Union (IoU) between bounding boxes.
90
87
 
91
- This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
92
- For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
93
- Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
94
- or (x1, y1, x2, y2) if `xywh=False`.
88
+ This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. For instance, you
89
+ may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). Internally, the code will split the last
90
+ dimension into (x, y, w, h) if `xywh=True`, or (x1, y1, x2, y2) if `xywh=False`.
95
91
 
96
92
  Args:
97
93
  box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
98
94
  box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
99
- xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
100
- (x1, y1, x2, y2) format.
95
+ xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in (x1, y1,
96
+ x2, y2) format.
101
97
  GIoU (bool, optional): If True, calculate Generalized IoU.
102
98
  DIoU (bool, optional): If True, calculate Distance IoU.
103
99
  CIoU (bool, optional): If True, calculate Complete IoU.
@@ -148,14 +144,13 @@ def bbox_iou(
148
144
 
149
145
 
150
146
  def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
151
- """
152
- Calculate masks IoU.
147
+ """Calculate masks IoU.
153
148
 
154
149
  Args:
155
150
  mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
156
- product of image width and height.
157
- mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
158
- product of image width and height.
151
+ product of image width and height.
152
+ mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the product
153
+ of image width and height.
159
154
  eps (float, optional): A small value to avoid division by zero.
160
155
 
161
156
  Returns:
@@ -169,8 +164,7 @@ def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> tor
169
164
  def kpt_iou(
170
165
  kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
171
166
  ) -> torch.Tensor:
172
- """
173
- Calculate Object Keypoint Similarity (OKS).
167
+ """Calculate Object Keypoint Similarity (OKS).
174
168
 
175
169
  Args:
176
170
  kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
@@ -191,8 +185,7 @@ def kpt_iou(
191
185
 
192
186
 
193
187
  def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """
195
- Generate covariance matrix from oriented bounding boxes.
188
+ """Generate covariance matrix from oriented bounding boxes.
196
189
 
197
190
  Args:
198
191
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
@@ -211,8 +204,7 @@ def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Ten
211
204
 
212
205
 
213
206
  def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
214
- """
215
- Calculate probabilistic IoU between oriented bounding boxes.
207
+ """Calculate probabilistic IoU between oriented bounding boxes.
216
208
 
217
209
  Args:
218
210
  obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
@@ -257,8 +249,7 @@ def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: flo
257
249
 
258
250
 
259
251
  def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
260
- """
261
- Calculate the probabilistic IoU between oriented bounding boxes.
252
+ """Calculate the probabilistic IoU between oriented bounding boxes.
262
253
 
263
254
  Args:
264
255
  obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
@@ -294,8 +285,7 @@ def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarr
294
285
 
295
286
 
296
287
  def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
297
- """
298
- Compute smoothed positive and negative Binary Cross-Entropy targets.
288
+ """Compute smoothed positive and negative Binary Cross-Entropy targets.
299
289
 
300
290
  Args:
301
291
  eps (float, optional): The epsilon value for label smoothing.
@@ -311,8 +301,7 @@ def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
311
301
 
312
302
 
313
303
  class ConfusionMatrix(DataExportMixin):
314
- """
315
- A class for calculating and updating a confusion matrix for object detection and classification tasks.
304
+ """A class for calculating and updating a confusion matrix for object detection and classification tasks.
316
305
 
317
306
  Attributes:
318
307
  task (str): The type of task, either 'detect' or 'classify'.
@@ -323,8 +312,7 @@ class ConfusionMatrix(DataExportMixin):
323
312
  """
324
313
 
325
314
  def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
326
- """
327
- Initialize a ConfusionMatrix instance.
315
+ """Initialize a ConfusionMatrix instance.
328
316
 
329
317
  Args:
330
318
  names (dict[int, str], optional): Names of classes, used as labels on the plot.
@@ -338,21 +326,20 @@ class ConfusionMatrix(DataExportMixin):
338
326
  self.matches = {} if save_matches else None
339
327
 
340
328
  def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
341
- """
342
- Append the matches to TP, FP, FN or GT list for the last batch.
329
+ """Append the matches to TP, FP, FN or GT list for the last batch.
343
330
 
344
- This method updates the matches dictionary by appending specific batch data
345
- to the appropriate match type (True Positive, False Positive, or False Negative).
331
+ This method updates the matches dictionary by appending specific batch data to the appropriate match type (True
332
+ Positive, False Positive, or False Negative).
346
333
 
347
334
  Args:
348
335
  mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
349
- batch (dict[str, Any]): Batch data containing detection results with keys
350
- like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
336
+ batch (dict[str, Any]): Batch data containing detection results with keys like 'bboxes', 'cls', 'conf',
337
+ 'keypoints', 'masks'.
351
338
  idx (int): Index of the specific detection to append from the batch.
352
339
 
353
- Note:
354
- For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
355
- it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
340
+ Notes:
341
+ For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, it indicates
342
+ overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
356
343
  """
357
344
  if self.matches is None:
358
345
  return
@@ -364,8 +351,7 @@ class ConfusionMatrix(DataExportMixin):
364
351
  self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
365
352
 
366
353
  def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
367
- """
368
- Update confusion matrix for classification task.
354
+ """Update confusion matrix for classification task.
369
355
 
370
356
  Args:
371
357
  preds (list[N, min(nc,5)]): Predicted class labels.
@@ -382,15 +368,14 @@ class ConfusionMatrix(DataExportMixin):
382
368
  conf: float = 0.25,
383
369
  iou_thres: float = 0.45,
384
370
  ) -> None:
385
- """
386
- Update confusion matrix for object detection task.
371
+ """Update confusion matrix for object detection task.
387
372
 
388
373
  Args:
389
- detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
390
- Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
391
- Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
392
- batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
393
- 'cls' (Array[M]) keys, where M is the number of ground truth objects.
374
+ detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated
375
+ information. Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be Array[N, 4] for
376
+ regular boxes or Array[N, 5] for OBB with angle.
377
+ batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M,
378
+ 5]) and 'cls' (Array[M]) keys, where M is the number of ground truth objects.
394
379
  conf (float, optional): Confidence threshold for detections.
395
380
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
396
381
  """
@@ -460,8 +445,7 @@ class ConfusionMatrix(DataExportMixin):
460
445
  return self.matrix
461
446
 
462
447
  def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
463
- """
464
- Return true positives and false positives.
448
+ """Return true positives and false positives.
465
449
 
466
450
  Returns:
467
451
  tp (np.ndarray): True positives.
@@ -473,8 +457,7 @@ class ConfusionMatrix(DataExportMixin):
473
457
  return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
474
458
 
475
459
  def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
476
- """
477
- Plot grid of GT, TP, FP, FN for each image.
460
+ """Plot grid of GT, TP, FP, FN for each image.
478
461
 
479
462
  Args:
480
463
  img (torch.Tensor): Image to plot onto.
@@ -513,8 +496,7 @@ class ConfusionMatrix(DataExportMixin):
513
496
  @TryExcept(msg="ConfusionMatrix plot failure")
514
497
  @plt_settings()
515
498
  def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
516
- """
517
- Plot the confusion matrix using matplotlib and save it to a file.
499
+ """Plot the confusion matrix using matplotlib and save it to a file.
518
500
 
519
501
  Args:
520
502
  normalize (bool, optional): Whether to normalize the confusion matrix.
@@ -590,16 +572,17 @@ class ConfusionMatrix(DataExportMixin):
590
572
  LOGGER.info(" ".join(map(str, self.matrix[i])))
591
573
 
592
574
  def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
593
- """
594
- Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
595
- normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.
575
+ """Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
576
+ normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON,
577
+ or SQL.
596
578
 
597
579
  Args:
598
580
  normalize (bool): Whether to normalize the confusion matrix values.
599
581
  decimals (int): Number of decimal places to round the output values to.
600
582
 
601
583
  Returns:
602
- (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.
584
+ (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding
585
+ values for all actual classes.
603
586
 
604
587
  Examples:
605
588
  >>> results = model.val(data="coco8.yaml", plots=True)
@@ -643,8 +626,7 @@ def plot_pr_curve(
643
626
  names: dict[int, str] = {},
644
627
  on_plot=None,
645
628
  ):
646
- """
647
- Plot precision-recall curve.
629
+ """Plot precision-recall curve.
648
630
 
649
631
  Args:
650
632
  px (np.ndarray): X values for the PR curve.
@@ -688,8 +670,7 @@ def plot_mc_curve(
688
670
  ylabel: str = "Metric",
689
671
  on_plot=None,
690
672
  ):
691
- """
692
- Plot metric-confidence curve.
673
+ """Plot metric-confidence curve.
693
674
 
694
675
  Args:
695
676
  px (np.ndarray): X values for the metric-confidence curve.
@@ -725,8 +706,7 @@ def plot_mc_curve(
725
706
 
726
707
 
727
708
  def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
728
- """
729
- Compute the average precision (AP) given the recall and precision curves.
709
+ """Compute the average precision (AP) given the recall and precision curves.
730
710
 
731
711
  Args:
732
712
  recall (list): The recall curve.
@@ -769,8 +749,7 @@ def ap_per_class(
769
749
  eps: float = 1e-16,
770
750
  prefix: str = "",
771
751
  ) -> tuple:
772
- """
773
- Compute the average precision per class for object detection evaluation.
752
+ """Compute the average precision per class for object detection evaluation.
774
753
 
775
754
  Args:
776
755
  tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
@@ -855,8 +834,7 @@ def ap_per_class(
855
834
 
856
835
 
857
836
  class Metric(SimpleClass):
858
- """
859
- Class for computing evaluation metrics for Ultralytics YOLO models.
837
+ """Class for computing evaluation metrics for Ultralytics YOLO models.
860
838
 
861
839
  Attributes:
862
840
  p (list): Precision for each class. Shape: (nc,).
@@ -894,8 +872,7 @@ class Metric(SimpleClass):
894
872
 
895
873
  @property
896
874
  def ap50(self) -> np.ndarray | list:
897
- """
898
- Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
875
+ """Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
899
876
 
900
877
  Returns:
901
878
  (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
@@ -904,8 +881,7 @@ class Metric(SimpleClass):
904
881
 
905
882
  @property
906
883
  def ap(self) -> np.ndarray | list:
907
- """
908
- Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
884
+ """Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
909
885
 
910
886
  Returns:
911
887
  (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
@@ -914,8 +890,7 @@ class Metric(SimpleClass):
914
890
 
915
891
  @property
916
892
  def mp(self) -> float:
917
- """
918
- Return the Mean Precision of all classes.
893
+ """Return the Mean Precision of all classes.
919
894
 
920
895
  Returns:
921
896
  (float): The mean precision of all classes.
@@ -924,8 +899,7 @@ class Metric(SimpleClass):
924
899
 
925
900
  @property
926
901
  def mr(self) -> float:
927
- """
928
- Return the Mean Recall of all classes.
902
+ """Return the Mean Recall of all classes.
929
903
 
930
904
  Returns:
931
905
  (float): The mean recall of all classes.
@@ -934,8 +908,7 @@ class Metric(SimpleClass):
934
908
 
935
909
  @property
936
910
  def map50(self) -> float:
937
- """
938
- Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
911
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
939
912
 
940
913
  Returns:
941
914
  (float): The mAP at an IoU threshold of 0.5.
@@ -944,8 +917,7 @@ class Metric(SimpleClass):
944
917
 
945
918
  @property
946
919
  def map75(self) -> float:
947
- """
948
- Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
920
+ """Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
949
921
 
950
922
  Returns:
951
923
  (float): The mAP at an IoU threshold of 0.75.
@@ -954,8 +926,7 @@ class Metric(SimpleClass):
954
926
 
955
927
  @property
956
928
  def map(self) -> float:
957
- """
958
- Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
929
+ """Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
959
930
 
960
931
  Returns:
961
932
  (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
@@ -984,8 +955,7 @@ class Metric(SimpleClass):
984
955
  return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
985
956
 
986
957
  def update(self, results: tuple):
987
- """
988
- Update the evaluation metrics with a new set of results.
958
+ """Update the evaluation metrics with a new set of results.
989
959
 
990
960
  Args:
991
961
  results (tuple): A tuple containing evaluation metrics:
@@ -1030,15 +1000,15 @@ class Metric(SimpleClass):
1030
1000
 
1031
1001
 
1032
1002
  class DetMetrics(SimpleClass, DataExportMixin):
1033
- """
1034
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1003
+ """Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
1035
1004
 
1036
1005
  Attributes:
1037
1006
  names (dict[int, str]): A dictionary of class names.
1038
1007
  box (Metric): An instance of the Metric class for storing detection results.
1039
1008
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1009
  task (str): The task type, set to 'detect'.
1041
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1010
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1011
+ target classes, and target images.
1042
1012
  nt_per_class: Number of targets per class.
1043
1013
  nt_per_image: Number of targets per image.
1044
1014
 
@@ -1059,8 +1029,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1059
1029
  """
1060
1030
 
1061
1031
  def __init__(self, names: dict[int, str] = {}) -> None:
1062
- """
1063
- Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1032
+ """Initialize a DetMetrics instance with a save directory, plot flag, and class names.
1064
1033
 
1065
1034
  Args:
1066
1035
  names (dict[int, str], optional): Dictionary of class names.
@@ -1074,19 +1043,17 @@ class DetMetrics(SimpleClass, DataExportMixin):
1074
1043
  self.nt_per_image = None
1075
1044
 
1076
1045
  def update_stats(self, stat: dict[str, Any]) -> None:
1077
- """
1078
- Update statistics by appending new values to existing stat collections.
1046
+ """Update statistics by appending new values to existing stat collections.
1079
1047
 
1080
1048
  Args:
1081
- stat (dict[str, any]): Dictionary containing new statistical values to append.
1082
- Keys should match existing keys in self.stats.
1049
+ stat (dict[str, any]): Dictionary containing new statistical values to append. Keys should match existing
1050
+ keys in self.stats.
1083
1051
  """
1084
1052
  for k in self.stats.keys():
1085
1053
  self.stats[k].append(stat[k])
1086
1054
 
1087
1055
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1088
- """
1089
- Process predicted results for object detection and update metrics.
1056
+ """Process predicted results for object detection and update metrics.
1090
1057
 
1091
1058
  Args:
1092
1059
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1167,16 +1134,16 @@ class DetMetrics(SimpleClass, DataExportMixin):
1167
1134
  return self.box.curves_results
1168
1135
 
1169
1136
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1170
- """
1171
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
1172
- scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1137
+ """Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes
1138
+ shared scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1173
1139
 
1174
1140
  Args:
1175
- normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1176
- decimals (int): Number of decimal places to round the metrics values to.
1141
+ normalize (bool): For Detect metrics, everything is normalized by default [0-1].
1142
+ decimals (int): Number of decimal places to round the metrics values to.
1177
1143
 
1178
1144
  Returns:
1179
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1145
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1146
+ values.
1180
1147
 
1181
1148
  Examples:
1182
1149
  >>> results = model.val(data="coco8.yaml")
@@ -1202,8 +1169,7 @@ class DetMetrics(SimpleClass, DataExportMixin):
1202
1169
 
1203
1170
 
1204
1171
  class SegmentMetrics(DetMetrics):
1205
- """
1206
- Calculate and aggregate detection and segmentation metrics over a given set of classes.
1172
+ """Calculate and aggregate detection and segmentation metrics over a given set of classes.
1207
1173
 
1208
1174
  Attributes:
1209
1175
  names (dict[int, str]): Dictionary of class names.
@@ -1211,7 +1177,8 @@ class SegmentMetrics(DetMetrics):
1211
1177
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1212
1178
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1213
1179
  task (str): The task type, set to 'segment'.
1214
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1180
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1181
+ target classes, and target images.
1215
1182
  nt_per_class: Number of targets per class.
1216
1183
  nt_per_image: Number of targets per image.
1217
1184
 
@@ -1228,8 +1195,7 @@ class SegmentMetrics(DetMetrics):
1228
1195
  """
1229
1196
 
1230
1197
  def __init__(self, names: dict[int, str] = {}) -> None:
1231
- """
1232
- Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1198
+ """Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1233
1199
 
1234
1200
  Args:
1235
1201
  names (dict[int, str], optional): Dictionary of class names.
@@ -1240,8 +1206,7 @@ class SegmentMetrics(DetMetrics):
1240
1206
  self.stats["tp_m"] = [] # add additional stats for masks
1241
1207
 
1242
1208
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1243
- """
1244
- Process the detection and segmentation metrics over the given set of predictions.
1209
+ """Process the detection and segmentation metrics over the given set of predictions.
1245
1210
 
1246
1211
  Args:
1247
1212
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1313,16 +1278,17 @@ class SegmentMetrics(DetMetrics):
1313
1278
  return DetMetrics.curves_results.fget(self) + self.seg.curves_results
1314
1279
 
1315
1280
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1316
- """
1317
- Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both
1318
- box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1281
+ """Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes
1282
+ both box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for
1283
+ each class.
1319
1284
 
1320
1285
  Args:
1321
- normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1286
+ normalize (bool): For Segment metrics, everything is normalized by default [0-1].
1322
1287
  decimals (int): Number of decimal places to round the metrics values to.
1323
1288
 
1324
1289
  Returns:
1325
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1290
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1291
+ values.
1326
1292
 
1327
1293
  Examples:
1328
1294
  >>> results = model.val(data="coco8-seg.yaml")
@@ -1341,8 +1307,7 @@ class SegmentMetrics(DetMetrics):
1341
1307
 
1342
1308
 
1343
1309
  class PoseMetrics(DetMetrics):
1344
- """
1345
- Calculate and aggregate detection and pose metrics over a given set of classes.
1310
+ """Calculate and aggregate detection and pose metrics over a given set of classes.
1346
1311
 
1347
1312
  Attributes:
1348
1313
  names (dict[int, str]): Dictionary of class names.
@@ -1350,7 +1315,8 @@ class PoseMetrics(DetMetrics):
1350
1315
  box (Metric): An instance of the Metric class for storing detection results.
1351
1316
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1352
1317
  task (str): The task type, set to 'pose'.
1353
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1318
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1319
+ target classes, and target images.
1354
1320
  nt_per_class: Number of targets per class.
1355
1321
  nt_per_image: Number of targets per image.
1356
1322
 
@@ -1367,8 +1333,7 @@ class PoseMetrics(DetMetrics):
1367
1333
  """
1368
1334
 
1369
1335
  def __init__(self, names: dict[int, str] = {}) -> None:
1370
- """
1371
- Initialize the PoseMetrics class with directory path, class names, and plotting options.
1336
+ """Initialize the PoseMetrics class with directory path, class names, and plotting options.
1372
1337
 
1373
1338
  Args:
1374
1339
  names (dict[int, str], optional): Dictionary of class names.
@@ -1379,8 +1344,7 @@ class PoseMetrics(DetMetrics):
1379
1344
  self.stats["tp_p"] = [] # add additional stats for pose
1380
1345
 
1381
1346
  def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
1382
- """
1383
- Process the detection and pose metrics over the given set of predictions.
1347
+ """Process the detection and pose metrics over the given set of predictions.
1384
1348
 
1385
1349
  Args:
1386
1350
  save_dir (Path): Directory to save plots. Defaults to Path(".").
@@ -1456,16 +1420,16 @@ class PoseMetrics(DetMetrics):
1456
1420
  return DetMetrics.curves_results.fget(self) + self.pose.curves_results
1457
1421
 
1458
1422
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
1459
- """
1460
- Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and
1461
- pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1423
+ """Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box
1424
+ and pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
1462
1425
 
1463
1426
  Args:
1464
- normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1427
+ normalize (bool): For Pose metrics, everything is normalized by default [0-1].
1465
1428
  decimals (int): Number of decimal places to round the metrics values to.
1466
1429
 
1467
1430
  Returns:
1468
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
1431
+ (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric
1432
+ values.
1469
1433
 
1470
1434
  Examples:
1471
1435
  >>> results = model.val(data="coco8-pose.yaml")
@@ -1484,8 +1448,7 @@ class PoseMetrics(DetMetrics):
1484
1448
 
1485
1449
 
1486
1450
  class ClassifyMetrics(SimpleClass, DataExportMixin):
1487
- """
1488
- Class for computing classification metrics including top-1 and top-5 accuracy.
1451
+ """Class for computing classification metrics including top-1 and top-5 accuracy.
1489
1452
 
1490
1453
  Attributes:
1491
1454
  top1 (float): The top-1 accuracy.
@@ -1511,8 +1474,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1511
1474
  self.task = "classify"
1512
1475
 
1513
1476
  def process(self, targets: torch.Tensor, pred: torch.Tensor):
1514
- """
1515
- Process target classes and predicted classes to compute metrics.
1477
+ """Process target classes and predicted classes to compute metrics.
1516
1478
 
1517
1479
  Args:
1518
1480
  targets (torch.Tensor): Target classes.
@@ -1549,11 +1511,10 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1549
1511
  return []
1550
1512
 
1551
1513
  def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
1552
- """
1553
- Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1514
+ """Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
1554
1515
 
1555
1516
  Args:
1556
- normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1517
+ normalize (bool): For Classify metrics, everything is normalized by default [0-1].
1557
1518
  decimals (int): Number of decimal places to round the metrics values to.
1558
1519
 
1559
1520
  Returns:
@@ -1568,15 +1529,15 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1568
1529
 
1569
1530
 
1570
1531
  class OBBMetrics(DetMetrics):
1571
- """
1572
- Metrics for evaluating oriented bounding box (OBB) detection.
1532
+ """Metrics for evaluating oriented bounding box (OBB) detection.
1573
1533
 
1574
1534
  Attributes:
1575
1535
  names (dict[int, str]): Dictionary of class names.
1576
1536
  box (Metric): An instance of the Metric class for storing detection results.
1577
1537
  speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1578
1538
  task (str): The task type, set to 'obb'.
1579
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
1539
+ stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes,
1540
+ target classes, and target images.
1580
1541
  nt_per_class: Number of targets per class.
1581
1542
  nt_per_image: Number of targets per image.
1582
1543
 
@@ -1585,8 +1546,7 @@ class OBBMetrics(DetMetrics):
1585
1546
  """
1586
1547
 
1587
1548
  def __init__(self, names: dict[int, str] = {}) -> None:
1588
- """
1589
- Initialize an OBBMetrics instance with directory, plotting, and class names.
1549
+ """Initialize an OBBMetrics instance with directory, plotting, and class names.
1590
1550
 
1591
1551
  Args:
1592
1552
  names (dict[int, str], optional): Dictionary of class names.