emu-mps 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/mps.py ADDED
@@ -0,0 +1,528 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections import Counter
5
+ from typing import Any, List, Optional, Iterable
6
+
7
+ import torch
8
+
9
+ from emu_base import State
10
+ from emu_mps.algebra import add_factors, scale_factors
11
+ from emu_mps.utils import (
12
+ DEVICE_COUNT,
13
+ apply_measurement_errors,
14
+ assign_devices,
15
+ truncate_impl,
16
+ tensor_trace,
17
+ n_operator,
18
+ )
19
+
20
+
21
+ class MPS(State):
22
+ """
23
+ Matrix Product State, aka tensor train.
24
+
25
+ Each tensor has 3 dimensions ordered as such: (left bond, site, right bond).
26
+
27
+ Only qubits are supported.
28
+ """
29
+
30
+ DEFAULT_MAX_BOND_DIM: int = 1024
31
+ DEFAULT_PRECISION: float = 1e-5
32
+
33
+ def __init__(
34
+ self,
35
+ factors: List[torch.Tensor],
36
+ /,
37
+ *,
38
+ orthogonality_center: Optional[int] = None,
39
+ precision: float = DEFAULT_PRECISION,
40
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
41
+ num_gpus_to_use: Optional[int] = DEVICE_COUNT,
42
+ ):
43
+ """
44
+ This constructor creates a MPS directly from a list of tensors. It is for internal use only.
45
+
46
+ Args:
47
+ factors: the tensors for each site
48
+ WARNING: for efficiency in a lot of use cases, this list of tensors
49
+ IS NOT DEEP-COPIED. Therefore, the new MPS object is not necessarily
50
+ the exclusive owner of the list and its tensors. As a consequence,
51
+ beware of potential external modifications affecting the list or the tensors.
52
+ You are responsible for deciding whether to pass its own exclusive copy
53
+ of the data to this constructor, or some shared objects.
54
+ orthogonality_center: the orthogonality center of the MPS, or None (in which case
55
+ it will be orthogonalized when needed)
56
+ precision: the precision with which to truncate here or in tdvp
57
+ max_bond_dim: the maximum bond dimension to allow
58
+ num_gpus_to_use: distribute the factors over this many GPUs
59
+ 0=all factors to cpu, None=keep the existing device assignment.
60
+ """
61
+ self.precision = precision
62
+ self.max_bond_dim = max_bond_dim
63
+
64
+ assert all(
65
+ factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
66
+ ), "The dimensions of consecutive tensors should match"
67
+ assert (
68
+ factors[0].shape[0] == 1 and factors[-1].shape[2] == 1
69
+ ), "The dimension of the left (right) link of the first (last) tensor should be 1"
70
+
71
+ self.factors = factors
72
+ self.num_sites = len(factors)
73
+ assert self.num_sites > 1 # otherwise, do state vector
74
+
75
+ assert (orthogonality_center is None) or (
76
+ 0 <= orthogonality_center < self.num_sites
77
+ ), "Invalid orthogonality center provided"
78
+ self.orthogonality_center = orthogonality_center
79
+
80
+ if num_gpus_to_use is not None:
81
+ assign_devices(self.factors, min(DEVICE_COUNT, num_gpus_to_use))
82
+
83
+ @classmethod
84
+ def make(
85
+ cls,
86
+ num_sites: int,
87
+ precision: float = DEFAULT_PRECISION,
88
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
89
+ num_gpus_to_use: int = DEVICE_COUNT,
90
+ ) -> MPS:
91
+ """
92
+ Returns a MPS in ground state |000..0>.
93
+
94
+ Args:
95
+ num_sites: the number of qubits
96
+ precision: the precision with which to truncate here or in tdvp
97
+ max_bond_dim: the maximum bond dimension to allow
98
+ num_gpus_to_use: distribute the factors over this many GPUs
99
+ 0=all factors to cpu
100
+ """
101
+ if num_sites <= 1:
102
+ raise ValueError("For 1 qubit states, do state vector")
103
+
104
+ return cls(
105
+ [
106
+ torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
107
+ for _ in range(num_sites)
108
+ ],
109
+ precision=precision,
110
+ max_bond_dim=max_bond_dim,
111
+ num_gpus_to_use=num_gpus_to_use,
112
+ orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
113
+ )
114
+
115
+ def __repr__(self) -> str:
116
+ result = "["
117
+ for fac in self.factors:
118
+ result += repr(fac)
119
+ result += ", "
120
+ result += "]"
121
+ return result
122
+
123
+ def orthogonalize(self, desired_orthogonality_center: int = 0) -> int:
124
+ """
125
+ Orthogonalize the state on the given orthogonality center.
126
+
127
+ Returns the new orthogonality center index as an integer,
128
+ this is convenient for type-checking purposes.
129
+ """
130
+ assert (
131
+ 0 <= desired_orthogonality_center < self.num_sites
132
+ ), f"Cannot move orthogonality center to nonexistent qubit #{desired_orthogonality_center}"
133
+
134
+ lr_swipe_start = (
135
+ self.orthogonality_center if self.orthogonality_center is not None else 0
136
+ )
137
+
138
+ for i in range(lr_swipe_start, desired_orthogonality_center):
139
+ q, r = torch.linalg.qr(self.factors[i].reshape(-1, self.factors[i].shape[2]))
140
+ self.factors[i] = q.reshape(self.factors[i].shape[0], 2, -1)
141
+ self.factors[i + 1] = torch.tensordot(
142
+ r.to(self.factors[i + 1].device), self.factors[i + 1], dims=1
143
+ )
144
+
145
+ rl_swipe_start = (
146
+ self.orthogonality_center
147
+ if self.orthogonality_center is not None
148
+ else (self.num_sites - 1)
149
+ )
150
+
151
+ for i in range(rl_swipe_start, desired_orthogonality_center, -1):
152
+ q, r = torch.linalg.qr(
153
+ self.factors[i].reshape(self.factors[i].shape[0], -1).mT,
154
+ )
155
+ self.factors[i] = q.mT.reshape(-1, 2, self.factors[i].shape[2])
156
+ self.factors[i - 1] = torch.tensordot(
157
+ self.factors[i - 1], r.to(self.factors[i - 1].device), ([2], [1])
158
+ )
159
+
160
+ self.orthogonality_center = desired_orthogonality_center
161
+
162
+ return desired_orthogonality_center
163
+
164
+ def truncate(self) -> None:
165
+ """
166
+ SVD based truncation of the state. Puts the orthogonality center at the first qubit.
167
+ Calls orthogonalize on the last qubit, and then sweeps a series of SVDs right-left.
168
+ Uses self.precision and self.max_bond_dim for determining accuracy.
169
+ An in-place operation.
170
+ """
171
+ self.orthogonalize(self.num_sites - 1)
172
+ truncate_impl(
173
+ self.factors,
174
+ max_error=self.precision,
175
+ max_rank=self.max_bond_dim,
176
+ )
177
+ self.orthogonality_center = 0
178
+
179
+ def get_max_bond_dim(self) -> int:
180
+ """
181
+ Return the max bond dimension of this MPS.
182
+
183
+ Returns:
184
+ the largest bond dimension in the state
185
+ """
186
+ return max((x.shape[2] for x in self.factors), default=0)
187
+
188
+ def sample(
189
+ self, num_shots: int, p_false_pos: float = 0.0, p_false_neg: float = 0.0
190
+ ) -> Counter[str]:
191
+ """
192
+ Samples bitstrings, taking into account the specified error rates.
193
+
194
+ Args:
195
+ num_shots: how many bitstrings to sample
196
+ p_false_pos: the rate at which a 0 is read as a 1
197
+ p_false_neg: teh rate at which a 1 is read as a 0
198
+
199
+ Returns:
200
+ the measured bitstrings, by count
201
+ """
202
+ self.orthogonalize(0)
203
+
204
+ num_qubits = len(self.factors)
205
+ rnd_matrix = torch.rand(num_shots, num_qubits)
206
+ bitstrings = Counter(
207
+ self._sample_implementation(rnd_matrix[x, :]) for x in range(num_shots)
208
+ )
209
+ if p_false_neg > 0 or p_false_pos > 0:
210
+ bitstrings = apply_measurement_errors(
211
+ bitstrings,
212
+ p_false_pos=p_false_pos,
213
+ p_false_neg=p_false_neg,
214
+ )
215
+ return bitstrings
216
+
217
+ def norm(self) -> float:
218
+ """Computes the norm of the MPS."""
219
+ orthogonality_center = (
220
+ self.orthogonality_center
221
+ if self.orthogonality_center is not None
222
+ else self.orthogonalize(0)
223
+ )
224
+
225
+ return float(
226
+ torch.linalg.norm(self.factors[orthogonality_center].to("cpu")).item()
227
+ )
228
+
229
+ def _sample_implementation(self, rnd_vector: torch.Tensor) -> str:
230
+ """
231
+ Samples this MPS once, returning the resulting bitstring.
232
+ """
233
+ assert rnd_vector.shape == (self.num_sites,)
234
+ assert self.orthogonality_center == 0
235
+
236
+ num_qubits = len(self.factors)
237
+
238
+ bitstring = ""
239
+ acc_mps_j: torch.tensor = self.factors[0]
240
+
241
+ for qubit in range(num_qubits):
242
+ # comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
243
+ comp_basis = 0 # check if the qubit is in |0>
244
+ # Measure the qubit j by applying the projector onto nth comp basis state
245
+ tensorj_projected_n = acc_mps_j[:, comp_basis, :]
246
+ probability_n = (tensorj_projected_n.norm() ** 2).item()
247
+
248
+ if rnd_vector[qubit] > probability_n:
249
+ # the qubit is in |1>
250
+ comp_basis = 1
251
+ tensorj_projected_n = acc_mps_j[:, comp_basis, :]
252
+ probability_n = 1 - probability_n
253
+
254
+ bitstring += str(comp_basis)
255
+ if qubit < num_qubits - 1:
256
+ acc_mps_j = torch.tensordot(
257
+ tensorj_projected_n.to(device=self.factors[qubit + 1].device),
258
+ self.factors[qubit + 1],
259
+ dims=1,
260
+ )
261
+ acc_mps_j /= math.sqrt(probability_n)
262
+
263
+ return bitstring
264
+
265
+ def inner(self, other: State) -> float | complex:
266
+ """
267
+ Compute the inner product between this state and other.
268
+ Note that self is the left state in the inner product,
269
+ so this function is linear in other, and anti-linear in self
270
+
271
+ Args:
272
+ other: the other state
273
+
274
+ Returns:
275
+ inner product
276
+ """
277
+ assert isinstance(other, MPS), "Other state also needs to be an MPS"
278
+ assert (
279
+ self.num_sites == other.num_sites
280
+ ), "States do not have the same number of sites"
281
+
282
+ acc = torch.ones(1, 1, dtype=self.factors[0].dtype, device=self.factors[0].device)
283
+
284
+ for i in range(self.num_sites):
285
+ acc = acc.to(self.factors[i].device)
286
+ acc = torch.tensordot(acc, other.factors[i].to(acc.device), dims=1)
287
+ acc = torch.tensordot(self.factors[i].conj(), acc, dims=([0, 1], [0, 1]))
288
+
289
+ return acc.item() # type: ignore[no-any-return]
290
+
291
+ def get_memory_footprint(self) -> float:
292
+ """
293
+ Returns the number of MBs of memory occupied to store the state
294
+
295
+ Returns:
296
+ the memory in MBs
297
+ """
298
+ return ( # type: ignore[no-any-return]
299
+ sum(factor.element_size() * factor.numel() for factor in self.factors) * 1e-6
300
+ )
301
+
302
+ def __add__(self, other: State) -> MPS:
303
+ """
304
+ Returns the sum of two MPSs, computed with a direct algorithm.
305
+ The resulting MPS is orthogonalized on the first site and truncated
306
+ up to `self.precision`.
307
+
308
+ Args:
309
+ other: the other state
310
+
311
+ Returns:
312
+ the summed state
313
+ """
314
+ assert isinstance(other, MPS), "Other state also needs to be an MPS"
315
+ new_tt = add_factors(self.factors, other.factors)
316
+ result = MPS(
317
+ new_tt,
318
+ precision=self.precision,
319
+ max_bond_dim=self.max_bond_dim,
320
+ num_gpus_to_use=None,
321
+ orthogonality_center=None, # Orthogonality is lost.
322
+ )
323
+ result.truncate()
324
+ return result
325
+
326
+ def __rmul__(self, scalar: complex) -> MPS:
327
+ """
328
+ Multiply an MPS by a scalar.
329
+
330
+ Args:
331
+ scalar: the scale factor
332
+
333
+ Returns:
334
+ the scaled MPS
335
+ """
336
+ which = (
337
+ self.orthogonality_center
338
+ if self.orthogonality_center is not None
339
+ else 0 # No need to orthogonalize for scaling.
340
+ )
341
+ factors = scale_factors(self.factors, scalar, which=which)
342
+ return MPS(
343
+ factors,
344
+ precision=self.precision,
345
+ max_bond_dim=self.max_bond_dim,
346
+ num_gpus_to_use=None,
347
+ orthogonality_center=self.orthogonality_center,
348
+ )
349
+
350
+ def __imul__(self, scalar: complex) -> MPS:
351
+ return self.__rmul__(scalar)
352
+
353
+ @staticmethod
354
+ def from_state_string(
355
+ *,
356
+ basis: Iterable[str],
357
+ nqubits: int,
358
+ strings: dict[str, complex],
359
+ **kwargs: Any,
360
+ ) -> MPS:
361
+ """
362
+ See the base class.
363
+
364
+ Args:
365
+ basis: A tuple containing the basis states (e.g., ('r', 'g')).
366
+ nqubits: the number of qubits.
367
+ strings: A dictionary mapping state strings to complex or floats amplitudes.
368
+
369
+ Returns:
370
+ The resulting MPS representation of the state.s
371
+ """
372
+
373
+ basis = set(basis)
374
+ if basis == {"r", "g"}:
375
+ one = "r"
376
+ elif basis == {"0", "1"}:
377
+ one = "1"
378
+ else:
379
+ raise ValueError("Unsupported basis provided")
380
+
381
+ basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
382
+ basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
383
+
384
+ accum_mps = MPS(
385
+ [torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
386
+ orthogonality_center=0,
387
+ **kwargs,
388
+ )
389
+
390
+ for state, amplitude in strings.items():
391
+ factors = [basis_1 if ch == one else basis_0 for ch in state]
392
+ accum_mps += amplitude * MPS(factors, **kwargs)
393
+ norm = accum_mps.norm()
394
+ if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
395
+ print("\nThe state is not normalized, normalizing it for you.")
396
+ accum_mps *= 1 / norm
397
+
398
+ return accum_mps
399
+
400
+ def expect_batch(self, single_qubit_operators: torch.Tensor) -> torch.Tensor:
401
+ """
402
+ Computes expectation values for each qubit and each single qubit operator in
403
+ the batched input tensor.
404
+
405
+ Returns a tensor T such that T[q, i] is the expectation value for qubit #q
406
+ and operator single_qubit_operators[i].
407
+ """
408
+ orthogonality_center = (
409
+ self.orthogonality_center
410
+ if self.orthogonality_center is not None
411
+ else self.orthogonalize(0)
412
+ )
413
+
414
+ result = torch.zeros(
415
+ self.num_sites, single_qubit_operators.shape[0], dtype=torch.complex128
416
+ )
417
+
418
+ center_factor = self.factors[orthogonality_center]
419
+ for qubit_index in range(orthogonality_center, self.num_sites):
420
+ temp = torch.tensordot(center_factor.conj(), center_factor, ([0, 2], [0, 2]))
421
+
422
+ result[qubit_index] = torch.tensordot(
423
+ single_qubit_operators.to(temp.device), temp, dims=2
424
+ )
425
+
426
+ if qubit_index < self.num_sites - 1:
427
+ _, r = torch.linalg.qr(center_factor.reshape(-1, center_factor.shape[2]))
428
+ center_factor = torch.tensordot(
429
+ r, self.factors[qubit_index + 1].to(r.device), dims=1
430
+ )
431
+
432
+ center_factor = self.factors[orthogonality_center]
433
+ for qubit_index in range(orthogonality_center - 1, -1, -1):
434
+ _, r = torch.linalg.qr(
435
+ center_factor.reshape(center_factor.shape[0], -1).mT,
436
+ )
437
+ center_factor = torch.tensordot(
438
+ self.factors[qubit_index],
439
+ r.to(self.factors[qubit_index].device),
440
+ ([2], [1]),
441
+ )
442
+
443
+ temp = torch.tensordot(center_factor.conj(), center_factor, ([0, 2], [0, 2]))
444
+
445
+ result[qubit_index] = torch.tensordot(
446
+ single_qubit_operators.to(temp.device), temp, dims=2
447
+ )
448
+
449
+ return result
450
+
451
+ def apply(self, qubit_index: int, single_qubit_operator: torch.Tensor) -> None:
452
+ """
453
+ Apply given single qubit operator to qubit qubit_index, leaving the MPS
454
+ orthogonalized on that qubit.
455
+ """
456
+ self.orthogonalize(qubit_index)
457
+
458
+ self.factors[qubit_index] = torch.tensordot(
459
+ self.factors[qubit_index],
460
+ single_qubit_operator.to(self.factors[qubit_index].device),
461
+ ([1], [1]),
462
+ ).transpose(1, 2)
463
+
464
+ def get_correlation_matrix(
465
+ self, *, operator: torch.Tensor = n_operator
466
+ ) -> list[list[float]]:
467
+ """
468
+ Efficiently compute the symmetric correlation matrix
469
+ C_ij = <self|operator_i operator_j|self>
470
+ in basis ("r", "g").
471
+
472
+ Args:
473
+ operator: a 2x2 Torch tensor to use
474
+
475
+ Returns:
476
+ the corresponding correlation matrix
477
+ """
478
+ assert operator.shape == (2, 2)
479
+
480
+ result = [[0.0 for _ in range(self.num_sites)] for _ in range(self.num_sites)]
481
+
482
+ for left in range(0, self.num_sites):
483
+ self.orthogonalize(left)
484
+ accumulator = torch.tensordot(
485
+ self.factors[left],
486
+ operator.to(self.factors[left].device),
487
+ dims=([1], [0]),
488
+ )
489
+ accumulator = torch.tensordot(
490
+ accumulator, self.factors[left].conj(), dims=([0, 2], [0, 1])
491
+ )
492
+ result[left][left] = accumulator.trace().item().real
493
+ for right in range(left + 1, self.num_sites):
494
+ partial = torch.tensordot(
495
+ accumulator.to(self.factors[right].device),
496
+ self.factors[right],
497
+ dims=([0], [0]),
498
+ )
499
+ partial = torch.tensordot(
500
+ partial, self.factors[right].conj(), dims=([0], [0])
501
+ )
502
+
503
+ result[left][right] = (
504
+ torch.tensordot(
505
+ partial, operator.to(partial.device), dims=([0, 2], [0, 1])
506
+ )
507
+ .trace()
508
+ .item()
509
+ .real
510
+ )
511
+ result[right][left] = result[left][right]
512
+ accumulator = tensor_trace(partial, 0, 2)
513
+
514
+ return result
515
+
516
+
517
+ def inner(left: MPS, right: MPS) -> float | complex:
518
+ """
519
+ Wrapper around MPS.inner.
520
+
521
+ Args:
522
+ left: the anti-linear argument
523
+ right: the linear argument
524
+
525
+ Returns:
526
+ the inner product
527
+ """
528
+ return left.inner(right)
emu_mps/mps_backend.py ADDED
@@ -0,0 +1,35 @@
1
+ from pulser import Sequence
2
+
3
+ from emu_base import Backend, BackendConfig, Results
4
+ from emu_mps.mps_config import MPSConfig
5
+ from emu_mps.mps_backend_impl import create_impl
6
+
7
+
8
+ class MPSBackend(Backend):
9
+ """
10
+ A backend for emulating Pulser sequences using Matrix Product States (MPS),
11
+ aka tensor trains.
12
+ """
13
+
14
+ def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
15
+ """
16
+ Emulates the given sequence.
17
+
18
+ Args:
19
+ sequence: a Pulser sequence to simulate
20
+ mps_config: the backends config. Should be of type MPSConfig
21
+
22
+ Returns:
23
+ the simulation results
24
+ """
25
+ assert isinstance(mps_config, MPSConfig)
26
+
27
+ self.validate_sequence(sequence)
28
+
29
+ impl = create_impl(sequence, mps_config)
30
+ impl.init() # This is separate from the constructor for testing purposes.
31
+
32
+ while not impl.is_finished():
33
+ impl.progress()
34
+
35
+ return impl.results