xbarray 0.0.1a1__tar.gz → 0.0.1a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xbarray might be problematic. Click here for more details.

Files changed (47) hide show
  1. {xbarray-0.0.1a1/xbarray.egg-info → xbarray-0.0.1a2}/PKG-INFO +2 -1
  2. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/README.md +13 -7
  3. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/pyproject.toml +2 -1
  4. xbarray-0.0.1a2/xbarray/__init__.py +1 -0
  5. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/jax/__init__.py +10 -4
  6. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/jax/_extra.py +1 -1
  7. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/jax/random.py +1 -2
  8. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/numpy/__init__.py +8 -4
  9. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/numpy/_extra.py +1 -1
  10. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/numpy/random.py +1 -2
  11. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/pytorch/__init__.py +8 -4
  12. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/pytorch/_extra.py +1 -1
  13. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/pytorch/random.py +1 -2
  14. xbarray-0.0.1a2/xbarray/_src/serialization/__init__.py +1 -0
  15. xbarray-0.0.1a2/xbarray/_src/serialization/serialization_map.py +31 -0
  16. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray/base/base.py +2 -2
  17. xbarray-0.0.1a2/xbarray/cls_impl/cls_base.py +10 -0
  18. xbarray-0.0.1a2/xbarray/jax.py +13 -0
  19. xbarray-0.0.1a2/xbarray/numpy.py +13 -0
  20. xbarray-0.0.1a2/xbarray/pytorch.py +13 -0
  21. {xbarray-0.0.1a1 → xbarray-0.0.1a2/xbarray.egg-info}/PKG-INFO +2 -1
  22. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray.egg-info/SOURCES.txt +19 -13
  23. xbarray-0.0.1a1/xbarray/__init__.py +0 -1
  24. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/LICENSE +0 -0
  25. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/__init__.py +0 -0
  26. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/__init__.py +0 -0
  27. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_api_constant.py +0 -0
  28. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_api_fft_typing.py +0 -0
  29. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_api_linalg_typing.py +0 -0
  30. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_api_return_typing.py +0 -0
  31. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_api_typing.py +0 -0
  32. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_2024_12/_array_typing.py +0 -0
  33. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_compat/__init__.py +0 -0
  34. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_compat/_api_typing.py +0 -0
  35. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_compat/_array_typing.py +0 -0
  36. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_extra/__init__.py +0 -0
  37. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_extra/_api_typing.py +0 -0
  38. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/array_api_typing/typing_extra/_at.py +0 -0
  39. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/setup.cfg +0 -0
  40. {xbarray-0.0.1a1/xbarray/common → xbarray-0.0.1a2/xbarray/_src/implementations/_common}/implementations.py +0 -0
  41. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/jax/_typing.py +0 -0
  42. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/numpy/_typing.py +0 -0
  43. {xbarray-0.0.1a1/xbarray → xbarray-0.0.1a2/xbarray/_src/implementations}/pytorch/_typing.py +0 -0
  44. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray/base/__init__.py +0 -0
  45. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray.egg-info/dependency_links.txt +0 -0
  46. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray.egg-info/requires.txt +0 -0
  47. {xbarray-0.0.1a1 → xbarray-0.0.1a2}/xbarray.egg-info/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xbarray
3
- Version: 0.0.1a1
3
+ Version: 0.0.1a2
4
+ Summary: Cross-backend Python array library based on the Array API Standard.
4
5
  Requires-Python: >=3.10
5
6
  License-File: LICENSE
6
7
  Requires-Dist: typing_extensions>=4.5
@@ -1,17 +1,23 @@
1
- # xarray
2
- Cross-backend Python N-dimensional array library based on Array API.
1
+ # xbarray
2
+ Cross-backend Python array library based on the Array API Standard.
3
3
 
4
4
  This allows you to write array transformations that can run on different backends like NumPy, PyTorch, and Jax.
5
5
 
6
+ ## Installation
7
+
8
+ ```bash
9
+ pip install xbarray
10
+ ```
11
+
6
12
  ## Usage:
7
13
 
8
14
  Abstract typing:
9
15
 
