dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/loss.py CHANGED
@@ -1,5 +1,9 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
3
7
  import torch
4
8
  import torch.nn as nn
5
9
  import torch.nn.functional as F
@@ -14,23 +18,26 @@ from .tal import bbox2dist
14
18
 
15
19
 
16
20
  class VarifocalLoss(nn.Module):
17
- """
18
- Varifocal loss by Zhang et al.
21
+ """Varifocal loss by Zhang et al.
19
22
 
20
- https://arxiv.org/abs/2008.13367.
23
+ Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
24
+ hard-to-classify examples and balancing positive/negative samples.
21
25
 
22
- Args:
26
+ Attributes:
23
27
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
24
28
  alpha (float): The balancing factor used to address class imbalance.
29
+
30
+ References:
31
+ https://arxiv.org/abs/2008.13367
25
32
  """
26
33
 
27
- def __init__(self, gamma=2.0, alpha=0.75):
28
- """Initialize the VarifocalLoss class."""
34
+ def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
35
+ """Initialize the VarifocalLoss class with focusing and balancing parameters."""
29
36
  super().__init__()
30
37
  self.gamma = gamma
31
38
  self.alpha = alpha
32
39
 
33
- def forward(self, pred_score, gt_score, label):
40
+ def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
34
41
  """Compute varifocal loss between predictions and ground truth."""
35
42
  weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
36
43
  with autocast(enabled=False):
@@ -43,21 +50,23 @@ class VarifocalLoss(nn.Module):
43
50
 
44
51
 
45
52
  class FocalLoss(nn.Module):
46
- """
47
- Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
53
+ """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
48
54
 
49
- Args:
55
+ Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing on
56
+ hard negatives during training.
57
+
58
+ Attributes:
50
59
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
51
- alpha (float | list): The balancing factor used to address class imbalance.
60
+ alpha (torch.Tensor): The balancing factor used to address class imbalance.
52
61
  """
53
62
 
54
- def __init__(self, gamma=1.5, alpha=0.25):
55
- """Initialize FocalLoss class with no parameters."""
63
+ def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
64
+ """Initialize FocalLoss class with focusing and balancing parameters."""
56
65
  super().__init__()
57
66
  self.gamma = gamma
58
67
  self.alpha = torch.tensor(alpha)
59
68
 
60
- def forward(self, pred, label):
69
+ def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
61
70
  """Calculate focal loss with modulating factors for class imbalance."""
62
71
  loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
63
72
  # p_t = torch.exp(-loss)
@@ -78,12 +87,12 @@ class FocalLoss(nn.Module):
78
87
  class DFLoss(nn.Module):
79
88
  """Criterion class for computing Distribution Focal Loss (DFL)."""
80
89
 
81
- def __init__(self, reg_max=16) -> None:
90
+ def __init__(self, reg_max: int = 16) -> None:
82
91
  """Initialize the DFL module with regularization maximum."""
83
92
  super().__init__()
84
93
  self.reg_max = reg_max
85
94
 
86
- def __call__(self, pred_dist, target):
95
+ def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
87
96
  """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
88
97
  target = target.clamp_(0, self.reg_max - 1 - 0.01)
89
98
  tl = target.long() # target left
@@ -99,12 +108,21 @@ class DFLoss(nn.Module):
99
108
  class BboxLoss(nn.Module):
100
109
  """Criterion class for computing training losses for bounding boxes."""
101
110
 
102
- def __init__(self, reg_max=16):
111
+ def __init__(self, reg_max: int = 16):
103
112
  """Initialize the BboxLoss module with regularization maximum and DFL settings."""
104
113
  super().__init__()
105
114
  self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
106
115
 
107
- def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
116
+ def forward(
117
+ self,
118
+ pred_dist: torch.Tensor,
119
+ pred_bboxes: torch.Tensor,
120
+ anchor_points: torch.Tensor,
121
+ target_bboxes: torch.Tensor,
122
+ target_scores: torch.Tensor,
123
+ target_scores_sum: torch.Tensor,
124
+ fg_mask: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, torch.Tensor]:
108
126
  """Compute IoU and DFL losses for bounding boxes."""
109
127
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
110
128
  iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
@@ -124,11 +142,20 @@ class BboxLoss(nn.Module):
124
142
  class RotatedBboxLoss(BboxLoss):
125
143
  """Criterion class for computing training losses for rotated bounding boxes."""
126
144
 
127
- def __init__(self, reg_max):
128
- """Initialize the BboxLoss module with regularization maximum and DFL settings."""
145
+ def __init__(self, reg_max: int):
146
+ """Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
129
147
  super().__init__(reg_max)
130
148
 
131
- def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
149
+ def forward(
150
+ self,
151
+ pred_dist: torch.Tensor,
152
+ pred_bboxes: torch.Tensor,
153
+ anchor_points: torch.Tensor,
154
+ target_bboxes: torch.Tensor,
155
+ target_scores: torch.Tensor,
156
+ target_scores_sum: torch.Tensor,
157
+ fg_mask: torch.Tensor,
158
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
159
  """Compute IoU and DFL losses for rotated bounding boxes."""
133
160
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
134
161
  iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
@@ -148,12 +175,14 @@ class RotatedBboxLoss(BboxLoss):
148
175
  class KeypointLoss(nn.Module):
149
176
  """Criterion class for computing keypoint losses."""
150
177
 
151
- def __init__(self, sigmas) -> None:
178
+ def __init__(self, sigmas: torch.Tensor) -> None:
152
179
  """Initialize the KeypointLoss class with keypoint sigmas."""
153
180
  super().__init__()
154
181
  self.sigmas = sigmas
155
182
 
156
- def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
183
+ def forward(
184
+ self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
185
+ ) -> torch.Tensor:
157
186
  """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
158
187
  d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
159
188
  kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
@@ -165,7 +194,7 @@ class KeypointLoss(nn.Module):
165
194
  class v8DetectionLoss:
166
195
  """Criterion class for computing training losses for YOLOv8 object detection."""
167
196
 
168
- def __init__(self, model, tal_topk=10): # model must be de-paralleled
197
+ def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
169
198
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
170
199
  device = next(model.parameters()).device # get model device
171
200
  h = model.args # hyperparameters
@@ -185,7 +214,7 @@ class v8DetectionLoss:
185
214
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
186
215
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
187
216
 
188
- def preprocess(self, targets, batch_size, scale_tensor):
217
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
189
218
  """Preprocess targets by converting to tensor format and scaling coordinates."""
190
219
  nl, ne = targets.shape
191
220
  if nl == 0:
@@ -202,7 +231,7 @@ class v8DetectionLoss:
202
231
  out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
203
232
  return out
204
233
 
205
- def bbox_decode(self, anchor_points, pred_dist):
234
+ def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
206
235
  """Decode predicted object bounding box coordinates from anchor points and distribution."""
207
236
  if self.use_dfl:
208
237
  b, a, c = pred_dist.shape # batch, anchors, channels
@@ -211,7 +240,7 @@ class v8DetectionLoss:
211
240
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
212
241
  return dist2bbox(pred_dist, anchor_points, xywh=False)
213
242
 
214
- def __call__(self, preds, batch):
243
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
215
244
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
216
245
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
217
246
  feats = preds[1] if isinstance(preds, tuple) else preds
@@ -229,7 +258,7 @@ class v8DetectionLoss:
229
258
 
230
259
  # Targets
231
260
  targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
232
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
261
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
233
262
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
234
263
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
235
264
 
@@ -256,9 +285,14 @@ class v8DetectionLoss:
256
285
 
257
286
  # Bbox loss
258
287
  if fg_mask.sum():
259
- target_bboxes /= stride_tensor
260
288
  loss[0], loss[2] = self.bbox_loss(
261
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
289
+ pred_distri,
290
+ pred_bboxes,
291
+ anchor_points,
292
+ target_bboxes / stride_tensor,
293
+ target_scores,
294
+ target_scores_sum,
295
+ fg_mask,
262
296
  )
263
297
 
264
298
  loss[0] *= self.hyp.box # box gain
@@ -276,7 +310,7 @@ class v8SegmentationLoss(v8DetectionLoss):
276
310
  super().__init__(model)
277
311
  self.overlap = model.args.overlap_mask
278
312
 
279
- def __call__(self, preds, batch):
313
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
280
314
  """Calculate and return the combined loss for detection and segmentation."""
281
315
  loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
282
316
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
@@ -298,7 +332,7 @@ class v8SegmentationLoss(v8DetectionLoss):
298
332
  try:
299
333
  batch_idx = batch["batch_idx"].view(-1, 1)
300
334
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
301
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
335
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
302
336
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
303
337
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
304
338
  except RuntimeError as e:
@@ -357,21 +391,20 @@ class v8SegmentationLoss(v8DetectionLoss):
357
391
  loss[2] *= self.hyp.cls # cls gain
358
392
  loss[3] *= self.hyp.dfl # dfl gain
359
393
 
360
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
394
+ return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
361
395
 
362
396
  @staticmethod
363
397
  def single_mask_loss(
364
398
  gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
365
399
  ) -> torch.Tensor:
366
- """
367
- Compute the instance segmentation loss for a single image.
400
+ """Compute the instance segmentation loss for a single image.
368
401
 
369
402
  Args:
370
- gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
371
- pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
403
+ gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
404
+ pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
372
405
  proto (torch.Tensor): Prototype masks of shape (32, H, W).
373
- xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
374
- area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
406
+ xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
407
+ area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
375
408
 
376
409
  Returns:
377
410
  (torch.Tensor): The calculated mask loss for a single image.
@@ -396,8 +429,7 @@ class v8SegmentationLoss(v8DetectionLoss):
396
429
  imgsz: torch.Tensor,
397
430
  overlap: bool,
398
431
  ) -> torch.Tensor:
399
- """
400
- Calculate the loss for instance segmentation.
432
+ """Calculate the loss for instance segmentation.
401
433
 
402
434
  Args:
403
435
  fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
@@ -464,7 +496,7 @@ class v8PoseLoss(v8DetectionLoss):
464
496
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
465
497
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
466
498
 
467
- def __call__(self, preds, batch):
499
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
468
500
  """Calculate the total loss and detach it for pose estimation."""
469
501
  loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
470
502
  feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
@@ -485,7 +517,7 @@ class v8PoseLoss(v8DetectionLoss):
485
517
  batch_size = pred_scores.shape[0]
486
518
  batch_idx = batch["batch_idx"].view(-1, 1)
487
519
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
488
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
520
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
489
521
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
490
522
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
491
523
 
@@ -531,7 +563,7 @@ class v8PoseLoss(v8DetectionLoss):
531
563
  return loss * batch_size, loss.detach() # loss(box, cls, dfl)
532
564
 
533
565
  @staticmethod
534
- def kpts_decode(anchor_points, pred_kpts):
566
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
535
567
  """Decode predicted keypoints to image coordinates."""
536
568
  y = pred_kpts.clone()
537
569
  y[..., :2] *= 2.0
@@ -540,10 +572,16 @@ class v8PoseLoss(v8DetectionLoss):
540
572
  return y
541
573
 
542
574
  def calculate_keypoints_loss(
543
- self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
544
- ):
545
- """
546
- Calculate the keypoints loss for the model.
575
+ self,
576
+ masks: torch.Tensor,
577
+ target_gt_idx: torch.Tensor,
578
+ keypoints: torch.Tensor,
579
+ batch_idx: torch.Tensor,
580
+ stride_tensor: torch.Tensor,
581
+ target_bboxes: torch.Tensor,
582
+ pred_kpts: torch.Tensor,
583
+ ) -> tuple[torch.Tensor, torch.Tensor]:
584
+ """Calculate the keypoints loss for the model.
547
585
 
548
586
  This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
549
587
  based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
@@ -609,12 +647,11 @@ class v8PoseLoss(v8DetectionLoss):
609
647
  class v8ClassificationLoss:
610
648
  """Criterion class for computing training losses for classification."""
611
649
 
612
- def __call__(self, preds, batch):
650
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
613
651
  """Compute the classification loss between predictions and true labels."""
614
652
  preds = preds[1] if isinstance(preds, (list, tuple)) else preds
615
653
  loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
616
- loss_items = loss.detach()
617
- return loss, loss_items
654
+ return loss, loss.detach()
618
655
 
619
656
 
620
657
  class v8OBBLoss(v8DetectionLoss):
@@ -626,7 +663,7 @@ class v8OBBLoss(v8DetectionLoss):
626
663
  self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
627
664
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
628
665
 
629
- def preprocess(self, targets, batch_size, scale_tensor):
666
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
630
667
  """Preprocess targets for oriented bounding box detection."""
631
668
  if targets.shape[0] == 0:
632
669
  out = torch.zeros(batch_size, 0, 6, device=self.device)
@@ -643,7 +680,7 @@ class v8OBBLoss(v8DetectionLoss):
643
680
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
644
681
  return out
645
682
 
646
- def __call__(self, preds, batch):
683
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
647
684
  """Calculate and return the loss for oriented bounding box detection."""
648
685
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
649
686
  feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
@@ -667,7 +704,7 @@ class v8OBBLoss(v8DetectionLoss):
667
704
  targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
668
705
  rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
669
706
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
670
- targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
707
+ targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
671
708
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
672
709
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
673
710
  except RuntimeError as e:
@@ -715,9 +752,10 @@ class v8OBBLoss(v8DetectionLoss):
715
752
 
716
753
  return loss * batch_size, loss.detach() # loss(box, cls, dfl)
717
754
 
718
- def bbox_decode(self, anchor_points, pred_dist, pred_angle):
719
- """
720
- Decode predicted object bounding box coordinates from anchor points and distribution.
755
+ def bbox_decode(
756
+ self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
757
+ ) -> torch.Tensor:
758
+ """Decode predicted object bounding box coordinates from anchor points and distribution.
721
759
 
722
760
  Args:
723
761
  anchor_points (torch.Tensor): Anchor points, (h*w, 2).
@@ -741,7 +779,7 @@ class E2EDetectLoss:
741
779
  self.one2many = v8DetectionLoss(model, tal_topk=10)
742
780
  self.one2one = v8DetectionLoss(model, tal_topk=1)
743
781
 
744
- def __call__(self, preds, batch):
782
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
745
783
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
746
784
  preds = preds[1] if isinstance(preds, tuple) else preds
747
785
  one2many = preds["one2many"]
@@ -762,7 +800,7 @@ class TVPDetectLoss:
762
800
  self.ori_no = self.vp_criterion.no
763
801
  self.ori_reg_max = self.vp_criterion.reg_max
764
802
 
765
- def __call__(self, preds, batch):
803
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
766
804
  """Calculate the loss for text-visual prompt detection."""
767
805
  feats = preds[1] if isinstance(preds, tuple) else preds
768
806
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
@@ -776,7 +814,7 @@ class TVPDetectLoss:
776
814
  box_loss = vp_loss[0][1]
777
815
  return box_loss, vp_loss[1]
778
816
 
779
- def _get_vp_features(self, feats):
817
+ def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
780
818
  """Extract visual-prompt features from the model output."""
781
819
  vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
782
820
 
@@ -798,7 +836,7 @@ class TVPSegmentLoss(TVPDetectLoss):
798
836
  super().__init__(model)
799
837
  self.vp_criterion = v8SegmentationLoss(model)
800
838
 
801
- def __call__(self, preds, batch):
839
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
802
840
  """Calculate the loss for text-visual prompt segmentation."""
803
841
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
804
842
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it