emu-mps 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -10,10 +10,10 @@ from emu_base import (
10
10
  StateResult,
11
11
  SecondMomentOfEnergy,
12
12
  )
13
+ from .mps_config import MPSConfig
13
14
  from .mpo import MPO
14
15
  from .mps import MPS, inner
15
16
  from .mps_backend import MPSBackend
16
- from .mps_config import MPSConfig
17
17
 
18
18
 
19
19
  __all__ = [
@@ -35,4 +35,4 @@ __all__ = [
35
35
  "SecondMomentOfEnergy",
36
36
  ]
37
37
 
38
- __version__ = "1.2.3"
38
+ __version__ = "1.2.5"
emu_mps/algebra.py CHANGED
@@ -1,13 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Optional
4
+
3
5
  import torch
4
6
  import math
7
+
8
+ from emu_mps import MPSConfig
5
9
  from emu_mps.utils import truncate_impl
6
10
 
7
11
 
8
12
  def add_factors(
9
- left: list[torch.tensor], right: list[torch.tensor]
10
- ) -> list[torch.tensor]:
13
+ left: list[torch.Tensor], right: list[torch.Tensor]
14
+ ) -> list[torch.Tensor]:
11
15
  """
12
16
  Direct sum algorithm implementation to sum two tensor trains (MPS/MPO).
13
17
  It assumes the left and right bond are along the dimension 0 and -1 of each tensor.
@@ -48,8 +52,8 @@ def add_factors(
48
52
 
49
53
 
50
54
  def scale_factors(
51
- factors: list[torch.tensor], scalar: complex, *, which: int
52
- ) -> list[torch.tensor]:
55
+ factors: list[torch.Tensor], scalar: complex, *, which: int
56
+ ) -> list[torch.Tensor]:
53
57
  """
54
58
  Returns a new list of factors where the tensor at the given index is scaled by `scalar`.
55
59
  """
@@ -57,10 +61,10 @@ def scale_factors(
57
61
 
58
62
 
59
63
  def zip_right_step(
60
- slider: torch.tensor,
61
- top: torch.tensor,
62
- bottom: torch.tensor,
63
- ) -> torch.tensor:
64
+ slider: torch.Tensor,
65
+ top: torch.Tensor,
66
+ bottom: torch.Tensor,
67
+ ) -> tuple[torch.Tensor, torch.Tensor]:
64
68
  """
65
69
  Returns a new `MPS/O` factor of the result of the multiplication MPO @ MPS/O,
66
70
  and the updated slider, performing a single step of the
@@ -113,11 +117,10 @@ def zip_right_step(
113
117
 
114
118
 
115
119
  def zip_right(
116
- top_factors: list[torch.tensor],
117
- bottom_factors: list[torch.tensor],
118
- max_error: float = 1e-5,
119
- max_rank: int = 1024,
120
- ) -> list[torch.tensor]:
120
+ top_factors: list[torch.Tensor],
121
+ bottom_factors: list[torch.Tensor],
122
+ config: Optional[MPSConfig] = None,
123
+ ) -> list[torch.Tensor]:
121
124
  """
122
125
  Returns a new matrix product, resulting from applying `top` to `bottom`.
123
126
  The resulting factors are:
@@ -136,6 +139,8 @@ def zip_right(
136
139
  A final truncation sweep, from right to left,
137
140
  moves back the orthogonal center to the first element.
138
141
  """
142
+ config = config if config is not None else MPSConfig()
143
+
139
144
  if len(top_factors) != len(bottom_factors):
140
145
  raise ValueError("Cannot multiply two matrix products of different lengths.")
141
146
 
@@ -146,6 +151,6 @@ def zip_right(
146
151
  new_factors.append(res)
147
152
  new_factors[-1] @= slider[:, :, 0]
148
153
 
149
- truncate_impl(new_factors, max_error=max_error, max_rank=max_rank)
154
+ truncate_impl(new_factors, config=config)
150
155
 
151
156
  return new_factors
emu_mps/constants.py ADDED
@@ -0,0 +1,4 @@
1
+ import torch
2
+
3
+
4
+ DEVICE_COUNT = torch.cuda.device_count()
emu_mps/hamiltonian.py CHANGED
@@ -285,6 +285,7 @@ def _get_interactions_to_keep(interaction_matrix: torch.Tensor) -> list[torch.Te
285
285
  returns a list of bool valued tensors,
286
286
  indicating which interaction terms to keep for each bond in the MPO
287
287
  """
288
+ interaction_matrix = interaction_matrix.clone()
288
289
  nqubits = interaction_matrix.size(dim=1)
289
290
  middle = nqubits // 2
290
291
  interaction_matrix += torch.eye(
@@ -359,10 +360,9 @@ def make_H(
359
360
 
360
361
  nqubits = interaction_matrix.size(dim=1)
361
362
  middle = nqubits // 2
362
-
363
363
  interactions_to_keep = _get_interactions_to_keep(interaction_matrix)
364
364
 
365
- cores = [_first_factor(interactions_to_keep[0].item())]
365
+ cores = [_first_factor(interactions_to_keep[0].item() != 0.0)]
366
366
 
367
367
  if nqubits > 2:
368
368
  for i in range(1, middle):
@@ -394,7 +394,7 @@ def make_H(
394
394
  )
395
395
  )
396
396
  if nqubits == 2:
397
- scale = interaction_matrix[0, 1]
397
+ scale = interaction_matrix[0, 1].item()
398
398
  elif interactions_to_keep[-1][0]:
399
399
  scale = 1.0
400
400
  else:
emu_mps/mpo.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
  import itertools
3
- from typing import Any, List, cast, Iterable, Optional
3
+ from typing import Any, List, Iterable, Optional
4
4
 
5
5
  import torch
6
6
 
@@ -8,7 +8,8 @@ from emu_mps.algebra import add_factors, scale_factors, zip_right
8
8
  from emu_base.base_classes.operator import FullOp, QuditOp
9
9
  from emu_base import Operator, State
10
10
  from emu_mps.mps import MPS
11
- from emu_mps.utils import new_left_bath, assign_devices, DEVICE_COUNT
11
+ from emu_mps.utils import new_left_bath, assign_devices
12
+ from emu_mps.constants import DEVICE_COUNT
12
13
 
13
14
 
14
15
  def _validate_operator_targets(operations: FullOp, nqubits: int) -> None:
@@ -80,8 +81,7 @@ class MPO(Operator):
80
81
  factors = zip_right(
81
82
  self.factors,
82
83
  other.factors,
83
- max_error=other.precision,
84
- max_rank=other.max_bond_dim,
84
+ config=other.config,
85
85
  )
86
86
  return MPS(factors, orthogonality_center=0)
87
87
 
@@ -175,14 +175,15 @@ class MPO(Operator):
175
175
  Returns:
176
176
  the operator in MPO form.
177
177
  """
178
+ operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)
178
179
 
179
180
  _validate_operator_targets(operations, nqubits)
180
181
 
181
182
  basis = set(basis)
182
183
  if basis == {"r", "g"}:
183
- # operators will now contain the basis for single qubit ops, and potentially
184
- # user defined strings in terms of these
185
- operators |= {
184
+ # operators_with_tensors will now contain the basis for single qubit ops,
185
+ # and potentially user defined strings in terms of these
186
+ operators_with_tensors |= {
186
187
  "gg": torch.tensor(
187
188
  [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
188
189
  ).reshape(1, 2, 2, 1),
@@ -197,9 +198,9 @@ class MPO(Operator):
197
198
  ).reshape(1, 2, 2, 1),
198
199
  }
199
200
  elif basis == {"0", "1"}:
200
- # operators will now contain the basis for single qubit ops, and potentially
201
- # user defined strings in terms of these
202
- operators |= {
201
+ # operators_with_tensors will now contain the basis for single qubit ops,
202
+ # and potentially user defined strings in terms of these
203
+ operators_with_tensors |= {
203
204
  "00": torch.tensor(
204
205
  [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
205
206
  ).reshape(1, 2, 2, 1),
@@ -218,26 +219,28 @@ class MPO(Operator):
218
219
 
219
220
  mpos = []
220
221
  for coeff, tensorop in operations:
221
- # this function will recurse through the operators, and replace any definitions
222
+ # this function will recurse through the operators_with_tensors,
223
+ # and replace any definitions
222
224
  # in terms of strings by the computed tensor
223
225
  def replace_operator_string(op: QuditOp | torch.Tensor) -> torch.Tensor:
224
- if isinstance(op, dict):
225
- for opstr, coeff in op.items():
226
- tensor = replace_operator_string(operators[opstr])
227
- operators[opstr] = tensor
228
- op[opstr] = tensor * coeff
229
- op = sum(cast(list[torch.Tensor], op.values()))
230
- return op
226
+ if isinstance(op, torch.Tensor):
227
+ return op
228
+
229
+ result = torch.zeros(1, 2, 2, 1, dtype=torch.complex128)
230
+ for opstr, coeff in op.items():
231
+ tensor = replace_operator_string(operators_with_tensors[opstr])
232
+ operators_with_tensors[opstr] = tensor
233
+ result += tensor * coeff
234
+ return result
231
235
 
232
236
  factors = [
233
237
  torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
234
238
  ] * nqubits
235
239
 
236
- for i, op in enumerate(tensorop):
237
- tensorop[i] = (replace_operator_string(op[0]), op[1])
238
-
239
240
  for op in tensorop:
240
- for i in op[1]:
241
- factors[i] = op[0]
241
+ factor = replace_operator_string(op[0])
242
+ for target_qubit in op[1]:
243
+ factors[target_qubit] = factor
244
+
242
245
  mpos.append(coeff * MPO(factors, **kwargs))
243
246
  return sum(mpos[1:], start=mpos[0])
emu_mps/mps.py CHANGED
@@ -7,15 +7,16 @@ from typing import Any, List, Optional, Iterable
7
7
  import torch
8
8
 
9
9
  from emu_base import State
10
+ from emu_mps import MPSConfig
10
11
  from emu_mps.algebra import add_factors, scale_factors
11
12
  from emu_mps.utils import (
12
- DEVICE_COUNT,
13
13
  apply_measurement_errors,
14
14
  assign_devices,
15
15
  truncate_impl,
16
16
  tensor_trace,
17
17
  n_operator,
18
18
  )
19
+ from emu_mps.constants import DEVICE_COUNT
19
20
 
20
21
 
21
22
  class MPS(State):
@@ -27,17 +28,13 @@ class MPS(State):
27
28
  Only qubits are supported.
28
29
  """
29
30
 
30
- DEFAULT_MAX_BOND_DIM: int = 1024
31
- DEFAULT_PRECISION: float = 1e-5
32
-
33
31
  def __init__(
34
32
  self,
35
33
  factors: List[torch.Tensor],
36
34
  /,
37
35
  *,
38
36
  orthogonality_center: Optional[int] = None,
39
- precision: float = DEFAULT_PRECISION,
40
- max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
37
+ config: Optional[MPSConfig] = None,
41
38
  num_gpus_to_use: Optional[int] = DEVICE_COUNT,
42
39
  ):
43
40
  """
@@ -53,14 +50,11 @@ class MPS(State):
53
50
  of the data to this constructor, or some shared objects.
54
51
  orthogonality_center: the orthogonality center of the MPS, or None (in which case
55
52
  it will be orthogonalized when needed)
56
- precision: the precision with which to truncate here or in tdvp
57
- max_bond_dim: the maximum bond dimension to allow
53
+ config: the emu-mps config object passed to the run method
58
54
  num_gpus_to_use: distribute the factors over this many GPUs
59
55
  0=all factors to cpu, None=keep the existing device assignment.
60
56
  """
61
- self.precision = precision
62
- self.max_bond_dim = max_bond_dim
63
-
57
+ self.config = config if config is not None else MPSConfig()
64
58
  assert all(
65
59
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
66
60
  ), "The dimensions of consecutive tensors should match"
@@ -84,8 +78,7 @@ class MPS(State):
84
78
  def make(
85
79
  cls,
86
80
  num_sites: int,
87
- precision: float = DEFAULT_PRECISION,
88
- max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
81
+ config: Optional[MPSConfig] = None,
89
82
  num_gpus_to_use: int = DEVICE_COUNT,
90
83
  ) -> MPS:
91
84
  """
@@ -93,11 +86,12 @@ class MPS(State):
93
86
 
94
87
  Args:
95
88
  num_sites: the number of qubits
96
- precision: the precision with which to truncate here or in tdvp
97
- max_bond_dim: the maximum bond dimension to allow
89
+ config: the MPSConfig
98
90
  num_gpus_to_use: distribute the factors over this many GPUs
99
91
  0=all factors to cpu
100
92
  """
93
+ config = config if config is not None else MPSConfig()
94
+
101
95
  if num_sites <= 1:
102
96
  raise ValueError("For 1 qubit states, do state vector")
103
97
 
@@ -106,8 +100,7 @@ class MPS(State):
106
100
  torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
107
101
  for _ in range(num_sites)
108
102
  ],
109
- precision=precision,
110
- max_bond_dim=max_bond_dim,
103
+ config=config,
111
104
  num_gpus_to_use=num_gpus_to_use,
112
105
  orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
113
106
  )
@@ -165,15 +158,11 @@ class MPS(State):
165
158
  """
166
159
  SVD based truncation of the state. Puts the orthogonality center at the first qubit.
167
160
  Calls orthogonalize on the last qubit, and then sweeps a series of SVDs right-left.
168
- Uses self.precision and self.max_bond_dim for determining accuracy.
161
+ Uses self.config for determining accuracy.
169
162
  An in-place operation.
170
163
  """
171
164
  self.orthogonalize(self.num_sites - 1)
172
- truncate_impl(
173
- self.factors,
174
- max_error=self.precision,
175
- max_rank=self.max_bond_dim,
176
- )
165
+ truncate_impl(self.factors, config=self.config)
177
166
  self.orthogonality_center = 0
178
167
 
179
168
  def get_max_bond_dim(self) -> int:
@@ -236,7 +225,7 @@ class MPS(State):
236
225
  num_qubits = len(self.factors)
237
226
 
238
227
  bitstring = ""
239
- acc_mps_j: torch.tensor = self.factors[0]
228
+ acc_mps_j: torch.Tensor = self.factors[0]
240
229
 
241
230
  for qubit in range(num_qubits):
242
231
  # comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
@@ -303,7 +292,7 @@ class MPS(State):
303
292
  """
304
293
  Returns the sum of two MPSs, computed with a direct algorithm.
305
294
  The resulting MPS is orthogonalized on the first site and truncated
306
- up to `self.precision`.
295
+ up to `self.config.precision`.
307
296
 
308
297
  Args:
309
298
  other: the other state
@@ -315,8 +304,7 @@ class MPS(State):
315
304
  new_tt = add_factors(self.factors, other.factors)
316
305
  result = MPS(
317
306
  new_tt,
318
- precision=self.precision,
319
- max_bond_dim=self.max_bond_dim,
307
+ config=self.config,
320
308
  num_gpus_to_use=None,
321
309
  orthogonality_center=None, # Orthogonality is lost.
322
310
  )
@@ -341,8 +329,7 @@ class MPS(State):
341
329
  factors = scale_factors(self.factors, scalar, which=which)
342
330
  return MPS(
343
331
  factors,
344
- precision=self.precision,
345
- max_bond_dim=self.max_bond_dim,
332
+ config=self.config,
346
333
  num_gpus_to_use=None,
347
334
  orthogonality_center=self.orthogonality_center,
348
335
  )
emu_mps/mps_backend.py CHANGED
@@ -1,8 +1,12 @@
1
- from pulser import Sequence
2
-
3
1
  from emu_base import Backend, BackendConfig, Results
4
2
  from emu_mps.mps_config import MPSConfig
5
- from emu_mps.mps_backend_impl import create_impl
3
+ from emu_mps.mps_backend_impl import create_impl, MPSBackendImpl
4
+ from pulser import Sequence
5
+ import pickle
6
+ import os
7
+ import time
8
+ import logging
9
+ import pathlib
6
10
 
7
11
 
8
12
  class MPSBackend(Backend):
@@ -11,6 +15,32 @@ class MPSBackend(Backend):
11
15
  aka tensor trains.
12
16
  """
13
17
 
18
+ def resume(self, autosave_file: str | pathlib.Path) -> Results:
19
+ """
20
+ Resume simulation from autosave file.
21
+ Only resume simulations from data you trust!
22
+ Unpickling of untrusted data is not safe.
23
+ """
24
+ if isinstance(autosave_file, str):
25
+ autosave_file = pathlib.Path(autosave_file)
26
+
27
+ if not autosave_file.is_file():
28
+ raise ValueError(f"Not a file: {autosave_file}")
29
+
30
+ with open(autosave_file, "rb") as f:
31
+ impl: MPSBackendImpl = pickle.load(f)
32
+
33
+ impl.autosave_file = autosave_file
34
+ impl.last_save_time = time.time()
35
+ impl.config.init_logging() # FIXME: might be best to take logger object out of config.
36
+
37
+ logging.getLogger("global_logger").warning(
38
+ f"Resuming simulation from file {autosave_file}\n"
39
+ f"Saving simulation state every {impl.config.autosave_dt} seconds"
40
+ )
41
+
42
+ return self._run(impl)
43
+
14
44
  def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
15
45
  """
16
46
  Emulates the given sequence.
@@ -29,7 +59,14 @@ class MPSBackend(Backend):
29
59
  impl = create_impl(sequence, mps_config)
30
60
  impl.init() # This is separate from the constructor for testing purposes.
31
61
 
62
+ return self._run(impl)
63
+
64
+ @staticmethod
65
+ def _run(impl: MPSBackendImpl) -> Results:
32
66
  while not impl.is_finished():
33
67
  impl.progress()
34
68
 
69
+ if impl.autosave_file.is_file():
70
+ os.remove(impl.autosave_file)
71
+
35
72
  return impl.results
@@ -1,7 +1,11 @@
1
1
  import math
2
+ import pathlib
2
3
  import random
4
+ import uuid
3
5
  from resource import RUSAGE_SELF, getrusage
4
6
  from typing import Optional
7
+ import pickle
8
+ import os
5
9
 
6
10
  import torch
7
11
  import time
@@ -9,6 +13,7 @@ from pulser import Sequence
9
13
 
10
14
  from emu_base import Results, State, PulserData
11
15
  from emu_base.math.brents_root_finding import BrentsRootFinder
16
+ from emu_mps.constants import DEVICE_COUNT
12
17
  from emu_mps.hamiltonian import make_H, update_H
13
18
  from emu_mps.mpo import MPO
14
19
  from emu_mps.mps import MPS
@@ -17,7 +22,6 @@ from emu_mps.noise import compute_noise_from_lindbladians, pick_well_prepared_qu
17
22
  from emu_mps.tdvp import (
18
23
  evolve_single,
19
24
  evolve_pair,
20
- EvolveConfig,
21
25
  new_right_bath,
22
26
  right_baths,
23
27
  )
@@ -52,6 +56,7 @@ class MPSBackendImpl:
52
56
  def __init__(self, mps_config: MPSConfig, pulser_data: PulserData):
53
57
  self.config = mps_config
54
58
  self.target_time = float(self.config.dt)
59
+ self.pulser_data = pulser_data
55
60
  self.qubit_count = pulser_data.qubit_count
56
61
  assert self.qubit_count >= 2
57
62
  self.omega = pulser_data.omega
@@ -71,15 +76,23 @@ class MPSBackendImpl:
71
76
  self.tdvp_index = 0
72
77
  self.timestep_index = 0
73
78
  self.results = Results()
74
-
75
- self.evolve_config = EvolveConfig(
76
- exp_tolerance=self.config.precision * self.config.extra_krylov_tolerance,
77
- norm_tolerance=self.config.precision * self.config.extra_krylov_tolerance,
78
- max_krylov_dim=self.config.max_krylov_dim,
79
- is_hermitian=not self.has_lindblad_noise,
80
- max_error=self.config.precision,
81
- max_rank=self.config.max_bond_dim,
79
+ self.autosave_file = self._get_autosave_filepath(self.config.autosave_prefix)
80
+ self.config.logger.warning(
81
+ f"""Will save simulation state to file "{self.autosave_file.name}"
82
+ every {self.config.autosave_dt} seconds.\n"""
83
+ f"""To resume: `MPSBackend().resume("{self.autosave_file}")`"""
82
84
  )
85
+ self.last_save_time = time.time()
86
+
87
+ if self.config.num_gpus_to_use > DEVICE_COUNT:
88
+ self.config.logger.warning(
89
+ f"Requested to use {self.config.num_gpus_to_use} GPU(s) "
90
+ f"but only {DEVICE_COUNT if DEVICE_COUNT > 0 else 'cpu'} available"
91
+ )
92
+
93
+ @staticmethod
94
+ def _get_autosave_filepath(autosave_prefix: str) -> pathlib.Path:
95
+ return pathlib.Path(os.getcwd()) / (autosave_prefix + str(uuid.uuid1()) + ".dat")
83
96
 
84
97
  def init_dark_qubits(self) -> None:
85
98
  has_state_preparation_error: bool = (
@@ -112,8 +125,7 @@ class MPSBackendImpl:
112
125
  if initial_state is None:
113
126
  self.state = MPS.make(
114
127
  self.qubit_count,
115
- precision=self.config.precision,
116
- max_bond_dim=self.config.max_bond_dim,
128
+ config=self.config,
117
129
  num_gpus_to_use=self.config.num_gpus_to_use,
118
130
  )
119
131
  return
@@ -128,8 +140,7 @@ class MPSBackendImpl:
128
140
  initial_state = MPS(
129
141
  # Deep copy of every tensor of the initial state.
130
142
  [f.clone().detach() for f in initial_state.factors],
131
- precision=self.config.precision,
132
- max_bond_dim=self.config.max_bond_dim,
143
+ config=self.config,
133
144
  num_gpus_to_use=self.config.num_gpus_to_use,
134
145
  )
135
146
  initial_state.truncate()
@@ -198,7 +209,8 @@ class MPSBackendImpl:
198
209
  ham_factor=self.hamiltonian.factors[index],
199
210
  baths=baths,
200
211
  dt=dt,
201
- config=self.evolve_config,
212
+ config=self.config,
213
+ is_hermitian=not self.has_lindblad_noise,
202
214
  )
203
215
  else:
204
216
  assert orth_center_right is not None
@@ -213,8 +225,9 @@ class MPSBackendImpl:
213
225
  ham_factors=self.hamiltonian.factors[l : r + 1],
214
226
  baths=baths,
215
227
  dt=dt,
216
- config=self.evolve_config,
228
+ config=self.config,
217
229
  orth_center_right=orth_center_right,
230
+ is_hermitian=not self.has_lindblad_noise,
218
231
  )
219
232
 
220
233
  self.state.orthogonality_center = r if orth_center_right else l
@@ -292,7 +305,7 @@ class MPSBackendImpl:
292
305
  )
293
306
  if not self.has_lindblad_noise:
294
307
  # Free memory because it won't be used anymore
295
- self.right_baths[-2] = None
308
+ self.right_baths[-2] = torch.zeros(0)
296
309
 
297
310
  self._evolve(self.tdvp_index, dt=-delta_time / 2)
298
311
  self.left_baths.pop()
@@ -312,7 +325,7 @@ class MPSBackendImpl:
312
325
  else:
313
326
  raise Exception("Didn't expect this")
314
327
 
315
- # TODO: checkpoint/autosave here
328
+ self.save_simulation()
316
329
 
317
330
  def tdvp_complete(self) -> None:
318
331
  self.current_time = self.target_time
@@ -343,14 +356,40 @@ class MPSBackendImpl:
343
356
  self.log_step_statistics(duration=time.time() - self.time)
344
357
  self.time = time.time()
345
358
 
359
+ def save_simulation(self) -> None:
360
+ if self.last_save_time > time.time() - self.config.autosave_dt:
361
+ return
362
+
363
+ basename = self.autosave_file
364
+ with open(basename.with_suffix(".new"), "wb") as file_handle:
365
+ pickle.dump(self, file_handle)
366
+
367
+ if basename.is_file():
368
+ os.rename(basename, basename.with_suffix(".bak"))
369
+
370
+ os.rename(basename.with_suffix(".new"), basename)
371
+ autosave_filesize = os.path.getsize(self.autosave_file) / 1e6
372
+
373
+ if basename.with_suffix(".bak").is_file():
374
+ os.remove(basename.with_suffix(".bak"))
375
+
376
+ self.last_save_time = time.time()
377
+
378
+ self.config.logger.debug(
379
+ f"Saved simulation state in file {self.autosave_file} ({autosave_filesize}MB)"
380
+ )
381
+
346
382
  def fill_results(self) -> None:
347
383
  normalized_state = 1 / self.state.norm() * self.state
348
384
 
385
+ current_time_int: int = round(self.current_time)
386
+ assert abs(self.current_time - current_time_int) < 1e-10
387
+
349
388
  if self.well_prepared_qubits_filter is None:
350
389
  for callback in self.config.callbacks:
351
390
  callback(
352
391
  self.config,
353
- self.current_time,
392
+ current_time_int,
354
393
  normalized_state,
355
394
  self.hamiltonian,
356
395
  self.results,
@@ -359,7 +398,7 @@ class MPSBackendImpl:
359
398
 
360
399
  full_mpo, full_state = None, None
361
400
  for callback in self.config.callbacks:
362
- if self.current_time not in callback.evaluation_times:
401
+ if current_time_int not in callback.evaluation_times:
363
402
  continue
364
403
 
365
404
  if full_mpo is None or full_state is None:
@@ -380,7 +419,7 @@ class MPSBackendImpl:
380
419
  ),
381
420
  )
382
421
 
383
- callback(self.config, self.current_time, full_state, full_mpo, self.results)
422
+ callback(self.config, current_time_int, full_state, full_mpo, self.results)
384
423
 
385
424
  def log_step_statistics(self, *, duration: float) -> None:
386
425
  if self.state.factors[0].is_cuda:
@@ -424,13 +463,12 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
424
463
  """
425
464
 
426
465
  jump_threshold: float
427
- aggregated_lindblad_ops: Optional[torch.Tensor]
466
+ aggregated_lindblad_ops: torch.Tensor
428
467
  norm_gap_before_jump: float
429
468
  root_finder: Optional[BrentsRootFinder]
430
469
 
431
470
  def __init__(self, config: MPSConfig, pulser_data: PulserData):
432
471
  super().__init__(config, pulser_data)
433
- self.aggregated_lindblad_ops = None
434
472
  self.lindblad_ops = pulser_data.lindblad_ops
435
473
  self.root_finder = None
436
474
 
@@ -491,7 +529,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
491
529
  for qubit in range(self.state.num_sites)
492
530
  for op in self.lindblad_ops
493
531
  ],
494
- weights=jump_operator_weights.reshape(-1),
532
+ weights=jump_operator_weights.reshape(-1).tolist(),
495
533
  )[0]
496
534
 
497
535
  self.state.apply(jumped_qubit_index, jump_operator)
emu_mps/mps_config.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from typing import Any
2
2
 
3
3
  from emu_base import BackendConfig, State
4
- from emu_mps.utils import DEVICE_COUNT
4
+ from emu_mps.constants import DEVICE_COUNT
5
5
 
6
6
 
7
7
  class MPSConfig(BackendConfig):
@@ -23,6 +23,10 @@ class MPSConfig(BackendConfig):
23
23
  num_gpus_to_use: during the simulation, distribute the state over this many GPUs
24
24
  0=all factors to cpu. As shown in the benchmarks, using multiple GPUs might
25
25
  alleviate memory pressure per GPU, but the runtime should be similar.
26
+ autosave_prefix: filename prefix for autosaving simulation state to file
27
+ autosave_dt: minimum time interval in seconds between two autosaves
28
+ Saving the simulation state is only possible at specific times,
29
+ therefore this interval is only a lower bound.
26
30
  kwargs: arguments that are passed to the base class
27
31
 
28
32
  Examples:
@@ -43,6 +47,8 @@ class MPSConfig(BackendConfig):
43
47
  max_krylov_dim: int = 100,
44
48
  extra_krylov_tolerance: float = 1e-3,
45
49
  num_gpus_to_use: int = DEVICE_COUNT,
50
+ autosave_prefix: str = "emu_mps_save_",
51
+ autosave_dt: int = 600, # 10 minutes
46
52
  **kwargs: Any,
47
53
  ):
48
54
  super().__init__(**kwargs)
@@ -62,3 +68,11 @@ class MPSConfig(BackendConfig):
62
68
  and self.noise_model.amp_sigma != 0.0
63
69
  ):
64
70
  raise NotImplementedError("Unsupported noise type: amp_sigma")
71
+
72
+ self.autosave_prefix = autosave_prefix
73
+ self.autosave_dt = autosave_dt
74
+
75
+ MIN_AUTOSAVE_DT = 10
76
+ assert (
77
+ self.autosave_dt > MIN_AUTOSAVE_DT
78
+ ), f"autosave_dt must be larger than {MIN_AUTOSAVE_DT} seconds"