tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tensorcircuit/circuit.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """
2
- Quantum circuit: the state simulator
2
+ Quantum circuit: the state simulator.
3
+ Supports qubit (dim=2) and qudit (3 <= dim <= 36) systems.
4
+ For string-encoded samples/counts, digits use 0-9A-Z where A=10, ..., Z=35.
3
5
  """
4
6
 
5
7
  # pylint: disable=invalid-name
@@ -13,8 +15,8 @@ import tensornetwork as tn
13
15
 
14
16
  from . import gates
15
17
  from . import channels
16
- from .cons import backend, contractor, dtypestr, npdtype
17
- from .quantum import QuOperator, identity
18
+ from .cons import backend, contractor, dtypestr, npdtype, _ALPHABET
19
+ from .quantum import QuOperator, identity, _infer_num_sites, onehot_d_tensor
18
20
  from .simplify import _full_light_cone_cancel
19
21
  from .basecircuit import BaseCircuit
20
22
 
@@ -23,7 +25,7 @@ Tensor = Any
23
25
 
24
26
 
25
27
  class Circuit(BaseCircuit):
26
- """
28
+ r"""
27
29
  ``Circuit`` class.
28
30
  Simple usage demo below.
29
31
 
@@ -45,14 +47,18 @@ class Circuit(BaseCircuit):
45
47
  inputs: Optional[Tensor] = None,
46
48
  mps_inputs: Optional[QuOperator] = None,
47
49
  split: Optional[Dict[str, Any]] = None,
50
+ dim: Optional[int] = None,
48
51
  ) -> None:
49
- """
52
+ r"""
50
53
  Circuit object based on state simulator.
54
+ Do not use this class with d!=2 directly, use tc.QuditCircuit instead for qudit systems.
51
55
 
52
56
  :param nqubits: The number of qubits in the circuit.
53
57
  :type nqubits: int
58
+ :param dim: The local Hilbert space dimension per site. Qudit is supported for 2 <= d <= 36.
59
+ :type dim: If None, the dimension of the circuit will be `2`, which is a qubit system.
54
60
  :param inputs: If not None, the initial state of the circuit is taken as ``inputs``
55
- instead of :math:`\\vert 0\\rangle^n` qubits, defaults to None.
61
+ instead of :math:`\vert 0 \rangle^n` qubits, defaults to None.
56
62
  :type inputs: Optional[Tensor], optional
57
63
  :param mps_inputs: QuVector for a MPS like initial wavefunction.
58
64
  :type mps_inputs: Optional[QuOperator]
@@ -60,6 +66,7 @@ class Circuit(BaseCircuit):
60
66
  ``max_singular_values`` and ``max_truncation_err``.
61
67
  :type split: Optional[Dict[str, Any]]
62
68
  """
69
+ self._d = 2 if dim is None else dim
63
70
  self.inputs = inputs
64
71
  self.mps_inputs = mps_inputs
65
72
  self.split = split
@@ -70,18 +77,19 @@ class Circuit(BaseCircuit):
70
77
  "inputs": inputs,
71
78
  "mps_inputs": mps_inputs,
72
79
  "split": split,
80
+ "dim": dim,
73
81
  }
74
82
  if (inputs is None) and (mps_inputs is None):
75
- nodes = self.all_zero_nodes(nqubits)
83
+ nodes = self.all_zero_nodes(nqubits, dim=self._d)
76
84
  self._front = [n.get_edge(0) for n in nodes]
77
85
  elif inputs is not None: # provide input function
78
86
  inputs = backend.convert_to_tensor(inputs)
79
87
  inputs = backend.cast(inputs, dtype=dtypestr)
80
88
  inputs = backend.reshape(inputs, [-1])
81
89
  N = inputs.shape[0]
82
- n = int(np.log(N) / np.log(2))
90
+ n = _infer_num_sites(N, dim=self._d)
83
91
  assert n == nqubits or n == 2 * nqubits
84
- inputs = backend.reshape(inputs, [2 for _ in range(n)])
92
+ inputs = backend.reshape(inputs, [self._d for _ in range(n)])
85
93
  inputs = Gate(inputs)
86
94
  nodes = [inputs]
87
95
  self._front = [inputs.get_edge(i) for i in range(n)]
@@ -178,27 +186,14 @@ class Circuit(BaseCircuit):
178
186
 
179
187
  :param index: The index of qubit that the Z direction postselection applied on.
180
188
  :type index: int
181
- :param keep: 0 for spin up, 1 for spin down, defaults to be 0.
189
+ :param keep: the post-selected digit in {0, ..., d-1}, defaults to be 0.
182
190
  :type keep: int, optional
183
191
  """
