xbarray 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. array_api_typing/__init__.py +9 -0
  2. array_api_typing/typing_2024_12/__init__.py +12 -0
  3. array_api_typing/typing_2024_12/_api_constant.py +32 -0
  4. array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
  5. array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
  6. array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
  7. array_api_typing/typing_2024_12/_api_typing.py +5855 -0
  8. array_api_typing/typing_2024_12/_array_typing.py +1265 -0
  9. array_api_typing/typing_compat/__init__.py +12 -0
  10. array_api_typing/typing_compat/_api_typing.py +27 -0
  11. array_api_typing/typing_compat/_array_typing.py +36 -0
  12. array_api_typing/typing_extra/__init__.py +12 -0
  13. array_api_typing/typing_extra/_api_typing.py +651 -0
  14. array_api_typing/typing_extra/_at.py +87 -0
  15. xbarray/__init__.py +1 -0
  16. xbarray/backends/_cls_base.py +9 -0
  17. xbarray/backends/_implementations/_common/implementations.py +87 -0
  18. xbarray/backends/_implementations/jax/__init__.py +33 -0
  19. xbarray/backends/_implementations/jax/_extra.py +127 -0
  20. xbarray/backends/_implementations/jax/_typing.py +15 -0
  21. xbarray/backends/_implementations/jax/random.py +115 -0
  22. xbarray/backends/_implementations/numpy/__init__.py +25 -0
  23. xbarray/backends/_implementations/numpy/_extra.py +98 -0
  24. xbarray/backends/_implementations/numpy/_typing.py +14 -0
  25. xbarray/backends/_implementations/numpy/random.py +105 -0
  26. xbarray/backends/_implementations/pytorch/__init__.py +26 -0
  27. xbarray/backends/_implementations/pytorch/_extra.py +135 -0
  28. xbarray/backends/_implementations/pytorch/_typing.py +13 -0
  29. xbarray/backends/_implementations/pytorch/random.py +101 -0
  30. xbarray/backends/base.py +218 -0
  31. xbarray/backends/jax.py +19 -0
  32. xbarray/backends/numpy.py +19 -0
  33. xbarray/backends/pytorch.py +22 -0
  34. xbarray/jax.py +4 -0
  35. xbarray/numpy.py +4 -0
  36. xbarray/pytorch.py +4 -0
  37. xbarray/transformations/pointcloud/__init__.py +1 -0
  38. xbarray/transformations/pointcloud/base.py +449 -0
  39. xbarray/transformations/pointcloud/jax.py +24 -0
  40. xbarray/transformations/pointcloud/numpy.py +23 -0
  41. xbarray/transformations/pointcloud/pytorch.py +23 -0
  42. xbarray/transformations/rotation_conversions/__init__.py +1 -0
  43. xbarray/transformations/rotation_conversions/base.py +713 -0
  44. xbarray/transformations/rotation_conversions/jax.py +41 -0
  45. xbarray/transformations/rotation_conversions/numpy.py +41 -0
  46. xbarray/transformations/rotation_conversions/pytorch.py +41 -0
  47. xbarray-0.0.1a13.dist-info/METADATA +20 -0
  48. xbarray-0.0.1a13.dist-info/RECORD +51 -0
  49. xbarray-0.0.1a13.dist-info/WHEEL +5 -0
  50. xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
  51. xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
