cobra-array 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {cobra_array-0.2.0/src/cobra_array.egg-info → cobra_array-0.2.2}/PKG-INFO +1 -1
  2. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/__init__.py +1 -1
  3. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/_array.py +2 -2
  4. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/_array.pyi +59 -58
  5. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/_namespace.py +131 -0
  6. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/_namespace.pyi +46 -23
  7. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/convert.py +4 -4
  8. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/types.py +8 -7
  9. {cobra_array-0.2.0 → cobra_array-0.2.2/src/cobra_array.egg-info}/PKG-INFO +1 -1
  10. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_compat.py +56 -6
  11. {cobra_array-0.2.0 → cobra_array-0.2.2}/LICENSE +0 -0
  12. {cobra_array-0.2.0 → cobra_array-0.2.2}/README.md +0 -0
  13. {cobra_array-0.2.0 → cobra_array-0.2.2}/pyproject.toml +0 -0
  14. {cobra_array-0.2.0 → cobra_array-0.2.2}/setup.cfg +0 -0
  15. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/_core.py +0 -0
  16. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/_utils.py +0 -0
  17. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/array_api.py +0 -0
  18. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/__init__.py +0 -0
  19. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/compat/_base.py +0 -0
  20. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/convert.pyi +0 -0
  21. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/default.py +0 -0
  22. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array/exceptions.py +0 -0
  23. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array.egg-info/SOURCES.txt +0 -0
  24. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array.egg-info/dependency_links.txt +0 -0
  25. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array.egg-info/requires.txt +0 -0
  26. {cobra_array-0.2.0 → cobra_array-0.2.2}/src/cobra_array.egg-info/top_level.txt +0 -0
  27. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_backend.py +0 -0
  28. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_compat_namespace.py +0 -0
  29. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_convert.py +0 -0
  30. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_default.py +0 -0
  31. {cobra_array-0.2.0 → cobra_array-0.2.2}/tests/test_wrap.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -80,7 +80,7 @@ from ._utils import (
80
80
  )
81
81
 
82
82
  __author__ = "Zhen Tian"
83
- __version__ = "0.2.0"
83
+ __version__ = "0.2.2"
84
84
 
