cobra-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,53 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Python version: 3.9
3
+ # @TianZhen
4
+
5
+ from __future__ import annotations
6
+ from typing import (Any, TYPE_CHECKING)
7
+
8
+ from .._utils import array_namespace_alias
9
+
10
+ if TYPE_CHECKING:
11
+ from array_api_compat.common._typing import Namespace
12
+
13
+
14
+ class Compat:
15
+ """
16
+ A base class for creating compatibility wrappers for :class:`CompatNamespace` and :class:`CompatArray`.
17
+ """
18
+ _xp: Namespace
19
+ _xp_name: str
20
+
21
+ _UNWRAP_COMPAT: bool = True
22
+
23
+ def __new__(cls, xp: Any, /):
24
+ obj = super().__new__(cls)
25
+ obj._xp_name = array_namespace_alias(xp)
26
+ obj._xp = xp
27
+
28
+ return obj
29
+
30
+ def _get_xp_attr(self, name: str):
31
+ """Try to get the attribute `name` from the `array namespace`."""
32
+ try:
33
+ return getattr(self._xp, name)
34
+ except AttributeError:
35
+ raise AttributeError(f"Namespace `{self._xp_name}` of `{self.__class__.__name__}` has no attribute `{name}`.") from None
36
+
37
+ @property
38
+ def xp(self) -> Namespace:
39
+ """
40
+ The `array namespace`.
41
+ """
42
+ return self._xp
43
+
44
+ @property
45
+ def xp_name(self) -> str:
46
+ """
47
+ The alias of the `array namespace`.
48
+ """
49
+ return self._xp_name
50
+
51
+ def __array_namespace__(self, *, api_version=None) -> Any:
52
+ """Returns an object that has all the array API functions on it."""
53
+ raise NotImplementedError(f"`__array_namespace__()` is not implemented for `{self.__class__.__name__}`.")
@@ -0,0 +1,305 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Python version: 3.9
3
+ # @TianZhen
4
+
5
+ from __future__ import annotations
6
+
7
+ from ._base import Compat
8
+ from ._array import (CompatArray, wrap_arraylike, unwrap)
9
+ from ..exceptions import CompatNamespaceAttributeError
10
+
11
+
12
+ class CompatNamespace(Compat):
13
+ """
14
+ A wrapper around an `array namespace` that provides a compatibility layer for backend-agnostic array operations.
15
+
16
+ :class:`CompatNamespace` exposes functions from the underlying `array namespace` (e.g., `NumPy`, `PyTorch`) while ensuring compliance with the `Python Array API standard`.
17
+ It includes detailed documentation for functions that are not suitable for an object-oriented interface.
18
+
19
+ All functions preserve the semantics of the underlying namespace, with additional guarantees on input and output handling.
20
+
21
+ Notes
22
+ -----
23
+ - Functions correspond directly to those defined in the underlying `array namespace`, following the `Python Array API standard`.
24
+ - This namespace complements :class:`CompatArray` by providing a functional interface for operations that are not naturally expressed as methods.
25
+ - All functions guarantee that any array-like objects in the returned value are automatically wrapped as :class:`CompatArray`. This conversion is applied recursively to arrays contained in Python containers (e.g., `tuple`, `list`, `dict`). Non-array objects remain unchanged.
26
+ """
27
+ def __new__(cls, xp, /):
28
+ if isinstance(xp, CompatNamespace):
29
+ # for `CompatNamespace` input
30
+ return xp
31
+ # for `Namespace` input
32
+ return super().__new__(cls, xp)
33
+
34
+ # === Creation functions ===
35
+ def meshgrid(self, *arrays, indexing):
36
+ """
37
+ Returns coordinate matrices from coordinate vectors.
38
+
39
+ Parameters
40
+ ----------
41
+ arrays : ArrayLike[Any]
42
+ An arbitrary number of one-dimensional arrays representing grid coordinates.
43
+ Each array should have the same numeric data type.
44
+
45
+ indexing : Literal["xy", "ij"], default to `"xy"`
46
+ Cartesian `"xy"` or matrix `"ij"` indexing of output.
47
+ If provided zero or one one-dimensional vector(s) (i.e., the zero- and one-dimensional cases, respectively), the indexing keyword has no effect and should be ignored.
48
+
49
+ Returns
50
+ -------
51
+ List[CompatArray]
52
+ List of `N` :class:`CompatArray` arrays, where `N` is the number of provided one-dimensional input arrays.
53
+ Each returned array must have rank `N`.
54
+ For `N` one-dimensional arrays having lengths `Ni = len(xi)`,
55
+ - `matrix indexing ij`: Each returned array must have the shape `(N1, N2, N3, ..., Nn)`;
56
+ - `Cartesian indexing xy`: Each returned array must have shape `(N2, N1, N3, ..., Nn)`.
57
+
58
+ Accordingly, for the two-dimensional case with input one-dimensional arrays of length `M` and `N`,
59
+ - `matrix indexing ij`: Each returned array must have shape `(M, N)`;
60
+ - `Cartesian indexing xy`: Each returned array must have shape `(N, M)`.
61
+
62
+ Similarly, for the three-dimensional case with input one-dimensional arrays of length `M`, `N`, and `P`,
63
+ - `matrix indexing ij`: Each returned array must have shape `(M, N, P)`;
64
+ - `Cartesian indexing xy`: Each returned array must have shape `(N, M, P)`.
65
+
66
+ Each returned array should have the same data type as the input arrays.
67
+ """
68
+ result = self._get_xp_attr("meshgrid")(
69
+ *tuple(unwrap(arr) for arr in arrays),
70
+ indexing=indexing
71
+ )
72
+ return [CompatArray(arr, xp=self._xp) for arr in result]
73
+
74
+ # === Data Type functions ===
75
+ def can_cast(self, from_, to, /):
76
+ """
77
+ Determines whether an array can be cast to a different data type according to type promotion rules.
78
+
79
+ Parameters
80
+ ----------
81
+ from_ : Union[DType, ArrayLike[Any]]
82
+ Input array or data type.
83
+
84
+ to : DType
85
+ Output data type.
86
+
87
+ Returns
88
+ -------
89
+ bool
90
+ A boolean indicating whether the cast is possible.
91
+ """
92
+ return self._get_xp_attr("can_cast")(unwrap(from_), to)
93
+
94
+ def finfo(self, type_, /):
95
+ """
96
+ Machine limits for floating-point data types.
97
+
98
+ Parameters
99
+ ----------
100
+ type_ : Union[DType, ArrayLike[Any]]
101
+ The kind of floating-point data-type about which to get information.
102
+ - _complex_: The information is about its component data type.
103
+
104
+ Returns
105
+ -------
106
+ finfo_object
107
+ An object having the following attributes:
108
+ - `bits`: _int_
109
+ Number of bits occupied by the real-valued floating-point data type.
110
+ - `eps`: _float_
111
+ Difference between `1.0` and the next smallest representable real-valued floating-point number larger than 1.0 according to the IEEE-754 standard.
112
+ - `max`: _float_
113
+ Largest representable real-valued number.
114
+ - `min`: _float_
115
+ Smallest representable real-valued number.
116
+ - `smallest_normal`: _float_
117
+ Smallest positive real-valued floating-point number with full precision.
118
+ - `dtype`: _dtype_
119
+ Real-valued floating-point data type.
120
+ """
121
+ return self._get_xp_attr("finfo")(unwrap(type_))
122
+
123
+ def iinfo(self, type_, /):
124
+ """
125
+ Machine limits for integer data types.
126
+
127
+ Parameters
128
+ ----------
129
+ type_ : Union[DType, ArrayLike[Any]]
130
+ The kind of integer data-type about which to get information.
131
+ - _complex_: The information is about its component data type.
132
+
133
+ Returns
134
+ -------
135
+ iinfo_object
136
+ an object having the following attributes:
137
+ - `bits`: _int_
138
+ Number of bits occupied by the integer data type.
139
+ - `max`: _int_
140
+ Largest representable integer value.
141
+ - `min`: _int_
142
+ Smallest representable integer value.
143
+ - `dtype`: _dtype_
144
+ Integer data type.
145
+ """
146
+ return self._get_xp_attr("iinfo")(unwrap(type_))
147
+
148
+ def isdtype(self, dtype, kind) -> bool:
149
+ """
150
+ Returns a boolean indicating whether a provided :param:`dtype` is of a specified data type :param:`kind`.
151
+
152
+ Parameters
153
+ ----------
154
+ dtype : DType
155
+ The input dtype.
156
+
157
+ kind : Union[str, DType, Tuple[Union[str, DType], ...]]
158
+ The data type kind.
159
+ If kind is a dtype, the function must return a boolean indicating whether the input dtype is equal to the dtype specified by kind.
160
+ - _string_: The function must return a boolean indicating whether the input dtype is of a specified data type kind. The following dtype kinds must be supported:
161
+
162
+ - `bool`: boolean data types (e.g., bool);
163
+ - `signed integer`: signed integer data types (e.g., int8, int16, int32, int64);
164
+ - `unsigned integer`: unsigned integer data types (e.g., uint8, uint16, uint32, uint64);
165
+ - `integral`: integer data types. Shorthand for (`signed integer`, `unsigned integer`);
166
+ - `real floating`: real-valued floating-point data types (e.g., float32, float64);
167
+ - `complex floating`: complex floating-point data types (e.g., complex64, complex128);
168
+ - `numeric`: numeric data types. Shorthand for (`integral`, `real floating`, `complex floating`).
169
+ - _tuple_: The tuple specifies a union of `dtypes` and/or `kinds`, and the function must return a boolean indicating whether the input :param:`dtype` is either equal to a specified dtype or belongs to at least one specified data type kind.
170
+
171
+ Returns
172
+ -------
173
+ bool
174
+ A boolean indicating whether the input dtype is of the specified data type kind.
175
+ """
176
+ return self._get_xp_attr("isdtype")(dtype, kind)
177
+
178
+ def result_type(self, *arrays_and_dtypes):
179
+ """
180
+ Returns the `dtype` that results from applying type promotion rules (see Type Promotion Rules) to the arguments.
181
+
182
+ Parameters
183
+ ----------
184
+ arrays_and_dtypes : Union[ArrayOrAny, DType]
185
+ An arbitrary number of input arrays, scalars, and/or dtypes.
186
+
187
+ Returns
188
+ -------
189
+ DType
190
+ The dtype resulting from an operation involving the input arrays, scalars, and/or dtypes.
191
+ """
192
+ return self._get_xp_attr("result_type")(*tuple(unwrap(arr) for arr in arrays_and_dtypes))
193
+
194
+ # === Manipulation functions ===
195
+ def broadcast_arrays(self, *arrays):
196
+ """
197
+ Broadcasts one or more arrays against one another.
198
+
199
+ Parameters
200
+ ----------
201
+ arrays : ArrayLike[Any]
202
+ An arbitrary number of to-be broadcasted arrays.
203
+
204
+ Returns
205
+ -------
206
+ List[CompatArray]
207
+ A list of broadcasted :class:`CompatArray` arrays.
208
+ Each array must have the same shape.
209
+ Each array must have the same dtype as its corresponding input array.
210
+ """
211
+ result = self._get_xp_attr("broadcast_arrays")(*tuple(unwrap(arr) for arr in arrays))
212
+ return [CompatArray(arr, xp=self._xp) for arr in result]
213
+
214
+ # === Constants ===
215
+ @property
216
+ def e(self):
217
+ return self._get_xp_attr("e")
218
+
219
+ @property
220
+ def pi(self):
221
+ return self._get_xp_attr("pi")
222
+
223
+ @property
224
+ def inf(self):
225
+ return self._get_xp_attr("inf")
226
+
227
+ @property
228
+ def nan(self):
229
+ return self._get_xp_attr("nan")
230
+
231
+ @property
232
+ def newaxis(self):
233
+ """An alias for None which is useful for indexing arrays."""
234
+ return self._get_xp_attr("newaxis")
235
+
236
+ # === Data type ===
237
+ @property
238
+ def int8(self):
239
+ return self._get_xp_attr("int8")
240
+
241
+ @property
242
+ def int16(self):
243
+ return self._get_xp_attr("int16")
244
+
245
+ @property
246
+ def int32(self):
247
+ return self._get_xp_attr("int32")
248
+
249
+ @property
250
+ def int64(self):
251
+ return self._get_xp_attr("int64")
252
+
253
+ @property
254
+ def uint8(self):
255
+ return self._get_xp_attr("uint8")
256
+
257
+ @property
258
+ def uint16(self):
259
+ return self._get_xp_attr("uint16")
260
+
261
+ @property
262
+ def uint32(self):
263
+ return self._get_xp_attr("uint32")
264
+
265
+ @property
266
+ def uint64(self):
267
+ return self._get_xp_attr("uint64")
268
+
269
+ @property
270
+ def float32(self):
271
+ return self._get_xp_attr("float32")
272
+
273
+ @property
274
+ def float64(self):
275
+ return self._get_xp_attr("float64")
276
+
277
+ @property
278
+ def complex64(self):
279
+ return self._get_xp_attr("complex64")
280
+
281
+ @property
282
+ def complex128(self):
283
+ return self._get_xp_attr("complex128")
284
+
285
+ @property
286
+ def bool(self):
287
+ return self._get_xp_attr("bool")
288
+
289
+ @property
290
+ def __name__(self):
291
+ return "(compat)" + getattr(self._xp, "__name__", type(self._xp).__name__)
292
+
293
+ def __getattr__(self, name: str):
294
+ attr = self._get_xp_attr(name)
295
+
296
+ if callable(attr):
297
+ def wrapper(*args, **kwargs):
298
+ if not args and not kwargs:
299
+ return wrap_arraylike(attr(), xp=self._xp)
300
+
301
+ new_args = tuple(unwrap(a) for a in args)
302
+ new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} if kwargs else kwargs
303
+ return wrap_arraylike(attr(*new_args, **new_kwargs), xp=self._xp)
304
+ return wrapper
305
+ raise CompatNamespaceAttributeError(f"`CompatNamespace` `{self._xp_name}` does not support attribute `{name}`.")