xbarray 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. array_api_typing/__init__.py +9 -0
  2. array_api_typing/typing_2024_12/__init__.py +12 -0
  3. array_api_typing/typing_2024_12/_api_constant.py +32 -0
  4. array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
  5. array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
  6. array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
  7. array_api_typing/typing_2024_12/_api_typing.py +5855 -0
  8. array_api_typing/typing_2024_12/_array_typing.py +1265 -0
  9. array_api_typing/typing_compat/__init__.py +12 -0
  10. array_api_typing/typing_compat/_api_typing.py +27 -0
  11. array_api_typing/typing_compat/_array_typing.py +36 -0
  12. array_api_typing/typing_extra/__init__.py +12 -0
  13. array_api_typing/typing_extra/_api_typing.py +651 -0
  14. array_api_typing/typing_extra/_at.py +87 -0
  15. xbarray/__init__.py +1 -0
  16. xbarray/backends/_cls_base.py +9 -0
  17. xbarray/backends/_implementations/_common/implementations.py +87 -0
  18. xbarray/backends/_implementations/jax/__init__.py +33 -0
  19. xbarray/backends/_implementations/jax/_extra.py +127 -0
  20. xbarray/backends/_implementations/jax/_typing.py +15 -0
  21. xbarray/backends/_implementations/jax/random.py +115 -0
  22. xbarray/backends/_implementations/numpy/__init__.py +25 -0
  23. xbarray/backends/_implementations/numpy/_extra.py +98 -0
  24. xbarray/backends/_implementations/numpy/_typing.py +14 -0
  25. xbarray/backends/_implementations/numpy/random.py +105 -0
  26. xbarray/backends/_implementations/pytorch/__init__.py +26 -0
  27. xbarray/backends/_implementations/pytorch/_extra.py +135 -0
  28. xbarray/backends/_implementations/pytorch/_typing.py +13 -0
  29. xbarray/backends/_implementations/pytorch/random.py +101 -0
  30. xbarray/backends/base.py +218 -0
  31. xbarray/backends/jax.py +19 -0
  32. xbarray/backends/numpy.py +19 -0
  33. xbarray/backends/pytorch.py +22 -0
  34. xbarray/jax.py +4 -0
  35. xbarray/numpy.py +4 -0
  36. xbarray/pytorch.py +4 -0
  37. xbarray/transformations/pointcloud/__init__.py +1 -0
  38. xbarray/transformations/pointcloud/base.py +449 -0
  39. xbarray/transformations/pointcloud/jax.py +24 -0
  40. xbarray/transformations/pointcloud/numpy.py +23 -0
  41. xbarray/transformations/pointcloud/pytorch.py +23 -0
  42. xbarray/transformations/rotation_conversions/__init__.py +1 -0
  43. xbarray/transformations/rotation_conversions/base.py +713 -0
  44. xbarray/transformations/rotation_conversions/jax.py +41 -0
  45. xbarray/transformations/rotation_conversions/numpy.py +41 -0
  46. xbarray/transformations/rotation_conversions/pytorch.py +41 -0
  47. xbarray-0.0.1a13.dist-info/METADATA +20 -0
  48. xbarray-0.0.1a13.dist-info/RECORD +51 -0
  49. xbarray-0.0.1a13.dist-info/WHEEL +5 -0
  50. xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
  51. xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
