dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
18
18
 
19
19
 
20
20
  def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
21
- """
22
- Read an image from a file with multilanguage filename support.
21
+ """Read an image from a file with multilanguage filename support.
23
22
 
24
23
  Args:
25
24
  filename (str): Path to the file to read.
@@ -36,17 +35,58 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
36
35
  if filename.endswith((".tiff", ".tif")):
37
36
  success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
38
37
  if success:
39
- # Handle RGB images in tif/tiff format
38
+ # Handle multi-frame TIFFs and color images
40
39
  return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
41
40
  return None
42
41
  else:
43
42
  im = cv2.imdecode(file_bytes, flags)
43
+ # Fallback for formats OpenCV imdecode may not support (AVIF, HEIC)
44
+ if im is None and filename.lower().endswith((".avif", ".heic")):
45
+ im = _imread_pil(filename, flags)
44
46
  return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
45
47
 
46
48
 
47
- def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
49
+ _pil_plugins_registered = False
50
+
51
+
52
+ def _imread_pil(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
53
+ """Read an image using PIL as fallback for formats not supported by OpenCV.
54
+
55
+ Args:
56
+ filename (str): Path to the file to read.
57
+ flags (int, optional): OpenCV imread flags (used to determine grayscale conversion).
58
+
59
+ Returns:
60
+ (np.ndarray | None): The read image array in BGR format, or None if reading fails.
48
61
  """
49
- Write an image to a file with multilanguage filename support.
62
+ global _pil_plugins_registered
63
+ try:
64
+ from PIL import Image
65
+
66
+ # Register HEIF/AVIF plugins once
67
+ if not _pil_plugins_registered:
68
+ try:
69
+ import pillow_heif
70
+
71
+ pillow_heif.register_heif_opener()
72
+ except ImportError:
73
+ pass
74
+ try:
75
+ import pillow_avif # noqa: F401
76
+ except ImportError:
77
+ pass
78
+ _pil_plugins_registered = True
79
+
80
+ with Image.open(filename) as img:
81
+ if flags == cv2.IMREAD_GRAYSCALE:
82
+ return np.asarray(img.convert("L"))
83
+ return cv2.cvtColor(np.asarray(img.convert("RGB")), cv2.COLOR_RGB2BGR)
84
+ except Exception:
85
+ return None
86
+
87
+
88
+ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
89
+ """Write an image to a file with multilanguage filename support.
50
90
 
51
91
  Args:
52
92
  filename (str): Path to the file to write.
@@ -71,15 +111,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
71
111
 
72
112
 
73
113
  def imshow(winname: str, mat: np.ndarray) -> None:
74
- """
75
- Display an image in the specified window with multilanguage window name support.
114
+ """Display an image in the specified window with multilanguage window name support.
76
115
 
77
116
  This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
78
117
  multilanguage window names by encoding them properly for OpenCV compatibility.
79
118
 
80
119
  Args:
81
- winname (str): Name of the window where the image will be displayed. If a window with this name already
82
- exists, the image will be displayed in that window.
120
+ winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
121
+ the image will be displayed in that window.
83
122
  mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
84
123
 
85
124
  Examples:
@@ -96,8 +135,7 @@ _torch_save = torch.save
96
135
 
97
136
 
98
137
  def torch_load(*args, **kwargs):
99
- """
100
- Load a PyTorch model with updated arguments to avoid warnings.
138
+ """Load a PyTorch model with updated arguments to avoid warnings.
101
139
 
102
140
  This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
103
141
 
@@ -109,8 +147,8 @@ def torch_load(*args, **kwargs):
109
147
  (Any): The loaded PyTorch object.
110
148
 
111
149
  Notes:
112
- For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
113
- if the argument is not provided, to avoid deprecation warnings.
150
+ For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
151
+ not provided, to avoid deprecation warnings.
114
152
  """
115
153
  from ultralytics.utils.torch_utils import TORCH_1_13
116
154
 
@@ -121,11 +159,10 @@ def torch_load(*args, **kwargs):
121
159
 
122
160
 
123
161
  def torch_save(*args, **kwargs):
124
- """
125
- Save PyTorch objects with retry mechanism for robustness.
162
+ """Save PyTorch objects with retry mechanism for robustness.
126
163
 
127
- This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
128
- due to device flushing delays or antivirus scanning.
164
+ This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
165
+ to device flushing delays or antivirus scanning.
129
166
 
130
167
  Args:
131
168
  *args (Any): Positional arguments to pass to torch.save.
@@ -146,8 +183,7 @@ def torch_save(*args, **kwargs):
146
183
 
147
184
  @contextmanager
148
185
  def arange_patch(args):
149
- """
150
- Workaround for ONNX torch.arange incompatibility with FP16.
186
+ """Workaround for ONNX torch.arange incompatibility with FP16.
151
187
 
152
188
  https://github.com/pytorch/pytorch/issues/148041.
153
189
  """
@@ -165,10 +201,28 @@ def arange_patch(args):
165
201
  yield
166
202
 
167
203
 
204
+ @contextmanager
205
+ def onnx_export_patch():
206
+ """Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
207
+ from ultralytics.utils.torch_utils import TORCH_2_9
208
+
209
+ if TORCH_2_9:
210
+ func = torch.onnx.export
211
+
212
+ def torch_export(*args, **kwargs):
213
+ """Return a 1-D tensor of size with values from the interval and common difference."""
214
+ return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
215
+
216
+ torch.onnx.export = torch_export # patch
217
+ yield
218
+ torch.onnx.export = func # unpatch
219
+ else:
220
+ yield
221
+
222
+
168
223
  @contextmanager
169
224
  def override_configs(args, overrides: dict[str, Any] | None = None):
170
- """
171
- Context manager to temporarily override configurations in args.
225
+ """Context manager to temporarily override configurations in args.
172
226
 
173
227
  Args:
174
228
  args (IterableSimpleNamespace): Original configuration arguments.
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import math
6
- import warnings
6
+ from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import Any, Callable
8
+ from typing import Any
9
9
 
10
10
  import cv2
11
11
  import numpy as np
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
19
19
 
20
20
 
21
21
  class Colors:
22
- """
23
- Ultralytics color palette for visualization and plotting.
24
-
25
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
26
- RGB values and accessing predefined color schemes for object detection and pose estimation.
22
+ """Ultralytics color palette for visualization and plotting.
27
23
 
28
- Attributes:
29
- palette (list[tuple]): List of RGB color tuples for general use.
30
- n (int): The number of colors in the palette.
31
- pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
32
-
33
- Examples:
34
- >>> from ultralytics.utils.plotting import Colors
35
- >>> colors = Colors()
36
- >>> colors(5, True) # Returns BGR format: (221, 111, 255)
37
- >>> colors(5, False) # Returns RGB format: (255, 111, 221)
24
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
25
+ values and accessing predefined color schemes for object detection and pose estimation.
38
26
 
39
27
  ## Ultralytics Color Palette
40
28
 
@@ -90,6 +78,17 @@ class Colors:
90
78
 
91
79
  For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
92
80
  Please use the official Ultralytics colors for all marketing materials.
81
+
82
+ Attributes:
83
+ palette (list[tuple]): List of RGB color tuples for general use.
84
+ n (int): The number of colors in the palette.
85
+ pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
86
+
87
+ Examples:
88
+ >>> from ultralytics.utils.plotting import Colors
89
+ >>> colors = Colors()
90
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
91
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
93
92
  """
94
93
 
95
94
  def __init__(self):
@@ -145,8 +144,7 @@ class Colors:
145
144
  )
146
145
 
147
146
  def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
148
- """
149
- Convert hex color codes to RGB values.
147
+ """Convert hex color codes to RGB values.
150
148
 
151
149
  Args:
152
150
  i (int | torch.Tensor): Color index.
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
168
166
 
169
167
 
170
168
  class Annotator:
171
- """
172
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
169
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
173
170
 
174
171
  Attributes:
175
172
  im (Image.Image | np.ndarray): The image to annotate.
@@ -206,10 +203,12 @@ class Annotator:
206
203
  if not input_is_pil:
207
204
  if im.shape[2] == 1: # handle grayscale
208
205
  im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
206
+ elif im.shape[2] == 2: # handle 2-channel images
207
+ im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
209
208
  elif im.shape[2] > 3: # multispectral
210
209
  im = np.ascontiguousarray(im[..., :3])
211
210
  if self.pil: # use PIL
212
- self.im = im if input_is_pil else Image.fromarray(im)
211
+ self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
213
212
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
214
213
  self.im = self.im.convert("RGB")
215
214
  self.draw = ImageDraw.Draw(self.im, "RGBA")
@@ -278,8 +277,7 @@ class Annotator:
278
277
  }
279
278
 
280
279
  def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
281
- """
282
- Assign text color based on background color.
280
+ """Assign text color based on background color.
283
281
 
284
282
  Args:
285
283
  color (tuple, optional): The background color of the rectangle for text (B, G, R).
@@ -302,8 +300,7 @@ class Annotator:
302
300
  return txt_color
303
301
 
304
302
  def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
305
- """
306
- Draw a bounding box on an image with a given label.
303
+ """Draw a bounding box on an image with a given label.
307
304
 
308
305
  Args:
309
306
  box (tuple): The bounding box coordinates (x1, y1, x2, y2).
@@ -364,8 +361,7 @@ class Annotator:
364
361
  )
365
362
 
366
363
  def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
367
- """
368
- Plot masks on image.
364
+ """Plot masks on image.
369
365
 
370
366
  Args:
371
367
  masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
@@ -384,25 +380,32 @@ class Annotator:
384
380
  overlay[mask.astype(bool)] = colors[i]
385
381
  self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
386
382
  else:
387
- assert isinstance(masks, torch.Tensor), "`masks` must be a torch.Tensor if `im_gpu` is provided."
383
+ assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
388
384
  if len(masks) == 0:
389
385
  self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
386
+ return
390
387
  if im_gpu.device != masks.device:
391
388
  im_gpu = im_gpu.to(masks.device)
389
+
390
+ ih, iw = self.im.shape[:2]
391
+ if not retina_masks:
392
+ # Use scale_masks to properly remove padding and upsample, convert bool to float first
393
+ masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
394
+ # Convert original BGR image to RGB tensor
395
+ im_gpu = (
396
+ torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
397
+ )
398
+
392
399
  colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
393
400
  colors = colors[:, None, None] # shape(n,1,1,3)
394
401
  masks = masks.unsqueeze(3) # shape(n,h,w,1)
395
402
  masks_color = masks * (colors * alpha) # shape(n,h,w,3)
396
-
397
403
  inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
398
- mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
404
+ mcs = masks_color.max(dim=0).values # shape(h,w,3)
399
405
 
400
- im_gpu = im_gpu.flip(dims=[0]) # flip channel
401
- im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
406
+ im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
402
407
  im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
403
- im_mask = im_gpu * 255
404
- im_mask_np = im_mask.byte().cpu().numpy()
405
- self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
408
+ self.im[:] = (im_gpu * 255).byte().cpu().numpy()
406
409
  if self.pil:
407
410
  # Convert im back to PIL and update draw
408
411
  self.fromarray(self.im)
@@ -416,8 +419,7 @@ class Annotator:
416
419
  conf_thres: float = 0.25,
417
420
  kpt_color: tuple | None = None,
418
421
  ):
419
- """
420
- Plot keypoints on the image.
422
+ """Plot keypoints on the image.
421
423
 
422
424
  Args:
423
425
  kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
@@ -427,7 +429,7 @@ class Annotator:
427
429
  conf_thres (float, optional): Confidence threshold.
428
430
  kpt_color (tuple, optional): Keypoint color (B, G, R).
429
431
 
430
- Note:
432
+ Notes:
431
433
  - `kpt_line=True` currently only supports human pose plotting.
432
434
  - Modifies self.im in-place.
433
435
  - If self.pil is True, converts image to numpy array and back to PIL.
@@ -480,8 +482,7 @@ class Annotator:
480
482
  self.draw.rectangle(xy, fill, outline, width)
481
483
 
482
484
  def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
483
- """
484
- Add text to an image using PIL or cv2.
485
+ """Add text to an image using PIL or cv2.
485
486
 
486
487
  Args:
487
488
  xy (list[int]): Top-left coordinates for text placement.
@@ -511,18 +512,19 @@ class Annotator:
511
512
  cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
512
513
 
513
514
  def fromarray(self, im):
514
- """Update self.im from a numpy array."""
515
+ """Update `self.im` from a NumPy array or PIL image."""
515
516
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
516
517
  self.draw = ImageDraw.Draw(self.im)
517
518
 
518
- def result(self):
519
- """Return annotated image as array."""
520
- return np.asarray(self.im)
519
+ def result(self, pil=False):
520
+ """Return annotated image as array or PIL image."""
521
+ im = np.asarray(self.im) # self.im is in BGR
522
+ return Image.fromarray(im[..., ::-1]) if pil else im
521
523
 
522
524
  def show(self, title: str | None = None):
523
525
  """Show the annotated image."""
524
- im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
525
- if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
526
+ im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
527
+ if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
526
528
  try:
527
529
  display(im) # noqa - display() function only available in ipython environments
528
530
  except ImportError as e:
@@ -535,12 +537,11 @@ class Annotator:
535
537
  cv2.imwrite(filename, np.asarray(self.im))
536
538
 
537
539
  @staticmethod
538
- def get_bbox_dimension(bbox: tuple | None = None):
539
- """
540
- Calculate the dimensions and area of a bounding box.
540
+ def get_bbox_dimension(bbox: tuple | list):
541
+ """Calculate the dimensions and area of a bounding box.
541
542
 
542
543
  Args:
543
- bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
+ bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
545
 
545
546
  Returns:
546
547
  width (float): Width of the bounding box.
@@ -562,8 +563,7 @@ class Annotator:
562
563
  @TryExcept()
563
564
  @plt_settings()
564
565
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
565
- """
566
- Plot training labels including class histograms and box statistics.
566
+ """Plot training labels including class histograms and box statistics.
567
567
 
568
568
  Args:
569
569
  boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
576
576
  import polars
577
577
  from matplotlib.colors import LinearSegmentedColormap
578
578
 
579
- # Filter matplotlib>=3.7.2 warning
580
- warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
581
- warnings.filterwarnings("ignore", category=FutureWarning)
582
-
583
579
  # Plot dataset labels
584
580
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
585
581
  nc = int(cls.max() + 1) # number of classes
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
601
597
  ax[0].set_xlabel("classes")
602
598
  boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
603
599
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
604
- for cls, box in zip(cls[:500], boxes[:500]):
605
- ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
600
+ for class_id, box in zip(cls[:500], boxes[:500]):
601
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
606
602
  ax[1].imshow(img)
607
603
  ax[1].axis("off")
608
604
 
@@ -633,12 +629,11 @@ def save_one_box(
633
629
  BGR: bool = False,
634
630
  save: bool = True,
635
631
  ):
636
- """
637
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
632
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
638
633
 
639
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
640
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
641
- adjustments to the bounding box.
634
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
635
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
636
+ bounding box.
642
637
 
643
638
  Args:
644
639
  xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
@@ -691,11 +686,11 @@ def plot_images(
691
686
  save: bool = True,
692
687
  conf_thres: float = 0.25,
693
688
  ) -> np.ndarray | None:
694
- """
695
- Plot image grid with labels, bounding boxes, masks, and keypoints.
689
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
696
690
 
697
691
  Args:
698
- labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
692
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
693
+ 'keypoints', 'batch_idx', 'img'.
699
694
  images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
700
695
  paths (Optional[list[str]]): List of file paths for each image in the batch.
701
696
  fname (str): Output filename for the plotted image grid.
@@ -709,12 +704,12 @@ def plot_images(
709
704
  Returns:
710
705
  (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
711
706
 
712
- Note:
707
+ Notes:
713
708
  This function supports both tensor and numpy array inputs. It will automatically
714
709
  convert tensor inputs to numpy arrays for processing.
715
710
 
716
711
  Channel Support:
717
- - 1 channel: Greyscale
712
+ - 1 channel: Grayscale
718
713
  - 2 channels: Third channel added as zeros
719
714
  - 3 channels: Used as-is (standard RGB)
720
715
  - 4+ channels: Cropped to first 3 channels
@@ -791,7 +786,6 @@ def plot_images(
791
786
  boxes[..., 0] += x
792
787
  boxes[..., 1] += y
793
788
  is_obb = boxes.shape[-1] == 5 # xywhr
794
- # TODO: this transformation might be unnecessary
795
789
  boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
796
790
  for j, box in enumerate(boxes.astype(np.int64).tolist()):
797
791
  c = classes[j]
@@ -860,9 +854,9 @@ def plot_images(
860
854
 
861
855
  @plt_settings()
862
856
  def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
863
- """
864
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
865
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
857
+ """Plot training results from a results CSV file. The function supports various types of data including
858
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
859
+ CSV is located.
866
860
 
867
861
  Args:
868
862
  file (str, optional): Path to the CSV file containing the training results.
@@ -914,8 +908,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
914
908
 
915
909
 
916
910
  def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
917
- """
918
- Plot a scatter plot with points colored based on a 2D histogram.
911
+ """Plot a scatter plot with points colored based on a 2D histogram.
919
912
 
920
913
  Args:
921
914
  v (array-like): Values for the x-axis.
@@ -948,9 +941,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
948
941
 
949
942
  @plt_settings()
950
943
  def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
951
- """
952
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
953
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
944
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
945
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
946
+ the plots.
954
947
 
955
948
  Args:
956
949
  csv_file (str, optional): Path to the CSV file containing the tuning results.
@@ -979,6 +972,9 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
979
972
  if exclude_zero_fitness_points:
980
973
  mask = fitness > 0 # exclude zero-fitness points
981
974
  x, fitness = x[mask], fitness[mask]
975
+ if len(fitness) == 0:
976
+ LOGGER.warning("No valid fitness values to plot (all iterations may have failed)")
977
+ return
982
978
  # Iterative sigma rejection on lower bound only
983
979
  for _ in range(3): # max 3 iterations
984
980
  mean, std = fitness.mean(), fitness.std()
@@ -1017,8 +1013,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
1017
1013
 
1018
1014
  @plt_settings()
1019
1015
  def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1020
- """
1021
- Visualize feature maps of a given model module during inference.
1016
+ """Visualize feature maps of a given model module during inference.
1022
1017
 
1023
1018
  Args:
1024
1019
  x (torch.Tensor): Features to be visualized.