ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
ultralytics/utils/tal.py CHANGED
@@ -21,6 +21,7 @@ class TaskAlignedAssigner(nn.Module):
21
21
  Attributes:
22
22
  topk (int): The number of top candidates to consider.
23
23
  num_classes (int): The number of object classes.
24
+ bg_idx (int): Background class index.
24
25
  alpha (float): The alpha parameter for the classification component of the task-aligned metric.
25
26
  beta (float): The beta parameter for the localization component of the task-aligned metric.
26
27
  eps (float): A small value to prevent division by zero.
@@ -39,23 +40,25 @@ class TaskAlignedAssigner(nn.Module):
39
40
  @torch.no_grad()
40
41
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
41
42
  """
42
- Compute the task-aligned assignment. Reference code is available at
43
- https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
43
+ Compute the task-aligned assignment.
44
44
 
45
45
  Args:
46
- pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
47
- pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
48
- anc_points (Tensor): shape(num_total_anchors, 2)
49
- gt_labels (Tensor): shape(bs, n_max_boxes, 1)
50
- gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
51
- mask_gt (Tensor): shape(bs, n_max_boxes, 1)
46
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
47
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
48
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
49
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
50
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
51
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
52
52
 
53
53
  Returns:
54
- target_labels (Tensor): shape(bs, num_total_anchors)
55
- target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
56
- target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
57
- fg_mask (Tensor): shape(bs, num_total_anchors)
58
- target_gt_idx (Tensor): shape(bs, num_total_anchors)
54
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
55
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
56
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
57
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
58
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
59
+
60
+ References:
61
+ https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
59
62
  """
60
63
  self.bs = pd_scores.shape[0]
61
64
  self.n_max_boxes = gt_bboxes.shape[1]
@@ -81,23 +84,22 @@ class TaskAlignedAssigner(nn.Module):
81
84
 
82
85
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
83
86
  """
84
- Compute the task-aligned assignment. Reference code is available at
85
- https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
87
+ Compute the task-aligned assignment.
86
88
 
87
89
  Args:
88
- pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
89
- pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
90
- anc_points (Tensor): shape(num_total_anchors, 2)
91
- gt_labels (Tensor): shape(bs, n_max_boxes, 1)
92
- gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
93
- mask_gt (Tensor): shape(bs, n_max_boxes, 1)
90
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
91
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
92
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
93
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
94
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
95
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
94
96
 
95
97
  Returns:
96
- target_labels (Tensor): shape(bs, num_total_anchors)
97
- target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
98
- target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
99
- fg_mask (Tensor): shape(bs, num_total_anchors)
100
- target_gt_idx (Tensor): shape(bs, num_total_anchors)
98
+ target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
99
+ target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
100
+ target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
101
+ fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
102
+ target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
101
103
  """
