emu-mps 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/mpo.py CHANGED
@@ -83,7 +83,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
83
83
  other.factors,
84
84
  config=other.config,
85
85
  )
86
- return MPS(factors, orthogonality_center=0)
86
+ return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
87
87
 
88
88
  def __add__(self, other: MPO) -> MPO:
89
89
  """
emu_mps/mps.py CHANGED
@@ -38,6 +38,7 @@ class MPS(State[complex, torch.Tensor]):
38
38
  orthogonality_center: Optional[int] = None,
39
39
  config: Optional[MPSConfig] = None,
40
40
  num_gpus_to_use: Optional[int] = DEVICE_COUNT,
41
+ eigenstates: Sequence[Eigenstate] = ("r", "g"),
41
42
  ):
42
43
  """
43
44
  This constructor creates a MPS directly from a list of tensors. It is for internal use only.
@@ -56,7 +57,7 @@ class MPS(State[complex, torch.Tensor]):
56
57
  num_gpus_to_use: distribute the factors over this many GPUs
57
58
  0=all factors to cpu, None=keep the existing device assignment.
58
59
  """
59
- self._eigenstates = ["0", "1"]
60
+ super().__init__(eigenstates=eigenstates)
60
61
  self.config = config if config is not None else MPSConfig()
61
62
  assert all(
62
63
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
@@ -88,6 +89,7 @@ class MPS(State[complex, torch.Tensor]):
88
89
  num_sites: int,
89
90
  config: Optional[MPSConfig] = None,
90
91
  num_gpus_to_use: int = DEVICE_COUNT,
92
+ eigenstates: Sequence[Eigenstate] = ["0", "1"],
91
93
  ) -> MPS:
92
94
  """
93
95
  Returns a MPS in ground state |000..0>.
@@ -111,6 +113,7 @@ class MPS(State[complex, torch.Tensor]):
111
113
  config=config,
112
114
  num_gpus_to_use=num_gpus_to_use,
113
115
  orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
116
+ eigenstates=eigenstates,
114
117
  )
115
118
 
116
119
  def __repr__(self) -> str:
@@ -307,6 +310,25 @@ class MPS(State[complex, torch.Tensor]):
307
310
  """
308
311
  return torch.abs(self.inner(other)) ** 2 # type: ignore[no-any-return]
309
312
 
313
+ def entanglement_entropy(self, mps_site: int) -> torch.Tensor:
314
+ """
315
+ Returns
316
+ the Von Neumann entanglement entropy of the state `mps` at the bond between sites b and b+1
317
+ S = -Σᵢsᵢ² log(sᵢ²)),
318
+ where sᵢ are the singular values at the chosen bond.
319
+ """
320
+ self.orthogonalize(mps_site)
321
+
322
+ # perform svd on reshaped matrix at site b
323
+ matrix = self.factors[mps_site].flatten(end_dim=1)
324
+ s = torch.linalg.svdvals(matrix)
325
+
326
+ s_e = torch.Tensor(torch.special.entr(s**2))
327
+ s_e = torch.sum(s_e)
328
+
329
+ self.orthogonalize(0)
330
+ return s_e.cpu()
331
+
310
332
  def get_memory_footprint(self) -> float:
311
333
  """
312
334
  Returns the number of MBs of memory occupied to store the state
@@ -331,12 +353,16 @@ class MPS(State[complex, torch.Tensor]):
331
353
  the summed state
332
354
  """
333
355
  assert isinstance(other, MPS), "Other state also needs to be an MPS"
356
+ assert (
357
+ self.eigenstates == other.eigenstates
358
+ ), f"`Other` state has basis {other.eigenstates} != {self.eigenstates}"
334
359
  new_tt = add_factors(self.factors, other.factors)
335
360
  result = MPS(
336
361
  new_tt,
337
362
  config=self.config,
338
363
  num_gpus_to_use=None,
339
364
  orthogonality_center=None, # Orthogonality is lost.
365
+ eigenstates=self.eigenstates,
340
366
  )
341
367
  result.truncate()
342
368
  return result
@@ -362,6 +388,7 @@ class MPS(State[complex, torch.Tensor]):
362
388
  config=self.config,
363
389
  num_gpus_to_use=None,
364
390
  orthogonality_center=self.orthogonality_center,
391
+ eigenstates=self.eigenstates,
365
392
  )
366
393
 
367
394
  def __imul__(self, scalar: complex) -> MPS:
@@ -371,7 +398,7 @@ class MPS(State[complex, torch.Tensor]):
371
398
  def _from_state_amplitudes(
372
399
  cls,
373
400
  *,
374
- eigenstates: Sequence[str],
401
+ eigenstates: Sequence[Eigenstate],
375
402
  amplitudes: Mapping[str, complex],
376
403
  ) -> tuple[MPS, Mapping[str, complex]]:
377
404
  """
