dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@ from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy,
11
11
 
12
12
 
13
13
  def _ntuple(n):
14
- """From PyTorch internals."""
14
+ """Create a function that converts input to n-tuple by repeating singleton values."""
15
15
 
16
16
  def parse(x):
17
17
  """Parse input to return n-tuple by repeating singleton values n times."""
@@ -33,16 +33,29 @@ __all__ = ("Bboxes", "Instances") # tuple or list
33
33
 
34
34
  class Bboxes:
35
35
  """
36
- A class for handling bounding boxes.
36
+ A class for handling bounding boxes in multiple formats.
37
37
 
38
- The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
39
- Bounding box data should be provided in numpy arrays.
38
+ The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
39
+ conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
40
40
 
41
41
  Attributes:
42
42
  bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).
43
43
  format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
44
44
 
45
- Note:
45
+ Methods:
46
+ convert: Convert bounding box format from one type to another.
47
+ areas: Calculate the area of bounding boxes.
48
+ mul: Multiply bounding box coordinates by scale factor(s).
49
+ add: Add offset to bounding box coordinates.
50
+ concatenate: Concatenate multiple Bboxes objects.
51
+
52
+ Examples:
53
+ Create bounding boxes in YOLO format
54
+ >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh")
55
+ >>> bboxes.convert("xyxy")
56
+ >>> print(bboxes.areas())
57
+
58
+ Notes:
46
59
  This class does not handle normalization or denormalization of bounding boxes.
47
60
  """
48
61
 
@@ -60,7 +73,6 @@ class Bboxes:
60
73
  assert bboxes.shape[1] == 4
61
74
  self.bboxes = bboxes
62
75
  self.format = format
63
- # self.normalized = normalized
64
76
 
65
77
  def convert(self, format):
66
78
  """
@@ -82,36 +94,20 @@ class Bboxes:
82
94
  self.format = format
83
95
 
84
96
  def areas(self):
85
- """Return box areas."""
97
+ """Calculate the area of bounding boxes."""
86
98
  return (
87
99
  (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy
88
100
  if self.format == "xyxy"
89
101
  else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
90
102
  )
91
103
 
92
- # def denormalize(self, w, h):
93
- # if not self.normalized:
94
- # return
95
- # assert (self.bboxes <= 1.0).all()
96
- # self.bboxes[:, 0::2] *= w
97
- # self.bboxes[:, 1::2] *= h
98
- # self.normalized = False
99
- #
100
- # def normalize(self, w, h):
101
- # if self.normalized:
102
- # return
103
- # assert (self.bboxes > 1.0).any()
104
- # self.bboxes[:, 0::2] /= w
105
- # self.bboxes[:, 1::2] /= h
106
- # self.normalized = True
107
-
108
104
  def mul(self, scale):
109
105
  """
110
106
  Multiply bounding box coordinates by scale factor(s).
111
107
 
112
108
  Args:
113
- scale (int | tuple | list): Scale factor(s) for four coordinates.
114
- If int, the same scale is applied to all coordinates.
109
+ scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
110
+ all coordinates.
115
111
  """
116
112
  if isinstance(scale, Number):
117
113
  scale = to_4tuple(scale)
@@ -127,8 +123,8 @@ class Bboxes:
127
123
  Add offset to bounding box coordinates.
128
124
 
129
125
  Args:
130
- offset (int | tuple | list): Offset(s) for four coordinates.
131
- If int, the same offset is applied to all coordinates.
126
+ offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
127
+ all coordinates.
132
128
  """
133
129
  if isinstance(offset, Number):
134
130
  offset = to_4tuple(offset)
@@ -140,7 +136,7 @@ class Bboxes:
140
136
  self.bboxes[:, 3] += offset[3]
141
137
 
142
138
  def __len__(self):
143
- """Return the number of boxes."""
139
+ """Return the number of bounding boxes."""
144
140
  return len(self.bboxes)
145
141
 
146
142
  @classmethod
@@ -155,7 +151,7 @@ class Bboxes:
155
151
  Returns:
156
152
  (Bboxes): A new Bboxes object containing the concatenated bounding boxes.
157
153
 
158
- Note:
154
+ Notes:
159
155
  The input should be a list or tuple of Bboxes objects.
