dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/utils/nms.py CHANGED
@@ -27,11 +27,10 @@ def non_max_suppression(
27
27
  end2end: bool = False,
28
28
  return_idxs: bool = False,
29
29
  ):
30
- """
31
- Perform non-maximum suppression (NMS) on prediction results.
30
+ """Perform non-maximum suppression (NMS) on prediction results.
32
31
 
33
- Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
34
- detection formats including standard boxes, rotated boxes, and masks.
32
+ Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple detection
33
+ formats including standard boxes, rotated boxes, and masks.
35
34
 
36
35
  Args:
37
36
  prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
@@ -52,8 +51,8 @@ def non_max_suppression(
52
51
  return_idxs (bool): Whether to return the indices of kept detections.
53
52
 
54
53
  Returns:
55
- output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
56
- containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
54
+ output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks) containing (x1,
55
+ y1, x2, y2, confidence, class, mask1, mask2, ...).
57
56
  keepi (list[torch.Tensor]): Indices of kept detections if return_idxs=True.
58
57
  """
59
58
  # Checks
@@ -168,8 +167,7 @@ def non_max_suppression(
168
167
 
169
168
 
170
169
  class TorchNMS:
171
- """
172
- Ultralytics custom NMS implementation optimized for YOLO.
170
+ """Ultralytics custom NMS implementation optimized for YOLO.
173
171
 
174
172
  This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
175
173
  including both standard NMS and batched NMS for multi-class scenarios.
@@ -194,8 +192,7 @@ class TorchNMS:
194
192
  iou_func=box_iou,
195
193
  exit_early: bool = True,
196
194
  ) -> torch.Tensor:
197
- """
198
- Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
195
+ """Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
199
196
 
200
197
  Args:
201
198
  boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
@@ -231,15 +228,16 @@ class TorchNMS:
231
228
  upper_mask = row_idx < col_idx
232
229
  ious = ious * upper_mask
233
230
  # Zeroing these scores ensures the additional indices would not affect the final results
234
- scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
231
+ scores_ = scores[sorted_idx]
232
+ scores_[~((ious >= iou_threshold).sum(0) <= 0)] = 0
233
+ scores[sorted_idx] = scores_ # update original tensor for NMSModel
235
234
  # NOTE: return indices with fixed length to avoid TFLite reshape error
236
- pick = torch.topk(scores, scores.shape[0]).indices
235
+ pick = torch.topk(scores_, scores_.shape[0]).indices
237
236
  return sorted_idx[pick]
238
237
 
239
238
  @staticmethod
240
239
  def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
241
- """
242
- Optimized NMS with early termination that matches torchvision behavior exactly.
240
+ """Optimized NMS with early termination that matches torchvision behavior exactly.
243
241
 
244
242
  Args:
245
243
  boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
@@ -305,8 +303,7 @@ class TorchNMS:
305
303
  iou_threshold: float,
306
304
  use_fast_nms: bool = False,
307
305
  ) -> torch.Tensor:
308
- """
309
- Batched NMS for class-aware suppression.
306
+ """Batched NMS for class-aware suppression.
310
307
 
311
308
  Args:
312
309
  boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
ultralytics/utils/ops.py CHANGED
@@ -16,8 +16,7 @@ from ultralytics.utils import NOT_MACOS14
16
16
 
17
17
 
18
18
  class Profile(contextlib.ContextDecorator):
19
- """
20
- Ultralytics Profile class for timing code execution.
19
+ """Ultralytics Profile class for timing code execution.
21
20
 
22
21
  Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
23
22
  measurements with CUDA synchronization support for GPU operations.
@@ -40,8 +39,7 @@ class Profile(contextlib.ContextDecorator):
40
39
  """
41
40
 
42
41
  def __init__(self, t: float = 0.0, device: torch.device | None = None):
43
- """
44
- Initialize the Profile class.
42
+ """Initialize the Profile class.
45
43
 
46
44
  Args:
47
45
  t (float): Initial accumulated time in seconds.
@@ -56,7 +54,7 @@ class Profile(contextlib.ContextDecorator):
56
54
  self.start = self.time()
57
55
  return self
58
56
 
59
- def __exit__(self, type, value, traceback): # noqa
57
+ def __exit__(self, type, value, traceback):
60
58
  """Stop timing."""
61
59
  self.dt = self.time() - self.start # delta-time
62
60
  self.t += self.dt # accumulate dt
@@ -73,11 +71,10 @@ class Profile(contextlib.ContextDecorator):
73
71
 
74
72
 
75
73
  def segment2box(segment, width: int = 640, height: int = 640):
76
- """
77
- Convert segment coordinates to bounding box coordinates.
74
+ """Convert segment coordinates to bounding box coordinates.
78
75
 
79
- Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
80
- Applies inside-image constraint and clips coordinates when necessary.
76
+ Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates. Applies
77
+ inside-image constraint and clips coordinates when necessary.
81
78
 
82
79
  Args:
83
80
  segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
@@ -103,11 +100,10 @@ def segment2box(segment, width: int = 640, height: int = 640):
103
100
 
104
101
 
105
102
  def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
106
- """
107
- Rescale bounding boxes from one image shape to another.
103
+ """Rescale bounding boxes from one image shape to another.
108
104
 
109
- Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
110
- Supports both xyxy and xywh box formats.
105
+ Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes. Supports
106
+ both xyxy and xywh box formats.
111
107
 
112
108
  Args:
113
109
  img1_shape (tuple): Shape of the source image (height, width).
@@ -139,8 +135,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = T
139
135
 
140
136
 
141
137
  def make_divisible(x: int, divisor):
142
- """
143
- Return the nearest number that is divisible by the given divisor.
138
+ """Return the nearest number that is divisible by the given divisor.
144
139
 
145
140
  Args:
146
141
  x (int): The number to make divisible.
@@ -155,8 +150,7 @@ def make_divisible(x: int, divisor):
155
150
 
156
151
 
157
152
  def clip_boxes(boxes, shape):
158
- """
159
- Clip bounding boxes to image boundaries.
153
+ """Clip bounding boxes to image boundaries.
160
154
 
161
155
  Args:
162
156
  boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
@@ -184,8 +178,7 @@ def clip_boxes(boxes, shape):
184
178
 
185
179
 
186
180
  def clip_coords(coords, shape):
187
- """
188
- Clip line coordinates to image boundaries.
181
+ """Clip line coordinates to image boundaries.
189
182
 
190
183
  Args:
191
184
  coords (torch.Tensor | np.ndarray): Line coordinates to clip.
@@ -208,55 +201,9 @@ def clip_coords(coords, shape):
208
201
  return coords
209
202
 
210
203
 
211
- def scale_image(masks, im0_shape, ratio_pad=None):
212
- """
213
- Rescale masks to original image size.
214
-
215
- Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
216
- that was applied during preprocessing.
217
-
218
- Args:
219
- masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
220
- im0_shape (tuple): Original image shape as HWC or HW (supports both).
221
- ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
222
-
223
- Returns:
224
- (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
225
- """
226
- # Rescale coordinates (xyxy) from im1_shape to im0_shape
227
- im0_h, im0_w = im0_shape[:2] # supports both HWC or HW shapes
228
- im1_h, im1_w, _ = masks.shape
229
- if im1_h == im0_h and im1_w == im0_w:
230
- return masks
231
-
232
- if ratio_pad is None: # calculate from im0_shape
233
- gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
234
- pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2 # wh padding
235
- else:
236
- pad = ratio_pad[1]
237
-
238
- pad_w, pad_h = pad
239
- top = int(round(pad_h - 0.1))
240
- left = int(round(pad_w - 0.1))
241
- bottom = im1_h - int(round(pad_h + 0.1))
242
- right = im1_w - int(round(pad_w + 0.1))
243
-
244
- if len(masks.shape) < 2:
245
- raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
246
- masks = masks[top:bottom, left:right]
247
- # handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
248
- masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
249
- masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
250
- if len(masks.shape) == 2:
251
- masks = masks[:, :, None]
252
-
253
- return masks
254
-
255
-
256
204
  def xyxy2xywh(x):
257
- """
258
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
259
- top-left corner and (x2, y2) is the bottom-right corner.
205
+ """Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is
206
+ the top-left corner and (x2, y2) is the bottom-right corner.
260
207
 
261
208
  Args:
262
209
  x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
@@ -275,9 +222,8 @@ def xyxy2xywh(x):
275
222
 
276
223
 
277
224
  def xywh2xyxy(x):
278
- """
279
- Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
280
- top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
225
+ """Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is
226
+ the top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
281
227
 
282
228
  Args:
283
229
  x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
@@ -295,8 +241,7 @@ def xywh2xyxy(x):
295
241
 
296
242
 
297
243
  def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
298
- """
299
- Convert normalized bounding box coordinates to pixel coordinates.
244
+ """Convert normalized bounding box coordinates to pixel coordinates.
300
245
 
301
246
  Args:
302
247
  x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
@@ -306,8 +251,8 @@ def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
306
251
  padh (int): Padding height in pixels.
307
252
 
308
253
  Returns:
309
- y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
310
- x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
254
+ y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is
255
+ the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
311
256
  """
312
257
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
313
258
  y = empty_like(x) # faster than clone/copy
@@ -321,8 +266,7 @@ def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
321
266
 
322
267
 
323
268
  def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
324
- """
325
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
269
+ """Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
326
270
  width and height are normalized to image dimensions.
327
271
 
328
272
  Args:
@@ -348,14 +292,13 @@ def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0
348
292
 
349
293
 
350
294
  def xywh2ltwh(x):
351
- """
352
- Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
295
+ """Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
353
296
 
354
297
  Args:
355
298
  x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
356
299
 
357
300
  Returns:
358
- (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
301
+ (np.ndarray | torch.Tensor): Bounding box coordinates in ltwh format.
359
302
  """
360
303
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
361
304
  y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
@@ -364,14 +307,13 @@ def xywh2ltwh(x):
364
307
 
365
308
 
366
309
  def xyxy2ltwh(x):
367
- """
368
- Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
310
+ """Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
369
311
 
370
312
  Args:
371
313
  x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
372
314
 
373
315
  Returns:
374
- (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
316
+ (np.ndarray | torch.Tensor): Bounding box coordinates in ltwh format.
375
317
  """
376
318
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
377
319
  y[..., 2] = x[..., 2] - x[..., 0] # width
@@ -380,11 +322,10 @@ def xyxy2ltwh(x):
380
322
 
381
323
 
382
324
  def ltwh2xywh(x):
383
- """
384
- Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
325
+ """Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
385
326
 
386
327
  Args:
387
- x (torch.Tensor): Input bounding box coordinates.
328
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates.
388
329
 
389
330
  Returns:
390
331
  (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
@@ -396,15 +337,14 @@ def ltwh2xywh(x):
396
337
 
397
338
 
398
339
  def xyxyxyxy2xywhr(x):
399
- """
400
- Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
340
+ """Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
401
341
 
402
342
  Args:
403
343
  x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
404
344
 
405
345
  Returns:
406
- (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
407
- Rotation values are in radians from 0 to pi/2.
346
+ (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5). Rotation
347
+ values are in radians from [-pi/4, 3pi/4).
408
348
  """
409
349
  is_torch = isinstance(x, torch.Tensor)
410
350
  points = x.cpu().numpy() if is_torch else x
@@ -414,17 +354,25 @@ def xyxyxyxy2xywhr(x):
414
354
  # NOTE: Use cv2.minAreaRect to get accurate xywhr,
415
355
  # especially some objects are cut off by augmentations in dataloader.
416
356
  (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
417
- rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
357
+ # convert angle to radian and normalize to [-pi/4, 3pi/4)
358
+ theta = angle / 180 * np.pi
359
+ if w < h:
360
+ w, h = h, w
361
+ theta += np.pi / 2
362
+ while theta >= 3 * np.pi / 4:
363
+ theta -= np.pi
364
+ while theta < -np.pi / 4:
365
+ theta += np.pi
366
+ rboxes.append([cx, cy, w, h, theta])
418
367
  return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
419
368
 
420
369
 
421
370
  def xywhr2xyxyxyxy(x):
422
- """
423
- Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
371
+ """Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
424
372
 
425
373
  Args:
426
- x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
427
- Rotation values should be in radians from 0 to pi/2.
374
+ x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5). Rotation
375
+ values should be in radians from 0 to pi/2.
428
376
 
429
377
  Returns:
430
378
  (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
@@ -450,8 +398,7 @@ def xywhr2xyxyxyxy(x):
450
398
 
451
399
 
452
400
  def ltwh2xyxy(x):
453
- """
454
- Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
401
+ """Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
455
402
 
456
403
  Args:
457
404
  x (np.ndarray | torch.Tensor): Input bounding box coordinates.
@@ -460,14 +407,13 @@ def ltwh2xyxy(x):
460
407
  (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
461
408
  """
462
409
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
463
- y[..., 2] = x[..., 2] + x[..., 0] # width
464
- y[..., 3] = x[..., 3] + x[..., 1] # height
410
+ y[..., 2] = x[..., 2] + x[..., 0] # x2
411
+ y[..., 3] = x[..., 3] + x[..., 1] # y2
465
412
  return y
466
413
 
467
414
 
468
415
  def segments2boxes(segments):
469
- """
470
- Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
416
+ """Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
471
417
 
472
418
  Args:
473
419
  segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
@@ -483,8 +429,7 @@ def segments2boxes(segments):
483
429
 
484
430
 
485
431
  def resample_segments(segments, n: int = 1000):
486
- """
487
- Resample segments to n points each using linear interpolation.
432
+ """Resample segments to n points each using linear interpolation.
488
433
 
489
434
  Args:
490
435
  segments (list): List of (N, 2) arrays where N is the number of points in each segment.
@@ -506,9 +451,8 @@ def resample_segments(segments, n: int = 1000):
506
451
  return segments
507
452
 
508
453
 
509
- def crop_mask(masks, boxes):
510
- """
511
- Crop masks to bounding box regions.
454
+ def crop_mask(masks: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:
455
+ """Crop masks to bounding box regions.
512
456
 
513
457
  Args:
514
458
  masks (torch.Tensor): Masks with shape (N, H, W).
@@ -517,17 +461,25 @@ def crop_mask(masks, boxes):
517
461
  Returns:
518
462
  (torch.Tensor): Cropped masks.
519
463
  """
520
- _, h, w = masks.shape
521
- x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
522
- r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
523
- c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
524
-
525
- return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
464
+ if boxes.device != masks.device:
465
+ boxes = boxes.to(masks.device)
466
+ n, h, w = masks.shape
467
+ if n < 50 and not masks.is_cuda: # faster for fewer masks (predict)
468
+ for i, (x1, y1, x2, y2) in enumerate(boxes.round().int()):
469
+ masks[i, :y1] = 0
470
+ masks[i, y2:] = 0
471
+ masks[i, :, :x1] = 0
472
+ masks[i, :, x2:] = 0
473
+ return masks
474
+ else: # faster for more masks (val)
475
+ x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
476
+ r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
477
+ c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
478
+ return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
526
479
 
527
480
 
528
481
  def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
529
- """
530
- Apply masks to bounding boxes using mask head output.
482
+ """Apply masks to bounding boxes using mask head output.
531
483
 
532
484
  Args:
533
485
  protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
@@ -541,26 +493,20 @@ def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
541
493
  are the height and width of the input image. The mask is applied to the bounding boxes.
542
494
  """
543
495
  c, mh, mw = protos.shape # CHW
544
- ih, iw = shape
545
- masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
546
- width_ratio = mw / iw
547
- height_ratio = mh / ih
548
-
549
- downsampled_bboxes = bboxes.clone()
550
- downsampled_bboxes[:, 0] *= width_ratio
551
- downsampled_bboxes[:, 2] *= width_ratio
552
- downsampled_bboxes[:, 3] *= height_ratio
553
- downsampled_bboxes[:, 1] *= height_ratio
554
-
555
- masks = crop_mask(masks, downsampled_bboxes) # CHW
496
+ masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # NHW
497
+
498
+ width_ratio = mw / shape[1]
499
+ height_ratio = mh / shape[0]
500
+ ratios = torch.tensor([[width_ratio, height_ratio, width_ratio, height_ratio]], device=bboxes.device)
501
+
502
+ masks = crop_mask(masks, boxes=bboxes * ratios) # NHW
556
503
  if upsample:
557
- masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
558
- return masks.gt_(0.0)
504
+ masks = F.interpolate(masks[None], shape, mode="bilinear")[0] # NHW
505
+ return masks.gt_(0.0).byte()
559
506
 
560
507
 
561
508
  def process_mask_native(protos, masks_in, bboxes, shape):
562
- """
563
- Apply masks to bounding boxes using mask head output with native upsampling.
509
+ """Apply masks to bounding boxes using mask head output with native upsampling.
564
510
 
565
511
  Args:
566
512
  protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
@@ -569,43 +515,53 @@ def process_mask_native(protos, masks_in, bboxes, shape):
569
515
  shape (tuple): Input image size as (height, width).
570
516
 
571
517
  Returns:
572
- (torch.Tensor): Binary mask tensor with shape (H, W, N).
518
+ (torch.Tensor): Binary mask tensor with shape (N, H, W).
573
519
  """
574
520
  c, mh, mw = protos.shape # CHW
575
521
  masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
576
- masks = scale_masks(masks[None], shape)[0] # CHW
577
- masks = crop_mask(masks, bboxes) # CHW
578
- return masks.gt_(0.0)
522
+ masks = scale_masks(masks[None], shape)[0] # NHW
523
+ masks = crop_mask(masks, bboxes) # NHW
524
+ return masks.gt_(0.0).byte()
579
525
 
580
526
 
581
- def scale_masks(masks, shape, padding: bool = True):
582
- """
583
- Rescale segment masks to target shape.
527
+ def scale_masks(
528
+ masks: torch.Tensor,
529
+ shape: tuple[int, int],
530
+ ratio_pad: tuple[tuple[int, int], tuple[int, int]] | None = None,
531
+ padding: bool = True,
532
+ ) -> torch.Tensor:
533
+ """Rescale segment masks to target shape.
584
534
 
585
535
  Args:
586
536
  masks (torch.Tensor): Masks with shape (N, C, H, W).
587
- shape (tuple): Target height and width as (height, width).
537
+ shape (tuple[int, int]): Target height and width as (height, width).
538
+ ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
588
539
  padding (bool): Whether masks are based on YOLO-style augmented images with padding.
589
540
 
590
541
  Returns:
591
542
  (torch.Tensor): Rescaled masks.
592
543
  """
593
- mh, mw = masks.shape[2:]
594
- gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
595
- pad_w = mw - shape[1] * gain
596
- pad_h = mh - shape[0] * gain
597
- if padding:
598
- pad_w /= 2
599
- pad_h /= 2
600
- top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
601
- bottom = mh - int(round(pad_h + 0.1))
602
- right = mw - int(round(pad_w + 0.1))
603
- return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear", align_corners=False) # NCHW masks
544
+ im1_h, im1_w = masks.shape[2:]
545
+ im0_h, im0_w = shape[:2]
546
+ if im1_h == im0_h and im1_w == im0_w:
547
+ return masks
548
+
549
+ if ratio_pad is None: # calculate from im0_shape
550
+ gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
551
+ pad_w, pad_h = (im1_w - im0_w * gain), (im1_h - im0_h * gain) # wh padding
552
+ if padding:
553
+ pad_w /= 2
554
+ pad_h /= 2
555
+ else:
556
+ pad_w, pad_h = ratio_pad[1]
557
+ top, left = (round(pad_h - 0.1), round(pad_w - 0.1)) if padding else (0, 0)
558
+ bottom = im1_h - round(pad_h + 0.1)
559
+ right = im1_w - round(pad_w + 0.1)
560
+ return F.interpolate(masks[..., top:bottom, left:right].float(), shape, mode="bilinear") # NCHW masks
604
561
 
605
562
 
606
563
  def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
607
- """
608
- Rescale segment coordinates from img1_shape to img0_shape.
564
+ """Rescale segment coordinates from img1_shape to img0_shape.
609
565
 
610
566
  Args:
611
567
  img1_shape (tuple): Source image shape as HWC or HW (supports both).
@@ -640,8 +596,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool
640
596
 
641
597
 
642
598
  def regularize_rboxes(rboxes):
643
- """
644
- Regularize rotated bounding boxes to range [0, pi/2].
599
+ """Regularize rotated bounding boxes to range [0, pi/2].
645
600
 
646
601
  Args:
647
602
  rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
@@ -658,12 +613,11 @@ def regularize_rboxes(rboxes):
658
613
  return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
659
614
 
660
615
 
661
- def masks2segments(masks, strategy: str = "all"):
662
- """
663
- Convert masks to segments using contour detection.
616
+ def masks2segments(masks: np.ndarray | torch.Tensor, strategy: str = "all") -> list[np.ndarray]:
617
+ """Convert masks to segments using contour detection.
664
618
 
665
619
  Args:
666
- masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
620
+ masks (np.ndarray | torch.Tensor): Binary masks with shape (batch_size, 160, 160).
667
621
  strategy (str): Segmentation strategy, either 'all' or 'largest'.
668
622
 
669
623
  Returns:
@@ -671,8 +625,9 @@ def masks2segments(masks, strategy: str = "all"):
671
625
  """
672
626
  from ultralytics.data.converter import merge_multi_segment
673
627
 
628
+ masks = masks.astype("uint8") if isinstance(masks, np.ndarray) else masks.byte().cpu().numpy()
674
629
  segments = []
675
- for x in masks.int().cpu().numpy().astype("uint8"):
630
+ for x in np.ascontiguousarray(masks):
676
631
  c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
677
632
  if c:
678
633
  if strategy == "all": # merge and concatenate all segments
@@ -690,8 +645,7 @@ def masks2segments(masks, strategy: str = "all"):
690
645
 
691
646
 
692
647
  def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
693
- """
694
- Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
648
+ """Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
695
649
 
696
650
  Args:
697
651
  batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
@@ -699,12 +653,11 @@ def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
699
653
  Returns:
700
654
  (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
701
655
  """
702
- return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
656
+ return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).byte().cpu().numpy()
703
657
 
704
658
 
705
659
  def clean_str(s):
706
- """
707
- Clean a string by replacing special characters with '_' character.
660
+ """Clean a string by replacing special characters with '_' character.
708
661
 
709
662
  Args:
710
663
  s (str): A string needing special characters replaced.
@@ -712,11 +665,9 @@ def clean_str(s):
712
665
  Returns:
713
666
  (str): A string with special characters replaced by an underscore _.
714
667
  """
715
- return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
668
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨`><+]", repl="_", string=s)
716
669
 
717
670
 
718
671
  def empty_like(x):
719
672
  """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
720
- return (
721
- torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
722
- )
673
+ return torch.empty_like(x, dtype=x.dtype) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=x.dtype)