cobra-array 0.1.3__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {cobra_array-0.1.3/src/cobra_array.egg-info → cobra_array-0.1.5}/PKG-INFO +1 -1
  2. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/__init__.py +1 -5
  3. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/_core.py +98 -52
  4. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/__init__.py +1 -1
  5. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/_array.py +8 -1
  6. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/_array.pyi +5 -0
  7. {cobra_array-0.1.3 → cobra_array-0.1.5/src/cobra_array.egg-info}/PKG-INFO +1 -1
  8. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_backend.py +22 -18
  9. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_compat.py +8 -0
  10. {cobra_array-0.1.3 → cobra_array-0.1.5}/LICENSE +0 -0
  11. {cobra_array-0.1.3 → cobra_array-0.1.5}/README.md +0 -0
  12. {cobra_array-0.1.3 → cobra_array-0.1.5}/pyproject.toml +0 -0
  13. {cobra_array-0.1.3 → cobra_array-0.1.5}/setup.cfg +0 -0
  14. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/_utils.py +0 -0
  15. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/array_api.py +0 -0
  16. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/_base.py +0 -0
  17. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/_namespace.py +0 -0
  18. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/compat/_namespace.pyi +0 -0
  19. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/convert.py +0 -0
  20. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/convert.pyi +0 -0
  21. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/default.py +0 -0
  22. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/exceptions.py +0 -0
  23. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array/types.py +0 -0
  24. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array.egg-info/SOURCES.txt +0 -0
  25. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array.egg-info/dependency_links.txt +0 -0
  26. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array.egg-info/requires.txt +0 -0
  27. {cobra_array-0.1.3 → cobra_array-0.1.5}/src/cobra_array.egg-info/top_level.txt +0 -0
  28. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_compat_namespace.py +0 -0
  29. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_convert.py +0 -0
  30. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_default.py +0 -0
  31. {cobra_array-0.1.3 → cobra_array-0.1.5}/tests/test_wrap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -32,11 +32,8 @@ Examples
32
32
  from cobra_array.convert import to_numpy, to_tensor, to_list
33
33
 
34
34
  data = [[1, 2], [3, 4]]
35
-
36
35
  arr_np = to_numpy(data, dtype=np.float32) # numpy.ndarray float32
37
-
38
36
  arr_torch = to_tensor(data, device="cpu")
39
-
40
37
  back_to_list = to_list(arr_np) # [[1.0, 2.0], [3.0, 4.0]]
41
38
 
42
39
  - Context-based conversion::
@@ -66,7 +63,6 @@ Examples
66
63
  from cobra_array.default import as_default, default_spec
67
64
 
68
65
  spec = default_spec()
69
-
70
66
  x = as_default([1, 2, 3], unify_dtype=True, unify_device=True)
71
67
  """
72
68
 
@@ -84,7 +80,7 @@ from ._utils import (
84
80
  )
85
81
 
86
82
  __author__ = "Zhen Tian"
87
- __version__ = "0.1.3"
83
+ __version__ = "0.1.5"
88
84
 
89
85
  __all__ = [
90
86
  "array_spec",
@@ -15,7 +15,6 @@ from .default import (ArraySpec, default_spec)
15
15
  from .exceptions import (
16
16
  NoArrayInputsError,
17
17
  GetArrayNamespaceError,
18
- MissingDependencyError,
19
18
  NotArrayAPIObjectError
20
19
  )
21
20
 
@@ -25,11 +24,19 @@ if TYPE_CHECKING:
25
24
  from .types import (T, dtypeT, DType, AnyDevice, ArrayLike, ArrayLibraryName)
26
25
 
27
26
 
27
+ def _fallback_default_spec(fallback: bool, exc: Exception, /, from_: Optional[BaseException] = None) -> ArraySpec:
28
+ """Return the default array specification when fallback is enabled, otherwise re-raise the original exception."""
29
+ if fallback:
30
+ return default_spec()
31
+ raise exc from from_
32
+
33
+
28
34
  def array_spec(
29
35
  *arrays: object,
30
36
  kw_arrays: Optional[Dict[str, object]] = None,
31
37
  ref: Optional[Union[str, int]] = None,
32
38
  filter_arraylike: bool = False,
39
+ fallback: bool = False,
33
40
  api_version: Optional[str] = None,
34
41
  use_compat: Optional[bool] = None
35
42
  ) -> ArraySpec:
@@ -53,6 +60,11 @@ def array_spec(
53
60
  filter_arraylike : bool, default to `False`
54
61
  Whether to filter the provided inputs to all array-likes via :func:`array_api_compat.is_array_api_obj` when determining the `compatibility namespace`.
55
62
 
63
+ fallback : bool, default to `False`
64
+ Whether to fall back to the default `compatibility namespace` when failing to determine the `compatibility namespace` from the provided inputs or the reference array.
65
+ - `True`: Return the default `compatibility namespace` from :func:`default.default_spec` when an error occurs;
66
+ - `False`: Raise the original exception when an error occurs.
67
+
56
68
  api_version : Optional[str], default to `None`
57
69
  The target array API version for the returned `compatibility namespace`. See also :param:`api_version` in :func:`array_api_compat.array_namespace`.
58
70
 
@@ -73,18 +85,22 @@ def array_spec(
73
85
 
74
86
  Raises
75
87
  ------
88
+ Refer to :func:`default.default_spec` for possible exceptions.
89
+
76
90
  NoArrayInputsError
77
- If no array inputs are provided in `arrays` and `kw_arrays`.
91
+ If no array inputs are provided in `arrays` and `kw_arrays`, and :param:`fallback` is `False`.
78
92
  GetArrayNamespaceError
79
- If an error occurs while determining the `compatibility namespace` from the provided inputs or the reference array.
93
+ If an error occurs while determining the `compatibility namespace` from the provided inputs or the reference array, and :param:`fallback` is `False`.
80
94
  KeyError
81
- If `ref` is a string but not a key in `kw_arrays`.
95
+ If `ref` is a string but not a key in `kw_arrays`, and :param:`fallback` is `False`.
82
96
  IndexError
83
- If `ref` is an integer but out of range for the array inputs.
97
+ If `ref` is an integer but out of range for the array inputs, and :param:`fallback` is `False`.
98
+ ValueError
99
+ If `ref` is an integer but negative.
84
100
  TypeError
85
101
  If `ref` is not `None`, a string, or an integer.
86
102
  NotArrayAPIObjectError
87
- If the reference array determined by `ref` is not an array API compatible array object.
103
+ If the reference array determined by `ref` is not an array API compatible array object, and :param:`fallback` is `False`.
88
104
 
89
105
  Examples
90
106
  --------
@@ -116,7 +132,10 @@ def array_spec(
116
132
 
117
133
  if len(arrays) == 0 and len(kw_arrays) == 0:
118
134
  # no arrays provided
119
- raise NoArrayInputsError("Expected at least one array input.")
135
+ return _fallback_default_spec(
136
+ fallback,
137
+ NoArrayInputsError("Expected at least one array input.")
138
+ )
120
139
 
121
140
  # iterator of all arrays
122
141
  all_arrays = chain(arrays, kw_arrays.values())
@@ -134,37 +153,48 @@ def array_spec(
134
153
  ), None, None
135
154
  )
136
155
  except Exception as e:
137
- raise GetArrayNamespaceError(
138
- "Failed to determine the `compatibility namespace` from the provided inputs."
139
- ) from e
156
+ return _fallback_default_spec(
157
+ fallback,
158
+ GetArrayNamespaceError(
159
+ "Failed to determine the `compatibility namespace` from "
160
+ "the provided inputs."
161
+ ), from_=e
162
+ )
140
163
 
141
164
  if isinstance(ref, str):
142
165
  # use the specified kw_array as reference array to determine the namespace
143
166
  try:
144
167
  ref_arr = kw_arrays[ref]
145
168
  except KeyError:
146
- raise KeyError(
147
- f"Parameter `ref` of `array_spec()` must be a key in `kw_arrays`, got {ref!r}."
148
- ) from None
169
+ return _fallback_default_spec(
170
+ fallback,
171
+ KeyError(
172
+ "Parameter `ref` of `array_spec()` must be a key in "
173
+ f"`kw_arrays`, got {ref!r}."
174
+ )
175
+ )
149
176
  elif isinstance(ref, int):
150
177
  # use the specified array as reference array to determine the namespace
151
178
  try:
152
179
  ref_arr = next(islice(all_arrays, ref, ref + 1))
153
180
  except ValueError:
154
- raise IndexError(
181
+ raise ValueError(
155
182
  "Parameter `ref` of `array_spec()` must be a "
156
183
  f"non-negative index for the array inputs, got {ref!r}."
157
184
  ) from None
158
185
  except StopIteration:
159
- if filter_arraylike:
186
+ try:
187
+ if filter_arraylike:
188
+ raise IndexError(
189
+ "Parameter `ref` of `array_spec()` is out of range "
190
+ f"for the array-like inputs, got {ref!r}."
191
+ ) from None
160
192
  raise IndexError(
161
- "Parameter `ref` of `array_spec()` is out of range "
162
- f"for the array-like inputs, got {ref!r}."
193
+ "Parameter `ref` of `array_spec()` must be in the range "
194
+ f"[0, {len(arrays) + len(kw_arrays)}) for the array inputs, got {ref!r}."
163
195
  ) from None
164
- raise IndexError(
165
- "Parameter `ref` of `array_spec()` must be in the range "
166
- f"[0, {len(arrays) + len(kw_arrays)}) for the array inputs, got {ref!r}."
167
- ) from None
196
+ except IndexError as e:
197
+ return _fallback_default_spec(fallback, e)
168
198
  else:
169
199
  raise TypeError(
170
200
  "Parameter `ref` of `array_spec()` must be `str`, "
@@ -172,8 +202,11 @@ def array_spec(
172
202
  )
173
203
 
174
204
  if not filter_arraylike and not api.is_array_api_obj(ref_arr):
175
- raise NotArrayAPIObjectError(
176
- f"Reference array must be an array API compatible array object, got {ref_arr!r}."
205
+ return _fallback_default_spec(
206
+ fallback,
207
+ NotArrayAPIObjectError(
208
+ f"Reference array must be an array API compatible array object, got {ref_arr!r}."
209
+ )
177
210
  )
178
211
 
179
212
  dtype = getattr(ref_arr, "dtype", None)
@@ -187,19 +220,28 @@ def array_spec(
187
220
  ), dtype, device
188
221
  )
189
222
  except Exception as e:
190
- raise GetArrayNamespaceError(
191
- "Failed to determine the `compatibility namespace` from the reference array."
192
- ) from e
223
+ return _fallback_default_spec(
224
+ fallback,
225
+ GetArrayNamespaceError(
226
+ f"Failed to determine the `compatibility namespace` from the reference array {ref_arr!r}."
227
+ ), from_=e
228
+ )
193
229
 
194
230
 
195
231
  # initialize the context variable for array specification
196
232
  _arr_spec_var = ContextVar("arr_spec")
197
233
 
198
234
 
199
- def context_spec() -> ArraySpec:
235
+ def context_spec(fallback: bool = True) -> ArraySpec:
200
236
  """
201
237
  Get the `compatibility namespace`, `dtype` and `device` associated with the most recent :func:`unify_array_args`-decorated function call in the current context.
202
- If there is no such function call in the current context, return the default `compatibility namespace`, `dtype` and `device` from :func:`default.default_spec`.
238
+
239
+ Parameters
240
+ ----------
241
+ fallback : bool, default to `True`
242
+ Whether to fall back to the default `compatibility namespace` when there is no such function call in the current context.
243
+ - `True`: Return the default `compatibility namespace` from :func:`default.default_spec` when the error occurs;
244
+ - `False`: Raise `LookupError` when the error occurs.
203
245
 
204
246
  Returns
205
247
  -------
@@ -209,23 +251,26 @@ def context_spec() -> ArraySpec:
209
251
  Raises
210
252
  ------
211
253
  Refer to :func:`default.default_spec` for possible exceptions.
254
+
255
+ LookupError
256
+ If there is no :func:`unify_array_args`-decorated function call in the current context, and :param:`fallback` is `False`.
212
257
  """
213
258
  try:
214
259
  return _arr_spec_var.get()
215
- except LookupError:
216
- return default_spec()
260
+ except LookupError as e:
261
+ return _fallback_default_spec(fallback, e)
217
262
 
218
263
 
219
264
  @overload
220
- def as_context(obj: NDArray[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, Literal["cpu"]]: ...
265
+ def as_context(obj: NDArray[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ..., fallback: bool = ...) -> CompatArray[dtypeT, Literal["cpu"]]: ...
221
266
  @overload
222
- def as_context(obj: ArrayLike[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[dtypeT, AnyDevice]: ...
267
+ def as_context(obj: ArrayLike[dtypeT], /, *, unify_dtype: Literal[False], unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ..., fallback: bool = ...) -> CompatArray[dtypeT, AnyDevice]: ...
223
268
  @overload
224
- def as_context(obj: ArrayLike[Any], /, *, unify_dtype: Literal[True] = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ...) -> CompatArray[Any, AnyDevice]: ...
269
+ def as_context(obj: ArrayLike[Any], /, *, unify_dtype: Literal[True] = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: bool = ..., fallback: bool = ...) -> CompatArray[Any, AnyDevice]: ...
225
270
  @overload
226
- def as_context(obj: object, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[False] = ...) -> CompatArray[Any, AnyDevice]: ...
271
+ def as_context(obj: object, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[False] = ..., fallback: bool = ...) -> CompatArray[Any, AnyDevice]: ...
227
272
  @overload
228
- def as_context(obj: T, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[True]) -> T: ...
273
+ def as_context(obj: T, /, *, unify_dtype: bool = ..., unify_device: bool = ..., copy: bool = ..., arraylike_only: Literal[True], fallback: bool = ...) -> T: ...
229
274
 
230
275
 
231
276
  def as_context(
@@ -234,7 +279,8 @@ def as_context(
234
279
  unify_dtype: bool = True,
235
280
  unify_device: bool = True,
236
281
  copy: bool = False,
237
- arraylike_only: bool = False
282
+ arraylike_only: bool = False,
283
+ fallback: bool = True
238
284
  ) -> Any:
239
285
  """