@@ -385,8 +412,6 @@ class MPS(State[complex, torch.Tensor]):
385
412
  Returns:
386
413
  The resulting MPS representation of the state.s
387
414
  """
388
-
389
- nqubits = len(next(iter(amplitudes.keys())))
390
415
  basis = set(eigenstates)
391
416
  if basis == {"r", "g"}:
392
417
  one = "r"
@@ -395,17 +420,20 @@ class MPS(State[complex, torch.Tensor]):
395
420
  else:
396
421
  raise ValueError("Unsupported basis provided")
397
422
 
423
+ nqubits = cls._validate_amplitudes(amplitudes, eigenstates)
424
+
398
425
  basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
399
426
  basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
400
427
 
401
428
  accum_mps = MPS(
402
429
  [torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
403
430
  orthogonality_center=0,
431
+ eigenstates=eigenstates,
404
432
  )
405
433
 
406
434
  for state, amplitude in amplitudes.items():
407
435
  factors = [basis_1 if ch == one else basis_0 for ch in state]
408
- accum_mps += amplitude * MPS(factors)
436
+ accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
409
437
  norm = accum_mps.norm()
410
438
  if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
411
439
  print("\nThe state is not normalized, normalizing it for you.")
@@ -151,14 +151,14 @@ class MPSBackendImpl:
151
151
  def __getstate__(self) -> dict:
152
152
  for obs in self.config.observables:
153
153
  obs.apply = MethodType(type(obs).apply, obs) # type: ignore[method-assign]
154
- d = self.__dict__
154
+ d = self.__dict__.copy()
155
155
  # mypy thinks the method below is an attribute, because of the __getattr__ override
156
156
  d["results"] = self.results._to_abstract_repr() # type: ignore[operator]
157
157
  return d
158
158
 
159
159
  def __setstate__(self, d: dict) -> None:
160
- d["results"] = Results._from_abstract_repr(d["results"]) # type: ignore [attr-defined]
161
160
  self.__dict__ = d
161
+ self.results = Results._from_abstract_repr(d["results"]) # type: ignore [attr-defined]
162
162
  self.config.monkeypatch_observables()
163
163
 
164
164
  @staticmethod
@@ -208,6 +208,7 @@ class MPSBackendImpl:
208
208
  [f.clone().detach() for f in initial_state.factors],
209
209
  config=self.config,
210
210
  num_gpus_to_use=self.config.num_gpus_to_use,
211
+ eigenstates=initial_state.eigenstates,
211
212
  )
212
213
  initial_state.truncate()
213
214
  initial_state *= 1 / initial_state.norm()
@@ -444,7 +445,6 @@ class MPSBackendImpl:
444
445
  basename = self.autosave_file
445
446
  with open(basename.with_suffix(".new"), "wb") as file_handle:
446
447
  pickle.dump(self, file_handle)
447
-
448
448
  if basename.is_file():
449
449
  os.rename(basename, basename.with_suffix(".bak"))
450
450
 
@@ -497,13 +497,15 @@ class MPSBackendImpl:
497
497
  )
498
498
  full_state = MPS(
499
499
  extended_mps_factors(
500
- normalized_state.factors, self.well_prepared_qubits_filter
500
+ normalized_state.factors,
501
+ self.well_prepared_qubits_filter,
501
502
  ),
502
503
  num_gpus_to_use=None, # Keep the already assigned devices.
503
504
  orthogonality_center=get_extended_site_index(
504
505
  self.well_prepared_qubits_filter,
505
506
  normalized_state.orthogonality_center,
506
507
  ),
508
+ eigenstates=normalized_state.eigenstates,
507
509
  )
508
510
 
509
511
  callback(self.config, fractional_time, full_state, full_mpo, self.results)
@@ -536,7 +538,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
536
538
 
537
539
  def set_jump_threshold(self, bound: float) -> None:
538
540
  self.jump_threshold = random.uniform(0.0, bound)
539
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
541
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
540
542
 
541
543
  def init(self) -> None:
542
544
  self.init_lindblad_noise()
@@ -547,7 +549,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
547
549
  previous_time = self.current_time
548
550
  self.current_time = self.target_time
549
551
  previous_norm_gap_before_jump = self.norm_gap_before_jump
550
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
552
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
551
553
 
552
554
  if self.root_finder is None:
553
555
  # No quantum jump location finding in progress
@@ -567,7 +569,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
567
569
 
568
570
  return
569
571
 
570
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
572
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
571
573
  self.root_finder.provide_ordinate(self.current_time, self.norm_gap_before_jump)
572
574
 
573
575
  if self.root_finder.is_converged(tolerance=1):
@@ -593,7 +595,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
593
595
  self.state *= 1 / self.state.norm()
594
596
  self.init_baths()
595
597
 
596
- norm_after_normalizing = self.state.norm()
598
+ norm_after_normalizing = self.state.norm().item()
597
599
  assert math.isclose(norm_after_normalizing, 1, abs_tol=1e-10)
598
600
  self.set_jump_threshold(norm_after_normalizing**2)
599
601
 
emu_mps/observables.py ADDED
@@ -0,0 +1,40 @@
1
+ from pulser.backend.state import State
2
+ from pulser.backend.observable import Observable
3
+ from emu_mps.mps import MPS
4
+ from typing import Sequence, Any
5
+ import torch
6
+
7
+
8
+ class EntanglementEntropy(Observable):
9
+ """Entanglement Entropy subclass used only in emu_mps"""
10
+
11
+ def __init__(
12
+ self,
13
+ mps_site: int,
14
+ *,
15
+ evaluation_times: Sequence[float] | None = None,
16
+ tag_suffix: str | None = None,
17
+ ):
18
+ super().__init__(evaluation_times=evaluation_times, tag_suffix=tag_suffix)
19
+ self.mps_site = mps_site
20
+
21
+ @property
22
+ def _base_tag(self) -> str:
23
+ return "entanglement_entropy"
24
+
25
+ def _to_abstract_repr(self) -> dict[str, Any]:
26
+ repr = super()._to_abstract_repr()
27
+ repr["mps_site"] = self.mps_site
28
+ return repr
29
+
30
+ def apply(self, *, state: State, **kwargs: Any) -> torch.Tensor:
31
+ if not isinstance(state, MPS):
32
+ raise NotImplementedError(
33
+ "Entanglement entropy observable is only available for emu_mps emulator."
34
+ )
35
+ if not (0 <= self.mps_site <= len(state.factors) - 2):
36
+ raise ValueError(
37
+ f"Invalid bond index {self.mps_site}. "
38
+ f"Expected value in range 0 <= bond_index <= {len(state.factors)-2}."
39
+ )
40
+ return state.entanglement_entropy(self.mps_site)
@@ -1,9 +1,9 @@
1
1
  from .optimiser import minimize_bandwidth
2
- from .permutations import permute_list, permute_matrix, invert_permutation
2
+ from .permutations import permute_tensor, inv_permutation, permute_string
3
3
 
4
4
  __all__ = [
5
5
  "minimize_bandwidth",
6
- "permute_list",
7
- "permute_matrix",
8
- "invert_permutation",
6
+ "permute_string",
7
+ "permute_tensor",
8
+ "inv_permutation",
9
9
  ]
@@ -1,21 +1,18 @@
1
+ import itertools
2
+
3
+ import torch
1
4
  from scipy.sparse import csr_matrix
2
5
  from scipy.sparse.csgraph import reverse_cuthill_mckee
3
- import numpy as np
4
- from emu_mps.optimatrix.permutations import permute_matrix, permute_list
5
- import itertools
6
6
 
7
+ from emu_mps.optimatrix.permutations import permute_tensor
7
8
 
8
- def is_symmetric(mat: np.ndarray) -> bool:
9
- if mat.shape[0] != mat.shape[1]:
10
- return False
11
- if not np.allclose(mat, mat.T, atol=1e-8):
12
- return False
13
9
 
14
- return True
10
+ def is_symmetric(matrix: torch.Tensor, tol: float = 1e-8) -> bool:
11
+ return torch.allclose(matrix, matrix.T, atol=tol)
15
12
 
16
13
 
17
- def matrix_bandwidth(mat: np.ndarray) -> float:
18
- """matrix_bandwidth(matrix: np.ndarray) -> float
14
+ def matrix_bandwidth(mat: torch.Tensor) -> float:
15
+ """matrix_bandwidth(matrix: torch.tensor) -> torch.Tensor
19
16
 
20
17
  Computes bandwidth as max weighted distance between columns of
21
18
  a square matrix as `max (abs(matrix[i, j] * (j - i))`.
@@ -45,19 +42,27 @@ def matrix_bandwidth(mat: np.ndarray) -> float:
45
42
 
46
43
  Example:
47
44
  -------
48
- >>> matrix = np.array([
49
- ... [ 1, -17, 2.4],
50
- ... [ 9, 1, -10],
51
- ... [-15, 20, 1],])
52
- >>> matrix_bandwidth(matrix) # 30.0 because abs(-15 * (2-0) == 30)
45
+ >>> matrix = torch.tensor([
46
+ ... [1.0, -17.0, 2.4],
47
+ ... [9.0, 1.0, -10.0],
48
+ ... [-15.0, 20.0, 1.0]
49
+ ... ])
50
+ >>> matrix_bandwidth(matrix) # because abs(-15 * (0 - 2)) = 30.0
53
51
  30.0
54
52
  """
