emu-mps 2.3.0__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -9,12 +9,11 @@ from pulser.backend import (
9
9
  StateResult,
10
10
  EnergySecondMoment,
11
11
  )
12
- from .mps_config import MPSConfig
12
+ from .mps_config import MPSConfig, Solver
13
13
  from .mpo import MPO
14
14
  from .mps import MPS, inner
15
15
  from .mps_backend import MPSBackend
16
16
  from .observables import EntanglementEntropy
17
- from emu_base import aggregate
18
17
 
19
18
 
20
19
  __all__ = [
@@ -23,6 +22,7 @@ __all__ = [
23
22
  "MPS",
24
23
  "inner",
25
24
  "MPSConfig",
25
+ "Solver",
26
26
  "MPSBackend",
27
27
  "StateResult",
28
28
  "BitStrings",
@@ -33,8 +33,7 @@ __all__ = [
33
33
  "Energy",
34
34
  "EnergyVariance",
35
35
  "EnergySecondMoment",
36
- "aggregate",
37
36
  "EntanglementEntropy",
38
37
  ]
39
38
 
40
- __version__ = "2.3.0"
39
+ __version__ = "2.4.0"
emu_mps/mps.py CHANGED
@@ -153,7 +153,7 @@ class MPS(State[complex, torch.Tensor]):
153
153
 
154
154
  for i in range(rl_swipe_start, desired_orthogonality_center, -1):
155
155
  q, r = torch.linalg.qr(
156
- self.factors[i].view(self.factors[i].shape[0], -1).mT,
156
+ self.factors[i].contiguous().view(self.factors[i].shape[0], -1).mT,
157
157
  )
158
158
  self.factors[i] = q.mT.view(-1, 2, self.factors[i].shape[2])
159
159
  self.factors[i - 1] = torch.tensordot(
@@ -10,7 +10,6 @@ import uuid
10
10
  from copy import deepcopy
11
11
  from collections import Counter
12
12
  from enum import Enum, auto
13
- from resource import RUSAGE_SELF, getrusage
14
13
  from types import MethodType
15
14
  from typing import Any, Optional
16
15
 
@@ -18,7 +17,7 @@ import torch
18
17
  from pulser import Sequence
19
18
  from pulser.backend import EmulationConfig, Observable, Results, State
20
19
 
21
- from emu_base import DEVICE_COUNT, PulserData
20
+ from emu_base import DEVICE_COUNT, PulserData, get_max_rss
22
21
  from emu_base.math.brents_root_finding import BrentsRootFinder
23
22
  from emu_base.utils import deallocate_tensor
24
23
 
@@ -26,9 +25,9 @@ from emu_mps.hamiltonian import make_H, update_H
26
25
  from emu_mps.mpo import MPO
27
26
  from emu_mps.mps import MPS
28
27
  from emu_mps.mps_config import MPSConfig
29
- from emu_base.noise import pick_dark_qubits
30
28
  from emu_base.jump_lindblad_operators import compute_noise_from_lindbladians
31
29
  import emu_mps.optimatrix as optimat
30
+ from emu_mps.solver import Solver
32
31
  from emu_mps.solver_utils import (
33
32
  evolve_pair,
34
33
  evolve_single,
@@ -69,14 +68,7 @@ class Statistics(Observable):
69
68
  """Calculates the observable to store in the Results."""
70
69
  assert isinstance(state, MPS)
71
70
  duration = self.data[-1]
72
- if state.factors[0].is_cuda:
73
- max_mem_per_device = (
74
- torch.cuda.max_memory_allocated(device) * 1e-6
75
- for device in range(torch.cuda.device_count())
76
- )
77
- max_mem = max(max_mem_per_device)
78
- else:
79
- max_mem = getrusage(RUSAGE_SELF).ru_maxrss * 1e-3
71
+ max_mem = get_max_rss(state.factors[0].is_cuda)
80
72
 
81
73
  config.logger.info(
82
74
  f"step = {len(self.data)}/{self.timestep_count}, "
@@ -193,10 +185,9 @@ class MPSBackendImpl:
193
185
  def init_dark_qubits(self) -> None:
194
186
  # has_state_preparation_error
195
187
  if self.config.noise_model.state_prep_error > 0.0:
188
+ bad_atoms = self.pulser_data.hamiltonian.bad_atoms
196
189
  self.well_prepared_qubits_filter = torch.logical_not(
197
- pick_dark_qubits(
198
- self.config.noise_model.state_prep_error, self.qubit_count
199
- )
190
+ torch.tensor(list(bool(x) for x in bad_atoms.values()))
200
191
  )
201
192
  else:
202
193
  self.well_prepared_qubits_filter = None
@@ -252,6 +243,7 @@ class MPSBackendImpl:
252
243
  initial_state.truncate()
253
244
  initial_state *= 1 / initial_state.norm()
254
245
  self.state = initial_state
246
+ self.state.orthogonalize(0)
255
247
 
256
248
  def init_hamiltonian(self) -> None:
257
249
  """
@@ -699,18 +691,20 @@ class DMRGBackendImpl(MPSBackendImpl):
699
691
  mps_config: MPSConfig,
700
692
  pulser_data: PulserData,
701
693
  energy_tolerance: float = 1e-5,
702
- max_sweeps: int = 999,
703
- residual_tolerance: float = 1e-7,
694
+ max_sweeps: int = 2000,
704
695
  ):
696
+
697
+ if mps_config.noise_model.noise_types != ():
698
+ raise NotImplementedError(
699
+ "DMRG solver does not currently support noise types"
700
+ f"you are using: {mps_config.noise_model.noise_types}"
701
+ )
705
702
  super().__init__(mps_config, pulser_data)
706
- self.init()
707
- self.state.orthogonalize(0)
708
703
  self.previous_energy: Optional[float] = None
709
704
  self.current_energy: Optional[float] = None
710
705
  self.sweep_count: int = 0
711
706
  self.energy_tolerance: float = energy_tolerance
712
707
  self.max_sweeps: int = max_sweeps
713
- self.residual_tolerance: float = residual_tolerance
714
708
 
715
709
  def convergence_check(self, energy_tolerance: float) -> bool:
716
710
  if self.previous_energy is None or self.current_energy is None:
@@ -728,63 +722,72 @@ class DMRGBackendImpl(MPSBackendImpl):
728
722
  SwipeDirection.RIGHT_TO_LEFT,
729
723
  ), "Unknown Swipe direction"
730
724
 
731
- if self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT:
732
- left_idx, right_idx = idx, idx + 1
733
- elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
734
- left_idx, right_idx = idx - 1, idx
735
-
736
725
  orth_center_right = self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT
737
726
  new_L, new_R, energy = minimize_energy_pair(
738
- state_factors=self.state.factors[left_idx : right_idx + 1],
739
- ham_factors=self.hamiltonian.factors[left_idx : right_idx + 1],
727
+ state_factors=self.state.factors[idx : idx + 2],
728
+ ham_factors=self.hamiltonian.factors[idx : idx + 2],
740
729
  baths=(self.left_baths[-1], self.right_baths[-1]),
741
730
  orth_center_right=orth_center_right,
742
731
  config=self.config,
743
- residual_tolerance=self.residual_tolerance,
732
+ residual_tolerance=self.config.precision,
744
733
  )
745
- self.state.factors[left_idx], self.state.factors[right_idx] = new_L, new_R
746
- self.state.orthogonality_center = right_idx if orth_center_right else left_idx
734
+ self.state.factors[idx], self.state.factors[idx + 1] = new_L, new_R
735
+ self.state.orthogonality_center = idx + 1 if orth_center_right else idx
747
736
  self.current_energy = energy
748
737
 
749
738
  # updating baths and orthogonality center
750
739
  if self.swipe_direction == SwipeDirection.LEFT_TO_RIGHT:
740
+ self._left_to_right_update(idx)
741
+ elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
742
+ self._right_to_left_update(idx)
743
+ else:
744
+ raise Exception("Did not expect this")
745
+
746
+ self.save_simulation()
747
+
748
+ def _left_to_right_update(self, idx: int) -> None:
749
+ if idx < self.qubit_count - 2:
751
750
  self.left_baths.append(
752
751
  new_left_bath(
753
752
  self.get_current_left_bath(),
754
- self.state.factors[left_idx],
755
- self.hamiltonian.factors[right_idx],
756
- ).to(self.state.factors[right_idx].device)
753
+ self.state.factors[idx],
754
+ self.hamiltonian.factors[idx],
755
+ ).to(self.state.factors[idx + 1].device)
757
756
  )
758
757
  self.right_baths.pop()
759
758
  self.sweep_index += 1
760
759
 
761
- if self.sweep_index == self.qubit_count - 1:
762
- self.swipe_direction = SwipeDirection.RIGHT_TO_LEFT
760
+ if self.sweep_index == self.qubit_count - 2:
761
+ self.swipe_direction = SwipeDirection.RIGHT_TO_LEFT
763
762
 
764
- elif self.swipe_direction == SwipeDirection.RIGHT_TO_LEFT:
763
+ def _right_to_left_update(self, idx: int) -> None:
764
+ if idx > 0:
765
765
  self.right_baths.append(
766
766
  new_right_bath(
767
767
  self.get_current_right_bath(),
768
- self.state.factors[right_idx],
769
- self.hamiltonian.factors[right_idx],
770
- ).to(self.state.factors[left_idx].device)
768
+ self.state.factors[idx + 1],
769
+ self.hamiltonian.factors[idx + 1],
770
+ ).to(self.state.factors[idx].device)
771
771
  )
772
772
  self.left_baths.pop()
773
773
  self.sweep_index -= 1
774
774
 
775
- if self.sweep_index == 0:
776
- self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
777
- self.sweep_count += 1
778
- self.sweep_complete()
775
+ if self.sweep_index == 0:
776
+ self.state.orthogonalize(0)
777
+ self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
778
+ self.sweep_count += 1
779
+ self.sweep_complete()
779
780
 
780
781
  def sweep_complete(self) -> None:
781
782
  # This marks the end of one full sweep: checking convergence
782
783
  if self.convergence_check(self.energy_tolerance):
784
+ self.current_time = self.target_time
783
785
  self.timestep_complete()
784
786
  elif self.sweep_count + 1 > self.max_sweeps:
785
- # not converged: restart a new sweep
787
+ # not converged
786
788
  raise RuntimeError(f"DMRG did not converge after {self.max_sweeps} sweeps")
787
789
  else:
790
+ # not converged for the current sweep. restart
788
791
  self.previous_energy = self.current_energy
789
792
 
790
793
  assert self.sweep_index == 0
@@ -798,5 +801,6 @@ def create_impl(sequence: Sequence, config: MPSConfig) -> MPSBackendImpl:
798
801
 
799
802
  if pulser_data.has_lindblad_noise:
800
803
  return NoisyMPSBackendImpl(config, pulser_data)
801
-
804
+ if config.solver == Solver.DMRG:
805
+ return DMRGBackendImpl(config, pulser_data)
802
806
  return MPSBackendImpl(config, pulser_data)
emu_mps/mps_config.py CHANGED
@@ -4,6 +4,7 @@ from types import MethodType
4
4
  import copy
5
5
 
6
6
  from emu_base import DEVICE_COUNT
7
+ from emu_mps.solver import Solver
7
8
  from emu_mps.custom_callback_implementations import (
8
9
  energy_mps_impl,
9
10
  energy_second_moment_mps_impl,
@@ -53,6 +54,11 @@ class MPSConfig(EmulationConfig):
53
54
  autosave_dt: minimum time interval in seconds between two autosaves.
54
55
  Saving the simulation state is only possible at specific times,
55
56
  therefore this interval is only a lower bound.
57
+ solver: chooses the solver algorithm to run a sequence.
58
+ Two options are currently available:
59
+ ``TDVP``, which performs ordinary time evolution,
60
+ and ``DMRG``, which adiabatically follows the ground state
61
+ of a given adiabatic pulse.
56
62
  kwargs: arguments that are passed to the base class
57
63
 
58
64
  Examples:
@@ -81,6 +87,7 @@ class MPSConfig(EmulationConfig):
81
87
  log_file: pathlib.Path | None = None,
82
88
  autosave_prefix: str = "emu_mps_save_",
83
89
  autosave_dt: int = 600, # 10 minutes
90
+ solver: Solver = Solver.TDVP,
84
91
  **kwargs: Any,
85
92
  ):
86
93
  kwargs.setdefault("observables", [BitStrings(evaluation_times=[1.0])])
@@ -97,6 +104,7 @@ class MPSConfig(EmulationConfig):
97
104
  log_file=log_file,
98
105
  autosave_prefix=autosave_prefix,
99
106
  autosave_dt=autosave_dt,
107
+ solver=solver,
100
108
  **kwargs,
101
109
  )
102
110
  if self.optimize_qubit_ordering:
@@ -144,6 +152,7 @@ class MPSConfig(EmulationConfig):
144
152
  "log_file",
145
153
  "autosave_prefix",
146
154
  "autosave_dt",
155
+ "solver",
147
156
  }
148
157
 
149
158
  def monkeypatch_observables(self) -> None:
emu_mps/solver.py ADDED
@@ -0,0 +1,6 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Solver(str, Enum):
5
+ TDVP = "tdvp"
6
+ DMRG = "dmrg"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-mps
3
- Version: 2.3.0
3
+ Version: 2.4.0
4
4
  Summary: Pasqal MPS based pulse emulator built on PyTorch
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
@@ -25,7 +25,7 @@ Classifier: Programming Language :: Python :: 3.10
25
25
  Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
- Requires-Dist: emu-base==2.3.0
28
+ Requires-Dist: emu-base==2.4.0
29
29
  Description-Content-Type: text/markdown
30
30
 
31
31
  <div align="center">
@@ -1,18 +1,19 @@
1
- emu_mps/__init__.py,sha256=vvVRxFPudsVALA-_YyqWf8rlIRhtI1WzBWakYk2vfB8,734
1
+ emu_mps/__init__.py,sha256=ySFn8SLcpsp1ndtvv38qg_gTlO9jt_jmwZVZd21hsg0,708
2
2
  emu_mps/algebra.py,sha256=78XP9HEbV3wGNUzIulcLU5HizW4XAYmcFdkCe1T1x-k,5489
3
3
  emu_mps/custom_callback_implementations.py,sha256=SZGKVyS8U5hy07L-3SqpWlCAqGGKFTlSlWexZwSmjrM,2408
4
4
  emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
5
  emu_mps/mpo.py,sha256=aWSVuEzZM-_7ZD5Rz3-tSJWX22ARP0tMIl3gUu-_4V4,7834
6
- emu_mps/mps.py,sha256=n5KPuWr2qCNW4Q1OKpUGrSuhu4IbVp1v1BPLvLILbGg,19960
6
+ emu_mps/mps.py,sha256=I7LAxoOPKr4me5ZFdL5AZMerg7go9I1fu-MAI0ZBMgU,19973
7
7
  emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
- emu_mps/mps_backend_impl.py,sha256=O3vbjw3jlgyj0hOYv3nMmAP0lINqxbt96mQPgwPZN5w,30198
9
- emu_mps/mps_config.py,sha256=PoSKZxJMhG6zfzgEjj4tIvyiyYRQywxkRgidh8MRBsA,8222
8
+ emu_mps/mps_backend_impl.py,sha256=JG4MlEsImc9izF-iwpgyw1t8q4M_20DxbbD4HDObSz0,30268
9
+ emu_mps/mps_config.py,sha256=58Y9HEExeHm8p7Zm5CVwku7Y1pQmkaBV7M_hKv4PQcA,8629
10
10
  emu_mps/observables.py,sha256=7GQDH5kyaVNrwckk2f8ZJRV9Ca4jKhWWDsOCqYWsoEk,1349
11
+ emu_mps/solver.py,sha256=M9xkHhlEouTBvoPw2UYVu6kij7CO4Z1FXw_SiGFtdgo,85
11
12
  emu_mps/solver_utils.py,sha256=VQ02_RxvPcjyXippuIY4Swpx4EdqtoJTt8Ie70GgdqU,8550
12
13
  emu_mps/utils.py,sha256=hgtaRUtBAzk76ab-S_wTVkvqfVOmaUks38zWame9GRQ,7132
13
14
  emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
14
15
  emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
15
16
  emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
16
- emu_mps-2.3.0.dist-info/METADATA,sha256=WLLpx31eshWCwR95aB_B4yoM8BRAKT81UR2mjVR3998,3587
17
- emu_mps-2.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
- emu_mps-2.3.0.dist-info/RECORD,,
17
+ emu_mps-2.4.0.dist-info/METADATA,sha256=6eUzR_h7HcAVtC9Sy3umvCIw-w_bICX98VD-uXjRnxM,3587
18
+ emu_mps-2.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ emu_mps-2.4.0.dist-info/RECORD,,