dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import math
4
6
  import warnings
7
+ from collections.abc import Callable
5
8
  from pathlib import Path
6
- from typing import Callable, Dict, List, Optional, Union
9
+ from typing import Any
7
10
 
8
11
  import cv2
9
12
  import numpy as np
@@ -17,21 +20,21 @@ from ultralytics.utils.files import increment_path
17
20
 
18
21
 
19
22
  class Colors:
20
- """
21
- Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors.
23
+ """Ultralytics color palette for visualization and plotting.
22
24
 
23
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
24
- RGB values.
25
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
26
+ values and accessing predefined color schemes for object detection and pose estimation.
25
27
 
26
28
  Attributes:
27
- palette (List[Tuple]): List of RGB color values.
29
+ palette (list[tuple]): List of RGB color tuples for general use.
28
30
  n (int): The number of colors in the palette.
29
31
  pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
30
32
 
31
33
  Examples:
32
34
  >>> from ultralytics.utils.plotting import Colors
33
35
  >>> colors = Colors()
34
- >>> colors(5, True) # ff6fdd or (255, 111, 221)
36
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
37
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
35
38
 
36
39
  ## Ultralytics Color Palette
37
40
 
@@ -85,7 +88,8 @@ class Colors:
85
88
 
86
89
  !!! note "Ultralytics Brand Colors"
87
90
 
88
- For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials.
91
+ For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
92
+ Please use the official Ultralytics colors for all marketing materials.
89
93
  """
90
94
 
91
95
  def __init__(self):
@@ -140,13 +144,21 @@ class Colors:
140
144
  dtype=np.uint8,
141
145
  )
142
146
 
143
- def __call__(self, i, bgr=False):
144
- """Convert hex color codes to RGB values."""
147
+ def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
148
+ """Convert hex color codes to RGB values.
149
+
150
+ Args:
151
+ i (int | torch.Tensor): Color index.
152
+ bgr (bool, optional): Whether to return BGR format instead of RGB.
153
+
154
+ Returns:
155
+ (tuple): RGB or BGR color tuple.
156
+ """
145
157
  c = self.palette[int(i) % self.n]
146
158
  return (c[2], c[1], c[0]) if bgr else c
147
159
 
148
160
  @staticmethod
149
- def hex2rgb(h):
161
+ def hex2rgb(h: str) -> tuple:
150
162
  """Convert hex color codes to RGB values (i.e. default PIL order)."""
151
163
  return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
152
164
 
@@ -155,17 +167,16 @@ colors = Colors() # create instance for 'from utils.plots import colors'
155
167
 
156
168
 
157
169
  class Annotator:
158
- """
159
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
170
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
160
171
 
161
172
  Attributes:
162
- im (Image.Image or np.ndarray): The image to annotate.
173
+ im (Image.Image | np.ndarray): The image to annotate.
163
174
  pil (bool): Whether to use PIL or cv2 for drawing annotations.
164
- font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
175
+ font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
165
176
  lw (float): Line width for drawing.
166
- skeleton (List[List[int]]): Skeleton structure for keypoints.
167
- limb_color (List[int]): Color palette for limbs.
168
- kpt_color (List[int]): Color palette for keypoints.
177
+ skeleton (list[list[int]]): Skeleton structure for keypoints.
178
+ limb_color (list[int]): Color palette for limbs.
179
+ kpt_color (list[int]): Color palette for keypoints.
169
180
  dark_colors (set): Set of colors considered dark for text contrast.
170
181
  light_colors (set): Set of colors considered light for text contrast.
171
182
 
@@ -173,14 +184,28 @@ class Annotator:
173
184
  >>> from ultralytics.utils.plotting import Annotator
174
185
  >>> im0 = cv2.imread("test.png")
175
186
  >>> annotator = Annotator(im0, line_width=10)
187
+ >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
176
188
  """
177
189
 
178
- def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
190
+ def __init__(
191
+ self,
192
+ im,
193
+ line_width: int | None = None,
194
+ font_size: int | None = None,
195
+ font: str = "Arial.ttf",
196
+ pil: bool = False,
197
+ example: str = "abc",
198
+ ):
179
199
  """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
