ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,9 @@ to_4tuple = _ntuple(4)
26
26
  # `xyxy` means left top and right bottom
27
27
  # `xywh` means center x, center y and width, height(YOLO format)
28
28
  # `ltwh` means left top and width, height(COCO format)
29
- _formats = ['xyxy', 'xywh', 'ltwh']
29
+ _formats = ["xyxy", "xywh", "ltwh"]
30
30
 
31
- __all__ = 'Bboxes', # tuple or list
31
+ __all__ = ("Bboxes",) # tuple or list
32
32
 
33
33
 
34
34
  class Bboxes:
@@ -46,9 +46,9 @@ class Bboxes:
46
46
  This class does not handle normalization or denormalization of bounding boxes.
47
47
  """
48
48
 
49
- def __init__(self, bboxes, format='xyxy') -> None:
49
+ def __init__(self, bboxes, format="xyxy") -> None:
50
50
  """Initializes the Bboxes class with bounding box data in a specified format."""
51
- assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
51
+ assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
52
52
  bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
53
53
  assert bboxes.ndim == 2
54
54
  assert bboxes.shape[1] == 4
@@ -58,21 +58,21 @@ class Bboxes:
58
58
 
59
59
  def convert(self, format):
60
60
  """Converts bounding box format from one type to another."""
61
- assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
61
+ assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
62
62
  if self.format == format:
63
63
  return
64
- elif self.format == 'xyxy':
65
- func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
66
- elif self.format == 'xywh':
67
- func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
64
+ elif self.format == "xyxy":
65
+ func = xyxy2xywh if format == "xywh" else xyxy2ltwh
66
+ elif self.format == "xywh":
67
+ func = xywh2xyxy if format == "xyxy" else xywh2ltwh
68
68
  else:
69
- func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
69
+ func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
70
70
  self.bboxes = func(self.bboxes)
71
71
  self.format = format
72
72
 
73
73
  def areas(self):
74
74
  """Return box areas."""
75
- self.convert('xyxy')
75
+ self.convert("xyxy")
76
76
  return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
77
77
 
78
78
  # def denormalize(self, w, h):
@@ -124,7 +124,7 @@ class Bboxes:
124
124
  return len(self.bboxes)
125
125
 
126
126
  @classmethod
127
- def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
127
+ def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
128
128
  """
129
129
  Concatenate a list of Bboxes objects into a single Bboxes object.
130
130
 
@@ -148,7 +148,7 @@ class Bboxes:
148
148
  return boxes_list[0]
149
149
  return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
150
150
 
151
- def __getitem__(self, index) -> 'Bboxes':
151
+ def __getitem__(self, index) -> "Bboxes":
152
152
  """
153
153
  Retrieve a specific bounding box or a set of bounding boxes using indexing.
154
154
 
@@ -169,7 +169,7 @@ class Bboxes:
169
169
  if isinstance(index, int):
170
170
  return Bboxes(self.bboxes[index].view(1, -1))
171
171
  b = self.bboxes[index]
172
- assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
172
+ assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
173
173
  return Bboxes(b)
174
174
 
175
175
 
@@ -205,7 +205,7 @@ class Instances:
205
205
  This class does not perform input validation, and it assumes the inputs are well-formed.
206
206
  """
207
207
 
208
- def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
208
+ def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
209
209
  """
210
210
  Args:
211
211
  bboxes (ndarray): bboxes with shape [N, 4].
@@ -263,7 +263,7 @@ class Instances:
263
263
 
264
264
  def add_padding(self, padw, padh):
265
265
  """Handle rect and mosaic situation."""
266
- assert not self.normalized, 'you should add padding with absolute coordinates.'
266
+ assert not self.normalized, "you should add padding with absolute coordinates."
267
267
  self._bboxes.add(offset=(padw, padh, padw, padh))
268
268
  self.segments[..., 0] += padw
269
269
  self.segments[..., 1] += padh
@@ -271,7 +271,7 @@ class Instances:
271
271
  self.keypoints[..., 0] += padw
272
272
  self.keypoints[..., 1] += padh
273
273
 
274
- def __getitem__(self, index) -> 'Instances':
274
+ def __getitem__(self, index) -> "Instances":
275
275
  """
