tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
@@ -9,6 +9,7 @@ from functools import reduce, partial
9
9
  from operator import mul
10
10
  from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
11
11
 
12
+ import math
12
13
  import numpy as np
13
14
  from ..utils import return_partial
14
15
 
@@ -405,6 +406,31 @@ class ExtendedBackend:
405
406
  a = self.reshape(a, [2 for _ in range(nleg)])
406
407
  return a
407
408
 
409
+ def reshaped(self: Any, a: Tensor, d: int) -> Tensor:
410
+ """
411
+ Reshape a tensor to the [d, d, ...] shape.
412
+
413
+ :param a: Input tensor
414
+ :type a: Tensor
415
+ :param d: edge length for each dimension
416
+ :type d: int
417
+ :return: the reshaped tensor
418
+ :rtype: Tensor
419
+ """
420
+ if not isinstance(d, int) or d <= 0:
421
+ raise ValueError("d must be a positive integer.")
422
+
423
+ size = self.sizen(a)
424
+ if size == 0:
425
+ return self.reshape(a, (0,))
426
+
427
+ nleg_float = math.log(size, d)
428
+ nleg = int(round(nleg_float))
429
+ if d**nleg != size:
430
+ raise ValueError(f"cannot reshape: size {size} is not a power of d={d}")
431
+
432
+ return self.reshape(a, (d,) * nleg)
433
+
408
434
  def reshapem(self: Any, a: Tensor) -> Tensor:
409
435
  """
410
436
  Reshape a tensor to the [l, l] shape.
@@ -581,6 +607,97 @@ class ExtendedBackend:
581
607
  "Backend '{}' has not implemented `argmin`.".format(self.name)
582
608
  )
583
609
 
610
+ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
611
+ """
612
+ return the indices that would sort an array.
613
+
614
+ :param a: the tensor to be sorted
615
+ :type a: Tensor
616
+ :param axis: the sorted axis, defaults to -1
617
+ :type axis: int
618
+ :return: the sorted indices
619
+ :rtype: Tensor
620
+ """
621
+ raise NotImplementedError(
622
+ "Backend '{}' has not implemented `argsort`.".format(self.name)
623
+ )
624
+
625
+ def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
626
+ """
627
+ Sort a tensor along the given axis.
628
+
629
+ :param a: [description]
630
+ :type a: Tensor
631
+ :param axis: [description], defaults to -1
632
+ :type axis: int, optional
633
+ :return: [description]
634
+ :rtype: Tensor
635
+ """
636
+ raise NotImplementedError(
637
+ "Backend '{}' has not implemented `sort`.".format(self.name)
638
+ )
639
+
640
+ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
641
+ """
642
+ Test whether all array elements along a given axis evaluate to True.
643
+
644
+ :param a: Input tensor
645
+ :type a: Tensor
646
+ :param axis: Axis or axes along which a logical AND reduction is performed,
647
+ defaults to None
648
+ :type axis: Optional[Sequence[int]], optional
649
+ :return: A new boolean or tensor resulting from the AND reduction
650
+ :rtype: Tensor
651
+ """
652
+ raise NotImplementedError(
653
+ "Backend '{}' has not implemented `all`.".format(self.name)
654
+ )
655
+
656
+ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
657
+ """
658
+ Return coordinate matrices from coordinate vectors.
659
+
660
+ :param args: coordinate vectors
661
+ :type args: Any
662
+ :param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
663
+ which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
664
+ - 'ij': matrix indexing, first dimension corresponds to rows (default)
665
+ - 'xy': Cartesian indexing, first dimension corresponds to columns
666
+ Example:
667
+ >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
668
+ Shapes:
669
+ - x.shape == (2, 2) # rows correspond to y vector length
670
+ - y.shape == (2, 2)
671
+ Values:
672
+ x = [[0, 1],
673
+ [0, 1]]
674
+ y = [[0, 0],
675
+ [2, 2]]
676
+ :type kwargs: Any
677
+ :return: list of coordinate matrices
678
+ :rtype: Any
679
+ """
680
+ raise NotImplementedError(
681
+ "Backend '{}' has not implemented `meshgrid`.".format(self.name)
682
+ )
683
+
684
+ def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
685
+ """
686
+ Expand the shape of a tensor.
687
+ Insert a new axis that will appear at the `axis` position in the expanded
688
+ tensor shape.
689
+
690
+ :param a: Input tensor
691
+ :type a: Tensor
692
+ :param axis: Position in the expanded axes where the new axis is placed
693
+ :type axis: int
694
+ :return: Output tensor with the number of dimensions increased by one.
695
+ :rtype: Tensor
696
+ """
697
+ raise NotImplementedError(
698
+ "Backend '{}' has not implemented `expand_dims`.".format(self.name)
699
+ )
700
+
584
701
  def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
585
702
  """
586
703
  Find the unique elements and their corresponding counts of the given tensor ``a``.
@@ -700,6 +817,9 @@ class ExtendedBackend:
700
817
  "Backend '{}' has not implemented `is_tensor`.".format(self.name)
701
818
  )
702
819
 
820
+ def matvec(self: Any, A: Tensor, x: Tensor) -> Tensor:
821
+ return self.tensordot(A, x, axes=[[1], [0]])
822
+
703
823
  def cast(self: Any, a: Tensor, dtype: str) -> Tensor:
704
824
  """
705
825
  Cast the tensor dtype of a ``a``.
@@ -715,6 +835,21 @@ class ExtendedBackend:
715
835
  "Backend '{}' has not implemented `cast`.".format(self.name)
716
836
  )
717
837
 
838
+ def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
839
+ """
840
+ Convert input to tensor.
841
+
842
+ :param a: input data to be converted
843
+ :type a: Tensor
844
+ :param dtype: target dtype, optional
845
+ :type dtype: Optional[str]
846
+ :return: converted tensor
847
+ :rtype: Tensor
848
+ """
849
+ raise NotImplementedError(
850
+ "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
851
+ )
852
+
718
853
  def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
719
854
  """
720
855
  Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
@@ -730,6 +865,82 @@ class ExtendedBackend:
730
865
  "Backend '{}' has not implemented `mod`.".format(self.name)
731
866
  )
732
867
 
868
+ def floor_divide(self: Any, x: Tensor, y: Tensor) -> Tensor:
869
+ r"""
870
+ Compute the element-wise floor division of two tensors.
871
+
872
+ This operation returns a new tensor containing the result of
873
+ dividing `x` by `y` and rounding each element down towards
874
+ negative infinity. The semantics are equivalent to the Python
875
+ `//` operator:
876
+
877
+ result[i] = floor(x[i] / y[i])
878
+
879
+ Broadcasting is supported according to the backend's rules.
880
+
881
+ :param x: Dividend tensor.
882
+ :type x: Tensor
883
+ :param y: Divisor tensor, must be broadcastable with `x`.
884
+ :type y: Tensor
885
+ :return: A tensor with the broadcasted shape of `x` and `y`,
886
+ where each element is the floored result of the division.
887
+ :rtype: Tensor
888
+
889
+ :raises NotImplementedError: If the backend does not provide an
890
+ implementation for `floor_divide`.
891
+ """
892
+ raise NotImplementedError(
893
+ "Backend '{}' has not implemented `floor_divide`.".format(self.name)
894
+ )
895
+
896
+ def floor(self: Any, a: Tensor) -> Tensor:
897
+ """
898
+ Compute the element-wise floor of the input tensor.
899
+
900
+ This operation returns a new tensor with the largest integers
901
+ less than or equal to each element of the input tensor,
902
+ i.e. it rounds each value down towards negative infinity.
903
+
904
+ :param a: Input tensor containing numeric values.
905
+ :type a: Tensor
906
+ :return: A tensor with the same shape as `a`, where each element
907
+ is the floored value of the corresponding element in `a`.
908
+ :rtype: Tensor
909
+
910
+ :raises NotImplementedError: If the backend does not provide an
911
+ implementation for `floor`.
912
+ """
913
+ raise NotImplementedError(
914
+ "Backend '{}' has not implemented `floor`.".format(self.name)
915
+ )
916
+
917
+ def clip(self: Any, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
918
+ """
919
+ Clip (limit) the values of a tensor element-wise to the range [a_min, a_max].
920
+
921
+ Each element in the input tensor `a` is compared against the corresponding
922
+ bounds `a_min` and `a_max`. If a value in `a` is less than `a_min`, it is set
923
+ to `a_min`; if greater than `a_max`, it is set to `a_max`. Otherwise, the
924
+ value is left unchanged. The result preserves the dtype and device of the input.
925
+
926
+ :param a: Input tensor containing values to be clipped.
927
+ :type a: Tensor
928
+ :param a_min: Lower bound (minimum value) for clipping. Can be a scalar tensor
929
+ or broadcastable to the shape of `a`.
930
+ :type a_min: Tensor
931
+ :param a_max: Upper bound (maximum value) for clipping. Can be a scalar tensor
932
+ or broadcastable to the shape of `a`.
933
+ :type a_max: Tensor
934
+ :return: A tensor with the same shape as `a`, where all values are clipped
935
+ to lie within the interval [a_min, a_max].
936
+ :rtype: Tensor
937
+
938
+ :raises NotImplementedError: If the backend does not implement `clip`.
939
+ """
940
+ raise NotImplementedError(
941
+ "Backend '{}' has not implemented `clip`.".format(self.name)
942
+ )
943
+
733
944
  def reverse(self: Any, a: Tensor) -> Tensor:
734
945
  """
735
946
  return ``a[::-1]``, only 1D tensor is guaranteed for consistent behavior
@@ -805,6 +1016,21 @@ class ExtendedBackend:
805
1016
  "Backend '{}' has not implemented `solve`.".format(self.name)
806
1017
  )
807
1018
 
1019
+ def special_jv(self: Any, v: int, z: Tensor, M: int) -> Tensor:
1020
+ """
1021
+ Special function: Bessel function of the first kind.
1022
+
1023
+ :param v: The order of the Bessel function.
1024
+ :type v: int
1025
+ :param z: The argument of the Bessel function.
1026
+ :type z: Tensor
1027
+ :return: The value of the Bessel function [J_0, ...J_{v-1}(z)].
1028
+ :rtype: Tensor
1029
+ """
1030
+ raise NotImplementedError(
1031
+ "Backend '{}' has not implemented `special_jv`.".format(self.name)
1032
+ )
1033
+
808
1034
  def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
809
1035
  """
810
1036
  Find indices where elements should be inserted to maintain order.
@@ -813,8 +1039,8 @@ class ExtendedBackend:
813
1039
  :type a: Tensor
814
1040
  :param v: value to inserted
815
1041
  :type v: Tensor
816
- :param side: If left’, the index of the first suitable location found is given.
817
- If right’, return the last such index.
1042
+ :param side: If `left`, the index of the first suitable location found is given.
1043
+ If `right`, return the last such index.
818
1044
  If there is no suitable index, return either 0 or N (where N is the length of a),
819
1045
  defaults to "left"
820
1046
  :type side: str, optional
@@ -1248,6 +1474,26 @@ class ExtendedBackend:
1248
1474
  "Backend '{}' has not implemented `sparse_dense_matmul`.".format(self.name)
1249
1475
  )
1250
1476
 
1477
+ def sparse_csr_from_coo(self: Any, coo: Tensor, strict: bool = False) -> Tensor:
1478
+ """
1479
+ transform a coo matrix to a csr matrix
1480
+
1481
+ :param coo: a coo matrix
1482
+ :type coo: Tensor
1483
+ :param strict: whether to enforce the transform, defaults to False,
1484
+ corresponding to return the coo matrix if there is no implementation for specific backend.
1485
+ :type strict: bool, optional
1486
+ :return: a csr matrix
1487
+ :rtype: Tensor
1488
+ """
1489
+ if strict:
1490
+ raise NotImplementedError(
1491
+ "Backend '{}' has not implemented `sparse_csr_from_coo`.".format(
1492
+ self.name
1493
+ )
1494
+ )
1495
+ return coo
1496
+
1251
1497
  def to_dense(self: Any, sp_a: Tensor) -> Tensor:
1252
1498
  """
1253
1499
  Convert a sparse matrix to dense tensor.
@@ -1351,6 +1597,28 @@ class ExtendedBackend:
1351
1597
  "Backend '{}' has not implemented `cond`.".format(self.name)
1352
1598
  )
1353
1599
 
1600
+ def where(
1601
+ self: Any,
1602
+ condition: Tensor,
1603
+ x: Optional[Tensor] = None,
1604
+ y: Optional[Tensor] = None,
1605
+ ) -> Tensor:
1606
+ """
1607
+ Return a tensor of elements selected from either x or y, depending on condition.
1608
+
1609
+ :param condition: Where True, yield x, otherwise yield y.
1610
+ :type condition: Tensor (bool)
1611
+ :param x: Values from which to choose when condition is True.
1612
+ :type x: Tensor
1613
+ :param y: Values from which to choose when condition is False.
1614
+ :type y: Tensor
1615
+ :return: A tensor with elements from x where condition is True, and y otherwise.
1616
+ :rtype: Tensor
1617
+ """
1618
+ raise NotImplementedError(
1619
+ "Backend '{}' has not implemented `where`.".format(self.name)
1620
+ )
1621
+
1354
1622
  def switch(
1355
1623
  self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
1356
1624
  ) -> Tensor:
@@ -1386,10 +1654,49 @@ class ExtendedBackend:
1386
1654
  :rtype: Tensor
1387
1655
  """
1388
1656
  carry = init
1389
- for x in xs:
1390
- carry = f(carry, x)
1657
+ # Check if `xs` is a PyTree (tuple or list) of arrays.
1658
+ if isinstance(xs, (tuple, list)):
1659
+ for x_slice_tuple in zip(*xs):
1660
+ # x_slice_tuple will be (k_elems[i], j_elems[i]) at each step.
1661
+ carry = f(carry, x_slice_tuple)
1662
+ else:
1663
+ # If xs is a single array, iterate normally.
1664
+ for x in xs:
1665
+ carry = f(carry, x)
1666
+
1391
1667
  return carry
1392
1668
 
1669
+ def jaxy_scan(
1670
+ self: Any, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
1671
+ ) -> Tensor:
1672
+ """
1673
+ This API follows jax scan style. TF use plain for loop
1674
+
1675
+ :param f: _description_
1676
+ :type f: Callable[[Tensor, Tensor], Tensor]
1677
+ :param init: _description_
1678
+ :type init: Tensor
1679
+ :param xs: _description_
1680
+ :type xs: Tensor
1681
+ :raises ValueError: _description_
1682
+ :return: _description_
1683
+ :rtype: Tensor
1684
+ """
1685
+ if xs is None:
1686
+ raise ValueError("Either xs or length must be provided.")
1687
+ if xs is not None:
1688
+ length = len(xs)
1689
+ carry, outputs_to_stack = init, []
1690
+ for i in range(length):
1691
+ if isinstance(xs, (tuple, list)):
1692
+ x = [ele[i] for ele in xs]
1693
+ else:
1694
+ x = xs[i]
1695
+ new_carry, y = f(carry, x)
1696
+ carry = new_carry
1697
+ outputs_to_stack.append(y)
1698
+ return carry, self.stack(outputs_to_stack)
1699
+
1393
1700
  def stop_gradient(self: Any, a: Tensor) -> Tensor:
1394
1701
  """
1395
1702
  Stop backpropagation from ``a``.
@@ -1668,7 +1975,7 @@ class ExtendedBackend:
1668
1975
  f: Callable[..., Any],
1669
1976
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
1670
1977
  jit_compile: Optional[bool] = None,
1671
- **kws: Any
1978
+ **kws: Any,
1672
1979
  ) -> Callable[..., Any]:
1673
1980
  """