@@ -0,0 +1,105 @@
1
+ from typing import Union, Optional, Tuple, Any
2
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
3
+ import numpy as np
4
+
5
+ __all__ = [
6
+ "random_number_generator",
7
+ "random_discrete_uniform",
8
+ "random_uniform",
9
+ "random_exponential",
10
+ "random_normal",
11
+ "random_geometric",
12
+ "random_permutation"
13
+ ]
14
+
15
+ def random_number_generator(
16
+ seed : Optional[int] = None,
17
+ *,
18
+ device : Optional[DEVICE_TYPE] = None
19
+ ) -> RNG_TYPE:
20
+ return np.random.default_rng(seed)
21
+
22
+ def random_discrete_uniform(
23
+ shape : Union[int, Tuple[int, ...]],
24
+ /,
25
+ from_num : int,
26
+ to_num : int,
27
+ *,
28
+ rng : RNG_TYPE,
29
+ dtype : Optional[DTYPE_TYPE] = None,
30
+ device : Optional[DEVICE_TYPE] = None
31
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
32
+ t = rng.integers(int(from_num), int(to_num), size=shape)
33
+ if dtype is not None:
34
+ t = t.astype(dtype)
35
+ return rng, t
36
+
37
+ def random_uniform(
38
+ shape: Union[int, Tuple[int, ...]],
39
+ /,
40
+ *,
41
+ rng : RNG_TYPE,
42
+ low : float = 0.0, high : float = 1.0,
43
+ dtype : Optional[DTYPE_TYPE] = None,
44
+ device : Optional[DEVICE_TYPE] = None
45
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
46
+ t = rng.uniform(float(low), float(high), size=shape)
47
+ if dtype is not None:
48
+ t = t.astype(dtype)
49
+ return rng, t
50
+
51
+ def random_exponential(
52
+ shape: Union[int, Tuple[int, ...]],
53
+ /,
54
+ *,
55
+ rng : RNG_TYPE,
56
+ lambd : float = 1.0,
57
+ dtype : Optional[DTYPE_TYPE] = None,
58
+ device : Optional[DEVICE_TYPE] = None
59
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
60
+ t = rng.exponential(1.0 / float(lambd), size=shape)
61
+ if dtype is not None:
62
+ t = t.astype(dtype)
63
+ return rng, t
64
+
65
+ def random_normal(
66
+ shape: Union[int, Tuple[int, ...]],
67
+ /,
68
+ *,
69
+ rng : RNG_TYPE,
70
+ mean : float = 0.0, std : float = 1.0,
71
+ dtype : Optional[DTYPE_TYPE] = None,
72
+ device : Optional[Any] = None
73
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
74
+ t = rng.normal(mean, std, size=shape)
75
+ if dtype is not None:
76
+ t = t.astype(dtype)
77
+ return rng, t
78
+
79
+
80
+ def random_geometric(
81
+ shape: Union[int, Tuple[int, ...]],
82
+ /,
83
+ *,
84
+ p: float,
85
+ rng: RNG_TYPE,
86
+ dtype: Optional[DTYPE_TYPE] = None,
87
+ device: Optional[Any] = None
88
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
89
+ t = rng.geometric(p, size=shape)
90
+ if dtype is not None:
91
+ t = t.astype(dtype)
92
+ return rng, t
93
+
94
+ def random_permutation(
95
+ n : int,
96
+ /,
97
+ *,
98
+ rng: RNG_TYPE,
99
+ dtype: Optional[DTYPE_TYPE] = None,
100
+ device: Optional[DEVICE_TYPE] = None
101
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
102
+ t = rng.permutation(n)
103
+ if dtype is not None:
104
+ t = t.astype(dtype)
105
+ return rng, t
@@ -0,0 +1,26 @@
1
+ from array_api_compat.torch import *
2
+ from array_api_compat.torch import __array_api_version__, __array_namespace_info__
3
+ from array_api_compat.common._helpers import *
4
+
5
+ simplified_name = "pytorch"
6
+
7
+ from array_api_compat import torch as compat_module
8
+
9
+ # Import and bind all functions from array_api_extra before exposing them
10
+ import array_api_extra
11
+ from functools import partial
12
+ for api_name in dir(array_api_extra):
13
+ if api_name.startswith('_'):
14
+ continue
15
+
16
+ if api_name in ['at', 'broadcast_shapes']:
17
+ globals()[api_name] = getattr(array_api_extra, api_name)
18
+ else:
19
+ globals()[api_name] = partial(
20
+ getattr(array_api_extra, api_name),
21
+ xp=compat_module
22
+ )
23
+
24
+ from ._typing import *
25
+ from ._extra import *
26
+ __import__(__package__ + ".random")
@@ -0,0 +1,135 @@
1
+ from typing import Any, Union, Optional
2
+ import numpy as np
3
+ import torch
4
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
5
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
6
+
7
+ PYTORCH_DTYPE_CAST_MAP = {
8
+ torch.uint16: torch.int16,
9
+ torch.uint32: torch.int32,
10
+ torch.uint64: torch.int64,
11
+ torch.float8_e4m3fn: torch.float16,
12
+ torch.float8_e5m2: torch.float16,
13
+ }
14
+
15
+ __all__ = [
16
+ "default_integer_dtype",
17
+ "default_index_dtype",
18
+ "default_floating_dtype",
19
+ "default_boolean_dtype",
20
+ "serialize_device",
21
+ "deserialize_device",
22
+ "is_backendarray",
23
+ "from_numpy",
24
+ "from_other_backend",
25
+ "to_numpy",
26
+ "to_dlpack",
27
+ "dtype_is_real_integer",
28
+ "dtype_is_real_floating",
29
+ "dtype_is_boolean",
30
+ "abbreviate_array",
31
+ "map_fn_over_arrays",
32
+ "pad_dim",
33
+ ]
34
+
35
+ default_integer_dtype = torch.int32
36
+ default_index_dtype = torch.long
37
+ default_floating_dtype = torch.float32
38
+ default_boolean_dtype = torch.bool
39
+
40
+ def serialize_device(device : Optional[DEVICE_TYPE]) -> Optional[str]:
41
+ if device is None:
42
+ return None
43
+ if isinstance(device, str):
44
+ return device
45
+ assert isinstance(device, torch.device)
46
+ return f"{device.type}:{device.index if device.index is not None else 0}"
47
+
48
+ def deserialize_device(device_str : Optional[str]) -> Optional[DEVICE_TYPE]:
49
+ if device_str is None:
50
+ return None
51
+ if ":" in device_str:
52
+ device_type, index_str = device_str.split(":")
53
+ index = int(index_str)
54
+ return torch.device(device_type, index)
55
+ else:
56
+ return torch.device(device_str)
57
+
58
+ def is_backendarray(data : Any) -> bool:
59
+ return isinstance(data, torch.Tensor)
60
+
61
+ def from_numpy(
62
+ data : np.ndarray,
63
+ /,
64
+ *,
65
+ dtype : Optional[DTYPE_TYPE] = None,
66
+ device : Optional[DEVICE_TYPE] = None
67
+ ) -> ARRAY_TYPE:
68
+ t = torch.from_numpy(data)
69
+ target_dtype = dtype if dtype is not None else PYTORCH_DTYPE_CAST_MAP.get(t.dtype, t.dtype)
70
+ if target_dtype is not None or device is not None:
71
+ t = t.to(device=device, dtype=target_dtype)
72
+ return t
73
+
74
+ def from_other_backend(
75
+ other_backend: ComputeBackend,
76
+ data: Any,
77
+ /,
78
+ ) -> ARRAY_TYPE:
79
+ dat_dlpack = other_backend.to_dlpack(data)
80
+ return torch.from_dlpack(dat_dlpack)
81
+
82
+ def to_numpy(
83
+ data : ARRAY_TYPE
84
+ ) -> np.ndarray:
85
+ # Torch bfloat16 is not supported by numpy
86
+ if data.dtype == torch.bfloat16:
87
+ data = data.to(torch.float32)
88
+ return data.cpu().numpy()
89
+
90
+ def to_dlpack(
91
+ data: ARRAY_TYPE,
92
+ /,
93
+ ) -> SupportsDLPack:
94
+ return data
95
+
96
+ def dtype_is_real_integer(
97
+ dtype: DTYPE_TYPE
98
+ ) -> bool:
99
+ # https://pytorch.org/docs/stable/tensors.html#id12
100
+ return dtype in [
101
+ torch.int8, torch.int16, torch.int32, torch.int64,
102
+ torch.uint8,
103
+ torch.int,
104
+ torch.long
105
+ ]
106
+
107
+ def dtype_is_real_floating(
108
+ dtype: DTYPE_TYPE
109
+ ) -> bool:
110
+ return dtype in [
111
+ torch.float16, torch.float32, torch.float64,
112
+ torch.float, torch.double,
113
+ torch.bfloat16
114
+ ]
115
+
116
+ def dtype_is_boolean(
117
+ dtype: DTYPE_TYPE
118
+ ) -> bool:
119
+ return dtype == torch.bool
120
+
121
+ from .._common.implementations import *
122
+ from array_api_compat import torch as compat_module
123
+ abbreviate_array = get_abbreviate_array_function(
124
+ compat_module,
125
+ default_integer_dtype=default_integer_dtype,
126
+ func_dtype_is_real_floating=dtype_is_real_floating,
127
+ func_dtype_is_real_integer=dtype_is_real_integer,
128
+ func_dtype_is_boolean=dtype_is_boolean
129
+ )
130
+ map_fn_over_arrays = get_map_fn_over_arrays_function(
131
+ is_backendarray=is_backendarray,
132
+ )
133
+ pad_dim = get_pad_dim_function(
134
+ backend=compat_module,
135
+ )
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+ __all__ = [
4
+ 'ARRAY_TYPE',
5
+ 'DTYPE_TYPE',
6
+ 'DEVICE_TYPE',
7
+ 'RNG_TYPE',
8
+ ]
9
+
10
+ ARRAY_TYPE = torch.Tensor
11
+ DTYPE_TYPE = torch.dtype
12
+ DEVICE_TYPE = torch.device
13
+ RNG_TYPE = torch.Generator
@@ -0,0 +1,101 @@
1
+ from typing import Union, Optional, Tuple, Any
2
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
3
+ import torch
4
+
5
+ __all__ = [
6
+ "random_number_generator",
7
+ "random_discrete_uniform",
8
+ "random_uniform",
9
+ "random_exponential",
10
+ "random_normal",
11
+ "random_geometric",
12
+ "random_permutation"
13
+ ]
14
+
15
+ def random_number_generator(
16
+ seed : Optional[int] = None,
17
+ *,
18
+ device : Optional[DEVICE_TYPE] = None
19
+ ) -> RNG_TYPE:
20
+ rng = torch.Generator(
21
+ device=device
22
+ )
23
+ if seed is not None:
24
+ rng = rng.manual_seed(seed)
25
+ return rng
26
+
27
+
28
+ def random_discrete_uniform(
29
+ shape : Union[int, Tuple[int, ...]],
30
+ /,
31
+ from_num : int,
32
+ to_num : int,
33
+ *,
34
+ rng : RNG_TYPE,
35
+ dtype : Optional[DTYPE_TYPE] = None,
36
+ device : Optional[DEVICE_TYPE] = None
37
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
38
+ t = torch.randint(int(from_num), int(to_num), shape, generator=rng, dtype=dtype, device=device)
39
+ return rng, t
40
+
41
+ def random_uniform(
42
+ shape: Union[int, Tuple[int, ...]],
43
+ /,
44
+ *,
45
+ rng : RNG_TYPE,
46
+ low : float = 0.0, high : float = 1.0,
47
+ dtype : Optional[DTYPE_TYPE] = None,
48
+ device : Optional[DEVICE_TYPE] = None
49
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
50
+ t = torch.rand(shape, generator=rng, dtype=dtype, device=device)
51
+ t = t * (high - low) + low
52
+ return rng, t
53
+
54
+ def random_exponential(
55
+ shape: Union[int, Tuple[int, ...]],
56
+ /,
57
+ *,
58
+ rng : RNG_TYPE,
59
+ lambd : float = 1.0,
60
+ dtype : Optional[DTYPE_TYPE] = None,
61
+ device : Optional[DEVICE_TYPE] = None
62
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
63
+ t = torch.empty(shape, dtype=dtype, device=device)
64
+ t = t.exponential_(lambd, generator=rng)
65
+ return rng, t
66
+
67
+ def random_normal(
68
+ shape: Union[int, Tuple[int, ...]],
69
+ /,
70
+ *,
71
+ rng : RNG_TYPE,
72
+ mean : float = 0.0, std : float = 1.0,
73
+ dtype : Optional[DTYPE_TYPE] = None,
74
+ device : Optional[Any] = None
75
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
76
+ t = torch.normal(mean, std, shape, generator=rng, dtype=dtype, device=device)
77
+ return rng, t
78
+
79
+ def random_geometric(
80
+ shape: Union[int, Tuple[int, ...]],
81
+ /,
82
+ *,
83
+ p: float,
84
+ rng: RNG_TYPE,
85
+ dtype: Optional[DTYPE_TYPE] = None,
86
+ device: Optional[Any] = None
87
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
88
+ t = torch.empty(shape, dtype=dtype, device=device)
89
+ t = t.geometric_(p, generator=rng)
90
+ return rng, t
91
+
92
+ def random_permutation(
93
+ n : int,
94
+ /,
95
+ *,
96
+ rng: RNG_TYPE,
97
+ dtype: Optional[DTYPE_TYPE] = None,
98
+ device: Optional[DEVICE_TYPE] = None
99
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
100
+ t = torch.randperm(n, generator=rng, dtype=dtype, device=device)
101
+ return rng, t
@@ -0,0 +1,218 @@
1
+ from typing import Optional, Generic, TypeVar, Dict, Union, Any, Sequence, SupportsFloat, Tuple, Type, Callable, Mapping, Protocol, Literal
2
+ import abc
3
+ from array_api_typing.typing_extra import *
4
+ import numpy as np
5
+
6
+ ArrayAPISetIndex = SetIndex
7
+ ArrayAPIGetIndex = GetIndex
8
+
9
+ __all__ = [
10
+ "RNGBackend",
11
+ "ComputeBackend",
12
+ "BArrayType",
13
+ "BDeviceType",
14
+ "BDtypeType",
15
+ "BRNGType",
16
+ "SupportsDLPack",
17
+ "ArrayAPIArray",
18
+ "ArrayAPIDevice",
19
+ "ArrayAPIDType",
20
+ "ArrayAPINamespace",
21
+ "ArrayAPISetIndex",
22
+ "ArrayAPIGetIndex",
23
+ ]
24
+
25
+ BArrayType = TypeVar("BArrayType", covariant=True, bound=ArrayAPIArray)
26
+ BDeviceType = TypeVar("BDeviceType", covariant=True, bound=ArrayAPIDevice)
27
+ BDtypeType = TypeVar("BDtypeType", covariant=True, bound=ArrayAPIDType)
28
+ BRNGType = TypeVar("BRNGType", covariant=True)
29
+ class RNGBackend(Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
30
+ @abc.abstractmethod
31
+ def random_number_generator(
32
+ self,
33
+ seed : Optional[int] = None,
34
+ *,
35
+ device : Optional[BDeviceType] = None
36
+ ) -> BRNGType:
37
+ raise NotImplementedError
38
+
39
+ @abc.abstractmethod
40
+ def random_discrete_uniform(
41
+ self,
42
+ shape : Union[int, Tuple[int, ...]],
43
+ /,
44
+ from_num : int,
45
+ to_num : int,
46
+ *,
47
+ rng : BRNGType,
48
+ dtype : Optional[BDtypeType] = None,
49
+ device : Optional[BDeviceType] = None,
50
+ ) -> Tuple[BRNGType, BArrayType]:
51
+ """
52
+ Sample from a discrete uniform distribution [from_num, to_num) with shape `shape`.
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @abc.abstractmethod
57
+ def random_uniform(
58
+ self,
59
+ shape: Union[int, Tuple[int, ...]],
60
+ /,
61
+ *,
62
+ rng : BRNGType,
63
+ low : float = 0.0, high : float = 1.0,
64
+ dtype : Optional[BDtypeType] = None,
65
+ device : Optional[BDeviceType] = None,
66
+ ) -> Tuple[BRNGType, BArrayType]:
67
+ raise NotImplementedError
68
+
69
+ @abc.abstractmethod
70
+ def random_exponential(
71
+ self,
72
+ shape: Union[int, Tuple[int, ...]],
73
+ /,
74
+ *,
75
+ rng : BRNGType,
76
+ lambd : float = 1.0,
77
+ dtype : Optional[BDtypeType] = None,
78
+ device : Optional[BDeviceType] = None,
79
+ ) -> Tuple[BRNGType, BArrayType]:
80
+ raise NotImplementedError
81
+
82
+ @abc.abstractmethod
83
+ def random_normal(
84
+ self,
85
+ shape: Union[int, Tuple[int, ...]],
86
+ /,
87
+ *,
88
+ rng : BRNGType,
89
+ mean : float = 0.0, std : float = 1.0,
90
+ dtype : Optional[BDtypeType] = None,
91
+ device : Optional[BDeviceType] = None,
92
+ ) -> Tuple[BRNGType, BArrayType]:
93
+ raise NotImplementedError
94
+
95
+ @abc.abstractmethod
96
+ def random_geometric(
97
+ self,
98
+ shape: Union[int, Tuple[int, ...]],
99
+ /,
100
+ *,
101
+ p : float,
102
+ rng : BRNGType,
103
+ dtype : Optional[BDtypeType] = None,
104
+ device : Optional[BDeviceType] = None,
105
+ ) -> Tuple[BRNGType, BArrayType]:
106
+ raise NotImplementedError
107
+
108
+ @abc.abstractmethod
109
+ def random_permutation(
110
+ self,
111
+ n : int,
112
+ /,
113
+ *,
114
+ rng : BRNGType,
115
+ dtype : Optional[BDtypeType] = None,
116
+ device : Optional[BDeviceType] = None
117
+ ) -> Tuple[BRNGType, BArrayType]:
118
+ raise NotImplementedError
119
+
120
+ class ComputeBackend(ArrayAPINamespace[BArrayType, BDeviceType, BDtypeType], Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
121
+ simplified_name : Literal['numpy', 'jax', 'pytorch']
122
+ ARRAY_TYPE : Type[BArrayType]
123
+ DEVICE_TYPE : Type[BDeviceType]
124
+ DTYPE_TYPE : Type[BDtypeType]
125
+ RNG_TYPE : Type[BRNGType]
126
+ default_integer_dtype : BDtypeType
127
+ default_index_dtype : BDtypeType
128
+ default_floating_dtype : BDtypeType
129
+ default_boolean_dtype : BDtypeType
130
+ random : RNGBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]
131
+
132
+ @abc.abstractmethod
133
+ def serialize_device(self, device : Optional[BDeviceType]) -> Optional[str]:
134
+ raise NotImplementedError
135
+
136
+ @abc.abstractmethod
137
+ def deserialize_device(self, device_str : Optional[str]) -> Optional[BDeviceType]:
138
+ raise NotImplementedError
139
+
140
+ @abc.abstractmethod
141
+ def is_backendarray(self, data : Any) -> bool:
142
+ raise NotImplementedError
143
+
144
+ @abc.abstractmethod
145
+ def from_numpy(
146
+ self,
147
+ data : np.ndarray,
148
+ /,
149
+ *,
150
+ dtype : Optional[BDtypeType] = None,
151
+ device : Optional[BDeviceType] = None
152
+ ) -> BArrayType:
153
+ raise NotImplementedError
154
+
155
+ @abc.abstractmethod
156
+ def from_other_backend(
157
+ self,
158
+ other_backend : "ComputeBackend",
159
+ data : ArrayAPIArray,
160
+ /
161
+ ) -> BArrayType:
162
+ """
163
+ Convert an array from another backend to this backend.
164
+ The other backend must be compatible with the ArrayAPI.
165
+ """
166
+ raise NotImplementedError
167
+
168
+ @abc.abstractmethod
169
+ def to_numpy(
170
+ self,
171
+ data : BArrayType
172
+ ) -> np.ndarray:
173
+ raise NotImplementedError
174
+
175
+ @abc.abstractmethod
176
+ def to_dlpack(
177
+ self,
178
+ data : BArrayType,
179
+ /,
180
+ ) -> SupportsDLPack:
181
+ raise NotImplementedError
182
+
183
+ @abc.abstractmethod
184
+ def dtype_is_real_integer(self, dtype : BDtypeType) -> bool:
185
+ raise NotImplementedError
186
+
187
+ @abc.abstractmethod
188
+ def dtype_is_real_floating(self, dtype : BDtypeType) -> bool:
189
+ raise NotImplementedError
190
+
191
+ @abc.abstractmethod
192
+ def dtype_is_boolean(self, dtype : BDtypeType) -> bool:
193
+ raise NotImplementedError
194
+
195
+ @abc.abstractmethod
196
+ def abbreviate_array(self, x : BArrayType, try_cast_scalar: bool = True) -> Union[float, int, BArrayType]:
197
+ """
198
+ Abbreviates an array to a single element if possible.
199
+ Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
200
+ """
201
+ raise NotImplementedError
202
+
203
+ @abc.abstractmethod
204
+ def map_fn_over_arrays(self, data : Any, func : Callable[[BArrayType], BArrayType]) -> Any:
205
+ """
206
+ Map a function over arrays in a data structure and produce a new data structure with the same shape.
207
+ This is useful for applying a function to all arrays in a nested structure.
208
+ """
209
+ raise NotImplementedError
210
+
211
+ @abc.abstractmethod
212
+ def pad_dim(self, x : BArrayType, dim : int, target_size : int, value : Union[float, int] = 0) -> BArrayType:
213
+ """
214
+ Pad a dimension of an array to a target size with a given value.
215
+ If the dimension is already the target size, return the original array.
216
+ If the dimension is larger than the target size, raise an error.
217
+ """
218
+ raise NotImplementedError
@@ -0,0 +1,19 @@
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import jax as jax_impl
3
+
4
+ class JaxComputeBackend(metaclass=ComputeBackendImplCls[jax_impl.ARRAY_TYPE, jax_impl.DEVICE_TYPE, jax_impl.DTYPE_TYPE, jax_impl.RNG_TYPE]):
5
+ ARRAY_TYPE = jax_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = jax_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = jax_impl.DEVICE_TYPE
8
+ RNG_TYPE = jax_impl.RNG_TYPE
9
+
10
+ for name in dir(jax_impl):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
15
+ setattr(JaxComputeBackend, name, getattr(jax_impl, name))
16
+
17
+ __all__ = [
18
+ 'JaxComputeBackend',
19
+ ]
@@ -0,0 +1,19 @@
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import numpy as numpy_impl
3
+
4
+ class NumpyComputeBackend(metaclass=ComputeBackendImplCls[numpy_impl.ARRAY_TYPE, numpy_impl.DEVICE_TYPE, numpy_impl.DTYPE_TYPE, numpy_impl.RNG_TYPE]):
5
+ ARRAY_TYPE = numpy_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = numpy_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = numpy_impl.DEVICE_TYPE
8
+ RNG_TYPE = numpy_impl.RNG_TYPE
9
+
10
+ for name in dir(numpy_impl):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
15
+ setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
16
+
17
+ __all__ = [
18
+ 'NumpyComputeBackend',
19
+ ]
@@ -0,0 +1,22 @@
1
+ from ._cls_base import ComputeBackendImplCls
2
+ from ._implementations import pytorch as pytorch_impl
3
+
4
+ class PytorchComputeBackend(metaclass=ComputeBackendImplCls[pytorch_impl.ARRAY_TYPE, pytorch_impl.DEVICE_TYPE, pytorch_impl.DTYPE_TYPE, pytorch_impl.RNG_TYPE]):
5
+ ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
8
+ RNG_TYPE = pytorch_impl.RNG_TYPE
9
+
10
+ for name in dir(pytorch_impl):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
15
+ try:
16
+ setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
17
+ except AttributeError:
18
+ pass
19
+
20
+ __all__ = [
21
+ 'PytorchComputeBackend',
22
+ ]
xbarray/jax.py ADDED
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.jax` instead.
3
+ """
4
+ from .backends.jax import *
xbarray/numpy.py ADDED
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.numpy` instead.
3
+ """
4
+ from .backends.numpy import *
xbarray/pytorch.py ADDED
@@ -0,0 +1,4 @@
1
+ """ This file is generated for backward compatibility.
2
+ If you are using xbarray >= 0.0.1a9, please import from `xbarray.backends.pytorch` instead.
3
+ """
4
+ from .backends.pytorch import *
@@ -0,0 +1 @@
1
+ from .base import *