ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
25
25
  box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
26
26
  box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
27
27
  iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
28
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
28
+ eps (float, optional): A small value to avoid division by zero.
29
29
 
30
30
  Returns:
31
31
  (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
@@ -57,7 +57,7 @@ def box_iou(box1, box2, eps=1e-7):
57
57
  Args:
58
58
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
59
59
  box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
60
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
60
+ eps (float, optional): A small value to avoid division by zero.
61
61
 
62
62
  Returns:
63
63
  (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
@@ -73,7 +73,7 @@ def box_iou(box1, box2, eps=1e-7):
73
73
 
74
74
  def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
75
75
  """
76
- Calculates the Intersection over Union (IoU) between bounding boxes.
76
+ Calculate the Intersection over Union (IoU) between bounding boxes.
77
77
 
78
78
  This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
79
79
  For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
@@ -84,11 +84,11 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
84
84
  box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
85
85
  box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
86
86
  xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
87
- (x1, y1, x2, y2) format. Defaults to True.
88
- GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
89
- DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
90
- CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
91
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
87
+ (x1, y1, x2, y2) format.
88
+ GIoU (bool, optional): If True, calculate Generalized IoU.
89
+ DIoU (bool, optional): If True, calculate Distance IoU.
90
+ CIoU (bool, optional): If True, calculate Complete IoU.
91
+ eps (float, optional): A small value to avoid division by zero.
92
92
 
93
93
  Returns:
94
94
  (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
@@ -143,7 +143,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
143
143
  product of image width and height.
144
144
  mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
145
145
  product of image width and height.
146
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
146
+ eps (float, optional): A small value to avoid division by zero.
147
147
 
148
148
  Returns:
149
149
  (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
@@ -162,7 +162,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
162
162
  kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
163
163
  area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
164
164
  sigma (list): A list containing 17 values representing keypoint scales.
165
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
165
+ eps (float, optional): A small value to avoid division by zero.
166
166
 
167
167
  Returns:
168
168
  (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
@@ -177,7 +177,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
177
177
 
178
178
  def _get_covariance_matrix(boxes):
179
179
  """
180
- Generating covariance matrix from obbs.
180
+ Generate covariance matrix from oriented bounding boxes.
181
181
 
182
182
  Args:
183
183
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
@@ -199,20 +199,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
199
199
  """
200
200
  Calculate probabilistic IoU between oriented bounding boxes.
201
201
 
202
- Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
203
-
204
202
  Args:
205
203
  obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
206
204
  obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
207
- CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
208
- eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
205
+ CIoU (bool, optional): If True, calculate CIoU.
206
+ eps (float, optional): Small value to avoid division by zero.
209
207
 
210
208
  Returns:
211
209
  (torch.Tensor): OBB similarities, shape (N,).
212
210
 
213
- Note:
214
- OBB format: [center_x, center_y, width, height, rotation_angle].
215
- If CIoU is True, returns CIoU instead of IoU.
211
+ Notes:
212
+ - OBB format: [center_x, center_y, width, height, rotation_angle].
213
+ - Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
216
214
  """
217
215
  x1, y1 = obb1[..., :2].split(1, dim=-1)
218
216
  x2, y2 = obb2[..., :2].split(1, dim=-1)
@@ -243,15 +241,18 @@ def probiou(obb1, obb2, CIoU=False, eps=1e-7):
243
241
 
244
242
  def batch_probiou(obb1, obb2, eps=1e-7):
245
243
  """
246
- Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
244
+ Calculate the probabilistic IoU between oriented bounding boxes.
247
245
 
248
246
  Args:
249
247
  obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
250
248
  obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
251
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
249
+ eps (float, optional): A small value to avoid division by zero.
252
250
 
253
251
  Returns:
254
252
  (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
253
+
254
+ References:
255
+ https://arxiv.org/pdf/2106.06072v1.pdf
255
256
  """
256
257
  obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
257
258
  obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
@@ -277,16 +278,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
277
278
 
278
279
  def smooth_bce(eps=0.1):
279
280
  """
280
- Computes smoothed positive and negative Binary Cross-Entropy targets.
281
-
282
- This function calculates positive and negative label smoothing BCE targets based on a given epsilon value.
283
- For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441.
281
+ Compute smoothed positive and negative Binary Cross-Entropy targets.
284
282
 
285
283
  Args:
286
- eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1.
284
+ eps (float, optional): The epsilon value for label smoothing.
287
285
 
288
286
  Returns:
289
287
  (tuple): A tuple containing the positive and negative label smoothing BCE targets.
288
+
289
+ References:
290
+ https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
290
291
  """
291
292
  return 1.0 - 0.5 * eps, 0.5 * eps
292
293
 
@@ -304,7 +305,15 @@ class ConfusionMatrix:
304
305
  """
305
306
 
306
307
  def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
307
- """Initialize attributes for the YOLO model."""
308
+ """
309
+ Initialize a ConfusionMatrix instance.
310
+
311
+ Args:
312
+ nc (int): Number of classes.
313
+ conf (float, optional): Confidence threshold for detections.
314
+ iou_thres (float, optional): IoU threshold for matching detections to ground truth.
315
+ task (str, optional): Type of task, either 'detect' or 'classify'.
316
+ """
308
317
  self.task = task
309
318
  self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
310
319
  self.nc = nc # number of classes
@@ -382,11 +391,16 @@ class ConfusionMatrix:
382
391
  self.matrix[dc, self.nc] += 1 # predicted background
383
392
 
384
393
  def matrix(self):
385
- """Returns the confusion matrix."""
394
+ """Return the confusion matrix."""
386
395
  return self.matrix
387
396
 
388
397
  def tp_fp(self):
389
- """Returns true positives and false positives."""
398
+ """
399
+ Return true positives and false positives.
400
+
401
+ Returns:
402
+ (tuple): True positives and false positives.
403
+ """
390
404
  tp = self.matrix.diagonal() # true positives
391
405
  fp = self.matrix.sum(1) - tp # false positives
392
406
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
@@ -454,7 +468,17 @@ def smooth(y, f=0.05):
454
468
 
455
469
  @plt_settings()
456
470
  def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
457
- """Plots a precision-recall curve."""
471
+ """
472
+ Plot precision-recall curve.
473
+
474
+ Args:
475
+ px (np.ndarray): X values for the PR curve.
476
+ py (np.ndarray): Y values for the PR curve.
477
+ ap (np.ndarray): Average precision values.
478
+ save_dir (Path, optional): Path to save the plot.
479
+ names (dict, optional): Dictionary mapping class indices to class names.
480
+ on_plot (callable, optional): Function to call after plot is saved.
481
+ """
458
482
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
459
483
  py = np.stack(py, axis=1)
460
484
 
@@ -479,7 +503,18 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=N
479
503
 
480
504
  @plt_settings()
481
505
  def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
482
- """Plots a metric-confidence curve."""
506
+ """
507
+ Plot metric-confidence curve.
508
+
509
+ Args:
510
+ px (np.ndarray): X values for the metric-confidence curve.
511
+ py (np.ndarray): Y values for the metric-confidence curve.
512
+ save_dir (Path, optional): Path to save the plot.
513
+ names (dict, optional): Dictionary mapping class indices to class names.
514
+ xlabel (str, optional): X-axis label.
515
+ ylabel (str, optional): Y-axis label.
516
+ on_plot (callable, optional): Function to call after plot is saved.
517
+ """
483
518
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
484
519
 
485
520
  if 0 < len(names) < 21: # display per-class legend if < 21 classes
@@ -538,33 +573,33 @@ def ap_per_class(
538
573
  tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
539
574
  ):
540
575
  """
541
- Computes the average precision per class for object detection evaluation.
576
+ Compute the average precision per class for object detection evaluation.
542
577
 
543
578
  Args:
544
579
  tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
545
580
  conf (np.ndarray): Array of confidence scores of the detections.
546
581
  pred_cls (np.ndarray): Array of predicted classes of the detections.
547
582
  target_cls (np.ndarray): Array of true classes of the detections.
548
- plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
549
- on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
550
- save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
551
- names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
552
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
553
- prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
583
+ plot (bool, optional): Whether to plot PR curves or not.
584
+ on_plot (func, optional): A callback to pass plots path and data when they are rendered.
585
+ save_dir (Path, optional): Directory to save the PR curves.
586
+ names (dict, optional): Dict of class names to plot PR curves.
587
+ eps (float, optional): A small value to avoid division by zero.
588
+ prefix (str, optional): A prefix string for saving the plot files.
554
589
 
555
590
  Returns:
556
- tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
557
- fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
558
- p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
559
- r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
560
- f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
561
- ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
562
- unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
563
- p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
564
- r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
565
- f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
566
- x (np.ndarray): X-axis values for the curves. Shape: (1000,).
567
- prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
591
+ tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
592
+ fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
593
+ p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
594
+ r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
595
+ f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
596
+ ap (np.ndarray): Average precision for each class at different IoU thresholds.
597
+ unique_classes (np.ndarray): An array of unique classes that have data.
598
+ p_curve (np.ndarray): Precision curves for each class.
599
+ r_curve (np.ndarray): Recall curves for each class.
600
+ f1_curve (np.ndarray): F1-score curves for each class.
601
+ x (np.ndarray): X-axis values for the curves.
602
+ prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
568
603
  """
569
604
  # Sort by objectness
570
605
  i = np.argsort(-conf)
@@ -651,7 +686,7 @@ class Metric(SimpleClass):
651
686
  """
652
687
 
653
688
  def __init__(self) -> None:
654
- """Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model."""
689
+ """Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model."""
655
690
  self.p = [] # (nc, )
656
691
  self.r = [] # (nc, )
657
692
  self.f1 = [] # (nc, )
@@ -662,7 +697,7 @@ class Metric(SimpleClass):
662
697
  @property
663
698
  def ap50(self):
664
699
  """
665
- Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
700
+ Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
666
701
 
667
702
  Returns:
668
703
  (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
@@ -672,7 +707,7 @@ class Metric(SimpleClass):
672
707
  @property
673
708
  def ap(self):
674
709
  """
675
- Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
710
+ Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
676
711
 
677
712
  Returns:
678
713
  (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
@@ -682,7 +717,7 @@ class Metric(SimpleClass):
682
717
  @property
683
718
  def mp(self):
684
719
  """
685
- Returns the Mean Precision of all classes.
720
+ Return the Mean Precision of all classes.
686
721
 
687
722
  Returns:
688
723
  (float): The mean precision of all classes.
@@ -692,7 +727,7 @@ class Metric(SimpleClass):
692
727
  @property
693
728
  def mr(self):
694
729
  """
695
- Returns the Mean Recall of all classes.
730
+ Return the Mean Recall of all classes.
696
731
 
697
732
  Returns:
698
733
  (float): The mean recall of all classes.
@@ -702,7 +737,7 @@ class Metric(SimpleClass):
702
737
  @property
703
738
  def map50(self):
704
739
  """
705
- Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
740
+ Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
706
741
 
707
742
  Returns:
708
743
  (float): The mAP at an IoU threshold of 0.5.
@@ -712,7 +747,7 @@ class Metric(SimpleClass):
712
747
  @property
713
748
  def map75(self):
714
749
  """
715
- Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
750
+ Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
716
751
 
717
752
  Returns:
718
753
  (float): The mAP at an IoU threshold of 0.75.
@@ -722,7 +757,7 @@ class Metric(SimpleClass):
722
757
  @property
723
758
  def map(self):
724
759
  """
725
- Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
760
+ Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
726
761
 
727
762
  Returns:
728
763
  (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
@@ -730,41 +765,42 @@ class Metric(SimpleClass):
730
765
  return self.all_ap.mean() if len(self.all_ap) else 0.0
731
766
 
732
767
  def mean_results(self):
733
- """Mean of results, return mp, mr, map50, map."""
768
+ """Return mean of results, mp, mr, map50, map."""
734
769
  return [self.mp, self.mr, self.map50, self.map]
735
770
 
736
771
  def class_result(self, i):
737
- """Class-aware result, return p[i], r[i], ap50[i], ap[i]."""
772
+ """Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
738
773
  return self.p[i], self.r[i], self.ap50[i], self.ap[i]
739
774
 
740
775
  @property
741
776
  def maps(self):
742
- """MAP of each class."""
777
+ """Return mAP of each class."""
743
778
  maps = np.zeros(self.nc) + self.map
744
779
  for i, c in enumerate(self.ap_class_index):
745
780
  maps[c] = self.ap[i]
746
781
  return maps
747
782
 
748
783
  def fitness(self):
749
- """Model fitness as a weighted combination of metrics."""
784
+ """Return model fitness as a weighted combination of metrics."""
750
785
  w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
751
786
  return (np.array(self.mean_results()) * w).sum()
752
787
 
753
788
  def update(self, results):
754
789
  """
755
- Updates the evaluation metrics of the model with a new set of results.
790
+ Update the evaluation metrics with a new set of results.
756
791
 
757
792
  Args:
758
- results (tuple): A tuple containing the following evaluation metrics:
759
- - p (list): Precision for each class. Shape: (nc,).
760
- - r (list): Recall for each class. Shape: (nc,).
761
- - f1 (list): F1 score for each class. Shape: (nc,).
762
- - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
763
- - ap_class_index (list): Index of class for each AP score. Shape: (nc,).
764
-
765
- Side Effects:
766
- Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based
767
- on the values provided in the `results` tuple.
793
+ results (tuple): A tuple containing evaluation metrics:
794
+ - p (list): Precision for each class.
795
+ - r (list): Recall for each class.
796
+ - f1 (list): F1 score for each class.
797
+ - all_ap (list): AP scores for all classes and all IoU thresholds.
798
+ - ap_class_index (list): Index of class for each AP score.
799
+ - p_curve (list): Precision curve for each class.
800
+ - r_curve (list): Recall curve for each class.
801
+ - f1_curve (list): F1 curve for each class.
802
+ - px (list): X values for the curves.
803
+ - prec_values (list): Precision values for each class.
768
804
  """
769
805
  (
770
806
  self.p,
@@ -781,12 +817,12 @@ class Metric(SimpleClass):
781
817
 
782
818
  @property
783
819
  def curves(self):
784
- """Returns a list of curves for accessing specific metrics curves."""
820
+ """Return a list of curves for accessing specific metrics curves."""
785
821
  return []
786
822
 
787
823
  @property
788
824
  def curves_results(self):
789
- """Returns a list of curves for accessing specific metrics curves."""
825
+ """Return a list of curves for accessing specific metrics curves."""
790
826
  return [
791
827
  [self.px, self.prec_values, "Recall", "Precision"],
792
828
  [self.px, self.f1_curve, "Confidence", "F1"],
@@ -797,36 +833,26 @@ class Metric(SimpleClass):
797
833
 
798
834
  class DetMetrics(SimpleClass):
799
835
  """
800
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
801
- object detection model.
802
-
803
- Args:
804
- save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
805
- plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
806
- names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
836
+ Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
807
837
 
808
838
  Attributes:
809
839
  save_dir (Path): A path to the directory where the output plots will be saved.
810
- plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
811
- names (dict of str): A dict of strings that represents the names of the classes.
812
- box (Metric): An instance of the Metric class for storing the results of the detection metrics.
813
- speed (dict): A dictionary for storing the execution time of different parts of the detection process.
814
-
815
- Methods:
816
- process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
817
- keys: Returns a list of keys for accessing the computed detection metrics.
818
- mean_results: Returns a list of mean values for the computed detection metrics.
819
- class_result(i): Returns a list of values for the computed detection metrics for a specific class.
820
- maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
821
- fitness: Computes the fitness score based on the computed detection metrics.
822
- ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
823
- results_dict: Returns a dictionary that maps detection metric keys to their computed values.
824
- curves: TODO
825
- curves_results: TODO
840
+ plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
841
+ names (dict): A dictionary of class names.
842
+ box (Metric): An instance of the Metric class for storing detection results.
843
+ speed (dict): A dictionary for storing execution times of different parts of the detection process.
844
+ task (str): The task type, set to 'detect'.
826
845
  """
827
846
 
828
847
  def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
829
- """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
848
+ """
849
+ Initialize a DetMetrics instance with a save directory, plot flag, and class names.
850
+
851
+ Args:
852
+ save_dir (Path, optional): Directory to save plots.
853
+ plot (bool, optional): Whether to plot precision-recall curves.
854
+ names (dict, optional): Dictionary mapping class indices to names.
855
+ """
830
856
  self.save_dir = save_dir
831
857
  self.plot = plot
832
858
  self.names = names
@@ -835,7 +861,16 @@ class DetMetrics(SimpleClass):
835
861
  self.task = "detect"
836
862
 
837
863
  def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
838
- """Process predicted results for object detection and update metrics."""
864
+ """
865
+ Process predicted results for object detection and update metrics.
866
+
867
+ Args:
868
+ tp (np.ndarray): True positive array.
869
+ conf (np.ndarray): Confidence array.
870
+ pred_cls (np.ndarray): Predicted class indices array.
871
+ target_cls (np.ndarray): Target class indices array.
872
+ on_plot (callable, optional): Function to call after plots are generated.
873
+ """
839
874
  results = ap_per_class(
840
875
  tp,
841
876
  conf,
@@ -851,7 +886,7 @@ class DetMetrics(SimpleClass):
851
886
 
852
887
  @property
853
888
  def keys(self):
854
- """Returns a list of keys for accessing specific metrics."""
889
+ """Return a list of keys for accessing specific metrics."""
855
890
  return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
856
891
 
857
892
  def mean_results(self):
@@ -864,32 +899,32 @@ class DetMetrics(SimpleClass):
864
899
 
865
900
  @property
866
901
  def maps(self):
867
- """Returns mean Average Precision (mAP) scores per class."""
902
+ """Return mean Average Precision (mAP) scores per class."""
868
903
  return self.box.maps
869
904
 
870
905
  @property
871
906
  def fitness(self):
872
- """Returns the fitness of box object."""
907
+ """Return the fitness of box object."""
873
908
  return self.box.fitness()
874
909
 
875
910
  @property
876
911
  def ap_class_index(self):
877
- """Returns the average precision index per class."""
912
+ """Return the average precision index per class."""
878
913
  return self.box.ap_class_index
879
914
 
880
915
  @property
881
916
  def results_dict(self):
882
- """Returns dictionary of computed performance metrics and statistics."""
917
+ """Return dictionary of computed performance metrics and statistics."""
883
918
  return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
884
919
 
885
920
  @property
886
921
  def curves(self):
887
- """Returns a list of curves for accessing specific metrics curves."""
922
+ """Return a list of curves for accessing specific metrics curves."""
888
923
  return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
889
924
 
890
925
  @property
891
926
  def curves_results(self):
892
- """Returns dictionary of computed performance metrics and statistics."""
927
+ """Return dictionary of computed performance metrics and statistics."""
893
928
  return self.box.curves_results
894
929
 
895
930
 
@@ -897,31 +932,25 @@ class SegmentMetrics(SimpleClass):
897
932
  """
898
933
  Calculates and aggregates detection and segmentation metrics over a given set of classes.
899
934
 
900
- Args:
901
- save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
902
- plot (bool): Whether to save the detection and segmentation plots. Default is False.
903
- names (list): List of class names. Default is an empty list.
904
-
905
935
  Attributes:
906
936
  save_dir (Path): Path to the directory where the output plots should be saved.
907
937
  plot (bool): Whether to save the detection and segmentation plots.
908
- names (list): List of class names.
938
+ names (dict): Dictionary of class names.
909
939
  box (Metric): An instance of the Metric class to calculate box detection metrics.
910
940
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
911
941
  speed (dict): Dictionary to store the time taken in different phases of inference.
912
-
913
- Methods:
914
- process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
915
- mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
916
- class_result(i): Returns the detection and segmentation metrics of class `i`.
917
- maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
918
- fitness: Returns the fitness scores, which are a single weighted combination of metrics.
919
- ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
920
- results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
942
+ task (str): The task type, set to 'segment'.
921
943
  """
922
944
 
923
945
  def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
924
- """Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names."""
946
+ """
947
+ Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
948
+
949
+ Args:
950
+ save_dir (Path, optional): Directory to save plots.
951
+ plot (bool, optional): Whether to plot precision-recall curves.
952
+ names (dict, optional): Dictionary mapping class indices to names.
953
+ """
925
954
  self.save_dir = save_dir
926
955
  self.plot = plot
927
956
  self.names = names
@@ -932,15 +961,15 @@ class SegmentMetrics(SimpleClass):
932
961
 
933
962
  def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
934
963
  """
935
- Processes the detection and segmentation metrics over the given set of predictions.
964
+ Process the detection and segmentation metrics over the given set of predictions.
936
965
 
937
966
  Args:
938
- tp (list): List of True Positive boxes.
939
- tp_m (list): List of True Positive masks.
940
- conf (list): List of confidence scores.
941
- pred_cls (list): List of predicted classes.
942
- target_cls (list): List of target classes.
943
- on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
967
+ tp (np.ndarray): True positive array for boxes.
968
+ tp_m (np.ndarray): True positive array for masks.
969
+ conf (np.ndarray): Confidence array.
970
+ pred_cls (np.ndarray): Predicted class indices array.
971
+ target_cls (np.ndarray): Target class indices array.
972
+ on_plot (callable, optional): Function to call after plots are generated.
944
973
  """
945
974
  results_mask = ap_per_class(
946
975
  tp_m,
@@ -971,7 +1000,7 @@ class SegmentMetrics(SimpleClass):
971
1000
 
972
1001
  @property
973
1002
  def keys(self):
974
- """Returns a list of keys for accessing metrics."""
1003
+ """Return a list of keys for accessing metrics."""
975
1004
  return [
976
1005
  "metrics/precision(B)",
977
1006
  "metrics/recall(B)",
@@ -988,32 +1017,36 @@ class SegmentMetrics(SimpleClass):
988
1017
  return self.box.mean_results() + self.seg.mean_results()
989
1018
 
990
1019
  def class_result(self, i):
991
- """Returns classification results for a specified class index."""
1020
+ """Return classification results for a specified class index."""
992
1021
  return self.box.class_result(i) + self.seg.class_result(i)
993
1022
 
994
1023
  @property
995
1024
  def maps(self):
996
- """Returns mAP scores for object detection and semantic segmentation models."""
1025
+ """Return mAP scores for object detection and semantic segmentation models."""
997
1026
  return self.box.maps + self.seg.maps
998
1027
 
999
1028
  @property
1000
1029
  def fitness(self):
1001
- """Get the fitness score for both segmentation and bounding box models."""
1030
+ """Return the fitness score for both segmentation and bounding box models."""
1002
1031
  return self.seg.fitness() + self.box.fitness()
1003
1032
 
1004
1033
  @property
1005
1034
  def ap_class_index(self):
1006
- """Boxes and masks have the same ap_class_index."""
1035
+ """
1036
+ Return the class indices.
1037
+
1038
+ Boxes and masks have the same ap_class_index.
1039
+ """
1007
1040
  return self.box.ap_class_index
1008
1041
 
1009
1042
  @property
1010
1043
  def results_dict(self):
1011
- """Returns results of object detection model for evaluation."""
1044
+ """Return results of object detection model for evaluation."""
1012
1045
  return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1013
1046
 
1014
1047
  @property
1015
1048
  def curves(self):
1016
- """Returns a list of curves for accessing specific metrics curves."""
1049
+ """Return a list of curves for accessing specific metrics curves."""
1017
1050
  return [
1018
1051
  "Precision-Recall(B)",
1019
1052
  "F1-Confidence(B)",
@@ -1027,7 +1060,7 @@ class SegmentMetrics(SimpleClass):
1027
1060
 
1028
1061
  @property
1029
1062
  def curves_results(self):
1030
- """Returns dictionary of computed performance metrics and statistics."""
1063
+ """Return dictionary of computed performance metrics and statistics."""
1031
1064
  return self.box.curves_results + self.seg.curves_results
1032
1065
 
1033
1066
 
@@ -1035,18 +1068,14 @@ class PoseMetrics(SegmentMetrics):
1035
1068
  """
1036
1069
  Calculates and aggregates detection and pose metrics over a given set of classes.
1037
1070
 
1038
- Args:
1039
- save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
1040
- plot (bool): Whether to save the detection and segmentation plots. Default is False.
1041
- names (list): List of class names. Default is an empty list.
1042
-
1043
1071
  Attributes:
1044
1072
  save_dir (Path): Path to the directory where the output plots should be saved.
1045
- plot (bool): Whether to save the detection and segmentation plots.
1046
- names (list): List of class names.
1073
+ plot (bool): Whether to save the detection and pose plots.
1074
+ names (dict): Dictionary of class names.
1047
1075
  box (Metric): An instance of the Metric class to calculate box detection metrics.
1048
- pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1076
+ pose (Metric): An instance of the Metric class to calculate pose metrics.
1049
1077
  speed (dict): Dictionary to store the time taken in different phases of inference.
1078
+ task (str): The task type, set to 'pose'.
1050
1079
 
1051
1080
  Methods:
1052
1081
  process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
@@ -1059,7 +1088,14 @@ class PoseMetrics(SegmentMetrics):
1059
1088
  """
1060
1089
 
1061
1090
  def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
1062
- """Initialize the PoseMetrics class with directory path, class names, and plotting options."""
1091
+ """
1092
+ Initialize the PoseMetrics class with directory path, class names, and plotting options.
1093
+
1094
+ Args:
1095
+ save_dir (Path, optional): Directory to save plots.
1096
+ plot (bool, optional): Whether to plot precision-recall curves.
1097
+ names (dict, optional): Dictionary mapping class indices to names.
1098
+ """
1063
1099
  super().__init__(save_dir, plot, names)
1064
1100
  self.save_dir = save_dir
1065
1101
  self.plot = plot
@@ -1071,15 +1107,15 @@ class PoseMetrics(SegmentMetrics):
1071
1107
 
1072
1108
  def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
1073
1109
  """
1074
- Processes the detection and pose metrics over the given set of predictions.
1110
+ Process the detection and pose metrics over the given set of predictions.
1075
1111
 
1076
1112
  Args:
1077
- tp (list): List of True Positive boxes.
1078
- tp_p (list): List of True Positive keypoints.
1079
- conf (list): List of confidence scores.
1080
- pred_cls (list): List of predicted classes.
1081
- target_cls (list): List of target classes.
1082
- on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
1113
+ tp (np.ndarray): True positive array for boxes.
1114
+ tp_p (np.ndarray): True positive array for keypoints.
1115
+ conf (np.ndarray): Confidence array.
1116
+ pred_cls (np.ndarray): Predicted class indices array.
1117
+ target_cls (np.ndarray): Target class indices array.
1118
+ on_plot (callable, optional): Function to call after plots are generated.
1083
1119
  """
1084
1120
  results_pose = ap_per_class(
1085
1121
  tp_p,
@@ -1110,7 +1146,7 @@ class PoseMetrics(SegmentMetrics):
1110
1146
 
1111
1147
  @property
1112
1148
  def keys(self):
1113
- """Returns list of evaluation metric keys."""
1149
+ """Return list of evaluation metric keys."""
1114
1150
  return [
1115
1151
  "metrics/precision(B)",
1116
1152
  "metrics/recall(B)",
@@ -1132,17 +1168,17 @@ class PoseMetrics(SegmentMetrics):
1132
1168
 
1133
1169
  @property
1134
1170
  def maps(self):
1135
- """Returns the mean average precision (mAP) per class for both box and pose detections."""
1171
+ """Return the mean average precision (mAP) per class for both box and pose detections."""
1136
1172
  return self.box.maps + self.pose.maps
1137
1173
 
1138
1174
  @property
1139
1175
  def fitness(self):
1140
- """Computes classification metrics and speed using the `targets` and `pred` inputs."""
1176
+ """Return combined fitness score for pose and box detection."""
1141
1177
  return self.pose.fitness() + self.box.fitness()
1142
1178
 
1143
1179
  @property
1144
1180
  def curves(self):
1145
- """Returns a list of curves for accessing specific metrics curves."""
1181
+ """Return a list of curves for accessing specific metrics curves."""
1146
1182
  return [
1147
1183
  "Precision-Recall(B)",
1148
1184
  "F1-Confidence(B)",
@@ -1156,7 +1192,7 @@ class PoseMetrics(SegmentMetrics):
1156
1192
 
1157
1193
  @property
1158
1194
  def curves_results(self):
1159
- """Returns dictionary of computed performance metrics and statistics."""
1195
+ """Return dictionary of computed performance metrics and statistics."""
1160
1196
  return self.box.curves_results + self.pose.curves_results
1161
1197
 
1162
1198
 
@@ -1167,13 +1203,8 @@ class ClassifyMetrics(SimpleClass):
1167
1203
  Attributes:
1168
1204
  top1 (float): The top-1 accuracy.
1169
1205
  top5 (float): The top-5 accuracy.
1170
- speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
1171
- fitness (float): The fitness of the model, which is equal to top-5 accuracy.
1172
- results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
1173
- keys (List[str]): A list of keys for the results_dict.
1174
-
1175
- Methods:
1176
- process(targets, pred): Processes the targets and predictions to compute classification metrics.
1206
+ speed (dict): A dictionary containing the time taken for each step in the pipeline.
1207
+ task (str): The task type, set to 'classify'.
1177
1208
  """
1178
1209
 
1179
1210
  def __init__(self) -> None:
@@ -1184,7 +1215,13 @@ class ClassifyMetrics(SimpleClass):
1184
1215
  self.task = "classify"
1185
1216
 
1186
1217
  def process(self, targets, pred):
1187
- """Target classes and predicted classes."""
1218
+ """
1219
+ Process target classes and predicted classes to compute metrics.
1220
+
1221
+ Args:
1222
+ targets (torch.Tensor): Target classes.
1223
+ pred (torch.Tensor): Predicted classes.
1224
+ """
1188
1225
  pred, targets = torch.cat(pred), torch.cat(targets)
1189
1226
  correct = (targets[:, None] == pred).float()
1190
1227
  acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
@@ -1192,35 +1229,54 @@ class ClassifyMetrics(SimpleClass):
1192
1229
 
1193
1230
  @property
1194
1231
  def fitness(self):
1195
- """Returns mean of top-1 and top-5 accuracies as fitness score."""
1232
+ """Return mean of top-1 and top-5 accuracies as fitness score."""
1196
1233
  return (self.top1 + self.top5) / 2
1197
1234
 
1198
1235
  @property
1199
1236
  def results_dict(self):
1200
- """Returns a dictionary with model's performance metrics and fitness score."""
1237
+ """Return a dictionary with model's performance metrics and fitness score."""
1201
1238
  return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
1202
1239
 
1203
1240
  @property
1204
1241
  def keys(self):
1205
- """Returns a list of keys for the results_dict property."""
1242
+ """Return a list of keys for the results_dict property."""
1206
1243
  return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
1207
1244
 
1208
1245
  @property
1209
1246
  def curves(self):
1210
- """Returns a list of curves for accessing specific metrics curves."""
1247
+ """Return a list of curves for accessing specific metrics curves."""
1211
1248
  return []
1212
1249
 
1213
1250
  @property
1214
1251
  def curves_results(self):
1215
- """Returns a list of curves for accessing specific metrics curves."""
1252
+ """Return a list of curves for accessing specific metrics curves."""
1216
1253
  return []
1217
1254
 
1218
1255
 
1219
1256
  class OBBMetrics(SimpleClass):
1220
- """Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
1257
+ """
1258
+ Metrics for evaluating oriented bounding box (OBB) detection.
1259
+
1260
+ Attributes:
1261
+ save_dir (Path): Path to the directory where the output plots should be saved.
1262
+ plot (bool): Whether to save the detection plots.
1263
+ names (dict): Dictionary of class names.
1264
+ box (Metric): An instance of the Metric class for storing detection results.
1265
+ speed (dict): A dictionary for storing execution times of different parts of the detection process.
1266
+
1267
+ References:
1268
+ https://arxiv.org/pdf/2106.06072.pdf
1269
+ """
1221
1270
 
1222
1271
  def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
1223
- """Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
1272
+ """
1273
+ Initialize an OBBMetrics instance with directory, plotting, and class names.
1274
+
1275
+ Args:
1276
+ save_dir (Path, optional): Directory to save plots.
1277
+ plot (bool, optional): Whether to plot precision-recall curves.
1278
+ names (dict, optional): Dictionary mapping class indices to names.
1279
+ """
1224
1280
  self.save_dir = save_dir
1225
1281
  self.plot = plot
1226
1282
  self.names = names
@@ -1228,7 +1284,16 @@ class OBBMetrics(SimpleClass):
1228
1284
  self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
1229
1285
 
1230
1286
  def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
1231
- """Process predicted results for object detection and update metrics."""
1287
+ """
1288
+ Process predicted results for object detection and update metrics.
1289
+
1290
+ Args:
1291
+ tp (np.ndarray): True positive array.
1292
+ conf (np.ndarray): Confidence array.
1293
+ pred_cls (np.ndarray): Predicted class indices array.
1294
+ target_cls (np.ndarray): Target class indices array.
1295
+ on_plot (callable, optional): Function to call after plots are generated.
1296
+ """
1232
1297
  results = ap_per_class(
1233
1298
  tp,
1234
1299
  conf,
@@ -1244,7 +1309,7 @@ class OBBMetrics(SimpleClass):
1244
1309
 
1245
1310
  @property
1246
1311
  def keys(self):
1247
- """Returns a list of keys for accessing specific metrics."""
1312
+ """Return a list of keys for accessing specific metrics."""
1248
1313
  return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
1249
1314
 
1250
1315
  def mean_results(self):
@@ -1257,30 +1322,30 @@ class OBBMetrics(SimpleClass):
1257
1322
 
1258
1323
  @property
1259
1324
  def maps(self):
1260
- """Returns mean Average Precision (mAP) scores per class."""
1325
+ """Return mean Average Precision (mAP) scores per class."""
1261
1326
  return self.box.maps
1262
1327
 
1263
1328
  @property
1264
1329
  def fitness(self):
1265
- """Returns the fitness of box object."""
1330
+ """Return the fitness of box object."""
1266
1331
  return self.box.fitness()
1267
1332
 
1268
1333
  @property
1269
1334
  def ap_class_index(self):
1270
- """Returns the average precision index per class."""
1335
+ """Return the average precision index per class."""
1271
1336
  return self.box.ap_class_index
1272
1337
 
1273
1338
  @property
1274
1339
  def results_dict(self):
1275
- """Returns dictionary of computed performance metrics and statistics."""
1340
+ """Return dictionary of computed performance metrics and statistics."""
1276
1341
  return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
1277
1342
 
1278
1343
  @property
1279
1344
  def curves(self):
1280
- """Returns a list of curves for accessing specific metrics curves."""
1345
+ """Return a list of curves for accessing specific metrics curves."""
1281
1346
  return []
1282
1347
 
1283
1348
  @property
1284
1349
  def curves_results(self):
1285
- """Returns a list of curves for accessing specific metrics curves."""
1350
+ """Return a list of curves for accessing specific metrics curves."""
1286
1351
  return []