180
200
  non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
181
201
  input_is_pil = isinstance(im, Image.Image)
182
202
  self.pil = pil or non_ascii or input_is_pil
183
203
  self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
204
+ if not input_is_pil:
205
+ if im.shape[2] == 1: # handle grayscale
206
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
207
+ elif im.shape[2] > 3: # multispectral
208
+ im = np.ascontiguousarray(im[..., :3])
184
209
  if self.pil: # use PIL
185
210
  self.im = im if input_is_pil else Image.fromarray(im)
186
211
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
@@ -196,10 +221,6 @@ class Annotator:
196
221
  if check_version(pil_version, "9.2.0"):
197
222
  self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
198
223
  else: # use cv2
199
- if im.shape[2] == 1: # handle grayscale
200
- im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
201
- elif im.shape[2] > 3: # multispectral
202
- im = np.ascontiguousarray(im[..., :3])
203
224
  assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
204
225
  self.im = im if im.flags.writeable else im.copy()
205
226
  self.tf = max(self.lw - 1, 1) # font thickness
@@ -254,9 +275,8 @@ class Annotator:
254
275
  (104, 31, 17),
255
276
  }
256
277
 
257
- def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
258
- """
259
- Assign text color based on background color.
278
+ def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
279
+ """Assign text color based on background color.
260
280
 
261
281
  Args:
262
282
  color (tuple, optional): The background color of the rectangle for text (B, G, R).
@@ -278,16 +298,14 @@ class Annotator:
278
298
  else:
279
299
  return txt_color
280
300
 
281
- def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False):
282
- """
283
- Draw a bounding box on an image with a given label.
301
+ def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
302
+ """Draw a bounding box on an image with a given label.
284
303
 
285
304
  Args:
286
305
  box (tuple): The bounding box coordinates (x1, y1, x2, y2).
287
306
  label (str, optional): The text label to be displayed.
288
307
  color (tuple, optional): The background color of the rectangle (B, G, R).
289
308
  txt_color (tuple, optional): The color of the text (R, G, B).
290
- rotated (bool, optional): Whether the task is oriented bounding box detection.
291
309
 
292
310
  Examples:
293
311
  >>> from ultralytics.utils.plotting import Annotator
@@ -298,13 +316,13 @@ class Annotator:
298
316
  txt_color = self.get_txt_color(color, txt_color)
299
317
  if isinstance(box, torch.Tensor):
300
318
  box = box.tolist()
301
- if self.pil or not is_ascii(label):
302
- if rotated:
303
- p1 = box[0]
304
- self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box
305
- else:
306
- p1 = (box[0], box[1])
307
- self.draw.rectangle(box, width=self.lw, outline=color) # box
319
+
320
+ multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
321
+ p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
322
+ if self.pil:
323
+ self.draw.polygon(
324
+ [tuple(b) for b in box], width=self.lw, outline=color
325
+ ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
308
326
  if label:
309
327
  w, h = self.font.getsize(label) # text width, height
310
328
  outside = p1[1] >= h # label fits outside box
@@ -317,12 +335,11 @@ class Annotator:
317
335
  # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
318
336
  self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
319
337
  else: # cv2
320
- if rotated:
321
- p1 = [int(b) for b in box[0]]
322
- cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box
323
- else:
324
- p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
325
- cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
338
+ cv2.polylines(
339
+ self.im, [np.asarray(box, dtype=int)], True, color, self.lw
340
+ ) if multi_points else cv2.rectangle(
341
+ self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
342
+ )
326
343
  if label:
327
344
  w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
328
345
  h += 3 # add pixels to pad text
@@ -342,45 +359,66 @@ class Annotator:
342
359
  lineType=cv2.LINE_AA,
343
360
  )
344
361
 
345
- def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
346
- """
347
- Plot masks on image.
362
+ def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
363
+ """Plot masks on image.
348
364
 
349
365
  Args:
350
- masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w]
351
- colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n]
352
- im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
366
+ masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
367
+ colors (list[list[int]]): Colors for predicted masks, [[r, g, b] * n]
368
+ im_gpu (torch.Tensor | None): Image is in cuda, shape: [3, h, w], range: [0, 1]
353
369
  alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
354
370
  retina_masks (bool, optional): Whether to use high resolution masks or not.
355
371
  """
356
372
  if self.pil:
357
373
  # Convert to numpy first
358
374
  self.im = np.asarray(self.im).copy()
359
- if len(masks) == 0:
360
- self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
361
- if im_gpu.device != masks.device:
362
- im_gpu = im_gpu.to(masks.device)
363
- colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
364
- colors = colors[:, None, None] # shape(n,1,1,3)
365
- masks = masks.unsqueeze(3) # shape(n,h,w,1)
366
- masks_color = masks * (colors * alpha) # shape(n,h,w,3)
367
-
368
- inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
369
- mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
370
-
371
- im_gpu = im_gpu.flip(dims=[0]) # flip channel
372
- im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
373
- im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
374
- im_mask = im_gpu * 255
375
- im_mask_np = im_mask.byte().cpu().numpy()
376
- self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
375
+ if im_gpu is None:
376
+ assert isinstance(masks, np.ndarray), "`masks` must be a np.ndarray if `im_gpu` is not provided."
377
+ overlay = self.im.copy()
378
+ for i, mask in enumerate(masks):
379
+ overlay[mask.astype(bool)] = colors[i]
380
+ self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
381
+ else:
382
+ assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
383
+ if len(masks) == 0:
384
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
385
+ return
386
+ if im_gpu.device != masks.device:
387
+ im_gpu = im_gpu.to(masks.device)
388
+
389
+ ih, iw = self.im.shape[:2]
390
+ if not retina_masks:
391
+ # Use scale_masks to properly remove padding and upsample, convert bool to float first
392
+ masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
393
+ # Convert original BGR image to RGB tensor
394
+ im_gpu = (
395
+ torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
396
+ )
397
+
398
+ colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
399
+ colors = colors[:, None, None] # shape(n,1,1,3)
400
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
401
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
402
+ inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
403
+ mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
404
+
405
+ im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
406
+ im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
407
+ self.im[:] = (im_gpu * 255).byte().cpu().numpy()
377
408
  if self.pil:
378
409
  # Convert im back to PIL and update draw
379
410
  self.fromarray(self.im)
380
411
 
381
- def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
382
- """
383
- Plot keypoints on the image.
412
+ def kpts(
413
+ self,
414
+ kpts,
415
+ shape: tuple = (640, 640),
416
+ radius: int | None = None,
417
+ kpt_line: bool = True,
418
+ conf_thres: float = 0.25,
419
+ kpt_color: tuple | None = None,
420
+ ):
421
+ """Plot keypoints on the image.
384
422
 
385
423
  Args:
386
424
  kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
@@ -390,7 +428,7 @@ class Annotator:
390
428
  conf_thres (float, optional): Confidence threshold.
391
429
  kpt_color (tuple, optional): Keypoint color (B, G, R).
392
430
 
393
- Note:
431
+ Notes:
394
432
  - `kpt_line=True` currently only supports human pose plotting.
395
433
  - Modifies self.im in-place.
396
434
  - If self.pil is True, converts image to numpy array and back to PIL.
@@ -438,16 +476,15 @@ class Annotator:
438
476
  # Convert im back to PIL and update draw
439
477
  self.fromarray(self.im)
440
478
 
441
- def rectangle(self, xy, fill=None, outline=None, width=1):
479
+ def rectangle(self, xy, fill=None, outline=None, width: int = 1):
442
480
  """Add rectangle to image (PIL-only)."""
443
481
  self.draw.rectangle(xy, fill, outline, width)
444
482
 
445
- def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_color=()):
446
- """
447
- Add text to an image using PIL or cv2.
483
+ def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
484
+ """Add text to an image using PIL or cv2.
448
485
 
449
486
  Args:
450
- xy (List[int]): Top-left coordinates for text placement.
487
+ xy (list[int]): Top-left coordinates for text placement.
451
488
  text (str): Text to be drawn.
452
489
  txt_color (tuple, optional): Text color (R, G, B).
453
490
  anchor (str, optional): Text anchor position ('top' or 'bottom').
@@ -482,7 +519,7 @@ class Annotator:
482
519
  """Return annotated image as array."""
483
520
  return np.asarray(self.im)
484
521
 
485
- def show(self, title=None):
522
+ def show(self, title: str | None = None):
486
523
  """Show the annotated image."""
487
524
  im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
488
525
  if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
@@ -493,14 +530,13 @@ class Annotator:
493
530
  else:
494
531
  im.show(title=title)
495
532
 
496
- def save(self, filename="image.jpg"):
533
+ def save(self, filename: str = "image.jpg"):
497
534
  """Save the annotated image to 'filename'."""
498
535
  cv2.imwrite(filename, np.asarray(self.im))
499
536
 
500
537
  @staticmethod
501
- def get_bbox_dimension(bbox=None):
502
- """
503
- Calculate the dimensions and area of a bounding box.
538
+ def get_bbox_dimension(bbox: tuple | None = None):
539
+ """Calculate the dimensions and area of a bounding box.
504
540
 
505
541
  Args:
506
542
  bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
@@ -522,11 +558,10 @@ class Annotator:
522
558
  return width, height, width * height
523
559
 
524
560
 
525
- @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
561
+ @TryExcept()
526
562
  @plt_settings()
527
563
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
528
- """
529
- Plot training labels including class histograms and box statistics.
564
+ """Plot training labels including class histograms and box statistics.
530
565
 
531
566
  Args:
532
567
  boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
@@ -536,7 +571,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
536
571
  on_plot (Callable, optional): Function to call after plot is saved.
537
572
  """
