tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
@@ -4,21 +4,25 @@ Quantum circuit: MPS state simulator
4
4
 
5
5
  # pylint: disable=invalid-name
6
6
 
7
- from functools import reduce
7
+ from functools import reduce, partial
8
8
  from typing import Any, List, Optional, Sequence, Tuple, Dict, Union
9
9
  from copy import copy
10
+ import logging
11
+ import types
10
12
 
11
13
  import numpy as np
12
14
  import tensornetwork as tn
13
- from tensorcircuit.quantum import QuOperator, QuVector
14
15
 
15
16
  from . import gates
16
17
  from .cons import backend, npdtype, contractor, rdtypestr, dtypestr
18
+ from .quantum import QuOperator, QuVector, extract_tensors_from_qop, _decode_basis_label
17
19
  from .mps_base import FiniteMPS
18
20
  from .abstractcircuit import AbstractCircuit
21
+ from .utils import arg_alias
19
22
 
20
23
  Gate = gates.Gate
21
24
  Tensor = Any
25
+ logger = logging.getLogger(__name__)
22
26
 
23
27
 
24
28
  def split_tensor(
@@ -77,6 +81,10 @@ class MPSCircuit(AbstractCircuit):
77
81
 
78
82
  is_mps = True
79
83
 
84
+ @partial(
85
+ arg_alias,
86
+ alias_dict={"wavefunction": ["inputs"]},
87
+ )
80
88
  def __init__(
81
89
  self,
82
90
  nqubits: int,
@@ -84,12 +92,16 @@ class MPSCircuit(AbstractCircuit):
84
92
  tensors: Optional[Sequence[Tensor]] = None,
85
93
  wavefunction: Optional[Union[QuVector, Tensor]] = None,
86
94
  split: Optional[Dict[str, Any]] = None,
95
+ dim: Optional[int] = None,
87
96
  ) -> None:
88
97
  """
89
98
  MPSCircuit object based on state simulator.
99
+ Do not use this class with d!=2 directly
90
100
 
91
101
  :param nqubits: The number of qubits in the circuit.
92
102
  :type nqubits: int
103
+ :param dim: The local Hilbert space dimension per site. Qudit is supported for 2 <= d <= 36.
104
+ :type dim: If None, the dimension of the circuit will be `2`, which is a qubit system.
93
105
  :param center_position: The center position of MPS, default to 0
94
106
  :type center_position: int, optional
95
107
  :param tensors: If not None, the initial state of the circuit is taken as ``tensors``
@@ -102,6 +114,7 @@ class MPSCircuit(AbstractCircuit):
102
114
  :param split: Split rules
103
115
  :type split: Any
104
116
  """
117
+ self._d = 2 if dim is None else dim
105
118
  self.circuit_param = {
106
119
  "nqubits": nqubits,
107
120
  "center_position": center_position,
@@ -118,8 +131,21 @@ class MPSCircuit(AbstractCircuit):
118
131
  ), "tensors and wavefunction cannot be used at input simutaneously"
119
132
  # TODO(@SUSYUSTC): find better way to address QuVector
120
133
  if isinstance(wavefunction, QuVector):
121
- wavefunction = wavefunction.eval()
122
- tensors = self.wavefunction_to_tensors(wavefunction, split=self.split)
134
+ try:
135
+ nodes, is_mps, _ = extract_tensors_from_qop(wavefunction)
136
+ if not is_mps:
137
+ raise ValueError("wavefunction is not a valid MPS")
138
+ tensors = [node.tensor for node in nodes]
139
+ except ValueError as e:
140
+ logger.warning(repr(e))
141
+ wavefunction = wavefunction.eval()
142
+ tensors = self.wavefunction_to_tensors(
143
+ wavefunction, split=self.split
144
+ )
145
+ else: # full wavefunction
146
+ tensors = self.wavefunction_to_tensors(
147
+ wavefunction, dim_phys=self._d, split=self.split
148
+ )
123
149
  assert len(tensors) == nqubits
124
150
  self._mps = FiniteMPS(tensors, canonicalize=False)
125
151
  self._mps.center_position = 0
@@ -133,8 +159,13 @@ class MPSCircuit(AbstractCircuit):
133
159
  self._mps = FiniteMPS(tensors, canonicalize=True, center_position=0)
134
160
  else:
135
161
  tensors = [
136
- np.array([1.0, 0.0], dtype=npdtype)[None, :, None]
137
- for i in range(nqubits)
162
+ np.concatenate(
163
+ [
164
+ np.array([1.0], dtype=npdtype),
165
+ np.zeros((self._d - 1,), dtype=npdtype),
166
+ ]
167
+ )[None, :, None]
168
+ for _ in range(nqubits)
138
169
  ]
139
170
  self._mps = FiniteMPS(tensors, canonicalize=False)
140
171
  if center_position is not None:
@@ -260,6 +291,17 @@ class MPSCircuit(AbstractCircuit):
260
291
  index_to: int,
261
292
  split: Optional[Dict[str, Any]] = None,
262
293
  ) -> None:
294
+ """
295
+ Apply a series of SWAP gates to move a qubit from ``index_from`` to ``index_to``.
296
+
297
+ :param index_from: The starting index of the qubit.
298
+ :type index_from: int
299
+ :param index_to: The destination index of the qubit.
300
+ :type index_to: int
301
+ :param split: Truncation options for the SWAP gates. Defaults to None.
302
+ consistent with the split option of the class.
303
+ :type split: Optional[Dict[str, Any]], optional
304
+ """
263
305
  if split is None:
264
306
  split = self.split
265
307
  self.position(index_from)
@@ -341,42 +383,51 @@ class MPSCircuit(AbstractCircuit):
341
383
  # b
342
384
 
343
385
  # index must be ordered
344
- assert np.all(np.diff(index) > 0)
345
- index_left = np.min(index)
386
+ if len(index) == 0:
387
+ raise ValueError("`index` must contain at least one site.")
388
+ if not all(index[i] < index[i + 1] for i in range(len(index) - 1)):
389
+ raise ValueError("`index` must be strictly increasing.")
390
+
391
+ index_left = int(np.min(index))
346
392
  if isinstance(gate, tn.Node):
347
393
  gate = backend.copy(gate.tensor)
348
- index = np.array(index) - index_left
394
+
349
395
  nindex = len(index)
396
+ in_dims = tuple(backend.shape_tuple(gate))[:nindex]
397
+ dim = int(in_dims[0])
398
+ dim_phys_mpo = dim * dim
399
+ gate = backend.reshape(gate, (dim,) * nindex + (dim,) * nindex)
350
400
  # transform gate from (in1, in2, ..., out1, out2 ...) to
351
401
  # (in1, out1, in2, out2, ...)
352
- order = tuple(np.arange(2 * nindex).reshape((2, nindex)).T.flatten())
353
- shape = (4,) * nindex
354
- gate = backend.reshape(backend.transpose(gate, order), shape)
355
- argsort = np.argsort(index)
402
+ order = tuple(np.arange(2 * nindex).reshape(2, nindex).T.flatten().tolist())
403
+ gate = backend.transpose(gate, order)
356
404
  # reorder the gate according to the site positions
357
- gate = backend.transpose(gate, tuple(argsort))
358
- index = index[argsort] # type: ignore
405
+ gate = backend.reshape(gate, (dim_phys_mpo,) * nindex)
359
406
  # split the gate into tensors assuming they are adjacent
360
- main_tensors = cls.wavefunction_to_tensors(gate, dim_phys=4, norm=False)
407
+ main_tensors = cls.wavefunction_to_tensors(
408
+ gate, dim_phys=dim_phys_mpo, norm=False
409
+ )
361
410
  # each tensor is in shape of (i, a, b, j)
362
- tensors = []
363
- previous_i = None
364
- for i, main_tensor in zip(index, main_tensors):
365
- # insert identites in the middle
411
+ tensors: list[Tensor] = []
412
+ previous_i: Optional[int] = None
413
+ index_arr = np.array(index, dtype=int) - index_left
414
+
415
+ for i, main_tensor in zip(index_arr, main_tensors):
366
416
  if previous_i is not None:
367
- for _ in range(previous_i + 1, i):
368
- bond_dim = tensors[-1].shape[-1]
369
- I = (
370
- np.eye(bond_dim * 2)
371
- .reshape((bond_dim, 2, bond_dim, 2))
372
- .transpose((0, 1, 3, 2))
373
- .astype(dtypestr)
417
+ for _gap_site in range(int(previous_i) + 1, int(i)):
418
+ bond_dim = int(backend.shape_tuple(tensors[-1])[-1])
419
+ eye2d = backend.eye(
420
+ bond_dim * dim, dtype=backend.dtype(tensors[-1])
374
421
  )
375
- tensors.append(backend.convert_to_tensor(I))
376
- nleft, _, nright = main_tensor.shape
377
- tensor = backend.reshape(main_tensor, (nleft, 2, 2, nright))
422
+ I4 = backend.reshape(eye2d, (bond_dim, dim, bond_dim, dim))
423
+ I4 = backend.transpose(I4, (0, 1, 3, 2))
424
+ tensors.append(I4)
425
+
426
+ nleft, _, nright = backend.shape_tuple(main_tensor)
427
+ tensor = backend.reshape(main_tensor, (int(nleft), dim, dim, int(nright)))
378
428
  tensors.append(tensor)
379
- previous_i = i
429
+ previous_i = int(i)
430
+
380
431
  return tensors, index_left
381
432
 
382
433
  @classmethod
@@ -419,15 +470,15 @@ class MPSCircuit(AbstractCircuit):
419
470
  """
420
471
  if split is None:
421
472
  split = {}
422
- ni = tensor_left.shape[0]
423
- nk = tensor_right.shape[-1]
473
+ ni, di = tensor_left.shape[0], tensor_right.shape[1]
474
+ nk, dk = tensor_right.shape[-1], tensor_right.shape[-2]
424
475
  T = backend.einsum("iaj,jbk->iabk", tensor_left, tensor_right)
425
- T = backend.reshape(T, (ni * 2, nk * 2))
476
+ T = backend.reshape(T, (ni * di, nk * dk))
426
477
  new_tensor_left, new_tensor_right = split_tensor(
427
478
  T, center_left=center_left, split=split
428
479
  )
429
- new_tensor_left = backend.reshape(new_tensor_left, (ni, 2, -1))
430
- new_tensor_right = backend.reshape(new_tensor_right, (-1, 2, nk))
480
+ new_tensor_left = backend.reshape(new_tensor_left, (ni, di, -1))
481
+ new_tensor_right = backend.reshape(new_tensor_right, (-1, dk, nk))
431
482
  return new_tensor_left, new_tensor_right
432
483
 
433
484
  def reduce_dimension(
@@ -437,7 +488,15 @@ class MPSCircuit(AbstractCircuit):
437
488
  split: Optional[Dict[str, Any]] = None,
438
489
  ) -> None:
439
490
  """
440
- Reduce the bond dimension between two adjacent sites by SVD
491
+ Reduce the bond dimension between two adjacent sites using SVD.
492
+
493
+ :param index_left: The index of the left tensor of the bond to be truncated.
494
+ :type index_left: int
495
+ :param center_left: If True, the orthogonality center will be on the left tensor after truncation.
496
+ Otherwise, it will be on the right tensor. Defaults to True.
497
+ :type center_left: bool, optional
498
+ :param split: Truncation options for the SVD. Defaults to None.
499
+ :type split: Optional[Dict[str, Any]], optional
441
500
  """
442
501
  if split is None:
443
502
  split = self.split
@@ -463,7 +522,22 @@ class MPSCircuit(AbstractCircuit):
463
522
  split: Optional[Dict[str, Any]] = None,
464
523
  ) -> None:
465
524
  """
466
- Apply a MPO to the MPS
525
+ Apply a Matrix Product Operator (MPO) to the MPS.
526
+
527
+ The application involves three main steps:
528
+ 1. Contract the MPO tensors with the corresponding MPS tensors.
529
+ 2. Canonicalize the resulting tensors by moving the orthogonality center.
530
+ 3. Truncate the bond dimensions to control complexity.
531
+
532
+ :param tensors: A sequence of tensors representing the MPO.
533
+ :type tensors: Sequence[Tensor]
534
+ :param index_left: The starting index on the MPS where the MPO is applied.
535
+ :type index_left: int
536
+ :param center_left: If True, the final orthogonality center will be at the left end of the MPO.
537
+ Otherwise, it will be at the right end. Defaults to True.
538
+ :type center_left: bool, optional
539
+ :param split: Truncation options for bond dimension reduction. Defaults to None.
540
+ :type split: Optional[Dict[str, Any]], optional
467
541
  """
468
542
  # step 1:
469
543
  # contract tensor
@@ -498,10 +572,11 @@ class MPSCircuit(AbstractCircuit):
498
572
  for i, idx in zip(i_list, idx_list):
499
573
  O = tensors[i]
500
574
  T = self._mps.tensors[idx]
501
- ni, _, _, nj = O.shape
575
+ ni, d_in, _, nj = O.shape
502
576
  nk, _, nl = T.shape
503
577
  OT = backend.einsum("iabj,kbl->ikajl", O, T)
504
- OT = backend.reshape(OT, (ni * nk, 2, nj * nl))
578
+ OT = backend.reshape(OT, (ni * nk, d_in, nj * nl))
579
+
505
580
  self._mps.tensors[idx] = OT
506
581
 
507
582
  # canonicalize
@@ -521,10 +596,17 @@ class MPSCircuit(AbstractCircuit):
521
596
  *index: int,
522
597
  split: Optional[Dict[str, Any]] = None,
523
598
  ) -> None:
524
- # TODO(@SUSYUSTC): jax autograd is wrong on this function
525
599
  """
526
- Apply a n-qubit gate by transforming the gate to MPO
600
+ Apply an n-qubit gate to the MPS by converting it to an MPO.
601
+
602
+ :param gate: The n-qubit gate to apply.
603
+ :type gate: Gate
604
+ :param index: The indices of the qubits to apply the gate to.
605
+ :type index: int
606
+ :param split: Truncation options for the MPO application. Defaults to None.
607
+ :type split: Optional[Dict[str, Any]], optional
527
608
  """
609
+ # TODO(@SUSYUSTC): jax autograd is wrong on this function
528
610
  ordered = np.all(np.diff(index) > 0)
529
611
  if not ordered:
530
612
  order = np.argsort(index)
@@ -601,8 +683,7 @@ class MPSCircuit(AbstractCircuit):
601
683
  :type keep: int, optional
602
684
  """
603
685
  # normalization not guaranteed
604
- assert keep in [0, 1]
605
- gate = backend.zeros((2, 2), dtype=dtypestr)
686
+ gate = backend.zeros((self._d, self._d), dtype=dtypestr)
606
687
  gate = backend.scatter(
607
688
  gate,
608
689
  backend.convert_to_tensor([[keep, keep]]),
@@ -633,7 +714,7 @@ class MPSCircuit(AbstractCircuit):
633
714
  def wavefunction_to_tensors(
634
715
  cls,
635
716
  wavefunction: Tensor,
636
- dim_phys: int = 2,
717
+ dim_phys: Optional[int] = None,
637
718
  norm: bool = True,
638
719
  split: Optional[Dict[str, Any]] = None,
639
720
  ) -> List[Tensor]:
@@ -651,6 +732,7 @@ class MPSCircuit(AbstractCircuit):
651
732
  :return: The tensors
652
733
  :rtype: List[Tensor]
653
734
  """
735
+ dim_phys = dim_phys if dim_phys is not None else 2
654
736
  if split is None:
655
737
  split = {}
656
738
  wavefunction = backend.reshape(wavefunction, (-1, 1))
@@ -709,10 +791,16 @@ class MPSCircuit(AbstractCircuit):
709
791
  for key in vars(self):
710
792
  if key == "_mps":
711
793
  continue
712
- if backend.is_tensor(info[key]):
713
- copied_value = backend.copy(info[key])
794
+ val = info[key]
795
+ if backend.is_tensor(val):
796
+ copied_value = backend.copy(val)
797
+ elif isinstance(val, types.ModuleType):
798
+ copied_value = val
714
799
  else:
715
- copied_value = copy(info[key])
800
+ try:
801
+ copied_value = copy(val)
802
+ except TypeError:
803
+ copied_value = val
716
804
  setattr(result, key, copied_value)
717
805
  return result
718
806
 
@@ -756,7 +844,8 @@ class MPSCircuit(AbstractCircuit):
756
844
 
757
845
  def amplitude(self, l: str) -> Tensor:
758
846
  assert len(l) == self._nqubits
759
- tensors = [self._mps.tensors[i][:, int(s), :] for i, s in enumerate(l)]
847
+ idx_list = _decode_basis_label(l, n=self._nqubits, dim=self._d)
848
+ tensors = [self._mps.tensors[i][:, idx, :] for i, idx in enumerate(idx_list)]
760
849
  return reduce(backend.matmul, tensors)[0, 0]
761
850
 
762
851
  def proj_with_mps(self, other: "MPSCircuit", conj: bool = True) -> Tensor:
@@ -814,6 +903,7 @@ class MPSCircuit(AbstractCircuit):
814
903
 
815
904
  mps = self.__class__(
816
905
  nqubits,
906
+ dim=self._d,
817
907
  tensors=tensors,
818
908
  center_position=center_position,
819
909
  split=self.split.copy(),
@@ -941,36 +1031,35 @@ class MPSCircuit(AbstractCircuit):
941
1031
  # set the center to the left side, then gradually move to the right and do measurement at sites
942
1032
  """
943
1033
  mps = self.copy()
944
- up = backend.convert_to_tensor(np.array([1, 0]).astype(dtypestr))
945
- down = backend.convert_to_tensor(np.array([0, 1]).astype(dtypestr))
946
1034
 
947
1035
  p = 1.0
948
1036
  p = backend.convert_to_tensor(p)
949
1037
  p = backend.cast(p, dtype=rdtypestr)
950
- sample = []
1038
+ sample: Tensor = []
951
1039
  for k, site in enumerate(index):
952
1040
  mps.position(site)
953
- # do measurement
954
1041
  tensor = mps._mps.tensors[site]
955
1042
  ps = backend.real(
956
1043
  backend.einsum("iaj,iaj->a", tensor, backend.conj(tensor))
957
1044
  )
958
1045
  ps /= backend.sum(ps)
959
- pu = ps[0]
960
1046
  if status is None:
961
- r = backend.implicit_randu()[0]
1047
+ outcome = backend.implicit_randc(
1048
+ self._d, shape=1, p=backend.cast(ps, rdtypestr)
1049
+ )[0]
962
1050
  else:
963
- r = status[k]
964
- r = backend.real(backend.cast(r, dtypestr))
965
- eps = 0.31415926 * 1e-12
966
- sign = backend.sign(r - pu + eps) / 2 + 0.5 # in case status is exactly 0.5
967
- sign = backend.convert_to_tensor(sign)
968
- sign = backend.cast(sign, dtype=rdtypestr)
969
- sign_complex = backend.cast(sign, dtypestr)
970
- sample.append(sign_complex)
971
- p = p * (pu * (-1) ** sign + sign)
972
- m = (1 - sign_complex) * up + sign_complex * down
1051
+ one_r = backend.cast(backend.convert_to_tensor(1.0), rdtypestr)
1052
+ st = backend.cast(status[k : k + 1], rdtypestr)
1053
+ ind = backend.probability_sample(
1054
+ shots=1, p=backend.cast(ps, rdtypestr), status=one_r - st
1055
+ )
1056
+ outcome = backend.cast(ind[0], "int32")
1057
+
1058
+ p = p * ps[outcome]
1059
+ basis = backend.convert_to_tensor(np.eye(self._d).astype(dtypestr))
1060
+ m = basis[outcome]
973
1061
  mps._mps.tensors[site] = backend.einsum("iaj,a->ij", tensor, m)[:, None, :]
1062
+ sample.append(outcome)
974
1063
  sample = backend.stack(sample)
975
1064
  sample = backend.real(sample)
976
1065
  if with_prob: