cobra-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,105 @@
1
+ # src/cobra_array/__init__.py
2
+ """
3
+ cobra-array
4
+ ===========
5
+
6
+ A unified array interface for multiple array libraries, providing seamless interoperability and convenient utilities for array manipulation and conversion.
7
+ A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across NumPy/PyTorch-style ecosystems.
8
+
9
+ Modules
10
+ -------
11
+ - :mod:`cobra_array.compat`: Compatibility utilities for `compatibility namespaces` and `compatibility arrays`.
12
+ - :mod:`cobra_array.convert`: Utilities for converting between different array types and namespaces.
13
+ - :mod:`cobra_array.default`: Default utilities for array specifications and namespaces.
14
+ - :mod:`cobra_array.array_api`: Utilities for working with array namespaces and devices.
15
+ Functions
16
+ ---------
17
+ - :func:`array_spec`: Get the array specification of an object.
18
+ - :func:`context_spec`: Get the context specification of an object.
19
+ - :func:`as_context`: Convert an array-like object to an array in the determined context namespace.
20
+ - :func:`array_context`: Get the context namespace for a given array or set of arrays.
21
+ - :func:`unify_args`: A decorator to unify the array arguments of a function to the same namespace, dtype, and device.
22
+ - :func:`array_namespace_alias`: Get the alias of an array namespace if it exists.
23
+ - :func:`is_compat_namespace`: Check if an object is a `compatibility namespace`.
24
+ - :func:`is_array_namespace`: Check if an object is an `array namespace`.
25
+
26
+ Examples
27
+ --------
28
+ Basic conversions::
29
+
30
+ import numpy as np
31
+ from cobra_array.convert import to_numpy, to_tensor, to_list
32
+
33
+ data = [[1, 2], [3, 4]]
34
+
35
+ arr_np = to_numpy(data, dtype=np.float32)
36
+ print(type(arr_np), arr_np.dtype) # numpy.ndarray float32
37
+
38
+ arr_torch = to_tensor(data, device="cpu")
39
+ print(type(arr_torch), arr_torch.device)
40
+
41
+ back_to_list = to_list(arr_np)
42
+ print(back_to_list) # [[1.0, 2.0], [3.0, 4.0]]
43
+
44
+ Context-based conversion::
45
+
46
+ import numpy as np
47
+ from cobra_array import array_context, as_context, context_spec
48
+
49
+ with array_context(xp="numpy", dtype=np.float32, device="cpu"):
50
+ x = as_context([1, 2, 3])
51
+ y = as_context(np.array([4, 5]))
52
+ spec = context_spec()
53
+ print(spec.cxp.xp_name, spec.dtype, spec.device)
54
+ print(x, y)
55
+
56
+ Auto-unify function arguments::
57
+
58
+ import numpy as np
59
+ from cobra_array import unify_args
60
+
61
+ @unify_args(ref=0, unify_dtype=True, unify_device=True, arraylike_only=True)
62
+ def add_and_mean(a, b):
63
+ c = a + b
64
+ return c.mean()
65
+
66
+ out = add_and_mean(np.array([1, 2, 3]), [4, 5, 6])
67
+ print(out)
68
+
69
+ Default backend strategy::
70
+
71
+ from cobra_array.default import as_default, default_spec
72
+
73
+ spec = default_spec()
74
+ print(spec.cxp.xp_name, spec.dtype, spec.device)
75
+
76
+ x = as_default([1, 2, 3], unify_dtype=True, unify_device=True)
77
+ print(x, x.dtype)
78
+ """
79
+
80
+ from ._core import (
81
+ array_spec,
82
+ context_spec,
83
+ as_context,
84
+ array_context,
85
+ unify_args
86
+ )
87
+ from ._utils import (
88
+ array_namespace_alias,
89
+ is_compat_namespace,
90
+ is_array_namespace
91
+ )
92
+
93
+ __author__ = "Zhen Tian"
94
+ __version__ = "0.1.0"
95
+
96
+ __all__ = [
97
+ "array_spec",
98
+ "context_spec",
99
+ "as_context",
100
+ "array_context",
101
+ "unify_args",
102
+ "array_namespace_alias",
103
+ "is_compat_namespace",
104
+ "is_array_namespace"
105
+ ]
cobra_array/_core.py ADDED
@@ -0,0 +1,392 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Python version: 3.9
3
+ # @TianZhen
4
+
5
+ from __future__ import annotations
6
+ from functools import wraps
7
+ from itertools import (islice, chain)
8
+ import array_api_compat as api
9
+ from contextvars import ContextVar
10
+ from typing import (Any, Literal, Union, Optional, Dict, overload, TYPE_CHECKING)
11
+
12
+ from .compat import wrap_arraylike
13
+ from .convert import (to_xp, as_array)
14
+ from .default import (ArraySpec, default_spec)
15
+ from .exceptions import (
16
+ NoArrayInputsError,
17
+ GetArrayNamespaceError,
18
+ MissingDependencyError,
19
+ NotArrayAPIObjectError
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from numpy.typing import NDArray
24
+ from .compat import CompatArray
25
+ from .types import (T, dtypeT, DType, Device, ArrayLike, ArrayLibraryName)
26
+
27
+
28
+ def array_spec(
29
+ *arrays: object,
30
+ kw_arrays: Optional[Dict[str, object]] = None,
31
+ ref: Optional[Union[str, int]] = None,
32
+ filter_arraylike: bool = False,
33
+ api_version: Optional[str] = None,
34
+ use_compat: Optional[bool] = None
35
+ ) -> ArraySpec:
36
+ """
37
+ Determine the array API compatible `compatibility namespace`, `dtype` and `device` from the provided array arguments and the reference array.
38
+
39
+ Parameters
40
+ ----------
41
+ arrays : object
42
+ Positional array-like objects to determine the `compatibility namespace`.
43
+
44
+ kw_arrays : Optional[Dict[str, object]], default to `None`
45
+ Keyword array-like objects to determine the `compatibility namespace`.
46
+
47
+ ref : Optional[Union[str, int]], default to `None`
48
+ Reference array to determine the `compatibility namespace`, `dtype` and `device`.
49
+ - `None`: Use all provided arrays in `arrays` and `kw_arrays` to determine the `compatibility namespace`, and return `None` for `dtype` and `device`;
50
+ - _str_: Use the specified keyword array in `kw_arrays` as reference array to determine the `compatibility namespace`, `dtype` and `device`;
51
+ - _int_: Use the specified positional array in `arrays` and `kw_arrays` (in order) as reference array to determine the `compatibility namespace`, `dtype` and `device`. The index and valid range are affected by :param:`filter_arraylike`.
52
+
53
+ filter_arraylike : bool, default to `False`
54
+ Whether to filter the provided inputs to all array-likes via :func:`array_api_compat.is_array_api_obj` when determining the `compatibility namespace`.
55
+
56
+ api_version : Optional[str], default to `None`
57
+ The target array API version for the returned `compatibility namespace`. See also :param:`api_version` in :func:`array_api_compat.array_namespace`.
58
+
59
+ use_compat : Optional[bool], default to `None`
60
+ See also :param:`use_compat` in :func:`array_api_compat.array_namespace`.
61
+ - `None`: Return the native namespace if it is already Array API–compatible, otherwise return a compat wrapper;
62
+ - `True`: Always return the compat-wrapped namespace;
63
+ - `False`: Return the native namespace.
64
+
65
+ NOTE: The compat-wrapped namespace is NOT `compatibility namespace`. The former is a wrapper in :pkg:`array_api_compat`.
66
+
67
+ Returns
68
+ -------
69
+ ArraySpec
70
+ An :class:`ArraySpec` named tuple containing the determined `cxp`(`compatibility namespace`), `dtype` and `device`.
71
+ - If :param:`ref` is `None`, the returned `dtype` and `device` will be `None`.
72
+ - `compatibility namespace` is a wrapper of the native `array namespace` provides a compatibility layer for backend-agnostic array operations. See also :class:`compat.CompatNamespace`.
73
+
74
+ Raises
75
+ ------
76
+ NoArrayInputsError
77
+ If no array inputs are provided in `arrays` and `kw_arrays`.
78
+ GetArrayNamespaceError
79
+ If an error occurs while determining the `compatibility namespace` from the provided inputs or the reference array.
80
+ KeyError
81
+ If `ref` is a string but not a key in `kw_arrays`.
82
+ IndexError
83
+ If `ref` is an integer but out of range for the array inputs.
84
+ TypeError
85
+ If `ref` is not `None`, a string, or an integer.
86
+ NotArrayAPIObjectError
87
+ If the reference array determined by `ref` is not an array API compatible array object.
88
+
89
+ Examples
90
+ --------
91
+
92
+ """
93
+ kw_arrays = kw_arrays or {}
94
+
95
+ if len(arrays) == 0 and len(kw_arrays) == 0:
96
+ # no arrays provided
97
+ raise NoArrayInputsError("Expected at least one array input.")
98
+
99
+ # iterator of all arrays
100
+ all_arrays = chain(arrays, kw_arrays.values())
101
+ if filter_arraylike:
102
+ all_arrays = (a for a in all_arrays if api.is_array_api_obj(a))
103
+
104
+ if ref is None:
105
+ # use arrays and kw_arrays in order to determine the namespace
106
+ try:
107
+ return ArraySpec.create(
108
+ api.array_namespace(
109
+ *all_arrays,
110
+ api_version=api_version,
111
+ use_compat=use_compat
112
+ ), None, None
113
+ )
114
+ except Exception as e:
115
+ raise GetArrayNamespaceError(
116
+ "Failed to determine the `compatibility namespace` from the provided inputs."
117
+ ) from e
118
+
119
+ if isinstance(ref, str):
120
+ # use the specified kw_array as reference array to determine the namespace
121
+ try:
122
+ ref_arr = kw_arrays[ref]
123
+ except KeyError:
124
+ raise KeyError(
125
+ f"Parameter `ref` of `array_spec()` must be a key in `kw_arrays`, got {ref!r}."
126
+ ) from None
127
+ elif isinstance(ref, int):
128
+ # use the specified array as reference array to determine the namespace
129
+ try:
130
+ ref_arr = next(islice(all_arrays, ref, ref + 1))
131
+ except ValueError:
132
+ raise IndexError(
133
+ "Parameter `ref` of `array_spec()` must be a "
134
+ f"non-negative index for the array inputs, got {ref!r}."
135
+ ) from None
136
+ except StopIteration:
137
+ if filter_arraylike:
138
+ raise IndexError(
139
+ "Parameter `ref` of `array_spec()` is out of range "
140
+ f"for the array-like inputs, got {ref!r}."
141
+ ) from None
142
+ raise IndexError(
143
+ "Parameter `ref` of `array_spec()` must be in the range "
144
+ f"[0, {len(arrays) + len(kw_arrays)}) for the array inputs, got {ref!r}."
145
+ ) from None
146
+ else:
147
+ raise TypeError(
148
+ "Parameter `ref` of `array_spec()` must be `str`, "
149
+ f"`int` or `NoneType`, got {type(ref)}."
150
+ )
151
+
152
+ if not filter_arraylike and not api.is_array_api_obj(ref_arr):
153
+ raise NotArrayAPIObjectError(
154
+ f"Reference array must be an array API compatible array object, got {ref_arr!r}."
155
+ )
156
+
157
+ dtype = getattr(ref_arr, "dtype", None)
158
+ device = api.device(ref_arr)
159
+ try:
160
+ return ArraySpec.create(
161
+ api.array_namespace(
162
+ ref_arr,
163
+ api_version=api_version,
164
+ use_compat=use_compat
165
+ ), dtype, device
166
+ )
167
+ except Exception as e:
168
+ raise GetArrayNamespaceError(
169
+ "Failed to determine the `compatibility namespace` from the reference array."
170
+ ) from e
171
+
172
+
173
+ # initialize the context variable for array specification
174
+ _arr_spec_var = ContextVar("arr_spec")
175
+
176
+
177
+ def context_spec() -> ArraySpec:
178
+ """
179
+ Get the `compatibility namespace`, `dtype` and `device` associated with the most recent :func:`unify_array_args`-decorated function call in the current context.
180
+ If there is no such function call in the current context, return the default `compatibility namespace`, `dtype` and `device` from :func:`default.default_spec`.
181
+
182
+ Returns
183
+ -------
184
+ ArraySpec
185
+ An :class:`ArraySpec` named tuple containing the determined `cxp`(`compatibility namespace`), `dtype` and `device`.
186
+
187
+ Raises
188
+ ------
189
+ Refer to :func:`default.default_spec` for possible exceptions.
190
+ """
191
+ try:
192
+ return _arr_spec_var.get()
193
+ except LookupError:
194
+ return default_spec()
195
+
196
+
197
+ @overload
198
+ def as_context(obj: NDArray[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Literal["cpu"]]: ...
199
+ @overload
200
+ def as_context(obj: ArrayLike[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Any]: ...
201
+ @overload
202
+ def as_context(obj: ArrayLike[Any], /, *, unify_dtype: Literal[True] = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[Any, Any]: ...
203
+ @overload
204
+ def as_context(obj: object, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> CompatArray[Any, Any]: ...
205
+ @overload
206
+ def as_context(obj: T, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[True]) -> T: ...
207
+
208
+
209
+ def as_context(
210
+ obj: object,
211
+ /,
212
+ unify_dtype: bool = True,
213
+ unify_device: bool = True,
214
+ copy: bool = False,
215
+ arraylike_only: bool = False
216
+ ) -> Any:
217
+ """
218
+ Convert the given object to a :class:`CompatArray` array in the current context `compatibility namespace`, with the `dtype` and `device` unified to the current context if specified.
219
+
220
+ Parameters
221
+ ----------
222
+ obj : object
223
+ The object to be converted to a :class:`CompatArray` array.
224
+
225
+ unify_dtype : bool, default to `True`
226
+ Whether to unify the `dtype` of the converted array to that of the current context.
227
+
228
+ unify_device : bool, default to `True`
229
+ Whether to unify the `device` of the converted array to that of the current context.
230
+
231
+ copy : bool, default to `False`
232
+ Whether to return a copy of the array if it is already in the current context namespace.
233
+
234
+ arraylike_only : bool, default to `False`
235
+ Whether to only convert array-like objects to arrays in the current context namespace, and return the object itself if it is not array-like.
236
+
237
+ Returns
238
+ -------
239
+ CompatArray[Any, Any]
240
+ The converted array representation of the object in the current context `compatibility namespace`, with the current context `dtype` and `device` if specified.
241
+ object
242
+ If :param:`arraylike_only` is `True` and the object is not array-like.
243
+
244
+ Raises
245
+ ------
246
+ Refer to :func:`convert.as_array`, :func:`context_spec` for possible exceptions.
247
+ """
248
+ spec = context_spec()
249
+ return wrap_arraylike(as_array(
250
+ obj, spec.cxp,
251
+ dtype=spec.dtype if unify_dtype else None,
252
+ device=spec.device if unify_device else None,
253
+ copy=copy,
254
+ arraylike_only=arraylike_only
255
+ ), xp=spec.cxp)
256
+
257
+
258
+ class array_context:
259
+ """
260
+ **Context Manager** to set the context `compatibility namespace`, `dtype` and `device` for the enclosed block of code.
261
+ """
262
+ @classmethod
263
+ def from_array_spec(cls, arr_spec: ArraySpec, /):
264
+ """
265
+ Create an :class:`array_context` from an :class:`ArraySpec` named tuple.
266
+
267
+ Parameters
268
+ ----------
269
+ arr_spec : ArraySpec
270
+ An :class:`ArraySpec` named tuple containing the `cxp`(`compatibility namespace`), `dtype` and `device`.
271
+ """
272
+ return cls(xp=arr_spec.cxp, dtype=arr_spec.dtype, device=arr_spec.device)
273
+
274
+ def __init__(
275
+ self,
276
+ xp: Optional[Union[object, ArrayLibraryName]] = None,
277
+ dtype: Optional[DType] = None,
278
+ device: Optional[Device] = None
279
+ ):
280
+ """
281
+ Initialize the context manager with the specified `array namespace`, `dtype` and `device`.
282
+
283
+ Parameters
284
+ ----------
285
+ xp : Optional[Union[object, ArrayLibraryName]], default to `None`
286
+ The target `array namespace` or array library name for the context.
287
+ - `None`: Use the `compatibility namespace` from the context;
288
+
289
+ dtype : Optional[DTypeT], default to `None`
290
+ The target `dtype` for the context.
291
+ - `None`: Use the `dtype` from the context.
292
+
293
+ device : Optional[DeviceT], default to `None`
294
+ The target `device` for the context.
295
+ - `None`: Use the `device` from the context.
296
+ """
297
+ spec = context_spec()
298
+ self.cur_spec = ArraySpec.create(
299
+ to_xp(xp) if xp is not None else spec.cxp,
300
+ dtype if dtype is not None else spec.dtype,
301
+ device if device is not None else spec.device
302
+ )
303
+
304
+ def __enter__(self):
305
+ """Set the context variable to the specified namespace, dtype and device when entering the context."""
306
+ self._token = _arr_spec_var.set(self.cur_spec)
307
+ return self.cur_spec
308
+
309
+ def __exit__(self, *args):
310
+ """Reset the context variable to its previous value when exiting the context."""
311
+ _arr_spec_var.reset(self._token)
312
+
313
+
314
+ def unify_args(
315
+ ref: Optional[Union[str, int]] = 0,
316
+ /,
317
+ filter_arraylike: bool = True,
318
+ api_version: Optional[str] = None,
319
+ use_compat: Optional[bool] = None,
320
+ unify_dtype: bool = False,
321
+ unify_device: bool = True,
322
+ arraylike_only: bool = True,
323
+ strict: bool = True
324
+ ):
325
+ """
326
+ **Decorator** to unify arguments of a function to the same `compatibility namespace`, `dtype` and `device` determined by the provided array arguments and the reference array.
327
+
328
+ Parameters
329
+ ----------
330
+ ref : Optional[Union[str, int]], default to `None`
331
+ Reference array to determine the `compatibility namespace`, `dtype` and `device`.
332
+ See also :param:`ref` in :func:`array_spec`.
333
+
334
+ filter_arraylike : bool, default to `False`
335
+ Whether to filter the provided inputs to all array-likes via :func:`array_api_compat.is_array_api_obj` when determining the `compatibility namespace`.
336
+
337
+ api_version : Optional[str], default to `None`
338
+ See also :param:`api_version` in :func:`array_spec`.
339
+
340
+ use_compat : Optional[bool], default to `None`
341
+ See also :param:`use_compat` in :func:`array_spec`.
342
+
343
+ unify_dtype : bool, default to `False`
344
+ Whether to unify the `dtype` of arguments to that of the reference array.
345
+
346
+ unify_device : bool, default to `True`
347
+ Whether to unify the `device` of arguments to that of the reference array.
348
+
349
+ arraylike_only : bool, default to `True`
350
+ Whether to only convert array-like objects to arrays in the determined namespace, and return the object itself if it is not array-like.
351
+
352
+ strict : bool, default to `True`
353
+ Whether to raise exceptions when failing to determine the `compatibility namespace`.
354
+ - `True`: Raise exceptions when failing to determine the `compatibility namespace` from the provided inputs or the reference array;
355
+ - `False`: Fall back to the default `compatibility namespace` if an error occurs. If all default array libraries are missing, just run the function without conversion.
356
+
357
+ Raises
358
+ ------
359
+ Refer to :func:`default.default_spec`, :func:`as_context` for possible exceptions.
360
+ """
361
+ def decorator(func):
362
+ @wraps(func)
363
+ def wrapper(*args, **kwargs):
364
+ # determine the namespace, dtype and device for the array inputs
365
+ try:
366
+ spec = array_spec(
367
+ *args, kw_arrays=kwargs, ref=ref,
368
+ filter_arraylike=filter_arraylike,
369
+ api_version=api_version,
370
+ use_compat=use_compat
371
+ )
372
+ except Exception as e:
373
+ if strict:
374
+ raise e
375
+ try:
376
+ # fall back to the default namespace
377
+ spec = default_spec()
378
+ except MissingDependencyError:
379
+ # just run the function without conversion
380
+ return func(*args, **kwargs)
381
+
382
+ with array_context.from_array_spec(spec):
383
+ out_args = tuple(
384
+ as_context(a, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only) for a in args
385
+ )
386
+ out_kwargs = {
387
+ k: as_context(v, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only)
388
+ for k, v in kwargs.items()
389
+ }
390
+ return func(*out_args, **out_kwargs)
391
+ return wrapper
392
+ return decorator
cobra_array/_utils.py ADDED
@@ -0,0 +1,93 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Python version: 3.9
3
+ # @TianZhen
4
+
5
+ from __future__ import annotations
6
+ import array_api_compat as api
7
+ import warnings
8
+ from types import ModuleType
9
+ from typing import Any
10
+
11
+ from .exceptions import UnsupportedNamespaceError
12
+
13
+ # Try to import `cobra_log.warning`.
14
+ try:
15
+ from cobra_log import warning
16
+ _WARN_AVAILABLE = True
17
+ except ImportError:
18
+ _WARN_AVAILABLE = False
19
+
20
+
21
+ def warn(msg: str, /, category: Any, stack: int = 2):
22
+ """Issue a warning message."""
23
+ if _WARN_AVAILABLE:
24
+ return warning(msg, stack=stack)
25
+ return warnings.warn(msg, category=category, stacklevel=stack+1)
26
+
27
+
28
+ def array_namespace_alias(xp: object) -> str:
29
+ """
30
+ Get the alias of the `array namespace`.
31
+
32
+ Parameters
33
+ ----------
34
+ xp : object
35
+ The `array namespace`.
36
+
37
+ Returns
38
+ -------
39
+ str
40
+ The alias of the `array namespace`.
41
+ - Including: `"NumPy"`, `"Cupy"`, `"PyTorch"`, `"NDONNX"`, `"Dask"`, `"JAX"`, `"sparse"` and `"array-api-strict"`.
42
+
43
+ Raises
44
+ ------
45
+ UnsupportedNameSpaceError
46
+ If the input object is not a supported namespace.
47
+ """
48
+ if isinstance(xp, ModuleType):
49
+ if api.is_numpy_namespace(xp):
50
+ return "NumPy"
51
+
52
+ if api.is_cupy_namespace(xp):
53
+ return "Cupy"
54
+
55
+ if api.is_torch_namespace(xp):
56
+ return "PyTorch"
57
+
58
+ if api.is_ndonnx_namespace(xp):
59
+ return "NDONNX"
60
+
61
+ if api.is_dask_namespace(xp):
62
+ return "Dask"
63
+
64
+ if api.is_jax_namespace(xp):
65
+ return "JAX"
66
+
67
+ if api.is_pydata_sparse_namespace(xp):
68
+ return "sparse"
69
+
70
+ if api.is_array_api_strict_namespace(xp):
71
+ return "array-api-strict"
72
+
73
+ raise UnsupportedNamespaceError(
74
+ f"Got unsupported array namespace of type {type(xp)}."
75
+ )
76
+
77
+
78
+ def is_array_namespace(obj: object) -> bool:
79
+ """
80
+ Returns `True` if input is a supported `array namespace`.
81
+ """
82
+ try:
83
+ array_namespace_alias(obj)
84
+ return True
85
+ except UnsupportedNamespaceError:
86
+ return False
87
+
88
+
89
+ def is_compat_namespace(xp: object) -> bool:
90
+ """
91
+ Returns `True` if input is a `compatibility namespace` wrapped by :class:`CompatNamespace`
92
+ """
93
+ return "(compat)" in getattr(xp, "__name__", "")