emu-mps 1.2.6__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -1,19 +1,19 @@
1
- from emu_base import (
2
- Callback,
1
+ from pulser.backend import (
3
2
  BitStrings,
4
3
  CorrelationMatrix,
5
4
  Energy,
6
5
  EnergyVariance,
7
6
  Expectation,
8
7
  Fidelity,
9
- QubitDensity,
8
+ Occupation,
10
9
  StateResult,
11
- SecondMomentOfEnergy,
10
+ EnergySecondMoment,
12
11
  )
13
12
  from .mps_config import MPSConfig
14
13
  from .mpo import MPO
15
14
  from .mps import MPS, inner
16
15
  from .mps_backend import MPSBackend
16
+ from emu_base import aggregate
17
17
 
18
18
 
19
19
  __all__ = [
@@ -23,16 +23,16 @@ __all__ = [
23
23
  "inner",
24
24
  "MPSConfig",
25
25
  "MPSBackend",
26
- "Callback",
27
26
  "StateResult",
28
27
  "BitStrings",
29
- "QubitDensity",
28
+ "Occupation",
30
29
  "CorrelationMatrix",
31
30
  "Expectation",
32
31
  "Fidelity",
33
32
  "Energy",
34
33
  "EnergyVariance",
35
- "SecondMomentOfEnergy",
34
+ "EnergySecondMoment",
35
+ "aggregate",
36
36
  ]
37
37
 
38
- __version__ = "1.2.6"
38
+ __version__ = "2.0.0"
@@ -0,0 +1,96 @@
1
+ import torch
2
+
3
+ from pulser.backend.default_observables import (
4
+ CorrelationMatrix,
5
+ EnergySecondMoment,
6
+ EnergyVariance,
7
+ Occupation,
8
+ Energy,
9
+ )
10
+ from typing import TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ from emu_mps.mps_config import MPSConfig
14
+ from emu_mps.mps import MPS
15
+ from emu_mps.mpo import MPO
16
+
17
+
18
+ def qubit_occupation_mps_impl(
19
+ self: Occupation,
20
+ *,
21
+ config: "MPSConfig",
22
+ state: "MPS",
23
+ hamiltonian: "MPO",
24
+ ) -> torch.Tensor:
25
+ """
26
+ Custom implementation of the occupation ❬ψ|nᵢ|ψ❭ for the EMU-MPS.
27
+ """
28
+ op = torch.tensor(
29
+ [[[0.0, 0.0], [0.0, 1.0]]], dtype=torch.complex128, device=state.factors[0].device
30
+ )
31
+ return state.expect_batch(op).real.reshape(-1).cpu()
32
+
33
+
34
+ def correlation_matrix_mps_impl(
35
+ self: CorrelationMatrix,
36
+ *,
37
+ config: "MPSConfig",
38
+ state: "MPS",
39
+ hamiltonian: "MPO",
40
+ ) -> torch.Tensor:
41
+ """
42
+ Custom implementation of the density-density correlation ❬ψ|nᵢnⱼ|ψ❭ for the EMU-MPS.
43
+
44
+ TODO: extend to arbitrary two-point correlation ❬ψ|AᵢBⱼ|ψ❭
45
+ """
46
+ return state.get_correlation_matrix().cpu()
47
+
48
+
49
+ def energy_variance_mps_impl(
50
+ self: EnergyVariance,
51
+ *,
52
+ config: "MPSConfig",
53
+ state: "MPS",
54
+ hamiltonian: "MPO",
55
+ ) -> torch.Tensor:
56
+ """
57
+ Custom implementation of the energy variance ❬ψ|H²|ψ❭-❬ψ|H|ψ❭² for the EMU-MPS.
58
+ """
59
+ h_squared = hamiltonian @ hamiltonian
60
+ h_2 = h_squared.expect(state).cpu()
61
+ h = hamiltonian.expect(state).cpu()
62
+ en_var = h_2 - h**2
63
+ return en_var.real # type: ignore[no-any-return]
64
+
65
+
66
+ def energy_second_moment_mps_impl(
67
+ self: EnergySecondMoment,
68
+ *,
69
+ config: "MPSConfig",
70
+ state: "MPS",
71
+ hamiltonian: "MPO",
72
+ ) -> torch.Tensor:
73
+ """
74
+ Custom implementation of the second moment of energy ❬ψ|H²|ψ❭
75
+ for the EMU-MPS.
76
+ """
77
+ h_square = hamiltonian @ hamiltonian
78
+ h_2 = h_square.expect(state).cpu()
79
+ assert torch.allclose(h_2.imag, torch.zeros_like(h_2.imag), atol=1e-4)
80
+ return h_2.real
81
+
82
+
83
+ def energy_mps_impl(
84
+ self: Energy,
85
+ *,
86
+ config: "MPSConfig",
87
+ state: "MPS",
88
+ hamiltonian: "MPO",
89
+ ) -> torch.Tensor:
90
+ """
91
+ Custom implementation of the second moment of energy ❬ψ|H²|ψ❭
92
+ for the EMU-MPS.
93
+ """
94
+ h = hamiltonian.expect(state)
95
+ assert torch.allclose(h.imag, torch.zeros_like(h.imag), atol=1e-4)
96
+ return h.real
emu_mps/mpo.py CHANGED
@@ -1,12 +1,13 @@
1
1
  from __future__ import annotations
2
2
  import itertools
3
- from typing import Any, List, Iterable, Optional
3
+ from typing import Any, List, Sequence, Optional
4
4
 
5
5
  import torch
6
6
 
7
+ from pulser.backend import State, Operator
8
+ from emu_base import DEVICE_COUNT
7
9
  from emu_mps.algebra import add_factors, scale_factors, zip_right
8
- from emu_base.base_classes.operator import FullOp, QuditOp
9
- from emu_base import Operator, State, DEVICE_COUNT
10
+ from pulser.backend.operator import FullOp, QuditOp
10
11
  from emu_mps.mps import MPS
11
12
  from emu_mps.utils import new_left_bath, assign_devices
12
13
 
@@ -29,7 +30,7 @@ def _validate_operator_targets(operations: FullOp, nqubits: int) -> None:
29
30
  )
30
31
 
31
32
 
32
- class MPO(Operator):
33
+ class MPO(Operator[complex, torch.Tensor, MPS]):
33
34
  """
34
35
  Matrix Product Operator.
35
36
 
@@ -61,7 +62,7 @@ class MPO(Operator):
61
62
  def __repr__(self) -> str:
62
63
  return "[" + ", ".join(map(repr, self.factors)) + "]"
63
64
 
64
- def __mul__(self, other: State) -> MPS:
65
+ def apply_to(self, other: MPS) -> MPS:
65
66
  """
66
67
  Applies this MPO to the given MPS.
67
68
  The returned MPS is:
@@ -84,7 +85,7 @@ class MPO(Operator):
84
85
  )
85
86
  return MPS(factors, orthogonality_center=0)
86
87
 
87
- def __add__(self, other: Operator) -> MPO:
88
+ def __add__(self, other: MPO) -> MPO:
88
89
  """
89
90
  Returns the sum of two MPOs, computed with a direct algorithm.
90
91
  The result is currently not truncated
@@ -113,7 +114,7 @@ class MPO(Operator):
113
114
  factors = scale_factors(self.factors, scalar, which=0)
114
115
  return MPO(factors)
115
116
 
116
- def __matmul__(self, other: Operator) -> MPO:
117
+ def __matmul__(self, other: MPO) -> MPO:
117
118
  """
