xbarray 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- array_api_typing/__init__.py +9 -0
- array_api_typing/typing_2024_12/__init__.py +12 -0
- array_api_typing/typing_2024_12/_api_constant.py +32 -0
- array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
- array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
- array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
- array_api_typing/typing_2024_12/_api_typing.py +5855 -0
- array_api_typing/typing_2024_12/_array_typing.py +1265 -0
- array_api_typing/typing_compat/__init__.py +12 -0
- array_api_typing/typing_compat/_api_typing.py +27 -0
- array_api_typing/typing_compat/_array_typing.py +36 -0
- array_api_typing/typing_extra/__init__.py +12 -0
- array_api_typing/typing_extra/_api_typing.py +651 -0
- array_api_typing/typing_extra/_at.py +87 -0
- xbarray/__init__.py +1 -0
- xbarray/backends/_cls_base.py +9 -0
- xbarray/backends/_implementations/_common/implementations.py +87 -0
- xbarray/backends/_implementations/jax/__init__.py +33 -0
- xbarray/backends/_implementations/jax/_extra.py +127 -0
- xbarray/backends/_implementations/jax/_typing.py +15 -0
- xbarray/backends/_implementations/jax/random.py +115 -0
- xbarray/backends/_implementations/numpy/__init__.py +25 -0
- xbarray/backends/_implementations/numpy/_extra.py +98 -0
- xbarray/backends/_implementations/numpy/_typing.py +14 -0
- xbarray/backends/_implementations/numpy/random.py +105 -0
- xbarray/backends/_implementations/pytorch/__init__.py +26 -0
- xbarray/backends/_implementations/pytorch/_extra.py +135 -0
- xbarray/backends/_implementations/pytorch/_typing.py +13 -0
- xbarray/backends/_implementations/pytorch/random.py +101 -0
- xbarray/backends/base.py +218 -0
- xbarray/backends/jax.py +19 -0
- xbarray/backends/numpy.py +19 -0
- xbarray/backends/pytorch.py +22 -0
- xbarray/jax.py +4 -0
- xbarray/numpy.py +4 -0
- xbarray/pytorch.py +4 -0
- xbarray/transformations/pointcloud/__init__.py +1 -0
- xbarray/transformations/pointcloud/base.py +449 -0
- xbarray/transformations/pointcloud/jax.py +24 -0
- xbarray/transformations/pointcloud/numpy.py +23 -0
- xbarray/transformations/pointcloud/pytorch.py +23 -0
- xbarray/transformations/rotation_conversions/__init__.py +1 -0
- xbarray/transformations/rotation_conversions/base.py +713 -0
- xbarray/transformations/rotation_conversions/jax.py +41 -0
- xbarray/transformations/rotation_conversions/numpy.py +41 -0
- xbarray/transformations/rotation_conversions/pytorch.py +41 -0
- xbarray-0.0.1a13.dist-info/METADATA +20 -0
- xbarray-0.0.1a13.dist-info/RECORD +51 -0
- xbarray-0.0.1a13.dist-info/WHEEL +5 -0
- xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
- xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"gather_pixel_value",
|
|
6
|
+
"bilinear_interpolate",
|
|
7
|
+
"pixel_coordinate_and_depth_to_world",
|
|
8
|
+
"depth_image_to_world",
|
|
9
|
+
"world_to_pixel_coordinate_and_depth",
|
|
10
|
+
"world_to_depth",
|
|
11
|
+
"farthest_point_sampling",
|
|
12
|
+
"random_point_sampling"
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
def gather_pixel_value(
|
|
16
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
17
|
+
values : BArrayType,
|
|
18
|
+
pixel_coordinates : BArrayType
|
|
19
|
+
) -> BArrayType:
|
|
20
|
+
"""
|
|
21
|
+
Gather pixel values at given pixel coordinates.
|
|
22
|
+
Args:
|
|
23
|
+
backend (ComputeBackend): The compute backend to use.
|
|
24
|
+
values (BArrayType): The input values of shape (..., H, W, C).
|
|
25
|
+
pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2) in (x, y) order.
|
|
26
|
+
Returns:
|
|
27
|
+
BArrayType: The gathered values of shape (..., N, C).
|
|
28
|
+
"""
|
|
29
|
+
assert backend.dtype_is_real_integer(pixel_coordinates.dtype), "pixel_coordinates must be of integer type."
|
|
30
|
+
flat_values = backend.reshape(values, (*values.shape[:-3], -1, values.shape[-1])) # (..., H * W, C)
|
|
31
|
+
H, W = values.shape[-3], values.shape[-2]
|
|
32
|
+
pixel_coordinates_x = pixel_coordinates[..., 0] # (..., N)
|
|
33
|
+
pixel_coordinates_y = pixel_coordinates[..., 1] # (..., N)
|
|
34
|
+
pixel_coordinates_x = backend.clip(pixel_coordinates_x, 0, W - 1)
|
|
35
|
+
pixel_coordinates_y = backend.clip(pixel_coordinates_y, 0, H - 1)
|
|
36
|
+
flat_indices = pixel_coordinates_y * W + pixel_coordinates_x # (..., N)
|
|
37
|
+
gathered_values = backend.take_along_axis(
|
|
38
|
+
flat_values,
|
|
39
|
+
flat_indices[..., None],
|
|
40
|
+
axis=-2
|
|
41
|
+
) # (..., N, C)
|
|
42
|
+
return gathered_values
|
|
43
|
+
|
|
44
|
+
def bilinear_interpolate(
|
|
45
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
46
|
+
values : BArrayType,
|
|
47
|
+
pixel_coordinates : BArrayType,
|
|
48
|
+
index_dtype : Optional[BDtypeType] = None,
|
|
49
|
+
uniform_weights : bool = False
|
|
50
|
+
) -> BArrayType:
|
|
51
|
+
"""
|
|
52
|
+
Obtain bilinearly interpolated values at given pixel coordinates.
|
|
53
|
+
Args:
|
|
54
|
+
backend (ComputeBackend): The compute backend to use.
|
|
55
|
+
values (BArrayType): The input values of shape (..., H, W, C).
|
|
56
|
+
pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2) in (x, y) order.
|
|
57
|
+
Returns:
|
|
58
|
+
BArrayType: The interpolated values of shape (..., N, C).
|
|
59
|
+
"""
|
|
60
|
+
H, W = values.shape[-3], values.shape[-2]
|
|
61
|
+
x = pixel_coordinates[..., 0] # (..., N)
|
|
62
|
+
y = pixel_coordinates[..., 1] # (..., N)
|
|
63
|
+
|
|
64
|
+
index_dtype = index_dtype if index_dtype is not None else (pixel_coordinates.dtype if backend.dtype_is_real_integer(pixel_coordinates.dtype) else backend.default_index_dtype)
|
|
65
|
+
|
|
66
|
+
x = backend.clip(x, 0, W - 1)
|
|
67
|
+
y = backend.clip(y, 0, H - 1)
|
|
68
|
+
x0 = backend.astype(backend.floor(x), index_dtype)
|
|
69
|
+
x1 = backend.astype(backend.clip(x0 + 1, max=W - 1), index_dtype)
|
|
70
|
+
y0 = backend.astype(backend.floor(y), index_dtype)
|
|
71
|
+
y1 = backend.astype(backend.clip(y0 + 1, max=H - 1), index_dtype)
|
|
72
|
+
|
|
73
|
+
pc00 = backend.stack([x0, y0], axis=-1) # (..., N, 2)
|
|
74
|
+
pc01 = backend.stack([x0, y1], axis=-1) # (..., N, 2)
|
|
75
|
+
pc10 = backend.stack([x1, y0], axis=-1) # (..., N, 2)
|
|
76
|
+
pc11 = backend.stack([x1, y1], axis=-1) # (..., N, 2)
|
|
77
|
+
all_queries = backend.concat([pc00, pc01, pc10, pc11], axis=-2) # (..., 4 * N, 2)
|
|
78
|
+
gathered_values = gather_pixel_value(
|
|
79
|
+
backend,
|
|
80
|
+
values,
|
|
81
|
+
all_queries
|
|
82
|
+
) # (..., 4 * N, C)
|
|
83
|
+
values_00 = gathered_values[..., :gathered_values.shape[-2] // 4, :] # (..., N, C)
|
|
84
|
+
values_01 = gathered_values[..., gathered_values.shape[-2] // 4:2 * gathered_values.shape[-2] // 4, :] # (..., N, C)
|
|
85
|
+
values_10 = gathered_values[..., 2 * gathered_values.shape[-2] // 4:3 * gathered_values.shape[-2] // 4, :] # (..., N, C)
|
|
86
|
+
values_11 = gathered_values[..., 3 * gathered_values.shape[-2] // 4:, :] # (..., N, C)
|
|
87
|
+
|
|
88
|
+
weight_00 = backend.all(backend.logical_not(backend.isnan(values_00)), axis=-1) # (..., N)
|
|
89
|
+
weight_01 = backend.all(backend.logical_not(backend.isnan(values_01)), axis=-1) # (..., N)
|
|
90
|
+
weight_10 = backend.all(backend.logical_not(backend.isnan(values_10)), axis=-1) # (..., N)
|
|
91
|
+
weight_11 = backend.all(backend.logical_not(backend.isnan(values_11)), axis=-1) # (..., N)
|
|
92
|
+
values_00 = backend.where(
|
|
93
|
+
weight_00[..., None],
|
|
94
|
+
values_00,
|
|
95
|
+
0
|
|
96
|
+
) # (..., N, C)
|
|
97
|
+
values_01 = backend.where(
|
|
98
|
+
weight_01[..., None],
|
|
99
|
+
values_01,
|
|
100
|
+
0
|
|
101
|
+
) # (..., N, C)
|
|
102
|
+
values_10 = backend.where(
|
|
103
|
+
weight_10[..., None],
|
|
104
|
+
values_10,
|
|
105
|
+
0
|
|
106
|
+
) # (..., N, C)
|
|
107
|
+
values_11 = backend.where(
|
|
108
|
+
weight_11[..., None],
|
|
109
|
+
values_11,
|
|
110
|
+
0
|
|
111
|
+
) # (..., N, C)
|
|
112
|
+
weight_00 = backend.astype(weight_00, values.dtype) # (..., N)
|
|
113
|
+
weight_01 = backend.astype(weight_01, values.dtype) # (..., N)
|
|
114
|
+
weight_10 = backend.astype(weight_10, values.dtype) # (..., N)
|
|
115
|
+
weight_11 = backend.astype(weight_11, values.dtype) # (..., N)
|
|
116
|
+
|
|
117
|
+
if not uniform_weights:
|
|
118
|
+
weight_00 *= (x1 - x) * (y1 - y)
|
|
119
|
+
weight_01 *= (x1 - x) * (y - y0)
|
|
120
|
+
weight_10 *= (x - x0) * (y1 - y)
|
|
121
|
+
weight_11 *= (x - x0) * (y - y0)
|
|
122
|
+
weights_sum = weight_00 + weight_01 + weight_10 + weight_11 # (..., N)
|
|
123
|
+
weights_sum = backend.clip(weights_sum, min=1e-6)
|
|
124
|
+
weight_00 /= weights_sum
|
|
125
|
+
weight_01 /= weights_sum
|
|
126
|
+
weight_10 /= weights_sum
|
|
127
|
+
weight_11 /= weights_sum
|
|
128
|
+
interpolated_values = (
|
|
129
|
+
values_00 * weight_00[..., None] +
|
|
130
|
+
values_01 * weight_01[..., None] +
|
|
131
|
+
values_10 * weight_10[..., None] +
|
|
132
|
+
values_11 * weight_11[..., None]
|
|
133
|
+
) # (..., N, C)
|
|
134
|
+
return interpolated_values
|
|
135
|
+
|
|
136
|
+
def pixel_coordinate_and_depth_to_world(
|
|
137
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
138
|
+
pixel_coordinates : BArrayType,
|
|
139
|
+
depth : BArrayType,
|
|
140
|
+
intrinsic_matrix : BArrayType,
|
|
141
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
142
|
+
) -> BArrayType:
|
|
143
|
+
"""
|
|
144
|
+
Convert pixel coordinates and depth to world coordinates.
|
|
145
|
+
Args:
|
|
146
|
+
backend (ComputeBackend): The compute backend to use.
|
|
147
|
+
pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2).
|
|
148
|
+
depth (BArrayType): The depth values of shape (..., N). Assume invalid depth is either nan or <= 0.
|
|
149
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
150
|
+
extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
|
|
151
|
+
Returns:
|
|
152
|
+
BArrayType: The world coordinates of shape (..., N, 4). The last dimension is (x, y, z, valid_mask).
|
|
153
|
+
"""
|
|
154
|
+
xs = pixel_coordinates[..., 0] # (..., N)
|
|
155
|
+
ys = pixel_coordinates[..., 1] # (..., N)
|
|
156
|
+
xs_norm = (xs - intrinsic_matrix[..., None, 0, 2]) / intrinsic_matrix[..., None, 0, 0] # (..., N)
|
|
157
|
+
ys_norm = (ys - intrinsic_matrix[..., None, 1, 2]) / intrinsic_matrix[..., None, 1, 1] # (..., N)
|
|
158
|
+
|
|
159
|
+
camera_coords = backend.stack([
|
|
160
|
+
xs_norm,
|
|
161
|
+
ys_norm,
|
|
162
|
+
backend.ones_like(depth)
|
|
163
|
+
], axis=-1) # (..., N, 3)
|
|
164
|
+
camera_coords *= depth[..., None] # (..., N, 3)
|
|
165
|
+
|
|
166
|
+
if extrinsic_matrix is not None:
|
|
167
|
+
R = extrinsic_matrix[..., :3, :3] # (..., 3, 3)
|
|
168
|
+
t = extrinsic_matrix[..., :3, 3] # (..., 3)
|
|
169
|
+
|
|
170
|
+
shifted_camera_coords = camera_coords - t[..., None, :] # (..., N, 3)
|
|
171
|
+
world_coords = backend.matmul(shifted_camera_coords, R) # (..., N, 3)
|
|
172
|
+
else:
|
|
173
|
+
world_coords = camera_coords # (..., N, 3)
|
|
174
|
+
|
|
175
|
+
valid_depth_mask = backend.logical_not(backend.logical_or(
|
|
176
|
+
backend.isnan(depth),
|
|
177
|
+
depth <= 0
|
|
178
|
+
)) # (..., N)
|
|
179
|
+
return backend.concat([
|
|
180
|
+
world_coords,
|
|
181
|
+
valid_depth_mask[..., None]
|
|
182
|
+
], axis=-1) # (..., N, 4)
|
|
183
|
+
|
|
184
|
+
def depth_image_to_world(
|
|
185
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
186
|
+
depth_image : BArrayType,
|
|
187
|
+
intrinsic_matrix : BArrayType,
|
|
188
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
189
|
+
) -> BArrayType:
|
|
190
|
+
"""
|
|
191
|
+
Convert a depth image to world coordinates.
|
|
192
|
+
Args:
|
|
193
|
+
backend (ComputeBackend): The compute backend to use.
|
|
194
|
+
depth_image (BArrayType): The depth image of shape (..., H, W).
|
|
195
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
196
|
+
extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
|
|
197
|
+
Returns:
|
|
198
|
+
BArrayType: The world coordinates of shape (..., H, W, 4). The last dimension is (x, y, z, valid_mask).
|
|
199
|
+
"""
|
|
200
|
+
H, W = depth_image.shape[-2:]
|
|
201
|
+
xs, ys = backend.meshgrid(
|
|
202
|
+
backend.arange(W, device=backend.device(depth_image), dtype=depth_image.dtype),
|
|
203
|
+
backend.arange(H, device=backend.device(depth_image), dtype=depth_image.dtype),
|
|
204
|
+
indexing="xy"
|
|
205
|
+
) # (H, W), (H, W)
|
|
206
|
+
assert xs.shape == (H, W) and ys.shape == (H, W)
|
|
207
|
+
|
|
208
|
+
pixel_coordinates = backend.stack([xs, ys], axis=-1) # (H, W, 2)
|
|
209
|
+
pixel_coordinates = backend.reshape(pixel_coordinates, [1] * (len(depth_image.shape) - 2) + [H * W, 2]) # (..., H * W, 2)
|
|
210
|
+
world_coords = pixel_coordinate_and_depth_to_world(
|
|
211
|
+
backend,
|
|
212
|
+
pixel_coordinates,
|
|
213
|
+
backend.reshape(depth_image, list(depth_image.shape[:-2]) + [H * W]), # (..., H * W)
|
|
214
|
+
intrinsic_matrix,
|
|
215
|
+
extrinsic_matrix
|
|
216
|
+
) # (..., H * W, 4)
|
|
217
|
+
world_coords = backend.reshape(world_coords, list(depth_image.shape[:-2]) + [H, W, 4]) # (..., H, W, 4)
|
|
218
|
+
return world_coords
|
|
219
|
+
|
|
220
|
+
def world_to_pixel_coordinate_and_depth(
|
|
221
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
222
|
+
world_coords : BArrayType,
|
|
223
|
+
intrinsic_matrix : BArrayType,
|
|
224
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
225
|
+
) -> Tuple[BArrayType, BArrayType]:
|
|
226
|
+
"""
|
|
227
|
+
Convert world coordinates to pixel coordinates and depth.
|
|
228
|
+
Args:
|
|
229
|
+
backend (ComputeBackend): The compute backend to use.
|
|
230
|
+
world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
|
|
231
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
232
|
+
extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
|
|
233
|
+
Returns:
|
|
234
|
+
BArrayType: The pixel coordinates xy of shape (..., N, 2).
|
|
235
|
+
BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
|
|
236
|
+
"""
|
|
237
|
+
if world_coords.shape[-1] == 3:
|
|
238
|
+
world_coords_h = backend.pad_dim(
|
|
239
|
+
world_coords,
|
|
240
|
+
dim=-1,
|
|
241
|
+
target_size=4,
|
|
242
|
+
value=1
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
assert world_coords.shape[-1] == 4
|
|
246
|
+
world_coords_h = world_coords
|
|
247
|
+
|
|
248
|
+
if extrinsic_matrix is not None:
|
|
249
|
+
camera_coords = backend.matmul(
|
|
250
|
+
extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
|
|
251
|
+
backend.matrix_transpose(world_coords_h) # (..., 4, N)
|
|
252
|
+
) # (..., 3, N) or (..., 4, N)
|
|
253
|
+
camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
|
|
254
|
+
if camera_coords.shape[-1] == 4:
|
|
255
|
+
camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
|
|
256
|
+
else:
|
|
257
|
+
camera_coords = world_coords_h[..., :3] # (..., N, 3)
|
|
258
|
+
|
|
259
|
+
point_px_homogeneous = backend.matmul(
|
|
260
|
+
intrinsic_matrix, # (..., 3, 3)
|
|
261
|
+
backend.matrix_transpose(camera_coords) # (..., 3, N)
|
|
262
|
+
) # (..., 3, N)
|
|
263
|
+
point_px_homogeneous = backend.matrix_transpose(point_px_homogeneous) # (..., N, 3)
|
|
264
|
+
point_px = point_px_homogeneous[..., :2] / point_px_homogeneous[..., 2:3] # (..., N, 2)
|
|
265
|
+
|
|
266
|
+
depth = camera_coords[..., 2] # (..., N)
|
|
267
|
+
depth_valid = depth > 0
|
|
268
|
+
depth = backend.where(depth_valid, depth, 0)
|
|
269
|
+
point_px = backend.where(
|
|
270
|
+
depth_valid[..., None],
|
|
271
|
+
point_px,
|
|
272
|
+
0
|
|
273
|
+
)
|
|
274
|
+
return point_px, depth
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def world_to_depth(
|
|
278
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
279
|
+
world_coords : BArrayType,
|
|
280
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
281
|
+
) -> BArrayType:
|
|
282
|
+
"""
|
|
283
|
+
Convert world coordinates to pixel coordinates and depth.
|
|
284
|
+
Args:
|
|
285
|
+
backend (ComputeBackend): The compute backend to use.
|
|
286
|
+
world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
|
|
287
|
+
extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
|
|
288
|
+
Returns:
|
|
289
|
+
BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
|
|
290
|
+
"""
|
|
291
|
+
if world_coords.shape[-1] == 3:
|
|
292
|
+
world_coords_h = backend.pad_dim(
|
|
293
|
+
world_coords,
|
|
294
|
+
dim=-1,
|
|
295
|
+
value=0
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
assert world_coords.shape[-1] == 4
|
|
299
|
+
world_coords_h = world_coords
|
|
300
|
+
|
|
301
|
+
if extrinsic_matrix is not None:
|
|
302
|
+
camera_coords = backend.matmul(
|
|
303
|
+
extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
|
|
304
|
+
backend.matrix_transpose(world_coords_h) # (..., 4, N)
|
|
305
|
+
) # (..., 3, N) or (..., 4, N)
|
|
306
|
+
camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
|
|
307
|
+
if camera_coords.shape[-1] == 4:
|
|
308
|
+
camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
|
|
309
|
+
else:
|
|
310
|
+
camera_coords = world_coords_h[..., :3] # (..., N, 3)
|
|
311
|
+
|
|
312
|
+
depth = camera_coords[..., 2] # (..., N)
|
|
313
|
+
depth_valid = depth > 0
|
|
314
|
+
depth = backend.where(depth_valid, depth, 0)
|
|
315
|
+
return depth
|
|
316
|
+
|
|
317
|
+
def farthest_point_sampling(
|
|
318
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
319
|
+
points : BArrayType,
|
|
320
|
+
num_samples : int,
|
|
321
|
+
rng : BRNGType,
|
|
322
|
+
points_valid : Optional[BArrayType] = None
|
|
323
|
+
) -> Tuple[BRNGType, BArrayType, Optional[BArrayType]]:
|
|
324
|
+
"""
|
|
325
|
+
Perform farthest point sampling on a set of points.
|
|
326
|
+
Args:
|
|
327
|
+
backend (ComputeBackend): The compute backend to use.
|
|
328
|
+
points (BArrayType): The input points of shape (..., N, D).
|
|
329
|
+
num_samples (int): The number of points to sample.
|
|
330
|
+
rng (BRNGType): The random number generator.
|
|
331
|
+
points_valid (Optional[BArrayType]): A boolean mask of shape (..., N) indicating valid points. If None, all points are considered valid.
|
|
332
|
+
Returns:
|
|
333
|
+
BRNGType: The updated random number generator.
|
|
334
|
+
BArrayType: The indices of the sampled points of shape (..., num_samples).
|
|
335
|
+
Optional[BArrayType]: The valid mask of shape (..., num_samples). Returned only if points_valid is provided.
|
|
336
|
+
"""
|
|
337
|
+
assert 0 < num_samples <= points.shape[-2], "num_samples must be in (0, N]"
|
|
338
|
+
device = backend.device(points)
|
|
339
|
+
|
|
340
|
+
flat_points = backend.reshape(points, [-1, *points.shape[-2:]]) # (B, N, D)
|
|
341
|
+
B, N, D = flat_points.shape
|
|
342
|
+
flat_points_valid = None if points_valid is None else backend.reshape(points_valid, [-1, N]) # (B, N)
|
|
343
|
+
|
|
344
|
+
batch_indices = backend.arange(B, dtype=backend.default_index_dtype, device=device)
|
|
345
|
+
|
|
346
|
+
centroids_idx = backend.zeros((B, num_samples), dtype=backend.default_index_dtype, device=device) # sampled point indices
|
|
347
|
+
centroids_valid = None if flat_points_valid is None else backend.zeros((B, num_samples), dtype=backend.default_boolean_dtype, device=device) # valid mask of sampled points
|
|
348
|
+
|
|
349
|
+
distance = backend.full((B, N), backend.inf, device=device) # distance of each point to its nearest centroid
|
|
350
|
+
if flat_points_valid is not None:
|
|
351
|
+
distance = backend.where(
|
|
352
|
+
flat_points_valid,
|
|
353
|
+
distance,
|
|
354
|
+
-backend.inf
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
if flat_points_valid is not None:
|
|
358
|
+
farthest_idx = backend.argmax(
|
|
359
|
+
backend.astype(flat_points_valid, backend.default_index_dtype),
|
|
360
|
+
axis=1
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
rng, farthest_idx = backend.random.random_discrete_uniform(
|
|
364
|
+
(B,),
|
|
365
|
+
0, N,
|
|
366
|
+
rng=rng,
|
|
367
|
+
dtype=backend.default_index_dtype,
|
|
368
|
+
device=device
|
|
369
|
+
) # initial random farthest point
|
|
370
|
+
centroids_idx[:, 0] = farthest_idx
|
|
371
|
+
if centroids_valid is not None and flat_points_valid is not None:
|
|
372
|
+
centroids_valid[:, 0] = flat_points_valid[batch_indices, farthest_idx]
|
|
373
|
+
|
|
374
|
+
for i in range(1, num_samples):
|
|
375
|
+
last_centroid = flat_points[batch_indices, farthest_idx][:, None, :] # (B, 1, D)
|
|
376
|
+
perpoint_dist_to_last_centroid = backend.sum((flat_points - last_centroid) ** 2, axis=-1) # (B, N)
|
|
377
|
+
distance = backend.minimum(
|
|
378
|
+
distance,
|
|
379
|
+
perpoint_dist_to_last_centroid
|
|
380
|
+
) # (B, N)
|
|
381
|
+
farthest_idx = backend.argmax(distance, axis=1) # (B,)
|
|
382
|
+
centroids_idx[:, i] = farthest_idx
|
|
383
|
+
if centroids_valid is not None and flat_points_valid is not None:
|
|
384
|
+
centroids_valid[:, i] = flat_points_valid[batch_indices, farthest_idx]
|
|
385
|
+
|
|
386
|
+
unflat_centroids_idx = backend.reshape(centroids_idx, list(points.shape[:-2]) + [num_samples]) # (..., num_samples)
|
|
387
|
+
unflat_centroids_valid = None if centroids_valid is None else backend.reshape(centroids_valid, list(points.shape[:-2]) + [num_samples]) # (..., num_samples)
|
|
388
|
+
return rng, unflat_centroids_idx, unflat_centroids_valid
|
|
389
|
+
|
|
390
|
+
def random_point_sampling(
|
|
391
|
+
backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
|
|
392
|
+
points : BArrayType,
|
|
393
|
+
num_samples : int,
|
|
394
|
+
rng : BRNGType,
|
|
395
|
+
points_valid : Optional[BArrayType] = None
|
|
396
|
+
) -> Tuple[BRNGType, BArrayType, Optional[BArrayType]]:
|
|
397
|
+
"""
|
|
398
|
+
Perform random point sampling on a set of points.
|
|
399
|
+
Args:
|
|
400
|
+
backend (ComputeBackend): The compute backend to use.
|
|
401
|
+
points (BArrayType): The input points of shape (..., N, D).
|
|
402
|
+
num_samples (int): The number of points to sample.
|
|
403
|
+
rng (BRNGType): The random number generator.
|
|
404
|
+
points_valid (Optional[BArrayType]): A boolean mask of shape (..., N) indicating valid points. If None, all points are considered valid.
|
|
405
|
+
Returns:
|
|
406
|
+
BRNGType: The updated random number generator.
|
|
407
|
+
BArrayType: The indices of the sampled points of shape (..., num_samples).
|
|
408
|
+
Optional[BArrayType]: The valid mask of shape (..., num_samples). Returned only if points_valid is provided.
|
|
409
|
+
"""
|
|
410
|
+
assert 0 < num_samples <= points.shape[-2], "num_samples must be in (0, N]"
|
|
411
|
+
device = backend.device(points)
|
|
412
|
+
|
|
413
|
+
flat_points = backend.reshape(points, [-1, *points.shape[-2:]]) # (B, N, D)
|
|
414
|
+
B, N, D = flat_points.shape
|
|
415
|
+
flat_points_valid = None if points_valid is None else backend.reshape(points_valid, [-1, N]) # (B, N)
|
|
416
|
+
|
|
417
|
+
if flat_points_valid is None:
|
|
418
|
+
sampled_idx = backend.empty((B, num_samples), dtype=backend.default_index_dtype, device=device)
|
|
419
|
+
for b in range(B):
|
|
420
|
+
rng, idx_b = backend.random.random_permutation(
|
|
421
|
+
N,
|
|
422
|
+
rng=rng,
|
|
423
|
+
device=device
|
|
424
|
+
)
|
|
425
|
+
sampled_idx[b] = idx_b[:num_samples]
|
|
426
|
+
unflat_sampled_idx = backend.reshape(sampled_idx, list(points.shape[:-2]) + [num_samples])
|
|
427
|
+
return rng, unflat_sampled_idx, None
|
|
428
|
+
else:
|
|
429
|
+
# valid_counts = backend.sum(
|
|
430
|
+
# backend.astype(flat_points_valid, backend.default_index_dtype),
|
|
431
|
+
# axis=1
|
|
432
|
+
# ) # (B,)
|
|
433
|
+
# assert bool(backend.all(valid_counts >= num_samples)), "Not enough valid points to sample from."
|
|
434
|
+
sampled_idx = backend.zeros((B, num_samples), dtype=backend.default_index_dtype, device=device)
|
|
435
|
+
sampled_valid = backend.zeros((B, num_samples), dtype=backend.default_boolean_dtype, device=device)
|
|
436
|
+
for b in range(B):
|
|
437
|
+
valid_indices_b = backend.nonzero(flat_points_valid[b])[0] # (valid_count_b,)
|
|
438
|
+
rng, permuted_valid_indices_b = backend.random.random_permutation(
|
|
439
|
+
valid_indices_b.shape[0],
|
|
440
|
+
rng=rng,
|
|
441
|
+
device=device
|
|
442
|
+
)
|
|
443
|
+
sampled_idx_b = valid_indices_b[permuted_valid_indices_b[:num_samples]] # (num_samples,)
|
|
444
|
+
sampled_idx[b, :sampled_idx_b.shape[0]] = sampled_idx_b
|
|
445
|
+
sampled_valid[b, :sampled_idx_b.shape[0]] = True
|
|
446
|
+
sampled_valid[b, sampled_idx_b.shape[0]:] = False
|
|
447
|
+
unflat_sampled_idx = backend.reshape(sampled_idx, list(points.shape[:-2]) + [num_samples])
|
|
448
|
+
unflat_sampled_valid = backend.reshape(sampled_valid, list(points.shape[:-2]) + [num_samples])
|
|
449
|
+
return rng, unflat_sampled_idx, unflat_sampled_valid
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.jax import JaxComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"gather_pixel_value",
|
|
7
|
+
"bilinear_interpolate",
|
|
8
|
+
"pixel_coordinate_and_depth_to_world",
|
|
9
|
+
"depth_image_to_world",
|
|
10
|
+
"world_to_pixel_coordinate_and_depth",
|
|
11
|
+
"world_to_depth",
|
|
12
|
+
"farthest_point_sampling",
|
|
13
|
+
"random_point_sampling",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
|
|
18
|
+
bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
|
|
19
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
20
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
21
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
22
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
23
|
+
farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
|
|
24
|
+
random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.numpy import NumpyComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"gather_pixel_value",
|
|
7
|
+
"bilinear_interpolate",
|
|
8
|
+
"pixel_coordinate_and_depth_to_world",
|
|
9
|
+
"depth_image_to_world",
|
|
10
|
+
"world_to_pixel_coordinate_and_depth",
|
|
11
|
+
"world_to_depth",
|
|
12
|
+
"farthest_point_sampling",
|
|
13
|
+
"random_point_sampling",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
|
|
17
|
+
bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
|
|
18
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
19
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
20
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
21
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
22
|
+
farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
|
|
23
|
+
random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.pytorch import PytorchComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"gather_pixel_value",
|
|
7
|
+
"bilinear_interpolate",
|
|
8
|
+
"pixel_coordinate_and_depth_to_world",
|
|
9
|
+
"depth_image_to_world",
|
|
10
|
+
"world_to_pixel_coordinate_and_depth",
|
|
11
|
+
"world_to_depth",
|
|
12
|
+
"farthest_point_sampling",
|
|
13
|
+
"random_point_sampling",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
|
|
17
|
+
bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
|
|
18
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
19
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
20
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
21
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
22
|
+
farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
|
|
23
|
+
random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|