cobra-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cobra_array/__init__.py +105 -0
- cobra_array/_core.py +392 -0
- cobra_array/_utils.py +93 -0
- cobra_array/array_api.py +117 -0
- cobra_array/compat/__init__.py +23 -0
- cobra_array/compat/_array.py +499 -0
- cobra_array/compat/_array.pyi +1816 -0
- cobra_array/compat/_base.py +53 -0
- cobra_array/compat/_namespace.py +305 -0
- cobra_array/compat/_namespace.pyi +833 -0
- cobra_array/convert.py +334 -0
- cobra_array/convert.pyi +73 -0
- cobra_array/default.py +142 -0
- cobra_array/exceptions.py +74 -0
- cobra_array/types.py +68 -0
- cobra_array-0.1.0.dist-info/METADATA +137 -0
- cobra_array-0.1.0.dist-info/RECORD +20 -0
- cobra_array-0.1.0.dist-info/WHEEL +5 -0
- cobra_array-0.1.0.dist-info/licenses/LICENSE +21 -0
- cobra_array-0.1.0.dist-info/top_level.txt +1 -0
cobra_array/array_api.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# src/cobra_array/array_api.py
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Python version: 3.9
|
|
4
|
+
# @TianZhen
|
|
5
|
+
"""
|
|
6
|
+
Array API utilities for :pkg:`cobra_array`.
|
|
7
|
+
|
|
8
|
+
Attributes
|
|
9
|
+
----------
|
|
10
|
+
- :attr:`torch_xp`: The `PyTorch` array namespace from :pkg:`array_api_compat` if `PyTorch` is available, otherwise `None`.
|
|
11
|
+
- :attr:`numpy_xp`: The `NumPy` array namespace from :pkg:`array_api_compat` if `NumPy` is available, otherwise `None`.
|
|
12
|
+
- :attr:`CUDA_AVAILABLE`: A boolean indicating whether CUDA is available for `PyTorch`.
|
|
13
|
+
- :attr:`TORCH_SUPPORTED_DEVICES`: The set of devices supported by `PyTorch`.
|
|
14
|
+
- :attr:`NUMPY_SUPPORTED_DEVICES`: The set of devices supported by `NumPy`.
|
|
15
|
+
Functions
|
|
16
|
+
---------
|
|
17
|
+
- :func:`resolve_device`: Get the device string from an object or a device specification string, and check if it is compatible with the specified `array namespace` if provided.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
from typing import (Any, TYPE_CHECKING, Optional, Union)
|
|
22
|
+
|
|
23
|
+
from ._utils import (array_namespace_alias, is_compat_namespace)
|
|
24
|
+
from .exceptions import (
|
|
25
|
+
CUDAUnavailableError,
|
|
26
|
+
DeviceNotSupportedError
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from .types import ArrayLibraryName
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# === PyTorch ===
|
|
34
|
+
TORCH_SUPPORTED_DEVICES = {"cpu", "cuda", "xpu", "mkldnn", "opengl", "opencl", "ideep", "hip", "ve", "ort", "mlc", "xla", "lazy", "vulkan", "meta", "hpu"}
|
|
35
|
+
try:
|
|
36
|
+
from array_api_compat import torch as torch_xp
|
|
37
|
+
import torch
|
|
38
|
+
CUDA_AVAILABLE = torch.cuda.is_available()
|
|
39
|
+
except ImportError:
|
|
40
|
+
torch_xp = None
|
|
41
|
+
CUDA_AVAILABLE = False
|
|
42
|
+
|
|
43
|
+
# === NumPy ===
|
|
44
|
+
NUMPY_SUPPORTED_DEVICES = {"cpu"}
|
|
45
|
+
try:
|
|
46
|
+
from array_api_compat import numpy as numpy_xp
|
|
47
|
+
except ImportError:
|
|
48
|
+
numpy_xp = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def resolve_device(
|
|
52
|
+
obj: object,
|
|
53
|
+
/, *,
|
|
54
|
+
xp: Optional[Union[Any, ArrayLibraryName]] = None
|
|
55
|
+
) -> Optional[str]:
|
|
56
|
+
"""
|
|
57
|
+
Get the device string from an object or a device specification string, and check if it is compatible with the specified `array namespace` if provided.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
obj : object
|
|
62
|
+
The input object or device specification string to extract the device information from.
|
|
63
|
+
|
|
64
|
+
xp : Optional[Union[Namespace, CompatNamespace, ArrayLibraryName]], default is `None`
|
|
65
|
+
The `array namespace` to check the device compatibility against.
|
|
66
|
+
- `None`: No compatibility check will be performed.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
str
|
|
71
|
+
The device string extracted from :param:`obj`.
|
|
72
|
+
None
|
|
73
|
+
If :param:`obj` is `None`.
|
|
74
|
+
|
|
75
|
+
Raises
|
|
76
|
+
------
|
|
77
|
+
Refer to :func:`array_namespace_alias` for possible exceptions.
|
|
78
|
+
|
|
79
|
+
DeviceNotSupportedError
|
|
80
|
+
If the extracted device is not compatible with the specified `array namespace`.
|
|
81
|
+
CUDAUnavailableError
|
|
82
|
+
If a CUDA device is specified but CUDA is not available for `PyTorch`.
|
|
83
|
+
"""
|
|
84
|
+
# source
|
|
85
|
+
if obj is None:
|
|
86
|
+
return None
|
|
87
|
+
source = str(obj).lower()
|
|
88
|
+
# check for namespace
|
|
89
|
+
if ":" in source:
|
|
90
|
+
s_type, s_index = source.split(":", 1)
|
|
91
|
+
else:
|
|
92
|
+
s_type, s_index = source, ""
|
|
93
|
+
s_type = s_type.strip()
|
|
94
|
+
s_index = s_index.strip()
|
|
95
|
+
s_fmt = f"{s_type}:{s_index}" if s_index else s_type
|
|
96
|
+
|
|
97
|
+
if xp is not None:
|
|
98
|
+
if isinstance(xp, str):
|
|
99
|
+
xp_name = xp
|
|
100
|
+
elif is_compat_namespace(xp):
|
|
101
|
+
xp_name = xp.xp_name
|
|
102
|
+
else:
|
|
103
|
+
xp_name = array_namespace_alias(xp)
|
|
104
|
+
if xp_name in ("numpy", "NumPy"):
|
|
105
|
+
# NumPy
|
|
106
|
+
if s_type != "cpu":
|
|
107
|
+
raise DeviceNotSupportedError(f"`NumPy` only support CPU device, got {s_fmt}.")
|
|
108
|
+
elif xp_name in ("torch", "PyTorch"):
|
|
109
|
+
# PyTorch
|
|
110
|
+
if s_type == "cuda":
|
|
111
|
+
# check if CUDA is available
|
|
112
|
+
if not CUDA_AVAILABLE:
|
|
113
|
+
raise CUDAUnavailableError("`PyTorch` specified a CUDA device but CUDA is not available.")
|
|
114
|
+
else:
|
|
115
|
+
if s_type not in TORCH_SUPPORTED_DEVICES:
|
|
116
|
+
raise DeviceNotSupportedError(f"`PyTorch` supports devices {TORCH_SUPPORTED_DEVICES}, got {s_fmt}.")
|
|
117
|
+
return s_fmt
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# src/cobra_array/compat/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Compatibility utilities for :pkg:`cobra_color`.
|
|
4
|
+
|
|
5
|
+
Functions
|
|
6
|
+
---------
|
|
7
|
+
- :func:`wrap_arraylike`: Wraps an array-like object in a :class:`CompatArray` array if it is an array API object.
|
|
8
|
+
- :func:`unwrap`: Unwraps a :class:`CompatArray` array to get the backend-specific array instance, or returns the object itself if it is not a :class:`CompatArray` array.
|
|
9
|
+
Classes
|
|
10
|
+
-------
|
|
11
|
+
- :class:`CompatArray`: A backend-agnostic array abstraction compliant with the `Python Array API standard`.
|
|
12
|
+
- :class:`CompatNamespace`: A wrapper around an `array namespace` providing a unified, backend-agnostic functional interface.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from ._array import (CompatArray, wrap_arraylike, unwrap)
|
|
16
|
+
from ._namespace import CompatNamespace
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"CompatArray",
|
|
20
|
+
"wrap_arraylike",
|
|
21
|
+
"unwrap",
|
|
22
|
+
"CompatNamespace",
|
|
23
|
+
]
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Python version: 3.9
|
|
3
|
+
# @TianZhen
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
import array_api_compat as api
|
|
7
|
+
from collections import namedtuple
|
|
8
|
+
|
|
9
|
+
from ._base import Compat
|
|
10
|
+
from ..convert import (to_numpy, to_tensor, to_list, to_xp, as_array)
|
|
11
|
+
from ..exceptions import (NotArrayAPIObjectError, CompatArrayAttributeError)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
UniqueResult = namedtuple("UniqueResult", ["values", "indices", "inverse_indices", "counts"])
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CompatArray(Compat):
|
|
18
|
+
"""
|
|
19
|
+
A backend-agnostic array abstraction compliant with the [`Python Array API standard`](https://data-apis.org/array-api/2024.12/API_specification/index.html).
|
|
20
|
+
|
|
21
|
+
:class:`CompatArray` provides a unified interface for numerical computation across multiple array backends (e.g., `NumPy`, `PyTorch`), strictly adhering to the `Python Array API standard`.
|
|
22
|
+
Detailed documentation is provided for all supported operations to ensure consistent and predictable behavior.
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
- All operations follow the semantics defined by the `Python Array API standard`.
|
|
27
|
+
- Methods correspond directly to standard functions, but are exposed in an object-oriented form.
|
|
28
|
+
- All methods guarantee that any array-like objects in the returned value are automatically wrapped as :class:`CompatArray`. This applies recursively to arrays contained in Python containers (e.g., `tuple`, `list`, `dict`). Non-array objects remain unchanged.
|
|
29
|
+
"""
|
|
30
|
+
_arr = None
|
|
31
|
+
_cxp = None
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_other(cls, obj, /, *, xp, copy=False):
|
|
35
|
+
"""
|
|
36
|
+
Create a :class:`CompatArray` array from another object using the specified `array namespace`.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
obj : object
|
|
41
|
+
The object to be converted to a :class:`CompatArray` array.
|
|
42
|
+
|
|
43
|
+
xp : Union[Any, ArrayLibraryName]
|
|
44
|
+
The `array namespace` or `compatibility namespace` to use for conversion.
|
|
45
|
+
|
|
46
|
+
copy : bool, default to `False`
|
|
47
|
+
Whether to create a copy of the data during conversion via :func:`convert.as_array`.
|
|
48
|
+
|
|
49
|
+
Raises
|
|
50
|
+
------
|
|
51
|
+
Refer to :func:`convert.as_array` for possible exceptions.
|
|
52
|
+
"""
|
|
53
|
+
_xp = to_xp(xp)
|
|
54
|
+
return cls(
|
|
55
|
+
as_array(unwrap(obj), _xp, copy=copy),
|
|
56
|
+
xp=_xp
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def __new__(cls, arr, /, *, copy=False, **kwargs):
|
|
60
|
+
if isinstance(arr, CompatArray):
|
|
61
|
+
# for `CompatArray` input
|
|
62
|
+
return arr.copy() if copy else arr
|
|
63
|
+
|
|
64
|
+
# for non-`CompatArray` input
|
|
65
|
+
if not api.is_array_api_obj(arr):
|
|
66
|
+
raise NotArrayAPIObjectError(
|
|
67
|
+
f"Parameter `arr` of `CompatArray` must be an array API compatible array object, got {type(arr)}."
|
|
68
|
+
)
|
|
69
|
+
_cxp = to_cxp(kwargs.get("xp", api.array_namespace(arr)))
|
|
70
|
+
_xp = _cxp.xp
|
|
71
|
+
obj = super().__new__(cls, _xp)
|
|
72
|
+
obj._arr = as_array(arr, _xp, copy=True) if copy else arr
|
|
73
|
+
obj._cxp = _cxp
|
|
74
|
+
|
|
75
|
+
return obj
|
|
76
|
+
|
|
77
|
+
# === Conversion functions ===
|
|
78
|
+
def to_numpy(self, *, copy=False):
|
|
79
|
+
"""
|
|
80
|
+
Convert `self` to a `NumPy array`.
|
|
81
|
+
See also :func:`convert.to_numpy`.
|
|
82
|
+
"""
|
|
83
|
+
return to_numpy(self._arr, copy=copy)
|
|
84
|
+
|
|
85
|
+
def to_tensor(self, *, device=None, copy=False):
|
|
86
|
+
"""
|
|
87
|
+
Convert `self` to a `PyTorch tensor`.
|
|
88
|
+
See also :func:`convert.to_tensor`.
|
|
89
|
+
"""
|
|
90
|
+
return to_tensor(self._arr, device=device, copy=copy)
|
|
91
|
+
|
|
92
|
+
def to_list(self, *, copy=False):
|
|
93
|
+
"""
|
|
94
|
+
Convert `self` to a built-in `list`.
|
|
95
|
+
See also :func:`convert.to_list`.
|
|
96
|
+
"""
|
|
97
|
+
return to_list(self._arr, copy=copy)
|
|
98
|
+
|
|
99
|
+
# === Device functions ===
|
|
100
|
+
def to_device(self, device, /, *, stream=None):
|
|
101
|
+
"""
|
|
102
|
+
Copy `self` from the device on which it currently resides to the specified `device`.
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
device : Device
|
|
107
|
+
A device object or name.
|
|
108
|
+
|
|
109
|
+
stream : Optional[Union[int, Any]], default to `None`
|
|
110
|
+
Stream object to use during copy.
|
|
111
|
+
In addition to the types supported in `array.__dlpack__`, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
|
|
112
|
+
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
CompatArray
|
|
116
|
+
An array with the same data and data type as `self` and located on the specified device.
|
|
117
|
+
|
|
118
|
+
Notes
|
|
119
|
+
-----
|
|
120
|
+
- For `NumPy`, this function effectively does nothing since the only supported device is the `CPU`;
|
|
121
|
+
- For `CuPy`, this method supports CuPy CUDA Device <cupy.cuda.Device> and Stream <cupy.cuda.Stream> objects.
|
|
122
|
+
- For `PyTorch`, this is the same as `self.to(device)` <torch.Tensor.to> (the stream argument is not supported in PyTorch).
|
|
123
|
+
"""
|
|
124
|
+
return api.to_device(self._arr, device, stream=stream)
|
|
125
|
+
|
|
126
|
+
# === Manipulation functions ===
|
|
127
|
+
def unstack(self, *, axis=0):
|
|
128
|
+
"""
|
|
129
|
+
Splits `self` into a sequence of arrays along the given axis.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
axis : int, default to `0`
|
|
134
|
+
Axis along which the array will be split.
|
|
135
|
+
A valid :param:`axis` must be on the interval `[-N, N)`, where `N` is the rank (number of dimensions) of `self`.
|
|
136
|
+
If provided an :param:`axis` outside of the required interval, the function must raise an exception.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
Tuple[CompatArray[TT, DT], ...]
|
|
141
|
+
Tuple of slices along the given dimension.
|
|
142
|
+
All the arrays have the same shape.
|
|
143
|
+
"""
|
|
144
|
+
result = self._get_xp_attr("unstack")(self._arr, axis=axis)
|
|
145
|
+
return tuple(CompatArray(arr, xp=self._xp) for arr in result)
|
|
146
|
+
|
|
147
|
+
# === Searching functions ===
|
|
148
|
+
def nonzero(self):
|
|
149
|
+
"""
|
|
150
|
+
Returns the indices of `self` elements which are non-zero.
|
|
151
|
+
- `self` must have a positive rank. If `self` is zero-dimensional, the function must raise an exception.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
Tuple[CompatArray[int, DT], ...]
|
|
156
|
+
A tuple of `k` arrays, one for each dimension of `self` and each of size `n` (where `n` is the total number of non-zero elements), containing the indices of the non-zero elements in that dimension.
|
|
157
|
+
The indices must be returned in row-major, C-style order.
|
|
158
|
+
The returned array must have the default array index data type.
|
|
159
|
+
|
|
160
|
+
Notes
|
|
161
|
+
-----
|
|
162
|
+
- If `self` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero;
|
|
163
|
+
- If `self` has a boolean data type, non-zero elements are those elements which are equal to `True`.
|
|
164
|
+
"""
|
|
165
|
+
result = self._get_xp_attr("nonzero")(self._arr)
|
|
166
|
+
return tuple(CompatArray(arr, xp=self._xp) for arr in result)
|
|
167
|
+
|
|
168
|
+
# === Set functions ===
|
|
169
|
+
def unique_all(self):
|
|
170
|
+
"""
|
|
171
|
+
Returns the unique elements of `self`, the first occurring indices for each unique element in `self`, the indices from the set of unique elements that reconstruct `self`, and the corresponding counts for each unique element in `self`.
|
|
172
|
+
- `self`:
|
|
173
|
+
- more than one dimension: the function must flatten `self` and return the unique elements of the flattened `self`.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
UniqueAllResult[CompatArray[TT, DT]]
|
|
178
|
+
A namedtuple (`values`, `indices`, `inverse_indices`, `counts`):
|
|
179
|
+
1. :attr:`values`: A one-dimensional array containing the unique elements of `self`. The array must have the same data type as `self`;
|
|
180
|
+
2. :attr:`indices`: An array containing the indices (first occurrences) of a flattened `self` that result in :attr:`values`. The array must have the same shape as :attr:`values` and must have the default array index data type;
|
|
181
|
+
3. :attr:`inverse_indices`: An array containing the indices of :attr:`values` that reconstruct `self`. The array must have the same shape as `self` and must have the default array index data type;
|
|
182
|
+
4. :attr:`counts`: An array containing the number of times each unique element occurs in `self`. The order of the returned counts must match the order of :attr:`values`, such that a specific element in :attr:`counts` corresponds to the respective unique element in :attr:`values`. The returned array must have same shape as :attr:`values` and must have the default array index data type.
|
|
183
|
+
"""
|
|
184
|
+
result = self._get_xp_attr("unique_all")(self._arr)
|
|
185
|
+
return UniqueResult(
|
|
186
|
+
values=CompatArray(result.values, xp=self._xp),
|
|
187
|
+
indices=CompatArray(result.indices, xp=self._xp),
|
|
188
|
+
inverse_indices=CompatArray(result.inverse_indices, xp=self._xp),
|
|
189
|
+
counts=CompatArray(result.counts, xp=self._xp),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def unique_counts(self):
|
|
193
|
+
"""
|
|
194
|
+
Returns the unique elements of `self` and the corresponding counts for each unique element in `self`.
|
|
195
|
+
- `self`:
|
|
196
|
+
- more than one dimension: the function must flatten `self` and return the unique elements of the flattened `self`.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
UniqueCountsResult[CompatArray[TT, DT]]
|
|
201
|
+
A namedtuple (`values`, `counts`):
|
|
202
|
+
1. :attr:`values`: A one-dimensional array containing the unique elements of `self`. The array must have the same data type as `self`;
|
|
203
|
+
2. :attr:`counts`: An array containing the number of times each unique element occurs in `self`. The order of the returned counts must match the order of :attr:`values`, such that a specific element in :attr:`counts` corresponds to the respective unique element in :attr:`values`. The returned array must have same shape as :attr:`values` and must have the default array index data type.
|
|
204
|
+
"""
|
|
205
|
+
result = self._get_xp_attr("unique_counts")(self._arr)
|
|
206
|
+
return UniqueResult(
|
|
207
|
+
values=CompatArray(result.values, xp=self._xp),
|
|
208
|
+
indices=None,
|
|
209
|
+
inverse_indices=None,
|
|
210
|
+
counts=CompatArray(result.counts, xp=self._xp),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def unique_inverse(self):
|
|
214
|
+
"""
|
|
215
|
+
Returns the unique elements of `self` and the indices from the set of unique elements that reconstruct `self`.
|
|
216
|
+
- `self`:
|
|
217
|
+
- more than one dimension: the function must flatten `self` and return the unique elements of the flattened `self`.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
UniqueInverseResult[TT, DT]
|
|
222
|
+
A namedtuple (`values`, `inverse_indices`):
|
|
223
|
+
1. :attr:`values`: A one-dimensional array containing the unique elements of `self`. The array must have the same data type as `self`;
|
|
224
|
+
2. :attr:`inverse_indices`: An array containing the indices of :attr:`values` that reconstruct `self`. The array must have the same shape as `self` and must have the default array index data type.
|
|
225
|
+
"""
|
|
226
|
+
result = self._get_xp_attr("unique_inverse")(self._arr)
|
|
227
|
+
return UniqueResult(
|
|
228
|
+
values=CompatArray(result.values, xp=self._xp),
|
|
229
|
+
indices=None,
|
|
230
|
+
inverse_indices=CompatArray(result.inverse_indices, xp=self._xp),
|
|
231
|
+
counts=None,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# === Others ===
|
|
235
|
+
def copy(self):
|
|
236
|
+
"""
|
|
237
|
+
Return a copy of `self` via :func:`convert.as_array`.
|
|
238
|
+
"""
|
|
239
|
+
return CompatArray.from_other(self._arr, xp=self._xp, copy=True)
|
|
240
|
+
|
|
241
|
+
def _get_attr(self, name: str):
|
|
242
|
+
"""Try to get the attribute `name` from `self`."""
|
|
243
|
+
try:
|
|
244
|
+
return getattr(self._arr, name)
|
|
245
|
+
except AttributeError:
|
|
246
|
+
raise AttributeError(f"`CompatArray` `{self._xp_name}` has no attribute `{name}`.") from None
|
|
247
|
+
|
|
248
|
+
def _get_cxp_attr(self, name: str):
|
|
249
|
+
"""Try to get the attribute `name` from the `compatibility namespace`."""
|
|
250
|
+
try:
|
|
251
|
+
return getattr(self._cxp, name)
|
|
252
|
+
except AttributeError:
|
|
253
|
+
raise AttributeError(f"Compatibility namespace `{self._xp_name}` of `{self.__class__.__name__}` has no attribute `{name}`.") from None
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def arr(self):
|
|
257
|
+
"""
|
|
258
|
+
The backend-specific array instance managed by :class:`CompatArray`.
|
|
259
|
+
"""
|
|
260
|
+
return self._arr
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def dtype(self):
|
|
264
|
+
"""
|
|
265
|
+
Data type of the elements of `self`.
|
|
266
|
+
"""
|
|
267
|
+
try:
|
|
268
|
+
return self._get_xp_attr("dtype")(self._arr)
|
|
269
|
+
except (AttributeError, TypeError):
|
|
270
|
+
return self._get_attr("dtype")
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def device(self):
|
|
274
|
+
"""
|
|
275
|
+
DeviceT on which `self` is stored.
|
|
276
|
+
"""
|
|
277
|
+
return api.device(self._arr)
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def shape(self):
|
|
281
|
+
"""
|
|
282
|
+
Dimensions of `self`.
|
|
283
|
+
"""
|
|
284
|
+
try:
|
|
285
|
+
result = self._get_xp_attr("shape")(self._arr)
|
|
286
|
+
except (AttributeError, TypeError):
|
|
287
|
+
result = self._get_attr("shape")
|
|
288
|
+
return tuple(result)
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def ndim(self):
|
|
292
|
+
"""
|
|
293
|
+
Number of `self` dimensions (axes).
|
|
294
|
+
"""
|
|
295
|
+
try:
|
|
296
|
+
return self._get_xp_attr("ndim")(self._arr)
|
|
297
|
+
except (AttributeError, TypeError):
|
|
298
|
+
return self._get_attr("ndim")
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def size(self):
|
|
302
|
+
"""
|
|
303
|
+
Number of elements in `self`.
|
|
304
|
+
"""
|
|
305
|
+
try:
|
|
306
|
+
return self._get_xp_attr("size")(self._arr)
|
|
307
|
+
except (AttributeError, TypeError):
|
|
308
|
+
return self._get_attr("size")
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def T(self):
|
|
312
|
+
"""
|
|
313
|
+
Transpose of `self`.
|
|
314
|
+
- If `self` has fewer than two dimensions, an error should be raised.
|
|
315
|
+
"""
|
|
316
|
+
try:
|
|
317
|
+
result = self._get_xp_attr("T")(self._arr)
|
|
318
|
+
except (AttributeError, TypeError):
|
|
319
|
+
result = self._get_attr("T")
|
|
320
|
+
return CompatArray(result, xp=self._xp)
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def mT(self):
|
|
324
|
+
"""
|
|
325
|
+
Transpose of a matrix (or a stack of matrices).
|
|
326
|
+
- If `self` has fewer than two dimensions, an error should be raised.
|
|
327
|
+
"""
|
|
328
|
+
try:
|
|
329
|
+
result = self._get_xp_attr("mT")(self._arr)
|
|
330
|
+
except (AttributeError, TypeError):
|
|
331
|
+
result = self._get_attr("mT")
|
|
332
|
+
return CompatArray(result, xp=self._xp)
|
|
333
|
+
|
|
334
|
+
def __array__(self):
|
|
335
|
+
"""Allow implicit NumPy conversion."""
|
|
336
|
+
return self.to_numpy()
|
|
337
|
+
|
|
338
|
+
def __getattr__(self, name: str):
|
|
339
|
+
attr = self._get_cxp_attr(name)
|
|
340
|
+
|
|
341
|
+
if callable(attr) and not isinstance(attr, type):
|
|
342
|
+
def wrapper(*args, **kwargs):
|
|
343
|
+
return attr(self._arr, *args, **kwargs)
|
|
344
|
+
return wrapper
|
|
345
|
+
raise CompatArrayAttributeError(f"`CompatArray` `{self._xp_name}` does not support attribute `{name}`.")
|
|
346
|
+
|
|
347
|
+
def __len__(self):
|
|
348
|
+
shape = self.shape
|
|
349
|
+
if len(shape) == 0:
|
|
350
|
+
raise TypeError("`len()` of a 0-D compatible array.")
|
|
351
|
+
return shape[0]
|
|
352
|
+
|
|
353
|
+
def __repr__(self):
|
|
354
|
+
return f"{self._xp_name}_Array({self._arr})"
|
|
355
|
+
|
|
356
|
+
def __abs__(self):
|
|
357
|
+
"""See also :func:`CompatArray.abs`."""
|
|
358
|
+
return self.abs()
|
|
359
|
+
|
|
360
|
+
def __add__(self, other, /):
|
|
361
|
+
"""See also :func:`CompatArray.add`."""
|
|
362
|
+
return self.add(other)
|
|
363
|
+
|
|
364
|
+
def __and__(self, other, /):
|
|
365
|
+
"""See also :func:`CompatArray.bitwise_and`."""
|
|
366
|
+
return self.bitwise_and(other)
|
|
367
|
+
|
|
368
|
+
def __bool__(self):
|
|
369
|
+
"""Converts `self` to a Python `bool` object."""
|
|
370
|
+
return bool(self._arr)
|
|
371
|
+
|
|
372
|
+
def __complex__(self):
|
|
373
|
+
"""Converts `self` to a Python `complex` object."""
|
|
374
|
+
return complex(self._arr) # type: ignore
|
|
375
|
+
|
|
376
|
+
def __eq__(self, other, /):
|
|
377
|
+
"""See also :func:`CompatArray.equal`."""
|
|
378
|
+
return self.equal(other)
|
|
379
|
+
|
|
380
|
+
def __float__(self):
|
|
381
|
+
"""Converts `self` to a Python `float` object."""
|
|
382
|
+
return float(self._arr) # type: ignore
|
|
383
|
+
|
|
384
|
+
def __floordiv__(self, other, /):
|
|
385
|
+
"""See also :func:`CompatArray.floor_divide`."""
|
|
386
|
+
return self.floor_divide(other)
|
|
387
|
+
|
|
388
|
+
def __ge__(self, other, /):
|
|
389
|
+
"""See also :func:`CompatArray.greater_equal`."""
|
|
390
|
+
return self.greater_equal(other)
|
|
391
|
+
|
|
392
|
+
def __getitem__(self, key, /):
|
|
393
|
+
"""Returns `self[key]`."""
|
|
394
|
+
return self._arr[key] # type: ignore
|
|
395
|
+
|
|
396
|
+
def __gt__(self, other, /):
|
|
397
|
+
"""See also :func:`CompatArray.greater`."""
|
|
398
|
+
return self.greater(other)
|
|
399
|
+
|
|
400
|
+
def __index__(self):
|
|
401
|
+
""" Converts `self` to a Python `int` object."""
|
|
402
|
+
return int(self._arr) # type: ignore
|
|
403
|
+
|
|
404
|
+
def __int__(self):
|
|
405
|
+
""" Converts `self` to a Python `int` object."""
|
|
406
|
+
return int(self._arr) # type: ignore
|
|
407
|
+
|
|
408
|
+
def __invert__(self):
|
|
409
|
+
"""See also :func:`CompatArray.bitwise_invert`."""
|
|
410
|
+
return self.bitwise_invert()
|
|
411
|
+
|
|
412
|
+
def __le__(self, other, /):
|
|
413
|
+
"""See also :func:`CompatArray.less_equal`."""
|
|
414
|
+
return self.less_equal(other)
|
|
415
|
+
|
|
416
|
+
def __lshift__(self, other, /):
|
|
417
|
+
"""See also :func:`CompatArray.bitwise_left_shift`."""
|
|
418
|
+
return self.bitwise_left_shift(other)
|
|
419
|
+
|
|
420
|
+
def __lt__(self, other, /):
|
|
421
|
+
"""See also :func:`CompatArray.less`."""
|
|
422
|
+
return self.less(other)
|
|
423
|
+
|
|
424
|
+
def __matmul__(self, other, /):
|
|
425
|
+
"""See also :func:`CompatArray.matmul`."""
|
|
426
|
+
return self.matmul(other)
|
|
427
|
+
|
|
428
|
+
def __mod__(self, other, /):
|
|
429
|
+
"""See also :func:`CompatArray.remainder`."""
|
|
430
|
+
return self.remainder(other)
|
|
431
|
+
|
|
432
|
+
def __mul__(self, other, /):
|
|
433
|
+
"""See also :func:`CompatArray.multiply`."""
|
|
434
|
+
return self.multiply(other)
|
|
435
|
+
|
|
436
|
+
def __ne__(self, other, /):
|
|
437
|
+
"""See also :func:`CompatArray.not_equal`."""
|
|
438
|
+
return self.not_equal(other)
|
|
439
|
+
|
|
440
|
+
def __neg__(self):
|
|
441
|
+
"""See also :func:`CompatArray.negative`."""
|
|
442
|
+
return self.negative()
|
|
443
|
+
|
|
444
|
+
def __or__(self, other, /):
|
|
445
|
+
"""See also :func:`CompatArray.bitwise_or`."""
|
|
446
|
+
return self.bitwise_or(other)
|
|
447
|
+
|
|
448
|
+
def __pos__(self):
|
|
449
|
+
"""See also :func:`CompatArray.positive`."""
|
|
450
|
+
return self.positive()
|
|
451
|
+
|
|
452
|
+
def __pow__(self, other, /):
|
|
453
|
+
"""See also :func:`CompatArray.power`."""
|
|
454
|
+
return self.power(other)
|
|
455
|
+
|
|
456
|
+
def __rshift__(self, other, /):
|
|
457
|
+
"""See also :func:`CompatArray.bitwise_right_shift`."""
|
|
458
|
+
return self.bitwise_right_shift(other)
|
|
459
|
+
|
|
460
|
+
def __setitem__(self, key, value, /):
|
|
461
|
+
"""Sets `self[key]` to `value`."""
|
|
462
|
+
self._arr[key] = value # type: ignore
|
|
463
|
+
|
|
464
|
+
def __sub__(self, other, /):
|
|
465
|
+
"""See also :func:`CompatArray.subtract`."""
|
|
466
|
+
return self.subtract(other)
|
|
467
|
+
|
|
468
|
+
def __truediv__(self, other, /):
|
|
469
|
+
"""See also :func:`CompatArray.divide`."""
|
|
470
|
+
return self.divide(other)
|
|
471
|
+
|
|
472
|
+
def __xor__(self, other, /):
|
|
473
|
+
"""See also :func:`CompatArray.bitwise_xor`."""
|
|
474
|
+
return self.bitwise_xor(other)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def unwrap(obj):
|
|
478
|
+
"""
|
|
479
|
+
Unwraps a :class:`CompatArray` array to get the backend-specific array instance, or returns the object itself if it is not a :class:`CompatArray` array.
|
|
480
|
+
"""
|
|
481
|
+
return obj.arr if isinstance(obj, CompatArray) else obj
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def wrap_arraylike(arr, xp=None):
|
|
485
|
+
"""
|
|
486
|
+
Wraps an array-like object in a :class:`CompatArray` array if it is an array API object.
|
|
487
|
+
"""
|
|
488
|
+
if api.is_array_api_obj(arr):
|
|
489
|
+
if xp is None:
|
|
490
|
+
return CompatArray(arr)
|
|
491
|
+
return CompatArray(arr, xp=xp)
|
|
492
|
+
return arr
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def to_cxp(xp):
|
|
496
|
+
"""Convert an `array namespace` or `compatibility namespace` to a :class:`CompatNamespace` instance."""
|
|
497
|
+
from ._namespace import CompatNamespace
|
|
498
|
+
|
|
499
|
+
return CompatNamespace(xp)
|