emu-mps 2.3.0__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -9,12 +9,11 @@ from pulser.backend import (
9
9
  StateResult,
10
10
  EnergySecondMoment,
11
11
  )
12
- from .mps_config import MPSConfig
12
+ from .mps_config import MPSConfig, Solver
13
13
  from .mpo import MPO
14
14
  from .mps import MPS, inner
15
15
  from .mps_backend import MPSBackend
16
16
  from .observables import EntanglementEntropy
17
- from emu_base import aggregate
18
17
 
19
18
 
20
19
  __all__ = [
@@ -23,6 +22,7 @@ __all__ = [
23
22
  "MPS",
24
23
  "inner",
25
24
  "MPSConfig",
25
+ "Solver",
26
26
  "MPSBackend",
27
27
  "StateResult",
28
28
  "BitStrings",
@@ -33,8 +33,7 @@ __all__ = [
33
33
  "Energy",
34
34
  "EnergyVariance",
35
35
  "EnergySecondMoment",
36
- "aggregate",
37
36
  "EntanglementEntropy",
38
37
  ]
39
38
 
40
- __version__ = "2.3.0"
39
+ __version__ = "2.4.1"
emu_mps/algebra.py CHANGED
@@ -1,11 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
4
3
 
5
4
  import torch
6
5
  import math
7
-
8
- from emu_mps import MPSConfig
9
6
  from emu_mps.utils import truncate_impl
10
7
 
11
8
 
@@ -119,7 +116,8 @@ def zip_right_step(
119
116
  def zip_right(
120
117
  top_factors: list[torch.Tensor],
121
118
  bottom_factors: list[torch.Tensor],
122
- config: Optional[MPSConfig] = None,
119
+ precision: float,
120
+ max_bond_dim: int,
123
121
  ) -> list[torch.Tensor]:
124
122
  """
125
123
  Returns a new matrix product, resulting from applying `top` to `bottom`.
@@ -139,8 +137,6 @@ def zip_right(
139
137
  A final truncation sweep, from right to left,
140
138
  moves back the orthogonal center to the first element.
141
139
  """
142
- config = config if config is not None else MPSConfig()
143
-
144
140
  if len(top_factors) != len(bottom_factors):
145
141
  raise ValueError("Cannot multiply two matrix products of different lengths.")
146
142
 
@@ -151,6 +147,6 @@ def zip_right(
151
147
  new_factors.append(res)
152
148
  new_factors[-1] @= slider[:, :, 0]
153
149
 
154
- truncate_impl(new_factors, config=config)
150
+ truncate_impl(new_factors, precision=precision, max_bond_dim=max_bond_dim)
155
151
 
156
152
  return new_factors
@@ -25,9 +25,9 @@ def qubit_occupation_mps_impl(
25
25
  """
26
26
  Custom implementation of the occupation ❬ψ|nᵢ|ψ❭ for the EMU-MPS.
27
27
  """
28
- op = torch.tensor(
29
- [[[0.0, 0.0], [0.0, 1.0]]], dtype=torch.complex128, device=state.factors[0].device
30
- )
28
+ dim = state.dim
29
+ op = torch.zeros(1, dim, dim, dtype=torch.complex128, device=state.factors[0].device)
30
+ op[0, 1, 1] = 1.0
31
31
  return state.expect_batch(op).real.view(-1).cpu()
32
32
 
33
33
 
emu_mps/mpo.py CHANGED
@@ -7,7 +7,7 @@ from pulser.backend import State, Operator
7
7
  from emu_base import DEVICE_COUNT
8
8
  from emu_mps.algebra import add_factors, scale_factors, zip_right
9
9
  from pulser.backend.operator import FullOp, QuditOp
10
- from emu_mps.mps import MPS
10
+ from emu_mps.mps import MPS, DEFAULT_MAX_BOND_DIM, DEFAULT_PRECISION
11
11
  from emu_mps.utils import new_left_bath, assign_devices
12
12
 
13
13
 
@@ -62,7 +62,8 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
62
62
  factors = zip_right(
63
63
  self.factors,
64
64
  other.factors,
65
- config=other.config,
65
+ precision=other.precision,
66
+ max_bond_dim=other.max_bond_dim,
66
67
  )
67
68
  return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
68
69
 
@@ -107,7 +108,12 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
107
108
  the composed operator
108
109
  """
109
110
  assert isinstance(other, MPO), "MPO can only be applied to another MPO"
110
- factors = zip_right(self.factors, other.factors)
111
+ factors = zip_right(
112
+ self.factors,
113
+ other.factors,
114
+ precision=DEFAULT_PRECISION,
115
+ max_bond_dim=DEFAULT_MAX_BOND_DIM,
116
+ )
111
117
  return MPO(factors)
112
118
 
113
119
  def expect(self, state: State) -> torch.Tensor:
emu_mps/mps.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
-
3
2
  import math
4
3
  from collections import Counter
5
4
  from typing import List, Optional, Sequence, TypeVar, Mapping
@@ -8,16 +7,19 @@ import torch
8
7
 
9
8
  from pulser.backend.state import State, Eigenstate
10
9
  from emu_base import DEVICE_COUNT, apply_measurement_errors
11
- from emu_mps import MPSConfig
12
10
  from emu_mps.algebra import add_factors, scale_factors
13
11
  from emu_mps.utils import (
14
12
  assign_devices,
15
13
  truncate_impl,
16
14
  tensor_trace,
17
- n_operator,
18
15
  )
19
16
 
17
+
20
18
  ArgScalarType = TypeVar("ArgScalarType")
19
+ dtype = torch.complex128
20
+
21
+ DEFAULT_PRECISION = 1e-5
22
+ DEFAULT_MAX_BOND_DIM = 1024
21
23
 
22
24
 
23
25
  class MPS(State[complex, torch.Tensor]):
@@ -35,12 +37,14 @@ class MPS(State[complex, torch.Tensor]):
35
37
  /,
36
38
  *,
37
39
  orthogonality_center: Optional[int] = None,
38
- config: Optional[MPSConfig] = None,
40
+ precision: float = DEFAULT_PRECISION,
41
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
39
42
  num_gpus_to_use: Optional[int] = DEVICE_COUNT,
40
43
  eigenstates: Sequence[Eigenstate] = ("r", "g"),
41
44
  ):
42
45
  """
43
- This constructor creates a MPS directly from a list of tensors. It is for internal use only.
46
+ This constructor creates a MPS directly from a list of tensors. It is
47
+ for internal use only.
44
48
 
45
49
  Args:
46
50
  factors: the tensors for each site
@@ -52,12 +56,14 @@ class MPS(State[complex, torch.Tensor]):
52
56
  of the data to this constructor, or some shared objects.
53
57
  orthogonality_center: the orthogonality center of the MPS, or None (in which case
54
58
  it will be orthogonalized when needed)
55
- config: the emu-mps config object passed to the run method
59
+ precision: the precision with which to keep this MPS
60
+ max_bond_dim: the maximum bond dimension to allow for this MPS
56
61
  num_gpus_to_use: distribute the factors over this many GPUs
57
62
  0=all factors to cpu, None=keep the existing device assignment.
58
63
  """
59
64
  super().__init__(eigenstates=eigenstates)
60
- self.config = config if config is not None else MPSConfig()
65
+ self.precision = precision
66
+ self.max_bond_dim = max_bond_dim
61
67
  assert all(
62
68
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
63
69
  ), "The dimensions of consecutive tensors should match"
@@ -69,6 +75,17 @@ class MPS(State[complex, torch.Tensor]):
69
75
  self.num_sites = len(factors)
70
76
  assert self.num_sites > 1 # otherwise, do state vector
71
77
 
78
+ self.dim = len(self.eigenstates)
79
+ assert all(factors[i].shape[1] == self.dim for i in range(self.num_sites)), (
80
+ "All tensors should have the same physical dimension as the number "
81
+ "of eigenstates"
82
+ )
83
+
84
+ self.n_operator = torch.zeros(
85
+ self.dim, self.dim, dtype=dtype, device=self.factors[0].device
86
+ )
87
+ self.n_operator[1, 1] = 1.0
88
+
72
89
  assert (orthogonality_center is None) or (
73
90
  0 <= orthogonality_center < self.num_sites
74
91
  ), "Invalid orthogonality center provided"
@@ -86,7 +103,8 @@ class MPS(State[complex, torch.Tensor]):
86
103
  def make(
87
104
  cls,
88
105
  num_sites: int,
89
- config: Optional[MPSConfig] = None,
106
+ precision: float = DEFAULT_PRECISION,
107
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
90
108
  num_gpus_to_use: int = DEVICE_COUNT,
91
109
  eigenstates: Sequence[Eigenstate] = ["0", "1"],
92
110
  ) -> MPS:
@@ -95,21 +113,35 @@ class MPS(State[complex, torch.Tensor]):
95
113
 
96
114
  Args:
97
115
  num_sites: the number of qubits
98
- config: the MPSConfig
116
+ precision: the precision with which to keep this MPS
117
+ max_bond_dim: the maximum bond dimension to allow for this MPS
99
118
  num_gpus_to_use: distribute the factors over this many GPUs
100
119
  0=all factors to cpu
101
120
  """
102
- config = config if config is not None else MPSConfig()
103
-
104
121
  if num_sites <= 1:
105
122
  raise ValueError("For 1 qubit states, do state vector")
106
123
 
107
- return cls(
108
- [
109
- torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
124
+ if len(eigenstates) == 2:
125
+ ground_state = [
126
+ torch.tensor([[[1.0], [0.0]]], dtype=dtype) for _ in range(num_sites)
127
+ ]
128
+
129
+ elif len(eigenstates) == 3: # (g,r,x)
130
+ ground_state = [
131
+ torch.tensor([[[1.0], [0.0], [0.0]]], dtype=dtype)
110
132
  for _ in range(num_sites)
111
- ],
112
- config=config,
133
+ ]
134
+
135
+ else:
136
+ raise ValueError(
137
+ "Unsupported basis provided. The supported "
138
+ "bases are:{('0','1'),('r','g'),('r','g','x')}"
139
+ )
140
+
141
+ return cls(
142
+ ground_state,
143
+ precision=precision,
144
+ max_bond_dim=max_bond_dim,
113
145
  num_gpus_to_use=num_gpus_to_use,
114
146
  orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
115
147
  eigenstates=eigenstates,
@@ -140,7 +172,8 @@ class MPS(State[complex, torch.Tensor]):
140
172
 
141
173
  for i in range(lr_swipe_start, desired_orthogonality_center):
142
174
  q, r = torch.linalg.qr(self.factors[i].view(-1, self.factors[i].shape[2]))
143
- self.factors[i] = q.view(self.factors[i].shape[0], 2, -1)
175
+
176
+ self.factors[i] = q.view(self.factors[i].shape[0], self.dim, -1)
144
177
  self.factors[i + 1] = torch.tensordot(
145
178
  r.to(self.factors[i + 1].device), self.factors[i + 1], dims=1
146
179
  )
@@ -153,9 +186,10 @@ class MPS(State[complex, torch.Tensor]):
153
186
 
154
187
  for i in range(rl_swipe_start, desired_orthogonality_center, -1):
155
188
  q, r = torch.linalg.qr(
156
- self.factors[i].view(self.factors[i].shape[0], -1).mT,
189
+ self.factors[i].contiguous().view(self.factors[i].shape[0], -1).mT,
157
190
  )
158
- self.factors[i] = q.mT.view(-1, 2, self.factors[i].shape[2])
191
+
192
+ self.factors[i] = q.mT.view(-1, self.dim, self.factors[i].shape[2])
159
193
  self.factors[i - 1] = torch.tensordot(
160
194
  self.factors[i - 1], r.to(self.factors[i - 1].device), ([2], [1])
161
195
  )
@@ -172,7 +206,9 @@ class MPS(State[complex, torch.Tensor]):
172
206
  An in-place operation.
173
207
  """
174
208
  self.orthogonalize(self.num_sites - 1)
175
- truncate_impl(self.factors, config=self.config)
209
+ truncate_impl(
210
+ self.factors, precision=self.precision, max_bond_dim=self.max_bond_dim
211
+ )
176
212
  self.orthogonality_center = 0
177
213
 
178
214
  def get_max_bond_dim(self) -> int:
@@ -182,7 +218,7 @@ class MPS(State[complex, torch.Tensor]):
182
218
  Returns:
183
219
  the largest bond dimension in the state
184
220
  """
185
- return max((x.shape[2] for x in self.factors), default=0)
221
+ return max((factor.shape[2] for factor in self.factors), default=0)
186
222
 
187
223
  def sample(
188
224
  self,
@@ -206,8 +242,6 @@ class MPS(State[complex, torch.Tensor]):
206
242
  assert one_state in {None, "r", "1"}
207
243
  self.orthogonalize(0)
208
244
 
209
- rnd_matrix = torch.rand(num_shots, self.num_sites).to(self.factors[0].device)
210
-
211
245
  bitstrings: Counter[str] = Counter()
212
246
 
213
247
  # Shots are performed in batches.
@@ -215,55 +249,45 @@ class MPS(State[complex, torch.Tensor]):
215
249
  max_batch_size = 32
216
250
 
217
251
  shots_done = 0
252
+
218
253
  while shots_done < num_shots:
219
254
  batch_size = min(max_batch_size, num_shots - shots_done)
220
255
  batched_accumulator = torch.ones(
221
- batch_size, 1, dtype=torch.complex128, device=self.factors[0].device
256
+ batch_size, 1, dtype=dtype, device=self.factors[0].device
222
257
  )
223
258
 
224
- batch_outcomes = torch.empty(batch_size, self.num_sites, dtype=torch.bool)
225
-
259
+ batch_outcomes = torch.empty(batch_size, self.num_sites, dtype=torch.int)
260
+ rangebatch = torch.arange(batch_size)
226
261
  for qubit, factor in enumerate(self.factors):
227
262
  batched_accumulator = torch.tensordot(
228
263
  batched_accumulator.to(factor.device), factor, dims=1
229
264
  )
230
265
 
231
- # Probability of measuring qubit == 0 for each shot in the batch
232
- probas = (
233
- torch.linalg.vector_norm(batched_accumulator[:, 0, :], dim=1) ** 2
234
- )
266
+ # Probabilities for each state in the basis
267
+ probn = torch.linalg.vector_norm(batched_accumulator, dim=2) ** 2
235
268
 
236
- outcomes = (
237
- rnd_matrix[shots_done : shots_done + batch_size, qubit].to(
238
- factor.device
239
- )
240
- > probas
241
- )
242
- batch_outcomes[:, qubit] = outcomes
269
+ # list of: 0,1 for |g>,|r> or 0,1,2 for |g>,|r>,|x>
270
+ outcomes = torch.multinomial(probn, num_samples=1).reshape(-1)
243
271
 
244
- # Batch collapse qubit
245
- tmp = torch.stack((~outcomes, outcomes), dim=1).to(dtype=torch.complex128)
272
+ batch_outcomes[:, qubit] = outcomes
246
273
 
247
- batched_accumulator = (
248
- torch.tensordot(batched_accumulator, tmp, dims=([1], [1]))
249
- .diagonal(dim1=0, dim2=2)
250
- .transpose(1, 0)
251
- )
252
- batched_accumulator /= torch.sqrt(
253
- (~outcomes) * probas + outcomes * (1 - probas)
254
- ).unsqueeze(1)
274
+ # expected shape (batch_size, bond_dim)
275
+ batched_accumulator = batched_accumulator[rangebatch, outcomes, :]
255
276
 
256
277
  shots_done += batch_size
257
278
 
258
279
  for outcome in batch_outcomes:
259
- bitstrings.update(["".join("0" if x == 0 else "1" for x in outcome)])
280
+ bitstrings.update(["".join("1" if x == 1 else "0" for x in outcome)])
260
281
 
261
- if p_false_neg > 0 or p_false_pos > 0:
282
+ if p_false_neg > 0 or p_false_pos > 0 and self.dim == 2:
262
283
  bitstrings = apply_measurement_errors(
263
284
  bitstrings,
264
285
  p_false_pos=p_false_pos,
265
286
  p_false_neg=p_false_neg,
266
287
  )
288
+ if p_false_pos > 0 and self.dim > 2:
289
+ raise NotImplementedError("Not implemented for qudits > 2 levels")
290
+
267
291
  return bitstrings
268
292
 
269
293
  def norm(self) -> torch.Tensor:
@@ -312,8 +336,11 @@ class MPS(State[complex, torch.Tensor]):
312
336
  def entanglement_entropy(self, mps_site: int) -> torch.Tensor:
313
337
  """
314
338
  Returns
315
- the Von Neumann entanglement entropy of the state `mps` at the bond between sites b and b+1
339
+ the Von Neumann entanglement entropy of the state `mps` at the bond
340
+ between sites b and b+1
341
+
316
342
  S = -Σᵢsᵢ² log(sᵢ²)),
343
+
317
344
  where sᵢ are the singular values at the chosen bond.
318
345
  """
319
346
  self.orthogonalize(mps_site)
@@ -358,7 +385,8 @@ class MPS(State[complex, torch.Tensor]):
358
385
  new_tt = add_factors(self.factors, other.factors)
359
386
  result = MPS(
360
387
  new_tt,
361
- config=self.config,
388
+ precision=self.precision,
389
+ max_bond_dim=self.max_bond_dim,
362
390
  num_gpus_to_use=None,
363
391
  orthogonality_center=None, # Orthogonality is lost.
364
392
  eigenstates=self.eigenstates,
@@ -384,7 +412,8 @@ class MPS(State[complex, torch.Tensor]):
384
412
  factors = scale_factors(self.factors, scalar, which=which)
385
413
  return MPS(
386
414
  factors,
387
- config=self.config,
415
+ precision=self.precision,
416
+ max_bond_dim=self.max_bond_dim,
388
417
  num_gpus_to_use=None,
389
418
  orthogonality_center=self.orthogonality_center,
390
419
  eigenstates=self.eigenstates,
@@ -412,25 +441,43 @@ class MPS(State[complex, torch.Tensor]):
412
441
  Returns:
413
442
  The resulting MPS representation of the state.s
414
443
  """
444
+
445
+ leak = ""
446
+ one = "r"
415
447
  basis = set(eigenstates)
448
+
416
449
  if basis == {"r", "g"}:
417
- one = "r"
450
+ pass
418
451
  elif basis == {"0", "1"}:
419
452
  one = "1"
453
+ elif basis == {"g", "r", "x"}:
454
+ leak = "x"
420
455
  else:
421
456
  raise ValueError("Unsupported basis provided")
457
+ dim = len(eigenstates)
458
+ if dim == 2:
459
+ basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=dtype) # ground state
460
+ basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=dtype) # excited state
422
461
 
423
- basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
424
- basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
462
+ elif dim == 3:
463
+ basis_0 = torch.tensor([[[1.0], [0.0], [0.0]]], dtype=dtype) # ground state
464
+ basis_1 = torch.tensor([[[0.0], [1.0], [0.0]]], dtype=dtype) # excited state
465
+ basis_x = torch.tensor([[[0.0], [0.0], [1.0]]], dtype=dtype) # leakage state
425
466
 
426
467
  accum_mps = MPS(
427
- [torch.zeros((1, 2, 1), dtype=torch.complex128)] * n_qudits,
468
+ [torch.zeros((1, dim, 1), dtype=dtype)] * n_qudits,
428
469
  orthogonality_center=0,
429
470
  eigenstates=eigenstates,
430
471
  )
431
-
432
472
  for state, amplitude in amplitudes.items():
433
- factors = [basis_1 if ch == one else basis_0 for ch in state]
473
+ factors = []
474
+ for ch in state:
475
+ if ch == one:
476
+ factors.append(basis_1)
477
+ elif ch == leak:
478
+ factors.append(basis_x)
479
+ else:
480
+ factors.append(basis_0)
434
481
  accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
435
482
  norm = accum_mps.norm()
436
483
  if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
@@ -453,9 +500,7 @@ class MPS(State[complex, torch.Tensor]):
453
500
  else self.orthogonalize(0)
454
501
  )
455
502
 
456
- result = torch.zeros(
457
- self.num_sites, single_qubit_operators.shape[0], dtype=torch.complex128
458
- )
503
+ result = torch.zeros(self.num_sites, single_qubit_operators.shape[0], dtype=dtype)
459
504
 
460
505
  center_factor = self.factors[orthogonality_center]
461
506
  for qubit_index in range(orthogonality_center, self.num_sites):
@@ -503,7 +548,7 @@ class MPS(State[complex, torch.Tensor]):
503
548
  )
504
549
 
505
550
  def get_correlation_matrix(
506
- self, *, operator: torch.Tensor = n_operator
551
+ self, operator: torch.Tensor | None = None
507
552
  ) -> torch.Tensor:
508
553
  """
509
554
  Efficiently compute the symmetric correlation matrix
@@ -511,14 +556,18 @@ class MPS(State[complex, torch.Tensor]):
511
556
  in basis ("r", "g").
512
557
 
513
558
  Args:
514
- operator: a 2x2 Torch tensor to use
559
+ operator: a 2x2 (or 3x3) Torch tensor to use
515
560
 
516
561
  Returns:
517
562
  the corresponding correlation matrix
518
563
  """
519
- assert operator.shape == (2, 2)
520
564
 
521
- result = torch.zeros(self.num_sites, self.num_sites, dtype=torch.complex128)
565
+ if operator is None:
566
+ operator = self.n_operator
567
+
568
+ assert operator.shape == (self.dim, self.dim), "Operator has wrong shape"
569
+
570
+ result = torch.zeros(self.num_sites, self.num_sites, dtype=dtype)
522
571
 
523
572
  for left in range(0, self.num_sites):
524
573
  self.orthogonalize(left)
@@ -10,7 +10,6 @@ import uuid
10
10
  from copy import deepcopy
11
11
  from collections import Counter
12
12
  from enum import Enum, auto
13
- from resource import RUSAGE_SELF, getrusage
14
13
  from types import MethodType
15
14
  from typing import Any, Optional
16
15
 
@@ -18,7 +17,7 @@ import torch
18
17
  from pulser import Sequence
19
18
  from pulser.backend import EmulationConfig, Observable, Results, State
20
19
 
21
- from emu_base import DEVICE_COUNT, PulserData
20
+ from emu_base import DEVICE_COUNT, PulserData, get_max_rss
22
21
  from emu_base.math.brents_root_finding import BrentsRootFinder
23
22
  from emu_base.utils import deallocate_tensor
24
23
 
@@ -26,9 +25,9 @@ from emu_mps.hamiltonian import make_H, update_H
26
25
  from emu_mps.mpo import MPO
27
26
  from emu_mps.mps import MPS
28
27
  from emu_mps.mps_config import MPSConfig
29
- from emu_base.noise import pick_dark_qubits
30
28
  from emu_base.jump_lindblad_operators import compute_noise_from_lindbladians
31
29
  import emu_mps.optimatrix as optimat
30
+ from emu_mps.solver import Solver
32
31
  from emu_mps.solver_utils import (
33
32
  evolve_pair,
34
33
  evolve_single,
@@ -69,14 +68,7 @@ class Statistics(Observable):
69
68
  """Calculates the observable to store in the Results."""
70
69
  assert isinstance(state, MPS)
71
70
  duration = self.data[-1]
72
- if state.factors[0].is_cuda:
73
- max_mem_per_device = (
74
- torch.cuda.max_memory_allocated(device) * 1e-6
75
- for device in range(torch.cuda.device_count())
76
- )
77
- max_mem = max(max_mem_per_device)
78
- else:
79
- max_mem = getrusage(RUSAGE_SELF).ru_maxrss * 1e-3
71
+ max_mem = get_max_rss(state.factors[0].is_cuda)
80
72
 
81
73
  config.logger.info(
82
74
  f"step = {len(self.data)}/{self.timestep_count}, "
@@ -193,10 +185,9 @@ class MPSBackendImpl:
193
185
  def init_dark_qubits(self) -> None:
194
186
  # has_state_preparation_error
195
187
  if self.config.noise_model.state_prep_error > 0.0:
188
+ bad_atoms = self.pulser_data.hamiltonian.bad_atoms
196
189
  self.well_prepared_qubits_filter = torch.logical_not(
197
- pick_dark_qubits(
198
- self.config.noise_model.state_prep_error, self.qubit_count
199
- )
190
+ torch.tensor(list(bool(x) for x in bad_atoms.values()))
200
191
  )
201
192
  else:
202
193
  self.well_prepared_qubits_filter = None
@@ -218,7 +209,8 @@ class MPSBackendImpl:
218
209
  if initial_state is None:
219
210
  self.state = MPS.make(
220
211
  self.qubit_count,
221
- config=self.config,
212
+ precision=self.config.precision,
213
+ max_bond_dim=self.config.max_bond_dim,
222
214
  num_gpus_to_use=self.config.num_gpus_to_use,
223
215
  )
224
216
  return
@@ -245,13 +237,15 @@ class MPSBackendImpl:
245
237
  initial_state = MPS(
246
238
  # Deep copy of every tensor of the initial state.
247
239
  [f.detach().clone() for f in initial_state.factors],
248
- config=self.config,
240
+ precision=self.config.precision,
241
+ max_bond_dim=self.config.max_bond_dim,
249
242
  num_gpus_to_use=self.config.num_gpus_to_use,
250
243
  eigenstates=initial_state.eigenstates,
251
244
  )
252
245
  initial_state.truncate()
253
246
  initial_state *= 1 / initial_state.norm()
254
247
  self.state = initial_state
248
+ self.state.orthogonalize(0)
255
249
 
256
250
  def init_hamiltonian(self) -> None:
257
251
  """
@@ -699,18 +693,20 @@ class DMRGBackendImpl(MPSBackendImpl):
699
693
  mps_config: MPSConfig,
700
694
  pulser_data: PulserData,
701
695
  energy_tolerance: float = 1e-5,
702
- max_sweeps: int = 999,
703
- residual_tolerance: float = 1e-7,
696
+ max_sweeps: int = 2000,
704
697
  ):
698
+
699
+ if mps_config.noise_model.noise_types != ():
700
+ raise NotImplementedError(
701
+ "DMRG solver does not currently support noise types"
702
+ f"you are using: {mps_config.noise_model.noise_types}"
703
+ )
705
704
  super().__init__(mps_config, pulser_data)
706
- self.init()
707
- self.state.orthogonalize(0)
708
705
  self.previous_energy: Optional[float] = None
709
706
  self.current_energy: Optional[float] = None
710
707
  self.sweep_count: int = 0
711
708
  self.energy_tolerance: float = energy_tolerance
712
709
  self.max_sweeps: int = max_sweeps
713
- self.residual_tolerance: float = residual_tolerance
714
710
 
715
711
  def convergence_check(self, energy_tolerance: float) -> bool:
716
712
  if self.previous_energy is None or self.current_energy is None:
@@ -728,63 +724,72 @@ class DMRGBackendImpl(MPSBackendImpl):
728
724
  SwipeDirection.RIGHT_TO_LEFT,
729
725
  ), "Unknown Swipe direction"
730
726
 
731
- if self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT:
732
- left_idx, right_idx = idx, idx + 1
733
- elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
734
- left_idx, right_idx = idx - 1, idx
735
-
736
727
  orth_center_right = self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT
737
728
  new_L, new_R, energy = minimize_energy_pair(
738
- state_factors=self.state.factors[left_idx : right_idx + 1],
739
- ham_factors=self.hamiltonian.factors[left_idx : right_idx + 1],
729
+ state_factors=self.state.factors[idx : idx + 2],
730
+ ham_factors=self.hamiltonian.factors[idx : idx + 2],
740
731
  baths=(self.left_baths[-1], self.right_baths[-1]),
741
732
  orth_center_right=orth_center_right,
742
733
  config=self.config,
743
- residual_tolerance=self.residual_tolerance,
734
+ residual_tolerance=self.config.precision,
744
735
  )
745
- self.state.factors[left_idx], self.state.factors[right_idx] = new_L, new_R
746
- self.state.orthogonality_center = right_idx if orth_center_right else left_idx
736
+ self.state.factors[idx], self.state.factors[idx + 1] = new_L, new_R
737
+ self.state.orthogonality_center = idx + 1 if orth_center_right else idx
747
738
  self.current_energy = energy
748
739
 
749
740
  # updating baths and orthogonality center
750
741
  if self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT:
742
+ self._left_to_right_update(idx)
743
+ elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
744
+ self._right_to_left_update(idx)
745
+ else:
746
+ raise Exception("Did not expect this")
747
+
748
+ self.save_simulation()
749
+
750
+ def _left_to_right_update(self, idx: int) -> None:
751
+ if idx < self.qubit_count - 2:
751
752
  self.left_baths.append(
752
753
  new_left_bath(
753
754
  self.get_current_left_bath(),
754
- self.state.factors[left_idx],
755
- self.hamiltonian.factors[right_idx],
756
- ).to(self.state.factors[right_idx].device)
755
+ self.state.factors[idx],
756
+ self.hamiltonian.factors[idx],
757
+ ).to(self.state.factors[idx + 1].device)
757
758
  )
758
759
  self.right_baths.pop()
759
760
  self.sweep_index += 1
760
761
 
761
- if self.sweep_index == self.qubit_count - 1:
762
- self.swipe_direction = SwipeDirection.RIGHT_TO_LEFT
762
+ if self.sweep_index == self.qubit_count - 2:
763
+ self.swipe_direction = SwipeDirection.RIGHT_TO_LEFT
763
764
 
764
- elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
765
+ def _right_to_left_update(self, idx: int) -> None:
766
+ if idx > 0:
765
767
  self.right_baths.append(
766
768
  new_right_bath(
767
769
  self.get_current_right_bath(),
768
- self.state.factors[right_idx],
769
- self.hamiltonian.factors[right_idx],
770
- ).to(self.state.factors[left_idx].device)
770
+ self.state.factors[idx + 1],
771
+ self.hamiltonian.factors[idx + 1],
772
+ ).to(self.state.factors[idx].device)
771
773
  )
772
774
  self.left_baths.pop()
773
775
  self.sweep_index -= 1
774
776
 
775
- if self.sweep_index == 0:
776
- self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
777
- self.sweep_count += 1
778
- self.sweep_complete()
777
+ if self.sweep_index == 0:
778
+ self.state.orthogonalize(0)
779
+ self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
780
+ self.sweep_count += 1
781
+ self.sweep_complete()
779
782
 
780
783
  def sweep_complete(self) -> None:
781
784
  # This marks the end of one full sweep: checking convergence
782
785
  if self.convergence_check(self.energy_tolerance):
786
+ self.current_time = self.target_time
783
787
  self.timestep_complete()
784
788
  elif self.sweep_count + 1 > self.max_sweeps:
785
- # not converged: restart a new sweep
789
+ # not converged
786
790
  raise RuntimeError(f"DMRG did not converge after {self.max_sweeps} sweeps")
787
791
  else:
792
+ # not converged for the current sweep. restart
788
793
  self.previous_energy = self.current_energy
789
794
 
790
795
  assert self.sweep_index == 0
@@ -798,5 +803,6 @@ def create_impl(sequence: Sequence, config: MPSConfig) -> MPSBackendImpl:
798
803
 
799
804
  if pulser_data.has_lindblad_noise:
800
805
  return NoisyMPSBackendImpl(config, pulser_data)
801
-
806
+ if config.solver == Solver.DMRG:
807
+ return DMRGBackendImpl(config, pulser_data)
802
808
  return MPSBackendImpl(config, pulser_data)
emu_mps/mps_config.py CHANGED
@@ -4,6 +4,9 @@ from types import MethodType
4
4
  import copy
5
5
 
6
6
  from emu_base import DEVICE_COUNT
7
+ from emu_mps.mps import MPS, DEFAULT_MAX_BOND_DIM, DEFAULT_PRECISION
8
+ from emu_mps.mpo import MPO
9
+ from emu_mps.solver import Solver
7
10
  from emu_mps.custom_callback_implementations import (
8
11
  energy_mps_impl,
9
12
  energy_second_moment_mps_impl,
@@ -32,15 +35,17 @@ class MPSConfig(EmulationConfig):
32
35
  See the API for that class for a list of available options.
33
36
 
34
37
  Args:
35
- dt: the timestep size that the solver uses. Note that observables are
38
+ dt: The timestep size that the solver uses. Note that observables are
36
39
  only calculated if the evaluation_times are divisible by dt.
37
- precision: up to what precision the state is truncated
38
- max_bond_dim: the maximum bond dimension that the state is allowed to have.
40
+ precision: Up to what precision the state is truncated.
41
+ Defaults to `1e-5`.
42
+ max_bond_dim: The maximum bond dimension that the state is allowed to have.
43
+ Defaults to `1024`.
39
44
  max_krylov_dim:
40
- the size of the krylov subspace that the Lanczos algorithm maximally builds
45
+ The size of the krylov subspace that the Lanczos algorithm maximally builds
41
46
  extra_krylov_tolerance:
42
- the Lanczos algorithm uses this*precision as the convergence tolerance
43
- num_gpus_to_use: during the simulation, distribute the state over this many GPUs
47
+ The Lanczos algorithm uses this*precision as the convergence tolerance
48
+ num_gpus_to_use: During the simulation, distribute the state over this many GPUs
44
49
  0=all factors to cpu. As shown in the benchmarks, using multiple GPUs might
45
50
  alleviate memory pressure per GPU, but the runtime should be similar.
46
51
  optimize_qubit_ordering: Optimize the register ordering. Improves performance and
@@ -50,10 +55,15 @@ class MPSConfig(EmulationConfig):
50
55
  log_level: How much to log. Set to `logging.WARN` to get rid of the timestep info.
51
56
  log_file: If specified, log to this file rather than stout.
52
57
  autosave_prefix: filename prefix for autosaving simulation state to file
53
- autosave_dt: minimum time interval in seconds between two autosaves.
58
+ autosave_dt: Minimum time interval in seconds between two autosaves.
54
59
  Saving the simulation state is only possible at specific times,
55
60
  therefore this interval is only a lower bound.
56
- kwargs: arguments that are passed to the base class
61
+ solver: Chooses the solver algorithm to run a sequence.
62
+ Two options are currently available:
63
+ ``TDVP``, which performs ordinary time evolution,
64
+ and ``DMRG``, which adiabatically follows the ground state
65
+ of a given adiabatic pulse.
66
+ kwargs: Arguments that are passed to the base class
57
67
 
58
68
  Examples:
59
69
  >>> num_gpus_to_use = 2 #use 2 gpus if available, otherwise 1 or cpu
@@ -65,13 +75,15 @@ class MPSConfig(EmulationConfig):
65
75
 
66
76
  # Whether to warn if unexpected kwargs are received
67
77
  _enforce_expected_kwargs: ClassVar[bool] = True
78
+ _state_type = MPS
79
+ _operator_type = MPO
68
80
 
69
81
  def __init__(
70
82
  self,
71
83
  *,
72
84
  dt: int = 10,
73
- precision: float = 1e-5,
74
- max_bond_dim: int = 1024,
85
+ precision: float = DEFAULT_PRECISION,
86
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
75
87
  max_krylov_dim: int = 100,
76
88
  extra_krylov_tolerance: float = 1e-3,
77
89
  num_gpus_to_use: int = DEVICE_COUNT,
@@ -81,6 +93,7 @@ class MPSConfig(EmulationConfig):
81
93
  log_file: pathlib.Path | None = None,
82
94
  autosave_prefix: str = "emu_mps_save_",
83
95
  autosave_dt: int = 600, # 10 minutes
96
+ solver: Solver = Solver.TDVP,
84
97
  **kwargs: Any,
85
98
  ):
86
99
  kwargs.setdefault("observables", [BitStrings(evaluation_times=[1.0])])
@@ -97,6 +110,7 @@ class MPSConfig(EmulationConfig):
97
110
  log_file=log_file,
98
111
  autosave_prefix=autosave_prefix,
99
112
  autosave_dt=autosave_dt,
113
+ solver=solver,
100
114
  **kwargs,
101
115
  )
102
116
  if self.optimize_qubit_ordering:
@@ -144,6 +158,7 @@ class MPSConfig(EmulationConfig):
144
158
  "log_file",
145
159
  "autosave_prefix",
146
160
  "autosave_dt",
161
+ "solver",
147
162
  }
148
163
 
149
164
  def monkeypatch_observables(self) -> None:
emu_mps/observables.py CHANGED
@@ -6,7 +6,18 @@ import torch
6
6
 
7
7
 
8
8
  class EntanglementEntropy(Observable):
9
- """Entanglement Entropy subclass used only in emu_mps"""
9
+ """Entanglement Entropy of the state partition at qubit `mps_site`.
10
+
11
+ Args:
12
+ mps_site: the qubit index at which the bipartition is made.
13
+ All qubits with index $\\leq$ `mps_site` are put in the left partition.
14
+ evaluation_times: The relative times at which to store the state.
15
+ If left as `None`, uses the ``default_evaluation_times`` of the
16
+ backend's ``EmulationConfig``.
17
+ tag_suffix: An optional suffix to append to the tag. Needed if
18
+ multiple instances of the same observable are given to the
19
+ same EmulationConfig.
20
+ """
10
21
 
11
22
  def __init__(
12
23
  self,
emu_mps/solver.py ADDED
@@ -0,0 +1,6 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Solver(str, Enum):
5
+ TDVP = "tdvp"
6
+ DMRG = "dmrg"
emu_mps/solver_utils.py CHANGED
@@ -51,9 +51,10 @@ def make_op(
51
51
  .view(left_ham_factor.shape[0], 4, 4, -1)
52
52
  )
53
53
 
54
- op = lambda x: time_step * apply_effective_Hamiltonian(
55
- x, combined_hamiltonian_factors, left_bath, right_bath
56
- )
54
+ def op(x: torch.Tensor) -> torch.Tensor:
55
+ return time_step * apply_effective_Hamiltonian(
56
+ x, combined_hamiltonian_factors, left_bath, right_bath
57
+ )
57
58
 
58
59
  return combined_state_factors, right_device, op
59
60
 
@@ -205,17 +206,18 @@ def evolve_single(
205
206
 
206
207
  left_bath, right_bath = baths
207
208
 
208
- op = (
209
- lambda x: -_TIME_CONVERSION_COEFF
210
- * 1j
211
- * dt
212
- * apply_effective_Hamiltonian(
213
- x,
214
- ham_factor,
215
- left_bath,
216
- right_bath,
209
+ def op(x: torch.Tensor) -> torch.Tensor:
210
+ return (
211
+ -_TIME_CONVERSION_COEFF
212
+ * 1j
213
+ * dt
214
+ * apply_effective_Hamiltonian(
215
+ x,
216
+ ham_factor,
217
+ left_bath,
218
+ right_bath,
219
+ )
217
220
  )
218
- )
219
221
 
220
222
  return krylov_exp(
221
223
  op,
emu_mps/utils.py CHANGED
@@ -1,8 +1,6 @@
1
1
  from typing import List, Optional
2
2
  import torch
3
3
 
4
- from emu_mps import MPSConfig
5
-
6
4
 
7
5
  def new_left_bath(
8
6
  bath: torch.Tensor, state: torch.Tensor, op: torch.Tensor
@@ -59,8 +57,7 @@ def split_tensor(
59
57
 
60
58
 
61
59
  def truncate_impl(
62
- factors: list[torch.Tensor],
63
- config: MPSConfig,
60
+ factors: list[torch.Tensor], precision: float, max_bond_dim: int
64
61
  ) -> None:
65
62
  """
66
63
  Eigenvalues-based truncation of a matrix product.
@@ -76,8 +73,8 @@ def truncate_impl(
76
73
 
77
74
  l, r = split_tensor(
78
75
  factors[i].view(factor_shape[0], -1),
79
- max_error=config.precision,
80
- max_rank=config.max_bond_dim,
76
+ max_error=precision,
77
+ max_rank=max_bond_dim,
81
78
  orth_center_right=False,
82
79
  )
83
80
 
@@ -208,15 +205,6 @@ def get_extended_site_index(
208
205
  raise ValueError(f"Index {desired_index} does not exist")
209
206
 
210
207
 
211
- n_operator: torch.Tensor = torch.tensor(
212
- [
213
- [0, 0],
214
- [0, 1],
215
- ],
216
- dtype=torch.complex128,
217
- )
218
-
219
-
220
208
  def tensor_trace(tensor: torch.Tensor, dim1: int, dim2: int) -> torch.Tensor:
221
209
  """
222
210
  Contract two legs of a single tensor.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-mps
3
- Version: 2.3.0
3
+ Version: 2.4.1
4
4
  Summary: Pasqal MPS based pulse emulator built on PyTorch
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
@@ -25,7 +25,7 @@ Classifier: Programming Language :: Python :: 3.10
25
25
  Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
- Requires-Dist: emu-base==2.3.0
28
+ Requires-Dist: emu-base==2.4.1
29
29
  Description-Content-Type: text/markdown
30
30
 
31
31
  <div align="center">
@@ -0,0 +1,19 @@
1
+ emu_mps/__init__.py,sha256=R4Fw8r-_GEeckaxyQ9leUKVfLoukcrwvaTQsTbzxhbs,708
2
+ emu_mps/algebra.py,sha256=vi3d4xOBEsfQ7gsjYNns9AMLjJXXXppYVEThavNvAg0,5408
3
+ emu_mps/custom_callback_implementations.py,sha256=WeczmO6qkvBIipvXLqX45i3D7M4ovOrepusIGs6d2Ts,2420
4
+ emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
+ emu_mps/mpo.py,sha256=2HNwN4Fz04QIVfPcPaMmt2q89ZBxN3K-vVeiFkOtqzs,8049
6
+ emu_mps/mps.py,sha256=KEXrLdqhi5EvpdX-9J38ZfrRdt3oDmcGsVLghMPUfQw,21600
7
+ emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
+ emu_mps/mps_backend_impl.py,sha256=Pcbn27lhg3n3Lzo3CGwhlWPqPi-8YV1ntkgl901AoUs,30400
9
+ emu_mps/mps_config.py,sha256=QmwgU8INEnxrxZkhboYIsHGwml0UhgPkddT5zh8KVBU,8867
10
+ emu_mps/observables.py,sha256=4C_ewkd3YkJP0xghTrGUTgXUGvJRCQcetb8cU0SjMl0,1900
11
+ emu_mps/solver.py,sha256=M9xkHhlEouTBvoPw2UYVu6kij7CO4Z1FXw_SiGFtdgo,85
12
+ emu_mps/solver_utils.py,sha256=EnNzEaUrtTMQbrWoqOy8vyDsQwlsfQCUc2HgOp4z8dk,8680
13
+ emu_mps/utils.py,sha256=pW5N_EbbGiOviQpJCw1a0pVgEDObP_InceNaIqY5bHE,6982
14
+ emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
15
+ emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
16
+ emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
17
+ emu_mps-2.4.1.dist-info/METADATA,sha256=2SUt6GLqhz_6-2ewuVqDH9xvz3gx1ZROU975intvuwk,3587
18
+ emu_mps-2.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ emu_mps-2.4.1.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- emu_mps/__init__.py,sha256=vvVRxFPudsVALA-_YyqWf8rlIRhtI1WzBWakYk2vfB8,734
2
- emu_mps/algebra.py,sha256=78XP9HEbV3wGNUzIulcLU5HizW4XAYmcFdkCe1T1x-k,5489
3
- emu_mps/custom_callback_implementations.py,sha256=SZGKVyS8U5hy07L-3SqpWlCAqGGKFTlSlWexZwSmjrM,2408
4
- emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
- emu_mps/mpo.py,sha256=aWSVuEzZM-_7ZD5Rz3-tSJWX22ARP0tMIl3gUu-_4V4,7834
6
- emu_mps/mps.py,sha256=n5KPuWr2qCNW4Q1OKpUGrSuhu4IbVp1v1BPLvLILbGg,19960
7
- emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
- emu_mps/mps_backend_impl.py,sha256=O3vbjw3jlgyj0hOYv3nMmAP0lINqxbt96mQPgwPZN5w,30198
9
- emu_mps/mps_config.py,sha256=PoSKZxJMhG6zfzgEjj4tIvyiyYRQywxkRgidh8MRBsA,8222
10
- emu_mps/observables.py,sha256=7GQDH5kyaVNrwckk2f8ZJRV9Ca4jKhWWDsOCqYWsoEk,1349
11
- emu_mps/solver_utils.py,sha256=VQ02_RxvPcjyXippuIY4Swpx4EdqtoJTt8Ie70GgdqU,8550
12
- emu_mps/utils.py,sha256=hgtaRUtBAzk76ab-S_wTVkvqfVOmaUks38zWame9GRQ,7132
13
- emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
14
- emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
15
- emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
16
- emu_mps-2.3.0.dist-info/METADATA,sha256=WLLpx31eshWCwR95aB_B4yoM8BRAKT81UR2mjVR3998,3587
17
- emu_mps-2.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- emu_mps-2.3.0.dist-info/RECORD,,