102
104
  mask_pos, align_metric, overlaps = self.get_pos_mask(
103
105
  pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
@@ -118,7 +120,22 @@ class TaskAlignedAssigner(nn.Module):
118
120
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
119
121
 
120
122
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
121
- """Get in_gts mask, (b, max_num_obj, h*w)."""
123
+ """
124
+ Get positive mask for each ground truth box.
125
+
126
+ Args:
127
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
128
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
129
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
130
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
131
+ anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
132
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
133
+
134
+ Returns:
135
+ mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
136
+ align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
138
+ """
122
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
123
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
124
141
  align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
@@ -130,7 +147,20 @@ class TaskAlignedAssigner(nn.Module):
130
147
  return mask_pos, align_metric, overlaps
131
148
 
132
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
133
- """Compute alignment metric given predicted and ground truth bounding boxes."""
150
+ """
151
+ Compute alignment metric given predicted and ground truth bounding boxes.
152
+
153
+ Args:
154
+ pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
155
+ pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
156
+ gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
157
+ gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
158
+ mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
159
+
160
+ Returns:
161
+ align_metric (torch.Tensor): Alignment metric combining classification and localization.
162
+ overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
163
+ """
134
164
  na = pd_bboxes.shape[-2]
135
165
  mask_gt = mask_gt.bool() # b, max_num_obj, h*w
136
166
  overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
@@ -151,7 +181,16 @@ class TaskAlignedAssigner(nn.Module):
151
181
  return align_metric, overlaps
152
182
 
153
183
  def iou_calculation(self, gt_bboxes, pd_bboxes):
154
- """IoU calculation for horizontal bounding boxes."""
184
+ """
185
+ Calculate IoU for horizontal bounding boxes.
186
+
187
+ Args:
188
+ gt_bboxes (torch.Tensor): Ground truth boxes.
189
+ pd_bboxes (torch.Tensor): Predicted boxes.
190
+
191
+ Returns:
192
+ (torch.Tensor): IoU values between each pair of boxes.
193
+ """
155
194
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
156
195
 
157
196
  def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
@@ -159,16 +198,16 @@ class TaskAlignedAssigner(nn.Module):
159
198
  Select the top-k candidates based on the given metrics.
160
199
 
161
200
  Args:
162
- metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
201
+ metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
163
202
  max_num_obj is the maximum number of objects, and h*w represents the
164
203
  total number of anchor points.
165
204
  largest (bool): If True, select the largest values; otherwise, select the smallest values.
166
- topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
205
+ topk_mask (torch.Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
167
206
  topk is the number of top candidates to consider. If not provided,
168
207
  the top-k values are automatically computed based on the given metrics.
169
208
 
170
209
  Returns:
171
- (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
210
+ (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
172
211
  """
173
212
  # (b, max_num_obj, topk)
174
213
  topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
@@ -183,7 +222,6 @@ class TaskAlignedAssigner(nn.Module):
183
222
  for k in range(self.topk):
184
223
  # Expand topk_idxs for each value of k and add 1 at the specified positions
185
224
  count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
186
- # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
187
225
  # Filter invalid bboxes
188
226
  count_tensor.masked_fill_(count_tensor > 1, 0)
189
227
 
@@ -194,24 +232,21 @@ class TaskAlignedAssigner(nn.Module):
194
232
  Compute target labels, target bounding boxes, and target scores for the positive anchor points.
195
233
 
196
234
  Args:
197
- gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
235
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
198
236
  batch size and max_num_obj is the maximum number of objects.
199
- gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
200
- target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
237
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
238
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
201
239
  anchor points, with shape (b, h*w), where h*w is the total
202
240
  number of anchor points.
203
- fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
241
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
204
242
  (foreground) anchor points.
205
243
 
206
244
  Returns:
207
- (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
208
- - target_labels (Tensor): Shape (b, h*w), containing the target labels for
209
- positive anchor points.
210
- - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
211
- for positive anchor points.
212
- - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
213
- for positive anchor points, where num_classes is the number
214
- of object classes.
245
+ target_labels (torch.Tensor): Shape (b, h*w), containing the target labels for positive anchor points.
246
+ target_bboxes (torch.Tensor): Shape (b, h*w, 4), containing the target bounding boxes for positive
247
+ anchor points.
248
+ target_scores (torch.Tensor): Shape (b, h*w, num_classes), containing the target scores for positive
249
+ anchor points.
215
250
  """
216
251
  # Assigned target labels, (b, 1)
217
252
  batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
@@ -258,7 +293,6 @@ class TaskAlignedAssigner(nn.Module):
258
293
  bs, n_boxes, _ = gt_bboxes.shape
259
294
  lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
260
295
  bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
261
- # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
262
296
  return bbox_deltas.amin(3).gt_(eps)
263
297
 
264
298
  @staticmethod
@@ -275,9 +309,6 @@ class TaskAlignedAssigner(nn.Module):
275
309
  target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
276
310
  fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
277
311
  mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
278
-
279
- Note:
280
- b: batch size, h: height, w: width.
281
312
  """
282
313
  # Convert (b, n_max_boxes, h*w) -> (b, h*w)
283
314
  fg_mask = mask_pos.sum(-2)
@@ -299,7 +330,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
299
330
  """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
300
331
 
301
332
  def iou_calculation(self, gt_bboxes, pd_bboxes):
302
- """IoU calculation for rotated bounding boxes."""
333
+ """Calculate IoU for rotated bounding boxes."""
303
334
  return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
304
335
 
305
336
  @staticmethod
@@ -308,11 +339,11 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
308
339
  Select the positive anchor center in gt for rotated bounding boxes.
309
340
 
310
341
  Args:
311
- xy_centers (Tensor): shape(h*w, 2)
312
- gt_bboxes (Tensor): shape(b, n_boxes, 5)
342
+ xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
343
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
313
344
 
314
345
  Returns:
315
- (Tensor): shape(b, n_boxes, h*w)
346
+ (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
316
347
  """
317
348
  # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
318
349
  corners = xywhr2xyxyxyxy(gt_bboxes)
@@ -368,13 +399,13 @@ def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
368
399
  Decode predicted rotated bounding box coordinates from anchor points and distribution.
369
400
 
370
401
  Args:
371
- pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4).
372
- pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1).
373
- anchor_points (torch.Tensor): Anchor points, shape (h*w, 2).
402
+ pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
403
+ pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
404
+ anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
374
405
  dim (int, optional): Dimension along which to split. Defaults to -1.
375
406
 
376
407
  Returns:
377
- (torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4).
408
+ (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
378
409
  """
379
410
  lt, rb = pred_dist.split(2, dim=dim)
380
411
  cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
@@ -90,12 +90,12 @@ def autocast(enabled: bool, device: str = "cuda"):
90
90
  Returns:
91
91
  (torch.amp.autocast): The appropriate autocast context manager.
92
92
 
93
- Note:
93
+ Notes:
94
94
  - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
95
95
  - For older versions, it uses `torch.cuda.autocast`.
96
96
 
97
97
  Examples:
98
- >>> with autocast(amp=True):
98
+ >>> with autocast(enabled=True):
99
99
  ... # Your mixed precision operations here
100
100
  ... pass
101
101
  """
@@ -130,7 +130,7 @@ def get_gpu_info(index):
130
130
 
131
131
  def select_device(device="", batch=0, newline=False, verbose=True):
132
132
  """
133
- Selects the appropriate PyTorch device based on the provided arguments.
133
+ Select the appropriate PyTorch device based on the provided arguments.
134
134
 
135
135
  The function takes a string specifying the device or a torch.device object and returns a torch.device object
136
136
  representing the selected device. The function also validates the number of available devices and raises an
@@ -299,7 +299,18 @@ def fuse_deconv_and_bn(deconv, bn):
299
299
 
300
300
 
301
301
  def model_info(model, detailed=False, verbose=True, imgsz=640):
302
- """Print and return detailed model information layer by layer."""
302
+ """
303
+ Print and return detailed model information layer by layer.
304
+
305
+ Args:
306
+ model (nn.Module): Model to analyze.
307
+ detailed (bool, optional): Whether to print detailed layer information. Defaults to False.
308
+ verbose (bool, optional): Whether to print model information. Defaults to True.
309
+ imgsz (int | List, optional): Input image size. Defaults to 640.
310
+
311
+ Returns:
312
+ (Tuple[int, int, int, float]): Number of layers, parameters, gradients, and GFLOPs.
313
+ """
303
314
  if not verbose:
304
315
  return
305
316
  n_p = get_num_params(model) # number of parameters
@@ -343,6 +354,12 @@ def model_info_for_loggers(trainer):
343
354
  """
344
355
  Return model info dict with useful model information.
345
356
 
357
+ Args:
358
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
359
+
360
+ Returns:
361
+ (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
362
+
346
363
  Examples:
347
364
  YOLOv8n info for loggers
348
365
  >>> results = {
@@ -368,7 +385,16 @@ def model_info_for_loggers(trainer):
368
385
 
369
386
 
370
387
  def get_flops(model, imgsz=640):
371
- """Return a YOLO model's FLOPs."""
388
+ """
389
+ Return a YOLO model's FLOPs.
390
+
391
+ Args:
392
+ model (nn.Module): The model to calculate FLOPs for.
393
+ imgsz (int | List[int], optional): Input image size. Defaults to 640.
394
+
395
+ Returns:
396
+ (float): The model's FLOPs in billions.
397
+ """
372
398
  if not thop:
373
399
  return 0.0 # if not installed return 0.0 GFLOPs
374
400
 
@@ -392,7 +418,16 @@ def get_flops(model, imgsz=640):
392
418
 
393
419
 
394
420
  def get_flops_with_torch_profiler(model, imgsz=640):
395
- """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
421
+ """
422
+ Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
423
+
424
+ Args:
425
+ model (nn.Module): The model to calculate FLOPs for.
426
+ imgsz (int | List[int], optional): Input image size. Defaults to 640.
427
+
428
+ Returns:
429
+ (float): The model's FLOPs in billions.
430
+ """
396
431
  if not TORCH_2_0: # torch profiler implemented in torch>=2.0
397
432
  return 0.0
398
433
  model = de_parallel(model)
@@ -430,7 +465,18 @@ def initialize_weights(model):
430
465
 
431
466
 
432
467
  def scale_img(img, ratio=1.0, same_shape=False, gs=32):
433
- """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple."""
468
+ """
469
+ Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
470
+
471
+ Args:
472
+ img (torch.Tensor): Input image tensor.
473
+ ratio (float, optional): Scaling ratio. Defaults to 1.0.
474
+ same_shape (bool, optional): Whether to maintain the same shape. Defaults to False.
475
+ gs (int, optional): Grid size for padding. Defaults to 32.
476
+
477
+ Returns:
478
+ (torch.Tensor): Scaled and padded image tensor.
479
+ """
434
480
  if ratio == 1.0:
435
481
  return img
436
482
  h, w = img.shape[2:]
@@ -442,7 +488,15 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32):
442
488
 
443
489
 
444
490
  def copy_attr(a, b, include=(), exclude=()):
445
- """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
491
+ """
492
+ Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
493
+
494
+ Args:
495
+ a (object): Destination object to copy attributes to.
496
+ b (object): Source object to copy attributes from.
497
+ include (tuple, optional): Attributes to include. If empty, all attributes are included. Defaults to ().
498
+ exclude (tuple, optional): Attributes to exclude. Defaults to ().
499
+ """
446
500
  for k, v in b.__dict__.items():
447
501
  if (len(include) and k not in include) or k.startswith("_") or k in exclude:
448
502
  continue
@@ -451,7 +505,12 @@ def copy_attr(a, b, include=(), exclude=()):
451
505
 
452
506
 
453
507
  def get_latest_opset():
454
- """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
508
+ """
509
+ Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.
510
+
511
+ Returns:
512
+ (int): The ONNX opset version.
513
+ """
455
514
  if TORCH_1_13:
456
515
  # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset'
457
516
  return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
@@ -461,27 +520,69 @@ def get_latest_opset():
461
520
 
462
521
 
463
522
  def intersect_dicts(da, db, exclude=()):
464
- """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
523
+ """
524
+ Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
525
+
526
+ Args:
527
+ da (dict): First dictionary.
528
+ db (dict): Second dictionary.
529
+ exclude (tuple, optional): Keys to exclude. Defaults to ().
530
+
531
+ Returns:
532
+ (dict): Dictionary of intersecting keys with matching shapes.
533
+ """
465
534
  return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
466
535
 
467
536
 
468
537
  def is_parallel(model):
469
- """Returns True if model is of type DP or DDP."""
538
+ """
539
+ Returns True if model is of type DP or DDP.
540
+
541
+ Args:
542
+ model (nn.Module): Model to check.
543
+
544
+ Returns:
545
+ (bool): True if model is DataParallel or DistributedDataParallel.
546
+ """
470
547
  return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
471
548
 
472
549
 
473
550
  def de_parallel(model):
474
- """De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
551
+ """
552
+ De-parallelize a model: returns single-GPU model if model is of type DP or DDP.
553
+
554
+ Args:
555
+ model (nn.Module): Model to de-parallelize.
556
+
557
+ Returns:
558
+ (nn.Module): De-parallelized model.
559
+ """
475
560
  return model.module if is_parallel(model) else model
476
561
 
477
562
 
478
563
  def one_cycle(y1=0.0, y2=1.0, steps=100):
479
- """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
564
+ """
565
+ Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
566
+
567
+ Args:
568
+ y1 (float, optional): Initial value. Defaults to 0.0.
569
+ y2 (float, optional): Final value. Defaults to 1.0.
570
+ steps (int, optional): Number of steps. Defaults to 100.
571
+
572
+ Returns:
573
+ (function): Lambda function for computing the sinusoidal ramp.
574
+ """
480
575
  return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
481
576
 
482
577
 
483
578
  def init_seeds(seed=0, deterministic=False):
484
- """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
579
+ """
580
+ Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
581
+
582
+ Args:
583
+ seed (int, optional): Random seed. Defaults to 0.
584
+ deterministic (bool, optional): Whether to set deterministic algorithms. Defaults to False.
585
+ """
485
586
  random.seed(seed)
486
587
  np.random.seed(seed)
487
588
  torch.manual_seed(seed)
@@ -510,16 +611,30 @@ def unset_deterministic():
510
611
 
511
612
  class ModelEMA:
512
613
  """
513
- Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
514
- average of everything in the model state_dict (parameters and buffers).
614
+ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models.
515
615
 
616
+ Keeps a moving average of everything in the model state_dict (parameters and buffers).
516
617
  For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
517
618
 
518
619
  To disable EMA set the `enabled` attribute to `False`.
620
+
621
+ Attributes:
622
+ ema (nn.Module): Copy of the model in evaluation mode.
623
+ updates (int): Number of EMA updates.
624
+ decay (function): Decay function that determines the EMA weight.
625
+ enabled (bool): Whether EMA is enabled.
519
626
  """
520
627
 
521
628
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
522
- """Initialize EMA for 'model' with given arguments."""
629
+ """
630
+ Initialize EMA for 'model' with given arguments.
631
+
632
+ Args:
633
+ model (nn.Module): Model to create EMA for.
634
+ decay (float, optional): Maximum EMA decay rate. Defaults to 0.9999.
635
+ tau (int, optional): EMA decay time constant. Defaults to 2000.
636
+ updates (int, optional): Initial number of updates. Defaults to 0.
637
+ """
523
638
  self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
524
639
  self.updates = updates # number of EMA updates
525
640
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@@ -528,7 +643,12 @@ class ModelEMA:
528
643
  self.enabled = True
529
644
 
530
645
  def update(self, model):
531
- """Update EMA parameters."""
646
+ """
647
+ Update EMA parameters.
648
+
649
+ Args:
650
+ model (nn.Module): Model to update EMA from.
651
+ """
532
652
  if self.enabled:
533
653
  self.updates += 1
534
654
  d = self.decay(self.updates)
@@ -541,7 +661,14 @@ class ModelEMA:
541
661
  # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
542
662
 
543
663
  def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
544
- """Updates attributes and saves stripped model with optimizer removed."""
664
+ """
665
+ Updates attributes and saves stripped model with optimizer removed.
666
+
667
+ Args:
668
+ model (nn.Module): Model to update attributes from.
669
+ include (tuple, optional): Attributes to include. Defaults to ().
670
+ exclude (tuple, optional): Attributes to exclude. Defaults to ("process_group", "reducer").
671
+ """
545
672
  if self.enabled:
546
673
  copy_attr(self.ema, model, include, exclude)
547
674
 
@@ -551,9 +678,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
551
678
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
552
679
 
553
680
  Args:
554
- f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
555
- s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
556
- updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
681
+ f (str | Path): File path to model to strip the optimizer from. Defaults to 'best.pt'.
682
+ s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
683
+ updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
557
684
 
558
685
  Returns:
559
686
  (dict): The combined checkpoint dictionary.
@@ -563,9 +690,6 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict
563
690
  >>> from ultralytics.utils.torch_utils import strip_optimizer
564
691
  >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
565
692
  >>> strip_optimizer(f)
566
-
567
- Note:
568
- Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
569
693
  """
570
694
  try:
571
695
  x = torch.load(f, map_location=torch.device("cpu"))
@@ -613,7 +737,11 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
613
737
  """
614
738
  Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
615
739
 
616
- This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
740
+ Args:
741
+ state_dict (dict): Optimizer state dictionary.
742
+
743
+ Returns:
744
+ (dict): Converted optimizer state dictionary with FP16 tensors.
617
745
  """
618
746
  for state in state_dict["state"].values():
619
747
  for k, v in state.items():
@@ -653,6 +781,16 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
653
781
  """
654
782
  Ultralytics speed, memory and FLOPs profiler.
655
783
 
784
+ Args:
785
+ input (torch.Tensor | List[torch.Tensor]): Input tensor(s) to profile.
786
+ ops (nn.Module | List[nn.Module]): Model or list of operations to profile.
787
+ n (int, optional): Number of iterations to average. Defaults to 10.
788
+ device (str | torch.device, optional): Device to profile on. Defaults to None.
789
+ max_num_obj (int, optional): Maximum number of objects for simulation. Defaults to 0.
790
+
791
+ Returns:
792
+ (List): Profile results for each operation.
793
+
656
794
  Examples:
657
795
  >>> from ultralytics.utils.torch_utils import profile
658
796
  >>> input = torch.randn(16, 3, 640, 640)
@@ -721,7 +859,15 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
721
859
 
722
860
 
723
861
  class EarlyStopping:
724
- """Early stopping class that stops training when a specified number of epochs have passed without improvement."""
862
+ """
863
+ Early stopping class that stops training when a specified number of epochs have passed without improvement.
864
+
865
+ Attributes:
866
+ best_fitness (float): Best fitness value observed.
867
+ best_epoch (int): Epoch where best fitness was observed.
868
+ patience (int): Number of epochs to wait after fitness stops improving before stopping.
869
+ possible_stop (bool): Flag indicating if stopping may occur next epoch.
870
+ """
725
871
 
726
872
  def __init__(self, patience=50):
727
873
  """
@@ -770,11 +916,12 @@ class FXModel(nn.Module):
770
916
  """
771
917
  A custom model class for torch.fx compatibility.
772
918
 
773
- This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
774
- It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
919
+ This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
920
+ manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
921
+ copying.
775
922
 
776
- Args:
777
- model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
923
+ Attributes:
924
+ model (nn.Module): The original model's layers.
778
925
  """
779
926
 
780
927
  def __init__(self, model):
@@ -782,7 +929,7 @@ class FXModel(nn.Module):
782
929
  Initialize the FXModel.
783
930
 
784
931
  Args:
785
- model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
932
+ model (nn.Module): The original model to wrap for torch.fx compatibility.
786
933
  """
787
934
  super().__init__()
788
935
  copy_attr(self, model)
@@ -793,7 +940,8 @@ class FXModel(nn.Module):
793
940
  """
794
941
  Forward pass through the model.
795
942
 
796
- This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
943
+ This method performs the forward pass through the model, handling the dependencies between layers and saving
944
+ intermediate outputs.
797
945
 
798
946
  Args:
799
947
  x (torch.Tensor): The input tensor to the model.