tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +5 -1
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +312 -5
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +92 -3
- tensorcircuit/backends/jax_ops.py +108 -0
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +123 -82
- tensorcircuit/circuit.py +67 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +1 -0
- tensorcircuit/densitymatrix.py +16 -11
- tensorcircuit/experimental.py +7 -152
- tensorcircuit/fgs.py +5 -6
- tensorcircuit/gates.py +66 -22
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +109 -61
- tensorcircuit/quantum.py +697 -133
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +45 -31
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +4 -2
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +29 -8
- tensorcircuit/templates/lattice.py +676 -335
- tensorcircuit/timeevol.py +896 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1035
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -409
- tests/test_circuit.py +0 -1713
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -318
- tests/test_gates.py +0 -156
- tests/test_hamiltonians.py +0 -159
- tests/test_interfaces.py +0 -557
- tests/test_keras.py +0 -160
- tests/test_lattice.py +0 -1666
- tests/test_miscs.py +0 -334
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -549
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -379
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_stabilizer.py +0 -226
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,6 +9,7 @@ from functools import reduce, partial
|
|
|
9
9
|
from operator import mul
|
|
10
10
|
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
|
11
11
|
|
|
12
|
+
import math
|
|
12
13
|
import numpy as np
|
|
13
14
|
from ..utils import return_partial
|
|
14
15
|
|
|
@@ -405,6 +406,31 @@ class ExtendedBackend:
|
|
|
405
406
|
a = self.reshape(a, [2 for _ in range(nleg)])
|
|
406
407
|
return a
|
|
407
408
|
|
|
409
|
+
def reshaped(self: Any, a: Tensor, d: int) -> Tensor:
|
|
410
|
+
"""
|
|
411
|
+
Reshape a tensor to the [d, d, ...] shape.
|
|
412
|
+
|
|
413
|
+
:param a: Input tensor
|
|
414
|
+
:type a: Tensor
|
|
415
|
+
:param d: edge length for each dimension
|
|
416
|
+
:type d: int
|
|
417
|
+
:return: the reshaped tensor
|
|
418
|
+
:rtype: Tensor
|
|
419
|
+
"""
|
|
420
|
+
if not isinstance(d, int) or d <= 0:
|
|
421
|
+
raise ValueError("d must be a positive integer.")
|
|
422
|
+
|
|
423
|
+
size = self.sizen(a)
|
|
424
|
+
if size == 0:
|
|
425
|
+
return self.reshape(a, (0,))
|
|
426
|
+
|
|
427
|
+
nleg_float = math.log(size, d)
|
|
428
|
+
nleg = int(round(nleg_float))
|
|
429
|
+
if d**nleg != size:
|
|
430
|
+
raise ValueError(f"cannot reshape: size {size} is not a power of d={d}")
|
|
431
|
+
|
|
432
|
+
return self.reshape(a, (d,) * nleg)
|
|
433
|
+
|
|
408
434
|
def reshapem(self: Any, a: Tensor) -> Tensor:
|
|
409
435
|
"""
|
|
410
436
|
Reshape a tensor to the [l, l] shape.
|
|
@@ -581,6 +607,97 @@ class ExtendedBackend:
|
|
|
581
607
|
"Backend '{}' has not implemented `argmin`.".format(self.name)
|
|
582
608
|
)
|
|
583
609
|
|
|
610
|
+
def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
|
|
611
|
+
"""
|
|
612
|
+
return the indices that would sort an array.
|
|
613
|
+
|
|
614
|
+
:param a: the tensor to be sorted
|
|
615
|
+
:type a: Tensor
|
|
616
|
+
:param axis: the sorted axis, defaults to -1
|
|
617
|
+
:type axis: int
|
|
618
|
+
:return: the sorted indices
|
|
619
|
+
:rtype: Tensor
|
|
620
|
+
"""
|
|
621
|
+
raise NotImplementedError(
|
|
622
|
+
"Backend '{}' has not implemented `argsort`.".format(self.name)
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
|
|
626
|
+
"""
|
|
627
|
+
Sort a tensor along the given axis.
|
|
628
|
+
|
|
629
|
+
:param a: [description]
|
|
630
|
+
:type a: Tensor
|
|
631
|
+
:param axis: [description], defaults to -1
|
|
632
|
+
:type axis: int, optional
|
|
633
|
+
:return: [description]
|
|
634
|
+
:rtype: Tensor
|
|
635
|
+
"""
|
|
636
|
+
raise NotImplementedError(
|
|
637
|
+
"Backend '{}' has not implemented `sort`.".format(self.name)
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
641
|
+
"""
|
|
642
|
+
Test whether all array elements along a given axis evaluate to True.
|
|
643
|
+
|
|
644
|
+
:param a: Input tensor
|
|
645
|
+
:type a: Tensor
|
|
646
|
+
:param axis: Axis or axes along which a logical AND reduction is performed,
|
|
647
|
+
defaults to None
|
|
648
|
+
:type axis: Optional[Sequence[int]], optional
|
|
649
|
+
:return: A new boolean or tensor resulting from the AND reduction
|
|
650
|
+
:rtype: Tensor
|
|
651
|
+
"""
|
|
652
|
+
raise NotImplementedError(
|
|
653
|
+
"Backend '{}' has not implemented `all`.".format(self.name)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
657
|
+
"""
|
|
658
|
+
Return coordinate matrices from coordinate vectors.
|
|
659
|
+
|
|
660
|
+
:param args: coordinate vectors
|
|
661
|
+
:type args: Any
|
|
662
|
+
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
|
|
663
|
+
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
|
|
664
|
+
- 'ij': matrix indexing, first dimension corresponds to rows (default)
|
|
665
|
+
- 'xy': Cartesian indexing, first dimension corresponds to columns
|
|
666
|
+
Example:
|
|
667
|
+
>>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
|
|
668
|
+
Shapes:
|
|
669
|
+
- x.shape == (2, 2) # rows correspond to y vector length
|
|
670
|
+
- y.shape == (2, 2)
|
|
671
|
+
Values:
|
|
672
|
+
x = [[0, 1],
|
|
673
|
+
[0, 1]]
|
|
674
|
+
y = [[0, 0],
|
|
675
|
+
[2, 2]]
|
|
676
|
+
:type kwargs: Any
|
|
677
|
+
:return: list of coordinate matrices
|
|
678
|
+
:rtype: Any
|
|
679
|
+
"""
|
|
680
|
+
raise NotImplementedError(
|
|
681
|
+
"Backend '{}' has not implemented `meshgrid`.".format(self.name)
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
|
|
685
|
+
"""
|
|
686
|
+
Expand the shape of a tensor.
|
|
687
|
+
Insert a new axis that will appear at the `axis` position in the expanded
|
|
688
|
+
tensor shape.
|
|
689
|
+
|
|
690
|
+
:param a: Input tensor
|
|
691
|
+
:type a: Tensor
|
|
692
|
+
:param axis: Position in the expanded axes where the new axis is placed
|
|
693
|
+
:type axis: int
|
|
694
|
+
:return: Output tensor with the number of dimensions increased by one.
|
|
695
|
+
:rtype: Tensor
|
|
696
|
+
"""
|
|
697
|
+
raise NotImplementedError(
|
|
698
|
+
"Backend '{}' has not implemented `expand_dims`.".format(self.name)
|
|
699
|
+
)
|
|
700
|
+
|
|
584
701
|
def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
|
|
585
702
|
"""
|
|
586
703
|
Find the unique elements and their corresponding counts of the given tensor ``a``.
|
|
@@ -700,6 +817,9 @@ class ExtendedBackend:
|
|
|
700
817
|
"Backend '{}' has not implemented `is_tensor`.".format(self.name)
|
|
701
818
|
)
|
|
702
819
|
|
|
820
|
+
def matvec(self: Any, A: Tensor, x: Tensor) -> Tensor:
|
|
821
|
+
return self.tensordot(A, x, axes=[[1], [0]])
|
|
822
|
+
|
|
703
823
|
def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
|
|
704
824
|
"""
|
|
705
825
|
Cast the tensor dtype of a ``a``.
|
|
@@ -715,6 +835,21 @@ class ExtendedBackend:
|
|
|
715
835
|
"Backend '{}' has not implemented `cast`.".format(self.name)
|
|
716
836
|
)
|
|
717
837
|
|
|
838
|
+
def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
839
|
+
"""
|
|
840
|
+
Convert input to tensor.
|
|
841
|
+
|
|
842
|
+
:param a: input data to be converted
|
|
843
|
+
:type a: Tensor
|
|
844
|
+
:param dtype: target dtype, optional
|
|
845
|
+
:type dtype: Optional[str]
|
|
846
|
+
:return: converted tensor
|
|
847
|
+
:rtype: Tensor
|
|
848
|
+
"""
|
|
849
|
+
raise NotImplementedError(
|
|
850
|
+
"Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
|
|
851
|
+
)
|
|
852
|
+
|
|
718
853
|
def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
|
|
719
854
|
"""
|
|
720
855
|
Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
|
|
@@ -730,6 +865,82 @@ class ExtendedBackend:
|
|
|
730
865
|
"Backend '{}' has not implemented `mod`.".format(self.name)
|
|
731
866
|
)
|
|
732
867
|
|
|
868
|
+
def floor_divide(self: Any, x: Tensor, y: Tensor) -> Tensor:
|
|
869
|
+
r"""
|
|
870
|
+
Compute the element-wise floor division of two tensors.
|
|
871
|
+
|
|
872
|
+
This operation returns a new tensor containing the result of
|
|
873
|
+
dividing `x` by `y` and rounding each element down towards
|
|
874
|
+
negative infinity. The semantics are equivalent to the Python
|
|
875
|
+
`//` operator:
|
|
876
|
+
|
|
877
|
+
result[i] = floor(x[i] / y[i])
|
|
878
|
+
|
|
879
|
+
Broadcasting is supported according to the backend's rules.
|
|
880
|
+
|
|
881
|
+
:param x: Dividend tensor.
|
|
882
|
+
:type x: Tensor
|
|
883
|
+
:param y: Divisor tensor, must be broadcastable with `x`.
|
|
884
|
+
:type y: Tensor
|
|
885
|
+
:return: A tensor with the broadcasted shape of `x` and `y`,
|
|
886
|
+
where each element is the floored result of the division.
|
|
887
|
+
:rtype: Tensor
|
|
888
|
+
|
|
889
|
+
:raises NotImplementedError: If the backend does not provide an
|
|
890
|
+
implementation for `floor_divide`.
|
|
891
|
+
"""
|
|
892
|
+
raise NotImplementedError(
|
|
893
|
+
"Backend '{}' has not implemented `floor_divide`.".format(self.name)
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
def floor(self: Any, a: Tensor) -> Tensor:
|
|
897
|
+
"""
|
|
898
|
+
Compute the element-wise floor of the input tensor.
|
|
899
|
+
|
|
900
|
+
This operation returns a new tensor with the largest integers
|
|
901
|
+
less than or equal to each element of the input tensor,
|
|
902
|
+
i.e. it rounds each value down towards negative infinity.
|
|
903
|
+
|
|
904
|
+
:param a: Input tensor containing numeric values.
|
|
905
|
+
:type a: Tensor
|
|
906
|
+
:return: A tensor with the same shape as `a`, where each element
|
|
907
|
+
is the floored value of the corresponding element in `a`.
|
|
908
|
+
:rtype: Tensor
|
|
909
|
+
|
|
910
|
+
:raises NotImplementedError: If the backend does not provide an
|
|
911
|
+
implementation for `floor`.
|
|
912
|
+
"""
|
|
913
|
+
raise NotImplementedError(
|
|
914
|
+
"Backend '{}' has not implemented `floor`.".format(self.name)
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
918
|
+
"""
|
|
919
|
+
Clip (limit) the values of a tensor element-wise to the range [a_min, a_max].
|
|
920
|
+
|
|
921
|
+
Each element in the input tensor `a` is compared against the corresponding
|
|
922
|
+
bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set
|
|
923
|
+
to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the
|
|
924
|
+
value is left unchanged. The result preserves the dtype and device of the input.
|
|
925
|
+
|
|
926
|
+
:param a: Input tensor containing values to be clipped.
|
|
927
|
+
:type a: Tensor
|
|
928
|
+
:param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor
|
|
929
|
+
or broadcastable to the shape of `a`.
|
|
930
|
+
:type a_min: Tensor
|
|
931
|
+
:param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor
|
|
932
|
+
or broadcastable to the shape of `a`.
|
|
933
|
+
:type a_max: Tensor
|
|
934
|
+
:return: A tensor with the same shape as `a`, where all values are clipped
|
|
935
|
+
to lie within the interval [a_min, a_max].
|
|
936
|
+
:rtype: Tensor
|
|
937
|
+
|
|
938
|
+
:raises NotImplementedError: If the backend does not implement `clip`.
|
|
939
|
+
"""
|
|
940
|
+
raise NotImplementedError(
|
|
941
|
+
"Backend '{}' has not implemented `clip`.".format(self.name)
|
|
942
|
+
)
|
|
943
|
+
|
|
733
944
|
def reverse(self: Any, a: Tensor) -> Tensor:
|
|
734
945
|
"""
|
|
735
946
|
return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
|
|
@@ -805,6 +1016,21 @@ class ExtendedBackend:
|
|
|
805
1016
|
"Backend '{}' has not implemented `solve`.".format(self.name)
|
|
806
1017
|
)
|
|
807
1018
|
|
|
1019
|
+
def special_jv(self: Any, v: int, z: Tensor, M: int) -> Tensor:
|
|
1020
|
+
"""
|
|
1021
|
+
Special function: Bessel function of the first kind.
|
|
1022
|
+
|
|
1023
|
+
:param v: The order of the Bessel function.
|
|
1024
|
+
:type v: int
|
|
1025
|
+
:param z: The argument of the Bessel function.
|
|
1026
|
+
:type z: Tensor
|
|
1027
|
+
:return: The value of the Bessel function [J_0, ...J_{v-1}(z)].
|
|
1028
|
+
:rtype: Tensor
|
|
1029
|
+
"""
|
|
1030
|
+
raise NotImplementedError(
|
|
1031
|
+
"Backend '{}' has not implemented `special_jv`.".format(self.name)
|
|
1032
|
+
)
|
|
1033
|
+
|
|
808
1034
|
def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
|
|
809
1035
|
"""
|
|
810
1036
|
Find indices where elements should be inserted to maintain order.
|
|
@@ -813,8 +1039,8 @@ class ExtendedBackend:
|
|
|
813
1039
|
:type a: Tensor
|
|
814
1040
|
:param v: value to inserted
|
|
815
1041
|
:type v: Tensor
|
|
816
|
-
:param side: If
|
|
817
|
-
If
|
|
1042
|
+
:param side: If `left`, the index of the first suitable location found is given.
|
|
1043
|
+
If `right`, return the last such index.
|
|
818
1044
|
If there is no suitable index, return either 0 or N (where N is the length of a),
|
|
819
1045
|
defaults to "left"
|
|
820
1046
|
:type side: str, optional
|
|
@@ -1248,6 +1474,26 @@ class ExtendedBackend:
|
|
|
1248
1474
|
"Backend '{}' has not implemented `sparse_dense_matmul`.".format(self.name)
|
|
1249
1475
|
)
|
|
1250
1476
|
|
|
1477
|
+
def sparse_csr_from_coo(self: Any, coo: Tensor, strict: bool = False) -> Tensor:
|
|
1478
|
+
"""
|
|
1479
|
+
transform a coo matrix to a csr matrix
|
|
1480
|
+
|
|
1481
|
+
:param coo: a coo matrix
|
|
1482
|
+
:type coo: Tensor
|
|
1483
|
+
:param strict: whether to enforce the transform, defaults to False,
|
|
1484
|
+
corresponding to return the coo matrix if there is no implementation for specific backend.
|
|
1485
|
+
:type strict: bool, optional
|
|
1486
|
+
:return: a csr matrix
|
|
1487
|
+
:rtype: Tensor
|
|
1488
|
+
"""
|
|
1489
|
+
if strict:
|
|
1490
|
+
raise NotImplementedError(
|
|
1491
|
+
"Backend '{}' has not implemented `sparse_csr_from_coo`.".format(
|
|
1492
|
+
self.name
|
|
1493
|
+
)
|
|
1494
|
+
)
|
|
1495
|
+
return coo
|
|
1496
|
+
|
|
1251
1497
|
def to_dense(self: Any, sp_a: Tensor) -> Tensor:
|
|
1252
1498
|
"""
|
|
1253
1499
|
Convert a sparse matrix to dense tensor.
|
|
@@ -1351,6 +1597,28 @@ class ExtendedBackend:
|
|
|
1351
1597
|
"Backend '{}' has not implemented `cond`.".format(self.name)
|
|
1352
1598
|
)
|
|
1353
1599
|
|
|
1600
|
+
def where(
|
|
1601
|
+
self: Any,
|
|
1602
|
+
condition: Tensor,
|
|
1603
|
+
x: Optional[Tensor] = None,
|
|
1604
|
+
y: Optional[Tensor] = None,
|
|
1605
|
+
) -> Tensor:
|
|
1606
|
+
"""
|
|
1607
|
+
Return a tensor of elements selected from either x or y, depending on condition.
|
|
1608
|
+
|
|
1609
|
+
:param condition: Where True, yield x, otherwise yield y.
|
|
1610
|
+
:type condition: Tensor (bool)
|
|
1611
|
+
:param x: Values from which to choose when condition is True.
|
|
1612
|
+
:type x: Tensor
|
|
1613
|
+
:param y: Values from which to choose when condition is False.
|
|
1614
|
+
:type y: Tensor
|
|
1615
|
+
:return: A tensor with elements from x where condition is True, and y otherwise.
|
|
1616
|
+
:rtype: Tensor
|
|
1617
|
+
"""
|
|
1618
|
+
raise NotImplementedError(
|
|
1619
|
+
"Backend '{}' has not implemented `where`.".format(self.name)
|
|
1620
|
+
)
|
|
1621
|
+
|
|
1354
1622
|
def switch(
|
|
1355
1623
|
self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
|
|
1356
1624
|
) -> Tensor:
|
|
@@ -1386,10 +1654,49 @@ class ExtendedBackend:
|
|
|
1386
1654
|
:rtype: Tensor
|
|
1387
1655
|
"""
|
|
1388
1656
|
carry = init
|
|
1389
|
-
|
|
1390
|
-
|
|
1657
|
+
# Check if `xs` is a PyTree (tuple or list) of arrays.
|
|
1658
|
+
if isinstance(xs, (tuple, list)):
|
|
1659
|
+
for x_slice_tuple in zip(*xs):
|
|
1660
|
+
# x_slice_tuple will be (k_elems[i], j_elems[i]) at each step.
|
|
1661
|
+
carry = f(carry, x_slice_tuple)
|
|
1662
|
+
else:
|
|
1663
|
+
# If xs is a single array, iterate normally.
|
|
1664
|
+
for x in xs:
|
|
1665
|
+
carry = f(carry, x)
|
|
1666
|
+
|
|
1391
1667
|
return carry
|
|
1392
1668
|
|
|
1669
|
+
def jaxy_scan(
|
|
1670
|
+
self: Any, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
|
|
1671
|
+
) -> Tensor:
|
|
1672
|
+
"""
|
|
1673
|
+
This API follows jax scan style. TF use plain for loop
|
|
1674
|
+
|
|
1675
|
+
:param f: _description_
|
|
1676
|
+
:type f: Callable[[Tensor, Tensor], Tensor]
|
|
1677
|
+
:param init: _description_
|
|
1678
|
+
:type init: Tensor
|
|
1679
|
+
:param xs: _description_
|
|
1680
|
+
:type xs: Tensor
|
|
1681
|
+
:raises ValueError: _description_
|
|
1682
|
+
:return: _description_
|
|
1683
|
+
:rtype: Tensor
|
|
1684
|
+
"""
|
|
1685
|
+
if xs is None:
|
|
1686
|
+
raise ValueError("Either xs or length must be provided.")
|
|
1687
|
+
if xs is not None:
|
|
1688
|
+
length = len(xs)
|
|
1689
|
+
carry, outputs_to_stack = init, []
|
|
1690
|
+
for i in range(length):
|
|
1691
|
+
if isinstance(xs, (tuple, list)):
|
|
1692
|
+
x = [ele[i] for ele in xs]
|
|
1693
|
+
else:
|
|
1694
|
+
x = xs[i]
|
|
1695
|
+
new_carry, y = f(carry, x)
|
|
1696
|
+
carry = new_carry
|
|
1697
|
+
outputs_to_stack.append(y)
|
|
1698
|
+
return carry, self.stack(outputs_to_stack)
|
|
1699
|
+
|
|
1393
1700
|
def stop_gradient(self: Any, a: Tensor) -> Tensor:
|
|
1394
1701
|
"""
|
|
1395
1702
|
Stop backpropagation from ``a``.
|
|
@@ -1668,7 +1975,7 @@ class ExtendedBackend:
|
|
|
1668
1975
|
f: Callable[..., Any],
|
|
1669
1976
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
1670
1977
|
jit_compile: Optional[bool] = None,
|
|
1671
|
-
**kws: Any
|
|
1978
|
+
**kws: Any,
|
|
1672
1979
|
) -> Callable[..., Any]:
|
|
1673
1980
|
"""
|
|
1674
1981
|
Return the jitted version of function ``f``.
|
|
@@ -56,10 +56,12 @@ class CuPyBackend(tnbackend, ExtendedBackend): # type: ignore
|
|
|
56
56
|
cpx = cupyx
|
|
57
57
|
self.name = "cupy"
|
|
58
58
|
|
|
59
|
-
def convert_to_tensor(self, a: Tensor) -> Tensor:
|
|
59
|
+
def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
60
60
|
if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
|
|
61
61
|
a = cp.array(a)
|
|
62
62
|
a = cp.asarray(a)
|
|
63
|
+
if dtype is not None:
|
|
64
|
+
a = self.cast(a, dtype)
|
|
63
65
|
return a
|
|
64
66
|
|
|
65
67
|
def sum(
|
|
@@ -50,12 +50,17 @@ class optax_optimizer:
|
|
|
50
50
|
return params
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def _convert_to_tensor_jax(
|
|
53
|
+
def _convert_to_tensor_jax(
|
|
54
|
+
self: Any, tensor: Tensor, dtype: Optional[str] = None
|
|
55
|
+
) -> Tensor:
|
|
54
56
|
if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
|
|
55
57
|
raise TypeError(
|
|
56
58
|
("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
|
|
57
59
|
)
|
|
58
60
|
result = jnp.asarray(tensor)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
# Use the backend's cast method to handle dtype conversion
|
|
63
|
+
result = self.cast(result, dtype)
|
|
59
64
|
return result
|
|
60
65
|
|
|
61
66
|
|
|
@@ -170,6 +175,27 @@ def _eigh_jax(self: Any, tensor: Tensor) -> Tensor:
|
|
|
170
175
|
return adaware_eigh(tensor)
|
|
171
176
|
|
|
172
177
|
|
|
178
|
+
def bcsr_scalar_mul(self: Tensor, other: Tensor) -> Tensor:
|
|
179
|
+
"""
|
|
180
|
+
Implements scalar multiplication for BCSR matrices (self * scalar).
|
|
181
|
+
"""
|
|
182
|
+
import jax.numpy as jnp
|
|
183
|
+
from jax.experimental.sparse import BCSR
|
|
184
|
+
|
|
185
|
+
if jnp.isscalar(other):
|
|
186
|
+
# The core logic: only the data array is affected by scalar multiplication.
|
|
187
|
+
# The sparsity pattern (indices, indptr) remains the same.
|
|
188
|
+
new_data = self.data * other
|
|
189
|
+
|
|
190
|
+
# Return a new BCSR instance with the scaled data.
|
|
191
|
+
return BCSR((new_data, self.indices, self.indptr), shape=self.shape)
|
|
192
|
+
|
|
193
|
+
# For any other type of multiplication (e.g., element-wise with another matrix),
|
|
194
|
+
# return NotImplemented. This allows Python to try other operations,
|
|
195
|
+
# like other.__rmul__(self).
|
|
196
|
+
return NotImplemented
|
|
197
|
+
|
|
198
|
+
|
|
173
199
|
tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = (
|
|
174
200
|
_convert_to_tensor_jax
|
|
175
201
|
)
|
|
@@ -219,6 +245,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
219
245
|
|
|
220
246
|
self.name = "jax"
|
|
221
247
|
|
|
248
|
+
# --- Monkey-patch the BCSR class ---
|
|
249
|
+
|
|
250
|
+
sparse.BCSR.__mul__ = bcsr_scalar_mul # type: ignore
|
|
251
|
+
sparse.BCSR.__rmul__ = bcsr_scalar_mul # type: ignore
|
|
252
|
+
|
|
222
253
|
# it is already child of numpy backend, and self.np = self.jax.np
|
|
223
254
|
def eye(
|
|
224
255
|
self, N: int, dtype: Optional[str] = None, M: Optional[int] = None
|
|
@@ -243,8 +274,10 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
243
274
|
def copy(self, tensor: Tensor) -> Tensor:
|
|
244
275
|
return jnp.array(tensor, copy=True)
|
|
245
276
|
|
|
246
|
-
def convert_to_tensor(self, tensor: Tensor) -> Tensor:
|
|
277
|
+
def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
|
|
247
278
|
result = jnp.asarray(tensor)
|
|
279
|
+
if dtype is not None:
|
|
280
|
+
result = self.cast(result, dtype)
|
|
248
281
|
return result
|
|
249
282
|
|
|
250
283
|
def abs(self, a: Tensor) -> Tensor:
|
|
@@ -342,6 +375,15 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
342
375
|
def mod(self, x: Tensor, y: Tensor) -> Tensor:
|
|
343
376
|
return jnp.mod(x, y)
|
|
344
377
|
|
|
378
|
+
def floor(self, a: Tensor) -> Tensor:
|
|
379
|
+
return jnp.floor(a)
|
|
380
|
+
|
|
381
|
+
def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
|
|
382
|
+
return jnp.floor_divide(x, y)
|
|
383
|
+
|
|
384
|
+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
|
|
385
|
+
return jnp.clip(a, a_min, a_max)
|
|
386
|
+
|
|
345
387
|
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
|
|
346
388
|
return jnp.right_shift(x, y)
|
|
347
389
|
|
|
@@ -387,6 +429,12 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
387
429
|
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
|
|
388
430
|
return jnp.argmin(a, axis=axis)
|
|
389
431
|
|
|
432
|
+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
433
|
+
return jnp.argsort(a, axis=axis)
|
|
434
|
+
|
|
435
|
+
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
|
|
436
|
+
return jnp.sort(a, axis=axis)
|
|
437
|
+
|
|
390
438
|
def unique_with_counts( # type: ignore
|
|
391
439
|
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
|
|
392
440
|
) -> Tuple[Tensor, Tensor]:
|
|
@@ -407,6 +455,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
407
455
|
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
|
|
408
456
|
return jnp.cumsum(a, axis)
|
|
409
457
|
|
|
458
|
+
def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
|
|
459
|
+
return jnp.all(a, axis=axis)
|
|
460
|
+
|
|
410
461
|
def is_tensor(self, a: Any) -> bool:
|
|
411
462
|
if not isinstance(a, jnp.ndarray):
|
|
412
463
|
return False
|
|
@@ -418,6 +469,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
418
469
|
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
|
|
419
470
|
return jsp.linalg.solve(A, b, assume_a=assume_a)
|
|
420
471
|
|
|
472
|
+
def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
|
|
473
|
+
from .jax_ops import bessel_jv_jax_rescaled
|
|
474
|
+
|
|
475
|
+
return bessel_jv_jax_rescaled(v, z, M)
|
|
476
|
+
|
|
421
477
|
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
|
|
422
478
|
if not self.is_tensor(a):
|
|
423
479
|
a = self.convert_to_tensor(a)
|
|
@@ -615,6 +671,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
615
671
|
carry, _ = libjax.lax.scan(f_jax, init, xs)
|
|
616
672
|
return carry
|
|
617
673
|
|
|
674
|
+
def jaxy_scan(
|
|
675
|
+
self, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
|
|
676
|
+
) -> Tensor:
|
|
677
|
+
return libjax.lax.scan(f, init, xs)
|
|
678
|
+
|
|
618
679
|
def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
|
|
619
680
|
# updates = jnp.reshape(updates, indices.shape)
|
|
620
681
|
# return operand.at[indices].set(updates)
|
|
@@ -639,11 +700,20 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
639
700
|
) -> Tensor:
|
|
640
701
|
return sp_a @ b
|
|
641
702
|
|
|
703
|
+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
|
|
704
|
+
try:
|
|
705
|
+
return sparse.BCSR.from_bcoo(coo) # type: ignore
|
|
706
|
+
except AttributeError as e:
|
|
707
|
+
if not strict:
|
|
708
|
+
return coo
|
|
709
|
+
else:
|
|
710
|
+
raise e
|
|
711
|
+
|
|
642
712
|
def to_dense(self, sp_a: Tensor) -> Tensor:
|
|
643
713
|
return sp_a.todense()
|
|
644
714
|
|
|
645
715
|
def is_sparse(self, a: Tensor) -> bool:
|
|
646
|
-
return isinstance(a, sparse.
|
|
716
|
+
return isinstance(a, sparse.JAXSparse) # type: ignore
|
|
647
717
|
|
|
648
718
|
def device(self, a: Tensor) -> str:
|
|
649
719
|
(dev,) = a.devices()
|
|
@@ -790,4 +860,23 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
|
|
|
790
860
|
|
|
791
861
|
vvag = vectorized_value_and_grad
|
|
792
862
|
|
|
863
|
+
def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
|
|
864
|
+
"""
|
|
865
|
+
Backend-agnostic meshgrid function.
|
|
866
|
+
"""
|
|
867
|
+
return jnp.meshgrid(*args, **kwargs)
|
|
868
|
+
|
|
793
869
|
optimizer = optax_optimizer
|
|
870
|
+
|
|
871
|
+
def expand_dims(self, a: Tensor, axis: int) -> Tensor:
|
|
872
|
+
return jnp.expand_dims(a, axis)
|
|
873
|
+
|
|
874
|
+
def where(
|
|
875
|
+
self,
|
|
876
|
+
condition: Tensor,
|
|
877
|
+
x: Optional[Tensor] = None,
|
|
878
|
+
y: Optional[Tensor] = None,
|
|
879
|
+
) -> Tensor:
|
|
880
|
+
if x is None and y is None:
|
|
881
|
+
return jnp.where(condition)
|
|
882
|
+
return jnp.where(condition, x, y)
|