118
119
  Compose two operators. The ordering is that
119
120
  self is applied after other.
@@ -128,7 +129,7 @@ class MPO(Operator):
128
129
  factors = zip_right(self.factors, other.factors)
129
130
  return MPO(factors)
130
131
 
131
- def expect(self, state: State) -> float | complex:
132
+ def expect(self, state: State) -> torch.Tensor:
132
133
  """
133
134
  Compute the expectation value of self on the given state.
134
135
 
@@ -151,17 +152,17 @@ class MPO(Operator):
151
152
  state.factors[i + 1].device
152
153
  )
153
154
  acc = new_left_bath(acc, state.factors[n], self.factors[n])
154
- return acc.item() # type: ignore [no-any-return]
155
-
156
- @staticmethod
157
- def from_operator_string(
158
- basis: Iterable[str],
159
- nqubits: int,
160
- operations: FullOp,
161
- operators: dict[str, QuditOp] = {},
162
- /,
155
+ return acc.reshape(1)[0].cpu()
156
+
157
+ @classmethod
158
+ def _from_operator_repr(
159
+ cls,
160
+ *,
161
+ eigenstates: Sequence[str],
162
+ n_qudits: int,
163
+ operations: FullOp[complex],
163
164
  **kwargs: Any,
164
- ) -> MPO:
165
+ ) -> tuple[MPO, FullOp[complex]]:
165
166
  """
166
167
  See the base class
167
168
 
@@ -174,15 +175,16 @@ class MPO(Operator):
174
175
  Returns:
175
176
  the operator in MPO form.
176
177
  """
177
- operators_with_tensors: dict[str, torch.Tensor | QuditOp] = dict(operators)
178
178
 
179
- _validate_operator_targets(operations, nqubits)
179
+ _validate_operator_targets(operations, n_qudits)
180
180
 
181
- basis = set(basis)
181
+ basis = set(eigenstates)
182
+
183
+ operators_with_tensors: dict[str, torch.Tensor | QuditOp]
182
184
  if basis == {"r", "g"}:
183
185
  # operators_with_tensors will now contain the basis for single qubit ops,
184
186
  # and potentially user defined strings in terms of these
185
- operators_with_tensors |= {
187
+ operators_with_tensors = {
186
188
  "gg": torch.tensor(
187
189
  [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
188
190
  ).reshape(1, 2, 2, 1),
@@ -199,7 +201,7 @@ class MPO(Operator):
199
201
  elif basis == {"0", "1"}:
200
202
  # operators_with_tensors will now contain the basis for single qubit ops,
201
203
  # and potentially user defined strings in terms of these
202
- operators_with_tensors |= {
204
+ operators_with_tensors = {
203
205
  "00": torch.tensor(
204
206
  [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
205
207
  ).reshape(1, 2, 2, 1),
@@ -234,12 +236,12 @@ class MPO(Operator):
234
236
 
235
237
  factors = [
236
238
  torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
237
- ] * nqubits
239
+ ] * n_qudits
238
240
 
239
241
  for op in tensorop:
240
242
  factor = replace_operator_string(op[0])
241
243
  for target_qubit in op[1]:
242
244
  factors[target_qubit] = factor
243
245
 
244
- mpos.append(coeff * MPO(factors, **kwargs))
245
- return sum(mpos[1:], start=mpos[0])
246
+ mpos.append(coeff * cls(factors, **kwargs))
247
+ return sum(mpos[1:], start=mpos[0]), operations # type: ignore[no-any-return]
emu_mps/mps.py CHANGED
@@ -2,11 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import math
4
4
  from collections import Counter
5
- from typing import Any, List, Optional, Iterable
5
+ from typing import List, Optional, Sequence, TypeVar, Mapping
6
6
 
7
7
  import torch
8
8
 
9
- from emu_base import State, DEVICE_COUNT
9
+ from pulser.backend.state import State, Eigenstate
10
+ from emu_base import DEVICE_COUNT
10
11
  from emu_mps import MPSConfig
11
12
  from emu_mps.algebra import add_factors, scale_factors
12
13
  from emu_mps.utils import (
@@ -17,8 +18,10 @@ from emu_mps.utils import (
17
18
  n_operator,
18
19
  )
19
20
 
21
+ ArgScalarType = TypeVar("ArgScalarType")
20
22
 
21
- class MPS(State):
23
+
24
+ class MPS(State[complex, torch.Tensor]):
22
25
  """
23
26
  Matrix Product State, aka tensor train.
24
27
 
@@ -53,6 +56,7 @@ class MPS(State):
53
56
  num_gpus_to_use: distribute the factors over this many GPUs
54
57
  0=all factors to cpu, None=keep the existing device assignment.
55
58
  """
59
+ self._eigenstates = ["0", "1"]
56
60
  self.config = config if config is not None else MPSConfig()
57
61
  assert all(
58
62
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
@@ -73,6 +77,11 @@ class MPS(State):
73
77
  if num_gpus_to_use is not None:
74
78
  assign_devices(self.factors, min(DEVICE_COUNT, num_gpus_to_use))
75
79
 
80
+ @property
81
+ def n_qudits(self) -> int:
82
+ """The number of qudits in the state."""
83
+ return self.num_sites
84
+
76
85
  @classmethod
77
86
  def make(
78
87
  cls,
@@ -174,7 +183,12 @@ class MPS(State):
174
183
  return max((x.shape[2] for x in self.factors), default=0)
175
184
 
176
185
  def sample(
177
- self, num_shots: int, p_false_pos: float = 0.0, p_false_neg: float = 0.0
186
+ self,
187
+ *,
188
+ num_shots: int,
189
+ one_state: Eigenstate | None = None,
190
+ p_false_pos: float = 0.0,
191
+ p_false_neg: float = 0.0,
178
192
  ) -> Counter[str]:
179
193
  """
180
194
  Samples bitstrings, taking into account the specified error rates.
@@ -182,18 +196,66 @@ class MPS(State):
182
196
  Args:
183
197
  num_shots: how many bitstrings to sample
184
198
  p_false_pos: the rate at which a 0 is read as a 1
185
- p_false_neg: teh rate at which a 1 is read as a 0
199
+ p_false_neg: the rate at which a 1 is read as a 0
186
200
 
187
201
  Returns:
188
202
  the measured bitstrings, by count
