tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
tensorcircuit/circuit.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """
2
- Quantum circuit: the state simulator
2
+ Quantum circuit: the state simulator.
3
+ Supports qubit (dim=2) and qudit (3 <= dim <= 36) systems.
4
+ For string-encoded samples/counts, digits use 0-9A-Z where A=10, ..., Z=35.
3
5
  """
4
6
 
5
7
  # pylint: disable=invalid-name
@@ -13,8 +15,8 @@ import tensornetwork as tn
13
15
 
14
16
  from . import gates
15
17
  from . import channels
16
- from .cons import backend, contractor, dtypestr, npdtype
17
- from .quantum import QuOperator, identity
18
+ from .cons import backend, contractor, dtypestr, npdtype, _ALPHABET
19
+ from .quantum import QuOperator, identity, _infer_num_sites, onehot_d_tensor
18
20
  from .simplify import _full_light_cone_cancel
19
21
  from .basecircuit import BaseCircuit
20
22
 
@@ -23,7 +25,7 @@ Tensor = Any
23
25
 
24
26
 
25
27
  class Circuit(BaseCircuit):
26
- """
28
+ r"""
27
29
  ``Circuit`` class.
28
30
  Simple usage demo below.
29
31
 
@@ -45,14 +47,18 @@ class Circuit(BaseCircuit):
45
47
  inputs: Optional[Tensor] = None,
46
48
  mps_inputs: Optional[QuOperator] = None,
47
49
  split: Optional[Dict[str, Any]] = None,
50
+ dim: Optional[int] = None,
48
51
  ) -> None:
49
- """
52
+ r"""
50
53
  Circuit object based on state simulator.
54
+ Do not use this class with d!=2 directly, use tc.QuditCircuit instead for qudit systems.
51
55
 
52
56
  :param nqubits: The number of qubits in the circuit.
53
57
  :type nqubits: int
58
+ :param dim: The local Hilbert space dimension per site. Qudit is supported for 2 <= d <= 36.
59
+ :type dim: If None, the dimension of the circuit will be `2`, which is a qubit system.
54
60
  :param inputs: If not None, the initial state of the circuit is taken as ``inputs``
55
- instead of :math:`\\vert 0\\rangle^n` qubits, defaults to None.
61
+ instead of :math:`\vert 0 \rangle^n` qubits, defaults to None.
56
62
  :type inputs: Optional[Tensor], optional
57
63
  :param mps_inputs: QuVector for a MPS like initial wavefunction.
58
64
  :type mps_inputs: Optional[QuOperator]
@@ -60,6 +66,7 @@ class Circuit(BaseCircuit):
60
66
  ``max_singular_values`` and ``max_truncation_err``.
61
67
  :type split: Optional[Dict[str, Any]]
62
68
  """
69
+ self._d = 2 if dim is None else dim
63
70
  self.inputs = inputs
64
71
  self.mps_inputs = mps_inputs
65
72
  self.split = split
@@ -70,18 +77,19 @@ class Circuit(BaseCircuit):
70
77
  "inputs": inputs,
71
78
  "mps_inputs": mps_inputs,
72
79
  "split": split,
80
+ "dim": dim,
73
81
  }
74
82
  if (inputs is None) and (mps_inputs is None):
75
- nodes = self.all_zero_nodes(nqubits)
83
+ nodes = self.all_zero_nodes(nqubits, dim=self._d)
76
84
  self._front = [n.get_edge(0) for n in nodes]
77
85
  elif inputs is not None: # provide input function
78
86
  inputs = backend.convert_to_tensor(inputs)
79
87
  inputs = backend.cast(inputs, dtype=dtypestr)
80
88
  inputs = backend.reshape(inputs, [-1])
81
89
  N = inputs.shape[0]
82
- n = int(np.log(N) / np.log(2))
90
+ n = _infer_num_sites(N, dim=self._d)
83
91
  assert n == nqubits or n == 2 * nqubits
84
- inputs = backend.reshape(inputs, [2 for _ in range(n)])
92
+ inputs = backend.reshape(inputs, [self._d for _ in range(n)])
85
93
  inputs = Gate(inputs)
86
94
  nodes = [inputs]
87
95
  self._front = [inputs.get_edge(i) for i in range(n)]
@@ -178,27 +186,14 @@ class Circuit(BaseCircuit):
178
186
 
179
187
  :param index: The index of qubit that the Z direction postselection applied on.
180
188
  :type index: int
181
- :param keep: 0 for spin up, 1 for spin down, defaults to be 0.
189
+ :param keep: the post-selected digit in {0, ..., d-1}, defaults to be 0.
182
190
  :type keep: int, optional
183
191
  """
