mxlpy 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -4
- mxlpy/fit.py +1414 -0
- mxlpy/fuzzy.py +139 -0
- mxlpy/identify.py +5 -5
- mxlpy/integrators/int_scipy.py +4 -3
- mxlpy/meta/codegen_latex.py +1 -0
- mxlpy/meta/source_tools.py +1 -1
- mxlpy/model.py +41 -24
- mxlpy/nn/__init__.py +5 -0
- mxlpy/nn/_equinox.py +293 -0
- mxlpy/nn/_torch.py +59 -2
- mxlpy/npe/__init__.py +5 -0
- mxlpy/npe/_equinox.py +344 -0
- mxlpy/npe/_torch.py +6 -22
- mxlpy/parallel.py +73 -4
- mxlpy/surrogates/__init__.py +5 -0
- mxlpy/surrogates/_equinox.py +195 -0
- mxlpy/surrogates/_torch.py +5 -20
- mxlpy/symbolic/symbolic_model.py +30 -3
- mxlpy/types.py +1 -0
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/METADATA +4 -1
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/RECORD +24 -23
- mxlpy/fit/__init__.py +0 -9
- mxlpy/fit/common.py +0 -298
- mxlpy/fit/global_.py +0 -534
- mxlpy/fit/local_.py +0 -591
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.25.0.dist-info → mxlpy-0.26.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/fit.py
ADDED
@@ -0,0 +1,1414 @@
|
|
1
|
+
"""Fitting routines.
|
2
|
+
|
3
|
+
Single model, single data routines
|
4
|
+
----------------------------------
|
5
|
+
- `steady_state`
|
6
|
+
- `time_course`
|
7
|
+
- `protocol_time_course`
|
8
|
+
|
9
|
+
Multiple model, single data routines
|
10
|
+
------------------------------------
|
11
|
+
- `ensemble_steady_state`
|
12
|
+
- `ensemble_time_course`
|
13
|
+
- `ensemble_protocol_time_course`
|
14
|
+
|
15
|
+
A carousel is a special case of an ensemble, where the general
|
16
|
+
structure (e.g. stoichiometries) is the same, while the reactions kinetics
|
17
|
+
can vary
|
18
|
+
- `carousel_steady_state`
|
19
|
+
- `carousel_time_course`
|
20
|
+
- `carousel_protocol_time_course`
|
21
|
+
|
22
|
+
Multiple model, multiple data
|
23
|
+
-----------------------------
|
24
|
+
- `joint_steady_state`
|
25
|
+
- `joint_time_course`
|
26
|
+
- `joint_protocol_time_course`
|
27
|
+
|
28
|
+
Multiple model, multiple data, multiple methods
|
29
|
+
-----------------------------------------------
|
30
|
+
Here we also allow to run different methods (e.g. steady-state vs time courses)
|
31
|
+
for each combination of model:data.
|
32
|
+
|
33
|
+
- `joint_mixed`
|
34
|
+
|
35
|
+
Minimizers
|
36
|
+
----------
|
37
|
+
- LocalScipyMinimizer, including common methods such as Nelder-Mead or L-BFGS-B
|
38
|
+
- GlobalScipyMinimizer, including common methods such as basin hopping or dual annealing
|
39
|
+
|
40
|
+
Loss functions
|
41
|
+
--------------
|
42
|
+
- rmse
|
43
|
+
|
44
|
+
"""
|
45
|
+
|
46
|
+
from __future__ import annotations
|
47
|
+
|
48
|
+
import logging
|
49
|
+
import multiprocessing
|
50
|
+
from collections.abc import Callable
|
51
|
+
from copy import deepcopy
|
52
|
+
from dataclasses import dataclass
|
53
|
+
from functools import partial
|
54
|
+
from typing import TYPE_CHECKING, Literal, Protocol
|
55
|
+
|
56
|
+
import numpy as np
|
57
|
+
import pandas as pd
|
58
|
+
import pebble
|
59
|
+
from scipy.optimize import (
|
60
|
+
basinhopping,
|
61
|
+
differential_evolution,
|
62
|
+
direct,
|
63
|
+
dual_annealing,
|
64
|
+
minimize,
|
65
|
+
shgo,
|
66
|
+
)
|
67
|
+
from wadler_lindig import pformat
|
68
|
+
|
69
|
+
from mxlpy import parallel
|
70
|
+
from mxlpy.model import Model
|
71
|
+
from mxlpy.simulator import Simulator
|
72
|
+
from mxlpy.types import Array, IntegratorType, cast
|
73
|
+
|
74
|
+
if TYPE_CHECKING:
|
75
|
+
import pandas as pd
|
76
|
+
from scipy.optimize._optimize import OptimizeResult
|
77
|
+
|
78
|
+
from mxlpy.carousel import Carousel
|
79
|
+
from mxlpy.model import Model
|
80
|
+
|
81
|
+
LOGGER = logging.getLogger(__name__)
|
82
|
+
|
83
|
+
|
84
|
+
__all__ = [
|
85
|
+
"Bounds",
|
86
|
+
"EnsembleFitResult",
|
87
|
+
"FitResult",
|
88
|
+
"FitSettings",
|
89
|
+
"GlobalScipyMinimizer",
|
90
|
+
"InitialGuess",
|
91
|
+
"JointFitResult",
|
92
|
+
"LOGGER",
|
93
|
+
"LocalScipyMinimizer",
|
94
|
+
"LossFn",
|
95
|
+
"MinResult",
|
96
|
+
"Minimizer",
|
97
|
+
"MixedSettings",
|
98
|
+
"ResFn",
|
99
|
+
"ResidualFn",
|
100
|
+
"ResidualProtocol",
|
101
|
+
"carousel_protocol_time_course",
|
102
|
+
"carousel_steady_state",
|
103
|
+
"carousel_time_course",
|
104
|
+
"cosine_similarity",
|
105
|
+
"ensemble_protocol_time_course",
|
106
|
+
"ensemble_steady_state",
|
107
|
+
"ensemble_time_course",
|
108
|
+
"joint_mixed",
|
109
|
+
"joint_protocol_time_course",
|
110
|
+
"joint_steady_state",
|
111
|
+
"joint_time_course",
|
112
|
+
"mae",
|
113
|
+
"mean",
|
114
|
+
"mean_absolute_percentage",
|
115
|
+
"mean_squared",
|
116
|
+
"mean_squared_logarithmic",
|
117
|
+
"protocol_time_course",
|
118
|
+
"protocol_time_course_residual",
|
119
|
+
"rmse",
|
120
|
+
"steady_state",
|
121
|
+
"steady_state_residual",
|
122
|
+
"time_course",
|
123
|
+
"time_course_residual",
|
124
|
+
]
|
125
|
+
|
126
|
+
type InitialGuess = dict[str, float]
|
127
|
+
|
128
|
+
type Bounds = dict[str, tuple[float | None, float | None]]
|
129
|
+
type ResFn = Callable[[Array], float]
|
130
|
+
type LossFn = Callable[
|
131
|
+
[
|
132
|
+
pd.DataFrame | pd.Series,
|
133
|
+
pd.DataFrame | pd.Series,
|
134
|
+
],
|
135
|
+
float,
|
136
|
+
]
|
137
|
+
|
138
|
+
|
139
|
+
type Minimizer = Callable[
|
140
|
+
[
|
141
|
+
ResidualFn,
|
142
|
+
InitialGuess,
|
143
|
+
Bounds,
|
144
|
+
],
|
145
|
+
MinResult | None,
|
146
|
+
]
|
147
|
+
|
148
|
+
|
149
|
+
class ResidualProtocol(Protocol):
|
150
|
+
"""Protocol for steady state residual functions.
|
151
|
+
|
152
|
+
This is the user-facing variant, for stuff like
|
153
|
+
- `fit.steady_state`
|
154
|
+
- `fit.time_course`
|
155
|
+
- `fit.protocol_time_course`
|
156
|
+
|
157
|
+
The settings are later partially applied to yield ResidualFn
|
158
|
+
"""
|
159
|
+
|
160
|
+
def __call__(
|
161
|
+
self,
|
162
|
+
updates: dict[str, float],
|
163
|
+
settings: _Settings,
|
164
|
+
) -> float:
|
165
|
+
"""Calculate residual error between model steady state and experimental data."""
|
166
|
+
...
|
167
|
+
|
168
|
+
|
169
|
+
class ResidualFn(Protocol):
|
170
|
+
"""Protocol for steady state residual functions.
|
171
|
+
|
172
|
+
This is the internal version, which is produced by partial
|
173
|
+
application of `settings` of `ResidualProtocol`
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __call__(
|
177
|
+
self,
|
178
|
+
updates: dict[str, float],
|
179
|
+
) -> float:
|
180
|
+
"""Calculate residual error between model steady state and experimental data."""
|
181
|
+
...
|
182
|
+
|
183
|
+
|
184
|
+
@dataclass
|
185
|
+
class FitSettings:
|
186
|
+
"""Settings for a fit."""
|
187
|
+
|
188
|
+
model: Model
|
189
|
+
data: pd.Series | pd.DataFrame
|
190
|
+
y0: dict[str, float] | None = None
|
191
|
+
integrator: IntegratorType | None = None
|
192
|
+
loss_fn: LossFn | None = None
|
193
|
+
protocol: pd.DataFrame | None = None
|
194
|
+
|
195
|
+
|
196
|
+
@dataclass
|
197
|
+
class MixedSettings:
|
198
|
+
"""Settings for a fit."""
|
199
|
+
|
200
|
+
model: Model
|
201
|
+
data: pd.Series | pd.DataFrame
|
202
|
+
residual_fn: ResidualFn
|
203
|
+
y0: dict[str, float] | None = None
|
204
|
+
integrator: IntegratorType | None = None
|
205
|
+
loss_fn: LossFn | None = None
|
206
|
+
protocol: pd.DataFrame | None = None
|
207
|
+
|
208
|
+
|
209
|
+
@dataclass
|
210
|
+
class _Settings:
|
211
|
+
"""Non user-facing version of FitSettings."""
|
212
|
+
|
213
|
+
model: Model
|
214
|
+
data: pd.Series | pd.DataFrame
|
215
|
+
y0: dict[str, float] | None
|
216
|
+
integrator: IntegratorType | None
|
217
|
+
loss_fn: LossFn
|
218
|
+
p_names: list[str]
|
219
|
+
v_names: list[str]
|
220
|
+
protocol: pd.DataFrame | None = None
|
221
|
+
residual_fn: ResidualFn | None = None
|
222
|
+
|
223
|
+
|
224
|
+
@dataclass
|
225
|
+
class MinResult:
|
226
|
+
"""Result of a minimization operation."""
|
227
|
+
|
228
|
+
parameters: dict[str, float]
|
229
|
+
residual: float
|
230
|
+
|
231
|
+
def __repr__(self) -> str:
|
232
|
+
"""Return default representation."""
|
233
|
+
return pformat(self)
|
234
|
+
|
235
|
+
|
236
|
+
@dataclass
|
237
|
+
class FitResult:
|
238
|
+
"""Result of a fit operation."""
|
239
|
+
|
240
|
+
model: Model
|
241
|
+
best_pars: dict[str, float]
|
242
|
+
loss: float
|
243
|
+
|
244
|
+
def __repr__(self) -> str:
|
245
|
+
"""Return default representation."""
|
246
|
+
return pformat(self)
|
247
|
+
|
248
|
+
|
249
|
+
@dataclass
|
250
|
+
class EnsembleFitResult:
|
251
|
+
"""Result of a carousel fit operation."""
|
252
|
+
|
253
|
+
fits: list[FitResult]
|
254
|
+
|
255
|
+
def __repr__(self) -> str:
|
256
|
+
"""Return default representation."""
|
257
|
+
return pformat(self)
|
258
|
+
|
259
|
+
def get_best_fit(self) -> FitResult:
|
260
|
+
"""Get the best fit from the carousel."""
|
261
|
+
return min(self.fits, key=lambda x: x.loss)
|
262
|
+
|
263
|
+
|
264
|
+
@dataclass
|
265
|
+
class JointFitResult:
|
266
|
+
"""Result of joint fit operation."""
|
267
|
+
|
268
|
+
best_pars: dict[str, float]
|
269
|
+
loss: float
|
270
|
+
|
271
|
+
def __repr__(self) -> str:
|
272
|
+
"""Return default representation."""
|
273
|
+
return pformat(self)
|
274
|
+
|
275
|
+
|
276
|
+
###############################################################################
|
277
|
+
# loss fns
|
278
|
+
###############################################################################
|
279
|
+
|
280
|
+
|
281
|
+
def mean(
|
282
|
+
y_pred: pd.DataFrame | pd.Series,
|
283
|
+
y_true: pd.DataFrame | pd.Series,
|
284
|
+
) -> float:
|
285
|
+
"""Calculate root mean square error between model and data."""
|
286
|
+
return cast(float, np.mean(y_pred - y_true))
|
287
|
+
|
288
|
+
|
289
|
+
def mean_squared(
|
290
|
+
y_pred: pd.DataFrame | pd.Series,
|
291
|
+
y_true: pd.DataFrame | pd.Series,
|
292
|
+
) -> float:
|
293
|
+
"""Calculate mean square error between model and data."""
|
294
|
+
return cast(float, np.mean(np.square(y_pred - y_true)))
|
295
|
+
|
296
|
+
|
297
|
+
def rmse(
|
298
|
+
y_pred: pd.DataFrame | pd.Series,
|
299
|
+
y_true: pd.DataFrame | pd.Series,
|
300
|
+
) -> float:
|
301
|
+
"""Calculate root mean square error between model and data."""
|
302
|
+
return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
|
303
|
+
|
304
|
+
|
305
|
+
def mae(
|
306
|
+
y_pred: pd.DataFrame | pd.Series,
|
307
|
+
y_true: pd.DataFrame | pd.Series,
|
308
|
+
) -> float:
|
309
|
+
"""Calculate mean absolute error."""
|
310
|
+
return cast(float, np.mean(np.abs(y_true - y_pred)))
|
311
|
+
|
312
|
+
|
313
|
+
def mean_absolute_percentage(
|
314
|
+
y_pred: pd.DataFrame | pd.Series,
|
315
|
+
y_true: pd.DataFrame | pd.Series,
|
316
|
+
) -> float:
|
317
|
+
"""Calculate mean absolute error."""
|
318
|
+
return cast(float, 100 * np.mean(np.abs((y_true - y_pred) / y_pred)))
|
319
|
+
|
320
|
+
|
321
|
+
def mean_squared_logarithmic(
|
322
|
+
y_pred: pd.DataFrame | pd.Series,
|
323
|
+
y_true: pd.DataFrame | pd.Series,
|
324
|
+
) -> float:
|
325
|
+
"""Calculate root mean square error between model and data."""
|
326
|
+
return cast(float, np.mean(np.square(np.log(y_pred + 1) - np.log(y_true + 1))))
|
327
|
+
|
328
|
+
|
329
|
+
def cosine_similarity(
|
330
|
+
y_pred: pd.DataFrame | pd.Series,
|
331
|
+
y_true: pd.DataFrame | pd.Series,
|
332
|
+
) -> float:
|
333
|
+
"""Calculate root mean square error between model and data."""
|
334
|
+
norm = np.linalg.norm
|
335
|
+
return cast(float, -np.sum(norm(y_pred, 2) * norm(y_true, 2)))
|
336
|
+
|
337
|
+
|
338
|
+
###############################################################################
|
339
|
+
# Minimizers
|
340
|
+
###############################################################################
|
341
|
+
|
342
|
+
|
343
|
+
@dataclass
|
344
|
+
class LocalScipyMinimizer:
|
345
|
+
"""Local multivariate minimization using scipy.optimize.
|
346
|
+
|
347
|
+
See Also
|
348
|
+
--------
|
349
|
+
https://docs.scipy.org/doc/scipy/reference/optimize.html#local-multivariate-optimization
|
350
|
+
|
351
|
+
"""
|
352
|
+
|
353
|
+
tol: float = 1e-6
|
354
|
+
method: Literal[
|
355
|
+
"Nelder-Mead",
|
356
|
+
"Powell",
|
357
|
+
"CG",
|
358
|
+
"BFGS",
|
359
|
+
"Newton-CG",
|
360
|
+
"L-BFGS-B",
|
361
|
+
"TNC",
|
362
|
+
"COBYLA",
|
363
|
+
"COBYQA",
|
364
|
+
"SLSQP",
|
365
|
+
"trust-constr",
|
366
|
+
"dogleg",
|
367
|
+
"trust-ncg",
|
368
|
+
"trust-exact",
|
369
|
+
"trust-krylov",
|
370
|
+
] = "L-BFGS-B"
|
371
|
+
|
372
|
+
def __call__(
|
373
|
+
self,
|
374
|
+
residual_fn: ResidualFn,
|
375
|
+
p0: dict[str, float],
|
376
|
+
bounds: Bounds,
|
377
|
+
) -> MinResult | None:
|
378
|
+
"""Call minimzer."""
|
379
|
+
par_names = list(p0.keys())
|
380
|
+
|
381
|
+
res = minimize(
|
382
|
+
lambda par_values: residual_fn(_pack_updates(par_values, par_names)),
|
383
|
+
x0=list(p0.values()),
|
384
|
+
bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
|
385
|
+
method=self.method,
|
386
|
+
tol=self.tol,
|
387
|
+
)
|
388
|
+
if res.success:
|
389
|
+
return MinResult(
|
390
|
+
parameters=dict(
|
391
|
+
zip(
|
392
|
+
p0,
|
393
|
+
res.x,
|
394
|
+
strict=True,
|
395
|
+
),
|
396
|
+
),
|
397
|
+
residual=res.fun,
|
398
|
+
)
|
399
|
+
|
400
|
+
LOGGER.warning("Minimisation failed due to %s", res.message)
|
401
|
+
return None
|
402
|
+
|
403
|
+
|
404
|
+
@dataclass
|
405
|
+
class GlobalScipyMinimizer:
|
406
|
+
"""Global iate minimization using scipy.optimize.
|
407
|
+
|
408
|
+
See Also
|
409
|
+
--------
|
410
|
+
https://docs.scipy.org/doc/scipy/reference/optimize.html#global-optimization
|
411
|
+
|
412
|
+
"""
|
413
|
+
|
414
|
+
tol: float = 1e-6
|
415
|
+
method: Literal[
|
416
|
+
"basinhopping",
|
417
|
+
"differential_evolution",
|
418
|
+
"shgo",
|
419
|
+
"dual_annealing",
|
420
|
+
"direct",
|
421
|
+
] = "basinhopping"
|
422
|
+
|
423
|
+
def __call__(
|
424
|
+
self,
|
425
|
+
residual_fn: ResidualFn,
|
426
|
+
p0: dict[str, float],
|
427
|
+
bounds: Bounds,
|
428
|
+
) -> MinResult | None:
|
429
|
+
"""Minimize residual fn."""
|
430
|
+
res: OptimizeResult
|
431
|
+
par_names = list(p0.keys())
|
432
|
+
res_fn: ResFn = lambda par_values: residual_fn( # noqa: E731
|
433
|
+
_pack_updates(par_values, par_names)
|
434
|
+
)
|
435
|
+
|
436
|
+
if self.method == "basinhopping":
|
437
|
+
res = basinhopping(
|
438
|
+
res_fn,
|
439
|
+
x0=list(p0.values()),
|
440
|
+
)
|
441
|
+
elif self.method == "differential_evolution":
|
442
|
+
res = differential_evolution(res_fn, bounds)
|
443
|
+
elif self.method == "shgo":
|
444
|
+
res = shgo(res_fn, bounds)
|
445
|
+
elif self.method == "dual_annealing":
|
446
|
+
res = dual_annealing(res_fn, bounds)
|
447
|
+
elif self.method == "direct":
|
448
|
+
res = direct(res_fn, bounds)
|
449
|
+
else:
|
450
|
+
msg = f"Unknown method {self.method}"
|
451
|
+
raise NotImplementedError(msg)
|
452
|
+
if res.success:
|
453
|
+
return MinResult(
|
454
|
+
parameters=dict(
|
455
|
+
zip(
|
456
|
+
p0,
|
457
|
+
res.x,
|
458
|
+
strict=True,
|
459
|
+
),
|
460
|
+
),
|
461
|
+
residual=res.fun,
|
462
|
+
)
|
463
|
+
|
464
|
+
LOGGER.warning("Minimisation failed.")
|
465
|
+
return None
|
466
|
+
|
467
|
+
|
468
|
+
###############################################################################
|
469
|
+
# Residual functions
|
470
|
+
###############################################################################
|
471
|
+
|
472
|
+
|
473
|
+
def _pack_updates(
|
474
|
+
par_values: Array,
|
475
|
+
par_names: list[str],
|
476
|
+
) -> dict[str, float]:
|
477
|
+
return dict(
|
478
|
+
zip(
|
479
|
+
par_names,
|
480
|
+
par_values,
|
481
|
+
strict=True,
|
482
|
+
)
|
483
|
+
)
|
484
|
+
|
485
|
+
|
486
|
+
def steady_state_residual(
|
487
|
+
updates: dict[str, float],
|
488
|
+
settings: _Settings,
|
489
|
+
) -> float:
|
490
|
+
"""Calculate residual error between model steady state and experimental data."""
|
491
|
+
model = settings.model
|
492
|
+
if (y0 := settings.y0) is not None:
|
493
|
+
model.update_variables(y0)
|
494
|
+
for p in settings.p_names:
|
495
|
+
model.update_parameter(p, updates[p])
|
496
|
+
for p in settings.v_names:
|
497
|
+
model.update_variable(p, updates[p])
|
498
|
+
|
499
|
+
res = (
|
500
|
+
Simulator(
|
501
|
+
model,
|
502
|
+
integrator=settings.integrator,
|
503
|
+
)
|
504
|
+
.simulate_to_steady_state()
|
505
|
+
.get_result()
|
506
|
+
)
|
507
|
+
if res is None:
|
508
|
+
return cast(float, np.inf)
|
509
|
+
|
510
|
+
return settings.loss_fn(
|
511
|
+
res.get_combined().loc[:, cast(list, settings.data.index)],
|
512
|
+
settings.data,
|
513
|
+
)
|
514
|
+
|
515
|
+
|
516
|
+
def time_course_residual(
|
517
|
+
updates: dict[str, float],
|
518
|
+
settings: _Settings,
|
519
|
+
) -> float:
|
520
|
+
"""Calculate residual error between model time course and experimental data."""
|
521
|
+
model = settings.model
|
522
|
+
if (y0 := settings.y0) is not None:
|
523
|
+
model.update_variables(y0)
|
524
|
+
for p in settings.p_names:
|
525
|
+
model.update_parameter(p, updates[p])
|
526
|
+
for p in settings.v_names:
|
527
|
+
model.update_variable(p, updates[p])
|
528
|
+
|
529
|
+
res = (
|
530
|
+
Simulator(
|
531
|
+
model,
|
532
|
+
integrator=settings.integrator,
|
533
|
+
)
|
534
|
+
.simulate_time_course(cast(list, settings.data.index))
|
535
|
+
.get_result()
|
536
|
+
)
|
537
|
+
if res is None:
|
538
|
+
return cast(float, np.inf)
|
539
|
+
results_ss = res.get_combined()
|
540
|
+
|
541
|
+
return settings.loss_fn(
|
542
|
+
results_ss.loc[:, cast(list, settings.data.columns)],
|
543
|
+
settings.data,
|
544
|
+
)
|
545
|
+
|
546
|
+
|
547
|
+
def protocol_time_course_residual(
|
548
|
+
updates: dict[str, float],
|
549
|
+
settings: _Settings,
|
550
|
+
) -> float:
|
551
|
+
"""Calculate residual error between model time course and experimental data."""
|
552
|
+
model = settings.model
|
553
|
+
if (y0 := settings.y0) is not None:
|
554
|
+
model.update_variables(y0)
|
555
|
+
for p in settings.p_names:
|
556
|
+
model.update_parameter(p, updates[p])
|
557
|
+
for p in settings.v_names:
|
558
|
+
model.update_variable(p, updates[p])
|
559
|
+
|
560
|
+
if (protocol := settings.protocol) is None:
|
561
|
+
raise ValueError
|
562
|
+
|
563
|
+
res = (
|
564
|
+
Simulator(
|
565
|
+
model,
|
566
|
+
integrator=settings.integrator,
|
567
|
+
)
|
568
|
+
.simulate_protocol_time_course(
|
569
|
+
protocol=protocol,
|
570
|
+
time_points=settings.data.index,
|
571
|
+
)
|
572
|
+
.get_result()
|
573
|
+
)
|
574
|
+
if res is None:
|
575
|
+
return cast(float, np.inf)
|
576
|
+
results_ss = res.get_combined()
|
577
|
+
|
578
|
+
return settings.loss_fn(
|
579
|
+
results_ss.loc[:, cast(list, settings.data.columns)],
|
580
|
+
settings.data,
|
581
|
+
)
|
582
|
+
|
583
|
+
|
584
|
+
def steady_state(
|
585
|
+
model: Model,
|
586
|
+
*,
|
587
|
+
p0: dict[str, float],
|
588
|
+
data: pd.Series,
|
589
|
+
minimizer: Minimizer,
|
590
|
+
y0: dict[str, float] | None = None,
|
591
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
592
|
+
integrator: IntegratorType | None = None,
|
593
|
+
loss_fn: LossFn = rmse,
|
594
|
+
bounds: Bounds | None = None,
|
595
|
+
as_deepcopy: bool = True,
|
596
|
+
) -> FitResult | None:
|
597
|
+
"""Fit model parameters to steady-state experimental data.
|
598
|
+
|
599
|
+
Examples:
|
600
|
+
>>> steady_state(model, p0, data)
|
601
|
+
{'k1': 0.1, 'k2': 0.2}
|
602
|
+
|
603
|
+
Args:
|
604
|
+
model: Model instance to fit
|
605
|
+
data: Experimental steady state data as pandas Series
|
606
|
+
p0: Initial guesses as {name: value}
|
607
|
+
y0: Initial conditions as {species_name: value}
|
608
|
+
minimizer: Function to minimize fitting error
|
609
|
+
residual_fn: Function to calculate fitting error
|
610
|
+
integrator: ODE integrator class
|
611
|
+
loss_fn: Loss function to use for residual calculation
|
612
|
+
bounds: Mapping of bounds per parameter
|
613
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
614
|
+
|
615
|
+
Returns:
|
616
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
617
|
+
|
618
|
+
Note:
|
619
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
620
|
+
|
621
|
+
"""
|
622
|
+
if as_deepcopy:
|
623
|
+
model = deepcopy(model)
|
624
|
+
|
625
|
+
p_names = model.get_parameter_names()
|
626
|
+
v_names = model.get_variable_names()
|
627
|
+
|
628
|
+
fn: ResidualFn = partial(
|
629
|
+
residual_fn,
|
630
|
+
settings=_Settings(
|
631
|
+
model=model,
|
632
|
+
data=data,
|
633
|
+
y0=y0,
|
634
|
+
integrator=integrator,
|
635
|
+
loss_fn=loss_fn,
|
636
|
+
p_names=[i for i in p0 if i in p_names],
|
637
|
+
v_names=[i for i in p0 if i in v_names],
|
638
|
+
),
|
639
|
+
)
|
640
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
641
|
+
if min_result is None:
|
642
|
+
return None
|
643
|
+
|
644
|
+
return FitResult(
|
645
|
+
model=model,
|
646
|
+
best_pars=min_result.parameters,
|
647
|
+
loss=min_result.residual,
|
648
|
+
)
|
649
|
+
|
650
|
+
|
651
|
+
def time_course(
|
652
|
+
model: Model,
|
653
|
+
*,
|
654
|
+
p0: dict[str, float],
|
655
|
+
data: pd.DataFrame,
|
656
|
+
minimizer: Minimizer,
|
657
|
+
y0: dict[str, float] | None = None,
|
658
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
659
|
+
integrator: IntegratorType | None = None,
|
660
|
+
loss_fn: LossFn = rmse,
|
661
|
+
bounds: Bounds | None = None,
|
662
|
+
as_deepcopy: bool = True,
|
663
|
+
) -> FitResult | None:
|
664
|
+
"""Fit model parameters to time course of experimental data.
|
665
|
+
|
666
|
+
Examples:
|
667
|
+
>>> time_course(model, p0, data)
|
668
|
+
{'k1': 0.1, 'k2': 0.2}
|
669
|
+
|
670
|
+
Args:
|
671
|
+
model: Model instance to fit
|
672
|
+
data: Experimental time course data
|
673
|
+
p0: Initial guesses as {parameter_name: value}
|
674
|
+
y0: Initial conditions as {species_name: value}
|
675
|
+
minimizer: Function to minimize fitting error
|
676
|
+
residual_fn: Function to calculate fitting error
|
677
|
+
integrator: ODE integrator class
|
678
|
+
loss_fn: Loss function to use for residual calculation
|
679
|
+
bounds: Mapping of bounds per parameter
|
680
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
681
|
+
|
682
|
+
Returns:
|
683
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
684
|
+
|
685
|
+
Note:
|
686
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
687
|
+
|
688
|
+
"""
|
689
|
+
if as_deepcopy:
|
690
|
+
model = deepcopy(model)
|
691
|
+
p_names = model.get_parameter_names()
|
692
|
+
v_names = model.get_variable_names()
|
693
|
+
|
694
|
+
fn: ResidualFn = partial(
|
695
|
+
residual_fn,
|
696
|
+
settings=_Settings(
|
697
|
+
model=model,
|
698
|
+
data=data,
|
699
|
+
y0=y0,
|
700
|
+
integrator=integrator,
|
701
|
+
loss_fn=loss_fn,
|
702
|
+
p_names=[i for i in p0 if i in p_names],
|
703
|
+
v_names=[i for i in p0 if i in v_names],
|
704
|
+
),
|
705
|
+
)
|
706
|
+
|
707
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
708
|
+
if min_result is None:
|
709
|
+
return None
|
710
|
+
|
711
|
+
return FitResult(
|
712
|
+
model=model,
|
713
|
+
best_pars=min_result.parameters,
|
714
|
+
loss=min_result.residual,
|
715
|
+
)
|
716
|
+
|
717
|
+
|
718
|
+
def protocol_time_course(
|
719
|
+
model: Model,
|
720
|
+
*,
|
721
|
+
p0: dict[str, float],
|
722
|
+
data: pd.DataFrame,
|
723
|
+
protocol: pd.DataFrame,
|
724
|
+
minimizer: Minimizer,
|
725
|
+
y0: dict[str, float] | None = None,
|
726
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
727
|
+
integrator: IntegratorType | None = None,
|
728
|
+
loss_fn: LossFn = rmse,
|
729
|
+
bounds: Bounds | None = None,
|
730
|
+
as_deepcopy: bool = True,
|
731
|
+
) -> FitResult | None:
|
732
|
+
"""Fit model parameters to time course of experimental data.
|
733
|
+
|
734
|
+
Time points of protocol time course are taken from the data.
|
735
|
+
|
736
|
+
Examples:
|
737
|
+
>>> time_course(model, p0, data)
|
738
|
+
{'k1': 0.1, 'k2': 0.2}
|
739
|
+
|
740
|
+
Args:
|
741
|
+
model: Model instance to fit
|
742
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
743
|
+
data: Experimental time course data
|
744
|
+
protocol: Experimental protocol
|
745
|
+
y0: Initial conditions as {species_name: value}
|
746
|
+
minimizer: Function to minimize fitting error
|
747
|
+
residual_fn: Function to calculate fitting error
|
748
|
+
integrator: ODE integrator class
|
749
|
+
loss_fn: Loss function to use for residual calculation
|
750
|
+
time_points_per_step: Number of time points per step in the protocol
|
751
|
+
bounds: Mapping of bounds per parameter
|
752
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
753
|
+
|
754
|
+
Returns:
|
755
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
756
|
+
|
757
|
+
Note:
|
758
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
759
|
+
|
760
|
+
"""
|
761
|
+
if as_deepcopy:
|
762
|
+
model = deepcopy(model)
|
763
|
+
p_names = model.get_parameter_names()
|
764
|
+
v_names = model.get_variable_names()
|
765
|
+
|
766
|
+
fn: ResidualFn = partial(
|
767
|
+
residual_fn,
|
768
|
+
settings=_Settings(
|
769
|
+
model=model,
|
770
|
+
data=data,
|
771
|
+
y0=y0,
|
772
|
+
integrator=integrator,
|
773
|
+
loss_fn=loss_fn,
|
774
|
+
p_names=[i for i in p0 if i in p_names],
|
775
|
+
v_names=[i for i in p0 if i in v_names],
|
776
|
+
protocol=protocol,
|
777
|
+
),
|
778
|
+
)
|
779
|
+
|
780
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
781
|
+
if min_result is None:
|
782
|
+
return None
|
783
|
+
|
784
|
+
return FitResult(
|
785
|
+
model=model,
|
786
|
+
best_pars=min_result.parameters,
|
787
|
+
loss=min_result.residual,
|
788
|
+
)
|
789
|
+
|
790
|
+
|
791
|
+
###############################################################################
|
792
|
+
# Ensemble / carousel
|
793
|
+
# This is multi-model, single data fitting, where the models share parameters
|
794
|
+
###############################################################################
|
795
|
+
|
796
|
+
|
797
|
+
def ensemble_steady_state(
|
798
|
+
ensemble: list[Model],
|
799
|
+
*,
|
800
|
+
p0: dict[str, float],
|
801
|
+
data: pd.Series,
|
802
|
+
minimizer: Minimizer,
|
803
|
+
y0: dict[str, float] | None = None,
|
804
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
805
|
+
integrator: IntegratorType | None = None,
|
806
|
+
loss_fn: LossFn = rmse,
|
807
|
+
bounds: Bounds | None = None,
|
808
|
+
as_deepcopy: bool = True,
|
809
|
+
) -> EnsembleFitResult:
|
810
|
+
"""Fit model ensemble parameters to steady-state experimental data.
|
811
|
+
|
812
|
+
Examples:
|
813
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
814
|
+
|
815
|
+
Args:
|
816
|
+
ensemble: Ensemble to fit
|
817
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
818
|
+
data: Experimental time course data
|
819
|
+
protocol: Experimental protocol
|
820
|
+
y0: Initial conditions as {species_name: value}
|
821
|
+
minimizer: Function to minimize fitting error
|
822
|
+
residual_fn: Function to calculate fitting error
|
823
|
+
integrator: ODE integrator class
|
824
|
+
loss_fn: Loss function to use for residual calculation
|
825
|
+
time_points_per_step: Number of time points per step in the protocol
|
826
|
+
bounds: Mapping of bounds per parameter
|
827
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
828
|
+
|
829
|
+
Returns:
|
830
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
831
|
+
|
832
|
+
Note:
|
833
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
834
|
+
|
835
|
+
"""
|
836
|
+
return EnsembleFitResult(
|
837
|
+
[
|
838
|
+
fit
|
839
|
+
for i in parallel.parallelise(
|
840
|
+
partial(
|
841
|
+
steady_state,
|
842
|
+
p0=p0,
|
843
|
+
data=data,
|
844
|
+
y0=y0,
|
845
|
+
integrator=integrator,
|
846
|
+
loss_fn=loss_fn,
|
847
|
+
minimizer=minimizer,
|
848
|
+
residual_fn=residual_fn,
|
849
|
+
bounds=bounds,
|
850
|
+
as_deepcopy=as_deepcopy,
|
851
|
+
),
|
852
|
+
inputs=list(enumerate(ensemble)),
|
853
|
+
)
|
854
|
+
if (fit := i[1]) is not None
|
855
|
+
]
|
856
|
+
)
|
857
|
+
|
858
|
+
|
859
|
+
def carousel_steady_state(
|
860
|
+
carousel: Carousel,
|
861
|
+
*,
|
862
|
+
p0: dict[str, float],
|
863
|
+
data: pd.Series,
|
864
|
+
minimizer: Minimizer,
|
865
|
+
y0: dict[str, float] | None = None,
|
866
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
867
|
+
integrator: IntegratorType | None = None,
|
868
|
+
loss_fn: LossFn = rmse,
|
869
|
+
bounds: Bounds | None = None,
|
870
|
+
as_deepcopy: bool = True,
|
871
|
+
) -> EnsembleFitResult:
|
872
|
+
"""Fit model parameters to steady-state experimental data over a carousel.
|
873
|
+
|
874
|
+
Examples:
|
875
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
876
|
+
|
877
|
+
Args:
|
878
|
+
carousel: Model carousel to fit
|
879
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
880
|
+
data: Experimental time course data
|
881
|
+
protocol: Experimental protocol
|
882
|
+
y0: Initial conditions as {species_name: value}
|
883
|
+
minimizer: Function to minimize fitting error
|
884
|
+
residual_fn: Function to calculate fitting error
|
885
|
+
integrator: ODE integrator class
|
886
|
+
loss_fn: Loss function to use for residual calculation
|
887
|
+
time_points_per_step: Number of time points per step in the protocol
|
888
|
+
bounds: Mapping of bounds per parameter
|
889
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
890
|
+
|
891
|
+
Returns:
|
892
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
893
|
+
|
894
|
+
Note:
|
895
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
896
|
+
|
897
|
+
"""
|
898
|
+
return ensemble_steady_state(
|
899
|
+
carousel.variants,
|
900
|
+
p0=p0,
|
901
|
+
data=data,
|
902
|
+
minimizer=minimizer,
|
903
|
+
y0=y0,
|
904
|
+
residual_fn=residual_fn,
|
905
|
+
integrator=integrator,
|
906
|
+
loss_fn=loss_fn,
|
907
|
+
bounds=bounds,
|
908
|
+
as_deepcopy=as_deepcopy,
|
909
|
+
)
|
910
|
+
|
911
|
+
|
912
|
+
def ensemble_time_course(
|
913
|
+
ensemble: list[Model],
|
914
|
+
*,
|
915
|
+
p0: dict[str, float],
|
916
|
+
data: pd.DataFrame,
|
917
|
+
minimizer: Minimizer,
|
918
|
+
y0: dict[str, float] | None = None,
|
919
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
920
|
+
integrator: IntegratorType | None = None,
|
921
|
+
loss_fn: LossFn = rmse,
|
922
|
+
bounds: Bounds | None = None,
|
923
|
+
as_deepcopy: bool = True,
|
924
|
+
) -> EnsembleFitResult:
|
925
|
+
"""Fit model parameters to time course of experimental data over a carousel.
|
926
|
+
|
927
|
+
Time points are taken from the data.
|
928
|
+
|
929
|
+
Examples:
|
930
|
+
>>> carousel_time_course(carousel, p0=p0, data=data)
|
931
|
+
|
932
|
+
Args:
|
933
|
+
ensemble: Model ensemble to fit
|
934
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
935
|
+
data: Experimental time course data
|
936
|
+
protocol: Experimental protocol
|
937
|
+
y0: Initial conditions as {species_name: value}
|
938
|
+
minimizer: Function to minimize fitting error
|
939
|
+
residual_fn: Function to calculate fitting error
|
940
|
+
integrator: ODE integrator class
|
941
|
+
loss_fn: Loss function to use for residual calculation
|
942
|
+
time_points_per_step: Number of time points per step in the protocol
|
943
|
+
bounds: Mapping of bounds per parameter
|
944
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
945
|
+
|
946
|
+
Returns:
|
947
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
948
|
+
|
949
|
+
Note:
|
950
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
951
|
+
|
952
|
+
"""
|
953
|
+
return EnsembleFitResult(
|
954
|
+
[
|
955
|
+
fit
|
956
|
+
for i in parallel.parallelise(
|
957
|
+
partial(
|
958
|
+
time_course,
|
959
|
+
p0=p0,
|
960
|
+
data=data,
|
961
|
+
y0=y0,
|
962
|
+
integrator=integrator,
|
963
|
+
loss_fn=loss_fn,
|
964
|
+
minimizer=minimizer,
|
965
|
+
residual_fn=residual_fn,
|
966
|
+
bounds=bounds,
|
967
|
+
as_deepcopy=as_deepcopy,
|
968
|
+
),
|
969
|
+
inputs=list(enumerate(ensemble)),
|
970
|
+
)
|
971
|
+
if (fit := i[1]) is not None
|
972
|
+
]
|
973
|
+
)
|
974
|
+
|
975
|
+
|
976
|
+
def carousel_time_course(
|
977
|
+
carousel: Carousel,
|
978
|
+
*,
|
979
|
+
p0: dict[str, float],
|
980
|
+
data: pd.DataFrame,
|
981
|
+
minimizer: Minimizer,
|
982
|
+
y0: dict[str, float] | None = None,
|
983
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
984
|
+
integrator: IntegratorType | None = None,
|
985
|
+
loss_fn: LossFn = rmse,
|
986
|
+
bounds: Bounds | None = None,
|
987
|
+
as_deepcopy: bool = True,
|
988
|
+
) -> EnsembleFitResult:
|
989
|
+
"""Fit model parameters to time course of experimental data over a carousel.
|
990
|
+
|
991
|
+
Time points are taken from the data.
|
992
|
+
|
993
|
+
Examples:
|
994
|
+
>>> carousel_time_course(carousel, p0=p0, data=data)
|
995
|
+
|
996
|
+
Args:
|
997
|
+
carousel: Model carousel to fit
|
998
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
999
|
+
data: Experimental time course data
|
1000
|
+
protocol: Experimental protocol
|
1001
|
+
y0: Initial conditions as {species_name: value}
|
1002
|
+
minimizer: Function to minimize fitting error
|
1003
|
+
residual_fn: Function to calculate fitting error
|
1004
|
+
integrator: ODE integrator class
|
1005
|
+
loss_fn: Loss function to use for residual calculation
|
1006
|
+
time_points_per_step: Number of time points per step in the protocol
|
1007
|
+
bounds: Mapping of bounds per parameter
|
1008
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
1009
|
+
|
1010
|
+
Returns:
|
1011
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
1012
|
+
|
1013
|
+
Note:
|
1014
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
1015
|
+
|
1016
|
+
"""
|
1017
|
+
return ensemble_time_course(
|
1018
|
+
carousel.variants,
|
1019
|
+
p0=p0,
|
1020
|
+
data=data,
|
1021
|
+
minimizer=minimizer,
|
1022
|
+
y0=y0,
|
1023
|
+
residual_fn=residual_fn,
|
1024
|
+
integrator=integrator,
|
1025
|
+
loss_fn=loss_fn,
|
1026
|
+
bounds=bounds,
|
1027
|
+
as_deepcopy=as_deepcopy,
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
|
1031
|
+
def ensemble_protocol_time_course(
|
1032
|
+
ensemble: list[Model],
|
1033
|
+
*,
|
1034
|
+
p0: dict[str, float],
|
1035
|
+
data: pd.DataFrame,
|
1036
|
+
minimizer: Minimizer,
|
1037
|
+
protocol: pd.DataFrame,
|
1038
|
+
y0: dict[str, float] | None = None,
|
1039
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
1040
|
+
integrator: IntegratorType | None = None,
|
1041
|
+
loss_fn: LossFn = rmse,
|
1042
|
+
bounds: Bounds | None = None,
|
1043
|
+
as_deepcopy: bool = True,
|
1044
|
+
) -> EnsembleFitResult:
|
1045
|
+
"""Fit model parameters to time course of experimental data over a protocol.
|
1046
|
+
|
1047
|
+
Time points of protocol time course are taken from the data.
|
1048
|
+
|
1049
|
+
Examples:
|
1050
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
1051
|
+
|
1052
|
+
Args:
|
1053
|
+
ensemble: Model ensemble: value}
|
1054
|
+
p0: initial parameter guess
|
1055
|
+
data: Experimental time course data
|
1056
|
+
protocol: Experimental protocol
|
1057
|
+
y0: Initial conditions as {species_name: value}
|
1058
|
+
minimizer: Function to minimize fitting error
|
1059
|
+
residual_fn: Function to calculate fitting error
|
1060
|
+
integrator: ODE integrator class
|
1061
|
+
loss_fn: Loss function to use for residual calculation
|
1062
|
+
time_points_per_step: Number of time points per step in the protocol
|
1063
|
+
bounds: Mapping of bounds per parameter
|
1064
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
1065
|
+
|
1066
|
+
Returns:
|
1067
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
1068
|
+
|
1069
|
+
Note:
|
1070
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
1071
|
+
|
1072
|
+
"""
|
1073
|
+
return EnsembleFitResult(
|
1074
|
+
[
|
1075
|
+
fit
|
1076
|
+
for i in parallel.parallelise(
|
1077
|
+
partial(
|
1078
|
+
protocol_time_course,
|
1079
|
+
p0=p0,
|
1080
|
+
data=data,
|
1081
|
+
protocol=protocol,
|
1082
|
+
y0=y0,
|
1083
|
+
integrator=integrator,
|
1084
|
+
loss_fn=loss_fn,
|
1085
|
+
minimizer=minimizer,
|
1086
|
+
residual_fn=residual_fn,
|
1087
|
+
bounds=bounds,
|
1088
|
+
as_deepcopy=as_deepcopy,
|
1089
|
+
),
|
1090
|
+
inputs=list(enumerate(ensemble)),
|
1091
|
+
)
|
1092
|
+
if (fit := i[1]) is not None
|
1093
|
+
]
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
|
1097
|
+
def carousel_protocol_time_course(
|
1098
|
+
carousel: Carousel,
|
1099
|
+
*,
|
1100
|
+
p0: dict[str, float],
|
1101
|
+
data: pd.DataFrame,
|
1102
|
+
minimizer: Minimizer,
|
1103
|
+
protocol: pd.DataFrame,
|
1104
|
+
y0: dict[str, float] | None = None,
|
1105
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
1106
|
+
integrator: IntegratorType | None = None,
|
1107
|
+
loss_fn: LossFn = rmse,
|
1108
|
+
bounds: Bounds | None = None,
|
1109
|
+
as_deepcopy: bool = True,
|
1110
|
+
) -> EnsembleFitResult:
|
1111
|
+
"""Fit model parameters to time course of experimental data over a protocol.
|
1112
|
+
|
1113
|
+
Time points of protocol time course are taken from the data.
|
1114
|
+
|
1115
|
+
Examples:
|
1116
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
1117
|
+
|
1118
|
+
Args:
|
1119
|
+
carousel: Model carousel to fit
|
1120
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
1121
|
+
data: Experimental time course data
|
1122
|
+
protocol: Experimental protocol
|
1123
|
+
y0: Initial conditions as {species_name: value}
|
1124
|
+
minimizer: Function to minimize fitting error
|
1125
|
+
residual_fn: Function to calculate fitting error
|
1126
|
+
integrator: ODE integrator class
|
1127
|
+
loss_fn: Loss function to use for residual calculation
|
1128
|
+
time_points_per_step: Number of time points per step in the protocol
|
1129
|
+
bounds: Mapping of bounds per parameter
|
1130
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
1131
|
+
|
1132
|
+
Returns:
|
1133
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
1134
|
+
|
1135
|
+
Note:
|
1136
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
1137
|
+
|
1138
|
+
"""
|
1139
|
+
return ensemble_protocol_time_course(
|
1140
|
+
carousel.variants,
|
1141
|
+
p0=p0,
|
1142
|
+
data=data,
|
1143
|
+
minimizer=minimizer,
|
1144
|
+
protocol=protocol,
|
1145
|
+
y0=y0,
|
1146
|
+
residual_fn=residual_fn,
|
1147
|
+
integrator=integrator,
|
1148
|
+
loss_fn=loss_fn,
|
1149
|
+
bounds=bounds,
|
1150
|
+
as_deepcopy=as_deepcopy,
|
1151
|
+
)
|
1152
|
+
|
1153
|
+
|
1154
|
+
###############################################################################
|
1155
|
+
# Joint fitting
|
1156
|
+
# This is multi-model, multi-data fitting, where the models share some parameters
|
1157
|
+
###############################################################################
|
1158
|
+
|
1159
|
+
|
1160
|
+
def _unpacked[T1, T2, Tout](inp: tuple[T1, T2], fn: Callable[[T1, T2], Tout]) -> Tout:
|
1161
|
+
return fn(*inp)
|
1162
|
+
|
1163
|
+
|
1164
|
+
def _sum_of_residuals(
|
1165
|
+
updates: dict[str, float],
|
1166
|
+
residual_fn: ResidualProtocol,
|
1167
|
+
fits: list[_Settings],
|
1168
|
+
pool: pebble.ProcessPool,
|
1169
|
+
) -> float:
|
1170
|
+
future = pool.map(
|
1171
|
+
partial(_unpacked, fn=residual_fn),
|
1172
|
+
[(updates, i) for i in fits],
|
1173
|
+
timeout=None,
|
1174
|
+
)
|
1175
|
+
error = 0.0
|
1176
|
+
it = future.result()
|
1177
|
+
while True:
|
1178
|
+
try:
|
1179
|
+
error += next(it)
|
1180
|
+
except StopIteration:
|
1181
|
+
break
|
1182
|
+
except TimeoutError:
|
1183
|
+
return np.inf
|
1184
|
+
return error
|
1185
|
+
|
1186
|
+
|
1187
|
+
def joint_steady_state(
|
1188
|
+
to_fit: list[FitSettings],
|
1189
|
+
*,
|
1190
|
+
p0: dict[str, float],
|
1191
|
+
minimizer: Minimizer,
|
1192
|
+
y0: dict[str, float] | None = None,
|
1193
|
+
integrator: IntegratorType | None = None,
|
1194
|
+
loss_fn: LossFn = rmse,
|
1195
|
+
bounds: Bounds | None = None,
|
1196
|
+
max_workers: int | None = None,
|
1197
|
+
as_deepcopy: bool = True,
|
1198
|
+
) -> JointFitResult | None:
|
1199
|
+
"""Multi-model, multi-data fitting."""
|
1200
|
+
full_settings = []
|
1201
|
+
for i in to_fit:
|
1202
|
+
p_names = i.model.get_parameter_names()
|
1203
|
+
v_names = i.model.get_variable_names()
|
1204
|
+
full_settings.append(
|
1205
|
+
_Settings(
|
1206
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1207
|
+
data=i.data,
|
1208
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1209
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1210
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1211
|
+
p_names=[j for j in p0 if j in p_names],
|
1212
|
+
v_names=[j for j in p0 if j in v_names],
|
1213
|
+
)
|
1214
|
+
)
|
1215
|
+
|
1216
|
+
with pebble.ProcessPool(
|
1217
|
+
max_workers=(
|
1218
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1219
|
+
)
|
1220
|
+
) as pool:
|
1221
|
+
min_result = minimizer(
|
1222
|
+
partial(
|
1223
|
+
_sum_of_residuals,
|
1224
|
+
residual_fn=steady_state_residual,
|
1225
|
+
fits=full_settings,
|
1226
|
+
pool=pool,
|
1227
|
+
),
|
1228
|
+
p0,
|
1229
|
+
{} if bounds is None else bounds,
|
1230
|
+
)
|
1231
|
+
if min_result is None:
|
1232
|
+
return None
|
1233
|
+
|
1234
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1235
|
+
|
1236
|
+
|
1237
|
+
def joint_time_course(
|
1238
|
+
to_fit: list[FitSettings],
|
1239
|
+
*,
|
1240
|
+
p0: dict[str, float],
|
1241
|
+
minimizer: Minimizer,
|
1242
|
+
y0: dict[str, float] | None = None,
|
1243
|
+
integrator: IntegratorType | None = None,
|
1244
|
+
loss_fn: LossFn = rmse,
|
1245
|
+
bounds: Bounds | None = None,
|
1246
|
+
max_workers: int | None = None,
|
1247
|
+
as_deepcopy: bool = True,
|
1248
|
+
) -> JointFitResult | None:
|
1249
|
+
"""Multi-model, multi-data fitting."""
|
1250
|
+
full_settings = []
|
1251
|
+
for i in to_fit:
|
1252
|
+
p_names = i.model.get_parameter_names()
|
1253
|
+
v_names = i.model.get_variable_names()
|
1254
|
+
full_settings.append(
|
1255
|
+
_Settings(
|
1256
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1257
|
+
data=i.data,
|
1258
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1259
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1260
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1261
|
+
p_names=[j for j in p0 if j in p_names],
|
1262
|
+
v_names=[j for j in p0 if j in v_names],
|
1263
|
+
)
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
with pebble.ProcessPool(
|
1267
|
+
max_workers=(
|
1268
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1269
|
+
)
|
1270
|
+
) as pool:
|
1271
|
+
min_result = minimizer(
|
1272
|
+
partial(
|
1273
|
+
_sum_of_residuals,
|
1274
|
+
residual_fn=time_course_residual,
|
1275
|
+
fits=full_settings,
|
1276
|
+
pool=pool,
|
1277
|
+
),
|
1278
|
+
p0,
|
1279
|
+
{} if bounds is None else bounds,
|
1280
|
+
)
|
1281
|
+
if min_result is None:
|
1282
|
+
return None
|
1283
|
+
|
1284
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1285
|
+
|
1286
|
+
|
1287
|
+
def joint_protocol_time_course(
|
1288
|
+
to_fit: list[FitSettings],
|
1289
|
+
*,
|
1290
|
+
p0: dict[str, float],
|
1291
|
+
minimizer: Minimizer,
|
1292
|
+
y0: dict[str, float] | None = None,
|
1293
|
+
integrator: IntegratorType | None = None,
|
1294
|
+
loss_fn: LossFn = rmse,
|
1295
|
+
bounds: Bounds | None = None,
|
1296
|
+
max_workers: int | None = None,
|
1297
|
+
as_deepcopy: bool = True,
|
1298
|
+
) -> JointFitResult | None:
|
1299
|
+
"""Multi-model, multi-data fitting."""
|
1300
|
+
full_settings = []
|
1301
|
+
for i in to_fit:
|
1302
|
+
p_names = i.model.get_parameter_names()
|
1303
|
+
v_names = i.model.get_variable_names()
|
1304
|
+
full_settings.append(
|
1305
|
+
_Settings(
|
1306
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1307
|
+
data=i.data,
|
1308
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1309
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1310
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1311
|
+
p_names=[j for j in p0 if j in p_names],
|
1312
|
+
v_names=[j for j in p0 if j in v_names],
|
1313
|
+
)
|
1314
|
+
)
|
1315
|
+
|
1316
|
+
with pebble.ProcessPool(
|
1317
|
+
max_workers=(
|
1318
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1319
|
+
)
|
1320
|
+
) as pool:
|
1321
|
+
min_result = minimizer(
|
1322
|
+
partial(
|
1323
|
+
_sum_of_residuals,
|
1324
|
+
residual_fn=protocol_time_course_residual,
|
1325
|
+
fits=full_settings,
|
1326
|
+
pool=pool,
|
1327
|
+
),
|
1328
|
+
p0,
|
1329
|
+
{} if bounds is None else bounds,
|
1330
|
+
)
|
1331
|
+
if min_result is None:
|
1332
|
+
return None
|
1333
|
+
|
1334
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1335
|
+
|
1336
|
+
|
1337
|
+
###############################################################################
|
1338
|
+
# Joint fitting
|
1339
|
+
# This is multi-model, multi-data, multi-simulation fitting
|
1340
|
+
# The models share some parameters here, everything else can be changed though
|
1341
|
+
###############################################################################
|
1342
|
+
|
1343
|
+
|
1344
|
+
def _execute(inp: tuple[dict[str, float], ResidualProtocol, _Settings]) -> float:
|
1345
|
+
updates, residual_fn, settings = inp
|
1346
|
+
return residual_fn(updates, settings)
|
1347
|
+
|
1348
|
+
|
1349
|
+
def _mixed_sum_of_residuals(
|
1350
|
+
updates: dict[str, float],
|
1351
|
+
fits: list[_Settings],
|
1352
|
+
pool: pebble.ProcessPool,
|
1353
|
+
) -> float:
|
1354
|
+
future = pool.map(_execute, [(updates, i.residual_fn, i) for i in fits])
|
1355
|
+
error = 0.0
|
1356
|
+
it = future.result()
|
1357
|
+
while True:
|
1358
|
+
try:
|
1359
|
+
error += next(it)
|
1360
|
+
except StopIteration:
|
1361
|
+
break
|
1362
|
+
except TimeoutError:
|
1363
|
+
return np.inf
|
1364
|
+
return error
|
1365
|
+
|
1366
|
+
|
1367
|
+
def joint_mixed(
|
1368
|
+
to_fit: list[MixedSettings],
|
1369
|
+
*,
|
1370
|
+
p0: dict[str, float],
|
1371
|
+
minimizer: Minimizer,
|
1372
|
+
y0: dict[str, float] | None = None,
|
1373
|
+
integrator: IntegratorType | None = None,
|
1374
|
+
loss_fn: LossFn = rmse,
|
1375
|
+
bounds: Bounds | None = None,
|
1376
|
+
max_workers: int | None = None,
|
1377
|
+
as_deepcopy: bool = True,
|
1378
|
+
) -> JointFitResult | None:
|
1379
|
+
"""Multi-model, multi-data, multi-simulation fitting."""
|
1380
|
+
full_settings = []
|
1381
|
+
for i in to_fit:
|
1382
|
+
p_names = i.model.get_parameter_names()
|
1383
|
+
v_names = i.model.get_variable_names()
|
1384
|
+
full_settings.append(
|
1385
|
+
_Settings(
|
1386
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1387
|
+
data=i.data,
|
1388
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1389
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1390
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1391
|
+
p_names=[j for j in p0 if j in p_names],
|
1392
|
+
v_names=[j for j in p0 if j in v_names],
|
1393
|
+
residual_fn=i.residual_fn,
|
1394
|
+
)
|
1395
|
+
)
|
1396
|
+
|
1397
|
+
with pebble.ProcessPool(
|
1398
|
+
max_workers=(
|
1399
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1400
|
+
)
|
1401
|
+
) as pool:
|
1402
|
+
min_result = minimizer(
|
1403
|
+
partial(
|
1404
|
+
_mixed_sum_of_residuals,
|
1405
|
+
fits=full_settings,
|
1406
|
+
pool=pool,
|
1407
|
+
),
|
1408
|
+
p0,
|
1409
|
+
{} if bounds is None else bounds,
|
1410
|
+
)
|
1411
|
+
if min_result is None:
|
1412
|
+
return None
|
1413
|
+
|
1414
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|