dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/utils/ops.py CHANGED
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import contextlib
4
6
  import math
5
7
  import re
@@ -10,33 +12,38 @@ import numpy as np
10
12
  import torch
11
13
  import torch.nn.functional as F
12
14
 
13
- from ultralytics.utils import LOGGER
14
- from ultralytics.utils.metrics import batch_probiou
15
+ from ultralytics.utils import NOT_MACOS14
15
16
 
16
17
 
17
18
  class Profile(contextlib.ContextDecorator):
18
- """
19
- YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
19
+ """Ultralytics Profile class for timing code execution.
20
+
21
+ Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
22
+ measurements with CUDA synchronization support for GPU operations.
20
23
 
21
24
  Attributes:
22
- t (float): Accumulated time.
25
+ t (float): Accumulated time in seconds.
23
26
  device (torch.device): Device used for model inference.
24
- cuda (bool): Whether CUDA is being used.
27
+ cuda (bool): Whether CUDA is being used for timing synchronization.
25
28
 
26
29
  Examples:
27
- >>> from ultralytics.utils.ops import Profile
30
+ Use as a context manager to time code execution
28
31
  >>> with Profile(device=device) as dt:
29
32
  ... pass # slow operation here
30
33
  >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
34
+
35
+ Use as a decorator to time function execution
36
+ >>> @Profile()
37
+ ... def slow_function():
38
+ ... time.sleep(0.1)
31
39
  """
32
40
 
33
- def __init__(self, t=0.0, device: torch.device = None):
34
- """
35
- Initialize the Profile class.
41
+ def __init__(self, t: float = 0.0, device: torch.device | None = None):
42
+ """Initialize the Profile class.
36
43
 
37
44
  Args:
38
- t (float): Initial time.
39
- device (torch.device): Device used for model inference.
45
+ t (float): Initial accumulated time in seconds.
46
+ device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
40
47
  """
41
48
  self.t = t
42
49
  self.device = device
@@ -47,36 +54,38 @@ class Profile(contextlib.ContextDecorator):
47
54
  self.start = self.time()
48
55
  return self
49
56
 
50
- def __exit__(self, type, value, traceback): # noqa
57
+ def __exit__(self, type, value, traceback):
51
58
  """Stop timing."""
52
59
  self.dt = self.time() - self.start # delta-time
53
60
  self.t += self.dt # accumulate dt
54
61
 
55
62
  def __str__(self):
56
- """Returns a human-readable string representing the accumulated elapsed time in the profiler."""
63
+ """Return a human-readable string representing the accumulated elapsed time."""
57
64
  return f"Elapsed time is {self.t} s"
58
65
 
59
66
  def time(self):
60
- """Get current time."""
67
+ """Get current time with CUDA synchronization if applicable."""
61
68
  if self.cuda:
62
69
  torch.cuda.synchronize(self.device)
63
70
  return time.perf_counter()
64
71
 
65
72
 
66
- def segment2box(segment, width=640, height=640):
67
- """
68
- Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
73
+ def segment2box(segment, width: int = 640, height: int = 640):
74
+ """Convert segment coordinates to bounding box coordinates.
75
+
76
+ Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates. Applies
77
+ inside-image constraint and clips coordinates when necessary.
69
78
 
70
79
  Args:
71
- segment (torch.Tensor): The segment label.
72
- width (int): The width of the image.
73
- height (int): The height of the image.
80
+ segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
81
+ width (int): Width of the image in pixels.
82
+ height (int): Height of the image in pixels.
74
83
 
75
84
  Returns:
76
- (np.ndarray): The minimum and maximum x and y values of the segment.
85
+ (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
77
86
  """
78
87
  x, y = segment.T # segment xy
79
- # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294
88
+ # Clip coordinates if 3 out of 4 sides are outside the image
80
89
  if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
81
90
  x = x.clip(0, width)
82
91
  y = y.clip(0, height)
@@ -90,46 +99,43 @@ def segment2box(segment, width=640, height=640):
90
99
  ) # xyxy
91
100
 
92
101
 
93
- def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
94
- """
95
- Rescale bounding boxes from img1_shape to img0_shape.
102
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
103
+ """Rescale bounding boxes from one image shape to another.
104
+
105
+ Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes. Supports
106
+ both xyxy and xywh box formats.
96
107
 
97
108
  Args:
98
- img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
99
- boxes (torch.Tensor): The bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2).
100
- img0_shape (tuple): The shape of the target image, in the format of (height, width).
101
- ratio_pad (tuple): A tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
102
- calculated based on the size difference between the two images.
103
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
104
- rescaling.
105
- xywh (bool): The box format is xywh or not.
109
+ img1_shape (tuple): Shape of the source image (height, width).
110
+ boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
111
+ img0_shape (tuple): Shape of the target image (height, width).
112
+ ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
113
+ padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
114
+ xywh (bool): Whether box format is xywh (True) or xyxy (False).
106
115
 
107
116
  Returns:
108
- (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2).
117
+ (torch.Tensor): Rescaled bounding boxes in the same format as input.
109
118
  """
110
119
  if ratio_pad is None: # calculate from img0_shape
111
120
  gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
112
- pad = (
113
- round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
114
- round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
115
- ) # wh padding
121
+ pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
122
+ pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
116
123
  else:
117
124
  gain = ratio_pad[0][0]
118
- pad = ratio_pad[1]
125
+ pad_x, pad_y = ratio_pad[1]
119
126
 
120
127
  if padding:
121
- boxes[..., 0] -= pad[0] # x padding
122
- boxes[..., 1] -= pad[1] # y padding
128
+ boxes[..., 0] -= pad_x # x padding
129
+ boxes[..., 1] -= pad_y # y padding
123
130
  if not xywh:
124
- boxes[..., 2] -= pad[0] # x padding
125
- boxes[..., 3] -= pad[1] # y padding
131
+ boxes[..., 2] -= pad_x # x padding
132
+ boxes[..., 3] -= pad_y # y padding
126
133
  boxes[..., :4] /= gain
127
- return clip_boxes(boxes, img0_shape)
134
+ return boxes if xywh else clip_boxes(boxes, img0_shape)
128
135
 
129
136
 
130
- def make_divisible(x, divisor):
131
- """
132
- Returns the nearest number that is divisible by the given divisor.
137
+ def make_divisible(x: int, divisor):
138
+ """Return the nearest number that is divisible by the given divisor.
133
139
 
134
140
  Args:
135
141
  x (int): The number to make divisible.
@@ -143,276 +149,96 @@ def make_divisible(x, divisor):
143
149
  return math.ceil(x / divisor) * divisor
144
150
 
145
151
 
146
- def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
147
- """
148
- NMS for oriented bounding boxes using probiou and fast-nms.
149
-
150
- Args:
151
- boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
152
- scores (torch.Tensor): Confidence scores, shape (N,).
153
- threshold (float): IoU threshold.
154
- use_triu (bool): Whether to use `torch.triu` operator. It'd be useful for disable it
155
- when exporting obb models to some formats that do not support `torch.triu`.
156
-
157
- Returns:
158
- (torch.Tensor): Indices of boxes to keep after NMS.
159
- """
160
- sorted_idx = torch.argsort(scores, descending=True)
161
- boxes = boxes[sorted_idx]
162
- ious = batch_probiou(boxes, boxes)
163
- if use_triu:
164
- ious = ious.triu_(diagonal=1)
165
- # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
166
- # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
167
- pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
168
- else:
169
- n = boxes.shape[0]
170
- row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
171
- col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
172
- upper_mask = row_idx < col_idx
173
- ious = ious * upper_mask
174
- # Zeroing these scores ensures the additional indices would not affect the final results
175
- scores[~((ious >= threshold).sum(0) <= 0)] = 0
176
- # NOTE: return indices with fixed length to avoid TFLite reshape error
177
- pick = torch.topk(scores, scores.shape[0]).indices
178
- return sorted_idx[pick]
179
-
180
-
181
- def non_max_suppression(
182
- prediction,
183
- conf_thres=0.25,
184
- iou_thres=0.45,
185
- classes=None,
186
- agnostic=False,
187
- multi_label=False,
188
- labels=(),
189
- max_det=300,
190
- nc=0, # number of classes (optional)
191
- max_time_img=0.05,
192
- max_nms=30000,
193
- max_wh=7680,
194
- in_place=True,
195
- rotated=False,
196
- end2end=False,
197
- return_idxs=False,
198
- ):
199
- """
200
- Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
201
-
202
- Args:
203
- prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
204
- containing the predicted boxes, classes, and masks. The tensor should be in the format
205
- output by a model, such as YOLO.
206
- conf_thres (float): The confidence threshold below which boxes will be filtered out.
207
- Valid values are between 0.0 and 1.0.
208
- iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
209
- Valid values are between 0.0 and 1.0.
210
- classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
211
- agnostic (bool): If True, the model is agnostic to the number of classes, and all
212
- classes will be considered as one.
213
- multi_label (bool): If True, each box may have multiple labels.
214
- labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
215
- list contains the apriori labels for a given image. The list should be in the format
216
- output by a dataloader, with each label being a tuple of (class_index, x, y, w, h).
217
- max_det (int): The maximum number of boxes to keep after NMS.
218
- nc (int): The number of classes output by the model. Any indices after this will be considered masks.
219
- max_time_img (float): The maximum time (seconds) for processing one image.
220
- max_nms (int): The maximum number of boxes into torchvision.ops.nms().
221
- max_wh (int): The maximum box width and height in pixels.
222
- in_place (bool): If True, the input prediction tensor will be modified in place.
223
- rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
224
- end2end (bool): If the model doesn't require NMS.
225
- return_idxs (bool): Return the indices of the detections that were kept.
226
-
227
- Returns:
228
- (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
229
- shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
230
- (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
231
- """
232
- import torchvision # scope for faster 'import ultralytics'
233
-
234
- # Checks
235
- assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
236
- assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
237
- if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
238
- prediction = prediction[0] # select only inference output
239
- if classes is not None:
240
- classes = torch.tensor(classes, device=prediction.device)
241
-
242
- if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
243
- output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
244
- if classes is not None:
245
- output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
246
- return output
247
-
248
- bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
249
- nc = nc or (prediction.shape[1] - 4) # number of classes
250
- nm = prediction.shape[1] - nc - 4 # number of masks
251
- mi = 4 + nc # mask start index
252
- xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
253
- xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs
254
-
255
- # Settings
256
- # min_wh = 2 # (pixels) minimum box width and height
257
- time_limit = 2.0 + max_time_img * bs # seconds to quit after
258
- multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
259
-
260
- prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
261
- if not rotated:
262
- if in_place:
263
- prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
264
- else:
265
- prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
266
-
267
- t = time.time()
268
- output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
269
- keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
270
- for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
271
- # Apply constraints
272
- # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
273
- filt = xc[xi] # confidence
274
- x, xk = x[filt], xk[filt]
275
-
276
- # Cat apriori labels if autolabelling
277
- if labels and len(labels[xi]) and not rotated:
278
- lb = labels[xi]
279
- v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
280
- v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
281
- v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
282
- x = torch.cat((x, v), 0)
283
-
284
- # If none remain process next image
285
- if not x.shape[0]:
286
- continue
287
-
288
- # Detections matrix nx6 (xyxy, conf, cls)
289
- box, cls, mask = x.split((4, nc, nm), 1)
290
-
291
- if multi_label:
292
- i, j = torch.where(cls > conf_thres)
293
- x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
294
- xk = xk[i]
295
- else: # best class only
296
- conf, j = cls.max(1, keepdim=True)
297
- filt = conf.view(-1) > conf_thres
298
- x = torch.cat((box, conf, j.float(), mask), 1)[filt]
299
- xk = xk[filt]
300
-
301
- # Filter by class
302
- if classes is not None:
303
- filt = (x[:, 5:6] == classes).any(1)
304
- x, xk = x[filt], xk[filt]
305
-
306
- # Check shape
307
- n = x.shape[0] # number of boxes
308
- if not n: # no boxes
309
- continue
310
- if n > max_nms: # excess boxes
311
- filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
312
- x, xk = x[filt], xk[filt]
313
-
314
- # Batched NMS
315
- c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
316
- scores = x[:, 4] # scores
317
- if rotated:
318
- boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
319
- i = nms_rotated(boxes, scores, iou_thres)
320
- else:
321
- boxes = x[:, :4] + c # boxes (offset by class)
322
- i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
323
- i = i[:max_det] # limit detections
324
-
325
- # # Experimental
326
- # merge = False # use merge-NMS
327
- # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
328
- # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
329
- # from .metrics import box_iou
330
- # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
331
- # weights = iou * scores[None] # box weights
332
- # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
333
- # redundant = True # require redundant detections
334
- # if redundant:
335
- # i = i[iou.sum(1) > 1] # require redundancy
336
-
337
- output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
338
- if (time.time() - t) > time_limit:
339
- LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
340
- break # time limit exceeded
341
-
342
- return (output, keepi) if return_idxs else output
343
-
344
-
345
152
  def clip_boxes(boxes, shape):
346
- """
347
- Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
153
+ """Clip bounding boxes to image boundaries.
348
154
 
349
155
  Args:
350
- boxes (torch.Tensor | numpy.ndarray): The bounding boxes to clip.
351
- shape (tuple): The shape of the image.
156
+ boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
157
+ shape (tuple): Image shape as HWC or HW (supports both).
352
158
 
353
159
  Returns:
354
- (torch.Tensor | numpy.ndarray): The clipped boxes.
355
- """
356
- if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
357
- boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
358
- boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
359
- boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
360
- boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
160
+ (torch.Tensor | np.ndarray): Clipped bounding boxes.
161
+ """
162
+ h, w = shape[:2] # supports both HWC or HW shapes
163
+ if isinstance(boxes, torch.Tensor): # faster individually
164
+ if NOT_MACOS14:
165
+ boxes[..., 0].clamp_(0, w) # x1
166
+ boxes[..., 1].clamp_(0, h) # y1
167
+ boxes[..., 2].clamp_(0, w) # x2
168
+ boxes[..., 3].clamp_(0, h) # y2
169
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
170
+ boxes[..., 0] = boxes[..., 0].clamp(0, w)
171
+ boxes[..., 1] = boxes[..., 1].clamp(0, h)
172
+ boxes[..., 2] = boxes[..., 2].clamp(0, w)
173
+ boxes[..., 3] = boxes[..., 3].clamp(0, h)
361
174
  else: # np.array (faster grouped)
362
- boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
363
- boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
175
+ boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, w) # x1, x2
176
+ boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, h) # y1, y2
364
177
  return boxes
