emu-mps 2.0.1__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
emu_mps/mpo.py CHANGED
@@ -83,7 +83,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
83
83
  other.factors,
84
84
  config=other.config,
85
85
  )
86
- return MPS(factors, orthogonality_center=0)
86
+ return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
87
87
 
88
88
  def __add__(self, other: MPO) -> MPO:
89
89
  """
emu_mps/mps.py CHANGED
@@ -38,6 +38,7 @@ class MPS(State[complex, torch.Tensor]):
38
38
  orthogonality_center: Optional[int] = None,
39
39
  config: Optional[MPSConfig] = None,
40
40
  num_gpus_to_use: Optional[int] = DEVICE_COUNT,
41
+ eigenstates: Sequence[Eigenstate] = ("r", "g"),
41
42
  ):
42
43
  """
43
44
  This constructor creates a MPS directly from a list of tensors. It is for internal use only.
@@ -56,7 +57,7 @@ class MPS(State[complex, torch.Tensor]):
56
57
  num_gpus_to_use: distribute the factors over this many GPUs
57
58
  0=all factors to cpu, None=keep the existing device assignment.
58
59
  """
59
- self._eigenstates = ["0", "1"]
60
+ super().__init__(eigenstates=eigenstates)
60
61
  self.config = config if config is not None else MPSConfig()
61
62
  assert all(
62
63
  factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
@@ -88,6 +89,7 @@ class MPS(State[complex, torch.Tensor]):
88
89
  num_sites: int,
89
90
  config: Optional[MPSConfig] = None,
90
91
  num_gpus_to_use: int = DEVICE_COUNT,
92
+ eigenstates: Sequence[Eigenstate] = ["0", "1"],
91
93
  ) -> MPS:
92
94
  """
93
95
  Returns a MPS in ground state |000..0>.
@@ -111,6 +113,7 @@ class MPS(State[complex, torch.Tensor]):
111
113
  config=config,
112
114
  num_gpus_to_use=num_gpus_to_use,
113
115
  orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
116
+ eigenstates=eigenstates,
114
117
  )
115
118
 
116
119
  def __repr__(self) -> str:
@@ -350,12 +353,16 @@ class MPS(State[complex, torch.Tensor]):
350
353
  the summed state
351
354
  """
352
355
  assert isinstance(other, MPS), "Other state also needs to be an MPS"
356
+ assert (
357
+ self.eigenstates == other.eigenstates
358
+ ), f"`Other` state has basis {other.eigenstates} != {self.eigenstates}"
353
359
  new_tt = add_factors(self.factors, other.factors)
354
360
  result = MPS(
355
361
  new_tt,
356
362
  config=self.config,
357
363
  num_gpus_to_use=None,
358
364
  orthogonality_center=None, # Orthogonality is lost.
365
+ eigenstates=self.eigenstates,
359
366
  )
360
367
  result.truncate()
361
368
  return result
@@ -381,6 +388,7 @@ class MPS(State[complex, torch.Tensor]):
381
388
  config=self.config,
382
389
  num_gpus_to_use=None,
383
390
  orthogonality_center=self.orthogonality_center,
391
+ eigenstates=self.eigenstates,
384
392
  )
385
393
 
386
394
  def __imul__(self, scalar: complex) -> MPS:
@@ -390,7 +398,7 @@ class MPS(State[complex, torch.Tensor]):
390
398
  def _from_state_amplitudes(
391
399
  cls,
392
400
  *,
393
- eigenstates: Sequence[str],
401
+ eigenstates: Sequence[Eigenstate],
394
402
  amplitudes: Mapping[str, complex],
395
403
  ) -> tuple[MPS, Mapping[str, complex]]:
396
404
  """
@@ -404,8 +412,6 @@ class MPS(State[complex, torch.Tensor]):
404
412
  Returns:
405
413
  The resulting MPS representation of the state.s
