dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,7 @@ _imshow = cv2.imshow # copy to avoid recursion errors
18
18
 
19
19
 
20
20
  def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
21
- """
22
- Read an image from a file with multilanguage filename support.
21
+ """Read an image from a file with multilanguage filename support.
23
22
 
24
23
  Args:
25
24
  filename (str): Path to the file to read.
@@ -36,7 +35,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
36
35
  if filename.endswith((".tiff", ".tif")):
37
36
  success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
38
37
  if success:
39
- # Handle RGB images in tif/tiff format
38
+ # Handle multi-frame TIFFs and color images
40
39
  return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
41
40
  return None
42
41
  else:
@@ -45,8 +44,7 @@ def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
45
44
 
46
45
 
47
46
  def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
48
- """
49
- Write an image to a file with multilanguage filename support.
47
+ """Write an image to a file with multilanguage filename support.
50
48
 
51
49
  Args:
52
50
  filename (str): Path to the file to write.
@@ -71,15 +69,14 @@ def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) ->
71
69
 
72
70
 
73
71
  def imshow(winname: str, mat: np.ndarray) -> None:
74
- """
75
- Display an image in the specified window with multilanguage window name support.
72
+ """Display an image in the specified window with multilanguage window name support.
76
73
 
77
74
  This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
78
75
  multilanguage window names by encoding them properly for OpenCV compatibility.
79
76
 
80
77
  Args:
81
- winname (str): Name of the window where the image will be displayed. If a window with this name already
82
- exists, the image will be displayed in that window.
78
+ winname (str): Name of the window where the image will be displayed. If a window with this name already exists,
79
+ the image will be displayed in that window.
83
80
  mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
84
81
 
85
82
  Examples:
@@ -96,8 +93,7 @@ _torch_save = torch.save
96
93
 
97
94
 
98
95
  def torch_load(*args, **kwargs):
99
- """
100
- Load a PyTorch model with updated arguments to avoid warnings.
96
+ """Load a PyTorch model with updated arguments to avoid warnings.
101
97
 
102
98
  This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
103
99
 
@@ -109,8 +105,8 @@ def torch_load(*args, **kwargs):
109
105
  (Any): The loaded PyTorch object.
110
106
 
111
107
  Notes:
112
- For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
113
- if the argument is not provided, to avoid deprecation warnings.
108
+ For PyTorch versions 1.13 and above, this function automatically sets `weights_only=False` if the argument is
109
+ not provided, to avoid deprecation warnings.
114
110
  """
115
111
  from ultralytics.utils.torch_utils import TORCH_1_13
116
112
 
@@ -121,11 +117,10 @@ def torch_load(*args, **kwargs):
121
117
 
122
118
 
123
119
  def torch_save(*args, **kwargs):
124
- """
125
- Save PyTorch objects with retry mechanism for robustness.
120
+ """Save PyTorch objects with retry mechanism for robustness.
126
121
 
127
- This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
128
- due to device flushing delays or antivirus scanning.
122
+ This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur due
123
+ to device flushing delays or antivirus scanning.
129
124
 
130
125
  Args:
131
126
  *args (Any): Positional arguments to pass to torch.save.
@@ -146,8 +141,7 @@ def torch_save(*args, **kwargs):
146
141
 
147
142
  @contextmanager
148
143
  def arange_patch(args):
149
- """
150
- Workaround for ONNX torch.arange incompatibility with FP16.
144
+ """Workaround for ONNX torch.arange incompatibility with FP16.
151
145
 
152
146
  https://github.com/pytorch/pytorch/issues/148041.
153
147
  """
@@ -165,10 +159,28 @@ def arange_patch(args):
165
159
  yield
166
160
 
167
161
 
162
+ @contextmanager
163
+ def onnx_export_patch():
164
+ """Workaround for ONNX export issues in PyTorch 2.9+ with Dynamo enabled."""
165
+ from ultralytics.utils.torch_utils import TORCH_2_9
166
+
167
+ if TORCH_2_9:
168
+ func = torch.onnx.export
169
+
170
+ def torch_export(*args, **kwargs):
171
+ """Return a 1-D tensor of size with values from the interval and common difference."""
172
+ return func(*args, **kwargs, dynamo=False) # cast to dtype instead of passing dtype
173
+
174
+ torch.onnx.export = torch_export # patch
175
+ yield
176
+ torch.onnx.export = func # unpatch
177
+ else:
178
+ yield
179
+
180
+
168
181
  @contextmanager