85
85
  __all__ = [
86
86
  "array_spec",
@@ -339,7 +339,7 @@ class CompatArray(Compat):
339
339
  @property
340
340
  def device(self):
341
341
  """
342
- DeviceT on which `self` is stored.
342
+ Device on which `self` is stored.
343
343
  """
344
344
  return api.device(self._arr)
345
345
 
@@ -402,7 +402,7 @@ class CompatArray(Compat):
402
402
  """Allow implicit NumPy conversion."""
403
403
  return self.to_numpy()
404
404
 
405
- def __getattr__(self, name: str):
405
+ def __getattr__(self, name):
406
406
  attr = self._get_xp_attr(name)
407
407
 
408
408
  if callable(attr) and not isinstance(attr, type):
@@ -10,7 +10,7 @@ from typing import (Union, List, Tuple, Optional, Any, Sequence, Generic, TypeVa
10
10
  from ._base import Compat
11
11
  from ._namespace import CompatNamespace
12
12
  from ..types import (
13
- T, DTypeT, DeviceT, dtypeT, deviceT, DType, AnyDevice,
13
+ T, dtypeT, DTypeT, deviceT, DeviceT, DType, AnyDevice,
14
14
  ArrayLike, ArrayLibraryName,
15
15
  ArrayOrAny, ArrayOrScalar, ArrayOrReal, ArrayOrIntLike, ArrayOrInt, ArrayOrbool,
16
16
  UniqueAllResult, UniqueCountsResult, UniqueInverseResult
@@ -64,7 +64,6 @@ class CompatArray(Compat, Generic[TT, DT]):
64
64
  def to_device(self, device: DeviceT, /, *, stream: Optional[Any] = None) -> CompatArray[TT, DeviceT]: ...
65
65
  @overload
66
66
  def to_device(self, device: AnyDevice, /, *, stream: Optional[Any] = None) -> CompatArray[TT, AnyDevice]: ...
67
-
68
67
  def to_device(self, device: AnyDevice, /, *, stream: Optional[Any] = None) -> CompatArray[Any, AnyDevice]: ...
69
68
 
70
69
  # === Data type functions ===
@@ -117,14 +116,14 @@ class CompatArray(Compat, Generic[TT, DT]):
117
116
  """
118
117
  ...
119
118
 
120
- def acos(self) -> CompatArray[float, DT]:
119
+ def acos(self) -> CompatArray[type[float], DT]:
121
120
  """
122
121
  Computes the element-wise `principal value of the inverse cosine` of `self`.
123
122
  - `self` should have a floating-point data type.
124
123
  """
125
124
  ...
126
125
 
127
- def acosh(self) -> CompatArray[float, DT]:
126
+ def acosh(self) -> CompatArray[type[float], DT]:
128
127
  """
129
128
  Computes the element-wise `inverse hyperbolic cosine` of `self`.
130
129
  - `self` should have a floating-point data type.
@@ -138,28 +137,28 @@ class CompatArray(Compat, Generic[TT, DT]):
138
137
  """
139
138
  ...
140
139
 
141
- def asin(self) -> CompatArray[float, DT]:
140
+ def asin(self) -> CompatArray[type[float], DT]:
142
141
  """
143
142
  Computes the element-wise `principal value of the inverse sine` of `self`.
144
143
  - `self` should have a floating-point data type.
145
144
  """
146
145
  ...
147
146
 
148
- def asinh(self) -> CompatArray[float, DT]:
147
+ def asinh(self) -> CompatArray[type[float], DT]:
149
148
  """
150
149
  Computes the element-wise `inverse hyperbolic sine` of `self`.
151
150
  - `self` should have a floating-point data type.
152
151
  """
153
152
  ...
154
153
 
155
- def atan(self) -> CompatArray[float, DT]:
154
+ def atan(self) -> CompatArray[type[float], DT]:
156
155
  """
157
156
  Computes the element-wise `principal value of the inverse tangent` of `self`.
158
157
  - `self` should have a floating-point data type.
159
158
  """
160
159
  ...
161
160
 
162
- def atan2(self, other: ArrayOrReal, /) -> CompatArray[float, DT]:
161
+ def atan2(self, other: ArrayOrReal, /) -> CompatArray[type[float], DT]:
163
162
  """
164
163
  Computes the element-wise `inverse tangent` of `self / other`, taking into account the signs of both inputs.
165
164
  - `self` should have a real-valued floating-point data type.
@@ -167,7 +166,7 @@ class CompatArray(Compat, Generic[TT, DT]):
167
166
  """
168
167
  ...
169
168
 
170
- def atanh(self) -> CompatArray[float, DT]:
169
+ def atanh(self) -> CompatArray[type[float], DT]:
171
170
  """
172
171
  Computes the element-wise `inverse hyperbolic tangent` of `self`.
173
172
  - `self` should have a floating-point data type.
@@ -241,7 +240,7 @@ class CompatArray(Compat, Generic[TT, DT]):
241
240
  """
242
241
  ...
243
242
 
244
- def copysign(self, other: ArrayOrReal, /) -> CompatArray[float, DT]:
243
+ def copysign(self, other: ArrayOrReal, /) -> CompatArray[type[float], DT]:
245
244
  """
246
245
  Computes the element-wise `copysign` of `self` with `other`.
247
246
  - `self` should have a real-valued floating-point data type.
@@ -249,41 +248,41 @@ class CompatArray(Compat, Generic[TT, DT]):
249
248
  """
250
249
  ...
251
250
 
252
- def cos(self) -> CompatArray[float, DT]:
251
+ def cos(self) -> CompatArray[type[float], DT]:
253
252
  """
254
253
  Computes the element-wise `cosine` of `self`.
255
254
  - `self` should have a floating-point data type.
256
255
  """
257
256
  ...
258
257
 
259
- def cosh(self) -> CompatArray[float, DT]:
258
+ def cosh(self) -> CompatArray[type[float], DT]:
260
259
  """
261
260
  Computes the element-wise `hyperbolic cosine` of `self`.
262
261
  - `self` should have a floating-point data type.
263
262
  """
264
263
  ...
265
264
 
266
- def divide(self, other: ArrayOrScalar, /) -> CompatArray[float, DT]:
265
+ def divide(self, other: ArrayOrScalar, /) -> CompatArray[type[float], DT]:
267
266
  """
268
267
  Computes the element-wise `division` of `self` by `other`.
269
268
  - `self` should have a numeric data type.
270
269
  """
271
270
  ...
272
271
 
273
- def equal(self, other: ArrayOrAny, /) -> CompatArray[bool, DT]:
272
+ def equal(self, other: ArrayOrAny, /) -> CompatArray[type[bool], DT]:
274
273
  """
275
274
  Computes the element-wise truth value of `self == other`.
276
275
  """
277
276
  ...
278
277
 
279
- def exp(self) -> CompatArray[float, DT]:
278
+ def exp(self) -> CompatArray[type[float], DT]:
280
279
  """
281
280
  Computes the element-wise `exponential` (`exp(x)`) of `self`.
282
281
  - `self` should have a floating-point data type.
283
282
  """
284
283
  ...
285
284
 
286
- def expm1(self) -> CompatArray[float, DT]:
285
+ def expm1(self) -> CompatArray[type[float], DT]:
287
286
  """
288
287
  Computes the element-wise `exp(x) - 1` of `self`.
289
288
  - `self` should have a floating-point data type.
@@ -306,21 +305,21 @@ class CompatArray(Compat, Generic[TT, DT]):
306
305
  """
307
306
  ...
308
307
 
309
- def greater(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]:
308
+ def greater(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]:
310
309
  """
311
310
  Computes the element-wise truth value of `self > other`.
312
311
  - `self` should have a real-valued data type.
313
312
  """
314
313
  ...
315
314
 
316
- def greater_equal(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]:
315
+ def greater_equal(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]:
317
316
  """
318
317
  Computes the element-wise truth value of `self >= other`.
319
318
  - `self` should have a real-valued data type.
320
319
  """
321
320
  ...
322
321
 
323
- def hypot(self, other: ArrayOrReal, /) -> CompatArray[float, DT]:
322
+ def hypot(self, other: ArrayOrReal, /) -> CompatArray[type[float], DT]:
324
323
  """
325
324
  Computes the element-wise `hypotenuse` of `self` and `other`.
326
325
  - `self` should have a real-valued floating-point data type.
@@ -328,77 +327,77 @@ class CompatArray(Compat, Generic[TT, DT]):
328
327
  """
329
328
  ...
330
329
 
331
- def imag(self) -> CompatArray[float, DT]:
330
+ def imag(self) -> CompatArray[type[float], DT]:
332
331
  """
333
332
  Computes the element-wise `imaginary component` of `self`.
334
333
  - `self` should have a complex floating-point data type.
335
334
  """
336
335
  ...
337
336
 
338
- def isfinite(self) -> CompatArray[bool, DT]:
337
+ def isfinite(self) -> CompatArray[type[bool], DT]:
339
338
  """
340
339
  Tests the element-wise `finiteness` of `self`.
341
340
  - `self` should have a numeric data type.
342
341
  """
343
342
  ...
344
343
 
345
- def isinf(self) -> CompatArray[bool, DT]:
344
+ def isinf(self) -> CompatArray[type[bool], DT]:
346
345
  """
347
346
  Tests the element-wise `infinity` of `self`.
348
347
  - - `self` should have a numeric data type.
349
348
  """
350
349
  ...
351
350
 
352
- def isnan(self) -> CompatArray[bool, DT]:
351
+ def isnan(self) -> CompatArray[type[bool], DT]:
353
352
  """
354
353
  Tests the element-wise `NaN` of `self`.
355
354
  - `self` should have a numeric data type.
356
355
  """
357
356
  ...
358
357
 
359
- def less(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]:
358
+ def less(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]:
360
359
  """
361
360
  Computes the element-wise truth value of `self < other`.
362
361
  - `self` should have a real-valued data type.
363
362
  """
364
363
  ...
365
364
 
366
- def less_equal(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]:
365
+ def less_equal(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]:
367
366
  """
368
367
  Computes the element-wise truth value of `self <= other`.
369
368
  - `self` should have a real-valued data type.
370
369
  """
371
370
  ...
372
371
 
373
- def log(self) -> CompatArray[float, DT]:
372
+ def log(self) -> CompatArray[type[float], DT]:
374
373
  """
375
374
  Computes the element-wise `natural logarithm` (base `e`) of `self`.
376
375
  - `self` should have a floating-point data type.
377
376
  """
378
377
  ...
379
378
 
380
- def log1p(self) -> CompatArray[float, DT]:
379
+ def log1p(self) -> CompatArray[type[float], DT]:
381
380
  """
382
381
  Computes the element-wise `log(1 + x)` (base `e`) of `self`.
383
382
  - `self` should have a floating-point data type.
384
383
  """
385
384
  ...
386
385
 
387
- def log2(self) -> CompatArray[float, DT]:
386
+ def log2(self) -> CompatArray[type[float], DT]:
388
387
  """
389
388
  Computes the element-wise `base-2 logarithm` of `self`.
390
389
  - `self` should have a floating-point data type.
391
390
  """
392
391
  ...
393
392
 
394
- def log10(self) -> CompatArray[float, DT]:
393
+ def log10(self) -> CompatArray[type[float], DT]:
395
394
  """
396
395
  Computes the element-wise `base-10 logarithm` of `self`.
397
396
  - `self` should have a floating-point data type.
398
397
  """
399
398
  ...
400
399
 
401
- def logaddexp(self, other: ArrayOrReal, /) -> CompatArray[float, DT]:
400
+ def logaddexp(self, other: ArrayOrReal, /) -> CompatArray[type[float], DT]:
402
401
  """
403
402
  Computes the element-wise `logaddexp` of `self` and `other`.
404
403
  - Equivalent to `log(exp(self) + exp(other))`.
@@ -406,28 +405,28 @@ class CompatArray(Compat, Generic[TT, DT]):
406
405
  """
407
406
  ...
408
407
 
409
- def logical_and(self, other: ArrayOrbool, /) -> CompatArray[bool, DT]:
408
+ def logical_and(self, other: ArrayOrbool, /) -> CompatArray[type[bool], DT]:
410
409
  """
411
410
  Computes the element-wise `logical AND` of `self` and `other`.
412
411
  - `self` should have a boolean data type.
413
412
  """
414
413
  ...
415
414
 
416
- def logical_not(self) -> CompatArray[bool, DT]:
415
+ def logical_not(self) -> CompatArray[type[bool], DT]:
417
416
  """
418
417
  Computes the element-wise `logical NOT` of `self`.
419
418
  - `self` should have a boolean data type.
420
419
  """
421
420
  ...
422
421
 
423
- def logical_or(self, other: ArrayOrbool, /) -> CompatArray[bool, DT]:
422
+ def logical_or(self, other: ArrayOrbool, /) -> CompatArray[type[bool], DT]:
424
423
  """
425
424
  Computes the element-wise `logical OR` of `self` and `other`.
426
425
  - `self` should have a boolean data type.
427
426
  """
428
427
  ...
429
428
 
430
- def logical_xor(self, other: ArrayOrbool, /) -> CompatArray[bool, DT]:
429
+ def logical_xor(self, other: ArrayOrbool, /) -> CompatArray[type[bool], DT]:
431
430
  """
432
431
  Computes the element-wise `logical XOR` of `self` and `other`.
433
432
  - `self` should have a boolean data type.
@@ -469,7 +468,7 @@ class CompatArray(Compat, Generic[TT, DT]):
469
468
  """
470
469
  ...
471
470
 
472
- def not_equal(self, other: ArrayOrAny, /) -> CompatArray[bool, DT]:
471
+ def not_equal(self, other: ArrayOrAny, /) -> CompatArray[type[bool], DT]:
473
472
  """
474
473
  Computes the element-wise truth value of `self != other`.
475
474
  """
@@ -490,14 +489,14 @@ class CompatArray(Compat, Generic[TT, DT]):
490
489
  """
491
490
  ...
492
491
 
493
- def real(self) -> CompatArray[float, DT]:
492
+ def real(self) -> CompatArray[type[float], DT]:
494
493
  """
495
494
  Computes the element-wise `real component` of `self`.
496
495
  - `self` should have a numeric data type.
497
496
  """
498
497
  ...
499
498
 
500
- def reciprocal(self) -> CompatArray[float, DT]:
499
+ def reciprocal(self) -> CompatArray[type[float], DT]:
501
500
  """
502
501
  Computes the element-wise `reciprocal` of `self`.
503
502
  - `self` should have a floating-point data type.
@@ -525,7 +524,7 @@ class CompatArray(Compat, Generic[TT, DT]):
525
524
  """
526
525
  ...
527
526
 
528
- def signbit(self) -> CompatArray[bool, DT]:
527
+ def signbit(self) -> CompatArray[type[bool], DT]:
529
528
  """
530
529
  Tests the element-wise `sign bit` of `self`.
531
530
  - Tests each element for whenever is either `-0`, `less than zero`, or a signed `NaN` (i.e., a NaN value whose sign bit is 1).
@@ -533,14 +532,14 @@ class CompatArray(Compat, Generic[TT, DT]):
533
532
  """
534
533
  ...
535
534
 
536
- def sin(self) -> CompatArray[float, DT]:
535
+ def sin(self) -> CompatArray[type[float], DT]:
537
536
  """
538
537
  Computes the element-wise `sine` of `self`.
539
538
  - `self` should have a floating-point data type.
540
539
  """
541
540
  ...
542
541
 
543
- def sinh(self) -> CompatArray[float, DT]:
542
+ def sinh(self) -> CompatArray[type[float], DT]:
544
543
  """
545
544
  Computes the element-wise `hyperbolic sine` of `self`.
546
545
  - `self` should have a floating-point data type.
@@ -554,7 +553,7 @@ class CompatArray(Compat, Generic[TT, DT]):
554
553
  """
555
554
  ...
556
555
 
557
- def sqrt(self) -> CompatArray[float, DT]:
556
+ def sqrt(self) -> CompatArray[type[float], DT]:
558
557
  """
559
558
  Computes the element-wise `principal square root` of `self`.
560
559
  - `self` should have a floating-point data type.
@@ -568,14 +567,14 @@ class CompatArray(Compat, Generic[TT, DT]):
568
567
  """
569
568
  ...
570
569
 
571
- def tan(self) -> CompatArray[float, DT]:
570
+ def tan(self) -> CompatArray[type[float], DT]:
572
571
  """
573
572
  Computes the element-wise `tangent` of `self`.
574
573
  - `self` should have a floating-point data type.
575
574
  """
576
575
  ...
577
576
 
578
- def tanh(self) -> CompatArray[float, DT]:
577
+ def tanh(self) -> CompatArray[type[float], DT]:
579
578
  """
580
579
  Computes the element-wise `hyperbolic tangent` of `self`.
581
580
  - `self` should have a floating-point data type.
@@ -994,7 +993,7 @@ class CompatArray(Compat, Generic[TT, DT]):
994
993
  self, *,
995
994
  axis: Optional[int] = None,
996
995
  keepdims: bool = False
997
- ) -> CompatArray[int, DT]:
996
+ ) -> CompatArray[type[int], DT]:
998
997
  """
999
998
  Returns the indices of the maximum values along a specified axis.
1000
999
 
@@ -1027,7 +1026,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1027
1026
  self, *,
1028
1027
  axis: Optional[int] = None,
1029
1028
  keepdims: bool = False
1030
- ) -> CompatArray[int, DT]:
1029
+ ) -> CompatArray[type[int], DT]:
1031
1030
  """
1032
1031
  Returns the indices of the minimum values along a specified axis.
1033
1032
 
@@ -1056,13 +1055,13 @@ class CompatArray(Compat, Generic[TT, DT]):
1056
1055
  """
1057
1056
  ...
1058
1057
 
1059
- def nonzero(self) -> Tuple[CompatArray[int, DT], ...]: ...
1058
+ def nonzero(self) -> Tuple[CompatArray[type[int], DT], ...]: ...
1060
1059
 
1061
1060
  def count_nonzero(
1062
1061
  self, *,
1063
1062
  axis: Optional[Union[int, Tuple[int, ...]]] = None,
1064
1063
  keepdims: bool = False
1065
- ) -> CompatArray[int, DT]:
1064
+ ) -> CompatArray[type[int], DT]:
1066
1065
  """
1067
1066
  Counts the number of `self` elements which are non-zero.
1068
1067
 
@@ -1097,7 +1096,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1097
1096
  /, *,
1098
1097
  side: Literal['left', 'right'] = "left",
1099
1098
  sorter: Optional[ArrayLike[Any]] = None
1100
- ) -> CompatArray[int, DT]:
1099
+ ) -> CompatArray[type[int], DT]:
1101
1100
  """
