ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -117,11 +117,11 @@ class BaseMixTransform:
117
117
  if self.pre_transform is not None:
118
118
  for i, data in enumerate(mix_labels):
119
119
  mix_labels[i] = self.pre_transform(data)
120
- labels['mix_labels'] = mix_labels
120
+ labels["mix_labels"] = mix_labels
121
121
 
122
122
  # Mosaic or MixUp
123
123
  labels = self._mix_transform(labels)
124
- labels.pop('mix_labels', None)
124
+ labels.pop("mix_labels", None)
125
125
  return labels
126
126
 
127
127
  def _mix_transform(self, labels):
@@ -149,8 +149,8 @@ class Mosaic(BaseMixTransform):
149
149
 
150
150
  def __init__(self, dataset, imgsz=640, p=1.0, n=4):
151
151
  """Initializes the object with a dataset, image size, probability, and border."""
152
- assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
153
- assert n in (4, 9), 'grid must be equal to 4 or 9.'
152
+ assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
153
+ assert n in (4, 9), "grid must be equal to 4 or 9."
154
154
  super().__init__(dataset=dataset, p=p)
155
155
  self.dataset = dataset
156
156
  self.imgsz = imgsz
@@ -166,20 +166,21 @@ class Mosaic(BaseMixTransform):
166
166
 
167
167
  def _mix_transform(self, labels):
168
168
  """Apply mixup transformation to the input image and labels."""
169
- assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
170
- assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
171
- return self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(
172
- labels) # This code is modified for mosaic3 method.
169
+ assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive."
170
+ assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment."
171
+ return (
172
+ self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
173
+ ) # This code is modified for mosaic3 method.
173
174
 
174
175
  def _mosaic3(self, labels):
175
176
  """Create a 1x3 image mosaic."""
176
177
  mosaic_labels = []
177
178
  s = self.imgsz
178
179
  for i in range(3):
179
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
180
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
180
181
  # Load image
181
- img = labels_patch['img']
182
- h, w = labels_patch.pop('resized_shape')
182
+ img = labels_patch["img"]
183
+ h, w = labels_patch.pop("resized_shape")
183
184
 
184
185
  # Place img in img3
185
186
  if i == 0: # center
@@ -194,7 +195,7 @@ class Mosaic(BaseMixTransform):
194
195
  padw, padh = c[:2]
195
196
  x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
196
197
 
197
- img3[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img3[ymin:ymax, xmin:xmax]
198
+ img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax]
198
199
  # hp, wp = h, w # height, width previous for next iteration
199
200
 
200
201
  # Labels assuming imgsz*2 mosaic size
@@ -202,7 +203,7 @@ class Mosaic(BaseMixTransform):
202
203
  mosaic_labels.append(labels_patch)
203
204
  final_labels = self._cat_labels(mosaic_labels)
204
205
 
205
- final_labels['img'] = img3[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
206
+ final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
206
207
  return final_labels
207
208
 
208
209
  def _mosaic4(self, labels):
@@ -211,10 +212,10 @@ class Mosaic(BaseMixTransform):
211
212
  s = self.imgsz
212
213
  yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
213
214
  for i in range(4):
214
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
215
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
215
216
  # Load image
216
- img = labels_patch['img']
217
- h, w = labels_patch.pop('resized_shape')
217
+ img = labels_patch["img"]
218
+ h, w = labels_patch.pop("resized_shape")
218
219
 
219
220
  # Place img in img4
220
221
  if i == 0: # top left
@@ -238,7 +239,7 @@ class Mosaic(BaseMixTransform):
238
239
  labels_patch = self._update_labels(labels_patch, padw, padh)
239
240
  mosaic_labels.append(labels_patch)
240
241
  final_labels = self._cat_labels(mosaic_labels)
241
- final_labels['img'] = img4
242
+ final_labels["img"] = img4
242
243
  return final_labels
243
244
 
244
245
  def _mosaic9(self, labels):
@@ -247,10 +248,10 @@ class Mosaic(BaseMixTransform):
247
248
  s = self.imgsz
248
249
  hp, wp = -1, -1 # height, width previous
249
250
  for i in range(9):
250
- labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
251
+ labels_patch = labels if i == 0 else labels["mix_labels"][i - 1]
251
252
  # Load image
252
- img = labels_patch['img']
253
- h, w = labels_patch.pop('resized_shape')
253
+ img = labels_patch["img"]
254
+ h, w = labels_patch.pop("resized_shape")
254
255
 
255
256
  # Place img in img9
256
257
  if i == 0: # center
@@ -278,7 +279,7 @@ class Mosaic(BaseMixTransform):
278
279
  x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
279
280
 
280
281
  # Image
281
- img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
282
+ img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax]
282
283
  hp, wp = h, w # height, width previous for next iteration
