qilisdk 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qilisdk/analog/__init__.py +1 -2
- qilisdk/analog/hamiltonian.py +1 -68
- qilisdk/analog/schedule.py +288 -313
- qilisdk/backends/backend.py +5 -1
- qilisdk/backends/cuda_backend.py +9 -5
- qilisdk/backends/qutip_backend.py +23 -12
- qilisdk/core/__init__.py +4 -0
- qilisdk/core/interpolator.py +406 -0
- qilisdk/core/parameterizable.py +66 -10
- qilisdk/core/variables.py +150 -7
- qilisdk/digital/circuit.py +1 -0
- qilisdk/digital/circuit_transpiler.py +46 -0
- qilisdk/digital/circuit_transpiler_passes/__init__.py +18 -0
- qilisdk/digital/circuit_transpiler_passes/circuit_transpiler_pass.py +36 -0
- qilisdk/digital/circuit_transpiler_passes/decompose_multi_controlled_gates_pass.py +216 -0
- qilisdk/digital/circuit_transpiler_passes/numeric_helpers.py +82 -0
- qilisdk/digital/gates.py +12 -2
- qilisdk/{speqtrum/experiments → experiments}/__init__.py +13 -2
- qilisdk/{speqtrum/experiments → experiments}/experiment_functional.py +90 -2
- qilisdk/{speqtrum/experiments → experiments}/experiment_result.py +16 -0
- qilisdk/functionals/sampling.py +8 -1
- qilisdk/functionals/time_evolution.py +6 -2
- qilisdk/functionals/variational_program.py +58 -0
- qilisdk/speqtrum/speqtrum.py +360 -130
- qilisdk/speqtrum/speqtrum_models.py +108 -19
- qilisdk/utils/openfermion/__init__.py +38 -0
- qilisdk/{core/algorithm.py → utils/openfermion/__init__.pyi} +2 -3
- qilisdk/utils/openfermion/openfermion.py +45 -0
- qilisdk/utils/visualization/schedule_renderers.py +16 -8
- {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/METADATA +74 -24
- {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/RECORD +33 -26
- {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/WHEEL +1 -1
- qilisdk/analog/linear_schedule.py +0 -121
- {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/licenses/LICENCE +0 -0
qilisdk/backends/backend.py
CHANGED
|
@@ -89,7 +89,11 @@ class Backend(ABC):
|
|
|
89
89
|
|
|
90
90
|
def evaluate_sample(parameters: list[float]) -> float:
|
|
91
91
|
param_names = functional.functional.get_parameter_names()
|
|
92
|
-
|
|
92
|
+
param_dict = {param_names[i]: param for i, param in enumerate(parameters)}
|
|
93
|
+
err = functional.check_parameter_constraints(param_dict)
|
|
94
|
+
if err > 0:
|
|
95
|
+
return err
|
|
96
|
+
functional.functional.set_parameters(param_dict)
|
|
93
97
|
results = self.execute(functional.functional)
|
|
94
98
|
final_results = functional.cost_function.compute_cost(results)
|
|
95
99
|
if isinstance(final_results, float):
|
qilisdk/backends/cuda_backend.py
CHANGED
|
@@ -25,6 +25,7 @@ from loguru import logger
|
|
|
25
25
|
from qilisdk.analog.hamiltonian import Hamiltonian, PauliI, PauliOperator, PauliX, PauliY, PauliZ
|
|
26
26
|
from qilisdk.backends.backend import Backend
|
|
27
27
|
from qilisdk.core.qtensor import QTensor
|
|
28
|
+
from qilisdk.digital.circuit_transpiler_passes import DecomposeMultiControlledGatesPass
|
|
28
29
|
from qilisdk.digital.exceptions import UnsupportedGateError
|
|
29
30
|
from qilisdk.digital.gates import RX, RY, RZ, SWAP, U1, U2, U3, Adjoint, BasicGate, Controlled, H, I, M, S, T, X, Y, Z
|
|
30
31
|
from qilisdk.functionals.sampling_result import SamplingResult
|
|
@@ -137,13 +138,14 @@ class CudaBackend(Backend):
|
|
|
137
138
|
kernel = cudaq.make_kernel()
|
|
138
139
|
qubits = kernel.qalloc(functional.circuit.nqubits)
|
|
139
140
|
|
|
140
|
-
|
|
141
|
+
transpiled_circuit = DecomposeMultiControlledGatesPass().run(functional.circuit)
|
|
142
|
+
for gate in transpiled_circuit.gates:
|
|
141
143
|
if isinstance(gate, Controlled):
|
|
142
144
|
self._handle_controlled(kernel, gate, qubits[gate.control_qubits[0]], qubits[gate.target_qubits[0]])
|
|
143
145
|
elif isinstance(gate, Adjoint):
|
|
144
146
|
self._handle_adjoint(kernel, gate, qubits[gate.target_qubits[0]])
|
|
145
147
|
elif isinstance(gate, M):
|
|
146
|
-
self._handle_M(kernel, gate,
|
|
148
|
+
self._handle_M(kernel, gate, transpiled_circuit, qubits)
|
|
147
149
|
else:
|
|
148
150
|
handler = self._basic_gate_handlers.get(type(gate), None)
|
|
149
151
|
if handler is None:
|
|
@@ -158,12 +160,14 @@ class CudaBackend(Backend):
|
|
|
158
160
|
logger.info("Executing TimeEvolution (T={}, dt={})", functional.schedule.T, functional.schedule.dt)
|
|
159
161
|
cudaq.set_target("dynamics")
|
|
160
162
|
|
|
161
|
-
steps = np.linspace(0, functional.schedule.T, (round(functional.schedule.T
|
|
163
|
+
steps = np.linspace(0, functional.schedule.T, (round(functional.schedule.T // functional.schedule.dt) + 1))
|
|
164
|
+
tlist = np.array(functional.schedule.tlist)
|
|
165
|
+
steps = np.union1d(steps, tlist)
|
|
162
166
|
|
|
163
167
|
cuda_schedule = CudaSchedule(steps, ["t"])
|
|
164
168
|
|
|
165
|
-
def get_schedule(key: str) -> Callable:
|
|
166
|
-
return lambda t: (functional.schedule.
|
|
169
|
+
def get_schedule(key: str) -> Callable[[complex], float]:
|
|
170
|
+
return lambda t: (functional.schedule.coefficients[key][t.real])
|
|
167
171
|
|
|
168
172
|
cuda_hamiltonian = sum(
|
|
169
173
|
ScalarOperator(get_schedule(key)) * self._hamiltonian_to_cuda(ham)
|
|
@@ -27,6 +27,7 @@ from qilisdk.analog.hamiltonian import Hamiltonian, PauliI, PauliOperator
|
|
|
27
27
|
from qilisdk.backends.backend import Backend
|
|
28
28
|
from qilisdk.core.qtensor import QTensor, tensor_prod
|
|
29
29
|
from qilisdk.digital import RX, RY, RZ, SWAP, U1, U2, U3, Circuit, H, I, M, S, T, X, Y, Z
|
|
30
|
+
from qilisdk.digital.circuit_transpiler_passes import DecomposeMultiControlledGatesPass
|
|
30
31
|
from qilisdk.digital.exceptions import UnsupportedGateError
|
|
31
32
|
from qilisdk.digital.gates import Adjoint, BasicGate, Controlled
|
|
32
33
|
from qilisdk.functionals.sampling_result import SamplingResult
|
|
@@ -117,18 +118,17 @@ class QutipBackend(Backend):
|
|
|
117
118
|
|
|
118
119
|
"""
|
|
119
120
|
logger.info("Executing Sampling (shots={})", functional.nshots)
|
|
120
|
-
qutip_circuit = self._get_qutip_circuit(functional.circuit)
|
|
121
121
|
|
|
122
|
-
counts: Counter[str] = Counter()
|
|
123
122
|
init_state = tensor(*[basis(2, 0) for _ in range(functional.circuit.nqubits)])
|
|
124
123
|
|
|
125
124
|
measurements_set = set()
|
|
126
125
|
for m in functional.circuit.gates:
|
|
127
126
|
if isinstance(m, M):
|
|
128
127
|
measurements_set.update(list(m.target_qubits))
|
|
129
|
-
|
|
130
128
|
measurements = sorted(measurements_set)
|
|
131
129
|
|
|
130
|
+
transpiled_circuit = DecomposeMultiControlledGatesPass().run(functional.circuit)
|
|
131
|
+
qutip_circuit = self._get_qutip_circuit(transpiled_circuit)
|
|
132
132
|
sim = CircuitSimulator(qutip_circuit)
|
|
133
133
|
|
|
134
134
|
res = sim.run_statistics(init_state) # runs the full circuit for one shot
|
|
@@ -172,7 +172,9 @@ class QutipBackend(Backend):
|
|
|
172
172
|
ValueError: if the initial state provided is invalid.
|
|
173
173
|
"""
|
|
174
174
|
logger.info("Executing TimeEvolution (T={}, dt={})", functional.schedule.T, functional.schedule.dt)
|
|
175
|
-
|
|
175
|
+
steps = np.linspace(0, functional.schedule.T, int(functional.schedule.T // functional.schedule.dt))
|
|
176
|
+
tlist = np.array(functional.schedule.tlist)
|
|
177
|
+
steps = np.union1d(steps, tlist)
|
|
176
178
|
|
|
177
179
|
qutip_hamiltonians = []
|
|
178
180
|
for hamiltonian in functional.schedule.hamiltonians.values():
|
|
@@ -185,7 +187,7 @@ class QutipBackend(Backend):
|
|
|
185
187
|
H_t = [
|
|
186
188
|
[
|
|
187
189
|
qutip_hamiltonians[i],
|
|
188
|
-
np.array([functional.schedule.
|
|
190
|
+
np.array([functional.schedule.coefficients[h][t] for t in tlist]),
|
|
189
191
|
]
|
|
190
192
|
for i, h in enumerate(functional.schedule.hamiltonians)
|
|
191
193
|
]
|
|
@@ -308,10 +310,10 @@ class QutipBackend(Backend):
|
|
|
308
310
|
|
|
309
311
|
def _handle_controlled(self, circuit: QubitCircuit, gate: Controlled) -> None: # noqa: PLR6301
|
|
310
312
|
"""
|
|
311
|
-
Handle a controlled gate operation.
|
|
313
|
+
Handle a controlled gate operation by registering a custom QuTiP gate.
|
|
312
314
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
+
For non-native controlled gates we construct the block-matrix explicitly, mirroring
|
|
316
|
+
the approach recommended in the QuTiP QIP documentation for custom controlled rotations.
|
|
315
317
|
|
|
316
318
|
Raises:
|
|
317
319
|
UnsupportedGateError: If the number of control qubits is not equal to one or if the basic gate is unsupported.
|
|
@@ -320,13 +322,22 @@ class QutipBackend(Backend):
|
|
|
320
322
|
logger.error("Controlled gate with {} control qubits not supported", len(gate.control_qubits))
|
|
321
323
|
raise UnsupportedGateError
|
|
322
324
|
|
|
323
|
-
def qutip_controlled_gate() -> Qobj:
|
|
324
|
-
return QutipGates.controlled_gate(Qobj(gate.basic_gate.matrix), controls=0, targets=1)
|
|
325
|
-
|
|
326
325
|
if gate.name == "CNOT":
|
|
327
326
|
circuit.add_gate("CNOT", targets=[*gate.target_qubits], controls=[*gate.control_qubits])
|
|
328
327
|
else:
|
|
329
|
-
|
|
328
|
+
base_matrix = gate.basic_gate.matrix
|
|
329
|
+
dim_target = base_matrix.shape[0]
|
|
330
|
+
dim_total = 2 * dim_target
|
|
331
|
+
dims = [[2] + [2] * len(gate.target_qubits), [2] + [2] * len(gate.target_qubits)]
|
|
332
|
+
|
|
333
|
+
def qutip_controlled_gate() -> Qobj:
|
|
334
|
+
mat = np.zeros((dim_total, dim_total), dtype=np.complex128)
|
|
335
|
+
mat[:dim_target, :dim_target] = np.eye(dim_target, dtype=np.complex128)
|
|
336
|
+
mat[dim_target:, dim_target:] = base_matrix
|
|
337
|
+
return Qobj(mat, dims=dims)
|
|
338
|
+
|
|
339
|
+
matrix_digest = base_matrix.tobytes().hex()[:16]
|
|
340
|
+
gate_name = f"{gate.name}_{matrix_digest}"
|
|
330
341
|
if gate_name not in circuit.user_gates:
|
|
331
342
|
circuit.user_gates[gate_name] = qutip_controlled_gate
|
|
332
343
|
circuit.add_gate(gate_name, targets=[*gate.control_qubits, *gate.target_qubits])
|
qilisdk/core/__init__.py
CHANGED
|
@@ -22,6 +22,7 @@ from .variables import (
|
|
|
22
22
|
LT,
|
|
23
23
|
NEQ,
|
|
24
24
|
BinaryVariable,
|
|
25
|
+
Domain,
|
|
25
26
|
Equal,
|
|
26
27
|
GreaterThan,
|
|
27
28
|
GreaterThanOrEqual,
|
|
@@ -30,6 +31,7 @@ from .variables import (
|
|
|
30
31
|
NotEqual,
|
|
31
32
|
Parameter,
|
|
32
33
|
SpinVariable,
|
|
34
|
+
Term,
|
|
33
35
|
Variable,
|
|
34
36
|
)
|
|
35
37
|
|
|
@@ -42,6 +44,7 @@ __all__ = [
|
|
|
42
44
|
"NEQ",
|
|
43
45
|
"BinaryVariable",
|
|
44
46
|
"Constraint",
|
|
47
|
+
"Domain",
|
|
45
48
|
"Equal",
|
|
46
49
|
"GreaterThan",
|
|
47
50
|
"GreaterThanOrEqual",
|
|
@@ -54,6 +57,7 @@ __all__ = [
|
|
|
54
57
|
"Parameter",
|
|
55
58
|
"QTensor",
|
|
56
59
|
"SpinVariable",
|
|
60
|
+
"Term",
|
|
57
61
|
"Variable",
|
|
58
62
|
"basis_state",
|
|
59
63
|
"bra",
|
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
# Copyright 2025 Qilimanjaro Quantum Tech
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import inspect
|
|
17
|
+
from bisect import bisect_right
|
|
18
|
+
from collections.abc import Callable
|
|
19
|
+
from copy import copy
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Any, Mapping
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
|
|
25
|
+
from qilisdk.core.parameterizable import Parameterizable
|
|
26
|
+
from qilisdk.core.variables import LEQ, BaseVariable, Number, Parameter, Term
|
|
27
|
+
from qilisdk.yaml import yaml
|
|
28
|
+
|
|
29
|
+
_TIME_PARAMETER_NAME = "t"
|
|
30
|
+
PARAMETERIZED_NUMBER = float | Parameter | Term
|
|
31
|
+
|
|
32
|
+
# type aliases just to keep this short
|
|
33
|
+
TimeDict = dict[PARAMETERIZED_NUMBER | tuple[float, float], PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Interpolation(str, Enum):
|
|
37
|
+
STEP = "Step function interpolation between schedule points"
|
|
38
|
+
LINEAR = "linear interpolation between schedule points"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _process_callable(
|
|
42
|
+
function: Callable[[], PARAMETERIZED_NUMBER], current_time: Parameter, **kwargs: Any
|
|
43
|
+
) -> tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]:
|
|
44
|
+
# Define variables
|
|
45
|
+
parameters: dict[str, Parameter] = {}
|
|
46
|
+
|
|
47
|
+
# get callable parameters
|
|
48
|
+
c_params = inspect.signature(function).parameters
|
|
49
|
+
EMPTY = inspect.Parameter.empty
|
|
50
|
+
# process callable parameters
|
|
51
|
+
for param_name, param_info in c_params.items():
|
|
52
|
+
# parameter type extraction
|
|
53
|
+
if param_info.annotation is not EMPTY and param_info.annotation is Parameter:
|
|
54
|
+
if param_info.default is not EMPTY:
|
|
55
|
+
parameters[param_info.default.label] = copy(param_info.default)
|
|
56
|
+
else:
|
|
57
|
+
value = kwargs.get(param_name, 0)
|
|
58
|
+
if isinstance(value, (float, int)):
|
|
59
|
+
parameters[param_name] = Parameter(param_name, value)
|
|
60
|
+
elif isinstance(value, Parameter):
|
|
61
|
+
parameters[value.label] = value
|
|
62
|
+
|
|
63
|
+
if _TIME_PARAMETER_NAME in c_params:
|
|
64
|
+
kwargs[_TIME_PARAMETER_NAME] = current_time
|
|
65
|
+
term = function(**kwargs)
|
|
66
|
+
if isinstance(term, Term) and not all(
|
|
67
|
+
(isinstance(v, Parameter) or v.label == _TIME_PARAMETER_NAME) for v in term.variables()
|
|
68
|
+
):
|
|
69
|
+
raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
|
|
70
|
+
if isinstance(term, BaseVariable) and not (isinstance(term, Parameter) or term.label == _TIME_PARAMETER_NAME):
|
|
71
|
+
raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
|
|
72
|
+
return term, parameters
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@yaml.register_class
|
|
76
|
+
class Interpolator(Parameterizable):
|
|
77
|
+
"""It's a dictionary that can interpolate between defined indecies."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
time_dict: TimeDict,
|
|
82
|
+
interpolation: Interpolation = Interpolation.LINEAR,
|
|
83
|
+
nsamples: int = 100,
|
|
84
|
+
**kwargs: Any,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""Initialize an interpolator over discrete points or intervals.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
time_dict (TimeDict): Mapping from time points or intervals to coefficients or callables.
|
|
90
|
+
interpolation (Interpolation): Interpolation rule between provided points (``LINEAR`` or ``STEP``).
|
|
91
|
+
nsamples (int): Number of samples used to expand interval definitions.
|
|
92
|
+
**kwargs: Extra arguments forwarded to callable coefficient processing.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError:if the time intervals contain a number of points different than 2.
|
|
96
|
+
"""
|
|
97
|
+
super(Interpolator, self).__init__()
|
|
98
|
+
self._interpolation = interpolation
|
|
99
|
+
self._time_dict: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER] = {}
|
|
100
|
+
self._current_time = Parameter("t", 0)
|
|
101
|
+
self._total_time: float | None = None
|
|
102
|
+
self.iter_time_step = 0
|
|
103
|
+
self._cached = False
|
|
104
|
+
self._cached_time: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER | Number] = {}
|
|
105
|
+
self._tlist: list[PARAMETERIZED_NUMBER] | None = None
|
|
106
|
+
self._fixed_tlist: list[float] | None = None
|
|
107
|
+
self._max_time: PARAMETERIZED_NUMBER | None = None
|
|
108
|
+
self._time_scale_cache: float | None = None
|
|
109
|
+
|
|
110
|
+
for time, coefficient in time_dict.items():
|
|
111
|
+
if isinstance(time, tuple):
|
|
112
|
+
if len(time) != 2: # noqa: PLR2004
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"time intervals need to be defined by two points, but this interval was provided: {time}"
|
|
115
|
+
)
|
|
116
|
+
for t in np.linspace(0, 1, (max(2, nsamples))):
|
|
117
|
+
time_point = (1 - float(t)) * time[0] + float(t) * time[1]
|
|
118
|
+
if isinstance(time_point, (Parameter, Term)):
|
|
119
|
+
self._extract_parameters(time_point)
|
|
120
|
+
self.add_time_point(time_point, coefficient, **kwargs)
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
self.add_time_point(time, coefficient, **kwargs)
|
|
124
|
+
self._tlist = self._generate_tlist()
|
|
125
|
+
|
|
126
|
+
time_insertion_list = sorted(
|
|
127
|
+
[k for item in time_dict for k in (item if isinstance(item, tuple) else (item,))],
|
|
128
|
+
key=self._get_value,
|
|
129
|
+
)
|
|
130
|
+
l = len(time_insertion_list)
|
|
131
|
+
for i in range(l):
|
|
132
|
+
t = time_insertion_list[i]
|
|
133
|
+
if isinstance(t, (Parameter, Term)):
|
|
134
|
+
if i > 0:
|
|
135
|
+
term = LEQ(time_insertion_list[i - 1], t)
|
|
136
|
+
if term not in self._parameter_constraints:
|
|
137
|
+
self._parameter_constraints.append(term)
|
|
138
|
+
if i < l - 1:
|
|
139
|
+
term = LEQ(t, time_insertion_list[i + 1])
|
|
140
|
+
if term not in self._parameter_constraints:
|
|
141
|
+
self._parameter_constraints.append(term)
|
|
142
|
+
|
|
143
|
+
def _generate_tlist(self) -> list[PARAMETERIZED_NUMBER]:
|
|
144
|
+
return sorted((self._time_dict.keys()), key=self._get_value)
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def tlist(self) -> list[PARAMETERIZED_NUMBER]:
|
|
148
|
+
if self._tlist is None:
|
|
149
|
+
self._tlist = self._generate_tlist()
|
|
150
|
+
if self._max_time is not None:
|
|
151
|
+
if self._time_scale_cache is None:
|
|
152
|
+
max_t = self._get_value(max(self._tlist, key=self._get_value)) or 1
|
|
153
|
+
max_t = max_t if max_t != 0 else 1
|
|
154
|
+
self._time_scale_cache = self._get_value(self._max_time) / max_t
|
|
155
|
+
return [t * self._time_scale_cache for t in self._tlist]
|
|
156
|
+
return self._tlist
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def fixed_tlist(self) -> list[float]:
|
|
160
|
+
if self._fixed_tlist:
|
|
161
|
+
return self._fixed_tlist
|
|
162
|
+
self._fixed_tlist = [self._get_value(k) for k in self.tlist]
|
|
163
|
+
return self._fixed_tlist
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def total_time(self) -> float:
|
|
167
|
+
if not self._total_time:
|
|
168
|
+
self._total_time = max(self.fixed_tlist)
|
|
169
|
+
return self._total_time
|
|
170
|
+
|
|
171
|
+
def items(self) -> list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]:
|
|
172
|
+
if self._max_time is not None and self._tlist is not None:
|
|
173
|
+
if self._time_scale_cache is None:
|
|
174
|
+
self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
|
|
175
|
+
max(self._tlist, key=self._get_value)
|
|
176
|
+
)
|
|
177
|
+
return [(k * self._time_scale_cache, v) for k, v in self._time_dict.items()]
|
|
178
|
+
return list(self._time_dict.items())
|
|
179
|
+
|
|
180
|
+
def fixed_items(self) -> list[tuple[float, float]]:
|
|
181
|
+
return [(t, self._get_value(self[t], t)) for t in self.fixed_tlist]
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def coefficients(self) -> list[PARAMETERIZED_NUMBER]:
|
|
185
|
+
return [self._time_dict[t] for t in self.tlist]
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def coefficients_dict(self) -> dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]:
|
|
189
|
+
return copy(self._time_dict)
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def fixed_coefficeints(self) -> list[float]:
|
|
193
|
+
return [self._get_value(self[t]) for t in self.fixed_tlist]
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def parameters(self) -> dict[str, Parameter]:
|
|
197
|
+
return self._parameters
|
|
198
|
+
|
|
199
|
+
def set_max_time(self, max_time: PARAMETERIZED_NUMBER) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Rescale all time points to a new maximum duration while keeping relative spacing.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
ValueError: If the max time is set to zero.
|
|
205
|
+
"""
|
|
206
|
+
if self._get_value(max_time) == 0:
|
|
207
|
+
raise ValueError("Setting the max time to zero.")
|
|
208
|
+
self._delete_cache()
|
|
209
|
+
self._max_time = max_time
|
|
210
|
+
|
|
211
|
+
def _delete_cache(self) -> None:
|
|
212
|
+
self._cached = False
|
|
213
|
+
self._total_time = None
|
|
214
|
+
self._cached_time = {}
|
|
215
|
+
self._tlist = None
|
|
216
|
+
self._fixed_tlist = None
|
|
217
|
+
self._time_scale_cache = None
|
|
218
|
+
|
|
219
|
+
def _get_value(self, value: PARAMETERIZED_NUMBER | complex, t: float | None = None) -> float:
|
|
220
|
+
if isinstance(value, (int, float)):
|
|
221
|
+
return value
|
|
222
|
+
if isinstance(value, complex):
|
|
223
|
+
return value.real
|
|
224
|
+
if isinstance(value, Parameter):
|
|
225
|
+
if value.label == _TIME_PARAMETER_NAME:
|
|
226
|
+
if t is None:
|
|
227
|
+
raise ValueError("Can't evaluate Parameter because time is not provided.")
|
|
228
|
+
value.set_value(t)
|
|
229
|
+
return float(value.evaluate())
|
|
230
|
+
if isinstance(value, Term):
|
|
231
|
+
ctx: Mapping[BaseVariable, list[int] | int | float] = {self._current_time: t} if t is not None else {}
|
|
232
|
+
aux = value.evaluate(ctx)
|
|
233
|
+
|
|
234
|
+
return aux.real if isinstance(aux, complex) else float(aux)
|
|
235
|
+
raise ValueError(f"Invalid value of type {type(value)} is being evaluated.")
|
|
236
|
+
|
|
237
|
+
def _extract_parameters(self, element: PARAMETERIZED_NUMBER) -> None:
|
|
238
|
+
if isinstance(element, Parameter) and element.label != _TIME_PARAMETER_NAME:
|
|
239
|
+
self._parameters[element.label] = element
|
|
240
|
+
elif isinstance(element, Term):
|
|
241
|
+
if not element.is_parameterized_term():
|
|
242
|
+
raise ValueError(
|
|
243
|
+
f"Tlist can only contain parameters and no variables, but the term {element} contains objects other than parameters."
|
|
244
|
+
)
|
|
245
|
+
for p in element.variables():
|
|
246
|
+
if isinstance(p, Parameter) and p.label != _TIME_PARAMETER_NAME:
|
|
247
|
+
self._parameters[p.label] = p
|
|
248
|
+
|
|
249
|
+
def add_time_point(
|
|
250
|
+
self,
|
|
251
|
+
time: PARAMETERIZED_NUMBER,
|
|
252
|
+
coefficient: PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER],
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> None:
|
|
255
|
+
self._extract_parameters(time)
|
|
256
|
+
coeff = coefficient
|
|
257
|
+
if callable(coeff):
|
|
258
|
+
self._current_time.set_value(self._get_value(time))
|
|
259
|
+
coeff, _params = _process_callable(coeff, self._current_time, **kwargs)
|
|
260
|
+
self._extract_parameters(coeff)
|
|
261
|
+
if len(_params) > 0:
|
|
262
|
+
self._parameters.update(_params)
|
|
263
|
+
elif isinstance(coeff, (int, float, Parameter, Term)):
|
|
264
|
+
self._extract_parameters(coeff)
|
|
265
|
+
else:
|
|
266
|
+
raise ValueError
|
|
267
|
+
if self._max_time is not None and self._tlist is not None:
|
|
268
|
+
if self._time_scale_cache is None:
|
|
269
|
+
self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
|
|
270
|
+
max(self._tlist, key=self._get_value)
|
|
271
|
+
)
|
|
272
|
+
time /= self._time_scale_cache
|
|
273
|
+
self._time_dict[time] = coeff
|
|
274
|
+
self._delete_cache()
|
|
275
|
+
|
|
276
|
+
def set_parameter_values(self, values: list[float]) -> None:
|
|
277
|
+
self._delete_cache()
|
|
278
|
+
super().set_parameter_values(values)
|
|
279
|
+
|
|
280
|
+
def set_parameters(self, parameters: dict[str, int | float]) -> None:
|
|
281
|
+
self._delete_cache()
|
|
282
|
+
super().set_parameters(parameters)
|
|
283
|
+
|
|
284
|
+
def set_parameter_bounds(self, ranges: dict[str, tuple[float, float]]) -> None:
|
|
285
|
+
self._delete_cache()
|
|
286
|
+
super().set_parameter_bounds(ranges)
|
|
287
|
+
|
|
288
|
+
def get_coefficient(self, time_step: float) -> float:
|
|
289
|
+
time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step)
|
|
290
|
+
val = self.get_coefficient_expression(time_step=time_step)
|
|
291
|
+
|
|
292
|
+
if self._max_time is not None and self._tlist is not None:
|
|
293
|
+
if self._time_scale_cache is None:
|
|
294
|
+
self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
|
|
295
|
+
max(self._tlist, key=self._get_value)
|
|
296
|
+
)
|
|
297
|
+
time_step /= self._time_scale_cache
|
|
298
|
+
|
|
299
|
+
return self._get_value(val, time_step)
|
|
300
|
+
|
|
301
|
+
def get_coefficient_expression(self, time_step: float) -> Number | Term | Parameter:
|
|
302
|
+
time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step)
|
|
303
|
+
|
|
304
|
+
# generate the tlist
|
|
305
|
+
self._tlist = self._generate_tlist()
|
|
306
|
+
|
|
307
|
+
if time_step in self.fixed_tlist:
|
|
308
|
+
indx = self.fixed_tlist.index(time_step)
|
|
309
|
+
return self._time_dict[self._tlist[indx]]
|
|
310
|
+
if time_step in self._cached_time:
|
|
311
|
+
return self._cached_time[time_step]
|
|
312
|
+
|
|
313
|
+
if self._max_time is not None and self._tlist is not None:
|
|
314
|
+
if self._time_scale_cache is None:
|
|
315
|
+
self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
|
|
316
|
+
max(self._tlist, key=self._get_value)
|
|
317
|
+
)
|
|
318
|
+
time_step /= self._time_scale_cache
|
|
319
|
+
factor = self._time_scale_cache or 1.0
|
|
320
|
+
|
|
321
|
+
result = None
|
|
322
|
+
if self._interpolation is Interpolation.STEP:
|
|
323
|
+
result = self._get_coefficient_expression_step(time_step)
|
|
324
|
+
if self._interpolation is Interpolation.LINEAR:
|
|
325
|
+
result = self._get_coefficient_expression_linear(time_step)
|
|
326
|
+
|
|
327
|
+
if result is None:
|
|
328
|
+
raise ValueError(f"interpolation {self._interpolation.value} is not supported!")
|
|
329
|
+
self._cached_time[time_step * factor] = result
|
|
330
|
+
return result
|
|
331
|
+
|
|
332
|
+
def _get_coefficient_expression_step(self, time_step: float) -> Number | Term | Parameter:
|
|
333
|
+
self._tlist = self._generate_tlist()
|
|
334
|
+
prev_indx = bisect_right(self._tlist, time_step, key=self._get_value) - 1
|
|
335
|
+
if prev_indx >= len(self._tlist):
|
|
336
|
+
prev_indx = -1
|
|
337
|
+
prev_time_step = self._tlist[prev_indx]
|
|
338
|
+
return self._time_dict[prev_time_step]
|
|
339
|
+
|
|
340
|
+
def _get_coefficient_expression_linear(self, time_step: float) -> Number | Term | Parameter:
|
|
341
|
+
self._tlist = self._generate_tlist()
|
|
342
|
+
insert_pos = bisect_right(self._tlist, time_step, key=self._get_value)
|
|
343
|
+
|
|
344
|
+
prev_idx = self._tlist[insert_pos - 1] if insert_pos else None
|
|
345
|
+
next_idx = self._tlist[insert_pos] if insert_pos < len(self._tlist) else None
|
|
346
|
+
prev_expr = self._time_dict[prev_idx] if prev_idx is not None else None
|
|
347
|
+
next_expr = self._time_dict[next_idx] if next_idx is not None else None
|
|
348
|
+
|
|
349
|
+
def _linear_value(
|
|
350
|
+
t0: PARAMETERIZED_NUMBER, v0: PARAMETERIZED_NUMBER, t1: PARAMETERIZED_NUMBER, v1: PARAMETERIZED_NUMBER
|
|
351
|
+
) -> PARAMETERIZED_NUMBER:
|
|
352
|
+
t0_val = self._get_value(t0)
|
|
353
|
+
t1_val = self._get_value(t1)
|
|
354
|
+
if t0_val == t1_val:
|
|
355
|
+
raise ValueError(
|
|
356
|
+
f"Ambigous evaluation: The same time step {t0_val} has two different coefficient assignation ({v0} and {v1})."
|
|
357
|
+
)
|
|
358
|
+
alpha: float = (time_step - t0_val) / (t1_val - t0_val)
|
|
359
|
+
next_is_term = isinstance(v1, (Term, Parameter))
|
|
360
|
+
prev_is_term = isinstance(v0, (Term, Parameter))
|
|
361
|
+
if next_is_term and prev_is_term and v1 != v0:
|
|
362
|
+
v1 = self._get_value(v1, t1_val)
|
|
363
|
+
v0 = self._get_value(v0, t0_val)
|
|
364
|
+
elif next_is_term and not prev_is_term:
|
|
365
|
+
v1 = self._get_value(v1, t1_val)
|
|
366
|
+
elif prev_is_term and not next_is_term:
|
|
367
|
+
v0 = self._get_value(v0, t0_val)
|
|
368
|
+
|
|
369
|
+
return v1 * alpha + v0 * (1 - alpha)
|
|
370
|
+
|
|
371
|
+
if prev_expr is None and next_expr is not None:
|
|
372
|
+
if len(self._tlist) == 1:
|
|
373
|
+
return next_expr
|
|
374
|
+
first_idx = self._tlist[0]
|
|
375
|
+
second_idx = self._tlist[1]
|
|
376
|
+
return _linear_value(first_idx, self._time_dict[first_idx], second_idx, self._time_dict[second_idx])
|
|
377
|
+
|
|
378
|
+
if next_expr is None and prev_expr is not None:
|
|
379
|
+
if len(self._tlist) == 1:
|
|
380
|
+
return prev_expr
|
|
381
|
+
last_idx = self._tlist[-1]
|
|
382
|
+
penultimate_idx = self._tlist[-2]
|
|
383
|
+
return _linear_value(penultimate_idx, self._time_dict[penultimate_idx], last_idx, self._time_dict[last_idx])
|
|
384
|
+
if prev_expr is None and next_expr is None:
|
|
385
|
+
return 0
|
|
386
|
+
|
|
387
|
+
if next_idx is None or prev_idx is None or prev_expr is None or next_expr is None:
|
|
388
|
+
raise ValueError("Something unexpected happened while retrieving the coefficient.")
|
|
389
|
+
return _linear_value(prev_idx, prev_expr, next_idx, next_expr)
|
|
390
|
+
|
|
391
|
+
def __getitem__(self, time_step: float) -> float:
|
|
392
|
+
return self.get_coefficient(time_step)
|
|
393
|
+
|
|
394
|
+
def __len__(self) -> int:
|
|
395
|
+
return len(self.tlist)
|
|
396
|
+
|
|
397
|
+
def __iter__(self) -> "Interpolator":
|
|
398
|
+
self.iter_time_step = 0
|
|
399
|
+
return self
|
|
400
|
+
|
|
401
|
+
def __next__(self) -> float:
|
|
402
|
+
if self.iter_time_step < self.__len__():
|
|
403
|
+
result = self[self.fixed_tlist[self.iter_time_step]]
|
|
404
|
+
self.iter_time_step += 1
|
|
405
|
+
return result
|
|
406
|
+
raise StopIteration
|