xbarray 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- array_api_typing/__init__.py +9 -0
- array_api_typing/typing_2024_12/__init__.py +12 -0
- array_api_typing/typing_2024_12/_api_constant.py +32 -0
- array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
- array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
- array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
- array_api_typing/typing_2024_12/_api_typing.py +5855 -0
- array_api_typing/typing_2024_12/_array_typing.py +1265 -0
- array_api_typing/typing_compat/__init__.py +12 -0
- array_api_typing/typing_compat/_api_typing.py +27 -0
- array_api_typing/typing_compat/_array_typing.py +36 -0
- array_api_typing/typing_extra/__init__.py +12 -0
- array_api_typing/typing_extra/_api_typing.py +651 -0
- array_api_typing/typing_extra/_at.py +87 -0
- xbarray/__init__.py +1 -0
- xbarray/backends/_cls_base.py +9 -0
- xbarray/backends/_implementations/_common/implementations.py +87 -0
- xbarray/backends/_implementations/jax/__init__.py +33 -0
- xbarray/backends/_implementations/jax/_extra.py +127 -0
- xbarray/backends/_implementations/jax/_typing.py +15 -0
- xbarray/backends/_implementations/jax/random.py +115 -0
- xbarray/backends/_implementations/numpy/__init__.py +25 -0
- xbarray/backends/_implementations/numpy/_extra.py +98 -0
- xbarray/backends/_implementations/numpy/_typing.py +14 -0
- xbarray/backends/_implementations/numpy/random.py +105 -0
- xbarray/backends/_implementations/pytorch/__init__.py +26 -0
- xbarray/backends/_implementations/pytorch/_extra.py +135 -0
- xbarray/backends/_implementations/pytorch/_typing.py +13 -0
- xbarray/backends/_implementations/pytorch/random.py +101 -0
- xbarray/backends/base.py +218 -0
- xbarray/backends/jax.py +19 -0
- xbarray/backends/numpy.py +19 -0
- xbarray/backends/pytorch.py +22 -0
- xbarray/jax.py +4 -0
- xbarray/numpy.py +4 -0
- xbarray/pytorch.py +4 -0
- xbarray/transformations/pointcloud/__init__.py +1 -0
- xbarray/transformations/pointcloud/base.py +449 -0
- xbarray/transformations/pointcloud/jax.py +24 -0
- xbarray/transformations/pointcloud/numpy.py +23 -0
- xbarray/transformations/pointcloud/pytorch.py +23 -0
- xbarray/transformations/rotation_conversions/__init__.py +1 -0
- xbarray/transformations/rotation_conversions/base.py +713 -0
- xbarray/transformations/rotation_conversions/jax.py +41 -0
- xbarray/transformations/rotation_conversions/numpy.py +41 -0
- xbarray/transformations/rotation_conversions/pytorch.py +41 -0
- xbarray-0.0.1a13.dist-info/METADATA +20 -0
- xbarray-0.0.1a13.dist-info/RECORD +51 -0
- xbarray-0.0.1a13.dist-info/WHEEL +5 -0
- xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
- xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Protocol, TypeVar, Optional, Any, Tuple, Union, Type, TypedDict, List
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from array_api_typing.typing_compat._api_typing import _NAMESPACE_ARRAY
|
|
4
|
+
from array_api_typing.typing_compat._array_typing import SetIndex
|
|
5
|
+
|
|
6
|
+
class AtResult(Protocol[_NAMESPACE_ARRAY]):
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def __getitem__(self, idx: SetIndex, /) -> "AtResult[_NAMESPACE_ARRAY]":
|
|
9
|
+
"""
|
|
10
|
+
Allow for the alternate syntax ``at(x)[start:stop:step]``.
|
|
11
|
+
|
|
12
|
+
It looks prettier than ``at(x, slice(start, stop, step))``
|
|
13
|
+
and feels more intuitive coming from the JAX documentation.
|
|
14
|
+
"""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def set(
|
|
19
|
+
self,
|
|
20
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
21
|
+
/,
|
|
22
|
+
copy: Optional[bool] = None,
|
|
23
|
+
) -> _NAMESPACE_ARRAY:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def add(
|
|
28
|
+
self,
|
|
29
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
30
|
+
/,
|
|
31
|
+
copy: Optional[bool] = None,
|
|
32
|
+
) -> _NAMESPACE_ARRAY:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def subtract(
|
|
37
|
+
self,
|
|
38
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
39
|
+
/,
|
|
40
|
+
copy: Optional[bool] = None,
|
|
41
|
+
) -> _NAMESPACE_ARRAY:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def multiply(
|
|
46
|
+
self,
|
|
47
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
48
|
+
/,
|
|
49
|
+
copy: Optional[bool] = None,
|
|
50
|
+
) -> _NAMESPACE_ARRAY:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def divide(
|
|
55
|
+
self,
|
|
56
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
57
|
+
/,
|
|
58
|
+
copy: Optional[bool] = None,
|
|
59
|
+
) -> _NAMESPACE_ARRAY:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def power(
|
|
64
|
+
self,
|
|
65
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
66
|
+
/,
|
|
67
|
+
copy: Optional[bool] = None,
|
|
68
|
+
) -> _NAMESPACE_ARRAY:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def min(
|
|
73
|
+
self,
|
|
74
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
75
|
+
/,
|
|
76
|
+
copy: Optional[bool] = None,
|
|
77
|
+
) -> _NAMESPACE_ARRAY:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def max(
|
|
82
|
+
self,
|
|
83
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
84
|
+
/,
|
|
85
|
+
copy: Optional[bool] = None,
|
|
86
|
+
) -> _NAMESPACE_ARRAY:
|
|
87
|
+
pass
|
xbarray/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .backends.base import * # Import Abstract Typings
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from typing import Type, Generic
|
|
2
|
+
from .base import BArrayType, BDeviceType, BDtypeType, BRNGType
|
|
3
|
+
|
|
4
|
+
class ComputeBackendImplCls(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType], Type):
|
|
5
|
+
def __str__(self):
|
|
6
|
+
return self.simplified_name
|
|
7
|
+
|
|
8
|
+
def __repr__(self):
|
|
9
|
+
return self.simplified_name
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Any, Union, Callable, Mapping, Sequence
|
|
2
|
+
from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
|
|
3
|
+
import array_api_compat
|
|
4
|
+
import dataclasses
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"get_abbreviate_array_function",
|
|
8
|
+
"get_map_fn_over_arrays_function",
|
|
9
|
+
"get_pad_dim_function",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
def get_abbreviate_array_function(
|
|
13
|
+
backend : CompatNamespace[CompatArray, Any, Any],
|
|
14
|
+
default_integer_dtype : CompatDType,
|
|
15
|
+
func_dtype_is_real_floating : Callable[[CompatDType], bool],
|
|
16
|
+
func_dtype_is_real_integer : Callable[[CompatDType], bool],
|
|
17
|
+
func_dtype_is_boolean : Callable[[CompatDType], bool],
|
|
18
|
+
):
|
|
19
|
+
def abbreviate_array(array : CompatArray, try_cast_scalar : bool = True) -> Union[float, int, CompatArray]:
|
|
20
|
+
"""
|
|
21
|
+
Abbreivates an array to a single element if possible.
|
|
22
|
+
Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
|
|
23
|
+
"""
|
|
24
|
+
abbr_array = array
|
|
25
|
+
idx = backend.zeros(1, dtype=default_integer_dtype, device=array_api_compat.device(abbr_array))
|
|
26
|
+
for dim_i in range(len(array.shape)):
|
|
27
|
+
first_elem = backend.take(abbr_array, idx, axis=dim_i)
|
|
28
|
+
if backend.all(abbr_array == first_elem):
|
|
29
|
+
abbr_array = first_elem
|
|
30
|
+
else:
|
|
31
|
+
continue
|
|
32
|
+
if try_cast_scalar:
|
|
33
|
+
if all(i == 1 for i in abbr_array.shape):
|
|
34
|
+
elem = abbr_array[tuple([0] * len(abbr_array.shape))]
|
|
35
|
+
if func_dtype_is_real_floating(elem.dtype):
|
|
36
|
+
return float(elem)
|
|
37
|
+
elif func_dtype_is_real_integer(elem.dtype):
|
|
38
|
+
return int(elem)
|
|
39
|
+
elif func_dtype_is_boolean(elem.dtype):
|
|
40
|
+
return bool(elem)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Abbreviated array element dtype must be a real floating or integer or boolean type, actual dtype: {elem.dtype}")
|
|
43
|
+
else:
|
|
44
|
+
return array
|
|
45
|
+
return abbreviate_array
|
|
46
|
+
|
|
47
|
+
def get_map_fn_over_arrays_function(
|
|
48
|
+
is_backendarray : Callable[[Any], bool],
|
|
49
|
+
):
|
|
50
|
+
def map_fn_over_arrays(data : Any, func : Callable[[CompatArray], CompatArray]) -> Any:
|
|
51
|
+
"""
|
|
52
|
+
Map a function to the data.
|
|
53
|
+
"""
|
|
54
|
+
if is_backendarray(data):
|
|
55
|
+
return func(data)
|
|
56
|
+
elif isinstance(data, Mapping):
|
|
57
|
+
ret = {k: map_fn_over_arrays(v, func) for k, v in data.items()}
|
|
58
|
+
try:
|
|
59
|
+
return type(data)(**ret) # try to keep the same mapping type
|
|
60
|
+
except:
|
|
61
|
+
return ret
|
|
62
|
+
elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
|
|
63
|
+
ret = [map_fn_over_arrays(i, func) for i in data]
|
|
64
|
+
try:
|
|
65
|
+
return type(data)(ret) # try to keep the same sequence type
|
|
66
|
+
except:
|
|
67
|
+
return ret
|
|
68
|
+
elif dataclasses.is_dataclass(data):
|
|
69
|
+
return type(data)(**map_fn_over_arrays(dataclasses.asdict(data), func))
|
|
70
|
+
else:
|
|
71
|
+
return data
|
|
72
|
+
return map_fn_over_arrays
|
|
73
|
+
|
|
74
|
+
def get_pad_dim_function(
|
|
75
|
+
backend : CompatNamespace[CompatArray, Any, Any],
|
|
76
|
+
):
|
|
77
|
+
def pad_dim(x : CompatArray, dim : int, target_size : int, value : Union[float, int] = 0) -> CompatArray:
|
|
78
|
+
size_at_target = x.shape[dim]
|
|
79
|
+
if size_at_target > target_size:
|
|
80
|
+
raise ValueError(f"Cannot pad dimension {dim} of size {size_at_target} to smaller target size {target_size}")
|
|
81
|
+
if size_at_target == target_size:
|
|
82
|
+
return x
|
|
83
|
+
target_shape = list(x.shape)
|
|
84
|
+
target_shape[dim] = target_size - size_at_target
|
|
85
|
+
pad_value = backend.full(target_shape, value, dtype=x.dtype, device=array_api_compat.device(x))
|
|
86
|
+
return backend.concat([x, pad_value], axis=dim)
|
|
87
|
+
return pad_dim
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import jax.numpy
|
|
2
|
+
|
|
3
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
4
|
+
compat_module = jax.numpy
|
|
5
|
+
from jax.numpy import *
|
|
6
|
+
from jax.numpy import __array_api_version__, __array_namespace_info__
|
|
7
|
+
else:
|
|
8
|
+
import jax.experimental.array_api as compat_module
|
|
9
|
+
from jax.experimental.array_api import *
|
|
10
|
+
from jax.experimental.array_api import __array_api_version__, __array_namespace_info__
|
|
11
|
+
|
|
12
|
+
from array_api_compat.common._helpers import *
|
|
13
|
+
|
|
14
|
+
simplified_name = "jax"
|
|
15
|
+
|
|
16
|
+
# Import and bind all functions from array_api_extra before exposing them
|
|
17
|
+
import array_api_extra
|
|
18
|
+
from functools import partial
|
|
19
|
+
for api_name in dir(array_api_extra):
|
|
20
|
+
if api_name.startswith('_'):
|
|
21
|
+
continue
|
|
22
|
+
|
|
23
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
24
|
+
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
25
|
+
else:
|
|
26
|
+
globals()[api_name] = partial(
|
|
27
|
+
getattr(array_api_extra, api_name),
|
|
28
|
+
xp=compat_module
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from ._typing import *
|
|
32
|
+
from ._extra import *
|
|
33
|
+
__import__(__package__ + ".random")
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Any, Union, Optional, Callable
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
6
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"default_integer_dtype",
|
|
10
|
+
"default_index_dtype",
|
|
11
|
+
"default_floating_dtype",
|
|
12
|
+
"default_boolean_dtype",
|
|
13
|
+
"serialize_device",
|
|
14
|
+
"deserialize_device",
|
|
15
|
+
"is_backendarray",
|
|
16
|
+
"from_numpy",
|
|
17
|
+
"from_other_backend",
|
|
18
|
+
"to_numpy",
|
|
19
|
+
"to_dlpack",
|
|
20
|
+
"dtype_is_real_integer",
|
|
21
|
+
"dtype_is_real_floating",
|
|
22
|
+
"dtype_is_boolean",
|
|
23
|
+
"abbreviate_array",
|
|
24
|
+
"map_fn_over_arrays",
|
|
25
|
+
"pad_dim",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
default_integer_dtype = int
|
|
29
|
+
default_index_dtype = int
|
|
30
|
+
default_floating_dtype = float
|
|
31
|
+
default_boolean_dtype = bool
|
|
32
|
+
|
|
33
|
+
def serialize_device(device : Optional[Any]) -> Optional[str]:
|
|
34
|
+
if device is None:
|
|
35
|
+
return None
|
|
36
|
+
assert hasattr(device, "platform") and hasattr(device, "id")
|
|
37
|
+
return device.platform + ":" + str(device.id)
|
|
38
|
+
|
|
39
|
+
def deserialize_device(device_str : Optional[str]) -> Optional[Any]:
|
|
40
|
+
if device_str is None:
|
|
41
|
+
return None
|
|
42
|
+
platform, id_str = device_str.split(":")
|
|
43
|
+
id = int(id_str)
|
|
44
|
+
# Find the device with matching platform and id
|
|
45
|
+
for device in jax.devices(platform):
|
|
46
|
+
if device.platform == platform and device.id == id:
|
|
47
|
+
return device
|
|
48
|
+
# Fallback with the same platform
|
|
49
|
+
for device in jax.devices(platform):
|
|
50
|
+
if device.platform == platform:
|
|
51
|
+
return device
|
|
52
|
+
raise ValueError(f"Device with platform '{platform}' and id '{id}' not found.")
|
|
53
|
+
|
|
54
|
+
def is_backendarray(data : Any) -> bool:
|
|
55
|
+
return isinstance(data, jax.Array)
|
|
56
|
+
|
|
57
|
+
def from_numpy(
|
|
58
|
+
data : np.ndarray,
|
|
59
|
+
/,
|
|
60
|
+
*,
|
|
61
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
62
|
+
device : Optional[DEVICE_TYPE] = None
|
|
63
|
+
) -> ARRAY_TYPE:
|
|
64
|
+
return jax.numpy.asarray(data, dtype=dtype, device=device)
|
|
65
|
+
|
|
66
|
+
def from_other_backend(
|
|
67
|
+
other_backend: ComputeBackend,
|
|
68
|
+
data: Any,
|
|
69
|
+
/,
|
|
70
|
+
) -> ARRAY_TYPE:
|
|
71
|
+
data_dlpack = other_backend.to_dlpack(data)
|
|
72
|
+
return jax.dlpack.from_dlpack(data_dlpack)
|
|
73
|
+
# except Exception as e:
|
|
74
|
+
# # jax sometimes has tiling issues with dlpack converted data
|
|
75
|
+
# np = other_backend.to_numpy(data)
|
|
76
|
+
# return from_numpy(np)
|
|
77
|
+
|
|
78
|
+
def to_numpy(
|
|
79
|
+
data : ARRAY_TYPE
|
|
80
|
+
) -> np.ndarray:
|
|
81
|
+
if data.dtype == jax.dtypes.bfloat16:
|
|
82
|
+
data = data.astype(np.float32)
|
|
83
|
+
return np.asarray(data)
|
|
84
|
+
|
|
85
|
+
def to_dlpack(
|
|
86
|
+
data: ARRAY_TYPE,
|
|
87
|
+
/,
|
|
88
|
+
) -> SupportsDLPack:
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
def dtype_is_real_integer(
|
|
92
|
+
dtype: DTYPE_TYPE
|
|
93
|
+
) -> bool:
|
|
94
|
+
return np.issubdtype(dtype, np.integer)
|
|
95
|
+
|
|
96
|
+
def dtype_is_real_floating(
|
|
97
|
+
dtype: DTYPE_TYPE
|
|
98
|
+
) -> bool:
|
|
99
|
+
return dtype == jax.dtypes.bfloat16 or np.issubdtype(dtype, np.floating)
|
|
100
|
+
|
|
101
|
+
def dtype_is_boolean(
|
|
102
|
+
dtype: DTYPE_TYPE
|
|
103
|
+
) -> bool:
|
|
104
|
+
return dtype == np.bool_ or dtype == bool
|
|
105
|
+
|
|
106
|
+
from .._common.implementations import *
|
|
107
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
108
|
+
compat_module = jax.numpy
|
|
109
|
+
else:
|
|
110
|
+
import jax.experimental.array_api as compat_module
|
|
111
|
+
abbreviate_array = get_abbreviate_array_function(
|
|
112
|
+
backend=compat_module,
|
|
113
|
+
default_integer_dtype=default_integer_dtype,
|
|
114
|
+
func_dtype_is_real_floating=dtype_is_real_floating,
|
|
115
|
+
func_dtype_is_real_integer=dtype_is_real_integer,
|
|
116
|
+
func_dtype_is_boolean=dtype_is_boolean
|
|
117
|
+
)
|
|
118
|
+
def map_fn_over_arrays(
|
|
119
|
+
data : Any, func : Callable[[ARRAY_TYPE], ARRAY_TYPE]
|
|
120
|
+
):
|
|
121
|
+
return jax.tree.map(
|
|
122
|
+
func,
|
|
123
|
+
data
|
|
124
|
+
)
|
|
125
|
+
pad_dim = get_pad_dim_function(
|
|
126
|
+
backend=compat_module,
|
|
127
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Union, Any, Optional
|
|
2
|
+
import jax
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
'ARRAY_TYPE',
|
|
7
|
+
'DTYPE_TYPE',
|
|
8
|
+
'DEVICE_TYPE',
|
|
9
|
+
'RNG_TYPE',
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
ARRAY_TYPE = jax.Array
|
|
13
|
+
DTYPE_TYPE = np.dtype
|
|
14
|
+
DEVICE_TYPE = Union[jax.Device, jax.sharding.Sharding]
|
|
15
|
+
RNG_TYPE = jax.Array
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import Union, Optional, Tuple, Any
|
|
2
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
3
|
+
import numpy as np
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"random_number_generator",
|
|
8
|
+
"random_discrete_uniform",
|
|
9
|
+
"random_uniform",
|
|
10
|
+
"random_exponential",
|
|
11
|
+
"random_normal",
|
|
12
|
+
"random_geometric",
|
|
13
|
+
"random_permutation"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def random_number_generator(
|
|
17
|
+
seed : Optional[int] = None,
|
|
18
|
+
*,
|
|
19
|
+
device : Optional[DEVICE_TYPE] = None
|
|
20
|
+
) -> RNG_TYPE:
|
|
21
|
+
rng_seed = np.random.randint(65535) if seed is None else seed
|
|
22
|
+
rng = jax.random.key(
|
|
23
|
+
seed=rng_seed
|
|
24
|
+
)
|
|
25
|
+
return rng
|
|
26
|
+
|
|
27
|
+
def random_discrete_uniform(
|
|
28
|
+
shape : Union[int, Tuple[int, ...]],
|
|
29
|
+
/,
|
|
30
|
+
from_num : int,
|
|
31
|
+
to_num : int,
|
|
32
|
+
*,
|
|
33
|
+
rng : RNG_TYPE,
|
|
34
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
35
|
+
device : Optional[DEVICE_TYPE] = None
|
|
36
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
37
|
+
new_rng, rng = jax.random.split(rng)
|
|
38
|
+
t = jax.random.randint(rng, shape, minval=int(from_num), maxval=int(to_num), dtype=dtype or int)
|
|
39
|
+
if device is not None:
|
|
40
|
+
t = jax.device_put(t, device)
|
|
41
|
+
return new_rng, t
|
|
42
|
+
|
|
43
|
+
def random_uniform(
|
|
44
|
+
shape: Union[int, Tuple[int, ...]],
|
|
45
|
+
/,
|
|
46
|
+
*,
|
|
47
|
+
rng : RNG_TYPE,
|
|
48
|
+
low : float = 0.0, high : float = 1.0,
|
|
49
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
50
|
+
device : Optional[DEVICE_TYPE] = None
|
|
51
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
52
|
+
new_rng, rng = jax.random.split(rng)
|
|
53
|
+
data = jax.random.uniform(rng, shape, dtype=dtype or float, minval=low, maxval=high)
|
|
54
|
+
if device is not None:
|
|
55
|
+
data = jax.device_put(data, device)
|
|
56
|
+
return new_rng, data
|
|
57
|
+
|
|
58
|
+
def random_exponential(
|
|
59
|
+
shape: Union[int, Tuple[int, ...]],
|
|
60
|
+
/,
|
|
61
|
+
*,
|
|
62
|
+
rng : RNG_TYPE,
|
|
63
|
+
lambd : float = 1.0,
|
|
64
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
65
|
+
device : Optional[DEVICE_TYPE] = None
|
|
66
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
67
|
+
new_rng, rng = jax.random.split(rng)
|
|
68
|
+
data = jax.random.exponential(rng, shape, dtype=dtype or float) / lambd
|
|
69
|
+
if device is not None:
|
|
70
|
+
data = jax.device_put(data, device)
|
|
71
|
+
return new_rng, data
|
|
72
|
+
|
|
73
|
+
def random_normal(
|
|
74
|
+
shape: Union[int, Tuple[int, ...]],
|
|
75
|
+
/,
|
|
76
|
+
*,
|
|
77
|
+
rng : RNG_TYPE,
|
|
78
|
+
mean : float = 0.0, std : float = 1.0,
|
|
79
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
80
|
+
device : Optional[Any] = None
|
|
81
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
82
|
+
new_rng, rng = jax.random.split(rng)
|
|
83
|
+
data = jax.random.normal(rng, shape, dtype=dtype or float) * std + mean
|
|
84
|
+
if device is not None:
|
|
85
|
+
data = jax.device_put(data, device)
|
|
86
|
+
return new_rng, data
|
|
87
|
+
|
|
88
|
+
def random_geometric(
|
|
89
|
+
shape: Union[int, Tuple[int, ...]],
|
|
90
|
+
/,
|
|
91
|
+
*,
|
|
92
|
+
p: float,
|
|
93
|
+
rng: RNG_TYPE,
|
|
94
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
95
|
+
device: Optional[Any] = None
|
|
96
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
97
|
+
new_rng, rng = jax.random.split(rng)
|
|
98
|
+
data = jax.random.geometric(rng, p=p, shape=shape, dtype=dtype or int)
|
|
99
|
+
if device is not None:
|
|
100
|
+
data = jax.device_put(data, device)
|
|
101
|
+
return new_rng, data
|
|
102
|
+
|
|
103
|
+
def random_permutation(
|
|
104
|
+
n : int,
|
|
105
|
+
/,
|
|
106
|
+
*,
|
|
107
|
+
rng: RNG_TYPE,
|
|
108
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
109
|
+
device: Optional[DEVICE_TYPE] = None
|
|
110
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
111
|
+
new_rng, rng = jax.random.split(rng)
|
|
112
|
+
data = jax.random.permutation(rng, n, dtype=dtype or int)
|
|
113
|
+
if device is not None:
|
|
114
|
+
data = jax.device_put(data, device)
|
|
115
|
+
return new_rng, data
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from array_api_compat.numpy import *
|
|
2
|
+
from array_api_compat.numpy import __array_api_version__, __array_namespace_info__
|
|
3
|
+
from array_api_compat.common._helpers import *
|
|
4
|
+
|
|
5
|
+
simplified_name = "numpy"
|
|
6
|
+
|
|
7
|
+
from array_api_compat import numpy as compat_module
|
|
8
|
+
# Import and bind all functions from array_api_extra before exposing them
|
|
9
|
+
import array_api_extra
|
|
10
|
+
from functools import partial
|
|
11
|
+
for api_name in dir(array_api_extra):
|
|
12
|
+
if api_name.startswith('_'):
|
|
13
|
+
continue
|
|
14
|
+
|
|
15
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
16
|
+
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
17
|
+
else:
|
|
18
|
+
globals()[api_name] = partial(
|
|
19
|
+
getattr(array_api_extra, api_name),
|
|
20
|
+
xp=compat_module
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from ._typing import *
|
|
24
|
+
from ._extra import *
|
|
25
|
+
__import__(__package__ + ".random")
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from typing import Any, Union, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
4
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"default_integer_dtype",
|
|
8
|
+
"default_index_dtype",
|
|
9
|
+
"default_floating_dtype",
|
|
10
|
+
"default_boolean_dtype",
|
|
11
|
+
"serialize_device",
|
|
12
|
+
"deserialize_device",
|
|
13
|
+
"is_backendarray",
|
|
14
|
+
"from_numpy",
|
|
15
|
+
"from_other_backend",
|
|
16
|
+
"to_numpy",
|
|
17
|
+
"to_dlpack",
|
|
18
|
+
"dtype_is_real_integer",
|
|
19
|
+
"dtype_is_real_floating",
|
|
20
|
+
"dtype_is_boolean",
|
|
21
|
+
"abbreviate_array",
|
|
22
|
+
"map_fn_over_arrays",
|
|
23
|
+
"pad_dim",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
default_integer_dtype = int
|
|
27
|
+
default_index_dtype = int
|
|
28
|
+
default_floating_dtype = float
|
|
29
|
+
default_boolean_dtype = bool
|
|
30
|
+
|
|
31
|
+
def serialize_device(device : Optional[DEVICE_TYPE]) -> Optional[str]:
|
|
32
|
+
return None # NumPy does not have device concept
|
|
33
|
+
|
|
34
|
+
def deserialize_device(device_str : Optional[str]) -> Optional[DEVICE_TYPE]:
|
|
35
|
+
return None # NumPy does not have device concept
|
|
36
|
+
|
|
37
|
+
def is_backendarray(data : Any) -> bool:
|
|
38
|
+
return isinstance(data, np.ndarray)
|
|
39
|
+
|
|
40
|
+
def from_numpy(
|
|
41
|
+
data : np.ndarray,
|
|
42
|
+
/,
|
|
43
|
+
*,
|
|
44
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
45
|
+
device : Optional[DEVICE_TYPE] = None
|
|
46
|
+
) -> ARRAY_TYPE:
|
|
47
|
+
return data
|
|
48
|
+
|
|
49
|
+
def from_other_backend(
|
|
50
|
+
other_backend: ComputeBackend,
|
|
51
|
+
data: Any,
|
|
52
|
+
/,
|
|
53
|
+
) -> ARRAY_TYPE:
|
|
54
|
+
return other_backend.to_numpy(data)
|
|
55
|
+
|
|
56
|
+
def to_numpy(
|
|
57
|
+
data : ARRAY_TYPE
|
|
58
|
+
) -> np.ndarray:
|
|
59
|
+
return data
|
|
60
|
+
|
|
61
|
+
def to_dlpack(
|
|
62
|
+
data: ARRAY_TYPE,
|
|
63
|
+
/,
|
|
64
|
+
) -> SupportsDLPack:
|
|
65
|
+
return data
|
|
66
|
+
|
|
67
|
+
def dtype_is_real_integer(
|
|
68
|
+
dtype: DTYPE_TYPE
|
|
69
|
+
) -> bool:
|
|
70
|
+
return np.issubdtype(dtype, np.integer)
|
|
71
|
+
|
|
72
|
+
def dtype_is_real_floating(
|
|
73
|
+
dtype: DTYPE_TYPE
|
|
74
|
+
) -> bool:
|
|
75
|
+
return np.issubdtype(dtype, np.floating)
|
|
76
|
+
|
|
77
|
+
def dtype_is_boolean(
|
|
78
|
+
dtype: DTYPE_TYPE
|
|
79
|
+
) -> bool:
|
|
80
|
+
return dtype == np.bool_ or dtype == bool
|
|
81
|
+
|
|
82
|
+
from .._common.implementations import *
|
|
83
|
+
from array_api_compat import numpy as compat_module
|
|
84
|
+
abbreviate_array = get_abbreviate_array_function(
|
|
85
|
+
backend=compat_module,
|
|
86
|
+
default_integer_dtype=default_integer_dtype,
|
|
87
|
+
func_dtype_is_real_floating=dtype_is_real_floating,
|
|
88
|
+
func_dtype_is_real_integer=dtype_is_real_integer,
|
|
89
|
+
func_dtype_is_boolean=dtype_is_boolean,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
map_fn_over_arrays = get_map_fn_over_arrays_function(
|
|
93
|
+
is_backendarray=is_backendarray,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
pad_dim = get_pad_dim_function(
|
|
97
|
+
backend=compat_module,
|
|
98
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Union, Any, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
'ARRAY_TYPE',
|
|
6
|
+
'DTYPE_TYPE',
|
|
7
|
+
'DEVICE_TYPE',
|
|
8
|
+
'RNG_TYPE',
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
ARRAY_TYPE = np.ndarray
|
|
12
|
+
DTYPE_TYPE = np.dtype
|
|
13
|
+
DEVICE_TYPE = Any
|
|
14
|
+
RNG_TYPE = np.random.Generator
|