276
276
  Retrieve a specific instance or a set of instances using indexing.
277
277
 
@@ -301,7 +301,7 @@ class Instances:
301
301
 
302
302
  def flipud(self, h):
303
303
  """Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
304
- if self._bboxes.format == 'xyxy':
304
+ if self._bboxes.format == "xyxy":
305
305
  y1 = self.bboxes[:, 1].copy()
306
306
  y2 = self.bboxes[:, 3].copy()
307
307
  self.bboxes[:, 1] = h - y2
@@ -314,7 +314,7 @@ class Instances:
314
314
 
315
315
  def fliplr(self, w):
316
316
  """Reverses the order of the bounding boxes and segments horizontally."""
317
- if self._bboxes.format == 'xyxy':
317
+ if self._bboxes.format == "xyxy":
318
318
  x1 = self.bboxes[:, 0].copy()
319
319
  x2 = self.bboxes[:, 2].copy()
320
320
  self.bboxes[:, 0] = w - x2
@@ -328,10 +328,10 @@ class Instances:
328
328
  def clip(self, w, h):
329
329
  """Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
330
330
  ori_format = self._bboxes.format
331
- self.convert_bbox(format='xyxy')
331
+ self.convert_bbox(format="xyxy")
332
332
  self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
333
333
  self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
334
- if ori_format != 'xyxy':
334
+ if ori_format != "xyxy":
335
335
  self.convert_bbox(format=ori_format)
336
336
  self.segments[..., 0] = self.segments[..., 0].clip(0, w)
337
337
  self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@@ -367,7 +367,7 @@ class Instances:
367
367
  return len(self.bboxes)
368
368
 
369
369
  @classmethod
370
- def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
370
+ def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
371
371
  """
372
372
  Concatenates a list of Instances objects into a single Instances object.
373
373
 
ultralytics/utils/loss.py CHANGED
@@ -28,22 +28,27 @@ class VarifocalLoss(nn.Module):
28
28
  """Computes varfocal loss."""
29
29
  weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
30
30
  with torch.cuda.amp.autocast(enabled=False):
31
- loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
32
- weight).mean(1).sum()
31
+ loss = (
32
+ (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
33
+ .mean(1)
34
+ .sum()
35
+ )
33
36
  return loss
34
37
 
35
38
 
36
39
  class FocalLoss(nn.Module):
37
40
  """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
38
41
 
39
- def __init__(self, ):
42
+ def __init__(
43
+ self,
44
+ ):
40
45
  """Initializer for FocalLoss class with no parameters."""
41
46
  super().__init__()
42
47
 
43
48
  @staticmethod
44
49
  def forward(pred, label, gamma=1.5, alpha=0.25):
45
50
  """Calculates and updates confusion matrix for object detection/classification tasks."""
46
- loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
51
+ loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
47
52
  # p_t = torch.exp(-loss)
48
53
  # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
49
54
 
@@ -91,8 +96,10 @@ class BboxLoss(nn.Module):
91
96
  tr = tl + 1 # target right
92
97
  wl = tr - target # weight left
93
98
  wr = 1 - wl # weight right
94
- return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
95
- F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
99
+ return (
100
+ F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
101
+ + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
102
+ ).mean(-1, keepdim=True)
96
103
 
97
104
 
98
105
  class RotatedBboxLoss(BboxLoss):
@@ -145,7 +152,7 @@ class v8DetectionLoss:
145
152
  h = model.args # hyperparameters
146
153
 
147
154
  m = model.model[-1] # Detect() module
148
- self.bce = nn.BCEWithLogitsLoss(reduction='none')
155
+ self.bce = nn.BCEWithLogitsLoss(reduction="none")
149
156
  self.hyp = h
150
157
  self.stride = m.stride # model strides
151
158
  self.nc = m.nc # number of classes
@@ -190,7 +197,8 @@ class v8DetectionLoss:
190
197
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
191
198
  feats = preds[1] if isinstance(preds, tuple) else preds
192
199
  pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
193
- (self.reg_max * 4, self.nc), 1)
200
+ (self.reg_max * 4, self.nc), 1
201
+ )
194
202
 
195
203
  pred_scores = pred_scores.permute(0, 2, 1).contiguous()
196
204
  pred_distri = pred_distri.permute(0, 2, 1).contiguous()
@@ -201,7 +209,7 @@ class v8DetectionLoss:
201
209
  anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
202
210
 
203
211
  # Targets
204
- targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
212
+ targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
205
213
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
206
214
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
207
215
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@@ -210,8 +218,13 @@ class v8DetectionLoss:
210
218
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
211
219
 
212
220
  _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
213
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
214
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
221
+ pred_scores.detach().sigmoid(),
222
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
223
+ anchor_points * stride_tensor,
224
+ gt_labels,
225
+ gt_bboxes,
226
+ mask_gt,
227
+ )
215
228
 
216
229
  target_scores_sum = max(target_scores.sum(), 1)
217
230
 
@@ -222,8 +235,9 @@ class v8DetectionLoss:
222
235
  # Bbox loss
223
236
  if fg_mask.sum():
224
237
  target_bboxes /= stride_tensor
225
- loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
226
- target_scores_sum, fg_mask)
238
+ loss[0], loss[2] = self.bbox_loss(
239
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
240
+ )
227
241
 
228
242
  loss[0] *= self.hyp.box # box gain
229
243
  loss[1] *= self.hyp.cls # cls gain
@@ -246,7 +260,8 @@ class v8SegmentationLoss(v8DetectionLoss):
246
260
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
247
261
  batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
248
262
  pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
249
- (self.reg_max * 4, self.nc), 1)
263
+ (self.reg_max * 4, self.nc), 1
264
+ )
250
265
 
251
266
  # B, grids, ..
252
267
  pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -259,24 +274,31 @@ class v8SegmentationLoss(v8DetectionLoss):
259
274
 
260
275
  # Targets
261
276
  try:
262
- batch_idx = batch['batch_idx'].view(-1, 1)
263
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
277
+ batch_idx = batch["batch_idx"].view(-1, 1)
278
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
264
279
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
265
280
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
266
281
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
267
282
  except RuntimeError as e:
268
- raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
269
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
270
- "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
271
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
272
- 'as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help.') from e
283
+ raise TypeError(
284
+ "ERROR segment dataset incorrectly formatted or not a segment dataset.\n"
285
+ "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
286
+ "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
287
+ "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
288
+ "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
289
+ ) from e
273
290
 
274
291
  # Pboxes
275
292
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
276
293
 
277
294
  _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
278
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
279
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
295
+ pred_scores.detach().sigmoid(),
296
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
297
+ anchor_points * stride_tensor,
298
+ gt_labels,
299
+ gt_bboxes,
300
+ mask_gt,
301
+ )
280
302
 
281
303
  target_scores_sum = max(target_scores.sum(), 1)
282
304
 
@@ -286,15 +308,23 @@ class v8SegmentationLoss(v8DetectionLoss):
286
308
 
287
309
  if fg_mask.sum():
288
310
  # Bbox loss
289
- loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
290
- target_scores, target_scores_sum, fg_mask)
311
+ loss[0], loss[3] = self.bbox_loss(
312
+ pred_distri,
313
+ pred_bboxes,
314
+ anchor_points,
315
+ target_bboxes / stride_tensor,
316
+ target_scores,
317
+ target_scores_sum,
318
+ fg_mask,
319
+ )
291
320
  # Masks loss
292
- masks = batch['masks'].to(self.device).float()
321
+ masks = batch["masks"].to(self.device).float()
293
322
  if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
294
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
323
+ masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
295
324
 
296
- loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto,
297
- pred_masks, imgsz, self.overlap)
325
+ loss[1] = self.calculate_segmentation_loss(
326
+ fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
327
+ )
298
328
 
299
329
  # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
300
330
  else:
@@ -308,8 +338,9 @@ class v8SegmentationLoss(v8DetectionLoss):
308
338
  return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
309
339
 
310
340
  @staticmethod