184
192
  # normalization not guaranteed
185
- # assert keep in [0, 1]
186
- if keep < 0.5:
187
- gate = np.array(
188
- [
189
- [1.0],
190
- [0.0],
191
- ],
192
- dtype=npdtype,
193
- )
194
- else:
195
- gate = np.array(
196
- [
197
- [0.0],
198
- [1.0],
199
- ],
200
- dtype=npdtype,
201
- )
193
+ gate = np.array(
194
+ [[0.0] if _idx != keep else [1.0] for _idx in range(self._d)],
195
+ dtype=npdtype,
196
+ )
202
197
 
203
198
  mg1 = tn.Node(gate)
204
199
  mg2 = tn.Node(gate)
@@ -231,6 +226,25 @@ class Circuit(BaseCircuit):
231
226
  pz: float,
232
227
  status: Optional[float] = None,
233
228
  ) -> float:
229
+ """
230
+ Apply a depolarizing channel to the circuit in a Monte Carlo way.
231
+ For each call, one of the Pauli gates (X, Y, Z) or an Identity gate is applied to the qubit
232
+ at the given index based on the probabilities `px`, `py`, and `pz`.
233
+
234
+ :param index: The index of the qubit to apply the depolarizing channel on.
235
+ :type index: int
236
+ :param px: The probability of applying an X gate.
237
+ :type px: float
238
+ :param py: The probability of applying a Y gate.
239
+ :type py: float
240
+ :param pz: The probability of applying a Z gate.
241
+ :type pz: float
242
+ :param status: A random number between 0 and 1 to determine which gate to apply. If None,
243
+ a random number is generated automatically. Defaults to None.
244
+ :type status: Optional[float], optional
245
+ :return: Returns 0.0. The function modifies the circuit in place.
246
+ :rtype: float
247
+ """
234
248
  if status is None:
235
249
  status = backend.implicit_randu()[0]
236
250
  g = backend.cond(
@@ -323,6 +337,35 @@ class Circuit(BaseCircuit):
323
337
  status: Optional[float] = None,
324
338
  name: Optional[str] = None,
325
339
  ) -> Tensor:
340
+ """
341
+ Apply a unitary Kraus channel to the circuit using a Monte Carlo approach. This method is functionally
342
+ similar to `unitary_kraus` but uses `backend.switch` for selecting the Kraus operator, which can have
343
+ different performance characteristics on some backends.
344
+
345
+ A random Kraus operator from the provided list is applied to the circuit based on the given probabilities.
346
+ This method is jittable and suitable for simulating noisy quantum circuits where the noise is represented
347
+ by unitary Kraus operators.
348
+
349
+ .. warning::
350
+ This method may have issues with `vmap` due to potential concurrent access locks, potentially related with
351
+ `backend.switch`. `unitary_kraus` is generally recommended.
352
+
353
+ :param kraus: A sequence of `Gate` objects representing the unitary Kraus operators.
354
+ :type kraus: Sequence[Gate]
355
+ :param index: The qubit indices on which to apply the Kraus channel.
356
+ :type index: int
357
+ :param prob: A sequence of probabilities corresponding to each Kraus operator. If None, probabilities
358
+ are derived from the operators themselves. Defaults to None.
359
+ :type prob: Optional[Sequence[float]], optional
360
+ :param status: A random number between 0 and 1 to determine which Kraus operator to apply. If None,
361
+ a random number is generated automatically. Defaults to None.
362
+ :type status: Optional[float], optional
363
+ :param name: An optional name for the operation. Defaults to None.
364
+ :type name: Optional[str], optional
365
+ :return: A tensor indicating which Kraus operator was applied.
366
+ :rtype: Tensor
367
+ """
368
+
326
369
  # dont use, has issue conflicting with vmap, concurrent access lock emerged
327
370
  # potential issue raised from switch
