emu-base 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emu_base/__init__.py +47 -0
- emu_base/base_classes/__init__.py +31 -0
- emu_base/base_classes/aggregators.py +59 -0
- emu_base/base_classes/backend.py +48 -0
- emu_base/base_classes/callback.py +90 -0
- emu_base/base_classes/config.py +81 -0
- emu_base/base_classes/default_callbacks.py +300 -0
- emu_base/base_classes/operator.py +126 -0
- emu_base/base_classes/results.py +174 -0
- emu_base/base_classes/state.py +97 -0
- emu_base/lindblad_operators.py +44 -0
- emu_base/math/__init__.py +3 -0
- emu_base/math/brents_root_finding.py +121 -0
- emu_base/math/krylov_exp.py +127 -0
- emu_base/pulser_adapter.py +248 -0
- emu_base/utils.py +9 -0
- emu_base-1.2.1.dist-info/METADATA +134 -0
- emu_base-1.2.1.dist-info/RECORD +19 -0
- emu_base-1.2.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Iterable
|
|
5
|
+
|
|
6
|
+
from emu_base.base_classes.state import State
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
QuditOp = dict[str, complex] # single qubit operator
|
|
10
|
+
TensorOp = list[tuple[QuditOp, list[int]]] # QuditOp applied to list of qubits
|
|
11
|
+
FullOp = list[tuple[complex, TensorOp]] # weighted sum of TensorOp
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Operator(ABC):
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def __mul__(self, other: State) -> State:
|
|
17
|
+
"""
|
|
18
|
+
Apply the operator to a state
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
other: the state to apply this operator to
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
the resulting state
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def __add__(self, other: Operator) -> Operator:
|
|
30
|
+
"""
|
|
31
|
+
Computes the sum of two operators.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
other: the other operator
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
the summed operator
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def expect(self, state: State) -> float | complex:
|
|
43
|
+
"""
|
|
44
|
+
Compute the expectation value of self on the given state.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
state: the state with which to compute
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
the expectation
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def from_operator_string(
|
|
56
|
+
basis: Iterable[str],
|
|
57
|
+
nqubits: int,
|
|
58
|
+
operations: FullOp,
|
|
59
|
+
operators: dict[str, QuditOp] = {},
|
|
60
|
+
/,
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
) -> Operator:
|
|
63
|
+
"""
|
|
64
|
+
Create an operator in the backend-specific format from the
|
|
65
|
+
pulser abstract representation
|
|
66
|
+
<https://www.notion.so/pasqal/Abstract-State-and-Operator-Definition>
|
|
67
|
+
By default it supports strings 'ij', where i and j in basis,
|
|
68
|
+
to denote |i><j|, but additional symbols can be defined in operators
|
|
69
|
+
For a list of existing bases, see
|
|
70
|
+
<https://pulser.readthedocs.io/en/stable/conventions.html>
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
basis: the eigenstates in the basis to use
|
|
74
|
+
nqubits: how many qubits there are in the state
|
|
75
|
+
operations: which bitstrings make up the state with what weight
|
|
76
|
+
operators: additional symbols to be used in operations
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
the operator in whatever format the backend provides.
|
|
80
|
+
|
|
81
|
+
Examples:
|
|
82
|
+
>>> basis = {"r", "g"} #rydberg basis
|
|
83
|
+
>>> nqubits = 3 #or whatever
|
|
84
|
+
>>> x = {"rg": 1.0, "gr": 1.0}
|
|
85
|
+
>>> z = {"gg": 1.0, "rr": -1.0}
|
|
86
|
+
>>> operators = {"X": x, "Z": z} #define X and Z as conveniences
|
|
87
|
+
>>>
|
|
88
|
+
>>> operations = [ # 4 X1X + 3 1Z1
|
|
89
|
+
>>> (
|
|
90
|
+
>>> 1.0,
|
|
91
|
+
>>> [
|
|
92
|
+
>>> ({"X": 2.0}, [0, 2]),
|
|
93
|
+
>>> ({"Z": 3.0}, [1]),
|
|
94
|
+
>>> ],
|
|
95
|
+
>>> )
|
|
96
|
+
>>> ]
|
|
97
|
+
>>> op = Operator.from_operator_string(basis, nqubits, operations, operators)
|
|
98
|
+
"""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def __rmul__(self, scalar: complex) -> Operator:
|
|
103
|
+
"""
|
|
104
|
+
Scale the operator by a scale factor.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
scalar: the scale factor
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
the scaled operator
|
|
111
|
+
"""
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def __matmul__(self, other: Operator) -> Operator:
|
|
116
|
+
"""
|
|
117
|
+
Compose two operators. The ordering is that
|
|
118
|
+
self is applied after other.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
other: the operator to compose with self
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
the composed operator
|
|
125
|
+
"""
|
|
126
|
+
pass
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Callable, Optional
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from emu_base.base_classes.callback import Callback, AggregationType
|
|
8
|
+
from emu_base.base_classes.aggregators import aggregation_types_definitions
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class Results:
|
|
13
|
+
"""
|
|
14
|
+
This class contains emulation results. Since the results written by
|
|
15
|
+
an emulator are defined through callbacks, the contents of this class
|
|
16
|
+
are not known a-priori.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
statistics: Any = None # Backend-specific data
|
|
20
|
+
|
|
21
|
+
_results: dict[str, dict[int, Any]] = field(default_factory=dict)
|
|
22
|
+
_default_aggregation_types: dict[str, Optional[AggregationType]] = field(
|
|
23
|
+
default_factory=dict
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def aggregate(
|
|
28
|
+
cls,
|
|
29
|
+
results_to_aggregate: list["Results"],
|
|
30
|
+
**aggregator_functions: Callable[[Any], Any],
|
|
31
|
+
) -> "Results":
|
|
32
|
+
if len(results_to_aggregate) == 0:
|
|
33
|
+
raise ValueError("no results to aggregate")
|
|
34
|
+
|
|
35
|
+
if len(results_to_aggregate) == 1:
|
|
36
|
+
return results_to_aggregate[0]
|
|
37
|
+
|
|
38
|
+
stored_callbacks = set(results_to_aggregate[0].get_result_names())
|
|
39
|
+
|
|
40
|
+
if not all(
|
|
41
|
+
set(results.get_result_names()) == stored_callbacks
|
|
42
|
+
for results in results_to_aggregate
|
|
43
|
+
):
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"Monte-Carlo results seem to provide from incompatible simulations: "
|
|
46
|
+
"they do not all contain the same observables"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
aggregated: Results = cls()
|
|
50
|
+
|
|
51
|
+
for stored_callback in stored_callbacks:
|
|
52
|
+
aggregation_type = aggregator_functions.get(
|
|
53
|
+
stored_callback,
|
|
54
|
+
results_to_aggregate[0].get_aggregation_type(stored_callback),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
if aggregation_type is None:
|
|
58
|
+
logging.getLogger("global_logger").warning(
|
|
59
|
+
f"Skipping aggregation of `{stored_callback}`"
|
|
60
|
+
)
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
aggregation_function: Any = (
|
|
64
|
+
aggregation_type
|
|
65
|
+
if callable(aggregation_type)
|
|
66
|
+
else aggregation_types_definitions[aggregation_type]
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
evaluation_times = results_to_aggregate[0].get_result_times(stored_callback)
|
|
70
|
+
if not all(
|
|
71
|
+
results.get_result_times(stored_callback) == evaluation_times
|
|
72
|
+
for results in results_to_aggregate
|
|
73
|
+
):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"Monte-Carlo results seem to provide from incompatible simulations: "
|
|
76
|
+
"the callbacks are not stored at the same times"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
aggregated._results[stored_callback] = {
|
|
80
|
+
t: aggregation_function(
|
|
81
|
+
[result[stored_callback, t] for result in results_to_aggregate]
|
|
82
|
+
)
|
|
83
|
+
for t in evaluation_times
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
return aggregated
|
|
87
|
+
|
|
88
|
+
def store(self, *, callback: Callback, time: Any, value: Any) -> None:
|
|
89
|
+
self._results.setdefault(callback.name, {})
|
|
90
|
+
|
|
91
|
+
if time in self._results[callback.name]:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"A value is already stored for observable '{callback.name}' at time {time}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._results[callback.name][time] = value
|
|
97
|
+
self._default_aggregation_types[callback.name] = callback.default_aggregation_type
|
|
98
|
+
|
|
99
|
+
def __getitem__(self, key: Any) -> Any:
|
|
100
|
+
if isinstance(key, tuple):
|
|
101
|
+
# results["energy", t]
|
|
102
|
+
callback_name, time = key
|
|
103
|
+
|
|
104
|
+
if callback_name not in self._results:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"No value for observable '{callback_name}' has been stored"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if time not in self._results[callback_name]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"No value stored at time {time} for observable '{callback_name}'"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return self._results[callback_name][time]
|
|
115
|
+
|
|
116
|
+
# results["energy"][t]
|
|
117
|
+
assert isinstance(key, str)
|
|
118
|
+
callback_name = key
|
|
119
|
+
if callback_name not in self._results:
|
|
120
|
+
raise ValueError(f"No value for observable '{callback_name}' has been stored")
|
|
121
|
+
|
|
122
|
+
return self._results[key]
|
|
123
|
+
|
|
124
|
+
def get_result_names(self) -> list[str]:
|
|
125
|
+
"""
|
|
126
|
+
get a list of results present in this object
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
list of results by name
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
return list(self._results.keys())
|
|
135
|
+
|
|
136
|
+
def get_result_times(self, name: str) -> list[int]:
|
|
137
|
+
"""
|
|
138
|
+
get a list of times for which the given result has been stored
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
name: name of the result to get times of
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
list of times in ns
|
|
145
|
+
|
|
146
|
+
"""
|
|
147
|
+
return list(self._results[name].keys())
|
|
148
|
+
|
|
149
|
+
def get_result(self, name: str, time: int) -> Any:
|
|
150
|
+
"""
|
|
151
|
+
get the given result at the given time
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
name: name of the result to get
|
|
155
|
+
time: time in ns at which to get the result
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
the result
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
return self._results[name][time]
|
|
162
|
+
|
|
163
|
+
def get_aggregation_type(self, name: str) -> Optional[AggregationType]:
|
|
164
|
+
return self._default_aggregation_types[name]
|
|
165
|
+
|
|
166
|
+
def dump(self, file_path: Path) -> None:
|
|
167
|
+
with file_path.open("w") as file_handle:
|
|
168
|
+
json.dump(
|
|
169
|
+
{
|
|
170
|
+
"observables": self._results,
|
|
171
|
+
"statistics": self.statistics,
|
|
172
|
+
},
|
|
173
|
+
file_handle,
|
|
174
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Iterable
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class State(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Base class enforcing an API for quantum states.
|
|
10
|
+
Each backend will implement its own type of state, and the
|
|
11
|
+
below methods.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def inner(self, other: State) -> float | complex:
|
|
16
|
+
"""
|
|
17
|
+
Compute the inner product between this state and other.
|
|
18
|
+
Note that self is the left state in the inner product,
|
|
19
|
+
so this function is linear in other, and anti-linear in self
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
other: the other state
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
inner product
|
|
26
|
+
"""
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def sample(
|
|
31
|
+
self, num_shots: int, p_false_pos: float = 0.0, p_false_neg: float = 0.0
|
|
32
|
+
) -> Counter[str]:
|
|
33
|
+
"""
|
|
34
|
+
Sample bitstrings from the state, taking into account error rates.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
num_shots: how many bitstrings to sample
|
|
38
|
+
p_false_pos: the rate at which a 0 is read as a 1
|
|
39
|
+
p_false_neg: the rate at which a 1 is read as a 0
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
the measured bitstrings, by count
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def __add__(self, other: State) -> State:
|
|
48
|
+
"""
|
|
49
|
+
Computes the sum of two states.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
other: the other state
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
the summed state
|
|
56
|
+
"""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def __rmul__(self, scalar: complex) -> State:
|
|
61
|
+
"""
|
|
62
|
+
Scale the state by a scale factor.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
scalar: the scale factor
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
the scaled state
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
@abstractmethod
|
|
74
|
+
def from_state_string(
|
|
75
|
+
*, basis: Iterable[str], nqubits: int, strings: dict[str, complex], **kwargs: Any
|
|
76
|
+
) -> State:
|
|
77
|
+
"""
|
|
78
|
+
Construct a state from the pulser abstract representation
|
|
79
|
+
<https://www.notion.so/pasqal/Abstract-State-and-Operator-Definition>
|
|
80
|
+
For a list of existing bases, see
|
|
81
|
+
<https://pulser.readthedocs.io/en/stable/conventions.html>
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
basis: A tuple containing the basis states.
|
|
85
|
+
nqubits: the number of qubits.
|
|
86
|
+
strings: A dictionary mapping state strings to complex or floats amplitudes
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
the state in whatever format the backend provides.
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> afm_string_state = {"rrr": 1.0 / math.sqrt(2), "ggg": 1.0 / math.sqrt(2)}
|
|
93
|
+
>>> afm_state = State.from_state_string(
|
|
94
|
+
>>> basis=("r", "g"), nqubits=3, strings=afm_string_state
|
|
95
|
+
>>> )
|
|
96
|
+
"""
|
|
97
|
+
pass
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from pulser.noise_model import NoiseModel
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_lindblad_operators(
|
|
7
|
+
*, noise_type: str, noise_model: NoiseModel
|
|
8
|
+
) -> list[torch.Tensor]:
|
|
9
|
+
assert noise_type in noise_model.noise_types
|
|
10
|
+
|
|
11
|
+
if noise_type == "relaxation":
|
|
12
|
+
c = math.sqrt(noise_model.relaxation_rate)
|
|
13
|
+
return [
|
|
14
|
+
torch.tensor([[0, c], [0, 0]], dtype=torch.complex128),
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
if noise_type == "dephasing":
|
|
18
|
+
if noise_model.hyperfine_dephasing_rate != 0.0:
|
|
19
|
+
raise NotImplementedError("hyperfine_dephasing_rate is unsupported")
|
|
20
|
+
|
|
21
|
+
c = math.sqrt(noise_model.dephasing_rate / 2)
|
|
22
|
+
return [
|
|
23
|
+
torch.tensor([[-c, 0], [0, c]], dtype=torch.complex128),
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
if noise_type == "depolarizing":
|
|
27
|
+
c = math.sqrt(noise_model.depolarizing_rate / 4)
|
|
28
|
+
return [
|
|
29
|
+
torch.tensor([[0, c], [c, 0]], dtype=torch.complex128),
|
|
30
|
+
torch.tensor([[0, 1j * c], [-1j * c, 0]], dtype=torch.complex128),
|
|
31
|
+
torch.tensor([[-c, 0], [0, c]], dtype=torch.complex128),
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
if noise_type == "eff_noise":
|
|
35
|
+
if not all(op.shape == (2, 2) for op in noise_model.eff_noise_opers):
|
|
36
|
+
raise ValueError("Only 2 * 2 effective noise operator matrices are supported")
|
|
37
|
+
|
|
38
|
+
return [
|
|
39
|
+
math.sqrt(rate)
|
|
40
|
+
* torch.flip(op if isinstance(op, torch.Tensor) else torch.tensor(op), (0, 1))
|
|
41
|
+
for rate, op in zip(noise_model.eff_noise_rates, noise_model.eff_noise_opers)
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
raise ValueError(f"Unknown noise type: {noise_type}")
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BrentsRootFinder:
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
*,
|
|
8
|
+
start: float,
|
|
9
|
+
end: float,
|
|
10
|
+
f_start: float,
|
|
11
|
+
f_end: float,
|
|
12
|
+
epsilon: float = 1e-6,
|
|
13
|
+
):
|
|
14
|
+
self.epsilon = epsilon
|
|
15
|
+
|
|
16
|
+
assert start <= end
|
|
17
|
+
self.a, self.b = start, end
|
|
18
|
+
self.fa = f_start
|
|
19
|
+
self.fb = f_end
|
|
20
|
+
|
|
21
|
+
assert self.fa * self.fb < 0, "Function root needs to be between a and b"
|
|
22
|
+
|
|
23
|
+
# b has to be the better guess
|
|
24
|
+
if abs(self.fa) < abs(self.fb):
|
|
25
|
+
self.a, self.b = self.b, self.a
|
|
26
|
+
self.fa, self.fb = self.fb, self.fa
|
|
27
|
+
|
|
28
|
+
self.c = self.a
|
|
29
|
+
self.d = self.c
|
|
30
|
+
self.fc = self.fa
|
|
31
|
+
|
|
32
|
+
self.bisection = True
|
|
33
|
+
self.current_guess = self.b
|
|
34
|
+
self.next_abscissa: Optional[float] = None
|
|
35
|
+
|
|
36
|
+
def get_next_abscissa(self) -> float:
|
|
37
|
+
if abs(self.fc - self.fa) < self.epsilon or abs(self.fc - self.fb) < self.epsilon:
|
|
38
|
+
# Secant method
|
|
39
|
+
dx = self.fb * (self.b - self.a) / (self.fa - self.fb)
|
|
40
|
+
else:
|
|
41
|
+
# Inverse quadratic interpolation
|
|
42
|
+
s = self.fb / self.fa
|
|
43
|
+
r = self.fb / self.fc
|
|
44
|
+
t = self.fa / self.fc
|
|
45
|
+
q = (t - 1) * (s - 1) * (r - 1)
|
|
46
|
+
p = s * (t * (r - t) * (self.c - self.b) + (r - 1) * (self.b - self.a))
|
|
47
|
+
dx = p / q
|
|
48
|
+
|
|
49
|
+
# Use bisection instead of interpolation
|
|
50
|
+
# if the interpolation is not within bounds.
|
|
51
|
+
delta = abs(2 * self.epsilon * self.b)
|
|
52
|
+
adx = abs(dx)
|
|
53
|
+
delta_bc = abs(self.b - self.c)
|
|
54
|
+
delta_cd = abs(self.c - self.d)
|
|
55
|
+
delta_ab = self.a - self.b
|
|
56
|
+
if (
|
|
57
|
+
(adx >= abs(3 * delta_ab / 4) or dx * delta_ab < 0)
|
|
58
|
+
or (self.bisection and adx >= delta_bc / 2)
|
|
59
|
+
or (not self.bisection and adx >= delta_cd / 2)
|
|
60
|
+
or (self.bisection and delta_bc < delta)
|
|
61
|
+
or (not self.bisection and delta_cd < delta)
|
|
62
|
+
):
|
|
63
|
+
dx = (self.a - self.b) / 2
|
|
64
|
+
self.bisection = True
|
|
65
|
+
else:
|
|
66
|
+
self.bisection = False
|
|
67
|
+
|
|
68
|
+
self.next_abscissa = self.b + dx
|
|
69
|
+
self.d = self.c
|
|
70
|
+
self.c, self.fc = self.b, self.fb
|
|
71
|
+
|
|
72
|
+
return self.next_abscissa
|
|
73
|
+
|
|
74
|
+
def provide_ordinate(self, abscissa: float, ordinate: float) -> None:
|
|
75
|
+
# First argument is just a safety
|
|
76
|
+
assert (
|
|
77
|
+
self.next_abscissa is not None and abscissa == self.next_abscissa
|
|
78
|
+
), "Something went wrong"
|
|
79
|
+
|
|
80
|
+
# Update interval
|
|
81
|
+
if self.fa * ordinate < 0:
|
|
82
|
+
self.b, self.fb = abscissa, ordinate
|
|
83
|
+
else:
|
|
84
|
+
self.a, self.fa = abscissa, ordinate
|
|
85
|
+
|
|
86
|
+
# b has to be the better guess
|
|
87
|
+
if abs(self.fa) < abs(self.fb):
|
|
88
|
+
self.a, self.b = self.b, self.a
|
|
89
|
+
self.fa, self.fb = self.fb, self.fa
|
|
90
|
+
|
|
91
|
+
self.current_guess = self.b
|
|
92
|
+
|
|
93
|
+
def is_converged(self, tolerance: float) -> bool:
|
|
94
|
+
return abs(self.b - self.a) < tolerance
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def find_root_brents(
|
|
98
|
+
f: Callable[[float], float],
|
|
99
|
+
*,
|
|
100
|
+
start: float,
|
|
101
|
+
end: float,
|
|
102
|
+
f_start: Optional[float] = None,
|
|
103
|
+
f_end: Optional[float] = None,
|
|
104
|
+
tolerance: float = 1e-6,
|
|
105
|
+
epsilon: float = 1e-6,
|
|
106
|
+
) -> float:
|
|
107
|
+
"""
|
|
108
|
+
Approximates and returns the zero of a scalar function using Brent's method.
|
|
109
|
+
"""
|
|
110
|
+
f_start = f_start if f_start is not None else f(start)
|
|
111
|
+
f_end = f_end if f_end is not None else f(end)
|
|
112
|
+
|
|
113
|
+
root_finder = BrentsRootFinder(
|
|
114
|
+
start=start, end=end, f_start=f_start, f_end=f_end, epsilon=epsilon
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
while not root_finder.is_converged(tolerance):
|
|
118
|
+
x = root_finder.get_next_abscissa()
|
|
119
|
+
root_finder.provide_ordinate(x, f(x))
|
|
120
|
+
|
|
121
|
+
return root_finder.current_guess
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
DEFAULT_MAX_KRYLOV_DIM: int = 100
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class KrylovExpResult:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
result: torch.Tensor,
|
|
11
|
+
converged: bool,
|
|
12
|
+
happy_breakdown: bool,
|
|
13
|
+
iteration_count: int,
|
|
14
|
+
):
|
|
15
|
+
assert (not happy_breakdown) or converged
|
|
16
|
+
|
|
17
|
+
self.converged = converged
|
|
18
|
+
self.happy_breakdown = happy_breakdown
|
|
19
|
+
self.iteration_count = iteration_count
|
|
20
|
+
self.result = result
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def krylov_exp_impl(
|
|
24
|
+
op: Callable,
|
|
25
|
+
v: torch.Tensor,
|
|
26
|
+
is_hermitian: bool, # note: complex-proportional to its adjoint is enough
|
|
27
|
+
exp_tolerance: float,
|
|
28
|
+
norm_tolerance: float,
|
|
29
|
+
max_krylov_dim: int = DEFAULT_MAX_KRYLOV_DIM,
|
|
30
|
+
) -> KrylovExpResult:
|
|
31
|
+
"""
|
|
32
|
+
Computes exp(op).v using either the Lanczos or Arnoldi algorithm,
|
|
33
|
+
based on the `is_hermitian` flag.
|
|
34
|
+
All inputs must be on the same device.
|
|
35
|
+
|
|
36
|
+
Convergence is checked using the exponential of the "extended T matrix", a criterion
|
|
37
|
+
described in "Expokit: A Software Package for Computing Matrix Exponentials"
|
|
38
|
+
(https://www.maths.uq.edu.au/expokit/paper.pdf).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
initial_norm = v.norm()
|
|
42
|
+
|
|
43
|
+
lanczos_vectors = [v / initial_norm]
|
|
44
|
+
T = torch.zeros(max_krylov_dim + 2, max_krylov_dim + 2, dtype=v.dtype)
|
|
45
|
+
|
|
46
|
+
for j in range(max_krylov_dim):
|
|
47
|
+
w = op(lanczos_vectors[-1])
|
|
48
|
+
|
|
49
|
+
n = w.norm()
|
|
50
|
+
|
|
51
|
+
k_start = max(0, j - 1) if is_hermitian else 0
|
|
52
|
+
for k in range(k_start, j + 1):
|
|
53
|
+
overlap = torch.tensordot(lanczos_vectors[k].conj(), w, dims=w.dim())
|
|
54
|
+
T[k, j] = overlap
|
|
55
|
+
w = w - overlap * lanczos_vectors[k]
|
|
56
|
+
|
|
57
|
+
n2 = w.norm()
|
|
58
|
+
T[j + 1, j] = n2
|
|
59
|
+
|
|
60
|
+
if n2 < norm_tolerance:
|
|
61
|
+
# Happy breakdown
|
|
62
|
+
expd = torch.linalg.matrix_exp(T[: j + 1, : j + 1])
|
|
63
|
+
result = initial_norm * sum(
|
|
64
|
+
a * b for a, b in zip(expd[:, 0], lanczos_vectors)
|
|
65
|
+
)
|
|
66
|
+
return KrylovExpResult(
|
|
67
|
+
result=result, converged=True, happy_breakdown=True, iteration_count=j + 1
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
lanczos_vectors.append(w / n2)
|
|
71
|
+
|
|
72
|
+
# Compute exponential of extended T matrix
|
|
73
|
+
T[j + 2, j + 1] = 1
|
|
74
|
+
expd = torch.linalg.matrix_exp(T[: j + 3, : j + 3])
|
|
75
|
+
|
|
76
|
+
# Local truncation error estimation
|
|
77
|
+
err1 = abs(expd[j + 1, 0])
|
|
78
|
+
err2 = abs(expd[j + 2, 0] * n)
|
|
79
|
+
|
|
80
|
+
err = err1 if err1 < err2 else (err1 * err2 / (err1 - err2))
|
|
81
|
+
|
|
82
|
+
if err < exp_tolerance:
|
|
83
|
+
# Converged
|
|
84
|
+
result = initial_norm * sum(
|
|
85
|
+
a * b for a, b in zip(expd[: len(lanczos_vectors), 0], lanczos_vectors)
|
|
86
|
+
)
|
|
87
|
+
return KrylovExpResult(
|
|
88
|
+
result=result,
|
|
89
|
+
converged=True,
|
|
90
|
+
happy_breakdown=False,
|
|
91
|
+
iteration_count=j + 1,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
result = initial_norm * sum(
|
|
95
|
+
a * b for a, b in zip(expd[: len(lanczos_vectors), 0], lanczos_vectors)
|
|
96
|
+
)
|
|
97
|
+
return KrylovExpResult(
|
|
98
|
+
result=result,
|
|
99
|
+
converged=False,
|
|
100
|
+
happy_breakdown=False,
|
|
101
|
+
iteration_count=max_krylov_dim,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def krylov_exp(
|
|
106
|
+
op: torch.Tensor,
|
|
107
|
+
v: torch.Tensor,
|
|
108
|
+
exp_tolerance: float,
|
|
109
|
+
norm_tolerance: float,
|
|
110
|
+
is_hermitian: bool = True, # note: complex-proportional to its adjoint is enough
|
|
111
|
+
max_krylov_dim: int = DEFAULT_MAX_KRYLOV_DIM,
|
|
112
|
+
) -> torch.Tensor:
|
|
113
|
+
krylov_result = krylov_exp_impl(
|
|
114
|
+
op,
|
|
115
|
+
v,
|
|
116
|
+
is_hermitian=is_hermitian,
|
|
117
|
+
exp_tolerance=exp_tolerance,
|
|
118
|
+
norm_tolerance=norm_tolerance,
|
|
119
|
+
max_krylov_dim=max_krylov_dim,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if not krylov_result.converged:
|
|
123
|
+
raise RecursionError(
|
|
124
|
+
"exponentiation algorithm did not converge to precision in allotted number of steps."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return krylov_result.result
|