184
192
  # normalization not guaranteed
185
- # assert keep in [0, 1]
186
- if keep < 0.5:
187
- gate = np.array(
188
- [
189
- [1.0],
190
- [0.0],
191
- ],
192
- dtype=npdtype,
193
- )
194
- else:
195
- gate = np.array(
196
- [
197
- [0.0],
198
- [1.0],
199
- ],
200
- dtype=npdtype,
201
- )
193
+ gate = np.array(
194
+ [[0.0] if _idx != keep else [1.0] for _idx in range(self._d)],
195
+ dtype=npdtype,
196
+ )
202
197
 
203
198
  mg1 = tn.Node(gate)
204
199
  mg2 = tn.Node(gate)
@@ -231,6 +226,25 @@ class Circuit(BaseCircuit):
231
226
  pz: float,
232
227
  status: Optional[float] = None,
233
228
  ) -> float:
229
+ """
230
+ Apply a depolarizing channel to the circuit in a Monte Carlo way.
231
+ For each call, one of the Pauli gates (X, Y, Z) or an Identity gate is applied to the qubit
232
+ at the given index based on the probabilities `px`, `py`, and `pz`.
233
+
234
+ :param index: The index of the qubit to apply the depolarizing channel on.
235
+ :type index: int
236
+ :param px: The probability of applying an X gate.
237
+ :type px: float
238
+ :param py: The probability of applying a Y gate.
239
+ :type py: float
240
+ :param pz: The probability of applying a Z gate.
241
+ :type pz: float
242
+ :param status: A random number between 0 and 1 to determine which gate to apply. If None,
243
+ a random number is generated automatically. Defaults to None.
244
+ :type status: Optional[float], optional
245
+ :return: Returns 0.0. The function modifies the circuit in place.
246
+ :rtype: float
247
+ """
234
248
  if status is None:
235
249
  status = backend.implicit_randu()[0]
236
250
  g = backend.cond(
@@ -323,6 +337,35 @@ class Circuit(BaseCircuit):
323
337
  status: Optional[float] = None,
324
338
  name: Optional[str] = None,
325
339
  ) -> Tensor:
340
+ """
341
+ Apply a unitary Kraus channel to the circuit using a Monte Carlo approach. This method is functionally
342
+ similar to `unitary_kraus` but uses `backend.switch` for selecting the Kraus operator, which can have
343
+ different performance characteristics on some backends.
344
+
345
+ A random Kraus operator from the provided list is applied to the circuit based on the given probabilities.
346
+ This method is jittable and suitable for simulating noisy quantum circuits where the noise is represented
347
+ by unitary Kraus operators.
348
+
349
+ .. warning::
350
+ This method may have issues with `vmap` due to potential concurrent access locks, potentially related with
351
+ `backend.switch`. `unitary_kraus` is generally recommended.
352
+
353
+ :param kraus: A sequence of `Gate` objects representing the unitary Kraus operators.
354
+ :type kraus: Sequence[Gate]
355
+ :param index: The qubit indices on which to apply the Kraus channel.
356
+ :type index: int
357
+ :param prob: A sequence of probabilities corresponding to each Kraus operator. If None, probabilities
358
+ are derived from the operators themselves. Defaults to None.
359
+ :type prob: Optional[Sequence[float]], optional
360
+ :param status: A random number between 0 and 1 to determine which Kraus operator to apply. If None,
361
+ a random number is generated automatically. Defaults to None.
362
+ :type status: Optional[float], optional
363
+ :param name: An optional name for the operation. Defaults to None.
364
+ :type name: Optional[str], optional
365
+ :return: A tensor indicating which Kraus operator was applied.
366
+ :rtype: Tensor
367
+ """
368
+
326
369
  # dont use, has issue conflicting with vmap, concurrent access lock emerged
