qilisdk 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. qilisdk/analog/__init__.py +1 -2
  2. qilisdk/analog/hamiltonian.py +1 -68
  3. qilisdk/analog/schedule.py +288 -313
  4. qilisdk/backends/backend.py +5 -1
  5. qilisdk/backends/cuda_backend.py +9 -5
  6. qilisdk/backends/qutip_backend.py +23 -12
  7. qilisdk/core/__init__.py +4 -0
  8. qilisdk/core/interpolator.py +406 -0
  9. qilisdk/core/parameterizable.py +66 -10
  10. qilisdk/core/variables.py +150 -7
  11. qilisdk/digital/circuit.py +1 -0
  12. qilisdk/digital/circuit_transpiler.py +46 -0
  13. qilisdk/digital/circuit_transpiler_passes/__init__.py +18 -0
  14. qilisdk/digital/circuit_transpiler_passes/circuit_transpiler_pass.py +36 -0
  15. qilisdk/digital/circuit_transpiler_passes/decompose_multi_controlled_gates_pass.py +216 -0
  16. qilisdk/digital/circuit_transpiler_passes/numeric_helpers.py +82 -0
  17. qilisdk/digital/gates.py +12 -2
  18. qilisdk/{speqtrum/experiments → experiments}/__init__.py +13 -2
  19. qilisdk/{speqtrum/experiments → experiments}/experiment_functional.py +90 -2
  20. qilisdk/{speqtrum/experiments → experiments}/experiment_result.py +16 -0
  21. qilisdk/functionals/sampling.py +8 -1
  22. qilisdk/functionals/time_evolution.py +6 -2
  23. qilisdk/functionals/variational_program.py +58 -0
  24. qilisdk/speqtrum/speqtrum.py +360 -130
  25. qilisdk/speqtrum/speqtrum_models.py +108 -19
  26. qilisdk/utils/openfermion/__init__.py +38 -0
  27. qilisdk/{core/algorithm.py → utils/openfermion/__init__.pyi} +2 -3
  28. qilisdk/utils/openfermion/openfermion.py +45 -0
  29. qilisdk/utils/visualization/schedule_renderers.py +16 -8
  30. {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/METADATA +74 -24
  31. {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/RECORD +33 -26
  32. {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/WHEEL +1 -1
  33. qilisdk/analog/linear_schedule.py +0 -121
  34. {qilisdk-0.1.6.dist-info → qilisdk-0.1.7.dist-info}/licenses/LICENCE +0 -0
@@ -89,7 +89,11 @@ class Backend(ABC):
89
89
 
90
90
  def evaluate_sample(parameters: list[float]) -> float:
91
91
  param_names = functional.functional.get_parameter_names()
92
- functional.functional.set_parameters({param_names[i]: param for i, param in enumerate(parameters)})
92
+ param_dict = {param_names[i]: param for i, param in enumerate(parameters)}
93
+ err = functional.check_parameter_constraints(param_dict)
94
+ if err > 0:
95
+ return err
96
+ functional.functional.set_parameters(param_dict)
93
97
  results = self.execute(functional.functional)
94
98
  final_results = functional.cost_function.compute_cost(results)
95
99
  if isinstance(final_results, float):
@@ -25,6 +25,7 @@ from loguru import logger
25
25
  from qilisdk.analog.hamiltonian import Hamiltonian, PauliI, PauliOperator, PauliX, PauliY, PauliZ
26
26
  from qilisdk.backends.backend import Backend
27
27
  from qilisdk.core.qtensor import QTensor
28
+ from qilisdk.digital.circuit_transpiler_passes import DecomposeMultiControlledGatesPass
28
29
  from qilisdk.digital.exceptions import UnsupportedGateError
29
30
  from qilisdk.digital.gates import RX, RY, RZ, SWAP, U1, U2, U3, Adjoint, BasicGate, Controlled, H, I, M, S, T, X, Y, Z
30
31
  from qilisdk.functionals.sampling_result import SamplingResult
@@ -137,13 +138,14 @@ class CudaBackend(Backend):
137
138
  kernel = cudaq.make_kernel()
138
139
  qubits = kernel.qalloc(functional.circuit.nqubits)
139
140
 
140
- for gate in functional.circuit.gates:
141
+ transpiled_circuit = DecomposeMultiControlledGatesPass().run(functional.circuit)
142
+ for gate in transpiled_circuit.gates:
141
143
  if isinstance(gate, Controlled):
142
144
  self._handle_controlled(kernel, gate, qubits[gate.control_qubits[0]], qubits[gate.target_qubits[0]])
143
145
  elif isinstance(gate, Adjoint):
144
146
  self._handle_adjoint(kernel, gate, qubits[gate.target_qubits[0]])
145
147
  elif isinstance(gate, M):
146
- self._handle_M(kernel, gate, functional.circuit, qubits)
148
+ self._handle_M(kernel, gate, transpiled_circuit, qubits)
147
149
  else:
148
150
  handler = self._basic_gate_handlers.get(type(gate), None)
149
151
  if handler is None:
@@ -158,12 +160,14 @@ class CudaBackend(Backend):
158
160
  logger.info("Executing TimeEvolution (T={}, dt={})", functional.schedule.T, functional.schedule.dt)
159
161
  cudaq.set_target("dynamics")
160
162
 
161
- steps = np.linspace(0, functional.schedule.T, (round(functional.schedule.T / functional.schedule.dt) + 1))
163
+ steps = np.linspace(0, functional.schedule.T, (round(functional.schedule.T // functional.schedule.dt) + 1))
164
+ tlist = np.array(functional.schedule.tlist)
165
+ steps = np.union1d(steps, tlist)
162
166
 
163
167
  cuda_schedule = CudaSchedule(steps, ["t"])
164
168
 
165
- def get_schedule(key: str) -> Callable:
166
- return lambda t: (functional.schedule.get_coefficient(t.real, key))
169
+ def get_schedule(key: str) -> Callable[[complex], float]:
170
+ return lambda t: (functional.schedule.coefficients[key][t.real])
167
171
 
168
172
  cuda_hamiltonian = sum(
169
173
  ScalarOperator(get_schedule(key)) * self._hamiltonian_to_cuda(ham)
@@ -27,6 +27,7 @@ from qilisdk.analog.hamiltonian import Hamiltonian, PauliI, PauliOperator
27
27
  from qilisdk.backends.backend import Backend
28
28
  from qilisdk.core.qtensor import QTensor, tensor_prod
29
29
  from qilisdk.digital import RX, RY, RZ, SWAP, U1, U2, U3, Circuit, H, I, M, S, T, X, Y, Z
30
+ from qilisdk.digital.circuit_transpiler_passes import DecomposeMultiControlledGatesPass
30
31
  from qilisdk.digital.exceptions import UnsupportedGateError
31
32
  from qilisdk.digital.gates import Adjoint, BasicGate, Controlled
32
33
  from qilisdk.functionals.sampling_result import SamplingResult
@@ -117,18 +118,17 @@ class QutipBackend(Backend):
117
118
 
118
119
  """
119
120
  logger.info("Executing Sampling (shots={})", functional.nshots)
120
- qutip_circuit = self._get_qutip_circuit(functional.circuit)
121
121
 
122
- counts: Counter[str] = Counter()
123
122
  init_state = tensor(*[basis(2, 0) for _ in range(functional.circuit.nqubits)])
124
123
 
125
124
  measurements_set = set()
126
125
  for m in functional.circuit.gates:
127
126
  if isinstance(m, M):
128
127
  measurements_set.update(list(m.target_qubits))
129
-
130
128
  measurements = sorted(measurements_set)
131
129
 
130
+ transpiled_circuit = DecomposeMultiControlledGatesPass().run(functional.circuit)
131
+ qutip_circuit = self._get_qutip_circuit(transpiled_circuit)
132
132
  sim = CircuitSimulator(qutip_circuit)
133
133
 
134
134
  res = sim.run_statistics(init_state) # runs the full circuit for one shot
@@ -172,7 +172,9 @@ class QutipBackend(Backend):
172
172
  ValueError: if the initial state provided is invalid.
173
173
  """
174
174
  logger.info("Executing TimeEvolution (T={}, dt={})", functional.schedule.T, functional.schedule.dt)
175
- tlist = np.linspace(0, functional.schedule.T, int(functional.schedule.T / functional.schedule.dt))
175
+ steps = np.linspace(0, functional.schedule.T, int(functional.schedule.T // functional.schedule.dt))
176
+ tlist = np.array(functional.schedule.tlist)
177
+ steps = np.union1d(steps, tlist)
176
178
 
177
179
  qutip_hamiltonians = []
178
180
  for hamiltonian in functional.schedule.hamiltonians.values():
@@ -185,7 +187,7 @@ class QutipBackend(Backend):
185
187
  H_t = [
186
188
  [
187
189
  qutip_hamiltonians[i],
188
- np.array([functional.schedule.get_coefficient(t, h) for t in tlist]),
190
+ np.array([functional.schedule.coefficients[h][t] for t in tlist]),
189
191
  ]
190
192
  for i, h in enumerate(functional.schedule.hamiltonians)
191
193
  ]
@@ -308,10 +310,10 @@ class QutipBackend(Backend):
308
310
 
309
311
  def _handle_controlled(self, circuit: QubitCircuit, gate: Controlled) -> None: # noqa: PLR6301
310
312
  """
311
- Handle a controlled gate operation.
313
+ Handle a controlled gate operation by registering a custom QuTiP gate.
312
314
 
313
- This method processes a controlled gate by creating a temporary kernel for the basic gate,
314
- applying its handler, and then integrating it into the main kernel as a controlled operation.
315
+ For non-native controlled gates we construct the block-matrix explicitly, mirroring
316
+ the approach recommended in the QuTiP QIP documentation for custom controlled rotations.
315
317
 
316
318
  Raises:
317
319
  UnsupportedGateError: If the number of control qubits is not equal to one or if the basic gate is unsupported.
@@ -320,13 +322,22 @@ class QutipBackend(Backend):
320
322
  logger.error("Controlled gate with {} control qubits not supported", len(gate.control_qubits))
321
323
  raise UnsupportedGateError
322
324
 
323
- def qutip_controlled_gate() -> Qobj:
324
- return QutipGates.controlled_gate(Qobj(gate.basic_gate.matrix), controls=0, targets=1)
325
-
326
325
  if gate.name == "CNOT":
327
326
  circuit.add_gate("CNOT", targets=[*gate.target_qubits], controls=[*gate.control_qubits])
328
327
  else:
329
- gate_name = "Controlled_" + gate.name
328
+ base_matrix = gate.basic_gate.matrix
329
+ dim_target = base_matrix.shape[0]
330
+ dim_total = 2 * dim_target
331
+ dims = [[2] + [2] * len(gate.target_qubits), [2] + [2] * len(gate.target_qubits)]
332
+
333
+ def qutip_controlled_gate() -> Qobj:
334
+ mat = np.zeros((dim_total, dim_total), dtype=np.complex128)
335
+ mat[:dim_target, :dim_target] = np.eye(dim_target, dtype=np.complex128)
336
+ mat[dim_target:, dim_target:] = base_matrix
337
+ return Qobj(mat, dims=dims)
338
+
339
+ matrix_digest = base_matrix.tobytes().hex()[:16]
340
+ gate_name = f"{gate.name}_{matrix_digest}"
330
341
  if gate_name not in circuit.user_gates:
331
342
  circuit.user_gates[gate_name] = qutip_controlled_gate
332
343
  circuit.add_gate(gate_name, targets=[*gate.control_qubits, *gate.target_qubits])
qilisdk/core/__init__.py CHANGED
@@ -22,6 +22,7 @@ from .variables import (
22
22
  LT,
23
23
  NEQ,
24
24
  BinaryVariable,
25
+ Domain,
25
26
  Equal,
26
27
  GreaterThan,
27
28
  GreaterThanOrEqual,
@@ -30,6 +31,7 @@ from .variables import (
30
31
  NotEqual,
31
32
  Parameter,
32
33
  SpinVariable,
34
+ Term,
33
35
  Variable,
34
36
  )
35
37
 
@@ -42,6 +44,7 @@ __all__ = [
42
44
  "NEQ",
43
45
  "BinaryVariable",
44
46
  "Constraint",
47
+ "Domain",
45
48
  "Equal",
46
49
  "GreaterThan",
47
50
  "GreaterThanOrEqual",
@@ -54,6 +57,7 @@ __all__ = [
54
57
  "Parameter",
55
58
  "QTensor",
56
59
  "SpinVariable",
60
+ "Term",
57
61
  "Variable",
58
62
  "basis_state",
59
63
  "bra",
@@ -0,0 +1,406 @@
1
+ # Copyright 2025 Qilimanjaro Quantum Tech
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ import inspect
17
+ from bisect import bisect_right
18
+ from collections.abc import Callable
19
+ from copy import copy
20
+ from enum import Enum
21
+ from typing import Any, Mapping
22
+
23
+ import numpy as np
24
+
25
+ from qilisdk.core.parameterizable import Parameterizable
26
+ from qilisdk.core.variables import LEQ, BaseVariable, Number, Parameter, Term
27
+ from qilisdk.yaml import yaml
28
+
29
+ _TIME_PARAMETER_NAME = "t"
30
+ PARAMETERIZED_NUMBER = float | Parameter | Term
31
+
32
+ # type aliases just to keep this short
33
+ TimeDict = dict[PARAMETERIZED_NUMBER | tuple[float, float], PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER]]
34
+
35
+
36
+ class Interpolation(str, Enum):
37
+ STEP = "Step function interpolation between schedule points"
38
+ LINEAR = "linear interpolation between schedule points"
39
+
40
+
41
+ def _process_callable(
42
+ function: Callable[[], PARAMETERIZED_NUMBER], current_time: Parameter, **kwargs: Any
43
+ ) -> tuple[PARAMETERIZED_NUMBER, dict[str, Parameter]]:
44
+ # Define variables
45
+ parameters: dict[str, Parameter] = {}
46
+
47
+ # get callable parameters
48
+ c_params = inspect.signature(function).parameters
49
+ EMPTY = inspect.Parameter.empty
50
+ # process callable parameters
51
+ for param_name, param_info in c_params.items():
52
+ # parameter type extraction
53
+ if param_info.annotation is not EMPTY and param_info.annotation is Parameter:
54
+ if param_info.default is not EMPTY:
55
+ parameters[param_info.default.label] = copy(param_info.default)
56
+ else:
57
+ value = kwargs.get(param_name, 0)
58
+ if isinstance(value, (float, int)):
59
+ parameters[param_name] = Parameter(param_name, value)
60
+ elif isinstance(value, Parameter):
61
+ parameters[value.label] = value
62
+
63
+ if _TIME_PARAMETER_NAME in c_params:
64
+ kwargs[_TIME_PARAMETER_NAME] = current_time
65
+ term = function(**kwargs)
66
+ if isinstance(term, Term) and not all(
67
+ (isinstance(v, Parameter) or v.label == _TIME_PARAMETER_NAME) for v in term.variables()
68
+ ):
69
+ raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
70
+ if isinstance(term, BaseVariable) and not (isinstance(term, Parameter) or term.label == _TIME_PARAMETER_NAME):
71
+ raise ValueError("function contains variables that are not time. Only Parameters are allowed.")
72
+ return term, parameters
73
+
74
+
75
+ @yaml.register_class
76
+ class Interpolator(Parameterizable):
77
+ """It's a dictionary that can interpolate between defined indecies."""
78
+
79
+ def __init__(
80
+ self,
81
+ time_dict: TimeDict,
82
+ interpolation: Interpolation = Interpolation.LINEAR,
83
+ nsamples: int = 100,
84
+ **kwargs: Any,
85
+ ) -> None:
86
+ """Initialize an interpolator over discrete points or intervals.
87
+
88
+ Args:
89
+ time_dict (TimeDict): Mapping from time points or intervals to coefficients or callables.
90
+ interpolation (Interpolation): Interpolation rule between provided points (``LINEAR`` or ``STEP``).
91
+ nsamples (int): Number of samples used to expand interval definitions.
92
+ **kwargs: Extra arguments forwarded to callable coefficient processing.
93
+
94
+ Raises:
95
+ ValueError:if the time intervals contain a number of points different than 2.
96
+ """
97
+ super(Interpolator, self).__init__()
98
+ self._interpolation = interpolation
99
+ self._time_dict: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER] = {}
100
+ self._current_time = Parameter("t", 0)
101
+ self._total_time: float | None = None
102
+ self.iter_time_step = 0
103
+ self._cached = False
104
+ self._cached_time: dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER | Number] = {}
105
+ self._tlist: list[PARAMETERIZED_NUMBER] | None = None
106
+ self._fixed_tlist: list[float] | None = None
107
+ self._max_time: PARAMETERIZED_NUMBER | None = None
108
+ self._time_scale_cache: float | None = None
109
+
110
+ for time, coefficient in time_dict.items():
111
+ if isinstance(time, tuple):
112
+ if len(time) != 2: # noqa: PLR2004
113
+ raise ValueError(
114
+ f"time intervals need to be defined by two points, but this interval was provided: {time}"
115
+ )
116
+ for t in np.linspace(0, 1, (max(2, nsamples))):
117
+ time_point = (1 - float(t)) * time[0] + float(t) * time[1]
118
+ if isinstance(time_point, (Parameter, Term)):
119
+ self._extract_parameters(time_point)
120
+ self.add_time_point(time_point, coefficient, **kwargs)
121
+
122
+ else:
123
+ self.add_time_point(time, coefficient, **kwargs)
124
+ self._tlist = self._generate_tlist()
125
+
126
+ time_insertion_list = sorted(
127
+ [k for item in time_dict for k in (item if isinstance(item, tuple) else (item,))],
128
+ key=self._get_value,
129
+ )
130
+ l = len(time_insertion_list)
131
+ for i in range(l):
132
+ t = time_insertion_list[i]
133
+ if isinstance(t, (Parameter, Term)):
134
+ if i > 0:
135
+ term = LEQ(time_insertion_list[i - 1], t)
136
+ if term not in self._parameter_constraints:
137
+ self._parameter_constraints.append(term)
138
+ if i < l - 1:
139
+ term = LEQ(t, time_insertion_list[i + 1])
140
+ if term not in self._parameter_constraints:
141
+ self._parameter_constraints.append(term)
142
+
143
+ def _generate_tlist(self) -> list[PARAMETERIZED_NUMBER]:
144
+ return sorted((self._time_dict.keys()), key=self._get_value)
145
+
146
+ @property
147
+ def tlist(self) -> list[PARAMETERIZED_NUMBER]:
148
+ if self._tlist is None:
149
+ self._tlist = self._generate_tlist()
150
+ if self._max_time is not None:
151
+ if self._time_scale_cache is None:
152
+ max_t = self._get_value(max(self._tlist, key=self._get_value)) or 1
153
+ max_t = max_t if max_t != 0 else 1
154
+ self._time_scale_cache = self._get_value(self._max_time) / max_t
155
+ return [t * self._time_scale_cache for t in self._tlist]
156
+ return self._tlist
157
+
158
+ @property
159
+ def fixed_tlist(self) -> list[float]:
160
+ if self._fixed_tlist:
161
+ return self._fixed_tlist
162
+ self._fixed_tlist = [self._get_value(k) for k in self.tlist]
163
+ return self._fixed_tlist
164
+
165
+ @property
166
+ def total_time(self) -> float:
167
+ if not self._total_time:
168
+ self._total_time = max(self.fixed_tlist)
169
+ return self._total_time
170
+
171
+ def items(self) -> list[tuple[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]]:
172
+ if self._max_time is not None and self._tlist is not None:
173
+ if self._time_scale_cache is None:
174
+ self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
175
+ max(self._tlist, key=self._get_value)
176
+ )
177
+ return [(k * self._time_scale_cache, v) for k, v in self._time_dict.items()]
178
+ return list(self._time_dict.items())
179
+
180
+ def fixed_items(self) -> list[tuple[float, float]]:
181
+ return [(t, self._get_value(self[t], t)) for t in self.fixed_tlist]
182
+
183
+ @property
184
+ def coefficients(self) -> list[PARAMETERIZED_NUMBER]:
185
+ return [self._time_dict[t] for t in self.tlist]
186
+
187
+ @property
188
+ def coefficients_dict(self) -> dict[PARAMETERIZED_NUMBER, PARAMETERIZED_NUMBER]:
189
+ return copy(self._time_dict)
190
+
191
+ @property
192
+ def fixed_coefficeints(self) -> list[float]:
193
+ return [self._get_value(self[t]) for t in self.fixed_tlist]
194
+
195
+ @property
196
+ def parameters(self) -> dict[str, Parameter]:
197
+ return self._parameters
198
+
199
+ def set_max_time(self, max_time: PARAMETERIZED_NUMBER) -> None:
200
+ """
201
+ Rescale all time points to a new maximum duration while keeping relative spacing.
202
+
203
+ Raises:
204
+ ValueError: If the max time is set to zero.
205
+ """
206
+ if self._get_value(max_time) == 0:
207
+ raise ValueError("Setting the max time to zero.")
208
+ self._delete_cache()
209
+ self._max_time = max_time
210
+
211
+ def _delete_cache(self) -> None:
212
+ self._cached = False
213
+ self._total_time = None
214
+ self._cached_time = {}
215
+ self._tlist = None
216
+ self._fixed_tlist = None
217
+ self._time_scale_cache = None
218
+
219
+ def _get_value(self, value: PARAMETERIZED_NUMBER | complex, t: float | None = None) -> float:
220
+ if isinstance(value, (int, float)):
221
+ return value
222
+ if isinstance(value, complex):
223
+ return value.real
224
+ if isinstance(value, Parameter):
225
+ if value.label == _TIME_PARAMETER_NAME:
226
+ if t is None:
227
+ raise ValueError("Can't evaluate Parameter because time is not provided.")
228
+ value.set_value(t)
229
+ return float(value.evaluate())
230
+ if isinstance(value, Term):
231
+ ctx: Mapping[BaseVariable, list[int] | int | float] = {self._current_time: t} if t is not None else {}
232
+ aux = value.evaluate(ctx)
233
+
234
+ return aux.real if isinstance(aux, complex) else float(aux)
235
+ raise ValueError(f"Invalid value of type {type(value)} is being evaluated.")
236
+
237
+ def _extract_parameters(self, element: PARAMETERIZED_NUMBER) -> None:
238
+ if isinstance(element, Parameter) and element.label != _TIME_PARAMETER_NAME:
239
+ self._parameters[element.label] = element
240
+ elif isinstance(element, Term):
241
+ if not element.is_parameterized_term():
242
+ raise ValueError(
243
+ f"Tlist can only contain parameters and no variables, but the term {element} contains objects other than parameters."
244
+ )
245
+ for p in element.variables():
246
+ if isinstance(p, Parameter) and p.label != _TIME_PARAMETER_NAME:
247
+ self._parameters[p.label] = p
248
+
249
+ def add_time_point(
250
+ self,
251
+ time: PARAMETERIZED_NUMBER,
252
+ coefficient: PARAMETERIZED_NUMBER | Callable[..., PARAMETERIZED_NUMBER],
253
+ **kwargs: Any,
254
+ ) -> None:
255
+ self._extract_parameters(time)
256
+ coeff = coefficient
257
+ if callable(coeff):
258
+ self._current_time.set_value(self._get_value(time))
259
+ coeff, _params = _process_callable(coeff, self._current_time, **kwargs)
260
+ self._extract_parameters(coeff)
261
+ if len(_params) > 0:
262
+ self._parameters.update(_params)
263
+ elif isinstance(coeff, (int, float, Parameter, Term)):
264
+ self._extract_parameters(coeff)
265
+ else:
266
+ raise ValueError
267
+ if self._max_time is not None and self._tlist is not None:
268
+ if self._time_scale_cache is None:
269
+ self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
270
+ max(self._tlist, key=self._get_value)
271
+ )
272
+ time /= self._time_scale_cache
273
+ self._time_dict[time] = coeff
274
+ self._delete_cache()
275
+
276
+ def set_parameter_values(self, values: list[float]) -> None:
277
+ self._delete_cache()
278
+ super().set_parameter_values(values)
279
+
280
+ def set_parameters(self, parameters: dict[str, int | float]) -> None:
281
+ self._delete_cache()
282
+ super().set_parameters(parameters)
283
+
284
+ def set_parameter_bounds(self, ranges: dict[str, tuple[float, float]]) -> None:
285
+ self._delete_cache()
286
+ super().set_parameter_bounds(ranges)
287
+
288
+ def get_coefficient(self, time_step: float) -> float:
289
+ time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step)
290
+ val = self.get_coefficient_expression(time_step=time_step)
291
+
292
+ if self._max_time is not None and self._tlist is not None:
293
+ if self._time_scale_cache is None:
294
+ self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
295
+ max(self._tlist, key=self._get_value)
296
+ )
297
+ time_step /= self._time_scale_cache
298
+
299
+ return self._get_value(val, time_step)
300
+
301
+ def get_coefficient_expression(self, time_step: float) -> Number | Term | Parameter:
302
+ time_step = time_step.item() if isinstance(time_step, np.generic) else self._get_value(time_step)
303
+
304
+ # generate the tlist
305
+ self._tlist = self._generate_tlist()
306
+
307
+ if time_step in self.fixed_tlist:
308
+ indx = self.fixed_tlist.index(time_step)
309
+ return self._time_dict[self._tlist[indx]]
310
+ if time_step in self._cached_time:
311
+ return self._cached_time[time_step]
312
+
313
+ if self._max_time is not None and self._tlist is not None:
314
+ if self._time_scale_cache is None:
315
+ self._time_scale_cache = self._get_value(self._max_time) / self._get_value(
316
+ max(self._tlist, key=self._get_value)
317
+ )
318
+ time_step /= self._time_scale_cache
319
+ factor = self._time_scale_cache or 1.0
320
+
321
+ result = None
322
+ if self._interpolation is Interpolation.STEP:
323
+ result = self._get_coefficient_expression_step(time_step)
324
+ if self._interpolation is Interpolation.LINEAR:
325
+ result = self._get_coefficient_expression_linear(time_step)
326
+
327
+ if result is None:
328
+ raise ValueError(f"interpolation {self._interpolation.value} is not supported!")
329
+ self._cached_time[time_step * factor] = result
330
+ return result
331
+
332
+ def _get_coefficient_expression_step(self, time_step: float) -> Number | Term | Parameter:
333
+ self._tlist = self._generate_tlist()
334
+ prev_indx = bisect_right(self._tlist, time_step, key=self._get_value) - 1
335
+ if prev_indx >= len(self._tlist):
336
+ prev_indx = -1
337
+ prev_time_step = self._tlist[prev_indx]
338
+ return self._time_dict[prev_time_step]
339
+
340
+ def _get_coefficient_expression_linear(self, time_step: float) -> Number | Term | Parameter:
341
+ self._tlist = self._generate_tlist()
342
+ insert_pos = bisect_right(self._tlist, time_step, key=self._get_value)
343
+
344
+ prev_idx = self._tlist[insert_pos - 1] if insert_pos else None
345
+ next_idx = self._tlist[insert_pos] if insert_pos < len(self._tlist) else None
346
+ prev_expr = self._time_dict[prev_idx] if prev_idx is not None else None
347
+ next_expr = self._time_dict[next_idx] if next_idx is not None else None
348
+
349
+ def _linear_value(
350
+ t0: PARAMETERIZED_NUMBER, v0: PARAMETERIZED_NUMBER, t1: PARAMETERIZED_NUMBER, v1: PARAMETERIZED_NUMBER
351
+ ) -> PARAMETERIZED_NUMBER:
352
+ t0_val = self._get_value(t0)
353
+ t1_val = self._get_value(t1)
354
+ if t0_val == t1_val:
355
+ raise ValueError(
356
+ f"Ambigous evaluation: The same time step {t0_val} has two different coefficient assignation ({v0} and {v1})."
357
+ )
358
+ alpha: float = (time_step - t0_val) / (t1_val - t0_val)
359
+ next_is_term = isinstance(v1, (Term, Parameter))
360
+ prev_is_term = isinstance(v0, (Term, Parameter))
361
+ if next_is_term and prev_is_term and v1 != v0:
362
+ v1 = self._get_value(v1, t1_val)
363
+ v0 = self._get_value(v0, t0_val)
364
+ elif next_is_term and not prev_is_term:
365
+ v1 = self._get_value(v1, t1_val)
366
+ elif prev_is_term and not next_is_term:
367
+ v0 = self._get_value(v0, t0_val)
368
+
369
+ return v1 * alpha + v0 * (1 - alpha)
370
+
371
+ if prev_expr is None and next_expr is not None:
372
+ if len(self._tlist) == 1:
373
+ return next_expr
374
+ first_idx = self._tlist[0]
375
+ second_idx = self._tlist[1]
376
+ return _linear_value(first_idx, self._time_dict[first_idx], second_idx, self._time_dict[second_idx])
377
+
378
+ if next_expr is None and prev_expr is not None:
379
+ if len(self._tlist) == 1:
380
+ return prev_expr
381
+ last_idx = self._tlist[-1]
382
+ penultimate_idx = self._tlist[-2]
383
+ return _linear_value(penultimate_idx, self._time_dict[penultimate_idx], last_idx, self._time_dict[last_idx])
384
+ if prev_expr is None and next_expr is None:
385
+ return 0
386
+
387
+ if next_idx is None or prev_idx is None or prev_expr is None or next_expr is None:
388
+ raise ValueError("Something unexpected happened while retrieving the coefficient.")
389
+ return _linear_value(prev_idx, prev_expr, next_idx, next_expr)
390
+
391
+ def __getitem__(self, time_step: float) -> float:
392
+ return self.get_coefficient(time_step)
393
+
394
+ def __len__(self) -> int:
395
+ return len(self.tlist)
396
+
397
+ def __iter__(self) -> "Interpolator":
398
+ self.iter_time_step = 0
399
+ return self
400
+
401
+ def __next__(self) -> float:
402
+ if self.iter_time_step < self.__len__():
403
+ result = self[self.fixed_tlist[self.iter_time_step]]
404
+ self.iter_time_step += 1
405
+ return result
406
+ raise StopIteration