55
53
 
56
- bandwidth = max(abs(el * (index[0] - index[1])) for index, el in np.ndenumerate(mat))
57
- return float(bandwidth)
54
+ n = mat.shape[0]
58
55
 
56
+ i_arr = torch.arange(n).view(-1, 1) # shape (n, 1)
57
+ j_arr = torch.arange(n).view(1, -1) # shape (1, n)
59
58
 
60
- def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.ndarray:
59
+ weighted = torch.abs(mat * (j_arr - i_arr))
60
+ return torch.max(weighted).to(mat.dtype).item()
61
+
62
+
63
+ def minimize_bandwidth_above_threshold(
64
+ mat: torch.Tensor, threshold: float
65
+ ) -> torch.Tensor:
61
66
  """
62
67
  minimize_bandwidth_above_threshold(matrix, trunc) -> permutation_lists
63
68
 
@@ -78,24 +83,25 @@ def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.
78
83
 
79
84
  Example:
80
85
  -------
81
- >>> matrix = np.array([
82
- ... [1, 2, 3],
83
- ... [2, 5, 6],
84
- ... [3, 6, 9]])
86
+ >>> matrix = torch.tensor([
87
+ ... [1, 2, 3],
88
+ ... [2, 5, 6],
89
+ ... [3, 6, 9]
90
+ ... ], dtype=torch.float32)
85
91
  >>> threshold = 3
86
92
  >>> minimize_bandwidth_above_threshold(matrix, threshold)
87
- array([1, 2, 0], dtype=int32)
93
+ tensor([1, 2, 0], dtype=torch.int32)
88
94
  """