283
284
 
284
285
  # Labels assuming imgsz*2 mosaic size
@@ -286,16 +287,16 @@ class Mosaic(BaseMixTransform):
286
287
  mosaic_labels.append(labels_patch)
287
288
  final_labels = self._cat_labels(mosaic_labels)
288
289
 
289
- final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
290
+ final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]]
290
291
  return final_labels
291
292
 
292
293
  @staticmethod
293
294
  def _update_labels(labels, padw, padh):
294
295
  """Update labels."""
295
- nh, nw = labels['img'].shape[:2]
296
- labels['instances'].convert_bbox(format='xyxy')
297
- labels['instances'].denormalize(nw, nh)
298
- labels['instances'].add_padding(padw, padh)
296
+ nh, nw = labels["img"].shape[:2]
297
+ labels["instances"].convert_bbox(format="xyxy")
298
+ labels["instances"].denormalize(nw, nh)
299
+ labels["instances"].add_padding(padw, padh)
299
300
  return labels
300
301
 
301
302
  def _cat_labels(self, mosaic_labels):
@@ -306,18 +307,20 @@ class Mosaic(BaseMixTransform):
306
307
  instances = []
307
308
  imgsz = self.imgsz * 2 # mosaic imgsz
308
309
  for labels in mosaic_labels:
309
- cls.append(labels['cls'])
310
- instances.append(labels['instances'])
310
+ cls.append(labels["cls"])
311
+ instances.append(labels["instances"])
312
+ # Final labels
311
313
  final_labels = {
312
- 'im_file': mosaic_labels[0]['im_file'],
313
- 'ori_shape': mosaic_labels[0]['ori_shape'],
314
- 'resized_shape': (imgsz, imgsz),
315
- 'cls': np.concatenate(cls, 0),
316
- 'instances': Instances.concatenate(instances, axis=0),
317
- 'mosaic_border': self.border} # final_labels
318
- final_labels['instances'].clip(imgsz, imgsz)
319
- good = final_labels['instances'].remove_zero_area_boxes()
320
- final_labels['cls'] = final_labels['cls'][good]
314
+ "im_file": mosaic_labels[0]["im_file"],
315
+ "ori_shape": mosaic_labels[0]["ori_shape"],
316
+ "resized_shape": (imgsz, imgsz),
317
+ "cls": np.concatenate(cls, 0),
318
+ "instances": Instances.concatenate(instances, axis=0),
319
+ "mosaic_border": self.border,
320
+ }
321
+ final_labels["instances"].clip(imgsz, imgsz)
322
+ good = final_labels["instances"].remove_zero_area_boxes()
323
+ final_labels["cls"] = final_labels["cls"][good]
321
324
  return final_labels
322
325
 
323
326
 
@@ -335,10 +338,10 @@ class MixUp(BaseMixTransform):
335
338
  def _mix_transform(self, labels):
336
339
  """Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
337
340
  r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
338
- labels2 = labels['mix_labels'][0]
339
- labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
340
- labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
341
- labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
341
+ labels2 = labels["mix_labels"][0]
342
+ labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
343
+ labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
344
+ labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
342
345
  return labels
343
346
 
344
347
 
@@ -366,14 +369,9 @@ class RandomPerspective:
366
369
  box_candidates(box1, box2): Filters out bounding boxes that don't meet certain criteria post-transformation.
367
370
  """
368
371
 
369
- def __init__(self,
370
- degrees=0.0,
371
- translate=0.1,
372
- scale=0.5,
373
- shear=0.0,
374
- perspective=0.0,
375
- border=(0, 0),
376
- pre_transform=None):
372
+ def __init__(
373
+ self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None
374
+ ):
377
375
  """Initializes RandomPerspective object with transformation parameters."""
378
376
 
379
377
  self.degrees = degrees
@@ -519,18 +517,18 @@ class RandomPerspective:
519
517
  Args:
520
518
  labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
521
519
  """
522
- if self.pre_transform and 'mosaic_border' not in labels:
520
+ if self.pre_transform and "mosaic_border" not in labels:
523
521
  labels = self.pre_transform(labels)
524
- labels.pop('ratio_pad', None) # do not need ratio pad
522
+ labels.pop("ratio_pad", None) # do not need ratio pad
525
523
 
526
- img = labels['img']
527
- cls = labels['cls']
528
- instances = labels.pop('instances')
524
+ img = labels["img"]
525
+ cls = labels["cls"]
526
+ instances = labels.pop("instances")
529
527
  # Make sure the coord formats are right
530
- instances.convert_bbox(format='xyxy')
528
+ instances.convert_bbox(format="xyxy")
531
529
  instances.denormalize(*img.shape[:2][::-1])
532
530
 
533
- border = labels.pop('mosaic_border', self.border)
531
+ border = labels.pop("mosaic_border", self.border)
534
532
  self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
535
533
  # M is affine matrix
536
534
  # Scale for func:`box_candidates`
@@ -546,20 +544,20 @@ class RandomPerspective:
546
544
 
547
545
  if keypoints is not None:
548
546
  keypoints = self.apply_keypoints(keypoints, M)
549
- new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
547
+ new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
550
548
  # Clip
551
549
  new_instances.clip(*self.size)
552
550
 
553
551
  # Filter instances
554
552
  instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
555
553
  # Make the bboxes have the same scale with new_bboxes
556
- i = self.box_candidates(box1=instances.bboxes.T,
557
- box2=new_instances.bboxes.T,
558
- area_thr=0.01 if len(segments) else 0.10)
559
- labels['instances'] = new_instances[i]
560
- labels['cls'] = cls[i]
561
- labels['img'] = img
562
- labels['resized_shape'] = img.shape[:2]
554
+ i = self.box_candidates(
555
+ box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10
556
+ )
557
+ labels["instances"] = new_instances[i]
558
+ labels["cls"] = cls[i]
559
+ labels["img"] = img
560
+ labels["resized_shape"] = img.shape[:2]
563
561
  return labels
564
562
 
565
563
  def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
@@ -611,7 +609,7 @@ class RandomHSV:
611
609
 
612
610
  The modified image replaces the original image in the input 'labels' dict.
613
611
  """
614
- img = labels['img']
612
+ img = labels["img"]
615
613
  if self.hgain or self.sgain or self.vgain:
616
614
  r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
617
615
  hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
@@ -634,7 +632,7 @@ class RandomFlip:
634
632
  Also updates any instances (bounding boxes, keypoints, etc.) accordingly.
635
633
  """
636
634
 
637
- def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
635
+ def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
638
636
  """
639
637
  Initializes the RandomFlip class with probability and direction.
640
638
 
@@ -644,7 +642,7 @@ class RandomFlip:
644
642
  Default is 'horizontal'.
645
643
  flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
646
644
  """
647
- assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
645
+ assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
648
646
  assert 0 <= p <= 1.0
649
647
 
650
648
  self.p = p
@@ -662,25 +660,25 @@ class RandomFlip:
662
660
  Returns:
663
661
  (dict): The same dict with the flipped image and updated instances under the 'img' and 'instances' keys.
664
662
  """
665
- img = labels['img']
666
- instances = labels.pop('instances')
667
- instances.convert_bbox(format='xywh')
663
+ img = labels["img"]
664
+ instances = labels.pop("instances")
665
+ instances.convert_bbox(format="xywh")
668
666
  h, w = img.shape[:2]
