ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -33,15 +33,55 @@ class Colors:
33
33
 
34
34
  def __init__(self):
35
35
  """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
36
- hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
37
- '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
38
- self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
36
+ hexs = (
37
+ "FF3838",
38
+ "FF9D97",
39
+ "FF701F",
40
+ "FFB21D",
41
+ "CFD231",
42
+ "48F90A",
43
+ "92CC17",
44
+ "3DDB86",
45
+ "1A9334",
46
+ "00D4BB",
47
+ "2C99A8",
48
+ "00C2FF",
49
+ "344593",
50
+ "6473FF",
51
+ "0018EC",
52
+ "8438FF",
53
+ "520085",
54
+ "CB38FF",
55
+ "FF95C8",
56
+ "FF37C7",
57
+ )
58
+ self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
39
59
  self.n = len(self.palette)
40
- self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
41
- [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
42
- [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
43
- [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
44
- dtype=np.uint8)
60
+ self.pose_palette = np.array(
61
+ [
62
+ [255, 128, 0],
63
+ [255, 153, 51],
64
+ [255, 178, 102],
65
+ [230, 230, 0],
66
+ [255, 153, 255],
67
+ [153, 204, 255],
68
+ [255, 102, 255],
69
+ [255, 51, 255],
70
+ [102, 178, 255],
71
+ [51, 153, 255],
72
+ [255, 153, 153],
73
+ [255, 102, 102],
74
+ [255, 51, 51],
75
+ [153, 255, 153],
76
+ [102, 255, 102],
77
+ [51, 255, 51],
78
+ [0, 255, 0],
79
+ [0, 0, 255],
80
+ [255, 0, 0],
81
+ [255, 255, 255],
82
+ ],
83
+ dtype=np.uint8,
84
+ )
45
85
 
46
86
  def __call__(self, i, bgr=False):
47
87
  """Converts hex color codes to RGB values."""
@@ -51,7 +91,7 @@ class Colors:
51
91
  @staticmethod
52
92
  def hex2rgb(h):
53
93
  """Converts hex color codes to RGB values (i.e. default PIL order)."""
54
- return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
94
+ return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
55
95
 
56
96
 
57
97
  colors = Colors() # create instance for 'from utils.plots import colors'
@@ -71,9 +111,9 @@ class Annotator:
71
111
  kpt_color (List[int]): Color palette for keypoints.
72
112
  """
73
113
 
74
- def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
114
+ def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
75
115
  """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
76
- assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
116
+ assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images."
77
117
  non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
78
118
  self.pil = pil or non_ascii
79
119
  self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
@@ -81,26 +121,45 @@ class Annotator:
81
121
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
82
122
  self.draw = ImageDraw.Draw(self.im)
83
123
  try:
84
- font = check_font('Arial.Unicode.ttf' if non_ascii else font)
124
+ font = check_font("Arial.Unicode.ttf" if non_ascii else font)
85
125
  size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
86
126
  self.font = ImageFont.truetype(str(font), size)
87
127
  except Exception:
88
128
  self.font = ImageFont.load_default()
89
129
  # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
90
- if check_version(pil_version, '9.2.0'):
130
+ if check_version(pil_version, "9.2.0"):
91
131
  self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
92
132
  else: # use cv2
93
133
  self.im = im if im.flags.writeable else im.copy()
94
134
  self.tf = max(self.lw - 1, 1) # font thickness
95
135
  self.sf = self.lw / 3 # font scale
96
136
  # Pose
97
- self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
98
- [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
137
+ self.skeleton = [
138
+ [16, 14],
139
+ [14, 12],
140
+ [17, 15],
141
+ [15, 13],
142
+ [12, 13],
143
+ [6, 12],
144
+ [7, 13],
145
+ [6, 7],
146
+ [6, 8],
147
+ [7, 9],
148
+ [8, 10],
149
+ [9, 11],
150
+ [2, 3],
151
+ [1, 2],
152
+ [1, 3],
153
+ [2, 4],
154
+ [3, 5],
155
+ [4, 6],
156
+ [5, 7],
157
+ ]
99
158
 
100
159
  self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
101
160
  self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
102
161
 
103
- def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
162
+ def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
104
163
  """Add one xyxy box to image with label."""
105
164
  if isinstance(box, torch.Tensor):
106
165
  box = box.tolist()
@@ -134,13 +193,16 @@ class Annotator:
134
193
  outside = p1[1] - h >= 3
135
194
  p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
136
195
  cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
137
- cv2.putText(self.im,
138
- label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
139
- 0,
140
- self.sf,
141
- txt_color,
142
- thickness=self.tf,
143
- lineType=cv2.LINE_AA)
196
+ cv2.putText(
197
+ self.im,
198
+ label,
199
+ (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
200
+ 0,
201
+ self.sf,
202
+ txt_color,
203
+ thickness=self.tf,
204
+ lineType=cv2.LINE_AA,
205
+ )
144
206
 
145
207
  def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
146
208
  """
