xbarray 0.0.1a7__tar.gz → 0.0.1a10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xbarray might be problematic. Click here for more details.

Files changed (63) hide show
  1. {xbarray-0.0.1a7/xbarray.egg-info → xbarray-0.0.1a10}/PKG-INFO +6 -1
  2. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/pyproject.toml +8 -1
  3. xbarray-0.0.1a10/xbarray/__init__.py +1 -0
  4. xbarray-0.0.1a10/xbarray/backends/_cls_base.py +9 -0
  5. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/_common/implementations.py +32 -7
  6. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/jax/_extra.py +4 -1
  7. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/jax/random.py +1 -1
  8. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/numpy/_extra.py +6 -2
  9. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/pytorch/_extra.py +5 -2
  10. {xbarray-0.0.1a7/xbarray/base → xbarray-0.0.1a10/xbarray/backends}/base.py +11 -1
  11. {xbarray-0.0.1a7/xbarray → xbarray-0.0.1a10/xbarray/backends}/jax.py +3 -3
  12. {xbarray-0.0.1a7/xbarray → xbarray-0.0.1a10/xbarray/backends}/numpy.py +3 -3
  13. xbarray-0.0.1a10/xbarray/backends/pytorch.py +22 -0
  14. xbarray-0.0.1a10/xbarray/jax.py +4 -0
  15. xbarray-0.0.1a10/xbarray/numpy.py +4 -0
  16. xbarray-0.0.1a10/xbarray/pytorch.py +4 -0
  17. xbarray-0.0.1a10/xbarray/transformations/pointcloud/__init__.py +1 -0
  18. xbarray-0.0.1a10/xbarray/transformations/pointcloud/base.py +185 -0
  19. xbarray-0.0.1a10/xbarray/transformations/pointcloud/jax.py +15 -0
  20. xbarray-0.0.1a10/xbarray/transformations/pointcloud/numpy.py +15 -0
  21. xbarray-0.0.1a10/xbarray/transformations/pointcloud/pytorch.py +15 -0
  22. xbarray-0.0.1a10/xbarray/transformations/rotation_conversions/__init__.py +1 -0
  23. xbarray-0.0.1a10/xbarray/transformations/rotation_conversions/base.py +712 -0
  24. xbarray-0.0.1a10/xbarray/transformations/rotation_conversions/jax.py +41 -0
  25. xbarray-0.0.1a10/xbarray/transformations/rotation_conversions/numpy.py +41 -0
  26. xbarray-0.0.1a10/xbarray/transformations/rotation_conversions/pytorch.py +41 -0
  27. {xbarray-0.0.1a7 → xbarray-0.0.1a10/xbarray.egg-info}/PKG-INFO +6 -1
  28. xbarray-0.0.1a10/xbarray.egg-info/SOURCES.txt +54 -0
  29. xbarray-0.0.1a7/xbarray/__init__.py +0 -1
  30. xbarray-0.0.1a7/xbarray/_src/serialization/__init__.py +0 -1
  31. xbarray-0.0.1a7/xbarray/_src/serialization/serialization_map.py +0 -31
  32. xbarray-0.0.1a7/xbarray/base/__init__.py +0 -1
  33. xbarray-0.0.1a7/xbarray/cls_impl/cls_base.py +0 -10
  34. xbarray-0.0.1a7/xbarray/pytorch.py +0 -19
  35. xbarray-0.0.1a7/xbarray.egg-info/SOURCES.txt +0 -44
  36. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/LICENSE +0 -0
  37. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/README.md +0 -0
  38. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/__init__.py +0 -0
  39. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/__init__.py +0 -0
  40. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_api_constant.py +0 -0
  41. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_api_fft_typing.py +0 -0
  42. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_api_linalg_typing.py +0 -0
  43. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_api_return_typing.py +0 -0
  44. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_api_typing.py +0 -0
  45. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_2024_12/_array_typing.py +0 -0
  46. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_compat/__init__.py +0 -0
  47. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_compat/_api_typing.py +0 -0
  48. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_compat/_array_typing.py +0 -0
  49. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_extra/__init__.py +0 -0
  50. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_extra/_api_typing.py +0 -0
  51. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/array_api_typing/typing_extra/_at.py +0 -0
  52. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/setup.cfg +0 -0
  53. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/jax/__init__.py +0 -0
  54. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/jax/_typing.py +0 -0
  55. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/numpy/__init__.py +0 -0
  56. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/numpy/_typing.py +0 -0
  57. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/numpy/random.py +0 -0
  58. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/pytorch/__init__.py +0 -0
  59. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/pytorch/_typing.py +0 -0
  60. {xbarray-0.0.1a7/xbarray/_src/implementations → xbarray-0.0.1a10/xbarray/backends/_implementations}/pytorch/random.py +0 -0
  61. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/xbarray.egg-info/dependency_links.txt +0 -0
  62. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/xbarray.egg-info/requires.txt +0 -0
  63. {xbarray-0.0.1a7 → xbarray-0.0.1a10}/xbarray.egg-info/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xbarray