311
- def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor,
312
- area: torch.Tensor) -> torch.Tensor:
341
+ def single_mask_loss(
342
+ gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
343
+ ) -> torch.Tensor:
313
344
  """
314
345
  Compute the instance segmentation loss for a single image.
315
346
 
@@ -327,8 +358,8 @@ class v8SegmentationLoss(v8DetectionLoss):
327
358
  The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
328
359
  predicted masks from the prototype masks and predicted mask coefficients.
329
360
  """
330
- pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
331
- loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
361
+ pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
362
+ loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
332
363
  return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
333
364
 
334
365
  def calculate_segmentation_loss(
@@ -387,8 +418,9 @@ class v8SegmentationLoss(v8DetectionLoss):
387
418
  else:
388
419
  gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
389
420
 
390
- loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i],
391
- marea_i[fg_mask_i])
421
+ loss += self.single_mask_loss(
422
+ gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
423
+ )
392
424
 
393
425
  # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
394
426
  else:
@@ -415,7 +447,8 @@ class v8PoseLoss(v8DetectionLoss):
415
447
  loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
416
448
  feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
417
449
  pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
418
- (self.reg_max * 4, self.nc), 1)
450
+ (self.reg_max * 4, self.nc), 1
451
+ )
419
452
 
420
453
  # B, grids, ..
421
454
  pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -428,8 +461,8 @@ class v8PoseLoss(v8DetectionLoss):
428
461
 
429
462
  # Targets
430
463
  batch_size = pred_scores.shape[0]
431
- batch_idx = batch['batch_idx'].view(-1, 1)
432
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
464
+ batch_idx = batch["batch_idx"].view(-1, 1)
465
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
433
466
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
434
467
  gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
435
468
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
@@ -439,8 +472,13 @@ class v8PoseLoss(v8DetectionLoss):
439
472
  pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
440
473
 
441
474
  _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
442
- pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
443
- anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
475
+ pred_scores.detach().sigmoid(),
476
+ (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
477
+ anchor_points * stride_tensor,
478
+ gt_labels,
479
+ gt_bboxes,
480
+ mask_gt,
481
+ )
444
482
 
445
483
  target_scores_sum = max(target_scores.sum(), 1)
446
484
 
@@ -451,14 +489,16 @@ class v8PoseLoss(v8DetectionLoss):
451
489
  # Bbox loss
452
490
  if fg_mask.sum():
453
491
  target_bboxes /= stride_tensor
454
- loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
455
- target_scores_sum, fg_mask)
456
- keypoints = batch['keypoints'].to(self.device).float().clone()
492
+ loss[0], loss[4] = self.bbox_loss(
493
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
494
+ )
495
+ keypoints = batch["keypoints"].to(self.device).float().clone()
457
496
  keypoints[..., 0] *= imgsz[1]
458
497
  keypoints[..., 1] *= imgsz[0]
459
498
 
460
- loss[1], loss[2] = self.calculate_keypoints_loss(fg_mask, target_gt_idx, keypoints, batch_idx,
461
- stride_tensor, target_bboxes, pred_kpts)
499
+ loss[1], loss[2] = self.calculate_keypoints_loss(
500
+ fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
501
+ )
462
502
 
463
503
  loss[0] *= self.hyp.box # box gain
464
504
  loss[1] *= self.hyp.pose # pose gain
@@ -477,8 +517,9 @@ class v8PoseLoss(v8DetectionLoss):
477
517
  y[..., 1] += anchor_points[:, [1]] - 0.5
478
518
  return y
479
519
 
