tensorcircuit-nightly 1.3.0.dev20250815__py3-none-any.whl → 1.3.0.dev20250816__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

tensorcircuit/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "1.3.0.dev20250815"
1
+ __version__ = "1.3.0.dev20250816"
2
2
  __author__ = "TensorCircuit Authors"
3
3
  __creator__ = "refraction-ray"
4
4
 
@@ -596,6 +596,82 @@ class ExtendedBackend:
596
596
  "Backend '{}' has not implemented `argsort`.".format(self.name)
597
597
  )
598
598
 
599
+ def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
600
+ """
601
+ Sort a tensor along the given axis.
602
+
603
+ :param a: [description]
604
+ :type a: Tensor
605
+ :param axis: [description], defaults to -1
606
+ :type axis: int, optional
607
+ :return: [description]
608
+ :rtype: Tensor
609
+ """
610
+ raise NotImplementedError(
611
+ "Backend '{}' has not implemented `sort`.".format(self.name)
612
+ )
613
+
614
+ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
615
+ """
616
+ Test whether all array elements along a given axis evaluate to True.
617
+
618
+ :param a: Input tensor
619
+ :type a: Tensor
620
+ :param axis: Axis or axes along which a logical AND reduction is performed,
621
+ defaults to None
622
+ :type axis: Optional[Sequence[int]], optional
623
+ :return: A new boolean or tensor resulting from the AND reduction
624
+ :rtype: Tensor
625
+ """
626
+ raise NotImplementedError(
627
+ "Backend '{}' has not implemented `all`.".format(self.name)
628
+ )
629
+
630
+ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
631
+ """
632
+ Return coordinate matrices from coordinate vectors.
633
+
634
+ :param args: coordinate vectors
635
+ :type args: Any
636
+ :param kwargs: keyword arguments for meshgrid, typically includes 'indexing'
637
+ which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing).
638
+ - 'ij': matrix indexing, first dimension corresponds to rows (default)
639
+ - 'xy': Cartesian indexing, first dimension corresponds to columns
640
+ Example:
641
+ >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy')
642
+ Shapes:
643
+ - x.shape == (2, 2) # rows correspond to y vector length
644
+ - y.shape == (2, 2)
645
+ Values:
646
+ x = [[0, 1],
647
+ [0, 1]]
648
+ y = [[0, 0],
649
+ [2, 2]]
650
+ :type kwargs: Any
651
+ :return: list of coordinate matrices
652
+ :rtype: Any
653
+ """
654
+ raise NotImplementedError(
655
+ "Backend '{}' has not implemented `meshgrid`.".format(self.name)
656
+ )
657
+
658
+ def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
659
+ """
660
+ Expand the shape of a tensor.
661
+ Insert a new axis that will appear at the `axis` position in the expanded
662
+ tensor shape.
663
+
664
+ :param a: Input tensor
665
+ :type a: Tensor
666
+ :param axis: Position in the expanded axes where the new axis is placed
667
+ :type axis: int
668
+ :return: Output tensor with the number of dimensions increased by one.
669
+ :rtype: Tensor
670
+ """
671
+ raise NotImplementedError(
672
+ "Backend '{}' has not implemented `expand_dims`.".format(self.name)
673
+ )
674
+
599
675
  def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
600
676
  """
601
677
  Find the unique elements and their corresponding counts of the given tensor ``a``.
@@ -733,6 +809,21 @@ class ExtendedBackend:
733
809
  "Backend '{}' has not implemented `cast`.".format(self.name)
734
810
  )
735
811
 
812
+ def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor:
813
+ """
814
+ Convert input to tensor.
815
+
816
+ :param a: input data to be converted
817
+ :type a: Tensor
818
+ :param dtype: target dtype, optional
819
+ :type dtype: Optional[str]
820
+ :return: converted tensor
821
+ :rtype: Tensor
822
+ """
823
+ raise NotImplementedError(
824
+ "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name)
825
+ )
826
+
736
827
  def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
737
828
  """
738
829
  Compute y-mod of x (negative number behavior is not guaranteed to be consistent)
@@ -1404,6 +1495,28 @@ class ExtendedBackend:
1404
1495
  "Backend '{}' has not implemented `cond`.".format(self.name)
1405
1496
  )
1406
1497
 