169
182
  def override_configs(args, overrides: dict[str, Any] | None = None):
170
- """
171
- Context manager to temporarily override configurations in args.
183
+ """Context manager to temporarily override configurations in args.
172
184
 
173
185
  Args:
174
186
  args (IterableSimpleNamespace): Original configuration arguments.
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import math
6
- import warnings
6
+ from collections.abc import Callable
7
7
  from pathlib import Path
8
- from typing import Any, Callable
8
+ from typing import Any
9
9
 
10
10
  import cv2
11
11
  import numpy as np
@@ -19,22 +19,10 @@ from ultralytics.utils.files import increment_path
19
19
 
20
20
 
21
21
  class Colors:
22
- """
23
- Ultralytics color palette for visualization and plotting.
24
-
25
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
26
- RGB values and accessing predefined color schemes for object detection and pose estimation.
22
+ """Ultralytics color palette for visualization and plotting.
27
23
 
28
- Attributes:
29
- palette (list[tuple]): List of RGB color tuples for general use.
30
- n (int): The number of colors in the palette.
31
- pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
32
-
33
- Examples:
34
- >>> from ultralytics.utils.plotting import Colors
35
- >>> colors = Colors()
36
- >>> colors(5, True) # Returns BGR format: (221, 111, 255)
37
- >>> colors(5, False) # Returns RGB format: (255, 111, 221)
24
+ This class provides methods to work with the Ultralytics color palette, including converting hex color codes to RGB
25
+ values and accessing predefined color schemes for object detection and pose estimation.
38
26
 
39
27
  ## Ultralytics Color Palette
40
28
 
@@ -90,6 +78,17 @@ class Colors:
90
78
 
91
79
  For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
92
80
  Please use the official Ultralytics colors for all marketing materials.
81
+
82
+ Attributes:
83
+ palette (list[tuple]): List of RGB color tuples for general use.
84
+ n (int): The number of colors in the palette.
85
+ pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
86
+
87
+ Examples:
88
+ >>> from ultralytics.utils.plotting import Colors
89
+ >>> colors = Colors()
90
+ >>> colors(5, True) # Returns BGR format: (221, 111, 255)
91
+ >>> colors(5, False) # Returns RGB format: (255, 111, 221)
93
92
  """
94
93
 
95
94
  def __init__(self):
@@ -145,8 +144,7 @@ class Colors:
145
144
  )
146
145
 
147
146
  def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
148
- """
149
- Convert hex color codes to RGB values.
147
+ """Convert hex color codes to RGB values.
150
148
 
151
149
  Args:
152
150
  i (int | torch.Tensor): Color index.
@@ -168,8 +166,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
168
166
 
169
167
 
170
168
  class Annotator:
171
- """
172
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
169
+ """Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
173
170
 
174
171
  Attributes:
175
172
  im (Image.Image | np.ndarray): The image to annotate.
@@ -206,10 +203,12 @@ class Annotator:
206
203
  if not input_is_pil:
207
204
  if im.shape[2] == 1: # handle grayscale
208
205
  im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
206
+ elif im.shape[2] == 2: # handle 2-channel images
207
+ im = np.ascontiguousarray(np.dstack((im, np.zeros_like(im[..., :1]))))
209
208
  elif im.shape[2] > 3: # multispectral
210
209
  im = np.ascontiguousarray(im[..., :3])
211
210
  if self.pil: # use PIL
212
- self.im = im if input_is_pil else Image.fromarray(im)
211
+ self.im = im if input_is_pil else Image.fromarray(im) # stay in BGR since color palette is in BGR
213
212
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
214
213
  self.im = self.im.convert("RGB")
215
214
  self.draw = ImageDraw.Draw(self.im, "RGBA")
@@ -278,8 +277,7 @@ class Annotator:
278
277
  }
279
278
 
280
279
  def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
281
- """
282
- Assign text color based on background color.
280
+ """Assign text color based on background color.
283
281
 
284
282
  Args:
285
283
  color (tuple, optional): The background color of the rectangle for text (B, G, R).
@@ -302,8 +300,7 @@ class Annotator:
302
300
  return txt_color
303
301
 
304
302
  def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
305
- """
306
- Draw a bounding box on an image with a given label.
303
+ """Draw a bounding box on an image with a given label.
307
304
 
308
305
  Args:
309
306
  box (tuple): The bounding box coordinates (x1, y1, x2, y2).
@@ -364,8 +361,7 @@ class Annotator:
364
361
  )
365
362
 
366
363
  def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
367
- """
368
- Plot masks on image.
364
+ """Plot masks on image.
369
365
 
370
366
  Args:
371
367
  masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
@@ -384,25 +380,32 @@ class Annotator:
384
380
  overlay[mask.astype(bool)] = colors[i]
385
381
  self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
386
382
  else:
387
- assert isinstance(masks, torch.Tensor), "`masks` must be a torch.Tensor if `im_gpu` is provided."
383
+ assert isinstance(masks, torch.Tensor), "'masks' must be a torch.Tensor if 'im_gpu' is provided."
388
384
  if len(masks) == 0:
389
385
  self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
386
+ return
390
387
  if im_gpu.device != masks.device:
391
388
  im_gpu = im_gpu.to(masks.device)
389
+
390
+ ih, iw = self.im.shape[:2]
391
+ if not retina_masks:
392
+ # Use scale_masks to properly remove padding and upsample, convert bool to float first
393
+ masks = ops.scale_masks(masks[None].float(), (ih, iw))[0] > 0.5
394
+ # Convert original BGR image to RGB tensor
395
+ im_gpu = (
396
+ torch.from_numpy(self.im).to(masks.device).permute(2, 0, 1).flip(0).contiguous().float() / 255.0
397
+ )
398
+
392
399
  colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
393
400
  colors = colors[:, None, None] # shape(n,1,1,3)
394
401
  masks = masks.unsqueeze(3) # shape(n,h,w,1)
395
402
  masks_color = masks * (colors * alpha) # shape(n,h,w,3)
396
-
397
403
  inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
398
- mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
404
+ mcs = masks_color.max(dim=0).values # shape(h,w,3)
399
405
 
400
- im_gpu = im_gpu.flip(dims=[0]) # flip channel
401
- im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
406
+ im_gpu = im_gpu.flip(dims=[0]).permute(1, 2, 0).contiguous() # shape(h,w,3)
402
407
  im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
403
- im_mask = im_gpu * 255
404
- im_mask_np = im_mask.byte().cpu().numpy()
405
- self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
408
+ self.im[:] = (im_gpu * 255).byte().cpu().numpy()
406
409
  if self.pil:
407
410
  # Convert im back to PIL and update draw
408
411
  self.fromarray(self.im)
@@ -416,8 +419,7 @@ class Annotator:
416
419
  conf_thres: float = 0.25,
417
420
  kpt_color: tuple | None = None,
418
421
  ):
419
- """
420
- Plot keypoints on the image.
422
+ """Plot keypoints on the image.
421
423
 
422
424
  Args:
423
425
  kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
@@ -427,7 +429,7 @@ class Annotator:
427
429
  conf_thres (float, optional): Confidence threshold.
428
430
  kpt_color (tuple, optional): Keypoint color (B, G, R).
429
431
 
430
- Note:
432
+ Notes:
431
433
  - `kpt_line=True` currently only supports human pose plotting.
432
434
  - Modifies self.im in-place.
433
435
  - If self.pil is True, converts image to numpy array and back to PIL.
@@ -480,8 +482,7 @@ class Annotator:
480
482
  self.draw.rectangle(xy, fill, outline, width)
481
483
 
482
484
  def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
483
- """
484
- Add text to an image using PIL or cv2.
485
+ """Add text to an image using PIL or cv2.
485
486
 
486
487
  Args:
487
488
  xy (list[int]): Top-left coordinates for text placement.
@@ -511,18 +512,19 @@ class Annotator:
511
512
  cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
512
513
 
513
514
  def fromarray(self, im):
514
- """Update self.im from a numpy array."""
515
+ """Update `self.im` from a NumPy array or PIL image."""
515
516
  self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
516
517
  self.draw = ImageDraw.Draw(self.im)
517
518
 
518
- def result(self):
519
- """Return annotated image as array."""
520
- return np.asarray(self.im)
519
+ def result(self, pil=False):
520
+ """Return annotated image as array or PIL image."""
521
+ im = np.asarray(self.im) # self.im is in BGR
522
+ return Image.fromarray(im[..., ::-1]) if pil else im
521
523
 
522
524
  def show(self, title: str | None = None):
523
525
  """Show the annotated image."""
524
- im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
525
- if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
526
+ im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert BGR NumPy array to RGB PIL Image
527
+ if IS_COLAB or IS_KAGGLE: # cannot use IS_JUPYTER as it runs for all IPython environments
526
528
  try:
527
529
  display(im) # noqa - display() function only available in ipython environments
528
530
  except ImportError as e:
@@ -535,12 +537,11 @@ class Annotator:
535
537
  cv2.imwrite(filename, np.asarray(self.im))
536
538
 
537
539
  @staticmethod
538
- def get_bbox_dimension(bbox: tuple | None = None):
539
- """
540
- Calculate the dimensions and area of a bounding box.
540
+ def get_bbox_dimension(bbox: tuple | list):
541
+ """Calculate the dimensions and area of a bounding box.
541
542
 
542
543
  Args:
543
- bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
+ bbox (tuple | list): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
544
545
 
545
546
  Returns:
546
547
  width (float): Width of the bounding box.
@@ -562,8 +563,7 @@ class Annotator:
562
563
  @TryExcept()
563
564
  @plt_settings()
564
565
  def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
565
- """
566
- Plot training labels including class histograms and box statistics.
566
+ """Plot training labels including class histograms and box statistics.
567
567
 
568
568
  Args:
569
569
  boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
@@ -576,10 +576,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
576
576
  import polars
577
577
  from matplotlib.colors import LinearSegmentedColormap
578
578
 
579
- # Filter matplotlib>=3.7.2 warning
580
- warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
581
- warnings.filterwarnings("ignore", category=FutureWarning)
582
-
583
579
  # Plot dataset labels
584
580
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
585
581
  nc = int(cls.max() + 1) # number of classes
@@ -601,8 +597,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
601
597
  ax[0].set_xlabel("classes")
602
598
  boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
603
599
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
604
- for cls, box in zip(cls[:500], boxes[:500]):
605
- ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
600
+ for class_id, box in zip(cls[:500], boxes[:500]):
601
+ ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot
606
602
  ax[1].imshow(img)
607
603
  ax[1].axis("off")
608
604
 
@@ -633,12 +629,11 @@ def save_one_box(
633
629
  BGR: bool = False,
634
630
  save: bool = True,
635
631
  ):
636
- """
637
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
632
+ """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
638
633
 
639
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
640
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
641
- adjustments to the bounding box.
634
+ This function takes a bounding box and an image, and then saves a cropped portion of the image according to the
635
+ bounding box. Optionally, the crop can be squared, and the function allows for gain and padding adjustments to the
636
+ bounding box.
642
637
 
643
638
  Args:
644
639
  xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
@@ -691,11 +686,11 @@ def plot_images(
691
686
  save: bool = True,
692
687
  conf_thres: float = 0.25,
693
688
  ) -> np.ndarray | None:
694
- """
695
- Plot image grid with labels, bounding boxes, masks, and keypoints.
689
+ """Plot image grid with labels, bounding boxes, masks, and keypoints.
696
690
 
697
691
  Args:
698
- labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
692
+ labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks',
693
+ 'keypoints', 'batch_idx', 'img'.
699
694
  images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
700
695
  paths (Optional[list[str]]): List of file paths for each image in the batch.
701
696
  fname (str): Output filename for the plotted image grid.
@@ -709,12 +704,12 @@ def plot_images(
709
704
  Returns:
710
705
  (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
711
706
 
712
- Note:
707
+ Notes:
713
708
  This function supports both tensor and numpy array inputs. It will automatically
714
709
  convert tensor inputs to numpy arrays for processing.
715
710
 
716
711
  Channel Support:
717
- - 1 channel: Greyscale
712
+ - 1 channel: Grayscale
718
713
  - 2 channels: Third channel added as zeros
719
714
  - 3 channels: Used as-is (standard RGB)
720
715
  - 4+ channels: Cropped to first 3 channels
@@ -791,7 +786,6 @@ def plot_images(
791
786
  boxes[..., 0] += x
792
787
  boxes[..., 1] += y
793
788
  is_obb = boxes.shape[-1] == 5 # xywhr
794
- # TODO: this transformation might be unnecessary
795
789
  boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
796
790
  for j, box in enumerate(boxes.astype(np.int64).tolist()):
797
791
  c = classes[j]
@@ -860,9 +854,9 @@ def plot_images(
860
854
 
861
855
  @plt_settings()
862
856
  def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
863
- """
864
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
865
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
857
+ """Plot training results from a results CSV file. The function supports various types of data including
858
+ segmentation, pose estimation, and classification. Plots are saved as 'results.png' in the directory where the
859
+ CSV is located.
866
860
 
867
861
  Args:
868
862
  file (str, optional): Path to the CSV file containing the training results.
@@ -914,8 +908,7 @@ def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Call
914
908
 
915
909
 
916
910
  def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
917
- """
918
- Plot a scatter plot with points colored based on a 2D histogram.
911
+ """Plot a scatter plot with points colored based on a 2D histogram.
919
912
 
920
913
  Args:
921
914
  v (array-like): Values for the x-axis.
@@ -948,9 +941,9 @@ def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float
948
941
 
949
942
  @plt_settings()
950
943
  def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
951
- """
952
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
953
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
944
+ """Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each
945
+ key in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on
946
+ the plots.
954
947
 
955
948
  Args:
956
949
  csv_file (str, optional): Path to the CSV file containing the tuning results.
@@ -1017,8 +1010,7 @@ def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_p
1017
1010
 
1018
1011
  @plt_settings()
1019
1012
  def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
1020
- """
1021
- Visualize feature maps of a given model module during inference.
1013
+ """Visualize feature maps of a given model module during inference.
1022
1014
 
1023
1015
  Args:
1024
1016
  x (torch.Tensor): Features to be visualized.
ultralytics/utils/tal.py CHANGED
@@ -10,8 +10,7 @@ from .torch_utils import TORCH_1_11
10
10
 
11
11
 
12
12
  class TaskAlignedAssigner(nn.Module):
13
- """
14
- A task-aligned assigner for object detection.
13
+ """A task-aligned assigner for object detection.
15
14
 
16
15
  This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
17
16
  classification and localization information.
@@ -25,8 +24,7 @@ class TaskAlignedAssigner(nn.Module):
25
24
  """
26
25
 
27
26
  def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
28
- """
29
- Initialize a TaskAlignedAssigner object with customizable hyperparameters.
27
+ """Initialize a TaskAlignedAssigner object with customizable hyperparameters.
30
28
 
31
29
  Args:
32
30
  topk (int, optional): The number of top candidates to consider.
@@ -44,8 +42,7 @@ class TaskAlignedAssigner(nn.Module):
44
42
 
45
43
  @torch.no_grad()
46
44
  def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
47
- """
48
- Compute the task-aligned assignment.
45
+ """Compute the task-aligned assignment.
49
46
 
50
47
  Args:
51
48
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -88,8 +85,7 @@ class TaskAlignedAssigner(nn.Module):
88
85
  return tuple(t.to(device) for t in result)
89
86
 
90
87
  def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
91
- """
92
- Compute the task-aligned assignment.
88
+ """Compute the task-aligned assignment.
93
89
 
94
90
  Args:
95
91
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -125,8 +121,7 @@ class TaskAlignedAssigner(nn.Module):
125
121
  return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
126
122
 
127
123
  def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
128
- """
129
- Get positive mask for each ground truth box.
124
+ """Get positive mask for each ground truth box.
130
125
 
131
126
  Args:
132
127
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -139,7 +134,7 @@ class TaskAlignedAssigner(nn.Module):
139
134
  Returns:
140
135
  mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
141
136
  align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
142
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
137
+ overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
143
138
  """
144
139
  mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
145
140
  # Get anchor_align metric, (b, max_num_obj, h*w)
@@ -152,8 +147,7 @@ class TaskAlignedAssigner(nn.Module):
152
147
  return mask_pos, align_metric, overlaps
153
148
 
154
149
  def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
155
- """
156
- Compute alignment metric given predicted and ground truth bounding boxes.
150
+ """Compute alignment metric given predicted and ground truth bounding boxes.
157
151
 
158
152
  Args:
159
153
  pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
@@ -186,8 +180,7 @@ class TaskAlignedAssigner(nn.Module):
186
180
  return align_metric, overlaps
187
181
 
188
182
  def iou_calculation(self, gt_bboxes, pd_bboxes):
189
- """
190
- Calculate IoU for horizontal bounding boxes.
183
+ """Calculate IoU for horizontal bounding boxes.
191
184
 
192
185
  Args:
193
186
  gt_bboxes (torch.Tensor): Ground truth boxes.
@@ -199,14 +192,13 @@ class TaskAlignedAssigner(nn.Module):
199
192
  return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
200
193
 
201
194
  def select_topk_candidates(self, metrics, topk_mask=None):
202
- """
203
- Select the top-k candidates based on the given metrics.
195
+ """Select the top-k candidates based on the given metrics.
204
196
 
205
197
  Args:
206
198
  metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
207
199
  the maximum number of objects, and h*w represents the total number of anchor points.
208
- topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
209
- topk is the number of top candidates to consider. If not provided, the top-k values are automatically
200
+ topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
201
+ is the number of top candidates to consider. If not provided, the top-k values are automatically
210
202
  computed based on the given metrics.
211
203
 
212
204
  Returns:
@@ -231,18 +223,16 @@ class TaskAlignedAssigner(nn.Module):
231
223
  return count_tensor.to(metrics.dtype)
232
224
 
233
225
  def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
234
- """
235
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
226
+ """Compute target labels, target bounding boxes, and target scores for the positive anchor points.
236
227
 
237
228
  Args:
238
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
239
- batch size and max_num_obj is the maximum number of objects.
229
+ gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
230
+ max_num_obj is the maximum number of objects.
240
231
  gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
241
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
242
- anchor points, with shape (b, h*w), where h*w is the total
243
- number of anchor points.
244
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
245
- (foreground) anchor points.
232
+ target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
233
+ shape (b, h*w), where h*w is the total number of anchor points.
234
+ fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
235
+ points.
246
236
 
247
237
  Returns:
248
238
  target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
@@ -275,8 +265,7 @@ class TaskAlignedAssigner(nn.Module):
275
265
 
276
266
  @staticmethod
277
267
  def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
278
- """
279
- Select positive anchor centers within ground truth bounding boxes.
268
+ """Select positive anchor centers within ground truth bounding boxes.
280
269
 
281
270
  Args:
282
271
  xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
@@ -286,9 +275,9 @@ class TaskAlignedAssigner(nn.Module):
286
275
  Returns:
287
276
  (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
288
277
 
289
- Note:
290
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
291
- Bounding box format: [x_min, y_min, x_max, y_max].
278
+ Notes:
279
+ - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
280
+ - Bounding box format: [x_min, y_min, x_max, y_max].
292
281
  """
293
282
  n_anchors = xy_centers.shape[0]
294
283
  bs, n_boxes, _ = gt_bboxes.shape
@@ -298,8 +287,7 @@ class TaskAlignedAssigner(nn.Module):
298
287
 
299
288
  @staticmethod
300
289
  def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
301
- """
302
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
290
+ """Select anchor boxes with highest IoU when assigned to multiple ground truths.
303
291
 
304
292
  Args:
305
293
  mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
@@ -336,8 +324,7 @@ class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
336
324
 
337
325
  @staticmethod
338
326
  def select_candidates_in_gts(xy_centers, gt_bboxes):
339
- """
340
- Select the positive anchor center in gt for rotated bounding boxes.
327
+ """Select the positive anchor center in gt for rotated bounding boxes.
341
328
 
342
329
  Args:
343
330
  xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
@@ -396,8 +383,7 @@ def bbox2dist(anchor_points, bbox, reg_max):
396
383
 
397
384
 
398
385
  def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
399
- """
400
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
386
+ """Decode predicted rotated bounding box coordinates from anchor points and distribution.
401
387
 
402
388
  Args:
403
389
  pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).