189
203
  """
204
+ assert one_state in {None, "r", "1"}
190
205
  self.orthogonalize(0)
191
206
 
192
- num_qubits = len(self.factors)
193
- rnd_matrix = torch.rand(num_shots, num_qubits)
194
- bitstrings = Counter(
195
- self._sample_implementation(rnd_matrix[x, :]) for x in range(num_shots)
196
- )
207
+ rnd_matrix = torch.rand(num_shots, self.num_sites).to(self.factors[0].device)
208
+
209
+ bitstrings: Counter[str] = Counter()
210
+
211
+ # Shots are performed in batches.
212
+ # Larger max_batch_size is faster but uses more memory.
213
+ max_batch_size = 32
214
+
215
+ shots_done = 0
216
+ while shots_done < num_shots:
217
+ batch_size = min(max_batch_size, num_shots - shots_done)
218
+ batched_accumulator = torch.ones(
219
+ batch_size, 1, dtype=torch.complex128, device=self.factors[0].device
220
+ )
221
+
222
+ batch_outcomes = torch.empty(batch_size, self.num_sites, dtype=torch.bool)
223
+
224
+ for qubit, factor in enumerate(self.factors):
225
+ batched_accumulator = torch.tensordot(
226
+ batched_accumulator.to(factor.device), factor, dims=1
227
+ )
228
+
229
+ # Probability of measuring qubit == 0 for each shot in the batch
230
+ probas = (
231
+ torch.linalg.vector_norm(batched_accumulator[:, 0, :], dim=1) ** 2
232
+ )
233
+
234
+ outcomes = (
235
+ rnd_matrix[shots_done : shots_done + batch_size, qubit].to(
236
+ factor.device
237
+ )
238
+ > probas
239
+ )
240
+ batch_outcomes[:, qubit] = outcomes
241
+
242
+ # Batch collapse qubit
243
+ tmp = torch.stack((~outcomes, outcomes), dim=1).to(dtype=torch.complex128)
244
+
245
+ batched_accumulator = (
246
+ torch.tensordot(batched_accumulator, tmp, dims=([1], [1]))
247
+ .diagonal(dim1=0, dim2=2)
248
+ .transpose(1, 0)
249
+ )
250
+ batched_accumulator /= torch.sqrt(
251
+ (~outcomes) * probas + outcomes * (1 - probas)
252
+ ).unsqueeze(1)
253
+
254
+ shots_done += batch_size
255
+
256
+ for outcome in batch_outcomes:
257
+ bitstrings.update(["".join("0" if x == 0 else "1" for x in outcome)])
258
+
197
259
  if p_false_neg > 0 or p_false_pos > 0:
198
260
  bitstrings = apply_measurement_errors(
199
261
  bitstrings,
@@ -202,55 +264,17 @@ class MPS(State):
202
264
  )
203
265
  return bitstrings
204
266
 
205
- def norm(self) -> float:
267
+ def norm(self) -> torch.Tensor:
206
268
  """Computes the norm of the MPS."""
207
269
  orthogonality_center = (
208
270
  self.orthogonality_center
209
271
  if self.orthogonality_center is not None
210
272
  else self.orthogonalize(0)
211
273
  )
274
+ # the torch.norm function is not properly typed.
275
+ return self.factors[orthogonality_center].norm().cpu() # type: ignore[no-any-return]
212
276
 
213
- return float(
214
- torch.linalg.norm(self.factors[orthogonality_center].to("cpu")).item()
215
- )
216
-
217
- def _sample_implementation(self, rnd_vector: torch.Tensor) -> str:
218
- """
219
- Samples this MPS once, returning the resulting bitstring.
220
- """
221
- assert rnd_vector.shape == (self.num_sites,)
222
- assert self.orthogonality_center == 0
223
-
224
- num_qubits = len(self.factors)
225
-
226
- bitstring = ""
227
- acc_mps_j: torch.Tensor = self.factors[0]
228
-
229
- for qubit in range(num_qubits):
230
- # comp_basis is a projector: 0 is for ket |0> and 1 for ket |1>
231
- comp_basis = 0 # check if the qubit is in |0>
232
- # Measure the qubit j by applying the projector onto nth comp basis state
233
- tensorj_projected_n = acc_mps_j[:, comp_basis, :]
234
- probability_n = (tensorj_projected_n.norm() ** 2).item()
235
-
236
- if rnd_vector[qubit] > probability_n:
237
- # the qubit is in |1>
238
- comp_basis = 1
239
- tensorj_projected_n = acc_mps_j[:, comp_basis, :]
240
- probability_n = 1 - probability_n
241
-
242
- bitstring += str(comp_basis)
243
- if qubit < num_qubits - 1:
244
- acc_mps_j = torch.tensordot(
245
- tensorj_projected_n.to(device=self.factors[qubit + 1].device),
246
- self.factors[qubit + 1],
247
- dims=1,
248
- )
249
- acc_mps_j /= math.sqrt(probability_n)
250
-
251
- return bitstring
252
-
253
- def inner(self, other: State) -> float | complex:
277
+ def inner(self, other: State) -> torch.Tensor:
254
278
  """
255
279
  Compute the inner product between this state and other.
256
280
  Note that self is the left state in the inner product,
@@ -274,7 +298,14 @@ class MPS(State):
274
298
  acc = torch.tensordot(acc, other.factors[i].to(acc.device), dims=1)
275
299
  acc = torch.tensordot(self.factors[i].conj(), acc, dims=([0, 1], [0, 1]))
276
300
 
277
- return acc.item() # type: ignore[no-any-return]
301
+ return acc.reshape(1)[0].cpu()
302
+
303
+ def overlap(self, other: State, /) -> torch.Tensor:
304
+ """
305
+ Compute the overlap of this state and other. This is defined as
306
+ $|\\langle self | other \\rangle |^2$
307
+ """
308
+ return torch.abs(self.inner(other)) ** 2 # type: ignore[no-any-return]
278
309
 
279
310
  def get_memory_footprint(self) -> float:
280
311
  """
@@ -336,14 +367,13 @@ class MPS(State):
336
367
  def __imul__(self, scalar: complex) -> MPS:
337
368
  return self.__rmul__(scalar)
338
369
 
339
- @staticmethod
340
- def from_state_string(
370
+ @classmethod
371
+ def _from_state_amplitudes(
372
+ cls,
341
373
  *,
342
- basis: Iterable[str],
343
- nqubits: int,
344
- strings: dict[str, complex],
345
- **kwargs: Any,
346
- ) -> MPS:
374
+ eigenstates: Sequence[str],
375
+ amplitudes: Mapping[str, complex],
376
+ ) -> tuple[MPS, Mapping[str, complex]]:
347
377
  """
348
378
  See the base class.
349
379
 
@@ -356,7 +386,8 @@ class MPS(State):
356
386
  The resulting MPS representation of the state.s
357
387
  """
358
388
 
359
- basis = set(basis)
389
+ nqubits = len(next(iter(amplitudes.keys())))
390
+ basis = set(eigenstates)
360
391
  if basis == {"r", "g"}:
361
392
  one = "r"
362
393
  elif basis == {"0", "1"}:
@@ -370,18 +401,17 @@ class MPS(State):
370
401
  accum_mps = MPS(
371
402
  [torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
372
403
  orthogonality_center=0,
373
- **kwargs,
374
404
  )
375
405
 
376
- for state, amplitude in strings.items():
406
+ for state, amplitude in amplitudes.items():
377
407
  factors = [basis_1 if ch == one else basis_0 for ch in state]
378
- accum_mps += amplitude * MPS(factors, **kwargs)
408
+ accum_mps += amplitude * MPS(factors)
379
409
  norm = accum_mps.norm()
380
410
  if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
381
411
  print("\nThe state is not normalized, normalizing it for you.")
382
412
  accum_mps *= 1 / norm
383
413
 
384
- return accum_mps
414
+ return accum_mps, amplitudes
385
415
 
386
416
  def expect_batch(self, single_qubit_operators: torch.Tensor) -> torch.Tensor:
387
417
  """
@@ -449,7 +479,7 @@ class MPS(State):
449
479
 
450
480
  def get_correlation_matrix(
451
481
  self, *, operator: torch.Tensor = n_operator
452
- ) -> list[list[float]]:
482
+ ) -> torch.Tensor:
453
483
  """
454
484
  Efficiently compute the symmetric correlation matrix
455
485
  C_ij = <self|operator_i operator_j|self>
@@ -463,7 +493,7 @@ class MPS(State):
463
493
  """
464
494
  assert operator.shape == (2, 2)
465
495
 
466
- result = [[0.0 for _ in range(self.num_sites)] for _ in range(self.num_sites)]
496
+ result = torch.zeros(self.num_sites, self.num_sites, dtype=torch.complex128)
467
497
 
468
498
  for left in range(0, self.num_sites):
469
499
  self.orthogonalize(left)
@@ -475,7 +505,7 @@ class MPS(State):
475
505
  accumulator = torch.tensordot(
476
506
  accumulator, self.factors[left].conj(), dims=([0, 2], [0, 1])
477
507
  )
478
- result[left][left] = accumulator.trace().item().real
508
+ result[left, left] = accumulator.trace().item().real
479
509
  for right in range(left + 1, self.num_sites):
480
510
  partial = torch.tensordot(
481
511
  accumulator.to(self.factors[right].device),
@@ -486,7 +516,7 @@ class MPS(State):
486
516
  partial, self.factors[right].conj(), dims=([0], [0])
487
517
  )
488
518
 
489
- result[left][right] = (
519
+ result[left, right] = (
490
520
  torch.tensordot(
491
521
  partial, operator.to(partial.device), dims=([0, 2], [0, 1])
492
522
  )
@@ -494,13 +524,13 @@ class MPS(State):
494
524
  .item()
495
525
  .real
496
526
  )
497
- result[right][left] = result[left][right]
527
+ result[right, left] = result[left, right]
498
528
  accumulator = tensor_trace(partial, 0, 2)
499
529
 
500
530
  return result
501
531
 
502
532
 
503
- def inner(left: MPS, right: MPS) -> float | complex:
533
+ def inner(left: MPS, right: MPS) -> torch.Tensor:
504
534
  """
505
535
  Wrapper around MPS.inner.
506
536
 
emu_mps/mps_backend.py CHANGED
@@ -1,7 +1,6 @@
1
- from emu_base import Backend, BackendConfig, Results
1
+ from pulser.backend import EmulatorBackend, Results
2
2
  from emu_mps.mps_config import MPSConfig
3
3
  from emu_mps.mps_backend_impl import create_impl, MPSBackendImpl
4
- from pulser import Sequence
5
4
  import pickle
6
5
  import os
7
6
  import time
@@ -9,13 +8,16 @@ import logging
9
8
  import pathlib
10
9
 
11
10
 
12
- class MPSBackend(Backend):
11
+ class MPSBackend(EmulatorBackend):
13
12
  """
14
13
  A backend for emulating Pulser sequences using Matrix Product States (MPS),
15
14
  aka tensor trains.
16
15
  """
17
16
 
18
- def resume(self, autosave_file: str | pathlib.Path) -> Results:
17
+ default_config = MPSConfig()
18
+
19
+ @staticmethod
20
+ def resume(autosave_file: str | pathlib.Path) -> Results:
19
21
  """
20
22
  Resume simulation from autosave file.
21
23
  Only resume simulations from data you trust!
@@ -39,24 +41,18 @@ class MPSBackend(Backend):
39
41
  f"Saving simulation state every {impl.config.autosave_dt} seconds"
40
42
  )
41
43
 
42
- return self._run(impl)
44
+ return MPSBackend._run(impl)
43
45
 
44
- def run(self, sequence: Sequence, mps_config: BackendConfig) -> Results:
46
+ def run(self) -> Results:
45
47
  """
46
48
  Emulates the given sequence.
47
49
 
48
- Args:
49
- sequence: a Pulser sequence to simulate
50
- mps_config: the backends config. Should be of type MPSConfig
51
-
52
50
  Returns:
53
51
  the simulation results
54
52
  """
55
- assert isinstance(mps_config, MPSConfig)
56
-
57
- self.validate_sequence(sequence)
53
+ assert isinstance(self._config, MPSConfig)
58
54
 
59
- impl = create_impl(sequence, mps_config)
55
+ impl = create_impl(self._sequence, self._config)
60
56
  impl.init() # This is separate from the constructor for testing purposes.
61
57
 
62
58
  return self._run(impl)
@@ -2,16 +2,19 @@ import math
2
2
  import pathlib
3
3
  import random
4
4
  import uuid
5
+
5
6
  from resource import RUSAGE_SELF, getrusage
6
- from typing import Optional
7
+ from typing import Optional, Any
8
+ import typing
7
9
  import pickle
8
10
  import os
9
-
10
11
  import torch
11
12
  import time
12
13
  from pulser import Sequence
14
+ from types import MethodType
13
15
 
14
- from emu_base import Results, State, PulserData, DEVICE_COUNT
16
+ from pulser.backend import State, Observable, EmulationConfig, Results
17
+ from emu_base import PulserData, DEVICE_COUNT
15
18
  from emu_base.math.brents_root_finding import BrentsRootFinder
16
19
  from emu_mps.hamiltonian import make_H, update_H
17
20
  from emu_mps.mpo import MPO
@@ -33,6 +36,56 @@ from emu_mps.utils import (
33
36
  from enum import Enum, auto
34
37
 
35
38
 
39
+ class Statistics(Observable):
40
+ def __init__(
41
+ self,
42
+ evaluation_times: typing.Sequence[float] | None,
43
+ data: list[float],
44
+ timestep_count: int,
45
+ ):
46
+ super().__init__(evaluation_times=evaluation_times)
47
+ self.data = data
48
+ self.timestep_count = timestep_count
49
+
50
+ @property
51
+ def _base_tag(self) -> str:
52
+ return "statistics"
53
+
54
+ def apply(
55
+ self,
56
+ *,
57
+ config: EmulationConfig,
58
+ state: State,
59
+ **kwargs: Any,
60
+ ) -> dict:
61
+ """Calculates the observable to store in the Results."""
62
+ assert isinstance(state, MPS)
63
+ duration = self.data[-1]
64
+ if state.factors[0].is_cuda:
65
+ max_mem_per_device = (
66
+ torch.cuda.max_memory_allocated(device) * 1e-6
67
+ for device in range(torch.cuda.device_count())
68
+ )
69
+ max_mem = max(max_mem_per_device)
70
+ else:
71
+ max_mem = getrusage(RUSAGE_SELF).ru_maxrss * 1e-3
72
+
73
+ config.logger.info(
74
+ f"step = {len(self.data)}/{self.timestep_count}, "
75
+ + f"χ = {state.get_max_bond_dim()}, "
76
+ + f"|ψ| = {state.get_memory_footprint():.3f} MB, "
77
+ + f"RSS = {max_mem:.3f} MB, "
78
+ + f"Δt = {duration:.3f} s"
79
+ )
80
+
81
+ return {
82
+ "max_bond_dimension": state.get_max_bond_dim(),
83
+ "memory_footprint": state.get_memory_footprint(),
84
+ "RSS": max_mem,
85
+ "duration": duration,
86
+ }
87
+
88
+
36
89
  class SwipeDirection(Enum):
37
90
  LEFT_TO_RIGHT = auto()
38
91
  RIGHT_TO_LEFT = auto()
@@ -54,7 +107,8 @@ class MPSBackendImpl:
54
107
 
55
108
  def __init__(self, mps_config: MPSConfig, pulser_data: PulserData):
56
109
  self.config = mps_config
57
- self.target_time = float(self.config.dt)
110
+ self.target_times = pulser_data.target_times
111
+ self.target_time = self.target_times[1]
58
112
  self.pulser_data = pulser_data
59
113
  self.qubit_count = pulser_data.qubit_count
60
114
  assert self.qubit_count >= 2
@@ -74,9 +128,14 @@ class MPSBackendImpl:
74
128
  self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
75
129
  self.tdvp_index = 0
76
130
  self.timestep_index = 0
77
- self.results = Results()
131
+ self.results = Results(atom_order=(), total_duration=self.target_times[-1])
132
+ self.statistics = Statistics(
133
+ evaluation_times=[t / self.target_times[-1] for t in self.target_times],
134
+ data=[],
135
+ timestep_count=self.timestep_count,
136
+ )
78
137
  self.autosave_file = self._get_autosave_filepath(self.config.autosave_prefix)
79
- self.config.logger.warning(
138
+ self.config.logger.debug(
80
139
  f"""Will save simulation state to file "{self.autosave_file.name}"
81
140
  every {self.config.autosave_dt} seconds.\n"""
82
141
  f"""To resume: `MPSBackend().resume("{self.autosave_file}")`"""
@@ -89,23 +148,31 @@ class MPSBackendImpl:
89
148
  f"but only {DEVICE_COUNT if DEVICE_COUNT > 0 else 'cpu'} available"
90
149
  )
91
150
 
151
+ def __getstate__(self) -> dict:
152
+ for obs in self.config.observables:
153
+ obs.apply = MethodType(type(obs).apply, obs) # type: ignore[method-assign]
154
+ d = self.__dict__
155
+ # mypy thinks the method below is an attribute, because of the __getattr__ override
156
+ d["results"] = self.results._to_abstract_repr() # type: ignore[operator]
157
+ return d
158
+
159
+ def __setstate__(self, d: dict) -> None:
160
+ d["results"] = Results._from_abstract_repr(d["results"]) # type: ignore [attr-defined]
161
+ self.__dict__ = d
162
+ self.config.monkeypatch_observables()
163
+
92
164
  @staticmethod
93
165
  def _get_autosave_filepath(autosave_prefix: str) -> pathlib.Path:
94
166
  return pathlib.Path(os.getcwd()) / (autosave_prefix + str(uuid.uuid1()) + ".dat")
95
167
 
96
168
  def init_dark_qubits(self) -> None:
97
- has_state_preparation_error: bool = (
98
- self.config.noise_model is not None
99
- and self.config.noise_model.state_prep_error > 0.0
100
- )
101
-
102
- self.well_prepared_qubits_filter = (
103
- pick_well_prepared_qubits(
169
+ # has_state_preparation_error
170
+ if self.config.noise_model.state_prep_error > 0.0:
171
+ self.well_prepared_qubits_filter = pick_well_prepared_qubits(
104
172
  self.config.noise_model.state_prep_error, self.qubit_count
105
173
  )
106
- if has_state_preparation_error
107
- else None
108
- )
174
+ else:
175
+ self.well_prepared_qubits_filter = None
109
176
 
110
177
  if self.well_prepared_qubits_filter is not None:
111
178
  self.qubit_count = sum(1 for x in self.well_prepared_qubits_filter if x)
@@ -152,9 +219,11 @@ class MPSBackendImpl:
152
219
  too many factors are put in the Hamiltonian
153
220
  """
154
221
  self.hamiltonian = make_H(
155
- interaction_matrix=self.masked_interaction_matrix
156
- if self.is_masked
157
- else self.full_interaction_matrix,
222
+ interaction_matrix=(
223
+ self.masked_interaction_matrix
224
+ if self.is_masked
225
+ else self.full_interaction_matrix
226
+ ),
158
227
  hamiltonian_type=self.hamiltonian_type,
159
228
  num_gpus_to_use=self.config.num_gpus_to_use,
160
229
  )
@@ -176,6 +245,12 @@ class MPSBackendImpl:
176
245
  self.right_baths = right_baths(self.state, self.hamiltonian, final_qubit=2)
177
246
  assert len(self.right_baths) == self.qubit_count - 1
178
247
 
248
+ def get_current_right_bath(self) -> torch.Tensor:
249
+ return self.right_baths[-1]
250
+
251
+ def get_current_left_bath(self) -> torch.Tensor:
252
+ return self.left_baths[-1]
253
+
179
254
  def init(self) -> None:
180
255
  self.init_dark_qubits()
181
256
  self.init_initial_state(self.config.initial_state)
@@ -196,7 +271,7 @@ class MPSBackendImpl:
196
271
  """
197
272
  assert 1 <= len(indices) <= 2
198
273
 
199
- baths = (self.left_baths[-1], self.right_baths[-1])
274
+ baths = (self.get_current_left_bath(), self.get_current_right_bath())
200
275
 
201
276
  if len(indices) == 1:
202
277
  assert orth_center_right is None
@@ -268,10 +343,10 @@ class MPSBackendImpl:
268
343
  )
269
344
  self.left_baths.append(
270
345
  new_left_bath(
271
- self.left_baths[-1],
346
+ self.get_current_left_bath(),
272
347
  self.state.factors[self.tdvp_index],
273
348
  self.hamiltonian.factors[self.tdvp_index],
274
- )
349
+ ).to(self.state.factors[self.tdvp_index + 1].device)
275
350
  )
276
351
  self._evolve(self.tdvp_index + 1, dt=-delta_time / 2)
277
352
  self.right_baths.pop()
@@ -297,10 +372,10 @@ class MPSBackendImpl:
297
372
  assert self.tdvp_index <= self.qubit_count - 2
298
373
  self.right_baths.append(
299
374
  new_right_bath(
300
- self.right_baths[-1],
375
+ self.get_current_right_bath(),
301
376
  self.state.factors[self.tdvp_index + 1],
302
377
  self.hamiltonian.factors[self.tdvp_index + 1],
303
- )
378
+ ).to(self.state.factors[self.tdvp_index].device)
304
379
  )
305
380
  if not self.has_lindblad_noise:
306
381
  # Free memory because it won't be used anymore
@@ -333,7 +408,6 @@ class MPSBackendImpl:
333
408
  def timestep_complete(self) -> None:
334
409
  self.fill_results()
335
410
  self.timestep_index += 1
336
- self.target_time = float((self.timestep_index + 1) * self.config.dt)
337
411
  if self.is_masked and self.current_time >= self.slm_end_time:
338
412
  self.is_masked = False
339
413
  self.hamiltonian = make_H(
@@ -343,6 +417,7 @@ class MPSBackendImpl:
343
417
  )
344
418
 
345
419
  if not self.is_finished():
420
+ self.target_time = self.target_times[self.timestep_index + 1]
346
421
  update_H(
347
422
  hamiltonian=self.hamiltonian,
348
423
  omega=self.omega[self.timestep_index, :],
@@ -352,7 +427,14 @@ class MPSBackendImpl:
352
427
  )
353
428
  self.init_baths()
354
429
 
355
- self.log_step_statistics(duration=time.time() - self.time)
430
+ self.statistics.data.append(time.time() - self.time)
431
+ self.statistics(
432
+ self.config,
433
+ self.current_time / self.target_times[-1],
434
+ self.state,
435
+ self.hamiltonian,
436
+ self.results,
437
+ )
356
438
  self.time = time.time()
357
439
 
358
440
  def save_simulation(self) -> None:
@@ -382,13 +464,14 @@ class MPSBackendImpl:
382
464
  normalized_state = 1 / self.state.norm() * self.state
383
465
 
384
466
  current_time_int: int = round(self.current_time)
467
+ fractional_time = self.current_time / self.target_times[-1]
385
468
  assert abs(self.current_time - current_time_int) < 1e-10
386
469
 
387
470
  if self.well_prepared_qubits_filter is None:
388
- for callback in self.config.callbacks:
471
+ for callback in self.config.observables:
389
472
  callback(
390
473
  self.config,
391
- current_time_int,
474
+ fractional_time,
392
475
  normalized_state,
393
476
  self.hamiltonian,
394
477
  self.results,
@@ -396,63 +479,34 @@ class MPSBackendImpl:
396
479
  return
397
480
 
398
481
  full_mpo, full_state = None, None
399
- for callback in self.config.callbacks:
400
- if current_time_int not in callback.evaluation_times:
401
- continue
402
-
403
- if full_mpo is None or full_state is None:
404
- # Only do this potentially expensive step once and when needed.
405
- full_mpo = MPO(
406
- extended_mpo_factors(
407
- self.hamiltonian.factors, self.well_prepared_qubits_filter
408
- )
409
- )
410
- full_state = MPS(
411
- extended_mps_factors(
412
- normalized_state.factors, self.well_prepared_qubits_filter
413
- ),
414
- num_gpus_to_use=None, # Keep the already assigned devices.
415
- orthogonality_center=get_extended_site_index(
416
- self.well_prepared_qubits_filter,
417
- normalized_state.orthogonality_center,
418
- ),
482
+ for callback in self.config.observables:
483
+ time_tol = 0.5 / self.target_times[-1] + 1e-10
484
+ if (
485
+ callback.evaluation_times is not None
486
+ and self.config.is_time_in_evaluation_times(
487
+ fractional_time, callback.evaluation_times, tol=time_tol
419
488
  )
489
+ ) or self.config.is_evaluation_time(fractional_time, tol=time_tol):
490
+
491
+ if full_mpo is None or full_state is None:
492
+ # Only do this potentially expensive step once and when needed.
493
+ full_mpo = MPO(
494
+ extended_mpo_factors(
495
+ self.hamiltonian.factors, self.well_prepared_qubits_filter
496
+ )
497
+ )
498
+ full_state = MPS(
499
+ extended_mps_factors(
500
+ normalized_state.factors, self.well_prepared_qubits_filter
501
+ ),
502
+ num_gpus_to_use=None, # Keep the already assigned devices.
503
+ orthogonality_center=get_extended_site_index(
504
+ self.well_prepared_qubits_filter,
505
+ normalized_state.orthogonality_center,
506
+ ),
507
+ )
420
508
 
421
- callback(self.config, current_time_int, full_state, full_mpo, self.results)
422
-
423
- def log_step_statistics(self, *, duration: float) -> None:
424
- if self.state.factors[0].is_cuda:
425
- max_mem_per_device = (
426
- torch.cuda.max_memory_allocated(device) * 1e-6
427
- for device in range(torch.cuda.device_count())
428
- )
429
- max_mem = max(max_mem_per_device)
430
- else:
431
- max_mem = getrusage(RUSAGE_SELF).ru_maxrss * 1e-3
432
-
433
- self.config.logger.info(
434
- f"step = {self.timestep_index}/{self.timestep_count}, "
435
- + f"χ = {self.state.get_max_bond_dim()}, "
436
- + f"|ψ| = {self.state.get_memory_footprint():.3f} MB, "
437
- + f"RSS = {max_mem:.3f} MB, "
438
- + f"Δt = {duration:.3f} s"
439
- )
440
-
441
- if self.results.statistics is None:
442
- assert self.timestep_index == 1
443
- self.results.statistics = {"steps": []}
444
-
445
- assert "steps" in self.results.statistics
446
- assert len(self.results.statistics["steps"]) == self.timestep_index - 1
447
-
448
- self.results.statistics["steps"].append(
449
- {
450
- "max_bond_dimension": self.state.get_max_bond_dim(),
451
- "memory_footprint": self.state.get_memory_footprint(),
452
- "RSS": max_mem,
453
- "duration": duration,
454
- }
455
- )
509
+ callback(self.config, fractional_time, full_state, full_mpo, self.results)
456
510
 
457
511
 
458
512
  class NoisyMPSBackendImpl(MPSBackendImpl):
@@ -479,12 +533,15 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
479
533
  self.aggregated_lindblad_ops = stacked.conj().transpose(1, 2) @ stacked
480
534
 
481
535
  self.lindblad_noise = compute_noise_from_lindbladians(self.lindblad_ops)
482
- self.jump_threshold = random.random()
536
+
537
+ def set_jump_threshold(self, bound: float) -> None:
538
+ self.jump_threshold = random.uniform(0.0, bound)
483
539
  self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
484
540
 
485
541
  def init(self) -> None:
486
- super().init()
487
542
  self.init_lindblad_noise()
543
+ super().init()
544
+ self.set_jump_threshold(1.0)
488
545
 
489
546
  def tdvp_complete(self) -> None:
490
547
  previous_time = self.current_time
@@ -515,7 +572,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
515
572
 
516
573
  if self.root_finder.is_converged(tolerance=1):
517
574
  self.do_random_quantum_jump()
518
- self.target_time = (self.timestep_index + 1) * self.config.dt
575
+ self.target_time = self.target_times[self.timestep_index + 1]
519
576
  self.root_finder = None
520
577
  else:
521
578
  self.target_time = self.root_finder.get_next_abscissa()
@@ -534,11 +591,11 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
534
591
  self.state.apply(jumped_qubit_index, jump_operator)
535
592
  self.state.orthogonalize(0)
536
593
  self.state *= 1 / self.state.norm()
594
+ self.init_baths()
537
595
 
538
596
  norm_after_normalizing = self.state.norm()
539
597
  assert math.isclose(norm_after_normalizing, 1, abs_tol=1e-10)
540
- self.jump_threshold = random.uniform(0.0, norm_after_normalizing**2)
541
- self.norm_gap_before_jump = norm_after_normalizing**2 - self.jump_threshold
598
+ self.set_jump_threshold(norm_after_normalizing**2)
542
599
 
543
600
  def fill_results(self) -> None:
544
601
  # Remove the noise from self.hamiltonian for the callbacks.
emu_mps/mps_config.py CHANGED
@@ -1,16 +1,37 @@
1
- from typing import Any
1
+ from typing import Any, ClassVar
2
+ from types import MethodType
2
3
 
3
- from emu_base import BackendConfig, State, DEVICE_COUNT
4
+ import copy
4
5
 
6
+ from emu_base import DEVICE_COUNT
7
+ from emu_mps.custom_callback_implementations import (
8
+ energy_mps_impl,
9
+ energy_second_moment_mps_impl,
10
+ energy_variance_mps_impl,
11
+ correlation_matrix_mps_impl,
12
+ qubit_occupation_mps_impl,
13
+ )
14
+ from pulser.backend import (
15
+ Occupation,
16
+ CorrelationMatrix,
17
+ Energy,
18
+ EnergySecondMoment,
19
+ EnergyVariance,
20
+ BitStrings,
21
+ EmulationConfig,
22
+ )
23
+ import logging
24
+ import pathlib
25
+ import sys
5
26
 
6
- class MPSConfig(BackendConfig):
27
+
28
+ class MPSConfig(EmulationConfig):
7
29
  """
8
- The configuration of the emu-ct MPSBackend. The kwargs passed to this class
30
+ The configuration of the emu-mps MPSBackend. The kwargs passed to this class
9
31
  are passed on to the base class.
10
32
  See the API for that class for a list of available options.
11
33
 
12
34
  Args:
13
- initial_state: the initial state to use in the simulation
14
35
  dt: the timestep size that the solver uses. Note that observables are
15
36
  only calculated if the evaluation_times are divisible by dt.
16
37
  precision: up to what precision the state is truncated
@@ -36,42 +57,129 @@ class MPSConfig(BackendConfig):
36
57
  >>> with_modulation=True) #the last arg is taken from the base class
37
58
  """
38
59
 
60
+ # Whether to warn if unexpected kwargs are received
61
+ _enforce_expected_kwargs: ClassVar[bool] = True
62
+
39
63
  def __init__(
40
64
  self,
41
65
  *,
42
- initial_state: State | None = None,
43
66
  dt: int = 10,
44
67
  precision: float = 1e-5,
45
68
  max_bond_dim: int = 1024,
46
69
  max_krylov_dim: int = 100,
47
70
  extra_krylov_tolerance: float = 1e-3,
48
71
  num_gpus_to_use: int = DEVICE_COUNT,
72
+ interaction_cutoff: float = 0.0,
73
+ log_level: int = logging.INFO,
74
+ log_file: pathlib.Path | None = None,
49
75
  autosave_prefix: str = "emu_mps_save_",
50
76
  autosave_dt: int = 600, # 10 minutes
51
77
  **kwargs: Any,
52
78
  ):
53
- super().__init__(**kwargs)
54
- self.initial_state = initial_state
55
- self.dt = dt
56
- self.precision = precision
57
- self.max_bond_dim = max_bond_dim
58
- self.max_krylov_dim = max_krylov_dim
59
- self.num_gpus_to_use = num_gpus_to_use
60
- self.extra_krylov_tolerance = extra_krylov_tolerance
79
+ kwargs.setdefault("observables", [BitStrings(evaluation_times=[1.0])])
80
+ super().__init__(
81
+ dt=dt,
82
+ precision=precision,
83
+ max_bond_dim=max_bond_dim,
84
+ max_krylov_dim=max_krylov_dim,
85
+ extra_krylov_tolerance=extra_krylov_tolerance,
86
+ num_gpus_to_use=num_gpus_to_use,
87
+ interaction_cutoff=interaction_cutoff,
88
+ log_level=log_level,
89
+ log_file=log_file,
90
+ autosave_prefix=autosave_prefix,
91
+ autosave_dt=autosave_dt,
92
+ **kwargs,
93
+ )
61
94
 
62
- if self.noise_model is not None:
63
- if "doppler" in self.noise_model.noise_types:
64
- raise NotImplementedError("Unsupported noise type: doppler")
65
- if (
66
- "amplitude" in self.noise_model.noise_types
67
- and self.noise_model.amp_sigma != 0.0
68
- ):
69
- raise NotImplementedError("Unsupported noise type: amp_sigma")
70
-
71
- self.autosave_prefix = autosave_prefix
72
- self.autosave_dt = autosave_dt
95
+ if "doppler" in self.noise_model.noise_types:
96
+ raise NotImplementedError("Unsupported noise type: doppler")
97
+ if (
98
+ "amplitude" in self.noise_model.noise_types
99
+ and self.noise_model.amp_sigma != 0.0
100
+ ):
101
+ raise NotImplementedError("Unsupported noise type: amp_sigma")
73
102
 
74
103
  MIN_AUTOSAVE_DT = 10
75
104
  assert (
76
105
  self.autosave_dt > MIN_AUTOSAVE_DT
77
106
  ), f"autosave_dt must be larger than {MIN_AUTOSAVE_DT} seconds"
107
+
108
+ self.monkeypatch_observables()
109
+
110
+ self.logger = logging.getLogger("global_logger")
111
+ if log_file is None:
112
+ logging.basicConfig(
113
+ level=log_level, format="%(message)s", stream=sys.stdout, force=True
114
+ ) # default to stream = sys.stderr
115
+ else:
116
+ logging.basicConfig(
117
+ level=log_level,
118
+ format="%(message)s",
119
+ filename=str(log_file),
120
+ filemode="w",
121
+ force=True,
122
+ )
123
+ if (self.noise_model.runs != 1 and self.noise_model.runs is not None) or (
124
+ self.noise_model.samples_per_run != 1
125
+ and self.noise_model.samples_per_run is not None
126
+ ):
127
+ self.logger.warning(
128
+ "Warning: The runs and samples_per_run values of the NoiseModel are ignored!"
129
+ )
130
+
131
+ def _expected_kwargs(self) -> set[str]:
132
+ return super()._expected_kwargs() | {
133
+ "dt",
134
+ "precision",
135
+ "max_bond_dim",
136
+ "max_krylov_dim",
137
+ "extra_krylov_tolerance",
138
+ "num_gpus_to_use",
139
+ "interaction_cutoff",
140
+ "log_level",
141
+ "log_file",
142
+ "autosave_prefix",
143
+ "autosave_dt",
144
+ }
145
+
146
+ def monkeypatch_observables(self) -> None:
147
+ obs_list = []
148
+ for _, obs in enumerate(self.observables): # monkey patch
149
+ obs_copy = copy.deepcopy(obs)
150
+ if isinstance(obs, Occupation):
151
+ obs_copy.apply = MethodType( # type: ignore[method-assign]
152
+ qubit_occupation_mps_impl, obs_copy
153
+ )
154
+ elif isinstance(obs, EnergyVariance):
155
+ obs_copy.apply = MethodType( # type: ignore[method-assign]
156
+ energy_variance_mps_impl, obs_copy
157
+ )
158
+ elif isinstance(obs, EnergySecondMoment):
159
+ obs_copy.apply = MethodType( # type: ignore[method-assign]
160
+ energy_second_moment_mps_impl, obs_copy
161
+ )
162
+ elif isinstance(obs, CorrelationMatrix):
163
+ obs_copy.apply = MethodType( # type: ignore[method-assign]
164
+ correlation_matrix_mps_impl, obs_copy
165
+ )
166
+ elif isinstance(obs, Energy):
167
+ obs_copy.apply = MethodType( # type: ignore[method-assign]
168
+ energy_mps_impl, obs_copy
169
+ )
170
+ obs_list.append(obs_copy)
171
+ self.observables = tuple(obs_list)
172
+
173
+ def init_logging(self) -> None:
174
+ if self.log_file is None:
175
+ logging.basicConfig(
176
+ level=self.log_level, format="%(message)s", stream=sys.stdout, force=True
177
+ ) # default to stream = sys.stderr
178
+ else:
179
+ logging.basicConfig(
180
+ level=self.log_level,
181
+ format="%(message)s",
182
+ filename=str(self.log_file),
183
+ filemode="w",
184
+ force=True,
185
+ )
@@ -210,9 +210,9 @@ def minimize_bandwidth(input_matrix: np.ndarray, samples: int = 100) -> list[int
210
210
  # We are interested in strength of the interaction, not sign
211
211
 
212
212
  L = input_mat.shape[0]
213
- rnd_permutations = itertools.chain(
213
+ rnd_permutations: itertools.chain[list[int]] = itertools.chain(
214
214
  [list(range(L))], # First element is always the identity list
215
- (np.random.permutation(L).tolist() for _ in range(samples)),
215
+ (np.random.permutation(L).tolist() for _ in range(samples)), # type: ignore[misc]
216
216
  )
217
217
 
218
218
  opt_permutations_and_opt_bandwidth = (
emu_mps/tdvp.py CHANGED
@@ -117,6 +117,7 @@ def evolve_pair(
117
117
 
118
118
  left_ham_factor = left_ham_factor.to(left_device)
119
119
  right_ham_factor = right_ham_factor.to(left_device)
120
+ right_bath = right_bath.to(left_device)
120
121
 
121
122
  combined_hamiltonian_factors = (
122
123
  torch.tensordot(left_ham_factor, right_ham_factor, dims=1)
emu_mps/utils.py CHANGED
@@ -130,13 +130,18 @@ def extended_mps_factors(
130
130
  bond_dimension = mps_factors[factor_index].shape[2]
131
131
  factor_index += 1
132
132
  elif factor_index == len(mps_factors):
133
- factor = torch.zeros(bond_dimension, 2, 1, dtype=torch.complex128)
133
+ factor = torch.zeros(
134
+ bond_dimension, 2, 1, dtype=torch.complex128
135
+ ) # FIXME: assign device
134
136
  factor[:, 0, :] = torch.eye(bond_dimension, 1)
135
137
  bond_dimension = 1
136
138
  result.append(factor)
137
139
  else:
138
140
  factor = torch.zeros(
139
- bond_dimension, 2, bond_dimension, dtype=torch.complex128
141
+ bond_dimension,
142
+ 2,
143
+ bond_dimension,
144
+ dtype=torch.complex128, # FIXME: assign device
140
145
  )
141
146
  factor[:, 0, :] = torch.eye(bond_dimension, bond_dimension)
142
147
  result.append(factor)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-mps
3
- Version: 1.2.6
3
+ Version: 2.0.0
4
4
  Summary: Pasqal MPS based pulse emulator built on PyTorch
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
@@ -25,7 +25,7 @@ Classifier: Programming Language :: Python :: 3.10
25
25
  Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
- Requires-Dist: emu-base==1.2.6
28
+ Requires-Dist: emu-base==2.0.0
29
29
  Description-Content-Type: text/markdown
30
30
 
31
31
  <div align="center">
@@ -0,0 +1,18 @@
1
+ emu_mps/__init__.py,sha256=N7maf5x7U6ewsK1hMebl0OBga9dSB0NoesSbtKtgN0c,662
2
+ emu_mps/algebra.py,sha256=ngPtTH-j2ZCBWoaJZXlkUyIlug7dY7Q92gzfnRlpPMA,5485
3
+ emu_mps/custom_callback_implementations.py,sha256=CUs0kW3HRaPE7UeFNQOFbeWJMsz4hS2q4rgS57BBp-A,2411
4
+ emu_mps/hamiltonian.py,sha256=LcBs6CKBb643a1e9AAVtQoUfa4L_0dIhLOKecx5OOWs,15864
5
+ emu_mps/mpo.py,sha256=wSonS6i3zEt3yRTgyZ7F6vT51pUKavFcLOxFFphBv8k,8793
6
+ emu_mps/mps.py,sha256=3g-iL8ZpLd2ueMFXlFgY2qk3CEf8_cd4UEm8TDOAxRI,18972
7
+ emu_mps/mps_backend.py,sha256=_3rlg6XeI4fHaDiJRfPL6pDkX9k48hAHKXd8fkvkOFs,2004
8
+ emu_mps/mps_backend_impl.py,sha256=tg-ZMXmQNU5CeT4fi9yVhuHrJ35Nt4qi2l-UtEgAFOE,22952
9
+ emu_mps/mps_config.py,sha256=89nu5OhNUX31eAeeYvvKnAHegpPVD43jH5Nmp635HVU,6984
10
+ emu_mps/noise.py,sha256=h4X2EFjoC_Ok0gZ8I9wN77RANXaVehTBbjkcbY_GAmY,784
11
+ emu_mps/tdvp.py,sha256=pIQ2NXA2Mrkp3elhqQbX3pdJVbtKkG3c5r9fFlJo7pI,5755
12
+ emu_mps/utils.py,sha256=BqRJYAcXqprtZVJ0V_j954ON2bhTdtZiaTojsYyrWrg,8193
13
+ emu_mps/optimatrix/__init__.py,sha256=lHWYNeThHp57ZrwTwXd0p8bNvcCv0w_AZ31iCWflBUo,226
14
+ emu_mps/optimatrix/optimiser.py,sha256=7j9_jMQC-Uh2DzdIVB44InRzZO6AbbGhvmm7lC6N3tk,6737
15
+ emu_mps/optimatrix/permutations.py,sha256=JRXGont8B4QgbkV9CzrA0w7uzLgBrmZ1J9aa0G52hPo,1979
16
+ emu_mps-2.0.0.dist-info/METADATA,sha256=Is92e2JsE9NzvRFbpsI44W16pmVm-bJq-pQd7g736Ss,3505
17
+ emu_mps-2.0.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ emu_mps-2.0.0.dist-info/RECORD,,
@@ -1,17 +0,0 @@
1
- emu_mps/__init__.py,sha256=KphiJUrBt3vjS0e2_tQejeB12X3YwqrEhxwfF1QRXQE,646
2
- emu_mps/algebra.py,sha256=ngPtTH-j2ZCBWoaJZXlkUyIlug7dY7Q92gzfnRlpPMA,5485
3
- emu_mps/hamiltonian.py,sha256=LcBs6CKBb643a1e9AAVtQoUfa4L_0dIhLOKecx5OOWs,15864
4
- emu_mps/mpo.py,sha256=H5vkJvz4AfXfnPbvgWznBWpMUO8LnGL3_NAP3IhxZzQ,8740
5
- emu_mps/mps.py,sha256=CcduX2BC4ArBwwF41w_FQCa6wqmynegQQC9zkK0EmgQ,17826
6
- emu_mps/mps_backend.py,sha256=6fVaq-D4xyicYRjGjhqMEieC7---90LpfpbV7ZD7zkQ,2192
7
- emu_mps/mps_backend_impl.py,sha256=Rp7WbT0Dto1G4ArqSLEzSHkJAuMIEZfUqUMP9Dyz31M,20838
8
- emu_mps/mps_config.py,sha256=ydKN0OOaWCBcNd9V-4CU5ZZ4w1FRT-bbKyZQD2WCaME,3317
9
- emu_mps/noise.py,sha256=h4X2EFjoC_Ok0gZ8I9wN77RANXaVehTBbjkcbY_GAmY,784
10
- emu_mps/tdvp.py,sha256=TH4CcBNczRURXYGPXndWKDs0jWXz_x9ozM961uGiSOw,5711
11
- emu_mps/utils.py,sha256=n9BcpuIz4Kl6EYlATaK8TKsyF-T7FTwbBo6KSAQYzl8,8066
12
- emu_mps/optimatrix/__init__.py,sha256=lHWYNeThHp57ZrwTwXd0p8bNvcCv0w_AZ31iCWflBUo,226
13
- emu_mps/optimatrix/optimiser.py,sha256=cVMdm2r_4OpbthcQuFMrJ9rNR9WEJRga9c_lHrJFkhw,6687
14
- emu_mps/optimatrix/permutations.py,sha256=JRXGont8B4QgbkV9CzrA0w7uzLgBrmZ1J9aa0G52hPo,1979
15
- emu_mps-1.2.6.dist-info/METADATA,sha256=iUObGpVmQN3Y6GM_ScEFmFlY80htsWrIjo53ER2Flvw,3505
16
- emu_mps-1.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- emu_mps-1.2.6.dist-info/RECORD,,