1498
+ def where(
1499
+ self: Any,
1500
+ condition: Tensor,
1501
+ x: Optional[Tensor] = None,
1502
+ y: Optional[Tensor] = None,
1503
+ ) -> Tensor:
1504
+ """
1505
+ Return a tensor of elements selected from either x or y, depending on condition.
1506
+
1507
+ :param condition: Where True, yield x, otherwise yield y.
1508
+ :type condition: Tensor (bool)
1509
+ :param x: Values from which to choose when condition is True.
1510
+ :type x: Tensor
1511
+ :param y: Values from which to choose when condition is False.
1512
+ :type y: Tensor
1513
+ :return: A tensor with elements from x where condition is True, and y otherwise.
1514
+ :rtype: Tensor
1515
+ """
1516
+ raise NotImplementedError(
1517
+ "Backend '{}' has not implemented `where`.".format(self.name)
1518
+ )
1519
+
1407
1520
  def switch(
1408
1521
  self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
1409
1522
  ) -> Tensor:
@@ -56,10 +56,12 @@ class CuPyBackend(tnbackend, ExtendedBackend): # type: ignore
56
56
  cpx = cupyx
57
57
  self.name = "cupy"
58
58
 
59
- def convert_to_tensor(self, a: Tensor) -> Tensor:
59
+ def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor:
60
60
  if not isinstance(a, cp.ndarray) and not cp.isscalar(a):
61
61
  a = cp.array(a)
62
62
  a = cp.asarray(a)
63
+ if dtype is not None:
64
+ a = self.cast(a, dtype)
63
65
  return a
64
66
 
65
67
  def sum(
@@ -50,12 +50,17 @@ class optax_optimizer:
50
50
  return params
51
51
 
52
52
 
53
- def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor:
53
+ def _convert_to_tensor_jax(
54
+ self: Any, tensor: Tensor, dtype: Optional[str] = None
55
+ ) -> Tensor:
54
56
  if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor):
55
57
  raise TypeError(
56
58
  ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}")
57
59
  )
58
60
  result = jnp.asarray(tensor)
61
+ if dtype is not None:
62
+ # Use the backend's cast method to handle dtype conversion
63
+ result = self.cast(result, dtype)
59
64
  return result
60
65
 
61
66
 
@@ -243,8 +248,10 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
243
248
  def copy(self, tensor: Tensor) -> Tensor:
244
249
  return jnp.array(tensor, copy=True)
245
250
 
246
- def convert_to_tensor(self, tensor: Tensor) -> Tensor:
251
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
247
252
  result = jnp.asarray(tensor)
253
+ if dtype is not None:
254
+ result = self.cast(result, dtype)
248
255
  return result
249
256
 
250
257
  def abs(self, a: Tensor) -> Tensor:
@@ -390,6 +397,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
390
397
  def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
391
398
  return jnp.argsort(a, axis=axis)
392
399
 
400
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
401
+ return jnp.sort(a, axis=axis)
402
+
393
403
  def unique_with_counts( # type: ignore
394
404
  self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
395
405
  ) -> Tuple[Tensor, Tensor]:
@@ -410,6 +420,9 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
410
420
  def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
411
421
  return jnp.cumsum(a, axis)
412
422
 
423
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
424
+ return jnp.all(a, axis=axis)
425
+
413
426
  def is_tensor(self, a: Any) -> bool:
414
427
  if not isinstance(a, jnp.ndarray):
415
428
  return False
@@ -812,4 +825,23 @@ class JaxBackend(jax_backend.JaxBackend, ExtendedBackend): # type: ignore
812
825
 
813
826
  vvag = vectorized_value_and_grad
814
827
 
828
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
829
+ """
830
+ Backend-agnostic meshgrid function.
831
+ """
832
+ return jnp.meshgrid(*args, **kwargs)
833
+
815
834
  optimizer = optax_optimizer
835
+
836
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
837
+ return jnp.expand_dims(a, axis)
838
+
839
+ def where(
840
+ self,
841
+ condition: Tensor,
842
+ x: Optional[Tensor] = None,
843
+ y: Optional[Tensor] = None,
844
+ ) -> Tensor:
845
+ if x is None and y is None:
846
+ return jnp.where(condition)
847
+ return jnp.where(condition, x, y)
@@ -35,10 +35,14 @@ def _sum_numpy(
35
35
  # see https://github.com/google/TensorNetwork/issues/952
36
36
 
37
37
 
38
- def _convert_to_tensor_numpy(self: Any, a: Tensor) -> Tensor:
38
+ def _convert_to_tensor_numpy(
39
+ self: Any, a: Tensor, dtype: Optional[str] = None
40
+ ) -> Tensor:
39
41
  if not isinstance(a, np.ndarray) and not np.isscalar(a):
40
42
  a = np.array(a)
41
43
  a = np.asarray(a)
44
+ if dtype is not None:
45
+ a = a.astype(getattr(np, dtype))
42
46
  return a
43
47
 
44
48
 
@@ -132,6 +136,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
132
136
  def kron(self, a: Tensor, b: Tensor) -> Tensor:
133
137
  return np.kron(a, b)
134
138
 
139
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
140
+ return np.meshgrid(*args, **kwargs)
141
+
135
142
  def dtype(self, a: Tensor) -> str:
136
143
  return a.dtype.__str__() # type: ignore
137
144
 
@@ -151,6 +158,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
151
158
  dtype = getattr(np, dtype)
152
159
  return np.array(1j, dtype=dtype)
153
160
 
161
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
162
+ return np.expand_dims(a, axis)
163
+
154
164
  def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
155
165
  return np.stack(a, axis=axis)
156
166
 
@@ -173,6 +183,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
173
183
  ) -> Tensor:
174
184
  return np.std(a, axis=axis, keepdims=keepdims)
175
185
 
186
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
187
+ return np.all(a, axis=axis)
188
+
176
189
  def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
177
190
  return np.unique(a, return_counts=True) # type: ignore
178
191
 
@@ -188,6 +201,9 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
188
201
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
189
202
  return np.argmin(a, axis=axis)
190
203
 
204
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
205
+ return np.sort(a, axis=axis)
206
+
191
207
  def sigmoid(self, a: Tensor) -> Tensor:
192
208
  return expit(a)
193
209
 
@@ -345,6 +361,17 @@ class NumpyBackend(numpy_backend.NumPyBackend, ExtendedBackend): # type: ignore
345
361
  def is_sparse(self, a: Tensor) -> bool:
346
362
  return issparse(a) # type: ignore
347
363
 
364
+ def where(
365
+ self,
366
+ condition: Tensor,
367
+ x: Optional[Tensor] = None,
368
+ y: Optional[Tensor] = None,
369
+ ) -> Tensor:
370
+ if x is None and y is None:
371
+ return np.where(condition)
372
+ assert x is not None and y is not None
373
+ return np.where(condition, x, y)
374
+
348
375
  def cond(
349
376
  self,
350
377
  pred: bool,
@@ -238,6 +238,15 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
238
238
  def copy(self, a: Tensor) -> Tensor:
239
239
  return a.clone()
240
240
 
241
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
242
+ if self.is_tensor(tensor):
243
+ result = tensor
244
+ else:
245
+ result = torchlib.tensor(tensor)
246
+ if dtype is not None:
247
+ result = self.cast(result, dtype)
248
+ return result
249
+
241
250
  def expm(self, a: Tensor) -> Tensor:
242
251
  raise NotImplementedError("pytorch backend doesn't support expm")
243
252
  # in 2020, torch has no expm, hmmm. but that's ok,
@@ -369,6 +378,17 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
369
378
  def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
370
379
  return torchlib.argmin(a, dim=axis)
371
380
 
381
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
382
+ return torchlib.sort(a, dim=axis).values
383
+
384
+ def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
385
+ """
386
+ Corresponds to torch.all.
387
+ """
388
+ if axis is None:
389
+ return torchlib.all(tensor)
390
+ return torchlib.all(tensor, dim=axis)
391
+
372
392
  def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
373
393
  return torchlib.unique(a, return_counts=True) # type: ignore
374
394
 
@@ -425,6 +445,16 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
425
445
  v = self.convert_to_tensor(v)
426
446
  return torchlib.searchsorted(a, v, side=side)
427
447
 
448
+ def where(
449
+ self,
450
+ condition: Tensor,
451
+ x: Optional[Tensor] = None,
452
+ y: Optional[Tensor] = None,
453
+ ) -> Tensor:
454
+ if x is None and y is None:
455
+ return torchlib.where(condition)
456
+ return torchlib.where(condition, x, y)
457
+
428
458
  def reverse(self, a: Tensor) -> Tensor:
429
459
  return torchlib.flip(a, dims=(-1,))
430
460
 
@@ -706,6 +736,12 @@ class PyTorchBackend(pytorch_backend.PyTorchBackend, ExtendedBackend): # type:
706
736
 
707
737
  return wrapper
708
738
 
739
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
740
+ return torchlib.unsqueeze(a, dim=axis)
741
+
709
742
  vvag = vectorized_value_and_grad
710
743
 
744
+ def meshgrid(self, *args: Any, **kws: Any) -> Tensor:
745
+ return torchlib.meshgrid(*args, **kws)
746
+
711
747
  optimizer = torch_optimizer
@@ -75,6 +75,12 @@ class keras_optimizer:
75
75
  def _tensordot_tf(
76
76
  self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]]
77
77
  ) -> Tensor:
78
+ # Use TensorFlow's dtype promotion rules by converting both to a common dtype
79
+ if a.dtype != b.dtype:
80
+ # Find the result dtype using TensorFlow's type promotion rules
81
+ common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype)
82
+ a = tf.cast(a, common_dtype)
83
+ b = tf.cast(b, common_dtype)
78
84
  return tf.tensordot(a, b, axes)
79
85
 
80
86
 
@@ -441,6 +447,12 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
441
447
  def copy(self, a: Tensor) -> Tensor:
442
448
  return tf.identity(a)
443
449
 
450
+ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor:
451
+ result = tf.convert_to_tensor(tensor)
452
+ if dtype is not None:
453
+ result = self.cast(result, dtype)
454
+ return result
455
+
444
456
  def expm(self, a: Tensor) -> Tensor:
445
457
  return tf.linalg.expm(a)
446
458
 
@@ -524,6 +536,20 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
524
536
  def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
525
537
  return tf.reduce_max(a, axis=axis)
526
538
 
539
+ def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
540
+ return tf.reduce_all(tf.cast(a, tf.bool), axis=axis)
541
+
542
+ def where(
543
+ self,
544
+ condition: Tensor,
545
+ x: Optional[Tensor] = None,
546
+ y: Optional[Tensor] = None,
547
+ ) -> Tensor:
548
+ if x is None and y is None:
549
+ # Return a tuple of tensors to be consistent with other backends
550
+ return tuple(tf.unstack(tf.where(condition), axis=1))
551
+ return tf.where(condition, x, y)
552
+
527
553
  def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
528
554
  return tf.math.argmax(a, axis=axis)
529
555
 
@@ -533,6 +559,9 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
533
559
  def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
534
560
  return tf.argsort(a, axis=axis)
535
561
 
562
+ def sort(self, a: Tensor, axis: int = -1) -> Tensor:
563
+ return tf.sort(a, axis=axis)
564
+
536
565
  def shape_tuple(self, a: Tensor) -> Tuple[int, ...]:
537
566
  return tuple(a.shape)
538
567
 
@@ -1061,4 +1090,13 @@ class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend):
1061
1090
 
1062
1091
  vvag = vectorized_value_and_grad
1063
1092
 
1093
+ def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
1094
+ """
1095
+ Backend-agnostic meshgrid function.
1096
+ """
1097
+ return tf.meshgrid(*args, **kwargs)
1098
+
1064
1099
  optimizer = keras_optimizer
1100
+
1101
+ def expand_dims(self, a: Tensor, axis: int) -> Tensor:
1102
+ return tf.expand_dims(a, axis)
@@ -17,13 +17,14 @@ def _create_empty_sparse_matrix(shape: Tuple[int, int]) -> Any:
17
17
  def heisenberg_hamiltonian(
18
18
  lattice: AbstractLattice,
19
19
  j_coupling: Union[float, List[float], Tuple[float, ...]] = 1.0,
20
+ interaction_scope: str = "neighbors",
20
21
  ) -> Any:
21
22
  """
22
23
  Generates the sparse matrix of the Heisenberg Hamiltonian for a given lattice.
23
24
 
24
25
  The Heisenberg Hamiltonian is defined as:
25
- H = J * Σ_{<i,j>} (X_i X_j + Y_i Y_j + Z_i Z_j)
26
- where the sum is over all unique nearest-neighbor pairs <i,j>.
26
+ H = J * Σ_{i,j} (X_i X_j + Y_i Y_j + Z_i Z_j)
27
+ where the sum is over a specified set of interacting pairs {i,j}.
27
28
 
28
29
  :param lattice: An instance of a class derived from AbstractLattice,
29
30
  which provides the geometric information of the system.
@@ -32,11 +33,23 @@ def heisenberg_hamiltonian(
32
33
  isotropic model (Jx=Jy=Jz) or a list/tuple of 3 floats for an
33
34
  anisotropic model (Jx, Jy, Jz). Defaults to 1.0.
34
35
  :type j_coupling: Union[float, List[float], Tuple[float, ...]], optional
36
+ :param interaction_scope: Defines the range of interactions.
37
+ - "neighbors": Includes only nearest-neighbor pairs (default).
38
+ - "all": Includes all unique pairs of sites.
39
+ :type interaction_scope: str, optional
35
40
  :return: The Hamiltonian as a backend-agnostic sparse matrix.
36
41
  :rtype: Any
37
42
  """
38
43
  num_sites = lattice.num_sites
39
- neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True)
44
+ if interaction_scope == "neighbors":
45
+ neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True)
46
+ elif interaction_scope == "all":
47
+ neighbor_pairs = lattice.get_all_pairs()
48
+ else:
49
+ raise ValueError(
50
+ f"Invalid interaction_scope: '{interaction_scope}'. "
51
+ "Must be 'neighbors' or 'all'."
52
+ )
40
53
 
41
54
  if isinstance(j_coupling, (float, int)):
42
55
  js = [float(j_coupling)] * 3