tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
@@ -9,6 +9,7 @@ from functools import reduce, partial
9
9
  from operator import mul
10
10
  from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
11
11
 
12
+ import math
12
13
  import numpy as np
13
14
  from ..utils import return_partial
14
15
 
@@ -46,17 +47,22 @@ class ExtendedBackend:
46
47
  "Backend '{}' has not implemented `expm`.".format(self.name)
47
48
  )
48
49
 
49
- def sqrtmh(self: Any, a: Tensor) -> Tensor:
50
+ def sqrtmh(self: Any, a: Tensor, psd: bool = False) -> Tensor:
50
51
  """
51
52
  Return the sqrtm of a Hermitian matrix ``a``.
52
53
 
53
54
  :param a: tensor in matrix form
54
55
  :type a: Tensor
56
+ :param psd: whether the input ``a`` is guaranteed as a positive semidefinite matrix,
57
+ defaults False
58
+ :type psd: bool
55
59
  :return: sqrtm of ``a``
56
60
  :rtype: Tensor
57
61
  """
58
62
  # maybe friendly for AD and also cosidering that several backend has no support for native sqrtm
59
63
  e, v = self.eigh(a)
64
+ if psd:
65
+ e = self.relu(e)
60
66
  e = self.sqrt(e)
61
67
  return v @ self.diagflat(e) @ self.adjoint(v)
62
68
 
@@ -400,6 +406,31 @@ class ExtendedBackend:
400
406
  a = self.reshape(a, [2 for _ in range(nleg)])
401
407
  return a
402
408
 
409
+ def reshaped(self: Any, a: Tensor, d: int) -> Tensor:
410
+ """
411
+ Reshape a tensor to the [d, d, ...] shape.
412
+
413
+ :param a: Input tensor
414
+ :type a: Tensor
415
+ :param d: edge length for each dimension
416
+ :type d: int
417
+ :return: the reshaped tensor
418
+ :rtype: Tensor
419
+ """
420
+ if not isinstance(d, int) or d <= 0:
421
+ raise ValueError("d must be a positive integer.")
422
+
423
+ size = self.sizen(a)
424
+ if size == 0:
425
+ return self.reshape(a, (0,))
426
+
427
+ nleg_float = math.log(size, d)
428
+ nleg = int(round(nleg_float))
429
+ if d**nleg != size:
430
+ raise ValueError(f"cannot reshape: size {size} is not a power of d={d}")
431
+
432
+ return self.reshape(a, (d,) * nleg)
433
+
403
434
  def reshapem(self: Any, a: Tensor) -> Tensor:
404
435
  """
405
436
  Reshape a tensor to the [l, l] shape.
@@ -576,6 +607,97 @@ class ExtendedBackend:
576
607
  "Backend '{}' has not implemented `argmin`.".format(self.name)
577
608
  )
578
609
 
610
+ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
611
+ """
612
+ return the indices that would sort an array.
613
+
614
+ :param a: the tensor to be sorted
615
+ :type a: Tensor
616
+ :param axis: the sorted axis, defaults to -1
617
+ :type axis: int
618
+ :return: the sorted indices
619
+ :rtype: Tensor
620
+ """
621
+ raise NotImplementedError(
622
+ "Backend '{}' has not implemented `argsort`.".format(self.name)
623
+ )
624
+
625
+ def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
626
+ """
627
+ Sort a tensor along the given axis.
628
+
629
+ :param a: [description]
630
+ :type a: Tensor
631
+ :param axis: [description], defaults to -1
632
+ :type axis: int, optional
633
+ :return: [description]
634
+ :rtype: Tensor
635
+ """
636
+ raise NotImplementedError(
637
+ "Backend '{}' has not implemented `sort`.".format(self.name)
638
+ )
639
+
640
+ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
641
+ """
642
+ Test whether all array elements along a given axis evaluate to True.
643
+
644
+ :param a: Input tensor
645
+ :type a: Tensor
646
+ :param axis: Axis or axes along which a logical AND reduction is performed,
647
+ defaults to None
648
+ :type axis: Optional[Sequence[int]], optional
649
+ :return: A new boolean or tensor resulting from the AND reduction
650
+ :rtype: Tensor
651
+ """
652
+ raise NotImplementedError(
653
+ "Backend '{}' has not implemented `all`.".format(self.name)
654
+ )
655
+
656
+ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
657
+ """
658
+ Return coordinate matrices from coordinate vectors.
659
+
660
+ :param args: coordinate vectors
661
+ :type args: Any
662
+ :param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
663
+ which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
664
+ - 'ij': matrix indexing, first dimension corresponds to rows (default)
665
+ - 'xy': Cartesian indexing, first dimension corresponds to columns
666
+ Example:
667
+ >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
668
+ Shapes:
669
+ - x.shape == (2, 2) # rows correspond to y vector length
670
+ - y.shape == (2, 2)
671
+ Values:
672
+ x = [[0, 1],
673
+ [0, 1]]
674
+ y = [[0, 0],
675
+ [2, 2]]
676
+ :type kwargs: Any
677
+ :return: list of coordinate matrices
678
+ :rtype: Any
679
+ """
680
+ raise NotImplementedError(
681
+ "Backend '{}' has not implemented `meshgrid`.".format(self.name)
682
+ )
683
+
684
+ def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
685
+ """
686
+ Expand the shape of a tensor.
687
+ Insert a new axis that will appear at the `axis` position in the expanded
688
+ tensor shape.
689
+
690
+ :param a: Input tensor
691
+ :type a: Tensor
692
+ :param axis: Position in the expanded axes where the new axis is placed
693
+ :type axis: int
694
+ :return: Output tensor with the number of dimensions increased by one.
695
+ :rtype: Tensor
696
+ """
697
+ raise NotImplementedError(
698
+ "Backend '{}' has not implemented `expand_dims`.".format(self.name)
699
+ )
700
+
579
701
  def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
580
702
  """
