dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
- dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +52 -51
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +191 -161
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +4 -6
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +2 -2
- ultralytics/solutions/instance_segmentation.py +7 -4
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -11
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +189 -79
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +45 -29
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
ultralytics/utils/ops.py
CHANGED
@@ -4,6 +4,7 @@ import contextlib
|
|
4
4
|
import math
|
5
5
|
import re
|
6
6
|
import time
|
7
|
+
from typing import Optional
|
7
8
|
|
8
9
|
import cv2
|
9
10
|
import numpy as np
|
@@ -16,27 +17,35 @@ from ultralytics.utils.metrics import batch_probiou
|
|
16
17
|
|
17
18
|
class Profile(contextlib.ContextDecorator):
|
18
19
|
"""
|
19
|
-
|
20
|
+
Ultralytics Profile class for timing code execution.
|
21
|
+
|
22
|
+
Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
|
23
|
+
measurements with CUDA synchronization support for GPU operations.
|
20
24
|
|
21
25
|
Attributes:
|
22
|
-
t (float): Accumulated time.
|
26
|
+
t (float): Accumulated time in seconds.
|
23
27
|
device (torch.device): Device used for model inference.
|
24
|
-
cuda (bool): Whether CUDA is being used.
|
28
|
+
cuda (bool): Whether CUDA is being used for timing synchronization.
|
25
29
|
|
26
30
|
Examples:
|
27
|
-
|
31
|
+
Use as a context manager to time code execution
|
28
32
|
>>> with Profile(device=device) as dt:
|
29
33
|
... pass # slow operation here
|
30
34
|
>>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
|
35
|
+
|
36
|
+
Use as a decorator to time function execution
|
37
|
+
>>> @Profile()
|
38
|
+
... def slow_function():
|
39
|
+
... time.sleep(0.1)
|
31
40
|
"""
|
32
41
|
|
33
|
-
def __init__(self, t=0.0, device: torch.device = None):
|
42
|
+
def __init__(self, t: float = 0.0, device: Optional[torch.device] = None):
|
34
43
|
"""
|
35
44
|
Initialize the Profile class.
|
36
45
|
|
37
46
|
Args:
|
38
|
-
t (float): Initial time.
|
39
|
-
device (torch.device): Device used for model inference.
|
47
|
+
t (float): Initial accumulated time in seconds.
|
48
|
+
device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
|
40
49
|
"""
|
41
50
|
self.t = t
|
42
51
|
self.device = device
|
@@ -53,30 +62,33 @@ class Profile(contextlib.ContextDecorator):
|
|
53
62
|
self.t += self.dt # accumulate dt
|
54
63
|
|
55
64
|
def __str__(self):
|
56
|
-
"""
|
65
|
+
"""Return a human-readable string representing the accumulated elapsed time."""
|
57
66
|
return f"Elapsed time is {self.t} s"
|
58
67
|
|
59
68
|
def time(self):
|
60
|
-
"""Get current time."""
|
69
|
+
"""Get current time with CUDA synchronization if applicable."""
|
61
70
|
if self.cuda:
|
62
71
|
torch.cuda.synchronize(self.device)
|
63
72
|
return time.perf_counter()
|
64
73
|
|
65
74
|
|
66
|
-
def segment2box(segment, width=640, height=640):
|
75
|
+
def segment2box(segment, width: int = 640, height: int = 640):
|
67
76
|
"""
|
68
|
-
Convert
|
77
|
+
Convert segment coordinates to bounding box coordinates.
|
78
|
+
|
79
|
+
Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
|
80
|
+
Applies inside-image constraint and clips coordinates when necessary.
|
69
81
|
|
70
82
|
Args:
|
71
|
-
segment (torch.Tensor):
|
72
|
-
width (int):
|
73
|
-
height (int):
|
83
|
+
segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
|
84
|
+
width (int): Width of the image in pixels.
|
85
|
+
height (int): Height of the image in pixels.
|
74
86
|
|
75
87
|
Returns:
|
76
|
-
(np.ndarray):
|
88
|
+
(np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
|
77
89
|
"""
|
78
90
|
x, y = segment.T # segment xy
|
79
|
-
#
|
91
|
+
# Clip coordinates if 3 out of 4 sides are outside the image
|
80
92
|
if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
|
81
93
|
x = x.clip(0, width)
|
82
94
|
y = y.clip(0, height)
|
@@ -90,22 +102,23 @@ def segment2box(segment, width=640, height=640):
|
|
90
102
|
) # xyxy
|
91
103
|
|
92
104
|
|
93
|
-
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
|
105
|
+
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
|
94
106
|
"""
|
95
|
-
Rescale bounding boxes from
|
107
|
+
Rescale bounding boxes from one image shape to another.
|
108
|
+
|
109
|
+
Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
|
110
|
+
Supports both xyxy and xywh box formats.
|
96
111
|
|
97
112
|
Args:
|
98
|
-
img1_shape (tuple):
|
99
|
-
boxes (torch.Tensor):
|
100
|
-
img0_shape (tuple):
|
101
|
-
ratio_pad (tuple):
|
102
|
-
|
103
|
-
|
104
|
-
rescaling.
|
105
|
-
xywh (bool): The box format is xywh or not.
|
113
|
+
img1_shape (tuple): Shape of the source image (height, width).
|
114
|
+
boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
|
115
|
+
img0_shape (tuple): Shape of the target image (height, width).
|
116
|
+
ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
|
117
|
+
padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
|
118
|
+
xywh (bool): Whether box format is xywh (True) or xyxy (False).
|
106
119
|
|
107
120
|
Returns:
|
108
|
-
(torch.Tensor):
|
121
|
+
(torch.Tensor): Rescaled bounding boxes in the same format as input.
|
109
122
|
"""
|
110
123
|
if ratio_pad is None: # calculate from img0_shape
|
111
124
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
@@ -127,9 +140,9 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xyw
|
|
127
140
|
return clip_boxes(boxes, img0_shape)
|
128
141
|
|
129
142
|
|
130
|
-
def make_divisible(x, divisor):
|
143
|
+
def make_divisible(x: int, divisor):
|
131
144
|
"""
|
132
|
-
|
145
|
+
Return the nearest number that is divisible by the given divisor.
|
133
146
|
|
134
147
|
Args:
|
135
148
|
x (int): The number to make divisible.
|
@@ -143,16 +156,15 @@ def make_divisible(x, divisor):
|
|
143
156
|
return math.ceil(x / divisor) * divisor
|
144
157
|
|
145
158
|
|
146
|
-
def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
|
159
|
+
def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):
|
147
160
|
"""
|
148
|
-
NMS
|
161
|
+
Perform NMS on oriented bounding boxes using probiou and fast-nms.
|
149
162
|
|
150
163
|
Args:
|
151
|
-
boxes (torch.Tensor): Rotated bounding boxes
|
152
|
-
scores (torch.Tensor): Confidence scores
|
153
|
-
threshold (float): IoU threshold.
|
154
|
-
use_triu (bool): Whether to use
|
155
|
-
when exporting obb models to some formats that do not support `torch.triu`.
|
164
|
+
boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.
|
165
|
+
scores (torch.Tensor): Confidence scores with shape (N,).
|
166
|
+
threshold (float): IoU threshold for NMS.
|
167
|
+
use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
|
156
168
|
|
157
169
|
Returns:
|
158
170
|
(torch.Tensor): Indices of boxes to keep after NMS.
|
@@ -162,7 +174,6 @@ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
|
|
162
174
|
ious = batch_probiou(boxes, boxes)
|
163
175
|
if use_triu:
|
164
176
|
ious = ious.triu_(diagonal=1)
|
165
|
-
# pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
166
177
|
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
167
178
|
pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
|
168
179
|
else:
|
@@ -180,54 +191,51 @@ def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
|
|
180
191
|
|
181
192
|
def non_max_suppression(
|
182
193
|
prediction,
|
183
|
-
conf_thres=0.25,
|
184
|
-
iou_thres=0.45,
|
194
|
+
conf_thres: float = 0.25,
|
195
|
+
iou_thres: float = 0.45,
|
185
196
|
classes=None,
|
186
|
-
agnostic=False,
|
187
|
-
multi_label=False,
|
197
|
+
agnostic: bool = False,
|
198
|
+
multi_label: bool = False,
|
188
199
|
labels=(),
|
189
|
-
max_det=300,
|
190
|
-
nc=0, # number of classes (optional)
|
191
|
-
max_time_img=0.05,
|
192
|
-
max_nms=30000,
|
193
|
-
max_wh=7680,
|
194
|
-
in_place=True,
|
195
|
-
rotated=False,
|
196
|
-
end2end=False,
|
197
|
-
return_idxs=False,
|
200
|
+
max_det: int = 300,
|
201
|
+
nc: int = 0, # number of classes (optional)
|
202
|
+
max_time_img: float = 0.05,
|
203
|
+
max_nms: int = 30000,
|
204
|
+
max_wh: int = 7680,
|
205
|
+
in_place: bool = True,
|
206
|
+
rotated: bool = False,
|
207
|
+
end2end: bool = False,
|
208
|
+
return_idxs: bool = False,
|
198
209
|
):
|
199
210
|
"""
|
200
|
-
Perform non-maximum suppression (NMS) on
|
211
|
+
Perform non-maximum suppression (NMS) on prediction results.
|
212
|
+
|
213
|
+
Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
|
214
|
+
detection formats including standard boxes, rotated boxes, and masks.
|
201
215
|
|
202
216
|
Args:
|
203
|
-
prediction (torch.Tensor):
|
204
|
-
containing
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
221
|
-
max_wh (int): The maximum box width and height in pixels.
|
222
|
-
in_place (bool): If True, the input prediction tensor will be modified in place.
|
223
|
-
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
|
224
|
-
end2end (bool): If the model doesn't require NMS.
|
225
|
-
return_idxs (bool): Return the indices of the detections that were kept.
|
217
|
+
prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
218
|
+
containing boxes, classes, and optional masks.
|
219
|
+
conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
|
220
|
+
iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
|
221
|
+
classes (List[int], optional): List of class indices to consider. If None, all classes are considered.
|
222
|
+
agnostic (bool): Whether to perform class-agnostic NMS.
|
223
|
+
multi_label (bool): Whether each box can have multiple labels.
|
224
|
+
labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.
|
225
|
+
max_det (int): Maximum number of detections to keep per image.
|
226
|
+
nc (int): Number of classes. Indices after this are considered masks.
|
227
|
+
max_time_img (float): Maximum time in seconds for processing one image.
|
228
|
+
max_nms (int): Maximum number of boxes for torchvision.ops.nms().
|
229
|
+
max_wh (int): Maximum box width and height in pixels.
|
230
|
+
in_place (bool): Whether to modify the input prediction tensor in place.
|
231
|
+
rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
|
232
|
+
end2end (bool): Whether the model is end-to-end and doesn't require NMS.
|
233
|
+
return_idxs (bool): Whether to return the indices of kept detections.
|
226
234
|
|
227
235
|
Returns:
|
228
|
-
(List[torch.Tensor]):
|
229
|
-
|
230
|
-
|
236
|
+
output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
|
237
|
+
containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
238
|
+
keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.
|
231
239
|
"""
|
232
240
|
import torchvision # scope for faster 'import ultralytics'
|
233
241
|
|
@@ -322,18 +330,6 @@ def non_max_suppression(
|
|
322
330
|
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
323
331
|
i = i[:max_det] # limit detections
|
324
332
|
|
325
|
-
# # Experimental
|
326
|
-
# merge = False # use merge-NMS
|
327
|
-
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
328
|
-
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
329
|
-
# from .metrics import box_iou
|
330
|
-
# iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
|
331
|
-
# weights = iou * scores[None] # box weights
|
332
|
-
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
333
|
-
# redundant = True # require redundant detections
|
334
|
-
# if redundant:
|
335
|
-
# i = i[iou.sum(1) > 1] # require redundancy
|
336
|
-
|
337
333
|
output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
|
338
334
|
if (time.time() - t) > time_limit:
|
339
335
|
LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
|
@@ -344,14 +340,14 @@ def non_max_suppression(
|
|
344
340
|
|
345
341
|
def clip_boxes(boxes, shape):
|
346
342
|
"""
|
347
|
-
|
343
|
+
Clip bounding boxes to image boundaries.
|
348
344
|
|
349
345
|
Args:
|
350
|
-
boxes (torch.Tensor | numpy.ndarray):
|
351
|
-
shape (tuple):
|
346
|
+
boxes (torch.Tensor | numpy.ndarray): Bounding boxes to clip.
|
347
|
+
shape (tuple): Image shape as (height, width).
|
352
348
|
|
353
349
|
Returns:
|
354
|
-
(torch.Tensor | numpy.ndarray):
|
350
|
+
(torch.Tensor | numpy.ndarray): Clipped bounding boxes.
|
355
351
|
"""
|
356
352
|
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
357
353
|
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
@@ -366,11 +362,11 @@ def clip_boxes(boxes, shape):
|
|
366
362
|
|
367
363
|
def clip_coords(coords, shape):
|
368
364
|
"""
|
369
|
-
Clip line coordinates to
|
365
|
+
Clip line coordinates to image boundaries.
|
370
366
|
|
371
367
|
Args:
|
372
|
-
coords (torch.Tensor | numpy.ndarray):
|
373
|
-
shape (tuple):
|
368
|
+
coords (torch.Tensor | numpy.ndarray): Line coordinates to clip.
|
369
|
+
shape (tuple): Image shape as (height, width).
|
374
370
|
|
375
371
|
Returns:
|
376
372
|
(torch.Tensor | numpy.ndarray): Clipped coordinates.
|
@@ -386,15 +382,18 @@ def clip_coords(coords, shape):
|
|
386
382
|
|
387
383
|
def scale_image(masks, im0_shape, ratio_pad=None):
|
388
384
|
"""
|
389
|
-
|
385
|
+
Rescale masks to original image size.
|
386
|
+
|
387
|
+
Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
|
388
|
+
that was applied during preprocessing.
|
390
389
|
|
391
390
|
Args:
|
392
|
-
masks (np.ndarray): Resized and padded masks
|
393
|
-
im0_shape (tuple):
|
394
|
-
ratio_pad (tuple):
|
391
|
+
masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
|
392
|
+
im0_shape (tuple): Original image shape as (height, width).
|
393
|
+
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
|
395
394
|
|
396
395
|
Returns:
|
397
|
-
|
396
|
+
(np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
|
398
397
|
"""
|
399
398
|
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
400
399
|
im1_shape = masks.shape
|
@@ -404,7 +403,6 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|
404
403
|
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
405
404
|
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
406
405
|
else:
|
407
|
-
# gain = ratio_pad[0][0]
|
408
406
|
pad = ratio_pad[1]
|
409
407
|
top, left = int(pad[1]), int(pad[0]) # y, x
|
410
408
|
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
@@ -425,10 +423,10 @@ def xyxy2xywh(x):
|
|
425
423
|
top-left corner and (x2, y2) is the bottom-right corner.
|
426
424
|
|
427
425
|
Args:
|
428
|
-
x (np.ndarray | torch.Tensor):
|
426
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
|
429
427
|
|
430
428
|
Returns:
|
431
|
-
|
429
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
|
432
430
|
"""
|
433
431
|
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
434
432
|
y = empty_like(x) # faster than clone/copy
|
@@ -445,10 +443,10 @@ def xywh2xyxy(x):
|
|
445
443
|
top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
|
446
444
|
|
447
445
|
Args:
|
448
|
-
x (np.ndarray | torch.Tensor):
|
446
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
|
449
447
|
|
450
448
|
Returns:
|
451
|
-
|
449
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
|
452
450
|
"""
|
453
451
|
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
454
452
|
y = empty_like(x) # faster than clone/copy
|
@@ -459,16 +457,16 @@ def xywh2xyxy(x):
|
|
459
457
|
return y
|
460
458
|
|
461
459
|
|
462
|
-
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
460
|
+
def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
|
463
461
|
"""
|
464
462
|
Convert normalized bounding box coordinates to pixel coordinates.
|
465
463
|
|
466
464
|
Args:
|
467
|
-
x (np.ndarray | torch.Tensor):
|
468
|
-
w (int):
|
469
|
-
h (int):
|
470
|
-
padw (int): Padding width.
|
471
|
-
padh (int): Padding height.
|
465
|
+
x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
|
466
|
+
w (int): Image width in pixels.
|
467
|
+
h (int): Image height in pixels.
|
468
|
+
padw (int): Padding width in pixels.
|
469
|
+
padh (int): Padding height in pixels.
|
472
470
|
|
473
471
|
Returns:
|
474
472
|
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
@@ -483,20 +481,20 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
|
483
481
|
return y
|
484
482
|
|
485
483
|
|
486
|
-
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
484
|
+
def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
|
487
485
|
"""
|
488
486
|
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
489
487
|
width and height are normalized to image dimensions.
|
490
488
|
|
491
489
|
Args:
|
492
|
-
x (np.ndarray | torch.Tensor):
|
493
|
-
w (int):
|
494
|
-
h (int):
|
495
|
-
clip (bool):
|
496
|
-
eps (float):
|
490
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
|
491
|
+
w (int): Image width in pixels.
|
492
|
+
h (int): Image height in pixels.
|
493
|
+
clip (bool): Whether to clip boxes to image boundaries.
|
494
|
+
eps (float): Minimum value for box width and height.
|
497
495
|
|
498
496
|
Returns:
|
499
|
-
|
497
|
+
(np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
|
500
498
|
"""
|
501
499
|
if clip:
|
502
500
|
x = clip_boxes(x, (h - eps, w - eps))
|
@@ -511,13 +509,13 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|
511
509
|
|
512
510
|
def xywh2ltwh(x):
|
513
511
|
"""
|
514
|
-
Convert
|
512
|
+
Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
|
515
513
|
|
516
514
|
Args:
|
517
|
-
x (np.ndarray | torch.Tensor):
|
515
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
|
518
516
|
|
519
517
|
Returns:
|
520
|
-
|
518
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
|
521
519
|
"""
|
522
520
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
523
521
|
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
@@ -527,13 +525,13 @@ def xywh2ltwh(x):
|
|
527
525
|
|
528
526
|
def xyxy2ltwh(x):
|
529
527
|
"""
|
530
|
-
Convert
|
528
|
+
Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
|
531
529
|
|
532
530
|
Args:
|
533
|
-
x (np.ndarray | torch.Tensor):
|
531
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
|
534
532
|
|
535
533
|
Returns:
|
536
|
-
|
534
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
|
537
535
|
"""
|
538
536
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
539
537
|
y[..., 2] = x[..., 2] - x[..., 0] # width
|
@@ -543,13 +541,13 @@ def xyxy2ltwh(x):
|
|
543
541
|
|
544
542
|
def ltwh2xywh(x):
|
545
543
|
"""
|
546
|
-
Convert
|
544
|
+
Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
547
545
|
|
548
546
|
Args:
|
549
|
-
x (torch.Tensor):
|
547
|
+
x (torch.Tensor): Input bounding box coordinates.
|
550
548
|
|
551
549
|
Returns:
|
552
|
-
|
550
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
|
553
551
|
"""
|
554
552
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
555
553
|
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
@@ -559,14 +557,14 @@ def ltwh2xywh(x):
|
|
559
557
|
|
560
558
|
def xyxyxyxy2xywhr(x):
|
561
559
|
"""
|
562
|
-
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation].
|
563
|
-
returned in radians from 0 to pi/2.
|
560
|
+
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
|
564
561
|
|
565
562
|
Args:
|
566
|
-
x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4]
|
563
|
+
x (numpy.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
|
567
564
|
|
568
565
|
Returns:
|
569
|
-
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format
|
566
|
+
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
|
567
|
+
Rotation values are in radians from 0 to pi/2.
|
570
568
|
"""
|
571
569
|
is_torch = isinstance(x, torch.Tensor)
|
572
570
|
points = x.cpu().numpy() if is_torch else x
|
@@ -582,14 +580,14 @@ def xyxyxyxy2xywhr(x):
|
|
582
580
|
|
583
581
|
def xywhr2xyxyxyxy(x):
|
584
582
|
"""
|
585
|
-
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4].
|
586
|
-
be in radians from 0 to pi/2.
|
583
|
+
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
|
587
584
|
|
588
585
|
Args:
|
589
|
-
x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format
|
586
|
+
x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
|
587
|
+
Rotation values should be in radians from 0 to pi/2.
|
590
588
|
|
591
589
|
Returns:
|
592
|
-
(numpy.ndarray | torch.Tensor): Converted corner points
|
590
|
+
(numpy.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
|
593
591
|
"""
|
594
592
|
cos, sin, cat, stack = (
|
595
593
|
(torch.cos, torch.sin, torch.cat, torch.stack)
|
@@ -616,10 +614,10 @@ def ltwh2xyxy(x):
|
|
616
614
|
Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
617
615
|
|
618
616
|
Args:
|
619
|
-
x (np.ndarray | torch.Tensor):
|
617
|
+
x (np.ndarray | torch.Tensor): Input bounding box coordinates.
|
620
618
|
|
621
619
|
Returns:
|
622
|
-
(np.ndarray | torch.Tensor):
|
620
|
+
(np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
|
623
621
|
"""
|
624
622
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
625
623
|
y[..., 2] = x[..., 2] + x[..., 0] # width
|
@@ -632,10 +630,10 @@ def segments2boxes(segments):
|
|
632
630
|
Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
|
633
631
|
|
634
632
|
Args:
|
635
|
-
segments (list): List of segments
|
633
|
+
segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
|
636
634
|
|
637
635
|
Returns:
|
638
|
-
(np.ndarray):
|
636
|
+
(np.ndarray): Bounding box coordinates in xywh format.
|
639
637
|
"""
|
640
638
|
boxes = []
|
641
639
|
for s in segments:
|
@@ -644,16 +642,16 @@ def segments2boxes(segments):
|
|
644
642
|
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
645
643
|
|
646
644
|
|
647
|
-
def resample_segments(segments, n=1000):
|
645
|
+
def resample_segments(segments, n: int = 1000):
|
648
646
|
"""
|
649
|
-
|
647
|
+
Resample segments to n points each using linear interpolation.
|
650
648
|
|
651
649
|
Args:
|
652
|
-
segments (list):
|
653
|
-
n (int): Number of points to resample
|
650
|
+
segments (list): List of (N, 2) arrays where N is the number of points in each segment.
|
651
|
+
n (int): Number of points to resample each segment to.
|
654
652
|
|
655
653
|
Returns:
|
656
|
-
|
654
|
+
(list): Resampled segments with n points each.
|
657
655
|
"""
|
658
656
|
for i, s in enumerate(segments):
|
659
657
|
if len(s) == n:
|
@@ -670,11 +668,11 @@ def resample_segments(segments, n=1000):
|
|
670
668
|
|
671
669
|
def crop_mask(masks, boxes):
|
672
670
|
"""
|
673
|
-
Crop masks to bounding
|
671
|
+
Crop masks to bounding box regions.
|
674
672
|
|
675
673
|
Args:
|
676
|
-
masks (torch.Tensor):
|
677
|
-
boxes (torch.Tensor):
|
674
|
+
masks (torch.Tensor): Masks with shape (N, H, W).
|
675
|
+
boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
|
678
676
|
|
679
677
|
Returns:
|
680
678
|
(torch.Tensor): Cropped masks.
|
@@ -687,16 +685,16 @@ def crop_mask(masks, boxes):
|
|
687
685
|
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
688
686
|
|
689
687
|
|
690
|
-
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
688
|
+
def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
|
691
689
|
"""
|
692
|
-
Apply masks to bounding boxes using
|
690
|
+
Apply masks to bounding boxes using mask head output.
|
693
691
|
|
694
692
|
Args:
|
695
|
-
protos (torch.Tensor):
|
696
|
-
masks_in (torch.Tensor):
|
697
|
-
bboxes (torch.Tensor):
|
698
|
-
shape (tuple):
|
699
|
-
upsample (bool):
|
693
|
+
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
|
694
|
+
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
|
695
|
+
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
|
696
|
+
shape (tuple): Input image size as (height, width).
|
697
|
+
upsample (bool): Whether to upsample masks to original image size.
|
700
698
|
|
701
699
|
Returns:
|
702
700
|
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
@@ -722,16 +720,16 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
|
722
720
|
|
723
721
|
def process_mask_native(protos, masks_in, bboxes, shape):
|
724
722
|
"""
|
725
|
-
Apply masks to bounding boxes using
|
723
|
+
Apply masks to bounding boxes using mask head output with native upsampling.
|
726
724
|
|
727
725
|
Args:
|
728
|
-
protos (torch.Tensor):
|
729
|
-
masks_in (torch.Tensor):
|
730
|
-
bboxes (torch.Tensor):
|
731
|
-
shape (tuple):
|
726
|
+
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
|
727
|
+
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
|
728
|
+
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
|
729
|
+
shape (tuple): Input image size as (height, width).
|
732
730
|
|
733
731
|
Returns:
|
734
|
-
(torch.Tensor):
|
732
|
+
(torch.Tensor): Binary mask tensor with shape (H, W, N).
|
735
733
|
"""
|
736
734
|
c, mh, mw = protos.shape # CHW
|
737
735
|
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
@@ -740,15 +738,14 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
|
740
738
|
return masks.gt_(0.0)
|
741
739
|
|
742
740
|
|
743
|
-
def scale_masks(masks, shape, padding=True):
|
741
|
+
def scale_masks(masks, shape, padding: bool = True):
|
744
742
|
"""
|
745
|
-
Rescale segment masks to shape.
|
743
|
+
Rescale segment masks to target shape.
|
746
744
|
|
747
745
|
Args:
|
748
|
-
masks (torch.Tensor): (N, C, H, W).
|
749
|
-
shape (tuple):
|
750
|
-
padding (bool):
|
751
|
-
rescaling.
|
746
|
+
masks (torch.Tensor): Masks with shape (N, C, H, W).
|
747
|
+
shape (tuple): Target height and width as (height, width).
|
748
|
+
padding (bool): Whether masks are based on YOLO-style augmented images with padding.
|
752
749
|
|
753
750
|
Returns:
|
754
751
|
(torch.Tensor): Rescaled masks.
|
@@ -767,21 +764,20 @@ def scale_masks(masks, shape, padding=True):
|
|
767
764
|
return masks
|
768
765
|
|
769
766
|
|
770
|
-
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
767
|
+
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
|
771
768
|
"""
|
772
|
-
Rescale segment coordinates
|
769
|
+
Rescale segment coordinates from img1_shape to img0_shape.
|
773
770
|
|
774
771
|
Args:
|
775
|
-
img1_shape (tuple):
|
776
|
-
coords (torch.Tensor):
|
777
|
-
img0_shape (tuple):
|
778
|
-
ratio_pad (tuple):
|
779
|
-
normalize (bool):
|
780
|
-
padding (bool):
|
781
|
-
rescaling.
|
772
|
+
img1_shape (tuple): Shape of the source image.
|
773
|
+
coords (torch.Tensor): Coordinates to scale with shape (N, 2).
|
774
|
+
img0_shape (tuple): Shape of the target image.
|
775
|
+
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
|
776
|
+
normalize (bool): Whether to normalize coordinates to range [0, 1].
|
777
|
+
padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
|
782
778
|
|
783
779
|
Returns:
|
784
|
-
|
780
|
+
(torch.Tensor): Scaled coordinates.
|
785
781
|
"""
|
786
782
|
if ratio_pad is None: # calculate from img0_shape
|
787
783
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
@@ -804,13 +800,13 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
|
804
800
|
|
805
801
|
def regularize_rboxes(rboxes):
|
806
802
|
"""
|
807
|
-
Regularize rotated boxes
|
803
|
+
Regularize rotated bounding boxes to range [0, pi/2].
|
808
804
|
|
809
805
|
Args:
|
810
|
-
rboxes (torch.Tensor): Input boxes
|
806
|
+
rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
|
811
807
|
|
812
808
|
Returns:
|
813
|
-
(torch.Tensor):
|
809
|
+
(torch.Tensor): Regularized rotated boxes.
|
814
810
|
"""
|
815
811
|
x, y, w, h, t = rboxes.unbind(dim=-1)
|
816
812
|
# Swap edge if t >= pi/2 while not being symmetrically opposite
|
@@ -821,16 +817,16 @@ def regularize_rboxes(rboxes):
|
|
821
817
|
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
822
818
|
|
823
819
|
|
824
|
-
def masks2segments(masks, strategy="all"):
|
820
|
+
def masks2segments(masks, strategy: str = "all"):
|
825
821
|
"""
|
826
|
-
Convert masks to segments.
|
822
|
+
Convert masks to segments using contour detection.
|
827
823
|
|
828
824
|
Args:
|
829
|
-
masks (torch.Tensor):
|
830
|
-
strategy (str): 'all' or 'largest'.
|
825
|
+
masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
|
826
|
+
strategy (str): Segmentation strategy, either 'all' or 'largest'.
|
831
827
|
|
832
828
|
Returns:
|
833
|
-
(list): List of segment masks.
|
829
|
+
(list): List of segment masks as float32 arrays.
|
834
830
|
"""
|
835
831
|
from ultralytics.data.converter import merge_multi_segment
|
836
832
|
|
@@ -854,20 +850,20 @@ def masks2segments(masks, strategy="all"):
|
|
854
850
|
|
855
851
|
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
856
852
|
"""
|
857
|
-
Convert a batch of FP32 torch tensors
|
853
|
+
Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
|
858
854
|
|
859
855
|
Args:
|
860
|
-
batch (torch.Tensor): Input tensor batch
|
856
|
+
batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
|
861
857
|
|
862
858
|
Returns:
|
863
|
-
(np.ndarray): Output NumPy array batch
|
859
|
+
(np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
|
864
860
|
"""
|
865
861
|
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
866
862
|
|
867
863
|
|
868
864
|
def clean_str(s):
|
869
865
|
"""
|
870
|
-
|
866
|
+
Clean a string by replacing special characters with '_' character.
|
871
867
|
|
872
868
|
Args:
|
873
869
|
s (str): A string needing special characters replaced.
|
@@ -879,7 +875,7 @@ def clean_str(s):
|
|
879
875
|
|
880
876
|
|
881
877
|
def empty_like(x):
|
882
|
-
"""
|
878
|
+
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
|
883
879
|
return (
|
884
880
|
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
|
885
881
|
)
|