328
371
  # general impl from Monte Carlo trajectory depolarizing above
@@ -431,8 +474,8 @@ class Circuit(BaseCircuit):
431
474
  if get_gate_from_index is None:
432
475
  raise ValueError("no `get_gate_from_index` implementation is provided")
433
476
  g = get_gate_from_index(r, kraus)
434
- g = backend.reshape(g, [2 for _ in range(sites * 2)])
435
- self.any(*index, unitary=g, name=name) # type: ignore
477
+ g = backend.reshape(g, [self._d for _ in range(sites * 2)])
478
+ self.any(*index, unitary=g, name=name, dim=self._d) # type: ignore
436
479
  return r
437
480
 
438
481
  def _general_kraus_tf(
@@ -557,9 +600,13 @@ class Circuit(BaseCircuit):
557
600
  for w, k in zip(prob, kraus_tensor)
558
601
  ]
559
602
  pick = self.unitary_kraus(
560
- new_kraus, *index, prob=prob, status=status, name=name
603
+ new_kraus,
604
+ *index,
605
+ prob=prob,
606
+ status=status,
607
+ name=name,
561
608
  )
562
- if with_prob is False:
609
+ if not with_prob:
563
610
  return pick
564
611
  else:
565
612
  return pick, prob
@@ -590,7 +637,11 @@ class Circuit(BaseCircuit):
590
637
  :type status: Optional[float], optional
591
638
  """
592
639
  return self._general_kraus_2(
593
- kraus, *index, status=status, with_prob=with_prob, name=name
640
+ kraus,
641
+ *index,
642
+ status=status,
643
+ with_prob=with_prob,
644
+ name=name,
594
645
  )
595
646
 
596
647
  apply_general_kraus = general_kraus
@@ -632,7 +683,7 @@ class Circuit(BaseCircuit):
632
683
  Apply %s quantum channel on the circuit.
633
684
  See :py:meth:`tensorcircuit.channels.%schannel`
634
685
 
635
- :param index: Qubit number that the gate applies on.
686
+ :param index: Site index that the gate applies on.
636
687
  :type index: int.
637
688
  :param status: uniform external random number between 0 and 1
638
689
  :type status: Tensor
@@ -689,8 +740,8 @@ class Circuit(BaseCircuit):
689
740
  :return: ``QuOperator`` object for the circuit unitary (open indices for the input state)
690
741
  :rtype: QuOperator
691
742
  """
692
- mps = identity([2 for _ in range(self._nqubits)])
693
- c = Circuit(self._nqubits)
743
+ mps = identity([self._d for _ in range(self._nqubits)])
744
+ c = Circuit(self._nqubits, dim=self._d)
694
745
  ns, es = self._copy()
695
746
  c._nodes = ns
696
747
  c._front = es
@@ -710,8 +761,8 @@ class Circuit(BaseCircuit):
710
761
  :return: The circuit unitary matrix
711
762
  :rtype: Tensor
712
763
  """
713
- mps = identity([2 for _ in range(self._nqubits)])
714
- c = Circuit(self._nqubits)
764
+ mps = identity([self._d for _ in range(self._nqubits)])
765
+ c = Circuit(self._nqubits, dim=self._d)
715
766
  ns, es = self._copy()
716
767
  c._nodes = ns
717
768
  c._front = es
@@ -724,6 +775,9 @@ class Circuit(BaseCircuit):
724
775
  """
725
776
  Take measurement on the given quantum lines by ``index``.
726
777
 
778
+ Return format:
779
+ - For d <= 36, the sample is a base-d string using 0-9A-Z (A=10,...).
780
+
727
781
  :Example:
728
782
 
729
783
  >>> c = tc.Circuit(3)
@@ -752,10 +806,7 @@ class Circuit(BaseCircuit):
752
806
  if i != j:
753
807
  e ^ edge2[i]
754
808
  for i in range(len(sample)):
755
- if sample[i] == "0":
756
- m = np.array([1, 0], dtype=npdtype)
757
- else:
758
- m = np.array([0, 1], dtype=npdtype)
809
+ m = onehot_d_tensor(sample[i], d=self._d)
759
810
  nodes1.append(tn.Node(m))
760
811
  nodes1[-1].get_edge(0) ^ edge1[index[i]]
761
812
  nodes2.append(tn.Node(m))
@@ -766,15 +817,13 @@ class Circuit(BaseCircuit):
766
817
  / p
767
818
  * contractor(nodes1, output_edge_order=[edge1[j], edge2[j]]).tensor
768
819
  )
769
- pu = rho[0, 0]
770
- r = backend.random_uniform([])
771
- r = backend.real(backend.cast(r, dtypestr))
772
- if r < backend.real(pu):
773
- sample += "0"
774
- p = p * pu
775
- else:
776
- sample += "1"
777
- p = p * (1 - pu)
820
+ probs = backend.real(backend.diagonal(rho))
821
+ probs /= backend.sum(probs)
822
+ outcome = backend.implicit_randc(self._d, shape=1, p=probs)
823
+
824
+ sample += _ALPHABET[outcome]
825
+ p *= float(probs[outcome])
826
+
778
827
  if with_prob:
779
828
  return sample, p
780
829
  else:
@@ -794,6 +843,10 @@ class Circuit(BaseCircuit):
794
843
  ) -> Tensor:
795
844
  """
796
845
  Compute the expectation of corresponding operators.
846
+ For qudit (d > 2),
847
+ ensure that operator tensor shapes are consistent with d (each site contributes two axes of size d).
848
+
849
+ Noise shorthand (via noise_conf) is qubit-only; for d>2, use explicit operators.
797
850
 
798
851
  :Example:
799
852
 
@@ -829,14 +882,12 @@ class Circuit(BaseCircuit):
829
882
  :param nmc: repetition time for Monte Carlo sampling for noisfy calculation, defaults to 1000
830
883
  :type nmc: int, optional
831
884
  :param status: external randomness given by tensor uniformly from [0, 1], defaults to None,
832
- used for noisfy circuit sampling
885
+ used for noisy circuit sampling
833
886
  :type status: Optional[Tensor], optional
834
887
  :raises ValueError: "Cannot measure two operators in one index"
835
888
  :return: Tensor with one element
836
889
  :rtype: Tensor
837
890
  """
838
- from .noisemodel import expectation_noisfy
839
-
840
891
  if noise_conf is None:
841
892
  # if not reuse:
842
893
  # nodes1, edge1 = self._copy()
@@ -851,6 +902,8 @@ class Circuit(BaseCircuit):
851
902
  nodes1 = _full_light_cone_cancel(nodes1)
852
903
  return contractor(nodes1).tensor
853
904
  else:
905
+ from .noisemodel import expectation_noisfy
906
+
854
907
  return expectation_noisfy(
855
908
  self,
856
909
  *ops,
@@ -871,9 +924,11 @@ def expectation(
871
924
  bra: Optional[Tensor] = None,
872
925
  conj: bool = True,
873
926
  normalization: bool = False,
927
+ dim: Optional[int] = None,
874
928
  ) -> Tensor:
875
929
  """
876
930
  Compute :math:`\\langle bra\\vert ops \\vert ket\\rangle`.
931
+ For qudit systems (d>2), ops must be reshaped with per-site axes of length d.
877
932
 
878
933
  Example 1 (:math:`bra` is same as :math:`ket`)
879
934
 
@@ -918,6 +973,8 @@ def expectation(
918
973
  :type ket: Tensor
919
974
  :param bra: :math:`bra`, defaults to None, which is the same as ``ket``.
920
975
  :type bra: Optional[Tensor], optional
976
+ :param dim: dimension of the circuit (defaults to 2)
977
+ :type dim: int, optional
921
978
  :param conj: :math:`bra` changes to the adjoint matrix of :math:`bra`, defaults to True.
922
979
  :type conj: bool, optional
923
980
  :param normalization: Normalize the :math:`ket` and :math:`bra`, defaults to False.
@@ -926,6 +983,7 @@ def expectation(
926
983
  :return: The result of :math:`\\langle bra\\vert ops \\vert ket\\rangle`.
927
984
  :rtype: Tensor
928
985
  """
986
+ dim = 2 if dim is None else dim
929
987
  if bra is None:
930
988
  bra = ket
931
989
  if isinstance(ket, QuOperator):
@@ -939,7 +997,7 @@ def expectation(
939
997
  for op, index in ops:
940
998
  if not isinstance(op, tn.Node):
941
999
  # op is only a matrix
942
- op = backend.reshape2(op)
1000
+ op = backend.reshaped(op, dim)
943
1001
  op = gates.Gate(op)
944
1002
  if isinstance(index, int):
945
1003
  index = [index]
@@ -963,8 +1021,8 @@ def expectation(
963
1021
  if conj is True:
964
1022
  bra = backend.conj(bra)
965
1023
  ket = backend.reshape(ket, [-1])
966
- ket = backend.reshape2(ket)
967
- bra = backend.reshape2(bra)
1024
+ ket = backend.reshaped(ket, dim)
1025
+ bra = backend.reshaped(bra, dim)
968
1026
  n = len(backend.shape_tuple(ket))
969
1027
  ket = Gate(ket)
970
1028
  bra = Gate(bra)
@@ -976,7 +1034,7 @@ def expectation(
976
1034
  for op, index in ops:
977
1035
  if not isinstance(op, tn.Node):
978
1036
  # op is only a matrix
979
- op = backend.reshape2(op)
1037
+ op = backend.reshaped(op, dim)
980
1038
  op = gates.Gate(op)
981
1039
  if isinstance(index, int):
982
1040
  index = [index]
@@ -36,7 +36,7 @@ def submit_task(
36
36
  shots: Union[int, Sequence[int]] = 1024,
37
37
  version: str = "1",
38
38
  circuit: Optional[Union[AbstractCircuit, Sequence[AbstractCircuit]]] = None,
39
- **kws: Any
39
+ **kws: Any,
40
40
  ) -> List[Task]:
41
41
  def _circuit2result(c: AbstractCircuit) -> Dict[str, Any]:
42
42
  if device.name in ["testing", "default"]:
@@ -30,7 +30,7 @@ def submit_task(
30
30
  circuit: Optional[Union[AbstractCircuit, Sequence[AbstractCircuit]]] = None,
31
31
  source: Optional[Union[str, Sequence[str]]] = None,
32
32
  compile: bool = True,
33
- **kws: Any
33
+ **kws: Any,
34
34
  ) -> Task:
35
35
  if source is None:
36
36
 
@@ -133,7 +133,7 @@ def submit_task(
133
133
  enable_qos_gate_decomposition: bool = True,
134
134
  enable_qos_initial_mapping: bool = False,
135
135
  qos_dry_run: bool = False,
136
- **kws: Any
136
+ **kws: Any,
137
137
  ) -> List[Task]:
138
138
  """
139
139
  Submit task via tencent provider, we suggest to enable one of the compiling functionality:
@@ -109,7 +109,7 @@ def prune(
109
109
  circuit: Union[AbstractCircuit, List[Dict[str, Any]]],
110
110
  rtol: float = 1e-3,
111
111
  atol: float = 1e-3,
112
- **kws: Any
112
+ **kws: Any,
113
113
  ) -> Any:
114
114
  if isinstance(circuit, list):
115
115
  qir = circuit
@@ -251,7 +251,7 @@ def _merge(
251
251
  def merge(
252
252
  circuit: Union[AbstractCircuit, List[Dict[str, Any]]],
253
253
  rules: Optional[Dict[Tuple[str, ...], str]] = None,
254
- **kws: Any
254
+ **kws: Any,
255
255
  ) -> Any:
256
256
  merge_rules = copy(default_merge_rules)
257
257
  if rules is not None:
tensorcircuit/cons.py CHANGED
@@ -8,7 +8,7 @@ import logging
8
8
  import sys
9
9
  import time
10
10
  from contextlib import contextmanager
11
- from functools import partial, reduce, wraps
11
+ from functools import partial, reduce, wraps, lru_cache
12
12
  from operator import mul
13
13
  from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple
14
14
 
@@ -23,6 +23,39 @@ from .simplify import _multi_remove
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
+ ## monkey patch
27
+ _NODE_CREATION_COUNTER = 0
28
+ _original_node_init = tn.Node.__init__
29
+
30
+
31
+ @wraps(_original_node_init)
32
+ def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None:
33
+ """Patched Node.__init__ to add a stable creation ID."""
34
+ global _NODE_CREATION_COUNTER
35
+ _original_node_init(self, *args, **kwargs)
36
+ self._stable_id_ = _NODE_CREATION_COUNTER
37
+ _NODE_CREATION_COUNTER += 1
38
+
39
+
40
+ tn.Node.__init__ = _patched_node_init
41
+
42
+
43
+ def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]:
44
+ n1, n2 = edge.node1, edge.node2
45
+ id1 = getattr(n1, "_stable_id_", -1)
46
+ id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2 # -2 for dangling
47
+
48
+ if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2):
49
+ id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1
50
+ else:
51
+ ax1, ax2 = edge.axis1, edge.axis2
52
+ return (id1, ax1, id2, ax2)
53
+
54
+
55
+ def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]:
56
+ return sorted(edges, key=_get_edge_stable_key)
57
+
58
+
26
59
  package_name = "tensorcircuit"
27
60
  thismodule = sys.modules[__name__]
28
61
  dtypestr = "complex64"
@@ -30,6 +63,7 @@ rdtypestr = "float32"
30
63
  npdtype = np.complex64
31
64
  backend: NumpyBackend = get_backend("numpy")
32
65
  contractor = tn.contractors.auto
66
+ _ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
33
67
  # these above lines are just for mypy, it is not very good at evaluating runtime object
34
68
 
35
69
 
@@ -261,7 +295,7 @@ def _merge_single_gates(
261
295
  e0 = n0[0]
262
296
  njs = [i for i, n in enumerate(nodes) if id(n) in [id(e0.node1), id(e0.node2)]]
263
297
  qjs = [i for i, n in enumerate(queue) if id(n) in [id(e0.node1), id(e0.node2)]]
264
- new_node = tn.contract(e0)
298
+ new_node = tn.contract_parallel(e0)
265
299
  total_size += _sizen(new_node)
266
300
 
267
301
  logger.debug(
@@ -439,6 +473,28 @@ def tn_greedy_contractor(
439
473
  # base = tn.contractors.opt_einsum_paths.path_contractors.base
440
474
  # utils = tn.contractors.opt_einsum_paths.utils
441
475
 
476
+ _einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
477
+
478
+
479
+ @lru_cache(2**14)
480
+ def get_symbol(i: int) -> str:
481
+ """Get the symbol corresponding to int ``i`` - runs through the usual 52
482
+ letters before resorting to unicode characters, starting at ``chr(192)``
483
+ and skipping surrogates. From cotengra codebase
484
+ """
485
+ if i < 52:
486
+ # use a-z, A-Z first
487
+ return _einsum_symbols_base[i]
488
+
489
+ # then proceed from 'À'
490
+ i += 140
491
+
492
+ if i >= 55296:
493
+ # Skip chr(57343) - chr(55296) as surrogates
494
+ i += 2048
495
+
496
+ return chr(i)
497
+
442
498
 
443
499
  def _get_path(
444
500
  nodes: List[tn.Node], algorithm: Any
@@ -451,30 +507,33 @@ def _get_path(
451
507
  return algorithm(input_sets, output_set, size_dict), nodes
452
508
 
453
509
 
510
+ def _identity(*args: Any, **kws: Any) -> Any:
511
+ return args
512
+
513
+
454
514
  def _get_path_cache_friendly(
455
515
  nodes: List[tn.Node], algorithm: Any
456
516
  ) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
457
517
  nodes = list(nodes)
518
+
519
+ nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
520
+ # if isinstance(algorithm, list):
521
+ # return algorithm, [nodes_new]
522
+
523
+ all_edges = tn.get_all_edges(nodes_new)
524
+ all_edges_sorted = sorted_edges(all_edges)
458
525
  mapping_dict = {}
459
526
  i = 0
460
- for n in nodes:
461
- for e in n:
462
- if id(e) not in mapping_dict:
463
- mapping_dict[id(e)] = i
464
- i += 1
465
- # TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
466
- input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
467
- placeholder = [[1e20 for _ in range(100)]]
468
- order = np.argsort(np.array(list(map(sorted, input_sets)) + placeholder, dtype=object))[:-1] # type: ignore
469
- nodes_new = [nodes[i] for i in order]
470
- if isinstance(algorithm, list):
471
- return algorithm, nodes_new
472
-
473
- input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
474
- output_set = set([mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)])
475
- size_dict = {
476
- mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new)
477
- }
527
+ for edge in all_edges_sorted:
528
+ if id(edge) not in mapping_dict:
529
+ mapping_dict[id(edge)] = get_symbol(i)
530
+ i += 1
531
+
532
+ input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
533
+ output_set = list(
534
+ [mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
535
+ )
536
+ size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
478
537
  logger.debug("input_sets: %s" % input_sets)
479
538
  logger.debug("output_set: %s" % output_set)
480
539
  logger.debug("size_dict: %s" % size_dict)
@@ -483,6 +542,9 @@ def _get_path_cache_friendly(
483
542
  # directly get input_sets, output_set and size_dict by using identity function as algorithm
484
543
 
485
544
 
545
+ get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
546
+
547
+
486
548
  # some contractor setup usages
487
549
  """
488
550
  import cotengra as ctg
@@ -513,7 +575,8 @@ opt = ctg.ReusableHyperOptimizer(
513
575
 
514
576
  def opt_reconf(inputs, output, size, **kws):
515
577
  tree = opt.search(inputs, output, size)
516
- tree_r = tree.subtree_reconfigure_forest(progbar=True, num_trees=10, num_restarts=20, subtree_weight_what=("size", ))
578
+ tree_r = tree.subtree_reconfigure_forest(progbar=True, num_trees=10,
579
+ num_restarts=20, subtree_weight_what=("size", ))
517
580
  return tree_r.get_path()
518
581
 
519
582
  tc.set_contractor("custom", optimizer=opt_reconf)
@@ -631,6 +694,51 @@ def _base(
631
694
  return final_node
632
695
 
633
696
 
697
+ class NodesReturn(Exception):
698
+ """
699
+ Intentionally stop execution to return a value.
700
+ """
701
+
702
+ def __init__(self, value_to_return: Any):
703
+ self.value = value_to_return
704
+ super().__init__(
705
+ f"Intentionally stopping execution to return: {value_to_return}"
706
+ )
707
+
708
+
709
+ def _get_sorted_nodes(nodes: List[Any], *args: Any, **kws: Any) -> Any:
710
+ nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
711
+ raise NodesReturn(nodes_new)
712
+
713
+
714
+ def function_nodes_capture(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
715
+ @wraps(func)
716
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
717
+ with runtime_contractor(method="before"):
718
+ try:
719
+ result = func(*args, **kwargs)
720
+ return result
721
+ except NodesReturn as e:
722
+ return e.value
723
+
724
+ return wrapper
725
+
726
+
727
+ @contextmanager
728
+ def runtime_nodes_capture(key: str = "nodes") -> Iterator[Any]:
729
+ old_contractor = getattr(thismodule, "contractor")
730
+ set_contractor(method="before")
731
+ captured_value: Dict[str, List[tn.Node]] = {}
732
+ try:
733
+ yield captured_value
734
+ except NodesReturn as e:
735
+ captured_value[key] = e.value
736
+ finally:
737
+ for module in sys.modules:
738
+ if module.startswith(package_name):
739
+ setattr(sys.modules[module], "contractor", old_contractor)
740
+
741
+
634
742
  def custom(
635
743
  nodes: List[Any],
636
744
  optimizer: Any,
@@ -701,6 +809,16 @@ def custom_stateful(
701
809
 
702
810
  # only work for custom
703
811
  def contraction_info_decorator(algorithm: Callable[..., Any]) -> Callable[..., Any]:
812
+ """Decorator to add contraction information logging to an optimizer.
813
+
814
+ This decorator wraps an optimization algorithm and prints detailed information
815
+ about the contraction cost (FLOPs, size, write) and path finding time.
816
+
817
+ :param algorithm: The optimization algorithm to decorate.
818
+ :type algorithm: Callable[..., Any]
819
+ :return: The decorated optimization algorithm.
820
+ :rtype: Callable[..., Any]
821
+ """
704
822
  from cotengra import ContractionTree
705
823
 
706
824
  def new_algorithm(
@@ -807,6 +925,9 @@ def set_contractor(
807
925
  **kws,
808
926
  )
809
927
 
928
+ elif method == "before": # a hack way to get the nodes
929
+ cf = _get_sorted_nodes
930
+
810
931
  else:
811
932
  # cf = getattr(tn.contractors, method, None)
812
933
  # if not cf: