ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
  import torch.nn as nn
@@ -7,6 +7,8 @@ import torch.nn.functional as F
7
7
  from ultralytics.utils.metrics import OKS_SIGMA
8
8
  from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
9
9
  from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
10
+ from ultralytics.utils.torch_utils import autocast
11
+
10
12
  from .metrics import bbox_iou, probiou
11
13
  from .tal import bbox2dist
12
14
 
@@ -26,7 +28,7 @@ class VarifocalLoss(nn.Module):
26
28
  def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
27
29
  """Computes varfocal loss."""
28
30
  weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
29
- with torch.cuda.amp.autocast(enabled=False):
31
+ with autocast(enabled=False):
30
32
  loss = (
31
33
  (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
32
34
  .mean(1)
@@ -60,39 +62,22 @@ class FocalLoss(nn.Module):
60
62
  return loss.mean(1).sum()
61
63
 
62
64
 
63
- class BboxLoss(nn.Module):
64
- """Criterion class for computing training losses during training."""
65
+ class DFLoss(nn.Module):
66
+ """Criterion class for computing DFL losses during training."""
65
67
 
66
- def __init__(self, reg_max, use_dfl=False):
67
- """Initialize the BboxLoss module with regularization maximum and DFL settings."""
68
+ def __init__(self, reg_max=16) -> None:
69
+ """Initialize the DFL module."""
68
70
  super().__init__()
69
71
  self.reg_max = reg_max
70
- self.use_dfl = use_dfl
71
-
72
- def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
73
- """IoU loss."""
74
- weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
75
- iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
76
- loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
77
72
 
78
- # DFL loss
79
- if self.use_dfl:
80
- target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
81
- loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
82
- loss_dfl = loss_dfl.sum() / target_scores_sum
83
- else:
84
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
85
-
86
- return loss_iou, loss_dfl
87
-
88
- @staticmethod
89
- def _df_loss(pred_dist, target):
73
+ def __call__(self, pred_dist, target):
90
74
  """
91
75
  Return sum of left and right DFL losses.
92
76
 
93
77
  Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
94
78
  https://ieeexplore.ieee.org/document/9792391
95
79
  """
80
+ target = target.clamp_(0, self.reg_max - 1 - 0.01)
96
81
  tl = target.long() # target left
97
82
  tr = tl + 1 # target right
98
83
  wl = tr - target # weight left
@@ -103,12 +88,37 @@ class BboxLoss(nn.Module):
103
88
  ).mean(-1, keepdim=True)
104
89
 
105
90
 
91
+ class BboxLoss(nn.Module):
92
+ """Criterion class for computing training losses during training."""
93
+
94
+ def __init__(self, reg_max=16):
95
+ """Initialize the BboxLoss module with regularization maximum and DFL settings."""
96
+ super().__init__()
97
+ self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
98
+
99
+ def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
100
+ """IoU loss."""
101
+ weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
102
+ iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
103
+ loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
104
+
105
+ # DFL loss
106
+ if self.dfl_loss:
107
+ target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
108
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
109
+ loss_dfl = loss_dfl.sum() / target_scores_sum
110
+ else:
111
+ loss_dfl = torch.tensor(0.0).to(pred_dist.device)
112
+
113
+ return loss_iou, loss_dfl
114
+
115
+
106
116
  class RotatedBboxLoss(BboxLoss):
107
117
  """Criterion class for computing training losses during training."""
108
118
 
109
- def __init__(self, reg_max, use_dfl=False):
119
+ def __init__(self, reg_max):
110
120
  """Initialize the BboxLoss module with regularization maximum and DFL settings."""
111
- super().__init__(reg_max, use_dfl)
121
+ super().__init__(reg_max)
112
122
 
113
123
  def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
114
124
  """IoU loss."""
@@ -117,9 +127,9 @@ class RotatedBboxLoss(BboxLoss):
117
127
  loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
118
128
 
119
129
  # DFL loss
120
- if self.use_dfl:
121
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
122
- loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
130
+ if self.dfl_loss:
131
+ target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
132
+ loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
123
133
  loss_dfl = loss_dfl.sum() / target_scores_sum
124
134
  else:
125
135
  loss_dfl = torch.tensor(0.0).to(pred_dist.device)
@@ -147,7 +157,7 @@ class KeypointLoss(nn.Module):
147
157
  class v8DetectionLoss:
148
158
  """Criterion class for computing training losses."""
149
159
 
150
- def __init__(self, model): # model must be de-paralleled
160
+ def __init__(self, model, tal_topk=10): # model must be de-paralleled
151
161
  """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
152
162
  device = next(model.parameters()).device # get model device
153
163
  h = model.args # hyperparameters
@@ -157,29 +167,29 @@ class v8DetectionLoss:
157
167
  self.hyp = h
158
168
  self.stride = m.stride # model strides
159
169
  self.nc = m.nc # number of classes
160
- self.no = m.no
170
+ self.no = m.nc + m.reg_max * 4
161
171
  self.reg_max = m.reg_max
162
172
  self.device = device
163
173
 
164
174
  self.use_dfl = m.reg_max > 1
165
175
 
166
- self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
167
- self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
176
+ self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
177
+ self.bbox_loss = BboxLoss(m.reg_max).to(device)
168
178
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
169
179
 
170
180
  def preprocess(self, targets, batch_size, scale_tensor):
171
181
  """Preprocesses the target counts and matches with the input batch size to output a tensor."""
172
- if targets.shape[0] == 0:
173
- out = torch.zeros(batch_size, 0, 5, device=self.device)
182
+ nl, ne = targets.shape
183
+ if nl == 0:
184
+ out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
174
185
  else:
175
186
  i = targets[:, 0] # image index
176
187
  _, counts = i.unique(return_counts=True)
177
188
  counts = counts.to(dtype=torch.int32)
178
- out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
189
+ out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
179
190
  for j in range(batch_size):
180
191
  matches = i == j
181
- n = matches.sum()
182
- if n:
192
+ if n := matches.sum():
183
193
  out[j, :n] = targets[matches, 1:]
184
194
  out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
185
195
  return out
@@ -213,12 +223,15 @@ class v8DetectionLoss:
213
223
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
214
224
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
215
225
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
216
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
226
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
217
227
 
218
228
  # Pboxes
219
229
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
230
+ # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
231
+ # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
220
232
 
221
233
  _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
234
+ # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
222
235
  pred_scores.detach().sigmoid(),
223
236
  (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
224
237
  anchor_points * stride_tensor,
@@ -279,7 +292,7 @@ class v8SegmentationLoss(v8DetectionLoss):
279
292
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
280
293
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
281
294
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
282
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
295
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
283
296
  except RuntimeError as e:
284
297
  raise TypeError(
285
298
  "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
@@ -466,7 +479,7 @@ class v8PoseLoss(v8DetectionLoss):
466
479
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
467
480
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
468
481
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
469
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
482
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
470
483
 
471
484
  # Pboxes
472
485
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
@@ -538,9 +551,8 @@ class v8PoseLoss(v8DetectionLoss):
538
551
  pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
539
552
 
540
553
  Returns:
541
- (tuple): Returns a tuple containing:
542
- - kpts_loss (torch.Tensor): The keypoints loss.
543
- - kpts_obj_loss (torch.Tensor): The keypoints object loss.
554
+ kpts_loss (torch.Tensor): The keypoints loss.
555
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
544
556
  """
545
557
  batch_idx = batch_idx.flatten()
546
558
  batch_size = len(masks)
@@ -591,21 +603,20 @@ class v8ClassificationLoss:
591
603
 
592
604
  def __call__(self, preds, batch):
593
605
  """Compute the classification loss between predictions and true labels."""
594
- loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
606
+ preds = preds[1] if isinstance(preds, (list, tuple)) else preds
607
+ loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
595
608
  loss_items = loss.detach()
596
609
  return loss, loss_items
597
610
 
598
611
 
599
612
  class v8OBBLoss(v8DetectionLoss):
600
- def __init__(self, model):
601
- """
602
- Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
613
+ """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
603
614
 
604
- Note model must be de-paralleled.
605
- """
615
+ def __init__(self, model):
616
+ """Initializes v8OBBLoss with model, assigner, and rotated bbox loss; note model must be de-paralleled."""
606
617
  super().__init__(model)
607
618
  self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
608
- self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)
619
+ self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
609
620
 
610
621
  def preprocess(self, targets, batch_size, scale_tensor):
611
622
  """Preprocesses the target counts and matches with the input batch size to output a tensor."""
@@ -618,8 +629,7 @@ class v8OBBLoss(v8DetectionLoss):
618
629
  out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
619
630
  for j in range(batch_size):
620
631
  matches = i == j
621
- n = matches.sum()
622
- if n:
632
+ if n := matches.sum():
623
633
  bboxes = targets[matches, 2:]
624
634
  bboxes[..., :4].mul_(scale_tensor)
625
635
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
@@ -651,7 +661,7 @@ class v8OBBLoss(v8DetectionLoss):
651
661
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
652
662
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
653
663
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
654
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
664
+ mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
655
665
  except RuntimeError as e:
656
666
  raise TypeError(
657
667
  "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
@@ -713,3 +723,21 @@ class v8OBBLoss(v8DetectionLoss):
713
723
  b, a, c = pred_dist.shape # batch, anchors, channels
714
724
  pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
715
725
  return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
726
+
727
+
728
+ class E2EDetectLoss:
729
+ """Criterion class for computing training losses."""
730
+
731
+ def __init__(self, model):
732
+ """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
733
+ self.one2many = v8DetectionLoss(model, tal_topk=10)
734
+ self.one2one = v8DetectionLoss(model, tal_topk=1)
735
+
736
+ def __call__(self, preds, batch):
737
+ """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
738
+ preds = preds[1] if isinstance(preds, tuple) else preds
739
+ one2many = preds["one2many"]
740
+ loss_one2many = self.one2many(one2many, batch)
741
+ one2one = preds["one2one"]
742
+ loss_one2one = self.one2one(one2one, batch)
743
+ return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """Model validation metrics."""
3
3
 
4
4
  import math
@@ -30,7 +30,6 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
30
30
  Returns:
31
31
  (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
32
32
  """
33
-
34
33
  # Get the coordinates of bounding boxes
35
34
  b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
36
35
  b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
@@ -53,7 +52,7 @@ def bbox_ioa(box1, box2, iou=False, eps=1e-7):
53
52
  def box_iou(box1, box2, eps=1e-7):
54
53
  """
55
54
  Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
56
- Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
55
+ Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
57
56
 
58
57
  Args:
59
58
  box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
@@ -63,9 +62,9 @@ def box_iou(box1, box2, eps=1e-7):
63
62
  Returns:
64
63
  (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
65
64
  """
66
-
65
+ # NOTE: Need .float() to get accurate iou values
67
66
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
68
- (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
67
+ (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
69
68
  inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
70
69
 
71
70
  # IoU = inter / (area1 + area2 - inter)
@@ -74,11 +73,16 @@ def box_iou(box1, box2, eps=1e-7):
74
73
 
75
74
  def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
76
75
  """
77
- Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
76
+ Calculates the Intersection over Union (IoU) between bounding boxes.
77
+
78
+ This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
79
+ For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
80
+ Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
81
+ or (x1, y1, x2, y2) if `xywh=False`.
78
82
 
79
83
  Args:
80
- box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
81
- box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
84
+ box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
85
+ box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
82
86
  xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
83
87
  (x1, y1, x2, y2) format. Defaults to True.
84
88
  GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
@@ -89,7 +93,6 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
89
93
  Returns:
90
94
  (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
91
95
  """
92
-
93
96
  # Get the coordinates of bounding boxes
94
97
  if xywh: # transform from xywh to xyxy
95
98
  (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
@@ -167,7 +170,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
167
170
  d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
168
171
  sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
169
172
  kpt_mask = kpt1[..., 2] != 0 # (N, 17)
170
- e = d / (2 * sigma).pow(2) / (area[:, None, None] + eps) / 2 # from cocoeval
173
+ e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
171
174
  # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
172
175
  return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
173
176
 
@@ -180,7 +183,7 @@ def _get_covariance_matrix(boxes):
180
183
  boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
181
184
 
182
185
  Returns:
183
- (torch.Tensor): Covariance metrixs corresponding to original rotated bounding boxes.
186
+ (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
184
187
  """
185
188
  # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
186
189
  gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
@@ -194,15 +197,22 @@ def _get_covariance_matrix(boxes):
194
197
 
195
198
  def probiou(obb1, obb2, CIoU=False, eps=1e-7):
196
199
  """
197
- Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
200
+ Calculate probabilistic IoU between oriented bounding boxes.
201
+
202
+ Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
198
203
 
199
204
  Args:
200
- obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
201
- obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
202
- eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
205
+ obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
206
+ obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
207
+ CIoU (bool, optional): If True, calculate CIoU. Defaults to False.
208
+ eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
203
209
 
204
210
  Returns:
205
- (torch.Tensor): A tensor of shape (N, ) representing obb similarities.
211
+ (torch.Tensor): OBB similarities, shape (N,).
212
+
213
+ Note:
214
+ OBB format: [center_x, center_y, width, height, rotation_angle].
215
+ If CIoU is True, returns CIoU instead of IoU.
206
216
  """
207
217
  x1, y1 = obb1[..., :2].split(1, dim=-1)
208
218
  x2, y2 = obb2[..., :2].split(1, dim=-1)
@@ -265,7 +275,7 @@ def batch_probiou(obb1, obb2, eps=1e-7):
265
275
  return 1 - hd
266
276
 
267
277
 
268
- def smooth_BCE(eps=0.1):
278
+ def smooth_bce(eps=0.1):
269
279
  """
270
280
  Computes smoothed positive and negative Binary Cross-Entropy targets.
271
281
 
@@ -298,7 +308,7 @@ class ConfusionMatrix:
298
308
  self.task = task
299
309
  self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
300
310
  self.nc = nc # number of classes
301
- self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
311
+ self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
302
312
  self.iou_thres = iou_thres
303
313
 
304
314
  def process_cls_preds(self, preds, targets):
@@ -367,10 +377,9 @@ class ConfusionMatrix:
367
377
  else:
368
378
  self.matrix[self.nc, gc] += 1 # true background
369
379
 
370
- if n:
371
- for i, dc in enumerate(detection_classes):
372
- if not any(m1 == i):
373
- self.matrix[dc, self.nc] += 1 # predicted background
380
+ for i, dc in enumerate(detection_classes):
381
+ if not any(m1 == i):
382
+ self.matrix[dc, self.nc] += 1 # predicted background
374
383
 
375
384
  def matrix(self):
376
385
  """Returns the confusion matrix."""
@@ -395,19 +404,19 @@ class ConfusionMatrix:
395
404
  names (tuple): Names of classes, used as labels on the plot.
396
405
  on_plot (func): An optional callback to pass plots path and data when they are rendered.
397
406
  """
398
- import seaborn as sn
407
+ import seaborn # scope for faster 'import ultralytics'
399
408
 
400
409
  array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
401
410
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
402
411
 
403
412
  fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
404
413
  nc, nn = self.nc, len(names) # number of classes, names
405
- sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
414
+ seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
406
415
  labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
407
416
  ticklabels = (list(names) + ["background"]) if labels else "auto"
408
417
  with warnings.catch_warnings():
409
418
  warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
410
- sn.heatmap(
419
+ seaborn.heatmap(
411
420
  array,
412
421
  ax=ax,
413
422
  annot=nc < 30,
@@ -423,7 +432,7 @@ class ConfusionMatrix:
423
432
  ax.set_xlabel("True")
424
433
  ax.set_ylabel("Predicted")
425
434
  ax.set_title(title)
426
- plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
435
+ plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
427
436
  fig.savefig(plot_fname, dpi=250)
428
437
  plt.close(fig)
429
438
  if on_plot:
@@ -444,7 +453,7 @@ def smooth(y, f=0.05):
444
453
 
445
454
 
446
455
  @plt_settings()
447
- def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=None):
456
+ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
448
457
  """Plots a precision-recall curve."""
449
458
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
450
459
  py = np.stack(py, axis=1)
@@ -455,7 +464,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N
455
464
  else:
456
465
  ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
457
466
 
458
- ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
467
+ ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
459
468
  ax.set_xlabel("Recall")
460
469
  ax.set_ylabel("Precision")
461
470
  ax.set_xlim(0, 1)
@@ -469,7 +478,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names=(), on_plot=N
469
478
 
470
479
 
471
480
  @plt_settings()
472
- def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names=(), xlabel="Confidence", ylabel="Metric", on_plot=None):
481
+ def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
473
482
  """Plots a metric-confidence curve."""
474
483
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
475
484
 
@@ -506,7 +515,6 @@ def compute_ap(recall, precision):
506
515
  (np.ndarray): Precision envelope curve.
507
516
  (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
508
517
  """
509
-
510
518
  # Append sentinel values to beginning and end
511
519
  mrec = np.concatenate(([0.0], recall, [1.0]))
512
520
  mpre = np.concatenate(([1.0], precision, [0.0]))
@@ -527,7 +535,7 @@ def compute_ap(recall, precision):
527
535
 
528
536
 
529
537
  def ap_per_class(
530
- tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names=(), eps=1e-16, prefix=""
538
+ tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
531
539
  ):
532
540
  """
533
541
  Computes the average precision per class for object detection evaluation.
@@ -540,26 +548,24 @@ def ap_per_class(
540
548
  plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
541
549
  on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
542
550
  save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
543
- names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
551
+ names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple.
544
552
  eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
545
553
  prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
546
554
 
547
555
  Returns:
548
- (tuple): A tuple of six arrays and one array of unique classes, where:
549
- tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
550
- fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
551
- p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
552
- r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
553
- f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
554
- ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
555
- unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
556
- p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
557
- r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
558
- f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
559
- x (np.ndarray): X-axis values for the curves. Shape: (1000,).
560
- prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
556
+ tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
557
+ fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
558
+ p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
559
+ r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
560
+ f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
561
+ ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
562
+ unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
563
+ p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
564
+ r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
565
+ f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
566
+ x (np.ndarray): X-axis values for the curves. Shape: (1000,).
567
+ prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
561
568
  """
562
-
563
569
  # Sort by objectness
564
570
  i = np.argsort(-conf)
565
571
  tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
@@ -595,7 +601,7 @@ def ap_per_class(
595
601
  # AP from recall-precision curve
596
602
  for j in range(tp.shape[1]):
597
603
  ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
598
- if plot and j == 0:
604
+ if j == 0:
599
605
  prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5
600
606
 
601
607
  prec_values = np.array(prec_values) # (nc, 1000)
@@ -791,20 +797,20 @@ class Metric(SimpleClass):
791
797
 
792
798
  class DetMetrics(SimpleClass):
793
799
  """
794
- This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
795
- (mAP) of an object detection model.
800
+ Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
801
+ object detection model.
796
802
 
797
803
  Args:
798
804
  save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
799
805
  plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
800
806
  on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
801
- names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
807
+ names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.
802
808
 
803
809
  Attributes:
804
810
  save_dir (Path): A path to the directory where the output plots will be saved.
805
811
  plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
806
812
  on_plot (func): An optional callback to pass plots path and data when they are rendered.
807
- names (tuple of str): A tuple of strings that represents the names of the classes.
813
+ names (dict of str): A dict of strings that represents the names of the classes.
808
814
  box (Metric): An instance of the Metric class for storing the results of the detection metrics.
809
815
  speed (dict): A dictionary for storing the execution time of different parts of the detection process.
810
816
 
@@ -821,7 +827,7 @@ class DetMetrics(SimpleClass):
821
827
  curves_results: TODO
822
828
  """
823
829
 
824
- def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
830
+ def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None:
825
831
  """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names."""
826
832
  self.save_dir = save_dir
827
833
  self.plot = plot
@@ -941,7 +947,6 @@ class SegmentMetrics(SimpleClass):
941
947
  pred_cls (list): List of predicted classes.
942
948
  target_cls (list): List of target classes.
943
949
  """
944
-
945
950
  results_mask = ap_per_class(
946
951
  tp_m,
947
952
  conf,
@@ -1083,7 +1088,6 @@ class PoseMetrics(SegmentMetrics):
1083
1088
  pred_cls (list): List of predicted classes.
1084
1089
  target_cls (list): List of target classes.
1085
1090
  """
1086
-
1087
1091
  results_pose = ap_per_class(
1088
1092
  tp_p,
1089
1093
  conf,
@@ -1171,8 +1175,6 @@ class ClassifyMetrics(SimpleClass):
1171
1175
  top1 (float): The top-1 accuracy.
1172
1176
  top5 (float): The top-5 accuracy.
1173
1177
  speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
1174
-
1175
- Properties:
1176
1178
  fitness (float): The fitness of the model, which is equal to top-5 accuracy.
1177
1179
  results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
1178
1180
  keys (List[str]): A list of keys for the results_dict.
@@ -1222,7 +1224,10 @@ class ClassifyMetrics(SimpleClass):
1222
1224
 
1223
1225
 
1224
1226
  class OBBMetrics(SimpleClass):
1227
+ """Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf."""
1228
+
1225
1229
  def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None:
1230
+ """Initialize an OBBMetrics instance with directory, plotting, callback, and class names."""
1226
1231
  self.save_dir = save_dir
1227
1232
  self.plot = plot
1228
1233
  self.on_plot = on_plot