emu-mps 2.1.1__py3-none-any.whl → 2.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/__init__.py CHANGED
@@ -37,4 +37,4 @@ __all__ = [
37
37
  "EntanglementEntropy",
38
38
  ]
39
39
 
40
- __version__ = "2.1.1"
40
+ __version__ = "2.2.1"
emu_mps/algebra.py CHANGED
@@ -106,14 +106,14 @@ def zip_right_step(
106
106
  # reshape slider as matrix
107
107
  left_inds = (slider.shape[0], *bottom.shape[1:-1])
108
108
  right_inds = (top.shape[-1], bottom.shape[-1])
109
- slider = slider.reshape(math.prod(left_inds), math.prod(right_inds))
109
+ slider = slider.contiguous().view(math.prod(left_inds), math.prod(right_inds))
110
110
 
111
111
  L, slider = torch.linalg.qr(slider)
112
112
 
113
113
  # reshape slider to its original shape
114
- slider = slider.reshape((-1, *right_inds))
114
+ slider = slider.view((-1, *right_inds))
115
115
  # reshape left as MPS/O factor and
116
- return L.reshape(*left_inds, -1), slider
116
+ return L.view(*left_inds, -1), slider
117
117
 
118
118
 
119
119
  def zip_right(
@@ -28,7 +28,7 @@ def qubit_occupation_mps_impl(
28
28
  op = torch.tensor(
29
29
  [[[0.0, 0.0], [0.0, 1.0]]], dtype=torch.complex128, device=state.factors[0].device
30
30
  )
31
- return state.expect_batch(op).real.reshape(-1).cpu()
31
+ return state.expect_batch(op).real.view(-1).cpu()
32
32
 
33
33
 
34
34
  def correlation_matrix_mps_impl(
emu_mps/mpo.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from __future__ import annotations
2
- import itertools
3
2
  from typing import Any, List, Sequence, Optional
4
3
 
5
4
  import torch
@@ -12,24 +11,6 @@ from emu_mps.mps import MPS
12
11
  from emu_mps.utils import new_left_bath, assign_devices
13
12
 
14
13
 
15
- def _validate_operator_targets(operations: FullOp, nqubits: int) -> None:
16
- for tensorop in operations:
17
- target_qids = (factor[1] for factor in tensorop[1])
18
- target_qids_list = list(itertools.chain(*target_qids))
19
- target_qids_set = set(target_qids_list)
20
- if len(target_qids_set) < len(target_qids_list):
21
- # Either the qubit id has been defined twice in an operation:
22
- for qids in target_qids:
23
- if len(set(qids)) < len(qids):
24
- raise ValueError("Duplicate atom ids in argument list.")
25
- # Or it was defined in two different operations
26
- raise ValueError("Each qubit can be targeted by only one operation.")
27
- if max(target_qids_set) >= nqubits:
28
- raise ValueError(
29
- "The operation targets more qubits than there are in the register."
30
- )
31
-
32
-
33
14
  class MPO(Operator[complex, torch.Tensor, MPS]):
34
15
  """
35
16
  Matrix Product Operator.
@@ -152,7 +133,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
152
133
  state.factors[i + 1].device
153
134
  )
154
135
  acc = new_left_bath(acc, state.factors[n], self.factors[n])
155
- return acc.reshape(1)[0].cpu()
136
+ return acc.view(1)[0].cpu()
156
137
 
157
138
  @classmethod
158
139
  def _from_operator_repr(
@@ -176,8 +157,6 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
176
157
  the operator in MPO form.
177
158
  """
178
159
 
179
- _validate_operator_targets(operations, n_qudits)
180
-
181
160
  basis = set(eigenstates)
182
161
 
183
162
  operators_with_tensors: dict[str, torch.Tensor | QuditOp]
@@ -185,35 +164,35 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
185
164
  # operators_with_tensors will now contain the basis for single qubit ops,
186
165
  # and potentially user defined strings in terms of these
187
166
  operators_with_tensors = {
188
- "gg": torch.tensor(
189
- [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
190
- ).reshape(1, 2, 2, 1),
191
- "gr": torch.tensor(
192
- [[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128
193
- ).reshape(1, 2, 2, 1),
194
- "rg": torch.tensor(
195
- [[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128
196
- ).reshape(1, 2, 2, 1),
197
- "rr": torch.tensor(
198
- [[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128
199
- ).reshape(1, 2, 2, 1),
167
+ "gg": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128).view(
168
+ 1, 2, 2, 1
169
+ ),
170
+ "gr": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128).view(
171
+ 1, 2, 2, 1
172
+ ),
173
+ "rg": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128).view(
174
+ 1, 2, 2, 1
175
+ ),
176
+ "rr": torch.tensor([[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128).view(
177
+ 1, 2, 2, 1
178
+ ),
200
179
  }
201
180
  elif basis == {"0", "1"}:
202
181
  # operators_with_tensors will now contain the basis for single qubit ops,
203
182
  # and potentially user defined strings in terms of these
204
183
  operators_with_tensors = {
205
- "00": torch.tensor(
206
- [[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128
207
- ).reshape(1, 2, 2, 1),
208
- "01": torch.tensor(
209
- [[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128
210
- ).reshape(1, 2, 2, 1),
211
- "10": torch.tensor(
212
- [[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128
213
- ).reshape(1, 2, 2, 1),
214
- "11": torch.tensor(
215
- [[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128
216
- ).reshape(1, 2, 2, 1),
184
+ "00": torch.tensor([[1.0, 0.0], [0.0, 0.0]], dtype=torch.complex128).view(
185
+ 1, 2, 2, 1
186
+ ),
187
+ "01": torch.tensor([[0.0, 0.0], [1.0, 0.0]], dtype=torch.complex128).view(
188
+ 1, 2, 2, 1
189
+ ),
190
+ "10": torch.tensor([[0.0, 1.0], [0.0, 0.0]], dtype=torch.complex128).view(
191
+ 1, 2, 2, 1
192
+ ),
193
+ "11": torch.tensor([[0.0, 0.0], [0.0, 1.0]], dtype=torch.complex128).view(
194
+ 1, 2, 2, 1
195
+ ),
217
196
  }
218
197
  else:
219
198
  raise ValueError("Unsupported basis provided")
@@ -235,7 +214,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
235
214
  return result
236
215
 
237
216
  factors = [
238
- torch.eye(2, 2, dtype=torch.complex128).reshape(1, 2, 2, 1)
217
+ torch.eye(2, 2, dtype=torch.complex128).view(1, 2, 2, 1)
239
218
  ] * n_qudits
240
219
 
241
220
  for op in tensorop:
emu_mps/mps.py CHANGED
@@ -140,8 +140,8 @@ class MPS(State[complex, torch.Tensor]):
140
140
  )
141
141
 
142
142
  for i in range(lr_swipe_start, desired_orthogonality_center):
143
- q, r = torch.linalg.qr(self.factors[i].reshape(-1, self.factors[i].shape[2]))
144
- self.factors[i] = q.reshape(self.factors[i].shape[0], 2, -1)
143
+ q, r = torch.linalg.qr(self.factors[i].view(-1, self.factors[i].shape[2]))
144
+ self.factors[i] = q.view(self.factors[i].shape[0], 2, -1)
145
145
  self.factors[i + 1] = torch.tensordot(
146
146
  r.to(self.factors[i + 1].device), self.factors[i + 1], dims=1
147
147
  )
@@ -154,9 +154,9 @@ class MPS(State[complex, torch.Tensor]):
154
154
 
155
155
  for i in range(rl_swipe_start, desired_orthogonality_center, -1):
156
156
  q, r = torch.linalg.qr(
157
- self.factors[i].reshape(self.factors[i].shape[0], -1).mT,
157
+ self.factors[i].view(self.factors[i].shape[0], -1).mT,
158
158
  )
159
- self.factors[i] = q.mT.reshape(-1, 2, self.factors[i].shape[2])
159
+ self.factors[i] = q.mT.view(-1, 2, self.factors[i].shape[2])
160
160
  self.factors[i - 1] = torch.tensordot(
161
161
  self.factors[i - 1], r.to(self.factors[i - 1].device), ([2], [1])
162
162
  )
@@ -301,7 +301,7 @@ class MPS(State[complex, torch.Tensor]):
301
301
  acc = torch.tensordot(acc, other.factors[i].to(acc.device), dims=1)
302
302
  acc = torch.tensordot(self.factors[i].conj(), acc, dims=([0, 1], [0, 1]))
303
303
 
304
- return acc.reshape(1)[0].cpu()
304
+ return acc.view(1)[0].cpu()
305
305
 
306
306
  def overlap(self, other: State, /) -> torch.Tensor:
307
307
  """
@@ -399,6 +399,7 @@ class MPS(State[complex, torch.Tensor]):
399
399
  cls,
400
400
  *,
401
401
  eigenstates: Sequence[Eigenstate],
402
+ n_qudits: int,
402
403
  amplitudes: Mapping[str, complex],
403
404
  ) -> tuple[MPS, Mapping[str, complex]]:
404
405
  """
@@ -420,13 +421,11 @@ class MPS(State[complex, torch.Tensor]):
420
421
  else:
421
422
  raise ValueError("Unsupported basis provided")
422
423
 
423
- nqubits = cls._validate_amplitudes(amplitudes, eigenstates)
424
-
425
424
  basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
426
425
  basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
427
426
 
428
427
  accum_mps = MPS(
429
- [torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
428
+ [torch.zeros((1, 2, 1), dtype=torch.complex128)] * n_qudits,
430
429
  orthogonality_center=0,
431
430
  eigenstates=eigenstates,
432
431
  )
@@ -468,7 +467,7 @@ class MPS(State[complex, torch.Tensor]):
468
467
  )
469
468
 
470
469
  if qubit_index < self.num_sites - 1:
471
- _, r = torch.linalg.qr(center_factor.reshape(-1, center_factor.shape[2]))
470
+ _, r = torch.linalg.qr(center_factor.view(-1, center_factor.shape[2]))
472
471
  center_factor = torch.tensordot(
473
472
  r, self.factors[qubit_index + 1].to(r.device), dims=1
474
473
  )
@@ -476,7 +475,7 @@ class MPS(State[complex, torch.Tensor]):
476
475
  center_factor = self.factors[orthogonality_center]
477
476
  for qubit_index in range(orthogonality_center - 1, -1, -1):
478
477
  _, r = torch.linalg.qr(
479
- center_factor.reshape(center_factor.shape[0], -1).mT,
478
+ center_factor.view(center_factor.shape[0], -1).mT,
480
479
  )
481
480
  center_factor = torch.tensordot(
482
481
  self.factors[qubit_index],
@@ -499,11 +498,10 @@ class MPS(State[complex, torch.Tensor]):
499
498
  """
500
499
  self.orthogonalize(qubit_index)
501
500
 
502
- self.factors[qubit_index] = torch.tensordot(
503
- self.factors[qubit_index],
504
- single_qubit_operator.to(self.factors[qubit_index].device),
505
- ([1], [1]),
506
- ).transpose(1, 2)
501
+ self.factors[qubit_index] = (
502
+ single_qubit_operator.to(self.factors[qubit_index].device)
503
+ @ self.factors[qubit_index]
504
+ )
507
505
 
508
506
  def get_correlation_matrix(
509
507
  self, *, operator: torch.Tensor = n_operator
@@ -7,6 +7,7 @@ import time
7
7
  import typing
8
8
  import uuid
9
9
 
10
+ from copy import deepcopy
10
11
  from collections import Counter
11
12
  from enum import Enum, auto
12
13
  from resource import RUSAGE_SELF, getrusage
@@ -19,6 +20,7 @@ from pulser.backend import EmulationConfig, Observable, Results, State
19
20
 
20
21
  from emu_base import DEVICE_COUNT, PulserData
21
22
  from emu_base.math.brents_root_finding import BrentsRootFinder
23
+ from emu_base.utils import deallocate_tensor
22
24
 
23
25
  from emu_mps.hamiltonian import make_H, update_H
24
26
  from emu_mps.mpo import MPO
@@ -27,7 +29,7 @@ from emu_mps.mps_config import MPSConfig
27
29
  from emu_mps.noise import pick_well_prepared_qubits
28
30
  from emu_base.jump_lindblad_operators import compute_noise_from_lindbladians
29
31
  import emu_mps.optimatrix as optimat
30
- from emu_mps.tdvp import (
32
+ from emu_mps.solver_utils import (
31
33
  evolve_pair,
32
34
  evolve_single,
33
35
  new_right_bath,
@@ -168,9 +170,12 @@ class MPSBackendImpl:
168
170
  )
169
171
 
170
172
  def __getstate__(self) -> dict:
171
- for obs in self.config.observables:
172
- obs.apply = MethodType(type(obs).apply, obs) # type: ignore[method-assign]
173
173
  d = self.__dict__.copy()
174
+ cp = deepcopy(self.config)
175
+ d["config"] = cp
176
+ d["state"].config = cp
177
+ for obs in cp.observables:
178
+ obs.apply = MethodType(type(obs).apply, obs) # type: ignore[method-assign]
174
179
  # mypy thinks the method below is an attribute, because of the __getattr__ override
175
180
  d["results"] = self.results._to_abstract_repr() # type: ignore[operator]
176
181
  return d
@@ -411,7 +416,7 @@ class MPSBackendImpl:
411
416
  )
412
417
  if not self.has_lindblad_noise:
413
418
  # Free memory because it won't be used anymore
414
- self.right_baths[-2] = torch.zeros(0)
419
+ deallocate_tensor(self.right_baths[-2])
415
420
 
416
421
  self._evolve(self.tdvp_index, dt=-delta_time / 2)
417
422
  self.left_baths.pop()
@@ -660,7 +665,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
660
665
  for qubit in range(self.state.num_sites)
661
666
  for op in self.lindblad_ops
662
667
  ],
663
- weights=jump_operator_weights.reshape(-1).tolist(),
668
+ weights=jump_operator_weights.view(-1).tolist(),
664
669
  )[0]
665
670
 
666
671
  self.state.apply(jumped_qubit_index, jump_operator)
emu_mps/mps_config.py CHANGED
@@ -102,14 +102,6 @@ class MPSConfig(EmulationConfig):
102
102
  if self.optimize_qubit_ordering:
103
103
  self.check_permutable_observables()
104
104
 
105
- if "doppler" in self.noise_model.noise_types:
106
- raise NotImplementedError("Unsupported noise type: doppler")
107
- if (
108
- "amplitude" in self.noise_model.noise_types
109
- and self.noise_model.amp_sigma != 0.0
110
- ):
111
- raise NotImplementedError("Unsupported noise type: amp_sigma")
112
-
113
105
  MIN_AUTOSAVE_DT = 10
114
106
  assert (
115
107
  self.autosave_dt > MIN_AUTOSAVE_DT
@@ -1,11 +1,63 @@
1
1
  import torch
2
+ from typing import Callable, Sequence
2
3
 
3
4
  from emu_base import krylov_exp
5
+ from emu_base.math.krylov_energy_min import krylov_energy_minimization
6
+ from emu_base.utils import deallocate_tensor
4
7
  from emu_mps import MPS, MPO
5
8
  from emu_mps.utils import split_tensor
6
9
  from emu_mps.mps_config import MPSConfig
7
10
 
8
11
 
12
+ def make_op(
13
+ time_step: float | complex,
14
+ state_factors: Sequence[torch.Tensor],
15
+ baths: tuple[torch.Tensor, torch.Tensor],
16
+ ham_factors: Sequence[torch.Tensor],
17
+ ) -> tuple[torch.Tensor, torch.device, Callable[[torch.Tensor], torch.Tensor]]:
18
+ assert len(state_factors) == 2
19
+ assert len(baths) == 2
20
+ assert len(ham_factors) == 2
21
+
22
+ left_state_factor, right_state_factor = state_factors
23
+ left_bath, right_bath = baths
24
+ left_ham_factor, right_ham_factor = ham_factors
25
+
26
+ left_device = left_state_factor.device
27
+ right_device = right_state_factor.device
28
+
29
+ left_bond_dim = left_state_factor.shape[0]
30
+ right_bond_dim = right_state_factor.shape[-1]
31
+
32
+ # Computation is done on left_device (arbitrary)
33
+
34
+ right_state_factor = right_state_factor.to(left_device)
35
+
36
+ combined_state_factors = torch.tensordot(
37
+ left_state_factor, right_state_factor, dims=1
38
+ ).reshape(left_bond_dim, 4, right_bond_dim)
39
+
40
+ deallocate_tensor(left_state_factor)
41
+ deallocate_tensor(right_state_factor)
42
+
43
+ left_ham_factor = left_ham_factor.to(left_device)
44
+ right_ham_factor = right_ham_factor.to(left_device)
45
+ right_bath = right_bath.to(left_device)
46
+
47
+ combined_hamiltonian_factors = (
48
+ torch.tensordot(left_ham_factor, right_ham_factor, dims=1)
49
+ .transpose(2, 3)
50
+ .contiguous()
51
+ .view(left_ham_factor.shape[0], 4, 4, -1)
52
+ )
53
+
54
+ op = lambda x: time_step * apply_effective_Hamiltonian(
55
+ x, combined_hamiltonian_factors, left_bath, right_bath
56
+ )
57
+
58
+ return combined_state_factors, right_device, op
59
+
60
+
9
61
  def new_right_bath(
10
62
  bath: torch.Tensor, state: torch.Tensor, op: torch.Tensor
11
63
  ) -> torch.Tensor:
@@ -75,13 +127,13 @@ def apply_effective_Hamiltonian(
75
127
  state = torch.tensordot(left_bath, state, 1)
76
128
  state = state.permute(0, 3, 1, 2)
77
129
  ham = ham.permute(0, 2, 1, 3)
78
- state = state.reshape(state.shape[0], state.shape[1], -1).contiguous()
79
- ham = ham.reshape(-1, ham.shape[2], ham.shape[3]).contiguous()
130
+ state = state.view(state.shape[0], state.shape[1], -1).contiguous()
131
+ ham = ham.contiguous().view(-1, ham.shape[2], ham.shape[3])
80
132
  state = torch.tensordot(state, ham, 1)
81
133
  state = state.permute(0, 2, 1, 3)
82
- state = state.reshape(state.shape[0], state.shape[1], -1).contiguous()
134
+ state = state.contiguous().view(state.shape[0], state.shape[1], -1)
83
135
  right_bath = right_bath.permute(2, 1, 0)
84
- right_bath = right_bath.reshape(-1, right_bath.shape[2]).contiguous()
136
+ right_bath = right_bath.contiguous().view(-1, right_bath.shape[2])
85
137
  state = torch.tensordot(state, right_bath, 1)
86
138
  return state
87
139
 
@@ -102,45 +154,19 @@ def evolve_pair(
102
154
  """
103
155
  Time evolution of a pair of tensors of a tensor train using baths and truncated SVD.
104
156
  Returned state tensors are kept on their respective devices.
105
- """
106
- assert len(state_factors) == 2
107
- assert len(baths) == 2
108
- assert len(ham_factors) == 2
109
-
110
- left_state_factor, right_state_factor = state_factors
111
- left_bath, right_bath = baths
112
- left_ham_factor, right_ham_factor = ham_factors
113
-
114
- left_device = left_state_factor.device
115
- right_device = right_state_factor.device
116
-
117
- left_bond_dim = left_state_factor.shape[0]
118
- right_bond_dim = right_state_factor.shape[-1]
119
-
120
- # Computation is done on left_device (arbitrary)
121
-
122
- combined_state_factors = torch.tensordot(
123
- left_state_factor, right_state_factor.to(left_device), dims=1
124
- ).reshape(left_bond_dim, 4, right_bond_dim)
125
157
 
126
- left_ham_factor = left_ham_factor.to(left_device)
127
- right_ham_factor = right_ham_factor.to(left_device)
128
- right_bath = right_bath.to(left_device)
129
-
130
- combined_hamiltonian_factors = (
131
- torch.tensordot(left_ham_factor, right_ham_factor, dims=1)
132
- .transpose(2, 3)
133
- .reshape(left_ham_factor.shape[0], 4, 4, -1)
134
- )
158
+ The input state tensor objects become invalid after calling that function.
159
+ """
135
160
 
136
- op = (
137
- lambda x: -_TIME_CONVERSION_COEFF
138
- * 1j
139
- * dt
140
- * apply_effective_Hamiltonian(
141
- x, combined_hamiltonian_factors, left_bath, right_bath
142
- )
161
+ time_step = -1j * _TIME_CONVERSION_COEFF * dt
162
+ combined_state_factors, right_device, op = make_op(
163
+ time_step=time_step,
164
+ state_factors=state_factors,
165
+ baths=baths,
166
+ ham_factors=ham_factors,
143
167
  )
168
+ left_bond_dim = combined_state_factors.shape[0]
169
+ right_bond_dim = combined_state_factors.shape[-1]
144
170
 
145
171
  evol = krylov_exp(
146
172
  op,
@@ -149,7 +175,7 @@ def evolve_pair(
149
175
  norm_tolerance=config.precision * config.extra_krylov_tolerance,
150
176
  max_krylov_dim=config.max_krylov_dim,
151
177
  is_hermitian=is_hermitian,
152
- ).reshape(left_bond_dim * 2, 2 * right_bond_dim)
178
+ ).view(left_bond_dim * 2, 2 * right_bond_dim)
153
179
 
154
180
  l, r = split_tensor(
155
181
  evol,
@@ -158,9 +184,7 @@ def evolve_pair(
158
184
  orth_center_right=orth_center_right,
159
185
  )
160
186
 
161
- return l.reshape(left_bond_dim, 2, -1), r.reshape(-1, 2, right_bond_dim).to(
162
- right_device
163
- )
187
+ return l.view(left_bond_dim, 2, -1), r.view(-1, 2, right_bond_dim).to(right_device)
164
188
 
165
189
 
166
190
  def evolve_single(
@@ -174,6 +198,8 @@ def evolve_single(
174
198
  ) -> torch.Tensor:
175
199
  """
176
200
  Time evolution of a single tensor of a tensor train using baths.
201
+
202
+ The input state tensor object becomes invalid after calling that function.
177
203
  """
178
204
  assert len(baths) == 2
179
205
 
@@ -199,3 +225,50 @@ def evolve_single(
199
225
  max_krylov_dim=config.max_krylov_dim,
200
226
  is_hermitian=is_hermitian,
201
227
  )
228
+
229
+
230
+ def minimize_energy_pair(
231
+ *,
232
+ state_factors: tuple[torch.Tensor],
233
+ baths: tuple[torch.Tensor, torch.Tensor],
234
+ ham_factors: tuple[torch.Tensor],
235
+ orth_center_right: bool,
236
+ config: MPSConfig,
237
+ residual_tolerance: float,
238
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
239
+ """
240
+ Minimizes the state factors (ψ_i, ψ_{i+1}) using the Lanczos/Arnoldi method
241
+ """
242
+
243
+ time_step = 1.0
244
+ combined_state_factors, right_device, op = make_op(
245
+ time_step=time_step,
246
+ state_factors=state_factors,
247
+ baths=baths,
248
+ ham_factors=ham_factors,
249
+ )
250
+
251
+ left_bond_dim = combined_state_factors.shape[0]
252
+ right_bond_dim = combined_state_factors.shape[-1]
253
+
254
+ updated_state, updated_energy = krylov_energy_minimization(
255
+ op,
256
+ combined_state_factors,
257
+ norm_tolerance=config.precision * config.extra_krylov_tolerance,
258
+ residual_tolerance=residual_tolerance,
259
+ max_krylov_dim=config.max_krylov_dim,
260
+ )
261
+ updated_state = updated_state.view(left_bond_dim * 2, 2 * right_bond_dim)
262
+
263
+ l, r = split_tensor(
264
+ updated_state,
265
+ max_error=config.precision,
266
+ max_rank=config.max_bond_dim,
267
+ orth_center_right=orth_center_right,
268
+ )
269
+
270
+ return (
271
+ l.view(left_bond_dim, 2, -1),
272
+ r.view(-1, 2, right_bond_dim).to(right_device),
273
+ updated_energy,
274
+ )
emu_mps/utils.py CHANGED
@@ -53,7 +53,7 @@ def split_tensor(
53
53
  _determine_cutoff_index(d, max_error),
54
54
  d.shape[0] - max_rank,
55
55
  )
56
- right = q.T.conj()[max_bond:, :]
56
+ right = q[:, max_bond:].T.conj_physical()
57
57
  left = m @ q
58
58
  left = left[:, max_bond:]
59
59
 
@@ -77,13 +77,13 @@ def truncate_impl(
77
77
  factor_shape = factors[i].shape
78
78
 
79
79
  l, r = split_tensor(
80
- factors[i].reshape(factor_shape[0], -1),
80
+ factors[i].view(factor_shape[0], -1),
81
81
  max_error=config.precision,
82
82
  max_rank=config.max_bond_dim,
83
83
  orth_center_right=False,
84
84
  )
85
85
 
86
- factors[i] = r.reshape(-1, *factor_shape[1:])
86
+ factors[i] = r.view(-1, *factor_shape[1:])
87
87
  factors[i - 1] = torch.tensordot(
88
88
  factors[i - 1], l.to(factors[i - 1].device), dims=1
89
89
  )
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emu-mps
3
- Version: 2.1.1
3
+ Version: 2.2.1
4
4
  Summary: Pasqal MPS based pulse emulator built on PyTorch
5
5
  Project-URL: Documentation, https://pasqal-io.github.io/emulators/
6
6
  Project-URL: Repository, https://github.com/pasqal-io/emulators
7
7
  Project-URL: Issues, https://github.com/pasqal-io/emulators/issues
8
- Author-email: Anton Quelle <anton.quelle@pasqal.com>, Mauro Mendizabal <mauro.mendizabal-pico@pasqal.com>, Stefano Grava <stefano.grava@pasqal.com>, Pablo Le Henaff <pablo.le-henaff@pasqal.com>
8
+ Author-email: Kemal Bidzhiev <kemal.bidzhiev@pasqal.com>, Stefano Grava <stefano.grava@pasqal.com>, Pablo Le Henaff <pablo.le-henaff@pasqal.com>, Mauro Mendizabal <mauro.mendizabal-pico@pasqal.com>, Elie Merhej <elie.merhej@pasqal.com>, Anton Quelle <anton.quelle@pasqal.com>
9
9
  License: PASQAL OPEN-SOURCE SOFTWARE LICENSE AGREEMENT (MIT-derived)
10
10
 
11
11
  The author of the License is:
@@ -25,7 +25,7 @@ Classifier: Programming Language :: Python :: 3.10
25
25
  Classifier: Programming Language :: Python :: Implementation :: CPython
26
26
  Classifier: Programming Language :: Python :: Implementation :: PyPy
27
27
  Requires-Python: >=3.10
28
- Requires-Dist: emu-base==2.1.1
28
+ Requires-Dist: emu-base==2.2.1
29
29
  Description-Content-Type: text/markdown
30
30
 
31
31
  <div align="center">
@@ -0,0 +1,19 @@
1
+ emu_mps/__init__.py,sha256=iXV15aC4QkDT-W_pEv2vLF1vcZAUtb7GP6D4kBEYxYk,734
2
+ emu_mps/algebra.py,sha256=78XP9HEbV3wGNUzIulcLU5HizW4XAYmcFdkCe1T1x-k,5489
3
+ emu_mps/custom_callback_implementations.py,sha256=SZGKVyS8U5hy07L-3SqpWlCAqGGKFTlSlWexZwSmjrM,2408
4
+ emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
+ emu_mps/mpo.py,sha256=aWSVuEzZM-_7ZD5Rz3-tSJWX22ARP0tMIl3gUu-_4V4,7834
6
+ emu_mps/mps.py,sha256=8i3Yz5C_cqHWeAJILOm7cz8P8cHDuNl6aBcLXtaO314,19964
7
+ emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
+ emu_mps/mps_backend_impl.py,sha256=SGqL2KiICE1fLviycrRdVlw3LOYRjpF5B0XgQlElSBs,26047
9
+ emu_mps/mps_config.py,sha256=PoSKZxJMhG6zfzgEjj4tIvyiyYRQywxkRgidh8MRBsA,8222
10
+ emu_mps/noise.py,sha256=5BXthepWLKnuSTJfIFuPl2AcYPxUeTJdRc2b28ekkhg,208
11
+ emu_mps/observables.py,sha256=7GQDH5kyaVNrwckk2f8ZJRV9Ca4jKhWWDsOCqYWsoEk,1349
12
+ emu_mps/solver_utils.py,sha256=NWwg6AeCCOrx8a5_ysSojdAOmg73W1202FtYx2JEHH0,8544
13
+ emu_mps/utils.py,sha256=PRPIe9B8n-6caVcUYn3uTFSvb3jAMkXX-63f8KtX5-U,8196
14
+ emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
15
+ emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
16
+ emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
17
+ emu_mps-2.2.1.dist-info/METADATA,sha256=FLJwfRKhNCuaThXQ39P92HcV4Fpz5KxPqyJMddrV-K8,3587
18
+ emu_mps-2.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
+ emu_mps-2.2.1.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- emu_mps/__init__.py,sha256=8GuHi09MFuRHfvIoaIJ_S9wMOgLPw9_I52OO7Eosols,734
2
- emu_mps/algebra.py,sha256=ngPtTH-j2ZCBWoaJZXlkUyIlug7dY7Q92gzfnRlpPMA,5485
3
- emu_mps/custom_callback_implementations.py,sha256=CUs0kW3HRaPE7UeFNQOFbeWJMsz4hS2q4rgS57BBp-A,2411
4
- emu_mps/hamiltonian.py,sha256=gOPxNOBmk6jRPPjevERuCP_scGv0EKYeAJ0uxooihes,15622
5
- emu_mps/mpo.py,sha256=1ogQ25GZCwMzZ_m449oGHcYyDKrofBCr1eyzzrIPMhQ,8824
6
- emu_mps/mps.py,sha256=GIiWxctNmHARgf-PgQc6IHKNCe5HYSnbtlXI6Hc-0wI,20085
7
- emu_mps/mps_backend.py,sha256=bS83qFxvdoK-c12_1WaPw6O7xUc7vdWifZNHUzNP5sM,2091
8
- emu_mps/mps_backend_impl.py,sha256=U-fNHVkmQFwi_Hfun0OdJ0vmqi9ncjyJ4gIbEM1mN0Y,25887
9
- emu_mps/mps_config.py,sha256=WA64iI4SxxKRM8-49mjvXUzrUv4miYolVYOhR0mVmtk,8555
10
- emu_mps/noise.py,sha256=5BXthepWLKnuSTJfIFuPl2AcYPxUeTJdRc2b28ekkhg,208
11
- emu_mps/observables.py,sha256=7GQDH5kyaVNrwckk2f8ZJRV9Ca4jKhWWDsOCqYWsoEk,1349
12
- emu_mps/tdvp.py,sha256=0qTw9qhg0WbaAyBgeTpULHrNL0ytj80ZUb1P6GKD7Ww,6172
13
- emu_mps/utils.py,sha256=BqRJYAcXqprtZVJ0V_j954ON2bhTdtZiaTojsYyrWrg,8193
14
- emu_mps/optimatrix/__init__.py,sha256=fBXQ7-rgDro4hcaBijCGhx3J69W96qcw5_3mWc7tND4,364
15
- emu_mps/optimatrix/optimiser.py,sha256=k9suYmKLKlaZ7ozFuIqvXHyCBoCtGgkX1mpen9GOdOo,6977
16
- emu_mps/optimatrix/permutations.py,sha256=9DDMZtrGGZ01b9F3GkzHR3paX4qNtZiPoI7Z_Kia3Lc,3727
17
- emu_mps-2.1.1.dist-info/METADATA,sha256=45NpjZK-7fd4dqVbo4Q76XqoMecHm_J2jiqX7oPUSKA,3505
18
- emu_mps-2.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
19
- emu_mps-2.1.1.dist-info/RECORD,,