xbarray 0.0.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xbarray might be problematic. Click here for more details.

Files changed (35) hide show
  1. array_api_typing/__init__.py +9 -0
  2. array_api_typing/typing_2024_12/__init__.py +12 -0
  3. array_api_typing/typing_2024_12/_api_constant.py +32 -0
  4. array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
  5. array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
  6. array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
  7. array_api_typing/typing_2024_12/_api_typing.py +5855 -0
  8. array_api_typing/typing_2024_12/_array_typing.py +1265 -0
  9. array_api_typing/typing_compat/__init__.py +12 -0
  10. array_api_typing/typing_compat/_api_typing.py +27 -0
  11. array_api_typing/typing_compat/_array_typing.py +36 -0
  12. array_api_typing/typing_extra/__init__.py +12 -0
  13. array_api_typing/typing_extra/_api_typing.py +651 -0
  14. array_api_typing/typing_extra/_at.py +87 -0
  15. xbarray/__init__.py +1 -0
  16. xbarray/base/__init__.py +1 -0
  17. xbarray/base/base.py +199 -0
  18. xbarray/common/implementations.py +61 -0
  19. xbarray/jax/__init__.py +25 -0
  20. xbarray/jax/_extra.py +98 -0
  21. xbarray/jax/_typing.py +15 -0
  22. xbarray/jax/random.py +116 -0
  23. xbarray/numpy/__init__.py +19 -0
  24. xbarray/numpy/_extra.py +83 -0
  25. xbarray/numpy/_typing.py +14 -0
  26. xbarray/numpy/random.py +106 -0
  27. xbarray/pytorch/__init__.py +20 -0
  28. xbarray/pytorch/_extra.py +109 -0
  29. xbarray/pytorch/_typing.py +13 -0
  30. xbarray/pytorch/random.py +102 -0
  31. xbarray-0.0.1a1.dist-info/METADATA +14 -0
  32. xbarray-0.0.1a1.dist-info/RECORD +35 -0
  33. xbarray-0.0.1a1.dist-info/WHEEL +5 -0
  34. xbarray-0.0.1a1.dist-info/licenses/LICENSE +21 -0
  35. xbarray-0.0.1a1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,87 @@
1
+ from typing import Protocol, TypeVar, Optional, Any, Tuple, Union, Type, TypedDict, List
2
+ from abc import abstractmethod
3
+ from array_api_typing.typing_compat._api_typing import _NAMESPACE_ARRAY
4
+ from array_api_typing.typing_compat._array_typing import SetIndex
5
+
6
+ class AtResult(Protocol[_NAMESPACE_ARRAY]):
7
+ @abstractmethod
8
+ def __getitem__(self, idx: SetIndex, /) -> "AtResult[_NAMESPACE_ARRAY]":
9
+ """
10
+ Allow for the alternate syntax ``at(x)[start:stop:step]``.
11
+
12
+ It looks prettier than ``at(x, slice(start, stop, step))``
13
+ and feels more intuitive coming from the JAX documentation.
14
+ """
15
+ pass
16
+
17
+ @abstractmethod
18
+ def set(
19
+ self,
20
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
21
+ /,
22
+ copy: Optional[bool] = None,
23
+ ) -> _NAMESPACE_ARRAY:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def add(
28
+ self,
29
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
30
+ /,
31
+ copy: Optional[bool] = None,
32
+ ) -> _NAMESPACE_ARRAY:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def subtract(
37
+ self,
38
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
39
+ /,
40
+ copy: Optional[bool] = None,
41
+ ) -> _NAMESPACE_ARRAY:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def multiply(
46
+ self,
47
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
48
+ /,
49
+ copy: Optional[bool] = None,
50
+ ) -> _NAMESPACE_ARRAY:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def divide(
55
+ self,
56
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
57
+ /,
58
+ copy: Optional[bool] = None,
59
+ ) -> _NAMESPACE_ARRAY:
60
+ pass
61
+
62
+ @abstractmethod
63
+ def power(
64
+ self,
65
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
66
+ /,
67
+ copy: Optional[bool] = None,
68
+ ) -> _NAMESPACE_ARRAY:
69
+ pass
70
+
71
+ @abstractmethod
72
+ def min(
73
+ self,
74
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
75
+ /,
76
+ copy: Optional[bool] = None,
77
+ ) -> _NAMESPACE_ARRAY:
78
+ pass
79
+
80
+ @abstractmethod
81
+ def max(
82
+ self,
83
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
84
+ /,
85
+ copy: Optional[bool] = None,
86
+ ) -> _NAMESPACE_ARRAY:
87
+ pass
xbarray/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .base import *
@@ -0,0 +1 @@
1
+ from .base import *
xbarray/base/base.py ADDED
@@ -0,0 +1,199 @@
1
+ from typing import Optional, Generic, TypeVar, Dict, Union, Any, Sequence, SupportsFloat, Tuple, Type, Callable, Mapping, Protocol
2
+ import abc
3
+ from array_api_typing.typing_extra import *
4
+ import numpy as np
5
+
6
+ ArrayAPISetIndex = SetIndex
7
+ ArrayAPIGetIndex = GetIndex
8
+
9
+ __all__ = [
10
+ "RNGBackend",
11
+ "ComputeBackend",
12
+ "BArrayType",
13
+ "BDeviceType",
14
+ "BDtypeType",
15
+ "BRNGType",
16
+ "SupportsDLPack",
17
+ "ArrayAPIArray",
18
+ "ArrayAPIDevice",
19
+ "ArrayAPIDType",
20
+ "ArrayAPINamespace",
21
+ "ArrayAPISetIndex",
22
+ "ArrayAPIGetIndex",
23
+ ]
24
+
25
+ BArrayType = TypeVar("BArrayType", covariant=True, bound=ArrayAPIArray)
26
+ BDeviceType = TypeVar("BDeviceType", covariant=True, bound=ArrayAPIDevice)
27
+ BDtypeType = TypeVar("BDtypeType", covariant=True, bound=ArrayAPIDType)
28
+ BRNGType = TypeVar("BRNGType", covariant=True)
29
+ class RNGBackend(Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
30
+ @abc.abstractmethod
31
+ def random_number_generator(
32
+ self,
33
+ seed : Optional[int] = None,
34
+ *,
35
+ device : Optional[BDeviceType] = None
36
+ ) -> BRNGType:
37
+ raise NotImplementedError
38
+
39
+ @abc.abstractmethod
40
+ def random_discrete_uniform(
41
+ self,
42
+ shape : Union[int, Tuple[int, ...]],
43
+ from_num : int,
44
+ to_num : int,
45
+ /,
46
+ *,
47
+ rng : BRNGType,
48
+ dtype : Optional[BDtypeType] = None,
49
+ device : Optional[BDeviceType] = None,
50
+ ) -> Tuple[BRNGType, BArrayType]:
51
+ """
52
+ Sample from a discrete uniform distribution [from_num, to_num) with shape `shape`.
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @abc.abstractmethod
57
+ def random_uniform(
58
+ self,
59
+ shape: Union[int, Tuple[int, ...]],
60
+ /,
61
+ *,
62
+ rng : BRNGType,
63
+ low : float = 0.0, high : float = 1.0,
64
+ dtype : Optional[BDtypeType] = None,
65
+ device : Optional[BDeviceType] = None,
66
+ ) -> Tuple[BRNGType, BArrayType]:
67
+ raise NotImplementedError
68
+
69
+ @abc.abstractmethod
70
+ def random_exponential(
71
+ self,
72
+ shape: Union[int, Tuple[int, ...]],
73
+ /,
74
+ *,
75
+ rng : BRNGType,
76
+ lambd : float = 1.0,
77
+ dtype : Optional[BDtypeType] = None,
78
+ device : Optional[BDeviceType] = None,
79
+ ) -> Tuple[BRNGType, BArrayType]:
80
+ raise NotImplementedError
81
+
82
+ @abc.abstractmethod
83
+ def random_normal(
84
+ self,
85
+ shape: Union[int, Tuple[int, ...]],
86
+ /,
87
+ *,
88
+ rng : BRNGType,
89
+ mean : float = 0.0, std : float = 1.0,
90
+ dtype : Optional[BDtypeType] = None,
91
+ device : Optional[BDeviceType] = None,
92
+ ) -> Tuple[BRNGType, BArrayType]:
93
+ raise NotImplementedError
94
+
95
+ @abc.abstractmethod
96
+ def random_geometric(
97
+ self,
98
+ shape: Union[int, Tuple[int, ...]],
99
+ /,
100
+ *,
101
+ p : float,
102
+ rng : BRNGType,
103
+ dtype : Optional[BDtypeType] = None,
104
+ device : Optional[BDeviceType] = None,
105
+ ) -> Tuple[BRNGType, BArrayType]:
106
+ raise NotImplementedError
107
+
108
+ @abc.abstractmethod
109
+ def random_permutation(
110
+ self,
111
+ n : int,
112
+ /,
113
+ *,
114
+ rng : BRNGType,
115
+ dtype : Optional[BDtypeType] = None,
116
+ device : Optional[BDeviceType] = None
117
+ ) -> Tuple[BRNGType, BArrayType]:
118
+ raise NotImplementedError
119
+
120
+ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
121
+ simplified_name : str
122
+ ARRAY_TYPE : Type[BArrayType]
123
+ DEVICE_TYPE : Type[BDeviceType]
124
+ DTYPE_TYPE : Type[BDtypeType]
125
+ RNG_TYPE : Type[BRNGType]
126
+ default_integer_dtype : BDtypeType
127
+ default_floating_dtype : BDtypeType
128
+ default_boolean_dtype : BDtypeType
129
+ random : RNGBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
130
+
131
+ @abc.abstractmethod
132
+ def is_backendarray(self, data : Any) -> bool:
133
+ raise NotImplementedError
134
+
135
+ @abc.abstractmethod
136
+ def from_numpy(
137
+ self,
138
+ data : np.ndarray,
139
+ /,
140
+ *,
141
+ dtype : Optional[BDtypeType] = None,
142
+ device : Optional[BDeviceType] = None
143
+ ) -> BArrayType:
144
+ raise NotImplementedError
145
+
146
+ @abc.abstractmethod
147
+ def from_other_backend(
148
+ self,
149
+ other_backend : "ComputeBackend",
150
+ data : ArrayAPIArray,
151
+ /
152
+ ) -> BArrayType:
153
+ """
154
+ Convert an array from another backend to this backend.
155
+ The other backend must be compatible with the ArrayAPI.
156
+ """
157
+ raise NotImplementedError
158
+
159
+ @abc.abstractmethod
160
+ def to_numpy(
161
+ self,
162
+ data : BArrayType
163
+ ) -> np.ndarray:
164
+ raise NotImplementedError
165
+
166
+ @abc.abstractmethod
167
+ def to_dlpack(
168
+ self,
169
+ data : BArrayType,
170
+ /,
171
+ ) -> SupportsDLPack:
172
+ raise NotImplementedError
173
+
174
+ @abc.abstractmethod
175
+ def dtype_is_real_integer(self, dtype : BDtypeType) -> bool:
176
+ raise NotImplementedError
177
+
178
+ @abc.abstractmethod
179
+ def dtype_is_real_floating(self, dtype : BDtypeType) -> bool:
180
+ raise NotImplementedError
181
+
182
+ @abc.abstractmethod
183
+ def dtype_is_boolean(self, dtype : BDtypeType) -> bool:
184
+ raise NotImplementedError
185
+
186
+ @abc.abstractmethod
187
+ def abbreviate_array(self, x : BArrayType, try_cast_scalar: bool = True) -> Union[float, int, BArrayType]:
188
+ """
189
+ Abbreviates an array to a single element if possible.
190
+ Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
191
+ """
192
+ pass
193
+
194
+ @abc.abstractmethod
195
+ def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
196
+ """
197
+ Map a function over arrays in a data structure and produce a new data structure with the same shape.
198
+ This is useful for applying a function to all arrays in a nested structure.
199
+ """
@@ -0,0 +1,61 @@
1
+ from typing import Any, Union, Callable
2
+ from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
3
+
4
+ __all__ = [
5
+ "get_abbreviate_array_function",
6
+ "get_map_fn_over_arrays_function",
7
+ ]
8
+
9
+ def get_abbreviate_array_function(
10
+ backend : CompatNamespace[CompatArray, Any, Any],
11
+ default_integer_dtype : CompatDType,
12
+ func_dtype_is_real_floating : Callable[[CompatDType], bool],
13
+ func_dtype_is_real_integer : Callable[[CompatDType], bool],
14
+ func_dtype_is_boolean : Callable[[CompatDType], bool],
15
+ ):
16
+ def abbreviate_array(array : CompatArray, try_cast_scalar : bool = True) -> Union[float, int, CompatArray]:
17
+ """
18
+ Abbreivates an array to a single element if possible.
19
+ Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
20
+ """
21
+ abbr_array = array
22
+ idx = backend.zeros(1, dtype=default_integer_dtype, device=backend.device(abbr_array))
23
+ for dim_i in range(len(array.shape)):
24
+ first_elem = backend.take(abbr_array, idx, axis=dim_i)
25
+ if backend.all(abbr_array == first_elem):
26
+ abbr_array = first_elem
27
+ else:
28
+ continue
29
+ if try_cast_scalar:
30
+ if all(i == 1 for i in abbr_array.shape):
31
+ elem = abbr_array[tuple([0] * len(abbr_array.shape))]
32
+ if func_dtype_is_real_floating(elem.dtype):
33
+ return float(elem)
34
+ elif func_dtype_is_real_integer(elem.dtype):
35
+ return int(elem)
36
+ elif func_dtype_is_boolean(elem.dtype):
37
+ return bool(elem)
38
+ else:
39
+ raise ValueError(f"Abbreviated array element dtype must be a real floating or integer or boolean type, actual dtype: {elem.dtype}")
40
+ else:
41
+ return array
42
+ return abbreviate_array
43
+
44
+ def get_map_fn_over_arrays_function(
45
+ is_backendarray : Callable[[Any], bool],
46
+ ):
47
+ def map_fn_over_arrays(data : Any, func : Callable[[CompatArray], CompatArray]) -> Any:
48
+ """
49
+ Map a function to the data.
50
+ """
51
+ if is_backendarray(data):
52
+ return func(data)
53
+ elif isinstance(data, dict):
54
+ return {k: map_fn_over_arrays(v, func) for k, v in data.items()}
55
+ elif isinstance(data, tuple):
56
+ return tuple(map_fn_over_arrays(i, func) for i in data)
57
+ elif isinstance(data, list):
58
+ return [map_fn_over_arrays(i, func) for i in data]
59
+ else:
60
+ return data
61
+ return map_fn_over_arrays
@@ -0,0 +1,25 @@
1
+ import jax.numpy
2
+
3
+ simplified_name = "jax"
4
+
5
+ if hasattr(jax.numpy, "__array_api_version__"):
6
+ compat_module = jax.numpy
7
+ from jax.numpy import *
8
+ else:
9
+ import jax.experimental.array_api as compat_module
10
+ from jax.experimental.array_api import *
11
+
12
+ # Import and bind all functions from array_api_extra before exposing them
13
+ import array_api_extra
14
+ from functools import partial
15
+ for api_name in dir(array_api_extra):
16
+ if api_name.startswith('_'):
17
+ continue
18
+ globals()[api_name] = partial(
19
+ getattr(array_api_extra, api_name),
20
+ xp=compat_module
21
+ )
22
+
23
+ from ._typing import *
24
+ from ._extra import *
25
+ __import__(__package__ + ".random")
xbarray/jax/_extra.py ADDED
@@ -0,0 +1,98 @@
1
+ from typing import Any, Union, Optional, Callable
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
6
+ from xbarray.base import ComputeBackend, SupportsDLPack
7
+
8
+ __all__ = [
9
+ "default_integer_dtype",
10
+ "default_floating_dtype",
11
+ "default_boolean_dtype",
12
+ "is_backendarray",
13
+ "from_numpy",
14
+ "from_other_backend",
15
+ "to_numpy",
16
+ "to_dlpack",
17
+ "dtype_is_real_integer",
18
+ "dtype_is_real_floating",
19
+ "dtype_is_boolean",
20
+ "abbreviate_array",
21
+ ]
22
+
23
+ default_integer_dtype = int
24
+ default_floating_dtype = float
25
+ default_boolean_dtype = bool
26
+
27
+ def is_backendarray(data : Any) -> bool:
28
+ return isinstance(data, jax.Array)
29
+
30
+ def from_numpy(
31
+ data : np.ndarray,
32
+ /,
33
+ *,
34
+ dtype : Optional[DTYPE_TYPE] = None,
35
+ device : Optional[DEVICE_TYPE] = None
36
+ ) -> ARRAY_TYPE:
37
+ return jax.numpy.asarray(data, dtype=dtype, device=device)
38
+
39
+ def from_other_backend(
40
+ other_backend: ComputeBackend,
41
+ data: Any,
42
+ /,
43
+ ) -> ARRAY_TYPE:
44
+ data_dlpack = other_backend.to_dlpack(data)
45
+ return jax.dlpack.from_dlpack(data_dlpack)
46
+ # except Exception as e:
47
+ # # jax sometimes has tiling issues with dlpack converted data
48
+ # np = other_backend.to_numpy(data)
49
+ # return from_numpy(np)
50
+
51
+ def to_numpy(
52
+ data : ARRAY_TYPE
53
+ ) -> np.ndarray:
54
+ if data.dtype == jax.dtypes.bfloat16:
55
+ data = data.astype(np.float32)
56
+ return np.asarray(data)
57
+
58
+ def to_dlpack(
59
+ self,
60
+ data: ARRAY_TYPE,
61
+ /,
62
+ ) -> SupportsDLPack:
63
+ return data
64
+
65
+ def dtype_is_real_integer(
66
+ dtype: DTYPE_TYPE
67
+ ) -> bool:
68
+ return np.issubdtype(dtype, np.integer)
69
+
70
+ def dtype_is_real_floating(
71
+ dtype: DTYPE_TYPE
72
+ ) -> bool:
73
+ return dtype == jax.dtypes.bfloat16 or np.issubdtype(dtype, np.floating)
74
+
75
+ def dtype_is_boolean(
76
+ dtype: DTYPE_TYPE
77
+ ) -> bool:
78
+ return dtype == np.bool_ or dtype == bool
79
+
80
+ from xbarray.common.implementations import *
81
+ if hasattr(jax.numpy, "__array_api_version__"):
82
+ compat_module = jax.numpy
83
+ else:
84
+ import jax.experimental.array_api as compat_module
85
+ abbreviate_array = get_abbreviate_array_function(
86
+ backend=compat_module,
87
+ default_integer_dtype=default_integer_dtype,
88
+ func_dtype_is_real_floating=dtype_is_real_floating,
89
+ func_dtype_is_real_integer=dtype_is_real_integer,
90
+ func_dtype_is_boolean=dtype_is_boolean
91
+ )
92
+ def map_fn_over_arrays(
93
+ data : Any, func : Callable[[ARRAY_TYPE], ARRAY_TYPE]
94
+ ):
95
+ return jax.tree.map(
96
+ func,
97
+ data
98
+ )
xbarray/jax/_typing.py ADDED
@@ -0,0 +1,15 @@
1
+ from typing import Union, Any, Optional
2
+ import jax
3
+ import numpy as np
4
+
5
+ __all__ = [
6
+ 'ARRAY_TYPE',
7
+ 'DTYPE_TYPE',
8
+ 'DEVICE_TYPE',
9
+ 'RNG_TYPE',
10
+ ]
11
+
12
+ ARRAY_TYPE = jax.Array
13
+ DTYPE_TYPE = np.dtype
14
+ DEVICE_TYPE = Union[jax.Device, jax.sharding.Sharding]
15
+ RNG_TYPE = jax.Array
xbarray/jax/random.py ADDED
@@ -0,0 +1,116 @@
1
+ from typing import Union, Optional, Tuple, Any
2
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
3
+ import numpy as np
4
+ import jax
5
+
6
+ __all__ = [
7
+ "random_number_generator",
8
+ "random_discrete_uniform",
9
+ "random_uniform",
10
+ "random_exponential",
11
+ "random_normal",
12
+ "random_geometric",
13
+ "random_permutation"
14
+ ]
15
+
16
+ def random_number_generator(
17
+ seed : Optional[int] = None,
18
+ *,
19
+ device : Optional[DEVICE_TYPE] = None
20
+ ) -> RNG_TYPE:
21
+ rng_seed = np.random.randint(0) if seed is None else seed
22
+ rng = jax.random.key(
23
+ seed=rng_seed
24
+ )
25
+ return rng
26
+
27
+ def random_discrete_uniform(
28
+ shape : Union[int, Tuple[int, ...]],
29
+ from_num : int,
30
+ to_num : int,
31
+ /,
32
+ *,
33
+ rng : RNG_TYPE,
34
+ dtype : Optional[DTYPE_TYPE] = None,
35
+ device : Optional[DEVICE_TYPE] = None
36
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
37
+ new_rng, rng = jax.random.split(rng)
38
+ t = jax.random.randint(rng, shape, minval=int(from_num), maxval=int(to_num), dtype=dtype or int)
39
+ if device is not None:
40
+ t = jax.device_put(t, device)
41
+ return new_rng, t
42
+
43
+ def random_uniform(
44
+ shape: Union[int, Tuple[int, ...]],
45
+ /,
46
+ *,
47
+ rng : RNG_TYPE,
48
+ low : float = 0.0, high : float = 1.0,
49
+ dtype : Optional[DTYPE_TYPE] = None,
50
+ device : Optional[DEVICE_TYPE] = None
51
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
52
+ new_rng, rng = jax.random.split(rng)
53
+ data = jax.random.uniform(rng, shape, dtype=dtype or float, minval=low, maxval=high)
54
+ if device is not None:
55
+ data = jax.device_put(data, device)
56
+ return new_rng, data
57
+
58
+ def random_exponential(
59
+ shape: Union[int, Tuple[int, ...]],
60
+ /,
61
+ *,
62
+ rng : RNG_TYPE,
63
+ lambd : float = 1.0,
64
+ dtype : Optional[DTYPE_TYPE] = None,
65
+ device : Optional[DEVICE_TYPE] = None
66
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
67
+ new_rng, rng = jax.random.split(rng)
68
+ data = jax.random.exponential(rng, shape, dtype=dtype or float) / lambd
69
+ if device is not None:
70
+ data = jax.device_put(data, device)
71
+ return new_rng, data
72
+
73
+ @classmethod
74
+ def random_normal(
75
+ shape: Union[int, Tuple[int, ...]],
76
+ /,
77
+ *,
78
+ rng : RNG_TYPE,
79
+ mean : float = 0.0, std : float = 1.0,
80
+ dtype : Optional[DTYPE_TYPE] = None,
81
+ device : Optional[Any] = None
82
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
83
+ new_rng, rng = jax.random.split(rng)
84
+ data = jax.random.normal(rng, shape, dtype=dtype or float) * std + mean
85
+ if device is not None:
86
+ data = jax.device_put(data, device)
87
+ return new_rng, data
88
+
89
+ def random_geometric(
90
+ shape: Union[int, Tuple[int, ...]],
91
+ /,
92
+ *,
93
+ p: float,
94
+ rng: RNG_TYPE,
95
+ dtype: Optional[DTYPE_TYPE] = None,
96
+ device: Optional[Any] = None
97
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
98
+ new_rng, rng = jax.random.split(rng)
99
+ data = jax.random.geometric(rng, p=p, shape=shape, dtype=dtype or int)
100
+ if device is not None:
101
+ data = jax.device_put(data, device)
102
+ return new_rng, data
103
+
104
+ def random_permutation(
105
+ n : int,
106
+ /,
107
+ *,
108
+ rng: RNG_TYPE,
109
+ dtype: Optional[DTYPE_TYPE] = None,
110
+ device: Optional[DEVICE_TYPE] = None
111
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
112
+ new_rng, rng = jax.random.split(rng)
113
+ data = jax.random.permutation(rng, n, dtype=dtype or int)
114
+ if device is not None:
115
+ data = jax.device_put(data, device)
116
+ return new_rng, data
@@ -0,0 +1,19 @@
1
+ from array_api_compat.numpy import *
2
+
3
+ simplified_name = "numpy"
4
+
5
+ from array_api_compat import numpy as compat_module
6
+ # Import and bind all functions from array_api_extra before exposing them
7
+ import array_api_extra
8
+ from functools import partial
9
+ for api_name in dir(array_api_extra):
10
+ if api_name.startswith('_'):
11
+ continue
12
+ globals()[api_name] = partial(
13
+ getattr(array_api_extra, api_name),
14
+ xp=compat_module
15
+ )
16
+
17
+ from ._typing import *
18
+ from ._extra import *
19
+ __import__(__package__ + ".random")