240
286
  Convert the given object to a :class:`CompatArray` array in the current context `compatibility namespace`, with the `dtype` and `device` unified to the current context if specified.
@@ -256,6 +302,11 @@ def as_context(
256
302
  arraylike_only : bool, default to `False`
257
303
  Whether to only convert array-like objects to arrays in the current context namespace, and return the object itself if it is not array-like.
258
304
 
305
+ fallback : bool, default to `True`
306
+ Whether to fall back to the default `compatibility namespace` when there is no such function call in the current context.
307
+ - `True`: Return the default `compatibility namespace` from :func:`default.default_spec` when the error occurs;
308
+ - `False`: Raise `LookupError` when the error occurs.
309
+
259
310
  Returns
260
311
  -------
261
312
  CompatArray[Any, AnyDevice]
@@ -273,7 +324,7 @@ def as_context(
273
324
  >>> as_context([1, 2, 3])
274
325
  PyTorch_Array(tensor([1., 2., 3.], dtype=torch.float64))
275
326
  """
276
- spec = context_spec()
327
+ spec = context_spec(fallback=fallback)
277
328
  return wrap_arraylike(as_array(
278
329
  obj, spec.cxp,
279
330
  dtype=spec.dtype if unify_dtype else None,
@@ -356,7 +407,7 @@ class array_context:
356
407
  The target `device` for the context.
357
408
  - `None`: Use the `device` from the context.
358
409
  """
359
- spec = context_spec()
410
+ spec = context_spec(fallback=True)
360
411
  self.cur_spec = ArraySpec.create(
361
412
  to_xp(xp) if xp is not None else spec.cxp,
362
413
  dtype if dtype is not None else spec.dtype,
@@ -382,7 +433,7 @@ def unify_args(
382
433
  unify_dtype: bool = False,
383
434
  unify_device: bool = True,
384
435
  arraylike_only: bool = True,
385
- strict: bool = True
436
+ fallback: bool = True
386
437
  ):
387
438
  """
388
439
  **Decorator** to unify arguments of a function to the same `compatibility namespace`, `dtype` and `device` determined by the provided array arguments and the reference array.
@@ -411,10 +462,10 @@ def unify_args(
411
462
  arraylike_only : bool, default to `True`
412
463
  Whether to only convert array-like objects to arrays in the determined namespace, and return the object itself if it is not array-like.
413
464
 
414
- strict : bool, default to `True`
415
- Whether to raise exceptions when failing to determine the `compatibility namespace`.
416
- - `True`: Raise exceptions when failing to determine the `compatibility namespace` from the provided inputs or the reference array;
417
- - `False`: Fall back to the default `compatibility namespace` if an error occurs. If all default array libraries are missing, just run the function without conversion.
465
+ fallback : bool, default to `True`
466
+ Whether to raise exceptions when failing to determine the `compatibility namespace` from the provided inputs or the reference array.
467
+ - `True`: Fall back to the default `compatibility namespace` from :func:`default.default_spec` when an error occurs;
468
+ - `False`: Just run the function without conversion when an error occurs.
418
469
 
419
470
  Raises
420
471
  ------
@@ -440,18 +491,13 @@ def unify_args(
440
491
  spec = array_spec(
441
492
  *args, kw_arrays=kwargs, ref=ref,
442
493
  filter_arraylike=filter_arraylike,
494
+ fallback=fallback,
443
495
  api_version=api_version,
444
496
  use_compat=use_compat
445
497
  )
446
- except Exception as e:
447
- if strict:
448
- raise e
449
- try:
450
- # fall back to the default namespace
451
- spec = default_spec()
452
- except MissingDependencyError:
453
- # just run the function without conversion
454
- return func(*args, **kwargs)
498
+ except Exception:
499
+ # just run the function without conversion
500
+ return func(*args, **kwargs)
455
501
 
456
502
  with array_context.from_array_spec(spec):
457
503
  out_args = tuple(
@@ -1,6 +1,6 @@
1
1
  # src/cobra_array/compat/__init__.py
2
2
  """
3
- Compatibility utilities for :pkg:`cobra_color`.
3
+ Compatibility utilities for :pkg:`cobra_array`.
4
4
 
5
5
  Functions
6
6
  ---------
@@ -310,10 +310,17 @@ class CompatArray(Compat):
310
310
  @property
311
311
  def arr(self):
312
312
  """
313
- The backend-specific array instance managed by :class:`CompatArray`.
313
+ The backend-specific array instance wrapped by `self`.
314
314
  """
315
315
  return self._arr
316
316
 
317
+ @property
318
+ def cxp(self):
319
+ """
320
+ The `compatibility namespace` associated with `self`.
321
+ """
322
+ return self._cxp
323
+
317
324
  @property
318
325
  def dtype(self):
319
326
  """
@@ -8,6 +8,7 @@ from numpy.typing import NDArray
8
8
  from typing import (Union, List, Tuple, Optional, Any, Sequence, Generic, TypeVar, Literal, overload)
9
9
 
10
10
  from ._base import Compat
11
+ from ._namespace import CompatNamespace
11
12
  from ..types import (
12
13
  T, DTypeT, DeviceT, dtypeT, deviceT, DType, AnyDevice,
13
14
  ArrayLike, ArrayLibraryName,
@@ -49,6 +50,8 @@ class CompatArray(Compat, Generic[TT, DT]):
49
50
  def __new__(cls, arr: CompatArray[dtypeT, deviceT], /, **kwargs) -> CompatArray[dtypeT, deviceT]: ...
50
51
  @overload
51
52
  def __new__(cls, arr: ArrayLike[dtypeT], /, **kwargs) -> CompatArray[dtypeT, AnyDevice]: ...
53
+ @overload
54
+ def __new__(cls, arr: object, /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
52
55
  def __new__(cls, arr: ArrayLike[Any], /, **kwargs) -> CompatArray[Any, AnyDevice]: ...
53
56
 
54
57
  # === Conversion functions ===
@@ -1750,6 +1753,8 @@ class CompatArray(Compat, Generic[TT, DT]):
1750
1753
  @property
1751
1754
  def arr(self) -> ArrayLike[TT]: ...
1752
1755
  @property
1756
+ def cxp(self) -> CompatNamespace: ...
1757
+ @property
1753
1758
  def dtype(self) -> TT: ...
1754
1759
  @property
1755
1760
  def device(self) -> DT: ...
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -86,21 +86,24 @@ def test_as_context_arraylike_only_passthrough():
86
86
 
87
87
 
88
88
  def test_unify_args_strict_true_raises_for_non_array_inputs():
89
- @unify_args(filter_arraylike=True, strict=True)
89
+ @unify_args(filter_arraylike=True, fallback=False)
90
90
  def fn(a, b):
91
91
  return a, b
92
92
 
93
- with pytest.raises(IndexError):
94
- fn("x", {"y": 1})
93
+ a, b = fn("x", {"y": 1})
94
+ assert a == "x"
95
+ assert b == {"y": 1}
95
96
 
96
97
 
97
98
  def test_unify_args_strict_false_fallback_and_convert():
98
- @unify_args(filter_arraylike=True, strict=False, arraylike_only=False)
99
+ @unify_args(filter_arraylike=True, fallback=True, arraylike_only=False)
99
100
  def fn(a):
100
- return a, context_spec().cxp.xp_name
101
+ return a, default_spec().cxp.xp_name
101
102
 
102
103
  out, xp_name = fn([1, 2, 3])
104
+
103
105
  assert isinstance(out, CompatArray)
106
+ assert out.cxp.xp_name == xp_name
104
107
  assert xp_name in ("NumPy", "PyTorch")
105
108
 
106
109
 
@@ -123,17 +126,18 @@ def test_resolve_device_torch_checks():
123
126
 
124
127
 
125
128
  if __name__ == "__main__":
126
- test_array_context_override_and_restore()
127
- test_as_context_converts_under_current_context()
128
- test_as_context_arraylike_only_passthrough()
129
- test_unify_args_strict_true_raises_for_non_array_inputs()
130
- test_unify_args_strict_false_fallback_and_convert()
131
- test_resolve_device_basic_and_numpy_constraints()
132
- test_resolve_device_torch_checks()
133
- test_array_context_override_and_restore()
134
- test_as_context_converts_under_current_context()
135
- test_as_context_arraylike_only_passthrough()
136
- test_unify_args_strict_true_raises_for_non_array_inputs()
129
+ # test_array_context_override_and_restore()
130
+ # test_as_context_converts_under_current_context()
131
+ # test_as_context_arraylike_only_passthrough()
132
+ # test_unify_args_strict_true_raises_for_non_array_inputs()
133
+ # test_unify_args_strict_false_fallback_and_convert()
134
+ # test_resolve_device_basic_and_numpy_constraints()
135
+ # test_resolve_device_torch_checks()
136
+ # test_array_context_override_and_restore()
137
+ # test_as_context_converts_under_current_context()
138
+ # test_as_context_arraylike_only_passthrough()
139
+ # test_unify_args_strict_true_raises_for_non_array_inputs()
140
+ # test_unify_args_strict_false_fallback_and_convert()
141
+ # test_resolve_device_basic_and_numpy_constraints()
142
+ # test_resolve_device_torch_checks()
137
143
  test_unify_args_strict_false_fallback_and_convert()
138
- test_resolve_device_basic_and_numpy_constraints()
139
- test_resolve_device_torch_checks()
@@ -251,3 +251,11 @@ def test_some_pyi_annotated_methods_via_dynamic_dispatch():
251
251
 
252
252
  assert r.shape == (4,)
253
253
  assert t.shape == (2, 1)
254
+
255
+
256
+ def test_cxp_of_compatarray_matches_array_namespace():
257
+ a = _arr_1d()
258
+ cxp = a.cxp
259
+
260
+ assert cxp is not None
261
+ assert cxp.xp_name == a.xp_name
File without changes
File without changes
File without changes
File without changes