538
573
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
539
- import pandas
574
+ import polars
540
575
  from matplotlib.colors import LinearSegmentedColormap
541
576
 
542
577
  # Filter matplotlib>=3.7.2 warning
@@ -547,16 +582,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
547
582
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
548
583
  nc = int(cls.max() + 1) # number of classes
549
584
  boxes = boxes[:1000000] # limit to 1M boxes
550
- x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
551
-
552
- try: # Seaborn correlogram
553
- import seaborn
554
-
555
- seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
556
- plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
557
- plt.close()
558
- except ImportError:
559
- pass # Skip if seaborn is not installed
585
+ x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
560
586
 
561
587
  # Matplotlib labels
562
588
  subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
@@ -568,12 +594,13 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
568
594
  if 0 < len(names) < 30:
569
595
  ax[0].set_xticks(range(len(names)))
570
596
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
597
+ ax[0].bar_label(y[2])
571
598
  else:
572
599
  ax[0].set_xlabel("classes")
573
600
  boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
574
601
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
575
602
  for cls, box in zip(cls[:500], boxes[:500]):
576
- ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
603
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
577
604
  ax[1].imshow(img)
578
605
  ax[1].axis("off")
579
606
 
@@ -583,8 +610,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
583
610
  ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color)
584
611
  ax[3].set_xlabel("width")
585
612
  ax[3].set_ylabel("height")
586
- for a in [0, 1, 2, 3]:
587
- for s in ["top", "right", "left", "bottom"]:
613
+ for a in {0, 1, 2, 3}:
614
+ for s in {"top", "right", "left", "bottom"}:
588
615
  ax[a].spines[s].set_visible(False)
589
616
 
590
617
  fname = save_dir / "labels.jpg"
@@ -594,13 +621,21 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
594
621
  on_plot(fname)
595
622
 
596
623
 
