cobra-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cobra_array/__init__.py +105 -0
- cobra_array/_core.py +392 -0
- cobra_array/_utils.py +93 -0
- cobra_array/array_api.py +117 -0
- cobra_array/compat/__init__.py +23 -0
- cobra_array/compat/_array.py +499 -0
- cobra_array/compat/_array.pyi +1816 -0
- cobra_array/compat/_base.py +53 -0
- cobra_array/compat/_namespace.py +305 -0
- cobra_array/compat/_namespace.pyi +833 -0
- cobra_array/convert.py +334 -0
- cobra_array/convert.pyi +73 -0
- cobra_array/default.py +142 -0
- cobra_array/exceptions.py +74 -0
- cobra_array/types.py +68 -0
- cobra_array-0.1.0.dist-info/METADATA +137 -0
- cobra_array-0.1.0.dist-info/RECORD +20 -0
- cobra_array-0.1.0.dist-info/WHEEL +5 -0
- cobra_array-0.1.0.dist-info/licenses/LICENSE +21 -0
- cobra_array-0.1.0.dist-info/top_level.txt +1 -0
cobra_array/__init__.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# src/cobra_array/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
cobra-array
|
|
4
|
+
===========
|
|
5
|
+
|
|
6
|
+
A unified array interface for multiple array libraries, providing seamless interoperability and convenient utilities for array manipulation and conversion.
|
|
7
|
+
A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across NumPy/PyTorch-style ecosystems.
|
|
8
|
+
|
|
9
|
+
Modules
|
|
10
|
+
-------
|
|
11
|
+
- :mod:`cobra_array.compat`: Compatibility utilities for `compatibility namespaces` and `compatibility arrays`.
|
|
12
|
+
- :mod:`cobra_array.convert`: Utilities for converting between different array types and namespaces.
|
|
13
|
+
- :mod:`cobra_array.default`: Default utilities for array specifications and namespaces.
|
|
14
|
+
- :mod:`cobra_array.array_api`: Utilities for working with array namespaces and devices.
|
|
15
|
+
Functions
|
|
16
|
+
---------
|
|
17
|
+
- :func:`array_spec`: Get the array specification of an object.
|
|
18
|
+
- :func:`context_spec`: Get the context specification of an object.
|
|
19
|
+
- :func:`as_context`: Convert an array-like object to an array in the determined context namespace.
|
|
20
|
+
- :func:`array_context`: Get the context namespace for a given array or set of arrays.
|
|
21
|
+
- :func:`unify_args`: A decorator to unify the array arguments of a function to the same namespace, dtype, and device.
|
|
22
|
+
- :func:`array_namespace_alias`: Get the alias of an array namespace if it exists.
|
|
23
|
+
- :func:`is_compat_namespace`: Check if an object is a `compatibility namespace`.
|
|
24
|
+
- :func:`is_array_namespace`: Check if an object is an `array namespace`.
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
--------
|
|
28
|
+
Basic conversions::
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
from cobra_array.convert import to_numpy, to_tensor, to_list
|
|
32
|
+
|
|
33
|
+
data = [[1, 2], [3, 4]]
|
|
34
|
+
|
|
35
|
+
arr_np = to_numpy(data, dtype=np.float32)
|
|
36
|
+
print(type(arr_np), arr_np.dtype) # numpy.ndarray float32
|
|
37
|
+
|
|
38
|
+
arr_torch = to_tensor(data, device="cpu")
|
|
39
|
+
print(type(arr_torch), arr_torch.device)
|
|
40
|
+
|
|
41
|
+
back_to_list = to_list(arr_np)
|
|
42
|
+
print(back_to_list) # [[1.0, 2.0], [3.0, 4.0]]
|
|
43
|
+
|
|
44
|
+
Context-based conversion::
|
|
45
|
+
|
|
46
|
+
import numpy as np
|
|
47
|
+
from cobra_array import array_context, as_context, context_spec
|
|
48
|
+
|
|
49
|
+
with array_context(xp="numpy", dtype=np.float32, device="cpu"):
|
|
50
|
+
x = as_context([1, 2, 3])
|
|
51
|
+
y = as_context(np.array([4, 5]))
|
|
52
|
+
spec = context_spec()
|
|
53
|
+
print(spec.cxp.xp_name, spec.dtype, spec.device)
|
|
54
|
+
print(x, y)
|
|
55
|
+
|
|
56
|
+
Auto-unify function arguments::
|
|
57
|
+
|
|
58
|
+
import numpy as np
|
|
59
|
+
from cobra_array import unify_args
|
|
60
|
+
|
|
61
|
+
@unify_args(ref=0, unify_dtype=True, unify_device=True, arraylike_only=True)
|
|
62
|
+
def add_and_mean(a, b):
|
|
63
|
+
c = a + b
|
|
64
|
+
return c.mean()
|
|
65
|
+
|
|
66
|
+
out = add_and_mean(np.array([1, 2, 3]), [4, 5, 6])
|
|
67
|
+
print(out)
|
|
68
|
+
|
|
69
|
+
Default backend strategy::
|
|
70
|
+
|
|
71
|
+
from cobra_array.default import as_default, default_spec
|
|
72
|
+
|
|
73
|
+
spec = default_spec()
|
|
74
|
+
print(spec.cxp.xp_name, spec.dtype, spec.device)
|
|
75
|
+
|
|
76
|
+
x = as_default([1, 2, 3], unify_dtype=True, unify_device=True)
|
|
77
|
+
print(x, x.dtype)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
from ._core import (
|
|
81
|
+
array_spec,
|
|
82
|
+
context_spec,
|
|
83
|
+
as_context,
|
|
84
|
+
array_context,
|
|
85
|
+
unify_args
|
|
86
|
+
)
|
|
87
|
+
from ._utils import (
|
|
88
|
+
array_namespace_alias,
|
|
89
|
+
is_compat_namespace,
|
|
90
|
+
is_array_namespace
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
__author__ = "Zhen Tian"
|
|
94
|
+
__version__ = "0.1.0"
|
|
95
|
+
|
|
96
|
+
__all__ = [
|
|
97
|
+
"array_spec",
|
|
98
|
+
"context_spec",
|
|
99
|
+
"as_context",
|
|
100
|
+
"array_context",
|
|
101
|
+
"unify_args",
|
|
102
|
+
"array_namespace_alias",
|
|
103
|
+
"is_compat_namespace",
|
|
104
|
+
"is_array_namespace"
|
|
105
|
+
]
|
cobra_array/_core.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Python version: 3.9
|
|
3
|
+
# @TianZhen
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from functools import wraps
|
|
7
|
+
from itertools import (islice, chain)
|
|
8
|
+
import array_api_compat as api
|
|
9
|
+
from contextvars import ContextVar
|
|
10
|
+
from typing import (Any, Literal, Union, Optional, Dict, overload, TYPE_CHECKING)
|
|
11
|
+
|
|
12
|
+
from .compat import wrap_arraylike
|
|
13
|
+
from .convert import (to_xp, as_array)
|
|
14
|
+
from .default import (ArraySpec, default_spec)
|
|
15
|
+
from .exceptions import (
|
|
16
|
+
NoArrayInputsError,
|
|
17
|
+
GetArrayNamespaceError,
|
|
18
|
+
MissingDependencyError,
|
|
19
|
+
NotArrayAPIObjectError
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from numpy.typing import NDArray
|
|
24
|
+
from .compat import CompatArray
|
|
25
|
+
from .types import (T, dtypeT, DType, Device, ArrayLike, ArrayLibraryName)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def array_spec(
|
|
29
|
+
*arrays: object,
|
|
30
|
+
kw_arrays: Optional[Dict[str, object]] = None,
|
|
31
|
+
ref: Optional[Union[str, int]] = None,
|
|
32
|
+
filter_arraylike: bool = False,
|
|
33
|
+
api_version: Optional[str] = None,
|
|
34
|
+
use_compat: Optional[bool] = None
|
|
35
|
+
) -> ArraySpec:
|
|
36
|
+
"""
|
|
37
|
+
Determine the array API compatible `compatibility namespace`, `dtype` and `device` from the provided array arguments and the reference array.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
arrays : object
|
|
42
|
+
Positional array-like objects to determine the `compatibility namespace`.
|
|
43
|
+
|
|
44
|
+
kw_arrays : Optional[Dict[str, object]], default to `None`
|
|
45
|
+
Keyword array-like objects to determine the `compatibility namespace`.
|
|
46
|
+
|
|
47
|
+
ref : Optional[Union[str, int]], default to `None`
|
|
48
|
+
Reference array to determine the `compatibility namespace`, `dtype` and `device`.
|
|
49
|
+
- `None`: Use all provided arrays in `arrays` and `kw_arrays` to determine the `compatibility namespace`, and return `None` for `dtype` and `device`;
|
|
50
|
+
- _str_: Use the specified keyword array in `kw_arrays` as reference array to determine the `compatibility namespace`, `dtype` and `device`;
|
|
51
|
+
- _int_: Use the specified positional array in `arrays` and `kw_arrays` (in order) as reference array to determine the `compatibility namespace`, `dtype` and `device`. The index and valid range are affected by :param:`filter_arraylike`.
|
|
52
|
+
|
|
53
|
+
filter_arraylike : bool, default to `False`
|
|
54
|
+
Whether to filter the provided inputs to all array-likes via :func:`array_api_compat.is_array_api_obj` when determining the `compatibility namespace`.
|
|
55
|
+
|
|
56
|
+
api_version : Optional[str], default to `None`
|
|
57
|
+
The target array API version for the returned `compatibility namespace`. See also :param:`api_version` in :func:`array_api_compat.array_namespace`.
|
|
58
|
+
|
|
59
|
+
use_compat : Optional[bool], default to `None`
|
|
60
|
+
See also :param:`use_compat` in :func:`array_api_compat.array_namespace`.
|
|
61
|
+
- `None`: Return the native namespace if it is already Array API–compatible, otherwise return a compat wrapper;
|
|
62
|
+
- `True`: Always return the compat-wrapped namespace;
|
|
63
|
+
- `False`: Return the native namespace.
|
|
64
|
+
|
|
65
|
+
NOTE: The compat-wrapped namespace is NOT `compatibility namespace`. The former is a wrapper in :pkg:`array_api_compat`.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
ArraySpec
|
|
70
|
+
An :class:`ArraySpec` named tuple containing the determined `cxp`(`compatibility namespace`), `dtype` and `device`.
|
|
71
|
+
- If :param:`ref` is `None`, the returned `dtype` and `device` will be `None`.
|
|
72
|
+
- `compatibility namespace` is a wrapper of the native `array namespace` provides a compatibility layer for backend-agnostic array operations. See also :class:`compat.CompatNamespace`.
|
|
73
|
+
|
|
74
|
+
Raises
|
|
75
|
+
------
|
|
76
|
+
NoArrayInputsError
|
|
77
|
+
If no array inputs are provided in `arrays` and `kw_arrays`.
|
|
78
|
+
GetArrayNamespaceError
|
|
79
|
+
If an error occurs while determining the `compatibility namespace` from the provided inputs or the reference array.
|
|
80
|
+
KeyError
|
|
81
|
+
If `ref` is a string but not a key in `kw_arrays`.
|
|
82
|
+
IndexError
|
|
83
|
+
If `ref` is an integer but out of range for the array inputs.
|
|
84
|
+
TypeError
|
|
85
|
+
If `ref` is not `None`, a string, or an integer.
|
|
86
|
+
NotArrayAPIObjectError
|
|
87
|
+
If the reference array determined by `ref` is not an array API compatible array object.
|
|
88
|
+
|
|
89
|
+
Examples
|
|
90
|
+
--------
|
|
91
|
+
|
|
92
|
+
"""
|
|
93
|
+
kw_arrays = kw_arrays or {}
|
|
94
|
+
|
|
95
|
+
if len(arrays) == 0 and len(kw_arrays) == 0:
|
|
96
|
+
# no arrays provided
|
|
97
|
+
raise NoArrayInputsError("Expected at least one array input.")
|
|
98
|
+
|
|
99
|
+
# iterator of all arrays
|
|
100
|
+
all_arrays = chain(arrays, kw_arrays.values())
|
|
101
|
+
if filter_arraylike:
|
|
102
|
+
all_arrays = (a for a in all_arrays if api.is_array_api_obj(a))
|
|
103
|
+
|
|
104
|
+
if ref is None:
|
|
105
|
+
# use arrays and kw_arrays in order to determine the namespace
|
|
106
|
+
try:
|
|
107
|
+
return ArraySpec.create(
|
|
108
|
+
api.array_namespace(
|
|
109
|
+
*all_arrays,
|
|
110
|
+
api_version=api_version,
|
|
111
|
+
use_compat=use_compat
|
|
112
|
+
), None, None
|
|
113
|
+
)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
raise GetArrayNamespaceError(
|
|
116
|
+
"Failed to determine the `compatibility namespace` from the provided inputs."
|
|
117
|
+
) from e
|
|
118
|
+
|
|
119
|
+
if isinstance(ref, str):
|
|
120
|
+
# use the specified kw_array as reference array to determine the namespace
|
|
121
|
+
try:
|
|
122
|
+
ref_arr = kw_arrays[ref]
|
|
123
|
+
except KeyError:
|
|
124
|
+
raise KeyError(
|
|
125
|
+
f"Parameter `ref` of `array_spec()` must be a key in `kw_arrays`, got {ref!r}."
|
|
126
|
+
) from None
|
|
127
|
+
elif isinstance(ref, int):
|
|
128
|
+
# use the specified array as reference array to determine the namespace
|
|
129
|
+
try:
|
|
130
|
+
ref_arr = next(islice(all_arrays, ref, ref + 1))
|
|
131
|
+
except ValueError:
|
|
132
|
+
raise IndexError(
|
|
133
|
+
"Parameter `ref` of `array_spec()` must be a "
|
|
134
|
+
f"non-negative index for the array inputs, got {ref!r}."
|
|
135
|
+
) from None
|
|
136
|
+
except StopIteration:
|
|
137
|
+
if filter_arraylike:
|
|
138
|
+
raise IndexError(
|
|
139
|
+
"Parameter `ref` of `array_spec()` is out of range "
|
|
140
|
+
f"for the array-like inputs, got {ref!r}."
|
|
141
|
+
) from None
|
|
142
|
+
raise IndexError(
|
|
143
|
+
"Parameter `ref` of `array_spec()` must be in the range "
|
|
144
|
+
f"[0, {len(arrays) + len(kw_arrays)}) for the array inputs, got {ref!r}."
|
|
145
|
+
) from None
|
|
146
|
+
else:
|
|
147
|
+
raise TypeError(
|
|
148
|
+
"Parameter `ref` of `array_spec()` must be `str`, "
|
|
149
|
+
f"`int` or `NoneType`, got {type(ref)}."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if not filter_arraylike and not api.is_array_api_obj(ref_arr):
|
|
153
|
+
raise NotArrayAPIObjectError(
|
|
154
|
+
f"Reference array must be an array API compatible array object, got {ref_arr!r}."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
dtype = getattr(ref_arr, "dtype", None)
|
|
158
|
+
device = api.device(ref_arr)
|
|
159
|
+
try:
|
|
160
|
+
return ArraySpec.create(
|
|
161
|
+
api.array_namespace(
|
|
162
|
+
ref_arr,
|
|
163
|
+
api_version=api_version,
|
|
164
|
+
use_compat=use_compat
|
|
165
|
+
), dtype, device
|
|
166
|
+
)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
raise GetArrayNamespaceError(
|
|
169
|
+
"Failed to determine the `compatibility namespace` from the reference array."
|
|
170
|
+
) from e
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# initialize the context variable for array specification
|
|
174
|
+
_arr_spec_var = ContextVar("arr_spec")
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def context_spec() -> ArraySpec:
|
|
178
|
+
"""
|
|
179
|
+
Get the `compatibility namespace`, `dtype` and `device` associated with the most recent :func:`unify_array_args`-decorated function call in the current context.
|
|
180
|
+
If there is no such function call in the current context, return the default `compatibility namespace`, `dtype` and `device` from :func:`default.default_spec`.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
ArraySpec
|
|
185
|
+
An :class:`ArraySpec` named tuple containing the determined `cxp`(`compatibility namespace`), `dtype` and `device`.
|
|
186
|
+
|
|
187
|
+
Raises
|
|
188
|
+
------
|
|
189
|
+
Refer to :func:`default.default_spec` for possible exceptions.
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
return _arr_spec_var.get()
|
|
193
|
+
except LookupError:
|
|
194
|
+
return default_spec()
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@overload
|
|
198
|
+
def as_context(obj: NDArray[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Literal["cpu"]]: ...
|
|
199
|
+
@overload
|
|
200
|
+
def as_context(obj: ArrayLike[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Any]: ...
|
|
201
|
+
@overload
|
|
202
|
+
def as_context(obj: ArrayLike[Any], /, *, unify_dtype: Literal[True] = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[Any, Any]: ...
|
|
203
|
+
@overload
|
|
204
|
+
def as_context(obj: object, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> CompatArray[Any, Any]: ...
|
|
205
|
+
@overload
|
|
206
|
+
def as_context(obj: T, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[True]) -> T: ...
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def as_context(
|
|
210
|
+
obj: object,
|
|
211
|
+
/,
|
|
212
|
+
unify_dtype: bool = True,
|
|
213
|
+
unify_device: bool = True,
|
|
214
|
+
copy: bool = False,
|
|
215
|
+
arraylike_only: bool = False
|
|
216
|
+
) -> Any:
|
|
217
|
+
"""
|
|
218
|
+
Convert the given object to a :class:`CompatArray` array in the current context `compatibility namespace`, with the `dtype` and `device` unified to the current context if specified.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
obj : object
|
|
223
|
+
The object to be converted to a :class:`CompatArray` array.
|
|
224
|
+
|
|
225
|
+
unify_dtype : bool, default to `True`
|
|
226
|
+
Whether to unify the `dtype` of the converted array to that of the current context.
|
|
227
|
+
|
|
228
|
+
unify_device : bool, default to `True`
|
|
229
|
+
Whether to unify the `device` of the converted array to that of the current context.
|
|
230
|
+
|
|
231
|
+
copy : bool, default to `False`
|
|
232
|
+
Whether to return a copy of the array if it is already in the current context namespace.
|
|
233
|
+
|
|
234
|
+
arraylike_only : bool, default to `False`
|
|
235
|
+
Whether to only convert array-like objects to arrays in the current context namespace, and return the object itself if it is not array-like.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
CompatArray[Any, Any]
|
|
240
|
+
The converted array representation of the object in the current context `compatibility namespace`, with the current context `dtype` and `device` if specified.
|
|
241
|
+
object
|
|
242
|
+
If :param:`arraylike_only` is `True` and the object is not array-like.
|
|
243
|
+
|
|
244
|
+
Raises
|
|
245
|
+
------
|
|
246
|
+
Refer to :func:`convert.as_array`, :func:`context_spec` for possible exceptions.
|
|
247
|
+
"""
|
|
248
|
+
spec = context_spec()
|
|
249
|
+
return wrap_arraylike(as_array(
|
|
250
|
+
obj, spec.cxp,
|
|
251
|
+
dtype=spec.dtype if unify_dtype else None,
|
|
252
|
+
device=spec.device if unify_device else None,
|
|
253
|
+
copy=copy,
|
|
254
|
+
arraylike_only=arraylike_only
|
|
255
|
+
), xp=spec.cxp)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class array_context:
|
|
259
|
+
"""
|
|
260
|
+
**Context Manager** to set the context `compatibility namespace`, `dtype` and `device` for the enclosed block of code.
|
|
261
|
+
"""
|
|
262
|
+
@classmethod
|
|
263
|
+
def from_array_spec(cls, arr_spec: ArraySpec, /):
|
|
264
|
+
"""
|
|
265
|
+
Create an :class:`array_context` from an :class:`ArraySpec` named tuple.
|
|
266
|
+
|
|
267
|
+
Parameters
|
|
268
|
+
----------
|
|
269
|
+
arr_spec : ArraySpec
|
|
270
|
+
An :class:`ArraySpec` named tuple containing the `cxp`(`compatibility namespace`), `dtype` and `device`.
|
|
271
|
+
"""
|
|
272
|
+
return cls(xp=arr_spec.cxp, dtype=arr_spec.dtype, device=arr_spec.device)
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
xp: Optional[Union[object, ArrayLibraryName]] = None,
|
|
277
|
+
dtype: Optional[DType] = None,
|
|
278
|
+
device: Optional[Device] = None
|
|
279
|
+
):
|
|
280
|
+
"""
|
|
281
|
+
Initialize the context manager with the specified `array namespace`, `dtype` and `device`.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
xp : Optional[Union[object, ArrayLibraryName]], default to `None`
|
|
286
|
+
The target `array namespace` or array library name for the context.
|
|
287
|
+
- `None`: Use the `compatibility namespace` from the context;
|
|
288
|
+
|
|
289
|
+
dtype : Optional[DTypeT], default to `None`
|
|
290
|
+
The target `dtype` for the context.
|
|
291
|
+
- `None`: Use the `dtype` from the context.
|
|
292
|
+
|
|
293
|
+
device : Optional[DeviceT], default to `None`
|
|
294
|
+
The target `device` for the context.
|
|
295
|
+
- `None`: Use the `device` from the context.
|
|
296
|
+
"""
|
|
297
|
+
spec = context_spec()
|
|
298
|
+
self.cur_spec = ArraySpec.create(
|
|
299
|
+
to_xp(xp) if xp is not None else spec.cxp,
|
|
300
|
+
dtype if dtype is not None else spec.dtype,
|
|
301
|
+
device if device is not None else spec.device
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def __enter__(self):
|
|
305
|
+
"""Set the context variable to the specified namespace, dtype and device when entering the context."""
|
|
306
|
+
self._token = _arr_spec_var.set(self.cur_spec)
|
|
307
|
+
return self.cur_spec
|
|
308
|
+
|
|
309
|
+
def __exit__(self, *args):
|
|
310
|
+
"""Reset the context variable to its previous value when exiting the context."""
|
|
311
|
+
_arr_spec_var.reset(self._token)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def unify_args(
|
|
315
|
+
ref: Optional[Union[str, int]] = 0,
|
|
316
|
+
/,
|
|
317
|
+
filter_arraylike: bool = True,
|
|
318
|
+
api_version: Optional[str] = None,
|
|
319
|
+
use_compat: Optional[bool] = None,
|
|
320
|
+
unify_dtype: bool = False,
|
|
321
|
+
unify_device: bool = True,
|
|
322
|
+
arraylike_only: bool = True,
|
|
323
|
+
strict: bool = True
|
|
324
|
+
):
|
|
325
|
+
"""
|
|
326
|
+
**Decorator** to unify arguments of a function to the same `compatibility namespace`, `dtype` and `device` determined by the provided array arguments and the reference array.
|
|
327
|
+
|
|
328
|
+
Parameters
|
|
329
|
+
----------
|
|
330
|
+
ref : Optional[Union[str, int]], default to `None`
|
|
331
|
+
Reference array to determine the `compatibility namespace`, `dtype` and `device`.
|
|
332
|
+
See also :param:`ref` in :func:`array_spec`.
|
|
333
|
+
|
|
334
|
+
filter_arraylike : bool, default to `False`
|
|
335
|
+
Whether to filter the provided inputs to all array-likes via :func:`array_api_compat.is_array_api_obj` when determining the `compatibility namespace`.
|
|
336
|
+
|
|
337
|
+
api_version : Optional[str], default to `None`
|
|
338
|
+
See also :param:`api_version` in :func:`array_spec`.
|
|
339
|
+
|
|
340
|
+
use_compat : Optional[bool], default to `None`
|
|
341
|
+
See also :param:`use_compat` in :func:`array_spec`.
|
|
342
|
+
|
|
343
|
+
unify_dtype : bool, default to `False`
|
|
344
|
+
Whether to unify the `dtype` of arguments to that of the reference array.
|
|
345
|
+
|
|
346
|
+
unify_device : bool, default to `True`
|
|
347
|
+
Whether to unify the `device` of arguments to that of the reference array.
|
|
348
|
+
|
|
349
|
+
arraylike_only : bool, default to `True`
|
|
350
|
+
Whether to only convert array-like objects to arrays in the determined namespace, and return the object itself if it is not array-like.
|
|
351
|
+
|
|
352
|
+
strict : bool, default to `True`
|
|
353
|
+
Whether to raise exceptions when failing to determine the `compatibility namespace`.
|
|
354
|
+
- `True`: Raise exceptions when failing to determine the `compatibility namespace` from the provided inputs or the reference array;
|
|
355
|
+
- `False`: Fall back to the default `compatibility namespace` if an error occurs. If all default array libraries are missing, just run the function without conversion.
|
|
356
|
+
|
|
357
|
+
Raises
|
|
358
|
+
------
|
|
359
|
+
Refer to :func:`default.default_spec`, :func:`as_context` for possible exceptions.
|
|
360
|
+
"""
|
|
361
|
+
def decorator(func):
|
|
362
|
+
@wraps(func)
|
|
363
|
+
def wrapper(*args, **kwargs):
|
|
364
|
+
# determine the namespace, dtype and device for the array inputs
|
|
365
|
+
try:
|
|
366
|
+
spec = array_spec(
|
|
367
|
+
*args, kw_arrays=kwargs, ref=ref,
|
|
368
|
+
filter_arraylike=filter_arraylike,
|
|
369
|
+
api_version=api_version,
|
|
370
|
+
use_compat=use_compat
|
|
371
|
+
)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
if strict:
|
|
374
|
+
raise e
|
|
375
|
+
try:
|
|
376
|
+
# fall back to the default namespace
|
|
377
|
+
spec = default_spec()
|
|
378
|
+
except MissingDependencyError:
|
|
379
|
+
# just run the function without conversion
|
|
380
|
+
return func(*args, **kwargs)
|
|
381
|
+
|
|
382
|
+
with array_context.from_array_spec(spec):
|
|
383
|
+
out_args = tuple(
|
|
384
|
+
as_context(a, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only) for a in args
|
|
385
|
+
)
|
|
386
|
+
out_kwargs = {
|
|
387
|
+
k: as_context(v, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only)
|
|
388
|
+
for k, v in kwargs.items()
|
|
389
|
+
}
|
|
390
|
+
return func(*out_args, **out_kwargs)
|
|
391
|
+
return wrapper
|
|
392
|
+
return decorator
|
cobra_array/_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Python version: 3.9
|
|
3
|
+
# @TianZhen
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
import array_api_compat as api
|
|
7
|
+
import warnings
|
|
8
|
+
from types import ModuleType
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from .exceptions import UnsupportedNamespaceError
|
|
12
|
+
|
|
13
|
+
# Try to import `cobra_log.warning`.
|
|
14
|
+
try:
|
|
15
|
+
from cobra_log import warning
|
|
16
|
+
_WARN_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
_WARN_AVAILABLE = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def warn(msg: str, /, category: Any, stack: int = 2):
|
|
22
|
+
"""Issue a warning message."""
|
|
23
|
+
if _WARN_AVAILABLE:
|
|
24
|
+
return warning(msg, stack=stack)
|
|
25
|
+
return warnings.warn(msg, category=category, stacklevel=stack+1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def array_namespace_alias(xp: object) -> str:
|
|
29
|
+
"""
|
|
30
|
+
Get the alias of the `array namespace`.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
xp : object
|
|
35
|
+
The `array namespace`.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
str
|
|
40
|
+
The alias of the `array namespace`.
|
|
41
|
+
- Including: `"NumPy"`, `"Cupy"`, `"PyTorch"`, `"NDONNX"`, `"Dask"`, `"JAX"`, `"sparse"` and `"array-api-strict"`.
|
|
42
|
+
|
|
43
|
+
Raises
|
|
44
|
+
------
|
|
45
|
+
UnsupportedNameSpaceError
|
|
46
|
+
If the input object is not a supported namespace.
|
|
47
|
+
"""
|
|
48
|
+
if isinstance(xp, ModuleType):
|
|
49
|
+
if api.is_numpy_namespace(xp):
|
|
50
|
+
return "NumPy"
|
|
51
|
+
|
|
52
|
+
if api.is_cupy_namespace(xp):
|
|
53
|
+
return "Cupy"
|
|
54
|
+
|
|
55
|
+
if api.is_torch_namespace(xp):
|
|
56
|
+
return "PyTorch"
|
|
57
|
+
|
|
58
|
+
if api.is_ndonnx_namespace(xp):
|
|
59
|
+
return "NDONNX"
|
|
60
|
+
|
|
61
|
+
if api.is_dask_namespace(xp):
|
|
62
|
+
return "Dask"
|
|
63
|
+
|
|
64
|
+
if api.is_jax_namespace(xp):
|
|
65
|
+
return "JAX"
|
|
66
|
+
|
|
67
|
+
if api.is_pydata_sparse_namespace(xp):
|
|
68
|
+
return "sparse"
|
|
69
|
+
|
|
70
|
+
if api.is_array_api_strict_namespace(xp):
|
|
71
|
+
return "array-api-strict"
|
|
72
|
+
|
|
73
|
+
raise UnsupportedNamespaceError(
|
|
74
|
+
f"Got unsupported array namespace of type {type(xp)}."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def is_array_namespace(obj: object) -> bool:
|
|
79
|
+
"""
|
|
80
|
+
Returns `True` if input is a supported `array namespace`.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
array_namespace_alias(obj)
|
|
84
|
+
return True
|
|
85
|
+
except UnsupportedNamespaceError:
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def is_compat_namespace(xp: object) -> bool:
|
|
90
|
+
"""
|
|
91
|
+
Returns `True` if input is a `compatibility namespace` wrapped by :class:`CompatNamespace`
|
|
92
|
+
"""
|
|
93
|
+
return "(compat)" in getattr(xp, "__name__", "")
|