581
703
  Find the unique elements and their corresponding counts of the given tensor ``a``.
@@ -695,6 +817,9 @@ class ExtendedBackend:
695
817
  "Backend '{}' has not implemented `is_tensor`.".format(self.name)
696
818
  )
697
819
 
820
+ def matvec(self: Any, A: Tensor, x: Tensor) -> Tensor:
821
+ return self.tensordot(A, x, axes=[[1], [0]])
822
+
698
823
  def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
699
824
  """
700
825
  Cast the tensor dtype of a ``a``.
@@ -710,6 +835,21 @@ class ExtendedBackend:
710
835
  "Backend '{}' has not implemented `cast`.".format(self.name)
711
836
  )
712
837
 
838
+ def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
839
+ """
840
+ Convert input to tensor.
841
+
842
+ :param a: input data to be converted
843
+ :type a: Tensor
844
+ :param dtype: target dtype, optional
845
+ :type dtype: Optional[str]
846
+ :return: converted tensor
847
+ :rtype: Tensor
848
+ """
849
+ raise NotImplementedError(
850
+ "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
851
+ )
852
+
713
853
  def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
714
854
  """
715
855
  Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
@@ -725,6 +865,82 @@ class ExtendedBackend:
725
865
  "Backend '{}' has not implemented `mod`.".format(self.name)
726
866
  )
727
867
 
868
+ def floor_divide(self: Any, x: Tensor, y: Tensor) -> Tensor:
869
+ r"""
870
+ Compute the element-wise floor division of two tensors.
871
+
872
+ This operation returns a new tensor containing the result of
873
+ dividing `x` by `y` and rounding each element down towards
874
+ negative infinity. The semantics are equivalent to the Python
875
+ `//` operator:
876
+
877
+ result[i] = floor(x[i] / y[i])
878
+
879
+ Broadcasting is supported according to the backend's rules.
880
+
881
+ :param x: Dividend tensor.
882
+ :type x: Tensor
883
+ :param y: Divisor tensor, must be broadcastable with `x`.
884
+ :type y: Tensor
885
+ :return: A tensor with the broadcasted shape of `x` and `y`,
886
+ where each element is the floored result of the division.
887
+ :rtype: Tensor
888
+
889
+ :raises NotImplementedError: If the backend does not provide an
890
+ implementation for `floor_divide`.
891
+ """
892
+ raise NotImplementedError(
893
+ "Backend '{}' has not implemented `floor_divide`.".format(self.name)
894
+ )
895
+
896
+ def floor(self: Any, a: Tensor) -> Tensor:
897
+ """
898
+ Compute the element-wise floor of the input tensor.
899
+
900
+ This operation returns a new tensor with the largest integers
901
+ less than or equal to each element of the input tensor,
902
+ i.e. it rounds each value down towards negative infinity.
903
+
904
+ :param a: Input tensor containing numeric values.
905
+ :type a: Tensor
906
+ :return: A tensor with the same shape as `a`, where each element
907
+ is the floored value of the corresponding element in `a`.
908
+ :rtype: Tensor
909
+
910
+ :raises NotImplementedError: If the backend does not provide an
911
+ implementation for `floor`.
912
+ """
913
+ raise NotImplementedError(
914
+ "Backend '{}' has not implemented `floor`.".format(self.name)
915
+ )
916
+
917
+ def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
918
+ """
919
+ Clip (limit) the values of a tensor element-wise to the range [a_min, a_max].
920
+
921
+ Each element in the input tensor `a` is compared against the corresponding
922
+ bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set
923
+ to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the
924
+ value is left unchanged. The result preserves the dtype and device of the input.
925
+
926
+ :param a: Input tensor containing values to be clipped.
927
+ :type a: Tensor
928
+ :param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor
929
+ or broadcastable to the shape of `a`.
930
+ :type a_min: Tensor
931
+ :param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor
932
+ or broadcastable to the shape of `a`.
933
+ :type a_max: Tensor
934
+ :return: A tensor with the same shape as `a`, where all values are clipped
935
+ to lie within the interval [a_min, a_max].
936
+ :rtype: Tensor
937
+
938
+ :raises NotImplementedError: If the backend does not implement `clip`.
939
+ """
940
+ raise NotImplementedError(
941
+ "Backend '{}' has not implemented `clip`.".format(self.name)
942
+ )
943
+
728
944
  def reverse(self: Any, a: Tensor) -> Tensor:
729
945
  """
730
946
  return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
@@ -800,6 +1016,21 @@ class ExtendedBackend:
800
1016
  "Backend '{}' has not implemented `solve`.".format(self.name)
801
1017
  )
802
1018
 
1019
+ def special_jv(self: Any, v: int, z: Tensor, M: int) -> Tensor:
1020
+ """
1021
+ Special function: Bessel function of the first kind.
1022
+
1023
+ :param v: The order of the Bessel function.
1024
+ :type v: int
1025
+ :param z: The argument of the Bessel function.
1026
+ :type z: Tensor
1027
+ :return: The value of the Bessel function [J_0, ...J_{v-1}(z)].
1028
+ :rtype: Tensor
1029
+ """
1030
+ raise NotImplementedError(
1031
+ "Backend '{}' has not implemented `special_jv`.".format(self.name)
1032
+ )
1033
+
803
1034
  def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
804
1035
  """
805
1036
  Find indices where elements should be inserted to maintain order.
@@ -808,8 +1039,8 @@ class ExtendedBackend:
808
1039
  :type a: Tensor
809
1040
  :param v: value to inserted
810
1041
  :type v: Tensor
811
- :param side: If left’, the index of the first suitable location found is given.
812
- If right’, return the last such index.
1042
+ :param side: If `left`, the index of the first suitable location found is given.
1043
+ If `right`, return the last such index.
813
1044
  If there is no suitable index, return either 0 or N (where N is the length of a),
814
1045
  defaults to "left"
815
1046
  :type side: str, optional
@@ -1243,6 +1474,26 @@ class ExtendedBackend:
1243
1474
  "Backend '{}' has not implemented `sparse_dense_matmul`.".format(self.name)
1244
1475
  )
1245
1476
 
1477
+ def sparse_csr_from_coo(self: Any, coo: Tensor, strict: bool = False) -> Tensor:
1478
+ """
1479
+ transform a coo matrix to a csr matrix
1480
+
1481
+ :param coo: a coo matrix
1482
+ :type coo: Tensor
1483
+ :param strict: whether to enforce the transform, defaults to False,
1484
+ corresponding to return the coo matrix if there is no implementation for specific backend.
1485
+ :type strict: bool, optional
1486
+ :return: a csr matrix
1487
+ :rtype: Tensor
1488
+ """
1489
+ if strict:
1490
+ raise NotImplementedError(
1491
+ "Backend '{}' has not implemented `sparse_csr_from_coo`.".format(
1492
+ self.name
1493
+ )
1494
+ )
1495
+ return coo
1496
+
1246
1497
  def to_dense(self: Any, sp_a: Tensor) -> Tensor:
1247
1498
  """
1248
1499
  Convert a sparse matrix to dense tensor.
@@ -1346,6 +1597,28 @@ class ExtendedBackend:
1346
1597
  "Backend '{}' has not implemented `cond`.".format(self.name)
1347
1598
  )
1348
1599
 
1600
+ def where(
1601
+ self: Any,
1602
+ condition: Tensor,
1603
+ x: Optional[Tensor] = None,
1604
+ y: Optional[Tensor] = None,
1605
+ ) -> Tensor:
1606
+ """
1607
+ Return a tensor of elements selected from either x or y, depending on condition.
1608
+
1609
+ :param condition: Where True, yield x, otherwise yield y.
1610
+ :type condition: Tensor (bool)
1611
+ :param x: Values from which to choose when condition is True.
1612
+ :type x: Tensor
1613
+ :param y: Values from which to choose when condition is False.
1614
+ :type y: Tensor
1615
+ :return: A tensor with elements from x where condition is True, and y otherwise.
1616
+ :rtype: Tensor
1617
+ """
1618
+ raise NotImplementedError(
1619
+ "Backend '{}' has not implemented `where`.".format(self.name)
1620
+ )
1621
+
1349
1622
  def switch(
1350
1623
  self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
1351
1624
  ) -> Tensor:
@@ -1381,10 +1654,49 @@ class ExtendedBackend:
1381
1654
  :rtype: Tensor
1382
1655
  """
1383
1656
  carry = init
1384
- for x in xs:
1385
- carry = f(carry, x)
1657
+ # Check if `xs` is a PyTree (tuple or list) of arrays.
1658
+ if isinstance(xs, (tuple, list)):
1659
+ for x_slice_tuple in zip(*xs):
1660
+ # x_slice_tuple will be (k_elems[i], j_elems[i]) at each step.
1661
+ carry = f(carry, x_slice_tuple)
1662
+ else:
1663
+ # If xs is a single array, iterate normally.
1664
+ for x in xs:
1665
+ carry = f(carry, x)
1666
+
1386
1667
  return carry
1387
1668
 
1669
+ def jaxy_scan(
1670
+ self: Any, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
1671
+ ) -> Tensor:
1672
+ """
1673
+ This API follows jax scan style. TF use plain for loop
1674
+
1675
+ :param f: _description_
1676
+ :type f: Callable[[Tensor, Tensor], Tensor]
1677
+ :param init: _description_
1678
+ :type init: Tensor
1679
+ :param xs: _description_
1680
+ :type xs: Tensor
1681
+ :raises ValueError: _description_
1682
+ :return: _description_
1683
+ :rtype: Tensor
1684
+ """
1685
+ if xs is None:
1686
+ raise ValueError("Either xs or length must be provided.")
1687
+ if xs is not None:
1688
+ length = len(xs)
1689
+ carry, outputs_to_stack = init, []
1690
+ for i in range(length):
1691
+ if isinstance(xs, (tuple, list)):
1692
+ x = [ele[i] for ele in xs]
1693
+ else:
1694
+ x = xs[i]
1695
+ new_carry, y = f(carry, x)
1696
+ carry = new_carry
1697
+ outputs_to_stack.append(y)
1698
+ return carry, self.stack(outputs_to_stack)
1699
+
1388
1700
  def stop_gradient(self: Any, a: Tensor) -> Tensor:
1389
1701
  """
1390
1702
  Stop backpropagation from ``a``.
@@ -1548,7 +1860,7 @@ class ExtendedBackend:
1548
1860
  if i == argnum
1549
1861
  else self.reshape(
1550
1862
  self.zeros(
1551
- [self.sizen(arg), self.sizen(arg)],
1863
+ [self.sizen(args[argnum]), self.sizen(arg)],
1552
1864
  dtype=arg.dtype,
1553
1865
  ),
1554
1866
  [-1] + list(self.shape_tuple(arg)),
@@ -1636,6 +1948,7 @@ class ExtendedBackend:
1636
1948
  ),
1637
1949
  jj,
1638
1950
  )
1951
+ jj = [jji for ind, jji in enumerate(jj) if ind in argnums]
1639
1952
  if len(jj) == 1:
1640
1953
  jj = jj[0]
1641
1954
  jjs.append(jj)
@@ -1662,7 +1975,7 @@ class ExtendedBackend:
1662
1975
  f: Callable[..., Any],
1663
1976
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
1664
1977
  jit_compile: Optional[bool] = None,
1665
- **kws: Any
1978
+ **kws: Any,
1666
1979
  ) -> Callable[..., Any]:
1667
1980
  """
