tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +18 -2
- tensorcircuit/about.py +46 -0
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +320 -7
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +102 -4
- tensorcircuit/backends/jax_ops.py +110 -1
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +157 -98
- tensorcircuit/circuit.py +115 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +142 -21
- tensorcircuit/densitymatrix.py +43 -14
- tensorcircuit/experimental.py +387 -129
- tensorcircuit/fgs.py +282 -81
- tensorcircuit/gates.py +66 -22
- tensorcircuit/interfaces/__init__.py +1 -3
- tensorcircuit/interfaces/jax.py +189 -0
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +154 -65
- tensorcircuit/quantum.py +868 -152
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +147 -20
- tensorcircuit/results/readout_mitigation.py +4 -1
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +479 -0
- tensorcircuit/templates/__init__.py +2 -0
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +174 -0
- tensorcircuit/templates/lattice.py +1789 -0
- tensorcircuit/timeevol.py +896 -0
- tensorcircuit/translation.py +10 -3
- tensorcircuit/utils.py +7 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1031
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -365
- tests/test_circuit.py +0 -1699
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -310
- tests/test_gates.py +0 -156
- tests/test_interfaces.py +0 -429
- tests/test_keras.py +0 -160
- tests/test_miscs.py +0 -277
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -526
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -347
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
|
@@ -9,6 +9,7 @@ from functools import reduce, partial
|
|
|
9
9
|
from operator import mul
|
|
10
10
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
|
11
11
|
|
|
12
|
+
import math
|
|
12
13
|
import numpy as np
|
|
13
14
|
from ..utils import return_partial
|
|
14
15
|
|
|
@@ -46,17 +47,22 @@ class ExtendedBackend:
|
|
|
46
47
|
"Backend '{}' has not implemented `expm`.".format(self.name)
|
|
47
48
|
)
|
|
48
49
|
|
|
49
|
-
def sqrtmh(self: Any, a: Tensor) -> Tensor:
|
|
50
|
+
def sqrtmh(self: Any, a: Tensor, psd: bool = False) -> Tensor:
|
|
50
51
|
"""
|
|
51
52
|
Return the sqrtm of a Hermitian matrix ``a``.
|
|
52
53
|
|
|
53
54
|
:param a: tensor in matrix form
|
|
54
55
|
:type a: Tensor
|
|
56
|
+
:param psd: whether the input ``a`` is guaranteed as a positive semidefinite matrix,
|
|
57
|
+
defaults False
|
|
58
|
+
:type psd: bool
|
|
55
59
|
:return: sqrtm of ``a``
|
|
56
60
|
:rtype: Tensor
|
|
57
61
|
"""
|
|
58
62
|
# maybe friendly for AD and also cosidering that several backend has no support for native sqrtm
|
|
59
63
|
e, v = self.eigh(a)
|
|
64
|
+
if psd:
|
|
65
|
+
e = self.relu(e)
|
|
60
66
|
e = self.sqrt(e)
|
|
61
67
|
return v @ self.diagflat(e) @ self.adjoint(v)
|
|
62
68
|
|
|
@@ -400,6 +406,31 @@ class ExtendedBackend:
|
|
|
400
406
|
a = self.reshape(a, [2 for _ in range(nleg)])
|
|
401
407
|
return a
|
|
402
408
|
|
|
409
|
+
def reshaped(self: Any, a: Tensor, d: int) -> Tensor:
|
|
410
|
+
"""
|
|
411
|
+
Reshape a tensor to the [d, d, ...] shape.
|
|
412
|
+
|
|
413
|
+
:param a: Input tensor
|
|
414
|
+
:type a: Tensor
|
|
415
|
+
:param d: edge length for each dimension
|
|
416
|
+
:type d: int
|
|
417
|
+
:return: the reshaped tensor
|
|
418
|
+
:rtype: Tensor
|
|
419
|
+
"""
|
|
420
|
+
if not isinstance(d, int) or d <= 0:
|
|
421
|
+
raise ValueError("d must be a positive integer.")
|
|
422
|
+
|
|
423
|
+
size = self.sizen(a)
|
|
424
|
+
if size == 0:
|
|
425
|
+
return self.reshape(a, (0,))
|
|
426
|
+
|
|
427
|
+
nleg_float = math.log(size, d)
|
|
428
|
+
nleg = int(round(nleg_float))
|
|
429
|
+
if d**nleg != size:
|
|
430
|
+
raise ValueError(f"cannot reshape: size {size} is not a power of d={d}")
|
|
431
|
+
|
|
432
|
+
return self.reshape(a, (d,) * nleg)
|
|
433
|
+
|
|
403
434
|
def reshapem(self: Any, a: Tensor) -> Tensor:
|
|
404
435
|
"""
|
|
405
436
|
Reshape a tensor to the [l, l] shape.
|
|
@@ -576,6 +607,97 @@ class ExtendedBackend:
|
|
|
576
607
|
"Backend '{}' has not implemented `argmin`.".format(self.name)
|
|
577
608
|
)
|
|
578
609
|
|
|
610
|
+
def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
|
|
611
|
+
"""
|
|
612
|
+
return the indices that would sort an array.
|
|
613
|
+
|
|
614
|
+
:param a: the tensor to be sorted
|
|
615
|
+
:type a: Tensor
|
|
616
|
+
:param axis: the sorted axis, defaults to -1
|
|
617
|
+
:type axis: int
|
|
618
|
+
:return: the sorted indices
|
|
619
|
+
:rtype: Tensor
|
|
620
|
+
"""
|
|
621
|
+
raise NotImplementedError(
|
|
622
|
+
"Backend '{}' has not implemented `argsort`.".format(self.name)
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
|
|
626
|
+
"""
|
|
627
|
+
Sort a tensor along the given axis.
|
|
628
|
+
|
|
629
|
+
:param a: [description]
|
|
630
|
+
:type a: Tensor
|
|
631
|
+
:param axis: [description], defaults to -1
|
|
632
|
+
:type axis: int, optional
|
|
633
|
+
:return: [description]
|
|
634
|
+
:rtype: Tensor
|
|
635
|
+
"""
|
|
636
|
+
raise NotImplementedError(
|
|
637
|
+
"Backend '{}' has not implemented `sort`.".format(self.name)
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
641
|
+
"""
|
|
642
|
+
Test whether all array elements along a given axis evaluate to True.
|
|
643
|
+
|
|
644
|
+
:param a: Input tensor
|
|
645
|
+
:type a: Tensor
|
|
646
|
+
:param axis: Axis or axes along which a logical AND reduction is performed,
|
|
647
|
+
defaults to None
|
|
648
|
+
:type axis: Optional[Sequence[int]], optional
|
|
649
|
+
:return: A new boolean or tensor resulting from the AND reduction
|
|
650
|
+
:rtype: Tensor
|
|
651
|
+
"""
|
|
652
|
+
raise NotImplementedError(
|
|
653
|
+
"Backend '{}' has not implemented `all`.".format(self.name)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
657
|
+
"""
|
|
658
|
+
Return coordinate matrices from coordinate vectors.
|
|
659
|
+
|
|
660
|
+
:param args: coordinate vectors
|
|
661
|
+
:type args: Any
|
|
662
|
+
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
|
|
663
|
+
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
|
|
664
|
+
- 'ij': matrix indexing, first dimension corresponds to rows (default)
|
|
665
|
+
- 'xy': Cartesian indexing, first dimension corresponds to columns
|
|
666
|
+
Example:
|
|
667
|
+
>>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
|
|
668
|
+
Shapes:
|
|
669
|
+
- x.shape == (2, 2) # rows correspond to y vector length
|
|
670
|
+
- y.shape == (2, 2)
|
|
671
|
+
Values:
|
|
672
|
+
x = [[0, 1],
|
|
673
|
+
[0, 1]]
|
|
674
|
+
y = [[0, 0],
|
|
675
|
+
[2, 2]]
|
|
676
|
+
:type kwargs: Any
|
|
677
|
+
:return: list of coordinate matrices
|
|
678
|
+
:rtype: Any
|
|
679
|
+
"""
|
|
680
|
+
raise NotImplementedError(
|
|
681
|
+
"Backend '{}' has not implemented `meshgrid`.".format(self.name)
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
|
|
685
|
+
"""
|
|
686
|
+
Expand the shape of a tensor.
|
|
687
|
+
Insert a new axis that will appear at the `axis` position in the expanded
|
|
688
|
+
tensor shape.
|
|
689
|
+
|
|
690
|
+
:param a: Input tensor
|
|
691
|
+
:type a: Tensor
|
|
692
|
+
:param axis: Position in the expanded axes where the new axis is placed
|
|
693
|
+
:type axis: int
|
|
694
|
+
:return: Output tensor with the number of dimensions increased by one.
|
|
695
|
+
:rtype: Tensor
|
|
696
|
+
"""
|
|
697
|
+
raise NotImplementedError(
|
|
698
|
+
"Backend '{}' has not implemented `expand_dims`.".format(self.name)
|
|
699
|
+
)
|
|
700
|
+
|
|
579
701
|
def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
|
|
580
702
|
"""
|
|
581
703
|
Find the unique elements and their corresponding counts of the given tensor ``a``.
|
|
@@ -695,6 +817,9 @@ class ExtendedBackend:
|
|
|
695
817
|
"Backend '{}' has not implemented `is_tensor`.".format(self.name)
|
|
696
818
|
)
|
|
697
819
|
|
|
820
|
+
def matvec(self: Any, A: Tensor, x: Tensor) -> Tensor:
|
|
821
|
+
return self.tensordot(A, x, axes=[[1], [0]])
|
|
822
|
+
|
|
698
823
|
def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
|
|
699
824
|
"""
|
|
700
825
|
Cast the tensor dtype of a ``a``.
|
|
@@ -710,6 +835,21 @@ class ExtendedBackend:
|
|
|
710
835
|
"Backend '{}' has not implemented `cast`.".format(self.name)
|
|
711
836
|
)
|
|
712
837
|
|
|
838
|
+
def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
839
|
+
"""
|
|
840
|
+
Convert input to tensor.
|
|
841
|
+
|
|
842
|
+
:param a: input data to be converted
|
|
843
|
+
:type a: Tensor
|
|
844
|
+
:param dtype: target dtype, optional
|
|
845
|
+
:type dtype: Optional[str]
|
|
846
|
+
:return: converted tensor
|
|
847
|
+
:rtype: Tensor
|
|
848
|
+
"""
|
|
849
|
+
raise NotImplementedError(
|
|
850
|
+
"Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
|
|
851
|
+
)
|
|
852
|
+
|
|
713
853
|
def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
|
|
714
854
|
"""
|
|
715
855
|
Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
|
|
@@ -725,6 +865,82 @@ class ExtendedBackend:
|
|
|
725
865
|
"Backend '{}' has not implemented `mod`.".format(self.name)
|
|
726
866
|
)
|
|
727
867
|
|
|
868
|
+
def floor_divide(self: Any, x: Tensor, y: Tensor) -> Tensor:
|
|
869
|
+
r"""
|
|
870
|
+
Compute the element-wise floor division of two tensors.
|
|
871
|
+
|
|
872
|
+
This operation returns a new tensor containing the result of
|
|
873
|
+
dividing `x` by `y` and rounding each element down towards
|
|
874
|
+
negative infinity. The semantics are equivalent to the Python
|
|
875
|
+
`//` operator:
|
|
876
|
+
|
|
877
|
+
result[i] = floor(x[i] / y[i])
|
|
878
|
+
|
|
879
|
+
Broadcasting is supported according to the backend's rules.
|
|
880
|
+
|
|
881
|
+
:param x: Dividend tensor.
|
|
882
|
+
:type x: Tensor
|
|
883
|
+
:param y: Divisor tensor, must be broadcastable with `x`.
|
|
884
|
+
:type y: Tensor
|
|
885
|
+
:return: A tensor with the broadcasted shape of `x` and `y`,
|
|
886
|
+
where each element is the floored result of the division.
|
|
887
|
+
:rtype: Tensor
|
|
888
|
+
|
|
889
|
+
:raises NotImplementedError: If the backend does not provide an
|
|
890
|
+
implementation for `floor_divide`.
|
|
891
|
+
"""
|
|
892
|
+
raise NotImplementedError(
|
|
893
|
+
"Backend '{}' has not implemented `floor_divide`.".format(self.name)
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
def floor(self: Any, a: Tensor) -> Tensor:
|
|
897
|
+
"""
|
|
898
|
+
Compute the element-wise floor of the input tensor.
|
|
899
|
+
|
|
900
|
+
This operation returns a new tensor with the largest integers
|
|
901
|
+
less than or equal to each element of the input tensor,
|
|
902
|
+
i.e. it rounds each value down towards negative infinity.
|
|
903
|
+
|
|
904
|
+
:param a: Input tensor containing numeric values.
|
|
905
|
+
:type a: Tensor
|
|
906
|
+
:return: A tensor with the same shape as `a`, where each element
|
|
907
|
+
is the floored value of the corresponding element in `a`.
|
|
908
|
+
:rtype: Tensor
|
|
909
|
+
|
|
910
|
+
:raises NotImplementedError: If the backend does not provide an
|
|
911
|
+
implementation for `floor`.
|
|
912
|
+
"""
|
|
913
|
+
raise NotImplementedError(
|
|
914
|
+
"Backend '{}' has not implemented `floor`.".format(self.name)
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
918
|
+
"""
|
|
919
|
+
Clip (limit) the values of a tensor element-wise to the range [a_min, a_max].
|
|
920
|
+
|
|
921
|
+
Each element in the input tensor `a` is compared against the corresponding
|
|
922
|
+
bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set
|
|
923
|
+
to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the
|
|
924
|
+
value is left unchanged. The result preserves the dtype and device of the input.
|
|
925
|
+
|
|
926
|
+
:param a: Input tensor containing values to be clipped.
|
|
927
|
+
:type a: Tensor
|
|
928
|
+
:param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor
|
|
929
|
+
or broadcastable to the shape of `a`.
|
|
930
|
+
:type a_min: Tensor
|
|
931
|
+
:param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor
|
|
932
|
+
or broadcastable to the shape of `a`.
|
|
933
|
+
:type a_max: Tensor
|
|
934
|
+
:return: A tensor with the same shape as `a`, where all values are clipped
|
|
935
|
+
to lie within the interval [a_min, a_max].
|
|
936
|
+
:rtype: Tensor
|
|
937
|
+
|
|
938
|
+
:raises NotImplementedError: If the backend does not implement `clip`.
|
|
939
|
+
"""
|
|
940
|
+
raise NotImplementedError(
|
|
941
|
+
"Backend '{}' has not implemented `clip`.".format(self.name)
|
|
942
|
+
)
|
|
943
|
+
|
|
728
944
|
def reverse(self: Any, a: Tensor) -> Tensor:
|
|
729
945
|
"""
|
|
730
946
|
return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
|
|
@@ -800,6 +1016,21 @@ class ExtendedBackend:
|
|
|
800
1016
|
"Backend '{}' has not implemented `solve`.".format(self.name)
|
|
801
1017
|
)
|
|
802
1018
|
|
|
1019
|
+
def special_jv(self: Any, v: int, z: Tensor, M: int) -> Tensor:
|
|
1020
|
+
"""
|
|
1021
|
+
Special function: Bessel function of the first kind.
|
|
1022
|
+
|
|
1023
|
+
:param v: The order of the Bessel function.
|
|
1024
|
+
:type v: int
|
|
1025
|
+
:param z: The argument of the Bessel function.
|
|
1026
|
+
:type z: Tensor
|
|
1027
|
+
:return: The value of the Bessel function [J_0, ...J_{v-1}(z)].
|
|
1028
|
+
:rtype: Tensor
|
|
1029
|
+
"""
|
|
1030
|
+
raise NotImplementedError(
|
|
1031
|
+
"Backend '{}' has not implemented `special_jv`.".format(self.name)
|
|
1032
|
+
)
|
|
1033
|
+
|
|
803
1034
|
def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
|
|
804
1035
|
"""
|
|
805
1036
|
Find indices where elements should be inserted to maintain order.
|
|
@@ -808,8 +1039,8 @@ class ExtendedBackend:
|
|
|
808
1039
|
:type a: Tensor
|
|
809
1040
|
:param v: value to inserted
|
|
810
1041
|
:type v: Tensor
|
|
811
|
-
:param side: If
|
|
812
|
-
If
|
|
1042
|
+
:param side: If `left`, the index of the first suitable location found is given.
|
|
1043
|
+
If `right`, return the last such index.
|
|
813
1044
|
If there is no suitable index, return either 0 or N (where N is the length of a),
|
|
814
1045
|
defaults to "left"
|
|
815
1046
|
:type side: str, optional
|
|
@@ -1243,6 +1474,26 @@ class ExtendedBackend:
|
|
|
1243
1474
|
"Backend '{}' has not implemented `sparse_dense_matmul`.".format(self.name)
|
|
1244
1475
|
)
|
|
1245
1476
|
|
|
1477
|
+
def sparse_csr_from_coo(self: Any, coo: Tensor, strict: bool = False) -> Tensor:
|
|
1478
|
+
"""
|
|
1479
|
+
transform a coo matrix to a csr matrix
|
|
1480
|
+
|
|
1481
|
+
:param coo: a coo matrix
|
|
1482
|
+
:type coo: Tensor
|
|
1483
|
+
:param strict: whether to enforce the transform, defaults to False,
|
|
1484
|
+
corresponding to return the coo matrix if there is no implementation for specific backend.
|
|
1485
|
+
:type strict: bool, optional
|
|
1486
|
+
:return: a csr matrix
|
|
1487
|
+
:rtype: Tensor
|
|
1488
|
+
"""
|
|
1489
|
+
if strict:
|
|
1490
|
+
raise NotImplementedError(
|
|
1491
|
+
"Backend '{}' has not implemented `sparse_csr_from_coo`.".format(
|
|
1492
|
+
self.name
|
|
1493
|
+
)
|
|
1494
|
+
)
|
|
1495
|
+
return coo
|
|
1496
|
+
|
|
1246
1497
|
def to_dense(self: Any, sp_a: Tensor) -> Tensor:
|
|
1247
1498
|
"""
|
|
1248
1499
|
Convert a sparse matrix to dense tensor.
|
|
@@ -1346,6 +1597,28 @@ class ExtendedBackend:
|
|
|
1346
1597
|
"Backend '{}' has not implemented `cond`.".format(self.name)
|
|
1347
1598
|
)
|
|
1348
1599
|
|
|
1600
|
+
def where(
|
|
1601
|
+
self: Any,
|
|
1602
|
+
condition: Tensor,
|
|
1603
|
+
x: Optional[Tensor] = None,
|
|
1604
|
+
y: Optional[Tensor] = None,
|
|
1605
|
+
) -> Tensor:
|
|
1606
|
+
"""
|
|
1607
|
+
Return a tensor of elements selected from either x or y, depending on condition.
|
|
1608
|
+
|
|
1609
|
+
:param condition: Where True, yield x, otherwise yield y.
|
|
1610
|
+
:type condition: Tensor (bool)
|
|
1611
|
+
:param x: Values from which to choose when condition is True.
|
|
1612
|
+
:type x: Tensor
|
|
1613
|
+
:param y: Values from which to choose when condition is False.
|
|
1614
|
+
:type y: Tensor
|
|
1615
|
+
:return: A tensor with elements from x where condition is True, and y otherwise.
|
|
1616
|
+
:rtype: Tensor
|
|
1617
|
+
"""
|
|
1618
|
+
raise NotImplementedError(
|
|
1619
|
+
"Backend '{}' has not implemented `where`.".format(self.name)
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1349
1622
|
def switch(
|
|
1350
1623
|
self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
|
|
1351
1624
|
) -> Tensor:
|
|
@@ -1381,10 +1654,49 @@ class ExtendedBackend:
|
|
|
1381
1654
|
:rtype: Tensor
|
|
1382
1655
|
"""
|
|
1383
1656
|
carry = init
|
|
1384
|
-
|
|
1385
|
-
|
|
1657
|
+
# Check if `xs` is a PyTree (tuple or list) of arrays.
|
|
1658
|
+
if isinstance(xs, (tuple, list)):
|
|
1659
|
+
for x_slice_tuple in zip(*xs):
|
|
1660
|
+
# x_slice_tuple will be (k_elems[i], j_elems[i]) at each step.
|
|
1661
|
+
carry = f(carry, x_slice_tuple)
|
|
1662
|
+
else:
|
|
1663
|
+
# If xs is a single array, iterate normally.
|
|
1664
|
+
for x in xs:
|
|
1665
|
+
carry = f(carry, x)
|
|
1666
|
+
|
|
1386
1667
|
return carry
|
|
1387
1668
|
|
|
1669
|
+
def jaxy_scan(
|
|
1670
|
+
self: Any, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
|
|
1671
|
+
) -> Tensor:
|
|
1672
|
+
"""
|
|
1673
|
+
This API follows jax scan style. TF use plain for loop
|
|
1674
|
+
|
|
1675
|
+
:param f: _description_
|
|
1676
|
+
:type f: Callable[[Tensor, Tensor], Tensor]
|
|
1677
|
+
:param init: _description_
|
|
1678
|
+
:type init: Tensor
|
|
1679
|
+
:param xs: _description_
|
|
1680
|
+
:type xs: Tensor
|
|
1681
|
+
:raises ValueError: _description_
|
|
1682
|
+
:return: _description_
|
|
1683
|
+
:rtype: Tensor
|
|
1684
|
+
"""
|
|
1685
|
+
if xs is None:
|
|
1686
|
+
raise ValueError("Either xs or length must be provided.")
|
|
1687
|
+
if xs is not None:
|
|
1688
|
+
length = len(xs)
|
|
1689
|
+
carry, outputs_to_stack = init, []
|
|
1690
|
+
for i in range(length):
|
|
1691
|
+
if isinstance(xs, (tuple, list)):
|
|
1692
|
+
x = [ele[i] for ele in xs]
|
|
1693
|
+
else:
|
|
1694
|
+
x = xs[i]
|
|
1695
|
+
new_carry, y = f(carry, x)
|
|
1696
|
+
carry = new_carry
|
|
1697
|
+
outputs_to_stack.append(y)
|
|
1698
|
+
return carry, self.stack(outputs_to_stack)
|
|
1699
|
+
|
|
1388
1700
|
def stop_gradient(self: Any, a: Tensor) -> Tensor:
|
|
1389
1701
|
"""
|
|
1390
1702
|
Stop backpropagation from ``a``.
|
|
@@ -1548,7 +1860,7 @@ class ExtendedBackend:
|
|
|
1548
1860
|
if i == argnum
|
|
1549
1861
|
else self.reshape(
|
|
1550
1862
|
self.zeros(
|
|
1551
|
-
[self.sizen(
|
|
1863
|
+
[self.sizen(args[argnum]), self.sizen(arg)],
|
|
1552
1864
|
dtype=arg.dtype,
|
|
1553
1865
|
),
|
|
1554
1866
|
[-1] + list(self.shape_tuple(arg)),
|
|
@@ -1636,6 +1948,7 @@ class ExtendedBackend:
|
|
|
1636
1948
|
),
|
|
1637
1949
|
jj,
|
|
1638
1950
|
)
|
|
1951
|
+
jj = [jji for ind, jji in enumerate(jj) if ind in argnums]
|
|
1639
1952
|
if len(jj) == 1:
|
|
1640
1953
|
jj = jj[0]
|
|
1641
1954
|
jjs.append(jj)
|
|
@@ -1662,7 +1975,7 @@ class ExtendedBackend:
|
|
|
1662
1975
|
f: Callable[..., Any],
|
|
1663
1976
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
1664
1977
|
jit_compile: Optional[bool] = None,
|
|
1665
|
-
**kws: Any
|
|
1978
|
+
**kws: Any,
|
|
1666
1979
|
) -> Callable[..., Any]:
|
|
1667
1980
|
"""
|
|
1668
1981
|
Return the jitted version of function ``f``.
|
|
@@ -56,10 +56,12 @@ class CuPyBackend(tnbackend, ExtendedBackend): # type: ignore
|
|
|
56
56
|
cpx = cupyx
|
|
57
57
|
self.name = "cupy"
|
|
58
58
|
|
|
59
|
-
def convert_to_tensor(self, a: Tensor) -> Tensor:
|
|
59
|
+
def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
60
60
|
if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
|
|
61
61
|
a = cp.array(a)
|
|
62
62
|
a = cp.asarray(a)
|
|
63
|
+
if dtype is not None:
|
|
64
|
+
a = self.cast(a, dtype)
|
|
63
65
|
return a
|
|
64
66
|
|
|
65
67
|
def sum(
|
|
@@ -50,12 +50,17 @@ class optax_optimizer:
|
|
|
50
50
|
return params
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def _convert_to_tensor_jax(
|
|
53
|
+
def _convert_to_tensor_jax(
|
|
54
|
+
self: Any, tensor: Tensor, dtype: Optional[str] = None
|
|
55
|
+
) -> Tensor:
|
|
54
56
|
if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
|
|
55
57
|
raise TypeError(
|
|
56
58
|
("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
|
|
57
59
|
)
|
|
58
60
|
result = jnp.asarray(tensor)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
# Use the backend's cast method to handle dtype conversion
|
|
63
|
+
result = self.cast(result, dtype)
|
|
59
64
|
return result
|
|
60
65
|
|
|
61
66
|
|
|
@@ -170,6 +175,27 @@ def _eigh_jax(self: Any, tensor: Tensor) -> Tensor:
|
|
|
170
175
|
return adaware_eigh(tensor)
|
|
171
176
|
|
|
172
177
|
|
|
178
|
+
def bcsr_scalar_mul(self: Tensor, other: Tensor) -> Tensor:
|
|
179
|
+
"""
|
|
180
|
+
Implements scalar multiplication for BCSR matrices (self * scalar).
|
|
181
|
+
"""
|
|
182
|
+
import jax.numpy as jnp
|
|
183
|
+
from jax.experimental.sparse import BCSR
|
|
184
|
+
|
|
185
|
+
if jnp.isscalar(other):
|
|
186
|
+
# The core logic: only the data array is affected by scalar multiplication.
|
|
187
|
+
# The sparsity pattern (indices, indptr) remains the same.
|
|
188
|
+
new_data = self.data * other
|
|
189
|
+
|
|
190
|
+
# Return a new BCSR instance with the scaled data.
|
|
191
|
+
return BCSR((new_data, self.indices, self.indptr), shape=self.shape)
|
|
192
|
+
|
|
193
|
+
# For any other type of multiplication (e.g., element-wise with another matrix),
|
|
194
|
+
# return NotImplemented. This allows Python to try other operations,
|
|
195
|
+
# like other.__rmul__(self).
|
|
196
|
+
return NotImplemented
|
|
197
|
+
|
|
198
|
+
|
|
173
199
|
tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = (
|
|
174
200
|
_convert_to_tensor_jax
|
|
175
201
|
)
|
|
@@ -219,6 +245,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
219
245
|
|
|
220
246
|
self.name = "jax"
|
|
221
247
|
|
|
248
|
+
# --- Monkey-patch the BCSR class ---
|
|
249
|
+
|
|
250
|
+
sparse.BCSR.__mul__ = bcsr_scalar_mul # type: ignore
|
|
251
|
+
sparse.BCSR.__rmul__ = bcsr_scalar_mul # type: ignore
|
|
252
|
+
|
|
222
253
|
# it is already child of numpy backend, and self.np = self.jax.np
|
|
223
254
|
def eye(
|
|
224
255
|
self, N: int, dtype: Optional[str] = None, M: Optional[int] = None
|
|
@@ -243,8 +274,10 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
243
274
|
def copy(self, tensor: Tensor) -> Tensor:
|
|
244
275
|
return jnp.array(tensor, copy=True)
|
|
245
276
|
|
|
246
|
-
def convert_to_tensor(self, tensor: Tensor) -> Tensor:
|
|
277
|
+
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
247
278
|
result = jnp.asarray(tensor)
|
|
279
|
+
if dtype is not None:
|
|
280
|
+
result = self.cast(result, dtype)
|
|
248
281
|
return result
|
|
249
282
|
|
|
250
283
|
def abs(self, a: Tensor) -> Tensor:
|
|
@@ -342,6 +375,15 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
342
375
|
def mod(self, x: Tensor, y: Tensor) -> Tensor:
|
|
343
376
|
return jnp.mod(x, y)
|
|
344
377
|
|
|
378
|
+
def floor(self, a: Tensor) -> Tensor:
|
|
379
|
+
return jnp.floor(a)
|
|
380
|
+
|
|
381
|
+
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
|
|
382
|
+
return jnp.floor_divide(x, y)
|
|
383
|
+
|
|
384
|
+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
385
|
+
return jnp.clip(a, a_min, a_max)
|
|
386
|
+
|
|
345
387
|
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
|
|
346
388
|
return jnp.right_shift(x, y)
|
|
347
389
|
|
|
@@ -387,6 +429,12 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
387
429
|
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
388
430
|
return jnp.argmin(a, axis=axis)
|
|
389
431
|
|
|
432
|
+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
433
|
+
return jnp.argsort(a, axis=axis)
|
|
434
|
+
|
|
435
|
+
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
436
|
+
return jnp.sort(a, axis=axis)
|
|
437
|
+
|
|
390
438
|
def unique_with_counts( # type: ignore
|
|
391
439
|
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
|
|
392
440
|
) -> Tuple[Tensor, Tensor]:
|
|
@@ -407,6 +455,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
407
455
|
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
|
|
408
456
|
return jnp.cumsum(a, axis)
|
|
409
457
|
|
|
458
|
+
def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
459
|
+
return jnp.all(a, axis=axis)
|
|
460
|
+
|
|
410
461
|
def is_tensor(self, a: Any) -> bool:
|
|
411
462
|
if not isinstance(a, jnp.ndarray):
|
|
412
463
|
return False
|
|
@@ -418,6 +469,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
418
469
|
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
|
|
419
470
|
return jsp.linalg.solve(A, b, assume_a=assume_a)
|
|
420
471
|
|
|
472
|
+
def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
|
|
473
|
+
from .jax_ops import bessel_jv_jax_rescaled
|
|
474
|
+
|
|
475
|
+
return bessel_jv_jax_rescaled(v, z, M)
|
|
476
|
+
|
|
421
477
|
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
|
|
422
478
|
if not self.is_tensor(a):
|
|
423
479
|
a = self.convert_to_tensor(a)
|
|
@@ -442,7 +498,14 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
442
498
|
def to_dlpack(self, a: Tensor) -> Any:
|
|
443
499
|
import jax.dlpack
|
|
444
500
|
|
|
445
|
-
|
|
501
|
+
try:
|
|
502
|
+
return jax.dlpack.to_dlpack(a) # type: ignore
|
|
503
|
+
except AttributeError: # jax >v0.7
|
|
504
|
+
# jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.
|
|
505
|
+
# Please use the newer DLPack API based on __dlpack__ and __dlpack_device__ instead.
|
|
506
|
+
# Typically, you can pass a JAX array directly to the `from_dlpack` function of
|
|
507
|
+
# another framework without using `to_dlpack`.
|
|
508
|
+
return a.__dlpack__()
|
|
446
509
|
|
|
447
510
|
def set_random_state(
|
|
448
511
|
self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False
|
|
@@ -608,7 +671,14 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
608
671
|
carry, _ = libjax.lax.scan(f_jax, init, xs)
|
|
609
672
|
return carry
|
|
610
673
|
|
|
674
|
+
def jaxy_scan(
|
|
675
|
+
self, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
|
|
676
|
+
) -> Tensor:
|
|
677
|
+
return libjax.lax.scan(f, init, xs)
|
|
678
|
+
|
|
611
679
|
def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
|
|
680
|
+
# updates = jnp.reshape(updates, indices.shape)
|
|
681
|
+
# return operand.at[indices].set(updates)
|
|
612
682
|
rank = len(operand.shape)
|
|
613
683
|
dnums = libjax.lax.ScatterDimensionNumbers(
|
|
614
684
|
update_window_dims=(),
|
|
@@ -630,11 +700,20 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
630
700
|
) -> Tensor:
|
|
631
701
|
return sp_a @ b
|
|
632
702
|
|
|
703
|
+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
|
|
704
|
+
try:
|
|
705
|
+
return sparse.BCSR.from_bcoo(coo) # type: ignore
|
|
706
|
+
except AttributeError as e:
|
|
707
|
+
if not strict:
|
|
708
|
+
return coo
|
|
709
|
+
else:
|
|
710
|
+
raise e
|
|
711
|
+
|
|
633
712
|
def to_dense(self, sp_a: Tensor) -> Tensor:
|
|
634
713
|
return sp_a.todense()
|
|
635
714
|
|
|
636
715
|
def is_sparse(self, a: Tensor) -> bool:
|
|
637
|
-
return isinstance(a, sparse.
|
|
716
|
+
return isinstance(a, sparse.JAXSparse) # type: ignore
|
|
638
717
|
|
|
639
718
|
def device(self, a: Tensor) -> str:
|
|
640
719
|
(dev,) = a.devices()
|
|
@@ -781,4 +860,23 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
781
860
|
|
|
782
861
|
vvag = vectorized_value_and_grad
|
|
783
862
|
|
|
863
|
+
def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
|
|
864
|
+
"""
|
|
865
|
+
Backend-agnostic meshgrid function.
|
|
866
|
+
"""
|
|
867
|
+
return jnp.meshgrid(*args, **kwargs)
|
|
868
|
+
|
|
784
869
|
optimizer = optax_optimizer
|
|
870
|
+
|
|
871
|
+
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
|
|
872
|
+
return jnp.expand_dims(a, axis)
|
|
873
|
+
|
|
874
|
+
def where(
|
|
875
|
+
self,
|
|
876
|
+
condition: Tensor,
|
|
877
|
+
x: Optional[Tensor] = None,
|
|
878
|
+
y: Optional[Tensor] = None,
|
|
879
|
+
) -> Tensor:
|
|
880
|
+
if x is None and y is None:
|
|
881
|
+
return jnp.where(condition)
|
|
882
|
+
return jnp.where(condition, x, y)
|