cobra-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cobra_array/__init__.py +105 -0
- cobra_array/_core.py +392 -0
- cobra_array/_utils.py +93 -0
- cobra_array/array_api.py +117 -0
- cobra_array/compat/__init__.py +23 -0
- cobra_array/compat/_array.py +499 -0
- cobra_array/compat/_array.pyi +1816 -0
- cobra_array/compat/_base.py +53 -0
- cobra_array/compat/_namespace.py +305 -0
- cobra_array/compat/_namespace.pyi +833 -0
- cobra_array/convert.py +334 -0
- cobra_array/convert.pyi +73 -0
- cobra_array/default.py +142 -0
- cobra_array/exceptions.py +74 -0
- cobra_array/types.py +68 -0
- cobra_array-0.1.0.dist-info/METADATA +137 -0
- cobra_array-0.1.0.dist-info/RECORD +20 -0
- cobra_array-0.1.0.dist-info/WHEEL +5 -0
- cobra_array-0.1.0.dist-info/licenses/LICENSE +21 -0
- cobra_array-0.1.0.dist-info/top_level.txt +1 -0
cobra_array/convert.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
# src/cobra_array/convert.py
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Python version: 3.9
|
|
4
|
+
# @TianZhen
|
|
5
|
+
"""
|
|
6
|
+
Conversion utilities for :pkg:`cobra_array`.
|
|
7
|
+
|
|
8
|
+
Functions
|
|
9
|
+
---------
|
|
10
|
+
- :func:`to_numpy`: Convert an object to a `NumPy array`.
|
|
11
|
+
- :func:`to_tensor`: Convert an object to a `PyTorch tensor`.
|
|
12
|
+
- :func:`to_list`: Convert an object to a built-in `list`.
|
|
13
|
+
- :func:`to_array_namespace` or :func:`to_xp`: Convert an array library name to a `array namespace` or return the `array namespace` directly if is a supported namespace.
|
|
14
|
+
- :func:`as_array`: Convert an object to an array in the specified `array namespace`.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from collections import abc
|
|
18
|
+
import array_api_compat as api
|
|
19
|
+
|
|
20
|
+
from .array_api import (numpy_xp, torch_xp, resolve_device)
|
|
21
|
+
from ._utils import (warn, is_array_namespace, is_compat_namespace)
|
|
22
|
+
from .exceptions import (
|
|
23
|
+
ParameterIgnoredWarning,
|
|
24
|
+
MissingDependencyError,
|
|
25
|
+
ConvertNoneTypeError,
|
|
26
|
+
UnsupportedNamespaceError,
|
|
27
|
+
UnsupportedArrayLibraryNameError,
|
|
28
|
+
ArrayConversionError,
|
|
29
|
+
NumPyConversionError,
|
|
30
|
+
TorchConversionError
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"to_numpy",
|
|
35
|
+
"to_tensor",
|
|
36
|
+
"to_list",
|
|
37
|
+
"to_xp",
|
|
38
|
+
"to_array_namespace",
|
|
39
|
+
"as_array"
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def to_numpy(obj, /, *, dtype=None, copy=True):
|
|
44
|
+
"""
|
|
45
|
+
Convert the given object to a `NumPy array` (i.e. :class:`np.ndarray` instance).
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
obj : object
|
|
50
|
+
The object to be converted to a `NumPy array`.
|
|
51
|
+
- _torch.Tensor_(need :pkg:`torch`): Converted to a `NumPy array` after detaching and moving to CPU;
|
|
52
|
+
- _set_: Converted to a `NumPy array` containing the elements of the set (order is not guaranteed);
|
|
53
|
+
- _others_: Converted to a `NumPy array` directly.
|
|
54
|
+
|
|
55
|
+
dtype : Optional[DTypeT], default to `None`
|
|
56
|
+
The data type of the resulting `NumPy array`.
|
|
57
|
+
- `None`: Use the default data type of the object.
|
|
58
|
+
|
|
59
|
+
copy : bool, default to `True`
|
|
60
|
+
Control whether to create a copy of the object when converting to a `NumPy array`.
|
|
61
|
+
- `True`: Create a copy of the object;
|
|
62
|
+
- `False`: A copy will only be made if necessary.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
NDArray[Any]
|
|
67
|
+
The converted `NumPy array` representation of the object.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
MissingDependencyError
|
|
72
|
+
If `NumPy` is not installed when calling this function.
|
|
73
|
+
NumPyConversionError
|
|
74
|
+
If an error occurs during conversion to a `NumPy array`.
|
|
75
|
+
"""
|
|
76
|
+
if numpy_xp is None:
|
|
77
|
+
raise MissingDependencyError("Dependency `NumPy` is required for `to_numpy()`.")
|
|
78
|
+
|
|
79
|
+
if torch_xp is not None and isinstance(obj, torch_xp.Tensor):
|
|
80
|
+
# as torch.Tensor
|
|
81
|
+
obj = obj.detach().cpu()
|
|
82
|
+
elif isinstance(obj, abc.Set):
|
|
83
|
+
# as set
|
|
84
|
+
obj = list(obj)
|
|
85
|
+
try:
|
|
86
|
+
if copy:
|
|
87
|
+
return numpy_xp.array(obj, dtype=dtype, copy=True)
|
|
88
|
+
return numpy_xp.asarray(obj, dtype=dtype)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
raise NumPyConversionError(
|
|
91
|
+
"An error occurred during conversion to NumPy array."
|
|
92
|
+
) from e
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def to_tensor(obj, /, *, dtype=None, device=None, copy=True):
|
|
96
|
+
"""
|
|
97
|
+
Convert the given object to a `PyTorch tensor` (i.e. :class:`torch.Tensor` instance).
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
obj : Any
|
|
102
|
+
The object to be converted to a `PyTorch tensor`.
|
|
103
|
+
- _set_: Converted to a `PyTorch tensor` containing the elements of the set (order is not guaranteed);
|
|
104
|
+
- `None`: Raises `ConvertNoneTypeError`;
|
|
105
|
+
- _others_: Converted to a `PyTorch tensor` directly.
|
|
106
|
+
|
|
107
|
+
dtype : Optional[DTypeT], default to `None`
|
|
108
|
+
The data type of the resulting `PyTorch tensor`.
|
|
109
|
+
- `None`: Use the default data type of the object.
|
|
110
|
+
|
|
111
|
+
device : Optional[DeviceT], default to `None`
|
|
112
|
+
The device on which the resulting `PyTorch tensor` will be allocated.
|
|
113
|
+
- `None`: Use the default device (usually `"cpu"`).
|
|
114
|
+
|
|
115
|
+
copy : bool, default to `True`
|
|
116
|
+
Control whether to create a copy of the object when converting to a `PyTorch tensor`.
|
|
117
|
+
- `True`: Create a copy of the object;
|
|
118
|
+
- `False`: A copy will only be made if necessary.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
torch.Tensor
|
|
123
|
+
The converted `PyTorch tensor` representation of the object.
|
|
124
|
+
|
|
125
|
+
Raises
|
|
126
|
+
------
|
|
127
|
+
MissingDependencyError
|
|
128
|
+
If `PyTorch` is not installed when calling this function.
|
|
129
|
+
ConvertNoneTypeError
|
|
130
|
+
If `None` is passed as the object to be converted.
|
|
131
|
+
CUDAUnavailableError
|
|
132
|
+
If a non-CPU device is specified but CUDA is not available.
|
|
133
|
+
TorchConversionError
|
|
134
|
+
If an error occurs during conversion to a `PyTorch tensor`.
|
|
135
|
+
"""
|
|
136
|
+
if obj is None:
|
|
137
|
+
raise ConvertNoneTypeError("Can not convert `NoneType` to a PyTorch tensor.")
|
|
138
|
+
if torch_xp is None:
|
|
139
|
+
raise MissingDependencyError("Dependency `PyTorch` is required for `to_tensor()`.")
|
|
140
|
+
device = resolve_device(device, xp="torch")
|
|
141
|
+
|
|
142
|
+
if isinstance(obj, torch_xp.Tensor):
|
|
143
|
+
# as torch.Tensor
|
|
144
|
+
return obj.to(dtype=dtype, device=device, copy=copy)
|
|
145
|
+
|
|
146
|
+
if isinstance(obj, abc.Set):
|
|
147
|
+
# as set
|
|
148
|
+
obj = list(obj)
|
|
149
|
+
try:
|
|
150
|
+
if copy:
|
|
151
|
+
return torch_xp.tensor(obj, dtype=dtype, device=device)
|
|
152
|
+
return torch_xp.as_tensor(obj, dtype=dtype, device=device) # type: ignore
|
|
153
|
+
except Exception as e:
|
|
154
|
+
raise TorchConversionError(
|
|
155
|
+
"An error occurred during conversion to PyTorch tensor."
|
|
156
|
+
) from e
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def to_list(obj, /, *, copy=True):
|
|
160
|
+
"""
|
|
161
|
+
Convert the given object to a built-in `list`.
|
|
162
|
+
|
|
163
|
+
Parameters
|
|
164
|
+
----------
|
|
165
|
+
obj : Any
|
|
166
|
+
The object to be converted to a `list`.
|
|
167
|
+
- _torch.Tensor_(need :pkg:`torch`): Converted to a `list` after detaching and moving to CPU;
|
|
168
|
+
- _np.ndarray_(need :pkg:`numpy`): Converted to a `list`;
|
|
169
|
+
- _Iterable_: Converted to a `list` containing the elements of the iterable (order is preserved);
|
|
170
|
+
- _scalar_(including _str_ and _bytes_): Converted to a `list` containing the scalar value as its single element;
|
|
171
|
+
- `None`: Raises `ConvertNoneTypeError`.
|
|
172
|
+
|
|
173
|
+
copy : bool, default to `True`
|
|
174
|
+
Control whether to create a copy when :param:`obj` is already a `list`. Other types of :param:`obj` will always be converted to a new `list`.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
List[Any]
|
|
179
|
+
The converted `list` representation of the object.
|
|
180
|
+
|
|
181
|
+
Raises
|
|
182
|
+
------
|
|
183
|
+
ConvertNoneTypeError
|
|
184
|
+
If `None` is passed as the object to be converted.
|
|
185
|
+
"""
|
|
186
|
+
if obj is None:
|
|
187
|
+
# as NoneType
|
|
188
|
+
raise ConvertNoneTypeError("Can not convert `NoneType` to a built-in list.")
|
|
189
|
+
|
|
190
|
+
if torch_xp is not None and isinstance(obj, torch_xp.Tensor):
|
|
191
|
+
# as torch.Tensor
|
|
192
|
+
return obj.detach().cpu().tolist()
|
|
193
|
+
|
|
194
|
+
if numpy_xp is not None and isinstance(obj, numpy_xp.ndarray):
|
|
195
|
+
# as numpy.ndarray
|
|
196
|
+
return obj.tolist()
|
|
197
|
+
|
|
198
|
+
if isinstance(obj, list):
|
|
199
|
+
# as list
|
|
200
|
+
return obj.copy() if copy else obj
|
|
201
|
+
|
|
202
|
+
if isinstance(obj, abc.Iterable) and not isinstance(obj, (str, bytes)):
|
|
203
|
+
# as iterable (not including str and bytes)
|
|
204
|
+
return list(obj)
|
|
205
|
+
# as scalar or others
|
|
206
|
+
return [obj]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def to_xp(obj, /):
|
|
210
|
+
"""
|
|
211
|
+
Convert an array library name or `compatibility namespace` to a `array namespace` or return the `array namespace` directly if is a supported namespace.
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
obj : Union[Namespace, CompatNamespace, ArrayLibraryName]
|
|
216
|
+
The `array namespace` or array library name.
|
|
217
|
+
- _ArrayLibraryName_ (`"numpy"` or `"torch"`): Return the corresponding `array namespace` module;
|
|
218
|
+
- _Namespace_: Return the namespace module directly.
|
|
219
|
+
- _CompatNamespace_: Return the underlying `array namespace` module.
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
Namespace
|
|
224
|
+
The `array namespace` module corresponding to the input.
|
|
225
|
+
|
|
226
|
+
Raises
|
|
227
|
+
------
|
|
228
|
+
MissingDependencyError
|
|
229
|
+
If the required array library for the specified `array namespace` is not installed.
|
|
230
|
+
UnsupportedArrayLibraryNameError
|
|
231
|
+
If an unsupported `array namespace` name is specified.
|
|
232
|
+
UnsupportedNamespaceError
|
|
233
|
+
If the input is not a supported `array namespace` name or module.
|
|
234
|
+
"""
|
|
235
|
+
if isinstance(obj, str):
|
|
236
|
+
if obj == "numpy":
|
|
237
|
+
if numpy_xp is None:
|
|
238
|
+
raise MissingDependencyError("Dependency `NumPy` is required for using array namespace.")
|
|
239
|
+
return numpy_xp
|
|
240
|
+
|
|
241
|
+
if obj == "torch":
|
|
242
|
+
if torch_xp is None:
|
|
243
|
+
raise MissingDependencyError("Dependency `PyTorch` is required for using array namespace.")
|
|
244
|
+
return torch_xp
|
|
245
|
+
|
|
246
|
+
raise UnsupportedArrayLibraryNameError(
|
|
247
|
+
"Parameter `obj` of `to_xp()` must be a supported array namespace "
|
|
248
|
+
f"name ('numpy', 'torch'), got {obj!r}."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if is_array_namespace(obj):
|
|
252
|
+
return obj # type: ignore
|
|
253
|
+
|
|
254
|
+
if is_compat_namespace(obj):
|
|
255
|
+
return obj.xp # type: ignore
|
|
256
|
+
|
|
257
|
+
raise UnsupportedNamespaceError(
|
|
258
|
+
f"Parameter `obj` of `to_xp()` is not a supported array namespace, got {obj!r}."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
to_array_namespace = to_xp # alias of `to_xp()`
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def as_array(obj, xp, /, *, dtype=None, device=None, copy=False, arraylike_only=False):
|
|
266
|
+
"""
|
|
267
|
+
Convert the given object to an array in the specified `array namespace` (e.g., `NumPy` or `PyTorch`).
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
obj : object
|
|
272
|
+
The object to be converted to an array.
|
|
273
|
+
|
|
274
|
+
xp : Union[object, ArrayLibraryName]
|
|
275
|
+
The target `array namespace` or array library name for the conversion.
|
|
276
|
+
- _ArrayLibraryName_ (`"numpy"` or `"torch"`): Converted to a `NumPy array` or `PyTorch tensor` respectively using the corresponding conversion functions;
|
|
277
|
+
- _Namespace_ or _CompatNamespace_: Converted to an array using the `asarray()` function provided by the namespace module, which must be compatible with the array API standard.
|
|
278
|
+
|
|
279
|
+
dtype : Optional[DTypeT], default to `None`
|
|
280
|
+
The data type of the resulting array.
|
|
281
|
+
- `None`: Use the default data type of the object.
|
|
282
|
+
|
|
283
|
+
device : Optional[DeviceT], default to `None`
|
|
284
|
+
The device on which the resulting array will be allocated (only if `array namespace` supports it).
|
|
285
|
+
|
|
286
|
+
copy : bool, default to `False`
|
|
287
|
+
Control whether to create a copy of the object when converting to an array.
|
|
288
|
+
|
|
289
|
+
arraylike_only : bool, default to `False`
|
|
290
|
+
Whether to only convert array-like objects to arrays in the specified `array namespace`, and return the object itself if it is not array-like.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
ArrayLike[Any]
|
|
295
|
+
The converted array representation of the object in the specified `array namespace`.
|
|
296
|
+
object
|
|
297
|
+
If :param:`arraylike_only` is `True` and the object is not array-like.
|
|
298
|
+
|
|
299
|
+
Raises
|
|
300
|
+
------
|
|
301
|
+
Refer to :func:`convert.to_xp`, :func:`convert.to_numpy` and :func:`convert.to_tensor` for possible exceptions.
|
|
302
|
+
|
|
303
|
+
AttributeError
|
|
304
|
+
If an unsupported `array namespace` is specified.
|
|
305
|
+
ArrayConversionError
|
|
306
|
+
If an error occurs during array conversion in the specified `array namespace`.
|
|
307
|
+
"""
|
|
308
|
+
if arraylike_only and not api.is_array_api_obj(obj):
|
|
309
|
+
return obj
|
|
310
|
+
|
|
311
|
+
arr_xp = to_xp(xp)
|
|
312
|
+
if api.is_numpy_namespace(arr_xp):
|
|
313
|
+
# as NumPy array
|
|
314
|
+
if device is not None and device != "cpu":
|
|
315
|
+
warn(
|
|
316
|
+
"NumPy array does not support setting a non-CPU device. "
|
|
317
|
+
f"Parameter `device` is ignored, got {device!r}.",
|
|
318
|
+
category=ParameterIgnoredWarning
|
|
319
|
+
)
|
|
320
|
+
return to_numpy(obj, dtype=dtype, copy=copy)
|
|
321
|
+
if api.is_torch_namespace(arr_xp):
|
|
322
|
+
# as PyTorch tensor
|
|
323
|
+
return to_tensor(obj, dtype=dtype, device=device, copy=copy)
|
|
324
|
+
# as other namespace object
|
|
325
|
+
try:
|
|
326
|
+
return arr_xp.asarray(obj, dtype=dtype, device=device, copy=copy) # type: ignore
|
|
327
|
+
except AttributeError:
|
|
328
|
+
raise AttributeError(
|
|
329
|
+
f"Parameter `xp` of `as_array()` is not a supported array namespace: {xp!r}"
|
|
330
|
+
) from None
|
|
331
|
+
except Exception as e:
|
|
332
|
+
raise ArrayConversionError(
|
|
333
|
+
f"An error occurred during array conversion in the specified array namespace {xp!r}"
|
|
334
|
+
) from e
|
cobra_array/convert.pyi
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# src/cobra_array/convert.py
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Python version: 3.9
|
|
4
|
+
# @TianZhen
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
from array_api_compat.common._typing import Namespace
|
|
10
|
+
from typing import (Any, Literal, Optional, List, Iterable, Union, overload)
|
|
11
|
+
|
|
12
|
+
from .types import (
|
|
13
|
+
T, StringT,
|
|
14
|
+
DTypeT, dtypeT, DType, Device,
|
|
15
|
+
ArrayLike, ArrayLibraryName
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@overload
|
|
20
|
+
def to_numpy(obj: NDArray[dtypeT], /, *, dtype: None = ..., copy: bool = ...) -> NDArray[dtypeT]: ...
|
|
21
|
+
@overload
|
|
22
|
+
def to_numpy(obj: ArrayLike[dtypeT], /, *, dtype: None = ..., copy: bool = ...) -> NDArray[dtypeT]: ...
|
|
23
|
+
@overload
|
|
24
|
+
def to_numpy(obj: object, /, *, dtype: None = ..., copy: bool = ...) -> NDArray[Any]: ...
|
|
25
|
+
@overload
|
|
26
|
+
def to_numpy(obj: object, /, *, dtype: DTypeT, copy: bool = ...) -> NDArray[DTypeT]: ...
|
|
27
|
+
def to_numpy(obj: object, /, *, dtype: Optional[DType] = None, copy: bool = True) -> NDArray[Any]: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def to_tensor(obj: object, /, *, dtype: Optional[DType] = None, device: Optional[Device] = None, copy: bool = True) -> Tensor: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@overload
|
|
34
|
+
def to_list(obj: List[T], /, *, copy: bool = ...) -> List[T]: ...
|
|
35
|
+
@overload
|
|
36
|
+
def to_list(obj: StringT, /, *, copy: bool = ...) -> List[StringT]: ...
|
|
37
|
+
@overload
|
|
38
|
+
def to_list(obj: NDArray[dtypeT], /, *, copy: bool = ...) -> List[dtypeT]: ...
|
|
39
|
+
@overload
|
|
40
|
+
def to_list(obj: ArrayLike[dtypeT], /, *, copy: bool = ...) -> List[dtypeT]: ...
|
|
41
|
+
@overload
|
|
42
|
+
def to_list(obj: Iterable[T], /, *, copy: bool = ...) -> List[T]: ...
|
|
43
|
+
@overload
|
|
44
|
+
def to_list(obj: T, /, *, copy: bool = ...) -> List[T]: ...
|
|
45
|
+
@overload
|
|
46
|
+
def to_list(obj: object, /, *, copy: bool = ...) -> List[Any]: ...
|
|
47
|
+
def to_list(obj: object, /, *, copy: bool = True) -> List[Any]: ...
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def to_xp(obj: Union[object, ArrayLibraryName], /) -> Namespace: ...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@overload
|
|
54
|
+
def as_array(obj: NDArray[dtypeT], xp: Literal["numpy"], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: bool = ...) -> NDArray[dtypeT]: ...
|
|
55
|
+
@overload
|
|
56
|
+
def as_array(obj: ArrayLike[dtypeT], xp: Literal["numpy"], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: bool = ...) -> NDArray[dtypeT]: ...
|
|
57
|
+
@overload
|
|
58
|
+
def as_array(obj: object, xp: Literal["numpy"], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> NDArray[Any]: ...
|
|
59
|
+
@overload
|
|
60
|
+
def as_array(obj: object, xp: Literal["numpy"], /, *, dtype: DTypeT, device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> NDArray[DTypeT]: ...
|
|
61
|
+
@overload
|
|
62
|
+
def as_array(obj: object, xp: Literal["torch"], /, *, dtype: Optional[DType] = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> Tensor: ...
|
|
63
|
+
@overload
|
|
64
|
+
def as_array(obj: NDArray[dtypeT], xp: Union[object, ArrayLibraryName], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: bool = ...) -> ArrayLike[dtypeT]: ...
|
|
65
|
+
@overload
|
|
66
|
+
def as_array(obj: ArrayLike[dtypeT], xp: Union[object, ArrayLibraryName], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: bool = ...) -> ArrayLike[dtypeT]: ...
|
|
67
|
+
@overload
|
|
68
|
+
def as_array(obj: object, xp: Union[object, ArrayLibraryName], /, *, dtype: None = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> ArrayLike[Any]: ...
|
|
69
|
+
@overload
|
|
70
|
+
def as_array(obj: object, xp: Union[object, ArrayLibraryName], /, *, dtype: DTypeT, device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> ArrayLike[DTypeT]: ...
|
|
71
|
+
@overload
|
|
72
|
+
def as_array(obj: T, xp: Union[object, ArrayLibraryName], /, *, dtype: Optional[DType] = ..., device: Optional[Device] = ..., copy: bool = ..., arraylike_only: Literal[True]) -> T: ...
|
|
73
|
+
def as_array(obj: object, xp: Union[object, ArrayLibraryName], /, *, dtype: Optional[DType] = None, device: Optional[Device] = None, copy: bool = False, arraylike_only: bool = False) -> Any: ...
|
cobra_array/default.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# src/cobra_array/default.py
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# Python version: 3.9
|
|
4
|
+
# @TianZhen
|
|
5
|
+
"""
|
|
6
|
+
Default utilities for :pkg:`cobra_array`.
|
|
7
|
+
|
|
8
|
+
Attributes
|
|
9
|
+
----------
|
|
10
|
+
- :attr:`DEFAULT_DTYPE`: The default data type for arrays.
|
|
11
|
+
- :attr:`DEFAULT_DEVICE`: The default device for arrays.
|
|
12
|
+
- :attr:`TORCH_COMPAT_NAMESPACE`: The `compatibility namespace` for `PyTorch`.
|
|
13
|
+
- :attr:`NUMPY_COMPAT_NAMESPACE`: The `compatibility namespace` for `NumPy`.
|
|
14
|
+
Functions
|
|
15
|
+
---------
|
|
16
|
+
- :func:`default_spec`: Get the default array specification.
|
|
17
|
+
- :func:`as_default`: Convert an array-like object to a :class:`CompatArray` array in the default context.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
from typing import (Any, Literal, TYPE_CHECKING, NamedTuple, overload)
|
|
22
|
+
|
|
23
|
+
from .compat import (CompatNamespace, wrap_arraylike)
|
|
24
|
+
from .convert import as_array
|
|
25
|
+
from .array_api import (numpy_xp, torch_xp)
|
|
26
|
+
from .exceptions import MissingDependencyError
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from numpy.typing import NDArray
|
|
30
|
+
from .compat import CompatArray
|
|
31
|
+
from .types import (ArrayLike, DType, dtypeT, T)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ArraySpec(NamedTuple):
|
|
35
|
+
"""
|
|
36
|
+
A named tuple to hold the specifications of an array.
|
|
37
|
+
- `cxp`: CompatNamespace
|
|
38
|
+
- `dtype`: DType
|
|
39
|
+
- `device`: Any
|
|
40
|
+
"""
|
|
41
|
+
cxp: CompatNamespace
|
|
42
|
+
dtype: DType
|
|
43
|
+
device: Any
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def create(cls, xp: object, dtype: Any, device: Any) -> ArraySpec:
|
|
47
|
+
"""Create an `ArraySpec` instance, convert the `xp` to a :class:`CompatNamespace` instance if it is not already one."""
|
|
48
|
+
return cls(
|
|
49
|
+
cxp=CompatNamespace(xp),
|
|
50
|
+
dtype=dtype,
|
|
51
|
+
device=device,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# get defaults
|
|
56
|
+
DEFAULT_DTYPE = float
|
|
57
|
+
DEFAULT_DEVICE = "cpu"
|
|
58
|
+
NUMPY_COMPAT_NAMESPACE = CompatNamespace(numpy_xp) if numpy_xp is not None else None
|
|
59
|
+
TORCH_COMPAT_NAMESPACE = CompatNamespace(torch_xp) if torch_xp is not None else None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def default_spec() -> ArraySpec:
|
|
63
|
+
"""
|
|
64
|
+
Try to get a suitable `compatibility namespace` from the available array libraries in order of `PyTorch` > `NumPy`, and return it along with the default `dtype` and `device`.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
ArraySpec
|
|
69
|
+
An :class:`ArraySpec` named tuple containing the default `cxp`(`compatibility namespace`), `dtype` and `device`.
|
|
70
|
+
|
|
71
|
+
Raises
|
|
72
|
+
------
|
|
73
|
+
MissingDependencyError
|
|
74
|
+
If all default array libraries (`NumPy` and `PyTorch`) are missing.
|
|
75
|
+
"""
|
|
76
|
+
default_xp = TORCH_COMPAT_NAMESPACE or NUMPY_COMPAT_NAMESPACE
|
|
77
|
+
if default_xp is None:
|
|
78
|
+
raise MissingDependencyError(
|
|
79
|
+
"Missing all default array libraries (`PyTorch` > `NumPy`)."
|
|
80
|
+
)
|
|
81
|
+
return ArraySpec(default_xp, DEFAULT_DTYPE, DEFAULT_DEVICE)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@overload
|
|
85
|
+
def as_default(obj: NDArray[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Literal["cpu"]]: ...
|
|
86
|
+
@overload
|
|
87
|
+
def as_default(obj: ArrayLike[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Any]: ...
|
|
88
|
+
@overload
|
|
89
|
+
def as_default(obj: ArrayLike[Any], /, *, unify_dtype: Literal[True] = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[Any, Any]: ...
|
|
90
|
+
@overload
|
|
91
|
+
def as_default(obj: object, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> CompatArray[Any, Any]: ...
|
|
92
|
+
@overload
|
|
93
|
+
def as_default(obj: T, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[True]) -> T: ...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def as_default(
|
|
97
|
+
obj: object,
|
|
98
|
+
/, *,
|
|
99
|
+
unify_dtype: bool = True,
|
|
100
|
+
unify_device: bool = True,
|
|
101
|
+
copy: bool = False,
|
|
102
|
+
arraylike_only: bool = False
|
|
103
|
+
) -> Any:
|
|
104
|
+
"""
|
|
105
|
+
Convert an array-like object to a :class:`CompatArray` array in default `compatibility namespace` with the default `dtype` and `device` if specified.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
obj : object
|
|
110
|
+
The object to be converted to a :class:`CompatArray` array.
|
|
111
|
+
|
|
112
|
+
unify_dtype : bool, default to `True`
|
|
113
|
+
Whether to unify the `dtype` of the converted array to that of the default context.
|
|
114
|
+
|
|
115
|
+
unify_device : bool, default to `True`
|
|
116
|
+
Whether to unify the `device` of the converted array to that of the default context.
|
|
117
|
+
|
|
118
|
+
copy : bool, default to `False`
|
|
119
|
+
Whether to return a copy of the array if it is already in the default context namespace.
|
|
120
|
+
|
|
121
|
+
arraylike_only : bool, default to `False`
|
|
122
|
+
Whether to only convert array-like objects to arrays in the default context namespace, and return the object itself if it is not array-like.
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
CompatArray[Any, Any]
|
|
127
|
+
The converted array representation of the object in the default context `compatibility namespace`, with the default `dtype` and `device` if specified.
|
|
128
|
+
object
|
|
129
|
+
If :param:`arraylike_only` is `True` and the object is not array-like.
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
Refer to :func:`convert.as_array`, :func:`default.default_spec` for possible exceptions.
|
|
134
|
+
"""
|
|
135
|
+
spec = default_spec()
|
|
136
|
+
return wrap_arraylike(as_array(
|
|
137
|
+
obj, spec.cxp,
|
|
138
|
+
dtype=spec.dtype if unify_dtype else None,
|
|
139
|
+
device=spec.device if unify_device else None,
|
|
140
|
+
copy=copy,
|
|
141
|
+
arraylike_only=arraylike_only
|
|
142
|
+
), xp=spec.cxp)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# src/cobra_array/exceptions.py
|
|
2
|
+
"""
|
|
3
|
+
Exceptions for :pkg:`cobra_array`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# === WARNING ===
|
|
8
|
+
class CobraArrayWarning(Warning):
|
|
9
|
+
"""Base warning class for :pkg:`cobra_array`."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ParameterIgnoredWarning(CobraArrayWarning):
|
|
13
|
+
"""Warning raised when a parameter is ignored."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# === ERROR ===
|
|
17
|
+
class CobraArrayError(Exception):
|
|
18
|
+
"""Base error class for :pkg:`cobra_array`."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MissingDependencyError(CobraArrayError, ImportError):
|
|
22
|
+
"""Raised when a required dependency is missing."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConvertNoneTypeError(CobraArrayError, TypeError):
|
|
26
|
+
"""Raised when conversion of `NoneType` is attempted."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class UnsupportedNamespaceError(CobraArrayError, ValueError):
|
|
30
|
+
"""Raised when an unsupported `array namespace` is specified."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UnsupportedArrayLibraryNameError(CobraArrayError, ValueError):
|
|
34
|
+
"""Raised when an unsupported `array namespace` name is specified."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ArrayConversionError(CobraArrayError):
|
|
38
|
+
"""Raised when an error occurs during array conversion in an `array namespace`."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NumPyConversionError(CobraArrayError):
|
|
42
|
+
"""Raised when an error occurs during array conversion in `NumPy`."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TorchConversionError(CobraArrayError):
|
|
46
|
+
"""Raised when an error occurs during array conversion in `PyTorch`."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CUDAUnavailableError(CobraArrayError):
|
|
50
|
+
"""Raised when CUDA is not available but a CUDA device is specified."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DeviceNotSupportedError(CobraArrayError):
|
|
54
|
+
"""Raised when a specified device is not supported by the current array library."""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class NotArrayAPIObjectError(CobraArrayError, ValueError):
|
|
58
|
+
"""Raised when an array is not an array API compatible array object."""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class NoArrayInputsError(CobraArrayError, ValueError):
|
|
62
|
+
"""Raised when no array inputs are provided."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class GetArrayNamespaceError(CobraArrayError):
|
|
66
|
+
"""Raised when an error occurs while determining the `array namespace`."""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CompatNamespaceAttributeError(CobraArrayError, AttributeError):
|
|
70
|
+
"""Raised when an attribute is not supported in :class:`CompatNamespace`."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class CompatArrayAttributeError(CobraArrayError, AttributeError):
|
|
74
|
+
"""Raised when an attribute is not supported in :class:`CompatArray`."""
|
cobra_array/types.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# src/cobra_array/types.py
|
|
2
|
+
"""
|
|
3
|
+
Type definitions for :pkg:`cobra_array`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import torch
|
|
8
|
+
from typing import (Union, Protocol, TypeVar, Any, Literal, NamedTuple, Generic, TYPE_CHECKING)
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .compat._array import CompatArray
|
|
12
|
+
|
|
13
|
+
# === Type Aliases ===
|
|
14
|
+
Value = Union[int, float, complex, bool]
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
StringT = TypeVar("StringT", str, bytes)
|
|
18
|
+
ValueT = TypeVar("ValueT", bound=Value)
|
|
19
|
+
|
|
20
|
+
# === DType and Device ===
|
|
21
|
+
DType = Any
|
|
22
|
+
# FIXME 无法推断
|
|
23
|
+
DeviceLiteral = Literal["cpu", "cuda", "xpu", "mkldnn", "opengl", "opencl", "ideep", "hip", "ve", "ort", "mlc", "xla", "lazy", "vulkan", "meta", "hpu"]
|
|
24
|
+
Device = Union[DeviceLiteral, torch.device, str]
|
|
25
|
+
|
|
26
|
+
dtypeT = TypeVar("dtypeT", bound=DType)
|
|
27
|
+
deviceT = TypeVar("deviceT", bound=Device)
|
|
28
|
+
DTypeT = TypeVar("DTypeT", bound=DType)
|
|
29
|
+
DeviceT = TypeVar("DeviceT", bound=Device)
|
|
30
|
+
|
|
31
|
+
DTypeT_co = TypeVar("DTypeT_co", bound=DType, covariant=True)
|
|
32
|
+
DeviceT_co = TypeVar("DeviceT_co", bound=Device, covariant=True)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# === Array Protocols ===
|
|
36
|
+
class ArrayLike(Protocol[DTypeT_co]):
|
|
37
|
+
@property
|
|
38
|
+
def dtype(self) -> DTypeT_co: ...
|
|
39
|
+
@property
|
|
40
|
+
def shape(self) -> Any: ...
|
|
41
|
+
def __array__(self) -> Any: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
ArrayLibraryName = Literal["numpy", "torch"]
|
|
45
|
+
ArrayT = TypeVar("ArrayT", bound=ArrayLike[Any])
|
|
46
|
+
ArrayOrAny = Union[ArrayLike[Any], int, float, complex, bool]
|
|
47
|
+
ArrayOrScalar = Union[ArrayLike[Any], int, float, complex]
|
|
48
|
+
ArrayOrReal = Union[ArrayLike[Any], int, float]
|
|
49
|
+
ArrayOrIntLike = Union[ArrayLike[Any], int, bool]
|
|
50
|
+
ArrayOrbool = Union[ArrayLike[Any], bool]
|
|
51
|
+
ArrayOrInt = Union[ArrayLike[Any], int]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class UniqueAllResult(NamedTuple, Generic[DTypeT_co, DeviceT_co]):
|
|
55
|
+
values: CompatArray[DTypeT_co, DeviceT_co]
|
|
56
|
+
indices: CompatArray[int, DeviceT_co]
|
|
57
|
+
inverse_indices: CompatArray[int, DeviceT_co]
|
|
58
|
+
counts: CompatArray[int, DeviceT_co]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class UniqueCountsResult(NamedTuple, Generic[DTypeT_co, DeviceT_co]):
|
|
62
|
+
values: CompatArray[DTypeT_co, DeviceT_co]
|
|
63
|
+
counts: CompatArray[int, DeviceT_co]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class UniqueInverseResult(NamedTuple, Generic[DTypeT_co, DeviceT_co]):
|
|
67
|
+
values: CompatArray[DTypeT_co, DeviceT_co]
|
|
68
|
+
inverse_indices: CompatArray[int, DeviceT_co]
|