597
- def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True):
598
- """
599
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
624
+ def save_one_box(
625
+ xyxy,
626
+ im,
627
+ file: Path = Path("im.jpg"),
628
+ gain: float = 1.02,
629
+ pad: int = 10,
630
+ square: bool = False,
631
+ BGR: bool = False,
632
+ save: bool = True,
633
+ ):
634
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
600
635
 
601
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
602
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
603
- adjustments to the bounding box.
636
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
637
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
638
+ bounding box.
604
639
 
605
640
  Args:
606
641
  xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
@@ -609,7 +644,7 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
609
644
  gain (float, optional): A multiplicative factor to increase the size of the bounding box.
610
645
  pad (int, optional): The number of pixels to add to the width and height of the bounding box.
611
646
  square (bool, optional): If True, the bounding box will be transformed into a square.
612
- BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB.
647
+ BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
613
648
  save (bool, optional): If True, the cropped image will be saved to disk.
614
649
 
615
650
  Returns:
@@ -629,73 +664,83 @@ def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False,
629
664
  b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
630
665
  xyxy = ops.xywh2xyxy(b).long()
631
666
  xyxy = ops.clip_boxes(xyxy, im.shape)
632
- crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
667
+ grayscale = im.shape[2] == 1 # grayscale image
668
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
633
669
  if save:
634
670
  file.parent.mkdir(parents=True, exist_ok=True) # make directory
635
671
  f = str(increment_path(file).with_suffix(".jpg"))
636
672
  # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
637
- Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
673
+ crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
674
+ Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
638
675
  return crop
639
676
 
640
677
 
641
678
  @threaded
642
679
  def plot_images(
643
- images: Union[torch.Tensor, np.ndarray],
644
- batch_idx: Union[torch.Tensor, np.ndarray],
645
- cls: Union[torch.Tensor, np.ndarray],
646
- bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
647
- confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
648
- masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
649
- kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
650
- paths: Optional[List[str]] = None,
680
+ labels: dict[str, Any],
681
+ images: torch.Tensor | np.ndarray = np.zeros((0, 3, 640, 640), dtype=np.float32),
682
+ paths: list[str] | None = None,
651
683
  fname: str = "images.jpg",
652
- names: Optional[Dict[int, str]] = None,
653
- on_plot: Optional[Callable] = None,
684
+ names: dict[int, str] | None = None,
685
+ on_plot: Callable | None = None,
654
686
  max_size: int = 1920,
655
687
  max_subplots: int = 16,
656
688
  save: bool = True,
657
689
  conf_thres: float = 0.25,
658
- ) -> Optional[np.ndarray]:
659
- """
660
- Plot image grid with labels, bounding boxes, masks, and keypoints.
690
+ ) -> np.ndarray | None:
691
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
661
692
 
662
693
  Args:
663
- images: Batch of images to plot. Shape: (batch_size, channels, height, width).
664
- batch_idx: Batch indices for each detection. Shape: (num_detections,).
665
- cls: Class labels for each detection. Shape: (num_detections,).
666
- bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
667
- confs: Confidence scores for each detection. Shape: (num_detections,).
668
- masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
669
- kpts: Keypoints for each detection. Shape: (num_detections, 51).
670
- paths: List of file paths for each image in the batch.
671
- fname: Output filename for the plotted image grid.
672
- names: Dictionary mapping class indices to class names.
673
- on_plot: Optional callback function to be called after saving the plot.
674
- max_size: Maximum size of the output image grid.
675
- max_subplots: Maximum number of subplots in the image grid.
676
- save: Whether to save the plotted image grid to a file.
677
- conf_thres: Confidence threshold for displaying detections.
694
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
695
+ 'keypoints', 'batch_idx', 'img'.
696
+ images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
697
+ paths (Optional[list[str]]): List of file paths for each image in the batch.
698
+ fname (str): Output filename for the plotted image grid.
699
+ names (Optional[dict[int, str]]): Dictionary mapping class indices to class names.
700
+ on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
701
+ max_size (int): Maximum size of the output image grid.
702
+ max_subplots (int): Maximum number of subplots in the image grid.
703
+ save (bool): Whether to save the plotted image grid to a file.
704
+ conf_thres (float): Confidence threshold for displaying detections.
678
705
 