1102
1101
  Finds the indices into `self` such that, if the corresponding elements in `other` were inserted before the indices, the order of `self`, when sorted in ascending order, would be preserved.
1103
1102
  - `self` must be a one-dimensional array. Should have a real-valued data type.
@@ -1188,7 +1187,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1188
1187
  axis: int = -1,
1189
1188
  descending: bool = False,
1190
1189
  stable: bool = True
1191
- ) -> CompatArray[int, DT]:
1190
+ ) -> CompatArray[type[int], DT]:
1192
1191
  """
1193
1192
  Returns the indices that sort `self` along a specified axis.
1194
1193
 
@@ -1634,7 +1633,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1634
1633
  self, *,
1635
1634
  axis: Optional[Union[int, Tuple[int, ...]]] = None,
1636
1635
  keepdims: bool = False
1637
- ) -> CompatArray[bool, DT]:
1636
+ ) -> CompatArray[type[bool], DT]:
1638
1637
  """
1639
1638
  Tests whether all `self` elements evaluate to `True` along a specified axis.
1640
1639
  - `Positive infinity`, `negative infinity`, and `NaN` must evaluate to `True`;
@@ -1667,7 +1666,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1667
1666
  self, *,
1668
1667
  axis: Optional[Union[int, Tuple[int, ...]]] = None,