89
95
 
90
- matrix_truncated = mat.copy()
91
- matrix_truncated[mat < threshold] = 0
92
- rcm_permutation = reverse_cuthill_mckee(
93
- csr_matrix(matrix_truncated), symmetric_mode=True
94
- )
95
- return np.array(rcm_permutation)
96
+ m_trunc = mat.clone()
97
+ m_trunc[mat < threshold] = 0.0
96
98
 
99
+ matrix_np = csr_matrix(m_trunc.numpy()) # SciPy's RCM compatibility
100
+ rcm_perm = reverse_cuthill_mckee(matrix_np, symmetric_mode=True)
101
+ return torch.from_numpy(rcm_perm.copy()) # translation requires copy
97
102
 
98
- def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
103
+
104
+ def minimize_bandwidth_global(mat: torch.Tensor) -> torch.Tensor:
99
105
  """
100
106
  minimize_bandwidth_global(matrix) -> list
101
107
 
@@ -111,74 +117,78 @@ def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
111
117
  -------
112
118
  permutation order that minimizes matrix bandwidth
113
119
 
114
- Example:
120
+ Example
115
121
  -------
116
- >>> matrix = np.array([
117
- ... [1, 2, 3],
118
- ... [2, 5, 6],
119
- ... [3, 6, 9]])
122
+ >>> matrix = torch.tensor([
123
+ ... [1, 2, 3],
124
+ ... [2, 5, 6],
125
+ ... [3, 6, 9]
126
+ ... ], dtype=torch.float32)
120
127
  >>> minimize_bandwidth_global(matrix)
121
- [2, 1, 0]
128
+ tensor([2, 1, 0], dtype=torch.int32)
122
129
  """
123
- mat_amplitude = np.max(np.abs(mat))
124
- # Search from 1.0 to 0.1 doesn't change result
130
+ mat_amplitude = torch.max(torch.abs(mat))
131
+
125
132
  permutations = (
126
- minimize_bandwidth_above_threshold(mat, trunc * mat_amplitude)
127
- for trunc in np.arange(start=0.1, stop=1.0, step=0.01)
133
+ minimize_bandwidth_above_threshold(mat, trunc.item() * mat_amplitude)
134
+ for trunc in torch.arange(0.1, 1.0, 0.01)
128
135
  )
129
136
 
130
137
  opt_permutation = min(
131
- permutations, key=lambda perm: matrix_bandwidth(permute_matrix(mat, list(perm)))
138
+ permutations, key=lambda perm: matrix_bandwidth(permute_tensor(mat, perm))
132
139
  )
133
- return list(opt_permutation) # opt_permutation is np.ndarray
140
+
141
+ return opt_permutation
134
142
 
135
143
 
136
144
  def minimize_bandwidth_impl(
137
- matrix: np.ndarray, initial_perm: list[int]
138
- ) -> tuple[list[int], float]:
145
+ matrix: torch.Tensor, initial_perm: torch.Tensor
146
+ ) -> tuple[torch.Tensor, float]:
139
147
  """
140
- minimize_bandwidth_impl(matrix, initial_perm) -> list
148
+ minimize_bandwidth_impl(matrix, initial_perm) -> (optimal_perm, bandwidth)
141
149
 
142
150
  Applies initial_perm to a matrix and
143
- finds the permutation list for a symmetric matrix that iteratively minimizes matrix bandwidth.
151
+ finds the permutation list for a symmetric matrix
152
+ that iteratively minimizes matrix bandwidth.
144
153
 
145
154
  Parameters
146
155
  -------
147
156
  matrix :
148
157
  symmetric square matrix
149
- initial_perm: list of integers
158
+ initial_perm: torch list of integers
150
159
 
151
160
 
152
161
  Returns
153
162
  -------
154
- permutation order that minimizes matrix bandwidth
163
+ optimal permutation and optimal matrix bandwidth
155
164
 
156
165
  Example:
157
166
  -------
158
167
  Periodic 1D chain
159
- >>> matrix = np.array([
168
+ >>> matrix = torch.tensor([
160
169
  ... [0, 1, 0, 0, 1],
161
170
  ... [1, 0, 1, 0, 0],
162
171
  ... [0, 1, 0, 1, 0],
163
172
  ... [0, 0, 1, 0, 1],
164
- ... [1, 0, 0, 1, 0]])
165
- >>> id_perm = list(range(matrix.shape[0]))
173
+ ... [1, 0, 0, 1, 0]], dtype=torch.float32)
174
+ >>> id_perm = torch.arange(matrix.shape[0])
166
175
  >>> minimize_bandwidth_impl(matrix, id_perm) # [3, 2, 4, 1, 0] does zig-zag
167
- ([3, 2, 4, 1, 0], 2.0)
176
+ (tensor([3, 2, 4, 1, 0]), 2.0)
168
177
 
169
178
  Simple 1D chain. Cannot be optimised further
170
- >>> matrix = np.array([
179
+ >>> matrix = torch.tensor([
171
180
  ... [0, 1, 0, 0, 0],
172
181
  ... [1, 0, 1, 0, 0],
173
182
  ... [0, 1, 0, 1, 0],
174
183
  ... [0, 0, 1, 0, 1],
175
- ... [0, 0, 0, 1, 0]])
176
- >>> id_perm = list(range(matrix.shape[0]))
184
+ ... [0, 0, 0, 1, 0]], dtype=torch.float32)
185
+ >>> id_perm = torch.arange(matrix.shape[0])
177
186
  >>> minimize_bandwidth_impl(matrix, id_perm)
178
- ([0, 1, 2, 3, 4], 1.0)
187
+ (tensor([0, 1, 2, 3, 4]), 1.0)
179
188
  """
