xbarray 0.0.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xbarray might be problematic. Click here for more details.
- array_api_typing/__init__.py +9 -0
- array_api_typing/typing_2024_12/__init__.py +12 -0
- array_api_typing/typing_2024_12/_api_constant.py +32 -0
- array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
- array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
- array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
- array_api_typing/typing_2024_12/_api_typing.py +5855 -0
- array_api_typing/typing_2024_12/_array_typing.py +1265 -0
- array_api_typing/typing_compat/__init__.py +12 -0
- array_api_typing/typing_compat/_api_typing.py +27 -0
- array_api_typing/typing_compat/_array_typing.py +36 -0
- array_api_typing/typing_extra/__init__.py +12 -0
- array_api_typing/typing_extra/_api_typing.py +651 -0
- array_api_typing/typing_extra/_at.py +87 -0
- xbarray/__init__.py +1 -0
- xbarray/base/__init__.py +1 -0
- xbarray/base/base.py +199 -0
- xbarray/common/implementations.py +61 -0
- xbarray/jax/__init__.py +25 -0
- xbarray/jax/_extra.py +98 -0
- xbarray/jax/_typing.py +15 -0
- xbarray/jax/random.py +116 -0
- xbarray/numpy/__init__.py +19 -0
- xbarray/numpy/_extra.py +83 -0
- xbarray/numpy/_typing.py +14 -0
- xbarray/numpy/random.py +106 -0
- xbarray/pytorch/__init__.py +20 -0
- xbarray/pytorch/_extra.py +109 -0
- xbarray/pytorch/_typing.py +13 -0
- xbarray/pytorch/random.py +102 -0
- xbarray-0.0.1a1.dist-info/METADATA +14 -0
- xbarray-0.0.1a1.dist-info/RECORD +35 -0
- xbarray-0.0.1a1.dist-info/WHEEL +5 -0
- xbarray-0.0.1a1.dist-info/licenses/LICENSE +21 -0
- xbarray-0.0.1a1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Protocol, TypeVar, Optional, Any, Tuple, Union, Type, TypedDict, List
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from array_api_typing.typing_compat._api_typing import _NAMESPACE_ARRAY
|
|
4
|
+
from array_api_typing.typing_compat._array_typing import SetIndex
|
|
5
|
+
|
|
6
|
+
class AtResult(Protocol[_NAMESPACE_ARRAY]):
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def __getitem__(self, idx: SetIndex, /) -> "AtResult[_NAMESPACE_ARRAY]":
|
|
9
|
+
"""
|
|
10
|
+
Allow for the alternate syntax ``at(x)[start:stop:step]``.
|
|
11
|
+
|
|
12
|
+
It looks prettier than ``at(x, slice(start, stop, step))``
|
|
13
|
+
and feels more intuitive coming from the JAX documentation.
|
|
14
|
+
"""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def set(
|
|
19
|
+
self,
|
|
20
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
21
|
+
/,
|
|
22
|
+
copy: Optional[bool] = None,
|
|
23
|
+
) -> _NAMESPACE_ARRAY:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def add(
|
|
28
|
+
self,
|
|
29
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
30
|
+
/,
|
|
31
|
+
copy: Optional[bool] = None,
|
|
32
|
+
) -> _NAMESPACE_ARRAY:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def subtract(
|
|
37
|
+
self,
|
|
38
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
39
|
+
/,
|
|
40
|
+
copy: Optional[bool] = None,
|
|
41
|
+
) -> _NAMESPACE_ARRAY:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def multiply(
|
|
46
|
+
self,
|
|
47
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
48
|
+
/,
|
|
49
|
+
copy: Optional[bool] = None,
|
|
50
|
+
) -> _NAMESPACE_ARRAY:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def divide(
|
|
55
|
+
self,
|
|
56
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
57
|
+
/,
|
|
58
|
+
copy: Optional[bool] = None,
|
|
59
|
+
) -> _NAMESPACE_ARRAY:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def power(
|
|
64
|
+
self,
|
|
65
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
66
|
+
/,
|
|
67
|
+
copy: Optional[bool] = None,
|
|
68
|
+
) -> _NAMESPACE_ARRAY:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def min(
|
|
73
|
+
self,
|
|
74
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
75
|
+
/,
|
|
76
|
+
copy: Optional[bool] = None,
|
|
77
|
+
) -> _NAMESPACE_ARRAY:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def max(
|
|
82
|
+
self,
|
|
83
|
+
y: Union[_NAMESPACE_ARRAY, float, int, complex],
|
|
84
|
+
/,
|
|
85
|
+
copy: Optional[bool] = None,
|
|
86
|
+
) -> _NAMESPACE_ARRAY:
|
|
87
|
+
pass
|
xbarray/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|
xbarray/base/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|
xbarray/base/base.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from typing import Optional, Generic, TypeVar, Dict, Union, Any, Sequence, SupportsFloat, Tuple, Type, Callable, Mapping, Protocol
|
|
2
|
+
import abc
|
|
3
|
+
from array_api_typing.typing_extra import *
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
ArrayAPISetIndex = SetIndex
|
|
7
|
+
ArrayAPIGetIndex = GetIndex
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"RNGBackend",
|
|
11
|
+
"ComputeBackend",
|
|
12
|
+
"BArrayType",
|
|
13
|
+
"BDeviceType",
|
|
14
|
+
"BDtypeType",
|
|
15
|
+
"BRNGType",
|
|
16
|
+
"SupportsDLPack",
|
|
17
|
+
"ArrayAPIArray",
|
|
18
|
+
"ArrayAPIDevice",
|
|
19
|
+
"ArrayAPIDType",
|
|
20
|
+
"ArrayAPINamespace",
|
|
21
|
+
"ArrayAPISetIndex",
|
|
22
|
+
"ArrayAPIGetIndex",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
BArrayType = TypeVar("BArrayType", covariant=True, bound=ArrayAPIArray)
|
|
26
|
+
BDeviceType = TypeVar("BDeviceType", covariant=True, bound=ArrayAPIDevice)
|
|
27
|
+
BDtypeType = TypeVar("BDtypeType", covariant=True, bound=ArrayAPIDType)
|
|
28
|
+
BRNGType = TypeVar("BRNGType", covariant=True)
|
|
29
|
+
class RNGBackend(Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def random_number_generator(
|
|
32
|
+
self,
|
|
33
|
+
seed : Optional[int] = None,
|
|
34
|
+
*,
|
|
35
|
+
device : Optional[BDeviceType] = None
|
|
36
|
+
) -> BRNGType:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def random_discrete_uniform(
|
|
41
|
+
self,
|
|
42
|
+
shape : Union[int, Tuple[int, ...]],
|
|
43
|
+
from_num : int,
|
|
44
|
+
to_num : int,
|
|
45
|
+
/,
|
|
46
|
+
*,
|
|
47
|
+
rng : BRNGType,
|
|
48
|
+
dtype : Optional[BDtypeType] = None,
|
|
49
|
+
device : Optional[BDeviceType] = None,
|
|
50
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
51
|
+
"""
|
|
52
|
+
Sample from a discrete uniform distribution [from_num, to_num) with shape `shape`.
|
|
53
|
+
"""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def random_uniform(
|
|
58
|
+
self,
|
|
59
|
+
shape: Union[int, Tuple[int, ...]],
|
|
60
|
+
/,
|
|
61
|
+
*,
|
|
62
|
+
rng : BRNGType,
|
|
63
|
+
low : float = 0.0, high : float = 1.0,
|
|
64
|
+
dtype : Optional[BDtypeType] = None,
|
|
65
|
+
device : Optional[BDeviceType] = None,
|
|
66
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def random_exponential(
|
|
71
|
+
self,
|
|
72
|
+
shape: Union[int, Tuple[int, ...]],
|
|
73
|
+
/,
|
|
74
|
+
*,
|
|
75
|
+
rng : BRNGType,
|
|
76
|
+
lambd : float = 1.0,
|
|
77
|
+
dtype : Optional[BDtypeType] = None,
|
|
78
|
+
device : Optional[BDeviceType] = None,
|
|
79
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
@abc.abstractmethod
|
|
83
|
+
def random_normal(
|
|
84
|
+
self,
|
|
85
|
+
shape: Union[int, Tuple[int, ...]],
|
|
86
|
+
/,
|
|
87
|
+
*,
|
|
88
|
+
rng : BRNGType,
|
|
89
|
+
mean : float = 0.0, std : float = 1.0,
|
|
90
|
+
dtype : Optional[BDtypeType] = None,
|
|
91
|
+
device : Optional[BDeviceType] = None,
|
|
92
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
@abc.abstractmethod
|
|
96
|
+
def random_geometric(
|
|
97
|
+
self,
|
|
98
|
+
shape: Union[int, Tuple[int, ...]],
|
|
99
|
+
/,
|
|
100
|
+
*,
|
|
101
|
+
p : float,
|
|
102
|
+
rng : BRNGType,
|
|
103
|
+
dtype : Optional[BDtypeType] = None,
|
|
104
|
+
device : Optional[BDeviceType] = None,
|
|
105
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
@abc.abstractmethod
|
|
109
|
+
def random_permutation(
|
|
110
|
+
self,
|
|
111
|
+
n : int,
|
|
112
|
+
/,
|
|
113
|
+
*,
|
|
114
|
+
rng : BRNGType,
|
|
115
|
+
dtype : Optional[BDtypeType] = None,
|
|
116
|
+
device : Optional[BDeviceType] = None
|
|
117
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
118
|
+
raise NotImplementedError
|
|
119
|
+
|
|
120
|
+
class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
121
|
+
simplified_name : str
|
|
122
|
+
ARRAY_TYPE : Type[BArrayType]
|
|
123
|
+
DEVICE_TYPE : Type[BDeviceType]
|
|
124
|
+
DTYPE_TYPE : Type[BDtypeType]
|
|
125
|
+
RNG_TYPE : Type[BRNGType]
|
|
126
|
+
default_integer_dtype : BDtypeType
|
|
127
|
+
default_floating_dtype : BDtypeType
|
|
128
|
+
default_boolean_dtype : BDtypeType
|
|
129
|
+
random : RNGBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
130
|
+
|
|
131
|
+
@abc.abstractmethod
|
|
132
|
+
def is_backendarray(self, data : Any) -> bool:
|
|
133
|
+
raise NotImplementedError
|
|
134
|
+
|
|
135
|
+
@abc.abstractmethod
|
|
136
|
+
def from_numpy(
|
|
137
|
+
self,
|
|
138
|
+
data : np.ndarray,
|
|
139
|
+
/,
|
|
140
|
+
*,
|
|
141
|
+
dtype : Optional[BDtypeType] = None,
|
|
142
|
+
device : Optional[BDeviceType] = None
|
|
143
|
+
) -> BArrayType:
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
|
|
146
|
+
@abc.abstractmethod
|
|
147
|
+
def from_other_backend(
|
|
148
|
+
self,
|
|
149
|
+
other_backend : "ComputeBackend",
|
|
150
|
+
data : ArrayAPIArray,
|
|
151
|
+
/
|
|
152
|
+
) -> BArrayType:
|
|
153
|
+
"""
|
|
154
|
+
Convert an array from another backend to this backend.
|
|
155
|
+
The other backend must be compatible with the ArrayAPI.
|
|
156
|
+
"""
|
|
157
|
+
raise NotImplementedError
|
|
158
|
+
|
|
159
|
+
@abc.abstractmethod
|
|
160
|
+
def to_numpy(
|
|
161
|
+
self,
|
|
162
|
+
data : BArrayType
|
|
163
|
+
) -> np.ndarray:
|
|
164
|
+
raise NotImplementedError
|
|
165
|
+
|
|
166
|
+
@abc.abstractmethod
|
|
167
|
+
def to_dlpack(
|
|
168
|
+
self,
|
|
169
|
+
data : BArrayType,
|
|
170
|
+
/,
|
|
171
|
+
) -> SupportsDLPack:
|
|
172
|
+
raise NotImplementedError
|
|
173
|
+
|
|
174
|
+
@abc.abstractmethod
|
|
175
|
+
def dtype_is_real_integer(self, dtype : BDtypeType) -> bool:
|
|
176
|
+
raise NotImplementedError
|
|
177
|
+
|
|
178
|
+
@abc.abstractmethod
|
|
179
|
+
def dtype_is_real_floating(self, dtype : BDtypeType) -> bool:
|
|
180
|
+
raise NotImplementedError
|
|
181
|
+
|
|
182
|
+
@abc.abstractmethod
|
|
183
|
+
def dtype_is_boolean(self, dtype : BDtypeType) -> bool:
|
|
184
|
+
raise NotImplementedError
|
|
185
|
+
|
|
186
|
+
@abc.abstractmethod
|
|
187
|
+
def abbreviate_array(self, x : BArrayType, try_cast_scalar: bool = True) -> Union[float, int, BArrayType]:
|
|
188
|
+
"""
|
|
189
|
+
Abbreviates an array to a single element if possible.
|
|
190
|
+
Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
|
|
191
|
+
"""
|
|
192
|
+
pass
|
|
193
|
+
|
|
194
|
+
@abc.abstractmethod
|
|
195
|
+
def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
|
|
196
|
+
"""
|
|
197
|
+
Map a function over arrays in a data structure and produce a new data structure with the same shape.
|
|
198
|
+
This is useful for applying a function to all arrays in a nested structure.
|
|
199
|
+
"""
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Any, Union, Callable
|
|
2
|
+
from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"get_abbreviate_array_function",
|
|
6
|
+
"get_map_fn_over_arrays_function",
|
|
7
|
+
]
|
|
8
|
+
|
|
9
|
+
def get_abbreviate_array_function(
|
|
10
|
+
backend : CompatNamespace[CompatArray, Any, Any],
|
|
11
|
+
default_integer_dtype : CompatDType,
|
|
12
|
+
func_dtype_is_real_floating : Callable[[CompatDType], bool],
|
|
13
|
+
func_dtype_is_real_integer : Callable[[CompatDType], bool],
|
|
14
|
+
func_dtype_is_boolean : Callable[[CompatDType], bool],
|
|
15
|
+
):
|
|
16
|
+
def abbreviate_array(array : CompatArray, try_cast_scalar : bool = True) -> Union[float, int, CompatArray]:
|
|
17
|
+
"""
|
|
18
|
+
Abbreivates an array to a single element if possible.
|
|
19
|
+
Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
|
|
20
|
+
"""
|
|
21
|
+
abbr_array = array
|
|
22
|
+
idx = backend.zeros(1, dtype=default_integer_dtype, device=backend.device(abbr_array))
|
|
23
|
+
for dim_i in range(len(array.shape)):
|
|
24
|
+
first_elem = backend.take(abbr_array, idx, axis=dim_i)
|
|
25
|
+
if backend.all(abbr_array == first_elem):
|
|
26
|
+
abbr_array = first_elem
|
|
27
|
+
else:
|
|
28
|
+
continue
|
|
29
|
+
if try_cast_scalar:
|
|
30
|
+
if all(i == 1 for i in abbr_array.shape):
|
|
31
|
+
elem = abbr_array[tuple([0] * len(abbr_array.shape))]
|
|
32
|
+
if func_dtype_is_real_floating(elem.dtype):
|
|
33
|
+
return float(elem)
|
|
34
|
+
elif func_dtype_is_real_integer(elem.dtype):
|
|
35
|
+
return int(elem)
|
|
36
|
+
elif func_dtype_is_boolean(elem.dtype):
|
|
37
|
+
return bool(elem)
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Abbreviated array element dtype must be a real floating or integer or boolean type, actual dtype: {elem.dtype}")
|
|
40
|
+
else:
|
|
41
|
+
return array
|
|
42
|
+
return abbreviate_array
|
|
43
|
+
|
|
44
|
+
def get_map_fn_over_arrays_function(
|
|
45
|
+
is_backendarray : Callable[[Any], bool],
|
|
46
|
+
):
|
|
47
|
+
def map_fn_over_arrays(data : Any, func : Callable[[CompatArray], CompatArray]) -> Any:
|
|
48
|
+
"""
|
|
49
|
+
Map a function to the data.
|
|
50
|
+
"""
|
|
51
|
+
if is_backendarray(data):
|
|
52
|
+
return func(data)
|
|
53
|
+
elif isinstance(data, dict):
|
|
54
|
+
return {k: map_fn_over_arrays(v, func) for k, v in data.items()}
|
|
55
|
+
elif isinstance(data, tuple):
|
|
56
|
+
return tuple(map_fn_over_arrays(i, func) for i in data)
|
|
57
|
+
elif isinstance(data, list):
|
|
58
|
+
return [map_fn_over_arrays(i, func) for i in data]
|
|
59
|
+
else:
|
|
60
|
+
return data
|
|
61
|
+
return map_fn_over_arrays
|
xbarray/jax/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import jax.numpy
|
|
2
|
+
|
|
3
|
+
simplified_name = "jax"
|
|
4
|
+
|
|
5
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
6
|
+
compat_module = jax.numpy
|
|
7
|
+
from jax.numpy import *
|
|
8
|
+
else:
|
|
9
|
+
import jax.experimental.array_api as compat_module
|
|
10
|
+
from jax.experimental.array_api import *
|
|
11
|
+
|
|
12
|
+
# Import and bind all functions from array_api_extra before exposing them
|
|
13
|
+
import array_api_extra
|
|
14
|
+
from functools import partial
|
|
15
|
+
for api_name in dir(array_api_extra):
|
|
16
|
+
if api_name.startswith('_'):
|
|
17
|
+
continue
|
|
18
|
+
globals()[api_name] = partial(
|
|
19
|
+
getattr(array_api_extra, api_name),
|
|
20
|
+
xp=compat_module
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from ._typing import *
|
|
24
|
+
from ._extra import *
|
|
25
|
+
__import__(__package__ + ".random")
|
xbarray/jax/_extra.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from typing import Any, Union, Optional, Callable
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
6
|
+
from xbarray.base import ComputeBackend, SupportsDLPack
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"default_integer_dtype",
|
|
10
|
+
"default_floating_dtype",
|
|
11
|
+
"default_boolean_dtype",
|
|
12
|
+
"is_backendarray",
|
|
13
|
+
"from_numpy",
|
|
14
|
+
"from_other_backend",
|
|
15
|
+
"to_numpy",
|
|
16
|
+
"to_dlpack",
|
|
17
|
+
"dtype_is_real_integer",
|
|
18
|
+
"dtype_is_real_floating",
|
|
19
|
+
"dtype_is_boolean",
|
|
20
|
+
"abbreviate_array",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
default_integer_dtype = int
|
|
24
|
+
default_floating_dtype = float
|
|
25
|
+
default_boolean_dtype = bool
|
|
26
|
+
|
|
27
|
+
def is_backendarray(data : Any) -> bool:
|
|
28
|
+
return isinstance(data, jax.Array)
|
|
29
|
+
|
|
30
|
+
def from_numpy(
|
|
31
|
+
data : np.ndarray,
|
|
32
|
+
/,
|
|
33
|
+
*,
|
|
34
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
35
|
+
device : Optional[DEVICE_TYPE] = None
|
|
36
|
+
) -> ARRAY_TYPE:
|
|
37
|
+
return jax.numpy.asarray(data, dtype=dtype, device=device)
|
|
38
|
+
|
|
39
|
+
def from_other_backend(
|
|
40
|
+
other_backend: ComputeBackend,
|
|
41
|
+
data: Any,
|
|
42
|
+
/,
|
|
43
|
+
) -> ARRAY_TYPE:
|
|
44
|
+
data_dlpack = other_backend.to_dlpack(data)
|
|
45
|
+
return jax.dlpack.from_dlpack(data_dlpack)
|
|
46
|
+
# except Exception as e:
|
|
47
|
+
# # jax sometimes has tiling issues with dlpack converted data
|
|
48
|
+
# np = other_backend.to_numpy(data)
|
|
49
|
+
# return from_numpy(np)
|
|
50
|
+
|
|
51
|
+
def to_numpy(
|
|
52
|
+
data : ARRAY_TYPE
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
if data.dtype == jax.dtypes.bfloat16:
|
|
55
|
+
data = data.astype(np.float32)
|
|
56
|
+
return np.asarray(data)
|
|
57
|
+
|
|
58
|
+
def to_dlpack(
|
|
59
|
+
self,
|
|
60
|
+
data: ARRAY_TYPE,
|
|
61
|
+
/,
|
|
62
|
+
) -> SupportsDLPack:
|
|
63
|
+
return data
|
|
64
|
+
|
|
65
|
+
def dtype_is_real_integer(
|
|
66
|
+
dtype: DTYPE_TYPE
|
|
67
|
+
) -> bool:
|
|
68
|
+
return np.issubdtype(dtype, np.integer)
|
|
69
|
+
|
|
70
|
+
def dtype_is_real_floating(
|
|
71
|
+
dtype: DTYPE_TYPE
|
|
72
|
+
) -> bool:
|
|
73
|
+
return dtype == jax.dtypes.bfloat16 or np.issubdtype(dtype, np.floating)
|
|
74
|
+
|
|
75
|
+
def dtype_is_boolean(
|
|
76
|
+
dtype: DTYPE_TYPE
|
|
77
|
+
) -> bool:
|
|
78
|
+
return dtype == np.bool_ or dtype == bool
|
|
79
|
+
|
|
80
|
+
from xbarray.common.implementations import *
|
|
81
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
82
|
+
compat_module = jax.numpy
|
|
83
|
+
else:
|
|
84
|
+
import jax.experimental.array_api as compat_module
|
|
85
|
+
abbreviate_array = get_abbreviate_array_function(
|
|
86
|
+
backend=compat_module,
|
|
87
|
+
default_integer_dtype=default_integer_dtype,
|
|
88
|
+
func_dtype_is_real_floating=dtype_is_real_floating,
|
|
89
|
+
func_dtype_is_real_integer=dtype_is_real_integer,
|
|
90
|
+
func_dtype_is_boolean=dtype_is_boolean
|
|
91
|
+
)
|
|
92
|
+
def map_fn_over_arrays(
|
|
93
|
+
data : Any, func : Callable[[ARRAY_TYPE], ARRAY_TYPE]
|
|
94
|
+
):
|
|
95
|
+
return jax.tree.map(
|
|
96
|
+
func,
|
|
97
|
+
data
|
|
98
|
+
)
|
xbarray/jax/_typing.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Union, Any, Optional
|
|
2
|
+
import jax
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
'ARRAY_TYPE',
|
|
7
|
+
'DTYPE_TYPE',
|
|
8
|
+
'DEVICE_TYPE',
|
|
9
|
+
'RNG_TYPE',
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
ARRAY_TYPE = jax.Array
|
|
13
|
+
DTYPE_TYPE = np.dtype
|
|
14
|
+
DEVICE_TYPE = Union[jax.Device, jax.sharding.Sharding]
|
|
15
|
+
RNG_TYPE = jax.Array
|
xbarray/jax/random.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Union, Optional, Tuple, Any
|
|
2
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
3
|
+
import numpy as np
|
|
4
|
+
import jax
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"random_number_generator",
|
|
8
|
+
"random_discrete_uniform",
|
|
9
|
+
"random_uniform",
|
|
10
|
+
"random_exponential",
|
|
11
|
+
"random_normal",
|
|
12
|
+
"random_geometric",
|
|
13
|
+
"random_permutation"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def random_number_generator(
|
|
17
|
+
seed : Optional[int] = None,
|
|
18
|
+
*,
|
|
19
|
+
device : Optional[DEVICE_TYPE] = None
|
|
20
|
+
) -> RNG_TYPE:
|
|
21
|
+
rng_seed = np.random.randint(0) if seed is None else seed
|
|
22
|
+
rng = jax.random.key(
|
|
23
|
+
seed=rng_seed
|
|
24
|
+
)
|
|
25
|
+
return rng
|
|
26
|
+
|
|
27
|
+
def random_discrete_uniform(
|
|
28
|
+
shape : Union[int, Tuple[int, ...]],
|
|
29
|
+
from_num : int,
|
|
30
|
+
to_num : int,
|
|
31
|
+
/,
|
|
32
|
+
*,
|
|
33
|
+
rng : RNG_TYPE,
|
|
34
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
35
|
+
device : Optional[DEVICE_TYPE] = None
|
|
36
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
37
|
+
new_rng, rng = jax.random.split(rng)
|
|
38
|
+
t = jax.random.randint(rng, shape, minval=int(from_num), maxval=int(to_num), dtype=dtype or int)
|
|
39
|
+
if device is not None:
|
|
40
|
+
t = jax.device_put(t, device)
|
|
41
|
+
return new_rng, t
|
|
42
|
+
|
|
43
|
+
def random_uniform(
|
|
44
|
+
shape: Union[int, Tuple[int, ...]],
|
|
45
|
+
/,
|
|
46
|
+
*,
|
|
47
|
+
rng : RNG_TYPE,
|
|
48
|
+
low : float = 0.0, high : float = 1.0,
|
|
49
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
50
|
+
device : Optional[DEVICE_TYPE] = None
|
|
51
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
52
|
+
new_rng, rng = jax.random.split(rng)
|
|
53
|
+
data = jax.random.uniform(rng, shape, dtype=dtype or float, minval=low, maxval=high)
|
|
54
|
+
if device is not None:
|
|
55
|
+
data = jax.device_put(data, device)
|
|
56
|
+
return new_rng, data
|
|
57
|
+
|
|
58
|
+
def random_exponential(
|
|
59
|
+
shape: Union[int, Tuple[int, ...]],
|
|
60
|
+
/,
|
|
61
|
+
*,
|
|
62
|
+
rng : RNG_TYPE,
|
|
63
|
+
lambd : float = 1.0,
|
|
64
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
65
|
+
device : Optional[DEVICE_TYPE] = None
|
|
66
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
67
|
+
new_rng, rng = jax.random.split(rng)
|
|
68
|
+
data = jax.random.exponential(rng, shape, dtype=dtype or float) / lambd
|
|
69
|
+
if device is not None:
|
|
70
|
+
data = jax.device_put(data, device)
|
|
71
|
+
return new_rng, data
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def random_normal(
|
|
75
|
+
shape: Union[int, Tuple[int, ...]],
|
|
76
|
+
/,
|
|
77
|
+
*,
|
|
78
|
+
rng : RNG_TYPE,
|
|
79
|
+
mean : float = 0.0, std : float = 1.0,
|
|
80
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
81
|
+
device : Optional[Any] = None
|
|
82
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
83
|
+
new_rng, rng = jax.random.split(rng)
|
|
84
|
+
data = jax.random.normal(rng, shape, dtype=dtype or float) * std + mean
|
|
85
|
+
if device is not None:
|
|
86
|
+
data = jax.device_put(data, device)
|
|
87
|
+
return new_rng, data
|
|
88
|
+
|
|
89
|
+
def random_geometric(
|
|
90
|
+
shape: Union[int, Tuple[int, ...]],
|
|
91
|
+
/,
|
|
92
|
+
*,
|
|
93
|
+
p: float,
|
|
94
|
+
rng: RNG_TYPE,
|
|
95
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
96
|
+
device: Optional[Any] = None
|
|
97
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
98
|
+
new_rng, rng = jax.random.split(rng)
|
|
99
|
+
data = jax.random.geometric(rng, p=p, shape=shape, dtype=dtype or int)
|
|
100
|
+
if device is not None:
|
|
101
|
+
data = jax.device_put(data, device)
|
|
102
|
+
return new_rng, data
|
|
103
|
+
|
|
104
|
+
def random_permutation(
|
|
105
|
+
n : int,
|
|
106
|
+
/,
|
|
107
|
+
*,
|
|
108
|
+
rng: RNG_TYPE,
|
|
109
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
110
|
+
device: Optional[DEVICE_TYPE] = None
|
|
111
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
112
|
+
new_rng, rng = jax.random.split(rng)
|
|
113
|
+
data = jax.random.permutation(rng, n, dtype=dtype or int)
|
|
114
|
+
if device is not None:
|
|
115
|
+
data = jax.device_put(data, device)
|
|
116
|
+
return new_rng, data
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from array_api_compat.numpy import *
|
|
2
|
+
|
|
3
|
+
simplified_name = "numpy"
|
|
4
|
+
|
|
5
|
+
from array_api_compat import numpy as compat_module
|
|
6
|
+
# Import and bind all functions from array_api_extra before exposing them
|
|
7
|
+
import array_api_extra
|
|
8
|
+
from functools import partial
|
|
9
|
+
for api_name in dir(array_api_extra):
|
|
10
|
+
if api_name.startswith('_'):
|
|
11
|
+
continue
|
|
12
|
+
globals()[api_name] = partial(
|
|
13
|
+
getattr(array_api_extra, api_name),
|
|
14
|
+
xp=compat_module
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from ._typing import *
|
|
18
|
+
from ._extra import *
|
|
19
|
+
__import__(__package__ + ".random")
|