365
178
 
366
179
 
367
180
  def clip_coords(coords, shape):
368
- """
369
- Clip line coordinates to the image boundaries.
181
+ """Clip line coordinates to image boundaries.
370
182
 
371
183
  Args:
372
- coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
373
- shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
184
+ coords (torch.Tensor | np.ndarray): Line coordinates to clip.
185
+ shape (tuple): Image shape as HWC or HW (supports both).
374
186
 
375
187
  Returns:
376
- (torch.Tensor | numpy.ndarray): Clipped coordinates.
377
- """
378
- if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
379
- coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
380
- coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
381
- else: # np.array (faster grouped)
382
- coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
383
- coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
188
+ (torch.Tensor | np.ndarray): Clipped coordinates.
189
+ """
190
+ h, w = shape[:2] # supports both HWC or HW shapes
191
+ if isinstance(coords, torch.Tensor):
192
+ if NOT_MACOS14:
193
+ coords[..., 0].clamp_(0, w) # x
194
+ coords[..., 1].clamp_(0, h) # y
195
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
196
+ coords[..., 0] = coords[..., 0].clamp(0, w)
197
+ coords[..., 1] = coords[..., 1].clamp(0, h)
198
+ else: # np.array
199
+ coords[..., 0] = coords[..., 0].clip(0, w) # x
200
+ coords[..., 1] = coords[..., 1].clip(0, h) # y
384
201
  return coords
385
202
 
386
203
 
387
204
  def scale_image(masks, im0_shape, ratio_pad=None):
388
- """
389
- Takes a mask, and resizes it to the original image size.
205
+ """Rescale masks to original image size.
206
+
207
+ Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding that
208
+ was applied during preprocessing.
390
209
 
391
210
  Args:
392
- masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
393
- im0_shape (tuple): The original image shape.
394
- ratio_pad (tuple): The ratio of the padding to the original image.
211
+ masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
212
+ im0_shape (tuple): Original image shape as HWC or HW (supports both).
213
+ ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
395
214
 
396
215
  Returns:
397
- masks (np.ndarray): The masks that are being returned with shape [h, w, num].
216
+ (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
398
217
  """
399
218
  # Rescale coordinates (xyxy) from im1_shape to im0_shape
400
- im1_shape = masks.shape
401
- if im1_shape[:2] == im0_shape[:2]:
219
+ im0_h, im0_w = im0_shape[:2] # supports both HWC or HW shapes
220
+ im1_h, im1_w, _ = masks.shape
221
+ if im1_h == im0_h and im1_w == im0_w:
402
222
  return masks
223
+
403
224
  if ratio_pad is None: # calculate from im0_shape
404
- gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
405
- pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
225
+ gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
226
+ pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2 # wh padding
406
227
  else:
407
- # gain = ratio_pad[0][0]
408
228
  pad = ratio_pad[1]
409
- top, left = int(pad[1]), int(pad[0]) # y, x
410
- bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
229
+
230
+ pad_w, pad_h = pad
231
+ top = round(pad_h - 0.1)
232
+ left = round(pad_w - 0.1)
233
+ bottom = im1_h - round(pad_h + 0.1)
234
+ right = im1_w - round(pad_w + 0.1)
411
235
 
412
236
  if len(masks.shape) < 2:
413
237
  raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
414
238
  masks = masks[top:bottom, left:right]
415
- masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
239
+ # handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
240
+ masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
241
+ masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
416
242
  if len(masks.shape) == 2:
417
243
  masks = masks[:, :, None]
418
244
 
@@ -420,35 +246,34 @@ def scale_image(masks, im0_shape, ratio_pad=None):
420
246
 
421
247
 
422
248
  def xyxy2xywh(x):
423
- """
424
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
425
- top-left corner and (x2, y2) is the bottom-right corner.
249
+ """Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is
250
+ the top-left corner and (x2, y2) is the bottom-right corner.
426
251
 
427
252
  Args:
428
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
253
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
429
254
 
430
255
  Returns:
431
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
256
+ (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
432
257
  """
433
258
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
434
259
  y = empty_like(x) # faster than clone/copy
435
- y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
436
- y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
437
- y[..., 2] = x[..., 2] - x[..., 0] # width
438
- y[..., 3] = x[..., 3] - x[..., 1] # height
260
+ x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
261
+ y[..., 0] = (x1 + x2) / 2 # x center
262
+ y[..., 1] = (y1 + y2) / 2 # y center
263
+ y[..., 2] = x2 - x1 # width
264
+ y[..., 3] = y2 - y1 # height
439
265
  return y
440
266
 
441
267
 
442
268
  def xywh2xyxy(x):
443
- """
444
- Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
445
- top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
269
+ """Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is
270
+ the top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
446
271
 
447
272
  Args:
448
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
273
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
449
274
 
450
275
  Returns:
451
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
276
+ (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
452
277
  """
453
278
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
454
279
  y = empty_like(x) # faster than clone/copy
@@ -459,65 +284,65 @@ def xywh2xyxy(x):
459
284
  return y
460
285
 
461
286
 
462
- def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
463
- """
464
- Convert normalized bounding box coordinates to pixel coordinates.
287
+ def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
288
+ """Convert normalized bounding box coordinates to pixel coordinates.
465
289
 
466
290
  Args:
467
- x (np.ndarray | torch.Tensor): The bounding box coordinates.
468
- w (int): Width of the image.
469
- h (int): Height of the image.
470
- padw (int): Padding width.
471
- padh (int): Padding height.
291
+ x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
292
+ w (int): Image width in pixels.
293
+ h (int): Image height in pixels.
294
+ padw (int): Padding width in pixels.
295
+ padh (int): Padding height in pixels.
472
296
 
473
297
  Returns:
474
- y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
475
- x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
298
+ y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is
299
+ the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
476
300
  """
477
301
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
478
302
  y = empty_like(x) # faster than clone/copy
479
- y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
480
- y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
481
- y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
482
- y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
303
+ xc, yc, xw, xh = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
304
+ half_w, half_h = xw / 2, xh / 2
305
+ y[..., 0] = w * (xc - half_w) + padw # top left x
306
+ y[..., 1] = h * (yc - half_h) + padh # top left y
307
+ y[..., 2] = w * (xc + half_w) + padw # bottom right x
308
+ y[..., 3] = h * (yc + half_h) + padh # bottom right y
483
309
  return y
484
310
 
485
311
 
486
- def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
487
- """
488
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
312
+ def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
313
+ """Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
489
314
  width and height are normalized to image dimensions.
490
315
 
491
316
  Args:
492
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
493
- w (int): The width of the image.
494
- h (int): The height of the image.
495
- clip (bool): If True, the boxes will be clipped to the image boundaries.
496
- eps (float): The minimum value of the box's width and height.
317
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
318
+ w (int): Image width in pixels.
319
+ h (int): Image height in pixels.
320
+ clip (bool): Whether to clip boxes to image boundaries.
321
+ eps (float): Minimum value for box width and height.
497
322
 
498
323
  Returns:
499
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
324
+ (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
500
325
  """
501
326
  if clip:
502
327
  x = clip_boxes(x, (h - eps, w - eps))
503
328
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
504
329
  y = empty_like(x) # faster than clone/copy
505
- y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
506
- y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
507
- y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
508
- y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
330
+ x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
331
+ y[..., 0] = ((x1 + x2) / 2) / w # x center
332
+ y[..., 1] = ((y1 + y2) / 2) / h # y center
333
+ y[..., 2] = (x2 - x1) / w # width
334
+ y[..., 3] = (y2 - y1) / h # height
509
335
  return y
510
336
 
511
337
 
512
338
  def xywh2ltwh(x):
513
- """
514
- Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
339
+ """Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
515
340
 
516
341
  Args:
517
- x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
342
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
518
343
 
519
344
  Returns:
520
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
345
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
521
346
  """
522
347
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
523
348
  y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
@@ -526,14 +351,13 @@ def xywh2ltwh(x):
526
351
 
527
352
 
528
353
  def xyxy2ltwh(x):
529
- """
530
- Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
354
+ """Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
531
355
 
532
356
  Args:
533
- x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
357
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
534
358
 
535
359
  Returns:
536
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
360
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
537
361
  """
538
362
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
539
363
  y[..., 2] = x[..., 2] - x[..., 0] # width
@@ -542,14 +366,13 @@ def xyxy2ltwh(x):
542
366
 
543
367
 
544
368
  def ltwh2xywh(x):
545
- """
546
- Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
369
+ """Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
547
370
 
548
371
  Args:
549
- x (torch.Tensor): the input tensor
372
+ x (torch.Tensor): Input bounding box coordinates.
550
373
 
551
374
  Returns:
552
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
375
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
553
376
  """
554
377
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
555
378
  y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
@@ -558,15 +381,14 @@ def ltwh2xywh(x):
558
381
 
559
382
 
560
383
  def xyxyxyxy2xywhr(x):
561
- """
562
- Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
563
- returned in radians from 0 to pi/2.
384
+ """Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
564
385
 
565
386
  Args:
566
- x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
387
+ x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
567
388
 
568
389
  Returns:
569
- (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
390
+ (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5). Rotation
391
+ values are in radians from 0 to pi/2.
570
392
  """
571
393
  is_torch = isinstance(x, torch.Tensor)
572
394
  points = x.cpu().numpy() if is_torch else x
@@ -581,15 +403,14 @@ def xyxyxyxy2xywhr(x):
581
403
 
582
404
 
583
405
  def xywhr2xyxyxyxy(x):
584
- """
585
- Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
586
- be in radians from 0 to pi/2.
406
+ """Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
587
407
 
588
408
  Args:
589
- x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
409
+ x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5). Rotation
410
+ values should be in radians from 0 to pi/2.
590
411
 
591
412
  Returns:
592
- (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
413
+ (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
593
414
  """
594
415
  cos, sin, cat, stack = (
595
416
  (torch.cos, torch.sin, torch.cat, torch.stack)
@@ -612,14 +433,13 @@ def xywhr2xyxyxyxy(x):
612
433
 
613
434
 
614
435
  def ltwh2xyxy(x):
615
- """
616
- Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
436
+ """Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
617
437
 
618
438
  Args:
619
- x (np.ndarray | torch.Tensor): The input image.
439
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates.
620
440
 
621
441
  Returns:
622
- (np.ndarray | torch.Tensor): The xyxy coordinates of the bounding boxes.
442
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
623
443
  """
624
444
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
625
445
  y[..., 2] = x[..., 2] + x[..., 0] # width
@@ -628,14 +448,13 @@ def ltwh2xyxy(x):
628
448
 
629
449
 
630
450
  def segments2boxes(segments):
631
- """
632
- Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
451
+ """Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
633
452
 
634
453
  Args:
635
- segments (list): List of segments, each segment is a list of points, each point is a list of x, y coordinates.
454
+ segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
636
455
 
637
456
  Returns:
638
- (np.ndarray): The xywh coordinates of the bounding boxes.
457
+ (np.ndarray): Bounding box coordinates in xywh format.
639
458
  """
640
459
  boxes = []
641
460
  for s in segments:
@@ -644,16 +463,15 @@ def segments2boxes(segments):
644
463
  return xyxy2xywh(np.array(boxes)) # cls, xywh
645
464
 
646
465
 
647
- def resample_segments(segments, n=1000):
648
- """
649
- Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
466
+ def resample_segments(segments, n: int = 1000):
467
+ """Resample segments to n points each using linear interpolation.
650
468
 
651
469
  Args:
652
- segments (list): A list of (n,2) arrays, where n is the number of points in the segment.
653
- n (int): Number of points to resample the segment to.
470
+ segments (list): List of (N, 2) arrays where N is the number of points in each segment.
471
+ n (int): Number of points to resample each segment to.
654
472
 
655
473
  Returns:
656
- segments (list): The resampled segments.
474
+ (list): Resampled segments with n points each.
657
475
  """
658
476
  for i, s in enumerate(segments):
659
477
  if len(s) == n:
@@ -668,124 +486,122 @@ def resample_segments(segments, n=1000):
668
486
  return segments
669
487
 
670
488
 
671
- def crop_mask(masks, boxes):
672
- """
673
- Crop masks to bounding boxes.
489
+ def crop_mask(masks: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:
490
+ """Crop masks to bounding box regions.
674
491
 
675
492
  Args:
676
- masks (torch.Tensor): [n, h, w] tensor of masks.
677
- boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form.
493
+ masks (torch.Tensor): Masks with shape (N, H, W).
494
+ boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
678
495
 
679
496
  Returns:
680
497
  (torch.Tensor): Cropped masks.
681
498
  """
682
- _, h, w = masks.shape
683
- x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
684
- r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
685
- c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
686
-
687
- return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
499
+ if boxes.device != masks.device:
500
+ boxes = boxes.to(masks.device)
501
+ n, h, w = masks.shape
502
+ if n < 50 and not masks.is_cuda: # faster for fewer masks (predict)
503
+ for i, (x1, y1, x2, y2) in enumerate(boxes.round().int()):
504
+ masks[i, :y1] = 0
505
+ masks[i, y2:] = 0
506
+ masks[i, :, :x1] = 0
507
+ masks[i, :, x2:] = 0
508
+ return masks
509
+ else: # faster for more masks (val)
510
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
511
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
512
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
513
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
688
514
 
689
515
 
690
- def process_mask(protos, masks_in, bboxes, shape, upsample=False):
691
- """
692
- Apply masks to bounding boxes using the output of the mask head.
516
+ def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
517
+ """Apply masks to bounding boxes using mask head output.
693
518
 
694
519
  Args:
695
- protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
696
- masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
697
- bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
698
- shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
699
- upsample (bool): A flag to indicate whether to upsample the mask to the original image size.
520
+ protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
521
+ masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
522
+ bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
523
+ shape (tuple): Input image size as (height, width).
524
+ upsample (bool): Whether to upsample masks to original image size.
700
525
 
701
526
  Returns:
702
527
  (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
703
528
  are the height and width of the input image. The mask is applied to the bounding boxes.
704
529
  """
705
530
  c, mh, mw = protos.shape # CHW
706
- ih, iw = shape
707
531
  masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
708
- width_ratio = mw / iw
709
- height_ratio = mh / ih
710
532
 
711
- downsampled_bboxes = bboxes.clone()
712
- downsampled_bboxes[:, 0] *= width_ratio
713
- downsampled_bboxes[:, 2] *= width_ratio
714
- downsampled_bboxes[:, 3] *= height_ratio
715
- downsampled_bboxes[:, 1] *= height_ratio
533
+ width_ratio = mw / shape[1]
534
+ height_ratio = mh / shape[0]
535
+ ratios = torch.tensor([[width_ratio, height_ratio, width_ratio, height_ratio]], device=bboxes.device)
716
536
 
717
- masks = crop_mask(masks, downsampled_bboxes) # CHW
537
+ masks = crop_mask(masks, boxes=bboxes * ratios) # CHW
718
538
  if upsample:
719
- masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
720
- return masks.gt_(0.0)
539
+ masks = F.interpolate(masks[None], shape, mode="bilinear")[0] # CHW
540
+ return masks.gt_(0.0).byte()
721
541
 
722
542
 
723
543
  def process_mask_native(protos, masks_in, bboxes, shape):
724
- """
725
- Apply masks to bounding boxes using the output of the mask head with native upsampling.
544
+ """Apply masks to bounding boxes using mask head output with native upsampling.
726
545
 
727
546
  Args:
728
- protos (torch.Tensor): [mask_dim, mask_h, mask_w].
729
- masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
730
- bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
731
- shape (tuple): The size of the input image (h,w).
547
+ protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
548
+ masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
549
+ bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
550
+ shape (tuple): Input image size as (height, width).
732
551
 
733
552
  Returns:
734
- (torch.Tensor): The returned masks with dimensions [h, w, n].
553
+ (torch.Tensor): Binary mask tensor with shape (H, W, N).
735
554
  """
736
555
  c, mh, mw = protos.shape # CHW
737
556
  masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
738
557
  masks = scale_masks(masks[None], shape)[0] # CHW
739
558
  masks = crop_mask(masks, bboxes) # CHW
740
- return masks.gt_(0.0)
559
+ return masks.gt_(0.0).byte()
741
560
 
742
561
 
743
- def scale_masks(masks, shape, padding=True):
744
- """
745
- Rescale segment masks to shape.
562
+ def scale_masks(masks, shape, padding: bool = True):
563
+ """Rescale segment masks to target shape.
746
564
 
747
565
  Args:
748
- masks (torch.Tensor): (N, C, H, W).
749
- shape (tuple): Height and width.
750
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
751
- rescaling.
566
+ masks (torch.Tensor): Masks with shape (N, C, H, W).
567
+ shape (tuple): Target height and width as (height, width).
568
+ padding (bool): Whether masks are based on YOLO-style augmented images with padding.
752
569
 
753
570
  Returns:
754
571
  (torch.Tensor): Rescaled masks.
755
572
  """
756
573
  mh, mw = masks.shape[2:]
757
574
  gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
758
- pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
575
+ pad_w = mw - shape[1] * gain
576
+ pad_h = mh - shape[0] * gain
759
577
  if padding:
760
- pad[0] /= 2
761
- pad[1] /= 2
762
- top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
763
- bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
764
- masks = masks[..., top:bottom, left:right]
765
-
766
- masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
767
- return masks
578
+ pad_w /= 2
579
+ pad_h /= 2
580
+ top, left = (round(pad_h - 0.1), round(pad_w - 0.1)) if padding else (0, 0)
581
+ bottom = mh - round(pad_h + 0.1)
582
+ right = mw - round(pad_w + 0.1)
583
+ return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear") # NCHW masks
768
584
 
769
585
 
770
- def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
771
- """
772
- Rescale segment coordinates (xy) from img1_shape to img0_shape.
586
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
587
+ """Rescale segment coordinates from img1_shape to img0_shape.
773
588
 
774
589
  Args:
775
- img1_shape (tuple): The shape of the image that the coords are from.
776
- coords (torch.Tensor): The coords to be scaled of shape n,2.
777
- img0_shape (tuple): The shape of the image that the segmentation is being applied to.
778
- ratio_pad (tuple): The ratio of the image size to the padded image size.
779
- normalize (bool): If True, the coordinates will be normalized to the range [0, 1].
780
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
781
- rescaling.
590
+ img1_shape (tuple): Source image shape as HWC or HW (supports both).
591
+ coords (torch.Tensor): Coordinates to scale with shape (N, 2).
592
+ img0_shape (tuple): Image 0 shape as HWC or HW (supports both).
593
+ ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
594
+ normalize (bool): Whether to normalize coordinates to range [0, 1].
595
+ padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
782
596
 
783
597
  Returns:
784
- coords (torch.Tensor): The scaled coordinates.
598
+ (torch.Tensor): Scaled coordinates.
785
599
  """
600
+ img0_h, img0_w = img0_shape[:2] # supports both HWC or HW shapes
786
601
  if ratio_pad is None: # calculate from img0_shape
787
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
788
- pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
602
+ img1_h, img1_w = img1_shape[:2] # supports both HWC or HW shapes
603
+ gain = min(img1_h / img0_h, img1_w / img0_w) # gain = old / new
604
+ pad = (img1_w - img0_w * gain) / 2, (img1_h - img0_h * gain) / 2 # wh padding
789
605
  else:
790
606
  gain = ratio_pad[0][0]
791
607
  pad = ratio_pad[1]
@@ -797,20 +613,19 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
797
613
  coords[..., 1] /= gain
798
614
  coords = clip_coords(coords, img0_shape)
799
615
  if normalize:
800
- coords[..., 0] /= img0_shape[1] # width
801
- coords[..., 1] /= img0_shape[0] # height
616
+ coords[..., 0] /= img0_w # width
617
+ coords[..., 1] /= img0_h # height
802
618
  return coords
803
619
 
804
620
 
805
621
  def regularize_rboxes(rboxes):
806
- """
807
- Regularize rotated boxes in range [0, pi/2].
622
+ """Regularize rotated bounding boxes to range [0, pi/2].
808
623
 
809
624
  Args:
810
- rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
625
+ rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
811
626
 
812
627
  Returns:
813
- (torch.Tensor): The regularized boxes.
628
+ (torch.Tensor): Regularized rotated boxes.
814
629
  """
815
630
  x, y, w, h, t = rboxes.unbind(dim=-1)
816
631
  # Swap edge if t >= pi/2 while not being symmetrically opposite
@@ -821,21 +636,20 @@ def regularize_rboxes(rboxes):
821
636
  return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
822
637
 
823
638
 
824
- def masks2segments(masks, strategy="all"):
825
- """
826
- Convert masks to segments.
639
+ def masks2segments(masks, strategy: str = "all"):
640
+ """Convert masks to segments using contour detection.
827
641
 
828
642
  Args:
829
- masks (torch.Tensor): The output of the model, which is a tensor of shape (batch_size, 160, 160).
830
- strategy (str): 'all' or 'largest'.
643
+ masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
644
+ strategy (str): Segmentation strategy, either 'all' or 'largest'.
831
645
 
832
646
  Returns:
833
- (list): List of segment masks.
647
+ (list): List of segment masks as float32 arrays.
834
648
  """
835
649
  from ultralytics.data.converter import merge_multi_segment
836
650
 
837
651
  segments = []
838
- for x in masks.int().cpu().numpy().astype("uint8"):
652
+ for x in masks.byte().cpu().numpy():
839
653
  c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
840
654
  if c:
841
655
  if strategy == "all": # merge and concatenate all segments
@@ -853,21 +667,19 @@ def masks2segments(masks, strategy="all"):
853
667
 
854
668
 
855
669
  def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
856
- """
857
- Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
670
+ """Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
858
671
 
859
672
  Args:
860
- batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
673
+ batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
861
674
 
862
675
  Returns:
863
- (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
676
+ (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
864
677
  """
865
- return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
678
+ return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).byte().cpu().numpy()
866
679
 
867
680
 
868
681
  def clean_str(s):
869
- """
870
- Cleans a string by replacing special characters with '_' character.
682
+ """Clean a string by replacing special characters with '_' character.
871
683
 
872
684
  Args:
873
685
  s (str): A string needing special characters replaced.
@@ -879,7 +691,7 @@ def clean_str(s):
879
691
 
880
692
 
881
693
  def empty_like(x):
882
- """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
694
+ """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
883
695
  return (
884
696
  torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
885
697
  )