@@ -171,7 +233,7 @@ class Annotator:
171
233
  im_gpu = im_gpu.flip(dims=[0]) # flip channel
172
234
  im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
173
235
  im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
174
- im_mask = (im_gpu * 255)
236
+ im_mask = im_gpu * 255
175
237
  im_mask_np = im_mask.byte().cpu().numpy()
176
238
  self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
177
239
  if self.pil:
@@ -230,9 +292,9 @@ class Annotator:
230
292
  """Add rectangle to image (PIL-only)."""
231
293
  self.draw.rectangle(xy, fill, outline, width)
232
294
 
233
- def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
295
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
234
296
  """Adds text to an image using PIL or cv2."""
235
- if anchor == 'bottom': # start y from font bottom
297
+ if anchor == "bottom": # start y from font bottom
236
298
  w, h = self.font.getsize(text) # text width, height
237
299
  xy[1] += 1 - h
238
300
  if self.pil:
@@ -241,8 +303,8 @@ class Annotator:
241
303
  self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
242
304
  # Using `txt_color` for background and draw fg with white color
243
305
  txt_color = (255, 255, 255)
244
- if '\n' in text:
245
- lines = text.split('\n')
306
+ if "\n" in text:
307
+ lines = text.split("\n")
246
308
  _, h = self.font.getsize(text)
247
309
  for line in lines:
248
310
  self.draw.text(xy, line, fill=txt_color, font=self.font)
@@ -314,15 +376,12 @@ class Annotator:
314
376
  text_y = t_size_in[1]
315
377
 
316
378
  # Create a rounded rectangle for in_count
317
- cv2.rectangle(self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color,
318
- -1)
319
- cv2.putText(self.im,
320
- str(counts), (text_x, text_y + t_size_in[1]),
321
- 0,
322
- tl / 2,
323
- txt_color,
324
- self.tf,
325
- lineType=cv2.LINE_AA)
379
+ cv2.rectangle(
380
+ self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
381
+ )
382
+ cv2.putText(
383
+ self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
384
+ )
326
385
 
327
386
  @staticmethod
328
387
  def estimate_pose_angle(a, b, c):
@@ -375,7 +434,7 @@ class Annotator:
375
434
  center_kpt (int): centroid pose index for workout monitoring
376
435
  line_thickness (int): thickness for text display
377
436
  """
378
- angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
437
+ angle_text, count_text, stage_text = (f" {angle_text:.2f}", "Steps : " + f"{count_text}", f" {stage_text}")
379
438
  font_scale = 0.6 + (line_thickness / 10.0)
380
439
 
381
440
  # Draw angle
@@ -383,21 +442,37 @@ class Annotator:
383
442
  angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
384
443
  angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
385
444
  angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
386
- cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0],
387
- angle_background_position[1] + angle_background_size[1]),
388
- (255, 255, 255), -1)
445
+ cv2.rectangle(
446
+ self.im,
447
+ angle_background_position,
448
+ (
449
+ angle_background_position[0] + angle_background_size[0],
450
+ angle_background_position[1] + angle_background_size[1],
451
+ ),
452
+ (255, 255, 255),
453
+ -1,
454
+ )
389
455
  cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
390
456
 
391
457
  # Draw Counts
392
458
  (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
393
459
  count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
394
- count_background_position = (angle_background_position[0],
395
- angle_background_position[1] + angle_background_size[1] + 5)
460
+ count_background_position = (
461
+ angle_background_position[0],
462
+ angle_background_position[1] + angle_background_size[1] + 5,
463
+ )
396
464
  count_background_size = (count_text_width + 10, count_text_height + 10 + (line_thickness * 2))
397
465
 
398
- cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0],
399
- count_background_position[1] + count_background_size[1]),
400
- (255, 255, 255), -1)
466
+ cv2.rectangle(
467
+ self.im,
468
+ count_background_position,
469
+ (
470
+ count_background_position[0] + count_background_size[0],
471
+ count_background_position[1] + count_background_size[1],
472
+ ),
473
+ (255, 255, 255),
474
+ -1,
475
+ )
401
476
  cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
402
477
 
403
478
  # Draw Stage
@@ -406,9 +481,16 @@ class Annotator:
406
481
  stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
407
482
  stage_background_size = (stage_text_width + 10, stage_text_height + 10)
408
483
 
409
- cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0],
410
- stage_background_position[1] + stage_background_size[1]),
411
- (255, 255, 255), -1)
484
+ cv2.rectangle(
485
+ self.im,
486
+ stage_background_position,
487
+ (
488
+ stage_background_position[0] + stage_background_size[0],
489
+ stage_background_position[1] + stage_background_size[1],
490
+ ),
491
+ (255, 255, 255),
492
+ -1,
493
+ )
412
494
  cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
413
495
 
414
496
  def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
@@ -423,14 +505,20 @@ class Annotator:
423
505
  """
424
506
  cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
425
507
 
426
- label = f'Track ID: {track_label}' if track_label else det_label
508
+ label = f"Track ID: {track_label}" if track_label else det_label
427
509
  text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
428
510
 
429
- cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
430
- (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
511
+ cv2.rectangle(
512
+ self.im,
513
+ (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
514
+ (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)),
515
+ mask_color,
516
+ -1,
517
+ )
431
518
 
432
- cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
433
- 2)
519
+ cv2.putText(
520
+ self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255), 2
521
+ )
434
522
 
435
523
  def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
436
524
  """
@@ -452,24 +540,24 @@ class Annotator:
452
540
 
453
541
  @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
454
542
  @plt_settings()
455
- def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
543
+ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
456
544
  """Plot training labels including class histograms and box statistics."""
457
545
  import pandas as pd
458
546
  import seaborn as sn
459
547
 
460
548
  # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
461
- warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
462
- warnings.filterwarnings('ignore', category=FutureWarning)
549
+ warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
550
+ warnings.filterwarnings("ignore", category=FutureWarning)
463
551
 
464
552
  # Plot dataset labels
465
553
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
466
554
  nc = int(cls.max() + 1) # number of classes
467
555
  boxes = boxes[:1000000] # limit to 1M boxes
468
- x = pd.DataFrame(boxes, columns=['x', 'y', 'width', 'height'])
556
+ x = pd.DataFrame(boxes, columns=["x", "y", "width", "height"])
469
557
 
470
558
  # Seaborn correlogram
471
- sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
472
- plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
559
+ sn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
560
+ plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
473
561
  plt.close()
474
562
 
475
563
  # Matplotlib labels
@@ -477,14 +565,14 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
477
565
  y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
478
566
  for i in range(nc):
479
567
  y[2].patches[i].set_color([x / 255 for x in colors(i)])
480
- ax[0].set_ylabel('instances')
568
+ ax[0].set_ylabel("instances")
481
569
  if 0 < len(names) < 30:
482
570
  ax[0].set_xticks(range(len(names)))
483
571
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
484
572
  else:
485
- ax[0].set_xlabel('classes')
486
- sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
487
- sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
573
+ ax[0].set_xlabel("classes")
574
+ sn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
575
+ sn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
488
576
 
489
577
  # Rectangles
490
578
  boxes[:, 0:2] = 0.5 # center
@@ -493,20 +581,20 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
493
581
  for cls, box in zip(cls[:500], boxes[:500]):
494
582
  ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
495
583
  ax[1].imshow(img)
496
- ax[1].axis('off')
584
+ ax[1].axis("off")
497
585
 
498
586
  for a in [0, 1, 2, 3]:
499
- for s in ['top', 'right', 'left', 'bottom']:
587
+ for s in ["top", "right", "left", "bottom"]:
500
588
  ax[a].spines[s].set_visible(False)
501
589
 
502
- fname = save_dir / 'labels.jpg'
590
+ fname = save_dir / "labels.jpg"
503
591
  plt.savefig(fname, dpi=200)
504
592
  plt.close()
505
593
  if on_plot:
506
594
  on_plot(fname)
507
595
 
508
596
 
509
- def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
597
+ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
510
598
  """
511
599
  Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
512
600
 
@@ -545,29 +633,31 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
545
633
  b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
546
634
  xyxy = ops.xywh2xyxy(b).long()
547
635
  xyxy = ops.clip_boxes(xyxy, im.shape)
548
- crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
636
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
549
637
  if save:
550
638
  file.parent.mkdir(parents=True, exist_ok=True) # make directory
551
- f = str(increment_path(file).with_suffix('.jpg'))
639
+ f = str(increment_path(file).with_suffix(".jpg"))
552
640
  # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
553
641
  Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
554
642
  return crop
555
643
 