669
667
  h = 1 if instances.normalized else h
670
668
  w = 1 if instances.normalized else w
671
669
 
672
670
  # Flip up-down
673
- if self.direction == 'vertical' and random.random() < self.p:
671
+ if self.direction == "vertical" and random.random() < self.p:
674
672
  img = np.flipud(img)
675
673
  instances.flipud(h)
676
- if self.direction == 'horizontal' and random.random() < self.p:
674
+ if self.direction == "horizontal" and random.random() < self.p:
677
675
  img = np.fliplr(img)
678
676
  instances.fliplr(w)
679
677
  # For keypoints
680
678
  if self.flip_idx is not None and instances.keypoints is not None:
681
679
  instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
682
- labels['img'] = np.ascontiguousarray(img)
683
- labels['instances'] = instances
680
+ labels["img"] = np.ascontiguousarray(img)
681
+ labels["instances"] = instances
684
682
  return labels
685
683
 
686
684
 
@@ -700,9 +698,9 @@ class LetterBox:
700
698
  """Return updated labels and image with added border."""
701
699
  if labels is None:
702
700
  labels = {}
703
- img = labels.get('img') if image is None else image
701
+ img = labels.get("img") if image is None else image
704
702
  shape = img.shape[:2] # current shape [height, width]
705
- new_shape = labels.pop('rect_shape', self.new_shape)
703
+ new_shape = labels.pop("rect_shape", self.new_shape)
706
704
  if isinstance(new_shape, int):
707
705
  new_shape = (new_shape, new_shape)
708
706
 
@@ -730,25 +728,26 @@ class LetterBox:
730
728
  img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
731
729
  top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
732
730
  left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
733
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
734
- value=(114, 114, 114)) # add border
735
- if labels.get('ratio_pad'):
736
- labels['ratio_pad'] = (labels['ratio_pad'], (left, top)) # for evaluation
731
+ img = cv2.copyMakeBorder(
732
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
733
+ ) # add border
734
+ if labels.get("ratio_pad"):
735
+ labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
737
736
 
738
737
  if len(labels):
739
738
  labels = self._update_labels(labels, ratio, dw, dh)
740
- labels['img'] = img
741
- labels['resized_shape'] = new_shape
739
+ labels["img"] = img
740
+ labels["resized_shape"] = new_shape
742
741
  return labels
743
742
  else:
744
743
  return img
745
744
 
746
745
  def _update_labels(self, labels, ratio, padw, padh):
747
746
  """Update labels."""
748
- labels['instances'].convert_bbox(format='xyxy')
749
- labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
750
- labels['instances'].scale(*ratio)
751
- labels['instances'].add_padding(padw, padh)
747
+ labels["instances"].convert_bbox(format="xyxy")
748
+ labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
749
+ labels["instances"].scale(*ratio)
750
+ labels["instances"].add_padding(padw, padh)
752
751
  return labels
753
752
 
754
753
 
@@ -785,11 +784,11 @@ class CopyPaste:
785
784
  1. Instances are expected to have 'segments' as one of their attributes for this augmentation to work.
786
785
  2. This method modifies the input dictionary 'labels' in place.