1674
1981
  Return the jitted version of function ``f``.
@@ -56,10 +56,12 @@ class CuPyBackend(tnbackend, ExtendedBackend): # type: ignore
56
56
  cpx = cupyx
57
57
  self.name = "cupy"
58
58
 
59
- def convert_to_tensor(self, a: Tensor) -> Tensor:
59
+ def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
60
60
  if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
61
61
  a = cp.array(a)
62
62
  a = cp.asarray(a)
63
+ if dtype is not None:
64
+ a = self.cast(a, dtype)
63
65
  return a
64
66
 
65
67
  def sum(
@@ -50,12 +50,17 @@ class optax_optimizer:
50
50
  return params
51
51
 
52
52
 
53
- def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor:
53
+ def _convert_to_tensor_jax(
54
+ self: Any, tensor: Tensor, dtype: Optional[str] = None
55
+ ) -> Tensor:
54
56
  if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
55
57
  raise TypeError(
56
58
  ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
57
59
  )
58
60
  result = jnp.asarray(tensor)
61
+ if dtype is not None:
62
+ # Use the backend's cast method to handle dtype conversion
63
+ result = self.cast(result, dtype)
59
64
  return result
60
65
 
61
66
 
@@ -170,6 +175,27 @@ def _eigh_jax(self: Any, tensor: Tensor) -> Tensor:
170
175
  return adaware_eigh(tensor)
171
176
 
172
177
 
178
+ def bcsr_scalar_mul(self: Tensor, other: Tensor) -> Tensor:
179
+ """
180
+ Implements scalar multiplication for BCSR matrices (self * scalar).
181
+ """
182
+ import jax.numpy as jnp
183
+ from jax.experimental.sparse import BCSR
184
+
185
+ if jnp.isscalar(other):
186
+ # The core logic: only the data array is affected by scalar multiplication.
187
+ # The sparsity pattern (indices, indptr) remains the same.
188
+ new_data = self.data * other
189
+
190
+ # Return a new BCSR instance with the scaled data.
191
+ return BCSR((new_data, self.indices, self.indptr), shape=self.shape)
192
+
193
+ # For any other type of multiplication (e.g., element-wise with another matrix),
194
+ # return NotImplemented. This allows Python to try other operations,
195
+ # like other.__rmul__(self).
196
+ return NotImplemented
197
+
198
+
173
199
  tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = (
174
200
  _convert_to_tensor_jax
175
201
  )
@@ -219,6 +245,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
219
245
 
220
246
  self.name = "jax"
221
247
 
248
+ # --- Monkey-patch the BCSR class ---
249
+
250
+ sparse.BCSR.__mul__ = bcsr_scalar_mul # type: ignore
251
+ sparse.BCSR.__rmul__ = bcsr_scalar_mul # type: ignore
252
+
222
253
  # it is already child of numpy backend, and self.np = self.jax.np
223
254
  def eye(
224
255
  self, N: int, dtype: Optional[str] = None, M: Optional[int] = None
@@ -243,8 +274,10 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
243
274
  def copy(self, tensor: Tensor) -> Tensor:
244
275
  return jnp.array(tensor, copy=True)
245
276
 
246
- def convert_to_tensor(self, tensor: Tensor) -> Tensor:
277
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
247
278
  result = jnp.asarray(tensor)
279
+ if dtype is not None:
280
+ result = self.cast(result, dtype)
248
281
  return result
249
282
 
250
283
  def abs(self, a: Tensor) -> Tensor:
@@ -342,6 +375,15 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
342
375
  def mod(self, x: Tensor, y: Tensor) -> Tensor:
343
376
  return jnp.mod(x, y)
344
377
 
378
+ def floor(self, a: Tensor) -> Tensor:
379
+ return jnp.floor(a)
380
+
381
+ def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
382
+ return jnp.floor_divide(x, y)
383
+
384
+ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
385
+ return jnp.clip(a, a_min, a_max)
386
+
345
387
  def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
346
388
  return jnp.right_shift(x, y)
347
389
 
@@ -387,6 +429,12 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
387
429
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
388
430
  return jnp.argmin(a, axis=axis)
389
431
 
432
+ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
433
+ return jnp.argsort(a, axis=axis)
434
+
435
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
436
+ return jnp.sort(a, axis=axis)
437
+
390
438
  def unique_with_counts( # type: ignore
391
439
  self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
392
440
  ) -> Tuple[Tensor, Tensor]:
@@ -407,6 +455,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
407
455
  def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
408
456
  return jnp.cumsum(a, axis)
409
457
 
458
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
459
+ return jnp.all(a, axis=axis)
460
+
410
461
  def is_tensor(self, a: Any) -> bool:
411
462
  if not isinstance(a, jnp.ndarray):
412
463
  return False
@@ -418,6 +469,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
418
469
  def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
419
470
  return jsp.linalg.solve(A, b, assume_a=assume_a)
420
471
 
472
+ def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
473
+ from .jax_ops import bessel_jv_jax_rescaled
474
+
475
+ return bessel_jv_jax_rescaled(v, z, M)
476
+
421
477
  def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
422
478
  if not self.is_tensor(a):
423
479
  a = self.convert_to_tensor(a)
@@ -442,7 +498,14 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
442
498
  def to_dlpack(self, a: Tensor) -> Any:
443
499
  import jax.dlpack
444
500
 
445
- return jax.dlpack.to_dlpack(a)
501
+ try:
502
+ return jax.dlpack.to_dlpack(a) # type: ignore
503
+ except AttributeError: # jax >v0.7
504
+ # jax.dlpack.to_dlpack was deprecated in JAX v0.6.0 and removed in JAX v0.7.0.
505
+ # Please use the newer DLPack API based on __dlpack__ and __dlpack_device__ instead.
506
+ # Typically, you can pass a JAX array directly to the `from_dlpack` function of
507
+ # another framework without using `to_dlpack`.
508
+ return a.__dlpack__()
446
509
 
447
510
  def set_random_state(
448
511
  self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False
@@ -608,6 +671,11 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
608
671
  carry, _ = libjax.lax.scan(f_jax, init, xs)
609
672
  return carry
610
673
 
674
+ def jaxy_scan(
675
+ self, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
676
+ ) -> Tensor:
677
+ return libjax.lax.scan(f, init, xs)
678
+
611
679
  def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
612
680
  # updates = jnp.reshape(updates, indices.shape)
613
681
  # return operand.at[indices].set(updates)
@@ -632,11 +700,20 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
632
700
  ) -> Tensor:
633
701
  return sp_a @ b
634
702
 
703
+ def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
704
+ try:
705
+ return sparse.BCSR.from_bcoo(coo) # type: ignore
706
+ except AttributeError as e:
707
+ if not strict:
708
+ return coo
709
+ else:
710
+ raise e
711
+
635
712
  def to_dense(self, sp_a: Tensor) -> Tensor:
636
713
  return sp_a.todense()
637
714
 
638
715
  def is_sparse(self, a: Tensor) -> bool:
639
- return isinstance(a, sparse.BCOO) # type: ignore
716
+ return isinstance(a, sparse.JAXSparse) # type: ignore
640
717
 
641
718
  def device(self, a: Tensor) -> str:
642
719
  (dev,) = a.devices()
@@ -783,4 +860,23 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
783
860
 
784
861
  vvag = vectorized_value_and_grad
785
862
 
863
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
864
+ """
865
+ Backend-agnostic meshgrid function.
866
+ """
867
+ return jnp.meshgrid(*args, **kwargs)
868
+
786
869
  optimizer = optax_optimizer
870
+
871
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
872
+ return jnp.expand_dims(a, axis)
873
+
874
+ def where(
875
+ self,
876
+ condition: Tensor,
877
+ x: Optional[Tensor] = None,
878
+ y: Optional[Tensor] = None,
879
+ ) -> Tensor:
880
+ if x is None and y is None:
881
+ return jnp.where(condition)
882
+ return jnp.where(condition, x, y)