3
- Version: 0.0.1a7
3
+ Version: 0.0.1a10
4
4
  Summary: Cross-backend Python array library based on the Array API Standard.
5
+ Project-URL: Homepage, https://github.com/UniEnvOrg/XBArray
6
+ Project-URL: Documentation, https://github.com/UniEnvOrg/XBArray
7
+ Project-URL: Repository, https://github.com/UniEnvOrg/XBArray
8
+ Project-URL: Issues, https://github.com/UniEnvOrg/XBArray/issues
9
+ Project-URL: Changelog, https://github.com/UniEnvOrg/XBArray/blob/main/CHANGELOG.md
5
10
  Requires-Python: >=3.10
6
11
  License-File: LICENSE
7
12
  Requires-Dist: typing_extensions>=4.5
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "xbarray"
3
3
  description = "Cross-backend Python array library based on the Array API Standard."
4
- version = "0.0.1a7"
4
+ version = "0.0.1a10"
5
5
  requires-python = ">= 3.10"
6
6
  dependencies = [
7
7
  "typing_extensions>=4.5",
@@ -18,6 +18,13 @@ jax = [
18
18
  "jax"
19
19
  ]
20
20
 
21
+ [project.urls]
22
+ Homepage = "https://github.com/UniEnvOrg/XBArray"
23
+ Documentation = "https://github.com/UniEnvOrg/XBArray"
24
+ Repository = "https://github.com/UniEnvOrg/XBArray"
25
+ Issues = "https://github.com/UniEnvOrg/XBArray/issues"
26
+ Changelog = "https://github.com/UniEnvOrg/XBArray/blob/main/CHANGELOG.md"
27
+
21
28
  [build-system]
22
29
  requires = ["setuptools"]
23
30
 
@@ -0,0 +1 @@
1
+ from .backends.base import * # Import Abstract Typings
@@ -0,0 +1,9 @@
1
+ from typing import Type, Generic
2
+ from .base import BArrayType, BDeviceType, BDtypeType, BRNGType
3
+
4
+ class ComputeBackendImplCls(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType], Type):
5
+ def __str__(self):
6
+ return self.simplified_name
7
+
8
+ def __repr__(self):
9
+ return self.simplified_name
@@ -1,10 +1,12 @@
1
- from typing import Any, Union, Callable
1
+ from typing import Any, Union, Callable, Mapping, Sequence
2
2
  from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
3
3
  import array_api_compat
4
+ import dataclasses
4
5
 
5
6
  __all__ = [
6
7
  "get_abbreviate_array_function",
7
8
  "get_map_fn_over_arrays_function",
9
+ "get_pad_dim_function",
8
10
  ]
9
11
 
