xbarray 0.0.1a5__tar.gz → 0.0.1a7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xbarray might be problematic. Click here for more details.
- {xbarray-0.0.1a5/xbarray.egg-info → xbarray-0.0.1a7}/PKG-INFO +1 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_api_return_typing.py +2 -2
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/pyproject.toml +1 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/jax/__init__.py +3 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/numpy/__init__.py +2 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/pytorch/__init__.py +2 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/jax.py +8 -2
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/numpy.py +8 -2
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/pytorch.py +8 -2
- {xbarray-0.0.1a5 → xbarray-0.0.1a7/xbarray.egg-info}/PKG-INFO +1 -1
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/LICENSE +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/README.md +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_api_constant.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_api_fft_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_api_linalg_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_api_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_2024_12/_array_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_compat/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_compat/_api_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_compat/_array_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_extra/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_extra/_api_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/array_api_typing/typing_extra/_at.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/setup.cfg +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/_common/implementations.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/jax/_extra.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/jax/_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/jax/random.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/numpy/_extra.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/numpy/_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/numpy/random.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/pytorch/_extra.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/pytorch/_typing.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/implementations/pytorch/random.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/serialization/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/_src/serialization/serialization_map.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/base/__init__.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/base/base.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray/cls_impl/cls_base.py +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray.egg-info/SOURCES.txt +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray.egg-info/dependency_links.txt +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray.egg-info/requires.txt +0 -0
- {xbarray-0.0.1a5 → xbarray-0.0.1a7}/xbarray.egg-info/top_level.txt +0 -0
|
@@ -98,6 +98,6 @@ class Info(Protocol):
|
|
|
98
98
|
...
|
|
99
99
|
|
|
100
100
|
def dtypes(
|
|
101
|
-
self, *, device: Optional[Device], kind: Optional[Union[str, Tuple[str, ...]]]
|
|
102
|
-
) -> DataTypes:
|
|
101
|
+
self, *, device: Optional[Device] = None, kind: Optional[Union[str, Tuple[str, ...]]] = None
|
|
102
|
+
) -> DataTypes: # In the original signature there's no default parameters for `device` and `kind`, but we add them since they exist in `array_api_compat` modules
|
|
103
103
|
...
|
|
@@ -3,9 +3,11 @@ import jax.numpy
|
|
|
3
3
|
if hasattr(jax.numpy, "__array_api_version__"):
|
|
4
4
|
compat_module = jax.numpy
|
|
5
5
|
from jax.numpy import *
|
|
6
|
+
from jax.numpy import __array_api_version__, __array_namespace_info__
|
|
6
7
|
else:
|
|
7
8
|
import jax.experimental.array_api as compat_module
|
|
8
9
|
from jax.experimental.array_api import *
|
|
10
|
+
from jax.experimental.array_api import __array_api_version__, __array_namespace_info__
|
|
9
11
|
|
|
10
12
|
from array_api_compat.common._helpers import *
|
|
11
13
|
|
|
@@ -18,7 +20,7 @@ for api_name in dir(array_api_extra):
|
|
|
18
20
|
if api_name.startswith('_'):
|
|
19
21
|
continue
|
|
20
22
|
|
|
21
|
-
if api_name
|
|
23
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
22
24
|
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
23
25
|
else:
|
|
24
26
|
globals()[api_name] = partial(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from array_api_compat.numpy import *
|
|
2
|
+
from array_api_compat.numpy import __array_api_version__, __array_namespace_info__
|
|
2
3
|
from array_api_compat.common._helpers import *
|
|
3
4
|
|
|
4
5
|
simplified_name = "numpy"
|
|
@@ -11,7 +12,7 @@ for api_name in dir(array_api_extra):
|
|
|
11
12
|
if api_name.startswith('_'):
|
|
12
13
|
continue
|
|
13
14
|
|
|
14
|
-
if api_name
|
|
15
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
15
16
|
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
16
17
|
else:
|
|
17
18
|
globals()[api_name] = partial(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from array_api_compat.torch import *
|
|
2
|
+
from array_api_compat.torch import __array_api_version__, __array_namespace_info__
|
|
2
3
|
from array_api_compat.common._helpers import *
|
|
3
4
|
|
|
4
5
|
simplified_name = "pytorch"
|
|
@@ -12,7 +13,7 @@ for api_name in dir(array_api_extra):
|
|
|
12
13
|
if api_name.startswith('_'):
|
|
13
14
|
continue
|
|
14
15
|
|
|
15
|
-
if api_name
|
|
16
|
+
if api_name in ['at', 'broadcast_shapes']:
|
|
16
17
|
globals()[api_name] = getattr(array_api_extra, api_name)
|
|
17
18
|
else:
|
|
18
19
|
globals()[api_name] = partial(
|
|
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
|
|
|
2
2
|
from ._src.implementations import jax as jax_impl
|
|
3
3
|
|
|
4
4
|
class JaxComputeBackend(metaclass=ComputeBackendImplCls):
|
|
5
|
-
|
|
5
|
+
ARRAY_TYPE = jax_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = jax_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = jax_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = jax_impl.RNG_TYPE
|
|
6
9
|
|
|
7
10
|
for name in dir(jax_impl):
|
|
8
|
-
if not name.startswith('_')
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
9
15
|
setattr(JaxComputeBackend, name, getattr(jax_impl, name))
|
|
10
16
|
|
|
11
17
|
__all__ = [
|
|
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
|
|
|
2
2
|
from ._src.implementations import numpy as numpy_impl
|
|
3
3
|
|
|
4
4
|
class NumpyComputeBackend(metaclass=ComputeBackendImplCls):
|
|
5
|
-
|
|
5
|
+
ARRAY_TYPE = numpy_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = numpy_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = numpy_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = numpy_impl.RNG_TYPE
|
|
6
9
|
|
|
7
10
|
for name in dir(numpy_impl):
|
|
8
|
-
if not name.startswith('_')
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
9
15
|
setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
|
|
10
16
|
|
|
11
17
|
__all__ = [
|
|
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
|
|
|
2
2
|
from ._src.implementations import pytorch as pytorch_impl
|
|
3
3
|
|
|
4
4
|
class PytorchComputeBackend(metaclass=ComputeBackendImplCls):
|
|
5
|
-
|
|
5
|
+
ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
|
|
6
|
+
DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
|
|
7
|
+
DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
|
|
8
|
+
RNG_TYPE = pytorch_impl.RNG_TYPE
|
|
6
9
|
|
|
7
10
|
for name in dir(pytorch_impl):
|
|
8
|
-
if not name.startswith('_')
|
|
11
|
+
if not name.startswith('_') or name in [
|
|
12
|
+
'__array_namespace_info__',
|
|
13
|
+
'__array_api_version__',
|
|
14
|
+
]:
|
|
9
15
|
setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
|
|
10
16
|
|
|
11
17
|
__all__ = [
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|