dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
18
18
 
19
19
 
20
20
  def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
21
- """
22
- Read an image from a file with multilanguage filename support.
21
+ """Read an image from a file with multilanguage filename support.
23
22
 
24
23
  Args:
25
24
  filename (str): Path to the file to read.
@@ -36,7 +35,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
36
35
  if filename.endswith((".tiff", ".tif")):
37
36
  success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
38
37
  if success:
39
- # Handle RGB images in tif/tiff format
38
+ # Handle multi-frame TIFFs and color images
40
39
  return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
41
40
  return None
42
41
  else:
@@ -45,8 +44,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
45
44
 
46
45
 
47
46
  def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
48
- """
49
- Write an image to a file with multilanguage filename support.
47
+ """Write an image to a file with multilanguage filename support.
50
48
 
51
49
  Args:
52
50
  filename (str): Path to the file to write.
@@ -71,15 +69,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
71
69
 
72
70
 
73
71
  def imshow(winname: str, mat: np.ndarray) -> None:
74
- """
75
- Display an image in the specified window with multilanguage window name support.
72
+ """Display an image in the specified window with multilanguage window name support.
76
73
 
77
74
  This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
78
75
  multilanguage window names by encoding them properly for OpenCV compatibility.
79
76
 
80
77
  Args:
81
- winname (str): Name of the window where the image will be displayed. If a window with this name already
82
- exists, the image will be displayed in that window.
78
+ winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
79
+ the image will be displayed in that window.
83
80
  mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
84
81
 
85
82
  Examples:
@@ -96,8 +93,7 @@ _torch_save = torch.save
96
93
 
97
94
 
98
95
  def torch_load(*args, **kwargs):
99
- """
100
- Load a PyTorch model with updated arguments to avoid warnings.
96
+ """Load a PyTorch model with updated arguments to avoid warnings.
101
97
 
102
98
  This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
103
99
 
@@ -109,8 +105,8 @@ def torch_load(*args, **kwargs):
109
105
  (Any): The loaded PyTorch object.
110
106
 
111
107
  Notes:
112
- For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
113
- if the argument is not provided, to avoid deprecation warnings.
108
+ For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
109
+ not provided, to avoid deprecation warnings.
114
110
  """
115
111
  from ultralytics.utils.torch_utils import TORCH_1_13
116
112
 
@@ -121,11 +117,10 @@ def torch_load(*args, **kwargs):
121
117
 
122
118
 
123
119
  def torch_save(*args, **kwargs):
124
- """
125
- Save PyTorch objects with retry mechanism for robustness.
120
+ """Save PyTorch objects with retry mechanism for robustness.
126
121
 
127
- This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
128
- due to device flushing delays or antivirus scanning.
122
+ This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
123
+ to device flushing delays or antivirus scanning.
129
124
 
130
125
  Args:
131
126
  *args (Any): Positional arguments to pass to torch.save.
@@ -146,8 +141,7 @@ def torch_save(*args, **kwargs):
146
141
 
147
142
  @contextmanager
148
143
  def arange_patch(args):
149
- """
150
- Workaround for ONNX torch.arange incompatibility with FP16.
144
+ """Workaround for ONNX torch.arange incompatibility with FP16.
151
145
 
152
146
  https://github.com/pytorch/pytorch/issues/148041.
153
147
  """
@@ -165,10 +159,28 @@ def arange_patch(args):
165
159
  yield
166
160
 
167
161
 