160
156
  """
161
157
  assert isinstance(boxes_list, (list, tuple))
@@ -172,18 +168,14 @@ class Bboxes:
172
168
  Retrieve a specific bounding box or a set of bounding boxes using indexing.
173
169
 
174
170
  Args:
175
- index (int | slice | np.ndarray): The index, slice, or boolean array to select
176
- the desired bounding boxes.
171
+ index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
177
172
 
178
173
  Returns:
179
174
  (Bboxes): A new Bboxes object containing the selected bounding boxes.
180
175
 
181
- Raises:
182
- AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
183
-
184
- Note:
185
- When using boolean indexing, make sure to provide a boolean array with the same
186
- length as the number of bounding boxes.
176
+ Notes:
177
+ When using boolean indexing, make sure to provide a boolean array with the same length as the number of
178
+ bounding boxes.
187
179
  """
188
180
  if isinstance(index, int):
189
181
  return Bboxes(self.bboxes[index].reshape(1, -1))
@@ -196,6 +188,10 @@ class Instances:
196
188
  """
197
189
  Container for bounding boxes, segments, and keypoints of detected objects in an image.
198
190
 
191
+ This class provides a unified interface for handling different types of object annotations including bounding
192
+ boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,
193
+ and format conversion.
194
+
199
195
  Attributes:
200
196
  _bboxes (Bboxes): Internal object for handling bounding box operations.
201
197
  keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).
@@ -216,6 +212,7 @@ class Instances:
216
212
  concatenate: Concatenate multiple Instances objects.
217
213
 
218
214
  Examples:
215
+ Create instances with bounding boxes and segments
219
216
  >>> instances = Instances(
220
217
  ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
221
218
  ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
@@ -225,14 +222,14 @@ class Instances:
225
222
 
226
223
  def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
227
224
  """
228
- Initialize the object with bounding boxes, segments, and keypoints.
225
+ Initialize the Instances object with bounding boxes, segments, and keypoints.
229
226
 
230
227
  Args:
231
- bboxes (np.ndarray): Bounding boxes, shape (N, 4).
228
+ bboxes (np.ndarray): Bounding boxes with shape (N, 4).
232
229
  segments (List | np.ndarray, optional): Segmentation masks.
233
- keypoints (np.ndarray, optional): Keypoints, shape (N, 17, 3) in format (x, y, visible).
234
- bbox_format (str, optional): Format of bboxes.
235
- normalized (bool, optional): Whether the coordinates are normalized.
230
+ keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).
231
+ bbox_format (str): Format of bboxes.
232
+ normalized (bool): Whether the coordinates are normalized.
236
233
  """
237
234
  self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
238
235
  self.keypoints = keypoints
@@ -333,9 +330,9 @@ class Instances:
333
330
  Returns:
334
331
  (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.
335
332
 
336
- Note:
337
- When using boolean indexing, make sure to provide a boolean array with the same
338
- length as the number of instances.
333
+ Notes:
334
+ When using boolean indexing, make sure to provide a boolean array with the same length as the number of
335
+ instances.
339
336
  """
340
337
  segments = self.segments[index] if len(self.segments) else self.segments
341
338
  keypoints = self.keypoints[index] if self.keypoints is not None else None
@@ -442,7 +439,7 @@ class Instances:
442
439
  self.keypoints = keypoints
443
440
 
444
441
  def __len__(self):
445
- """Return the length of the instance list."""
442
+ """Return the number of instances."""
446
443
  return len(self.bboxes)
447
444
 
448
445
  @classmethod
@@ -455,13 +452,12 @@ class Instances:
455
452
  axis (int, optional): The axis along which the arrays will be concatenated.
456
453
 
457
454
  Returns:
458
- (Instances): A new Instances object containing the concatenated bounding boxes,
459
- segments, and keypoints if present.
455
+ (Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
456
+ if present.
460
457
 
461
- Note:
462
- The `Instances` objects in the list should have the same properties, such as
463
- the format of the bounding boxes, whether keypoints are present, and if the
464
- coordinates are normalized.
458
+ Notes:
459
+ The `Instances` objects in the list should have the same properties, such as the format of the bounding
460
+ boxes, whether keypoints are present, and if the coordinates are normalized.
465
461
  """
466
462
  assert isinstance(instances_list, (list, tuple))
467
463
  if not instances_list:
ultralytics/utils/loss.py CHANGED
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from typing import Any, Dict, List, Tuple
4
+
3
5
  import torch
4
6
  import torch.nn as nn
5
7
  import torch.nn.functional as F
@@ -17,20 +19,24 @@ class VarifocalLoss(nn.Module):
17
19
  """
18
20
  Varifocal loss by Zhang et al.
19
21
 
20
- https://arxiv.org/abs/2008.13367.
22
+ Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
23
+ hard-to-classify examples and balancing positive/negative samples.
21
24
 
22
- Args:
25
+ Attributes:
23
26
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
24
27
  alpha (float): The balancing factor used to address class imbalance.
28
+
29
+ References:
30
+ https://arxiv.org/abs/2008.13367
25
31
  """
26
32
 
27
- def __init__(self, gamma=2.0, alpha=0.75):
28
- """Initialize the VarifocalLoss class."""
33
+ def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
34
+ """Initialize the VarifocalLoss class with focusing and balancing parameters."""
29
35
  super().__init__()
30
36
  self.gamma = gamma
31
37
  self.alpha = alpha
32
38
 
33
- def forward(self, pred_score, gt_score, label):
39
+ def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
34
40
  """Compute varifocal loss between predictions and ground truth."""
35
41
  weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
36
42
  with autocast(enabled=False):
@@ -46,18 +52,21 @@ class FocalLoss(nn.Module):
46
52
  """
47
53
  Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
48
54
 
49
- Args:
55
+ Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
56
+ on hard negatives during training.
57
+
58
+ Attributes:
50
59
  gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
51
- alpha (float | list): The balancing factor used to address class imbalance.
60
+ alpha (torch.Tensor): The balancing factor used to address class imbalance.
52
61
  """
53
62
 
54
- def __init__(self, gamma=1.5, alpha=0.25):
55
- """Initialize FocalLoss class with no parameters."""
63
+ def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
64
+ """Initialize FocalLoss class with focusing and balancing parameters."""
56
65
  super().__init__()
57
66
  self.gamma = gamma
58
67
  self.alpha = torch.tensor(alpha)
59
68
 
60
- def forward(self, pred, label):
69
+ def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
61
70
  """Calculate focal loss with modulating factors for class imbalance."""
62
71
  loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
63
72
  # p_t = torch.exp(-loss)
@@ -78,12 +87,12 @@ class FocalLoss(nn.Module):
78
87
  class DFLoss(nn.Module):
79
88
  """Criterion class for computing Distribution Focal Loss (DFL)."""
80
89
 
81
- def __init__(self, reg_max=16) -> None:
90
+ def __init__(self, reg_max: int = 16) -> None:
82
91
  """Initialize the DFL module with regularization maximum."""
83
92
  super().__init__()
84
93
  self.reg_max = reg_max
85
94
 
86
- def __call__(self, pred_dist, target):
95
+ def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
87
96
  """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
88
97
  target = target.clamp_(0, self.reg_max - 1 - 0.01)
89
98
  tl = target.long() # target left
@@ -99,12 +108,21 @@ class DFLoss(nn.Module):
99
108
  class BboxLoss(nn.Module):
100
109
  """Criterion class for computing training losses for bounding boxes."""
101
110
 
102
- def __init__(self, reg_max=16):
111
+ def __init__(self, reg_max: int = 16):
103
112
  """Initialize the BboxLoss module with regularization maximum and DFL settings."""
104
113
  super().__init__()
105
114
  self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
106
115
 
107
- def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
116
+ def forward(
117
+ self,
118
+ pred_dist: torch.Tensor,
119
+ pred_bboxes: torch.Tensor,
120
+ anchor_points: torch.Tensor,
121
+ target_bboxes: torch.Tensor,
122
+ target_scores: torch.Tensor,
123
+ target_scores_sum: torch.Tensor,
124
+ fg_mask: torch.Tensor,
125
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
126
  """Compute IoU and DFL losses for bounding boxes."""
109
127
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
110
128
  iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
@@ -124,11 +142,20 @@ class BboxLoss(nn.Module):
124
142
  class RotatedBboxLoss(BboxLoss):
125
143
  """Criterion class for computing training losses for rotated bounding boxes."""
126
144
 
127
- def __init__(self, reg_max):
128
- """Initialize the BboxLoss module with regularization maximum and DFL settings."""
145
+ def __init__(self, reg_max: int):
146
+ """Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
129
147
  super().__init__(reg_max)
130
148
 
131
- def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
149
+ def forward(
150
+ self,
151
+ pred_dist: torch.Tensor,
152
+ pred_bboxes: torch.Tensor,
153
+ anchor_points: torch.Tensor,
154
+ target_bboxes: torch.Tensor,
155
+ target_scores: torch.Tensor,
156
+ target_scores_sum: torch.Tensor,
157
+ fg_mask: torch.Tensor,
158
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
159
  """Compute IoU and DFL losses for rotated bounding boxes."""
133
160
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
134
161
  iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
@@ -148,12 +175,14 @@ class RotatedBboxLoss(BboxLoss):
148
175
  class KeypointLoss(nn.Module):
149
176
  """Criterion class for computing keypoint losses."""
150
177
 
151
- def __init__(self, sigmas) -> None:
178
+ def __init__(self, sigmas: torch.Tensor) -> None:
152
179
  """Initialize the KeypointLoss class with keypoint sigmas."""
153
180
  super().__init__()
154
181
  self.sigmas = sigmas
155
182
 
156
- def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
183
+ def forward(
184
+ self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
185
+ ) -> torch.Tensor:
157
186
  """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
158
187
  d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
159
188
  kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
@@ -165,7 +194,7 @@ class KeypointLoss(nn.Module):
165
194
  class v8DetectionLoss:
166
195
  """Criterion class for computing training losses for YOLOv8 object detection."""
167
196
 
168
- def __init__(self, model, tal_topk=10): # model must be de-paralleled
197
+ def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
169
198
  """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
170
199
  device = next(model.parameters()).device # get model device
171
200
  h = model.args # hyperparameters
@@ -185,7 +214,7 @@ class v8DetectionLoss:
185
214
  self.bbox_loss = BboxLoss(m.reg_max).to(device)
186
215
  self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
187
216
 
188
- def preprocess(self, targets, batch_size, scale_tensor):
217
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
189
218
  """Preprocess targets by converting to tensor format and scaling coordinates."""
190
219
  nl, ne = targets.shape
191
220
  if nl == 0:
@@ -202,7 +231,7 @@ class v8DetectionLoss:
202
231
  out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
203
232
  return out
204
233
 
205
- def bbox_decode(self, anchor_points, pred_dist):
234
+ def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
206
235
  """Decode predicted object bounding box coordinates from anchor points and distribution."""
207
236
  if self.use_dfl:
208
237
  b, a, c = pred_dist.shape # batch, anchors, channels
@@ -211,7 +240,7 @@ class v8DetectionLoss:
211
240
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
212
241
  return dist2bbox(pred_dist, anchor_points, xywh=False)
213
242
 
214
- def __call__(self, preds, batch):
243
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
215
244
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
216
245
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
217
246
  feats = preds[1] if isinstance(preds, tuple) else preds
@@ -276,7 +305,7 @@ class v8SegmentationLoss(v8DetectionLoss):
276
305
  super().__init__(model)
277
306
  self.overlap = model.args.overlap_mask
278
307
 
279
- def __call__(self, preds, batch):
308
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
280
309
  """Calculate and return the combined loss for detection and segmentation."""
281
310
  loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
282
311
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
@@ -367,11 +396,11 @@ class v8SegmentationLoss(v8DetectionLoss):
367
396
  Compute the instance segmentation loss for a single image.
368
397
 
369
398
  Args:
370
- gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
371
- pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
399
+ gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
400
+ pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
372
401
  proto (torch.Tensor): Prototype masks of shape (32, H, W).
373
- xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
374
- area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
402
+ xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
403
+ area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
375
404
 
376
405
  Returns:
377
406
  (torch.Tensor): The calculated mask loss for a single image.
@@ -464,7 +493,7 @@ class v8PoseLoss(v8DetectionLoss):
464
493
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
465
494
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
466
495
 
467
- def __call__(self, preds, batch):
496
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
468
497
  """Calculate the total loss and detach it for pose estimation."""
469
498
  loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
470
499
  feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
@@ -531,7 +560,7 @@ class v8PoseLoss(v8DetectionLoss):
531
560
  return loss * batch_size, loss.detach() # loss(box, cls, dfl)
532
561
 
533
562
  @staticmethod
534
- def kpts_decode(anchor_points, pred_kpts):
563
+ def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
535
564
  """Decode predicted keypoints to image coordinates."""
536
565
  y = pred_kpts.clone()
537
566
  y[..., :2] *= 2.0
@@ -540,8 +569,15 @@ class v8PoseLoss(v8DetectionLoss):
540
569
  return y
541
570
 
542
571
  def calculate_keypoints_loss(
543
- self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
544
- ):
572
+ self,
573
+ masks: torch.Tensor,
574
+ target_gt_idx: torch.Tensor,
575
+ keypoints: torch.Tensor,
576
+ batch_idx: torch.Tensor,
577
+ stride_tensor: torch.Tensor,
578
+ target_bboxes: torch.Tensor,
579
+ pred_kpts: torch.Tensor,
580
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
545
581
  """
546
582
  Calculate the keypoints loss for the model.
547
583
 
@@ -609,7 +645,7 @@ class v8PoseLoss(v8DetectionLoss):
609
645
  class v8ClassificationLoss:
610
646
  """Criterion class for computing training losses for classification."""
611
647
 
612
- def __call__(self, preds, batch):
648
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
613
649
  """Compute the classification loss between predictions and true labels."""
614
650
  preds = preds[1] if isinstance(preds, (list, tuple)) else preds
615
651
  loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
@@ -625,7 +661,7 @@ class v8OBBLoss(v8DetectionLoss):
625
661
  self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
626
662
  self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
627
663
 
628
- def preprocess(self, targets, batch_size, scale_tensor):
664
+ def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
629
665
  """Preprocess targets for oriented bounding box detection."""
630
666
  if targets.shape[0] == 0:
631
667
  out = torch.zeros(batch_size, 0, 6, device=self.device)
@@ -642,7 +678,7 @@ class v8OBBLoss(v8DetectionLoss):
642
678
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
643
679
  return out
644
680
 
645
- def __call__(self, preds, batch):
681
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
646
682
  """Calculate and return the loss for oriented bounding box detection."""
647
683
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
648
684
  feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
@@ -714,7 +750,9 @@ class v8OBBLoss(v8DetectionLoss):
714
750
 
715
751
  return loss * batch_size, loss.detach() # loss(box, cls, dfl)
716
752
 
717
- def bbox_decode(self, anchor_points, pred_dist, pred_angle):
753
+ def bbox_decode(
754
+ self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
755
+ ) -> torch.Tensor:
718
756
  """
719
757
  Decode predicted object bounding box coordinates from anchor points and distribution.
720
758
 
@@ -740,7 +778,7 @@ class E2EDetectLoss:
740
778
  self.one2many = v8DetectionLoss(model, tal_topk=10)
741
779
  self.one2one = v8DetectionLoss(model, tal_topk=1)
742
780
 
743
- def __call__(self, preds, batch):
781
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
744
782
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
745
783
  preds = preds[1] if isinstance(preds, tuple) else preds
746
784
  one2many = preds["one2many"]
@@ -761,7 +799,7 @@ class TVPDetectLoss:
761
799
  self.ori_no = self.vp_criterion.no
762
800
  self.ori_reg_max = self.vp_criterion.reg_max
763
801
 
764
- def __call__(self, preds, batch):
802
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
765
803
  """Calculate the loss for text-visual prompt detection."""
766
804
  feats = preds[1] if isinstance(preds, tuple) else preds
767
805
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
@@ -775,7 +813,7 @@ class TVPDetectLoss:
775
813
  box_loss = vp_loss[0][1]
776
814
  return box_loss, vp_loss[1]
777
815
 
778
- def _get_vp_features(self, feats):
816
+ def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
779
817
  """Extract visual-prompt features from the model output."""
780
818
  vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
781
819
 
@@ -797,7 +835,7 @@ class TVPSegmentLoss(TVPDetectLoss):
797
835
  super().__init__(model)
798
836
  self.vp_criterion = v8SegmentationLoss(model)
799
837
 
800
- def __call__(self, preds, batch):
838
+ def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
801
839
  """Calculate the loss for text-visual prompt segmentation."""
802
840
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
803
841
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it