10
12
  def get_abbreviate_array_function(
@@ -51,12 +53,35 @@ def get_map_fn_over_arrays_function(
51
53
  """
52
54
  if is_backendarray(data):
53
55
  return func(data)
54
- elif isinstance(data, dict):
55
- return {k: map_fn_over_arrays(v, func) for k, v in data.items()}
56
- elif isinstance(data, tuple):
57
- return tuple(map_fn_over_arrays(i, func) for i in data)
58
- elif isinstance(data, list):
59
- return [map_fn_over_arrays(i, func) for i in data]
56
+ elif isinstance(data, Mapping):
57
+ ret = {k: map_fn_over_arrays(v, func) for k, v in data.items()}
58
+ try:
59
+ return type(data)(**ret) # try to keep the same mapping type
60
+ except:
61
+ return ret
62
+ elif isinstance(data, Sequence):
63
+ ret = [map_fn_over_arrays(i, func) for i in data]
64
+ try:
65
+ return type(data)(ret) # try to keep the same sequence type
66
+ except:
67
+ return ret
68
+ elif dataclasses.is_dataclass(data):
69
+ return type(data)(**map_fn_over_arrays(dataclasses.asdict(data), func))
60
70
  else:
61
71
  return data
62
72
  return map_fn_over_arrays
73
+
74
+ def get_pad_dim_function(
75
+ backend : CompatNamespace[CompatArray, Any, Any],
76
+ ):
77
+ def pad_dim(x : CompatArray, dim : int, target_size : int, value : Union[float, int] = 0) -> CompatArray:
78
+ size_at_target = x.shape[dim]
79
+ if size_at_target > target_size:
80
+ raise ValueError(f"Cannot pad dimension {dim} of size {size_at_target} to smaller target size {target_size}")
81
+ if size_at_target == target_size:
82
+ return x
83
+ target_size = list(x.shape)
84
+ target_size[dim] = target_size - size_at_target
85
+ pad_value = backend.full(target_size, value, dtype=x.dtype, device=array_api_compat.device(x))
86
+ return backend.concat([x, pad_value], axis=dim)
87
+ return pad_dim
@@ -3,7 +3,7 @@ import jax
3
3
  import jax.numpy as jnp
4
4
  import numpy as np
5
5
  from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
6
- from xbarray.base import ComputeBackend, SupportsDLPack
6
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
7
7
 
8
8
  __all__ = [
9
9
  "default_integer_dtype",
@@ -96,3 +96,6 @@ def map_fn_over_arrays(
96
96
  func,
97
97
  data
98
98
  )
99
+ pad_dim = get_pad_dim_function(
100
+ backend=compat_module,
101
+ )
@@ -18,7 +18,7 @@ def random_number_generator(
18
18
  *,
19
19
  device : Optional[DEVICE_TYPE] = None
20
20
  ) -> RNG_TYPE:
21
- rng_seed = np.random.randint(0) if seed is None else seed
21
+ rng_seed = np.random.randint(65535) if seed is None else seed
22
22
  rng = jax.random.key(
23
23
  seed=rng_seed
24
24
  )
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Union, Optional
2
2
  import numpy as np
3
3
  from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
4
- from xbarray.base import ComputeBackend, SupportsDLPack
4
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
5
5
 
6
6
  __all__ = [
7
7
  "default_integer_dtype",
@@ -68,7 +68,7 @@ def dtype_is_boolean(
68
68
  ) -> bool:
69
69
  return dtype == np.bool_ or dtype == bool
70
70
 
71
- from .._common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
71
+ from .._common.implementations import *
72
72
  from array_api_compat import numpy as compat_module
73
73
  abbreviate_array = get_abbreviate_array_function(
74
74
  backend=compat_module,
@@ -80,4 +80,8 @@ abbreviate_array = get_abbreviate_array_function(
80
80
 
81
81
  map_fn_over_arrays = get_map_fn_over_arrays_function(
82
82
  is_backendarray=is_backendarray,
83
+ )
84
+
85
+ pad_dim = get_pad_dim_function(
86
+ backend=compat_module,
83
87
  )
@@ -2,7 +2,7 @@ from typing import Any, Union, Optional
2
2
  import numpy as np
3
3
  import torch
4
4
  from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
5
- from xbarray.base import ComputeBackend, SupportsDLPack
5
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
6
6
 
7
7
  PYTORCH_DTYPE_CAST_MAP = {
8
8
  torch.uint16: torch.int16,
@@ -95,7 +95,7 @@ def dtype_is_boolean(
95
95
  ) -> bool:
96
96
  return dtype == torch.bool
97
97
 
98
- from .._common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
98
+ from .._common.implementations import *
99
99
  from array_api_compat import torch as compat_module
100
100
  abbreviate_array = get_abbreviate_array_function(
101
101
  compat_module,
@@ -106,4 +106,7 @@ abbreviate_array = get_abbreviate_array_function(
106
106
  )
107
107
  map_fn_over_arrays = get_map_fn_over_arrays_function(
108
108
  is_backendarray=is_backendarray,
109
+ )
110
+ pad_dim = get_pad_dim_function(
111
+ backend=compat_module,
109
112
  )
@@ -189,7 +189,7 @@ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Pro
189
189
  Abbreviates an array to a single element if possible.
190
190
  Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
191
191
  """
192
- pass
192
+ raise NotImplementedError
193
193
 
194
194
  @abc.abstractmethod
195
195
  def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
@@ -197,3 +197,13 @@ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Pro
197
197
  Map a function over arrays in a data structure and produce a new data structure with the same shape.
198
198
  This is useful for applying a function to all arrays in a nested structure.
199
199
  """
200
+ raise NotImplementedError
201
+
202
+ @abc.abstractmethod
203
+ def pad_dim(self, x : BArrayType, dim : int, target_size : int, value : Union[float, int] = 0) -> BArrayType:
204
+ """
205
+ Pad a dimension of an array to a target size with a given value.
206
+ If the dimension is already the target size, return the original array.
207
+ If the dimension is larger than the target size, raise an error.
208
+ """
209
+ raise NotImplementedError
@@ -1,7 +1,7 @@
1
- from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
- from ._src.implementations import jax as jax_impl
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import jax as jax_impl
3
3
 
4
- class JaxComputeBackend(metaclass=ComputeBackendImplCls):
4
+ class JaxComputeBackend(metaclass=ComputeBackendImplCls[jax_impl.ARRAY_TYPE, jax_impl.DEVICE_TYPE, jax_impl.DTYPE_TYPE, jax_impl.RNG_TYPE]):
5
5
  ARRAY_TYPE = jax_impl.ARRAY_TYPE
6
6
  DTYPE_TYPE = jax_impl.DTYPE_TYPE
7
7
  DEVICE_TYPE = jax_impl.DEVICE_TYPE
@@ -1,7 +1,7 @@
1
- from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
- from ._src.implementations import numpy as numpy_impl
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import numpy as numpy_impl
3
3
 
4
- class NumpyComputeBackend(metaclass=ComputeBackendImplCls):
4
+ class NumpyComputeBackend(metaclass=ComputeBackendImplCls[numpy_impl.ARRAY_TYPE, numpy_impl.DEVICE_TYPE, numpy_impl.DTYPE_TYPE, numpy_impl.RNG_TYPE]):
5
5
  ARRAY_TYPE = numpy_impl.ARRAY_TYPE
6
6
  DTYPE_TYPE = numpy_impl.DTYPE_TYPE
7
7
  DEVICE_TYPE = numpy_impl.DEVICE_TYPE
@@ -0,0 +1,22 @@
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import pytorch as pytorch_impl
3
+
4
+ class PytorchComputeBackend(metaclass=ComputeBackendImplCls[pytorch_impl.ARRAY_TYPE, pytorch_impl.DEVICE_TYPE, pytorch_impl.DTYPE_TYPE, pytorch_impl.RNG_TYPE]):
5
+ ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
8
+ RNG_TYPE = pytorch_impl.RNG_TYPE
9
+
10
+ for name in dir(pytorch_impl):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
15
+ try:
16
+ setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
17
+ except AttributeError:
18
+ pass
19
+
20
+ __all__ = [
21
+ 'PytorchComputeBackend',
22
+ ]
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.jax` instead.
3
+ """
4
+ from .backends.jax import *
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.numpy` instead.
3
+ """
4
+ from .backends.numpy import *
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.pytorch` instead.
3
+ """
4
+ from .backends.pytorch import *
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1,185 @@
1
+ from typing import Optional
2
+ from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
3
+
4
+ __all__ = [
5
+ "pixel_coordinate_and_depth_to_world",
6
+ "depth_image_to_world",
7
+ "world_to_pixel_coordinate_and_depth",
8
+ "world_to_depth",
9
+ ]
10
+
11
+ def pixel_coordinate_and_depth_to_world(
12
+ backend : ComputeBackend,
13
+ pixel_coordinates : BArrayType,
14
+ depth : BArrayType,
15
+ intrinsic_matrix : BArrayType,
16
+ extrinsic_matrix : BArrayType
17
+ ) -> BArrayType:
18
+ """
19
+ Convert pixel coordinates and depth to world coordinates.
20
+ Args:
21
+ backend (ComputeBackend): The compute backend to use.
22
+ pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2).
23
+ depth (BArrayType): The depth values of shape (..., N). Assume invalid depth is either nan or <= 0.
24
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
25
+ extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
26
+ Returns:
27
+ BArrayType: The world coordinates of shape (..., N, 4). The last dimension is (x, y, z, valid_mask).
28
+ """
29
+ xs = pixel_coordinates[..., 0] # (..., N)
30
+ ys = pixel_coordinates[..., 1] # (..., N)
31
+ xs_norm = (xs - intrinsic_matrix[..., None, 0, 2]) / intrinsic_matrix[..., None, 0, 0] # (..., N)
32
+ ys_norm = (ys - intrinsic_matrix[..., None, 1, 2]) / intrinsic_matrix[..., None, 1, 1] # (..., N)
33
+
34
+ camera_coords = backend.stack([
35
+ xs_norm,
36
+ ys_norm,
37
+ backend.ones_like(depth)
38
+ ], dim=-1) # (..., N, 3)
39
+ camera_coords *= depth[..., None] # (..., N, 3)
40
+
41
+ R = extrinsic_matrix[..., :3, :3] # (..., 3, 3)
42
+ t = extrinsic_matrix[..., :3, 3] # (..., 3)
43
+
44
+ shifted_camera_coords = camera_coords - t[..., None, :] # (..., N, 3)
45
+ world_coords = backend.matmul(shifted_camera_coords, R) # (..., N, 3)
46
+
47
+ valid_depth_mask = backend.logical_not(backend.logical_or(
48
+ backend.isnan(depth),
49
+ depth <= 0
50
+ )) # (..., N)
51
+ return backend.concat([
52
+ world_coords,
53
+ valid_depth_mask[..., None]
54
+ ], dim=-1) # (..., N, 4)
55
+
56
+ def depth_image_to_world(
57
+ backend : ComputeBackend,
58
+ depth_image : BArrayType,
59
+ intrinsic_matrix : BArrayType,
60
+ extrinsic_matrix : BArrayType
61
+ ) -> BArrayType:
62
+ """
63
+ Convert a depth image to world coordinates.
64
+ Args:
65
+ backend (ComputeBackend): The compute backend to use.
66
+ depth_image (BArrayType): The depth image of shape (..., H, W).
67
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
68
+ extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
69
+ Returns:
70
+ BArrayType: The world coordinates of shape (..., H, W, 4). The last dimension is (x, y, z, valid_mask).
71
+ """
72
+ H, W = depth_image.shape[-2:]
73
+ ys, xs = backend.meshgrid(
74
+ backend.arange(H, device=backend.device(depth_image), dtype=depth_image.dtype),
75
+ backend.arange(W, device=backend.device(depth_image), dtype=depth_image.dtype),
76
+ indexing='ij'
77
+ ) # (H, W), (H, W)
78
+ pixel_coordinates = backend.stack([xs, ys], dim=-1) # (H, W, 2)
79
+ pixel_coordinates = backend.reshape(pixel_coordinates, [1] * (len(depth_image.shape) - 2) + [H * W, 2]) # (..., H * W, 2)
80
+ world_coords = pixel_coordinate_and_depth_to_world(
81
+ backend,
82
+ pixel_coordinates,
83
+ depth_image.reshape(depth_image.shape[:-2] + [H * W]), # (..., H * W)
84
+ intrinsic_matrix,
85
+ extrinsic_matrix
86
+ ) # (..., H * W, 4)
87
+ world_coords = backend.reshape(world_coords, depth_image.shape[:-2] + [H, W, 4]) # (..., H, W, 4)
88
+ return world_coords
89
+
90
+ def world_to_pixel_coordinate_and_depth(
91
+ backend : ComputeBackend,
92
+ world_coords : BArrayType,
93
+ intrinsic_matrix : BArrayType,
94
+ extrinsic_matrix : Optional[BArrayType] = None
95
+ ) -> BArrayType:
96
+ """
97
+ Convert world coordinates to pixel coordinates and depth.
98
+ Args:
99
+ backend (ComputeBackend): The compute backend to use.
100
+ world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
101
+ intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
102
+ extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
103
+ Returns:
104
+ BArrayType: The pixel coordinates xy of shape (..., N, 2).
105
+ BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
106
+ """
107
+ if world_coords.shape[-1] == 3:
108
+ world_coords_h = backend.pad_dim(
109
+ world_coords,
110
+ dim=-1,
111
+ value=0
112
+ )
113
+ else:
114
+ assert world_coords.shape[-1] == 4
115
+ world_coords_h = world_coords
116
+
117
+ if extrinsic_matrix is not None:
118
+ camera_coords = backend.matmul(
119
+ extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
120
+ backend.matrix_transpose(world_coords_h) # (..., 4, N)
121
+ ) # (..., 3, N) or (..., 4, N)
122
+ camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
123
+ if camera_coords.shape[-1] == 4:
124
+ camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
125
+ else:
126
+ camera_coords = world_coords_h[..., :3] # (..., N, 3)
127
+
128
+ point_px_homogeneous = backend.matmul(
129
+ intrinsic_matrix, # (..., 3, 3)
130
+ backend.matrix_transpose(camera_coords) # (..., 3, N)
131
+ ) # (..., 3, N)
132
+ point_px_homogeneous = backend.matrix_transpose(point_px_homogeneous) # (..., N, 3)
133
+ point_px = point_px_homogeneous[..., :2] / point_px_homogeneous[..., 2:3] # (..., N, 2)
134
+
135
+ depth = camera_coords[..., 2] # (..., N)
136
+ depth_valid = depth > 0
137
+ depth = backend.where(depth_valid, depth, 0)
138
+ point_px = backend.where(
139
+ depth_valid[..., None],
140
+ point_px,
141
+ 0
142
+ )
143
+ return point_px, depth
144
+
145
+
146
+ def world_to_depth(
147
+ backend : ComputeBackend,
148
+ world_coords : BArrayType,
149
+ extrinsic_matrix : Optional[BArrayType] = None
150
+ ) -> BArrayType:
151
+ """
152
+ Convert world coordinates to pixel coordinates and depth.
153
+ Args:
154
+ backend (ComputeBackend): The compute backend to use.
155
+ world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
156
+ extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
157
+ Returns:
158
+ BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
159
+ """
160
+ if world_coords.shape[-1] == 3:
161
+ world_coords_h = backend.pad_dim(
162
+ world_coords,
163
+ dim=-1,
164
+ value=0
165
+ )
166
+ else:
167
+ assert world_coords.shape[-1] == 4
168
+ world_coords_h = world_coords
169
+
170
+ if extrinsic_matrix is not None:
171
+ camera_coords = backend.matmul(
172
+ extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
173
+ backend.matrix_transpose(world_coords_h) # (..., 4, N)
174
+ ) # (..., 3, N) or (..., 4, N)
175
+ camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
176
+ if camera_coords.shape[-1] == 4:
177
+ camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
178
+ else:
179
+ camera_coords = world_coords_h[..., :3] # (..., N, 3)
180
+
181
+ depth = camera_coords[..., 2] # (..., N)
182
+ depth_valid = depth > 0
183
+ depth = backend.where(depth_valid, depth, 0)
184
+ return depth
185
+
@@ -0,0 +1,15 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.jax import JaxComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "pixel_coordinate_and_depth_to_world",
7
+ "depth_image_to_world",
8
+ "world_to_pixel_coordinate_and_depth",
9
+ "world_to_depth",
10
+ ]
11
+
12
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
13
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
14
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
15
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
@@ -0,0 +1,15 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.numpy import NumpyComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "pixel_coordinate_and_depth_to_world",
7
+ "depth_image_to_world",
8
+ "world_to_pixel_coordinate_and_depth",
9
+ "world_to_depth",
10
+ ]
11
+
12
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
13
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
14
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
15
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
@@ -0,0 +1,15 @@
1
+ from . import base as base_impl
2
+ from functools import partial
3
+ from xbarray.backends.pytorch import PytorchComputeBackend as BindingBackend
4
+
5
+ __all__ = [
6
+ "pixel_coordinate_and_depth_to_world",
7
+ "depth_image_to_world",
8
+ "world_to_pixel_coordinate_and_depth",
9
+ "world_to_depth",
10
+ ]
11
+
12
+ pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
13
+ depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
14
+ world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
15
+ world_to_depth = partial(base_impl.world_to_depth, BindingBackend)