180
- if initial_perm != list(range(matrix.shape[0])):
181
- matrix = permute_matrix(matrix, initial_perm)
189
+ L = matrix.shape[0]
190
+ if not torch.equal(initial_perm, torch.arange(L)):
191
+ matrix = permute_tensor(matrix, initial_perm)
182
192
  bandwidth = matrix_bandwidth(matrix)
183
193
  acc_permutation = initial_perm
184
194
 
@@ -191,28 +201,28 @@ def minimize_bandwidth_impl(
191
201
  )
192
202
 
193
203
  optimal_perm = minimize_bandwidth_global(matrix)
194
- test_mat = permute_matrix(matrix, optimal_perm)
204
+ test_mat = permute_tensor(matrix, optimal_perm)
195
205
  new_bandwidth = matrix_bandwidth(test_mat)
196
206
 
197
207
  if bandwidth <= new_bandwidth:
198
208
  break
199
209
 
200
210
  matrix = test_mat
201
- acc_permutation = permute_list(acc_permutation, optimal_perm)
211
+ acc_permutation = permute_tensor(acc_permutation, optimal_perm)
202
212
  bandwidth = new_bandwidth
203
213
 
204
214
  return acc_permutation, bandwidth
205
215
 
206
216
 
207
- def minimize_bandwidth(input_matrix: np.ndarray, samples: int = 100) -> list[int]:
217
+ def minimize_bandwidth(input_matrix: torch.Tensor, samples: int = 100) -> torch.Tensor:
208
218
  assert is_symmetric(input_matrix), "Input matrix is not symmetric"
209
- input_mat = abs(input_matrix)
219
+ input_mat = torch.abs(input_matrix)
210
220
  # We are interested in strength of the interaction, not sign
211
221
 
212
222
  L = input_mat.shape[0]