1669
1668
  keepdims: bool = False
1670
- ) -> CompatArray[bool, DT]:
1669
+ ) -> CompatArray[type[bool], DT]:
1671
1670
  """
1672
1671
  Tests whether any `self` elements evaluate to `True` along a specified axis.
1673
1672
  - `Positive infinity`, `negative infinity`, and `NaN` must evaluate to `True`;
@@ -1771,28 +1770,30 @@ class CompatArray(Compat, Generic[TT, DT]):
1771
1770
 
1772
1771
  def __array__(self) -> NDArray[TT]: ...
1773
1772
 
1773
+ def __getattr__(self, name: str) -> Any: ...
1774
+
1774
1775
  def __len__(self) -> int: ...
1775
1776
  def __abs__(self) -> CompatArray[TT, DT]: ...
1776
1777
  def __add__(self, other: ArrayOrScalar, /) -> CompatArray[Any, DT]: ...
1777
1778
  def __and__(self, other: ArrayOrIntLike, /) -> CompatArray[TT, DT]: ...
1778
1779
  def __bool__(self) -> bool: ...
1779
1780
  def __complex__(self) -> complex: ...
1780
- def __eq__(self, other: ArrayOrAny, /) -> CompatArray[bool, DT]: ...
1781
+ def __eq__(self, other: ArrayOrAny, /) -> CompatArray[type[bool], DT]: ...
1781
1782
  def __float__(self) -> float: ...
1782
1783
  def __floordiv__(self, other: ArrayOrReal, /) -> CompatArray[TT, DT]: ...
1783
- def __ge__(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]: ...
1784
+ def __ge__(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]: ...
1784
1785
  def __getitem__(self, key: Any, /) -> CompatArray[TT, DT]: ...
1785
- def __gt__(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]: ...
1786
+ def __gt__(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]: ...
1786
1787
  def __index__(self) -> int: ...
1787
1788
  def __int__(self) -> int: ...
1788
1789
  def __invert__(self) -> CompatArray[TT, DT]: ...
1789
- def __le__(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]: ...
1790
+ def __le__(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]: ...
1790
1791
  def __lshift__(self, other: ArrayOrInt, /) -> CompatArray[TT, DT]: ...
1791
- def __lt__(self, other: ArrayOrReal, /) -> CompatArray[bool, DT]: ...
1792
+ def __lt__(self, other: ArrayOrReal, /) -> CompatArray[type[bool], DT]: ...
1792
1793
  def __matmul__(self, other: ArrayLike[Any], /) -> CompatArray[Any, DT]: ...
1793
1794
  def __mod__(self, other: ArrayOrReal, /) -> CompatArray[Any, DT]: ...
1794
1795
  def __mul__(self, other: ArrayOrScalar, /) -> CompatArray[Any, DT]: ...
1795
- def __ne__(self, other: ArrayOrAny, /) -> CompatArray[bool, DT]: ...
1796
+ def __ne__(self, other: ArrayOrAny, /) -> CompatArray[type[bool], DT]: ...
1796
1797
  def __neg__(self) -> CompatArray[TT, DT]: ...
1797
1798
  def __or__(self, other: ArrayOrIntLike, /) -> CompatArray[TT, DT]: ...
1798
1799
  def __pos__(self) -> CompatArray[TT, DT]: ...
@@ -1800,7 +1801,7 @@ class CompatArray(Compat, Generic[TT, DT]):
1800
1801
  def __rshift__(self, other: ArrayOrInt, /) -> CompatArray[TT, DT]: ...
1801
1802
  def __setitem__(self, key: Any, value: Any, /): ...
1802
1803
  def __sub__(self, other: ArrayOrScalar, /) -> CompatArray[Any, DT]: ...
1803
- def __truediv__(self, other: ArrayOrScalar, /) -> CompatArray[float, DT]: ...
1804
+ def __truediv__(self, other: ArrayOrScalar, /) -> CompatArray[type[float], DT]: ...
1804
1805
  def __xor__(self, other: ArrayOrIntLike, /) -> CompatArray[TT, DT]: ...
1805
1806
 
1806
1807
 
@@ -236,6 +236,137 @@ class CompatNamespace(Compat):
236
236
  result = self._get_xp_attr("broadcast_arrays")(*[unwrap(arr) for arr in arrays])
237
237
  return [CompatArray(arr, xp=self) for arr in result]
238
238
 
239
+ # === Linear Algebra Extension ===
240
+ def vector_norm(self, x, /, *, axis=None, keepdims=False, ord=2):
241
+ """
242
+ Computes the vector norm of a vector (or batch of vectors) :param:`x`.
243
+
244
+ Parameters
245
+ ----------
246
+ x : ArrayLike[Any]
247
+ The input array. Should have a floating-point data type.
248
+
249
+ axis : Optional[Union[int, Tuple[int, ...]]], default to `None`
250
+ - _int_: :param:`axis` specifies the axis (dimension) along which to compute vector norms;
251
+ - _tuple_: :param:`axis` specifies the axes (dimensions) along which to compute batched vector norms;
252
+ - `None`: The vector norm must be computed over all array values (i.e., equivalent to computing the vector norm of a flattened array).
253
+
254
+ Negative indices must be supported.
255
+
256
+ keepdims : bool, default to `False`
257
+ - `True`: The axes (dimensions) specified by :param:`axis` must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see Broadcasting);
258
+ - `False`: The axes (dimensions) specified by :param:`axis` must not be included in the result.
259
+
260
+ ord : Union[int, float, Literal['inf', '-inf']], default to `2`
261
+ Order of the norm.
262
+ The following mathematical norms must be supported:
263
+ - `1`: L1-norm (Manhattan);
264
+ - `2`: L2-norm (Euclidean);
265
+ - `"inf"`: infinity norm;
266
+ - _int_ or _float_ (>=1): p-norm.
267
+
268
+ The following non-mathematical “norms” must be supported:
269
+ - `0`: sum(a != 0);
270
+ - `-1`: 1./sum(1./abs(a));
271
+ - `-2`: 1./sqrt(sum(1./a**2));
272
+ - `"-inf"`: min(abs(a));
273
+ - _int_ or _float_ (<1): sum(abs(a)**ord)**(1./ord).
274
+
275
+ Returns
276
+ -------
277
+ CompatArray
278
+ A :class:`CompatArray` array containing the vector norms.
279
+ - :param:`axis` is `None`: The returned array must be a zero-dimensional array containing a vector norm;
280
+ - :param:`axis` is a scalar value (_int_ or _float_): The returned array must have a rank which is one less than the rank of :param:`x`;
281
+ - :param:`axis` is _tuple_ (`n` elements): The returned array must have a rank which is `n` less than the rank of :param:`x`;
282
+
283
+ - :param:`x` is real-valued data type: The returned array must have a real-valued floating-point data type determined by Type Promotion Rules;
284
+ - :param:`x` is complex-valued data type: The returned array must have a real-valued floating-point data type whose precision matches the precision of :param:`x` (e.g., if :param:`x` is complex128, then the returned array must have a float64 data type).
285
+ """
286
+ if ord == "inf":
287
+ ord = float("inf")
288
+ elif ord == "-inf":
289
+ ord = float("-inf")
290
+
291
+ result = getattr(self.linalg, "vector_norm")(unwrap(x), axis=axis, keepdims=keepdims, ord=ord)
292
+ return CompatArray(result, xp=self)
293
+
294
+ def matrix_norm(self, x, /, *, keepdims=False, ord="fro"):
295
+ """
296
+ Computes the matrix norm of a matrix (or a stack of matrices) :param:`x`.
297
+
298
+ Parameters
299
+ ----------
300
+ x : ArrayLike[Any]
301
+ Input array having shape (..., `M`, `N`) and whose innermost two dimensions form `MxN` matrices. Should have a floating-point data type.
302
+
303
+ keepdims : bool, default to `False`
304
+ - `True`: The last two axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see Broadcasting);
305
+ - `False`: The last two axes (dimensions) must not be included in the result.
306
+
307
+ ord : Optional[Union[int, float, Literal['inf', '-inf', 'fro', 'nuc']]], default to `"fro"`
308
+ order of the norm.
309
+ The following mathematical norms must be supported:
310
+ - `"fro"`: Frobenius norm;
311
+ - `"nuc"`: nuclear norm;
312
+ - `1`: max(sum(abs(x), axis=0)). The norm corresponds to the induced matrix norm where `p=1` (i.e., the maximum absolute value column sum);
313
+ - `2`: largest singular value. The norm corresponds to the induced matrix norm where `p=inf` (i.e., the maximum absolute value row sum);
314
+ - `"inf"`: max(sum(abs(x), axis=1)). The norm corresponds to the induced matrix norm where `p=2` (i.e., the largest singular value).
315
+
316
+ The following non-mathematical “norms” must be supported:
317
+ - `-1`: min(sum(abs(x), axis=0));
318
+ - `-2`: smallest singular value;
319
+ - `"-inf"`: min(sum(abs(x), axis=1)).
320
+
321
+ Returns
322
+ -------
323
+ CompatArray
324
+ A :class:`CompatArray` array containing the norms for each `MxN` matrix.
325
+ - :param:`keepdims` is `False`: The returned array must have a rank which is two less than the rank of :param:`x`;
326
+
327
+ - :param:`x` is real-valued data type: The returned array must have a real-valued floating-point data type determined by Type Promotion Rules;
328
+ - :param:`x` is complex-valued data type: The returned array must have a real-valued floating-point data type whose precision matches the precision of :param:`x` (e.g., if :param:`x` is complex128, then the returned array must have a float64 data type).
329
+
330
+ """
331
+ if ord == "inf":
332
+ ord = float("inf")
333
+ elif ord == "-inf":
334
+ ord = float("-inf")
335
+
336
+ result = getattr(self.linalg, "matrix_norm")(unwrap(x), keepdims=keepdims, ord=ord)
337
+ return CompatArray(result, xp=self)
338
+
339
+ @property
340
+ def linalg(self):
341
+ """
342
+ The `linalg` namespace for linear algebra functions.
343
+ The following functions must be supported in the `linalg` namespace:
344
+ - `cholesky`(x, /, *, upper=False): Returns the lower (upper) Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix x.
345
+ - `cross`(x1, x2, /, *, axis=-1): Returns the cross product of 3-element vectors.
346
+ - `det`(x, /): Returns the determinant of a square matrix (or a stack of square matrices) x.
347
+ - `diagonal`(x, /, *, offset=0): Returns the specified diagonals of a matrix (or a stack of matrices) x.
348
+ - `eigh`(x, /): Returns an eigenvalue decomposition of a complex Hermitian or real symmetric matrix (or a stack of matrices) x.
349
+ - `eigvalsh`(x, /): Returns the eigenvalues of a complex Hermitian or real symmetric matrix (or a stack of matrices) x.
350
+ - `inv`(x, /): Returns the multiplicative inverse of a square matrix (or a stack of square matrices) x.
351
+ - `matmul`(x1, x2, /): Alias for matmul().
352
+ - `matrix_norm`(x, /, *, keepdims=False, ord='fro'): Computes the matrix norm of a matrix (or a stack of matrices) x.
353
+ - `matrix_power`(x, n, /): Raises a square matrix (or a stack of square matrices) x to an integer power n.
354
+ - `matrix_rank`(x, /, *, rtol=None): Returns the rank (i.e., number of non-zero singular values) of a matrix (or a stack of matrices).
355
+ - `matrix_transpose`(x, /): Alias for matrix_transpose().
356
+ - `outer`(x1, x2, /): Returns the outer product of two vectors x1 and x2.
357
+ - `pinv`(x, /, *, rtol=None): Returns the (Moore-Penrose) pseudo-inverse of a matrix (or a stack of matrices) x.
358
+ - `qr`(x, /, *, mode='reduced'): Returns the QR decomposition of a full column rank matrix (or a stack of matrices).
359
+ - `slogdet`(x, /): Returns the sign and the natural logarithm of the absolute value of the determinant of a square matrix (or a stack of square matrices) x.
360
+ - `solve`(x1, x2, /): Returns the solution of a square system of linear equations with a unique solution.
361
+ - `svd`(x, /, *, full_matrices=True): Returns a singular value decomposition (SVD) of a matrix (or a stack of matrices) x.
362
+ - `svdvals`(x, /): Returns the singular values of a matrix (or a stack of matrices) x.
363
+ - `tensordot`(x1, x2, /, *, axes=2): Alias for tensordot().
364
+ - `trace`(x, /, *, offset=0, dtype=None): Returns the sum along the specified diagonals of a matrix (or a stack of matrices) x.
365
+ - `vecdot`(x1, x2, /, *, axis=-1): Alias for vecdot().
366
+ - `vector_norm`(x, /, *, axis=None, keepdims=False, ord=2): Computes the vector norm of a vector (or batch of vectors) x.
367
+ """
368
+ return self._get_xp_attr("linalg")
369
+
239
370
  # === Constants ===
240
371
  @property
241
372
  def e(self):
@@ -4,12 +4,13 @@
4
4
 
5
5
  from __future__ import annotations
6
6
  from numpy.typing import NDArray
7
+ from types import ModuleType
7
8
  from typing import (Union, List, Tuple, Optional, Any, Literal, overload)
8
9
 
9
10
  from ._base import Compat
10
11
  from ._array import CompatArray
11
12
  from ..types import (
12
- DTypeT, DeviceT, dtypeT, DType, AnyDevice,
13
+ dtypeT, DTypeT, deviceT, DeviceT, DType, AnyDevice,
13
14
  ValueT, Value, ArrayLike, ArrayOrAny
14
15
  )
15
16
 
@@ -81,17 +82,17 @@ class CompatNamespace(Compat):
81
82
  ...
82
83
 
83
84
  @overload
84
- def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: None = ...) -> CompatArray[int, Literal["cpu"]]: ...
85
+ def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: None = ...) -> CompatArray[type[int], Literal["cpu"]]: ...
85
86
  @overload
86
- def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: None = ...) -> CompatArray[float, Literal["cpu"]]: ...
87
+ def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: None = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
87
88
  @overload
88
- def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: DeviceT) -> CompatArray[int, DeviceT]: ...
89
+ def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: DeviceT) -> CompatArray[type[int], DeviceT]: ...
89
90
  @overload
90
- def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: AnyDevice) -> CompatArray[int, AnyDevice]: ...
91
+ def arange(self, start: int, /, stop: Optional[int] = ..., step: int = ..., *, dtype: None = ..., device: AnyDevice) -> CompatArray[type[int], AnyDevice]: ...
91
92
  @overload
92
- def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: DeviceT) -> CompatArray[float, DeviceT]: ...
93
+ def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: DeviceT) -> CompatArray[type[float], DeviceT]: ...
93
94
  @overload
94
- def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: AnyDevice) -> CompatArray[float, AnyDevice]: ...
95
+ def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: None = ..., device: AnyDevice) -> CompatArray[type[float], AnyDevice]: ...
95
96
  @overload
96
97
  def arange(self, start: Union[int, float], /, stop: Optional[Union[int, float]] = ..., step: Union[int, float] = ..., *, dtype: DTypeT, device: None = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
97
98
  @overload
@@ -142,11 +143,11 @@ class CompatNamespace(Compat):
142
143
  ...
143
144
 
144
145
  @overload
145
- def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[float, Literal["cpu"]]: ...
146
+ def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
146
147
  @overload
147
- def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[float, DeviceT]: ...
148
+ def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[type[float], DeviceT]: ...
148
149
  @overload
149
- def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[float, AnyDevice]: ...
150
+ def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[type[float], AnyDevice]: ...
150
151
  @overload
151
152
  def empty(self, shape: Union[int, Tuple[int, ...]], *, dtype: DTypeT, device: None = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
152
153
  @overload
@@ -229,11 +230,11 @@ class CompatNamespace(Compat):
229
230
  ...
230
231
 
231
232
  @overload
232
- def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: None = ...) -> CompatArray[float, Literal["cpu"]]: ...
233
+ def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: None = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
233
234
  @overload
234
- def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: DeviceT) -> CompatArray[float, DeviceT]: ...
235
+ def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: DeviceT) -> CompatArray[type[float], DeviceT]: ...
235
236
  @overload
236
- def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: AnyDevice) -> CompatArray[float, AnyDevice]: ...
237
+ def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: None = ..., device: AnyDevice) -> CompatArray[type[float], AnyDevice]: ...
237
238
  @overload
238
239
  def eye(self, n_rows: int, n_cols: Optional[int] = ..., /, *, k: int = ..., dtype: DTypeT, device: None = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
239
240
  @overload
@@ -431,11 +432,11 @@ class CompatNamespace(Compat):
431
432
  ...
432
433
 
433
434
  @overload
434
- def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: None = ..., endpoint: bool = ...) -> CompatArray[float, Literal["cpu"]]: ...
435
+ def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: None = ..., endpoint: bool = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
435
436
  @overload
436
- def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: DeviceT, endpoint: bool = ...) -> CompatArray[float, DeviceT]: ...
437
+ def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: DeviceT, endpoint: bool = ...) -> CompatArray[type[float], DeviceT]: ...
437
438
  @overload
438
- def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: AnyDevice, endpoint: bool = ...) -> CompatArray[float, AnyDevice]: ...
439
+ def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: None = ..., device: AnyDevice, endpoint: bool = ...) -> CompatArray[type[float], AnyDevice]: ...
439
440
  @overload
440
441
  def linspace(self, start: Union[int, float, complex], stop: Union[int, float, complex], /, num: int, *, dtype: DTypeT, device: None = ..., endpoint: bool = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
441
442
  @overload
@@ -496,11 +497,11 @@ class CompatNamespace(Compat):
496
497
  def meshgrid(self, *arrays: ArrayLike[Any], indexing: Literal["xy", "ij"] = "xy") -> List[CompatArray[Any, AnyDevice]]: ...
497
498
 
498
499
  @overload
499
- def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[float, Literal["cpu"]]: ...
500
+ def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
500
501
  @overload
501
- def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[float, DeviceT]: ...
502
+ def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[type[float], DeviceT]: ...
502
503
  @overload
503
- def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[float, AnyDevice]: ...
504
+ def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[type[float], AnyDevice]: ...
504
505
  @overload
505
506
  def ones(self, shape: Union[int, Tuple[int, ...]], *, dtype: DTypeT, device: None = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
506
507
  @overload
@@ -648,11 +649,11 @@ class CompatNamespace(Compat):
648
649
  ...
649
650
 
650
651
  @overload
651
- def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[float, Literal["cpu"]]: ...
652
+ def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: None = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
652
653
  @overload
653
- def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[float, DeviceT]: ...
654
+ def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: DeviceT) -> CompatArray[type[float], DeviceT]: ...
654
655
  @overload
655
- def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[float, AnyDevice]: ...
656
+ def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: None = ..., device: AnyDevice) -> CompatArray[type[float], AnyDevice]: ...
656
657
  @overload
657
658
  def zeros(self, shape: Union[int, Tuple[int, ...]], *, dtype: DTypeT, device: None = ...) -> CompatArray[DTypeT, Literal["cpu"]]: ...
658
659
  @overload
@@ -802,7 +803,7 @@ class CompatNamespace(Compat):
802
803
  Returns
803
804
  -------
804
805
  CompatArray
805
- A :class:`CompatArray` output array containing the concatenated values.
806
+ A :class:`CompatArray` output array containing the concatenated values.
806
807
  """
