xbarray 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. array_api_typing/__init__.py +9 -0
  2. array_api_typing/typing_2024_12/__init__.py +12 -0
  3. array_api_typing/typing_2024_12/_api_constant.py +32 -0
  4. array_api_typing/typing_2024_12/_api_fft_typing.py +717 -0
  5. array_api_typing/typing_2024_12/_api_linalg_typing.py +897 -0
  6. array_api_typing/typing_2024_12/_api_return_typing.py +103 -0
  7. array_api_typing/typing_2024_12/_api_typing.py +5855 -0
  8. array_api_typing/typing_2024_12/_array_typing.py +1265 -0
  9. array_api_typing/typing_compat/__init__.py +12 -0
  10. array_api_typing/typing_compat/_api_typing.py +27 -0
  11. array_api_typing/typing_compat/_array_typing.py +36 -0
  12. array_api_typing/typing_extra/__init__.py +12 -0
  13. array_api_typing/typing_extra/_api_typing.py +651 -0
  14. array_api_typing/typing_extra/_at.py +87 -0
  15. xbarray/__init__.py +1 -0
  16. xbarray/backends/_cls_base.py +9 -0
  17. xbarray/backends/_implementations/_common/implementations.py +87 -0
  18. xbarray/backends/_implementations/jax/__init__.py +33 -0
  19. xbarray/backends/_implementations/jax/_extra.py +127 -0
  20. xbarray/backends/_implementations/jax/_typing.py +15 -0
  21. xbarray/backends/_implementations/jax/random.py +115 -0
  22. xbarray/backends/_implementations/numpy/__init__.py +25 -0
  23. xbarray/backends/_implementations/numpy/_extra.py +98 -0
  24. xbarray/backends/_implementations/numpy/_typing.py +14 -0
  25. xbarray/backends/_implementations/numpy/random.py +105 -0
  26. xbarray/backends/_implementations/pytorch/__init__.py +26 -0
  27. xbarray/backends/_implementations/pytorch/_extra.py +135 -0
  28. xbarray/backends/_implementations/pytorch/_typing.py +13 -0
  29. xbarray/backends/_implementations/pytorch/random.py +101 -0
  30. xbarray/backends/base.py +218 -0
  31. xbarray/backends/jax.py +19 -0
  32. xbarray/backends/numpy.py +19 -0
  33. xbarray/backends/pytorch.py +22 -0
  34. xbarray/jax.py +4 -0
  35. xbarray/numpy.py +4 -0
  36. xbarray/pytorch.py +4 -0
  37. xbarray/transformations/pointcloud/__init__.py +1 -0
  38. xbarray/transformations/pointcloud/base.py +449 -0
  39. xbarray/transformations/pointcloud/jax.py +24 -0
  40. xbarray/transformations/pointcloud/numpy.py +23 -0
  41. xbarray/transformations/pointcloud/pytorch.py +23 -0
  42. xbarray/transformations/rotation_conversions/__init__.py +1 -0
  43. xbarray/transformations/rotation_conversions/base.py +713 -0
  44. xbarray/transformations/rotation_conversions/jax.py +41 -0
  45. xbarray/transformations/rotation_conversions/numpy.py +41 -0
  46. xbarray/transformations/rotation_conversions/pytorch.py +41 -0
  47. xbarray-0.0.1a13.dist-info/METADATA +20 -0
  48. xbarray-0.0.1a13.dist-info/RECORD +51 -0
  49. xbarray-0.0.1a13.dist-info/WHEEL +5 -0
  50. xbarray-0.0.1a13.dist-info/licenses/LICENSE +21 -0
  51. xbarray-0.0.1a13.dist-info/top_level.txt +2 -0
