xbarray 0.0.1a5__py3-none-any.whl → 0.0.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xbarray might be problematic. Click here for more details.

@@ -98,6 +98,6 @@ class Info(Protocol):
98
98
  ...
99
99
 
100
100
  def dtypes(
101
- self, *, device: Optional[Device], kind: Optional[Union[str, Tuple[str, ...]]]
102
- ) -> DataTypes:
101
+ self, *, device: Optional[Device] = None, kind: Optional[Union[str, Tuple[str, ...]]] = None
102
+ ) -> DataTypes: # In the original signature there's no default parameters for `device` and `kind`, but we add them since they exist in `array_api_compat` modules
103
103
  ...
@@ -3,9 +3,11 @@ import jax.numpy
3
3
  if hasattr(jax.numpy, "__array_api_version__"):
4
4
  compat_module = jax.numpy
5
5
  from jax.numpy import *
6
+ from jax.numpy import __array_api_version__, __array_namespace_info__
6
7
  else:
7
8
  import jax.experimental.array_api as compat_module
8
9
  from jax.experimental.array_api import *
10
+ from jax.experimental.array_api import __array_api_version__, __array_namespace_info__
9
11
 
10
12
  from array_api_compat.common._helpers import *
11
13
 
@@ -18,7 +20,7 @@ for api_name in dir(array_api_extra):
18
20
  if api_name.startswith('_'):
19
21
  continue
20
22
 
21
- if api_name == 'at':
23
+ if api_name in ['at', 'broadcast_shapes']:
22
24
  globals()[api_name] = getattr(array_api_extra, api_name)
23
25
  else:
24
26
  globals()[api_name] = partial(
@@ -1,4 +1,5 @@
1
1
  from array_api_compat.numpy import *
2
+ from array_api_compat.numpy import __array_api_version__, __array_namespace_info__
2
3
  from array_api_compat.common._helpers import *
3
4
 
4
5
  simplified_name = "numpy"
@@ -11,7 +12,7 @@ for api_name in dir(array_api_extra):
11
12
  if api_name.startswith('_'):
12
13
  continue
13
14
 
14
- if api_name == 'at':
15
+ if api_name in ['at', 'broadcast_shapes']:
15
16
  globals()[api_name] = getattr(array_api_extra, api_name)
16
17
  else:
17
18
  globals()[api_name] = partial(
@@ -1,4 +1,5 @@
1
1
  from array_api_compat.torch import *
2
+ from array_api_compat.torch import __array_api_version__, __array_namespace_info__
2
3
  from array_api_compat.common._helpers import *
3
4
 
4
5
  simplified_name = "pytorch"
@@ -12,7 +13,7 @@ for api_name in dir(array_api_extra):
12
13
  if api_name.startswith('_'):
13
14
  continue
14
15
 
15
- if api_name == 'at':
16
+ if api_name in ['at', 'broadcast_shapes']:
16
17
  globals()[api_name] = getattr(array_api_extra, api_name)
17
18
  else:
18
19
  globals()[api_name] = partial(
xbarray/jax.py CHANGED
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
2
  from ._src.implementations import jax as jax_impl
3
3
 
4
4
  class JaxComputeBackend(metaclass=ComputeBackendImplCls):
5
- pass
5
+ ARRAY_TYPE = jax_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = jax_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = jax_impl.DEVICE_TYPE
8
+ RNG_TYPE = jax_impl.RNG_TYPE
6
9
 
7
10
  for name in dir(jax_impl):
8
- if not name.startswith('_'):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
9
15
  setattr(JaxComputeBackend, name, getattr(jax_impl, name))
10
16
 
11
17
  __all__ = [
xbarray/numpy.py CHANGED
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
2
  from ._src.implementations import numpy as numpy_impl
3
3
 
4
4
  class NumpyComputeBackend(metaclass=ComputeBackendImplCls):
5
- pass
5
+ ARRAY_TYPE = numpy_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = numpy_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = numpy_impl.DEVICE_TYPE
8
+ RNG_TYPE = numpy_impl.RNG_TYPE
6
9
 
7
10
  for name in dir(numpy_impl):
8
- if not name.startswith('_'):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
9
15
  setattr(NumpyComputeBackend, name, getattr(numpy_impl, name))
10
16
 
11
17
  __all__ = [
xbarray/pytorch.py CHANGED
@@ -2,10 +2,16 @@ from xbarray.cls_impl.cls_base import ComputeBackendImplCls
2
2
  from ._src.implementations import pytorch as pytorch_impl
3
3
 
4
4
  class PytorchComputeBackend(metaclass=ComputeBackendImplCls):
5
- pass
5
+ ARRAY_TYPE = pytorch_impl.ARRAY_TYPE
6
+ DTYPE_TYPE = pytorch_impl.DTYPE_TYPE
7
+ DEVICE_TYPE = pytorch_impl.DEVICE_TYPE
8
+ RNG_TYPE = pytorch_impl.RNG_TYPE
6
9
 
7
10
  for name in dir(pytorch_impl):
8
- if not name.startswith('_'):
11
+ if not name.startswith('_') or name in [
12
+ '__array_namespace_info__',
13
+ '__array_api_version__',
14
+ ]:
9
15
  setattr(PytorchComputeBackend, name, getattr(pytorch_impl, name))
10
16
 
11
17
  __all__ = [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xbarray
3
- Version: 0.0.1a5
3
+ Version: 0.0.1a7
4
4
  Summary: Cross-backend Python array library based on the Array API Standard.
5
5
  Requires-Python: >=3.10
6
6
  License-File: LICENSE
@@ -3,7 +3,7 @@ array_api_typing/typing_2024_12/__init__.py,sha256=JuD2yojHl3eQI74U75tFhSX5rSy32
3
3
  array_api_typing/typing_2024_12/_api_constant.py,sha256=tahdzdi7vAZ3UnM8oDcUqvx2FFsMTAyUbuwgarU6K9U,676
4
4
  array_api_typing/typing_2024_12/_api_fft_typing.py,sha256=i48S-9trw3s72xiCJmeCNGlJLwwOVq0kJC4lD9se_co,38641
5
5
  array_api_typing/typing_2024_12/_api_linalg_typing.py,sha256=Dt5fDTwYb-GAmP7duajwPqH3zaJmwvxDEOOzdl5X0LE,48950
6
- array_api_typing/typing_2024_12/_api_return_typing.py,sha256=7TxSBRv0H0uQ9VnixbyuEXJQnl1uESo102RsAloI9SA,2302
6
+ array_api_typing/typing_2024_12/_api_return_typing.py,sha256=rIygF4PAr7G1PK3cVOwQNrev_kGCuMcfIJ9i8L0u1tI,2462
7
7
  array_api_typing/typing_2024_12/_api_typing.py,sha256=rFuF7E6b2RrijzQBrZx6_dYEn9Yslq8FqI31L1Ccrf4,294255
8
8
  array_api_typing/typing_2024_12/_array_typing.py,sha256=GMqs5LgshujRIh7i0mGtZy4Vt56FSXqBmNlBGidFRnc,57403
9
9
  array_api_typing/typing_compat/__init__.py,sha256=JuD2yojHl3eQI74U75tFhSX5rSy32TR5nKyi714ZfX4,347
@@ -13,19 +13,19 @@ array_api_typing/typing_extra/__init__.py,sha256=YfdhD-Sfk3SCfI9lHmA-PbVLzms1OFF
13
13
  array_api_typing/typing_extra/_api_typing.py,sha256=Jj_E61r35EgecWmBvAzpASV4qub5aQI_O4aL-ngEvQ8,23028
14
14
  array_api_typing/typing_extra/_at.py,sha256=S7_YjOwR3a8olZWgwpLDFEfnekRufRqrtfiMLwrWjgo,2202
15
15
  xbarray/__init__.py,sha256=k4Ipp7IoODqHWZ-eeBhcU7Ch8FudF8rag6KbeIurrgY,45
16
- xbarray/jax.py,sha256=3DHAFsWvfurCI0H6loqx4LtmBov9KbIxeHGmwLKrdh4,344
17
- xbarray/numpy.py,sha256=8thyuJZlm0MJXGx6YWlgxV9ASG-In9jn-pwUFFk7iB4,358
18
- xbarray/pytorch.py,sha256=NEKlF_QoiOQl0T-aP_CkUxJ3HgFJEm0hmLy9v5I7qAo,372
16
+ xbarray/jax.py,sha256=nEdufdTHGXsQSYBbCUwgLikR7saAoZ0zL1a4vrs4I2w,569
17
+ xbarray/numpy.py,sha256=1EaT-RVrJEsrgXE-JmbrKfH6nTVLi9sjQ3D81rWhd4Y,591
18
+ xbarray/pytorch.py,sha256=4j1UOD6wC29zIAFOR6s0tV2vY2bqKCiiJfua_k_jHKs,613
19
19
  xbarray/_src/implementations/_common/implementations.py,sha256=ba6fyoHMrawrrjTULSqt073-aMw7abqFWh5YnzeOHMU,2688
20
- xbarray/_src/implementations/jax/__init__.py,sha256=fZMtXujTLhfHofWJzeMtsOuNhUYSZzCEj06RulT4RqQ,827
20
+ xbarray/_src/implementations/jax/__init__.py,sha256=vOtCkwe98vNI1qQOu0Xv9OIaVucaJp6bEI9_mzicBH0,1014
21
21
  xbarray/_src/implementations/jax/_extra.py,sha256=QmPL7Syp-6y2NUbMYsTlI8uER0Gh5xjBbrNeyTIGZsc,2515
22
22
  xbarray/_src/implementations/jax/_typing.py,sha256=U9BUxHNEjFB0LHF1KMrFLbh6E5kVvpAF8bZUbLNf25E,278
23
23
  xbarray/_src/implementations/jax/random.py,sha256=k6vuNLTOSUClMKI6My0-Xekkeu6hJrUyNib1TziuGag,3312
24
- xbarray/_src/implementations/numpy/__init__.py,sha256=UcUDdf7xDp_oLw5O73G_n15qYzPd5XQsq4cnZoaX9y0,687
24
+ xbarray/_src/implementations/numpy/__init__.py,sha256=hJEcFR86pb1G72MFS1X15qwI_-BuPxy5SS5PXB0alQI,792
25
25
  xbarray/_src/implementations/numpy/_extra.py,sha256=JJs0_HNsnEkBPlXsomRjX5Jh254iXuw_o_3ZXK5c83s,1993
26
26
  xbarray/_src/implementations/numpy/_typing.py,sha256=pgjLLAipwFsIk0QdgrAA4PEvjF_ugHbzfSTbLpA__6o,241
27
27
  xbarray/_src/implementations/numpy/random.py,sha256=C1z2-pMcDMRa7nKhKNTuY5LtTlu9nXvDpH1wY8mdIug,2630
28
- xbarray/_src/implementations/pytorch/__init__.py,sha256=ZCGmZpbv4NaCWU_08G-ipEpI1xMr7QLXzEhwrhD7fI8,690
28
+ xbarray/_src/implementations/pytorch/__init__.py,sha256=s2mF3jWI356oedPZevKYARCFhlXH0hqS9OgBzxMSgHQ,795
29
29
  xbarray/_src/implementations/pytorch/_extra.py,sha256=HwytMuHZ0iyYKmYCj6QaTUKukMRWb47pme6xJ07LLjY,2878
30
30
  xbarray/_src/implementations/pytorch/_typing.py,sha256=qSnNZD3IpgSoOTl3TBRBv2ifSAHOw0aR9uyzfV5KYVw,204
31
31
  xbarray/_src/implementations/pytorch/random.py,sha256=p4ZDsTc--z5XslVzel-iEVDtoKizrv39_1tfR4-469o,2657
@@ -34,8 +34,8 @@ xbarray/_src/serialization/serialization_map.py,sha256=05zfqip8qPCt-LgM_cx_zLc4y
34
34
  xbarray/base/__init__.py,sha256=ERmmOxz_9mUkIuccNbzUa5Y6gVLLVDdyc4cCxbCCUbY,20
35
35
  xbarray/base/base.py,sha256=V1yAXu2wUu6q50okX7JQ8onkGhHIQ0lkPkYOolwc45A,5928
36
36
  xbarray/cls_impl/cls_base.py,sha256=MUvJeMm4UVW4jXwfVP02GiCzqftrAND-JYvThge4PUw,312
37
- xbarray-0.0.1a5.dist-info/licenses/LICENSE,sha256=6P0HCOancSfch0dNycuDIe8_qwS0Id97Ih_8hjJ2PFI,1067
38
- xbarray-0.0.1a5.dist-info/METADATA,sha256=MDImlQt3Kh9anmVussaPDDswLzAwb85TeIn1lW0-v9A,436
39
- xbarray-0.0.1a5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- xbarray-0.0.1a5.dist-info/top_level.txt,sha256=VriXuFyU48Du4HQMzROSArhwqB6EZYY0n0mipgUqB9A,25
41
- xbarray-0.0.1a5.dist-info/RECORD,,
37
+ xbarray-0.0.1a7.dist-info/licenses/LICENSE,sha256=6P0HCOancSfch0dNycuDIe8_qwS0Id97Ih_8hjJ2PFI,1067
38
+ xbarray-0.0.1a7.dist-info/METADATA,sha256=MIiS1RQ5y8CdBbW_h7UvEMwO_QyaRnX_5BhI8CzV8Jc,436
39
+ xbarray-0.0.1a7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ xbarray-0.0.1a7.dist-info/top_level.txt,sha256=VriXuFyU48Du4HQMzROSArhwqB6EZYY0n0mipgUqB9A,25
41
+ xbarray-0.0.1a7.dist-info/RECORD,,