@@ -0,0 +1,449 @@
1
+ from typing import Optional, Tuple
2
+ from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
3
+
4
+ __all__ = [
5
+ "gather_pixel_value",
6
+ "bilinear_interpolate",
7
+ "pixel_coordinate_and_depth_to_world",
8
+ "depth_image_to_world",
9
+ "world_to_pixel_coordinate_and_depth",
10
+ "world_to_depth",
11
+ "farthest_point_sampling",
12
+ "random_point_sampling"
13
+ ]
14
+
15
+ def gather_pixel_value(
16
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
17
+ values : BArrayType,
18
+ pixel_coordinates : BArrayType
19
+ ) -> BArrayType:
20
+ """
21
+ Gather pixel values at given pixel coordinates.
22
+ Args:
23
+ backend (ComputeBackend): The compute backend to use.
24
+ values (BArrayType): The input values of shape (..., H, W, C).
25
+ pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2) in (x, y) order.
26
+ Returns:
27
+ BArrayType: The gathered values of shape (..., N, C).
28
+ """
29
+ assert backend.dtype_is_real_integer(pixel_coordinates.dtype), "pixel_coordinates must be of integer type."
30
+ flat_values = backend.reshape(values, (*values.shape[:-3], -1, values.shape[-1])) # (..., H * W, C)
31
+ H, W = values.shape[-3], values.shape[-2]
32
+ pixel_coordinates_x = pixel_coordinates[..., 0] # (..., N)
33
+ pixel_coordinates_y = pixel_coordinates[..., 1] # (..., N)
34
+ pixel_coordinates_x = backend.clip(pixel_coordinates_x, 0, W - 1)
35
+ pixel_coordinates_y = backend.clip(pixel_coordinates_y, 0, H - 1)
36
+ flat_indices = pixel_coordinates_y * W + pixel_coordinates_x # (..., N)
37
+ gathered_values = backend.take_along_axis(
38
+ flat_values,
39
+ flat_indices[..., None],
40
+ axis=-2
41
+ ) # (..., N, C)
42
+ return gathered_values
43
+
44
+ def bilinear_interpolate(
45
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
46
+ values : BArrayType,
47
+ pixel_coordinates : BArrayType,
48
+ index_dtype : Optional[BDtypeType] = None,
49
+ uniform_weights : bool = False
50
+ ) -> BArrayType:
51
+ """
52
+ Obtain bilinearly interpolated values at given pixel coordinates.
53
+ Args:
54
+ backend (ComputeBackend): The compute backend to use.
55
+ values (BArrayType): The input values of shape (..., H, W, C).
56
+ pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2) in (x, y) order.
57
+ Returns:
58
+ BArrayType: The interpolated values of shape (..., N, C).
59
+ """
60
+ H, W = values.shape[-3], values.shape[-2]
61
+ x = pixel_coordinates[..., 0] # (..., N)
62
+ y = pixel_coordinates[..., 1] # (..., N)
63
+
64
+ index_dtype = index_dtype if index_dtype is not None else (pixel_coordinates.dtype if backend.dtype_is_real_integer(pixel_coordinates.dtype) else backend.default_index_dtype)
65
+
66
+ x = backend.clip(x, 0, W - 1)
67
+ y = backend.clip(y, 0, H - 1)
68
+ x0 = backend.astype(backend.floor(x), index_dtype)
69
+ x1 = backend.astype(backend.clip(x0 + 1, max=W - 1), index_dtype)
70
+ y0 = backend.astype(backend.floor(y), index_dtype)
71
+ y1 = backend.astype(backend.clip(y0 + 1, max=H - 1), index_dtype)
72
+
73
+ pc00 = backend.stack([x0, y0], axis=-1) # (..., N, 2)
74
+ pc01 = backend.stack([x0, y1], axis=-1) # (..., N, 2)
75
+ pc10 = backend.stack([x1, y0], axis=-1) # (..., N, 2)
76
+ pc11 = backend.stack([x1, y1], axis=-1) # (..., N, 2)
77
+ all_queries = backend.concat([pc00, pc01, pc10, pc11], axis=-2) # (..., 4 * N, 2)
78
+ gathered_values = gather_pixel_value(
79
+ backend,
80
+ values,
81
+ all_queries
82
+ ) # (..., 4 * N, C)
83
+ values_00 = gathered_values[..., :gathered_values.shape[-2] // 4, :] # (..., N, C)
84
+ values_01 = gathered_values[..., gathered_values.shape[-2] // 4:2 * gathered_values.shape[-2] // 4, :] # (..., N, C)
85
+ values_10 = gathered_values[..., 2 * gathered_values.shape[-2] // 4:3 * gathered_values.shape[-2] // 4, :] # (..., N, C)
86
+ values_11 = gathered_values[..., 3 * gathered_values.shape[-2] // 4:, :] # (..., N, C)
87
+
88
+ weight_00 = backend.all(backend.logical_not(backend.isnan(values_00)), axis=-1) # (..., N)
89
+ weight_01 = backend.all(backend.logical_not(backend.isnan(values_01)), axis=-1) # (..., N)
90
+ weight_10 = backend.all(backend.logical_not(backend.isnan(values_10)), axis=-1) # (..., N)
91
+ weight_11 = backend.all(backend.logical_not(backend.isnan(values_11)), axis=-1) # (..., N)
92
+ values_00 = backend.where(
93
+ weight_00[..., None],
94
+ values_00,
95
+ 0
96
+ ) # (..., N, C)
97
+ values_01 = backend.where(
98
+ weight_01[..., None],
99
+ values_01,
100
+ 0
101
+ ) # (..., N, C)
102
+ values_10 = backend.where(
103
+ weight_10[..., None],
104
+ values_10,
105
+ 0
106
+ ) # (..., N, C)
107
+ values_11 = backend.where(
108
+ weight_11[..., None],
109
+ values_11,
110
+ 0
111
+ ) # (..., N, C)
112
+ weight_00 = backend.astype(weight_00, values.dtype) # (..., N)
113
+ weight_01 = backend.astype(weight_01, values.dtype) # (..., N)
114
+ weight_10 = backend.astype(weight_10, values.dtype) # (..., N)
115
+ weight_11 = backend.astype(weight_11, values.dtype) # (..., N)
116
+
117
+ if not uniform_weights:
118
+ weight_00 *= (x1 - x) * (y1 - y)
119
+ weight_01 *= (x1 - x) * (y - y0)
120
+ weight_10 *= (x - x0) * (y1 - y)
121
+ weight_11 *= (x - x0) * (y - y0)
122
+ weights_sum = weight_00 + weight_01 + weight_10 + weight_11 # (..., N)
123
+ weights_sum = backend.clip(weights_sum, min=1e-6)
124
+ weight_00 /= weights_sum
125
+ weight_01 /= weights_sum
126
+ weight_10 /= weights_sum
127
+ weight_11 /= weights_sum
128
+ interpolated_values = (
129
+ values_00 * weight_00[..., None] +
130
+ values_01 * weight_01[..., None] +
131
+ values_10 * weight_10[..., None] +
132
+ values_11 * weight_11[..., None]
133
+ ) # (..., N, C)
134
+ return interpolated_values
135
+
136
+ def pixel_coordinate_and_depth_to_world(
137
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
138
+ pixel_coordinates : BArrayType,
139
+ depth : BArrayType,
140
+ intrinsic_matrix : BArrayType,
141
+ extrinsic_matrix : Optional[BArrayType] = None
142
+ ) -> BArrayType:
143
+ """
144
+ Convert pixel coordinates and depth to world coordinates.
145
+ Args:
146
+ backend (ComputeBackend): The compute backend to use.
147
+ pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2).
148
+ depth (BArrayType): The depth values of shape (..., N). Assume invalid depth is either nan or <= 0.
149
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
150
+ extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
151
+ Returns:
152
+ BArrayType: The world coordinates of shape (..., N, 4). The last dimension is (x, y, z, valid_mask).
153
+ """
154
+ xs = pixel_coordinates[..., 0] # (..., N)
155
+ ys = pixel_coordinates[..., 1] # (..., N)
156
+ xs_norm = (xs - intrinsic_matrix[..., None, 0, 2]) / intrinsic_matrix[..., None, 0, 0] # (..., N)
157
+ ys_norm = (ys - intrinsic_matrix[..., None, 1, 2]) / intrinsic_matrix[..., None, 1, 1] # (..., N)
158
+
159
+ camera_coords = backend.stack([
160
+ xs_norm,
161
+ ys_norm,
162
+ backend.ones_like(depth)
163
+ ], axis=-1) # (..., N, 3)
164
+ camera_coords *= depth[..., None] # (..., N, 3)
165
+
166
+ if extrinsic_matrix is not None:
167
+ R = extrinsic_matrix[..., :3, :3] # (..., 3, 3)
168
+ t = extrinsic_matrix[..., :3, 3] # (..., 3)
169
+
170
+ shifted_camera_coords = camera_coords - t[..., None, :] # (..., N, 3)
171
+ world_coords = backend.matmul(shifted_camera_coords, R) # (..., N, 3)
172
+ else:
173
+ world_coords = camera_coords # (..., N, 3)
174
+
175
+ valid_depth_mask = backend.logical_not(backend.logical_or(
176
+ backend.isnan(depth),
177
+ depth <= 0
178
+ )) # (..., N)
179
+ return backend.concat([
180
+ world_coords,
181
+ valid_depth_mask[..., None]
182
+ ], axis=-1) # (..., N, 4)
183
+
184
+ def depth_image_to_world(
185
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
186
+ depth_image : BArrayType,
187
+ intrinsic_matrix : BArrayType,
188
+ extrinsic_matrix : Optional[BArrayType] = None
189
+ ) -> BArrayType:
190
+ """
191
+ Convert a depth image to world coordinates.
192
+ Args:
193
+ backend (ComputeBackend): The compute backend to use.
194
+ depth_image (BArrayType): The depth image of shape (..., H, W).
195
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
196
+ extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
197
+ Returns:
198
+ BArrayType: The world coordinates of shape (..., H, W, 4). The last dimension is (x, y, z, valid_mask).
199
+ """
200
+ H, W = depth_image.shape[-2:]
201
+ xs, ys = backend.meshgrid(
202
+ backend.arange(W, device=backend.device(depth_image), dtype=depth_image.dtype),
203
+ backend.arange(H, device=backend.device(depth_image), dtype=depth_image.dtype),
204
+ indexing="xy"
205
+ ) # (H, W), (H, W)
206
+ assert xs.shape == (H, W) and ys.shape == (H, W)
207
+
208
+ pixel_coordinates = backend.stack([xs, ys], axis=-1) # (H, W, 2)
209
+ pixel_coordinates = backend.reshape(pixel_coordinates, [1] * (len(depth_image.shape) - 2) + [H * W, 2]) # (..., H * W, 2)
210
+ world_coords = pixel_coordinate_and_depth_to_world(
211
+ backend,
212
+ pixel_coordinates,
213
+ backend.reshape(depth_image, list(depth_image.shape[:-2]) + [H * W]), # (..., H * W)
214
+ intrinsic_matrix,
215
+ extrinsic_matrix
216
+ ) # (..., H * W, 4)
217
+ world_coords = backend.reshape(world_coords, list(depth_image.shape[:-2]) + [H, W, 4]) # (..., H, W, 4)
218
+ return world_coords
219
+
220
+ def world_to_pixel_coordinate_and_depth(
221
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
222
+ world_coords : BArrayType,
223
+ intrinsic_matrix : BArrayType,
224
+ extrinsic_matrix : Optional[BArrayType] = None
225
+ ) -> Tuple[BArrayType, BArrayType]:
226
+ """
227
+ Convert world coordinates to pixel coordinates and depth.
228
+ Args:
229
+ backend (ComputeBackend): The compute backend to use.
230
+ world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
231
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
232
+ extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
233
+ Returns:
234
+ BArrayType: The pixel coordinates xy of shape (..., N, 2).
235
+ BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
236
+ """
237
+ if world_coords.shape[-1] == 3:
238
+ world_coords_h = backend.pad_dim(
239
+ world_coords,
240
+ dim=-1,
241
+ target_size=4,
242
+ value=1
243
+ )
244
+ else:
245
+ assert world_coords.shape[-1] == 4
246
+ world_coords_h = world_coords
247
+
248
+ if extrinsic_matrix is not None:
249
+ camera_coords = backend.matmul(
250
+ extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
251
+ backend.matrix_transpose(world_coords_h) # (..., 4, N)
252
+ ) # (..., 3, N) or (..., 4, N)
253
+ camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
254
+ if camera_coords.shape[-1] == 4:
255
+ camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
256
+ else:
257
+ camera_coords = world_coords_h[..., :3] # (..., N, 3)
258
+
259
+ point_px_homogeneous = backend.matmul(
260
+ intrinsic_matrix, # (..., 3, 3)
261
+ backend.matrix_transpose(camera_coords) # (..., 3, N)
262
+ ) # (..., 3, N)
263
+ point_px_homogeneous = backend.matrix_transpose(point_px_homogeneous) # (..., N, 3)
264
+ point_px = point_px_homogeneous[..., :2] / point_px_homogeneous[..., 2:3] # (..., N, 2)
265
+
266
+ depth = camera_coords[..., 2] # (..., N)
267
+ depth_valid = depth > 0
268
+ depth = backend.where(depth_valid, depth, 0)
269
+ point_px = backend.where(
270
+ depth_valid[..., None],
271
+ point_px,
272
+ 0
273
+ )
274
+ return point_px, depth
275
+
276
+
277
+ def world_to_depth(
278
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
279
+ world_coords : BArrayType,
280
+ extrinsic_matrix : Optional[BArrayType] = None
281
+ ) -> BArrayType:
282
+ """
283
+ Convert world coordinates to pixel coordinates and depth.
284
+ Args:
285
+ backend (ComputeBackend): The compute backend to use.
286
+ world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
287
+ extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
288
+ Returns:
289
+ BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
290
+ """
291
+ if world_coords.shape[-1] == 3:
292
+ world_coords_h = backend.pad_dim(
293
+ world_coords,
294
+ dim=-1,
295
+ value=0
296
+ )
297
+ else:
298
+ assert world_coords.shape[-1] == 4
299
+ world_coords_h = world_coords
300
+
301
+ if extrinsic_matrix is not None:
302
+ camera_coords = backend.matmul(
303
+ extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
304
+ backend.matrix_transpose(world_coords_h) # (..., 4, N)
305
+ ) # (..., 3, N) or (..., 4, N)
306
+ camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
307
+ if camera_coords.shape[-1] == 4:
308
+ camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
309
+ else:
310
+ camera_coords = world_coords_h[..., :3] # (..., N, 3)
311
+
312
+ depth = camera_coords[..., 2] # (..., N)
313
+ depth_valid = depth > 0
314
+ depth = backend.where(depth_valid, depth, 0)
315
+ return depth
316
+
317
+ def farthest_point_sampling(
318
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
319
+ points : BArrayType,
320
+ num_samples : int,
321
+ rng : BRNGType,
322
+ points_valid : Optional[BArrayType] = None
323
+ ) -> Tuple[BRNGType, BArrayType, Optional[BArrayType]]:
324
+ """
325
+ Perform farthest point sampling on a set of points.
326
+ Args:
327
+ backend (ComputeBackend): The compute backend to use.
328
+ points (BArrayType): The input points of shape (..., N, D).
329
+ num_samples (int): The number of points to sample.
330
+ rng (BRNGType): The random number generator.
331
+ points_valid (Optional[BArrayType]): A boolean mask of shape (..., N) indicating valid points. If None, all points are considered valid.
332
+ Returns:
333
+ BRNGType: The updated random number generator.
334
+ BArrayType: The indices of the sampled points of shape (..., num_samples).
335
+ Optional[BArrayType]: The valid mask of shape (..., num_samples). Returned only if points_valid is provided.
336
+ """
337
+ assert 0 < num_samples <= points.shape[-2], "num_samples must be in (0, N]"
338
+ device = backend.device(points)
339
+
340
+ flat_points = backend.reshape(points, [-1, *points.shape[-2:]]) # (B, N, D)
341
+ B, N, D = flat_points.shape
342
+ flat_points_valid = None if points_valid is None else backend.reshape(points_valid, [-1, N]) # (B, N)
343
+
344
+ batch_indices = backend.arange(B, dtype=backend.default_index_dtype, device=device)
345
+
346
+ centroids_idx = backend.zeros((B, num_samples), dtype=backend.default_index_dtype, device=device) # sampled point indices
347
+ centroids_valid = None if flat_points_valid is None else backend.zeros((B, num_samples), dtype=backend.default_boolean_dtype, device=device) # valid mask of sampled points
348
+
349
+ distance = backend.full((B, N), backend.inf, device=device) # distance of each point to its nearest centroid
350
+ if flat_points_valid is not None:
351
+ distance = backend.where(
352
+ flat_points_valid,
353
+ distance,
354
+ -backend.inf
355
+ )
356
+
357
+ if flat_points_valid is not None:
358
+ farthest_idx = backend.argmax(
359
+ backend.astype(flat_points_valid, backend.default_index_dtype),
360
+ axis=1
361
+ )
362
+ else:
363
+ rng, farthest_idx = backend.random.random_discrete_uniform(
364
+ (B,),
365
+ 0, N,
366
+ rng=rng,
367
+ dtype=backend.default_index_dtype,
368
+ device=device
369
+ ) # initial random farthest point
370
+ centroids_idx[:, 0] = farthest_idx
371
+ if centroids_valid is not None and flat_points_valid is not None:
372
+ centroids_valid[:, 0] = flat_points_valid[batch_indices, farthest_idx]
373
+
374
+ for i in range(1, num_samples):
375
+ last_centroid = flat_points[batch_indices, farthest_idx][:, None, :] # (B, 1, D)
376
+ perpoint_dist_to_last_centroid = backend.sum((flat_points - last_centroid) ** 2, axis=-1) # (B, N)
377
+ distance = backend.minimum(
378
+ distance,
379
+ perpoint_dist_to_last_centroid
380
+ ) # (B, N)
381
+ farthest_idx = backend.argmax(distance, axis=1) # (B,)
382
+ centroids_idx[:, i] = farthest_idx
383
+ if centroids_valid is not None and flat_points_valid is not None:
384
+ centroids_valid[:, i] = flat_points_valid[batch_indices, farthest_idx]
385
+
386
+ unflat_centroids_idx = backend.reshape(centroids_idx, list(points.shape[:-2]) + [num_samples]) # (..., num_samples)
387
+ unflat_centroids_valid = None if centroids_valid is None else backend.reshape(centroids_valid, list(points.shape[:-2]) + [num_samples]) # (..., num_samples)
388
+ return rng, unflat_centroids_idx, unflat_centroids_valid
389
+
390
+ def random_point_sampling(
391
+ backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType],
392
+ points : BArrayType,
393
+ num_samples : int,
394
+ rng : BRNGType,
395
+ points_valid : Optional[BArrayType] = None
396
+ ) -> Tuple[BRNGType, BArrayType, Optional[BArrayType]]:
397
+ """
398
+ Perform random point sampling on a set of points.
399
+ Args:
400
+ backend (ComputeBackend): The compute backend to use.
401
+ points (BArrayType): The input points of shape (..., N, D).
402
+ num_samples (int): The number of points to sample.
403
+ rng (BRNGType): The random number generator.
404
+ points_valid (Optional[BArrayType]): A boolean mask of shape (..., N) indicating valid points. If None, all points are considered valid.
405
+ Returns:
406
+ BRNGType: The updated random number generator.
407
+ BArrayType: The indices of the sampled points of shape (..., num_samples).
408
+ Optional[BArrayType]: The valid mask of shape (..., num_samples). Returned only if points_valid is provided.
409
+ """
410
+ assert 0 < num_samples <= points.shape[-2], "num_samples must be in (0, N]"
411
+ device = backend.device(points)
412
+
413
+ flat_points = backend.reshape(points, [-1, *points.shape[-2:]]) # (B, N, D)
414
+ B, N, D = flat_points.shape
415
+ flat_points_valid = None if points_valid is None else backend.reshape(points_valid, [-1, N]) # (B, N)
416
+
417
+ if flat_points_valid is None:
418
+ sampled_idx = backend.empty((B, num_samples), dtype=backend.default_index_dtype, device=device)
419
+ for b in range(B):
420
+ rng, idx_b = backend.random.random_permutation(
421
+ N,
422
+ rng=rng,
423
+ device=device
424
+ )
425
+ sampled_idx[b] = idx_b[:num_samples]
426
+ unflat_sampled_idx = backend.reshape(sampled_idx, list(points.shape[:-2]) + [num_samples])
427
+ return rng, unflat_sampled_idx, None
428
+ else:
429
+ # valid_counts = backend.sum(
430
+ # backend.astype(flat_points_valid, backend.default_index_dtype),
431
+ # axis=1
432
+ # ) # (B,)
433
+ # assert bool(backend.all(valid_counts >= num_samples)), "Not enough valid points to sample from."
434
+ sampled_idx = backend.zeros((B, num_samples), dtype=backend.default_index_dtype, device=device)
435
+ sampled_valid = backend.zeros((B, num_samples), dtype=backend.default_boolean_dtype, device=device)
436
+ for b in range(B):
437
+ valid_indices_b = backend.nonzero(flat_points_valid[b])[0] # (valid_count_b,)
438
+ rng, permuted_valid_indices_b = backend.random.random_permutation(
439
+ valid_indices_b.shape[0],
440
+ rng=rng,
441
+ device=device
442
+ )
443
+ sampled_idx_b = valid_indices_b[permuted_valid_indices_b[:num_samples]] # (num_samples,)
444
+ sampled_idx[b, :sampled_idx_b.shape[0]] = sampled_idx_b
445
+ sampled_valid[b, :sampled_idx_b.shape[0]] = True
446
+ sampled_valid[b, sampled_idx_b.shape[0]:] = False
447
+ unflat_sampled_idx = backend.reshape(sampled_idx, list(points.shape[:-2]) + [num_samples])
448
+ unflat_sampled_valid = backend.reshape(sampled_valid, list(points.shape[:-2]) + [num_samples])
449
+ return rng, unflat_sampled_idx, unflat_sampled_valid
@@ -0,0 +1,24 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.jax import JaxComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "gather_pixel_value",
7
+ "bilinear_interpolate",
8
+ "pixel_coordinate_and_depth_to_world",
9
+ "depth_image_to_world",
10
+ "world_to_pixel_coordinate_and_depth",
11
+ "world_to_depth",
12
+ "farthest_point_sampling",
13
+ "random_point_sampling",
14
+ ]
15
+
16
+
17
+ gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
18
+ bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
19
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
20
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
21
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
22
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
23
+ farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
24
+ random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
@@ -0,0 +1,23 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.numpy import NumpyComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "gather_pixel_value",
7
+ "bilinear_interpolate",
8
+ "pixel_coordinate_and_depth_to_world",
9
+ "depth_image_to_world",
10
+ "world_to_pixel_coordinate_and_depth",
11
+ "world_to_depth",
12
+ "farthest_point_sampling",
13
+ "random_point_sampling",
14
+ ]
15
+
16
+ gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
17
+ bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
18
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
19
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
20
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
21
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
22
+ farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
23
+ random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
@@ -0,0 +1,23 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.pytorch import PytorchComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "gather_pixel_value",
7
+ "bilinear_interpolate",
8
+ "pixel_coordinate_and_depth_to_world",
9
+ "depth_image_to_world",
10
+ "world_to_pixel_coordinate_and_depth",
11
+ "world_to_depth",
12
+ "farthest_point_sampling",
13
+ "random_point_sampling",
14
+ ]
15
+
16
+ gather_pixel_value = partial(base_impl.gather_pixel_value, BindingBackend)
17
+ bilinear_interpolate = partial(base_impl.bilinear_interpolate, BindingBackend)
18
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
19
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
20
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
21
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
22
+ farthest_point_sampling = partial(base_impl.farthest_point_sampling, BindingBackend)
23
+ random_point_sampling = partial(base_impl.random_point_sampling, BindingBackend)
@@ -0,0 +1 @@
1
+ from .base import *