327
370
  # potential issue raised from switch
328
371
  # general impl from Monte Carlo trajectory depolarizing above
@@ -431,8 +474,8 @@ class Circuit(BaseCircuit):
431
474
  if get_gate_from_index is None:
432
475
  raise ValueError("no `get_gate_from_index` implementation is provided")
433
476
  g = get_gate_from_index(r, kraus)
434
- g = backend.reshape(g, [2 for _ in range(sites * 2)])
435
- self.any(*index, unitary=g, name=name) # type: ignore
477
+ g = backend.reshape(g, [self._d for _ in range(sites * 2)])
478
+ self.any(*index, unitary=g, name=name, dim=self._d) # type: ignore
436
479
  return r
437
480
 
438
481
  def _general_kraus_tf(
@@ -557,9 +600,13 @@ class Circuit(BaseCircuit):
557
600
  for w, k in zip(prob, kraus_tensor)
558
601
  ]
559
602
  pick = self.unitary_kraus(
560
- new_kraus, *index, prob=prob, status=status, name=name
603
+ new_kraus,
604
+ *index,
605
+ prob=prob,
606
+ status=status,
607
+ name=name,
561
608
  )
562
- if with_prob is False:
609
+ if not with_prob:
563
610
  return pick
564
611
  else:
565
612
  return pick, prob
@@ -590,7 +637,11 @@ class Circuit(BaseCircuit):
590
637
  :type status: Optional[float], optional
591
638
  """
592
639
  return self._general_kraus_2(
593
- kraus, *index, status=status, with_prob=with_prob, name=name
640
+ kraus,
641
+ *index,
642
+ status=status,
643
+ with_prob=with_prob,
644
+ name=name,
594
645
  )
595
646
 
596
647
  apply_general_kraus = general_kraus
@@ -632,7 +683,7 @@ class Circuit(BaseCircuit):
632
683
  Apply %s quantum channel on the circuit.
633
684
  See :py:meth:`tensorcircuit.channels.%schannel`
634
685
 
635
- :param index: Qubit number that the gate applies on.
686
+ :param index: Site index that the gate applies on.
636
687
  :type index: int.
637
688
  :param status: uniform external random number between 0 and 1
638
689
  :type status: Tensor
@@ -689,8 +740,8 @@ class Circuit(BaseCircuit):
689
740
  :return: ``QuOperator`` object for the circuit unitary (open indices for the input state)
690
741
  :rtype: QuOperator
691
742
  """
692
- mps = identity([2 for _ in range(self._nqubits)])
693
- c = Circuit(self._nqubits)
743
+ mps = identity([self._d for _ in range(self._nqubits)])
744
+ c = Circuit(self._nqubits, dim=self._d)
694
745
  ns, es = self._copy()
695
746
  c._nodes = ns
696
747
  c._front = es
@@ -710,8 +761,8 @@ class Circuit(BaseCircuit):
710
761
  :return: The circuit unitary matrix
711
762
  :rtype: Tensor
712
763
  """
713
- mps = identity([2 for _ in range(self._nqubits)])
714
- c = Circuit(self._nqubits)
764
+ mps = identity([self._d for _ in range(self._nqubits)])
765
+ c = Circuit(self._nqubits, dim=self._d)
715
766
  ns, es = self._copy()
716
767
  c._nodes = ns
717
768
  c._front = es
@@ -724,6 +775,9 @@ class Circuit(BaseCircuit):
724
775
  """
725
776
  Take measurement on the given quantum lines by ``index``.
