array-api-compat 1.11__tar.gz → 1.11.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {array_api_compat-1.11 → array_api_compat-1.11.2}/PKG-INFO +3 -2
  2. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/__init__.py +1 -1
  3. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/_aliases.py +18 -17
  4. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/_aliases.py +15 -1
  5. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/array/_aliases.py +16 -1
  6. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/_aliases.py +14 -1
  7. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/torch/_aliases.py +48 -28
  8. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat.egg-info/PKG-INFO +3 -2
  9. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat.egg-info/SOURCES.txt +1 -0
  10. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_all.py +1 -0
  11. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_common.py +15 -0
  12. array_api_compat-1.11.2/tests/test_torch.py +98 -0
  13. {array_api_compat-1.11 → array_api_compat-1.11.2}/LICENSE +0 -0
  14. {array_api_compat-1.11 → array_api_compat-1.11.2}/README.md +0 -0
  15. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/_internal.py +0 -0
  16. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/__init__.py +0 -0
  17. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/_fft.py +0 -0
  18. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/_helpers.py +0 -0
  19. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/_linalg.py +0 -0
  20. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/common/_typing.py +0 -0
  21. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/__init__.py +0 -0
  22. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/_info.py +0 -0
  23. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/_typing.py +0 -0
  24. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/fft.py +0 -0
  25. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/cupy/linalg.py +0 -0
  26. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/__init__.py +0 -0
  27. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/array/__init__.py +0 -0
  28. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/array/_info.py +0 -0
  29. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/array/fft.py +0 -0
  30. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/dask/array/linalg.py +0 -0
  31. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/__init__.py +0 -0
  32. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/_info.py +0 -0
  33. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/_typing.py +0 -0
  34. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/fft.py +0 -0
  35. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/numpy/linalg.py +0 -0
  36. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/torch/__init__.py +0 -0
  37. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/torch/_info.py +0 -0
  38. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/torch/fft.py +0 -0
  39. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat/torch/linalg.py +0 -0
  40. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat.egg-info/dependency_links.txt +0 -0
  41. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat.egg-info/requires.txt +0 -0
  42. {array_api_compat-1.11 → array_api_compat-1.11.2}/array_api_compat.egg-info/top_level.txt +0 -0
  43. {array_api_compat-1.11 → array_api_compat-1.11.2}/setup.cfg +0 -0
  44. {array_api_compat-1.11 → array_api_compat-1.11.2}/setup.py +0 -0
  45. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_array_namespace.py +0 -0
  46. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_dask.py +0 -0
  47. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_isdtype.py +0 -0
  48. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_jax.py +0 -0
  49. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_no_dependencies.py +0 -0
  50. {array_api_compat-1.11 → array_api_compat-1.11.2}/tests/test_vendoring.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: array_api_compat
3
- Version: 1.11
3
+ Version: 1.11.2
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -34,6 +34,7 @@ Dynamic: description
34
34
  Dynamic: description-content-type
35
35
  Dynamic: home-page
36
36
  Dynamic: license
37
+ Dynamic: license-file
37
38
  Dynamic: provides-extra
38
39
  Dynamic: requires-python
39
40
  Dynamic: summary
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.11'
20
+ __version__ = '1.11.2'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from typing import NamedTuple
13
13
  import inspect
14
14
 
15
- from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
15
+ from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
16
16
 
17
17
  # These functions are modified from the NumPy versions.
18
18
 
@@ -363,28 +363,29 @@ def clip(
363
363
 
364
364
  # At least handle the case of Python integers correctly (see
365
365
  # https://github.com/numpy/numpy/pull/26892).
366
- if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
367
- min = None
368
- if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
369
- max = None
366
+ if wrapped_xp.isdtype(x.dtype, "integral"):
367
+ if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
368
+ min = None
369
+ if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
370
+ max = None
370
371
 
372
+ dev = device(x)
371
373
  if out is None:
372
- out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
373
- copy=True, device=device(x))
374
+ out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
375
+ out[()] = x
376
+
374
377
  if min is not None:
375
- if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
376
- # Avoid loss of precision due to torch defaulting to float32
377
- min = wrapped_xp.asarray(min, dtype=xp.float64)
378
- a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
378
+ a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev)
379
+ a = xp.broadcast_to(a, result_shape)
379
380
  ia = (out < a) | xp.isnan(a)
380
- # torch requires an explicit cast here
381
- out[ia] = wrapped_xp.astype(a[ia], out.dtype)
381
+ out[ia] = a[ia]
382
+
382
383
  if max is not None:
383
- if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
384
- max = wrapped_xp.asarray(max, dtype=xp.float64)
385
- b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
384
+ b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev)
385
+ b = xp.broadcast_to(b, result_shape)
386
386
  ib = (out > b) | xp.isnan(b)
387
- out[ib] = wrapped_xp.astype(b[ib], out.dtype)
387
+ out[ib] = b[ib]
388
+
388
389
  # Return a scalar for 0-D
389
390
  return out[()]
390
391
 
@@ -125,6 +125,20 @@ def astype(
125
125
  return out.copy() if copy and out is x else out
126
126
 
127
127
 
128
+ # cupy.count_nonzero does not have keepdims
129
+ def count_nonzero(
130
+ x: ndarray,
131
+ axis=None,
132
+ keepdims=False
133
+ ) -> ndarray:
134
+ result = cp.count_nonzero(x, axis)
135
+ if keepdims:
136
+ if axis is None:
137
+ return cp.reshape(result, [1]*x.ndim)
138
+ return cp.expand_dims(result, axis)
139
+ return result
140
+
141
+
128
142
  # These functions are completely new here. If the library already has them
129
143
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
130
144
  if hasattr(cp, 'vecdot'):
@@ -146,6 +160,6 @@ __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
146
160
  'acos', 'acosh', 'asin', 'asinh', 'atan',
147
161
  'atan2', 'atanh', 'bitwise_left_shift',
148
162
  'bitwise_invert', 'bitwise_right_shift',
149
- 'bool', 'concat', 'pow', 'sign']
163
+ 'bool', 'concat', 'count_nonzero', 'pow', 'sign']
150
164
 
151
165
  _all_ignore = ['cp', 'get_xp']
@@ -335,6 +335,21 @@ def argsort(
335
335
  return restore(x)
336
336
 
337
337
 
338
+ # dask.array.count_nonzero does not have keepdims
339
+ def count_nonzero(
340
+ x: Array,
341
+ axis=None,
342
+ keepdims=False
343
+ ) -> Array:
344
+ result = da.count_nonzero(x, axis)
345
+ if keepdims:
346
+ if axis is None:
347
+ return da.reshape(result, [1]*x.ndim)
348
+ return da.expand_dims(result, axis)
349
+ return result
350
+
351
+
352
+
338
353
  __all__ = _aliases.__all__ + [
339
354
  '__array_namespace_info__', 'asarray', 'astype', 'acos',
340
355
  'acosh', 'asin', 'asinh', 'atan', 'atan2',
@@ -343,6 +358,6 @@ __all__ = _aliases.__all__ + [
343
358
  'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
344
359
  'uint8', 'uint16', 'uint32', 'uint64',
345
360
  'complex64', 'complex128', 'iinfo', 'finfo',
346
- 'can_cast', 'result_type']
361
+ 'can_cast', 'count_nonzero', 'result_type']
347
362
 
348
363
  _all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"]
@@ -127,6 +127,19 @@ def astype(
127
127
  return x.astype(dtype=dtype, copy=copy)
128
128
 
129
129
 
130
+ # count_nonzero returns a python int for axis=None and keepdims=False
131
+ # https://github.com/numpy/numpy/issues/17562
132
+ def count_nonzero(
133
+ x : ndarray,
134
+ axis=None,
135
+ keepdims=False
136
+ ) -> ndarray:
137
+ result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
138
+ if axis is None and not keepdims:
139
+ return np.asarray(result)
140
+ return result
141
+
142
+
130
143
  # These functions are completely new here. If the library already has them
131
144
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
132
145
  if hasattr(np, 'vecdot'):
@@ -148,6 +161,6 @@ __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
148
161
  'acos', 'acosh', 'asin', 'asinh', 'atan',
149
162
  'atan2', 'atanh', 'bitwise_left_shift',
150
163
  'bitwise_invert', 'bitwise_right_shift',
151
- 'bool', 'concat', 'pow']
164
+ 'bool', 'concat', 'count_nonzero', 'pow']
152
165
 
153
166
  _all_ignore = ['np', 'get_xp']
@@ -1,15 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import wraps as _wraps
3
+ from functools import reduce as _reduce, wraps as _wraps
4
4
  from builtins import all as _builtin_all, any as _builtin_any
5
5
 
6
- from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7
- vecdot as _aliases_vecdot,
8
- clip as _aliases_clip,
9
- unstack as _aliases_unstack,
10
- cumulative_sum as _aliases_cumulative_sum,
11
- cumulative_prod as _aliases_cumulative_prod,
12
- )
6
+ from ..common import _aliases
13
7
  from .._internal import get_xp
14
8
 
15
9
  from ._info import __array_namespace_info__
@@ -130,25 +124,43 @@ _py_scalars = (bool, int, float, complex)
130
124
 
131
125
 
132
126
  def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype:
133
- if len(arrays_and_dtypes) == 0:
134
- raise TypeError("At least one array or dtype must be provided")
135
- if len(arrays_and_dtypes) == 1:
127
+ num = len(arrays_and_dtypes)
128
+
129
+ if num == 0:
130
+ raise ValueError("At least one array or dtype must be provided")
131
+
132
+ elif num == 1:
136
133
  x = arrays_and_dtypes[0]
137
134
  if isinstance(x, torch.dtype):
138
135
  return x
139
136
  return x.dtype
140
- if len(arrays_and_dtypes) > 2:
141
- return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:]))
142
137
 
143
- x, y = arrays_and_dtypes
144
- if isinstance(x, _py_scalars) or isinstance(y, _py_scalars):
145
- return torch.result_type(x, y)
138
+ if num == 2:
139
+ x, y = arrays_and_dtypes
140
+ return _result_type(x, y)
141
+
142
+ else:
143
+ # sort scalars so that they are treated last
144
+ scalars, others = [], []
145
+ for x in arrays_and_dtypes:
146
+ if isinstance(x, _py_scalars):
147
+ scalars.append(x)
148
+ else:
149
+ others.append(x)
150
+ if not others:
151
+ raise ValueError("At least one array or dtype must be provided")
152
+
153
+ # combine left-to-right
154
+ return _reduce(_result_type, others + scalars)
146
155
 
147
- xdt = x.dtype if not isinstance(x, torch.dtype) else x
148
- ydt = y.dtype if not isinstance(y, torch.dtype) else y
149
156
 
150
- if (xdt, ydt) in _promotion_table:
151
- return _promotion_table[xdt, ydt]
157
+ def _result_type(x, y):
158
+ if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
159
+ xdt = x.dtype if not isinstance(x, torch.dtype) else x
160
+ ydt = y.dtype if not isinstance(y, torch.dtype) else y
161
+
162
+ if (xdt, ydt) in _promotion_table:
163
+ return _promotion_table[xdt, ydt]
152
164
 
153
165
  # This doesn't result_type(dtype, dtype) for non-array API dtypes
154
166
  # because torch.result_type only accepts tensors. This does however, allow
@@ -157,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
157
169
  y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y
158
170
  return torch.result_type(x, y)
159
171
 
172
+
160
173
  def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
161
174
  if not isinstance(from_, torch.dtype):
162
175
  from_ = from_.dtype
@@ -215,10 +228,10 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
215
228
  return torch.clone(x)
216
229
  return torch.amin(x, axis, keepdims=keepdims)
217
230
 
218
- clip = get_xp(torch)(_aliases_clip)
219
- unstack = get_xp(torch)(_aliases_unstack)
220
- cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
221
- cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
231
+ clip = get_xp(torch)(_aliases.clip)
232
+ unstack = get_xp(torch)(_aliases.unstack)
233
+ cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
234
+ cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
222
235
 
223
236
  # torch.sort also returns a tuple
224
237
  # https://github.com/pytorch/pytorch/issues/70921