213
- rnd_permutations: itertools.chain[list[int]] = itertools.chain(
214
- [list(range(L))], # First element is always the identity list
215
- (np.random.permutation(L).tolist() for _ in range(samples)), # type: ignore[misc]
223
+ rnd_permutations = itertools.chain(
224
+ [torch.arange(L)], # identity permutation
225
+ [torch.randperm(L) for _ in range(samples)], # list of random permutations
216
226
  )
217
227
 
218
228
  opt_permutations_and_opt_bandwidth = (
@@ -1,36 +1,31 @@
1
- import numpy as np
1
+ import torch
2
2
 
3
3
 
4
- def permute_list(input_list: list, permutation: list[int]) -> list:
4
+ def permute_string(input_str: str, perm: torch.Tensor) -> str:
5
5
  """
6
- Permutes the input list according to the given permutation.
7
-
6
+ Permutes the input string according to the given permutation.
8
7
  Parameters
9
8
  -------
10
- input_list :
11
- A list to permute.
9
+ input_string :
10
+ A string to permute.
12
11
  permutation :
13
12
  A list of indices representing the new order.
14
-
15
13
  Returns
16
14
  -------
17
- The permuted list.
18
-
15
+ The permuted string.
19
16
  Example
20
17
  -------
21
- >>> permute_list(['a', 'b', 'c'], [2, 0, 1])
22
- ['c', 'a', 'b']
18
+ >>> permute_string("abc", torch.tensor([2, 0, 1]))
19
+ 'cab'
23
20
  """
21
+ char_list = list(input_str)
22
+ permuted = [char_list[i] for i in perm.tolist()]
23
+ return "".join(permuted)
24
24
 
25
- permuted_list = [None] * len(input_list)
26
- for i, p in enumerate(permutation):
27
- permuted_list[i] = input_list[p]
28
- return permuted_list
29
25
 
30
-
31
- def invert_permutation(permutation: list[int]) -> list[int]:
26
+ def inv_permutation(permutation: torch.Tensor) -> torch.Tensor:
32
27
  """
33
- invert_permutation(permutation) -> inv_permutation
28
+ inv_permutation(permutation) -> inverted_perm
34
29
 
35
30
  Inverts the input permutation list.
36
31
 
@@ -45,47 +40,60 @@ def invert_permutation(permutation: list[int]) -> list[int]:
45
40
 
46
41
  Example:
47
42
  -------
48
- >>> invert_permutation([2, 0, 1])
49
- [1, 2, 0]
43
+ >>> inv_permutation(torch.tensor([2, 0, 1]))
44
+ tensor([1, 2, 0])
50
45
  """
51
-
52
- inv_perm = np.empty_like(permutation)
53
- inv_perm[permutation] = np.arange(len(permutation))
54
- return list(inv_perm)
46
+ inv_perm = torch.empty_like(permutation)
47
+ inv_perm[permutation] = torch.arange(len(permutation))
48
+ return inv_perm
55
49
 
56
50
 
57
- def permute_matrix(mat: np.ndarray, permutation: list[int]) -> np.ndarray:
51
+ def permute_tensor(tensor: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
58
52
  """
59
- permute_matrix(matrix, permutation_list) -> permuted_matrix
60
-
61
- Simultaneously permutes columns and rows according to a permutation list.
53
+ Permute a 1D or square 2D torch tensor using the given permutation indices.
54
+ For 1D tensors, applies the permutation to the elements.
55
+ For 2D square tensors, applies the same permutation to both rows and columns.
62
56
 
63
57
  Parameters
64
- -------
65
- matrix :
66
- square matrix nxn
67
- permutation :
68
- permutation list
58
+ ----------
59
+ tensor : torch.Tensor
60
+ A 1D or 2D square tensor to be permuted.
61
+ perm : torch.Tensor
62
+ A 1D tensor of indices specifying the permutation order.
69
63
 
70
64
  Returns
71
65
  -------
72
- matrix with permuted columns and rows
73
-
74
- Example:
75
- -------
76
- >>> matrix = np.array([
77
- ... [1, 2, 3],
78
- ... [4, 5, 6],
79
- ... [7, 8, 9]])
80
- >>> permutation = [1, 0, 2]
81
- >>> permute_matrix(matrix, permutation)
82
- array([[5, 4, 6],
83
- [2, 1, 3],
84
- [8, 7, 9]])
66
+ torch.Tensor
67
+ A new tensor with elements (1D) or rows and columns (2D) permuted according to `perm`.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If tensor is not 1D or square 2D.
73
+
74
+ Examples
75
+ --------
76
+ >>> vector = torch.tensor([10, 20, 30])
77
+ >>> perm = torch.tensor([2, 0, 1])
78
+ >>> permute_tensor(vector, perm)
79
+ tensor([30, 10, 20])
80
+
81
+ >>> matrix = torch.tensor([
82
+ ... [1, 2, 3],
83
+ ... [4, 5, 6],
84
+ ... [7, 8, 9]])
85
+ >>> perm = torch.tensor([1, 0, 2])
86
+ >>> permute_tensor(matrix, perm)
87
+ tensor([[5, 4, 6],
88
+ [2, 1, 3],
89
+ [8, 7, 9]])
85
90
  """
86
-
87
- perm = np.array(permutation)
88
- return mat[perm, :][:, perm]
91
+ if tensor.ndim == 1:
92
+ return tensor[perm]
93
+ elif tensor.ndim == 2 and tensor.shape[0] == tensor.shape[1]:
94
+ return tensor[perm][:, perm]
95
+ else:
96
+ raise ValueError("Only 1D tensors or square 2D tensors are supported.")
89
97
 
90
98
 
91
99
  if __name__ == "__main__":
emu_mps/tdvp.py CHANGED
@@ -78,7 +78,7 @@ def apply_effective_Hamiltonian(
78
78
  return state
79
79
 
80
80
 
81
- _TIME_CONVERSION_COEFF = 0.001 # Omega and delta are given in rad/ms, dt in ns
81
+ _TIME_CONVERSION_COEFF = 0.001 # Omega and delta are given in rad/μs, dt in ns
82
82
 
83
83
 
84
84
  def evolve_pair(