787
786
  """
788
- im = labels['img']
789
- cls = labels['cls']
787
+ im = labels["img"]
788
+ cls = labels["cls"]
790
789
  h, w = im.shape[:2]
791
- instances = labels.pop('instances')
792
- instances.convert_bbox(format='xyxy')
790
+ instances = labels.pop("instances")
791
+ instances.convert_bbox(format="xyxy")
793
792
  instances.denormalize(w, h)
794
793
  if self.p and len(instances.segments):
795
794
  n = len(instances)
@@ -812,9 +811,9 @@ class CopyPaste:
812
811
  i = cv2.flip(im_new, 1).astype(bool)
813
812
  im[i] = result[i]
814
813
 
815
- labels['img'] = im
816
- labels['cls'] = cls
817
- labels['instances'] = instances
814
+ labels["img"] = im
815
+ labels["cls"] = cls
816
+ labels["instances"] = instances
818
817
  return labels
819
818
 
820
819
 
@@ -831,12 +830,13 @@ class Albumentations:
831
830
  """Initialize the transform object for YOLO bbox formatted params."""
832
831
  self.p = p
833
832
  self.transform = None
834
- prefix = colorstr('albumentations: ')
833
+ prefix = colorstr("albumentations: ")
835
834
  try:
836
835
  import albumentations as A
837
836
 
838
- check_version(A.__version__, '1.0.3', hard=True) # version requirement
837
+ check_version(A.__version__, "1.0.3", hard=True) # version requirement
839
838
 
839
+ # Transforms
840
840
  T = [
841
841
  A.Blur(p=0.01),
842
842
  A.MedianBlur(p=0.01),
@@ -844,31 +844,32 @@ class Albumentations:
844
844
  A.CLAHE(p=0.01),
845
845
  A.RandomBrightnessContrast(p=0.0),
846
846
  A.RandomGamma(p=0.0),
847
- A.ImageCompression(quality_lower=75, p=0.0)] # transforms
848
- self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
847
+ A.ImageCompression(quality_lower=75, p=0.0),
848
+ ]
849
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
849
850
 
850
- LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
851
+ LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
851
852
  except ImportError: # package not installed, skip
852
853
  pass
853
854
  except Exception as e:
854
- LOGGER.info(f'{prefix}{e}')
855
+ LOGGER.info(f"{prefix}{e}")
855
856
 
856
857
  def __call__(self, labels):
857
858
  """Generates object detections and returns a dictionary with detection results."""
858
- im = labels['img']
859
- cls = labels['cls']
859
+ im = labels["img"]
860
+ cls = labels["cls"]
860
861
  if len(cls):
861
- labels['instances'].convert_bbox('xywh')
862
- labels['instances'].normalize(*im.shape[:2][::-1])
863
- bboxes = labels['instances'].bboxes
862
+ labels["instances"].convert_bbox("xywh")
863
+ labels["instances"].normalize(*im.shape[:2][::-1])
864
+ bboxes = labels["instances"].bboxes
864
865
  # TODO: add supports of segments and keypoints
865
866
  if self.transform and random.random() < self.p:
866
867
  new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
867
- if len(new['class_labels']) > 0: # skip update if no bbox in new im
868
- labels['img'] = new['image']
869
- labels['cls'] = np.array(new['class_labels'])
870
- bboxes = np.array(new['bboxes'], dtype=np.float32)
871
- labels['instances'].update(bboxes=bboxes)
868
+ if len(new["class_labels"]) > 0: # skip update if no bbox in new im
869
+ labels["img"] = new["image"]
870
+ labels["cls"] = np.array(new["class_labels"])
871
+ bboxes = np.array(new["bboxes"], dtype=np.float32)
872
+ labels["instances"].update(bboxes=bboxes)
872
873
  return labels
873
874
 
874
875
 
@@ -888,15 +889,17 @@ class Format:
888
889
  batch_idx (bool): Keep batch indexes. Default is True.
889
890
  """
890
891
 
891
- def __init__(self,
892
- bbox_format='xywh',
893
- normalize=True,
894
- return_mask=False,
895
- return_keypoint=False,
896
- return_obb=False,
897
- mask_ratio=4,
898
- mask_overlap=True,
899
- batch_idx=True):
892
+ def __init__(
893
+ self,
894
+ bbox_format="xywh",
895
+ normalize=True,
896
+ return_mask=False,
897
+ return_keypoint=False,
898
+ return_obb=False,
899
+ mask_ratio=4,
900
+ mask_overlap=True,
901
+ batch_idx=True,
902
+ ):
900
903
  """Initializes the Format class with given parameters."""
901
904
  self.bbox_format = bbox_format
902
905
  self.normalize = normalize
@@ -909,10 +912,10 @@ class Format:
909
912
 
910
913
  def __call__(self, labels):
911
914
  """Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
912
- img = labels.pop('img')
915
+ img = labels.pop("img")
913
916
  h, w = img.shape[:2]
914
- cls = labels.pop('cls')
915
- instances = labels.pop('instances')
917
+ cls = labels.pop("cls")
918
+ instances = labels.pop("instances")
916
919
  instances.convert_bbox(format=self.bbox_format)
917
920
  instances.denormalize(w, h)
918
921
  nl = len(instances)
@@ -922,22 +925,24 @@ class Format:
922
925
  masks, instances, cls = self._format_segments(instances, cls, w, h)
923
926
  masks = torch.from_numpy(masks)
924
927
  else:
925
- masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
926
- img.shape[1] // self.mask_ratio)
927
- labels['masks'] = masks
928
+ masks = torch.zeros(
929
+ 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio
930
+ )
931
+ labels["masks"] = masks
928
932
  if self.normalize:
929
933
  instances.normalize(w, h)
930
- labels['img'] = self._format_img(img)
931
- labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
932
- labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
934
+ labels["img"] = self._format_img(img)
935
+ labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
936
+ labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
933
937
  if self.return_keypoint:
934
- labels['keypoints'] = torch.from_numpy(instances.keypoints)
938
+ labels["keypoints"] = torch.from_numpy(instances.keypoints)
935
939
  if self.return_obb:
936
- labels['bboxes'] = xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(
937
- instances.segments) else torch.zeros((0, 5))
940
+ labels["bboxes"] = (
941
+ xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5))
942
+ )
938
943
  # Then we can use collate_fn
939
944
  if self.batch_idx:
940
- labels['batch_idx'] = torch.zeros(nl)
945
+ labels["batch_idx"] = torch.zeros(nl)
941
946
  return labels
942
947
 
943
948
  def _format_img(self, img):
@@ -964,33 +969,39 @@ class Format:
964
969
 
965
970
  def v8_transforms(dataset, imgsz, hyp, stretch=False):
966
971
  """Convert images to a size suitable for YOLOv8 training."""
967
- pre_transform = Compose([
968
- Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
969
- CopyPaste(p=hyp.copy_paste),
970
- RandomPerspective(
971
- degrees=hyp.degrees,
972
- translate=hyp.translate,
973
- scale=hyp.scale,
974
- shear=hyp.shear,
975
- perspective=hyp.perspective,
976
- pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
977
- )])
978
- flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
972
+ pre_transform = Compose(
973
+ [
974
+ Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
975
+ CopyPaste(p=hyp.copy_paste),
976
+ RandomPerspective(
977
+ degrees=hyp.degrees,
978
+ translate=hyp.translate,
979
+ scale=hyp.scale,
980
+ shear=hyp.shear,
981
+ perspective=hyp.perspective,
982
+ pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
983
+ ),
984
+ ]
985
+ )
986
+ flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
979
987
  if dataset.use_keypoints:
980
- kpt_shape = dataset.data.get('kpt_shape', None)
988
+ kpt_shape = dataset.data.get("kpt_shape", None)
981
989
  if len(flip_idx) == 0 and hyp.fliplr > 0.0:
982
990
  hyp.fliplr = 0.0
983
991
  LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
984
992
  elif flip_idx and (len(flip_idx) != kpt_shape[0]):
985
- raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
993
+ raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}")
986
994
 
987
- return Compose([
988
- pre_transform,
989
- MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
990
- Albumentations(p=1.0),
991
- RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
992
- RandomFlip(direction='vertical', p=hyp.flipud),
993
- RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
995
+ return Compose(
996
+ [
997
+ pre_transform,
998
+ MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
999
+ Albumentations(p=1.0),
1000
+ RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
1001
+ RandomFlip(direction="vertical", p=hyp.flipud),
1002
+ RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),
1003
+ ]
1004
+ ) # transforms
994
1005
 
995
1006
 
996
1007
  # Classification augmentations -----------------------------------------------------------------------------------------
@@ -1031,10 +1042,13 @@ def classify_transforms(
1031
1042
  tfl = [T.Resize(scale_size)]
1032
1043
  tfl += [T.CenterCrop(size)]
1033
1044
 
1034
- tfl += [T.ToTensor(), T.Normalize(
1035
- mean=torch.tensor(mean),
1036
- std=torch.tensor(std),
1037
- )]
1045
+ tfl += [
1046
+ T.ToTensor(),
1047
+ T.Normalize(
1048
+ mean=torch.tensor(mean),
1049
+ std=torch.tensor(std),
1050
+ ),
1051
+ ]
1038
1052
 
1039
1053
  return T.Compose(tfl)
1040
1054
 
@@ -1053,7 +1067,7 @@ def classify_augmentations(
1053
1067
  hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
1054
1068
  hsv_v=0.4, # image HSV-Value augmentation (fraction)
1055
1069
  force_color_jitter=False,
1056
- erasing=0.,
1070
+ erasing=0.0,
1057
1071
  interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
1058
1072
  ):
1059
1073
  """
@@ -1080,13 +1094,13 @@ def classify_augmentations(
1080
1094
  """
1081
1095
  # Transforms to apply if albumentations not installed
1082
1096
  if not isinstance(size, int):
1083
- raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
1097
+ raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
1084
1098
  scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
1085
- ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
1099
+ ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
1086
1100
  primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
1087
- if hflip > 0.:
1101
+ if hflip > 0.0:
1088
1102
  primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
1089
- if vflip > 0.:
1103
+ if vflip > 0.0:
1090
1104
  primary_tfl += [T.RandomVerticalFlip(p=vflip)]
1091
1105
 
1092
1106
  secondary_tfl = []
@@ -1097,27 +1111,29 @@ def classify_augmentations(
1097
1111
  # this allows override without breaking old hparm cfgs
1098
1112
  disable_color_jitter = not force_color_jitter
1099
1113
 
1100
- if auto_augment == 'randaugment':
1114
+ if auto_augment == "randaugment":
1101
1115
  if TORCHVISION_0_11:
1102
1116
  secondary_tfl += [T.RandAugment(interpolation=interpolation)]
1103
1117
  else:
1104
1118
  LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
1105
1119
 
1106
- elif auto_augment == 'augmix':
1120
+ elif auto_augment == "augmix":
1107
1121
  if TORCHVISION_0_13:
1108
1122
  secondary_tfl += [T.AugMix(interpolation=interpolation)]
1109
1123
  else:
1110
1124
  LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
1111
1125
 
1112
- elif auto_augment == 'autoaugment':
1126
+ elif auto_augment == "autoaugment":
1113
1127
  if TORCHVISION_0_10:
1114
1128
  secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
1115
1129
  else:
1116
1130
  LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
1117
1131
 
1118
1132
  else:
1119
- raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
1120
- f'"augmix", "autoaugment" or None')
1133
+ raise ValueError(
1134
+ f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
1135
+ f'"augmix", "autoaugment" or None'
1136
+ )
1121
1137
 
1122
1138
  if not disable_color_jitter:
1123
1139
  secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
@@ -1125,7 +1141,8 @@ def classify_augmentations(
1125
1141
  final_tfl = [
1126
1142
  T.ToTensor(),
1127
1143
  T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
1128
- T.RandomErasing(p=erasing, inplace=True)]
1144
+ T.RandomErasing(p=erasing, inplace=True),
1145
+ ]
1129
1146
 
1130
1147
  return T.Compose(primary_tfl + secondary_tfl + final_tfl)
1131
1148
 
@@ -1177,7 +1194,7 @@ class ClassifyLetterBox:
1177
1194
 
1178
1195
  # Create padded image
1179
1196
  im_out = np.full((hs, ws, 3), 114, dtype=im.dtype)
1180
- im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
1197
+ im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
1181
1198
  return im_out
1182
1199
 
1183
1200
 
@@ -1205,7 +1222,7 @@ class CenterCrop:
1205
1222
  imh, imw = im.shape[:2]
1206
1223
  m = min(imh, imw) # min dimension
1207
1224
  top, left = (imh - m) // 2, (imw - m) // 2
1208
- return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
1225
+ return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
1209
1226
 
1210
1227
 
1211
1228
  # NOTE: keep this class for backward compatibility