726
777
 
778
+ Return format:
779
+ - For d <= 36, the sample is a base-d string using 0-9A-Z (A=10,...).
780
+
727
781
  :Example:
728
782
 
729
783
  >>> c = tc.Circuit(3)
@@ -752,10 +806,7 @@ class Circuit(BaseCircuit):
752
806
  if i != j:
753
807
  e ^ edge2[i]
754
808
  for i in range(len(sample)):
755
- if sample[i] == "0":
756
- m = np.array([1, 0], dtype=npdtype)
757
- else:
758
- m = np.array([0, 1], dtype=npdtype)
809
+ m = onehot_d_tensor(sample[i], d=self._d)
759
810
  nodes1.append(tn.Node(m))
760
811
  nodes1[-1].get_edge(0) ^ edge1[index[i]]
761
812
  nodes2.append(tn.Node(m))
@@ -766,15 +817,13 @@ class Circuit(BaseCircuit):
766
817
  / p
767
818
  * contractor(nodes1, output_edge_order=[edge1[j], edge2[j]]).tensor
768
819
  )
769
- pu = rho[0, 0]
770
- r = backend.random_uniform([])
771
- r = backend.real(backend.cast(r, dtypestr))
772
- if r < backend.real(pu):
773
- sample += "0"
774
- p = p * pu
775
- else:
776
- sample += "1"
777
- p = p * (1 - pu)
820
+ probs = backend.real(backend.diagonal(rho))
821
+ probs /= backend.sum(probs)
822
+ outcome = backend.implicit_randc(self._d, shape=1, p=probs)
823
+
824
+ sample += _ALPHABET[outcome]
825
+ p *= float(probs[outcome])
826
+
778
827
  if with_prob:
779
828
  return sample, p
780
829
  else:
@@ -794,6 +843,10 @@ class Circuit(BaseCircuit):
794
843
  ) -> Tensor:
795
844
  """
796
845
  Compute the expectation of corresponding operators.
846
+ For qudit (d > 2),
847
+ ensure that operator tensor shapes are consistent with d (each site contributes two axes of size d).
848
+
849
+ Noise shorthand (via noise_conf) is qubit-only; for d>2, use explicit operators.
797
850
 
798
851
  :Example:
799
852
 
@@ -829,14 +882,12 @@ class Circuit(BaseCircuit):
829
882
  :param nmc: repetition time for Monte Carlo sampling for noisfy calculation, defaults to 1000
830
883
  :type nmc: int, optional
831
884
  :param status: external randomness given by tensor uniformly from [0, 1], defaults to None,
832
- used for noisfy circuit sampling
885
+ used for noisy circuit sampling
833
886
  :type status: Optional[Tensor], optional
834
887
  :raises ValueError: "Cannot measure two operators in one index"
835
888
  :return: Tensor with one element
836
889
  :rtype: Tensor
837
890
  """
838
- from .noisemodel import expectation_noisfy
839
-
840
891
  if noise_conf is None:
841
892
  # if not reuse:
842
893
  # nodes1, edge1 = self._copy()
@@ -851,6 +902,8 @@ class Circuit(BaseCircuit):
851
902
  nodes1 = _full_light_cone_cancel(nodes1)
852
903
  return contractor(nodes1).tensor
853
904
  else:
905
+ from .noisemodel import expectation_noisfy
906
+
854
907
  return expectation_noisfy(
855
908
  self,
856
909
  *ops,
@@ -871,9 +924,11 @@ def expectation(
871
924
  bra: Optional[Tensor] = None,
872
925
  conj: bool = True,
873
926
  normalization: bool = False,
927
+ dim: Optional[int] = None,
874
928
  ) -> Tensor:
875
929
  """
876
930
  Compute :math:`\\langle bra\\vert ops \\vert ket\\rangle`.
931
+ For qudit systems (d>2), ops must be reshaped with per-site axes of length d.
877
932
 
878
933
  Example 1 (:math:`bra` is same as :math:`ket`)
879
934
 
@@ -918,6 +973,8 @@ def expectation(
918
973
  :type ket: Tensor
919
974
  :param bra: :math:`bra`, defaults to None, which is the same as ``ket``.
920
975
  :type bra: Optional[Tensor], optional
976
+ :param dim: dimension of the circuit (defaults to 2)
977
+ :type dim: int, optional
921
978
  :param conj: :math:`bra` changes to the adjoint matrix of :math:`bra`, defaults to True.
922
979
  :type conj: bool, optional
923
980
  :param normalization: Normalize the :math:`ket` and :math:`bra`, defaults to False.
@@ -926,6 +983,7 @@ def expectation(
926
983
  :return: The result of :math:`\\langle bra\\vert ops \\vert ket\\rangle`.
927
984
  :rtype: Tensor
928
985
  """
986
+ dim = 2 if dim is None else dim
929
987
  if bra is None:
930
988
  bra = ket
931
989
  if isinstance(ket, QuOperator):
@@ -939,7 +997,7 @@ def expectation(
939
997
  for op, index in ops:
940
998
  if not isinstance(op, tn.Node):
941
999
  # op is only a matrix
942
- op = backend.reshape2(op)
1000
+ op = backend.reshaped(op, dim)
943
1001
  op = gates.Gate(op)
944
1002
  if isinstance(index, int):
945
1003
  index = [index]
@@ -963,8 +1021,8 @@ def expectation(
963
1021
  if conj is True:
964
1022
  bra = backend.conj(bra)
965
1023
  ket = backend.reshape(ket, [-1])
966
- ket = backend.reshape2(ket)
967
- bra = backend.reshape2(bra)
1024
+ ket = backend.reshaped(ket, dim)
1025
+ bra = backend.reshaped(bra, dim)
968
1026
  n = len(backend.shape_tuple(ket))
969
1027
  ket = Gate(ket)
970
1028
  bra = Gate(bra)
@@ -976,7 +1034,7 @@ def expectation(
976
1034
  for op, index in ops:
977
1035
  if not isinstance(op, tn.Node):
978
1036
  # op is only a matrix
979
- op = backend.reshape2(op)
1037
+ op = backend.reshaped(op, dim)
980
1038
  op = gates.Gate(op)
981
1039
  if isinstance(index, int):
982
1040
  index = [index]
@@ -36,7 +36,7 @@ def submit_task(
36
36
  shots: Union[int, Sequence[int]] = 1024,
37
37
  version: str = "1",
38
38
  circuit: Optional[Union[AbstractCircuit, Sequence[AbstractCircuit]]] = None,
39
- **kws: Any
39
+ **kws: Any,
40
40
  ) -> List[Task]:
41
41
  def _circuit2result(c: AbstractCircuit) -> Dict[str, Any]:
42
42
  if device.name in ["testing", "default"]:
@@ -30,7 +30,7 @@ def submit_task(
30
30
  circuit: Optional[Union[AbstractCircuit, Sequence[AbstractCircuit]]] = None,
31
31
  source: Optional[Union[str, Sequence[str]]] = None,
32
32
  compile: bool = True,
33
- **kws: Any
33
+ **kws: Any,
34
34
  ) -> Task:
35
35
  if source is None:
36
36
 
@@ -133,7 +133,7 @@ def submit_task(
133
133
  enable_qos_gate_decomposition: bool = True,
134
134
  enable_qos_initial_mapping: bool = False,
135
135
  qos_dry_run: bool = False,
136
- **kws: Any
136
+ **kws: Any,
137
137
  ) -> List[Task]:
138
138
  """
139
139
  Submit task via tencent provider, we suggest to enable one of the compiling functionality:
@@ -109,7 +109,7 @@ def prune(
109
109
  circuit: Union[AbstractCircuit, List[Dict[str, Any]]],
110
110
  rtol: float = 1e-3,
111
111
  atol: float = 1e-3,
112
- **kws: Any
112
+ **kws: Any,
113
113
  ) -> Any:
114
114
  if isinstance(circuit, list):
115
115
  qir = circuit
@@ -251,7 +251,7 @@ def _merge(
251
251
  def merge(
252
252
  circuit: Union[AbstractCircuit, List[Dict[str, Any]]],
253
253
  rules: Optional[Dict[Tuple[str, ...], str]] = None,
254
- **kws: Any
254
+ **kws: Any,
255
255
  ) -> Any:
256
256
  merge_rules = copy(default_merge_rules)
257
257
  if rules is not None:
tensorcircuit/cons.py CHANGED
@@ -23,6 +23,39 @@ from .simplify import _multi_remove
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
+ ## monkey patch
27
+ _NODE_CREATION_COUNTER = 0
28
+ _original_node_init = tn.Node.__init__
29
+
30
+
31
+ @wraps(_original_node_init)
32
+ def _patched_node_init(self: Any, *args: Any, **kwargs: Any) -> None:
33
+ """Patched Node.__init__ to add a stable creation ID."""
34
+ global _NODE_CREATION_COUNTER
35
+ _original_node_init(self, *args, **kwargs)
36
+ self._stable_id_ = _NODE_CREATION_COUNTER
37
+ _NODE_CREATION_COUNTER += 1
38
+
39
+
40
+ tn.Node.__init__ = _patched_node_init
41
+
42
+
43
+ def _get_edge_stable_key(edge: tn.Edge) -> Tuple[int, int, int, int]:
44
+ n1, n2 = edge.node1, edge.node2
45
+ id1 = getattr(n1, "_stable_id_", -1)
46
+ id2 = getattr(n2, "_stable_id_", -1) if n2 is not None else -2 # -2 for dangling
47
+
48
+ if id1 > id2 or (id1 == id2 and edge.axis1 > edge.axis2):
49
+ id1, id2, ax1, ax2 = id2, id1, edge.axis2, edge.axis1
50
+ else:
51
+ ax1, ax2 = edge.axis1, edge.axis2
52
+ return (id1, ax1, id2, ax2)
53
+
54
+
55
+ def sorted_edges(edges: Iterator[tn.Edge]) -> List[tn.Edge]:
56
+ return sorted(edges, key=_get_edge_stable_key)
57
+
58
+
26
59
  package_name = "tensorcircuit"
27
60
  thismodule = sys.modules[__name__]
28
61
  dtypestr = "complex64"
@@ -30,6 +63,7 @@ rdtypestr = "float32"
30
63
  npdtype = np.complex64
31
64
  backend: NumpyBackend = get_backend("numpy")
32
65
  contractor = tn.contractors.auto
66
+ _ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
33
67
  # these above lines are just for mypy, it is not very good at evaluating runtime object
34
68
 
35
69
 
@@ -477,39 +511,29 @@ def _identity(*args: Any, **kws: Any) -> Any:
477
511
  return args
478
512
 
479
513
 
480
- def _sort_tuple_list(input_list: List[Any], output_list: List[Any]) -> List[Any]:
481
- sorted_elements = [(tuple(sorted(t)), i) for i, t in enumerate(input_list)]
482
- sorted_elements.sort()
483
- return [output_list[i] for _, i in sorted_elements]
484
-
485
-
486
514
  def _get_path_cache_friendly(
487
515
  nodes: List[tn.Node], algorithm: Any
488
516
  ) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
489
517
  nodes = list(nodes)
518
+
519
+ nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
520
+ # if isinstance(algorithm, list):
521
+ # return algorithm, [nodes_new]
522
+
523
+ all_edges = tn.get_all_edges(nodes_new)
524
+ all_edges_sorted = sorted_edges(all_edges)
490
525
  mapping_dict = {}
491
526
  i = 0
492
- for n in nodes:
493
- for e in n:
494
- if id(e) not in mapping_dict:
495
- mapping_dict[id(e)] = get_symbol(i)
496
- i += 1
497
- # TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
498
- input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
499
- # placeholder = [[1e20 for _ in range(100)]]
500
- # order = np.argsort(np.array(list(map(sorted, input_sets)), dtype=object)) # type: ignore
501
- # nodes_new = [nodes[i] for i in order]
502
- nodes_new = _sort_tuple_list(input_sets, nodes)
503
- if isinstance(algorithm, list):
504
- return algorithm, nodes_new
527
+ for edge in all_edges_sorted:
528
+ if id(edge) not in mapping_dict:
529
+ mapping_dict[id(edge)] = get_symbol(i)
530
+ i += 1
505
531
 
506
532
  input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
507
533
  output_set = list(
508
- [mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)]
534
+ [mapping_dict[id(e)] for e in sorted_edges(tn.get_subgraph_dangling(nodes_new))]
509
535
  )
510
- size_dict = {
511
- mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new)
512
- }
536
+ size_dict = {mapping_dict[id(edge)]: edge.dimension for edge in all_edges_sorted}
513
537
  logger.debug("input_sets: %s" % input_sets)
514
538
  logger.debug("output_set: %s" % output_set)
515
539
  logger.debug("size_dict: %s" % size_dict)
@@ -670,6 +694,51 @@ def _base(
670
694
  return final_node
671
695
 
672
696
 
697
+ class NodesReturn(Exception):
698
+ """
699
+ Intentionally stop execution to return a value.
700
+ """
701
+
702
+ def __init__(self, value_to_return: Any):
703
+ self.value = value_to_return
704
+ super().__init__(
705
+ f"Intentionally stopping execution to return: {value_to_return}"
706
+ )
707
+
708
+
709
+ def _get_sorted_nodes(nodes: List[Any], *args: Any, **kws: Any) -> Any:
710
+ nodes_new = sorted(nodes, key=lambda node: getattr(node, "_stable_id_", -1))
711
+ raise NodesReturn(nodes_new)
712
+
713
+
714
+ def function_nodes_capture(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
715
+ @wraps(func)
716
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
717
+ with runtime_contractor(method="before"):
718
+ try:
719
+ result = func(*args, **kwargs)
720
+ return result
721
+ except NodesReturn as e:
722
+ return e.value
723
+
724
+ return wrapper
725
+
726
+
727
+ @contextmanager
728
+ def runtime_nodes_capture(key: str = "nodes") -> Iterator[Any]:
729
+ old_contractor = getattr(thismodule, "contractor")
730
+ set_contractor(method="before")
731
+ captured_value: Dict[str, List[tn.Node]] = {}
732
+ try:
733
+ yield captured_value
734
+ except NodesReturn as e:
735
+ captured_value[key] = e.value
736
+ finally:
737
+ for module in sys.modules:
738
+ if module.startswith(package_name):
739
+ setattr(sys.modules[module], "contractor", old_contractor)
740
+
741
+
673
742
  def custom(
674
743
  nodes: List[Any],
675
744
  optimizer: Any,
@@ -740,6 +809,16 @@ def custom_stateful(
740
809
 
741
810
  # only work for custom
742
811
  def contraction_info_decorator(algorithm: Callable[..., Any]) -> Callable[..., Any]:
812
+ """Decorator to add contraction information logging to an optimizer.
813
+
814
+ This decorator wraps an optimization algorithm and prints detailed information
815
+ about the contraction cost (FLOPs, size, write) and path finding time.
816
+
817
+ :param algorithm: The optimization algorithm to decorate.
818
+ :type algorithm: Callable[..., Any]
819
+ :return: The decorated optimization algorithm.
820
+ :rtype: Callable[..., Any]
821
+ """
743
822
  from cotengra import ContractionTree
744
823
 
745
824
  def new_algorithm(
@@ -846,6 +925,9 @@ def set_contractor(
846
925
  **kws,
847
926
  )
848
927
 
928
+ elif method == "before": # a hack way to get the nodes
929
+ cf = _get_sorted_nodes
930
+
849
931
  else:
850
932
  # cf = getattr(tn.contractors, method, None)
851
933
  # if not cf: