ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/utils/ops.py CHANGED
@@ -4,6 +4,7 @@ import contextlib
4
4
  import math
5
5
  import re
6
6
  import time
7
+ from typing import Optional
7
8
 
8
9
  import cv2
9
10
  import numpy as np
@@ -16,27 +17,35 @@ from ultralytics.utils.metrics import batch_probiou
16
17
 
17
18
  class Profile(contextlib.ContextDecorator):
18
19
  """
19
- YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
20
+ Ultralytics Profile class for timing code execution.
21
+
22
+ Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
23
+ measurements with CUDA synchronization support for GPU operations.
20
24
 
21
25
  Attributes:
22
- t (float): Accumulated time.
26
+ t (float): Accumulated time in seconds.
23
27
  device (torch.device): Device used for model inference.
24
- cuda (bool): Whether CUDA is being used.
28
+ cuda (bool): Whether CUDA is being used for timing synchronization.
25
29
 
26
30
  Examples:
27
- >>> from ultralytics.utils.ops import Profile
31
+ Use as a context manager to time code execution
28
32
  >>> with Profile(device=device) as dt:
29
33
  ... pass # slow operation here
30
34
  >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
35
+
36
+ Use as a decorator to time function execution
37
+ >>> @Profile()
38
+ ... def slow_function():
39
+ ... time.sleep(0.1)
31
40
  """
32
41
 
33
- def __init__(self, t=0.0, device: torch.device = None):
42
+ def __init__(self, t: float = 0.0, device: Optional[torch.device] = None):
34
43
  """
35
44
  Initialize the Profile class.
36
45
 
37
46
  Args:
38
- t (float): Initial time.
39
- device (torch.device): Device used for model inference.
47
+ t (float): Initial accumulated time in seconds.
48
+ device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
40
49
  """
41
50
  self.t = t
42
51
  self.device = device
@@ -53,30 +62,33 @@ class Profile(contextlib.ContextDecorator):
53
62
  self.t += self.dt # accumulate dt
54
63
 
55
64
  def __str__(self):
56
- """Returns a human-readable string representing the accumulated elapsed time in the profiler."""
65
+ """Return a human-readable string representing the accumulated elapsed time."""
57
66
  return f"Elapsed time is {self.t} s"
58
67
 
59
68
  def time(self):
60
- """Get current time."""
69
+ """Get current time with CUDA synchronization if applicable."""
61
70
  if self.cuda:
62
71
  torch.cuda.synchronize(self.device)
63
72
  return time.perf_counter()
64
73
 
65
74
 
66
- def segment2box(segment, width=640, height=640):
75
+ def segment2box(segment, width: int = 640, height: int = 640):
67
76
  """
68
- Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
77
+ Convert segment coordinates to bounding box coordinates.
78
+
79
+ Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
80
+ Applies inside-image constraint and clips coordinates when necessary.
69
81
 
70
82
  Args:
71
- segment (torch.Tensor): The segment label.
72
- width (int): The width of the image.
73
- height (int): The height of the image.
83
+ segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
84
+ width (int): Width of the image in pixels.
85
+ height (int): Height of the image in pixels.
74
86
 
75
87
  Returns:
76
- (np.ndarray): The minimum and maximum x and y values of the segment.
88
+ (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
77
89
  """
78
90
  x, y = segment.T # segment xy
79
- # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294
91
+ # Clip coordinates if 3 out of 4 sides are outside the image
80
92
  if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
81
93
  x = x.clip(0, width)
82
94
  y = y.clip(0, height)
@@ -90,22 +102,23 @@ def segment2box(segment, width=640, height=640):
90
102
  ) # xyxy
91
103
 
92
104
 
93
- def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
105
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
94
106
  """
95
- Rescale bounding boxes from img1_shape to img0_shape.
107
+ Rescale bounding boxes from one image shape to another.
108
+
109
+ Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
110
+ Supports both xyxy and xywh box formats.
96
111
 
97
112
  Args:
98
- img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
99
- boxes (torch.Tensor): The bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2).
100
- img0_shape (tuple): The shape of the target image, in the format of (height, width).
101
- ratio_pad (tuple): A tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
102
- calculated based on the size difference between the two images.
103
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
104
- rescaling.
105
- xywh (bool): The box format is xywh or not.
113
+ img1_shape (tuple): Shape of the source image (height, width).
114
+ boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
115
+ img0_shape (tuple): Shape of the target image (height, width).
116
+ ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
117
+ padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
118
+ xywh (bool): Whether box format is xywh (True) or xyxy (False).
106
119
 
107
120
  Returns:
108
- (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2).
121
+ (torch.Tensor): Rescaled bounding boxes in the same format as input.
109
122
  """
110
123
  if ratio_pad is None: # calculate from img0_shape
111
124
  gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@@ -127,9 +140,9 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw
127
140
  return clip_boxes(boxes, img0_shape)
128
141
 
129
142
 
130
- def make_divisible(x, divisor):
143
+ def make_divisible(x: int, divisor):
131
144
  """
132
- Returns the nearest number that is divisible by the given divisor.
145
+ Return the nearest number that is divisible by the given divisor.
133
146
 
134
147
  Args:
135
148
  x (int): The number to make divisible.
@@ -143,16 +156,15 @@ def make_divisible(x, divisor):
143
156
  return math.ceil(x / divisor) * divisor
144
157
 
145
158
 
146
- def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
159
+ def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):
147
160
  """
148
- NMS for oriented bounding boxes using probiou and fast-nms.
161
+ Perform NMS on oriented bounding boxes using probiou and fast-nms.
149
162
 
150
163
  Args:
151
- boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
152
- scores (torch.Tensor): Confidence scores, shape (N,).
153
- threshold (float): IoU threshold.
154
- use_triu (bool): Whether to use `torch.triu` operator. It'd be useful for disable it
155
- when exporting obb models to some formats that do not support `torch.triu`.
164
+ boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.
165
+ scores (torch.Tensor): Confidence scores with shape (N,).
166
+ threshold (float): IoU threshold for NMS.
167
+ use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
156
168
 
157
169
  Returns:
158
170
  (torch.Tensor): Indices of boxes to keep after NMS.
@@ -162,7 +174,6 @@ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
162
174
  ious = batch_probiou(boxes, boxes)
163
175
  if use_triu:
164
176
  ious = ious.triu_(diagonal=1)
165
- # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
166
177
  # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
167
178
  pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
168
179
  else:
@@ -180,54 +191,51 @@ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
180
191
 
181
192
  def non_max_suppression(
182
193
  prediction,
183
- conf_thres=0.25,
184
- iou_thres=0.45,
194
+ conf_thres: float = 0.25,
195
+ iou_thres: float = 0.45,
185
196
  classes=None,
186
- agnostic=False,
187
- multi_label=False,
197
+ agnostic: bool = False,
198
+ multi_label: bool = False,
188
199
  labels=(),
189
- max_det=300,
190
- nc=0, # number of classes (optional)
191
- max_time_img=0.05,
192
- max_nms=30000,
193
- max_wh=7680,
194
- in_place=True,
195
- rotated=False,
196
- end2end=False,
197
- return_idxs=False,
200
+ max_det: int = 300,
201
+ nc: int = 0, # number of classes (optional)
202
+ max_time_img: float = 0.05,
203
+ max_nms: int = 30000,
204
+ max_wh: int = 7680,
205
+ in_place: bool = True,
206
+ rotated: bool = False,
207
+ end2end: bool = False,
208
+ return_idxs: bool = False,
198
209
  ):
199
210
  """
200
- Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
211
+ Perform non-maximum suppression (NMS) on prediction results.
212
+
213
+ Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
214
+ detection formats including standard boxes, rotated boxes, and masks.
201
215
 
202
216
  Args:
203
- prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
204
- containing the predicted boxes, classes, and masks. The tensor should be in the format
205
- output by a model, such as YOLO.
206
- conf_thres (float): The confidence threshold below which boxes will be filtered out.
207
- Valid values are between 0.0 and 1.0.
208
- iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
209
- Valid values are between 0.0 and 1.0.
210
- classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
211
- agnostic (bool): If True, the model is agnostic to the number of classes, and all
212
- classes will be considered as one.
213
- multi_label (bool): If True, each box may have multiple labels.
214
- labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
215
- list contains the apriori labels for a given image. The list should be in the format
216
- output by a dataloader, with each label being a tuple of (class_index, x, y, w, h).
217
- max_det (int): The maximum number of boxes to keep after NMS.
218
- nc (int): The number of classes output by the model. Any indices after this will be considered masks.
219
- max_time_img (float): The maximum time (seconds) for processing one image.
220
- max_nms (int): The maximum number of boxes into torchvision.ops.nms().
221
- max_wh (int): The maximum box width and height in pixels.
222
- in_place (bool): If True, the input prediction tensor will be modified in place.
223
- rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
224
- end2end (bool): If the model doesn't require NMS.
225
- return_idxs (bool): Return the indices of the detections that were kept.
217
+ prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
218
+ containing boxes, classes, and optional masks.
219
+ conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
220
+ iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
221
+ classes (List[int], optional): List of class indices to consider. If None, all classes are considered.
222
+ agnostic (bool): Whether to perform class-agnostic NMS.
223
+ multi_label (bool): Whether each box can have multiple labels.
224
+ labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.
225
+ max_det (int): Maximum number of detections to keep per image.
226
+ nc (int): Number of classes. Indices after this are considered masks.
227
+ max_time_img (float): Maximum time in seconds for processing one image.
228
+ max_nms (int): Maximum number of boxes for torchvision.ops.nms().
229
+ max_wh (int): Maximum box width and height in pixels.
230
+ in_place (bool): Whether to modify the input prediction tensor in place.
231
+ rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
232
+ end2end (bool): Whether the model is end-to-end and doesn't require NMS.
233
+ return_idxs (bool): Whether to return the indices of kept detections.
226
234
 
227
235
  Returns:
228
- (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
229
- shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
230
- (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
236
+ output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
237
+ containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
238
+ keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.
231
239
  """
232
240
  import torchvision # scope for faster 'import ultralytics'
233
241
 
@@ -322,18 +330,6 @@ def non_max_suppression(
322
330
  i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
323
331
  i = i[:max_det] # limit detections
324
332
 
325
- # # Experimental
326
- # merge = False # use merge-NMS
327
- # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
328
- # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
329
- # from .metrics import box_iou
330
- # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
331
- # weights = iou * scores[None] # box weights
332
- # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
333
- # redundant = True # require redundant detections
334
- # if redundant:
335
- # i = i[iou.sum(1) > 1] # require redundancy
336
-
337
333
  output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
338
334
  if (time.time() - t) > time_limit:
339
335
  LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
@@ -344,14 +340,14 @@ def non_max_suppression(
344
340
 
345
341
  def clip_boxes(boxes, shape):
346
342
  """
347
- Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
343
+ Clip bounding boxes to image boundaries.
348
344
 
349
345
  Args:
350
- boxes (torch.Tensor | numpy.ndarray): The bounding boxes to clip.
351
- shape (tuple): The shape of the image.
346
+ boxes (torch.Tensor | numpy.ndarray): Bounding boxes to clip.
347
+ shape (tuple): Image shape as (height, width).
352
348
 
353
349
  Returns:
354
- (torch.Tensor | numpy.ndarray): The clipped boxes.
350
+ (torch.Tensor | numpy.ndarray): Clipped bounding boxes.
355
351
  """
356
352
  if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
357
353
  boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
@@ -366,11 +362,11 @@ def clip_boxes(boxes, shape):
366
362
 
367
363
  def clip_coords(coords, shape):
368
364
  """
369
- Clip line coordinates to the image boundaries.
365
+ Clip line coordinates to image boundaries.
370
366
 
371
367
  Args:
372
- coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
373
- shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
368
+ coords (torch.Tensor | numpy.ndarray): Line coordinates to clip.
369
+ shape (tuple): Image shape as (height, width).
374
370
 
375
371
  Returns:
376
372
  (torch.Tensor | numpy.ndarray): Clipped coordinates.
@@ -386,15 +382,18 @@ def clip_coords(coords, shape):
386
382
 
387
383
  def scale_image(masks, im0_shape, ratio_pad=None):
388
384
  """
389
- Takes a mask, and resizes it to the original image size.
385
+ Rescale masks to original image size.
386
+
387
+ Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
388
+ that was applied during preprocessing.
390
389
 
391
390
  Args:
392
- masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
393
- im0_shape (tuple): The original image shape.
394
- ratio_pad (tuple): The ratio of the padding to the original image.
391
+ masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
392
+ im0_shape (tuple): Original image shape as (height, width).
393
+ ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
395
394
 
396
395
  Returns:
397
- masks (np.ndarray): The masks that are being returned with shape [h, w, num].
396
+ (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
398
397
  """
399
398
  # Rescale coordinates (xyxy) from im1_shape to im0_shape
400
399
  im1_shape = masks.shape
@@ -404,7 +403,6 @@ def scale_image(masks, im0_shape, ratio_pad=None):
404
403
  gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
405
404
  pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
406
405
  else:
407
- # gain = ratio_pad[0][0]
408
406
  pad = ratio_pad[1]
409
407
  top, left = int(pad[1]), int(pad[0]) # y, x
410
408
  bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
@@ -425,10 +423,10 @@ def xyxy2xywh(x):
425
423
  top-left corner and (x2, y2) is the bottom-right corner.
426
424
 
427
425
  Args:
428
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
426
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
429
427
 
430
428
  Returns:
431
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
429
+ (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
432
430
  """
433
431
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
434
432
  y = empty_like(x) # faster than clone/copy
@@ -445,10 +443,10 @@ def xywh2xyxy(x):
445
443
  top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
446
444
 
447
445
  Args:
448
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
446
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
449
447
 
450
448
  Returns:
451
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
449
+ (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
452
450
  """
453
451
  assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
454
452
  y = empty_like(x) # faster than clone/copy
@@ -459,16 +457,16 @@ def xywh2xyxy(x):
459
457
  return y
460
458
 
461
459
 
462
- def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
460
+ def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
463
461
  """
464
462
  Convert normalized bounding box coordinates to pixel coordinates.
465
463
 
466
464
  Args:
467
- x (np.ndarray | torch.Tensor): The bounding box coordinates.
468
- w (int): Width of the image.
469
- h (int): Height of the image.
470
- padw (int): Padding width.
471
- padh (int): Padding height.
465
+ x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
466
+ w (int): Image width in pixels.
467
+ h (int): Image height in pixels.
468
+ padw (int): Padding width in pixels.
469
+ padh (int): Padding height in pixels.
472
470
 
473
471
  Returns:
474
472
  y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
@@ -483,20 +481,20 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
483
481
  return y
484
482
 
485
483
 
486
- def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
484
+ def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
487
485
  """
488
486
  Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
489
487
  width and height are normalized to image dimensions.
490
488
 
491
489
  Args:
492
- x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
493
- w (int): The width of the image.
494
- h (int): The height of the image.
495
- clip (bool): If True, the boxes will be clipped to the image boundaries.
496
- eps (float): The minimum value of the box's width and height.
490
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
491
+ w (int): Image width in pixels.
492
+ h (int): Image height in pixels.
493
+ clip (bool): Whether to clip boxes to image boundaries.
494
+ eps (float): Minimum value for box width and height.
497
495
 
498
496
  Returns:
499
- y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
497
+ (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
500
498
  """
501
499
  if clip:
502
500
  x = clip_boxes(x, (h - eps, w - eps))
@@ -511,13 +509,13 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
511
509
 
512
510
  def xywh2ltwh(x):
513
511
  """
514
- Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
512
+ Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
515
513
 
516
514
  Args:
517
- x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
515
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
518
516
 
519
517
  Returns:
520
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
518
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
521
519
  """
522
520
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
523
521
  y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
@@ -527,13 +525,13 @@ def xywh2ltwh(x):
527
525
 
528
526
  def xyxy2ltwh(x):
529
527
  """
530
- Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
528
+ Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
531
529
 
532
530
  Args:
533
- x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
531
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
534
532
 
535
533
  Returns:
536
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
534
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
537
535
  """
538
536
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
539
537
  y[..., 2] = x[..., 2] - x[..., 0] # width
@@ -543,13 +541,13 @@ def xyxy2ltwh(x):
543
541
 
544
542
  def ltwh2xywh(x):
545
543
  """
546
- Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
544
+ Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
547
545
 
548
546
  Args:
549
- x (torch.Tensor): the input tensor
547
+ x (torch.Tensor): Input bounding box coordinates.
550
548
 
551
549
  Returns:
552
- y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
550
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
553
551
  """
554
552
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
555
553
  y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
@@ -559,14 +557,14 @@ def ltwh2xywh(x):
559
557
 
560
558
  def xyxyxyxy2xywhr(x):
561
559
  """
562
- Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
563
- returned in radians from 0 to pi/2.
560
+ Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
564
561
 
565
562
  Args:
566
- x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
563
+ x (numpy.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
567
564
 
568
565
  Returns:
569
- (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
566
+ (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
567
+ Rotation values are in radians from 0 to pi/2.
570
568
  """
571
569
  is_torch = isinstance(x, torch.Tensor)
572
570
  points = x.cpu().numpy() if is_torch else x
@@ -582,14 +580,14 @@ def xyxyxyxy2xywhr(x):
582
580
 
583
581
  def xywhr2xyxyxyxy(x):
584
582
  """
585
- Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
586
- be in radians from 0 to pi/2.
583
+ Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
587
584
 
588
585
  Args:
589
- x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
586
+ x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
587
+ Rotation values should be in radians from 0 to pi/2.
590
588
 
591
589
  Returns:
592
- (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
590
+ (numpy.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
593
591
  """
594
592
  cos, sin, cat, stack = (
595
593
  (torch.cos, torch.sin, torch.cat, torch.stack)
@@ -616,10 +614,10 @@ def ltwh2xyxy(x):
616
614
  Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
617
615
 
618
616
  Args:
619
- x (np.ndarray | torch.Tensor): The input image.
617
+ x (np.ndarray | torch.Tensor): Input bounding box coordinates.
620
618
 
621
619
  Returns:
622
- (np.ndarray | torch.Tensor): The xyxy coordinates of the bounding boxes.
620
+ (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
623
621
  """
624
622
  y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
625
623
  y[..., 2] = x[..., 2] + x[..., 0] # width
@@ -632,10 +630,10 @@ def segments2boxes(segments):
632
630
  Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
633
631
 
634
632
  Args:
635
- segments (list): List of segments, each segment is a list of points, each point is a list of x, y coordinates.
633
+ segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
636
634
 
637
635
  Returns:
638
- (np.ndarray): The xywh coordinates of the bounding boxes.
636
+ (np.ndarray): Bounding box coordinates in xywh format.
639
637
  """
640
638
  boxes = []
641
639
  for s in segments:
@@ -644,16 +642,16 @@ def segments2boxes(segments):
644
642
  return xyxy2xywh(np.array(boxes)) # cls, xywh
645
643
 
646
644
 
647
- def resample_segments(segments, n=1000):
645
+ def resample_segments(segments, n: int = 1000):
648
646
  """
649
- Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
647
+ Resample segments to n points each using linear interpolation.
650
648
 
651
649
  Args:
652
- segments (list): A list of (n,2) arrays, where n is the number of points in the segment.
653
- n (int): Number of points to resample the segment to.
650
+ segments (list): List of (N, 2) arrays where N is the number of points in each segment.
651
+ n (int): Number of points to resample each segment to.
654
652
 
655
653
  Returns:
656
- segments (list): The resampled segments.
654
+ (list): Resampled segments with n points each.
657
655
  """
658
656
  for i, s in enumerate(segments):
659
657
  if len(s) == n:
@@ -670,11 +668,11 @@ def resample_segments(segments, n=1000):
670
668
 
671
669
  def crop_mask(masks, boxes):
672
670
  """
673
- Crop masks to bounding boxes.
671
+ Crop masks to bounding box regions.
674
672
 
675
673
  Args:
676
- masks (torch.Tensor): [n, h, w] tensor of masks.
677
- boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form.
674
+ masks (torch.Tensor): Masks with shape (N, H, W).
675
+ boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
678
676
 
679
677
  Returns:
680
678
  (torch.Tensor): Cropped masks.
@@ -687,16 +685,16 @@ def crop_mask(masks, boxes):
687
685
  return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
688
686
 
689
687
 
690
- def process_mask(protos, masks_in, bboxes, shape, upsample=False):
688
+ def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
691
689
  """
692
- Apply masks to bounding boxes using the output of the mask head.
690
+ Apply masks to bounding boxes using mask head output.
693
691
 
694
692
  Args:
695
- protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
696
- masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
697
- bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
698
- shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
699
- upsample (bool): A flag to indicate whether to upsample the mask to the original image size.
693
+ protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
694
+ masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
695
+ bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
696
+ shape (tuple): Input image size as (height, width).
697
+ upsample (bool): Whether to upsample masks to original image size.
700
698
 
701
699
  Returns:
702
700
  (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
@@ -722,16 +720,16 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
722
720
 
723
721
  def process_mask_native(protos, masks_in, bboxes, shape):
724
722
  """
725
- Apply masks to bounding boxes using the output of the mask head with native upsampling.
723
+ Apply masks to bounding boxes using mask head output with native upsampling.
726
724
 
727
725
  Args:
728
- protos (torch.Tensor): [mask_dim, mask_h, mask_w].
729
- masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
730
- bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
731
- shape (tuple): The size of the input image (h,w).
726
+ protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
727
+ masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
728
+ bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
729
+ shape (tuple): Input image size as (height, width).
732
730
 
733
731
  Returns:
734
- (torch.Tensor): The returned masks with dimensions [h, w, n].
732
+ (torch.Tensor): Binary mask tensor with shape (H, W, N).
735
733
  """
736
734
  c, mh, mw = protos.shape # CHW
737
735
  masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
@@ -740,15 +738,14 @@ def process_mask_native(protos, masks_in, bboxes, shape):
740
738
  return masks.gt_(0.0)
741
739
 
742
740
 
743
- def scale_masks(masks, shape, padding=True):
741
+ def scale_masks(masks, shape, padding: bool = True):
744
742
  """
745
- Rescale segment masks to shape.
743
+ Rescale segment masks to target shape.
746
744
 
747
745
  Args:
748
- masks (torch.Tensor): (N, C, H, W).
749
- shape (tuple): Height and width.
750
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
751
- rescaling.
746
+ masks (torch.Tensor): Masks with shape (N, C, H, W).
747
+ shape (tuple): Target height and width as (height, width).
748
+ padding (bool): Whether masks are based on YOLO-style augmented images with padding.
752
749
 
753
750
  Returns:
754
751
  (torch.Tensor): Rescaled masks.
@@ -767,21 +764,20 @@ def scale_masks(masks, shape, padding=True):
767
764
  return masks
768
765
 
769
766
 
770
- def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
767
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
771
768
  """
772
- Rescale segment coordinates (xy) from img1_shape to img0_shape.
769
+ Rescale segment coordinates from img1_shape to img0_shape.
773
770
 
774
771
  Args:
775
- img1_shape (tuple): The shape of the image that the coords are from.
776
- coords (torch.Tensor): The coords to be scaled of shape n,2.
777
- img0_shape (tuple): The shape of the image that the segmentation is being applied to.
778
- ratio_pad (tuple): The ratio of the image size to the padded image size.
779
- normalize (bool): If True, the coordinates will be normalized to the range [0, 1].
780
- padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
781
- rescaling.
772
+ img1_shape (tuple): Shape of the source image.
773
+ coords (torch.Tensor): Coordinates to scale with shape (N, 2).
774
+ img0_shape (tuple): Shape of the target image.
775
+ ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
776
+ normalize (bool): Whether to normalize coordinates to range [0, 1].
777
+ padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
782
778
 
783
779
  Returns:
784
- coords (torch.Tensor): The scaled coordinates.
780
+ (torch.Tensor): Scaled coordinates.
785
781
  """
786
782
  if ratio_pad is None: # calculate from img0_shape
787
783
  gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
@@ -804,13 +800,13 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
804
800
 
805
801
  def regularize_rboxes(rboxes):
806
802
  """
807
- Regularize rotated boxes in range [0, pi/2].
803
+ Regularize rotated bounding boxes to range [0, pi/2].
808
804
 
809
805
  Args:
810
- rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
806
+ rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
811
807
 
812
808
  Returns:
813
- (torch.Tensor): The regularized boxes.
809
+ (torch.Tensor): Regularized rotated boxes.
814
810
  """
815
811
  x, y, w, h, t = rboxes.unbind(dim=-1)
816
812
  # Swap edge if t >= pi/2 while not being symmetrically opposite
@@ -821,16 +817,16 @@ def regularize_rboxes(rboxes):
821
817
  return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
822
818
 
823
819
 
824
- def masks2segments(masks, strategy="all"):
820
+ def masks2segments(masks, strategy: str = "all"):
825
821
  """
826
- Convert masks to segments.
822
+ Convert masks to segments using contour detection.
827
823
 
828
824
  Args:
829
- masks (torch.Tensor): The output of the model, which is a tensor of shape (batch_size, 160, 160).
830
- strategy (str): 'all' or 'largest'.
825
+ masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
826
+ strategy (str): Segmentation strategy, either 'all' or 'largest'.
831
827
 
832
828
  Returns:
833
- (list): List of segment masks.
829
+ (list): List of segment masks as float32 arrays.
834
830
  """
835
831
  from ultralytics.data.converter import merge_multi_segment
836
832
 
@@ -854,20 +850,20 @@ def masks2segments(masks, strategy="all"):
854
850
 
855
851
  def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
856
852
  """
857
- Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
853
+ Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
858
854
 
859
855
  Args:
860
- batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
856
+ batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
861
857
 
862
858
  Returns:
863
- (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
859
+ (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
864
860
  """
865
861
  return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
866
862
 
867
863
 
868
864
  def clean_str(s):
869
865
  """
870
- Cleans a string by replacing special characters with '_' character.
866
+ Clean a string by replacing special characters with '_' character.
871
867
 
872
868
  Args:
873
869
  s (str): A string needing special characters replaced.
@@ -879,7 +875,7 @@ def clean_str(s):
879
875
 
880
876
 
881
877
  def empty_like(x):
882
- """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
878
+ """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
883
879
  return (
884
880
  torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
885
881
  )