406
414
  """
407
-
408
- nqubits = len(next(iter(amplitudes.keys())))
409
415
  basis = set(eigenstates)
410
416
  if basis == {"r", "g"}:
411
417
  one = "r"
@@ -414,17 +420,20 @@ class MPS(State[complex, torch.Tensor]):
414
420
  else:
415
421
  raise ValueError("Unsupported basis provided")
416
422
 
423
+ nqubits = cls._validate_amplitudes(amplitudes, eigenstates)
424
+
417
425
  basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
418
426
  basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
419
427
 
420
428
  accum_mps = MPS(
421
429
  [torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
422
430
  orthogonality_center=0,
431
+ eigenstates=eigenstates,
423
432
  )
424
433
 
425
434
  for state, amplitude in amplitudes.items():
426
435
  factors = [basis_1 if ch == one else basis_0 for ch in state]
427
- accum_mps += amplitude * MPS(factors)
436
+ accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
428
437
  norm = accum_mps.norm()
429
438
  if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
430
439
  print("\nThe state is not normalized, normalizing it for you.")
emu_mps/mps_backend.py CHANGED
@@ -55,7 +55,9 @@ class MPSBackend(EmulatorBackend):
55
55
  impl = create_impl(self._sequence, self._config)
56
56
  impl.init() # This is separate from the constructor for testing purposes.
57
57
 
58
- return self._run(impl)
58
+ results = self._run(impl)
59
+
60
+ return impl.permute_results(results, self._config.optimize_qubit_ordering)
59
61
 
60
62
  @staticmethod
61
63
  def _run(impl: MPSBackendImpl) -> Results:
@@ -1,29 +1,34 @@
1
1
  import math
2
+ import os
2
3
  import pathlib
4
+ import pickle
3
5
  import random
6
+ import time
7
+ import typing
4
8
  import uuid
5
9
 
10
+ from collections import Counter
11
+ from enum import Enum, auto
6
12
  from resource import RUSAGE_SELF, getrusage
7
- from typing import Optional, Any
8
- import typing
9
- import pickle
10
- import os
13
+ from types import MethodType
14
+ from typing import Any, Optional
15
+
11
16
  import torch
12
- import time
13
17
  from pulser import Sequence
14
- from types import MethodType
18
+ from pulser.backend import EmulationConfig, Observable, Results, State
15
19
 
16
- from pulser.backend import State, Observable, EmulationConfig, Results
17
- from emu_base import PulserData, DEVICE_COUNT
20
+ from emu_base import DEVICE_COUNT, PulserData
18
21
  from emu_base.math.brents_root_finding import BrentsRootFinder
22
+
19
23
  from emu_mps.hamiltonian import make_H, update_H
20
24
  from emu_mps.mpo import MPO
21
25
  from emu_mps.mps import MPS
22
26
  from emu_mps.mps_config import MPSConfig
23
27
  from emu_mps.noise import compute_noise_from_lindbladians, pick_well_prepared_qubits
28
+ import emu_mps.optimatrix as optimat
24
29
  from emu_mps.tdvp import (
25
- evolve_single,
26
30
  evolve_pair,
31
+ evolve_single,
27
32
  new_right_bath,
28
33
  right_baths,
29
34
  )
@@ -33,7 +38,6 @@ from emu_mps.utils import (
33
38
  get_extended_site_index,
34
39
  new_left_bath,
35
40
  )
36
- from enum import Enum, auto
37
41
 
38
42
 
39
43
  class Statistics(Observable):
@@ -118,8 +122,17 @@ class MPSBackendImpl:
118
122
  self.timestep_count: int = self.omega.shape[0]
119
123
  self.has_lindblad_noise = pulser_data.has_lindblad_noise
120
124
  self.lindblad_noise = torch.zeros(2, 2, dtype=torch.complex128)
121
- self.full_interaction_matrix = pulser_data.full_interaction_matrix
122
- self.masked_interaction_matrix = pulser_data.masked_interaction_matrix
125
+ self.qubit_permutation = (
126
+ optimat.minimize_bandwidth(pulser_data.full_interaction_matrix)
127
+ if self.config.optimize_qubit_ordering
128
+ else optimat.eye_permutation(self.qubit_count)
129
+ )
130
+ self.full_interaction_matrix = optimat.permute_tensor(
131
+ pulser_data.full_interaction_matrix, self.qubit_permutation
132
+ )
133
+ self.masked_interaction_matrix = optimat.permute_tensor(
134
+ pulser_data.masked_interaction_matrix, self.qubit_permutation
135
+ )
123
136
  self.hamiltonian_type = pulser_data.hamiltonian_type
124
137
  self.slm_end_time = pulser_data.slm_end_time
125
138
  self.is_masked = self.slm_end_time > 0.0
@@ -128,7 +141,12 @@ class MPSBackendImpl:
128
141
  self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
129
142
  self.tdvp_index = 0
130
143
  self.timestep_index = 0
131
- self.results = Results(atom_order=(), total_duration=self.target_times[-1])
144
+ self.results = Results(
145
+ atom_order=optimat.permute_tuple(
146
+ pulser_data.qubit_ids, self.qubit_permutation
147
+ ),
148
+ total_duration=self.target_times[-1],
149
+ )
132
150
  self.statistics = Statistics(
133
151
  evaluation_times=[t / self.target_times[-1] for t in self.target_times],
134
152
  data=[],
@@ -203,11 +221,24 @@ class MPSBackendImpl:
203
221
  )
204
222
 
205
223
  assert isinstance(initial_state, MPS)
224
+ if not torch.equal(
225
+ self.qubit_permutation, optimat.eye_permutation(self.qubit_count)
226
+ ):
227
+ # permute the initial state to match with permuted Hamiltonian
228
+ abstr_repr = initial_state._to_abstract_repr()
229
+ eigs = abstr_repr["eigenstates"]
230
+ ampl = {
231
+ optimat.permute_string(bstr, self.qubit_permutation): amp
232
+ for bstr, amp in abstr_repr["amplitudes"].items()
233
+ }
234
+ initial_state = MPS.from_state_amplitudes(eigenstates=eigs, amplitudes=ampl)
235
+
206
236
  initial_state = MPS(
207
237
  # Deep copy of every tensor of the initial state.
208
238
  [f.clone().detach() for f in initial_state.factors],
209
239
  config=self.config,
210
240
  num_gpus_to_use=self.config.num_gpus_to_use,
241
+ eigenstates=initial_state.eigenstates,
211
242
  )
212
243
  initial_state.truncate()
213
244
  initial_state *= 1 / initial_state.norm()
@@ -496,17 +527,56 @@ class MPSBackendImpl:
496
527
  )
497
528
  full_state = MPS(
498
529
  extended_mps_factors(
499
- normalized_state.factors, self.well_prepared_qubits_filter
530
+ normalized_state.factors,
531
+ self.well_prepared_qubits_filter,
500
532
  ),
501
533
  num_gpus_to_use=None, # Keep the already assigned devices.
502
534
  orthogonality_center=get_extended_site_index(
503
535
  self.well_prepared_qubits_filter,
504
536
  normalized_state.orthogonality_center,
505
537
  ),
538
+ eigenstates=normalized_state.eigenstates,
506
539
  )
507
540
 
508
541
  callback(self.config, fractional_time, full_state, full_mpo, self.results)
509
542
 
543
+ def permute_results(self, results: Results, permute: bool) -> Results:
544
+ if permute:
545
+ inv_perm = optimat.inv_permutation(self.qubit_permutation)
546
+ permute_bitstrings(results, inv_perm)
547
+ permute_occupations_and_correlations(results, inv_perm)
548
+ permute_atom_order(results, inv_perm)
549
+ return results
550
+
551
+
552
+ def permute_bitstrings(results: Results, perm: torch.Tensor) -> None:
553
+ if "bitstrings" not in results.get_result_tags():
554
+ return
555
+ uuid_bs = results._find_uuid("bitstrings")
556
+
557
+ results._results[uuid_bs] = [
558
+ Counter({optimat.permute_string(bstr, perm): c for bstr, c in bs_counter.items()})
559
+ for bs_counter in results._results[uuid_bs]
560
+ ]
561
+
562
+
563
+ def permute_occupations_and_correlations(results: Results, perm: torch.Tensor) -> None:
564
+ for corr in ["occupation", "correlation_matrix"]:
565
+ if corr not in results.get_result_tags():
566
+ return
567
+
568
+ uuid_corr = results._find_uuid(corr)
569
+ corrs = results._results[uuid_corr]
570
+ results._results[uuid_corr] = [
571
+ optimat.permute_tensor(corr, perm) for corr in corrs
572
+ ]
573
+
574
+
575
+ def permute_atom_order(results: Results, perm: torch.Tensor) -> None:
576
+ at_ord = list(results.atom_order)
577
+ at_ord = optimat.permute_list(at_ord, perm)
578
+ results.atom_order = tuple(at_ord)
579
+
510
580
 
511
581
  class NoisyMPSBackendImpl(MPSBackendImpl):
512
582
  """
@@ -535,7 +605,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
535
605
 
536
606
  def set_jump_threshold(self, bound: float) -> None:
537
607
  self.jump_threshold = random.uniform(0.0, bound)
538
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
608
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
539
609
 
540
610
  def init(self) -> None:
541
611
  self.init_lindblad_noise()
@@ -546,7 +616,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
546
616
  previous_time = self.current_time
547
617
  self.current_time = self.target_time
548
618
  previous_norm_gap_before_jump = self.norm_gap_before_jump
549
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
619
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
550
620
 
551
621
  if self.root_finder is None:
552
622
  # No quantum jump location finding in progress
@@ -566,7 +636,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
566
636
 
567
637
  return
568
638
 
569
- self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
639
+ self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
570
640
  self.root_finder.provide_ordinate(self.current_time, self.norm_gap_before_jump)
571
641
 
572
642
  if self.root_finder.is_converged(tolerance=1):
@@ -592,7 +662,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
592
662
  self.state *= 1 / self.state.norm()
593
663
  self.init_baths()
594
664
 
595
- norm_after_normalizing = self.state.norm()
665
+ norm_after_normalizing = self.state.norm().item()
596
666
  assert math.isclose(norm_after_normalizing, 1, abs_tol=1e-10)
597
667
  self.set_jump_threshold(norm_after_normalizing**2)
598
668
 
emu_mps/mps_config.py CHANGED
@@ -69,6 +69,7 @@ class MPSConfig(EmulationConfig):
69
69
  max_krylov_dim: int = 100,
70
70
  extra_krylov_tolerance: float = 1e-3,
71
71
  num_gpus_to_use: int = DEVICE_COUNT,
72
+ optimize_qubit_ordering: bool = False,
72
73
  interaction_cutoff: float = 0.0,
73
74
  log_level: int = logging.INFO,
74
75
  log_file: pathlib.Path | None = None,
@@ -84,6 +85,7 @@ class MPSConfig(EmulationConfig):
84
85
  max_krylov_dim=max_krylov_dim,
85
86
  extra_krylov_tolerance=extra_krylov_tolerance,
86
87
  num_gpus_to_use=num_gpus_to_use,
88
+ optimize_qubit_ordering=optimize_qubit_ordering,
87
89
  interaction_cutoff=interaction_cutoff,
88
90
  log_level=log_level,
89
91
  log_file=log_file,
@@ -91,6 +93,8 @@ class MPSConfig(EmulationConfig):
91
93
  autosave_dt=autosave_dt,
92
94
  **kwargs,
93
95
  )
96
+ if self.optimize_qubit_ordering:
97
+ self.check_permutable_observables()
94
98
 
95
99
  if "doppler" in self.noise_model.noise_types:
96
100
  raise NotImplementedError("Unsupported noise type: doppler")
@@ -136,6 +140,7 @@ class MPSConfig(EmulationConfig):
136
140
  "max_krylov_dim",
137
141
  "extra_krylov_tolerance",
138
142
  "num_gpus_to_use",
143
+ "optimize_qubit_ordering",
139
144
  "interaction_cutoff",
140
145
  "log_level",
141
146
  "log_file",
@@ -183,3 +188,27 @@ class MPSConfig(EmulationConfig):
183
188
  filemode="w",
184
189
  force=True,
185
190
  )
191
+
192
+ def check_permutable_observables(self) -> None:
193
+ allowed_permutable_obs = set(
194
+ [
195
+ "bitstrings",
196
+ "occupation",
197
+ "correlation_matrix",
198
+ "statistics",
199
+ "energy",
200
+ "energy_variance",
201
+ "energy_second_moment",
202
+ ]
203
+ )
204
+
205
+ actual_obs = set([obs._base_tag for obs in self.observables])
206
+ not_allowed = actual_obs.difference(allowed_permutable_obs)
207
+ if not_allowed:
208
+ raise ValueError(
209
+ f"emu-mp allows only {allowed_permutable_obs} observables with"
210
+ " `optimize_qubit_ordering = True`."
211
+ f" you provided unsupported {not_allowed}"
212
+ " To use other observables, please set"
213
+ " `optimize_qubit_ordering = False` in `MPSConfig()`."
214
+ )
@@ -1,9 +1,20 @@
1
1
  from .optimiser import minimize_bandwidth
2
- from .permutations import permute_list, permute_matrix, invert_permutation
2
+ from .permutations import (
3
+ permute_tensor,
4
+ inv_permutation,
5
+ permute_string,
6
+ eye_permutation,
7
+ permute_list,
8
+ permute_tuple,
9
+ )
10
+
3
11
 
4
12
  __all__ = [
5
13
  "minimize_bandwidth",
14
+ "eye_permutation",
15
+ "permute_string",
16
+ "permute_tensor",
17
+ "inv_permutation",
6
18
  "permute_list",
7
- "permute_matrix",
8
- "invert_permutation",
19
+ "permute_tuple",
9
20
  ]
@@ -1,21 +1,18 @@
1
+ import itertools
2
+
3
+ import torch
1
4
  from scipy.sparse import csr_matrix
2
5
  from scipy.sparse.csgraph import reverse_cuthill_mckee
3
- import numpy as np
4
- from emu_mps.optimatrix.permutations import permute_matrix, permute_list
5
- import itertools
6
6
 
7
+ from emu_mps.optimatrix.permutations import permute_tensor
7
8
 
8
- def is_symmetric(mat: np.ndarray) -> bool:
9
- if mat.shape[0] != mat.shape[1]:
10
- return False
11
- if not np.allclose(mat, mat.T, atol=1e-8):
12
- return False
13
9
 
14
- return True
10
+ def is_symmetric(matrix: torch.Tensor, tol: float = 1e-8) -> bool:
11
+ return torch.allclose(matrix, matrix.T, atol=tol)
15
12
 
16
13
 
17
- def matrix_bandwidth(mat: np.ndarray) -> float:
18
- """matrix_bandwidth(matrix: np.ndarray) -> float
14
+ def matrix_bandwidth(mat: torch.Tensor) -> float:
15
+ """matrix_bandwidth(matrix: torch.tensor) -> torch.Tensor
19
16
 
20
17
  Computes bandwidth as max weighted distance between columns of
21
18
  a square matrix as `max (abs(matrix[i, j] * (j - i))`.
@@ -45,19 +42,27 @@ def matrix_bandwidth(mat: np.ndarray) -> float:
45
42
 
46
43
  Example:
47
44
  -------
48
- >>> matrix = np.array([
49
- ... [ 1, -17, 2.4],
50
- ... [ 9, 1, -10],
51
- ... [-15, 20, 1],])
52
- >>> matrix_bandwidth(matrix) # 30.0 because abs(-15 * (2-0) == 30)
45
+ >>> matrix = torch.tensor([
46
+ ... [1.0, -17.0, 2.4],
47
+ ... [9.0, 1.0, -10.0],
48
+ ... [-15.0, 20.0, 1.0]
49
+ ... ])
50
+ >>> matrix_bandwidth(matrix) # because abs(-15 * (0 - 2)) = 30.0
53
51
  30.0
54
52
  """
55
53
 
56
- bandwidth = max(abs(el * (index[0] - index[1])) for index, el in np.ndenumerate(mat))
57
- return float(bandwidth)
54
+ n = mat.shape[0]
58
55
 
56
+ i_arr = torch.arange(n).view(-1, 1) # shape (n, 1)
57
+ j_arr = torch.arange(n).view(1, -1) # shape (1, n)
59
58
 
60
- def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.ndarray:
59
+ weighted = torch.abs(mat * (j_arr - i_arr))
60
+ return torch.max(weighted).to(mat.dtype).item()
61
+
62
+
63
+ def minimize_bandwidth_above_threshold(
64
+ mat: torch.Tensor, threshold: float
65
+ ) -> torch.Tensor:
61
66
  """
62
67
  minimize_bandwidth_above_threshold(matrix, trunc) -> permutation_lists
63
68
 
@@ -78,24 +83,25 @@ def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.
78
83
 
79
84
  Example:
80
85
  -------
81
- >>> matrix = np.array([
82
- ... [1, 2, 3],
83
- ... [2, 5, 6],
84
- ... [3, 6, 9]])
86
+ >>> matrix = torch.tensor([
87
+ ... [1, 2, 3],
88
+ ... [2, 5, 6],
89
+ ... [3, 6, 9]
90
+ ... ], dtype=torch.float32)
85
91
  >>> threshold = 3
86
92
  >>> minimize_bandwidth_above_threshold(matrix, threshold)
87
- array([1, 2, 0], dtype=int32)
93
+ tensor([1, 2, 0], dtype=torch.int32)
88
94
  """
89
95
 
90
- matrix_truncated = mat.copy()
91
- matrix_truncated[mat < threshold] = 0
92
- rcm_permutation = reverse_cuthill_mckee(
93
- csr_matrix(matrix_truncated), symmetric_mode=True
94
- )
95
- return np.array(rcm_permutation)
96
+ m_trunc = mat.clone()
97
+ m_trunc[mat < threshold] = 0.0
96
98
 
99
+ matrix_np = csr_matrix(m_trunc.numpy()) # SciPy's RCM compatibility
100
+ rcm_perm = reverse_cuthill_mckee(matrix_np, symmetric_mode=True)
101
+ return torch.from_numpy(rcm_perm.copy()) # translation requires copy
97
102
 
98
- def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
103
+
104
+ def minimize_bandwidth_global(mat: torch.Tensor) -> torch.Tensor:
99
105
  """
100
106
  minimize_bandwidth_global(matrix) -> list
101
107
 
@@ -111,74 +117,78 @@ def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
111
117
  -------
112
118
  permutation order that minimizes matrix bandwidth
113
119
 
114
- Example:
120
+ Example
115
121
  -------
116
- >>> matrix = np.array([
117
- ... [1, 2, 3],
118
- ... [2, 5, 6],
119
- ... [3, 6, 9]])
122
+ >>> matrix = torch.tensor([
123
+ ... [1, 2, 3],
124
+ ... [2, 5, 6],
125
+ ... [3, 6, 9]
126
+ ... ], dtype=torch.float32)
120
127
  >>> minimize_bandwidth_global(matrix)
121
- [2, 1, 0]
128
+ tensor([2, 1, 0], dtype=torch.int32)
122
129
  """
123
- mat_amplitude = np.max(np.abs(mat))
124
- # Search from 1.0 to 0.1 doesn't change result
130
+ mat_amplitude = torch.max(torch.abs(mat))
131
+
125
132
  permutations = (
126
- minimize_bandwidth_above_threshold(mat, trunc * mat_amplitude)
127
- for trunc in np.arange(start=0.1, stop=1.0, step=0.01)
133
+ minimize_bandwidth_above_threshold(mat, trunc.item() * mat_amplitude)
134
+ for trunc in torch.arange(0.1, 1.0, 0.01)
128
135
  )
129
136
 
130
137
  opt_permutation = min(
131
- permutations, key=lambda perm: matrix_bandwidth(permute_matrix(mat, list(perm)))
138
+ permutations, key=lambda perm: matrix_bandwidth(permute_tensor(mat, perm))
132
139
  )
133
- return list(opt_permutation) # opt_permutation is np.ndarray
140
+
141
+ return opt_permutation
134
142
 
135
143
 
136
144
  def minimize_bandwidth_impl(
137
- matrix: np.ndarray, initial_perm: list[int]
138
- ) -> tuple[list[int], float]:
145
+ matrix: torch.Tensor, initial_perm: torch.Tensor
146
+ ) -> tuple[torch.Tensor, float]:
139
147
  """
140
- minimize_bandwidth_impl(matrix, initial_perm) -> list
148
+ minimize_bandwidth_impl(matrix, initial_perm) -> (optimal_perm, bandwidth)
141
149
 
142
150
  Applies initial_perm to a matrix and
143
- finds the permutation list for a symmetric matrix that iteratively minimizes matrix bandwidth.
151
+ finds the permutation list for a symmetric matrix
152
+ that iteratively minimizes matrix bandwidth.
144
153
 
145
154
  Parameters
146
155
  -------
147
156
  matrix :
148
157
  symmetric square matrix
149
- initial_perm: list of integers
158
+ initial_perm: torch list of integers
150
159
 
151
160
 
152
161
  Returns
153
162
  -------
154
- permutation order that minimizes matrix bandwidth
163
+ optimal permutation and optimal matrix bandwidth
155
164
 
156
165
  Example:
157
166
  -------
158
167
  Periodic 1D chain
159
- >>> matrix = np.array([
168
+ >>> matrix = torch.tensor([
160
169
  ... [0, 1, 0, 0, 1],
161
170
  ... [1, 0, 1, 0, 0],
162
171
  ... [0, 1, 0, 1, 0],
163
172
  ... [0, 0, 1, 0, 1],
164
- ... [1, 0, 0, 1, 0]])
165
- >>> id_perm = list(range(matrix.shape[0]))
173
+ ... [1, 0, 0, 1, 0]], dtype=torch.float32)
174
+ >>> id_perm = torch.arange(matrix.shape[0])
166
175
  >>> minimize_bandwidth_impl(matrix, id_perm) # [3, 2, 4, 1, 0] does zig-zag
167
- ([3, 2, 4, 1, 0], 2.0)
176
+ (tensor([3, 2, 4, 1, 0]), 2.0)
168
177
 
169
178
  Simple 1D chain. Cannot be optimised further
170
- >>> matrix = np.array([
179
+ >>> matrix = torch.tensor([
171
180
  ... [0, 1, 0, 0, 0],
172
181
  ... [1, 0, 1, 0, 0],
173
182
  ... [0, 1, 0, 1, 0],
174
183
  ... [0, 0, 1, 0, 1],
175
- ... [0, 0, 0, 1, 0]])
176
- >>> id_perm = list(range(matrix.shape[0]))
184
+ ... [0, 0, 0, 1, 0]], dtype=torch.float32)
185
+ >>> id_perm = torch.arange(matrix.shape[0])
177
186
  >>> minimize_bandwidth_impl(matrix, id_perm)
178
- ([0, 1, 2, 3, 4], 1.0)
187
+ (tensor([0, 1, 2, 3, 4]), 1.0)
179
188
  """
180
- if initial_perm != list(range(matrix.shape[0])):
181
- matrix = permute_matrix(matrix, initial_perm)
189
+ L = matrix.shape[0]
190
+ if not torch.equal(initial_perm, torch.arange(L)):
191
+ matrix = permute_tensor(matrix, initial_perm)
182
192
  bandwidth = matrix_bandwidth(matrix)
183
193
  acc_permutation = initial_perm
184
194
 
@@ -191,28 +201,28 @@ def minimize_bandwidth_impl(
191
201
  )
192
202
 
193
203
  optimal_perm = minimize_bandwidth_global(matrix)
194
- test_mat = permute_matrix(matrix, optimal_perm)
204
+ test_mat = permute_tensor(matrix, optimal_perm)
195
205
  new_bandwidth = matrix_bandwidth(test_mat)
196
206
 
197
207
  if bandwidth <= new_bandwidth:
198
208
  break
199
209
 
200
210
  matrix = test_mat
201
- acc_permutation = permute_list(acc_permutation, optimal_perm)
211
+ acc_permutation = permute_tensor(acc_permutation, optimal_perm)
202
212
  bandwidth = new_bandwidth
203
213
 
204
214
  return acc_permutation, bandwidth
205
215
 
206
216
 
207
- def minimize_bandwidth(input_matrix: np.ndarray, samples: int = 100) -> list[int]:
217
+ def minimize_bandwidth(input_matrix: torch.Tensor, samples: int = 100) -> torch.Tensor:
208
218
  assert is_symmetric(input_matrix), "Input matrix is not symmetric"
209
- input_mat = abs(input_matrix)
219
+ input_mat = torch.abs(input_matrix)
210
220
  # We are interested in strength of the interaction, not sign
211
221
 
212
222
  L = input_mat.shape[0]
213
- rnd_permutations: itertools.chain[list[int]] = itertools.chain(
214
- [list(range(L))], # First element is always the identity list
215
- (np.random.permutation(L).tolist() for _ in range(samples)), # type: ignore[misc]
223
+ rnd_permutations = itertools.chain(
224
+ [torch.arange(L)], # identity permutation
225
+ [torch.randperm(L) for _ in range(samples)], # list of random permutations
216
226
  )
217
227
 
218
228
  opt_permutations_and_opt_bandwidth = (