10
16
  ```python
11
- from xarray import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
17
+ from xbarray import ComputeBackend, BArrayType, BDeviceType, BDtypeType, BRNGType
12
18
  from typing import Generic
13
19
 
14
- class ABC(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType]):
20
+ class Behavior(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType]):
15
21
  def __init__(self, backend : ComputeBackend[BArrayType, BDeviceType, BDtypeType, BRNGType]) -> None:
16
22
  self.backend = backend
17
23
 
@@ -22,8 +28,8 @@ class ABC(Generic[BArrayType, BDeviceType, BDtypeType, BRNGType]):
22
28
  Concrete usage:
23
29
 
24
30
  ```python
25
- from xarray import pytorch as pytorch_backend
31
+ from xbarray.pytorch import PyTorchComputeBackend
26
32
 
27
- abc_pytorch_instance = ABC(pytorch_backend)
28
- abc_pytorch_array = abc_pytorch_instance.create_array()
33
+ behavior_pytorch_instance = Behavior(PyTorchComputeBackend)
34
+ behavior_pytorch_array = behavior_pytorch_instance.create_array()
29
35
  ```
@@ -1,6 +1,7 @@
1
1
  [project]
2
2
  name = "xbarray"
3
- version = "0.0.1a1"
3
+ description = "Cross-backend Python array library based on the Array API Standard."
4
+ version = "0.0.1a2"
4
5
  requires-python = ">= 3.10"
5
6
  dependencies = [
6
7
  "typing_extensions>=4.5",
@@ -0,0 +1 @@
1
+ from .base import * # Import Abstract Typings
@@ -9,16 +9,22 @@ else:
9
9
  import jax.experimental.array_api as compat_module
10
10
  from jax.experimental.array_api import *
11
11
 
12
+ from array_api_compat.common._helpers import *
13
+
12
14
  # Import and bind all functions from array_api_extra before exposing them
13
15
  import array_api_extra
14
16
  from functools import partial
15
17
  for api_name in dir(array_api_extra):
16
18
  if api_name.startswith('_'):
17
19
  continue
18
- globals()[api_name] = partial(
19
- getattr(array_api_extra, api_name),
20
- xp=compat_module
21
- )
20
+
21
+ if api_name == 'at':
22
+ globals()[api_name] = getattr(array_api_extra, api_name)
23
+ else:
24
+ globals()[api_name] = partial(
25
+ getattr(array_api_extra, api_name),
26
+ xp=compat_module
27
+ )
22
28
 
23
29
  from ._typing import *
24
30
  from ._extra import *
@@ -77,7 +77,7 @@ def dtype_is_boolean(
77
77
  ) -> bool:
78
78
  return dtype == np.bool_ or dtype == bool
79
79
 
80
- from xbarray.common.implementations import *
80
+ from .._common.implementations import *
81
81
  if hasattr(jax.numpy, "__array_api_version__"):
82
82
  compat_module = jax.numpy
83
83
  else:
@@ -26,9 +26,9 @@ def random_number_generator(
26
26
 
27
27
  def random_discrete_uniform(
28
28
  shape : Union[int, Tuple[int, ...]],
29
+ /,
29
30
  from_num : int,
30
31
  to_num : int,
31
- /,
32
32
  *,
33
33
  rng : RNG_TYPE,
34
34
  dtype : Optional[DTYPE_TYPE] = None,
@@ -70,7 +70,6 @@ def random_exponential(
70
70
  data = jax.device_put(data, device)
71
71
  return new_rng, data
72
72
 
73
- @classmethod
74
73
  def random_normal(
75
74
  shape: Union[int, Tuple[int, ...]],
76
75
  /,
@@ -9,10 +9,14 @@ from functools import partial
9
9
  for api_name in dir(array_api_extra):
10
10
  if api_name.startswith('_'):
11
11
  continue
12
- globals()[api_name] = partial(
13
- getattr(array_api_extra, api_name),
14
- xp=compat_module
15
- )
12
+
13
+ if api_name == 'at':
14
+ globals()[api_name] = getattr(array_api_extra, api_name)
15
+ else:
16
+ globals()[api_name] = partial(
17
+ getattr(array_api_extra, api_name),
18
+ xp=compat_module
19
+ )
16
20
 
17
21
  from ._typing import *
18
22
  from ._extra import *
@@ -68,7 +68,7 @@ def dtype_is_boolean(
68
68
  ) -> bool:
69
69
  return dtype == np.bool_ or dtype == bool
70
70
 
71
- from xbarray.common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
71
+ from .._common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
72
72
  from array_api_compat import numpy as compat_module
73
73
  abbreviate_array = get_abbreviate_array_function(
74
74
  backend=compat_module,
@@ -21,9 +21,9 @@ def random_number_generator(
21
21
 
22
22
  def random_discrete_uniform(
23
23
  shape : Union[int, Tuple[int, ...]],
24
+ /,
24
25
  from_num : int,
25
26
  to_num : int,
26
- /,
27
27
  *,
28
28
  rng : RNG_TYPE,
29
29
  dtype : Optional[DTYPE_TYPE] = None,
@@ -62,7 +62,6 @@ def random_exponential(
62
62
  t = t.astype(dtype)
63
63
  return rng, t
64
64
 
65
- @classmethod
66
65
  def random_normal(
67
66
  shape: Union[int, Tuple[int, ...]],
68
67
  /,
@@ -10,10 +10,14 @@ from functools import partial
10
10
  for api_name in dir(array_api_extra):
11
11
  if api_name.startswith('_'):
12
12
  continue
13
- globals()[api_name] = partial(
14
- getattr(array_api_extra, api_name),
15
- xp=compat_module
16
- )
13
+
14
+ if api_name == 'at':
15
+ globals()[api_name] = getattr(array_api_extra, api_name)
16
+ else:
17
+ globals()[api_name] = partial(
18
+ getattr(array_api_extra, api_name),
19
+ xp=compat_module
20
+ )
17
21
 
18
22
  from ._typing import *
19
23
  from ._extra import *
@@ -95,7 +95,7 @@ def dtype_is_boolean(
95
95
  ) -> bool:
96
96
  return dtype == torch.bool
97
97
 
98
- from xbarray.common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
98
+ from .._common.implementations import get_abbreviate_array_function, get_map_fn_over_arrays_function
99
99
  from array_api_compat import torch as compat_module
100
100
  abbreviate_array = get_abbreviate_array_function(
101
101
  compat_module,
@@ -27,9 +27,9 @@ def random_number_generator(
27
27
 
28
28
  def random_discrete_uniform(
29
29
  shape : Union[int, Tuple[int, ...]],
30
+ /,
30
31
  from_num : int,
31
32
  to_num : int,
32
- /,
33
33
  *,
34
34
  rng : RNG_TYPE,
35
35
  dtype : Optional[DTYPE_TYPE] = None,
@@ -64,7 +64,6 @@ def random_exponential(
64
64
  t = t.exponential_(lambd, generator=rng)
65
65
  return rng, t
66
66
 
67
- @classmethod
68
67
  def random_normal(
69
68
  shape: Union[int, Tuple[int, ...]],
70
69
  /,
@@ -0,0 +1 @@
1
+ from .serialization_map import *
@@ -0,0 +1,31 @@
1
+ from typing import Type, Dict, Any, Optional, List, Callable
2
+ from types import ModuleType
3
+ import importlib
4
+
5
+ __all__ = [
6
+ "implementation_module_to_name",
7
+ "name_to_implementation_module",
8
+ ]
9
+
10
+ def implementation_module_to_name(module : ModuleType) -> str:
11
+ """
12
+ Convert a backend module to its simplified name.
13
+ """
14
+ full_name = module.__name__
15
+ if not full_name.startswith("xbarray.implementations."):
16
+ raise ValueError(f"Module {full_name} is not a valid xbarray backend module.")
17
+
18
+ submodule_name = full_name[len("xbarray.implementations"):]
19
+ if '.' in submodule_name:
20
+ raise ValueError(f"Module {full_name} is not a valid xbarray backend module.")
21
+ return submodule_name
22
+
23
+ def name_to_implementation_module(name: str) -> Type[ModuleType]:
24
+ """
25
+ Convert a simplified backend name to its module.
26
+ """
27
+ try:
28
+ return importlib.import_module(f"xbarray.implementations.{name}")
29
+ except ImportError as e:
30
+ raise ImportError(f"Could not import backend module '{name}'.") from e
31
+
@@ -40,9 +40,9 @@ class RNGBackend(Protocol[BArrayType, BDeviceType, BDtypeType, BRNGType]):
40
40
  def random_discrete_uniform(
41
41
  self,
42
42
  shape : Union[int, Tuple[int, ...]],
43
- from_num : int,
44
- to_num : int,
45
43
  /,
44
+ from_num : int,
45
+ to_num : int,
46
46
  *,
47
47
  rng : BRNGType,
48
48
  dtype : Optional[BDtypeType] = None,
@@ -0,0 +1,10 @@
1
+ from types import ModuleType
2
+ from typing import Type
3
+ from xbarray._src.serialization import implementation_module_to_name, name_to_implementation_module
4
+
5
+ class ComputeBackendImplCls(Type):
6
+ def __str__(self):
7
+ return self.simplified_name
8
+
9
+ def __repr__(self):
10
+ return self.simplified_name
@@ -0,0 +1,13 @@
1
+ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
+ from ._src.implementations import jax as jax_impl
3
+
4
+ class JaxComputeBackend(metaclass=ComputeBackendImplCls):
5
+ pass
6
+
7
+ for name in dir(jax_impl):
8
+ if not name.startswith('_'):
9
+ setattr(JaxComputeBackend, name, getattr(jax_impl, name))
10
+
11
+ __all__ = [
12
+ 'JaxComputeBackend',
13
+ ]
@@ -0,0 +1,13 @@
1
+ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
+ from ._src.implementations import numpy as numpy_impl
3
+
4
+ class NumpyComputeBackend(metaclass=ComputeBackendImplCls):
5
+ pass
6
+
7
+ for name in dir(numpy_impl):
8
+ if not name.startswith('_'):
9
+ setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
10
+
11
+ __all__ = [
12
+ 'NumpyComputeBackend',
13
+ ]
@@ -0,0 +1,13 @@
1
+ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
+ from ._src.implementations import pytorch as pytorch_impl
3
+
4
+ class PytorchComputeBackend(metaclass=ComputeBackendImplCls):
5
+ pass
6
+
7
+ for name in dir(pytorch_impl):
8
+ if not name.startswith('_'):
9
+ setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
10
+
11
+ __all__ = [
12
+ 'PytorchComputeBackend',
13
+ ]
@@ -1,6 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xbarray
3
- Version: 0.0.1a1
3
+ Version: 0.0.1a2
4
+ Summary: Cross-backend Python array library based on the Array API Standard.
4
5
  Requires-Python: >=3.10
5
6
  License-File: LICENSE
6
7
  Requires-Dist: typing_extensions>=4.5
@@ -16,23 +16,29 @@ array_api_typing/typing_extra/__init__.py
16
16
  array_api_typing/typing_extra/_api_typing.py
17
17
  array_api_typing/typing_extra/_at.py
18
18
  xbarray/__init__.py
19
+ xbarray/jax.py
20
+ xbarray/numpy.py
21
+ xbarray/pytorch.py
19
22
  xbarray.egg-info/PKG-INFO
20
23
  xbarray.egg-info/SOURCES.txt
21
24
  xbarray.egg-info/dependency_links.txt
22
25
  xbarray.egg-info/requires.txt
23
26
  xbarray.egg-info/top_level.txt
27
+ xbarray/_src/implementations/_common/implementations.py
28
+ xbarray/_src/implementations/jax/__init__.py
29
+ xbarray/_src/implementations/jax/_extra.py
30
+ xbarray/_src/implementations/jax/_typing.py
31
+ xbarray/_src/implementations/jax/random.py
32
+ xbarray/_src/implementations/numpy/__init__.py
33
+ xbarray/_src/implementations/numpy/_extra.py
34
+ xbarray/_src/implementations/numpy/_typing.py
35
+ xbarray/_src/implementations/numpy/random.py
36
+ xbarray/_src/implementations/pytorch/__init__.py
37
+ xbarray/_src/implementations/pytorch/_extra.py
38
+ xbarray/_src/implementations/pytorch/_typing.py
39
+ xbarray/_src/implementations/pytorch/random.py
40
+ xbarray/_src/serialization/__init__.py
41
+ xbarray/_src/serialization/serialization_map.py
24
42
  xbarray/base/__init__.py
25
43
  xbarray/base/base.py
26
- xbarray/common/implementations.py
27
- xbarray/jax/__init__.py
28
- xbarray/jax/_extra.py
29
- xbarray/jax/_typing.py
30
- xbarray/jax/random.py
31
- xbarray/numpy/__init__.py
32
- xbarray/numpy/_extra.py
33
- xbarray/numpy/_typing.py
34
- xbarray/numpy/random.py
35
- xbarray/pytorch/__init__.py
36
- xbarray/pytorch/_extra.py
37
- xbarray/pytorch/_typing.py
38
- xbarray/pytorch/random.py
44
+ xbarray/cls_impl/cls_base.py
@@ -1 +0,0 @@
1
- from .base import *
File without changes
File without changes