480
- def calculate_keypoints_loss(self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes,
481
- pred_kpts):
520
+ def calculate_keypoints_loss(
521
+ self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
522
+ ):
482
523
  """
483
524
  Calculate the keypoints loss for the model.
484
525
 
@@ -507,21 +548,23 @@ class v8PoseLoss(v8DetectionLoss):
507
548
  max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
508
549
 
509
550
  # Create a tensor to hold batched keypoints
510
- batched_keypoints = torch.zeros((batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]),
511
- device=keypoints.device)
551
+ batched_keypoints = torch.zeros(
552
+ (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
553
+ )
512
554
 
513
555
  # TODO: any idea how to vectorize this?
514
556
  # Fill batched_keypoints with keypoints based on batch_idx
515
557
  for i in range(batch_size):
516
558
  keypoints_i = keypoints[batch_idx == i]
517
- batched_keypoints[i, :keypoints_i.shape[0]] = keypoints_i
559
+ batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
518
560
 
519
561
  # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
520
562
  target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
521
563
 
522
564
  # Use target_gt_idx_expanded to select keypoints from batched_keypoints
523
565
  selected_keypoints = batched_keypoints.gather(
524
- 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]))
566
+ 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
567
+ )
525
568
 
526
569
  # Divide coordinates by stride
527
570
  selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
@@ -547,13 +590,12 @@ class v8ClassificationLoss:
547
590
 
548
591
  def __call__(self, preds, batch):
549
592
  """Compute the classification loss between predictions and true labels."""
550
- loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='mean')
593
+ loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction="mean")
551
594
  loss_items = loss.detach()
552
595
  return loss, loss_items
553
596
 
554
597
 
555
598
  class v8OBBLoss(v8DetectionLoss):
556
-
557
599
  def __init__(self, model): # model must be de-paralleled
558
600
  super().__init__(model)
559
601
  self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
@@ -583,7 +625,8 @@ class v8OBBLoss(v8DetectionLoss):
583
625
  feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
584
626
  batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
585
627
  pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
586
- (self.reg_max * 4, self.nc), 1)
628
+ (self.reg_max * 4, self.nc), 1
629
+ )
587
630
 
588
631
  # b, grids, ..
589
632
  pred_scores = pred_scores.permute(0, 2, 1).contiguous()
@@ -596,19 +639,21 @@ class v8OBBLoss(v8DetectionLoss):
596
639
 
597
640
  # targets
598
641
  try:
599
- batch_idx = batch['batch_idx'].view(-1, 1)
600
- targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].view(-1, 5)), 1)
642
+ batch_idx = batch["batch_idx"].view(-1, 1)
643
+ targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
601
644
  rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
602
645
  targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
603
646
  targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
604
647
  gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
605
648
  mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
606
649
  except RuntimeError as e:
607
- raise TypeError('ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n'
608
- "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
609
- "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
610
- "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
611
- 'as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help.') from e
650
+ raise TypeError(
651
+ "ERROR OBB dataset incorrectly formatted or not a OBB dataset.\n"
652
+ "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
653
+ "i.e. 'yolo train model=yolov8n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
654
+ "correctly formatted 'OBB' dataset using 'data=coco8-obb.yaml' "
655
+ "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
656
+ ) from e
612
657
 
613
658
  # Pboxes
614
659
  pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
@@ -616,10 +661,14 @@ class v8OBBLoss(v8DetectionLoss):
616
661
  bboxes_for_assigner = pred_bboxes.clone().detach()
617
662
  # Only the first four elements need to be scaled
618
663
  bboxes_for_assigner[..., :4] *= stride_tensor
619
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(),
620
- bboxes_for_assigner.type(gt_bboxes.dtype),
621
- anchor_points * stride_tensor, gt_labels, gt_bboxes,
622
- mask_gt)
664
+ _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
665
+ pred_scores.detach().sigmoid(),
666
+ bboxes_for_assigner.type(gt_bboxes.dtype),
667
+ anchor_points * stride_tensor,
668
+ gt_labels,
669
+ gt_bboxes,
670
+ mask_gt,
671
+ )
623
672
 
624
673
  target_scores_sum = max(target_scores.sum(), 1)
625
674
 
@@ -630,8 +679,9 @@ class v8OBBLoss(v8DetectionLoss):
630
679
  # Bbox loss
631
680
  if fg_mask.sum():
632
681
  target_bboxes[..., :4] /= stride_tensor
633
- loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
634
- target_scores_sum, fg_mask)
682
+ loss[0], loss[2] = self.bbox_loss(
683
+ pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
684
+ )
635
685
  else:
636
686
  loss[0] += (pred_angle * 0).sum()
637
687