emu-mps 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emu_mps/__init__.py +3 -1
- emu_mps/hamiltonian.py +333 -347
- emu_mps/mpo.py +1 -1
- emu_mps/mps.py +33 -5
- emu_mps/mps_backend_impl.py +10 -8
- emu_mps/observables.py +40 -0
- emu_mps/optimatrix/__init__.py +4 -4
- emu_mps/optimatrix/optimiser.py +76 -66
- emu_mps/optimatrix/permutations.py +57 -49
- emu_mps/tdvp.py +1 -1
- {emu_mps-2.0.0.dist-info → emu_mps-2.0.2.dist-info}/METADATA +2 -2
- emu_mps-2.0.2.dist-info/RECORD +19 -0
- emu_mps-2.0.0.dist-info/RECORD +0 -18
- {emu_mps-2.0.0.dist-info → emu_mps-2.0.2.dist-info}/WHEEL +0 -0
emu_mps/mpo.py
CHANGED
|
@@ -83,7 +83,7 @@ class MPO(Operator[complex, torch.Tensor, MPS]):
|
|
|
83
83
|
other.factors,
|
|
84
84
|
config=other.config,
|
|
85
85
|
)
|
|
86
|
-
return MPS(factors, orthogonality_center=0)
|
|
86
|
+
return MPS(factors, orthogonality_center=0, eigenstates=other.eigenstates)
|
|
87
87
|
|
|
88
88
|
def __add__(self, other: MPO) -> MPO:
|
|
89
89
|
"""
|
emu_mps/mps.py
CHANGED
|
@@ -38,6 +38,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
38
38
|
orthogonality_center: Optional[int] = None,
|
|
39
39
|
config: Optional[MPSConfig] = None,
|
|
40
40
|
num_gpus_to_use: Optional[int] = DEVICE_COUNT,
|
|
41
|
+
eigenstates: Sequence[Eigenstate] = ("r", "g"),
|
|
41
42
|
):
|
|
42
43
|
"""
|
|
43
44
|
This constructor creates a MPS directly from a list of tensors. It is for internal use only.
|
|
@@ -56,7 +57,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
56
57
|
num_gpus_to_use: distribute the factors over this many GPUs
|
|
57
58
|
0=all factors to cpu, None=keep the existing device assignment.
|
|
58
59
|
"""
|
|
59
|
-
|
|
60
|
+
super().__init__(eigenstates=eigenstates)
|
|
60
61
|
self.config = config if config is not None else MPSConfig()
|
|
61
62
|
assert all(
|
|
62
63
|
factors[i - 1].shape[2] == factors[i].shape[0] for i in range(1, len(factors))
|
|
@@ -88,6 +89,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
88
89
|
num_sites: int,
|
|
89
90
|
config: Optional[MPSConfig] = None,
|
|
90
91
|
num_gpus_to_use: int = DEVICE_COUNT,
|
|
92
|
+
eigenstates: Sequence[Eigenstate] = ["0", "1"],
|
|
91
93
|
) -> MPS:
|
|
92
94
|
"""
|
|
93
95
|
Returns a MPS in ground state |000..0>.
|
|
@@ -111,6 +113,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
111
113
|
config=config,
|
|
112
114
|
num_gpus_to_use=num_gpus_to_use,
|
|
113
115
|
orthogonality_center=0, # Arbitrary: every qubit is an orthogonality center.
|
|
116
|
+
eigenstates=eigenstates,
|
|
114
117
|
)
|
|
115
118
|
|
|
116
119
|
def __repr__(self) -> str:
|
|
@@ -307,6 +310,25 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
307
310
|
"""
|
|
308
311
|
return torch.abs(self.inner(other)) ** 2 # type: ignore[no-any-return]
|
|
309
312
|
|
|
313
|
+
def entanglement_entropy(self, mps_site: int) -> torch.Tensor:
|
|
314
|
+
"""
|
|
315
|
+
Returns
|
|
316
|
+
the Von Neumann entanglement entropy of the state `mps` at the bond between sites b and b+1
|
|
317
|
+
S = -Σᵢsᵢ² log(sᵢ²)),
|
|
318
|
+
where sᵢ are the singular values at the chosen bond.
|
|
319
|
+
"""
|
|
320
|
+
self.orthogonalize(mps_site)
|
|
321
|
+
|
|
322
|
+
# perform svd on reshaped matrix at site b
|
|
323
|
+
matrix = self.factors[mps_site].flatten(end_dim=1)
|
|
324
|
+
s = torch.linalg.svdvals(matrix)
|
|
325
|
+
|
|
326
|
+
s_e = torch.Tensor(torch.special.entr(s**2))
|
|
327
|
+
s_e = torch.sum(s_e)
|
|
328
|
+
|
|
329
|
+
self.orthogonalize(0)
|
|
330
|
+
return s_e.cpu()
|
|
331
|
+
|
|
310
332
|
def get_memory_footprint(self) -> float:
|
|
311
333
|
"""
|
|
312
334
|
Returns the number of MBs of memory occupied to store the state
|
|
@@ -331,12 +353,16 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
331
353
|
the summed state
|
|
332
354
|
"""
|
|
333
355
|
assert isinstance(other, MPS), "Other state also needs to be an MPS"
|
|
356
|
+
assert (
|
|
357
|
+
self.eigenstates == other.eigenstates
|
|
358
|
+
), f"`Other` state has basis {other.eigenstates} != {self.eigenstates}"
|
|
334
359
|
new_tt = add_factors(self.factors, other.factors)
|
|
335
360
|
result = MPS(
|
|
336
361
|
new_tt,
|
|
337
362
|
config=self.config,
|
|
338
363
|
num_gpus_to_use=None,
|
|
339
364
|
orthogonality_center=None, # Orthogonality is lost.
|
|
365
|
+
eigenstates=self.eigenstates,
|
|
340
366
|
)
|
|
341
367
|
result.truncate()
|
|
342
368
|
return result
|
|
@@ -362,6 +388,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
362
388
|
config=self.config,
|
|
363
389
|
num_gpus_to_use=None,
|
|
364
390
|
orthogonality_center=self.orthogonality_center,
|
|
391
|
+
eigenstates=self.eigenstates,
|
|
365
392
|
)
|
|
366
393
|
|
|
367
394
|
def __imul__(self, scalar: complex) -> MPS:
|
|
@@ -371,7 +398,7 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
371
398
|
def _from_state_amplitudes(
|
|
372
399
|
cls,
|
|
373
400
|
*,
|
|
374
|
-
eigenstates: Sequence[
|
|
401
|
+
eigenstates: Sequence[Eigenstate],
|
|
375
402
|
amplitudes: Mapping[str, complex],
|
|
376
403
|
) -> tuple[MPS, Mapping[str, complex]]:
|
|
377
404
|
"""
|
|
@@ -385,8 +412,6 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
385
412
|
Returns:
|
|
386
413
|
The resulting MPS representation of the state.s
|
|
387
414
|
"""
|
|
388
|
-
|
|
389
|
-
nqubits = len(next(iter(amplitudes.keys())))
|
|
390
415
|
basis = set(eigenstates)
|
|
391
416
|
if basis == {"r", "g"}:
|
|
392
417
|
one = "r"
|
|
@@ -395,17 +420,20 @@ class MPS(State[complex, torch.Tensor]):
|
|
|
395
420
|
else:
|
|
396
421
|
raise ValueError("Unsupported basis provided")
|
|
397
422
|
|
|
423
|
+
nqubits = cls._validate_amplitudes(amplitudes, eigenstates)
|
|
424
|
+
|
|
398
425
|
basis_0 = torch.tensor([[[1.0], [0.0]]], dtype=torch.complex128) # ground state
|
|
399
426
|
basis_1 = torch.tensor([[[0.0], [1.0]]], dtype=torch.complex128) # excited state
|
|
400
427
|
|
|
401
428
|
accum_mps = MPS(
|
|
402
429
|
[torch.zeros((1, 2, 1), dtype=torch.complex128)] * nqubits,
|
|
403
430
|
orthogonality_center=0,
|
|
431
|
+
eigenstates=eigenstates,
|
|
404
432
|
)
|
|
405
433
|
|
|
406
434
|
for state, amplitude in amplitudes.items():
|
|
407
435
|
factors = [basis_1 if ch == one else basis_0 for ch in state]
|
|
408
|
-
accum_mps += amplitude * MPS(factors)
|
|
436
|
+
accum_mps += amplitude * MPS(factors, eigenstates=eigenstates)
|
|
409
437
|
norm = accum_mps.norm()
|
|
410
438
|
if not math.isclose(1.0, norm, rel_tol=1e-5, abs_tol=0.0):
|
|
411
439
|
print("\nThe state is not normalized, normalizing it for you.")
|
emu_mps/mps_backend_impl.py
CHANGED
|
@@ -151,14 +151,14 @@ class MPSBackendImpl:
|
|
|
151
151
|
def __getstate__(self) -> dict:
|
|
152
152
|
for obs in self.config.observables:
|
|
153
153
|
obs.apply = MethodType(type(obs).apply, obs) # type: ignore[method-assign]
|
|
154
|
-
d = self.__dict__
|
|
154
|
+
d = self.__dict__.copy()
|
|
155
155
|
# mypy thinks the method below is an attribute, because of the __getattr__ override
|
|
156
156
|
d["results"] = self.results._to_abstract_repr() # type: ignore[operator]
|
|
157
157
|
return d
|
|
158
158
|
|
|
159
159
|
def __setstate__(self, d: dict) -> None:
|
|
160
|
-
d["results"] = Results._from_abstract_repr(d["results"]) # type: ignore [attr-defined]
|
|
161
160
|
self.__dict__ = d
|
|
161
|
+
self.results = Results._from_abstract_repr(d["results"]) # type: ignore [attr-defined]
|
|
162
162
|
self.config.monkeypatch_observables()
|
|
163
163
|
|
|
164
164
|
@staticmethod
|
|
@@ -208,6 +208,7 @@ class MPSBackendImpl:
|
|
|
208
208
|
[f.clone().detach() for f in initial_state.factors],
|
|
209
209
|
config=self.config,
|
|
210
210
|
num_gpus_to_use=self.config.num_gpus_to_use,
|
|
211
|
+
eigenstates=initial_state.eigenstates,
|
|
211
212
|
)
|
|
212
213
|
initial_state.truncate()
|
|
213
214
|
initial_state *= 1 / initial_state.norm()
|
|
@@ -444,7 +445,6 @@ class MPSBackendImpl:
|
|
|
444
445
|
basename = self.autosave_file
|
|
445
446
|
with open(basename.with_suffix(".new"), "wb") as file_handle:
|
|
446
447
|
pickle.dump(self, file_handle)
|
|
447
|
-
|
|
448
448
|
if basename.is_file():
|
|
449
449
|
os.rename(basename, basename.with_suffix(".bak"))
|
|
450
450
|
|
|
@@ -497,13 +497,15 @@ class MPSBackendImpl:
|
|
|
497
497
|
)
|
|
498
498
|
full_state = MPS(
|
|
499
499
|
extended_mps_factors(
|
|
500
|
-
normalized_state.factors,
|
|
500
|
+
normalized_state.factors,
|
|
501
|
+
self.well_prepared_qubits_filter,
|
|
501
502
|
),
|
|
502
503
|
num_gpus_to_use=None, # Keep the already assigned devices.
|
|
503
504
|
orthogonality_center=get_extended_site_index(
|
|
504
505
|
self.well_prepared_qubits_filter,
|
|
505
506
|
normalized_state.orthogonality_center,
|
|
506
507
|
),
|
|
508
|
+
eigenstates=normalized_state.eigenstates,
|
|
507
509
|
)
|
|
508
510
|
|
|
509
511
|
callback(self.config, fractional_time, full_state, full_mpo, self.results)
|
|
@@ -536,7 +538,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
536
538
|
|
|
537
539
|
def set_jump_threshold(self, bound: float) -> None:
|
|
538
540
|
self.jump_threshold = random.uniform(0.0, bound)
|
|
539
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
541
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
540
542
|
|
|
541
543
|
def init(self) -> None:
|
|
542
544
|
self.init_lindblad_noise()
|
|
@@ -547,7 +549,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
547
549
|
previous_time = self.current_time
|
|
548
550
|
self.current_time = self.target_time
|
|
549
551
|
previous_norm_gap_before_jump = self.norm_gap_before_jump
|
|
550
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
552
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
551
553
|
|
|
552
554
|
if self.root_finder is None:
|
|
553
555
|
# No quantum jump location finding in progress
|
|
@@ -567,7 +569,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
567
569
|
|
|
568
570
|
return
|
|
569
571
|
|
|
570
|
-
self.norm_gap_before_jump = self.state.norm() ** 2 - self.jump_threshold
|
|
572
|
+
self.norm_gap_before_jump = self.state.norm().item() ** 2 - self.jump_threshold
|
|
571
573
|
self.root_finder.provide_ordinate(self.current_time, self.norm_gap_before_jump)
|
|
572
574
|
|
|
573
575
|
if self.root_finder.is_converged(tolerance=1):
|
|
@@ -593,7 +595,7 @@ class NoisyMPSBackendImpl(MPSBackendImpl):
|
|
|
593
595
|
self.state *= 1 / self.state.norm()
|
|
594
596
|
self.init_baths()
|
|
595
597
|
|
|
596
|
-
norm_after_normalizing = self.state.norm()
|
|
598
|
+
norm_after_normalizing = self.state.norm().item()
|
|
597
599
|
assert math.isclose(norm_after_normalizing, 1, abs_tol=1e-10)
|
|
598
600
|
self.set_jump_threshold(norm_after_normalizing**2)
|
|
599
601
|
|
emu_mps/observables.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from pulser.backend.state import State
|
|
2
|
+
from pulser.backend.observable import Observable
|
|
3
|
+
from emu_mps.mps import MPS
|
|
4
|
+
from typing import Sequence, Any
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EntanglementEntropy(Observable):
|
|
9
|
+
"""Entanglement Entropy subclass used only in emu_mps"""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
mps_site: int,
|
|
14
|
+
*,
|
|
15
|
+
evaluation_times: Sequence[float] | None = None,
|
|
16
|
+
tag_suffix: str | None = None,
|
|
17
|
+
):
|
|
18
|
+
super().__init__(evaluation_times=evaluation_times, tag_suffix=tag_suffix)
|
|
19
|
+
self.mps_site = mps_site
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def _base_tag(self) -> str:
|
|
23
|
+
return "entanglement_entropy"
|
|
24
|
+
|
|
25
|
+
def _to_abstract_repr(self) -> dict[str, Any]:
|
|
26
|
+
repr = super()._to_abstract_repr()
|
|
27
|
+
repr["mps_site"] = self.mps_site
|
|
28
|
+
return repr
|
|
29
|
+
|
|
30
|
+
def apply(self, *, state: State, **kwargs: Any) -> torch.Tensor:
|
|
31
|
+
if not isinstance(state, MPS):
|
|
32
|
+
raise NotImplementedError(
|
|
33
|
+
"Entanglement entropy observable is only available for emu_mps emulator."
|
|
34
|
+
)
|
|
35
|
+
if not (0 <= self.mps_site <= len(state.factors) - 2):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Invalid bond index {self.mps_site}. "
|
|
38
|
+
f"Expected value in range 0 <= bond_index <= {len(state.factors)-2}."
|
|
39
|
+
)
|
|
40
|
+
return state.entanglement_entropy(self.mps_site)
|
emu_mps/optimatrix/__init__.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from .optimiser import minimize_bandwidth
|
|
2
|
-
from .permutations import
|
|
2
|
+
from .permutations import permute_tensor, inv_permutation, permute_string
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
5
|
"minimize_bandwidth",
|
|
6
|
-
"
|
|
7
|
-
"
|
|
8
|
-
"
|
|
6
|
+
"permute_string",
|
|
7
|
+
"permute_tensor",
|
|
8
|
+
"inv_permutation",
|
|
9
9
|
]
|
emu_mps/optimatrix/optimiser.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
1
4
|
from scipy.sparse import csr_matrix
|
|
2
5
|
from scipy.sparse.csgraph import reverse_cuthill_mckee
|
|
3
|
-
import numpy as np
|
|
4
|
-
from emu_mps.optimatrix.permutations import permute_matrix, permute_list
|
|
5
|
-
import itertools
|
|
6
6
|
|
|
7
|
+
from emu_mps.optimatrix.permutations import permute_tensor
|
|
7
8
|
|
|
8
|
-
def is_symmetric(mat: np.ndarray) -> bool:
|
|
9
|
-
if mat.shape[0] != mat.shape[1]:
|
|
10
|
-
return False
|
|
11
|
-
if not np.allclose(mat, mat.T, atol=1e-8):
|
|
12
|
-
return False
|
|
13
9
|
|
|
14
|
-
|
|
10
|
+
def is_symmetric(matrix: torch.Tensor, tol: float = 1e-8) -> bool:
|
|
11
|
+
return torch.allclose(matrix, matrix.T, atol=tol)
|
|
15
12
|
|
|
16
13
|
|
|
17
|
-
def matrix_bandwidth(mat:
|
|
18
|
-
"""matrix_bandwidth(matrix:
|
|
14
|
+
def matrix_bandwidth(mat: torch.Tensor) -> float:
|
|
15
|
+
"""matrix_bandwidth(matrix: torch.tensor) -> torch.Tensor
|
|
19
16
|
|
|
20
17
|
Computes bandwidth as max weighted distance between columns of
|
|
21
18
|
a square matrix as `max (abs(matrix[i, j] * (j - i))`.
|
|
@@ -45,19 +42,27 @@ def matrix_bandwidth(mat: np.ndarray) -> float:
|
|
|
45
42
|
|
|
46
43
|
Example:
|
|
47
44
|
-------
|
|
48
|
-
>>> matrix =
|
|
49
|
-
...
|
|
50
|
-
...
|
|
51
|
-
...
|
|
52
|
-
|
|
45
|
+
>>> matrix = torch.tensor([
|
|
46
|
+
... [1.0, -17.0, 2.4],
|
|
47
|
+
... [9.0, 1.0, -10.0],
|
|
48
|
+
... [-15.0, 20.0, 1.0]
|
|
49
|
+
... ])
|
|
50
|
+
>>> matrix_bandwidth(matrix) # because abs(-15 * (0 - 2)) = 30.0
|
|
53
51
|
30.0
|
|
54
52
|
"""
|
|
55
53
|
|
|
56
|
-
|
|
57
|
-
return float(bandwidth)
|
|
54
|
+
n = mat.shape[0]
|
|
58
55
|
|
|
56
|
+
i_arr = torch.arange(n).view(-1, 1) # shape (n, 1)
|
|
57
|
+
j_arr = torch.arange(n).view(1, -1) # shape (1, n)
|
|
59
58
|
|
|
60
|
-
|
|
59
|
+
weighted = torch.abs(mat * (j_arr - i_arr))
|
|
60
|
+
return torch.max(weighted).to(mat.dtype).item()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def minimize_bandwidth_above_threshold(
|
|
64
|
+
mat: torch.Tensor, threshold: float
|
|
65
|
+
) -> torch.Tensor:
|
|
61
66
|
"""
|
|
62
67
|
minimize_bandwidth_above_threshold(matrix, trunc) -> permutation_lists
|
|
63
68
|
|
|
@@ -78,24 +83,25 @@ def minimize_bandwidth_above_threshold(mat: np.ndarray, threshold: float) -> np.
|
|
|
78
83
|
|
|
79
84
|
Example:
|
|
80
85
|
-------
|
|
81
|
-
>>> matrix =
|
|
82
|
-
...
|
|
83
|
-
...
|
|
84
|
-
...
|
|
86
|
+
>>> matrix = torch.tensor([
|
|
87
|
+
... [1, 2, 3],
|
|
88
|
+
... [2, 5, 6],
|
|
89
|
+
... [3, 6, 9]
|
|
90
|
+
... ], dtype=torch.float32)
|
|
85
91
|
>>> threshold = 3
|
|
86
92
|
>>> minimize_bandwidth_above_threshold(matrix, threshold)
|
|
87
|
-
|
|
93
|
+
tensor([1, 2, 0], dtype=torch.int32)
|
|
88
94
|
"""
|
|
89
95
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
rcm_permutation = reverse_cuthill_mckee(
|
|
93
|
-
csr_matrix(matrix_truncated), symmetric_mode=True
|
|
94
|
-
)
|
|
95
|
-
return np.array(rcm_permutation)
|
|
96
|
+
m_trunc = mat.clone()
|
|
97
|
+
m_trunc[mat < threshold] = 0.0
|
|
96
98
|
|
|
99
|
+
matrix_np = csr_matrix(m_trunc.numpy()) # SciPy's RCM compatibility
|
|
100
|
+
rcm_perm = reverse_cuthill_mckee(matrix_np, symmetric_mode=True)
|
|
101
|
+
return torch.from_numpy(rcm_perm.copy()) # translation requires copy
|
|
97
102
|
|
|
98
|
-
|
|
103
|
+
|
|
104
|
+
def minimize_bandwidth_global(mat: torch.Tensor) -> torch.Tensor:
|
|
99
105
|
"""
|
|
100
106
|
minimize_bandwidth_global(matrix) -> list
|
|
101
107
|
|
|
@@ -111,74 +117,78 @@ def minimize_bandwidth_global(mat: np.ndarray) -> list[int]:
|
|
|
111
117
|
-------
|
|
112
118
|
permutation order that minimizes matrix bandwidth
|
|
113
119
|
|
|
114
|
-
Example
|
|
120
|
+
Example
|
|
115
121
|
-------
|
|
116
|
-
>>> matrix =
|
|
117
|
-
...
|
|
118
|
-
...
|
|
119
|
-
...
|
|
122
|
+
>>> matrix = torch.tensor([
|
|
123
|
+
... [1, 2, 3],
|
|
124
|
+
... [2, 5, 6],
|
|
125
|
+
... [3, 6, 9]
|
|
126
|
+
... ], dtype=torch.float32)
|
|
120
127
|
>>> minimize_bandwidth_global(matrix)
|
|
121
|
-
[2, 1, 0]
|
|
128
|
+
tensor([2, 1, 0], dtype=torch.int32)
|
|
122
129
|
"""
|
|
123
|
-
mat_amplitude =
|
|
124
|
-
|
|
130
|
+
mat_amplitude = torch.max(torch.abs(mat))
|
|
131
|
+
|
|
125
132
|
permutations = (
|
|
126
|
-
minimize_bandwidth_above_threshold(mat, trunc * mat_amplitude)
|
|
127
|
-
for trunc in
|
|
133
|
+
minimize_bandwidth_above_threshold(mat, trunc.item() * mat_amplitude)
|
|
134
|
+
for trunc in torch.arange(0.1, 1.0, 0.01)
|
|
128
135
|
)
|
|
129
136
|
|
|
130
137
|
opt_permutation = min(
|
|
131
|
-
permutations, key=lambda perm: matrix_bandwidth(
|
|
138
|
+
permutations, key=lambda perm: matrix_bandwidth(permute_tensor(mat, perm))
|
|
132
139
|
)
|
|
133
|
-
|
|
140
|
+
|
|
141
|
+
return opt_permutation
|
|
134
142
|
|
|
135
143
|
|
|
136
144
|
def minimize_bandwidth_impl(
|
|
137
|
-
matrix:
|
|
138
|
-
) -> tuple[
|
|
145
|
+
matrix: torch.Tensor, initial_perm: torch.Tensor
|
|
146
|
+
) -> tuple[torch.Tensor, float]:
|
|
139
147
|
"""
|
|
140
|
-
minimize_bandwidth_impl(matrix, initial_perm) ->
|
|
148
|
+
minimize_bandwidth_impl(matrix, initial_perm) -> (optimal_perm, bandwidth)
|
|
141
149
|
|
|
142
150
|
Applies initial_perm to a matrix and
|
|
143
|
-
finds the permutation list for a symmetric matrix
|
|
151
|
+
finds the permutation list for a symmetric matrix
|
|
152
|
+
that iteratively minimizes matrix bandwidth.
|
|
144
153
|
|
|
145
154
|
Parameters
|
|
146
155
|
-------
|
|
147
156
|
matrix :
|
|
148
157
|
symmetric square matrix
|
|
149
|
-
initial_perm: list of integers
|
|
158
|
+
initial_perm: torch list of integers
|
|
150
159
|
|
|
151
160
|
|
|
152
161
|
Returns
|
|
153
162
|
-------
|
|
154
|
-
permutation
|
|
163
|
+
optimal permutation and optimal matrix bandwidth
|
|
155
164
|
|
|
156
165
|
Example:
|
|
157
166
|
-------
|
|
158
167
|
Periodic 1D chain
|
|
159
|
-
>>> matrix =
|
|
168
|
+
>>> matrix = torch.tensor([
|
|
160
169
|
... [0, 1, 0, 0, 1],
|
|
161
170
|
... [1, 0, 1, 0, 0],
|
|
162
171
|
... [0, 1, 0, 1, 0],
|
|
163
172
|
... [0, 0, 1, 0, 1],
|
|
164
|
-
... [1, 0, 0, 1, 0]])
|
|
165
|
-
>>> id_perm =
|
|
173
|
+
... [1, 0, 0, 1, 0]], dtype=torch.float32)
|
|
174
|
+
>>> id_perm = torch.arange(matrix.shape[0])
|
|
166
175
|
>>> minimize_bandwidth_impl(matrix, id_perm) # [3, 2, 4, 1, 0] does zig-zag
|
|
167
|
-
([3, 2, 4, 1, 0], 2.0)
|
|
176
|
+
(tensor([3, 2, 4, 1, 0]), 2.0)
|
|
168
177
|
|
|
169
178
|
Simple 1D chain. Cannot be optimised further
|
|
170
|
-
>>> matrix =
|
|
179
|
+
>>> matrix = torch.tensor([
|
|
171
180
|
... [0, 1, 0, 0, 0],
|
|
172
181
|
... [1, 0, 1, 0, 0],
|
|
173
182
|
... [0, 1, 0, 1, 0],
|
|
174
183
|
... [0, 0, 1, 0, 1],
|
|
175
|
-
... [0, 0, 0, 1, 0]])
|
|
176
|
-
>>> id_perm =
|
|
184
|
+
... [0, 0, 0, 1, 0]], dtype=torch.float32)
|
|
185
|
+
>>> id_perm = torch.arange(matrix.shape[0])
|
|
177
186
|
>>> minimize_bandwidth_impl(matrix, id_perm)
|
|
178
|
-
([0, 1, 2, 3, 4], 1.0)
|
|
187
|
+
(tensor([0, 1, 2, 3, 4]), 1.0)
|
|
179
188
|
"""
|
|
180
|
-
|
|
181
|
-
|
|
189
|
+
L = matrix.shape[0]
|
|
190
|
+
if not torch.equal(initial_perm, torch.arange(L)):
|
|
191
|
+
matrix = permute_tensor(matrix, initial_perm)
|
|
182
192
|
bandwidth = matrix_bandwidth(matrix)
|
|
183
193
|
acc_permutation = initial_perm
|
|
184
194
|
|
|
@@ -191,28 +201,28 @@ def minimize_bandwidth_impl(
|
|
|
191
201
|
)
|
|
192
202
|
|
|
193
203
|
optimal_perm = minimize_bandwidth_global(matrix)
|
|
194
|
-
test_mat =
|
|
204
|
+
test_mat = permute_tensor(matrix, optimal_perm)
|
|
195
205
|
new_bandwidth = matrix_bandwidth(test_mat)
|
|
196
206
|
|
|
197
207
|
if bandwidth <= new_bandwidth:
|
|
198
208
|
break
|
|
199
209
|
|
|
200
210
|
matrix = test_mat
|
|
201
|
-
acc_permutation =
|
|
211
|
+
acc_permutation = permute_tensor(acc_permutation, optimal_perm)
|
|
202
212
|
bandwidth = new_bandwidth
|
|
203
213
|
|
|
204
214
|
return acc_permutation, bandwidth
|
|
205
215
|
|
|
206
216
|
|
|
207
|
-
def minimize_bandwidth(input_matrix:
|
|
217
|
+
def minimize_bandwidth(input_matrix: torch.Tensor, samples: int = 100) -> torch.Tensor:
|
|
208
218
|
assert is_symmetric(input_matrix), "Input matrix is not symmetric"
|
|
209
|
-
input_mat = abs(input_matrix)
|
|
219
|
+
input_mat = torch.abs(input_matrix)
|
|
210
220
|
# We are interested in strength of the interaction, not sign
|
|
211
221
|
|
|
212
222
|
L = input_mat.shape[0]
|
|
213
|
-
rnd_permutations
|
|
214
|
-
[
|
|
215
|
-
|
|
223
|
+
rnd_permutations = itertools.chain(
|
|
224
|
+
[torch.arange(L)], # identity permutation
|
|
225
|
+
[torch.randperm(L) for _ in range(samples)], # list of random permutations
|
|
216
226
|
)
|
|
217
227
|
|
|
218
228
|
opt_permutations_and_opt_bandwidth = (
|
|
@@ -1,36 +1,31 @@
|
|
|
1
|
-
import
|
|
1
|
+
import torch
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
def
|
|
4
|
+
def permute_string(input_str: str, perm: torch.Tensor) -> str:
|
|
5
5
|
"""
|
|
6
|
-
Permutes the input
|
|
7
|
-
|
|
6
|
+
Permutes the input string according to the given permutation.
|
|
8
7
|
Parameters
|
|
9
8
|
-------
|
|
10
|
-
|
|
11
|
-
A
|
|
9
|
+
input_string :
|
|
10
|
+
A string to permute.
|
|
12
11
|
permutation :
|
|
13
12
|
A list of indices representing the new order.
|
|
14
|
-
|
|
15
13
|
Returns
|
|
16
14
|
-------
|
|
17
|
-
The permuted
|
|
18
|
-
|
|
15
|
+
The permuted string.
|
|
19
16
|
Example
|
|
20
17
|
-------
|
|
21
|
-
>>>
|
|
22
|
-
|
|
18
|
+
>>> permute_string("abc", torch.tensor([2, 0, 1]))
|
|
19
|
+
'cab'
|
|
23
20
|
"""
|
|
21
|
+
char_list = list(input_str)
|
|
22
|
+
permuted = [char_list[i] for i in perm.tolist()]
|
|
23
|
+
return "".join(permuted)
|
|
24
24
|
|
|
25
|
-
permuted_list = [None] * len(input_list)
|
|
26
|
-
for i, p in enumerate(permutation):
|
|
27
|
-
permuted_list[i] = input_list[p]
|
|
28
|
-
return permuted_list
|
|
29
25
|
|
|
30
|
-
|
|
31
|
-
def invert_permutation(permutation: list[int]) -> list[int]:
|
|
26
|
+
def inv_permutation(permutation: torch.Tensor) -> torch.Tensor:
|
|
32
27
|
"""
|
|
33
|
-
|
|
28
|
+
inv_permutation(permutation) -> inverted_perm
|
|
34
29
|
|
|
35
30
|
Inverts the input permutation list.
|
|
36
31
|
|
|
@@ -45,47 +40,60 @@ def invert_permutation(permutation: list[int]) -> list[int]:
|
|
|
45
40
|
|
|
46
41
|
Example:
|
|
47
42
|
-------
|
|
48
|
-
>>>
|
|
49
|
-
[1, 2, 0]
|
|
43
|
+
>>> inv_permutation(torch.tensor([2, 0, 1]))
|
|
44
|
+
tensor([1, 2, 0])
|
|
50
45
|
"""
|
|
51
|
-
|
|
52
|
-
inv_perm =
|
|
53
|
-
inv_perm
|
|
54
|
-
return list(inv_perm)
|
|
46
|
+
inv_perm = torch.empty_like(permutation)
|
|
47
|
+
inv_perm[permutation] = torch.arange(len(permutation))
|
|
48
|
+
return inv_perm
|
|
55
49
|
|
|
56
50
|
|
|
57
|
-
def
|
|
51
|
+
def permute_tensor(tensor: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
|
58
52
|
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
53
|
+
Permute a 1D or square 2D torch tensor using the given permutation indices.
|
|
54
|
+
For 1D tensors, applies the permutation to the elements.
|
|
55
|
+
For 2D square tensors, applies the same permutation to both rows and columns.
|
|
62
56
|
|
|
63
57
|
Parameters
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
square
|
|
67
|
-
|
|
68
|
-
permutation
|
|
58
|
+
----------
|
|
59
|
+
tensor : torch.Tensor
|
|
60
|
+
A 1D or 2D square tensor to be permuted.
|
|
61
|
+
perm : torch.Tensor
|
|
62
|
+
A 1D tensor of indices specifying the permutation order.
|
|
69
63
|
|
|
70
64
|
Returns
|
|
71
65
|
-------
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
66
|
+
torch.Tensor
|
|
67
|
+
A new tensor with elements (1D) or rows and columns (2D) permuted according to `perm`.
|
|
68
|
+
|
|
69
|
+
Raises
|
|
70
|
+
------
|
|
71
|
+
ValueError
|
|
72
|
+
If tensor is not 1D or square 2D.
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
>>> vector = torch.tensor([10, 20, 30])
|
|
77
|
+
>>> perm = torch.tensor([2, 0, 1])
|
|
78
|
+
>>> permute_tensor(vector, perm)
|
|
79
|
+
tensor([30, 10, 20])
|
|
80
|
+
|
|
81
|
+
>>> matrix = torch.tensor([
|
|
82
|
+
... [1, 2, 3],
|
|
83
|
+
... [4, 5, 6],
|
|
84
|
+
... [7, 8, 9]])
|
|
85
|
+
>>> perm = torch.tensor([1, 0, 2])
|
|
86
|
+
>>> permute_tensor(matrix, perm)
|
|
87
|
+
tensor([[5, 4, 6],
|
|
88
|
+
[2, 1, 3],
|
|
89
|
+
[8, 7, 9]])
|
|
85
90
|
"""
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
91
|
+
if tensor.ndim == 1:
|
|
92
|
+
return tensor[perm]
|
|
93
|
+
elif tensor.ndim == 2 and tensor.shape[0] == tensor.shape[1]:
|
|
94
|
+
return tensor[perm][:, perm]
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError("Only 1D tensors or square 2D tensors are supported.")
|
|
89
97
|
|
|
90
98
|
|
|
91
99
|
if __name__ == "__main__":
|
emu_mps/tdvp.py
CHANGED