@@ -0,0 +1,87 @@
1
+ from typing import Protocol, TypeVar, Optional, Any, Tuple, Union, Type, TypedDict, List
2
+ from abc import abstractmethod
3
+ from array_api_typing.typing_compat._api_typing import _NAMESPACE_ARRAY
4
+ from array_api_typing.typing_compat._array_typing import SetIndex
5
+
6
+ class AtResult(Protocol[_NAMESPACE_ARRAY]):
7
+ @abstractmethod
8
+ def __getitem__(self, idx: SetIndex, /) -> "AtResult[_NAMESPACE_ARRAY]":
9
+ """
10
+ Allow for the alternate syntax ``at(x)[start:stop:step]``.
11
+
12
+ It looks prettier than ``at(x, slice(start, stop, step))``
13
+ and feels more intuitive coming from the JAX documentation.
14
+ """
15
+ pass
16
+
17
+ @abstractmethod
18
+ def set(
19
+ self,
20
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
21
+ /,
22
+ copy: Optional[bool] = None,
23
+ ) -> _NAMESPACE_ARRAY:
24
+ pass
25
+
26
+ @abstractmethod
27
+ def add(
28
+ self,
29
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
30
+ /,
31
+ copy: Optional[bool] = None,
32
+ ) -> _NAMESPACE_ARRAY:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def subtract(
37
+ self,
38
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
39
+ /,
40
+ copy: Optional[bool] = None,
41
+ ) -> _NAMESPACE_ARRAY:
42
+ pass
43
+
44
+ @abstractmethod
45
+ def multiply(
46
+ self,
47
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
48
+ /,
49
+ copy: Optional[bool] = None,
50
+ ) -> _NAMESPACE_ARRAY:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def divide(
55
+ self,
56
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
57
+ /,
58
+ copy: Optional[bool] = None,
59
+ ) -> _NAMESPACE_ARRAY:
60
+ pass
61
+
62
+ @abstractmethod
63
+ def power(
64
+ self,
65
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
66
+ /,
67
+ copy: Optional[bool] = None,
68
+ ) -> _NAMESPACE_ARRAY:
69
+ pass
70
+
71
+ @abstractmethod
72
+ def min(
73
+ self,
74
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
75
+ /,
76
+ copy: Optional[bool] = None,
77
+ ) -> _NAMESPACE_ARRAY:
78
+ pass
79
+
80
+ @abstractmethod
81
+ def max(
82
+ self,
83
+ y: Union[_NAMESPACE_ARRAY, float, int, complex],
84
+ /,
85
+ copy: Optional[bool] = None,
86
+ ) -> _NAMESPACE_ARRAY:
87
+ pass
xbarray/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from .backends.base import * # Import Abstract Typings
@@ -0,0 +1,9 @@
1
+ from typing import Type, Generic
2
+ from .base import BArrayType, BDeviceType, BDtypeType, BRNGType
3
+
4
+ class ComputeBackendImplCls(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType], Type):
5
+ def __str__(self):
6
+ return self.simplified_name
7
+
8
+ def __repr__(self):
9
+ return self.simplified_name
@@ -0,0 +1,87 @@
1
+ from typing import Any, Union, Callable, Mapping, Sequence
2
+ from array_api_typing.typing_compat import ArrayAPINamespace as CompatNamespace, ArrayAPIArray as CompatArray, ArrayAPIDType as CompatDType
3
+ import array_api_compat
4
+ import dataclasses
5
+
6
+ __all__ = [
7
+ "get_abbreviate_array_function",
8
+ "get_map_fn_over_arrays_function",
9
+ "get_pad_dim_function",
10
+ ]
11
+
12
+ def get_abbreviate_array_function(
13
+ backend : CompatNamespace[CompatArray, Any, Any],
14
+ default_integer_dtype : CompatDType,
15
+ func_dtype_is_real_floating : Callable[[CompatDType], bool],
16
+ func_dtype_is_real_integer : Callable[[CompatDType], bool],
17
+ func_dtype_is_boolean : Callable[[CompatDType], bool],
18
+ ):
19
+ def abbreviate_array(array : CompatArray, try_cast_scalar : bool = True) -> Union[float, int, CompatArray]:
20
+ """
21
+ Abbreivates an array to a single element if possible.
22
+ Or, if some dimensions are the same, abbreviates to a smaller array (but with the same number of dimensions).
23
+ """
24
+ abbr_array = array
25
+ idx = backend.zeros(1, dtype=default_integer_dtype, device=array_api_compat.device(abbr_array))
26
+ for dim_i in range(len(array.shape)):
27
+ first_elem = backend.take(abbr_array, idx, axis=dim_i)
28
+ if backend.all(abbr_array == first_elem):
29
+ abbr_array = first_elem
30
+ else:
31
+ continue
32
+ if try_cast_scalar:
33
+ if all(i == 1 for i in abbr_array.shape):
34
+ elem = abbr_array[tuple([0] * len(abbr_array.shape))]
35
+ if func_dtype_is_real_floating(elem.dtype):
36
+ return float(elem)
37
+ elif func_dtype_is_real_integer(elem.dtype):
38
+ return int(elem)
39
+ elif func_dtype_is_boolean(elem.dtype):
40
+ return bool(elem)
41
+ else:
42
+ raise ValueError(f"Abbreviated array element dtype must be a real floating or integer or boolean type, actual dtype: {elem.dtype}")
43
+ else:
44
+ return array
45
+ return abbreviate_array
46
+
47
+ def get_map_fn_over_arrays_function(
48
+ is_backendarray : Callable[[Any], bool],
49
+ ):
50
+ def map_fn_over_arrays(data : Any, func : Callable[[CompatArray], CompatArray]) -> Any:
51
+ """
52
+ Map a function to the data.
53
+ """
54
+ if is_backendarray(data):
55
+ return func(data)
56
+ elif isinstance(data, Mapping):
57
+ ret = {k: map_fn_over_arrays(v, func) for k, v in data.items()}
58
+ try:
59
+ return type(data)(**ret) # try to keep the same mapping type
60
+ except:
61
+ return ret
62
+ elif isinstance(data, Sequence) and not isinstance(data, (str, bytes)):
63
+ ret = [map_fn_over_arrays(i, func) for i in data]
64
+ try:
65
+ return type(data)(ret) # try to keep the same sequence type
66
+ except:
67
+ return ret
68
+ elif dataclasses.is_dataclass(data):
69
+ return type(data)(**map_fn_over_arrays(dataclasses.asdict(data), func))
70
+ else:
71
+ return data
72
+ return map_fn_over_arrays
73
+
74
+ def get_pad_dim_function(
75
+ backend : CompatNamespace[CompatArray, Any, Any],
76
+ ):
77
+ def pad_dim(x : CompatArray, dim : int, target_size : int, value : Union[float, int] = 0) -> CompatArray:
78
+ size_at_target = x.shape[dim]
79
+ if size_at_target > target_size:
80
+ raise ValueError(f"Cannot pad dimension {dim} of size {size_at_target} to smaller target size {target_size}")
81
+ if size_at_target == target_size:
82
+ return x
83
+ target_shape = list(x.shape)
84
+ target_shape[dim] = target_size - size_at_target
85
+ pad_value = backend.full(target_shape, value, dtype=x.dtype, device=array_api_compat.device(x))
86
+ return backend.concat([x, pad_value], axis=dim)
87
+ return pad_dim
@@ -0,0 +1,33 @@
1
+ import jax.numpy
2
+
3
+ if hasattr(jax.numpy, "__array_api_version__"):
4
+ compat_module = jax.numpy
5
+ from jax.numpy import *
6
+ from jax.numpy import __array_api_version__, __array_namespace_info__
7
+ else:
8
+ import jax.experimental.array_api as compat_module
9
+ from jax.experimental.array_api import *
10
+ from jax.experimental.array_api import __array_api_version__, __array_namespace_info__
11
+
12
+ from array_api_compat.common._helpers import *
13
+
14
+ simplified_name = "jax"
15
+
16
+ # Import and bind all functions from array_api_extra before exposing them
17
+ import array_api_extra
18
+ from functools import partial
19
+ for api_name in dir(array_api_extra):
20
+ if api_name.startswith('_'):
21
+ continue
22
+
23
+ if api_name in ['at', 'broadcast_shapes']:
24
+ globals()[api_name] = getattr(array_api_extra, api_name)
25
+ else:
26
+ globals()[api_name] = partial(
27
+ getattr(array_api_extra, api_name),
28
+ xp=compat_module
29
+ )
30
+
31
+ from ._typing import *
32
+ from ._extra import *
33
+ __import__(__package__ + ".random")
@@ -0,0 +1,127 @@
1
+ from typing import Any, Union, Optional, Callable
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
6
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
7
+
8
+ __all__ = [
9
+ "default_integer_dtype",
10
+ "default_index_dtype",
11
+ "default_floating_dtype",
12
+ "default_boolean_dtype",
13
+ "serialize_device",
14
+ "deserialize_device",
15
+ "is_backendarray",
16
+ "from_numpy",
17
+ "from_other_backend",
18
+ "to_numpy",
19
+ "to_dlpack",
20
+ "dtype_is_real_integer",
21
+ "dtype_is_real_floating",
22
+ "dtype_is_boolean",
23
+ "abbreviate_array",
24
+ "map_fn_over_arrays",
25
+ "pad_dim",
26
+ ]
27
+
28
+ default_integer_dtype = int
29
+ default_index_dtype = int
30
+ default_floating_dtype = float
31
+ default_boolean_dtype = bool
32
+
33
+ def serialize_device(device : Optional[Any]) -> Optional[str]:
34
+ if device is None:
35
+ return None
36
+ assert hasattr(device, "platform") and hasattr(device, "id")
37
+ return device.platform + ":" + str(device.id)
38
+
39
+ def deserialize_device(device_str : Optional[str]) -> Optional[Any]:
40
+ if device_str is None:
41
+ return None
42
+ platform, id_str = device_str.split(":")
43
+ id = int(id_str)
44
+ # Find the device with matching platform and id
45
+ for device in jax.devices(platform):
46
+ if device.platform == platform and device.id == id:
47
+ return device
48
+ # Fallback with the same platform
49
+ for device in jax.devices(platform):
50
+ if device.platform == platform:
51
+ return device
52
+ raise ValueError(f"Device with platform '{platform}' and id '{id}' not found.")
53
+
54
+ def is_backendarray(data : Any) -> bool:
55
+ return isinstance(data, jax.Array)
56
+
57
+ def from_numpy(
58
+ data : np.ndarray,
59
+ /,
60
+ *,
61
+ dtype : Optional[DTYPE_TYPE] = None,
62
+ device : Optional[DEVICE_TYPE] = None
63
+ ) -> ARRAY_TYPE:
64
+ return jax.numpy.asarray(data, dtype=dtype, device=device)
65
+
66
+ def from_other_backend(
67
+ other_backend: ComputeBackend,
68
+ data: Any,
69
+ /,
70
+ ) -> ARRAY_TYPE:
71
+ data_dlpack = other_backend.to_dlpack(data)
72
+ return jax.dlpack.from_dlpack(data_dlpack)
73
+ # except Exception as e:
74
+ # # jax sometimes has tiling issues with dlpack converted data
75
+ # np = other_backend.to_numpy(data)
76
+ # return from_numpy(np)
77
+
78
+ def to_numpy(
79
+ data : ARRAY_TYPE
80
+ ) -> np.ndarray:
81
+ if data.dtype == jax.dtypes.bfloat16:
82
+ data = data.astype(np.float32)
83
+ return np.asarray(data)
84
+
85
+ def to_dlpack(
86
+ data: ARRAY_TYPE,
87
+ /,
88
+ ) -> SupportsDLPack:
89
+ return data
90
+
91
+ def dtype_is_real_integer(
92
+ dtype: DTYPE_TYPE
93
+ ) -> bool:
94
+ return np.issubdtype(dtype, np.integer)
95
+
96
+ def dtype_is_real_floating(
97
+ dtype: DTYPE_TYPE
98
+ ) -> bool:
99
+ return dtype == jax.dtypes.bfloat16 or np.issubdtype(dtype, np.floating)
100
+
101
+ def dtype_is_boolean(
102
+ dtype: DTYPE_TYPE
103
+ ) -> bool:
104
+ return dtype == np.bool_ or dtype == bool
105
+
106
+ from .._common.implementations import *
107
+ if hasattr(jax.numpy, "__array_api_version__"):
108
+ compat_module = jax.numpy
109
+ else:
110
+ import jax.experimental.array_api as compat_module
111
+ abbreviate_array = get_abbreviate_array_function(
112
+ backend=compat_module,
113
+ default_integer_dtype=default_integer_dtype,
114
+ func_dtype_is_real_floating=dtype_is_real_floating,
115
+ func_dtype_is_real_integer=dtype_is_real_integer,
116
+ func_dtype_is_boolean=dtype_is_boolean
117
+ )
118
+ def map_fn_over_arrays(
119
+ data : Any, func : Callable[[ARRAY_TYPE], ARRAY_TYPE]
120
+ ):
121
+ return jax.tree.map(
122
+ func,
123
+ data
124
+ )
125
+ pad_dim = get_pad_dim_function(
126
+ backend=compat_module,
127
+ )
@@ -0,0 +1,15 @@
1
+ from typing import Union, Any, Optional
2
+ import jax
3
+ import numpy as np
4
+
5
+ __all__ = [
6
+ 'ARRAY_TYPE',
7
+ 'DTYPE_TYPE',
8
+ 'DEVICE_TYPE',
9
+ 'RNG_TYPE',
10
+ ]
11
+
12
+ ARRAY_TYPE = jax.Array
13
+ DTYPE_TYPE = np.dtype
14
+ DEVICE_TYPE = Union[jax.Device, jax.sharding.Sharding]
15
+ RNG_TYPE = jax.Array
@@ -0,0 +1,115 @@
1
+ from typing import Union, Optional, Tuple, Any
2
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
3
+ import numpy as np
4
+ import jax
5
+
6
+ __all__ = [
7
+ "random_number_generator",
8
+ "random_discrete_uniform",
9
+ "random_uniform",
10
+ "random_exponential",
11
+ "random_normal",
12
+ "random_geometric",
13
+ "random_permutation"
14
+ ]
15
+
16
+ def random_number_generator(
17
+ seed : Optional[int] = None,
18
+ *,
19
+ device : Optional[DEVICE_TYPE] = None
20
+ ) -> RNG_TYPE:
21
+ rng_seed = np.random.randint(65535) if seed is None else seed
22
+ rng = jax.random.key(
23
+ seed=rng_seed
24
+ )
25
+ return rng
26
+
27
+ def random_discrete_uniform(
28
+ shape : Union[int, Tuple[int, ...]],
29
+ /,
30
+ from_num : int,
31
+ to_num : int,
32
+ *,
33
+ rng : RNG_TYPE,
34
+ dtype : Optional[DTYPE_TYPE] = None,
35
+ device : Optional[DEVICE_TYPE] = None
36
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
37
+ new_rng, rng = jax.random.split(rng)
38
+ t = jax.random.randint(rng, shape, minval=int(from_num), maxval=int(to_num), dtype=dtype or int)
39
+ if device is not None:
40
+ t = jax.device_put(t, device)
41
+ return new_rng, t
42
+
43
+ def random_uniform(
44
+ shape: Union[int, Tuple[int, ...]],
45
+ /,
46
+ *,
47
+ rng : RNG_TYPE,
48
+ low : float = 0.0, high : float = 1.0,
49
+ dtype : Optional[DTYPE_TYPE] = None,
50
+ device : Optional[DEVICE_TYPE] = None
51
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
52
+ new_rng, rng = jax.random.split(rng)
53
+ data = jax.random.uniform(rng, shape, dtype=dtype or float, minval=low, maxval=high)
54
+ if device is not None:
55
+ data = jax.device_put(data, device)
56
+ return new_rng, data
57
+
58
+ def random_exponential(
59
+ shape: Union[int, Tuple[int, ...]],
60
+ /,
61
+ *,
62
+ rng : RNG_TYPE,
63
+ lambd : float = 1.0,
64
+ dtype : Optional[DTYPE_TYPE] = None,
65
+ device : Optional[DEVICE_TYPE] = None
66
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
67
+ new_rng, rng = jax.random.split(rng)
68
+ data = jax.random.exponential(rng, shape, dtype=dtype or float) / lambd
69
+ if device is not None:
70
+ data = jax.device_put(data, device)
71
+ return new_rng, data
72
+
73
+ def random_normal(
74
+ shape: Union[int, Tuple[int, ...]],
75
+ /,
76
+ *,
77
+ rng : RNG_TYPE,
78
+ mean : float = 0.0, std : float = 1.0,
79
+ dtype : Optional[DTYPE_TYPE] = None,
80
+ device : Optional[Any] = None
81
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
82
+ new_rng, rng = jax.random.split(rng)
83
+ data = jax.random.normal(rng, shape, dtype=dtype or float) * std + mean
84
+ if device is not None:
85
+ data = jax.device_put(data, device)
86
+ return new_rng, data
87
+
88
+ def random_geometric(
89
+ shape: Union[int, Tuple[int, ...]],
90
+ /,
91
+ *,
92
+ p: float,
93
+ rng: RNG_TYPE,
94
+ dtype: Optional[DTYPE_TYPE] = None,
95
+ device: Optional[Any] = None
96
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
97
+ new_rng, rng = jax.random.split(rng)
98
+ data = jax.random.geometric(rng, p=p, shape=shape, dtype=dtype or int)
99
+ if device is not None:
100
+ data = jax.device_put(data, device)
101
+ return new_rng, data
102
+
103
+ def random_permutation(
104
+ n : int,
105
+ /,
106
+ *,
107
+ rng: RNG_TYPE,
108
+ dtype: Optional[DTYPE_TYPE] = None,
109
+ device: Optional[DEVICE_TYPE] = None
110
+ ) -> Tuple[RNG_TYPE, ARRAY_TYPE]:
111
+ new_rng, rng = jax.random.split(rng)
112
+ data = jax.random.permutation(rng, n, dtype=dtype or int)
113
+ if device is not None:
114
+ data = jax.device_put(data, device)
115
+ return new_rng, data
@@ -0,0 +1,25 @@
1
+ from array_api_compat.numpy import *
2
+ from array_api_compat.numpy import __array_api_version__, __array_namespace_info__
3
+ from array_api_compat.common._helpers import *
4
+
5
+ simplified_name = "numpy"
6
+
7
+ from array_api_compat import numpy as compat_module
8
+ # Import and bind all functions from array_api_extra before exposing them
9
+ import array_api_extra
10
+ from functools import partial
11
+ for api_name in dir(array_api_extra):
12
+ if api_name.startswith('_'):
13
+ continue
14
+
15
+ if api_name in ['at', 'broadcast_shapes']:
16
+ globals()[api_name] = getattr(array_api_extra, api_name)
17
+ else:
18
+ globals()[api_name] = partial(
19
+ getattr(array_api_extra, api_name),
20
+ xp=compat_module
21
+ )
22
+
23
+ from ._typing import *
24
+ from ._extra import *
25
+ __import__(__package__ + ".random")
@@ -0,0 +1,98 @@
1
+ from typing import Any, Union, Optional
2
+ import numpy as np
3
+ from ._typing import ARRAY_TYPE, DTYPE_TYPE, DEVICE_TYPE, RNG_TYPE
4
+ from xbarray.backends.base import ComputeBackend, SupportsDLPack
5
+
6
+ __all__ = [
7
+ "default_integer_dtype",
8
+ "default_index_dtype",
9
+ "default_floating_dtype",
10
+ "default_boolean_dtype",
11
+ "serialize_device",
12
+ "deserialize_device",
13
+ "is_backendarray",
14
+ "from_numpy",
15
+ "from_other_backend",
16
+ "to_numpy",
17
+ "to_dlpack",
18
+ "dtype_is_real_integer",
19
+ "dtype_is_real_floating",
20
+ "dtype_is_boolean",
21
+ "abbreviate_array",
22
+ "map_fn_over_arrays",
23
+ "pad_dim",
24
+ ]
25
+
26
+ default_integer_dtype = int
27
+ default_index_dtype = int
28
+ default_floating_dtype = float
29
+ default_boolean_dtype = bool
30
+
31
+ def serialize_device(device : Optional[DEVICE_TYPE]) -> Optional[str]:
32
+ return None # NumPy does not have device concept
33
+
34
+ def deserialize_device(device_str : Optional[str]) -> Optional[DEVICE_TYPE]:
35
+ return None # NumPy does not have device concept
36
+
37
+ def is_backendarray(data : Any) -> bool:
38
+ return isinstance(data, np.ndarray)
39
+
40
+ def from_numpy(
41
+ data : np.ndarray,
42
+ /,
43
+ *,
44
+ dtype : Optional[DTYPE_TYPE] = None,
45
+ device : Optional[DEVICE_TYPE] = None
46
+ ) -> ARRAY_TYPE:
47
+ return data
48
+
49
+ def from_other_backend(
50
+ other_backend: ComputeBackend,
51
+ data: Any,
52
+ /,
53
+ ) -> ARRAY_TYPE:
54
+ return other_backend.to_numpy(data)
55
+
56
+ def to_numpy(
57
+ data : ARRAY_TYPE
58
+ ) -> np.ndarray:
59
+ return data
60
+
61
+ def to_dlpack(
62
+ data: ARRAY_TYPE,
63
+ /,
64
+ ) -> SupportsDLPack:
65
+ return data
66
+
67
+ def dtype_is_real_integer(
68
+ dtype: DTYPE_TYPE
69
+ ) -> bool:
70
+ return np.issubdtype(dtype, np.integer)
71
+
72
+ def dtype_is_real_floating(
73
+ dtype: DTYPE_TYPE
74
+ ) -> bool:
75
+ return np.issubdtype(dtype, np.floating)
76
+
77
+ def dtype_is_boolean(
78
+ dtype: DTYPE_TYPE
79
+ ) -> bool:
80
+ return dtype == np.bool_ or dtype == bool
81
+
82
+ from .._common.implementations import *
83
+ from array_api_compat import numpy as compat_module
84
+ abbreviate_array = get_abbreviate_array_function(
85
+ backend=compat_module,
86
+ default_integer_dtype=default_integer_dtype,
87
+ func_dtype_is_real_floating=dtype_is_real_floating,
88
+ func_dtype_is_real_integer=dtype_is_real_integer,
89
+ func_dtype_is_boolean=dtype_is_boolean,
90
+ )
91
+
92
+ map_fn_over_arrays = get_map_fn_over_arrays_function(
93
+ is_backendarray=is_backendarray,
94
+ )
95
+
96
+ pad_dim = get_pad_dim_function(
97
+ backend=compat_module,
98
+ )
@@ -0,0 +1,14 @@
1
+ from typing import Union, Any, Optional
2
+ import numpy as np
3
+
4
+ __all__ = [
5
+ 'ARRAY_TYPE',
6
+ 'DTYPE_TYPE',
7
+ 'DEVICE_TYPE',
8
+ 'RNG_TYPE',
9
+ ]
10
+
11
+ ARRAY_TYPE = np.ndarray
12
+ DTYPE_TYPE = np.dtype
13
+ DEVICE_TYPE = Any
14
+ RNG_TYPE = np.random.Generator