cobra-array 0.1.4__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cobra_array-0.1.4/src/cobra_array.egg-info → cobra_array-0.2.0}/PKG-INFO +1 -1
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/__init__.py +1 -5
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/_core.py +2 -2
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/_utils.py +27 -12
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/__init__.py +1 -1
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_array.py +51 -23
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_array.pyi +2 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_base.py +1 -1
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_namespace.py +37 -16
- {cobra_array-0.1.4 → cobra_array-0.2.0/src/cobra_array.egg-info}/PKG-INFO +1 -1
- cobra_array-0.2.0/tests/test_compat.py +482 -0
- cobra_array-0.1.4/tests/test_compat.py +0 -261
- {cobra_array-0.1.4 → cobra_array-0.2.0}/LICENSE +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/README.md +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/pyproject.toml +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/setup.cfg +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/array_api.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_namespace.pyi +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/convert.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/convert.pyi +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/default.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/exceptions.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/types.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/SOURCES.txt +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/dependency_links.txt +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/requires.txt +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/top_level.txt +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_backend.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_compat_namespace.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_convert.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_default.py +0 -0
- {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_wrap.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cobra-array
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
|
|
5
5
|
Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
|
|
@@ -32,11 +32,8 @@ Examples
|
|
|
32
32
|
from cobra_array.convert import to_numpy, to_tensor, to_list
|
|
33
33
|
|
|
34
34
|
data = [[1, 2], [3, 4]]
|
|
35
|
-
|
|
36
35
|
arr_np = to_numpy(data, dtype=np.float32) # numpy.ndarray float32
|
|
37
|
-
|
|
38
36
|
arr_torch = to_tensor(data, device="cpu")
|
|
39
|
-
|
|
40
37
|
back_to_list = to_list(arr_np) # [[1.0, 2.0], [3.0, 4.0]]
|
|
41
38
|
|
|
42
39
|
- Context-based conversion::
|
|
@@ -66,7 +63,6 @@ Examples
|
|
|
66
63
|
from cobra_array.default import as_default, default_spec
|
|
67
64
|
|
|
68
65
|
spec = default_spec()
|
|
69
|
-
|
|
70
66
|
x = as_default([1, 2, 3], unify_dtype=True, unify_device=True)
|
|
71
67
|
"""
|
|
72
68
|
|
|
@@ -84,7 +80,7 @@ from ._utils import (
|
|
|
84
80
|
)
|
|
85
81
|
|
|
86
82
|
__author__ = "Zhen Tian"
|
|
87
|
-
__version__ = "0.
|
|
83
|
+
__version__ = "0.2.0"
|
|
88
84
|
|
|
89
85
|
__all__ = [
|
|
90
86
|
"array_spec",
|
|
@@ -500,9 +500,9 @@ def unify_args(
|
|
|
500
500
|
return func(*args, **kwargs)
|
|
501
501
|
|
|
502
502
|
with array_context.from_array_spec(spec):
|
|
503
|
-
out_args =
|
|
503
|
+
out_args = [
|
|
504
504
|
as_context(a, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only) for a in args
|
|
505
|
-
|
|
505
|
+
]
|
|
506
506
|
out_kwargs = {
|
|
507
507
|
k: as_context(v, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only)
|
|
508
508
|
for k, v in kwargs.items()
|
|
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
import array_api_compat as api
|
|
7
7
|
import warnings
|
|
8
8
|
from types import ModuleType
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import (Any, overload, Optional, Literal)
|
|
10
10
|
|
|
11
11
|
from .exceptions import UnsupportedNamespaceError
|
|
12
12
|
|
|
@@ -25,7 +25,13 @@ def warn(msg: str, /, category: Any, stack: int = 2):
|
|
|
25
25
|
return warnings.warn(msg, category=category, stacklevel=stack+1)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
@overload
|
|
29
|
+
def array_namespace_alias(xp: object, /, *, raise_on_unsupported: Literal[True] = ...) -> str: ...
|
|
30
|
+
@overload
|
|
31
|
+
def array_namespace_alias(xp: object, /, *, raise_on_unsupported: Literal[False]) -> Optional[str]: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def array_namespace_alias(xp: object, /, *, raise_on_unsupported: bool = True) -> Optional[str]:
|
|
29
35
|
"""
|
|
30
36
|
Get the alias of the `array namespace`.
|
|
31
37
|
|
|
@@ -34,18 +40,23 @@ def array_namespace_alias(xp: object) -> str:
|
|
|
34
40
|
xp : object
|
|
35
41
|
The `array namespace`.
|
|
36
42
|
|
|
43
|
+
raise_on_unsupported : bool, optional
|
|
44
|
+
Whether to raise an error for unsupported array namespaces.
|
|
45
|
+
|
|
37
46
|
Returns
|
|
38
47
|
-------
|
|
39
48
|
str
|
|
40
49
|
The alias of the `array namespace`.
|
|
41
50
|
- Including: `"NumPy"`, `"Cupy"`, `"PyTorch"`, `"NDONNX"`, `"Dask"`, `"JAX"`, `"sparse"` and `"array-api-strict"`.
|
|
51
|
+
`None`
|
|
52
|
+
If the input is not a supported `array namespace` and `raise_on_unsupported` is `False`.
|
|
42
53
|
|
|
43
54
|
Raises
|
|
44
55
|
------
|
|
45
56
|
UnsupportedNameSpaceError
|
|
46
|
-
If the input object is not a supported `array namespace`.
|
|
57
|
+
If the input object is not a supported `array namespace` and :param:`raise_on_unsupported` is `True`.
|
|
47
58
|
"""
|
|
48
|
-
if
|
|
59
|
+
if type(xp) is ModuleType:
|
|
49
60
|
if api.is_numpy_namespace(xp):
|
|
50
61
|
return "NumPy"
|
|
51
62
|
|
|
@@ -70,20 +81,17 @@ def array_namespace_alias(xp: object) -> str:
|
|
|
70
81
|
if api.is_array_api_strict_namespace(xp):
|
|
71
82
|
return "array-api-strict"
|
|
72
83
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
84
|
+
if raise_on_unsupported:
|
|
85
|
+
raise UnsupportedNamespaceError(
|
|
86
|
+
f"Got unsupported array namespace of type {type(xp)}."
|
|
87
|
+
)
|
|
76
88
|
|
|
77
89
|
|
|
78
90
|
def is_array_namespace(obj: object) -> bool:
|
|
79
91
|
"""
|
|
80
92
|
Returns `True` if input is a supported `array namespace`.
|
|
81
93
|
"""
|
|
82
|
-
|
|
83
|
-
array_namespace_alias(obj)
|
|
84
|
-
return True
|
|
85
|
-
except UnsupportedNamespaceError:
|
|
86
|
-
return False
|
|
94
|
+
return array_namespace_alias(obj, raise_on_unsupported=False) is not None
|
|
87
95
|
|
|
88
96
|
|
|
89
97
|
def is_compat_namespace(xp: object) -> bool:
|
|
@@ -91,3 +99,10 @@ def is_compat_namespace(xp: object) -> bool:
|
|
|
91
99
|
Returns `True` if input is a `compatibility namespace` wrapped by :class:`CompatNamespace`
|
|
92
100
|
"""
|
|
93
101
|
return "(compat)" in getattr(xp, "__name__", "")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# def is_array_api_object(obj: object) -> bool:
|
|
105
|
+
# """
|
|
106
|
+
# Returns `True` if input is an `array API object` that belongs to a supported `array namespace`.
|
|
107
|
+
# """
|
|
108
|
+
# return api.is_array_api_obj(obj)
|
|
@@ -108,16 +108,17 @@ class CompatArray(Compat):
|
|
|
108
108
|
_xp = to_xp(xp)
|
|
109
109
|
return cls(
|
|
110
110
|
as_array(unwrap(obj), _xp, copy=copy),
|
|
111
|
-
xp=_xp
|
|
111
|
+
xp=_xp, check=False
|
|
112
112
|
)
|
|
113
113
|
|
|
114
114
|
def __new__(cls, arr, /, *, copy=False, **kwargs):
|
|
115
|
-
if
|
|
115
|
+
if type(arr) is cls:
|
|
116
116
|
# for `CompatArray` input
|
|
117
117
|
return arr.copy() if copy else arr
|
|
118
118
|
|
|
119
|
+
_check = kwargs.get("check", True)
|
|
119
120
|
# for non-`CompatArray` input
|
|
120
|
-
if not api.is_array_api_obj(arr):
|
|
121
|
+
if _check and not api.is_array_api_obj(arr):
|
|
121
122
|
raise NotArrayAPIObjectError(
|
|
122
123
|
f"Parameter `arr` of `CompatArray` must be an array API compatible array object, got {type(arr)}."
|
|
123
124
|
)
|
|
@@ -135,6 +136,8 @@ class CompatArray(Compat):
|
|
|
135
136
|
Convert `self` to a `NumPy array`.
|
|
136
137
|
See also :func:`convert.to_numpy`.
|
|
137
138
|
"""
|
|
139
|
+
if self._xp_name == "NumPy" and not copy:
|
|
140
|
+
return self._arr
|
|
138
141
|
return to_numpy(self._arr, copy=copy)
|
|
139
142
|
|
|
140
143
|
def to_tensor(self, *, device=None, copy=False):
|
|
@@ -142,6 +145,8 @@ class CompatArray(Compat):
|
|
|
142
145
|
Convert `self` to a `PyTorch tensor`.
|
|
143
146
|
See also :func:`convert.to_tensor`.
|
|
144
147
|
"""
|
|
148
|
+
if self._xp_name == "PyTorch" and not copy and (device is None or device == self.device):
|
|
149
|
+
return self._arr
|
|
145
150
|
return to_tensor(self._arr, device=device, copy=copy)
|
|
146
151
|
|
|
147
152
|
def to_list(self, *, copy=False):
|
|
@@ -197,7 +202,7 @@ class CompatArray(Compat):
|
|
|
197
202
|
All the arrays have the same shape.
|
|
198
203
|
"""
|
|
199
204
|
result = self._get_xp_attr("unstack")(self._arr, axis=axis)
|
|
200
|
-
return tuple(CompatArray(arr, xp=self.
|
|
205
|
+
return tuple([CompatArray(arr, xp=self._cxp, check=False) for arr in result])
|
|
201
206
|
|
|
202
207
|
# === Searching functions ===
|
|
203
208
|
def nonzero(self):
|
|
@@ -218,7 +223,7 @@ class CompatArray(Compat):
|
|
|
218
223
|
- If `self` has a boolean data type, non-zero elements are those elements which are equal to `True`.
|
|
219
224
|
"""
|
|
220
225
|
result = self._get_xp_attr("nonzero")(self._arr)
|
|
221
|
-
return tuple(CompatArray(arr, xp=self.
|
|
226
|
+
return tuple([CompatArray(arr, xp=self._cxp) for arr in result])
|
|
222
227
|
|
|
223
228
|
# === Set functions ===
|
|
224
229
|
def unique_all(self):
|
|
@@ -238,10 +243,10 @@ class CompatArray(Compat):
|
|
|
238
243
|
"""
|
|
239
244
|
result = self._get_xp_attr("unique_all")(self._arr)
|
|
240
245
|
return UniqueResult(
|
|
241
|
-
values=CompatArray(result.values, xp=self.
|
|
242
|
-
indices=CompatArray(result.indices, xp=self.
|
|
243
|
-
inverse_indices=CompatArray(result.inverse_indices, xp=self.
|
|
244
|
-
counts=CompatArray(result.counts, xp=self.
|
|
246
|
+
values=CompatArray(result.values, xp=self._cxp, check=False),
|
|
247
|
+
indices=CompatArray(result.indices, xp=self._cxp, check=False),
|
|
248
|
+
inverse_indices=CompatArray(result.inverse_indices, xp=self._cxp, check=False),
|
|
249
|
+
counts=CompatArray(result.counts, xp=self._cxp, check=False),
|
|
245
250
|
)
|
|
246
251
|
|
|
247
252
|
def unique_counts(self):
|
|
@@ -259,10 +264,10 @@ class CompatArray(Compat):
|
|
|
259
264
|
"""
|
|
260
265
|
result = self._get_xp_attr("unique_counts")(self._arr)
|
|
261
266
|
return UniqueResult(
|
|
262
|
-
values=CompatArray(result.values, xp=self.
|
|
267
|
+
values=CompatArray(result.values, xp=self._cxp, check=False),
|
|
263
268
|
indices=None,
|
|
264
269
|
inverse_indices=None,
|
|
265
|
-
counts=CompatArray(result.counts, xp=self.
|
|
270
|
+
counts=CompatArray(result.counts, xp=self._cxp, check=False),
|
|
266
271
|
)
|
|
267
272
|
|
|
268
273
|
def unique_inverse(self):
|
|
@@ -280,9 +285,9 @@ class CompatArray(Compat):
|
|
|
280
285
|
"""
|
|
281
286
|
result = self._get_xp_attr("unique_inverse")(self._arr)
|
|
282
287
|
return UniqueResult(
|
|
283
|
-
values=CompatArray(result.values, xp=self.
|
|
288
|
+
values=CompatArray(result.values, xp=self._cxp, check=False),
|
|
284
289
|
indices=None,
|
|
285
|
-
inverse_indices=CompatArray(result.inverse_indices, xp=self.
|
|
290
|
+
inverse_indices=CompatArray(result.inverse_indices, xp=self._cxp, check=False),
|
|
286
291
|
counts=None,
|
|
287
292
|
)
|
|
288
293
|
|
|
@@ -291,7 +296,7 @@ class CompatArray(Compat):
|
|
|
291
296
|
"""
|
|
292
297
|
Return a copy of `self` via :func:`convert.as_array`.
|
|
293
298
|
"""
|
|
294
|
-
return CompatArray.from_other(self._arr, xp=self.
|
|
299
|
+
return CompatArray.from_other(self._arr, xp=self._cxp, copy=True)
|
|
295
300
|
|
|
296
301
|
def _get_attr(self, name: str):
|
|
297
302
|
"""Try to get the attribute `name` from `self`."""
|
|
@@ -379,7 +384,7 @@ class CompatArray(Compat):
|
|
|
379
384
|
result = self._get_xp_attr("T")(self._arr)
|
|
380
385
|
except (AttributeError, TypeError):
|
|
381
386
|
result = self._get_attr("T")
|
|
382
|
-
return CompatArray(result, xp=self.
|
|
387
|
+
return CompatArray(result, xp=self._cxp, check=False)
|
|
383
388
|
|
|
384
389
|
@property
|
|
385
390
|
def mT(self):
|
|
@@ -391,19 +396,19 @@ class CompatArray(Compat):
|
|
|
391
396
|
result = self._get_xp_attr("mT")(self._arr)
|
|
392
397
|
except (AttributeError, TypeError):
|
|
393
398
|
result = self._get_attr("mT")
|
|
394
|
-
return CompatArray(result, xp=self.
|
|
399
|
+
return CompatArray(result, xp=self._cxp, check=False)
|
|
395
400
|
|
|
396
401
|
def __array__(self):
|
|
397
402
|
"""Allow implicit NumPy conversion."""
|
|
398
403
|
return self.to_numpy()
|
|
399
404
|
|
|
400
405
|
def __getattr__(self, name: str):
|
|
401
|
-
attr = self.
|
|
406
|
+
attr = self._get_xp_attr(name)
|
|
402
407
|
|
|
403
408
|
if callable(attr) and not isinstance(attr, type):
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
return
|
|
409
|
+
wrapped = _make_wrapper(self._xp_name, attr, self._cxp, self._arr)
|
|
410
|
+
self.__dict__[name] = wrapped
|
|
411
|
+
return wrapped
|
|
407
412
|
raise CompatArrayAttributeError(f"`CompatArray` `{self._xp_name}` does not support attribute `{name}`.")
|
|
408
413
|
|
|
409
414
|
def __len__(self):
|
|
@@ -540,7 +545,7 @@ def unwrap(obj):
|
|
|
540
545
|
"""
|
|
541
546
|
Unwraps a :class:`CompatArray` array to get the backend-specific array instance, or returns the object itself if it is not a :class:`CompatArray` array.
|
|
542
547
|
"""
|
|
543
|
-
return obj.arr if
|
|
548
|
+
return obj.arr if type(obj) is CompatArray else obj
|
|
544
549
|
|
|
545
550
|
|
|
546
551
|
def wrap_arraylike(arr, xp=None):
|
|
@@ -549,8 +554,8 @@ def wrap_arraylike(arr, xp=None):
|
|
|
549
554
|
"""
|
|
550
555
|
if api.is_array_api_obj(arr):
|
|
551
556
|
if xp is None:
|
|
552
|
-
return CompatArray(arr)
|
|
553
|
-
return CompatArray(arr, xp=xp)
|
|
557
|
+
return CompatArray(arr, check=False)
|
|
558
|
+
return CompatArray(arr, xp=xp, check=False)
|
|
554
559
|
return arr
|
|
555
560
|
|
|
556
561
|
|
|
@@ -559,3 +564,26 @@ def to_cxp(xp):
|
|
|
559
564
|
from ._namespace import CompatNamespace
|
|
560
565
|
|
|
561
566
|
return CompatNamespace(xp)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def _make_wrapper(xp_name, attr, cxp, first, /):
|
|
570
|
+
"""Make a wrapper function for the attribute `name` of the `array namespace`."""
|
|
571
|
+
wrap_ = lambda x: wrap_arraylike(x, xp=cxp)
|
|
572
|
+
unwrap_ = unwrap
|
|
573
|
+
|
|
574
|
+
def wrapper(*args, **kwargs):
|
|
575
|
+
if xp_name == "NumPy":
|
|
576
|
+
return wrap_(attr(first, *args, **kwargs))
|
|
577
|
+
|
|
578
|
+
n = len(args)
|
|
579
|
+
if n == 0 and not kwargs:
|
|
580
|
+
return wrap_(attr(first))
|
|
581
|
+
if n == 1 and not kwargs:
|
|
582
|
+
return wrap_(attr(first, unwrap_(args[0])))
|
|
583
|
+
new_args = [unwrap_(a) for a in args]
|
|
584
|
+
if kwargs:
|
|
585
|
+
new_kwargs = {k: unwrap_(v) for k, v in kwargs.items()}
|
|
586
|
+
return wrap_(attr(first, *new_args, **new_kwargs))
|
|
587
|
+
return wrap_(attr(first, *new_args))
|
|
588
|
+
wrapper.__name__ = getattr(attr, "__name__", "wrapper")
|
|
589
|
+
return wrapper
|
|
@@ -50,6 +50,8 @@ class CompatArray(Compat, Generic[TT, DT]):
|
|
|
50
50
|
def __new__(cls, arr: CompatArray[dtypeT, deviceT], /, **kwargs) -> CompatArray[dtypeT, deviceT]: ...
|
|
51
51
|
@overload
|
|
52
52
|
def __new__(cls, arr: ArrayLike[dtypeT], /, **kwargs) -> CompatArray[dtypeT, AnyDevice]: ...
|
|
53
|
+
@overload
|
|
54
|
+
def __new__(cls, arr: object, /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
|
|
53
55
|
def __new__(cls, arr: ArrayLike[Any], /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
|
|
54
56
|
|
|
55
57
|
# === Conversion functions ===
|
|
@@ -50,7 +50,7 @@ class CompatNamespace(Compat):
|
|
|
50
50
|
AttributeError: ...
|
|
51
51
|
"""
|
|
52
52
|
def __new__(cls, xp, /):
|
|
53
|
-
if
|
|
53
|
+
if type(xp) is cls:
|
|
54
54
|
# for `CompatNamespace` input
|
|
55
55
|
return xp
|
|
56
56
|
# for `Namespace` input
|
|
@@ -91,10 +91,10 @@ class CompatNamespace(Compat):
|
|
|
91
91
|
Each returned array should have the same data type as the input arrays.
|
|
92
92
|
"""
|
|
93
93
|
result = self._get_xp_attr("meshgrid")(
|
|
94
|
-
*
|
|
94
|
+
*[unwrap(arr) for arr in arrays],
|
|
95
95
|
indexing=indexing
|
|
96
96
|
)
|
|
97
|
-
return [CompatArray(arr, xp=self
|
|
97
|
+
return [CompatArray(arr, xp=self) for arr in result]
|
|
98
98
|
|
|
99
99
|
# === Data Type functions ===
|
|
100
100
|
def can_cast(self, from_, to, /):
|
|
@@ -170,7 +170,7 @@ class CompatNamespace(Compat):
|
|
|
170
170
|
"""
|
|
171
171
|
return self._get_xp_attr("iinfo")(unwrap(type_))
|
|
172
172
|
|
|
173
|
-
def isdtype(self, dtype, kind)
|
|
173
|
+
def isdtype(self, dtype, kind):
|
|
174
174
|
"""
|
|
175
175
|
Returns a boolean indicating whether a provided :param:`dtype` is of a specified data type :param:`kind`.
|
|
176
176
|
|
|
@@ -214,7 +214,7 @@ class CompatNamespace(Compat):
|
|
|
214
214
|
DType
|
|
215
215
|
The dtype resulting from an operation involving the input arrays, scalars, and/or dtypes.
|
|
216
216
|
"""
|
|
217
|
-
return self._get_xp_attr("result_type")(*
|
|
217
|
+
return self._get_xp_attr("result_type")(*[unwrap(arr) for arr in arrays_and_dtypes])
|
|
218
218
|
|
|
219
219
|
# === Manipulation functions ===
|
|
220
220
|
def broadcast_arrays(self, *arrays):
|
|
@@ -233,8 +233,8 @@ class CompatNamespace(Compat):
|
|
|
233
233
|
Each array must have the same shape.
|
|
234
234
|
Each array must have the same dtype as its corresponding input array.
|
|
235
235
|
"""
|
|
236
|
-
result = self._get_xp_attr("broadcast_arrays")(*
|
|
237
|
-
return [CompatArray(arr, xp=self
|
|
236
|
+
result = self._get_xp_attr("broadcast_arrays")(*[unwrap(arr) for arr in arrays])
|
|
237
|
+
return [CompatArray(arr, xp=self) for arr in result]
|
|
238
238
|
|
|
239
239
|
# === Constants ===
|
|
240
240
|
@property
|
|
@@ -315,16 +315,37 @@ class CompatNamespace(Compat):
|
|
|
315
315
|
def __name__(self):
|
|
316
316
|
return "(compat)" + getattr(self._xp, "__name__", type(self._xp).__name__)
|
|
317
317
|
|
|
318
|
-
def __getattr__(self, name
|
|
318
|
+
def __getattr__(self, name):
|
|
319
319
|
attr = self._get_xp_attr(name)
|
|
320
320
|
|
|
321
321
|
if callable(attr):
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
new_args = tuple(unwrap(a) for a in args)
|
|
327
|
-
new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} if kwargs else kwargs
|
|
328
|
-
return wrap_arraylike(attr(*new_args, **new_kwargs), xp=self._xp)
|
|
329
|
-
return wrapper
|
|
322
|
+
wrapped = _make_wrapper(self._xp_name, attr, self)
|
|
323
|
+
self.__dict__[name] = wrapped
|
|
324
|
+
return wrapped
|
|
330
325
|
raise CompatNamespaceAttributeError(f"`CompatNamespace` `{self._xp_name}` does not support attribute `{name}`.")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def _make_wrapper(xp_name, attr, cxp, /):
|
|
329
|
+
"""Make a wrapper function for the attribute `name` of the `array namespace`."""
|
|
330
|
+
wrap_ = lambda x: wrap_arraylike(x, xp=cxp)
|
|
331
|
+
unwrap_ = unwrap
|
|
332
|
+
|
|
333
|
+
def wrapper(*args, **kwargs):
|
|
334
|
+
if xp_name == "NumPy":
|
|
335
|
+
return wrap_(attr(*args, **kwargs))
|
|
336
|
+
|
|
337
|
+
n = len(args)
|
|
338
|
+
if n == 0 and not kwargs:
|
|
339
|
+
return wrap_(attr())
|
|
340
|
+
if n == 1 and not kwargs:
|
|
341
|
+
return wrap_(attr(unwrap_(args[0])))
|
|
342
|
+
if n == 2 and not kwargs:
|
|
343
|
+
a0, a1 = args
|
|
344
|
+
return wrap_(attr(unwrap_(a0), unwrap_(a1)))
|
|
345
|
+
new_args = [unwrap_(a) for a in args]
|
|
346
|
+
if kwargs:
|
|
347
|
+
new_kwargs = {k: unwrap_(v) for k, v in kwargs.items()}
|
|
348
|
+
return wrap_(attr(*new_args, **new_kwargs))
|
|
349
|
+
return wrap_(attr(*new_args))
|
|
350
|
+
wrapper.__name__ = getattr(attr, "__name__", "wrapper")
|
|
351
|
+
return wrapper
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cobra-array
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
|
|
5
5
|
Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
|
|
6
6
|
Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
|
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytest
|
|
6
|
+
import time
|
|
7
|
+
from types import ModuleType
|
|
8
|
+
|
|
9
|
+
sys.path.insert(0, "/data/tianzhen/my_packages/cobra-array/src")
|
|
10
|
+
|
|
11
|
+
from cobra_array.array_api import torch_xp
|
|
12
|
+
from cobra_array.compat import CompatArray, unwrap, wrap_arraylike
|
|
13
|
+
from cobra_array.exceptions import CompatArrayAttributeError, NotArrayAPIObjectError
|
|
14
|
+
from cobra_array.compat._array import UniqueResult
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _arr_1d():
|
|
18
|
+
return CompatArray(np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _arr_2d():
|
|
22
|
+
return CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_from_other_with_numpy_namespace():
|
|
26
|
+
out = CompatArray.from_other([1, 2, 3], xp="numpy")
|
|
27
|
+
assert isinstance(out, CompatArray)
|
|
28
|
+
assert out.xp_name == "NumPy"
|
|
29
|
+
assert out.to_list() == [1, 2, 3]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_new_from_compatarray_copy_and_no_copy():
|
|
33
|
+
a = _arr_1d()
|
|
34
|
+
b = CompatArray(a)
|
|
35
|
+
c = CompatArray(a, copy=True)
|
|
36
|
+
|
|
37
|
+
assert b is a
|
|
38
|
+
assert c is not a
|
|
39
|
+
assert c.to_list() == a.to_list()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_new_invalid_input_raises():
|
|
43
|
+
with pytest.raises(NotArrayAPIObjectError):
|
|
44
|
+
CompatArray("not-array") # type: ignore[arg-type]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_to_numpy_to_list_and_array_protocol():
|
|
48
|
+
a = _arr_1d()
|
|
49
|
+
|
|
50
|
+
n1 = a.to_numpy(copy=False)
|
|
51
|
+
n2 = np.asarray(a)
|
|
52
|
+
|
|
53
|
+
assert isinstance(n1, np.ndarray)
|
|
54
|
+
assert isinstance(n2, np.ndarray)
|
|
55
|
+
assert n1.tolist() == [1.0, 2.0, 3.0]
|
|
56
|
+
assert n2.tolist() == [1.0, 2.0, 3.0]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_to_device_on_numpy_backend():
|
|
60
|
+
a = _arr_1d()
|
|
61
|
+
moved = a.to_device("cpu")
|
|
62
|
+
|
|
63
|
+
assert isinstance(moved, np.ndarray)
|
|
64
|
+
assert moved.tolist() == a.to_list()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.mark.skipif(torch_xp is None, reason="PyTorch not available")
|
|
68
|
+
def test_to_tensor_available_backend():
|
|
69
|
+
a = _arr_1d()
|
|
70
|
+
t = a.to_tensor(device="cpu")
|
|
71
|
+
|
|
72
|
+
assert isinstance(t, torch_xp.Tensor)
|
|
73
|
+
assert tuple(t.shape) == (3,)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_unstack_and_nonzero():
|
|
77
|
+
a = _arr_2d()
|
|
78
|
+
|
|
79
|
+
pieces = a.unstack(axis=0)
|
|
80
|
+
nz = a.nonzero()
|
|
81
|
+
|
|
82
|
+
assert isinstance(pieces, tuple)
|
|
83
|
+
assert len(pieces) == 2
|
|
84
|
+
assert all(isinstance(x, CompatArray) for x in pieces)
|
|
85
|
+
assert pieces[0].to_list() == [1.0, 2.0]
|
|
86
|
+
|
|
87
|
+
assert isinstance(nz, tuple)
|
|
88
|
+
assert len(nz) == 2
|
|
89
|
+
assert all(isinstance(x, CompatArray) for x in nz)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_unique_all_unique_counts_unique_inverse():
|
|
93
|
+
a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
|
|
94
|
+
|
|
95
|
+
print(type(a) is CompatArray)
|
|
96
|
+
|
|
97
|
+
start = time.process_time()
|
|
98
|
+
all_result = a.cxp.unique_all(a)
|
|
99
|
+
# counts_result = a.unique_counts()
|
|
100
|
+
# inverse_result = a.unique_inverse()
|
|
101
|
+
|
|
102
|
+
# print(all_result.values)
|
|
103
|
+
|
|
104
|
+
# assert isinstance(all_result.values, np.ndarray)
|
|
105
|
+
# assert isinstance(all_result.indices, np.ndarray)
|
|
106
|
+
# assert isinstance(all_result.inverse_indices, np.ndarray)
|
|
107
|
+
# assert isinstance(all_result.counts, np.ndarray)
|
|
108
|
+
|
|
109
|
+
print(f"(cxp1)Unique operations test took {time.process_time() - start:.5f} seconds")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
start = time.process_time()
|
|
113
|
+
all_result = a.unique_all()
|
|
114
|
+
# counts_result = a.unique_counts()
|
|
115
|
+
# inverse_result = a.unique_inverse()
|
|
116
|
+
|
|
117
|
+
# assert isinstance(all_result.values, CompatArray)
|
|
118
|
+
# assert isinstance(all_result.indices, CompatArray)
|
|
119
|
+
# assert isinstance(all_result.inverse_indices, CompatArray)
|
|
120
|
+
# assert isinstance(all_result.counts, CompatArray)
|
|
121
|
+
|
|
122
|
+
print(f"Unique operations test took {time.process_time() - start:.5f} seconds")
|
|
123
|
+
|
|
124
|
+
start = time.process_time()
|
|
125
|
+
all_result = a.cxp.unique_all(a)
|
|
126
|
+
# counts_result = a.unique_counts()
|
|
127
|
+
# inverse_result = a.unique_inverse()
|
|
128
|
+
|
|
129
|
+
# print(all_result.values)
|
|
130
|
+
|
|
131
|
+
# assert isinstance(all_result.values, np.ndarray)
|
|
132
|
+
# assert isinstance(all_result.indices, np.ndarray)
|
|
133
|
+
# assert isinstance(all_result.inverse_indices, np.ndarray)
|
|
134
|
+
# assert isinstance(all_result.counts, np.ndarray)
|
|
135
|
+
|
|
136
|
+
print(f"(cxp)Unique operations test took {time.process_time() - start:.5f} seconds")
|
|
137
|
+
|
|
138
|
+
# assert isinstance(counts_result.values, CompatArray)
|
|
139
|
+
# assert isinstance(counts_result.counts, CompatArray)
|
|
140
|
+
|
|
141
|
+
# assert isinstance(inverse_result.values, CompatArray)
|
|
142
|
+
# assert isinstance(inverse_result.inverse_indices, CompatArray)
|
|
143
|
+
|
|
144
|
+
a = np.array([1, 2, 2, 1, 3], dtype=np.int32)
|
|
145
|
+
|
|
146
|
+
start = time.process_time()
|
|
147
|
+
all_result = np.unique(a, return_index=True, return_inverse=True, return_counts=True)
|
|
148
|
+
|
|
149
|
+
UniqueResult(
|
|
150
|
+
CompatArray(all_result[0], xp=np),
|
|
151
|
+
CompatArray(all_result[1], xp=np),
|
|
152
|
+
CompatArray(all_result[2], xp=np),
|
|
153
|
+
CompatArray(all_result[3], xp=np)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
print(f"(RAW)Unique operations test took {time.process_time() - start:.5f} seconds")
|
|
157
|
+
|
|
158
|
+
# start = time.process_time()
|
|
159
|
+
# all_result = torch_xp.unique(torch.tensor(a), return_inverse=True, return_counts=True)
|
|
160
|
+
|
|
161
|
+
# print(f"API operations test took {time.process_time() - start:.4f} seconds")
|
|
162
|
+
|
|
163
|
+
from array_api_compat import get_namespace
|
|
164
|
+
xp = get_namespace(a)
|
|
165
|
+
|
|
166
|
+
start = time.process_time()
|
|
167
|
+
all_result = xp.unique_all(a)
|
|
168
|
+
|
|
169
|
+
UniqueResult(
|
|
170
|
+
CompatArray(all_result.values, xp=xp),
|
|
171
|
+
CompatArray(all_result.indices, xp=xp),
|
|
172
|
+
CompatArray(all_result.inverse_indices, xp=xp),
|
|
173
|
+
CompatArray(all_result.counts, xp=xp)
|
|
174
|
+
)
|
|
175
|
+
assert isinstance(all_result.values, xp.ndarray)
|
|
176
|
+
assert isinstance(all_result.indices, xp.ndarray)
|
|
177
|
+
assert isinstance(all_result.inverse_indices, xp.ndarray)
|
|
178
|
+
assert isinstance(all_result.counts, xp.ndarray)
|
|
179
|
+
|
|
180
|
+
print(f"API operations test took {time.process_time() - start:.5f} seconds")
|
|
181
|
+
|
|
182
|
+
print(type(xp) is ModuleType)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_add():
|
|
186
|
+
|
|
187
|
+
a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
|
|
188
|
+
print(type(a) is CompatArray)
|
|
189
|
+
|
|
190
|
+
start = time.time()
|
|
191
|
+
all_result = a.add(1)
|
|
192
|
+
all_result = a.add(1)
|
|
193
|
+
# counts_result = a.unique_counts()
|
|
194
|
+
# inverse_result = a.unique_inverse()
|
|
195
|
+
|
|
196
|
+
print(f"add operations test took {time.time() - start:.5f} seconds")
|
|
197
|
+
|
|
198
|
+
start = time.time()
|
|
199
|
+
all_result = a.cxp.add(a, 1)
|
|
200
|
+
all_result = a.cxp.add(a, 1)
|
|
201
|
+
|
|
202
|
+
print(f"(cxp)add operations test took {time.time() - start:.5f} seconds")
|
|
203
|
+
|
|
204
|
+
a = np.array([1, 2, 2, 1, 3], dtype=np.int32)
|
|
205
|
+
|
|
206
|
+
start = time.time()
|
|
207
|
+
all_result = np.add(a, 1)
|
|
208
|
+
|
|
209
|
+
print(f"(RAW)add operations test took {time.time() - start:.5f} seconds")
|
|
210
|
+
|
|
211
|
+
from array_api_compat import get_namespace
|
|
212
|
+
xp = get_namespace(a)
|
|
213
|
+
|
|
214
|
+
start = time.time()
|
|
215
|
+
all_result = xp.add(a, 1)
|
|
216
|
+
|
|
217
|
+
print(f"API operations test took {time.time() - start:.5f} seconds")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def test_copy_and_basic_properties():
|
|
221
|
+
a = _arr_2d()
|
|
222
|
+
b = a.copy()
|
|
223
|
+
|
|
224
|
+
assert isinstance(a.arr, np.ndarray)
|
|
225
|
+
assert isinstance(b, CompatArray)
|
|
226
|
+
assert b is not a
|
|
227
|
+
|
|
228
|
+
assert str(a.device) == "cpu"
|
|
229
|
+
assert a.shape == (2, 2)
|
|
230
|
+
assert a.ndim == 2
|
|
231
|
+
assert a.size == 4
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def test_transpose_properties():
|
|
235
|
+
a = _arr_2d()
|
|
236
|
+
|
|
237
|
+
t = a.T
|
|
238
|
+
|
|
239
|
+
assert isinstance(t, CompatArray)
|
|
240
|
+
assert t.shape == (2, 2)
|
|
241
|
+
|
|
242
|
+
with pytest.raises(AttributeError):
|
|
243
|
+
_ = a.mT
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_len_and_repr():
|
|
247
|
+
a = _arr_1d()
|
|
248
|
+
z = CompatArray(np.array(5))
|
|
249
|
+
|
|
250
|
+
assert len(a) == 3
|
|
251
|
+
assert "NumPy_Array" in repr(a)
|
|
252
|
+
|
|
253
|
+
with pytest.raises(TypeError):
|
|
254
|
+
len(z)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_getitem_and_setitem():
|
|
258
|
+
a = CompatArray(np.array([10, 20, 30], dtype=np.int32))
|
|
259
|
+
|
|
260
|
+
assert int(a[1]) == 20
|
|
261
|
+
|
|
262
|
+
a[1] = 99
|
|
263
|
+
assert a.to_list() == [10, 99, 30]
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def test_scalar_conversions():
|
|
267
|
+
b = CompatArray(np.array(True))
|
|
268
|
+
i = CompatArray(np.array(7))
|
|
269
|
+
f = CompatArray(np.array(1.5, dtype=np.float32))
|
|
270
|
+
c = CompatArray(np.array(2 + 3j, dtype=np.complex64))
|
|
271
|
+
|
|
272
|
+
assert bool(b) is True
|
|
273
|
+
assert int(i) == 7
|
|
274
|
+
assert i.__index__() == 7
|
|
275
|
+
assert float(f) == pytest.approx(1.5)
|
|
276
|
+
assert complex(c) == complex(2, 3)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def test_numpy_operator_overloads():
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
a = np.array([1, 2, 3], dtype=np.int32)
|
|
283
|
+
b = np.array([3, 2, 1], dtype=np.int32)
|
|
284
|
+
|
|
285
|
+
start = time.process_time()
|
|
286
|
+
assert a.__abs__().tolist() == [1, 2, 3]
|
|
287
|
+
assert (a + b).tolist() == [4, 4, 4]
|
|
288
|
+
assert (a - b).tolist() == [-2, 0, 2]
|
|
289
|
+
assert (a * b).tolist() == [3, 4, 3]
|
|
290
|
+
assert (a / 2).tolist() == [0.5, 1.0, 1.5]
|
|
291
|
+
assert (a // 2).tolist() == [0, 1, 1]
|
|
292
|
+
assert (a % 2).tolist() == [1, 0, 1]
|
|
293
|
+
assert (a & b).tolist() == [1, 2, 1]
|
|
294
|
+
assert (a | b).tolist() == [3, 2, 3]
|
|
295
|
+
assert (a ^ b).tolist() == [2, 0, 2]
|
|
296
|
+
assert (a << 1).tolist() == [2, 4, 6]
|
|
297
|
+
assert (a >> 1).tolist() == [0, 1, 1]
|
|
298
|
+
|
|
299
|
+
assert (-a).tolist() == [-1, -2, -3]
|
|
300
|
+
assert (+a).tolist() == [1, 2, 3]
|
|
301
|
+
assert (~a).tolist() == [-2, -3, -4]
|
|
302
|
+
|
|
303
|
+
assert (a == b).tolist() == [False, True, False]
|
|
304
|
+
assert (a != b).tolist() == [True, False, True]
|
|
305
|
+
assert (a > b).tolist() == [False, False, True]
|
|
306
|
+
assert (a >= b).tolist() == [False, True, True]
|
|
307
|
+
assert (a < b).tolist() == [True, False, False]
|
|
308
|
+
assert (a <= b).tolist() == [True, True, False]
|
|
309
|
+
|
|
310
|
+
print(f"NumPy operator overloads test took {time.process_time() - start:.4f} seconds")
|
|
311
|
+
|
|
312
|
+
a = CompatArray(np.array([1, 2, 3], dtype=np.int32))
|
|
313
|
+
b = CompatArray(np.array([3, 2, 1], dtype=np.int32))
|
|
314
|
+
|
|
315
|
+
start = time.process_time()
|
|
316
|
+
assert a.__abs__().to_list() == [1, 2, 3]
|
|
317
|
+
assert (a + b).to_list() == [4, 4, 4]
|
|
318
|
+
assert (a - b).to_list() == [-2, 0, 2]
|
|
319
|
+
assert (a * b).to_list() == [3, 4, 3]
|
|
320
|
+
assert (a / 2).to_list() == [0.5, 1.0, 1.5]
|
|
321
|
+
assert (a // 2).to_list() == [0, 1, 1]
|
|
322
|
+
assert (a % 2).to_list() == [1, 0, 1]
|
|
323
|
+
assert (a & b).to_list() == [1, 2, 1]
|
|
324
|
+
assert (a | b).to_list() == [3, 2, 3]
|
|
325
|
+
assert (a ^ b).to_list() == [2, 0, 2]
|
|
326
|
+
assert (a << 1).to_list() == [2, 4, 6]
|
|
327
|
+
assert (a >> 1).to_list() == [0, 1, 1]
|
|
328
|
+
|
|
329
|
+
assert (-a).to_list() == [-1, -2, -3]
|
|
330
|
+
assert (+a).to_list() == [1, 2, 3]
|
|
331
|
+
assert (~a).to_list() == [-2, -3, -4]
|
|
332
|
+
|
|
333
|
+
assert (a == b).to_list() == [False, True, False]
|
|
334
|
+
assert (a != b).to_list() == [True, False, True]
|
|
335
|
+
assert (a > b).to_list() == [False, False, True]
|
|
336
|
+
assert (a >= b).to_list() == [False, True, True]
|
|
337
|
+
assert (a < b).to_list() == [True, False, False]
|
|
338
|
+
assert (a <= b).to_list() == [True, True, False]
|
|
339
|
+
|
|
340
|
+
print(f"Operator overloads test took {time.process_time() - start:.4f} seconds")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# def test_torch_operator_overloads():
|
|
344
|
+
# a = CompatArray(torch.tensor([1, 2, 3], dtype=torch.int32))
|
|
345
|
+
# b = CompatArray(torch.tensor([3, 2, 1], dtype=torch.int32))
|
|
346
|
+
|
|
347
|
+
# start = time.process_time()
|
|
348
|
+
# assert a.__abs__().to_list() == [1, 2, 3]
|
|
349
|
+
# assert (a + b).to_list() == [4, 4, 4]
|
|
350
|
+
# assert (a - b).to_list() == [-2, 0, 2]
|
|
351
|
+
# assert (a * b).to_list() == [3, 4, 3]
|
|
352
|
+
# assert (a / 2).to_list() == [0.5, 1.0, 1.5]
|
|
353
|
+
# assert (a // 2).to_list() == [0, 1, 1]
|
|
354
|
+
# assert (a % 2).to_list() == [1, 0, 1]
|
|
355
|
+
# assert (a & b).to_list() == [1, 2, 1]
|
|
356
|
+
# assert (a | b).to_list() == [3, 2, 3]
|
|
357
|
+
# assert (a ^ b).to_list() == [2, 0, 2]
|
|
358
|
+
# assert (a << 1).to_list() == [2, 4, 6]
|
|
359
|
+
# assert (a >> 1).to_list() == [0, 1, 1]
|
|
360
|
+
|
|
361
|
+
# assert (-a).to_list() == [-1, -2, -3]
|
|
362
|
+
# assert (+a).to_list() == [1, 2, 3]
|
|
363
|
+
# assert (~a).to_list() == [-2, -3, -4]
|
|
364
|
+
|
|
365
|
+
# assert (a == b).to_list() == [False, True, False]
|
|
366
|
+
# assert (a != b).to_list() == [True, False, True]
|
|
367
|
+
# assert (a > b).to_list() == [False, False, True]
|
|
368
|
+
# assert (a >= b).to_list() == [False, True, True]
|
|
369
|
+
# assert (a < b).to_list() == [True, False, False]
|
|
370
|
+
# assert (a <= b).to_list() == [True, True, False]
|
|
371
|
+
|
|
372
|
+
# print(f"Operator overloads test took {time.process_time() - start:.4f} seconds")
|
|
373
|
+
|
|
374
|
+
# a = torch.tensor([1, 2, 3], dtype=torch.int32)
|
|
375
|
+
# b = torch.tensor([3, 2, 1], dtype=torch.int32)
|
|
376
|
+
|
|
377
|
+
# start = time.process_time()
|
|
378
|
+
# assert a.__abs__().tolist() == [1, 2, 3]
|
|
379
|
+
# assert (a + b).tolist() == [4, 4, 4]
|
|
380
|
+
# assert (a - b).tolist() == [-2, 0, 2]
|
|
381
|
+
# assert (a * b).tolist() == [3, 4, 3]
|
|
382
|
+
# assert (a / 2).tolist() == [0.5, 1.0, 1.5]
|
|
383
|
+
# assert (a // 2).tolist() == [0, 1, 1]
|
|
384
|
+
# assert (a % 2).tolist() == [1, 0, 1]
|
|
385
|
+
# assert (a & b).tolist() == [1, 2, 1]
|
|
386
|
+
# assert (a | b).tolist() == [3, 2, 3]
|
|
387
|
+
# assert (a ^ b).tolist() == [2, 0, 2]
|
|
388
|
+
# assert (a << 1).tolist() == [2, 4, 6]
|
|
389
|
+
# assert (a >> 1).tolist() == [0, 1, 1]
|
|
390
|
+
|
|
391
|
+
# assert (-a).tolist() == [-1, -2, -3]
|
|
392
|
+
# assert (+a).tolist() == [1, 2, 3]
|
|
393
|
+
# assert (~a).tolist() == [-2, -3, -4]
|
|
394
|
+
|
|
395
|
+
# assert (a == b).tolist() == [False, True, False]
|
|
396
|
+
# assert (a != b).tolist() == [True, False, True]
|
|
397
|
+
# assert (a > b).tolist() == [False, False, True]
|
|
398
|
+
# assert (a >= b).tolist() == [False, True, True]
|
|
399
|
+
# assert (a < b).tolist() == [True, False, False]
|
|
400
|
+
# assert (a <= b).tolist() == [True, True, False]
|
|
401
|
+
|
|
402
|
+
# print(f"Torch operator overloads test took {time.process_time() - start:.4f} seconds")
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def test_pow_operator_matches_current_implementation_behavior():
|
|
406
|
+
a = _arr_1d()
|
|
407
|
+
out = a ** 2
|
|
408
|
+
assert isinstance(out, CompatArray)
|
|
409
|
+
assert out.to_list() == [1.0, 4.0, 9.0]
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def test_matmul_operator():
|
|
413
|
+
x = CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
|
|
414
|
+
y = CompatArray(np.array([[2.0, 0.0], [1.0, 2.0]], dtype=np.float32))
|
|
415
|
+
|
|
416
|
+
out = x @ y
|
|
417
|
+
|
|
418
|
+
assert isinstance(out, CompatArray)
|
|
419
|
+
assert out.to_list() == [[4.0, 4.0], [10.0, 8.0]]
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def test_getattr_rejects_non_callable_namespace_attrs():
|
|
423
|
+
a = _arr_1d()
|
|
424
|
+
with pytest.raises(CompatArrayAttributeError):
|
|
425
|
+
_ = getattr(a, "pi")
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def test_module_helpers_unwrap_wrap_to_cxp():
|
|
429
|
+
raw = np.array([1, 2, 3])
|
|
430
|
+
wrapped = wrap_arraylike(raw)
|
|
431
|
+
|
|
432
|
+
assert isinstance(wrapped, CompatArray)
|
|
433
|
+
assert unwrap(wrapped) is raw
|
|
434
|
+
assert unwrap(raw) is raw
|
|
435
|
+
|
|
436
|
+
obj: Any = {"k": 1}
|
|
437
|
+
assert wrap_arraylike(obj) is obj
|
|
438
|
+
|
|
439
|
+
out = CompatArray(raw, xp=np)
|
|
440
|
+
assert out.xp_name == "NumPy"
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def test_some_pyi_annotated_methods_via_dynamic_dispatch():
|
|
444
|
+
a = _arr_2d()
|
|
445
|
+
|
|
446
|
+
b = a.astype(np.float64)
|
|
447
|
+
c = a.abs()
|
|
448
|
+
d = a.add(1)
|
|
449
|
+
s = a.sum()
|
|
450
|
+
r = a.reshape((4,))
|
|
451
|
+
t = a.take(np.array([0], dtype=np.int32), axis=1)
|
|
452
|
+
|
|
453
|
+
assert isinstance(b, CompatArray)
|
|
454
|
+
assert str(b.dtype) == "float64"
|
|
455
|
+
assert isinstance(c, CompatArray)
|
|
456
|
+
assert isinstance(d, CompatArray)
|
|
457
|
+
assert isinstance(s, CompatArray)
|
|
458
|
+
assert isinstance(r, CompatArray)
|
|
459
|
+
assert isinstance(t, CompatArray)
|
|
460
|
+
|
|
461
|
+
assert r.shape == (4,)
|
|
462
|
+
assert t.shape == (2, 1)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def test_cxp_of_compatarray_matches_array_namespace():
|
|
466
|
+
a = _arr_1d()
|
|
467
|
+
cxp = a.cxp
|
|
468
|
+
|
|
469
|
+
assert cxp is not None
|
|
470
|
+
assert cxp.xp_name == a.xp_name
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
if __name__ == "__main__":
|
|
474
|
+
test_unique_all_unique_counts_unique_inverse()
|
|
475
|
+
print("=" * 40)
|
|
476
|
+
|
|
477
|
+
test_numpy_operator_overloads()
|
|
478
|
+
print("=" * 40)
|
|
479
|
+
# test_torch_operator_overloads()
|
|
480
|
+
print("=" * 40)
|
|
481
|
+
|
|
482
|
+
test_add()
|
|
@@ -1,261 +0,0 @@
|
|
|
1
|
-
import sys
|
|
2
|
-
from typing import Any
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
|
-
sys.path.insert(0, "/data/tianzhen/my_packages/cobra-array/src")
|
|
8
|
-
|
|
9
|
-
from cobra_array.array_api import torch_xp
|
|
10
|
-
from cobra_array.compat import CompatArray, unwrap, wrap_arraylike
|
|
11
|
-
from cobra_array.exceptions import CompatArrayAttributeError, NotArrayAPIObjectError
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _arr_1d():
|
|
15
|
-
return CompatArray(np.array([1.0, 2.0, 3.0], dtype=np.float32))
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def _arr_2d():
|
|
19
|
-
return CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def test_from_other_with_numpy_namespace():
|
|
23
|
-
out = CompatArray.from_other([1, 2, 3], xp="numpy")
|
|
24
|
-
assert isinstance(out, CompatArray)
|
|
25
|
-
assert out.xp_name == "NumPy"
|
|
26
|
-
assert out.to_list() == [1, 2, 3]
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def test_new_from_compatarray_copy_and_no_copy():
|
|
30
|
-
a = _arr_1d()
|
|
31
|
-
b = CompatArray(a)
|
|
32
|
-
c = CompatArray(a, copy=True)
|
|
33
|
-
|
|
34
|
-
assert b is a
|
|
35
|
-
assert c is not a
|
|
36
|
-
assert c.to_list() == a.to_list()
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def test_new_invalid_input_raises():
|
|
40
|
-
with pytest.raises(NotArrayAPIObjectError):
|
|
41
|
-
CompatArray("not-array") # type: ignore[arg-type]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def test_to_numpy_to_list_and_array_protocol():
|
|
45
|
-
a = _arr_1d()
|
|
46
|
-
|
|
47
|
-
n1 = a.to_numpy(copy=False)
|
|
48
|
-
n2 = np.asarray(a)
|
|
49
|
-
|
|
50
|
-
assert isinstance(n1, np.ndarray)
|
|
51
|
-
assert isinstance(n2, np.ndarray)
|
|
52
|
-
assert n1.tolist() == [1.0, 2.0, 3.0]
|
|
53
|
-
assert n2.tolist() == [1.0, 2.0, 3.0]
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def test_to_device_on_numpy_backend():
|
|
57
|
-
a = _arr_1d()
|
|
58
|
-
moved = a.to_device("cpu")
|
|
59
|
-
|
|
60
|
-
assert isinstance(moved, np.ndarray)
|
|
61
|
-
assert moved.tolist() == a.to_list()
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@pytest.mark.skipif(torch_xp is None, reason="PyTorch not available")
|
|
65
|
-
def test_to_tensor_available_backend():
|
|
66
|
-
a = _arr_1d()
|
|
67
|
-
t = a.to_tensor(device="cpu")
|
|
68
|
-
|
|
69
|
-
assert isinstance(t, torch_xp.Tensor)
|
|
70
|
-
assert tuple(t.shape) == (3,)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def test_unstack_and_nonzero():
|
|
74
|
-
a = _arr_2d()
|
|
75
|
-
|
|
76
|
-
pieces = a.unstack(axis=0)
|
|
77
|
-
nz = a.nonzero()
|
|
78
|
-
|
|
79
|
-
assert isinstance(pieces, tuple)
|
|
80
|
-
assert len(pieces) == 2
|
|
81
|
-
assert all(isinstance(x, CompatArray) for x in pieces)
|
|
82
|
-
assert pieces[0].to_list() == [1.0, 2.0]
|
|
83
|
-
|
|
84
|
-
assert isinstance(nz, tuple)
|
|
85
|
-
assert len(nz) == 2
|
|
86
|
-
assert all(isinstance(x, CompatArray) for x in nz)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def test_unique_all_unique_counts_unique_inverse():
|
|
90
|
-
a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
|
|
91
|
-
|
|
92
|
-
all_result = a.unique_all()
|
|
93
|
-
counts_result = a.unique_counts()
|
|
94
|
-
inverse_result = a.unique_inverse()
|
|
95
|
-
|
|
96
|
-
assert isinstance(all_result.values, CompatArray)
|
|
97
|
-
assert isinstance(all_result.indices, CompatArray)
|
|
98
|
-
assert isinstance(all_result.inverse_indices, CompatArray)
|
|
99
|
-
assert isinstance(all_result.counts, CompatArray)
|
|
100
|
-
|
|
101
|
-
assert isinstance(counts_result.values, CompatArray)
|
|
102
|
-
assert isinstance(counts_result.counts, CompatArray)
|
|
103
|
-
|
|
104
|
-
assert isinstance(inverse_result.values, CompatArray)
|
|
105
|
-
assert isinstance(inverse_result.inverse_indices, CompatArray)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def test_copy_and_basic_properties():
|
|
109
|
-
a = _arr_2d()
|
|
110
|
-
b = a.copy()
|
|
111
|
-
|
|
112
|
-
assert isinstance(a.arr, np.ndarray)
|
|
113
|
-
assert isinstance(b, CompatArray)
|
|
114
|
-
assert b is not a
|
|
115
|
-
|
|
116
|
-
assert str(a.device) == "cpu"
|
|
117
|
-
assert a.shape == (2, 2)
|
|
118
|
-
assert a.ndim == 2
|
|
119
|
-
assert a.size == 4
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def test_transpose_properties():
|
|
123
|
-
a = _arr_2d()
|
|
124
|
-
|
|
125
|
-
t = a.T
|
|
126
|
-
|
|
127
|
-
assert isinstance(t, CompatArray)
|
|
128
|
-
assert t.shape == (2, 2)
|
|
129
|
-
|
|
130
|
-
with pytest.raises(AttributeError):
|
|
131
|
-
_ = a.mT
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def test_len_and_repr():
|
|
135
|
-
a = _arr_1d()
|
|
136
|
-
z = CompatArray(np.array(5))
|
|
137
|
-
|
|
138
|
-
assert len(a) == 3
|
|
139
|
-
assert "NumPy_Array" in repr(a)
|
|
140
|
-
|
|
141
|
-
with pytest.raises(TypeError):
|
|
142
|
-
len(z)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def test_getitem_and_setitem():
|
|
146
|
-
a = CompatArray(np.array([10, 20, 30], dtype=np.int32))
|
|
147
|
-
|
|
148
|
-
assert int(a[1]) == 20
|
|
149
|
-
|
|
150
|
-
a[1] = 99
|
|
151
|
-
assert a.to_list() == [10, 99, 30]
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
def test_scalar_conversions():
|
|
155
|
-
b = CompatArray(np.array(True))
|
|
156
|
-
i = CompatArray(np.array(7))
|
|
157
|
-
f = CompatArray(np.array(1.5, dtype=np.float32))
|
|
158
|
-
c = CompatArray(np.array(2 + 3j, dtype=np.complex64))
|
|
159
|
-
|
|
160
|
-
assert bool(b) is True
|
|
161
|
-
assert int(i) == 7
|
|
162
|
-
assert i.__index__() == 7
|
|
163
|
-
assert float(f) == pytest.approx(1.5)
|
|
164
|
-
assert complex(c) == complex(2, 3)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def test_operator_overloads():
|
|
168
|
-
a = CompatArray(np.array([1, 2, 3], dtype=np.int32))
|
|
169
|
-
b = CompatArray(np.array([3, 2, 1], dtype=np.int32))
|
|
170
|
-
|
|
171
|
-
assert a.__abs__().to_list() == [1, 2, 3]
|
|
172
|
-
assert (a + b).to_list() == [4, 4, 4]
|
|
173
|
-
assert (a - b).to_list() == [-2, 0, 2]
|
|
174
|
-
assert (a * b).to_list() == [3, 4, 3]
|
|
175
|
-
assert (a / 2).to_list() == [0.5, 1.0, 1.5]
|
|
176
|
-
assert (a // 2).to_list() == [0, 1, 1]
|
|
177
|
-
assert (a % 2).to_list() == [1, 0, 1]
|
|
178
|
-
assert (a & b).to_list() == [1, 2, 1]
|
|
179
|
-
assert (a | b).to_list() == [3, 2, 3]
|
|
180
|
-
assert (a ^ b).to_list() == [2, 0, 2]
|
|
181
|
-
assert (a << 1).to_list() == [2, 4, 6]
|
|
182
|
-
assert (a >> 1).to_list() == [0, 1, 1]
|
|
183
|
-
|
|
184
|
-
assert (-a).to_list() == [-1, -2, -3]
|
|
185
|
-
assert (+a).to_list() == [1, 2, 3]
|
|
186
|
-
assert (~a).to_list() == [-2, -3, -4]
|
|
187
|
-
|
|
188
|
-
assert (a == b).to_list() == [False, True, False]
|
|
189
|
-
assert (a != b).to_list() == [True, False, True]
|
|
190
|
-
assert (a > b).to_list() == [False, False, True]
|
|
191
|
-
assert (a >= b).to_list() == [False, True, True]
|
|
192
|
-
assert (a < b).to_list() == [True, False, False]
|
|
193
|
-
assert (a <= b).to_list() == [True, True, False]
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
def test_pow_operator_matches_current_implementation_behavior():
|
|
197
|
-
a = _arr_1d()
|
|
198
|
-
out = a ** 2
|
|
199
|
-
assert isinstance(out, CompatArray)
|
|
200
|
-
assert out.to_list() == [1.0, 4.0, 9.0]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def test_matmul_operator():
|
|
204
|
-
x = CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
|
|
205
|
-
y = CompatArray(np.array([[2.0, 0.0], [1.0, 2.0]], dtype=np.float32))
|
|
206
|
-
|
|
207
|
-
out = x @ y
|
|
208
|
-
|
|
209
|
-
assert isinstance(out, CompatArray)
|
|
210
|
-
assert out.to_list() == [[4.0, 4.0], [10.0, 8.0]]
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
def test_getattr_rejects_non_callable_namespace_attrs():
|
|
214
|
-
a = _arr_1d()
|
|
215
|
-
with pytest.raises(CompatArrayAttributeError):
|
|
216
|
-
_ = getattr(a, "pi")
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def test_module_helpers_unwrap_wrap_to_cxp():
|
|
220
|
-
raw = np.array([1, 2, 3])
|
|
221
|
-
wrapped = wrap_arraylike(raw)
|
|
222
|
-
|
|
223
|
-
assert isinstance(wrapped, CompatArray)
|
|
224
|
-
assert unwrap(wrapped) is raw
|
|
225
|
-
assert unwrap(raw) is raw
|
|
226
|
-
|
|
227
|
-
obj: Any = {"k": 1}
|
|
228
|
-
assert wrap_arraylike(obj) is obj
|
|
229
|
-
|
|
230
|
-
out = CompatArray(raw, xp=np)
|
|
231
|
-
assert out.xp_name == "NumPy"
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def test_some_pyi_annotated_methods_via_dynamic_dispatch():
|
|
235
|
-
a = _arr_2d()
|
|
236
|
-
|
|
237
|
-
b = a.astype(np.float64)
|
|
238
|
-
c = a.abs()
|
|
239
|
-
d = a.add(1)
|
|
240
|
-
s = a.sum()
|
|
241
|
-
r = a.reshape((4,))
|
|
242
|
-
t = a.take(np.array([0], dtype=np.int32), axis=1)
|
|
243
|
-
|
|
244
|
-
assert isinstance(b, CompatArray)
|
|
245
|
-
assert str(b.dtype) == "float64"
|
|
246
|
-
assert isinstance(c, CompatArray)
|
|
247
|
-
assert isinstance(d, CompatArray)
|
|
248
|
-
assert isinstance(s, CompatArray)
|
|
249
|
-
assert isinstance(r, CompatArray)
|
|
250
|
-
assert isinstance(t, CompatArray)
|
|
251
|
-
|
|
252
|
-
assert r.shape == (4,)
|
|
253
|
-
assert t.shape == (2, 1)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
def test_cxp_of_compatarray_matches_array_namespace():
|
|
257
|
-
a = _arr_1d()
|
|
258
|
-
cxp = a.cxp
|
|
259
|
-
|
|
260
|
-
assert cxp is not None
|
|
261
|
-
assert cxp.xp_name == a.xp_name
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|