emu-mps 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emu_mps/__init__.py +38 -0
- emu_mps/algebra.py +151 -0
- emu_mps/hamiltonian.py +449 -0
- emu_mps/mpo.py +243 -0
- emu_mps/mps.py +528 -0
- emu_mps/mps_backend.py +35 -0
- emu_mps/mps_backend_impl.py +525 -0
- emu_mps/mps_config.py +64 -0
- emu_mps/noise.py +29 -0
- emu_mps/tdvp.py +209 -0
- emu_mps/utils.py +258 -0
- emu_mps-1.2.1.dist-info/METADATA +133 -0
- emu_mps-1.2.1.dist-info/RECORD +14 -0
- emu_mps-1.2.1.dist-info/WHEEL +4 -0
emu_mps/mps.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections import Counter
|
|
5
|
+
from typing import Any, List, Optional, Iterable
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from emu_base import State
|
|
10
|
+
from emu_mps.algebra import add_factors, scale_factors
|
|
11
|
+
from emu_mps.utils import (
|
|
12
|
+
DEVICE_COUNT,
|
|
13
|
+
apply_measurement_errors,
|
|
14
|
+
assign_devices,
|
|
15
|
+
truncate_impl,
|
|
16
|
+
tensor_trace,
|
|
17
|
+
n_operator,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MPS(State):
|
|
22
|
+
"""
|
|
23
|
+
Matrix Product State, aka tensor train.
|
|
24
|
+
|
|
25
|
+
Each tensor has 3 dimensions ordered as such: (left bond, site, right bond).
|
|
26
|
+
|
|
27
|
+
Only qubits are supported.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
DEFAULT_MAX_BOND_DIM: int = 1024
|
|
31
|
+
DEFAULT_PRECISION: float = 1e-5
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
factors: List[torch.Tensor],
|
|
36
|
+
/,
|
|
37
|
+
*,
|
|
38
|
+
orthogonality_center: Optional[int] = None,
|
|
39
|
+
precision: float = DEFAULT_PRECISION,
|
|
40
|
+
max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
|
|
41
|
+
num_gpus_to_use: Optional[int] = DEVICE_COUNT,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
This constructor creates a MPS directly from a list of tensors. It is for internal use only.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
factors: the tensors for each site
|
|
48
|
+
WARNING: for efficiency in a lot of use cases, this list of tensors
|
|
49
|
+
IS NOT DEEP-COPIED. Therefore, the new MPS object is not necessarily
|
|
50
|
+
the exclusive owner of the list and its tensors. As a consequence,
|
|
51
|
+
beware of potential external modifications affecting the list or the tensors.
|
|
52
|
+
You are responsible for deciding whether to pass its own exclusive copy
|
|
53
|
+
of the data to this constructor, or some shared objects.
|
|
54
|
+
orthogonality_center: the orthogonality center of the MPS, or None (in which case
|
|
55
|
+
it will be orthogonalized when needed)
|
|
56
|
+
precision: the precision with which to truncate here or in tdvp
|
|
57
|
+
max_bond_dim: the maximum bond dimension to allow
|
|
58
|
+
num_gpus_to_use: distribute the factors over this many GPUs
|
|
59
|
+
0=all factors to cpu, None=keep the existing device assignment.
|
|
60
|
+
"""
|
|
61
|
+
self.precision = precision
|
|
62
|
+
self.max_bond_dim = max_bond_dim
|
|
63
|
+
|
|
64
|
+
assert all(
|
|
65
|
+
factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
|
|
66
|
+
), "The dimensions of consecutive tensors should match"
|
|
67
|
+
assert (
|
|
68
|
+
factors[0].shape[0] == 1 and factors[-1].shape[2] == 1
|
|
69
|
+
), "The dimension of the left (right) link of the first (last) tensor should be 1"
|
|
70
|
+
|
|
71
|
+
self.factors = factors
|
|
72
|
+
self.num_sites = len(factors)
|
|
73
|
+
assert self.num_sites > 1 # otherwise, do state vector
|
|
74
|
+
|
|
75
|
+
assert (orthogonality_center is None) or (
|
|
76
|
+
0 <= orthogonality_center < self.num_sites
|
|
77
|
+
), "Invalid orthogonality center provided"
|
|
78
|
+
self.orthogonality_center = orthogonality_center
|
|
79
|
+
|
|
80
|
+
if num_gpus_to_use is not None:
|
|
81
|
+
assign_devices(self.factors, min(DEVICE_COUNT, num_gpus_to_use))
|
|
82
|
+
|
|
83
|
+
@classmethod
|
|
84
|
+
def make(
|
|
85
|
+
cls,
|
|
86
|
+
num_sites: int,
|
|
87
|
+
precision: float = DEFAULT_PRECISION,
|
|
88
|
+
max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
|
|
89
|
+
num_gpus_to_use: int = DEVICE_COUNT,
|
|
90
|
+
) -> MPS:
|
|
91
|
+
"""
|
|
92
|
+
Returns a MPS in ground state |000..0>.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
num_sites: the number of qubits
|
|
96
|
+
precision: the precision with which to truncate here or in tdvp
|
|
97
|
+
max_bond_dim: the maximum bond dimension to allow
|
|
98
|
+
num_gpus_to_use: distribute the factors over this many GPUs
|
|
99
|
+
0=all factors to cpu
|
|
100
|
+
"""
|
|
101
|
+
if num_sites <= 1:
|
|
102
|
+
raise ValueError("For 1 qubit states, do state vector")
|
|
103
|
+
|
|
104
|
+
return cls(
|
|
105
|
+
[
|
|
106
|
+
torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
|
|
107
|
+
for _ in range(num_sites)
|
|
108
|
+
],
|
|
109
|
+
precision=precision,
|
|
110
|
+
max_bond_dim=max_bond_dim,
|
|
111
|
+
num_gpus_to_use=num_gpus_to_use,
|
|
112
|
+
orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def __repr__(self) -> str:
|
|
116
|
+
result = "["
|
|
117
|
+
for fac in self.factors:
|
|
118
|
+
result += repr(fac)
|
|
119
|
+
result += ", "
|
|
120
|
+
result += "]"
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
def orthogonalize(self, desired_orthogonality_center: int = 0) -> int:
|
|
124
|
+
"""
|
|
125
|
+
Orthogonalize the state on the given orthogonality center.
|
|
126
|
+
|
|
127
|
+
Returns the new orthogonality center index as an integer,
|
|
128
|
+
this is convenient for type-checking purposes.
|
|
129
|
+
"""
|
|
130
|
+
assert (
|
|
131
|
+
0 <= desired_orthogonality_center < self.num_sites
|
|
132
|
+
), f"Cannot move orthogonality center to nonexistent qubit #{desired_orthogonality_center}"
|
|
133
|
+
|
|
134
|
+
lr_swipe_start = (
|
|
135
|
+
self.orthogonality_center if self.orthogonality_center is not None else 0
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for i in range(lr_swipe_start, desired_orthogonality_center):
|
|
139
|
+
q, r = torch.linalg.qr(self.factors[i].reshape(-1, self.factors[i].shape[2]))
|
|
140
|
+
self.factors[i] = q.reshape(self.factors[i].shape[0], 2, -1)
|
|
141
|
+
self.factors[i + 1] = torch.tensordot(
|
|
142
|
+
r.to(self.factors[i + 1].device), self.factors[i + 1], dims=1
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
rl_swipe_start = (
|
|
146
|
+
self.orthogonality_center
|
|
147
|
+
if self.orthogonality_center is not None
|
|
148
|
+
else (self.num_sites - 1)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
for i in range(rl_swipe_start, desired_orthogonality_center, -1):
|
|
152
|
+
q, r = torch.linalg.qr(
|
|
153
|
+
self.factors[i].reshape(self.factors[i].shape[0], -1).mT,
|
|
154
|
+
)
|
|
155
|
+
self.factors[i] = q.mT.reshape(-1, 2, self.factors[i].shape[2])
|
|
156
|
+
self.factors[i - 1] = torch.tensordot(
|
|
157
|
+
self.factors[i - 1], r.to(self.factors[i - 1].device), ([2], [1])
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
self.orthogonality_center = desired_orthogonality_center
|
|
161
|
+
|
|
162
|
+
return desired_orthogonality_center
|
|
163
|
+
|
|
164
|
+
def truncate(self) -> None:
|
|
165
|
+
"""
|
|
166
|
+
SVD based truncation of the state. Puts the orthogonality center at the first qubit.
|
|
167
|
+
Calls orthogonalize on the last qubit, and then sweeps a series of SVDs right-left.
|
|
168
|
+
Uses self.precision and self.max_bond_dim for determining accuracy.
|
|
169
|
+
An in-place operation.
|
|
170
|
+
"""
|
|
171
|
+
self.orthogonalize(self.num_sites - 1)
|
|
172
|
+
truncate_impl(
|
|
173
|
+
self.factors,
|
|
174
|
+
max_error=self.precision,
|
|
175
|
+
max_rank=self.max_bond_dim,
|
|
176
|
+
)
|
|
177
|
+
self.orthogonality_center = 0
|
|
178
|
+
|
|
179
|
+
def get_max_bond_dim(self) -> int:
|
|
180
|
+
"""
|
|
181
|
+
Return the max bond dimension of this MPS.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
the largest bond dimension in the state
|
|
185
|
+
"""
|
|
186
|
+
return max((x.shape[2] for x in self.factors), default=0)
|
|
187
|
+
|
|
188
|
+
def sample(
|
|
189
|
+
self, num_shots: int, p_false_pos: float = 0.0, p_false_neg: float = 0.0
|
|
190
|
+
) -> Counter[str]:
|
|
191
|
+
"""
|
|
192
|
+
Samples bitstrings, taking into account the specified error rates.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
num_shots: how many bitstrings to sample
|
|
196
|
+
p_false_pos: the rate at which a 0 is read as a 1
|
|
197
|
+
p_false_neg: teh rate at which a 1 is read as a 0
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
the measured bitstrings, by count
|
|
201
|
+
"""
|
|
202
|
+
self.orthogonalize(0)
|
|
203
|
+
|
|
204
|
+
num_qubits = len(self.factors)
|
|
205
|
+
rnd_matrix = torch.rand(num_shots, num_qubits)
|
|
206
|
+
bitstrings = Counter(
|
|
207
|
+
self._sample_implementation(rnd_matrix[x, :]) for x in range(num_shots)
|
|
208
|
+
)
|
|
209
|
+
if p_false_neg > 0 or p_false_pos > 0:
|
|
210
|
+
bitstrings = apply_measurement_errors(
|
|
211
|
+
bitstrings,
|
|
212
|
+
p_false_pos=p_false_pos,
|
|
213
|
+
p_false_neg=p_false_neg,
|
|
214
|
+
)
|
|
215
|
+
return bitstrings
|
|
216
|
+
|
|
217
|
+
def norm(self) -> float:
|
|
218
|
+
"""Computes the norm of the MPS."""
|
|
219
|
+
orthogonality_center = (
|
|
220
|
+
self.orthogonality_center
|
|
221
|
+
if self.orthogonality_center is not None
|
|
222
|
+
else self.orthogonalize(0)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return float(
|
|
226
|
+
torch.linalg.norm(self.factors[orthogonality_center].to("cpu")).item()
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _sample_implementation(self, rnd_vector: torch.Tensor) -> str:
|
|
230
|
+
"""
|
|
231
|
+
Samples this MPS once, returning the resulting bitstring.
|
|
232
|
+
"""
|
|
233
|
+
assert rnd_vector.shape == (self.num_sites,)
|
|
234
|
+
assert self.orthogonality_center == 0
|
|
235
|
+
|
|
236
|
+
num_qubits = len(self.factors)
|
|
237
|
+
|
|
238
|
+
bitstring = ""
|
|
239
|
+
acc_mps_j: torch.tensor = self.factors[0]
|
|
240
|
+
|
|
241
|
+
for qubit in range(num_qubits):
|
|
242
|
+
# comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
|
|
243
|
+
comp_basis = 0 # check if the qubit is in |0>
|
|
244
|
+
# Measure the qubit j by applying the projector onto nth comp basis state
|
|
245
|
+
tensorj_projected_n = acc_mps_j[:, comp_basis, :]
|
|
246
|
+
probability_n = (tensorj_projected_n.norm() ** 2).item()
|
|
247
|
+
|
|
248
|
+
if rnd_vector[qubit] > probability_n:
|
|
249
|
+
# the qubit is in |1>
|
|
250
|
+
comp_basis = 1
|
|
251
|
+
tensorj_projected_n = acc_mps_j[:, comp_basis, :]
|
|
252
|
+
probability_n = 1 - probability_n
|
|
253
|
+
|
|
254
|
+
bitstring += str(comp_basis)
|
|
255
|
+
if qubit < num_qubits - 1:
|
|
256
|
+
acc_mps_j = torch.tensordot(
|
|
257
|
+
tensorj_projected_n.to(device=self.factors[qubit + 1].device),
|
|
258
|
+
self.factors[qubit + 1],
|
|
259
|
+
dims=1,
|
|
260
|
+
)
|
|
261
|
+
acc_mps_j /= math.sqrt(probability_n)
|
|
262
|
+
|
|
263
|
+
return bitstring
|
|
264
|
+
|
|
265
|
+
def inner(self, other: State) -> float | complex:
|
|
266
|
+
"""
|
|
267
|
+
Compute the inner product between this state and other.
|
|
268
|
+
Note that self is the left state in the inner product,
|
|
269
|
+
so this function is linear in other, and anti-linear in self
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
other: the other state
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
inner product
|
|
276
|
+
"""
|
|
277
|
+
assert isinstance(other, MPS), "Other state also needs to be an MPS"
|
|
278
|
+
assert (
|
|
279
|
+
self.num_sites == other.num_sites
|
|
280
|
+
), "States do not have the same number of sites"
|
|
281
|
+
|
|
282
|
+
acc = torch.ones(1, 1, dtype=self.factors[0].dtype, device=self.factors[0].device)
|
|
283
|
+
|
|
284
|
+
for i in range(self.num_sites):
|
|
285
|
+
acc = acc.to(self.factors[i].device)
|
|
286
|
+
acc = torch.tensordot(acc, other.factors[i].to(acc.device), dims=1)
|
|
287
|
+
acc = torch.tensordot(self.factors[i].conj(), acc, dims=([0, 1], [0, 1]))
|
|
288
|
+
|
|
289
|
+
return acc.item() # type: ignore[no-any-return]
|
|
290
|
+
|
|
291
|
+
def get_memory_footprint(self) -> float:
|
|
292
|
+
"""
|
|
293
|
+
Returns the number of MBs of memory occupied to store the state
|
|
294
|
+
|
|
295
|
+
Returns:
|
|
296
|
+
the memory in MBs
|
|
297
|
+
"""
|
|
298
|
+
return ( # type: ignore[no-any-return]
|
|
299
|
+
sum(factor.element_size() * factor.numel() for factor in self.factors) * 1e-6
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def __add__(self, other: State) -> MPS:
|
|
303
|
+
"""
|
|
304
|
+
Returns the sum of two MPSs, computed with a direct algorithm.
|
|
305
|
+
The resulting MPS is orthogonalized on the first site and truncated
|
|
306
|
+
up to `self.precision`.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
other: the other state
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
the summed state
|
|
313
|
+
"""
|
|
314
|
+
assert isinstance(other, MPS), "Other state also needs to be an MPS"
|
|
315
|
+
new_tt = add_factors(self.factors, other.factors)
|
|
316
|
+
result = MPS(
|
|
317
|
+
new_tt,
|
|
318
|
+
precision=self.precision,
|
|
319
|
+
max_bond_dim=self.max_bond_dim,
|
|
320
|
+
num_gpus_to_use=None,
|
|
321
|
+
orthogonality_center=None, # Orthogonality is lost.
|
|
322
|
+
)
|
|
323
|
+
result.truncate()
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
def __rmul__(self, scalar: complex) -> MPS:
|
|
327
|
+
"""
|
|
328
|
+
Multiply an MPS by a scalar.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
scalar: the scale factor
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
the scaled MPS
|
|
335
|
+
"""
|
|
336
|
+
which = (
|
|
337
|
+
self.orthogonality_center
|
|
338
|
+
if self.orthogonality_center is not None
|
|
339
|
+
else 0 # No need to orthogonalize for scaling.
|
|
340
|
+
)
|
|
341
|
+
factors = scale_factors(self.factors, scalar, which=which)
|
|
342
|
+
return MPS(
|
|
343
|
+
factors,
|
|
344
|
+
precision=self.precision,
|
|
345
|
+
max_bond_dim=self.max_bond_dim,
|
|
346
|
+
num_gpus_to_use=None,
|
|
347
|
+
orthogonality_center=self.orthogonality_center,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def __imul__(self, scalar: complex) -> MPS:
|
|
351
|
+
return self.__rmul__(scalar)
|
|
352
|
+
|
|
353
|
+
@staticmethod
|
|
354
|
+
def from_state_string(
|
|
355
|
+
*,
|
|
356
|
+
basis: Iterable[str],
|
|
357
|
+
nqubits: int,
|
|
358
|
+
strings: dict[str, complex],
|
|
359
|
+
**kwargs: Any,
|
|
360
|
+
) -> MPS:
|
|
361
|
+
"""
|
|
362
|
+
See the base class.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
basis: A tuple containing the basis states (e.g., ('r', 'g')).
|
|
366
|
+
nqubits: the number of qubits.
|
|
367
|
+
strings: A dictionary mapping state strings to complex or floats amplitudes.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
The resulting MPS representation of the state.s
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
basis = set(basis)
|
|
374
|
+
if basis == {"r", "g"}:
|
|
375
|
+
one = "r"
|
|
376
|
+
elif basis == {"0", "1"}:
|
|
377
|
+
one = "1"
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError("Unsupported basis provided")
|
|
380
|
+
|
|
381
|
+
basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
|
|
382
|
+
basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
|
|
383
|
+
|
|
384
|
+
accum_mps = MPS(
|
|
385
|
+
[torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
|
|
386
|
+
orthogonality_center=0,
|
|
387
|
+
**kwargs,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
for state, amplitude in strings.items():
|
|
391
|
+
factors = [basis_1 if ch == one else basis_0 for ch in state]
|
|
392
|
+
accum_mps += amplitude * MPS(factors, **kwargs)
|
|
393
|
+
norm = accum_mps.norm()
|
|
394
|
+
if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
|
|
395
|
+
print("\nThe state is not normalized, normalizing it for you.")
|
|
396
|
+
accum_mps *= 1 / norm
|
|
397
|
+
|
|
398
|
+
return accum_mps
|
|
399
|
+
|
|
400
|
+
def expect_batch(self, single_qubit_operators: torch.Tensor) -> torch.Tensor:
|
|
401
|
+
"""
|
|
402
|
+
Computes expectation values for each qubit and each single qubit operator in
|
|
403
|
+
the batched input tensor.
|
|
404
|
+
|
|
405
|
+
Returns a tensor T such that T[q, i] is the expectation value for qubit #q
|
|
406
|
+
and operator single_qubit_operators[i].
|
|
407
|
+
"""
|
|
408
|
+
orthogonality_center = (
|
|
409
|
+
self.orthogonality_center
|
|
410
|
+
if self.orthogonality_center is not None
|
|
411
|
+
else self.orthogonalize(0)
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
result = torch.zeros(
|
|
415
|
+
self.num_sites, single_qubit_operators.shape[0], dtype=torch.complex128
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
center_factor = self.factors[orthogonality_center]
|
|
419
|
+
for qubit_index in range(orthogonality_center, self.num_sites):
|
|
420
|
+
temp = torch.tensordot(center_factor.conj(), center_factor, ([0, 2], [0, 2]))
|
|
421
|
+
|
|
422
|
+
result[qubit_index] = torch.tensordot(
|
|
423
|
+
single_qubit_operators.to(temp.device), temp, dims=2
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
if qubit_index < self.num_sites - 1:
|
|
427
|
+
_, r = torch.linalg.qr(center_factor.reshape(-1, center_factor.shape[2]))
|
|
428
|
+
center_factor = torch.tensordot(
|
|
429
|
+
r, self.factors[qubit_index + 1].to(r.device), dims=1
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
center_factor = self.factors[orthogonality_center]
|
|
433
|
+
for qubit_index in range(orthogonality_center - 1, -1, -1):
|
|
434
|
+
_, r = torch.linalg.qr(
|
|
435
|
+
center_factor.reshape(center_factor.shape[0], -1).mT,
|
|
436
|
+
)
|
|
437
|
+
center_factor = torch.tensordot(
|
|
438
|
+
self.factors[qubit_index],
|
|
439
|
+
r.to(self.factors[qubit_index].device),
|
|
440
|
+
([2], [1]),
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
temp = torch.tensordot(center_factor.conj(), center_factor, ([0, 2], [0, 2]))
|
|
444
|
+
|
|
445
|
+
result[qubit_index] = torch.tensordot(
|
|
446
|
+
single_qubit_operators.to(temp.device), temp, dims=2
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
return result
|
|
450
|
+
|
|
451
|
+
def apply(self, qubit_index: int, single_qubit_operator: torch.Tensor) -> None:
|
|
452
|
+
"""
|
|
453
|
+
Apply given single qubit operator to qubit qubit_index, leaving the MPS
|
|
454
|
+
orthogonalized on that qubit.
|
|
455
|
+
"""
|
|
456
|
+
self.orthogonalize(qubit_index)
|
|
457
|
+
|
|
458
|
+
self.factors[qubit_index] = torch.tensordot(
|
|
459
|
+
self.factors[qubit_index],
|
|
460
|
+
single_qubit_operator.to(self.factors[qubit_index].device),
|
|
461
|
+
([1], [1]),
|
|
462
|
+
).transpose(1, 2)
|
|
463
|
+
|
|
464
|
+
def get_correlation_matrix(
|
|
465
|
+
self, *, operator: torch.Tensor = n_operator
|
|
466
|
+
) -> list[list[float]]:
|
|
467
|
+
"""
|
|
468
|
+
Efficiently compute the symmetric correlation matrix
|
|
469
|
+
C_ij = <self|operator_i operator_j|self>
|
|
470
|
+
in basis ("r", "g").
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
operator: a 2x2 Torch tensor to use
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
the corresponding correlation matrix
|
|
477
|
+
"""
|
|
478
|
+
assert operator.shape == (2, 2)
|
|
479
|
+
|
|
480
|
+
result = [[0.0 for _ in range(self.num_sites)] for _ in range(self.num_sites)]
|
|
481
|
+
|
|
482
|
+
for left in range(0, self.num_sites):
|
|
483
|
+
self.orthogonalize(left)
|
|
484
|
+
accumulator = torch.tensordot(
|
|
485
|
+
self.factors[left],
|
|
486
|
+
operator.to(self.factors[left].device),
|
|
487
|
+
dims=([1], [0]),
|
|
488
|
+
)
|
|
489
|
+
accumulator = torch.tensordot(
|
|
490
|
+
accumulator, self.factors[left].conj(), dims=([0, 2], [0, 1])
|
|
491
|
+
)
|
|
492
|
+
result[left][left] = accumulator.trace().item().real
|
|
493
|
+
for right in range(left + 1, self.num_sites):
|
|
494
|
+
partial = torch.tensordot(
|
|
495
|
+
accumulator.to(self.factors[right].device),
|
|
496
|
+
self.factors[right],
|
|
497
|
+
dims=([0], [0]),
|
|
498
|
+
)
|
|
499
|
+
partial = torch.tensordot(
|
|
500
|
+
partial, self.factors[right].conj(), dims=([0], [0])
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
result[left][right] = (
|
|
504
|
+
torch.tensordot(
|
|
505
|
+
partial, operator.to(partial.device), dims=([0, 2], [0, 1])
|
|
506
|
+
)
|
|
507
|
+
.trace()
|
|
508
|
+
.item()
|
|
509
|
+
.real
|
|
510
|
+
)
|
|
511
|
+
result[right][left] = result[left][right]
|
|
512
|
+
accumulator = tensor_trace(partial, 0, 2)
|
|
513
|
+
|
|
514
|
+
return result
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def inner(left: MPS, right: MPS) -> float | complex:
|
|
518
|
+
"""
|
|
519
|
+
Wrapper around MPS.inner.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
left: the anti-linear argument
|
|
523
|
+
right: the linear argument
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
the inner product
|
|
527
|
+
"""
|
|
528
|
+
return left.inner(right)
|
emu_mps/mps_backend.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from pulser import Sequence
|
|
2
|
+
|
|
3
|
+
from emu_base import Backend, BackendConfig, Results
|
|
4
|
+
from emu_mps.mps_config import MPSConfig
|
|
5
|
+
from emu_mps.mps_backend_impl import create_impl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MPSBackend(Backend):
|
|
9
|
+
"""
|
|
10
|
+
A backend for emulating Pulser sequences using Matrix Product States (MPS),
|
|
11
|
+
aka tensor trains.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
|
|
15
|
+
"""
|
|
16
|
+
Emulates the given sequence.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
sequence: a Pulser sequence to simulate
|
|
20
|
+
mps_config: the backends config. Should be of type MPSConfig
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
the simulation results
|
|
24
|
+
"""
|
|
25
|
+
assert isinstance(mps_config, MPSConfig)
|
|
26
|
+
|
|
27
|
+
self.validate_sequence(sequence)
|
|
28
|
+
|
|
29
|
+
impl = create_impl(sequence, mps_config)
|
|
30
|
+
impl.init() # This is separate from the constructor for testing purposes.
|
|
31
|
+
|
|
32
|
+
while not impl.is_finished():
|
|
33
|
+
impl.progress()
|
|
34
|
+
|
|
35
|
+
return impl.results
|