556
644
 
557
645
  @threaded
558
- def plot_images(images,
559
- batch_idx,
560
- cls,
561
- bboxes=np.zeros(0, dtype=np.float32),
562
- confs=None,
563
- masks=np.zeros(0, dtype=np.uint8),
564
- kpts=np.zeros((0, 51), dtype=np.float32),
565
- paths=None,
566
- fname='images.jpg',
567
- names=None,
568
- on_plot=None,
569
- max_subplots=16,
570
- save=True):
646
+ def plot_images(
647
+ images,
648
+ batch_idx,
649
+ cls,
650
+ bboxes=np.zeros(0, dtype=np.float32),
651
+ confs=None,
652
+ masks=np.zeros(0, dtype=np.uint8),
653
+ kpts=np.zeros((0, 51), dtype=np.float32),
654
+ paths=None,
655
+ fname="images.jpg",
656
+ names=None,
657
+ on_plot=None,
658
+ max_subplots=16,
659
+ save=True,
660
+ ):
571
661
  """Plot image grid with labels."""
572
662
  if isinstance(images, torch.Tensor):
573
663
  images = images.cpu().float().numpy()
@@ -585,7 +675,7 @@ def plot_images(images,
585
675
  max_size = 1920 # max image size
586
676
  bs, _, h, w = images.shape # batch size, _, height, width
587
677
  bs = min(bs, max_subplots) # limit plot images
588
- ns = np.ceil(bs ** 0.5) # number of subplots (square)
678
+ ns = np.ceil(bs**0.5) # number of subplots (square)
589
679
  if np.max(images[0]) <= 1:
590
680
  images *= 255 # de-normalise (optional)
591
681
 
@@ -593,7 +683,7 @@ def plot_images(images,
593
683
  mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
594
684
  for i in range(bs):
595
685
  x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
596
- mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
686
+ mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
597
687
 
598
688
  # Resize (optional)
599
689
  scale = max_size / ns / max(h, w)
@@ -612,7 +702,7 @@ def plot_images(images,
612
702
  annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
613
703
  if len(cls) > 0:
614
704
  idx = batch_idx == i
615
- classes = cls[idx].astype('int')
705
+ classes = cls[idx].astype("int")
616
706
  labels = confs is None
617
707
 
618
708
  if len(bboxes):
@@ -633,14 +723,14 @@ def plot_images(images,
633
723
  color = colors(c)
634
724
  c = names.get(c, c) if names else c
635
725
  if labels or conf[j] > 0.25: # 0.25 conf thresh
636
- label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
726
+ label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
637
727
  annotator.box_label(box, label, color=color, rotated=is_obb)
638
728
 
639
729
  elif len(classes):
640
730
  for c in classes:
641
731
  color = colors(c)
642
732
  c = names.get(c, c) if names else c
643
- annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
733
+ annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
644
734
 
645
735
  # Plot keypoints
646
736
  if len(kpts):
@@ -680,7 +770,9 @@ def plot_images(images,
680
770
  else:
681
771
  mask = image_masks[j].astype(bool)
682
772
  with contextlib.suppress(Exception):
683
- im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
773
+ im[y : y + h, x : x + w, :][mask] = (
774
+ im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
775
+ )
684
776
  annotator.fromarray(im)
685
777
  if save:
686
778
  annotator.im.save(fname) # save
@@ -691,7 +783,7 @@ def plot_images(images,
691
783
 
692
784
 
693
785
  @plt_settings()
694
- def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
786
+ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
695
787
  """
696
788
  Plot training results from a results CSV file. The function supports various types of data including segmentation,
697
789
  pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
@@ -714,6 +806,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
714
806
  """
715
807
  import pandas as pd
716
808
  from scipy.ndimage import gaussian_filter1d
809
+
717
810
  save_dir = Path(file).parent if file else Path(dir)
718
811
  if classify:
719
812
  fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
@@ -728,32 +821,32 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
728
821
  fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
729
822
  index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
730
823
  ax = ax.ravel()
731
- files = list(save_dir.glob('results*.csv'))
732
- assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
824
+ files = list(save_dir.glob("results*.csv"))
825
+ assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
733
826
  for f in files:
734
827
  try:
735
828
  data = pd.read_csv(f)
736
829
  s = [x.strip() for x in data.columns]
737
830
  x = data.values[:, 0]
738
831
  for i, j in enumerate(index):
739
- y = data.values[:, j].astype('float')
832
+ y = data.values[:, j].astype("float")
740
833
  # y[y == 0] = np.nan # don't show zero values
741
- ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
742
- ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
834
+ ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
835
+ ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
743
836
  ax[i].set_title(s[j], fontsize=12)
744
837
  # if j in [8, 9, 10]: # share train and val loss y axes
745
838
  # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
746
839
  except Exception as e:
747
- LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
840
+ LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
748
841
  ax[1].legend()
749
- fname = save_dir / 'results.png'
842
+ fname = save_dir / "results.png"
750
843
  fig.savefig(fname, dpi=200)
751
844
  plt.close()
752
845
  if on_plot:
753
846
  on_plot(fname)
754
847
 
755
848
 
756
- def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
849
+ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
757
850
  """
758
851
  Plots a scatter plot with points colored based on a 2D histogram.
759
852
 
@@ -774,14 +867,18 @@ def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none
774
867
  # Calculate 2D histogram and corresponding colors
775
868
  hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
776
869
  colors = [
777
- hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
778
- min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
870
+ hist[
871
+ min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
872
+ min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
873
+ ]
874
+ for i in range(len(v))
875
+ ]
779
876
 
780
877
  # Scatter plot
781
878
  plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
782
879
 
783
880
 
784
- def plot_tune_results(csv_file='tune_results.csv'):
881
+ def plot_tune_results(csv_file="tune_results.csv"):
785
882
  """
786
883
  Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
787
884
  in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
@@ -810,33 +907,33 @@ def plot_tune_results(csv_file='tune_results.csv'):
810
907
  v = x[:, i + num_metrics_columns]
811
908
  mu = v[j] # best single result
812
909
  plt.subplot(n, n, i + 1)
813
- plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
814
- plt.plot(mu, fitness.max(), 'k+', markersize=15)
815
- plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
816
- plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
910
+ plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
911
+ plt.plot(mu, fitness.max(), "k+", markersize=15)
912
+ plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
913
+ plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
817
914
  if i % n != 0:
818
915
  plt.yticks([])
819
916
 
820
- file = csv_file.with_name('tune_scatter_plots.png') # filename
917
+ file = csv_file.with_name("tune_scatter_plots.png") # filename
821
918
  plt.savefig(file, dpi=200)
822
919
  plt.close()
823
- LOGGER.info(f'Saved {file}')
920
+ LOGGER.info(f"Saved {file}")
824
921
 
825
922
  # Fitness vs iteration
826
923
  x = range(1, len(fitness) + 1)
827
924
  plt.figure(figsize=(10, 6), tight_layout=True)
828
- plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
829
- plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
830
- plt.title('Fitness vs Iteration')
831
- plt.xlabel('Iteration')
832
- plt.ylabel('Fitness')
925
+ plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
926
+ plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
927
+ plt.title("Fitness vs Iteration")
928
+ plt.xlabel("Iteration")
929
+ plt.ylabel("Fitness")
833
930
  plt.grid(True)
834
931
  plt.legend()
835
932
 
836
- file = csv_file.with_name('tune_fitness.png') # filename
933
+ file = csv_file.with_name("tune_fitness.png") # filename
837
934
  plt.savefig(file, dpi=200)
838
935
  plt.close()
839
- LOGGER.info(f'Saved {file}')
936
+ LOGGER.info(f"Saved {file}")
840
937
 
841
938
 
842
939
  def output_to_target(output, max_det=300):
@@ -861,7 +958,7 @@ def output_to_rotated_target(output, max_det=300):
861
958
  return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
862
959
 
863
960
 
864
- def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
961
+ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
865
962
  """
866
963
  Visualize feature maps of a given model module during inference.
867
964
 
@@ -872,7 +969,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
872
969
  n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
873
970
  save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
874
971
  """
875
- for m in ['Detect', 'Pose', 'Segment']:
972
+ for m in ["Detect", "Pose", "Segment"]:
876
973
  if m in module_type:
877
974
  return
878
975
  batch, channels, height, width = x.shape # batch, channels, height, width
@@ -886,9 +983,9 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
886
983
  plt.subplots_adjust(wspace=0.05, hspace=0.05)
887
984
  for i in range(n):
888
985
  ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
889
- ax[i].axis('off')
986
+ ax[i].axis("off")
890
987
 
891
- LOGGER.info(f'Saving {f}... ({n}/{channels})')
892
- plt.savefig(f, dpi=300, bbox_inches='tight')
988
+ LOGGER.info(f"Saving {f}... ({n}/{channels})")
989
+ plt.savefig(f, dpi=300, bbox_inches="tight")
893
990
  plt.close()
894
- np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
991
+ np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save