tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
tensorcircuit/quantum.py CHANGED
@@ -9,6 +9,7 @@ import math
9
9
  import os
10
10
  from functools import partial, reduce
11
11
  from operator import matmul, mul, or_
12
+ from collections import Counter
12
13
  from typing import (
13
14
  Any,
14
15
  Callable,
@@ -23,6 +24,7 @@ from typing import (
23
24
  )
24
25
 
25
26
  import numpy as np
27
+ import tensornetwork as tn
26
28
  from tensornetwork.network_components import AbstractNode, CopyNode, Edge, Node, connect
27
29
  from tensornetwork.network_operations import (
28
30
  copy,
@@ -30,7 +32,7 @@ from tensornetwork.network_operations import (
30
32
  remove_node,
31
33
  )
32
34
 
33
- from .cons import backend, contractor, dtypestr, npdtype, rdtypestr
35
+ from .cons import backend, contractor, dtypestr, npdtype, rdtypestr, _ALPHABET
34
36
  from .gates import Gate, num_to_tensor
35
37
  from .utils import arg_alias
36
38
 
@@ -55,6 +57,91 @@ def get_all_nodes(edges: Iterable[Edge]) -> List[Node]:
55
57
  return nodes
56
58
 
57
59
 
60
+ def onehot_d_tensor(_k: Union[int, Tensor], d: int = 2) -> Tensor:
61
+ """
62
+ Construct a one-hot vector (or matrix) of local dimension ``d``.
63
+
64
+ :param _k: index or indices to set as 1. Can be an int or a backend Tensor.
65
+ :type _k: int or Tensor
66
+ :param d: local dimension (number of categories), defaults to 2
67
+ :type d: int, optional
68
+ :return: one-hot encoded vector (shape [d]) or matrix (shape [len(_k), d])
69
+ :rtype: Tensor
70
+ """
71
+ if isinstance(_k, int):
72
+ vec = backend.one_hot(_k, d)
73
+ else:
74
+ vec = backend.one_hot(backend.cast(_k, "int32"), d)
75
+ return backend.cast(vec, dtypestr)
76
+
77
+
78
+ def _decode_basis_label(label: str, n: int, dim: int) -> List[int]:
79
+ """
80
+ Decode a string basis label into a list of integer digits.
81
+
82
+ The label is interpreted in base-``dim`` using characters ``0-9A-Z``.
83
+ Only dimensions up to 36 are supported.
84
+
85
+ :param label: basis label string, e.g. "010" or "A9F"
86
+ :type label: str
87
+ :param n: number of sites (expected length of the label)
88
+ :type n: int
89
+ :param dim: local dimension (2 <= dim <= 36)
90
+ :type dim: int
91
+ :return: list of integer digits of length ``n``, each in ``[0, dim-1]``
92
+ :rtype: List[int]
93
+
94
+ :raises NotImplementedError: if ``dim > 36``
95
+ :raises ValueError: if the label length mismatches ``n``,
96
+ or contains invalid/out-of-range characters
97
+ """
98
+ if dim > 36:
99
+ raise NotImplementedError(
100
+ f"String basis label supports d<=36 (0-9A-Z). Got dim={dim}. "
101
+ "Use an integer array/tensor of length n instead."
102
+ )
103
+ s = label.upper()
104
+ if len(s) != n:
105
+ raise ValueError(f"Basis label length mismatch: expect {n}, got {len(s)}")
106
+ digits = []
107
+ for ch in s:
108
+ if ch not in _ALPHABET:
109
+ raise ValueError(
110
+ f"Invalid character '{ch}' in basis label (allowed 0-9A-Z)."
111
+ )
112
+ v = _ALPHABET.index(ch)
113
+ if v >= dim:
114
+ raise ValueError(
115
+ f"Digit '{ch}' (= {v}) out of range for base-d with dim={dim}."
116
+ )
117
+ digits.append(v)
118
+ return digits
119
+
120
+
121
+ def _infer_num_sites(D: int, dim: int) -> int:
122
+ """
123
+ Infer the number of sites (n) from a Hilbert space dimension D
124
+ and local dimension d, assuming D = d**n.
125
+
126
+ :param D: total Hilbert space dimension (int)
127
+ :param dim: local dimension per site (int)
128
+ :return: n such that D == d**n
129
+ :raises ValueError: if D is not an exact power of d
130
+ """
131
+ if not (isinstance(D, int) and D > 0):
132
+ raise ValueError(f"D must be a positive integer, got {D}")
133
+ if not (isinstance(dim, int) and dim >= 2):
134
+ raise ValueError(f"d must be an integer >= 2, got {dim}")
135
+
136
+ tmp, n = D, 0
137
+ while tmp % dim == 0 and tmp > 1:
138
+ tmp //= dim
139
+ n += 1
140
+ if tmp != 1:
141
+ raise ValueError(f"Dimension {D} is not a power of local dim {dim}")
142
+ return n
143
+
144
+
58
145
  def _reachable(nodes: List[AbstractNode]) -> List[AbstractNode]:
59
146
  if not nodes:
60
147
  raise ValueError("Reachable requires at least 1 node.")
@@ -71,7 +158,7 @@ def _reachable(nodes: List[AbstractNode]) -> List[AbstractNode]:
71
158
  if n not in seen_nodes and n not in node_que[i + 1 :]:
72
159
  node_que.append(n)
73
160
  i += 1
74
- return seen_nodes
161
+ return sorted(seen_nodes, key=lambda node: getattr(node, "_stable_id_", -1))
75
162
 
76
163
 
77
164
  def reachable(
@@ -661,10 +748,10 @@ class QuOperator:
661
748
  return self.__mul__(other)
662
749
 
663
750
  def tensor_product(self, other: "QuOperator") -> "QuOperator":
664
- """
751
+ r"""
665
752
  Tensor product with another operator.
666
753
  Given two operators `A` and `B`, produces a new operator `AB` representing
667
- :math:`A B`. The `out_edges` (`in_edges`) of `AB` is simply the
754
+ :math:`A \otimes B`. The `out_edges` (`in_edges`) of `AB` is simply the
668
755
  concatenation of the `out_edges` (`in_edges`) of `A.copy()` with that of
669
756
  `B.copy()`:
670
757
  `new_out_edges = [*out_edges_A_copy, *out_edges_B_copy]`
@@ -1151,33 +1238,281 @@ def generate_local_hamiltonian(
1151
1238
  return hop
1152
1239
 
1153
1240
 
1154
- def tn2qop(tn_mpo: Any) -> QuOperator:
1241
+ # TODO(@Charlespkuer): Add more conversion functions for other packages
1242
+ def extract_tensors_from_qop(qop: QuOperator) -> Tuple[List[Node], bool, int]:
1155
1243
  """
1156
- Convert MPO in TensorNetwork package to QuOperator.
1244
+ Extract and sort tensors from QuOperator for conversion to other tensor network formats.
1157
1245
 
1158
- :param tn_mpo: MPO in the form of TensorNetwork package
1159
- :type tn_mpo: ``tn.matrixproductstates.mpo.*``
1160
- :return: MPO in the form of QuOperator
1246
+ :param qop: Input QuOperator to extract tensors from
1247
+ :type qop: QuOperator
1248
+ :return: Tuple containing (sorted_nodes, is_mps, nwires) where:
1249
+ - sorted_nodes: List of Node objects sorted in linear chain order
1250
+ - is_mps: Boolean flag indicating if the structure is MPS (True) or MPO (False)
1251
+ - nwires: Integer number of physical edges/qubits in the system
1252
+ :rtype: Tuple[List[Node], bool, int]
1253
+ """
1254
+ is_mps = len(qop.in_edges) == 0
1255
+ nwires = len(qop.out_edges)
1256
+ if not is_mps and len(qop.in_edges) != nwires:
1257
+ raise ValueError(
1258
+ "MPO must have the same number of input and output edges. "
1259
+ f"Got {len(qop.in_edges)} and {nwires}."
1260
+ )
1261
+
1262
+ # Collect all nodes from edges
1263
+ nodes_for_sorting = qop.nodes
1264
+ if len(nodes_for_sorting) != nwires:
1265
+ raise ValueError(f"Number of nodes does not match number of wires.")
1266
+
1267
+ # Find endpoint nodes
1268
+ endpoint_nodes = set()
1269
+ physical_edges = set(qop.out_edges) if is_mps else set(qop.in_edges + qop.out_edges)
1270
+ if is_mps:
1271
+ rank_2_nodes = {node for node in nodes_for_sorting if len(node.edges) == 2}
1272
+ if len(rank_2_nodes) == 2:
1273
+ endpoint_nodes = rank_2_nodes
1274
+
1275
+ if not endpoint_nodes:
1276
+ endpoint_nodes = {edge.node1 for edge in qop.ignore_edges if edge.node1}
1277
+
1278
+ if not endpoint_nodes and len(nodes_for_sorting) > 1:
1279
+ virtual_bond_counts = {}
1280
+ virtual_bond_dim_sums = {}
1281
+
1282
+ for node in nodes_for_sorting:
1283
+ virtual_bonds = 0
1284
+ virtual_dim_sum = 0
1285
+
1286
+ for edge in node.edges:
1287
+ if edge not in physical_edges and not edge.is_dangling():
1288
+ virtual_bonds += 1
1289
+ virtual_dim_sum += edge.dimension
1290
+
1291
+ virtual_bond_counts[node] = virtual_bonds
1292
+ virtual_bond_dim_sums[node] = virtual_dim_sum
1293
+
1294
+ min_dim_sum = min(virtual_bond_dim_sums.values())
1295
+ min_dim_nodes = {
1296
+ node
1297
+ for node, dim_sum in virtual_bond_dim_sums.items()
1298
+ if dim_sum == min_dim_sum
1299
+ }
1300
+
1301
+ if len(min_dim_nodes) == 2:
1302
+ endpoint_nodes = min_dim_nodes
1303
+
1304
+ if not endpoint_nodes:
1305
+ if len(nodes_for_sorting) == 1:
1306
+ raise ValueError("Cannot determine chain structure: only one node found.")
1307
+ elif len(nodes_for_sorting) >= 2:
1308
+ raise ValueError(f"Cannot identify endpoint nodes for your nodes.")
1309
+
1310
+ # Sort nodes along the chain
1311
+ sorted_nodes: list[Node] = []
1312
+ if endpoint_nodes and len(endpoint_nodes) >= 1:
1313
+ current = next(iter(endpoint_nodes))
1314
+ while current and len(sorted_nodes) < nwires:
1315
+ sorted_nodes.append(current)
1316
+ current = next(
1317
+ (
1318
+ e.node2 if e.node1 is current else e.node1
1319
+ for e in current.edges
1320
+ if not e.is_dangling()
1321
+ and e not in physical_edges
1322
+ and (e.node2 if e.node1 is current else e.node1) not in sorted_nodes
1323
+ ),
1324
+ None,
1325
+ )
1326
+
1327
+ if not sorted_nodes:
1328
+ raise ValueError("No valid chain structure found in the QuOperator. ")
1329
+ if len(sorted_nodes) > 0 and len(qop.ignore_edges) > 0:
1330
+ if sorted_nodes[0] is not qop.ignore_edges[0].node1:
1331
+ sorted_nodes = sorted_nodes[::-1]
1332
+
1333
+ return sorted_nodes, is_mps, nwires
1334
+
1335
+
1336
+ def tenpy2qop(tenpy_obj: Any) -> QuOperator:
1337
+ """
1338
+ Converts a TeNPy MPO or MPS to a TensorCircuit QuOperator.
1339
+ This definitive version correctly handles axis ordering and boundary
1340
+ conditions to be compatible with `eval_matrix`.
1341
+
1342
+ :param tenpy_obj: A MPO or MPS object from the TeNPy package.
1343
+ :type tenpy_obj: Union[tenpy.networks.mpo.MPO, tenpy.networks.mps.MPS]
1344
+ :return: The corresponding state or operator as a QuOperator.
1161
1345
  :rtype: QuOperator
1162
1346
  """
1163
- tn_mpo = tn_mpo.tensors
1164
- nwires = len(tn_mpo)
1165
- mpo = []
1166
- for i in range(nwires):
1167
- mpo.append(Node(tn_mpo[i]))
1347
+ # MPO objects have _W attribute containing tensor list (documented in tenpy.networks.mpo.MPO)
1348
+ # MPS objects have _B attribute containing tensor list (documented in tenpy.networks.mps.MPS)
1349
+ # These are internal attributes that store the actual tensor data for each site
1350
+ # Reference: https://tenpy.readthedocs.io/en/latest/reference/tenpy.networks.mpo.html
1351
+ # https://tenpy.readthedocs.io/en/latest/reference/tenpy.networks.mps.html
1352
+ is_mpo = hasattr(tenpy_obj, "_W")
1353
+ tenpy_tensors = tenpy_obj._W if is_mpo else tenpy_obj._B
1354
+ nwires = len(tenpy_tensors)
1355
+ if nwires == 0:
1356
+ return quantum_constructor([], [], [])
1357
+
1358
+ nodes = []
1359
+ if is_mpo:
1360
+ original_tensors_obj = tenpy_tensors
1361
+
1362
+ for i, W_obj in enumerate(original_tensors_obj):
1363
+ arr = W_obj.to_ndarray()
1364
+ labels = W_obj.get_leg_labels()
1365
+ wL_idx = labels.index("wL")
1366
+ p_idx = labels.index("p")
1367
+ p_star_idx = labels.index("p*")
1368
+ wR_idx = labels.index("wR")
1369
+
1370
+ arr_reordered = arr.transpose((wL_idx, p_idx, p_star_idx, wR_idx))
1371
+ if nwires == 1:
1372
+ arr_reordered = arr_reordered[[0], :, :, :]
1373
+ arr_reordered = arr_reordered[:, :, :, [-1]]
1374
+ else:
1375
+ if i == 0:
1376
+ arr_reordered = arr_reordered[[0], :, :, :]
1377
+ elif i == nwires - 1:
1378
+ arr_reordered = arr_reordered[:, :, :, [-1]]
1379
+
1380
+ node = Node(
1381
+ arr_reordered, name=f"mpo_{i}", axis_names=["wL", "p", "p*", "wR"]
1382
+ )
1383
+ nodes.append(node)
1384
+
1385
+ if nwires > 1:
1386
+ for i in range(nwires - 1):
1387
+ nodes[i][3] ^ nodes[i + 1][0]
1388
+
1389
+ out_edges = [n[2] for n in nodes]
1390
+ in_edges = [n[1] for n in nodes]
1391
+ ignore_edges = [nodes[0][0], nodes[-1][3]]
1392
+ else: # MPS
1393
+ for i in range(nwires):
1394
+ B_obj = tenpy_obj.get_B(i)
1395
+ arr = B_obj.to_ndarray()
1396
+ labels = B_obj.get_leg_labels()
1397
+ vL_idx = labels.index("vL")
1398
+ p_idx = labels.index("p")
1399
+ vR_idx = labels.index("vR")
1400
+ arr_reordered = arr.transpose((vL_idx, p_idx, vR_idx))
1401
+ node = Node(arr_reordered, name=f"mps_{i}", axis_names=["vL", "p", "vR"])
1402
+ nodes.append(node)
1403
+
1404
+ if nwires > 1:
1405
+ for i in range(nwires - 1):
1406
+ nodes[i][2] ^ nodes[i + 1][0]
1407
+
1408
+ out_edges = [n[1] for n in nodes]
1409
+ in_edges = []
1410
+ ignore_edges = [nodes[0][0], nodes[-1][2]]
1411
+
1412
+ qop = quantum_constructor(out_edges, in_edges, [], ignore_edges)
1168
1413
 
1169
- for i in range(nwires - 1):
1170
- connect(mpo[i][1], mpo[i + 1][0])
1171
- # TODO(@refraction-ray): whether in and out edge is in the correct order require further check
1172
- qop = quantum_constructor(
1173
- [mpo[i][-1] for i in range(nwires)], # out_edges
1174
- [mpo[i][-2] for i in range(nwires)], # in_edges
1175
- [],
1176
- [mpo[0][0], mpo[-1][1]], # ignore_edges
1177
- )
1178
1414
  return qop
1179
1415
 
1180
1416
 
1417
+ def qop2tenpy(qop: QuOperator) -> Any:
1418
+ """
1419
+ Convert TensorCircuit QuOperator to MPO or MPS from TeNPy.
1420
+
1421
+ Requirements: QuOperator must represent valid MPS/MPO structure:
1422
+ - Linear chain topology with open boundaries only
1423
+ - MPS: no input edges, consistent virtual bonds, rank-3 or 4(with empty input edges) tensors
1424
+ - MPO: equal input/output edges, rank-4 tensors
1425
+ - Cyclic boundary conditions NOT supported
1426
+
1427
+ :param qop: The corresponding state/operator as a QuOperator.
1428
+ :type qop: QuOperator
1429
+ :return: MPO or MPS object from the TeNPy package.
1430
+ :rtype: Union[tenpy.networks.mpo.MPO, tenpy.networks.mps.MPS]
1431
+ """
1432
+ try:
1433
+ from tenpy.networks import MPO, MPS, Site
1434
+ from tenpy.linalg import np_conserved as npc
1435
+ from tenpy.linalg import LegCharge
1436
+ except ImportError:
1437
+ raise ImportError("Please install TeNPy package to use this function.")
1438
+
1439
+ sorted_nodes, is_mps, nwires = extract_tensors_from_qop(qop)
1440
+
1441
+ physical_dim = qop.out_edges[0].dimension if is_mps else qop.in_edges[0].dimension
1442
+ sites = [Site(LegCharge.from_trivial(physical_dim), "q") for _ in range(nwires)]
1443
+
1444
+ # MPS Conversion
1445
+ if is_mps:
1446
+ tensors = []
1447
+ for i, node in enumerate(sorted_nodes):
1448
+ tensor = np.asarray(node.tensor)
1449
+ if tensor.ndim == 3:
1450
+ if i == 0:
1451
+ if tensor.shape[0] > 1:
1452
+ tensor = tensor[0:1, :, :]
1453
+ elif i == len(sorted_nodes) - 1:
1454
+ if tensor.shape[2] > 1:
1455
+ tensor = tensor[:, :, 0:1]
1456
+ tensors.append(
1457
+ npc.Array.from_ndarray(
1458
+ tensor,
1459
+ legcharges=[LegCharge.from_trivial(s) for s in tensor.shape],
1460
+ labels=["vL", "p", "vR"],
1461
+ )
1462
+ )
1463
+
1464
+ SVs = (
1465
+ [np.ones([1])]
1466
+ + [np.ones(tensors[i].get_leg("vR").ind_len) for i in range(nwires - 1)]
1467
+ + [np.ones([1])]
1468
+ )
1469
+ return MPS(sites, tensors, SVs, bc="finite")
1470
+
1471
+ # MPO Conversion
1472
+ raw_tensors = [np.asarray(node.tensor) for node in sorted_nodes]
1473
+
1474
+ if nwires == 1:
1475
+ chi = 1
1476
+ IdL = IdR = 0
1477
+ reconstructed_tensors = raw_tensors
1478
+ else:
1479
+ chi = max(
1480
+ raw_tensors[0].shape[3] if raw_tensors[0].ndim > 3 else 1,
1481
+ raw_tensors[-1].shape[0] if raw_tensors[-1].ndim > 3 else 1,
1482
+ )
1483
+ IdL = 0
1484
+ IdR = chi - 1 if chi > 1 else 0
1485
+
1486
+ reconstructed_tensors = []
1487
+ for i, tensor in enumerate(raw_tensors):
1488
+ if i == 0 and tensor.shape[0] < chi:
1489
+ new_shape = (chi,) + tensor.shape[1:]
1490
+ padded_tensor = np.zeros(new_shape, dtype=tensor.dtype)
1491
+ padded_tensor[IdL, ...] = tensor[0, ...]
1492
+ reconstructed_tensors.append(padded_tensor)
1493
+ elif i == nwires - 1 and len(tensor.shape) > 3 and tensor.shape[3] < chi:
1494
+ new_shape = tensor.shape[:3] + (chi,)
1495
+ padded_tensor = np.zeros(new_shape, dtype=tensor.dtype)
1496
+ padded_tensor[..., IdR] = tensor[..., 0]
1497
+ reconstructed_tensors.append(padded_tensor)
1498
+ else:
1499
+ reconstructed_tensors.append(tensor)
1500
+
1501
+ tenpy_Ws = []
1502
+ for tensor in reconstructed_tensors:
1503
+ labels = ["wL", "wR", "p", "p*"]
1504
+ tensor = np.transpose(tensor, (0, 3, 1, 2))
1505
+ tenpy_Ws.append(
1506
+ npc.Array.from_ndarray(
1507
+ tensor,
1508
+ legcharges=[LegCharge.from_trivial(s) for s in tensor.shape],
1509
+ labels=labels,
1510
+ )
1511
+ )
1512
+
1513
+ return MPO(sites, tenpy_Ws, bc="finite", IdL=IdL, IdR=IdR)
1514
+
1515
+
1181
1516
  def quimb2qop(qb_mpo: Any) -> QuOperator:
1182
1517
  """
1183
1518
  Convert MPO in Quimb package to QuOperator.
@@ -1219,6 +1554,152 @@ def quimb2qop(qb_mpo: Any) -> QuOperator:
1219
1554
  return qop
1220
1555
 
1221
1556
 
1557
+ def qop2quimb(qop: QuOperator) -> Any:
1558
+ """
1559
+ Convert QuOperator to MPO or MPS in Quimb package.
1560
+
1561
+ Requirements: QuOperator must represent valid MPS/MPO structure:
1562
+ - Linear chain topology with open boundaries only
1563
+ - MPS: no input edges, consistent virtual bonds between adjacent tensors
1564
+ - MPO: equal input/output edges, rank-4 tensors
1565
+ - Edge connectivity: each internal node connected to exactly 2 neighbors
1566
+ - Cyclic boundary conditions NOT supported
1567
+
1568
+ :param qop: MPO in the form of QuOperator
1569
+ :type qop: QuOperator
1570
+ :return: MPO in the form of Quimb package
1571
+ :rtype: quimb.tensor.tensor_gen.MatrixProductOperator
1572
+ """
1573
+ try:
1574
+ import quimb.tensor as qtn
1575
+ except ImportError:
1576
+ raise ImportError("Please install Quimb package to use this function.")
1577
+
1578
+ sorted_nodes, is_mps, _ = extract_tensors_from_qop(qop)
1579
+
1580
+ quimb_tensors = []
1581
+ node_map = {node: i for i, node in enumerate(sorted_nodes)}
1582
+
1583
+ for i, node in enumerate(sorted_nodes):
1584
+ tensor_data = node.tensor
1585
+ inds: List[str] = []
1586
+
1587
+ for axis, edge in enumerate(node.edges):
1588
+ if edge in qop.out_edges:
1589
+ site_index = qop.out_edges.index(edge)
1590
+ inds.append(f"k{site_index}")
1591
+ elif edge in qop.in_edges and not is_mps:
1592
+ site_index = qop.in_edges.index(edge)
1593
+ inds.append(f"b{site_index}")
1594
+ elif edge in qop.ignore_edges:
1595
+ if i == 0:
1596
+ inds.append("_left_dangling")
1597
+ elif i == len(sorted_nodes) - 1:
1598
+ inds.append("_right_dangling")
1599
+ else:
1600
+ inds.append(f"_ignore_{i}_{axis}")
1601
+ else:
1602
+ neighbor = edge.node1 if edge.node2 is node else edge.node2
1603
+ if neighbor in node_map:
1604
+ j = node_map[neighbor]
1605
+ left, right = min(i, j), max(i, j)
1606
+ inds.append(f"v{left}_{right}")
1607
+ else:
1608
+ inds.append(f"_unconnected_{i}_{axis}")
1609
+
1610
+ quimb_tensors.append(qtn.Tensor(tensor_data, inds=inds, tags=f"I{i}"))
1611
+
1612
+ tn = qtn.TensorNetwork(quimb_tensors)
1613
+
1614
+ if is_mps:
1615
+ return tn.as_network(qtn.MatrixProductState)
1616
+ else:
1617
+ return tn.as_network(qtn.MatrixProductOperator)
1618
+
1619
+
1620
+ def tn2qop(tn_obj: Any) -> QuOperator:
1621
+ """
1622
+ Convert MPO in TensorNetwork package to QuOperator.
1623
+
1624
+ :param tn_mpo: MPO in the form of TensorNetwork package
1625
+ :type tn_mpo: ``tn.matrixproductstates.mpo.*``
1626
+ :return: MPO in the form of QuOperator
1627
+ :rtype: QuOperator
1628
+ """
1629
+ tn_tensors = tn_obj.tensors
1630
+ nwires = len(tn_tensors)
1631
+
1632
+ if nwires == 0:
1633
+ return quantum_constructor([], [], [])
1634
+
1635
+ is_mps = all(len(t.shape) <= 3 for t in tn_tensors)
1636
+
1637
+ nodes = []
1638
+ for i in range(nwires):
1639
+ nodes.append(Node(tn_tensors[i], name=f"tensor_{i}"))
1640
+
1641
+ if is_mps:
1642
+ for i in range(nwires - 1):
1643
+ connect(nodes[i][-1], nodes[i + 1][0])
1644
+
1645
+ out_edges = []
1646
+ for i, node in enumerate(nodes):
1647
+ if len(node.edges) == 2:
1648
+ physical_edge = next(e for e in node.edges if e.is_dangling())
1649
+ out_edges.append(physical_edge)
1650
+ else:
1651
+ out_edges.append(node[1])
1652
+
1653
+ in_edges = []
1654
+
1655
+ ignore_edges = []
1656
+ left_dangling = next(
1657
+ (e for e in nodes[0].edges if e.is_dangling() and e not in out_edges), None
1658
+ )
1659
+ right_dangling = next(
1660
+ (e for e in nodes[-1].edges if e.is_dangling() and e not in out_edges), None
1661
+ )
1662
+
1663
+ if left_dangling:
1664
+ ignore_edges.append(left_dangling)
1665
+ if right_dangling:
1666
+ ignore_edges.append(right_dangling)
1667
+
1668
+ else:
1669
+ for i in range(nwires - 1):
1670
+ connect(nodes[i][1], nodes[i + 1][0])
1671
+
1672
+ out_edges = [nodes[i][-1] for i in range(nwires)]
1673
+ in_edges = [nodes[i][-2] for i in range(nwires)]
1674
+ ignore_edges = [nodes[0][0], nodes[-1][1]]
1675
+
1676
+ qop = quantum_constructor(
1677
+ out_edges,
1678
+ in_edges,
1679
+ [],
1680
+ ignore_edges,
1681
+ )
1682
+ return qop
1683
+
1684
+
1685
+ def qop2tn(qop: QuOperator) -> Any:
1686
+ """
1687
+ Convert QuOperator back to MPO or MPS in TensorNetwork package.
1688
+
1689
+ :param qop: MPO or MPS in the form of QuOperator, param in docstring
1690
+ :return: MPO or MPS in the form of TensorNetwork
1691
+ :rtype: Union[tn.matrixproductstates.MPO, tn.matrixproductstates.MPS]
1692
+ """
1693
+ sorted_nodes, is_mps, _ = extract_tensors_from_qop(qop)
1694
+
1695
+ tensors = [node.tensor for node in sorted_nodes]
1696
+
1697
+ if is_mps:
1698
+ return tn.FiniteMPS(tensors, canonicalize=False)
1699
+ else:
1700
+ return tn.matrixproductstates.mpo.FiniteMPO(tensors)
1701
+
1702
+
1222
1703
  # TODO(@refraction-ray): Z2 analogy or more general analogies for the following u1 functions
1223
1704
 
1224
1705
 
@@ -1433,7 +1914,7 @@ def PauliStringSum2Dense(
1433
1914
  return sparsem.todense()
1434
1915
  sparsem = backend.coo_sparse_matrix_from_numpy(sparsem)
1435
1916
  densem = backend.to_dense(sparsem)
1436
- return densem
1917
+ return backend.convert_to_tensor(densem)
1437
1918
 
1438
1919
 
1439
1920
  # already implemented as backend method
@@ -1755,7 +2236,10 @@ def entanglement_entropy(state: Tensor, cut: Union[int, List[int]]) -> Tensor:
1755
2236
 
1756
2237
 
1757
2238
  def reduced_wavefunction(
1758
- state: Tensor, cut: List[int], measure: Optional[List[int]] = None
2239
+ state: Tensor,
2240
+ cut: List[int],
2241
+ measure: Optional[List[int]] = None,
2242
+ dim: Optional[int] = None,
1759
2243
  ) -> Tensor:
1760
2244
  """
1761
2245
  Compute the reduced wavefunction from the quantum state ``state``.
@@ -1770,20 +2254,22 @@ def reduced_wavefunction(
1770
2254
  :type measure: List[int]
1771
2255
  :return: _description_
1772
2256
  :rtype: Tensor
2257
+ :param dim: dimension of qudit system
2258
+ :type dim: int
1773
2259
  """
2260
+ dim = 2 if dim is None else dim
1774
2261
  if measure is None:
1775
2262
  measure = [0 for _ in cut]
1776
- s = backend.reshape2(state)
2263
+ s = backend.reshaped(state, dim)
1777
2264
  n = len(backend.shape_tuple(s))
1778
2265
  s_node = Gate(s)
1779
2266
  end_nodes = []
1780
2267
  for c, m in zip(cut, measure):
1781
- rt = backend.cast(backend.convert_to_tensor(1 - m), dtypestr) * backend.cast(
1782
- backend.convert_to_tensor(np.array([1.0, 0.0])), dtypestr
1783
- ) + backend.cast(backend.convert_to_tensor(m), dtypestr) * backend.cast(
1784
- backend.convert_to_tensor(np.array([0.0, 1.0])), dtypestr
2268
+ oh = backend.cast(
2269
+ backend.one_hot(backend.cast(backend.convert_to_tensor(m), "int32"), dim),
2270
+ dtypestr,
1785
2271
  )
1786
- end_node = Gate(rt)
2272
+ end_node = Gate(backend.convert_to_tensor(oh))
1787
2273
  end_nodes.append(end_node)
1788
2274
  s_node[c] ^ end_node[0]
1789
2275
  new_node = contractor(
@@ -1798,8 +2284,9 @@ def reduced_density_matrix(
1798
2284
  cut: Union[int, List[int]],
1799
2285
  p: Optional[Tensor] = None,
1800
2286
  normalize: bool = True,
2287
+ dim: Optional[int] = None,
1801
2288
  ) -> Union[Tensor, QuOperator]:
1802
- """
2289
+ r"""
1803
2290
  Compute the reduced density matrix from the quantum state ``state``.
1804
2291
 
1805
2292
  :param state: The quantum state in form of Tensor or QuOperator.
@@ -1811,8 +2298,12 @@ def reduced_density_matrix(
1811
2298
  :type p: Optional[Tensor]
1812
2299
  :return: The reduced density matrix.
1813
2300
  :rtype: Union[Tensor, QuOperator]
1814
- :normalize: if True, returns a trace 1 density matrix. Otherwise does not normalize.
2301
+ :param normalize: if True, returns a trace 1 density matrix. Otherwise, does not normalize.
2302
+ :type normalize: bool
2303
+ :param dim: dimension of qudit system
2304
+ :type dim: int
1815
2305
  """
2306
+ dim = 2 if dim is None else dim
1816
2307
  if isinstance(cut, list) or isinstance(cut, tuple) or isinstance(cut, set):
1817
2308
  traceout = list(cut)
1818
2309
  else:
@@ -1825,21 +2316,19 @@ def reduced_density_matrix(
1825
2316
  return state.partial_trace(traceout)
1826
2317
  if len(state.shape) == 2 and state.shape[0] == state.shape[1]:
1827
2318
  # density operator
1828
- freedomexp = backend.sizen(state)
1829
- # traceout = sorted(traceout)[::-1]
1830
- freedom = int(np.log2(freedomexp) / 2)
1831
- # traceout2 = [i + freedom for i in traceout]
2319
+ freedom = _infer_num_sites(state.shape[0], dim)
1832
2320
  left = traceout + [i for i in range(freedom) if i not in traceout]
1833
2321
  right = [i + freedom for i in left]
1834
- rho = backend.reshape(state, [2 for _ in range(2 * freedom)])
2322
+
2323
+ rho = backend.reshape(state, [dim] * (2 * freedom))
1835
2324
  rho = backend.transpose(rho, perm=left + right)
1836
2325
  rho = backend.reshape(
1837
2326
  rho,
1838
2327
  [
1839
- 2 ** len(traceout),
1840
- 2 ** (freedom - len(traceout)),
1841
- 2 ** len(traceout),
1842
- 2 ** (freedom - len(traceout)),
2328
+ dim ** len(traceout),
2329
+ dim ** (freedom - len(traceout)),
2330
+ dim ** len(traceout),
2331
+ dim ** (freedom - len(traceout)),
1843
2332
  ],
1844
2333
  )
1845
2334
  if p is None:
@@ -1852,20 +2341,20 @@ def reduced_density_matrix(
1852
2341
  p = backend.reshape(p, [-1])
1853
2342
  rho = backend.einsum("a,aiaj->ij", p, rho)
1854
2343
  rho = backend.reshape(
1855
- rho, [2 ** (freedom - len(traceout)), 2 ** (freedom - len(traceout))]
2344
+ rho, [dim ** (freedom - len(traceout)), dim ** (freedom - len(traceout))]
1856
2345
  )
1857
2346
  if normalize:
1858
2347
  rho /= backend.trace(rho)
1859
2348
 
1860
2349
  else:
1861
2350
  w = state / backend.norm(state)
1862
- freedomexp = backend.sizen(state)
1863
- freedom = int(np.log(freedomexp) / np.log(2))
2351
+ size = int(backend.sizen(state))
2352
+ freedom = _infer_num_sites(size, dim)
1864
2353
  perm = [i for i in range(freedom) if i not in traceout]
1865
2354
  perm = perm + traceout
1866
- w = backend.reshape(w, [2 for _ in range(freedom)])
2355
+ w = backend.reshape(w, [dim for _ in range(freedom)])
1867
2356
  w = backend.transpose(w, perm=perm)
1868
- w = backend.reshape(w, [-1, 2 ** len(traceout)])
2357
+ w = backend.reshape(w, [-1, dim ** len(traceout)])
1869
2358
  if p is None:
1870
2359
  rho = w @ backend.adjoint(w)
1871
2360
  else:
@@ -1914,13 +2403,13 @@ def free_energy(
1914
2403
 
1915
2404
  def renyi_entropy(rho: Union[Tensor, QuOperator], k: int = 2) -> Tensor:
1916
2405
  """
1917
- Compute the Rényi entropy of order :math:`k` by given density matrix.
2406
+ Compute the Renyi entropy of order :math:`k` by given density matrix.
1918
2407
 
1919
2408
  :param rho: The density matrix in form of Tensor or QuOperator.
1920
2409
  :type rho: Union[Tensor, QuOperator]
1921
- :param k: The order of Rényi entropy, default is 2.
2410
+ :param k: The order of Renyi entropy, default is 2.
1922
2411
  :type k: int, optional
1923
- :return: The :math:`k` th order of Rényi entropy.
2412
+ :return: The :math:`k` th order of Renyi entropy.
1924
2413
  :rtype: Tensor
1925
2414
  """
1926
2415
  s = 1 / (1 - k) * backend.real(backend.log(trace_product(*[rho for _ in range(k)])))
@@ -1934,7 +2423,7 @@ def renyi_free_energy(
1934
2423
  k: int = 2,
1935
2424
  ) -> Tensor:
1936
2425
  """
1937
- Compute the Rényi free energy of the corresponding density matrix and Hamiltonian.
2426
+ Compute the Renyi free energy of the corresponding density matrix and Hamiltonian.
1938
2427
 
1939
2428
  :Example:
1940
2429
 
@@ -1951,9 +2440,9 @@ def renyi_free_energy(
1951
2440
  :type h: Union[Tensor, QuOperator]
1952
2441
  :param beta: Constant for the optimization, default is 1.
1953
2442
  :type beta: float, optional
1954
- :param k: The order of Rényi entropy, default is 2.
2443
+ :param k: The order of Renyi entropy, default is 2.
1955
2444
  :type k: int, optional
1956
- :return: The :math:`k` th order of Rényi entropy.
2445
+ :return: The :math:`k` th order of Renyi entropy.
1957
2446
  :rtype: Tensor
1958
2447
  """
1959
2448
  energy = backend.real(trace_product(rho, h))
@@ -2008,7 +2497,9 @@ def truncated_free_energy(
2008
2497
 
2009
2498
 
2010
2499
  @op2tensor
2011
- def partial_transpose(rho: Tensor, transposed_sites: List[int]) -> Tensor:
2500
+ def partial_transpose(
2501
+ rho: Tensor, transposed_sites: List[int], dim: Optional[int] = None
2502
+ ) -> Tensor:
2012
2503
  """
2013
2504
  _summary_
2014
2505
 
@@ -2016,10 +2507,13 @@ def partial_transpose(rho: Tensor, transposed_sites: List[int]) -> Tensor:
2016
2507
  :type rho: Tensor
2017
2508
  :param transposed_sites: sites int list to be transposed
2018
2509
  :type transposed_sites: List[int]
2510
+ :param dim: dimension of qudit system
2511
+ :type dim: int
2019
2512
  :return: _description_
2020
2513
  :rtype: Tensor
2021
2514
  """
2022
- rho = backend.reshape2(rho)
2515
+ dim = 2 if dim is None else dim
2516
+ rho = backend.reshaped(rho, dim)
2023
2517
  rho_node = Gate(rho)
2024
2518
  n = len(rho.shape) // 2
2025
2519
  left_edges = []
@@ -2037,7 +2531,9 @@ def partial_transpose(rho: Tensor, transposed_sites: List[int]) -> Tensor:
2037
2531
 
2038
2532
 
2039
2533
  @op2tensor
2040
- def entanglement_negativity(rho: Tensor, transposed_sites: List[int]) -> Tensor:
2534
+ def entanglement_negativity(
2535
+ rho: Tensor, transposed_sites: List[int], dim: Optional[int] = None
2536
+ ) -> Tensor:
2041
2537
  """
2042
2538
  _summary_
2043
2539
 
@@ -2045,17 +2541,21 @@ def entanglement_negativity(rho: Tensor, transposed_sites: List[int]) -> Tensor:
2045
2541
  :type rho: Tensor
2046
2542
  :param transposed_sites: _description_
2047
2543
  :type transposed_sites: List[int]
2544
+ :param dim: dimension of qudit system
2545
+ :type dim: int
2048
2546
  :return: _description_
2049
2547
  :rtype: Tensor
2050
2548
  """
2051
- rhot = partial_transpose(rho, transposed_sites)
2549
+ rhot = partial_transpose(rho, transposed_sites, dim=dim)
2052
2550
  es = backend.eigvalsh(rhot)
2053
2551
  rhot_m = backend.sum(backend.abs(es))
2054
2552
  return (rhot_m - 1.0) / 2.0
2055
2553
 
2056
2554
 
2057
2555
  @op2tensor
2058
- def log_negativity(rho: Tensor, transposed_sites: List[int], base: str = "e") -> Tensor:
2556
+ def log_negativity(
2557
+ rho: Tensor, transposed_sites: List[int], base: str = "e", dim: Optional[int] = None
2558
+ ) -> Tensor:
2059
2559
  """
2060
2560
  _summary_
2061
2561
 
@@ -2065,10 +2565,13 @@ def log_negativity(rho: Tensor, transposed_sites: List[int], base: str = "e") ->
2065
2565
  :type transposed_sites: List[int]
2066
2566
  :param base: whether use 2 based log or e based log, defaults to "e"
2067
2567
  :type base: str, optional
2568
+ :param dim: dimension of qudit system
2569
+ :type dim: int
2068
2570
  :return: _description_
2069
2571
  :rtype: Tensor
2070
2572
  """
2071
- rhot = partial_transpose(rho, transposed_sites)
2573
+ dim = 2 if dim is None else dim
2574
+ rhot = partial_transpose(rho, transposed_sites, dim)
2072
2575
  es = backend.eigvalsh(rhot)
2073
2576
  rhot_m = backend.sum(backend.abs(es))
2074
2577
  een = backend.log(rhot_m)
@@ -2154,7 +2657,9 @@ def double_state(h: Tensor, beta: float = 1) -> Tensor:
2154
2657
 
2155
2658
 
2156
2659
  @op2tensor
2157
- def mutual_information(s: Tensor, cut: Union[int, List[int]]) -> Tensor:
2660
+ def mutual_information(
2661
+ s: Tensor, cut: Union[int, List[int]], dim: Optional[int] = None
2662
+ ) -> Tensor:
2158
2663
  """
2159
2664
  Mutual information between AB subsystem described by ``cut``.
2160
2665
 
@@ -2162,9 +2667,12 @@ def mutual_information(s: Tensor, cut: Union[int, List[int]]) -> Tensor:
2162
2667
  :type s: Tensor
2163
2668
  :param cut: The AB subsystem.
2164
2669
  :type cut: Union[int, List[int]]
2670
+ :param dim: The diagonal matrix in form of Tensor.
2671
+ :type dim: Tensor
2165
2672
  :return: The mutual information between AB subsystem described by ``cut``.
2166
2673
  :rtype: Tensor
2167
2674
  """
2675
+ dim = 2 if dim is None else dim
2168
2676
  if isinstance(cut, list) or isinstance(cut, tuple) or isinstance(cut, set):
2169
2677
  traceout = list(cut)
2170
2678
  else:
@@ -2172,22 +2680,22 @@ def mutual_information(s: Tensor, cut: Union[int, List[int]]) -> Tensor:
2172
2680
 
2173
2681
  if len(s.shape) == 2 and s.shape[0] == s.shape[1]:
2174
2682
  # mixed state
2175
- n = int(np.log2(backend.sizen(s)) / 2)
2683
+ n = _infer_num_sites(s.shape[0], dim=dim)
2176
2684
  hab = entropy(s)
2177
2685
 
2178
2686
  # subsystem a
2179
- rhoa = reduced_density_matrix(s, traceout)
2687
+ rhoa = reduced_density_matrix(s, traceout, dim=dim)
2180
2688
  ha = entropy(rhoa)
2181
2689
 
2182
2690
  # need subsystem b as well
2183
2691
  other = tuple(i for i in range(n) if i not in traceout)
2184
- rhob = reduced_density_matrix(s, other) # type: ignore
2692
+ rhob = reduced_density_matrix(s, other, dim=dim) # type: ignore
2185
2693
  hb = entropy(rhob)
2186
2694
 
2187
2695
  # pure system
2188
2696
  else:
2189
2697
  hab = 0.0
2190
- rhoa = reduced_density_matrix(s, traceout)
2698
+ rhoa = reduced_density_matrix(s, traceout, dim=dim)
2191
2699
  ha = hb = entropy(rhoa)
2192
2700
 
2193
2701
  return ha + hb - hab
@@ -2196,7 +2704,9 @@ def mutual_information(s: Tensor, cut: Union[int, List[int]]) -> Tensor:
2196
2704
  # measurement results and transformations and correlations below
2197
2705
 
2198
2706
 
2199
- def count_s2d(srepr: Tuple[Tensor, Tensor], n: int) -> Tensor:
2707
+ def count_s2d(
2708
+ srepr: Tuple[Tensor, Tensor], n: int, dim: Optional[int] = None
2709
+ ) -> Tensor:
2200
2710
  """
2201
2711
  measurement shots results, sparse tuple representation to dense representation
2202
2712
  count_vector to count_tuple
@@ -2205,11 +2715,14 @@ def count_s2d(srepr: Tuple[Tensor, Tensor], n: int) -> Tensor:
2205
2715
  :type srepr: Tuple[Tensor, Tensor]
2206
2716
  :param n: number of qubits
2207
2717
  :type n: int
2718
+ :param dim: [description], defaults to None
2719
+ :type dim: int, optional
2208
2720
  :return: [description]
2209
2721
  :rtype: Tensor
2210
2722
  """
2723
+ dim = 2 if dim is None else dim
2211
2724
  return backend.scatter(
2212
- backend.cast(backend.zeros([2**n]), srepr[1].dtype),
2725
+ backend.cast(backend.zeros([dim**n]), srepr[1].dtype),
2213
2726
  backend.reshape(srepr[0], [-1, 1]),
2214
2727
  srepr[1],
2215
2728
  )
@@ -2252,117 +2765,146 @@ def count_d2s(drepr: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:
2252
2765
  count_t2v = count_d2s
2253
2766
 
2254
2767
 
2255
- def sample_int2bin(sample: Tensor, n: int) -> Tensor:
2768
+ def sample_int2bin(sample: Tensor, n: int, dim: Optional[int] = None) -> Tensor:
2256
2769
  """
2257
- int sample to bin sample
2770
+ Convert linear-index samples to per-site digits (base-d).
2258
2771
 
2259
- :param sample: in shape [trials] of int elements in the range [0, 2**n)
2772
+ :param sample: shape [trials], integers in [0, d**n)
2260
2773
  :type sample: Tensor
2261
- :param n: number of qubits
2774
+ :param n: number of sites
2262
2775
  :type n: int
2263
- :return: in shape [trials, n] of element (0, 1)
2776
+ :param dim: local dimension, defaults to 2
2777
+ :type dim: int, optional
2778
+ :return: shape [trials, n], entries in [0, d-1]
2264
2779
  :rtype: Tensor
2265
2780
  """
2266
- confg = backend.mod(
2267
- backend.right_shift(sample[..., None], backend.reverse(backend.arange(n))),
2268
- 2,
2269
- )
2270
- return confg
2781
+ dim = 2 if dim is None else dim
2782
+ if dim == 2:
2783
+ return backend.mod(
2784
+ backend.right_shift(sample[..., None], backend.reverse(backend.arange(n))),
2785
+ 2,
2786
+ )
2787
+ else:
2788
+ pos = backend.reverse(backend.arange(n))
2789
+ base = backend.power(dim, pos)
2790
+ digits = backend.mod(
2791
+ backend.floor_divide(sample[..., None], base), # ⌊sample / d**pos⌋
2792
+ dim,
2793
+ )
2794
+ return backend.cast(digits, "int32")
2271
2795
 
2272
2796
 
2273
- def sample_bin2int(sample: Tensor, n: int) -> Tensor:
2797
+ def sample_bin2int(sample: Tensor, n: int, dim: Optional[int] = None) -> Tensor:
2274
2798
  """
2275
2799
  bin sample to int sample
2276
2800
 
2277
2801
  :param sample: in shape [trials, n] of elements (0, 1)
2278
2802
  :type sample: Tensor
2279
- :param n: number of qubits
2803
+ :param n: number of sites
2280
2804
  :type n: int
2805
+ :param dim: local dimension, defaults to 2
2806
+ :type dim: int, optional
2281
2807
  :return: in shape [trials]
2282
2808
  :rtype: Tensor
2283
2809
  """
2284
- power = backend.convert_to_tensor([2**j for j in reversed(range(n))])
2810
+ dim = 2 if dim is None else dim
2811
+ power = backend.convert_to_tensor([dim**j for j in reversed(range(n))])
2285
2812
  return backend.sum(sample * power, axis=-1)
2286
2813
 
2287
2814
 
2288
2815
  def sample2count(
2289
- sample: Tensor, n: int, jittable: bool = True
2816
+ sample: Tensor,
2817
+ n: int,
2818
+ jittable: bool = True,
2819
+ dim: Optional[int] = None,
2290
2820
  ) -> Tuple[Tensor, Tensor]:
2291
2821
  """
2292
- sample_int to count_tuple
2822
+ sample_int to count_tuple (indices, counts), size = d**n
2293
2823
 
2294
- :param sample: _description_
2824
+ :param sample: linear-index samples, shape [shots]
2295
2825
  :type sample: Tensor
2296
- :param n: _description_
2826
+ :param n: number of sites
2297
2827
  :type n: int
2298
- :param jittable: _description_, defaults to True
2299
- :type jittable: bool, optional
2300
- :return: _description_
2828
+ :param jittable: whether to return fixed-size outputs (backend dependent)
2829
+ :type jittable: bool
2830
+ :param dim: local dimension per site, default 2 (qubit)
2831
+ :type dim: int, optional
2832
+ :return: (unique_indices, counts)
2301
2833
  :rtype: Tuple[Tensor, Tensor]
2302
2834
  """
2303
- d = 2**n
2835
+ dim = 2 if dim is None else dim
2836
+ size = dim**n
2304
2837
  if not jittable:
2305
2838
  results = backend.unique_with_counts(sample) # non-jittable
2306
- else: # jax specified
2307
- results = backend.unique_with_counts(sample, size=d, fill_value=-1)
2839
+ else: # jax specified / fixed-size
2840
+ results = backend.unique_with_counts(sample, size=size, fill_value=-1)
2308
2841
  return results
2309
2842
 
2310
2843
 
2311
- def count_vector2dict(count: Tensor, n: int, key: str = "bin") -> Dict[Any, int]:
2844
+ def count_vector2dict(
2845
+ count: Tensor, n: int, key: str = "bin", dim: Optional[int] = None
2846
+ ) -> Dict[Any, int]:
2312
2847
  """
2313
- convert_vector to count_dict_bin or count_dict_int
2848
+ Convert count_vector to count_dict_bin or count_dict_int.
2849
+ For d>10 cases, a base-d string (0-9A-Z) is used.
2314
2850
 
2315
- :param count: tensor in shape [2**n]
2851
+ :param count: tensor in shape [d**n]
2316
2852
  :type count: Tensor
2317
- :param n: number of qubits
2853
+ :param n: number of sites
2318
2854
  :type n: int
2319
2855
  :param key: can be "int" or "bin", defaults to "bin"
2320
2856
  :type key: str, optional
2321
- :return: _description_
2322
- :rtype: _type_
2857
+ :param dim: local dimension (default 2)
2858
+ :type dim: int, optional
2859
+ :return: mapping from configuration to count
2860
+ :rtype: Dict[Any, int]
2323
2861
  """
2324
2862
  from .interfaces import which_backend
2325
2863
 
2864
+ dim = 2 if dim is None else dim
2326
2865
  b = which_backend(count)
2327
- d = {i: b.numpy(count[i]).item() for i in range(2**n)}
2866
+ out_int = {i: b.numpy(count[i]).item() for i in range(dim**n)}
2328
2867
  if key == "int":
2329
- return d
2868
+ return out_int
2330
2869
  else:
2331
- dn = {}
2332
- for k, v in d.items():
2333
- kn = str(bin(k))[2:].zfill(n)
2334
- dn[kn] = v
2335
- return dn
2870
+ out_str = {}
2871
+ for k, v in out_int.items():
2872
+ kn = np.base_repr(k, base=dim).zfill(n)
2873
+ out_str[kn] = v
2874
+ return out_str
2336
2875
 
2337
2876
 
2338
2877
  def count_tuple2dict(
2339
- count: Tuple[Tensor, Tensor], n: int, key: str = "bin"
2878
+ count: Tuple[Tensor, Tensor], n: int, key: str = "bin", dim: Optional[int] = None
2340
2879
  ) -> Dict[Any, int]:
2341
2880
  """
2342
2881
  count_tuple to count_dict_bin or count_dict_int
2343
2882
 
2344
- :param count: count_tuple format
2883
+ :param count: count_tuple format (indices, counts)
2345
2884
  :type count: Tuple[Tensor, Tensor]
2346
- :param n: number of qubits
2885
+ :param n: number of sites (qubits or qudits)
2347
2886
  :type n: int
2348
2887
  :param key: can be "int" or "bin", defaults to "bin"
2349
2888
  :type key: str, optional
2889
+ :param dim: local dimension, defaults to 2
2890
+ :type dim: int, optional
2350
2891
  :return: count_dict
2351
- :rtype: _type_
2892
+ :rtype: Dict[Any, int]
2352
2893
  """
2353
- d = {
2894
+ dim = 2 if dim is None else dim
2895
+ out_int = {
2354
2896
  backend.numpy(i).item(): backend.numpy(j).item()
2355
2897
  for i, j in zip(count[0], count[1])
2356
2898
  if i >= 0
2357
2899
  }
2358
2900
  if key == "int":
2359
- return d
2901
+ return out_int
2360
2902
  else:
2361
- dn = {}
2362
- for k, v in d.items():
2363
- kn = str(bin(k))[2:].zfill(n)
2364
- dn[kn] = v
2365
- return dn
2903
+ out_str = {}
2904
+ for k, v in out_int.items():
2905
+ kn = np.base_repr(k, base=dim).zfill(n)
2906
+ out_str[kn] = v
2907
+ return out_str
2366
2908
 
2367
2909
 
2368
2910
  @partial(arg_alias, alias_dict={"counts": ["shots"], "format": ["format_"]})
@@ -2374,8 +2916,9 @@ def measurement_counts(
2374
2916
  random_generator: Optional[Any] = None,
2375
2917
  status: Optional[Tensor] = None,
2376
2918
  jittable: bool = False,
2919
+ dim: Optional[int] = None,
2377
2920
  ) -> Any:
2378
- """
2921
+ r"""
2379
2922
  Simulate the measuring of each qubit of ``p`` in the computational basis,
2380
2923
  thus producing output like that of ``qiskit``.
2381
2924
 
@@ -2390,6 +2933,7 @@ def measurement_counts(
2390
2933
  "count_tuple": # (np.array([0]), np.array([2]))
2391
2934
 
2392
2935
  "count_dict_bin": # {"00": 2, "01": 0, "10": 0, "11": 0}
2936
+ / for cases d\in [10, 36], "10" -> "A", ..., "35" -> "Z"
2393
2937
 
2394
2938
  "count_dict_int": # {0: 2, 1: 0, 2: 0, 3: 0}
2395
2939
 
@@ -2441,21 +2985,22 @@ def measurement_counts(
2441
2985
  state /= backend.norm(state)
2442
2986
  pi = backend.real(backend.conj(state) * state)
2443
2987
  pi = backend.reshape(pi, [-1])
2444
- d = int(backend.shape_tuple(pi)[0])
2445
- n = int(np.log(d) / np.log(2) + 1e-8)
2988
+
2989
+ local_d = 2 if dim is None else dim
2990
+ total_dim = int(backend.shape_tuple(pi)[0])
2991
+ n = _infer_num_sites(total_dim, local_d)
2992
+
2446
2993
  if (counts is None) or counts <= 0:
2447
2994
  if format == "count_vector":
2448
2995
  return pi
2449
2996
  elif format == "count_tuple":
2450
2997
  return count_d2s(pi)
2451
2998
  elif format == "count_dict_bin":
2452
- return count_vector2dict(pi, n, key="bin")
2999
+ return count_vector2dict(pi, n, key="bin", dim=local_d)
2453
3000
  elif format == "count_dict_int":
2454
- return count_vector2dict(pi, n, key="int")
3001
+ return count_vector2dict(pi, n, key="int", dim=local_d)
2455
3002
  else:
2456
- raise ValueError(
2457
- "unsupported format %s for analytical measurement" % format
2458
- )
3003
+ raise ValueError(f"unsupported format {format} for analytical measurement")
2459
3004
  else:
2460
3005
  raw_counts = backend.probability_sample(
2461
3006
  counts, pi, status=status, g=random_generator
@@ -2466,7 +3011,7 @@ def measurement_counts(
2466
3011
  # raw_counts = backend.stateful_randc(
2467
3012
  # random_generator, a=drange, shape=counts, p=pi
2468
3013
  # )
2469
- return sample2all(raw_counts, n, format=format, jittable=jittable)
3014
+ return sample2all(raw_counts, n, format=format, jittable=jittable, dim=local_d)
2470
3015
 
2471
3016
 
2472
3017
  measurement_results = measurement_counts
@@ -2474,45 +3019,64 @@ measurement_results = measurement_counts
2474
3019
 
2475
3020
  @partial(arg_alias, alias_dict={"format": ["format_"]})
2476
3021
  def sample2all(
2477
- sample: Tensor, n: int, format: str = "count_vector", jittable: bool = False
3022
+ sample: Tensor,
3023
+ n: int,
3024
+ format: str = "count_vector",
3025
+ jittable: bool = False,
3026
+ dim: Optional[int] = None,
2478
3027
  ) -> Any:
2479
3028
  """
2480
- transform ``sample_int`` or ``sample_bin`` form results to other forms specified by ``format``
3029
+ transform ``sample_int`` or ``sample_bin`` results to other forms specified by ``format``
2481
3030
 
2482
- :param sample: measurement shots results in ``sample_int`` or ``sample_bin`` format
3031
+ :param sample: measurement shots results in ``sample_int`` (shape [shots]) or ``sample_bin`` (shape [shots, n])
2483
3032
  :type sample: Tensor
2484
- :param n: number of qubits
3033
+ :param n: number of sites
2485
3034
  :type n: int
2486
- :param format: see the doc in the doc in :py:meth:`tensorcircuit.quantum.measurement_results`,
2487
- defaults to "count_vector"
3035
+ :param format: see :py:meth:`tensorcircuit.quantum.measurement_results`, defaults to "count_vector"
2488
3036
  :type format: str, optional
2489
3037
  :param jittable: only applicable to count transformation in jax backend, defaults to False
2490
3038
  :type jittable: bool, optional
3039
+ :param dim: local dimension (2 for qubit; >2 for qudit), defaults to 2
3040
+ :type dim: Optional[int]
2491
3041
  :return: measurement results specified as ``format``
2492
3042
  :rtype: Any
2493
3043
  """
3044
+ dim = 2 if dim is None else int(dim)
3045
+ n_max_d = int(32 / np.log2(dim))
3046
+ if n > n_max_d:
3047
+ assert (
3048
+ len(backend.shape_tuple(sample)) == 2
3049
+ ), f"n>{n_max_d} is only supported for ``sample_bin``"
3050
+ if format == "sample_bin":
3051
+ return sample
3052
+ if format == "count_dict_bin":
3053
+ binary_strings = ["".join(map(str, shot)) for shot in sample]
3054
+ return dict(Counter(binary_strings))
3055
+ raise ValueError(f"n={n} is too large for measurement representaion: {format}")
3056
+
2494
3057
  if len(backend.shape_tuple(sample)) == 1:
2495
3058
  sample_int = sample
2496
- sample_bin = sample_int2bin(sample, n)
3059
+ sample_bin = sample_int2bin(sample, n, dim=dim)
2497
3060
  elif len(backend.shape_tuple(sample)) == 2:
2498
- sample_int = sample_bin2int(sample, n)
3061
+ sample_int = sample_bin2int(sample, n, dim=dim)
2499
3062
  sample_bin = sample
2500
3063
  else:
2501
3064
  raise ValueError("unrecognized tensor shape for sample")
3065
+
2502
3066
  if format == "sample_int":
2503
3067
  return sample_int
2504
3068
  elif format == "sample_bin":
2505
3069
  return sample_bin
2506
3070
  else:
2507
- count_tuple = sample2count(sample_int, n, jittable)
3071
+ count_tuple = sample2count(sample_int, n, jittable=jittable, dim=dim)
2508
3072
  if format == "count_tuple":
2509
3073
  return count_tuple
2510
3074
  elif format == "count_vector":
2511
- return count_s2d(count_tuple, n)
3075
+ return count_s2d(count_tuple, n, dim=dim)
2512
3076
  elif format == "count_dict_bin":
2513
- return count_tuple2dict(count_tuple, n, key="bin")
3077
+ return count_tuple2dict(count_tuple, n, key="bin", dim=dim)
2514
3078
  elif format == "count_dict_int":
2515
- return count_tuple2dict(count_tuple, n, key="int")
3079
+ return count_tuple2dict(count_tuple, n, key="int", dim=dim)
2516
3080
  else:
2517
3081
  raise ValueError(
2518
3082
  "unsupported format %s for finite shots measurement" % format