xbarray 0.0.1a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- array_api_typing/__init__.py +9 -0
- array_api_typing/typing_2024_12/__init__.py +12 -0
- array_api_typing/typing_2024_12/_api_constant.py +32 -0
- array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
- array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
- array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
- array_api_typing/typing_2024_12/_api_typing.py +5855 -0
- array_api_typing/typing_2024_12/_array_typing.py +1265 -0
- array_api_typing/typing_compat/__init__.py +12 -0
- array_api_typing/typing_compat/_api_typing.py +27 -0
- array_api_typing/typing_compat/_array_typing.py +36 -0
- array_api_typing/typing_extra/__init__.py +12 -0
- array_api_typing/typing_extra/_api_typing.py +651 -0
- array_api_typing/typing_extra/_at.py +87 -0
- xbarray/__init__.py +1 -0
- xbarray/backends/_cls_base.py +9 -0
- xbarray/backends/_implementations/_common/implementations.py +87 -0
- xbarray/backends/_implementations/jax/__init__.py +33 -0
- xbarray/backends/_implementations/jax/_extra.py +127 -0
- xbarray/backends/_implementations/jax/_typing.py +15 -0
- xbarray/backends/_implementations/jax/random.py +115 -0
- xbarray/backends/_implementations/numpy/__init__.py +25 -0
- xbarray/backends/_implementations/numpy/_extra.py +98 -0
- xbarray/backends/_implementations/numpy/_typing.py +14 -0
- xbarray/backends/_implementations/numpy/random.py +105 -0
- xbarray/backends/_implementations/pytorch/__init__.py +26 -0
- xbarray/backends/_implementations/pytorch/_extra.py +135 -0
- xbarray/backends/_implementations/pytorch/_typing.py +13 -0
- xbarray/backends/_implementations/pytorch/random.py +101 -0
- xbarray/backends/base.py +218 -0
- xbarray/backends/jax.py +19 -0
- xbarray/backends/numpy.py +19 -0
- xbarray/backends/pytorch.py +22 -0
- xbarray/jax.py +4 -0
- xbarray/numpy.py +4 -0
- xbarray/pytorch.py +4 -0
- xbarray/transformations/pointcloud/__init__.py +1 -0
- xbarray/transformations/pointcloud/base.py +449 -0
- xbarray/transformations/pointcloud/jax.py +24 -0
- xbarray/transformations/pointcloud/numpy.py +23 -0
- xbarray/transformations/pointcloud/pytorch.py +23 -0
- xbarray/transformations/rotation_conversions/__init__.py +1 -0
- xbarray/transformations/rotation_conversions/base.py +713 -0
- xbarray/transformations/rotation_conversions/jax.py +41 -0
- xbarray/transformations/rotation_conversions/numpy.py +41 -0
- xbarray/transformations/rotation_conversions/pytorch.py +41 -0
- xbarray-0.0.1a13.dist-info/METADATA +20 -0
- xbarray-0.0.1a13.dist-info/RECORD +51 -0
- xbarray-0.0.1a13.dist-info/WHEEL +5 -0
- xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
- xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
from typing import Union, Optional, Tuple, Any
|
|
2
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"random_number_generator",
|
|
7
|
+
"random_discrete_uniform",
|
|
8
|
+
"random_uniform",
|
|
9
|
+
"random_exponential",
|
|
10
|
+
"random_normal",
|
|
11
|
+
"random_geometric",
|
|
12
|
+
"random_permutation"
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
def random_number_generator(
|
|
16
|
+
seed : Optional[int] = None,
|
|
17
|
+
*,
|
|
18
|
+
device : Optional[DEVICE_TYPE] = None
|
|
19
|
+
) -> RNG_TYPE:
|
|
20
|
+
return np.random.default_rng(seed)
|
|
21
|
+
|
|
22
|
+
def random_discrete_uniform(
|
|
23
|
+
shape : Union[int, Tuple[int, ...]],
|
|
24
|
+
/,
|
|
25
|
+
from_num : int,
|
|
26
|
+
to_num : int,
|
|
27
|
+
*,
|
|
28
|
+
rng : RNG_TYPE,
|
|
29
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
30
|
+
device : Optional[DEVICE_TYPE] = None
|
|
31
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
32
|
+
t = rng.integers(int(from_num), int(to_num), size=shape)
|
|
33
|
+
if dtype is not None:
|
|
34
|
+
t = t.astype(dtype)
|
|
35
|
+
return rng, t
|
|
36
|
+
|
|
37
|
+
def random_uniform(
|
|
38
|
+
shape: Union[int, Tuple[int, ...]],
|
|
39
|
+
/,
|
|
40
|
+
*,
|
|
41
|
+
rng : RNG_TYPE,
|
|
42
|
+
low : float = 0.0, high : float = 1.0,
|
|
43
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
44
|
+
device : Optional[DEVICE_TYPE] = None
|
|
45
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
46
|
+
t = rng.uniform(float(low), float(high), size=shape)
|
|
47
|
+
if dtype is not None:
|
|
48
|
+
t = t.astype(dtype)
|
|
49
|
+
return rng, t
|
|
50
|
+
|
|
51
|
+
def random_exponential(
|
|
52
|
+
shape: Union[int, Tuple[int, ...]],
|
|
53
|
+
/,
|
|
54
|
+
*,
|
|
55
|
+
rng : RNG_TYPE,
|
|
56
|
+
lambd : float = 1.0,
|
|
57
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
58
|
+
device : Optional[DEVICE_TYPE] = None
|
|
59
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
60
|
+
t = rng.exponential(1.0 / float(lambd), size=shape)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
t = t.astype(dtype)
|
|
63
|
+
return rng, t
|
|
64
|
+
|
|
65
|
+
def random_normal(
|
|
66
|
+
shape: Union[int, Tuple[int, ...]],
|
|
67
|
+
/,
|
|
68
|
+
*,
|
|
69
|
+
rng : RNG_TYPE,
|
|
70
|
+
mean : float = 0.0, std : float = 1.0,
|
|
71
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
72
|
+
device : Optional[Any] = None
|
|
73
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
74
|
+
t = rng.normal(mean, std, size=shape)
|
|
75
|
+
if dtype is not None:
|
|
76
|
+
t = t.astype(dtype)
|
|
77
|
+
return rng, t
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def random_geometric(
|
|
81
|
+
shape: Union[int, Tuple[int, ...]],
|
|
82
|
+
/,
|
|
83
|
+
*,
|
|
84
|
+
p: float,
|
|
85
|
+
rng: RNG_TYPE,
|
|
86
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
87
|
+
device: Optional[Any] = None
|
|
88
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
89
|
+
t = rng.geometric(p, size=shape)
|
|
90
|
+
if dtype is not None:
|
|
91
|
+
t = t.astype(dtype)
|
|
92
|
+
return rng, t
|
|
93
|
+
|
|
94
|
+
def random_permutation(
|
|
95
|
+
n : int,
|
|
96
|
+
/,
|
|
97
|
+
*,
|
|
98
|
+
rng: RNG_TYPE,
|
|
99
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
100
|
+
device: Optional[DEVICE_TYPE] = None
|
|
101
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
102
|
+
t = rng.permutation(n)
|
|
103
|
+
if dtype is not None:
|
|
104
|
+
t = t.astype(dtype)
|
|
105
|
+
return rng, t
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from array_api_compat.torch import *
|
|
2
|
+
from array_api_compat.torch import __array_api_version__, __array_namespace_info__
|
|
3
|
+
from array_api_compat.common._helpers import *
|
|
4
|
+
|
|
5
|
+
simplified_name = "pytorch"
|
|
6
|
+
|
|
7
|
+
from array_api_compat import torch as compat_module
|
|
8
|
+
|
|
9
|
+
# Import and bind all functions from array_api_extra before exposing them
|
|
10
|
+
import array_api_extra
|
|
11
|
+
from functools import partial
|
|
12
|
+
for api_name in dir(array_api_extra):
|
|
13
|
+
if api_name.startswith('_'):
|
|
14
|
+
continue
|
|
15
|
+
|
|
16
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
17
|
+
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
18
|
+
else:
|
|
19
|
+
globals()[api_name] = partial(
|
|
20
|
+
getattr(array_api_extra, api_name),
|
|
21
|
+
xp=compat_module
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from ._typing import *
|
|
25
|
+
from ._extra import *
|
|
26
|
+
__import__(__package__ + ".random")
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Any, Union, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
5
|
+
from xbarray.backends.base import ComputeBackend, SupportsDLPack
|
|
6
|
+
|
|
7
|
+
PYTORCH_DTYPE_CAST_MAP = {
|
|
8
|
+
torch.uint16: torch.int16,
|
|
9
|
+
torch.uint32: torch.int32,
|
|
10
|
+
torch.uint64: torch.int64,
|
|
11
|
+
torch.float8_e4m3fn: torch.float16,
|
|
12
|
+
torch.float8_e5m2: torch.float16,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"default_integer_dtype",
|
|
17
|
+
"default_index_dtype",
|
|
18
|
+
"default_floating_dtype",
|
|
19
|
+
"default_boolean_dtype",
|
|
20
|
+
"serialize_device",
|
|
21
|
+
"deserialize_device",
|
|
22
|
+
"is_backendarray",
|
|
23
|
+
"from_numpy",
|
|
24
|
+
"from_other_backend",
|
|
25
|
+
"to_numpy",
|
|
26
|
+
"to_dlpack",
|
|
27
|
+
"dtype_is_real_integer",
|
|
28
|
+
"dtype_is_real_floating",
|
|
29
|
+
"dtype_is_boolean",
|
|
30
|
+
"abbreviate_array",
|
|
31
|
+
"map_fn_over_arrays",
|
|
32
|
+
"pad_dim",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
default_integer_dtype = torch.int32
|
|
36
|
+
default_index_dtype = torch.long
|
|
37
|
+
default_floating_dtype = torch.float32
|
|
38
|
+
default_boolean_dtype = torch.bool
|
|
39
|
+
|
|
40
|
+
def serialize_device(device : Optional[DEVICE_TYPE]) -> Optional[str]:
|
|
41
|
+
if device is None:
|
|
42
|
+
return None
|
|
43
|
+
if isinstance(device, str):
|
|
44
|
+
return device
|
|
45
|
+
assert isinstance(device, torch.device)
|
|
46
|
+
return f"{device.type}:{device.index if device.index is not None else 0}"
|
|
47
|
+
|
|
48
|
+
def deserialize_device(device_str : Optional[str]) -> Optional[DEVICE_TYPE]:
|
|
49
|
+
if device_str is None:
|
|
50
|
+
return None
|
|
51
|
+
if ":" in device_str:
|
|
52
|
+
device_type, index_str = device_str.split(":")
|
|
53
|
+
index = int(index_str)
|
|
54
|
+
return torch.device(device_type, index)
|
|
55
|
+
else:
|
|
56
|
+
return torch.device(device_str)
|
|
57
|
+
|
|
58
|
+
def is_backendarray(data : Any) -> bool:
|
|
59
|
+
return isinstance(data, torch.Tensor)
|
|
60
|
+
|
|
61
|
+
def from_numpy(
|
|
62
|
+
data : np.ndarray,
|
|
63
|
+
/,
|
|
64
|
+
*,
|
|
65
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
66
|
+
device : Optional[DEVICE_TYPE] = None
|
|
67
|
+
) -> ARRAY_TYPE:
|
|
68
|
+
t = torch.from_numpy(data)
|
|
69
|
+
target_dtype = dtype if dtype is not None else PYTORCH_DTYPE_CAST_MAP.get(t.dtype, t.dtype)
|
|
70
|
+
if target_dtype is not None or device is not None:
|
|
71
|
+
t = t.to(device=device, dtype=target_dtype)
|
|
72
|
+
return t
|
|
73
|
+
|
|
74
|
+
def from_other_backend(
|
|
75
|
+
other_backend: ComputeBackend,
|
|
76
|
+
data: Any,
|
|
77
|
+
/,
|
|
78
|
+
) -> ARRAY_TYPE:
|
|
79
|
+
dat_dlpack = other_backend.to_dlpack(data)
|
|
80
|
+
return torch.from_dlpack(dat_dlpack)
|
|
81
|
+
|
|
82
|
+
def to_numpy(
|
|
83
|
+
data : ARRAY_TYPE
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
# Torch bfloat16 is not supported by numpy
|
|
86
|
+
if data.dtype == torch.bfloat16:
|
|
87
|
+
data = data.to(torch.float32)
|
|
88
|
+
return data.cpu().numpy()
|
|
89
|
+
|
|
90
|
+
def to_dlpack(
|
|
91
|
+
data: ARRAY_TYPE,
|
|
92
|
+
/,
|
|
93
|
+
) -> SupportsDLPack:
|
|
94
|
+
return data
|
|
95
|
+
|
|
96
|
+
def dtype_is_real_integer(
|
|
97
|
+
dtype: DTYPE_TYPE
|
|
98
|
+
) -> bool:
|
|
99
|
+
# https://pytorch.org/docs/stable/tensors.html#id12
|
|
100
|
+
return dtype in [
|
|
101
|
+
torch.int8, torch.int16, torch.int32, torch.int64,
|
|
102
|
+
torch.uint8,
|
|
103
|
+
torch.int,
|
|
104
|
+
torch.long
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
def dtype_is_real_floating(
|
|
108
|
+
dtype: DTYPE_TYPE
|
|
109
|
+
) -> bool:
|
|
110
|
+
return dtype in [
|
|
111
|
+
torch.float16, torch.float32, torch.float64,
|
|
112
|
+
torch.float, torch.double,
|
|
113
|
+
torch.bfloat16
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
def dtype_is_boolean(
|
|
117
|
+
dtype: DTYPE_TYPE
|
|
118
|
+
) -> bool:
|
|
119
|
+
return dtype == torch.bool
|
|
120
|
+
|
|
121
|
+
from .._common.implementations import *
|
|
122
|
+
from array_api_compat import torch as compat_module
|
|
123
|
+
abbreviate_array = get_abbreviate_array_function(
|
|
124
|
+
compat_module,
|
|
125
|
+
default_integer_dtype=default_integer_dtype,
|
|
126
|
+
func_dtype_is_real_floating=dtype_is_real_floating,
|
|
127
|
+
func_dtype_is_real_integer=dtype_is_real_integer,
|
|
128
|
+
func_dtype_is_boolean=dtype_is_boolean
|
|
129
|
+
)
|
|
130
|
+
map_fn_over_arrays = get_map_fn_over_arrays_function(
|
|
131
|
+
is_backendarray=is_backendarray,
|
|
132
|
+
)
|
|
133
|
+
pad_dim = get_pad_dim_function(
|
|
134
|
+
backend=compat_module,
|
|
135
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from typing import Union, Optional, Tuple, Any
|
|
2
|
+
from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"random_number_generator",
|
|
7
|
+
"random_discrete_uniform",
|
|
8
|
+
"random_uniform",
|
|
9
|
+
"random_exponential",
|
|
10
|
+
"random_normal",
|
|
11
|
+
"random_geometric",
|
|
12
|
+
"random_permutation"
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
def random_number_generator(
|
|
16
|
+
seed : Optional[int] = None,
|
|
17
|
+
*,
|
|
18
|
+
device : Optional[DEVICE_TYPE] = None
|
|
19
|
+
) -> RNG_TYPE:
|
|
20
|
+
rng = torch.Generator(
|
|
21
|
+
device=device
|
|
22
|
+
)
|
|
23
|
+
if seed is not None:
|
|
24
|
+
rng = rng.manual_seed(seed)
|
|
25
|
+
return rng
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def random_discrete_uniform(
|
|
29
|
+
shape : Union[int, Tuple[int, ...]],
|
|
30
|
+
/,
|
|
31
|
+
from_num : int,
|
|
32
|
+
to_num : int,
|
|
33
|
+
*,
|
|
34
|
+
rng : RNG_TYPE,
|
|
35
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
36
|
+
device : Optional[DEVICE_TYPE] = None
|
|
37
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
38
|
+
t = torch.randint(int(from_num), int(to_num), shape, generator=rng, dtype=dtype, device=device)
|
|
39
|
+
return rng, t
|
|
40
|
+
|
|
41
|
+
def random_uniform(
|
|
42
|
+
shape: Union[int, Tuple[int, ...]],
|
|
43
|
+
/,
|
|
44
|
+
*,
|
|
45
|
+
rng : RNG_TYPE,
|
|
46
|
+
low : float = 0.0, high : float = 1.0,
|
|
47
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
48
|
+
device : Optional[DEVICE_TYPE] = None
|
|
49
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
50
|
+
t = torch.rand(shape, generator=rng, dtype=dtype, device=device)
|
|
51
|
+
t = t * (high - low) + low
|
|
52
|
+
return rng, t
|
|
53
|
+
|
|
54
|
+
def random_exponential(
|
|
55
|
+
shape: Union[int, Tuple[int, ...]],
|
|
56
|
+
/,
|
|
57
|
+
*,
|
|
58
|
+
rng : RNG_TYPE,
|
|
59
|
+
lambd : float = 1.0,
|
|
60
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
61
|
+
device : Optional[DEVICE_TYPE] = None
|
|
62
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
63
|
+
t = torch.empty(shape, dtype=dtype, device=device)
|
|
64
|
+
t = t.exponential_(lambd, generator=rng)
|
|
65
|
+
return rng, t
|
|
66
|
+
|
|
67
|
+
def random_normal(
|
|
68
|
+
shape: Union[int, Tuple[int, ...]],
|
|
69
|
+
/,
|
|
70
|
+
*,
|
|
71
|
+
rng : RNG_TYPE,
|
|
72
|
+
mean : float = 0.0, std : float = 1.0,
|
|
73
|
+
dtype : Optional[DTYPE_TYPE] = None,
|
|
74
|
+
device : Optional[Any] = None
|
|
75
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
76
|
+
t = torch.normal(mean, std, shape, generator=rng, dtype=dtype, device=device)
|
|
77
|
+
return rng, t
|
|
78
|
+
|
|
79
|
+
def random_geometric(
|
|
80
|
+
shape: Union[int, Tuple[int, ...]],
|
|
81
|
+
/,
|
|
82
|
+
*,
|
|
83
|
+
p: float,
|
|
84
|
+
rng: RNG_TYPE,
|
|
85
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
86
|
+
device: Optional[Any] = None
|
|
87
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
88
|
+
t = torch.empty(shape, dtype=dtype, device=device)
|
|
89
|
+
t = t.geometric_(p, generator=rng)
|
|
90
|
+
return rng, t
|
|
91
|
+
|
|
92
|
+
def random_permutation(
|
|
93
|
+
n : int,
|
|
94
|
+
/,
|
|
95
|
+
*,
|
|
96
|
+
rng: RNG_TYPE,
|
|
97
|
+
dtype: Optional[DTYPE_TYPE] = None,
|
|
98
|
+
device: Optional[DEVICE_TYPE] = None
|
|
99
|
+
) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
|
|
100
|
+
t = torch.randperm(n, generator=rng, dtype=dtype, device=device)
|
|
101
|
+
return rng, t
|
xbarray/backends/base.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from typing import Optional, Generic, TypeVar, Dict, Union, Any, Sequence, SupportsFloat, Tuple, Type, Callable, Mapping, Protocol, Literal
|
|
2
|
+
import abc
|
|
3
|
+
from array_api_typing.typing_extra import *
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
ArrayAPISetIndex = SetIndex
|
|
7
|
+
ArrayAPIGetIndex = GetIndex
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"RNGBackend",
|
|
11
|
+
"ComputeBackend",
|
|
12
|
+
"BArrayType",
|
|
13
|
+
"BDeviceType",
|
|
14
|
+
"BDtypeType",
|
|
15
|
+
"BRNGType",
|
|
16
|
+
"SupportsDLPack",
|
|
17
|
+
"ArrayAPIArray",
|
|
18
|
+
"ArrayAPIDevice",
|
|
19
|
+
"ArrayAPIDType",
|
|
20
|
+
"ArrayAPINamespace",
|
|
21
|
+
"ArrayAPISetIndex",
|
|
22
|
+
"ArrayAPIGetIndex",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
BArrayType = TypeVar("BArrayType", covariant=True, bound=ArrayAPIArray)
|
|
26
|
+
BDeviceType = TypeVar("BDeviceType", covariant=True, bound=ArrayAPIDevice)
|
|
27
|
+
BDtypeType = TypeVar("BDtypeType", covariant=True, bound=ArrayAPIDType)
|
|
28
|
+
BRNGType = TypeVar("BRNGType", covariant=True)
|
|
29
|
+
class RNGBackend(Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def random_number_generator(
|
|
32
|
+
self,
|
|
33
|
+
seed : Optional[int] = None,
|
|
34
|
+
*,
|
|
35
|
+
device : Optional[BDeviceType] = None
|
|
36
|
+
) -> BRNGType:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def random_discrete_uniform(
|
|
41
|
+
self,
|
|
42
|
+
shape : Union[int, Tuple[int, ...]],
|
|
43
|
+
/,
|
|
44
|
+
from_num : int,
|
|
45
|
+
to_num : int,
|
|
46
|
+
*,
|
|
47
|
+
rng : BRNGType,
|
|
48
|
+
dtype : Optional[BDtypeType] = None,
|
|
49
|
+
device : Optional[BDeviceType] = None,
|
|
50
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
51
|
+
"""
|
|
52
|
+
Sample from a discrete uniform distribution [from_num, to_num) with shape `shape`.
|
|
53
|
+
"""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def random_uniform(
|
|
58
|
+
self,
|
|
59
|
+
shape: Union[int, Tuple[int, ...]],
|
|
60
|
+
/,
|
|
61
|
+
*,
|
|
62
|
+
rng : BRNGType,
|
|
63
|
+
low : float = 0.0, high : float = 1.0,
|
|
64
|
+
dtype : Optional[BDtypeType] = None,
|
|
65
|
+
device : Optional[BDeviceType] = None,
|
|
66
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def random_exponential(
|
|
71
|
+
self,
|
|
72
|
+
shape: Union[int, Tuple[int, ...]],
|
|
73
|
+
/,
|
|
74
|
+
*,
|
|
75
|
+
rng : BRNGType,
|
|
76
|
+
lambd : float = 1.0,
|
|
77
|
+
dtype : Optional[BDtypeType] = None,
|
|
78
|
+
device : Optional[BDeviceType] = None,
|
|
79
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
@abc.abstractmethod
|
|
83
|
+
def random_normal(
|
|
84
|
+
self,
|
|
85
|
+
shape: Union[int, Tuple[int, ...]],
|
|
86
|
+
/,
|
|
87
|
+
*,
|
|
88
|
+
rng : BRNGType,
|
|
89
|
+
mean : float = 0.0, std : float = 1.0,
|
|
90
|
+
dtype : Optional[BDtypeType] = None,
|
|
91
|
+
device : Optional[BDeviceType] = None,
|
|
92
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
@abc.abstractmethod
|
|
96
|
+
def random_geometric(
|
|
97
|
+
self,
|
|
98
|
+
shape: Union[int, Tuple[int, ...]],
|
|
99
|
+
/,
|
|
100
|
+
*,
|
|
101
|
+
p : float,
|
|
102
|
+
rng : BRNGType,
|
|
103
|
+
dtype : Optional[BDtypeType] = None,
|
|
104
|
+
device : Optional[BDeviceType] = None,
|
|
105
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
@abc.abstractmethod
|
|
109
|
+
def random_permutation(
|
|
110
|
+
self,
|
|
111
|
+
n : int,
|
|
112
|
+
/,
|
|
113
|
+
*,
|
|
114
|
+
rng : BRNGType,
|
|
115
|
+
dtype : Optional[BDtypeType] = None,
|
|
116
|
+
device : Optional[BDeviceType] = None
|
|
117
|
+
) -> Tuple[BRNGType, BArrayType]:
|
|
118
|
+
raise NotImplementedError
|
|
119
|
+
|
|
120
|
+
class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
|
|
121
|
+
simplified_name : Literal['numpy', 'jax', 'pytorch']
|
|
122
|
+
ARRAY_TYPE : Type[BArrayType]
|
|
123
|
+
DEVICE_TYPE : Type[BDeviceType]
|
|
124
|
+
DTYPE_TYPE : Type[BDtypeType]
|
|
125
|
+
RNG_TYPE : Type[BRNGType]
|
|
126
|
+
default_integer_dtype : BDtypeType
|
|
127
|
+
default_index_dtype : BDtypeType
|
|
128
|
+
default_floating_dtype : BDtypeType
|
|
129
|
+
default_boolean_dtype : BDtypeType
|
|
130
|
+
random : RNGBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
|
|
131
|
+
|
|
132
|
+
@abc.abstractmethod
|
|
133
|
+
def serialize_device(self, device : Optional[BDeviceType]) -> Optional[str]:
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
136
|
+
@abc.abstractmethod
|
|
137
|
+
def deserialize_device(self, device_str : Optional[str]) -> Optional[BDeviceType]:
|
|
138
|
+
raise NotImplementedError
|
|
139
|
+
|
|
140
|
+
@abc.abstractmethod
|
|
141
|
+
def is_backendarray(self, data : Any) -> bool:
|
|
142
|
+
raise NotImplementedError
|
|
143
|
+
|
|
144
|
+
@abc.abstractmethod
|
|
145
|
+
def from_numpy(
|
|
146
|
+
self,
|
|
147
|
+
data : np.ndarray,
|
|
148
|
+
/,
|
|
149
|
+
*,
|
|
150
|
+
dtype : Optional[BDtypeType] = None,
|
|
151
|
+
device : Optional[BDeviceType] = None
|
|
152
|
+
) -> BArrayType:
|
|
153
|
+
raise NotImplementedError
|
|
154
|
+
|
|
155
|
+
@abc.abstractmethod
|
|
156
|
+
def from_other_backend(
|
|
157
|
+
self,
|
|
158
|
+
other_backend : "ComputeBackend",
|
|
159
|
+
data : ArrayAPIArray,
|
|
160
|
+
/
|
|
161
|
+
) -> BArrayType:
|
|
162
|
+
"""
|
|
163
|
+
Convert an array from another backend to this backend.
|
|
164
|
+
The other backend must be compatible with the ArrayAPI.
|
|
165
|
+
"""
|
|
166
|
+
raise NotImplementedError
|
|
167
|
+
|
|
168
|
+
@abc.abstractmethod
|
|
169
|
+
def to_numpy(
|
|
170
|
+
self,
|
|
171
|
+
data : BArrayType
|
|
172
|
+
) -> np.ndarray:
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
@abc.abstractmethod
|
|
176
|
+
def to_dlpack(
|
|
177
|
+
self,
|
|
178
|
+
data : BArrayType,
|
|
179
|
+
/,
|
|
180
|
+
) -> SupportsDLPack:
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
183
|
+
@abc.abstractmethod
|
|
184
|
+
def dtype_is_real_integer(self, dtype : BDtypeType) -> bool:
|
|
185
|
+
raise NotImplementedError
|
|
186
|
+
|
|
187
|
+
@abc.abstractmethod
|
|
188
|
+
def dtype_is_real_floating(self, dtype : BDtypeType) -> bool:
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
@abc.abstractmethod
|
|
192
|
+
def dtype_is_boolean(self, dtype : BDtypeType) -> bool:
|
|
193
|
+
raise NotImplementedError
|
|
194
|
+
|
|
195
|
+
@abc.abstractmethod
|
|
196
|
+
def abbreviate_array(self, x : BArrayType, try_cast_scalar: bool = True) -> Union[float, int, BArrayType]:
|
|
197
|
+
"""
|
|
198
|
+
Abbreviates an array to a single element if possible.
|
|
199
|
+
Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
|
|
200
|
+
"""
|
|
201
|
+
raise NotImplementedError
|
|
202
|
+
|
|
203
|
+
@abc.abstractmethod
|
|
204
|
+
def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
|
|
205
|
+
"""
|
|
206
|
+
Map a function over arrays in a data structure and produce a new data structure with the same shape.
|
|
207
|
+
This is useful for applying a function to all arrays in a nested structure.
|
|
208
|
+
"""
|
|
209
|
+
raise NotImplementedError
|
|
210
|
+
|
|
211
|
+
@abc.abstractmethod
|
|
212
|
+
def pad_dim(self, x : BArrayType, dim : int, target_size : int, value : Union[float, int] = 0) -> BArrayType:
|
|
213
|
+
"""
|
|
214
|
+
Pad a dimension of an array to a target size with a given value.
|
|
215
|
+
If the dimension is already the target size, return the original array.
|
|
216
|
+
If the dimension is larger than the target size, raise an error.
|
|
217
|
+
"""
|
|
218
|
+
raise NotImplementedError
|
xbarray/backends/jax.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import jax as jax_impl
|
|
3
|
+
|
|
4
|
+
class JaxComputeBackend(metaclass=ComputeBackendImplCls[jax_impl.ARRAY_TYPE, jax_impl.DEVICE_TYPE, jax_impl.DTYPE_TYPE, jax_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = jax_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = jax_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = jax_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = jax_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(jax_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
setattr(JaxComputeBackend, name, getattr(jax_impl, name))
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
'JaxComputeBackend',
|
|
19
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import numpy as numpy_impl
|
|
3
|
+
|
|
4
|
+
class NumpyComputeBackend(metaclass=ComputeBackendImplCls[numpy_impl.ARRAY_TYPE, numpy_impl.DEVICE_TYPE, numpy_impl.DTYPE_TYPE, numpy_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = numpy_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = numpy_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = numpy_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = numpy_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(numpy_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
'NumpyComputeBackend',
|
|
19
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from ._cls_base import ComputeBackendImplCls
|
|
2
|
+
from ._implementations import pytorch as pytorch_impl
|
|
3
|
+
|
|
4
|
+
class PytorchComputeBackend(metaclass=ComputeBackendImplCls[pytorch_impl.ARRAY_TYPE, pytorch_impl.DEVICE_TYPE, pytorch_impl.DTYPE_TYPE, pytorch_impl.RNG_TYPE]):
|
|
5
|
+
ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = pytorch_impl.RNG_TYPE
|
|
9
|
+
|
|
10
|
+
for name in dir(pytorch_impl):
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
15
|
+
try:
|
|
16
|
+
setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
|
|
17
|
+
except AttributeError:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'PytorchComputeBackend',
|
|
22
|
+
]
|
xbarray/jax.py
ADDED
xbarray/numpy.py
ADDED
xbarray/pytorch.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import *
|