emu-mps 2.4.0__py3-none-any.whl → 2.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -36,4 +36,4 @@ __all__ = [
36
36
  "EntanglementEntropy",
37
37
  ]
38
38
 
39
- __version__ = "2.4.0"
39
+ __version__ = "2.4.1"
emu_mps/algebra.py CHANGED
@@ -1,11 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
4
3
 
5
4
  import torch
6
5
  import math
7
-
8
- from emu_mps import MPSConfig
9
6
  from emu_mps.utils import truncate_impl
10
7
 
11
8
 
@@ -119,7 +116,8 @@ def zip_right_step(
119
116
  def zip_right(
120
117
  top_factors: list[torch.Tensor],
121
118
  bottom_factors: list[torch.Tensor],
122
- config: Optional[MPSConfig] = None,
119
+ precision: float,
120
+ max_bond_dim: int,
123
121
  ) -> list[torch.Tensor]:
124
122
  """
125
123
  Returns a new matrix product, resulting from applying `top` to `bottom`.
@@ -139,8 +137,6 @@ def zip_right(
139
137
  A final truncation sweep, from right to left,
140
138
  moves back the orthogonal center to the first element.
141
139
  """
142
- config = config if config is not None else MPSConfig()
143
-
144
140
  if len(top_factors) != len(bottom_factors):
145
141
  raise ValueError("Cannot multiply two matrix products of different lengths.")
146
142
 
@@ -151,6 +147,6 @@ def zip_right(
151
147
  new_factors.append(res)
152
148
  new_factors[-1] @= slider[:, :, 0]
153
149
 
154
- truncate_impl(new_factors, config=config)
150
+ truncate_impl(new_factors, precision=precision, max_bond_dim=max_bond_dim)
155
151
 
156
152
  return new_factors
@@ -25,9 +25,9 @@ def qubit_occupation_mps_impl(
25
25
  """
26
26
  Custom implementation of the occupation ❬ψ|nᵢ|ψ❭ for the EMU-MPS.
27
27
  """
28
- op = torch.tensor(
29
- [[[0.0, 0.0], [0.0, 1.0]]], dtype=torch.complex128, device=state.factors[0].device
30
- )
28
+ dim = state.dim
29
+ op = torch.zeros(1, dim, dim, dtype=torch.complex128, device=state.factors[0].device)
30
+ op[0, 1, 1] = 1.0
31
31
  return state.expect_batch(op).real.view(-1).cpu()
32
32
 
33
33
 
emu_mps/mpo.py CHANGED
@@ -7,7 +7,7 @@ from pulser.backend import State, Operator
7
7
  from emu_base import DEVICE_COUNT
8
8
  from emu_mps.algebra import add_factors, scale_factors, zip_right
9
9
  from pulser.backend.operator import FullOp, QuditOp
10
- from emu_mps.mps import MPS
10
+ from emu_mps.mps import MPS, DEFAULT_MAX_BOND_DIM, DEFAULT_PRECISION
11
11
  from emu_mps.utils import new_left_bath, assign_devices
12
12
 
13
13
 
@@ -62,7 +62,8 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
62
62
  factors = zip_right(
63
63
  self.factors,
64
64
  other.factors,
65
- config=other.config,
65
+ precision=other.precision,
66
+ max_bond_dim=other.max_bond_dim,
66
67
  )
67
68
  return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
68
69
 
@@ -107,7 +108,12 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
107
108
  the composed operator
108
109
  """
109
110
  assert isinstance(other, MPO), "MPO can only be applied to another MPO"
110
- factors = zip_right(self.factors, other.factors)
111
+ factors = zip_right(
112
+ self.factors,
113
+ other.factors,
114
+ precision=DEFAULT_PRECISION,
115
+ max_bond_dim=DEFAULT_MAX_BOND_DIM,
116
+ )
111
117
  return MPO(factors)
112
118
 
113
119
  def expect(self, state: State) -> torch.Tensor:
emu_mps/mps.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
-
3
2
  import math
4
3
  from collections import Counter
5
4
  from typing import List, Optional, Sequence, TypeVar, Mapping
@@ -8,16 +7,19 @@ import torch
8
7
 
9
8
  from pulser.backend.state import State, Eigenstate
10
9
  from emu_base import DEVICE_COUNT, apply_measurement_errors
11
- from emu_mps import MPSConfig
12
10
  from emu_mps.algebra import add_factors, scale_factors
13
11
  from emu_mps.utils import (
14
12
  assign_devices,
15
13
  truncate_impl,
16
14
  tensor_trace,
17
- n_operator,
18
15
  )
19
16
 
17
+
20
18
  ArgScalarType = TypeVar("ArgScalarType")
19
+ dtype = torch.complex128
20
+
21
+ DEFAULT_PRECISION = 1e-5
22
+ DEFAULT_MAX_BOND_DIM = 1024
21
23
 
22
24
 
23
25
  class MPS(State[complex, torch.Tensor]):
@@ -35,12 +37,14 @@ class MPS(State[complex, torch.Tensor]):
35
37
  /,
36
38
  *,
37
39
  orthogonality_center: Optional[int] = None,
38
- config: Optional[MPSConfig] = None,
40
+ precision: float = DEFAULT_PRECISION,
41
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
39
42
  num_gpus_to_use: Optional[int] = DEVICE_COUNT,
40
43
  eigenstates: Sequence[Eigenstate] = ("r", "g"),
41
44
  ):
42
45
  """
43
- This constructor creates a MPS directly from a list of tensors. It is for internal use only.
46
+ This constructor creates a MPS directly from a list of tensors. It is
47
+ for internal use only.
44
48
 
45
49
  Args:
46
50
  factors: the tensors for each site
@@ -52,12 +56,14 @@ class MPS(State[complex, torch.Tensor]):
52
56
  of the data to this constructor, or some shared objects.
53
57
  orthogonality_center: the orthogonality center of the MPS, or None (in which case
54
58
  it will be orthogonalized when needed)
55
- config: the emu-mps config object passed to the run method
59
+ precision: the precision with which to keep this MPS
60
+ max_bond_dim: the maximum bond dimension to allow for this MPS
56
61
  num_gpus_to_use: distribute the factors over this many GPUs
57
62
  0=all factors to cpu, None=keep the existing device assignment.
58
63
  """
59
64
  super().__init__(eigenstates=eigenstates)
60
- self.config = config if config is not None else MPSConfig()
65
+ self.precision = precision
66
+ self.max_bond_dim = max_bond_dim
61
67
  assert all(
62
68
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
63
69
  ), "The dimensions of consecutive tensors should match"
@@ -69,6 +75,17 @@ class MPS(State[complex, torch.Tensor]):
69
75
  self.num_sites = len(factors)
70
76
  assert self.num_sites > 1 # otherwise, do state vector
71
77
 
78
+ self.dim = len(self.eigenstates)
79
+ assert all(factors[i].shape[1] == self.dim for i in range(self.num_sites)), (
80
+ "All tensors should have the same physical dimension as the number "
81
+ "of eigenstates"
82
+ )
83
+
84
+ self.n_operator = torch.zeros(
85
+ self.dim, self.dim, dtype=dtype, device=self.factors[0].device
86
+ )
87
+ self.n_operator[1, 1] = 1.0
88
+
72
89
  assert (orthogonality_center is None) or (
73
90
  0 <= orthogonality_center < self.num_sites
74
91
  ), "Invalid orthogonality center provided"
@@ -86,7 +103,8 @@ class MPS(State[complex, torch.Tensor]):
86
103
  def make(
87
104
  cls,
88
105
  num_sites: int,
89
- config: Optional[MPSConfig] = None,
106
+ precision: float = DEFAULT_PRECISION,
107
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
90
108
  num_gpus_to_use: int = DEVICE_COUNT,
91
109
  eigenstates: Sequence[Eigenstate] = ["0", "1"],
92
110
  ) -> MPS:
@@ -95,21 +113,35 @@ class MPS(State[complex, torch.Tensor]):
95
113
 
96
114
  Args:
97
115
  num_sites: the number of qubits
98
- config: the MPSConfig
116
+ precision: the precision with which to keep this MPS
117
+ max_bond_dim: the maximum bond dimension to allow for this MPS
99
118
  num_gpus_to_use: distribute the factors over this many GPUs
100
119
  0=all factors to cpu
101
120
  """
102
- config = config if config is not None else MPSConfig()
103
-
104
121
  if num_sites <= 1:
105
122
  raise ValueError("For 1 qubit states, do state vector")
106
123
 
107
- return cls(
108
- [
109
- torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128)
124
+ if len(eigenstates) == 2:
125
+ ground_state = [
126
+ torch.tensor([[[1.0], [0.0]]], dtype=dtype) for _ in range(num_sites)
127
+ ]
128
+
129
+ elif len(eigenstates) == 3: # (g,r,x)
130
+ ground_state = [
131
+ torch.tensor([[[1.0], [0.0], [0.0]]], dtype=dtype)
110
132
  for _ in range(num_sites)
111
- ],
112
- config=config,
133
+ ]
134
+
135
+ else:
136
+ raise ValueError(
137
+ "Unsupported basis provided. The supported "
138
+ "bases are:{('0','1'),('r','g'),('r','g','x')}"
139
+ )
140
+
141
+ return cls(
142
+ ground_state,
143
+ precision=precision,
144
+ max_bond_dim=max_bond_dim,
113
145
  num_gpus_to_use=num_gpus_to_use,
114
146
  orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
115
147
  eigenstates=eigenstates,
@@ -140,7 +172,8 @@ class MPS(State[complex, torch.Tensor]):
140
172
 
141
173
  for i in range(lr_swipe_start, desired_orthogonality_center):
142
174
  q, r = torch.linalg.qr(self.factors[i].view(-1, self.factors[i].shape[2]))
143
- self.factors[i] = q.view(self.factors[i].shape[0], 2, -1)
175
+
176
+ self.factors[i] = q.view(self.factors[i].shape[0], self.dim, -1)
144
177
  self.factors[i + 1] = torch.tensordot(
145
178
  r.to(self.factors[i + 1].device), self.factors[i + 1], dims=1
146
179
  )
@@ -155,7 +188,8 @@ class MPS(State[complex, torch.Tensor]):
155
188
  q, r = torch.linalg.qr(
156
189
  self.factors[i].contiguous().view(self.factors[i].shape[0], -1).mT,
157
190
  )
158
- self.factors[i] = q.mT.view(-1, 2, self.factors[i].shape[2])
191
+
192
+ self.factors[i] = q.mT.view(-1, self.dim, self.factors[i].shape[2])
159
193
  self.factors[i - 1] = torch.tensordot(
160
194
  self.factors[i - 1], r.to(self.factors[i - 1].device), ([2], [1])
161
195
  )
@@ -172,7 +206,9 @@ class MPS(State[complex, torch.Tensor]):
172
206
  An in-place operation.
173
207
  """
174
208
  self.orthogonalize(self.num_sites - 1)
175
- truncate_impl(self.factors, config=self.config)
209
+ truncate_impl(
210
+ self.factors, precision=self.precision, max_bond_dim=self.max_bond_dim
211
+ )
176
212
  self.orthogonality_center = 0
177
213
 
178
214
  def get_max_bond_dim(self) -> int:
@@ -182,7 +218,7 @@ class MPS(State[complex, torch.Tensor]):
182
218
  Returns:
183
219
  the largest bond dimension in the state
184
220
  """
185
- return max((x.shape[2] for x in self.factors), default=0)
221
+ return max((factor.shape[2] for factor in self.factors), default=0)
186
222
 
187
223
  def sample(
188
224
  self,
@@ -206,8 +242,6 @@ class MPS(State[complex, torch.Tensor]):
206
242
  assert one_state in {None, "r", "1"}
207
243
  self.orthogonalize(0)
208
244
 
209
- rnd_matrix = torch.rand(num_shots, self.num_sites).to(self.factors[0].device)
210
-
211
245
  bitstrings: Counter[str] = Counter()
212
246
 
213
247
  # Shots are performed in batches.
@@ -215,55 +249,45 @@ class MPS(State[complex, torch.Tensor]):
215
249
  max_batch_size = 32
216
250
 
217
251
  shots_done = 0
252
+
218
253
  while shots_done < num_shots:
219
254
  batch_size = min(max_batch_size, num_shots - shots_done)
220
255
  batched_accumulator = torch.ones(
221
- batch_size, 1, dtype=torch.complex128, device=self.factors[0].device
256
+ batch_size, 1, dtype=dtype, device=self.factors[0].device
222
257
  )
223
258
 
224
- batch_outcomes = torch.empty(batch_size, self.num_sites, dtype=torch.bool)
225
-
259
+ batch_outcomes = torch.empty(batch_size, self.num_sites, dtype=torch.int)
260
+ rangebatch = torch.arange(batch_size)
226
261
  for qubit, factor in enumerate(self.factors):
227
262
  batched_accumulator = torch.tensordot(
228
263
  batched_accumulator.to(factor.device), factor, dims=1
229
264
  )
230
265
 
231
- # Probability of measuring qubit == 0 for each shot in the batch
232
- probas = (
233
- torch.linalg.vector_norm(batched_accumulator[:, 0, :], dim=1) ** 2
234
- )
266
+ # Probabilities for each state in the basis
267
+ probn = torch.linalg.vector_norm(batched_accumulator, dim=2) ** 2
235
268
 
236
- outcomes = (
237
- rnd_matrix[shots_done : shots_done + batch_size, qubit].to(
238
- factor.device
239
- )
240
- > probas
241
- )
242
- batch_outcomes[:, qubit] = outcomes
269
+ # list of: 0,1 for |g>,|r> or 0,1,2 for |g>,|r>,|x>
270
+ outcomes = torch.multinomial(probn, num_samples=1).reshape(-1)
243
271
 
244
- # Batch collapse qubit
245
- tmp = torch.stack((~outcomes, outcomes), dim=1).to(dtype=torch.complex128)
272
+ batch_outcomes[:, qubit] = outcomes
246
273
 
247
- batched_accumulator = (
248
- torch.tensordot(batched_accumulator, tmp, dims=([1], [1]))
249
- .diagonal(dim1=0, dim2=2)
250
- .transpose(1, 0)
251
- )
252
- batched_accumulator /= torch.sqrt(
253
- (~outcomes) * probas + outcomes * (1 - probas)
254
- ).unsqueeze(1)
274
+ # expected shape (batch_size, bond_dim)
275
+ batched_accumulator = batched_accumulator[rangebatch, outcomes, :]
255
276
 
256
277
  shots_done += batch_size
257
278
 
258
279
  for outcome in batch_outcomes:
259
- bitstrings.update(["".join("0" if x == 0 else "1" for x in outcome)])
280
+ bitstrings.update(["".join("1" if x == 1 else "0" for x in outcome)])
260
281
 
261
- if p_false_neg > 0 or p_false_pos > 0:
282
+ if p_false_neg > 0 or p_false_pos > 0 and self.dim == 2:
262
283
  bitstrings = apply_measurement_errors(
263
284
  bitstrings,
264
285
  p_false_pos=p_false_pos,
265
286
  p_false_neg=p_false_neg,
266
287
  )
288
+ if p_false_pos > 0 and self.dim > 2:
289
+ raise NotImplementedError("Not implemented for qudits > 2 levels")
290
+
267
291
  return bitstrings
268
292
 
269
293
  def norm(self) -> torch.Tensor:
@@ -312,8 +336,11 @@ class MPS(State[complex, torch.Tensor]):
312
336
  def entanglement_entropy(self, mps_site: int) -> torch.Tensor:
313
337
  """
314
338
  Returns
315
- the Von Neumann entanglement entropy of the state `mps` at the bond between sites b and b+1
339
+ the Von Neumann entanglement entropy of the state `mps` at the bond
340
+ between sites b and b+1
341
+
316
342
  S = -Σᵢsᵢ² log(sᵢ²)),
343
+
317
344
  where sᵢ are the singular values at the chosen bond.
318
345
  """
319
346
  self.orthogonalize(mps_site)
@@ -358,7 +385,8 @@ class MPS(State[complex, torch.Tensor]):
358
385
  new_tt = add_factors(self.factors, other.factors)
359
386
  result = MPS(
360
387
  new_tt,
361
- config=self.config,
388
+ precision=self.precision,
389
+ max_bond_dim=self.max_bond_dim,
362
390
  num_gpus_to_use=None,
363
391
  orthogonality_center=None, # Orthogonality is lost.
364
392
  eigenstates=self.eigenstates,
@@ -384,7 +412,8 @@ class MPS(State[complex, torch.Tensor]):
384
412
  factors = scale_factors(self.factors, scalar, which=which)
385
413
  return MPS(
386
414
  factors,
387
- config=self.config,
415
+ precision=self.precision,
416
+ max_bond_dim=self.max_bond_dim,
388
417
  num_gpus_to_use=None,
389
418
  orthogonality_center=self.orthogonality_center,
390
419
  eigenstates=self.eigenstates,
@@ -412,25 +441,43 @@ class MPS(State[complex, torch.Tensor]):
412
441
  Returns:
413
442
  The resulting MPS representation of the state.s
414
443
  """
444
+
445
+ leak = ""
446
+ one = "r"
415
447
  basis = set(eigenstates)
448
+
416
449
  if basis == {"r", "g"}:
417
- one = "r"
450
+ pass
418
451
  elif basis == {"0", "1"}:
419
452
  one = "1"
453
+ elif basis == {"g", "r", "x"}:
454
+ leak = "x"
420
455
  else:
421
456
  raise ValueError("Unsupported basis provided")
457
+ dim = len(eigenstates)
458
+ if dim == 2:
459
+ basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=dtype) # ground state
460
+ basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=dtype) # excited state
422
461
 
423
- basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
424
- basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
462
+ elif dim == 3:
463
+ basis_0 = torch.tensor([[[1.0], [0.0], [0.0]]], dtype=dtype) # ground state
464
+ basis_1 = torch.tensor([[[0.0], [1.0], [0.0]]], dtype=dtype) # excited state
465
+ basis_x = torch.tensor([[[0.0], [0.0], [1.0]]], dtype=dtype) # leakage state
425
466
 
426
467
  accum_mps = MPS(
427
- [torch.zeros((1, 2, 1), dtype=torch.complex128)] * n_qudits,
468
+ [torch.zeros((1, dim, 1), dtype=dtype)] * n_qudits,
428
469
  orthogonality_center=0,
429
470
  eigenstates=eigenstates,
430
471
  )
431
-
432
472
  for state, amplitude in amplitudes.items():
433
- factors = [basis_1 if ch == one else basis_0 for ch in state]
473
+ factors = []
474
+ for ch in state:
475
+ if ch == one:
476
+ factors.append(basis_1)
477
+ elif ch == leak:
478
+ factors.append(basis_x)
479
+ else:
480
+ factors.append(basis_0)
434
481
  accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
435
482
  norm = accum_mps.norm()
436
483
  if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
@@ -453,9 +500,7 @@ class MPS(State[complex, torch.Tensor]):
453
500
  else self.orthogonalize(0)
454
501
  )
455
502
 
456
- result = torch.zeros(
457
- self.num_sites, single_qubit_operators.shape[0], dtype=torch.complex128
458
- )
503
+ result = torch.zeros(self.num_sites, single_qubit_operators.shape[0], dtype=dtype)
459
504
 
460
505
  center_factor = self.factors[orthogonality_center]
461
506
  for qubit_index in range(orthogonality_center, self.num_sites):
@@ -503,7 +548,7 @@ class MPS(State[complex, torch.Tensor]):
503
548
  )
504
549
 
505
550
  def get_correlation_matrix(
506
- self, *, operator: torch.Tensor = n_operator
551
+ self, operator: torch.Tensor | None = None
507
552
  ) -> torch.Tensor:
508
553
  """
509
554
  Efficiently compute the symmetric correlation matrix
@@ -511,14 +556,18 @@ class MPS(State[complex, torch.Tensor]):
511
556
  in basis ("r", "g").
512
557
 
513
558
  Args:
514
- operator: a 2x2 Torch tensor to use
559
+ operator: a 2x2 (or 3x3) Torch tensor to use
515
560
 
516
561
  Returns:
517
562
  the corresponding correlation matrix
518
563
  """
519
- assert operator.shape == (2, 2)
520
564
 
521
- result = torch.zeros(self.num_sites, self.num_sites, dtype=torch.complex128)
565
+ if operator is None:
566
+ operator = self.n_operator
567
+
568
+ assert operator.shape == (self.dim, self.dim), "Operator has wrong shape"
569
+
570
+ result = torch.zeros(self.num_sites, self.num_sites, dtype=dtype)
522
571
 
523
572
  for left in range(0, self.num_sites):
524
573
  self.orthogonalize(left)
@@ -209,7 +209,8 @@ class MPSBackendImpl:
209
209
  if initial_state is None:
210
210
  self.state = MPS.make(
211
211
  self.qubit_count,
212
- config=self.config,
212
+ precision=self.config.precision,
213
+ max_bond_dim=self.config.max_bond_dim,
213
214
  num_gpus_to_use=self.config.num_gpus_to_use,
214
215
  )
215
216
  return
@@ -236,7 +237,8 @@ class MPSBackendImpl:
236
237
  initial_state = MPS(
237
238
  # Deep copy of every tensor of the initial state.
238
239
  [f.detach().clone() for f in initial_state.factors],
239
- config=self.config,
240
+ precision=self.config.precision,
241
+ max_bond_dim=self.config.max_bond_dim,
240
242
  num_gpus_to_use=self.config.num_gpus_to_use,
241
243
  eigenstates=initial_state.eigenstates,
242
244
  )
emu_mps/mps_config.py CHANGED
@@ -4,6 +4,8 @@ from types import MethodType
4
4
  import copy
5
5
 
6
6
  from emu_base import DEVICE_COUNT
7
+ from emu_mps.mps import MPS, DEFAULT_MAX_BOND_DIM, DEFAULT_PRECISION
8
+ from emu_mps.mpo import MPO
7
9
  from emu_mps.solver import Solver
8
10
  from emu_mps.custom_callback_implementations import (
9
11
  energy_mps_impl,
@@ -33,15 +35,17 @@ class MPSConfig(EmulationConfig):
33
35
  See the API for that class for a list of available options.
34
36
 
35
37
  Args:
36
- dt: the timestep size that the solver uses. Note that observables are
38
+ dt: The timestep size that the solver uses. Note that observables are
37
39
  only calculated if the evaluation_times are divisible by dt.
38
- precision: up to what precision the state is truncated
39
- max_bond_dim: the maximum bond dimension that the state is allowed to have.
40
+ precision: Up to what precision the state is truncated.
41
+ Defaults to `1e-5`.
42
+ max_bond_dim: The maximum bond dimension that the state is allowed to have.
43
+ Defaults to `1024`.
40
44
  max_krylov_dim:
41
- the size of the krylov subspace that the Lanczos algorithm maximally builds
45
+ The size of the krylov subspace that the Lanczos algorithm maximally builds
42
46
  extra_krylov_tolerance:
43
- the Lanczos algorithm uses this*precision as the convergence tolerance
44
- num_gpus_to_use: during the simulation, distribute the state over this many GPUs
47
+ The Lanczos algorithm uses this*precision as the convergence tolerance
48
+ num_gpus_to_use: During the simulation, distribute the state over this many GPUs
45
49
  0=all factors to cpu. As shown in the benchmarks, using multiple GPUs might
46
50
  alleviate memory pressure per GPU, but the runtime should be similar.
47
51
  optimize_qubit_ordering: Optimize the register ordering. Improves performance and
@@ -51,15 +55,15 @@ class MPSConfig(EmulationConfig):
51
55
  log_level: How much to log. Set to `logging.WARN` to get rid of the timestep info.
52
56
  log_file: If specified, log to this file rather than stout.
53
57
  autosave_prefix: filename prefix for autosaving simulation state to file
54
- autosave_dt: minimum time interval in seconds between two autosaves.
58
+ autosave_dt: Minimum time interval in seconds between two autosaves.
55
59
  Saving the simulation state is only possible at specific times,
56
60
  therefore this interval is only a lower bound.
57
- solver: chooses the solver algorithm to run a sequence.
61
+ solver: Chooses the solver algorithm to run a sequence.
58
62
  Two options are currently available:
59
63
  ``TDVP``, which performs ordinary time evolution,
60
64
  and ``DMRG``, which adiabatically follows the ground state
61
65
  of a given adiabatic pulse.
62
- kwargs: arguments that are passed to the base class
66
+ kwargs: Arguments that are passed to the base class
63
67
 
64
68
  Examples:
65
69
  >>> num_gpus_to_use = 2 #use 2 gpus if available, otherwise 1 or cpu
@@ -71,13 +75,15 @@ class MPSConfig(EmulationConfig):
71
75
 
72
76
  # Whether to warn if unexpected kwargs are received
73
77
  _enforce_expected_kwargs: ClassVar[bool] = True
78
+ _state_type = MPS
79
+ _operator_type = MPO
74
80
 
75
81
  def __init__(
76
82
  self,
77
83
  *,
78
84
  dt: int = 10,
79
- precision: float = 1e-5,
80
- max_bond_dim: int = 1024,
85
+ precision: float = DEFAULT_PRECISION,
86
+ max_bond_dim: int = DEFAULT_MAX_BOND_DIM,
81
87
  max_krylov_dim: int = 100,
82
88
  extra_krylov_tolerance: float = 1e-3,
83
89
  num_gpus_to_use: int = DEVICE_COUNT,
emu_mps/observables.py CHANGED
@@ -6,7 +6,18 @@ import torch
6
6
 
7
7
 
8
8
  class EntanglementEntropy(Observable):
9
- """Entanglement Entropy subclass used only in emu_mps"""
9
+ """Entanglement Entropy of the state partition at qubit `mps_site`.
10
+
11
+ Args:
12
+ mps_site: the qubit index at which the bipartition is made.
13
+ All qubits with index $\\leq$ `mps_site` are put in the left partition.
14
+ evaluation_times: The relative times at which to store the state.
15
+ If left as `None`, uses the ``default_evaluation_times`` of the
16
+ backend's ``EmulationConfig``.
17
+ tag_suffix: An optional suffix to append to the tag. Needed if
18
+ multiple instances of the same observable are given to the
19
+ same EmulationConfig.
20
+ """
10
21
 
11
22
  def __init__(
12
23
  self,
emu_mps/solver_utils.py CHANGED
@@ -51,9 +51,10 @@ def make_op(
51
51
  .view(left_ham_factor.shape[0], 4, 4, -1)
52
52
  )
53
53
 
54
- op = lambda x: time_step * apply_effective_Hamiltonian(
55
- x, combined_hamiltonian_factors, left_bath, right_bath
56
- )
54
+ def op(x: torch.Tensor) -> torch.Tensor:
55
+ return time_step * apply_effective_Hamiltonian(
56
+ x, combined_hamiltonian_factors, left_bath, right_bath
57
+ )
57
58
 
58
59
  return combined_state_factors, right_device, op
59
60
 
@@ -205,17 +206,18 @@ def evolve_single(
205
206
 
206
207
  left_bath, right_bath = baths
207
208
 
208
- op = (
209
- lambda x: -_TIME_CONVERSION_COEFF
210
- * 1j
211
- * dt
212
- * apply_effective_Hamiltonian(
213
- x,
214
- ham_factor,
215
- left_bath,
216
- right_bath,
209
+ def op(x: torch.Tensor) -> torch.Tensor:
210
+ return (
211
+ -_TIME_CONVERSION_COEFF
212
+ * 1j
213
+ * dt
214
+ * apply_effective_Hamiltonian(
215
+ x,
216
+ ham_factor,
217
+ left_bath,
218
+ right_bath,
219
+ )
217
220
  )
218
- )
219
221
 
220
222
  return krylov_exp(
221
223
  op,
emu_mps/utils.py CHANGED
@@ -1,8 +1,6 @@
1
1
  from typing import List, Optional
2
2
  import torch
3
3
 
4
- from emu_mps import MPSConfig
5
-
6
4
 
7
5
  def new_left_bath(
8
6
  bath: torch.Tensor, state: torch.Tensor, op: torch.Tensor
@@ -59,8 +57,7 @@ def split_tensor(
59
57
 
60
58
 
61
59
  def truncate_impl(
62
- factors: list[torch.Tensor],
63
- config: MPSConfig,
60
+ factors: list[torch.Tensor], precision: float, max_bond_dim: int
64
61
  ) -> None:
65
62
  """
66
63
  Eigenvalues-based truncation of a matrix product.
@@ -76,8 +73,8 @@ def truncate_impl(
76
73
 
77
74
  l, r = split_tensor(
78
75
  factors[i].view(factor_shape[0], -1),
79
- max_error=config.precision,
80
- max_rank=config.max_bond_dim,
76
+ max_error=precision,
77
+ max_rank=max_bond_dim,
81
78
  orth_center_right=False,
82
79
  )
83
80
 
@@ -208,15 +205,6 @@ def get_extended_site_index(
208
205
  raise ValueError(f"Index {desired_index} does not exist")
209
206
 
210
207
 
211
- n_operator: torch.Tensor = torch.tensor(
212
- [
213
- [0, 0],
214
- [0, 1],
215
- ],
216
- dtype=torch.complex128,
217
- )
218
-
219
-
220
208
  def tensor_trace(tensor: torch.Tensor, dim1: int, dim2: int) -> torch.Tensor:
221
209
  """
222
210
  Contract two legs of a single tensor.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-mps
3
- Version: 2.4.0
3
+ Version: 2.4.1
4
4
  Summary: Pasqal MPS based pulse emulator built on PyTorch
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
@@ -25,7 +25,7 @@ Classifier: Programming Language :: Python :: 3.10
25
25
  Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
- Requires-Dist: emu-base==2.4.0
28
+ Requires-Dist: emu-base==2.4.1
29
29
  Description-Content-Type: text/markdown
30
30
 
31
31
  <div align="center">
@@ -0,0 +1,19 @@
1
+ emu_mps/__init__.py,sha256=R4Fw8r-_GEeckaxyQ9leUKVfLoukcrwvaTQsTbzxhbs,708
2
+ emu_mps/algebra.py,sha256=vi3d4xOBEsfQ7gsjYNns9AMLjJXXXppYVEThavNvAg0,5408
3
+ emu_mps/custom_callback_implementations.py,sha256=WeczmO6qkvBIipvXLqX45i3D7M4ovOrepusIGs6d2Ts,2420
4
+ emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
+ emu_mps/mpo.py,sha256=2HNwN4Fz04QIVfPcPaMmt2q89ZBxN3K-vVeiFkOtqzs,8049
6
+ emu_mps/mps.py,sha256=KEXrLdqhi5EvpdX-9J38ZfrRdt3oDmcGsVLghMPUfQw,21600
7
+ emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
+ emu_mps/mps_backend_impl.py,sha256=Pcbn27lhg3n3Lzo3CGwhlWPqPi-8YV1ntkgl901AoUs,30400
9
+ emu_mps/mps_config.py,sha256=QmwgU8INEnxrxZkhboYIsHGwml0UhgPkddT5zh8KVBU,8867
10
+ emu_mps/observables.py,sha256=4C_ewkd3YkJP0xghTrGUTgXUGvJRCQcetb8cU0SjMl0,1900
11
+ emu_mps/solver.py,sha256=M9xkHhlEouTBvoPw2UYVu6kij7CO4Z1FXw_SiGFtdgo,85
12
+ emu_mps/solver_utils.py,sha256=EnNzEaUrtTMQbrWoqOy8vyDsQwlsfQCUc2HgOp4z8dk,8680
13
+ emu_mps/utils.py,sha256=pW5N_EbbGiOviQpJCw1a0pVgEDObP_InceNaIqY5bHE,6982
14
+ emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
15
+ emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
16
+ emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
17
+ emu_mps-2.4.1.dist-info/METADATA,sha256=2SUt6GLqhz_6-2ewuVqDH9xvz3gx1ZROU975intvuwk,3587
18
+ emu_mps-2.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ emu_mps-2.4.1.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- emu_mps/__init__.py,sha256=ySFn8SLcpsp1ndtvv38qg_gTlO9jt_jmwZVZd21hsg0,708
2
- emu_mps/algebra.py,sha256=78XP9HEbV3wGNUzIulcLU5HizW4XAYmcFdkCe1T1x-k,5489
3
- emu_mps/custom_callback_implementations.py,sha256=SZGKVyS8U5hy07L-3SqpWlCAqGGKFTlSlWexZwSmjrM,2408
4
- emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
- emu_mps/mpo.py,sha256=aWSVuEzZM-_7ZD5Rz3-tSJWX22ARP0tMIl3gUu-_4V4,7834
6
- emu_mps/mps.py,sha256=I7LAxoOPKr4me5ZFdL5AZMerg7go9I1fu-MAI0ZBMgU,19973
7
- emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
- emu_mps/mps_backend_impl.py,sha256=JG4MlEsImc9izF-iwpgyw1t8q4M_20DxbbD4HDObSz0,30268
9
- emu_mps/mps_config.py,sha256=58Y9HEExeHm8p7Zm5CVwku7Y1pQmkaBV7M_hKv4PQcA,8629
10
- emu_mps/observables.py,sha256=7GQDH5kyaVNrwckk2f8ZJRV9Ca4jKhWWDsOCqYWsoEk,1349
11
- emu_mps/solver.py,sha256=M9xkHhlEouTBvoPw2UYVu6kij7CO4Z1FXw_SiGFtdgo,85
12
- emu_mps/solver_utils.py,sha256=VQ02_RxvPcjyXippuIY4Swpx4EdqtoJTt8Ie70GgdqU,8550
13
- emu_mps/utils.py,sha256=hgtaRUtBAzk76ab-S_wTVkvqfVOmaUks38zWame9GRQ,7132
14
- emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
15
- emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
16
- emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
17
- emu_mps-2.4.0.dist-info/METADATA,sha256=6eUzR_h7HcAVtC9Sy3umvCIw-w_bICX98VD-uXjRnxM,3587
18
- emu_mps-2.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- emu_mps-2.4.0.dist-info/RECORD,,