emu-mps 2.0.1__py3-none-any.whl → 2.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emu_mps/__init__.py +1 -1
- emu_mps/hamiltonian.py +333 -347
- emu_mps/mpo.py +1 -1
- emu_mps/mps.py +14 -5
- emu_mps/mps_backend.py +3 -1
- emu_mps/mps_backend_impl.py +88 -18
- emu_mps/mps_config.py +29 -0
- emu_mps/optimatrix/__init__.py +14 -3
- emu_mps/optimatrix/optimiser.py +76 -66
- emu_mps/optimatrix/permutations.py +98 -43
- emu_mps/tdvp.py +11 -3
- {emu_mps-2.0.1.dist-info → emu_mps-2.0.3.dist-info}/METADATA +2 -2
- emu_mps-2.0.3.dist-info/RECORD +19 -0
- emu_mps-2.0.1.dist-info/RECORD +0 -19
- {emu_mps-2.0.1.dist-info → emu_mps-2.0.3.dist-info}/WHEEL +0 -0
emu_mps/mpo.py
CHANGED
|
@@ -83,7 +83,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
|
|
|
83
83
|
other.factors,
|
|
84
84
|
config=other.config,
|
|
85
85
|
)
|
|
86
|
-
return MPS(factors, orthogonality_center=0)
|
|
86
|
+
return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
|
|
87
87
|
|
|
88
88
|
def __add__(self, other: MPO) -> MPO:
|
|
89
89
|
"""
|
emu_mps/mps.py
CHANGED
|
@@ -38,6 +38,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
38
38
|
orthogonality_center: Optional[int] = None,
|
|
39
39
|
config: Optional[MPSConfig] = None,
|
|
40
40
|
num_gpus_to_use: Optional[int] = DEVICE_COUNT,
|
|
41
|
+
eigenstates: Sequence[Eigenstate] = ("r", "g"),
|
|
41
42
|
):
|
|
42
43
|
"""
|
|
43
44
|
This constructor creates a MPS directly from a list of tensors. It is for internal use only.
|
|
@@ -56,7 +57,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
56
57
|
num_gpus_to_use: distribute the factors over this many GPUs
|
|
57
58
|
0=all factors to cpu, None=keep the existing device assignment.
|
|
58
59
|
"""
|
|
59
|
-
|
|
60
|
+
super().__init__(eigenstates=eigenstates)
|
|
60
61
|
self.config = config if config is not None else MPSConfig()
|
|
61
62
|
assert all(
|
|
62
63
|
factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
|
|
@@ -88,6 +89,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
88
89
|
num_sites: int,
|
|
89
90
|
config: Optional[MPSConfig] = None,
|
|
90
91
|
num_gpus_to_use: int = DEVICE_COUNT,
|
|
92
|
+
eigenstates: Sequence[Eigenstate] = ["0", "1"],
|
|
91
93
|
) -> MPS:
|
|
92
94
|
"""
|
|
93
95
|
Returns a MPS in ground state |000..0>.
|
|
@@ -111,6 +113,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
111
113
|
config=config,
|
|
112
114
|
num_gpus_to_use=num_gpus_to_use,
|
|
113
115
|
orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
|
|
116
|
+
eigenstates=eigenstates,
|
|
114
117
|
)
|
|
115
118
|
|
|
116
119
|
def __repr__(self) -> str:
|
|
@@ -350,12 +353,16 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
350
353
|
the summed state
|
|
351
354
|
"""
|
|
352
355
|
assert isinstance(other, MPS), "Other state also needs to be an MPS"
|
|
356
|
+
assert (
|
|
357
|
+
self.eigenstates == other.eigenstates
|
|
358
|
+
), f"`Other` state has basis {other.eigenstates} != {self.eigenstates}"
|
|
353
359
|
new_tt = add_factors(self.factors, other.factors)
|
|
354
360
|
result = MPS(
|
|
355
361
|
new_tt,
|
|
356
362
|
config=self.config,
|
|
357
363
|
num_gpus_to_use=None,
|
|
358
364
|
orthogonality_center=None, # Orthogonality is lost.
|
|
365
|
+
eigenstates=self.eigenstates,
|
|
359
366
|
)
|
|
360
367
|
result.truncate()
|
|
361
368
|
return result
|
|
@@ -381,6 +388,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
381
388
|
config=self.config,
|
|
382
389
|
num_gpus_to_use=None,
|
|
383
390
|
orthogonality_center=self.orthogonality_center,
|
|
391
|
+
eigenstates=self.eigenstates,
|
|
384
392
|
)
|
|
385
393
|
|
|
386
394
|
def __imul__(self, scalar: complex) -> MPS:
|
|
@@ -390,7 +398,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
390
398
|
def _from_state_amplitudes(
|
|
391
399
|
cls,
|
|
392
400
|
*,
|
|
393
|
-
eigenstates: Sequence[
|
|
401
|
+
eigenstates: Sequence[Eigenstate],
|
|
394
402
|
amplitudes: Mapping[str, complex],
|
|
395
403
|
) -> tuple[MPS, Mapping[str, complex]]:
|
|
396
404
|
"""
|
|
@@ -404,8 +412,6 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
404
412
|
Returns:
|
|
405
413
|
The resulting MPS representation of the state.s
|
|
406
414
|
"""
|
|
407
|
-
|
|
408
|
-
nqubits = len(next(iter(amplitudes.keys())))
|
|
409
415
|
basis = set(eigenstates)
|
|
410
416
|
if basis == {"r", "g"}:
|
|
411
417
|
one = "r"
|
|
@@ -414,17 +420,20 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
414
420
|
else:
|
|
415
421
|
raise ValueError("Unsupported basis provided")
|
|
416
422
|
|
|
423
|
+
nqubits = cls._validate_amplitudes(amplitudes, eigenstates)
|
|
424
|
+
|
|
417
425
|
basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
|
|
418
426
|
basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
|
|
419
427
|
|
|
420
428
|
accum_mps = MPS(
|
|
421
429
|
[torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
|
|
422
430
|
orthogonality_center=0,
|
|
431
|
+
eigenstates=eigenstates,
|
|
423
432
|
)
|
|
424
433
|
|
|
425
434
|
for state, amplitude in amplitudes.items():
|
|
426
435
|
factors = [basis_1 if ch == one else basis_0 for ch in state]
|
|
427
|
-
accum_mps += amplitude * MPS(factors)
|
|
436
|
+
accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
|
|
428
437
|
norm = accum_mps.norm()
|
|
429
438
|
if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
|
|
430
439
|
print("\nThe state is not normalized, normalizing it for you.")
|
emu_mps/mps_backend.py
CHANGED
|
@@ -55,7 +55,9 @@ class MPSBackend(EmulatorBackend):
|
|
|
55
55
|
impl = create_impl(self._sequence, self._config)
|
|
56
56
|
impl.init() # This is separate from the constructor for testing purposes.
|
|
57
57
|
|
|
58
|
-
|
|
58
|
+
results = self._run(impl)
|
|
59
|
+
|
|
60
|
+
return impl.permute_results(results, self._config.optimize_qubit_ordering)
|
|
59
61
|
|
|
60
62
|
@staticmethod
|
|
61
63
|
def _run(impl: MPSBackendImpl) -> Results:
|
emu_mps/mps_backend_impl.py
CHANGED
|
@@ -1,29 +1,34 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import os
|
|
2
3
|
import pathlib
|
|
4
|
+
import pickle
|
|
3
5
|
import random
|
|
6
|
+
import time
|
|
7
|
+
import typing
|
|
4
8
|
import uuid
|
|
5
9
|
|
|
10
|
+
from collections import Counter
|
|
11
|
+
from enum import Enum, auto
|
|
6
12
|
from resource import RUSAGE_SELF, getrusage
|
|
7
|
-
from
|
|
8
|
-
import
|
|
9
|
-
|
|
10
|
-
import os
|
|
13
|
+
from types import MethodType
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
11
16
|
import torch
|
|
12
|
-
import time
|
|
13
17
|
from pulser import Sequence
|
|
14
|
-
from
|
|
18
|
+
from pulser.backend import EmulationConfig, Observable, Results, State
|
|
15
19
|
|
|
16
|
-
from
|
|
17
|
-
from emu_base import PulserData, DEVICE_COUNT
|
|
20
|
+
from emu_base import DEVICE_COUNT, PulserData
|
|
18
21
|
from emu_base.math.brents_root_finding import BrentsRootFinder
|
|
22
|
+
|
|
19
23
|
from emu_mps.hamiltonian import make_H, update_H
|
|
20
24
|
from emu_mps.mpo import MPO
|
|
21
25
|
from emu_mps.mps import MPS
|
|
22
26
|
from emu_mps.mps_config import MPSConfig
|
|
23
27
|
from emu_mps.noise import compute_noise_from_lindbladians, pick_well_prepared_qubits
|
|
28
|
+
import emu_mps.optimatrix as optimat
|
|
24
29
|
from emu_mps.tdvp import (
|
|
25
|
-
evolve_single,
|
|
26
30
|
evolve_pair,
|
|
31
|
+
evolve_single,
|
|
27
32
|
new_right_bath,
|
|
28
33
|
right_baths,
|
|
29
34
|
)
|
|
@@ -33,7 +38,6 @@ from emu_mps.utils import (
|
|
|
33
38
|
get_extended_site_index,
|
|
34
39
|
new_left_bath,
|
|
35
40
|
)
|
|
36
|
-
from enum import Enum, auto
|
|
37
41
|
|
|
38
42
|
|
|
39
43
|
class Statistics(Observable):
|
|
@@ -118,8 +122,17 @@ class MPSBackendImpl:
|
|
|
118
122
|
self.timestep_count: int = self.omega.shape[0]
|
|
119
123
|
self.has_lindblad_noise = pulser_data.has_lindblad_noise
|
|
120
124
|
self.lindblad_noise = torch.zeros(2, 2, dtype=torch.complex128)
|
|
121
|
-
self.
|
|
122
|
-
|
|
125
|
+
self.qubit_permutation = (
|
|
126
|
+
optimat.minimize_bandwidth(pulser_data.full_interaction_matrix)
|
|
127
|
+
if self.config.optimize_qubit_ordering
|
|
128
|
+
else optimat.eye_permutation(self.qubit_count)
|
|
129
|
+
)
|
|
130
|
+
self.full_interaction_matrix = optimat.permute_tensor(
|
|
131
|
+
pulser_data.full_interaction_matrix, self.qubit_permutation
|
|
132
|
+
)
|
|
133
|
+
self.masked_interaction_matrix = optimat.permute_tensor(
|
|
134
|
+
pulser_data.masked_interaction_matrix, self.qubit_permutation
|
|
135
|
+
)
|
|
123
136
|
self.hamiltonian_type = pulser_data.hamiltonian_type
|
|
124
137
|
self.slm_end_time = pulser_data.slm_end_time
|
|
125
138
|
self.is_masked = self.slm_end_time > 0.0
|
|
@@ -128,7 +141,12 @@ class MPSBackendImpl:
|
|
|
128
141
|
self.swipe_direction = SwipeDirection.LEFT_TO_RIGHT
|
|
129
142
|
self.tdvp_index = 0
|
|
130
143
|
self.timestep_index = 0
|
|
131
|
-
self.results = Results(
|
|
144
|
+
self.results = Results(
|
|
145
|
+
atom_order=optimat.permute_tuple(
|
|
146
|
+
pulser_data.qubit_ids, self.qubit_permutation
|
|
147
|
+
),
|
|
148
|
+
total_duration=self.target_times[-1],
|
|
149
|
+
)
|
|
132
150
|
self.statistics = Statistics(
|
|
133
151
|
evaluation_times=[t / self.target_times[-1] for t in self.target_times],
|
|
134
152
|
data=[],
|
|
@@ -203,11 +221,24 @@ class MPSBackendImpl:
|
|
|
203
221
|
)
|
|
204
222
|
|
|
205
223
|
assert isinstance(initial_state, MPS)
|
|
224
|
+
if not torch.equal(
|
|
225
|
+
self.qubit_permutation, optimat.eye_permutation(self.qubit_count)
|
|
226
|
+
):
|
|
227
|
+
# permute the initial state to match with permuted Hamiltonian
|
|
228
|
+
abstr_repr = initial_state._to_abstract_repr()
|
|
229
|
+
eigs = abstr_repr["eigenstates"]
|
|
230
|
+
ampl = {
|
|
231
|
+
optimat.permute_string(bstr, self.qubit_permutation): amp
|
|
232
|
+
for bstr, amp in abstr_repr["amplitudes"].items()
|
|
233
|
+
}
|
|
234
|
+
initial_state = MPS.from_state_amplitudes(eigenstates=eigs, amplitudes=ampl)
|
|
235
|
+
|
|
206
236
|
initial_state = MPS(
|
|
207
237
|
# Deep copy of every tensor of the initial state.
|
|
208
238
|
[f.clone().detach() for f in initial_state.factors],
|
|
209
239
|
config=self.config,
|
|
210
240
|
num_gpus_to_use=self.config.num_gpus_to_use,
|
|
241
|
+
eigenstates=initial_state.eigenstates,
|
|
211
242
|
)
|
|
212
243
|
initial_state.truncate()
|
|
213
244
|
initial_state *= 1 / initial_state.norm()
|
|
@@ -496,17 +527,56 @@ class MPSBackendImpl:
|
|
|
496
527
|
)
|
|
497
528
|
full_state = MPS(
|
|
498
529
|
extended_mps_factors(
|
|
499
|
-
normalized_state.factors,
|
|
530
|
+
normalized_state.factors,
|
|
531
|
+
self.well_prepared_qubits_filter,
|
|
500
532
|
),
|
|
501
533
|
num_gpus_to_use=None, # Keep the already assigned devices.
|
|
502
534
|
orthogonality_center=get_extended_site_index(
|
|
503
535
|
self.well_prepared_qubits_filter,
|
|
504
536
|
normalized_state.orthogonality_center,
|
|
505
537
|
),
|
|
538
|
+
eigenstates=normalized_state.eigenstates,
|
|
506
539
|
)
|
|
507
540
|
|
|
508
541
|
callback(self.config, fractional_time, full_state, full_mpo, self.results)
|
|
509
542
|
|
|
543
|
+
def permute_results(self, results: Results, permute: bool) -> Results:
|
|
544
|
+
if permute:
|
|
545
|
+
inv_perm = optimat.inv_permutation(self.qubit_permutation)
|
|
546
|
+
permute_bitstrings(results, inv_perm)
|
|
547
|
+
permute_occupations_and_correlations(results, inv_perm)
|
|
548
|
+
permute_atom_order(results, inv_perm)
|
|
549
|
+
return results
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def permute_bitstrings(results: Results, perm: torch.Tensor) -> None:
|
|
553
|
+
if "bitstrings" not in results.get_result_tags():
|
|
554
|
+
return
|
|
555
|
+
uuid_bs = results._find_uuid("bitstrings")
|
|
556
|
+
|
|
557
|
+
results._results[uuid_bs] = [
|
|
558
|
+
Counter({optimat.permute_string(bstr, perm): c for bstr, c in bs_counter.items()})
|
|
559
|
+
for bs_counter in results._results[uuid_bs]
|
|
560
|
+
]
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def permute_occupations_and_correlations(results: Results, perm: torch.Tensor) -> None:
|
|
564
|
+
for corr in ["occupation", "correlation_matrix"]:
|
|
565
|
+
if corr not in results.get_result_tags():
|
|
566
|
+
return
|
|
567
|
+
|
|
568
|
+
uuid_corr = results._find_uuid(corr)
|
|
569
|
+
corrs = results._results[uuid_corr]
|
|
570
|
+
results._results[uuid_corr] = [
|
|
571
|
+
optimat.permute_tensor(corr, perm) for corr in corrs
|
|
572
|
+
]
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def permute_atom_order(results: Results, perm: torch.Tensor) -> None:
|
|
576
|
+
at_ord = list(results.atom_order)
|
|
577
|
+
at_ord = optimat.permute_list(at_ord, perm)
|
|
578
|
+
results.atom_order = tuple(at_ord)
|
|
579
|
+
|
|
510
580
|
|
|
511
581
|
class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
512
582
|
"""
|
|
@@ -535,7 +605,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
535
605
|
|
|
536
606
|
def set_jump_threshold(self, bound: float) -> None:
|
|
537
607
|
self.jump_threshold = random.uniform(0.0, bound)
|
|
538
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
608
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
539
609
|
|
|
540
610
|
def init(self) -> None:
|
|
541
611
|
self.init_lindblad_noise()
|
|
@@ -546,7 +616,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
546
616
|
previous_time = self.current_time
|
|
547
617
|
self.current_time = self.target_time
|
|
548
618
|
previous_norm_gap_before_jump = self.norm_gap_before_jump
|
|
549
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
619
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
550
620
|
|
|
551
621
|
if self.root_finder is None:
|
|
552
622
|
# No quantum jump location finding in progress
|
|
@@ -566,7 +636,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
566
636
|
|
|
567
637
|
return
|
|
568
638
|
|
|
569
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
639
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
570
640
|
self.root_finder.provide_ordinate(self.current_time, self.norm_gap_before_jump)
|
|
571
641
|
|
|
572
642
|
if self.root_finder.is_converged(tolerance=1):
|
|
@@ -592,7 +662,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
592
662
|
self.state *= 1 / self.state.norm()
|
|
593
663
|
self.init_baths()
|
|
594
664
|
|
|
595
|
-
norm_after_normalizing = self.state.norm()
|
|
665
|
+
norm_after_normalizing = self.state.norm().item()
|
|
596
666
|
assert math.isclose(norm_after_normalizing, 1, abs_tol=1e-10)
|
|
597
667
|
self.set_jump_threshold(norm_after_normalizing**2)
|
|
598
668
|
|
emu_mps/mps_config.py
CHANGED
|
@@ -69,6 +69,7 @@ class MPSConfig(EmulationConfig):
|
|
|
69
69
|
max_krylov_dim: int = 100,
|
|
70
70
|
extra_krylov_tolerance: float = 1e-3,
|
|
71
71
|
num_gpus_to_use: int = DEVICE_COUNT,
|
|
72
|
+
optimize_qubit_ordering: bool = False,
|
|
72
73
|
interaction_cutoff: float = 0.0,
|
|
73
74
|
log_level: int = logging.INFO,
|
|
74
75
|
log_file: pathlib.Path | None = None,
|
|
@@ -84,6 +85,7 @@ class MPSConfig(EmulationConfig):
|
|
|
84
85
|
max_krylov_dim=max_krylov_dim,
|
|
85
86
|
extra_krylov_tolerance=extra_krylov_tolerance,
|
|
86
87
|
num_gpus_to_use=num_gpus_to_use,
|
|
88
|
+
optimize_qubit_ordering=optimize_qubit_ordering,
|
|
87
89
|
interaction_cutoff=interaction_cutoff,
|
|
88
90
|
log_level=log_level,
|
|
89
91
|
log_file=log_file,
|
|
@@ -91,6 +93,8 @@ class MPSConfig(EmulationConfig):
|
|
|
91
93
|
autosave_dt=autosave_dt,
|
|
92
94
|
**kwargs,
|
|
93
95
|
)
|
|
96
|
+
if self.optimize_qubit_ordering:
|
|
97
|
+
self.check_permutable_observables()
|
|
94
98
|
|
|
95
99
|
if "doppler" in self.noise_model.noise_types:
|
|
96
100
|
raise NotImplementedError("Unsupported noise type: doppler")
|
|
@@ -136,6 +140,7 @@ class MPSConfig(EmulationConfig):
|
|
|
136
140
|
"max_krylov_dim",
|
|
137
141
|
"extra_krylov_tolerance",
|
|
138
142
|
"num_gpus_to_use",
|
|
143
|
+
"optimize_qubit_ordering",
|
|
139
144
|
"interaction_cutoff",
|
|
140
145
|
"log_level",
|
|
141
146
|
"log_file",
|
|
@@ -183,3 +188,27 @@ class MPSConfig(EmulationConfig):
|
|
|
183
188
|
filemode="w",
|
|
184
189
|
force=True,
|
|
185
190
|
)
|
|
191
|
+
|
|
192
|
+
def check_permutable_observables(self) -> None:
|
|
193
|
+
allowed_permutable_obs = set(
|
|
194
|
+
[
|
|
195
|
+
"bitstrings",
|
|
196
|
+
"occupation",
|
|
197
|
+
"correlation_matrix",
|
|
198
|
+
"statistics",
|
|
199
|
+
"energy",
|
|
200
|
+
"energy_variance",
|
|
201
|
+
"energy_second_moment",
|
|
202
|
+
]
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
actual_obs = set([obs._base_tag for obs in self.observables])
|
|
206
|
+
not_allowed = actual_obs.difference(allowed_permutable_obs)
|
|
207
|
+
if not_allowed:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f"emu-mp allows only {allowed_permutable_obs} observables with"
|
|
210
|
+
" `optimize_qubit_ordering = True`."
|
|
211
|
+
f" you provided unsupported {not_allowed}"
|
|
212
|
+
" To use other observables, please set"
|
|
213
|
+
" `optimize_qubit_ordering = False` in `MPSConfig()`."
|
|
214
|
+
)
|
emu_mps/optimatrix/__init__.py
CHANGED
|
@@ -1,9 +1,20 @@
|
|
|
1
1
|
from .optimiser import minimize_bandwidth
|
|
2
|
-
from .permutations import
|
|
2
|
+
from .permutations import (
|
|
3
|
+
permute_tensor,
|
|
4
|
+
inv_permutation,
|
|
5
|
+
permute_string,
|
|
6
|
+
eye_permutation,
|
|
7
|
+
permute_list,
|
|
8
|
+
permute_tuple,
|
|
9
|
+
)
|
|
10
|
+
|
|
3
11
|
|
|
4
12
|
__all__ = [
|
|
5
13
|
"minimize_bandwidth",
|
|
14
|
+
"eye_permutation",
|
|
15
|
+
"permute_string",
|
|
16
|
+
"permute_tensor",
|
|
17
|
+
"inv_permutation",
|
|
6
18
|
"permute_list",
|
|
7
|
-
"
|
|
8
|
-
"invert_permutation",
|
|
19
|
+
"permute_tuple",
|
|
9
20
|
]
|
emu_mps/optimatrix/optimiser.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
1
4
|
from scipy.sparse import csr_matrix
|
|
2
5
|
from scipy.sparse.csgraph import reverse_cuthill_mckee
|
|
3
|
-
import numpy as np
|
|
4
|
-
from emu_mps.optimatrix.permutations import permute_matrix, permute_list
|
|
5
|
-
import itertools
|
|
6
6
|
|
|
7
|
+
from emu_mps.optimatrix.permutations import permute_tensor
|
|
7
8
|
|
|
8
|
-
def is_symmetric(mat: np.ndarray) -> bool:
|
|
9
|
-
if mat.shape[0] != mat.shape[1]:
|
|
10
|
-
return False
|
|
11
|
-
if not np.allclose(mat, mat.T, atol=1e-8):
|
|
12
|
-
return False
|
|
13
9
|
|
|
14
|
-
|
|
10
|
+
def is_symmetric(matrix: torch.Tensor, tol: float = 1e-8) -> bool:
|
|
11
|
+
return torch.allclose(matrix, matrix.T, atol=tol)
|
|
15
12
|
|
|
16
13
|
|
|
17
|
-
def matrix_bandwidth(mat:
|
|
18
|
-
"""matrix_bandwidth(matrix:
|
|
14
|
+
def matrix_bandwidth(mat: torch.Tensor) -> float:
|
|
15
|
+
"""matrix_bandwidth(matrix: torch.tensor) -> torch.Tensor
|
|
19
16
|
|
|
20
17
|
Computes bandwidth as max weighted distance between columns of
|
|
21
18
|
a square matrix as `max (abs(matrix[i, j] * (j - i))`.
|
|
@@ -45,19 +42,27 @@ def matrix_bandwidth(mat: np.ndarray) -> float:
|
|
|
45
42
|
|
|
46
43
|
Example:
|
|
47
44
|
-------
|
|
48
|
-
>>> matrix =
|
|
49
|
-
...
|
|
50
|
-
...
|
|
51
|
-
...
|
|
52
|
-
|
|
45
|
+
>>> matrix = torch.tensor([
|
|
46
|
+
... [1.0, -17.0, 2.4],
|
|
47
|
+
... [9.0, 1.0, -10.0],
|
|
48
|
+
... [-15.0, 20.0, 1.0]
|
|
49
|
+
... ])
|
|
50
|
+
>>> matrix_bandwidth(matrix) # because abs(-15 * (0 - 2)) = 30.0
|
|
53
51
|
30.0
|
|
54
52
|
"""
|
|
55
53
|
|
|
56
|
-
|
|
57
|
-
return float(bandwidth)
|
|
54
|
+
n = mat.shape[0]
|
|
58
55
|
|
|
56
|
+
i_arr = torch.arange(n).view(-1, 1) # shape (n, 1)
|
|
57
|
+
j_arr = torch.arange(n).view(1, -1) # shape (1, n)
|
|
59
58
|
|
|
60
|
-
|
|
59
|
+
weighted = torch.abs(mat * (j_arr - i_arr))
|
|
60
|
+
return torch.max(weighted).to(mat.dtype).item()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def minimize_bandwidth_above_threshold(
|
|
64
|
+
mat: torch.Tensor, threshold: float
|
|
65
|
+
) -> torch.Tensor:
|
|
61
66
|
"""
|
|
62
67
|
minimize_bandwidth_above_threshold(matrix, trunc) -> permutation_lists
|
|
63
68
|
|
|
@@ -78,24 +83,25 @@ def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.
|
|
|
78
83
|
|
|
79
84
|
Example:
|
|
80
85
|
-------
|
|
81
|
-
>>> matrix =
|
|
82
|
-
...
|
|
83
|
-
...
|
|
84
|
-
...
|
|
86
|
+
>>> matrix = torch.tensor([
|
|
87
|
+
... [1, 2, 3],
|
|
88
|
+
... [2, 5, 6],
|
|
89
|
+
... [3, 6, 9]
|
|
90
|
+
... ], dtype=torch.float32)
|
|
85
91
|
>>> threshold = 3
|
|
86
92
|
>>> minimize_bandwidth_above_threshold(matrix, threshold)
|
|
87
|
-
|
|
93
|
+
tensor([1, 2, 0], dtype=torch.int32)
|
|
88
94
|
"""
|
|
89
95
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
rcm_permutation = reverse_cuthill_mckee(
|
|
93
|
-
csr_matrix(matrix_truncated), symmetric_mode=True
|
|
94
|
-
)
|
|
95
|
-
return np.array(rcm_permutation)
|
|
96
|
+
m_trunc = mat.clone()
|
|
97
|
+
m_trunc[mat < threshold] = 0.0
|
|
96
98
|
|
|
99
|
+
matrix_np = csr_matrix(m_trunc.numpy()) # SciPy's RCM compatibility
|
|
100
|
+
rcm_perm = reverse_cuthill_mckee(matrix_np, symmetric_mode=True)
|
|
101
|
+
return torch.from_numpy(rcm_perm.copy()) # translation requires copy
|
|
97
102
|
|
|
98
|
-
|
|
103
|
+
|
|
104
|
+
def minimize_bandwidth_global(mat: torch.Tensor) -> torch.Tensor:
|
|
99
105
|
"""
|
|
100
106
|
minimize_bandwidth_global(matrix) -> list
|
|
101
107
|
|
|
@@ -111,74 +117,78 @@ def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
|
|
|
111
117
|
-------
|
|
112
118
|
permutation order that minimizes matrix bandwidth
|
|
113
119
|
|
|
114
|
-
Example
|
|
120
|
+
Example
|
|
115
121
|
-------
|
|
116
|
-
>>> matrix =
|
|
117
|
-
...
|
|
118
|
-
...
|
|
119
|
-
...
|
|
122
|
+
>>> matrix = torch.tensor([
|
|
123
|
+
... [1, 2, 3],
|
|
124
|
+
... [2, 5, 6],
|
|
125
|
+
... [3, 6, 9]
|
|
126
|
+
... ], dtype=torch.float32)
|
|
120
127
|
>>> minimize_bandwidth_global(matrix)
|
|
121
|
-
[2, 1, 0]
|
|
128
|
+
tensor([2, 1, 0], dtype=torch.int32)
|
|
122
129
|
"""
|
|
123
|
-
mat_amplitude =
|
|
124
|
-
|
|
130
|
+
mat_amplitude = torch.max(torch.abs(mat))
|
|
131
|
+
|
|
125
132
|
permutations = (
|
|
126
|
-
minimize_bandwidth_above_threshold(mat, trunc * mat_amplitude)
|
|
127
|
-
for trunc in
|
|
133
|
+
minimize_bandwidth_above_threshold(mat, trunc.item() * mat_amplitude)
|
|
134
|
+
for trunc in torch.arange(0.1, 1.0, 0.01)
|
|
128
135
|
)
|
|
129
136
|
|
|
130
137
|
opt_permutation = min(
|
|
131
|
-
permutations, key=lambda perm: matrix_bandwidth(
|
|
138
|
+
permutations, key=lambda perm: matrix_bandwidth(permute_tensor(mat, perm))
|
|
132
139
|
)
|
|
133
|
-
|
|
140
|
+
|
|
141
|
+
return opt_permutation
|
|
134
142
|
|
|
135
143
|
|
|
136
144
|
def minimize_bandwidth_impl(
|
|
137
|
-
matrix:
|
|
138
|
-
) -> tuple[
|
|
145
|
+
matrix: torch.Tensor, initial_perm: torch.Tensor
|
|
146
|
+
) -> tuple[torch.Tensor, float]:
|
|
139
147
|
"""
|
|
140
|
-
minimize_bandwidth_impl(matrix, initial_perm) ->
|
|
148
|
+
minimize_bandwidth_impl(matrix, initial_perm) -> (optimal_perm, bandwidth)
|
|
141
149
|
|
|
142
150
|
Applies initial_perm to a matrix and
|
|
143
|
-
finds the permutation list for a symmetric matrix
|
|
151
|
+
finds the permutation list for a symmetric matrix
|
|
152
|
+
that iteratively minimizes matrix bandwidth.
|
|
144
153
|
|
|
145
154
|
Parameters
|
|
146
155
|
-------
|
|
147
156
|
matrix :
|
|
148
157
|
symmetric square matrix
|
|
149
|
-
initial_perm: list of integers
|
|
158
|
+
initial_perm: torch list of integers
|
|
150
159
|
|
|
151
160
|
|
|
152
161
|
Returns
|
|
153
162
|
-------
|
|
154
|
-
permutation
|
|
163
|
+
optimal permutation and optimal matrix bandwidth
|
|
155
164
|
|
|
156
165
|
Example:
|
|
157
166
|
-------
|
|
158
167
|
Periodic 1D chain
|
|
159
|
-
>>> matrix =
|
|
168
|
+
>>> matrix = torch.tensor([
|
|
160
169
|
... [0, 1, 0, 0, 1],
|
|
161
170
|
... [1, 0, 1, 0, 0],
|
|
162
171
|
... [0, 1, 0, 1, 0],
|
|
163
172
|
... [0, 0, 1, 0, 1],
|
|
164
|
-
... [1, 0, 0, 1, 0]])
|
|
165
|
-
>>> id_perm =
|
|
173
|
+
... [1, 0, 0, 1, 0]], dtype=torch.float32)
|
|
174
|
+
>>> id_perm = torch.arange(matrix.shape[0])
|
|
166
175
|
>>> minimize_bandwidth_impl(matrix, id_perm) # [3, 2, 4, 1, 0] does zig-zag
|
|
167
|
-
([3, 2, 4, 1, 0], 2.0)
|
|
176
|
+
(tensor([3, 2, 4, 1, 0]), 2.0)
|
|
168
177
|
|
|
169
178
|
Simple 1D chain. Cannot be optimised further
|
|
170
|
-
>>> matrix =
|
|
179
|
+
>>> matrix = torch.tensor([
|
|
171
180
|
... [0, 1, 0, 0, 0],
|
|
172
181
|
... [1, 0, 1, 0, 0],
|
|
173
182
|
... [0, 1, 0, 1, 0],
|
|
174
183
|
... [0, 0, 1, 0, 1],
|
|
175
|
-
... [0, 0, 0, 1, 0]])
|
|
176
|
-
>>> id_perm =
|
|
184
|
+
... [0, 0, 0, 1, 0]], dtype=torch.float32)
|
|
185
|
+
>>> id_perm = torch.arange(matrix.shape[0])
|
|
177
186
|
>>> minimize_bandwidth_impl(matrix, id_perm)
|
|
178
|
-
([0, 1, 2, 3, 4], 1.0)
|
|
187
|
+
(tensor([0, 1, 2, 3, 4]), 1.0)
|
|
179
188
|
"""
|
|
180
|
-
|
|
181
|
-
|
|
189
|
+
L = matrix.shape[0]
|
|
190
|
+
if not torch.equal(initial_perm, torch.arange(L)):
|
|
191
|
+
matrix = permute_tensor(matrix, initial_perm)
|
|
182
192
|
bandwidth = matrix_bandwidth(matrix)
|
|
183
193
|
acc_permutation = initial_perm
|
|
184
194
|
|
|
@@ -191,28 +201,28 @@ def minimize_bandwidth_impl(
|
|
|
191
201
|
)
|
|
192
202
|
|
|
193
203
|
optimal_perm = minimize_bandwidth_global(matrix)
|
|
194
|
-
test_mat =
|
|
204
|
+
test_mat = permute_tensor(matrix, optimal_perm)
|
|
195
205
|
new_bandwidth = matrix_bandwidth(test_mat)
|
|
196
206
|
|
|
197
207
|
if bandwidth <= new_bandwidth:
|
|
198
208
|
break
|
|
199
209
|
|
|
200
210
|
matrix = test_mat
|
|
201
|
-
acc_permutation =
|
|
211
|
+
acc_permutation = permute_tensor(acc_permutation, optimal_perm)
|
|
202
212
|
bandwidth = new_bandwidth
|
|
203
213
|
|
|
204
214
|
return acc_permutation, bandwidth
|
|
205
215
|
|
|
206
216
|
|
|
207
|
-
def minimize_bandwidth(input_matrix:
|
|
217
|
+
def minimize_bandwidth(input_matrix: torch.Tensor, samples: int = 100) -> torch.Tensor:
|
|
208
218
|
assert is_symmetric(input_matrix), "Input matrix is not symmetric"
|
|
209
|
-
input_mat = abs(input_matrix)
|
|
219
|
+
input_mat = torch.abs(input_matrix)
|
|
210
220
|
# We are interested in strength of the interaction, not sign
|
|
211
221
|
|
|
212
222
|
L = input_mat.shape[0]
|
|
213
|
-
rnd_permutations
|
|
214
|
-
[
|
|
215
|
-
|
|
223
|
+
rnd_permutations = itertools.chain(
|
|
224
|
+
[torch.arange(L)], # identity permutation
|
|
225
|
+
[torch.randperm(L) for _ in range(samples)], # list of random permutations
|
|
216
226
|
)
|
|
217
227
|
|
|
218
228
|
opt_permutations_and_opt_bandwidth = (
|