cobra-array 0.1.4__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {cobra_array-0.1.4/src/cobra_array.egg-info → cobra_array-0.2.0}/PKG-INFO +1 -1
  2. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/__init__.py +1 -5
  3. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/_core.py +2 -2
  4. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/_utils.py +27 -12
  5. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/__init__.py +1 -1
  6. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_array.py +51 -23
  7. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_array.pyi +2 -0
  8. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_base.py +1 -1
  9. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_namespace.py +37 -16
  10. {cobra_array-0.1.4 → cobra_array-0.2.0/src/cobra_array.egg-info}/PKG-INFO +1 -1
  11. cobra_array-0.2.0/tests/test_compat.py +482 -0
  12. cobra_array-0.1.4/tests/test_compat.py +0 -261
  13. {cobra_array-0.1.4 → cobra_array-0.2.0}/LICENSE +0 -0
  14. {cobra_array-0.1.4 → cobra_array-0.2.0}/README.md +0 -0
  15. {cobra_array-0.1.4 → cobra_array-0.2.0}/pyproject.toml +0 -0
  16. {cobra_array-0.1.4 → cobra_array-0.2.0}/setup.cfg +0 -0
  17. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/array_api.py +0 -0
  18. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/compat/_namespace.pyi +0 -0
  19. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/convert.py +0 -0
  20. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/convert.pyi +0 -0
  21. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/default.py +0 -0
  22. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/exceptions.py +0 -0
  23. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array/types.py +0 -0
  24. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/SOURCES.txt +0 -0
  25. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/dependency_links.txt +0 -0
  26. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/requires.txt +0 -0
  27. {cobra_array-0.1.4 → cobra_array-0.2.0}/src/cobra_array.egg-info/top_level.txt +0 -0
  28. {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_backend.py +0 -0
  29. {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_compat_namespace.py +0 -0
  30. {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_convert.py +0 -0
  31. {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_default.py +0 -0
  32. {cobra_array-0.1.4 → cobra_array-0.2.0}/tests/test_wrap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -32,11 +32,8 @@ Examples
32
32
  from cobra_array.convert import to_numpy, to_tensor, to_list
33
33
 
34
34
  data = [[1, 2], [3, 4]]
35
-
36
35
  arr_np = to_numpy(data, dtype=np.float32) # numpy.ndarray float32
37
-
38
36
  arr_torch = to_tensor(data, device="cpu")
39
-
40
37
  back_to_list = to_list(arr_np) # [[1.0, 2.0], [3.0, 4.0]]
41
38
 
42
39
  - Context-based conversion::
@@ -66,7 +63,6 @@ Examples
66
63
  from cobra_array.default import as_default, default_spec
67
64
 
68
65
  spec = default_spec()
69
-
70
66
  x = as_default([1, 2, 3], unify_dtype=True, unify_device=True)
71
67
  """
72
68
 
@@ -84,7 +80,7 @@ from ._utils import (
84
80
  )
85
81
 
86
82
  __author__ = "Zhen Tian"
87
- __version__ = "0.1.4"
83
+ __version__ = "0.2.0"
88
84
 
89
85
  __all__ = [
90
86
  "array_spec",
@@ -500,9 +500,9 @@ def unify_args(
500
500
  return func(*args, **kwargs)
501
501
 
502
502
  with array_context.from_array_spec(spec):
503
- out_args = tuple(
503
+ out_args = [
504
504
  as_context(a, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only) for a in args
505
- )
505
+ ]
506
506
  out_kwargs = {
507
507
  k: as_context(v, unify_dtype=unify_dtype, unify_device=unify_device, arraylike_only=arraylike_only)
508
508
  for k, v in kwargs.items()
@@ -6,7 +6,7 @@ from __future__ import annotations
6
6
  import array_api_compat as api
7
7
  import warnings
8
8
  from types import ModuleType
9
- from typing import Any
9
+ from typing import (Any, overload, Optional, Literal)
10
10
 
11
11
  from .exceptions import UnsupportedNamespaceError
12
12
 
@@ -25,7 +25,13 @@ def warn(msg: str, /, category: Any, stack: int = 2):
25
25
  return warnings.warn(msg, category=category, stacklevel=stack+1)
26
26
 
27
27
 
28
- def array_namespace_alias(xp: object) -> str:
28
+ @overload
29
+ def array_namespace_alias(xp: object, /, *, raise_on_unsupported: Literal[True] = ...) -> str: ...
30
+ @overload
31
+ def array_namespace_alias(xp: object, /, *, raise_on_unsupported: Literal[False]) -> Optional[str]: ...
32
+
33
+
34
+ def array_namespace_alias(xp: object, /, *, raise_on_unsupported: bool = True) -> Optional[str]:
29
35
  """
30
36
  Get the alias of the `array namespace`.
31
37
 
@@ -34,18 +40,23 @@ def array_namespace_alias(xp: object) -> str:
34
40
  xp : object
35
41
  The `array namespace`.
36
42
 
43
+ raise_on_unsupported : bool, optional
44
+ Whether to raise an error for unsupported array namespaces.
45
+
37
46
  Returns
38
47
  -------
39
48
  str
40
49
  The alias of the `array namespace`.
41
50
  - Including: `"NumPy"`, `"Cupy"`, `"PyTorch"`, `"NDONNX"`, `"Dask"`, `"JAX"`, `"sparse"` and `"array-api-strict"`.
51
+ `None`
52
+ If the input is not a supported `array namespace` and `raise_on_unsupported` is `False`.
42
53
 
43
54
  Raises
44
55
  ------
45
56
  UnsupportedNameSpaceError
46
- If the input object is not a supported `array namespace`.
57
+ If the input object is not a supported `array namespace` and :param:`raise_on_unsupported` is `True`.
47
58
  """
48
- if isinstance(xp, ModuleType):
59
+ if type(xp) is ModuleType:
49
60
  if api.is_numpy_namespace(xp):
50
61
  return "NumPy"
51
62
 
@@ -70,20 +81,17 @@ def array_namespace_alias(xp: object) -> str:
70
81
  if api.is_array_api_strict_namespace(xp):
71
82
  return "array-api-strict"
72
83
 
73
- raise UnsupportedNamespaceError(
74
- f"Got unsupported array namespace of type {type(xp)}."
75
- )
84
+ if raise_on_unsupported:
85
+ raise UnsupportedNamespaceError(
86
+ f"Got unsupported array namespace of type {type(xp)}."
87
+ )
76
88
 
77
89
 
78
90
  def is_array_namespace(obj: object) -> bool:
79
91
  """
80
92
  Returns `True` if input is a supported `array namespace`.
81
93
  """
82
- try:
83
- array_namespace_alias(obj)
84
- return True
85
- except UnsupportedNamespaceError:
86
- return False
94
+ return array_namespace_alias(obj, raise_on_unsupported=False) is not None
87
95
 
88
96
 
89
97
  def is_compat_namespace(xp: object) -> bool:
@@ -91,3 +99,10 @@ def is_compat_namespace(xp: object) -> bool:
91
99
  Returns `True` if input is a `compatibility namespace` wrapped by :class:`CompatNamespace`
92
100
  """
93
101
  return "(compat)" in getattr(xp, "__name__", "")
102
+
103
+
104
+ # def is_array_api_object(obj: object) -> bool:
105
+ # """
106
+ # Returns `True` if input is an `array API object` that belongs to a supported `array namespace`.
107
+ # """
108
+ # return api.is_array_api_obj(obj)
@@ -1,6 +1,6 @@
1
1
  # src/cobra_array/compat/__init__.py
2
2
  """
3
- Compatibility utilities for :pkg:`cobra_color`.
3
+ Compatibility utilities for :pkg:`cobra_array`.
4
4
 
5
5
  Functions
6
6
  ---------
@@ -108,16 +108,17 @@ class CompatArray(Compat):
108
108
  _xp = to_xp(xp)
109
109
  return cls(
110
110
  as_array(unwrap(obj), _xp, copy=copy),
111
- xp=_xp
111
+ xp=_xp, check=False
112
112
  )
113
113
 
114
114
  def __new__(cls, arr, /, *, copy=False, **kwargs):
115
- if isinstance(arr, CompatArray):
115
+ if type(arr) is cls:
116
116
  # for `CompatArray` input
117
117
  return arr.copy() if copy else arr
118
118
 
119
+ _check = kwargs.get("check", True)
119
120
  # for non-`CompatArray` input
120
- if not api.is_array_api_obj(arr):
121
+ if _check and not api.is_array_api_obj(arr):
121
122
  raise NotArrayAPIObjectError(
122
123
  f"Parameter `arr` of `CompatArray` must be an array API compatible array object, got {type(arr)}."
123
124
  )
@@ -135,6 +136,8 @@ class CompatArray(Compat):
135
136
  Convert `self` to a `NumPy array`.
136
137
  See also :func:`convert.to_numpy`.
137
138
  """
139
+ if self._xp_name == "NumPy" and not copy:
140
+ return self._arr
138
141
  return to_numpy(self._arr, copy=copy)
139
142
 
140
143
  def to_tensor(self, *, device=None, copy=False):
@@ -142,6 +145,8 @@ class CompatArray(Compat):
142
145
  Convert `self` to a `PyTorch tensor`.
143
146
  See also :func:`convert.to_tensor`.
144
147
  """
148
+ if self._xp_name == "PyTorch" and not copy and (device is None or device == self.device):
149
+ return self._arr
145
150
  return to_tensor(self._arr, device=device, copy=copy)
146
151
 
147
152
  def to_list(self, *, copy=False):
@@ -197,7 +202,7 @@ class CompatArray(Compat):
197
202
  All the arrays have the same shape.
198
203
  """
199
204
  result = self._get_xp_attr("unstack")(self._arr, axis=axis)
200
- return tuple(CompatArray(arr, xp=self._xp) for arr in result)
205
+ return tuple([CompatArray(arr, xp=self._cxp, check=False) for arr in result])
201
206
 
202
207
  # === Searching functions ===
203
208
  def nonzero(self):
@@ -218,7 +223,7 @@ class CompatArray(Compat):
218
223
  - If `self` has a boolean data type, non-zero elements are those elements which are equal to `True`.
219
224
  """
220
225
  result = self._get_xp_attr("nonzero")(self._arr)
221
- return tuple(CompatArray(arr, xp=self._xp) for arr in result)
226
+ return tuple([CompatArray(arr, xp=self._cxp) for arr in result])
222
227
 
223
228
  # === Set functions ===
224
229
  def unique_all(self):
@@ -238,10 +243,10 @@ class CompatArray(Compat):
238
243
  """
239
244
  result = self._get_xp_attr("unique_all")(self._arr)
240
245
  return UniqueResult(
241
- values=CompatArray(result.values, xp=self._xp),
242
- indices=CompatArray(result.indices, xp=self._xp),
243
- inverse_indices=CompatArray(result.inverse_indices, xp=self._xp),
244
- counts=CompatArray(result.counts, xp=self._xp),
246
+ values=CompatArray(result.values, xp=self._cxp, check=False),
247
+ indices=CompatArray(result.indices, xp=self._cxp, check=False),
248
+ inverse_indices=CompatArray(result.inverse_indices, xp=self._cxp, check=False),
249
+ counts=CompatArray(result.counts, xp=self._cxp, check=False),
245
250
  )
246
251
 
247
252
  def unique_counts(self):
@@ -259,10 +264,10 @@ class CompatArray(Compat):
259
264
  """
260
265
  result = self._get_xp_attr("unique_counts")(self._arr)
261
266
  return UniqueResult(
262
- values=CompatArray(result.values, xp=self._xp),
267
+ values=CompatArray(result.values, xp=self._cxp, check=False),
263
268
  indices=None,
264
269
  inverse_indices=None,
265
- counts=CompatArray(result.counts, xp=self._xp),
270
+ counts=CompatArray(result.counts, xp=self._cxp, check=False),
266
271
  )
267
272
 
268
273
  def unique_inverse(self):
@@ -280,9 +285,9 @@ class CompatArray(Compat):
280
285
  """
281
286
  result = self._get_xp_attr("unique_inverse")(self._arr)
282
287
  return UniqueResult(
283
- values=CompatArray(result.values, xp=self._xp),
288
+ values=CompatArray(result.values, xp=self._cxp, check=False),
284
289
  indices=None,
285
- inverse_indices=CompatArray(result.inverse_indices, xp=self._xp),
290
+ inverse_indices=CompatArray(result.inverse_indices, xp=self._cxp, check=False),
286
291
  counts=None,
287
292
  )
288
293
 
@@ -291,7 +296,7 @@ class CompatArray(Compat):
291
296
  """
292
297
  Return a copy of `self` via :func:`convert.as_array`.
293
298
  """
294
- return CompatArray.from_other(self._arr, xp=self._xp, copy=True)
299
+ return CompatArray.from_other(self._arr, xp=self._cxp, copy=True)
295
300
 
296
301
  def _get_attr(self, name: str):
297
302
  """Try to get the attribute `name` from `self`."""
@@ -379,7 +384,7 @@ class CompatArray(Compat):
379
384
  result = self._get_xp_attr("T")(self._arr)
380
385
  except (AttributeError, TypeError):
381
386
  result = self._get_attr("T")
382
- return CompatArray(result, xp=self._xp)
387
+ return CompatArray(result, xp=self._cxp, check=False)
383
388
 
384
389
  @property
385
390
  def mT(self):
@@ -391,19 +396,19 @@ class CompatArray(Compat):
391
396
  result = self._get_xp_attr("mT")(self._arr)
392
397
  except (AttributeError, TypeError):
393
398
  result = self._get_attr("mT")
394
- return CompatArray(result, xp=self._xp)
399
+ return CompatArray(result, xp=self._cxp, check=False)
395
400
 
396
401
  def __array__(self):
397
402
  """Allow implicit NumPy conversion."""
398
403
  return self.to_numpy()
399
404
 
400
405
  def __getattr__(self, name: str):
401
- attr = self._get_cxp_attr(name)
406
+ attr = self._get_xp_attr(name)
402
407
 
403
408
  if callable(attr) and not isinstance(attr, type):
404
- def wrapper(*args, **kwargs):
405
- return attr(self._arr, *args, **kwargs)
406
- return wrapper
409
+ wrapped = _make_wrapper(self._xp_name, attr, self._cxp, self._arr)
410
+ self.__dict__[name] = wrapped
411
+ return wrapped
407
412
  raise CompatArrayAttributeError(f"`CompatArray` `{self._xp_name}` does not support attribute `{name}`.")
408
413
 
409
414
  def __len__(self):
@@ -540,7 +545,7 @@ def unwrap(obj):
540
545
  """
541
546
  Unwraps a :class:`CompatArray` array to get the backend-specific array instance, or returns the object itself if it is not a :class:`CompatArray` array.
542
547
  """
543
- return obj.arr if isinstance(obj, CompatArray) else obj
548
+ return obj.arr if type(obj) is CompatArray else obj
544
549
 
545
550
 
546
551
  def wrap_arraylike(arr, xp=None):
@@ -549,8 +554,8 @@ def wrap_arraylike(arr, xp=None):
549
554
  """
550
555
  if api.is_array_api_obj(arr):
551
556
  if xp is None:
552
- return CompatArray(arr)
553
- return CompatArray(arr, xp=xp)
557
+ return CompatArray(arr, check=False)
558
+ return CompatArray(arr, xp=xp, check=False)
554
559
  return arr
555
560
 
556
561
 
@@ -559,3 +564,26 @@ def to_cxp(xp):
559
564
  from ._namespace import CompatNamespace
560
565
 
561
566
  return CompatNamespace(xp)
567
+
568
+
569
+ def _make_wrapper(xp_name, attr, cxp, first, /):
570
+ """Make a wrapper function for the attribute `name` of the `array namespace`."""
571
+ wrap_ = lambda x: wrap_arraylike(x, xp=cxp)
572
+ unwrap_ = unwrap
573
+
574
+ def wrapper(*args, **kwargs):
575
+ if xp_name == "NumPy":
576
+ return wrap_(attr(first, *args, **kwargs))
577
+
578
+ n = len(args)
579
+ if n == 0 and not kwargs:
580
+ return wrap_(attr(first))
581
+ if n == 1 and not kwargs:
582
+ return wrap_(attr(first, unwrap_(args[0])))
583
+ new_args = [unwrap_(a) for a in args]
584
+ if kwargs:
585
+ new_kwargs = {k: unwrap_(v) for k, v in kwargs.items()}
586
+ return wrap_(attr(first, *new_args, **new_kwargs))
587
+ return wrap_(attr(first, *new_args))
588
+ wrapper.__name__ = getattr(attr, "__name__", "wrapper")
589
+ return wrapper
@@ -50,6 +50,8 @@ class CompatArray(Compat, Generic[TT, DT]):
50
50
  def __new__(cls, arr: CompatArray[dtypeT, deviceT], /, **kwargs) -> CompatArray[dtypeT, deviceT]: ...
51
51
  @overload
52
52
  def __new__(cls, arr: ArrayLike[dtypeT], /, **kwargs) -> CompatArray[dtypeT, AnyDevice]: ...
53
+ @overload
54
+ def __new__(cls, arr: object, /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
53
55
  def __new__(cls, arr: ArrayLike[Any], /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
54
56
 
55
57
  # === Conversion functions ===
@@ -3,7 +3,7 @@
3
3
  # @TianZhen
4
4
 
5
5
  from __future__ import annotations
6
- from typing import (Any, TYPE_CHECKING)
6
+ from typing import (Any, TYPE_CHECKING, Callable)
7
7
 
8
8
  from .._utils import array_namespace_alias
9
9
 
@@ -50,7 +50,7 @@ class CompatNamespace(Compat):
50
50
  AttributeError: ...
51
51
  """
52
52
  def __new__(cls, xp, /):
53
- if isinstance(xp, CompatNamespace):
53
+ if type(xp) is cls:
54
54
  # for `CompatNamespace` input
55
55
  return xp
56
56
  # for `Namespace` input
@@ -91,10 +91,10 @@ class CompatNamespace(Compat):
91
91
  Each returned array should have the same data type as the input arrays.
92
92
  """
93
93
  result = self._get_xp_attr("meshgrid")(
94
- *tuple(unwrap(arr) for arr in arrays),
94
+ *[unwrap(arr) for arr in arrays],
95
95
  indexing=indexing
96
96
  )
97
- return [CompatArray(arr, xp=self._xp) for arr in result]
97
+ return [CompatArray(arr, xp=self) for arr in result]
98
98
 
99
99
  # === Data Type functions ===
100
100
  def can_cast(self, from_, to, /):
@@ -170,7 +170,7 @@ class CompatNamespace(Compat):
170
170
  """
171
171
  return self._get_xp_attr("iinfo")(unwrap(type_))
172
172
 
173
- def isdtype(self, dtype, kind) -> bool:
173
+ def isdtype(self, dtype, kind):
174
174
  """
175
175
  Returns a boolean indicating whether a provided :param:`dtype` is of a specified data type :param:`kind`.
176
176
 
@@ -214,7 +214,7 @@ class CompatNamespace(Compat):
214
214
  DType
215
215
  The dtype resulting from an operation involving the input arrays, scalars, and/or dtypes.
216
216
  """
217
- return self._get_xp_attr("result_type")(*tuple(unwrap(arr) for arr in arrays_and_dtypes))
217
+ return self._get_xp_attr("result_type")(*[unwrap(arr) for arr in arrays_and_dtypes])
218
218
 
219
219
  # === Manipulation functions ===
220
220
  def broadcast_arrays(self, *arrays):
@@ -233,8 +233,8 @@ class CompatNamespace(Compat):
233
233
  Each array must have the same shape.
234
234
  Each array must have the same dtype as its corresponding input array.
235
235
  """
236
- result = self._get_xp_attr("broadcast_arrays")(*tuple(unwrap(arr) for arr in arrays))
237
- return [CompatArray(arr, xp=self._xp) for arr in result]
236
+ result = self._get_xp_attr("broadcast_arrays")(*[unwrap(arr) for arr in arrays])
237
+ return [CompatArray(arr, xp=self) for arr in result]
238
238
 
239
239
  # === Constants ===
240
240
  @property
@@ -315,16 +315,37 @@ class CompatNamespace(Compat):
315
315
  def __name__(self):
316
316
  return "(compat)" + getattr(self._xp, "__name__", type(self._xp).__name__)
317
317
 
318
- def __getattr__(self, name: str):
318
+ def __getattr__(self, name):
319
319
  attr = self._get_xp_attr(name)
320
320
 
321
321
  if callable(attr):
322
- def wrapper(*args, **kwargs):
323
- if not args and not kwargs:
324
- return wrap_arraylike(attr(), xp=self._xp)
325
-
326
- new_args = tuple(unwrap(a) for a in args)
327
- new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} if kwargs else kwargs
328
- return wrap_arraylike(attr(*new_args, **new_kwargs), xp=self._xp)
329
- return wrapper
322
+ wrapped = _make_wrapper(self._xp_name, attr, self)
323
+ self.__dict__[name] = wrapped
324
+ return wrapped
330
325
  raise CompatNamespaceAttributeError(f"`CompatNamespace` `{self._xp_name}` does not support attribute `{name}`.")
326
+
327
+
328
+ def _make_wrapper(xp_name, attr, cxp, /):
329
+ """Make a wrapper function for the attribute `name` of the `array namespace`."""
330
+ wrap_ = lambda x: wrap_arraylike(x, xp=cxp)
331
+ unwrap_ = unwrap
332
+
333
+ def wrapper(*args, **kwargs):
334
+ if xp_name == "NumPy":
335
+ return wrap_(attr(*args, **kwargs))
336
+
337
+ n = len(args)
338
+ if n == 0 and not kwargs:
339
+ return wrap_(attr())
340
+ if n == 1 and not kwargs:
341
+ return wrap_(attr(unwrap_(args[0])))
342
+ if n == 2 and not kwargs:
343
+ a0, a1 = args
344
+ return wrap_(attr(unwrap_(a0), unwrap_(a1)))
345
+ new_args = [unwrap_(a) for a in args]
346
+ if kwargs:
347
+ new_kwargs = {k: unwrap_(v) for k, v in kwargs.items()}
348
+ return wrap_(attr(*new_args, **new_kwargs))
349
+ return wrap_(attr(*new_args))
350
+ wrapper.__name__ = getattr(attr, "__name__", "wrapper")
351
+ return wrapper
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.1.4
3
+ Version: 0.2.0
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -0,0 +1,482 @@
1
+ import sys
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import time
7
+ from types import ModuleType
8
+
9
+ sys.path.insert(0, "/data/tianzhen/my_packages/cobra-array/src")
10
+
11
+ from cobra_array.array_api import torch_xp
12
+ from cobra_array.compat import CompatArray, unwrap, wrap_arraylike
13
+ from cobra_array.exceptions import CompatArrayAttributeError, NotArrayAPIObjectError
14
+ from cobra_array.compat._array import UniqueResult
15
+
16
+
17
+ def _arr_1d():
18
+ return CompatArray(np.array([1.0, 2.0, 3.0], dtype=np.float32))
19
+
20
+
21
+ def _arr_2d():
22
+ return CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
23
+
24
+
25
+ def test_from_other_with_numpy_namespace():
26
+ out = CompatArray.from_other([1, 2, 3], xp="numpy")
27
+ assert isinstance(out, CompatArray)
28
+ assert out.xp_name == "NumPy"
29
+ assert out.to_list() == [1, 2, 3]
30
+
31
+
32
+ def test_new_from_compatarray_copy_and_no_copy():
33
+ a = _arr_1d()
34
+ b = CompatArray(a)
35
+ c = CompatArray(a, copy=True)
36
+
37
+ assert b is a
38
+ assert c is not a
39
+ assert c.to_list() == a.to_list()
40
+
41
+
42
+ def test_new_invalid_input_raises():
43
+ with pytest.raises(NotArrayAPIObjectError):
44
+ CompatArray("not-array") # type: ignore[arg-type]
45
+
46
+
47
+ def test_to_numpy_to_list_and_array_protocol():
48
+ a = _arr_1d()
49
+
50
+ n1 = a.to_numpy(copy=False)
51
+ n2 = np.asarray(a)
52
+
53
+ assert isinstance(n1, np.ndarray)
54
+ assert isinstance(n2, np.ndarray)
55
+ assert n1.tolist() == [1.0, 2.0, 3.0]
56
+ assert n2.tolist() == [1.0, 2.0, 3.0]
57
+
58
+
59
+ def test_to_device_on_numpy_backend():
60
+ a = _arr_1d()
61
+ moved = a.to_device("cpu")
62
+
63
+ assert isinstance(moved, np.ndarray)
64
+ assert moved.tolist() == a.to_list()
65
+
66
+
67
+ @pytest.mark.skipif(torch_xp is None, reason="PyTorch not available")
68
+ def test_to_tensor_available_backend():
69
+ a = _arr_1d()
70
+ t = a.to_tensor(device="cpu")
71
+
72
+ assert isinstance(t, torch_xp.Tensor)
73
+ assert tuple(t.shape) == (3,)
74
+
75
+
76
+ def test_unstack_and_nonzero():
77
+ a = _arr_2d()
78
+
79
+ pieces = a.unstack(axis=0)
80
+ nz = a.nonzero()
81
+
82
+ assert isinstance(pieces, tuple)
83
+ assert len(pieces) == 2
84
+ assert all(isinstance(x, CompatArray) for x in pieces)
85
+ assert pieces[0].to_list() == [1.0, 2.0]
86
+
87
+ assert isinstance(nz, tuple)
88
+ assert len(nz) == 2
89
+ assert all(isinstance(x, CompatArray) for x in nz)
90
+
91
+
92
+ def test_unique_all_unique_counts_unique_inverse():
93
+ a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
94
+
95
+ print(type(a) is CompatArray)
96
+
97
+ start = time.process_time()
98
+ all_result = a.cxp.unique_all(a)
99
+ # counts_result = a.unique_counts()
100
+ # inverse_result = a.unique_inverse()
101
+
102
+ # print(all_result.values)
103
+
104
+ # assert isinstance(all_result.values, np.ndarray)
105
+ # assert isinstance(all_result.indices, np.ndarray)
106
+ # assert isinstance(all_result.inverse_indices, np.ndarray)
107
+ # assert isinstance(all_result.counts, np.ndarray)
108
+
109
+ print(f"(cxp1)Unique operations test took {time.process_time() - start:.5f} seconds")
110
+
111
+
112
+ start = time.process_time()
113
+ all_result = a.unique_all()
114
+ # counts_result = a.unique_counts()
115
+ # inverse_result = a.unique_inverse()
116
+
117
+ # assert isinstance(all_result.values, CompatArray)
118
+ # assert isinstance(all_result.indices, CompatArray)
119
+ # assert isinstance(all_result.inverse_indices, CompatArray)
120
+ # assert isinstance(all_result.counts, CompatArray)
121
+
122
+ print(f"Unique operations test took {time.process_time() - start:.5f} seconds")
123
+
124
+ start = time.process_time()
125
+ all_result = a.cxp.unique_all(a)
126
+ # counts_result = a.unique_counts()
127
+ # inverse_result = a.unique_inverse()
128
+
129
+ # print(all_result.values)
130
+
131
+ # assert isinstance(all_result.values, np.ndarray)
132
+ # assert isinstance(all_result.indices, np.ndarray)
133
+ # assert isinstance(all_result.inverse_indices, np.ndarray)
134
+ # assert isinstance(all_result.counts, np.ndarray)
135
+
136
+ print(f"(cxp)Unique operations test took {time.process_time() - start:.5f} seconds")
137
+
138
+ # assert isinstance(counts_result.values, CompatArray)
139
+ # assert isinstance(counts_result.counts, CompatArray)
140
+
141
+ # assert isinstance(inverse_result.values, CompatArray)
142
+ # assert isinstance(inverse_result.inverse_indices, CompatArray)
143
+
144
+ a = np.array([1, 2, 2, 1, 3], dtype=np.int32)
145
+
146
+ start = time.process_time()
147
+ all_result = np.unique(a, return_index=True, return_inverse=True, return_counts=True)
148
+
149
+ UniqueResult(
150
+ CompatArray(all_result[0], xp=np),
151
+ CompatArray(all_result[1], xp=np),
152
+ CompatArray(all_result[2], xp=np),
153
+ CompatArray(all_result[3], xp=np)
154
+ )
155
+
156
+ print(f"(RAW)Unique operations test took {time.process_time() - start:.5f} seconds")
157
+
158
+ # start = time.process_time()
159
+ # all_result = torch_xp.unique(torch.tensor(a), return_inverse=True, return_counts=True)
160
+
161
+ # print(f"API operations test took {time.process_time() - start:.4f} seconds")
162
+
163
+ from array_api_compat import get_namespace
164
+ xp = get_namespace(a)
165
+
166
+ start = time.process_time()
167
+ all_result = xp.unique_all(a)
168
+
169
+ UniqueResult(
170
+ CompatArray(all_result.values, xp=xp),
171
+ CompatArray(all_result.indices, xp=xp),
172
+ CompatArray(all_result.inverse_indices, xp=xp),
173
+ CompatArray(all_result.counts, xp=xp)
174
+ )
175
+ assert isinstance(all_result.values, xp.ndarray)
176
+ assert isinstance(all_result.indices, xp.ndarray)
177
+ assert isinstance(all_result.inverse_indices, xp.ndarray)
178
+ assert isinstance(all_result.counts, xp.ndarray)
179
+
180
+ print(f"API operations test took {time.process_time() - start:.5f} seconds")
181
+
182
+ print(type(xp) is ModuleType)
183
+
184
+
185
+ def test_add():
186
+
187
+ a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
188
+ print(type(a) is CompatArray)
189
+
190
+ start = time.time()
191
+ all_result = a.add(1)
192
+ all_result = a.add(1)
193
+ # counts_result = a.unique_counts()
194
+ # inverse_result = a.unique_inverse()
195
+
196
+ print(f"add operations test took {time.time() - start:.5f} seconds")
197
+
198
+ start = time.time()
199
+ all_result = a.cxp.add(a, 1)
200
+ all_result = a.cxp.add(a, 1)
201
+
202
+ print(f"(cxp)add operations test took {time.time() - start:.5f} seconds")
203
+
204
+ a = np.array([1, 2, 2, 1, 3], dtype=np.int32)
205
+
206
+ start = time.time()
207
+ all_result = np.add(a, 1)
208
+
209
+ print(f"(RAW)add operations test took {time.time() - start:.5f} seconds")
210
+
211
+ from array_api_compat import get_namespace
212
+ xp = get_namespace(a)
213
+
214
+ start = time.time()
215
+ all_result = xp.add(a, 1)
216
+
217
+ print(f"API operations test took {time.time() - start:.5f} seconds")
218
+
219
+
220
+ def test_copy_and_basic_properties():
221
+ a = _arr_2d()
222
+ b = a.copy()
223
+
224
+ assert isinstance(a.arr, np.ndarray)
225
+ assert isinstance(b, CompatArray)
226
+ assert b is not a
227
+
228
+ assert str(a.device) == "cpu"
229
+ assert a.shape == (2, 2)
230
+ assert a.ndim == 2
231
+ assert a.size == 4
232
+
233
+
234
+ def test_transpose_properties():
235
+ a = _arr_2d()
236
+
237
+ t = a.T
238
+
239
+ assert isinstance(t, CompatArray)
240
+ assert t.shape == (2, 2)
241
+
242
+ with pytest.raises(AttributeError):
243
+ _ = a.mT
244
+
245
+
246
+ def test_len_and_repr():
247
+ a = _arr_1d()
248
+ z = CompatArray(np.array(5))
249
+
250
+ assert len(a) == 3
251
+ assert "NumPy_Array" in repr(a)
252
+
253
+ with pytest.raises(TypeError):
254
+ len(z)
255
+
256
+
257
+ def test_getitem_and_setitem():
258
+ a = CompatArray(np.array([10, 20, 30], dtype=np.int32))
259
+
260
+ assert int(a[1]) == 20
261
+
262
+ a[1] = 99
263
+ assert a.to_list() == [10, 99, 30]
264
+
265
+
266
+ def test_scalar_conversions():
267
+ b = CompatArray(np.array(True))
268
+ i = CompatArray(np.array(7))
269
+ f = CompatArray(np.array(1.5, dtype=np.float32))
270
+ c = CompatArray(np.array(2 + 3j, dtype=np.complex64))
271
+
272
+ assert bool(b) is True
273
+ assert int(i) == 7
274
+ assert i.__index__() == 7
275
+ assert float(f) == pytest.approx(1.5)
276
+ assert complex(c) == complex(2, 3)
277
+
278
+
279
+ def test_numpy_operator_overloads():
280
+
281
+
282
+ a = np.array([1, 2, 3], dtype=np.int32)
283
+ b = np.array([3, 2, 1], dtype=np.int32)
284
+
285
+ start = time.process_time()
286
+ assert a.__abs__().tolist() == [1, 2, 3]
287
+ assert (a + b).tolist() == [4, 4, 4]
288
+ assert (a - b).tolist() == [-2, 0, 2]
289
+ assert (a * b).tolist() == [3, 4, 3]
290
+ assert (a / 2).tolist() == [0.5, 1.0, 1.5]
291
+ assert (a // 2).tolist() == [0, 1, 1]
292
+ assert (a % 2).tolist() == [1, 0, 1]
293
+ assert (a & b).tolist() == [1, 2, 1]
294
+ assert (a | b).tolist() == [3, 2, 3]
295
+ assert (a ^ b).tolist() == [2, 0, 2]
296
+ assert (a << 1).tolist() == [2, 4, 6]
297
+ assert (a >> 1).tolist() == [0, 1, 1]
298
+
299
+ assert (-a).tolist() == [-1, -2, -3]
300
+ assert (+a).tolist() == [1, 2, 3]
301
+ assert (~a).tolist() == [-2, -3, -4]
302
+
303
+ assert (a == b).tolist() == [False, True, False]
304
+ assert (a != b).tolist() == [True, False, True]
305
+ assert (a > b).tolist() == [False, False, True]
306
+ assert (a >= b).tolist() == [False, True, True]
307
+ assert (a < b).tolist() == [True, False, False]
308
+ assert (a <= b).tolist() == [True, True, False]
309
+
310
+ print(f"NumPy operator overloads test took {time.process_time() - start:.4f} seconds")
311
+
312
+ a = CompatArray(np.array([1, 2, 3], dtype=np.int32))
313
+ b = CompatArray(np.array([3, 2, 1], dtype=np.int32))
314
+
315
+ start = time.process_time()
316
+ assert a.__abs__().to_list() == [1, 2, 3]
317
+ assert (a + b).to_list() == [4, 4, 4]
318
+ assert (a - b).to_list() == [-2, 0, 2]
319
+ assert (a * b).to_list() == [3, 4, 3]
320
+ assert (a / 2).to_list() == [0.5, 1.0, 1.5]
321
+ assert (a // 2).to_list() == [0, 1, 1]
322
+ assert (a % 2).to_list() == [1, 0, 1]
323
+ assert (a & b).to_list() == [1, 2, 1]
324
+ assert (a | b).to_list() == [3, 2, 3]
325
+ assert (a ^ b).to_list() == [2, 0, 2]
326
+ assert (a << 1).to_list() == [2, 4, 6]
327
+ assert (a >> 1).to_list() == [0, 1, 1]
328
+
329
+ assert (-a).to_list() == [-1, -2, -3]
330
+ assert (+a).to_list() == [1, 2, 3]
331
+ assert (~a).to_list() == [-2, -3, -4]
332
+
333
+ assert (a == b).to_list() == [False, True, False]
334
+ assert (a != b).to_list() == [True, False, True]
335
+ assert (a > b).to_list() == [False, False, True]
336
+ assert (a >= b).to_list() == [False, True, True]
337
+ assert (a < b).to_list() == [True, False, False]
338
+ assert (a <= b).to_list() == [True, True, False]
339
+
340
+ print(f"Operator overloads test took {time.process_time() - start:.4f} seconds")
341
+
342
+
343
+ # def test_torch_operator_overloads():
344
+ # a = CompatArray(torch.tensor([1, 2, 3], dtype=torch.int32))
345
+ # b = CompatArray(torch.tensor([3, 2, 1], dtype=torch.int32))
346
+
347
+ # start = time.process_time()
348
+ # assert a.__abs__().to_list() == [1, 2, 3]
349
+ # assert (a + b).to_list() == [4, 4, 4]
350
+ # assert (a - b).to_list() == [-2, 0, 2]
351
+ # assert (a * b).to_list() == [3, 4, 3]
352
+ # assert (a / 2).to_list() == [0.5, 1.0, 1.5]
353
+ # assert (a // 2).to_list() == [0, 1, 1]
354
+ # assert (a % 2).to_list() == [1, 0, 1]
355
+ # assert (a & b).to_list() == [1, 2, 1]
356
+ # assert (a | b).to_list() == [3, 2, 3]
357
+ # assert (a ^ b).to_list() == [2, 0, 2]
358
+ # assert (a << 1).to_list() == [2, 4, 6]
359
+ # assert (a >> 1).to_list() == [0, 1, 1]
360
+
361
+ # assert (-a).to_list() == [-1, -2, -3]
362
+ # assert (+a).to_list() == [1, 2, 3]
363
+ # assert (~a).to_list() == [-2, -3, -4]
364
+
365
+ # assert (a == b).to_list() == [False, True, False]
366
+ # assert (a != b).to_list() == [True, False, True]
367
+ # assert (a > b).to_list() == [False, False, True]
368
+ # assert (a >= b).to_list() == [False, True, True]
369
+ # assert (a < b).to_list() == [True, False, False]
370
+ # assert (a <= b).to_list() == [True, True, False]
371
+
372
+ # print(f"Operator overloads test took {time.process_time() - start:.4f} seconds")
373
+
374
+ # a = torch.tensor([1, 2, 3], dtype=torch.int32)
375
+ # b = torch.tensor([3, 2, 1], dtype=torch.int32)
376
+
377
+ # start = time.process_time()
378
+ # assert a.__abs__().tolist() == [1, 2, 3]
379
+ # assert (a + b).tolist() == [4, 4, 4]
380
+ # assert (a - b).tolist() == [-2, 0, 2]
381
+ # assert (a * b).tolist() == [3, 4, 3]
382
+ # assert (a / 2).tolist() == [0.5, 1.0, 1.5]
383
+ # assert (a // 2).tolist() == [0, 1, 1]
384
+ # assert (a % 2).tolist() == [1, 0, 1]
385
+ # assert (a & b).tolist() == [1, 2, 1]
386
+ # assert (a | b).tolist() == [3, 2, 3]
387
+ # assert (a ^ b).tolist() == [2, 0, 2]
388
+ # assert (a << 1).tolist() == [2, 4, 6]
389
+ # assert (a >> 1).tolist() == [0, 1, 1]
390
+
391
+ # assert (-a).tolist() == [-1, -2, -3]
392
+ # assert (+a).tolist() == [1, 2, 3]
393
+ # assert (~a).tolist() == [-2, -3, -4]
394
+
395
+ # assert (a == b).tolist() == [False, True, False]
396
+ # assert (a != b).tolist() == [True, False, True]
397
+ # assert (a > b).tolist() == [False, False, True]
398
+ # assert (a >= b).tolist() == [False, True, True]
399
+ # assert (a < b).tolist() == [True, False, False]
400
+ # assert (a <= b).tolist() == [True, True, False]
401
+
402
+ # print(f"Torch operator overloads test took {time.process_time() - start:.4f} seconds")
403
+
404
+
405
+ def test_pow_operator_matches_current_implementation_behavior():
406
+ a = _arr_1d()
407
+ out = a ** 2
408
+ assert isinstance(out, CompatArray)
409
+ assert out.to_list() == [1.0, 4.0, 9.0]
410
+
411
+
412
+ def test_matmul_operator():
413
+ x = CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
414
+ y = CompatArray(np.array([[2.0, 0.0], [1.0, 2.0]], dtype=np.float32))
415
+
416
+ out = x @ y
417
+
418
+ assert isinstance(out, CompatArray)
419
+ assert out.to_list() == [[4.0, 4.0], [10.0, 8.0]]
420
+
421
+
422
+ def test_getattr_rejects_non_callable_namespace_attrs():
423
+ a = _arr_1d()
424
+ with pytest.raises(CompatArrayAttributeError):
425
+ _ = getattr(a, "pi")
426
+
427
+
428
+ def test_module_helpers_unwrap_wrap_to_cxp():
429
+ raw = np.array([1, 2, 3])
430
+ wrapped = wrap_arraylike(raw)
431
+
432
+ assert isinstance(wrapped, CompatArray)
433
+ assert unwrap(wrapped) is raw
434
+ assert unwrap(raw) is raw
435
+
436
+ obj: Any = {"k": 1}
437
+ assert wrap_arraylike(obj) is obj
438
+
439
+ out = CompatArray(raw, xp=np)
440
+ assert out.xp_name == "NumPy"
441
+
442
+
443
+ def test_some_pyi_annotated_methods_via_dynamic_dispatch():
444
+ a = _arr_2d()
445
+
446
+ b = a.astype(np.float64)
447
+ c = a.abs()
448
+ d = a.add(1)
449
+ s = a.sum()
450
+ r = a.reshape((4,))
451
+ t = a.take(np.array([0], dtype=np.int32), axis=1)
452
+
453
+ assert isinstance(b, CompatArray)
454
+ assert str(b.dtype) == "float64"
455
+ assert isinstance(c, CompatArray)
456
+ assert isinstance(d, CompatArray)
457
+ assert isinstance(s, CompatArray)
458
+ assert isinstance(r, CompatArray)
459
+ assert isinstance(t, CompatArray)
460
+
461
+ assert r.shape == (4,)
462
+ assert t.shape == (2, 1)
463
+
464
+
465
+ def test_cxp_of_compatarray_matches_array_namespace():
466
+ a = _arr_1d()
467
+ cxp = a.cxp
468
+
469
+ assert cxp is not None
470
+ assert cxp.xp_name == a.xp_name
471
+
472
+
473
+ if __name__ == "__main__":
474
+ test_unique_all_unique_counts_unique_inverse()
475
+ print("=" * 40)
476
+
477
+ test_numpy_operator_overloads()
478
+ print("=" * 40)
479
+ # test_torch_operator_overloads()
480
+ print("=" * 40)
481
+
482
+ test_add()
@@ -1,261 +0,0 @@
1
- import sys
2
- from typing import Any
3
-
4
- import numpy as np
5
- import pytest
6
-
7
- sys.path.insert(0, "/data/tianzhen/my_packages/cobra-array/src")
8
-
9
- from cobra_array.array_api import torch_xp
10
- from cobra_array.compat import CompatArray, unwrap, wrap_arraylike
11
- from cobra_array.exceptions import CompatArrayAttributeError, NotArrayAPIObjectError
12
-
13
-
14
- def _arr_1d():
15
- return CompatArray(np.array([1.0, 2.0, 3.0], dtype=np.float32))
16
-
17
-
18
- def _arr_2d():
19
- return CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
20
-
21
-
22
- def test_from_other_with_numpy_namespace():
23
- out = CompatArray.from_other([1, 2, 3], xp="numpy")
24
- assert isinstance(out, CompatArray)
25
- assert out.xp_name == "NumPy"
26
- assert out.to_list() == [1, 2, 3]
27
-
28
-
29
- def test_new_from_compatarray_copy_and_no_copy():
30
- a = _arr_1d()
31
- b = CompatArray(a)
32
- c = CompatArray(a, copy=True)
33
-
34
- assert b is a
35
- assert c is not a
36
- assert c.to_list() == a.to_list()
37
-
38
-
39
- def test_new_invalid_input_raises():
40
- with pytest.raises(NotArrayAPIObjectError):
41
- CompatArray("not-array") # type: ignore[arg-type]
42
-
43
-
44
- def test_to_numpy_to_list_and_array_protocol():
45
- a = _arr_1d()
46
-
47
- n1 = a.to_numpy(copy=False)
48
- n2 = np.asarray(a)
49
-
50
- assert isinstance(n1, np.ndarray)
51
- assert isinstance(n2, np.ndarray)
52
- assert n1.tolist() == [1.0, 2.0, 3.0]
53
- assert n2.tolist() == [1.0, 2.0, 3.0]
54
-
55
-
56
- def test_to_device_on_numpy_backend():
57
- a = _arr_1d()
58
- moved = a.to_device("cpu")
59
-
60
- assert isinstance(moved, np.ndarray)
61
- assert moved.tolist() == a.to_list()
62
-
63
-
64
- @pytest.mark.skipif(torch_xp is None, reason="PyTorch not available")
65
- def test_to_tensor_available_backend():
66
- a = _arr_1d()
67
- t = a.to_tensor(device="cpu")
68
-
69
- assert isinstance(t, torch_xp.Tensor)
70
- assert tuple(t.shape) == (3,)
71
-
72
-
73
- def test_unstack_and_nonzero():
74
- a = _arr_2d()
75
-
76
- pieces = a.unstack(axis=0)
77
- nz = a.nonzero()
78
-
79
- assert isinstance(pieces, tuple)
80
- assert len(pieces) == 2
81
- assert all(isinstance(x, CompatArray) for x in pieces)
82
- assert pieces[0].to_list() == [1.0, 2.0]
83
-
84
- assert isinstance(nz, tuple)
85
- assert len(nz) == 2
86
- assert all(isinstance(x, CompatArray) for x in nz)
87
-
88
-
89
- def test_unique_all_unique_counts_unique_inverse():
90
- a = CompatArray(np.array([1, 2, 2, 1, 3], dtype=np.int32))
91
-
92
- all_result = a.unique_all()
93
- counts_result = a.unique_counts()
94
- inverse_result = a.unique_inverse()
95
-
96
- assert isinstance(all_result.values, CompatArray)
97
- assert isinstance(all_result.indices, CompatArray)
98
- assert isinstance(all_result.inverse_indices, CompatArray)
99
- assert isinstance(all_result.counts, CompatArray)
100
-
101
- assert isinstance(counts_result.values, CompatArray)
102
- assert isinstance(counts_result.counts, CompatArray)
103
-
104
- assert isinstance(inverse_result.values, CompatArray)
105
- assert isinstance(inverse_result.inverse_indices, CompatArray)
106
-
107
-
108
- def test_copy_and_basic_properties():
109
- a = _arr_2d()
110
- b = a.copy()
111
-
112
- assert isinstance(a.arr, np.ndarray)
113
- assert isinstance(b, CompatArray)
114
- assert b is not a
115
-
116
- assert str(a.device) == "cpu"
117
- assert a.shape == (2, 2)
118
- assert a.ndim == 2
119
- assert a.size == 4
120
-
121
-
122
- def test_transpose_properties():
123
- a = _arr_2d()
124
-
125
- t = a.T
126
-
127
- assert isinstance(t, CompatArray)
128
- assert t.shape == (2, 2)
129
-
130
- with pytest.raises(AttributeError):
131
- _ = a.mT
132
-
133
-
134
- def test_len_and_repr():
135
- a = _arr_1d()
136
- z = CompatArray(np.array(5))
137
-
138
- assert len(a) == 3
139
- assert "NumPy_Array" in repr(a)
140
-
141
- with pytest.raises(TypeError):
142
- len(z)
143
-
144
-
145
- def test_getitem_and_setitem():
146
- a = CompatArray(np.array([10, 20, 30], dtype=np.int32))
147
-
148
- assert int(a[1]) == 20
149
-
150
- a[1] = 99
151
- assert a.to_list() == [10, 99, 30]
152
-
153
-
154
- def test_scalar_conversions():
155
- b = CompatArray(np.array(True))
156
- i = CompatArray(np.array(7))
157
- f = CompatArray(np.array(1.5, dtype=np.float32))
158
- c = CompatArray(np.array(2 + 3j, dtype=np.complex64))
159
-
160
- assert bool(b) is True
161
- assert int(i) == 7
162
- assert i.__index__() == 7
163
- assert float(f) == pytest.approx(1.5)
164
- assert complex(c) == complex(2, 3)
165
-
166
-
167
- def test_operator_overloads():
168
- a = CompatArray(np.array([1, 2, 3], dtype=np.int32))
169
- b = CompatArray(np.array([3, 2, 1], dtype=np.int32))
170
-
171
- assert a.__abs__().to_list() == [1, 2, 3]
172
- assert (a + b).to_list() == [4, 4, 4]
173
- assert (a - b).to_list() == [-2, 0, 2]
174
- assert (a * b).to_list() == [3, 4, 3]
175
- assert (a / 2).to_list() == [0.5, 1.0, 1.5]
176
- assert (a // 2).to_list() == [0, 1, 1]
177
- assert (a % 2).to_list() == [1, 0, 1]
178
- assert (a & b).to_list() == [1, 2, 1]
179
- assert (a | b).to_list() == [3, 2, 3]
180
- assert (a ^ b).to_list() == [2, 0, 2]
181
- assert (a << 1).to_list() == [2, 4, 6]
182
- assert (a >> 1).to_list() == [0, 1, 1]
183
-
184
- assert (-a).to_list() == [-1, -2, -3]
185
- assert (+a).to_list() == [1, 2, 3]
186
- assert (~a).to_list() == [-2, -3, -4]
187
-
188
- assert (a == b).to_list() == [False, True, False]
189
- assert (a != b).to_list() == [True, False, True]
190
- assert (a > b).to_list() == [False, False, True]
191
- assert (a >= b).to_list() == [False, True, True]
192
- assert (a < b).to_list() == [True, False, False]
193
- assert (a <= b).to_list() == [True, True, False]
194
-
195
-
196
- def test_pow_operator_matches_current_implementation_behavior():
197
- a = _arr_1d()
198
- out = a ** 2
199
- assert isinstance(out, CompatArray)
200
- assert out.to_list() == [1.0, 4.0, 9.0]
201
-
202
-
203
- def test_matmul_operator():
204
- x = CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
205
- y = CompatArray(np.array([[2.0, 0.0], [1.0, 2.0]], dtype=np.float32))
206
-
207
- out = x @ y
208
-
209
- assert isinstance(out, CompatArray)
210
- assert out.to_list() == [[4.0, 4.0], [10.0, 8.0]]
211
-
212
-
213
- def test_getattr_rejects_non_callable_namespace_attrs():
214
- a = _arr_1d()
215
- with pytest.raises(CompatArrayAttributeError):
216
- _ = getattr(a, "pi")
217
-
218
-
219
- def test_module_helpers_unwrap_wrap_to_cxp():
220
- raw = np.array([1, 2, 3])
221
- wrapped = wrap_arraylike(raw)
222
-
223
- assert isinstance(wrapped, CompatArray)
224
- assert unwrap(wrapped) is raw
225
- assert unwrap(raw) is raw
226
-
227
- obj: Any = {"k": 1}
228
- assert wrap_arraylike(obj) is obj
229
-
230
- out = CompatArray(raw, xp=np)
231
- assert out.xp_name == "NumPy"
232
-
233
-
234
- def test_some_pyi_annotated_methods_via_dynamic_dispatch():
235
- a = _arr_2d()
236
-
237
- b = a.astype(np.float64)
238
- c = a.abs()
239
- d = a.add(1)
240
- s = a.sum()
241
- r = a.reshape((4,))
242
- t = a.take(np.array([0], dtype=np.int32), axis=1)
243
-
244
- assert isinstance(b, CompatArray)
245
- assert str(b.dtype) == "float64"
246
- assert isinstance(c, CompatArray)
247
- assert isinstance(d, CompatArray)
248
- assert isinstance(s, CompatArray)
249
- assert isinstance(r, CompatArray)
250
- assert isinstance(t, CompatArray)
251
-
252
- assert r.shape == (4,)
253
- assert t.shape == (2, 1)
254
-
255
-
256
- def test_cxp_of_compatarray_matches_array_namespace():
257
- a = _arr_1d()
258
- cxp = a.cxp
259
-
260
- assert cxp is not None
261
- assert cxp.xp_name == a.xp_name
File without changes
File without changes
File without changes
File without changes