@@ -527,7 +540,7 @@ def diff(
527
540
  return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
528
541
 
529
542
 
530
- # torch uses `dim` instead of `axis`
543
+ # torch uses `dim` instead of `axis`, does not have keepdims
531
544
  def count_nonzero(
532
545
  x: array,
533
546
  /,
@@ -535,7 +548,14 @@ def count_nonzero(
535
548
  axis: Optional[Union[int, Tuple[int, ...]]] = None,
536
549
  keepdims: bool = False,
537
550
  ) -> array:
538
- return torch.count_nonzero(x, dim=axis, keepdims=keepdims)
551
+ result = torch.count_nonzero(x, dim=axis)
552
+ if keepdims:
553
+ if axis is not None:
554
+ return result.unsqueeze(axis)
555
+ return _axis_none_keepdims(result, x.ndim, keepdims)
556
+ else:
557
+ return result
558
+
539
559
 
540
560
 
541
561
  def where(condition: array, x1: array, x2: array, /) -> array:
@@ -710,8 +730,8 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
710
730
  x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
711
731
  return torch.matmul(x1, x2, **kwargs)
712
732
 
713
- matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
714
- _vecdot = get_xp(torch)(_aliases_vecdot)
733
+ matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
734
+ _vecdot = get_xp(torch)(_aliases.vecdot)
715
735
 
716
736
  def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
717
737
  x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: array_api_compat
3
- Version: 1.11
3
+ Version: 1.11.2
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -34,6 +34,7 @@ Dynamic: description
34
34
  Dynamic: description-content-type
35
35
  Dynamic: home-page
36
36
  Dynamic: license
37
+ Dynamic: license-file
37
38
  Dynamic: provides-extra
38
39
  Dynamic: requires-python
39
40
  Dynamic: summary
@@ -44,4 +44,5 @@ tests/test_dask.py
44
44
  tests/test_isdtype.py
45
45
  tests/test_jax.py
46
46
  tests/test_no_dependencies.py
47
+ tests/test_torch.py
47
48
  tests/test_vendoring.py
@@ -16,6 +16,7 @@ from ._helpers import import_, wrapped_libraries
16
16
 
17
17
  import pytest
18
18
 
19
+ @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277")
19
20
  @pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
20
21
  def test_all(library):
21
22
  if library == "common":
@@ -367,3 +367,18 @@ def test_asarray_copy(library):
367
367
  assert all(b[0] == 1.0)
368
368
  else:
369
369
  assert all(b[0] == 0.0)
370
+
371
+
372
+ @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"])
373
+ def test_clip_out(library):
374
+ """Test non-standard out= parameter for clip()
375
+
376
+ (see "Avoid Restricting Behavior that is Outside the Scope of the Standard"
377
+ in https://data-apis.org/array-api-compat/dev/special-considerations.html)
378
+ """
379
+ xp = import_(library, wrapper=True)
380
+ x = xp.asarray([10, 20, 30])
381
+ out = xp.zeros_like(x)
382
+ xp.clip(x, 15, 25, out=out)
383
+ expect = xp.asarray([15, 20, 25])
384
+ assert xp.all(out == expect)
@@ -0,0 +1,98 @@
1
+ """Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2
+ """
3
+ import itertools
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from array_api_compat import torch as xp
9
+
10
+
11
+ class TestResultType:
12
+ def test_empty(self):
13
+ with pytest.raises(ValueError):
14
+ xp.result_type()
15
+
16
+ def test_one_arg(self):
17
+ for x in [1, 1.0, 1j, '...', None]:
18
+ with pytest.raises((ValueError, AttributeError)):
19
+ xp.result_type(x)
20
+
21
+ for x in [xp.float32, xp.int64, torch.complex64]:
22
+ assert xp.result_type(x) == x
23
+
24
+ for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
25
+ assert xp.result_type(x) == x.dtype
26
+
27
+ def test_two_args(self):
28
+ # Only include here things "unspecified" in the spec
29
+
30
+ # scalar, tensor or tensor,tensor
31
+ for x, y in [
32
+ (1., 1j),
33
+ (1j, xp.arange(3)),
34
+ (True, xp.asarray(3.)),
35
+ (xp.ones(3) == 1, 1j*xp.ones(3)),
36
+ ]:
37
+ assert xp.result_type(x, y) == torch.result_type(x, y)
38
+
39
+ # dtype, scalar
40
+ for x, y in [
41
+ (1j, xp.int64),
42
+ (True, xp.float64),
43
+ ]:
44
+ assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
45
+
46
+ # dtype, dtype
47
+ for x, y in [
48
+ (xp.bool, xp.complex64)
49
+ ]:
50
+ xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
51
+ assert xp.result_type(x, y) == torch.result_type(xt, yt)
52
+
53
+ def test_multi_arg(self):
54
+ torch.set_default_dtype(torch.float32)
55
+
56
+ args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
57
+ assert xp.result_type(*args) == torch.float16
58
+
59
+ args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
60
+ assert xp.result_type(*args) == xp.complex64
61
+
62
+ args = [1, 2, 3j, xp.float64, 4, 5, 6]
63
+ assert xp.result_type(*args) == xp.complex128
64
+
65
+ args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
66
+ assert xp.result_type(*args) == xp.complex128
67
+
68
+ i64 = xp.ones(1, dtype=xp.int64)
69
+ f16 = xp.ones(1, dtype=xp.float16)
70
+ for i in itertools.permutations([i64, f16, 1.0, 1.0]):
71
+ assert xp.result_type(*i) == xp.float16, f"{i}"
72
+
73
+ with pytest.raises(ValueError):
74
+ xp.result_type(1, 2, 3, 4)
75
+
76
+
77
+ @pytest.mark.parametrize("default_dt", ['float32', 'float64'])
78
+ @pytest.mark.parametrize("dtype_a",
79
+ (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
80
+ )
81
+ @pytest.mark.parametrize("dtype_b",
82
+ (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128)
83
+ )
84
+ def test_gh_273(self, default_dt, dtype_a, dtype_b):
85
+ # Regression test for https://github.com/data-apis/array-api-compat/issues/273
86
+
87
+ try:
88
+ prev_default = torch.get_default_dtype()
89
+ default_dtype = getattr(torch, default_dt)
90
+ torch.set_default_dtype(default_dtype)
91
+
92
+ a = xp.asarray([2, 1], dtype=dtype_a)
93
+ b = xp.asarray([1, -1], dtype=dtype_b)
94
+ dtype_1 = xp.result_type(a, b, 1.0)
95
+ dtype_2 = xp.result_type(b, a, 1.0)
96
+ assert dtype_1 == dtype_2
97
+ finally:
98
+ torch.set_default_dtype(prev_default)