162
+ @contextmanager
163
+ def onnx_export_patch():
164
+ """Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
165
+ from ultralytics.utils.torch_utils import TORCH_2_9
166
+
167
+ if TORCH_2_9:
168
+ func = torch.onnx.export
169
+
170
+ def torch_export(*args, **kwargs):
171
+ """Return a 1-D tensor of size with values from the interval and common difference."""
172
+ return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
173
+
174
+ torch.onnx.export = torch_export # patch
175
+ yield
176
+ torch.onnx.export = func # unpatch
177
+ else:
178
+ yield
179
+
180
+
168
181
  @contextmanager
169
182
  def override_configs(args, overrides: dict[str, Any] | None = None):
170
- """
171
- Context manager to temporarily override configurations in args.
183
+ """Context manager to temporarily override configurations in args.
172
184
 
173
185
  Args:
174
186
  args (IterableSimpleNamespace): Original configuration arguments.
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import math
6
- import warnings
6
+ from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import Any, Callable
8
+ from typing import Any
9
9
 
10
10
  import cv2
11
11
  import numpy as np
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
19
19
 
20
20
 
21
21
  class Colors:
22
- """
23
- Ultralytics color palette for visualization and plotting.
22
+ """Ultralytics color palette for visualization and plotting.
24
23
 
25
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
26
- RGB values and accessing predefined color schemes for object detection and pose estimation.
27
-
28
- Attributes:
29
- palette (list[tuple]): List of RGB color tuples for general use.
30
- n (int): The number of colors in the palette.
31
- pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
32
-
33
- Examples:
34
- >>> from ultralytics.utils.plotting import Colors
35
- >>> colors = Colors()
36
- >>> colors(5, True) # Returns BGR format: (221, 111, 255)
37
- >>> colors(5, False) # Returns RGB format: (255, 111, 221)
24
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
25
+ values and accessing predefined color schemes for object detection and pose estimation.
38
26
 
39
27
  ## Ultralytics Color Palette
40
28
 
@@ -90,6 +78,17 @@ class Colors:
90
78
 
91
79
  For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
92
80
  Please use the official Ultralytics colors for all marketing materials.
81
+
82
+ Attributes:
83
+ palette (list[tuple]): List of RGB color tuples for general use.
84
+ n (int): The number of colors in the palette.
85
+ pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
86
+
87
+ Examples:
88
+ >>> from ultralytics.utils.plotting import Colors
89
+ >>> colors = Colors()
90
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
91
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
93
92
  """
94
93
 
95
94
  def __init__(self):
@@ -145,8 +144,7 @@ class Colors:
145
144
  )
146
145
 
147
146
  def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
148
- """
149
- Convert hex color codes to RGB values.
147
+ """Convert hex color codes to RGB values.
150
148
 
151
149
  Args:
152
150
  i (int | torch.Tensor): Color index.
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
168
166
 
169
167
 
170
168
  class Annotator:
171
- """
172
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
169
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
173
170
 
174
171
  Attributes:
175
172
  im (Image.Image | np.ndarray): The image to annotate.
@@ -206,10 +203,12 @@ class Annotator:
206
203
  if not input_is_pil:
207
204
  if im.shape[2] == 1: # handle grayscale
208
205
  im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
206
+ elif im.shape[2] == 2: # handle 2-channel images
207
+ im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
209
208
  elif im.shape[2] > 3: # multispectral
210
209
  im = np.ascontiguousarray(im[..., :3])
211
210
  if self.pil: # use PIL
212
- self.im = im if input_is_pil else Image.fromarray(im)
211
+ self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
213
212
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
214
213
  self.im = self.im.convert("RGB")
215
214
  self.draw = ImageDraw.Draw(self.im, "RGBA")
@@ -278,8 +277,7 @@ class Annotator:
278
277
  }
279
278
 
280
279
  def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
281
- """
282
- Assign text color based on background color.
280
+ """Assign text color based on background color.
283
281
 
284
282
  Args:
285
283
  color (tuple, optional): The background color of the rectangle for text (B, G, R).
@@ -302,8 +300,7 @@ class Annotator:
302
300
  return txt_color
303
301
 
304
302
  def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
305
- """
306
- Draw a bounding box on an image with a given label.
303
+ """Draw a bounding box on an image with a given label.
307
304
 
308
305
  Args:
309
306
  box (tuple): The bounding box coordinates (x1, y1, x2, y2).
@@ -364,8 +361,7 @@ class Annotator:
364
361
  )
365
362
 
366
363
  def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
367
- """
368
- Plot masks on image.
364
+ """Plot masks on image.
369
365
 
370
366
  Args:
371
367
  masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
@@ -384,25 +380,32 @@ class Annotator:
384
380
  overlay[mask.astype(bool)] = colors[i]
385
381
  self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
386
382
  else:
387
- assert isinstance(masks, torch.Tensor), "`masks` must be a torch.Tensor if `im_gpu` is provided."
383
+ assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
388
384
  if len(masks) == 0:
389
385
  self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
386
+ return
390
387
  if im_gpu.device != masks.device:
391
388
  im_gpu = im_gpu.to(masks.device)
389
+
390
+ ih, iw = self.im.shape[:2]
391
+ if not retina_masks:
392
+ # Use scale_masks to properly remove padding and upsample, convert bool to float first
393
+ masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
394
+ # Convert original BGR image to RGB tensor
395
+ im_gpu = (
396
+ torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
397
+ )
398
+
392
399
  colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
393
400
  colors = colors[:, None, None] # shape(n,1,1,3)
394
401
  masks = masks.unsqueeze(3) # shape(n,h,w,1)
395
402
  masks_color = masks * (colors * alpha) # shape(n,h,w,3)
396
-
397
403
  inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
398
- mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
404
+ mcs = masks_color.max(dim=0).values # shape(h,w,3)
399
405
 
400
- im_gpu = im_gpu.flip(dims=[0]) # flip channel
401
- im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
406
+ im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
402
407
  im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
403
- im_mask = im_gpu * 255
404
- im_mask_np = im_mask.byte().cpu().numpy()
405
- self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
408
+ self.im[:] = (im_gpu * 255).byte().cpu().numpy()
406
409
  if self.pil:
407
410
  # Convert im back to PIL and update draw
408
411
  self.fromarray(self.im)
@@ -416,8 +419,7 @@ class Annotator:
416
419
  conf_thres: float = 0.25,
417
420
  kpt_color: tuple | None = None,
418
421
  ):
419
- """
420
- Plot keypoints on the image.
422
+ """Plot keypoints on the image.
421
423
 
422
424
  Args:
423
425
  kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
@@ -427,7 +429,7 @@ class Annotator:
427
429
  conf_thres (float, optional): Confidence threshold.
428
430
  kpt_color (tuple, optional): Keypoint color (B, G, R).
429
431
 
430
- Note:
432
+ Notes:
431
433
  - `kpt_line=True` currently only supports human pose plotting.
432
434
  - Modifies self.im in-place.
433
435
  - If self.pil is True, converts image to numpy array and back to PIL.
@@ -480,8 +482,7 @@ class Annotator:
480
482
  self.draw.rectangle(xy, fill, outline, width)
481
483
 
482
484
  def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
483
- """
484
- Add text to an image using PIL or cv2.
485
+ """Add text to an image using PIL or cv2.
485
486
 
486
487
  Args:
487
488
  xy (list[int]): Top-left coordinates for text placement.
@@ -511,18 +512,19 @@ class Annotator:
511
512
  cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
512
513
 
513
514
  def fromarray(self, im):
514
- """Update self.im from a numpy array."""
515
+ """Update `self.im` from a NumPy array or PIL image."""
515
516
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
516
517
  self.draw = ImageDraw.Draw(self.im)
517
518
 
518
- def result(self):
519
- """Return annotated image as array."""
520
- return np.asarray(self.im)
519
+ def result(self, pil=False):
520
+ """Return annotated image as array or PIL image."""
521
+ im = np.asarray(self.im) # self.im is in BGR
522
+ return Image.fromarray(im[..., ::-1]) if pil else im
521
523
 
522
524
  def show(self, title: str | None = None):
523
525
  """Show the annotated image."""
524
- im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
525
- if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
526
+ im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
527
+ if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
526
528
  try:
527
529
  display(im) # noqa - display() function only available in ipython environments
528
530
  except ImportError as e:
@@ -535,12 +537,11 @@ class Annotator:
535
537
  cv2.imwrite(filename, np.asarray(self.im))
536
538
 
537
539
  @staticmethod
538
- def get_bbox_dimension(bbox: tuple | None = None):
539
- """
540
- Calculate the dimensions and area of a bounding box.
540
+ def get_bbox_dimension(bbox: tuple | list):
541
+ """Calculate the dimensions and area of a bounding box.
541
542
 
542
543
  Args:
543
- bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
+ bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
545
 
545
546
  Returns:
546
547
  width (float): Width of the bounding box.
@@ -562,8 +563,7 @@ class Annotator:
562
563
  @TryExcept()
563
564
  @plt_settings()
564
565
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
565
- """
566
- Plot training labels including class histograms and box statistics.
566
+ """Plot training labels including class histograms and box statistics.
567
567
 
568
568
  Args:
569
569
  boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
576
576
  import polars
577
577
  from matplotlib.colors import LinearSegmentedColormap
578
578
 
579
- # Filter matplotlib>=3.7.2 warning
580
- warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
581
- warnings.filterwarnings("ignore", category=FutureWarning)
582
-
583
579
  # Plot dataset labels
584
580
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
585
581
  nc = int(cls.max() + 1) # number of classes
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
601
597
  ax[0].set_xlabel("classes")
602
598
  boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
603
599
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
604
- for cls, box in zip(cls[:500], boxes[:500]):
605
- ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
600
+ for class_id, box in zip(cls[:500], boxes[:500]):
601
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
606
602
  ax[1].imshow(img)
607
603
  ax[1].axis("off")
608
604
 
@@ -633,12 +629,11 @@ def save_one_box(
633
629
  BGR: bool = False,
634
630
  save: bool = True,
635
631
  ):
636
- """
637
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
632
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
638
633
 
639
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
640
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
641
- adjustments to the bounding box.
634
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
635
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
636
+ bounding box.
642
637
 
643
638
  Args:
644
639
  xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
@@ -691,11 +686,11 @@ def plot_images(
691
686
  save: bool = True,
692
687
  conf_thres: float = 0.25,
693
688
  ) -> np.ndarray | None:
694
- """
695
- Plot image grid with labels, bounding boxes, masks, and keypoints.
689
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
696
690
 
697
691
  Args:
698
- labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
692
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
693
+ 'keypoints', 'batch_idx', 'img'.
699
694
  images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
700
695
  paths (Optional[list[str]]): List of file paths for each image in the batch.
701
696
  fname (str): Output filename for the plotted image grid.
@@ -709,9 +704,15 @@ def plot_images(
709
704
  Returns:
710
705
  (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
711
706
 
712
- Note:
707
+ Notes:
713
708
  This function supports both tensor and numpy array inputs. It will automatically
714
709
  convert tensor inputs to numpy arrays for processing.
710
+
711
+ Channel Support:
712
+ - 1 channel: Grayscale
713
+ - 2 channels: Third channel added as zeros
714
+ - 3 channels: Used as-is (standard RGB)
715
+ - 4+ channels: Cropped to first 3 channels
715
716
  """
716
717
  for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
717
718
  if k not in labels:
@@ -731,7 +732,13 @@ def plot_images(
731
732
 
732
733
  if len(images) and isinstance(images, torch.Tensor):
733
734
  images = images.cpu().float().numpy()
734
- if images.shape[1] > 3:
735
+
736
+ # Handle 2-ch and n-ch images
737
+ c = images.shape[1]
738
+ if c == 2:
739
+ zero = np.zeros_like(images[:, :1])
740
+ images = np.concatenate((images, zero), axis=1) # pad 2-ch with a black channel
741
+ elif c > 3:
735
742
  images = images[:, :3] # crop multispectral images to first 3 channels
736
743
 
737
744
  bs, _, h, w = images.shape # batch size, _, height, width
@@ -766,10 +773,10 @@ def plot_images(
766
773
  idx = batch_idx == i
767
774
  classes = cls[idx].astype("int")
768
775
  labels = confs is None
776
+ conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
769
777
 
770
778
  if len(bboxes):
771
779
  boxes = bboxes[idx]
772
- conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
773
780
  if len(boxes):
774
781
  if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
775
782
  boxes[..., [0, 2]] *= w # scale to pixels
@@ -779,7 +786,6 @@ def plot_images(
779
786
  boxes[..., 0] += x
780
787
  boxes[..., 1] += y
781
788
  is_obb = boxes.shape[-1] == 5 # xywhr
782
- # TODO: this transformation might be unnecessary
783
789
  boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
784
790
  for j, box in enumerate(boxes.astype(np.int64).tolist()):
785
791
  c = classes[j]
@@ -793,7 +799,8 @@ def plot_images(
793
799
  for c in classes:
794
800
  color = colors(c)
795
801
  c = names.get(c, c) if names else c
796
- annotator.text([x, y], f"{c}", txt_color=color, box_color=(64, 64, 64, 128))
802
+ label = f"{c}" if labels else f"{c} {conf[0]:.1f}"
803
+ annotator.text([x, y], label, txt_color=color, box_color=(64, 64, 64, 128))
797
804
 
798
805
  # Plot keypoints
799
806
  if len(kpts):
@@ -812,14 +819,13 @@ def plot_images(
812
819
 
813
820
  # Plot masks
814
821
  if len(masks):
815
- if idx.shape[0] == masks.shape[0]: # overlap_mask=False
822
+ if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
816
823
  image_masks = masks[idx]
817
824
  else: # overlap_mask=True
818
825
  image_masks = masks[[i]] # (1, 640, 640)
819
826
  nl = idx.sum()
820
- index = np.arange(nl).reshape((nl, 1, 1)) + 1
821
- image_masks = np.repeat(image_masks, nl, axis=0)
822
- image_masks = np.where(image_masks == index, 1.0, 0.0)
827
+ index = np.arange(1, nl + 1).reshape((nl, 1, 1))
828
+ image_masks = (image_masks == index).astype(np.float32)
823
829
 
824
830
  im = np.asarray(annotator.im).copy()
825
831
  for j in range(len(image_masks)):
@@ -847,24 +853,14 @@ def plot_images(
847
853
 
848
854
 
849
855
  @plt_settings()
850
- def plot_results(
851
- file: str = "path/to/results.csv",
852
- dir: str = "",
853
- segment: bool = False,
854
- pose: bool = False,
855
- classify: bool = False,
856
- on_plot: Callable | None = None,
857
- ):
858
- """
859
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
860
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
856
+ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
857
+ """Plot training results from a results CSV file. The function supports various types of data including
858
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
859
+ CSV is located.
861
860
 
862
861
  Args:
863
862
  file (str, optional): Path to the CSV file containing the training results.
864
863
  dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
865
- segment (bool, optional): Flag to indicate if the data is for segmentation.
866
- pose (bool, optional): Flag to indicate if the data is for pose estimation.
867
- classify (bool, optional): Flag to indicate if the data is for classification.
868
864
  on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
869
865
 
870
866
  Examples:
@@ -876,34 +872,31 @@ def plot_results(
876
872
  from scipy.ndimage import gaussian_filter1d
877
873
 
878
874
  save_dir = Path(file).parent if file else Path(dir)
879
- if classify:
880
- fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
881
- index = [2, 5, 3, 4]
882
- elif segment:
883
- fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
884
- index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
885
- elif pose:
886
- fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
887
- index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
888
- else:
889
- fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
890
- index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
891
- ax = ax.ravel()
892
875
  files = list(save_dir.glob("results*.csv"))
893
876
  assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
894
- for f in files:
877
+
878
+ loss_keys, metric_keys = [], []
879
+ for i, f in enumerate(files):
895
880
  try:
896
881
  data = pl.read_csv(f, infer_schema_length=None)
897
- s = [x.strip() for x in data.columns]
882
+ if i == 0:
883
+ for c in data.columns:
884
+ if "loss" in c:
885
+ loss_keys.append(c)
886
+ elif "metric" in c:
887
+ metric_keys.append(c)
888
+ loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
889
+ columns = (
890
+ loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
891
+ )
892
+ fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
893
+ ax = ax.ravel()
898
894
  x = data.select(data.columns[0]).to_numpy().flatten()
899
- for i, j in enumerate(index):
900
- y = data.select(data.columns[j]).to_numpy().flatten().astype("float")
901
- # y[y == 0] = np.nan # don't show zero values
895
+ for i, j in enumerate(columns):
896
+ y = data.select(j).to_numpy().flatten().astype("float")
902
897
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
903
898
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
904
- ax[i].set_title(s[j], fontsize=12)
905
- # if j in {8, 9, 10}: # share train and val loss y axes
906
- # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
899
+ ax[i].set_title(j, fontsize=12)
907
900
  except Exception as e:
908
901
  LOGGER.error(f"Plotting error for {f}: {e}")
909
902
  ax[1].legend()
@@ -915,8 +908,7 @@ def plot_results(
915
908
 
916
909
 
917
910
  def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
918
- """
919
- Plot a scatter plot with points colored based on a 2D histogram.
911
+ """Plot a scatter plot with points colored based on a 2D histogram.
920
912
 
921
913
  Args:
922
914
  v (array-like): Values for the x-axis.
@@ -948,13 +940,14 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
948
940
 
949
941
 
950
942
  @plt_settings()
951
- def plot_tune_results(csv_file: str = "tune_results.csv"):
952
- """
953
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
954
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
943
+ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
944
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
945
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
946
+ the plots.
955
947
 
956
948
  Args:
957
949
  csv_file (str, optional): Path to the CSV file containing the tuning results.
950
+ exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
958
951
 
959
952
  Examples:
960
953
  >>> plot_tune_results("path/to/tune_results.csv")
@@ -976,6 +969,17 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
976
969
  keys = [x.strip() for x in data.columns][num_metrics_columns:]
977
970
  x = data.to_numpy()
978
971
  fitness = x[:, 0] # fitness
972
+ if exclude_zero_fitness_points:
973
+ mask = fitness > 0 # exclude zero-fitness points
974
+ x, fitness = x[mask], fitness[mask]
975
+ # Iterative sigma rejection on lower bound only
976
+ for _ in range(3): # max 3 iterations
977
+ mean, std = fitness.mean(), fitness.std()
978
+ lower_bound = mean - 3 * std
979
+ mask = fitness >= lower_bound
980
+ if mask.all(): # no more outliers
981
+ break
982
+ x, fitness = x[mask], fitness[mask]
979
983
  j = np.argmax(fitness) # max fitness index
980
984
  n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
981
985
  plt.figure(figsize=(10, 10), tight_layout=True)
@@ -1006,8 +1010,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv"):
1006
1010
 
1007
1011
  @plt_settings()
1008
1012
  def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1009
- """
1010
- Visualize feature maps of a given model module during inference.
1013
+ """Visualize feature maps of a given model module during inference.
1011
1014
 
1012
1015
  Args:
1013
1016
  x (torch.Tensor): Features to be visualized.