1668
1981
  Return the jitted version of function ``f``.
@@ -56,10 +56,12 @@ class CuPyBackend(tnbackend, ExtendedBackend): # type: ignore
56
56
  cpx = cupyx
57
57
  self.name = "cupy"
58
58
 
59
- def convert_to_tensor(self, a: Tensor) -> Tensor:
59
+ def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
60
60
  if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
61
61
  a = cp.array(a)
62
62
  a = cp.asarray(a)
63
+ if dtype is not None:
64
+ a = self.cast(a, dtype)
63
65
  return a
64
66
 
65
67
  def sum(
@@ -50,12 +50,17 @@ class optax_optimizer:
50
50
  return params
51
51
 
52
52
 
53
- def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor:
53
+ def _convert_to_tensor_jax(
54
+ self: Any, tensor: Tensor, dtype: Optional[str] = None
55
+ ) -> Tensor:
54
56
  if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
55
57
  raise TypeError(
56
58
  ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
57
59
  )
58
60
  result = jnp.asarray(tensor)
61
+ if dtype is not None:
62
+ # Use the backend's cast method to handle dtype conversion
63
+ result = self.cast(result, dtype)
59
64
  return result
60
65
 
61
66
 
@@ -170,6 +175,27 @@ def _eigh_jax(self: Any, tensor: Tensor) -> Tensor:
170
175
  return adaware_eigh(tensor)
171
176
 
172
177
 
178
+ def bcsr_scalar_mul(self: Tensor, other: Tensor) -> Tensor:
179
+ """
180
+ Implements scalar multiplication for BCSR matrices (self * scalar).
181
+ """
182
+ import jax.numpy as jnp
183
+ from jax.experimental.sparse import BCSR
184
+
185
+ if jnp.isscalar(other):
186
+ # The core logic: only the data array is affected by scalar multiplication.
187
+ # The sparsity pattern (indices, indptr) remains the same.
188
+ new_data = self.data * other
189
+
190
+ # Return a new BCSR instance with the scaled data.
191
+ return BCSR((new_data, self.indices, self.indptr), shape=self.shape)
192
+
193
+ # For any other type of multiplication (e.g., element-wise with another matrix),
194
+ # return NotImplemented. This allows Python to try other operations,
195
+ # like other.__rmul__(self).
196
+ return NotImplemented
197
+
198
+
173
199
  tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = (
174
200
  _convert_to_tensor_jax
175
201
  )
@@ -219,6 +245,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
219
245
 
220
246
  self.name = "jax"
221
247
 
248
+ # --- Monkey-patch the BCSR class ---
249
+
250
+ sparse.BCSR.__mul__ = bcsr_scalar_mul # type: ignore
251
+ sparse.BCSR.__rmul__ = bcsr_scalar_mul # type: ignore
252
+
222
253
  # it is already child of numpy backend, and self.np = self.jax.np
223
254
  def eye(
224
255
  self, N: int, dtype: Optional[str] = None, M: Optional[int] = None
@@ -243,8 +274,10 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
243
274
  def copy(self, tensor: Tensor) -> Tensor:
244
275
  return jnp.array(tensor, copy=True)
245
276
 
246
- def convert_to_tensor(self, tensor: Tensor) -> Tensor:
277
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
247
278
  result = jnp.asarray(tensor)
279
+ if dtype is not None:
280
+ result = self.cast(result, dtype)
248
281
  return result
249
282
 
250
283
  def abs(self, a: Tensor) -> Tensor:
@@ -342,6 +375,15 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
342
375
  def mod(self, x: Tensor, y: Tensor) -> Tensor:
343
376
  return jnp.mod(x, y)
344
377
 
378
+ def floor(self, a: Tensor) -> Tensor:
379
+ return jnp.floor(a)
380
+
381
+ def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
382
+ return jnp.floor_divide(x, y)
383
+
384
+ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
385
+ return jnp.clip(a, a_min, a_max)
386
+
345
387
  def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
346
388
  return jnp.right_shift(x, y)
347
389
 
@@ -387,6 +429,12 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
387
429
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
388
430
  return jnp.argmin(a, axis=axis)
389
431
 
432
+ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
433
+ return jnp.argsort(a, axis=axis)
434
+
435
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
436
+ return jnp.sort(a, axis=axis)
437
+
390
438
  def unique_with_counts( # type: ignore
391
439
  self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
392
440
  ) -> Tuple[Tensor, Tensor]:
@@ -407,6 +455,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
407
455
  def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
408
456
  return jnp.cumsum(a, axis)
409
457
 
458
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
459
+ return jnp.all(a, axis=axis)
460
+
410
461
  def is_tensor(self, a: Any) -> bool:
411
462
  if not isinstance(a, jnp.ndarray):
412
463
  return False
@@ -418,6 +469,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
418
469
  def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
419
470
  return jsp.linalg.solve(A, b, assume_a=assume_a)
420
471
 
472
+ def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
473
+ from .jax_ops import bessel_jv_jax_rescaled
474
+
475
+ return bessel_jv_jax_rescaled(v, z, M)
476
+
421
477
  def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
422
478
  if not self.is_tensor(a):
423
479
  a = self.convert_to_tensor(a)
@@ -442,7 +498,14 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
442
498
  def to_dlpack(self, a: Tensor) -> Any:
443
499
  import jax.dlpack
444
500
 
445
- return jax.dlpack.to_dlpack(a)
501
+ try:
502
+ return jax.dlpack.to_dlpack(a) # type: ignore
503
+ except AttributeError: # jax >v0.7
504
+ # jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.
505
+ # Please use the newer DLPack API based on __dlpack__ and __dlpack_device__ instead.
506
+ # Typically, you can pass a JAX array directly to the `from_dlpack` function of
507
+ # another framework without using `to_dlpack`.
508
+ return a.__dlpack__()
446
509
 
447
510
  def set_random_state(
448
511
  self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False
@@ -608,7 +671,14 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
608
671
  carry, _ = libjax.lax.scan(f_jax, init, xs)
609
672
  return carry
610
673
 
674
+ def jaxy_scan(
675
+ self, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
676
+ ) -> Tensor:
677
+ return libjax.lax.scan(f, init, xs)
678
+
611
679
  def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
680
+ # updates = jnp.reshape(updates, indices.shape)
681
+ # return operand.at[indices].set(updates)
612
682
  rank = len(operand.shape)
613
683
  dnums = libjax.lax.ScatterDimensionNumbers(
614
684
  update_window_dims=(),
@@ -630,11 +700,20 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
630
700
  ) -> Tensor:
631
701
  return sp_a @ b
632
702
 
703
+ def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
704
+ try:
705
+ return sparse.BCSR.from_bcoo(coo) # type: ignore
706
+ except AttributeError as e:
707
+ if not strict:
708
+ return coo
709
+ else:
710
+ raise e
711
+
633
712
  def to_dense(self, sp_a: Tensor) -> Tensor:
634
713
  return sp_a.todense()
635
714
 
636
715
  def is_sparse(self, a: Tensor) -> bool:
637
- return isinstance(a, sparse.BCOO) # type: ignore
716
+ return isinstance(a, sparse.JAXSparse) # type: ignore
638
717
 
639
718
  def device(self, a: Tensor) -> str:
640
719
  (dev,) = a.devices()
@@ -781,4 +860,23 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
781
860
 
782
861
  vvag = vectorized_value_and_grad
783
862
 
863
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
864
+ """
865
+ Backend-agnostic meshgrid function.
866
+ """
867
+ return jnp.meshgrid(*args, **kwargs)
868
+
784
869
  optimizer = optax_optimizer
870
+
871
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
872
+ return jnp.expand_dims(a, axis)
873
+
874
+ def where(
875
+ self,
876
+ condition: Tensor,
877
+ x: Optional[Tensor] = None,
878
+ y: Optional[Tensor] = None,
879
+ ) -> Tensor:
880
+ if x is None and y is None:
881
+ return jnp.where(condition)
882
+ return jnp.where(condition, x, y)