xbarray 0.0.1a8__py3-none-any.whl → 0.0.1a10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xbarray might be problematic. Click here for more details.
- xbarray/__init__.py +1 -1
- xbarray/backends/_cls_base.py +9 -0
- xbarray/{_src/implementations → backends/_implementations}/_common/implementations.py +32 -7
- xbarray/{_src/implementations → backends/_implementations}/jax/_extra.py +4 -1
- xbarray/{_src/implementations → backends/_implementations}/numpy/_extra.py +6 -2
- xbarray/{_src/implementations → backends/_implementations}/pytorch/_extra.py +5 -2
- xbarray/{base → backends}/base.py +11 -1
- xbarray/backends/jax.py +19 -0
- xbarray/backends/numpy.py +19 -0
- xbarray/backends/pytorch.py +22 -0
- xbarray/jax.py +4 -19
- xbarray/numpy.py +4 -19
- xbarray/pytorch.py +4 -19
- xbarray/transformations/pointcloud/__init__.py +1 -0
- xbarray/transformations/pointcloud/base.py +185 -0
- xbarray/transformations/pointcloud/jax.py +15 -0
- xbarray/transformations/pointcloud/numpy.py +15 -0
- xbarray/transformations/pointcloud/pytorch.py +15 -0
- xbarray/transformations/rotation_conversions/__init__.py +1 -0
- xbarray/transformations/rotation_conversions/base.py +712 -0
- xbarray/transformations/rotation_conversions/jax.py +41 -0
- xbarray/transformations/rotation_conversions/numpy.py +41 -0
- xbarray/transformations/rotation_conversions/pytorch.py +41 -0
- {xbarray-0.0.1a8.dist-info → xbarray-0.0.1a10.dist-info}/METADATA +6 -1
- xbarray-0.0.1a10.dist-info/RECORD +51 -0
- xbarray/_src/serialization/__init__.py +0 -1
- xbarray/_src/serialization/serialization_map.py +0 -31
- xbarray/base/__init__.py +0 -1
- xbarray/cls_impl/cls_base.py +0 -10
- xbarray-0.0.1a8.dist-info/RECORD +0 -41
- /xbarray/{_src/implementations → backends/_implementations}/jax/__init__.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/jax/_typing.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/jax/random.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/numpy/__init__.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/numpy/_typing.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/numpy/random.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/pytorch/__init__.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/pytorch/_typing.py +0 -0
- /xbarray/{_src/implementations → backends/_implementations}/pytorch/random.py +0 -0
- {xbarray-0.0.1a8.dist-info → xbarray-0.0.1a10.dist-info}/WHEEL +0 -0
- {xbarray-0.0.1a8.dist-info → xbarray-0.0.1a10.dist-info}/licenses/LICENSE +0 -0
- {xbarray-0.0.1a8.dist-info → xbarray-0.0.1a10.dist-info}/top_level.txt +0 -0
xbarray/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from .base import * # Import Abstract Typings
|
|
1
|
+
from .backends.base import * # Import Abstract Typings
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from typing import Type, Generic
|
|
2
|
+
from .base import BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
|
|
4
|
+
class ComputeBackendImplCls(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType], Type):
|
|
5
|
+
def __str__(self):
|
|
6
|
+
return self.simplified_name
|
|
7
|
+
|
|
8
|
+
def __repr__(self):
|
|
9
|
+
return self.simplified_name
|
|
@@ -1,10 +1,12 @@
|
|
|
1
|
-
from typing import Any, Union, Callable
|
|
1
|
+
from typing import Any, Union, Callable, Mapping, Sequence
|
|
2
2
|
from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
|
|
3
3
|
import array_api_compat
|
|
4
|
+
import dataclasses
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
"get_abbreviate_array_function",
|
|
7
8
|
"get_map_fn_over_arrays_function",
|
|
9
|
+
"get_pad_dim_function",
|
|
8
10
|
]
|
|
9
11
|
|
|
10
12
|
def get_abbreviate_array_function(
|
|
@@ -51,12 +53,35 @@ def get_map_fn_over_arrays_function(
|
|
|
51
53
|
"""
|
|
52
54
|
if is_backendarray(data):
|
|
53
55
|
return func(data)
|
|
54
|
-
elif isinstance(data,
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
56
|
+
elif isinstance(data, Mapping):
|
|
57
|
+
ret = {k: map_fn_over_arrays(v, func) for k, v in data.items()}
|
|
58
|
+
try:
|
|
59
|
+
return type(data)(**ret) # try to keep the same mapping type
|
|
60
|
+
except:
|
|
61
|
+
return ret
|
|
62
|
+
elif isinstance(data, Sequence):
|
|
63
|
+
ret = [map_fn_over_arrays(i, func) for i in data]
|
|
64
|
+
try:
|
|
65
|
+
return type(data)(ret) # try to keep the same sequence type
|
|
66
|
+
except:
|
|
67
|
+
return ret
|
|
68
|
+
elif dataclasses.is_dataclass(data):
|
|
69
|
+
return type(data)(**map_fn_over_arrays(dataclasses.asdict(data), func))
|
|
60
70
|
else:
|
|
61
71
|
return data
|
|
62
72
|
return map_fn_over_arrays
|
|
73
|
+
|
|
74
|
+
def get_pad_dim_function(
|
|
75
|
+
backend : CompatNamespace[CompatArray, Any, Any],
|
|
76
|
+
):
|
|
77
|
+
def pad_dim(x : CompatArray, dim : int, target_size : int, value : Union[float, int] = 0) -> CompatArray:
|
|
78
|
+
size_at_target = x.shape[dim]
|
|
79
|
+
if size_at_target > target_size:
|
|
80
|
+
raise ValueError(f"Cannot pad dimension {dim} of size {size_at_target} to smaller target size {target_size}")
|
|
81
|
+
if size_at_target == target_size:
|
|
82
|
+
return x
|
|
83
|
+
target_size = list(x.shape)
|
|
84
|
+
target_size[dim] = target_size - size_at_target
|
|
85
|
+
pad_value = backend.full(target_size, value, dtype=x.dtype, device=array_api_compat.device(x))
|
|
86
|
+
return backend.concat([x, pad_value], axis=dim)
|
|
87
|
+
return pad_dim
|
|
@@ -3,7 +3,7 @@ import jax
|
|
|
3
3
|
import jax.numpy as jnp
|
|
4
4
|
import numpy as np
|
|
5
5
|
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
6
|
-
from xbarray.base import ComputeBackend, SupportsDLPack
|
|
6
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
9
|
"default_integer_dtype",
|
|
@@ -96,3 +96,6 @@ def map_fn_over_arrays(
|
|
|
96
96
|
func,
|
|
97
97
|
data
|
|
98
98
|
)
|
|
99
|
+
pad_dim = get_pad_dim_function(
|
|
100
|
+
backend=compat_module,
|
|
101
|
+
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Union, Optional
|
|
2
2
|
import numpy as np
|
|
3
3
|
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
4
|
-
from xbarray.base import ComputeBackend, SupportsDLPack
|
|
4
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
5
5
|
|
|
6
6
|
__all__ = [
|
|
7
7
|
"default_integer_dtype",
|
|
@@ -68,7 +68,7 @@ def dtype_is_boolean(
|
|
|
68
68
|
) -> bool:
|
|
69
69
|
return dtype == np.bool_ or dtype == bool
|
|
70
70
|
|
|
71
|
-
from .._common.implementations import
|
|
71
|
+
from .._common.implementations import *
|
|
72
72
|
from array_api_compat import numpy as compat_module
|
|
73
73
|
abbreviate_array = get_abbreviate_array_function(
|
|
74
74
|
backend=compat_module,
|
|
@@ -80,4 +80,8 @@ abbreviate_array = get_abbreviate_array_function(
|
|
|
80
80
|
|
|
81
81
|
map_fn_over_arrays = get_map_fn_over_arrays_function(
|
|
82
82
|
is_backendarray=is_backendarray,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
pad_dim = get_pad_dim_function(
|
|
86
|
+
backend=compat_module,
|
|
83
87
|
)
|
|
@@ -2,7 +2,7 @@ from typing import Any, Union, Optional
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
import torch
|
|
4
4
|
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
5
|
-
from xbarray.base import ComputeBackend, SupportsDLPack
|
|
5
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
6
6
|
|
|
7
7
|
PYTORCH_DTYPE_CAST_MAP = {
|
|
8
8
|
torch.uint16: torch.int16,
|
|
@@ -95,7 +95,7 @@ def dtype_is_boolean(
|
|
|
95
95
|
) -> bool:
|
|
96
96
|
return dtype == torch.bool
|
|
97
97
|
|
|
98
|
-
from .._common.implementations import
|
|
98
|
+
from .._common.implementations import *
|
|
99
99
|
from array_api_compat import torch as compat_module
|
|
100
100
|
abbreviate_array = get_abbreviate_array_function(
|
|
101
101
|
compat_module,
|
|
@@ -106,4 +106,7 @@ abbreviate_array = get_abbreviate_array_function(
|
|
|
106
106
|
)
|
|
107
107
|
map_fn_over_arrays = get_map_fn_over_arrays_function(
|
|
108
108
|
is_backendarray=is_backendarray,
|
|
109
|
+
)
|
|
110
|
+
pad_dim = get_pad_dim_function(
|
|
111
|
+
backend=compat_module,
|
|
109
112
|
)
|
|
@@ -189,7 +189,7 @@ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Pro
|
|
|
189
189
|
Abbreviates an array to a single element if possible.
|
|
190
190
|
Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
|
|
191
191
|
"""
|
|
192
|
-
|
|
192
|
+
raise NotImplementedError
|
|
193
193
|
|
|
194
194
|
@abc.abstractmethod
|
|
195
195
|
def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
|
|
@@ -197,3 +197,13 @@ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Pro
|
|
|
197
197
|
Map a function over arrays in a data structure and produce a new data structure with the same shape.
|
|
198
198
|
This is useful for applying a function to all arrays in a nested structure.
|
|
199
199
|
"""
|
|
200
|
+
raise NotImplementedError
|
|
201
|
+
|
|
202
|
+
@abc.abstractmethod
|
|
203
|
+
def pad_dim(self, x : BArrayType, dim : int, target_size : int, value : Union[float, int] = 0) -> BArrayType:
|
|
204
|
+
"""
|
|
205
|
+
Pad a dimension of an array to a target size with a given value.
|
|
206
|
+
If the dimension is already the target size, return the original array.
|
|
207
|
+
If the dimension is larger than the target size, raise an error.
|
|
208
|
+
"""
|
|
209
|
+
raise NotImplementedError
|
xbarray/backends/jax.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import jax as jax_impl
|
|
3
|
+
|
|
4
|
+
class JaxComputeBackend(metaclass=ComputeBackendImplCls[jax_impl.ARRAY_TYPE, jax_impl.DEVICE_TYPE, jax_impl.DTYPE_TYPE, jax_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = jax_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = jax_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = jax_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = jax_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(jax_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
setattr(JaxComputeBackend, name, getattr(jax_impl, name))
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
'JaxComputeBackend',
|
|
19
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import numpy as numpy_impl
|
|
3
|
+
|
|
4
|
+
class NumpyComputeBackend(metaclass=ComputeBackendImplCls[numpy_impl.ARRAY_TYPE, numpy_impl.DEVICE_TYPE, numpy_impl.DTYPE_TYPE, numpy_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = numpy_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = numpy_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = numpy_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = numpy_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(numpy_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
'NumpyComputeBackend',
|
|
19
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import pytorch as pytorch_impl
|
|
3
|
+
|
|
4
|
+
class PytorchComputeBackend(metaclass=ComputeBackendImplCls[pytorch_impl.ARRAY_TYPE, pytorch_impl.DEVICE_TYPE, pytorch_impl.DTYPE_TYPE, pytorch_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = pytorch_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(pytorch_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
try:
|
|
16
|
+
setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
|
|
17
|
+
except AttributeError:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'PytorchComputeBackend',
|
|
22
|
+
]
|
xbarray/jax.py
CHANGED
|
@@ -1,19 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
ARRAY_TYPE = jax_impl.ARRAY_TYPE
|
|
6
|
-
DTYPE_TYPE = jax_impl.DTYPE_TYPE
|
|
7
|
-
DEVICE_TYPE = jax_impl.DEVICE_TYPE
|
|
8
|
-
RNG_TYPE = jax_impl.RNG_TYPE
|
|
9
|
-
|
|
10
|
-
for name in dir(jax_impl):
|
|
11
|
-
if not name.startswith('_') or name in [
|
|
12
|
-
'__array_namespace_info__',
|
|
13
|
-
'__array_api_version__',
|
|
14
|
-
]:
|
|
15
|
-
setattr(JaxComputeBackend, name, getattr(jax_impl, name))
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
'JaxComputeBackend',
|
|
19
|
-
]
|
|
1
|
+
""" This file is generated for backward compatibility.
|
|
2
|
+
If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.jax` instead.
|
|
3
|
+
"""
|
|
4
|
+
from .backends.jax import *
|
xbarray/numpy.py
CHANGED
|
@@ -1,19 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
ARRAY_TYPE = numpy_impl.ARRAY_TYPE
|
|
6
|
-
DTYPE_TYPE = numpy_impl.DTYPE_TYPE
|
|
7
|
-
DEVICE_TYPE = numpy_impl.DEVICE_TYPE
|
|
8
|
-
RNG_TYPE = numpy_impl.RNG_TYPE
|
|
9
|
-
|
|
10
|
-
for name in dir(numpy_impl):
|
|
11
|
-
if not name.startswith('_') or name in [
|
|
12
|
-
'__array_namespace_info__',
|
|
13
|
-
'__array_api_version__',
|
|
14
|
-
]:
|
|
15
|
-
setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
'NumpyComputeBackend',
|
|
19
|
-
]
|
|
1
|
+
""" This file is generated for backward compatibility.
|
|
2
|
+
If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.numpy` instead.
|
|
3
|
+
"""
|
|
4
|
+
from .backends.numpy import *
|
xbarray/pytorch.py
CHANGED
|
@@ -1,19 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
|
|
6
|
-
DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
|
|
7
|
-
DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
|
|
8
|
-
RNG_TYPE = pytorch_impl.RNG_TYPE
|
|
9
|
-
|
|
10
|
-
for name in dir(pytorch_impl):
|
|
11
|
-
if not name.startswith('_') or name in [
|
|
12
|
-
'__array_namespace_info__',
|
|
13
|
-
'__array_api_version__',
|
|
14
|
-
]:
|
|
15
|
-
setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
|
|
16
|
-
|
|
17
|
-
__all__ = [
|
|
18
|
-
'PytorchComputeBackend',
|
|
19
|
-
]
|
|
1
|
+
""" This file is generated for backward compatibility.
|
|
2
|
+
If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.pytorch` instead.
|
|
3
|
+
"""
|
|
4
|
+
from .backends.pytorch import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from xbarray.backends.base import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"pixel_coordinate_and_depth_to_world",
|
|
6
|
+
"depth_image_to_world",
|
|
7
|
+
"world_to_pixel_coordinate_and_depth",
|
|
8
|
+
"world_to_depth",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
def pixel_coordinate_and_depth_to_world(
|
|
12
|
+
backend : ComputeBackend,
|
|
13
|
+
pixel_coordinates : BArrayType,
|
|
14
|
+
depth : BArrayType,
|
|
15
|
+
intrinsic_matrix : BArrayType,
|
|
16
|
+
extrinsic_matrix : BArrayType
|
|
17
|
+
) -> BArrayType:
|
|
18
|
+
"""
|
|
19
|
+
Convert pixel coordinates and depth to world coordinates.
|
|
20
|
+
Args:
|
|
21
|
+
backend (ComputeBackend): The compute backend to use.
|
|
22
|
+
pixel_coordinates (BArrayType): The pixel coordinates of shape (..., N, 2).
|
|
23
|
+
depth (BArrayType): The depth values of shape (..., N). Assume invalid depth is either nan or <= 0.
|
|
24
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
25
|
+
extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
|
|
26
|
+
Returns:
|
|
27
|
+
BArrayType: The world coordinates of shape (..., N, 4). The last dimension is (x, y, z, valid_mask).
|
|
28
|
+
"""
|
|
29
|
+
xs = pixel_coordinates[..., 0] # (..., N)
|
|
30
|
+
ys = pixel_coordinates[..., 1] # (..., N)
|
|
31
|
+
xs_norm = (xs - intrinsic_matrix[..., None, 0, 2]) / intrinsic_matrix[..., None, 0, 0] # (..., N)
|
|
32
|
+
ys_norm = (ys - intrinsic_matrix[..., None, 1, 2]) / intrinsic_matrix[..., None, 1, 1] # (..., N)
|
|
33
|
+
|
|
34
|
+
camera_coords = backend.stack([
|
|
35
|
+
xs_norm,
|
|
36
|
+
ys_norm,
|
|
37
|
+
backend.ones_like(depth)
|
|
38
|
+
], dim=-1) # (..., N, 3)
|
|
39
|
+
camera_coords *= depth[..., None] # (..., N, 3)
|
|
40
|
+
|
|
41
|
+
R = extrinsic_matrix[..., :3, :3] # (..., 3, 3)
|
|
42
|
+
t = extrinsic_matrix[..., :3, 3] # (..., 3)
|
|
43
|
+
|
|
44
|
+
shifted_camera_coords = camera_coords - t[..., None, :] # (..., N, 3)
|
|
45
|
+
world_coords = backend.matmul(shifted_camera_coords, R) # (..., N, 3)
|
|
46
|
+
|
|
47
|
+
valid_depth_mask = backend.logical_not(backend.logical_or(
|
|
48
|
+
backend.isnan(depth),
|
|
49
|
+
depth <= 0
|
|
50
|
+
)) # (..., N)
|
|
51
|
+
return backend.concat([
|
|
52
|
+
world_coords,
|
|
53
|
+
valid_depth_mask[..., None]
|
|
54
|
+
], dim=-1) # (..., N, 4)
|
|
55
|
+
|
|
56
|
+
def depth_image_to_world(
|
|
57
|
+
backend : ComputeBackend,
|
|
58
|
+
depth_image : BArrayType,
|
|
59
|
+
intrinsic_matrix : BArrayType,
|
|
60
|
+
extrinsic_matrix : BArrayType
|
|
61
|
+
) -> BArrayType:
|
|
62
|
+
"""
|
|
63
|
+
Convert a depth image to world coordinates.
|
|
64
|
+
Args:
|
|
65
|
+
backend (ComputeBackend): The compute backend to use.
|
|
66
|
+
depth_image (BArrayType): The depth image of shape (..., H, W).
|
|
67
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
68
|
+
extrinsic_matrix (BArrayType): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4).
|
|
69
|
+
Returns:
|
|
70
|
+
BArrayType: The world coordinates of shape (..., H, W, 4). The last dimension is (x, y, z, valid_mask).
|
|
71
|
+
"""
|
|
72
|
+
H, W = depth_image.shape[-2:]
|
|
73
|
+
ys, xs = backend.meshgrid(
|
|
74
|
+
backend.arange(H, device=backend.device(depth_image), dtype=depth_image.dtype),
|
|
75
|
+
backend.arange(W, device=backend.device(depth_image), dtype=depth_image.dtype),
|
|
76
|
+
indexing='ij'
|
|
77
|
+
) # (H, W), (H, W)
|
|
78
|
+
pixel_coordinates = backend.stack([xs, ys], dim=-1) # (H, W, 2)
|
|
79
|
+
pixel_coordinates = backend.reshape(pixel_coordinates, [1] * (len(depth_image.shape) - 2) + [H * W, 2]) # (..., H * W, 2)
|
|
80
|
+
world_coords = pixel_coordinate_and_depth_to_world(
|
|
81
|
+
backend,
|
|
82
|
+
pixel_coordinates,
|
|
83
|
+
depth_image.reshape(depth_image.shape[:-2] + [H * W]), # (..., H * W)
|
|
84
|
+
intrinsic_matrix,
|
|
85
|
+
extrinsic_matrix
|
|
86
|
+
) # (..., H * W, 4)
|
|
87
|
+
world_coords = backend.reshape(world_coords, depth_image.shape[:-2] + [H, W, 4]) # (..., H, W, 4)
|
|
88
|
+
return world_coords
|
|
89
|
+
|
|
90
|
+
def world_to_pixel_coordinate_and_depth(
|
|
91
|
+
backend : ComputeBackend,
|
|
92
|
+
world_coords : BArrayType,
|
|
93
|
+
intrinsic_matrix : BArrayType,
|
|
94
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
95
|
+
) -> BArrayType:
|
|
96
|
+
"""
|
|
97
|
+
Convert world coordinates to pixel coordinates and depth.
|
|
98
|
+
Args:
|
|
99
|
+
backend (ComputeBackend): The compute backend to use.
|
|
100
|
+
world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
|
|
101
|
+
intrinsic_matrix (BArrayType): The camera intrinsic matrix of shape (..., 3, 3).
|
|
102
|
+
extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
|
|
103
|
+
Returns:
|
|
104
|
+
BArrayType: The pixel coordinates xy of shape (..., N, 2).
|
|
105
|
+
BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
|
|
106
|
+
"""
|
|
107
|
+
if world_coords.shape[-1] == 3:
|
|
108
|
+
world_coords_h = backend.pad_dim(
|
|
109
|
+
world_coords,
|
|
110
|
+
dim=-1,
|
|
111
|
+
value=0
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
assert world_coords.shape[-1] == 4
|
|
115
|
+
world_coords_h = world_coords
|
|
116
|
+
|
|
117
|
+
if extrinsic_matrix is not None:
|
|
118
|
+
camera_coords = backend.matmul(
|
|
119
|
+
extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
|
|
120
|
+
backend.matrix_transpose(world_coords_h) # (..., 4, N)
|
|
121
|
+
) # (..., 3, N) or (..., 4, N)
|
|
122
|
+
camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
|
|
123
|
+
if camera_coords.shape[-1] == 4:
|
|
124
|
+
camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
|
|
125
|
+
else:
|
|
126
|
+
camera_coords = world_coords_h[..., :3] # (..., N, 3)
|
|
127
|
+
|
|
128
|
+
point_px_homogeneous = backend.matmul(
|
|
129
|
+
intrinsic_matrix, # (..., 3, 3)
|
|
130
|
+
backend.matrix_transpose(camera_coords) # (..., 3, N)
|
|
131
|
+
) # (..., 3, N)
|
|
132
|
+
point_px_homogeneous = backend.matrix_transpose(point_px_homogeneous) # (..., N, 3)
|
|
133
|
+
point_px = point_px_homogeneous[..., :2] / point_px_homogeneous[..., 2:3] # (..., N, 2)
|
|
134
|
+
|
|
135
|
+
depth = camera_coords[..., 2] # (..., N)
|
|
136
|
+
depth_valid = depth > 0
|
|
137
|
+
depth = backend.where(depth_valid, depth, 0)
|
|
138
|
+
point_px = backend.where(
|
|
139
|
+
depth_valid[..., None],
|
|
140
|
+
point_px,
|
|
141
|
+
0
|
|
142
|
+
)
|
|
143
|
+
return point_px, depth
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def world_to_depth(
|
|
147
|
+
backend : ComputeBackend,
|
|
148
|
+
world_coords : BArrayType,
|
|
149
|
+
extrinsic_matrix : Optional[BArrayType] = None
|
|
150
|
+
) -> BArrayType:
|
|
151
|
+
"""
|
|
152
|
+
Convert world coordinates to pixel coordinates and depth.
|
|
153
|
+
Args:
|
|
154
|
+
backend (ComputeBackend): The compute backend to use.
|
|
155
|
+
world_coords (BArrayType): The world coordinates of shape (..., N, 3) or (..., N, 4). If the last dimension is 4, the last element is treated as a valid mask.
|
|
156
|
+
extrinsic_matrix (Optional[BArrayType]): The camera extrinsic matrix of shape (..., 3, 4) or (..., 4, 4). If None, assume identity matrix.
|
|
157
|
+
Returns:
|
|
158
|
+
BArrayType: The depth values of shape (..., N). Invalid points (where valid mask is False) will have depth 0.
|
|
159
|
+
"""
|
|
160
|
+
if world_coords.shape[-1] == 3:
|
|
161
|
+
world_coords_h = backend.pad_dim(
|
|
162
|
+
world_coords,
|
|
163
|
+
dim=-1,
|
|
164
|
+
value=0
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
assert world_coords.shape[-1] == 4
|
|
168
|
+
world_coords_h = world_coords
|
|
169
|
+
|
|
170
|
+
if extrinsic_matrix is not None:
|
|
171
|
+
camera_coords = backend.matmul(
|
|
172
|
+
extrinsic_matrix, # (..., 3, 4) or (..., 4, 4)
|
|
173
|
+
backend.matrix_transpose(world_coords_h) # (..., 4, N)
|
|
174
|
+
) # (..., 3, N) or (..., 4, N)
|
|
175
|
+
camera_coords = backend.matrix_transpose(camera_coords) # (..., N, 3) or (..., N, 4)
|
|
176
|
+
if camera_coords.shape[-1] == 4:
|
|
177
|
+
camera_coords = camera_coords[..., :3] / camera_coords[..., 3:4]
|
|
178
|
+
else:
|
|
179
|
+
camera_coords = world_coords_h[..., :3] # (..., N, 3)
|
|
180
|
+
|
|
181
|
+
depth = camera_coords[..., 2] # (..., N)
|
|
182
|
+
depth_valid = depth > 0
|
|
183
|
+
depth = backend.where(depth_valid, depth, 0)
|
|
184
|
+
return depth
|
|
185
|
+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.jax import JaxComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"pixel_coordinate_and_depth_to_world",
|
|
7
|
+
"depth_image_to_world",
|
|
8
|
+
"world_to_pixel_coordinate_and_depth",
|
|
9
|
+
"world_to_depth",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
13
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
14
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
15
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.numpy import NumpyComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"pixel_coordinate_and_depth_to_world",
|
|
7
|
+
"depth_image_to_world",
|
|
8
|
+
"world_to_pixel_coordinate_and_depth",
|
|
9
|
+
"world_to_depth",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
13
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
14
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
15
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from . import base as base_impl
|
|
2
|
+
from functools import partial
|
|
3
|
+
from xbarray.backends.pytorch import PytorchComputeBackend as BindingBackend
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"pixel_coordinate_and_depth_to_world",
|
|
7
|
+
"depth_image_to_world",
|
|
8
|
+
"world_to_pixel_coordinate_and_depth",
|
|
9
|
+
"world_to_depth",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
pixel_coordinate_and_depth_to_world = partial(base_impl.pixel_coordinate_and_depth_to_world, BindingBackend)
|
|
13
|
+
depth_image_to_world = partial(base_impl.depth_image_to_world, BindingBackend)
|
|
14
|
+
world_to_pixel_coordinate_and_depth = partial(base_impl.world_to_pixel_coordinate_and_depth, BindingBackend)
|
|
15
|
+
world_to_depth = partial(base_impl.world_to_depth, BindingBackend)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|