807
808
  ...
808
809
 
@@ -841,6 +842,26 @@ class CompatNamespace(Compat):
841
842
  """
842
843
  ...
843
844
 
845
+ # === Linear Algebra Extension ===
846
+ @overload
847
+ def vector_norm(self, x: NDArray[Any], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = ..., keepdims: bool = ..., ord: Union[int, float, Literal["inf", "-inf"]] = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
848
+ @overload
849
+ def vector_norm(self, x: CompatArray[Any, deviceT], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = ..., keepdims: bool = ..., ord: Union[int, float, Literal["inf", "-inf"]] = ...) -> CompatArray[type[float], deviceT]: ...
850
+ @overload
851
+ def vector_norm(self, x: ArrayLike[Any], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = ..., keepdims: bool = ..., ord: Union[int, float, Literal["inf", "-inf"]] = ...) -> CompatArray[type[float], AnyDevice]: ...
852
+ def vector_norm(self, x: ArrayLike[Any], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Union[int, float, Literal["inf", "-inf"]] = 2) -> CompatArray[type[float], AnyDevice]: ...
853
+
854
+ @overload
855
+ def matrix_norm(self, x: NDArray[Any], /, *, keepdims: bool = ..., ord: Optional[Union[int, float, Literal["inf", "-inf", "fro", "nuc"]]] = ...) -> CompatArray[type[float], Literal["cpu"]]: ...
856
+ @overload
857
+ def matrix_norm(self, x: CompatArray[Any, deviceT], /, *, keepdims: bool = ..., ord: Optional[Union[int, float, Literal["inf", "-inf", "fro", "nuc"]]] = ...) -> CompatArray[type[float], deviceT]: ...
858
+ @overload
859
+ def matrix_norm(self, x: ArrayLike[Any], /, *, keepdims: bool = ..., ord: Optional[Union[int, float, Literal["inf", "-inf", "fro", "nuc"]]] = ...) -> CompatArray[type[float], AnyDevice]: ...
860
+ def matrix_norm(self, x: ArrayLike[Any], /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal["inf", "-inf", "fro", "nuc"]]] = "fro") -> CompatArray[type[float], AnyDevice]: ...
861
+
862
+ @property
863
+ def linalg(self) -> ModuleType: ...
864
+
844
865
  # === Constants ===
845
866
  @property
846
867
  def e(self) -> float: ...
@@ -883,3 +904,5 @@ class CompatNamespace(Compat):
883
904
 
884
905
  @property
885
906
  def __name__(self) -> str: ...
907
+
908
+ def __getattr__(self, name: str) -> Any: ...
@@ -122,11 +122,11 @@ def to_tensor(obj, /, *, dtype=None, device=None, copy=True):
122
122
  - `None`: Raises `ConvertNoneTypeError`;
123
123
  - _others_: Converted to a `PyTorch tensor` directly.
124
124
 
125
- dtype : Optional[DTypeT], default to `None`
125
+ dtype : Optional[DType], default to `None`
126
126
  The data type of the resulting `PyTorch tensor`.
127
127
  - `None`: Use the default data type of the object.
128
128
 
129
- device : Optional[DeviceT], default to `None`
129
+ device : Optional[AnyDevice], default to `None`
130
130
  The device on which the resulting `PyTorch tensor` will be allocated.
131
131
  - `None`: Use the default device (usually `"cpu"`).
132
132
 
@@ -294,11 +294,11 @@ def as_array(obj, xp, /, *, dtype=None, device=None, copy=False, arraylike_only=
294
294
  - _ArrayLibraryName_ (`"numpy"` or `"torch"`): Converted to a `NumPy array` or `PyTorch tensor` respectively using the corresponding conversion functions;
295
295
  - _Namespace_ or _CompatNamespace_: Converted to an array using the `asarray()` function provided by the namespace module, which must be compatible with the array API standard.
296
296
 
297
- dtype : Optional[DTypeT], default to `None`
297
+ dtype : Optional[DType], default to `None`
298
298
  The data type of the resulting array.
299
299
  - `None`: Use the default data type of the object.
300
300
 
301
- device : Optional[DeviceT], default to `None`
301
+ device : Optional[AnyDevice], default to `None`
302
302
  The device on which the resulting array will be allocated (only if `array namespace` supports it).
303
303
 
304
304
  copy : bool, default to `False`
@@ -24,8 +24,9 @@ Device = Union[DeviceLiteral, torch.device]
24
24
  AnyDevice = Union[Device, str]
25
25
 
26
26
  dtypeT = TypeVar("dtypeT", bound=DType)
27
- deviceT = TypeVar("deviceT", bound=Device)
28
- anydeviceT = TypeVar("anydeviceT", bound=AnyDevice)
27
+ deviceT = TypeVar("deviceT", bound=AnyDevice)
28
+ # anydeviceT = TypeVar("anydeviceT", bound=AnyDevice)
29
+
29
30
  DTypeT = TypeVar("DTypeT", bound=DType)
30
31
  DeviceT = TypeVar("DeviceT", bound=Device)
31
32
  AnyDeviceT = TypeVar("AnyDeviceT", bound=AnyDevice)
@@ -55,16 +56,16 @@ ArrayOrInt = Union[ArrayLike[Any], int]
55
56
 
56
57
  class UniqueAllResult(NamedTuple, Generic[DTypeT_co, AnyDeviceT_co]):
57
58
  values: CompatArray[DTypeT_co, AnyDeviceT_co]
58
- indices: CompatArray[int, AnyDeviceT_co]
59
- inverse_indices: CompatArray[int, AnyDeviceT_co]
60
- counts: CompatArray[int, AnyDeviceT_co]
59
+ indices: CompatArray[type[int], AnyDeviceT_co]
60
+ inverse_indices: CompatArray[type[int], AnyDeviceT_co]
61
+ counts: CompatArray[type[int], AnyDeviceT_co]
61
62
 
62
63
 
63
64
  class UniqueCountsResult(NamedTuple, Generic[DTypeT_co, AnyDeviceT_co]):
64
65
  values: CompatArray[DTypeT_co, AnyDeviceT_co]
65
- counts: CompatArray[int, AnyDeviceT_co]
66
+ counts: CompatArray[type[int], AnyDeviceT_co]
66
67
 
67
68
 
68
69
  class UniqueInverseResult(NamedTuple, Generic[DTypeT_co, AnyDeviceT_co]):
69
70
  values: CompatArray[DTypeT_co, AnyDeviceT_co]
70
- inverse_indices: CompatArray[int, AnyDeviceT_co]
71
+ inverse_indices: CompatArray[type[int], AnyDeviceT_co]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cobra-array
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A backend-agnostic array utility library that unifies array conversion, context control, and cross-library operations across `NumPy`/`PyTorch`-style ecosystems.
5
5
  Author-email: Zhen Tian <zhen.tian.cs@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/tinchen777/cobra-array.git
@@ -67,10 +67,17 @@ def test_to_device_on_numpy_backend():
67
67
  @pytest.mark.skipif(torch_xp is None, reason="PyTorch not available")
68
68
  def test_to_tensor_available_backend():
69
69
  a = _arr_1d()
70
+
71
+ aa = a.to_device("cpu")
72
+
73
+ aa = CompatArray(a)
74
+
70
75
  t = a.to_tensor(device="cpu")
71
76
 
72
77
  assert isinstance(t, torch_xp.Tensor)
73
78
  assert tuple(t.shape) == (3,)
79
+
80
+ aa = wrap_arraylike(a)
74
81
 
75
82
 
76
83
  def test_unstack_and_nonzero():
@@ -78,6 +85,15 @@ def test_unstack_and_nonzero():
78
85
 
79
86
  pieces = a.unstack(axis=0)
80
87
  nz = a.nonzero()
88
+
89
+ ff = CompatArray(nz[0])
90
+
91
+ fff = a.cxp.zeros((1,2))
92
+
93
+ ff = CompatArray(fff)
94
+
95
+ gg = a.argmax()
96
+
81
97
 
82
98
  assert isinstance(pieces, tuple)
83
99
  assert len(pieces) == 2
@@ -470,13 +486,47 @@ def test_cxp_of_compatarray_matches_array_namespace():
470
486
  assert cxp.xp_name == a.xp_name
471
487
 
472
488
 
489
+ def test_linalg():
490
+ a = CompatArray(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
491
+ cxp = a.cxp
492
+ print(a)
493
+
494
+ f = cxp.vector_norm(a, axis=0, ord="-inf")
495
+
496
+ print(f)
497
+
498
+ f = cxp.matrix_norm(a, ord="nuc")
499
+
500
+ print(f)
501
+
502
+ # def test_linalg2():
503
+ # import torch
504
+ # a = CompatArray(torch.tensor([[1.0, 2.0], [3.0, 4.0]], device="cuda:0"))
505
+ # cxp = a.cxp
506
+ # print(a)
507
+
508
+ # f = cxp.vector_norm(a, axis=0, ord="-inf")
509
+
510
+ # print(f.device)
511
+
512
+ # print(f)
513
+
514
+ # f = cxp.matrix_norm(a, ord="nuc")
515
+
516
+ # print(f)
517
+ # print(f.device)
518
+
519
+
473
520
  if __name__ == "__main__":
474
- test_unique_all_unique_counts_unique_inverse()
475
- print("=" * 40)
521
+ # test_unique_all_unique_counts_unique_inverse()
522
+ # print("=" * 40)
476
523
 
477
- test_numpy_operator_overloads()
478
- print("=" * 40)
524
+ # test_numpy_operator_overloads()
525
+ # print("=" * 40)
479
526
  # test_torch_operator_overloads()
480
- print("=" * 40)
527
+ # print("=" * 40)
481
528
 
482
- test_add()
529
+ # test_add()
530
+
531
+ # test_linalg2()
532
+ pass
File without changes
File without changes
File without changes
File without changes