679
706
  Returns:
680
707
  (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
681
708
 
682
- Note:
709
+ Notes:
683
710
  This function supports both tensor and numpy array inputs. It will automatically
684
711
  convert tensor inputs to numpy arrays for processing.
712
+
713
+ Channel Support:
714
+ - 1 channel: Grayscale
715
+ - 2 channels: Third channel added as zeros
716
+ - 3 channels: Used as-is (standard RGB)
717
+ - 4+ channels: Cropped to first 3 channels
685
718
  """
686
- if isinstance(images, torch.Tensor):
719
+ for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
720
+ if k not in labels:
721
+ continue
722
+ if k == "cls" and labels[k].ndim == 2:
723
+ labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)
724
+ if isinstance(labels[k], torch.Tensor):
725
+ labels[k] = labels[k].cpu().numpy()
726
+
727
+ cls = labels.get("cls", np.zeros(0, dtype=np.int64))
728
+ batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
729
+ bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
730
+ confs = labels.get("conf", None)
731
+ masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
732
+ kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
733
+ images = labels.get("img", images) # default to input images
734
+
735
+ if len(images) and isinstance(images, torch.Tensor):
687
736
  images = images.cpu().float().numpy()
688
- if isinstance(cls, torch.Tensor):
689
- cls = cls.cpu().numpy()
690
- if isinstance(bboxes, torch.Tensor):
691
- bboxes = bboxes.cpu().numpy()
692
- if isinstance(masks, torch.Tensor):
693
- masks = masks.cpu().numpy().astype(int)
694
- if isinstance(kpts, torch.Tensor):
695
- kpts = kpts.cpu().numpy()
696
- if isinstance(batch_idx, torch.Tensor):
697
- batch_idx = batch_idx.cpu().numpy()
698
- if images.shape[1] > 3:
737
+
738
+ # Handle 2-ch and n-ch images
739
+ c = images.shape[1]
740
+ if c == 2:
741
+ zero = np.zeros_like(images[:, :1])
742
+ images = np.concatenate((images, zero), axis=1) # pad 2-ch with a black channel
743
+ elif c > 3:
699
744
  images = images[:, :3] # crop multispectral images to first 3 channels
700
745
 
701
746
  bs, _, h, w = images.shape # batch size, _, height, width
@@ -730,10 +775,10 @@ def plot_images(
730
775
  idx = batch_idx == i
731
776
  classes = cls[idx].astype("int")
732
777
  labels = confs is None
778
+ conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
733
779
 
734
780
  if len(bboxes):
735
781
  boxes = bboxes[idx]
736
- conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
737
782
  if len(boxes):
738
783
  if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
739
784
  boxes[..., [0, 2]] *= w # scale to pixels
@@ -743,6 +788,7 @@ def plot_images(
743
788
  boxes[..., 0] += x
744
789
  boxes[..., 1] += y
745
790
  is_obb = boxes.shape[-1] == 5 # xywhr
791
+ # TODO: this transformation might be unnecessary
746
792
  boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
747
793
  for j, box in enumerate(boxes.astype(np.int64).tolist()):
748
794
  c = classes[j]
@@ -750,13 +796,14 @@ def plot_images(
750
796
  c = names.get(c, c) if names else c
751
797
  if labels or conf[j] > conf_thres:
752
798
  label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
753
- annotator.box_label(box, label, color=color, rotated=is_obb)
799
+ annotator.box_label(box, label, color=color)
754
800
 
755
801
  elif len(classes):
756
802
  for c in classes:
757
803
  color = colors(c)
758
804
  c = names.get(c, c) if names else c
759
- annotator.text([x, y], f"{c}", txt_color=color, box_color=(64, 64, 64, 128))
805
+ label = f"{c}" if labels else f"{c} {conf[0]:.1f}"
806
+ annotator.text([x, y], label, txt_color=color, box_color=(64, 64, 64, 128))
760
807
 
761
808
  # Plot keypoints
762
809
  if len(kpts):
@@ -775,14 +822,13 @@ def plot_images(
775
822
 
776
823
  # Plot masks
777
824
  if len(masks):
778
- if idx.shape[0] == masks.shape[0]: # overlap_masks=False
825
+ if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
779
826
  image_masks = masks[idx]
780
- else: # overlap_masks=True
827
+ else: # overlap_mask=True
781
828
  image_masks = masks[[i]] # (1, 640, 640)
782
829
  nl = idx.sum()
783
- index = np.arange(nl).reshape((nl, 1, 1)) + 1
784
- image_masks = np.repeat(image_masks, nl, axis=0)
785
- image_masks = np.where(image_masks == index, 1.0, 0.0)
830
+ index = np.arange(1, nl + 1).reshape((nl, 1, 1))
831
+ image_masks = (image_masks == index).astype(np.float32)
786
832
 
787
833
  im = np.asarray(annotator.im).copy()
788
834
  for j in range(len(image_masks)):
@@ -810,17 +856,14 @@ def plot_images(
810
856
 
811
857
 
812
858
  @plt_settings()
813
- def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
814
- """
815
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
816
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
859
+ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
860
+ """Plot training results from a results CSV file. The function supports various types of data including
861
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
862
+ CSV is located.
817
863
 
818
864
  Args:
819
865
  file (str, optional): Path to the CSV file containing the training results.
820
866
  dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
821
- segment (bool, optional): Flag to indicate if the data is for segmentation.
822
- pose (bool, optional): Flag to indicate if the data is for pose estimation.
823
- classify (bool, optional): Flag to indicate if the data is for classification.
824
867
  on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
825
868
 
826
869
  Examples:
@@ -828,38 +871,35 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
828
871
  >>> plot_results("path/to/results.csv", segment=True)
829
872
  """
830
873
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
831
- import pandas as pd
874
+ import polars as pl
832
875
  from scipy.ndimage import gaussian_filter1d
833
876
 
834
877
  save_dir = Path(file).parent if file else Path(dir)
835
- if classify:
836
- fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
837
- index = [2, 5, 3, 4]
838
- elif segment:
839
- fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
840
- index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
841
- elif pose:
842
- fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
843
- index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
844
- else:
845
- fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
846
- index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
847
- ax = ax.ravel()
848
878
  files = list(save_dir.glob("results*.csv"))
849
879
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
850
- for f in files:
880
+
881
+ loss_keys, metric_keys = [], []
882
+ for i, f in enumerate(files):
851
883
  try:
852
- data = pd.read_csv(f)
853
- s = [x.strip() for x in data.columns]
854
- x = data.values[:, 0]
855
- for i, j in enumerate(index):
856
- y = data.values[:, j].astype("float")
857
- # y[y == 0] = np.nan # don't show zero values
884
+ data = pl.read_csv(f, infer_schema_length=None)
885
+ if i == 0:
886
+ for c in data.columns:
887
+ if "loss" in c:
888
+ loss_keys.append(c)
889
+ elif "metric" in c:
890
+ metric_keys.append(c)
891
+ loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
892
+ columns = (
893
+ loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
894
+ )
895
+ fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
896
+ ax = ax.ravel()
897
+ x = data.select(data.columns[0]).to_numpy().flatten()
898
+ for i, j in enumerate(columns):
899
+ y = data.select(j).to_numpy().flatten().astype("float")
858
900
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
859
901
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
860
- ax[i].set_title(s[j], fontsize=12)
861
- # if j in {8, 9, 10}: # share train and val loss y axes
862
- # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
902
+ ax[i].set_title(j, fontsize=12)
863
903
  except Exception as e:
864
904
  LOGGER.error(f"Plotting error for {f}: {e}")
865
905
  ax[1].legend()
@@ -870,9 +910,8 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
870
910
  on_plot(fname)
871
911
 
872
912
 
873
- def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"):
874
- """
875
- Plot a scatter plot with points colored based on a 2D histogram.
913
+ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
914
+ """Plot a scatter plot with points colored based on a 2D histogram.
876
915
 
877
916
  Args:
878
917
  v (array-like): Values for the x-axis.
@@ -903,19 +942,21 @@ def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none
903
942
  plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
904
943
 
905
944
 
906
- def plot_tune_results(csv_file="tune_results.csv"):
907
- """
908
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
909
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
945
+ @plt_settings()
946
+ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
947
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
948
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
949
+ the plots.
910
950
 
911
951
  Args:
912
952
  csv_file (str, optional): Path to the CSV file containing the tuning results.
953
+ exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
913
954
 
914
955
  Examples:
915
956
  >>> plot_tune_results("path/to/tune_results.csv")
916
957
  """
917
958
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
918
- import pandas as pd
959
+ import polars as pl
919
960
  from scipy.ndimage import gaussian_filter1d
920
961
 
921
962
  def _save_one_file(file):
@@ -926,11 +967,22 @@ def plot_tune_results(csv_file="tune_results.csv"):
926
967
 
927
968
  # Scatter plots for each hyperparameter
928
969
  csv_file = Path(csv_file)
929
- data = pd.read_csv(csv_file)
970
+ data = pl.read_csv(csv_file, infer_schema_length=None)
930
971
  num_metrics_columns = 1
931
972
  keys = [x.strip() for x in data.columns][num_metrics_columns:]
932
- x = data.values
973
+ x = data.to_numpy()
933
974
  fitness = x[:, 0] # fitness
975
+ if exclude_zero_fitness_points:
976
+ mask = fitness > 0 # exclude zero-fitness points
977
+ x, fitness = x[mask], fitness[mask]
978
+ # Iterative sigma rejection on lower bound only
979
+ for _ in range(3): # max 3 iterations
980
+ mean, std = fitness.mean(), fitness.std()
981
+ lower_bound = mean - 3 * std
982
+ mask = fitness >= lower_bound
983
+ if mask.all(): # no more outliers
984
+ break
985
+ x, fitness = x[mask], fitness[mask]
934
986
  j = np.argmax(fitness) # max fitness index
935
987
  n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
936
988
  plt.figure(figsize=(10, 10), tight_layout=True)
@@ -959,31 +1011,9 @@ def plot_tune_results(csv_file="tune_results.csv"):
959
1011
  _save_one_file(csv_file.with_name("tune_fitness.png"))
960
1012
 
961
1013
 
962
- def output_to_target(output, max_det=300):
963
- """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
964
- targets = []
965
- for i, o in enumerate(output):
966
- box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
967
- j = torch.full((conf.shape[0], 1), i)
968
- targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
969
- targets = torch.cat(targets, 0).numpy()
970
- return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
971
-
972
-
973
- def output_to_rotated_target(output, max_det=300):
974
- """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
975
- targets = []
976
- for i, o in enumerate(output):
977
- box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1)
978
- j = torch.full((conf.shape[0], 1), i)
979
- targets.append(torch.cat((j, cls, box, angle, conf), 1))
980
- targets = torch.cat(targets, 0).numpy()
981
- return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1]
982
-
983
-
984
- def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
985
- """
986
- Visualize feature maps of a given model module during inference.
1014
+ @plt_settings()
1015
+ def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1016
+ """Visualize feature maps of a given model module during inference.
987
1017
 
988
1018
  Args:
989
1019
  x (torch.Tensor): Features to be visualized.
@@ -1000,7 +1030,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detec
1000
1030
  if isinstance(x, torch.Tensor):
1001
1031
  _, channels, height, width = x.shape # batch, channels, height, width
1002
1032
  if height > 1 and width > 1:
1003
- f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
1033
+ f = save_dir / f"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png" # filename
1004
1034
 
1005
1035
  blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
1006
1036
  n = min(n, channels) # number of plots