mxlpy 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +2 -0
- mxlpy/fit.py +960 -359
- mxlpy/fuzzy.py +139 -0
- mxlpy/identify.py +1 -0
- mxlpy/integrators/int_scipy.py +4 -3
- mxlpy/meta/codegen_latex.py +1 -0
- mxlpy/meta/source_tools.py +1 -1
- mxlpy/model.py +74 -33
- mxlpy/nn/__init__.py +5 -0
- mxlpy/nn/_equinox.py +293 -0
- mxlpy/nn/_torch.py +59 -2
- mxlpy/npe/__init__.py +5 -0
- mxlpy/npe/_equinox.py +344 -0
- mxlpy/npe/_torch.py +6 -22
- mxlpy/parallel.py +73 -4
- mxlpy/surrogates/__init__.py +5 -0
- mxlpy/surrogates/_equinox.py +195 -0
- mxlpy/surrogates/_torch.py +5 -20
- mxlpy/symbolic/symbolic_model.py +30 -3
- mxlpy/types.py +172 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/METADATA +11 -1
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/RECORD +24 -20
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.26.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/fit.py
CHANGED
@@ -1,60 +1,226 @@
|
|
1
|
-
"""
|
1
|
+
"""Fitting routines.
|
2
|
+
|
3
|
+
Single model, single data routines
|
4
|
+
----------------------------------
|
5
|
+
- `steady_state`
|
6
|
+
- `time_course`
|
7
|
+
- `protocol_time_course`
|
8
|
+
|
9
|
+
Multiple model, single data routines
|
10
|
+
------------------------------------
|
11
|
+
- `ensemble_steady_state`
|
12
|
+
- `ensemble_time_course`
|
13
|
+
- `ensemble_protocol_time_course`
|
14
|
+
|
15
|
+
A carousel is a special case of an ensemble, where the general
|
16
|
+
structure (e.g. stoichiometries) is the same, while the reactions kinetics
|
17
|
+
can vary
|
18
|
+
- `carousel_steady_state`
|
19
|
+
- `carousel_time_course`
|
20
|
+
- `carousel_protocol_time_course`
|
21
|
+
|
22
|
+
Multiple model, multiple data
|
23
|
+
-----------------------------
|
24
|
+
- `joint_steady_state`
|
25
|
+
- `joint_time_course`
|
26
|
+
- `joint_protocol_time_course`
|
27
|
+
|
28
|
+
Multiple model, multiple data, multiple methods
|
29
|
+
-----------------------------------------------
|
30
|
+
Here we also allow to run different methods (e.g. steady-state vs time courses)
|
31
|
+
for each combination of model:data.
|
32
|
+
|
33
|
+
- `joint_mixed`
|
34
|
+
|
35
|
+
Minimizers
|
36
|
+
----------
|
37
|
+
- LocalScipyMinimizer, including common methods such as Nelder-Mead or L-BFGS-B
|
38
|
+
- GlobalScipyMinimizer, including common methods such as basin hopping or dual annealing
|
39
|
+
|
40
|
+
Loss functions
|
41
|
+
--------------
|
42
|
+
- rmse
|
2
43
|
|
3
|
-
This module provides functions foru fitting model parameters to experimental data,
|
4
|
-
including both steadyd-state and time-series data fitting capabilities.e
|
5
|
-
|
6
|
-
Functions:
|
7
|
-
fit_steady_state: Fits parameters to steady-state experimental data
|
8
|
-
fit_time_course: Fits parameters to time-series experimental data
|
9
44
|
"""
|
10
45
|
|
11
46
|
from __future__ import annotations
|
12
47
|
|
13
48
|
import logging
|
49
|
+
import multiprocessing
|
50
|
+
from collections.abc import Callable
|
14
51
|
from copy import deepcopy
|
15
52
|
from dataclasses import dataclass
|
16
53
|
from functools import partial
|
17
|
-
from typing import TYPE_CHECKING, Protocol
|
54
|
+
from typing import TYPE_CHECKING, Literal, Protocol
|
18
55
|
|
19
56
|
import numpy as np
|
20
|
-
|
57
|
+
import pandas as pd
|
58
|
+
import pebble
|
59
|
+
from scipy.optimize import (
|
60
|
+
basinhopping,
|
61
|
+
differential_evolution,
|
62
|
+
direct,
|
63
|
+
dual_annealing,
|
64
|
+
minimize,
|
65
|
+
shgo,
|
66
|
+
)
|
21
67
|
from wadler_lindig import pformat
|
22
68
|
|
23
69
|
from mxlpy import parallel
|
70
|
+
from mxlpy.model import Model
|
24
71
|
from mxlpy.simulator import Simulator
|
25
|
-
from mxlpy.types import Array,
|
72
|
+
from mxlpy.types import Array, IntegratorType, cast
|
26
73
|
|
27
74
|
if TYPE_CHECKING:
|
28
75
|
import pandas as pd
|
76
|
+
from scipy.optimize._optimize import OptimizeResult
|
29
77
|
|
30
78
|
from mxlpy.carousel import Carousel
|
31
79
|
from mxlpy.model import Model
|
32
80
|
|
33
81
|
LOGGER = logging.getLogger(__name__)
|
34
82
|
|
83
|
+
|
35
84
|
__all__ = [
|
36
85
|
"Bounds",
|
37
|
-
"
|
86
|
+
"EnsembleFitResult",
|
38
87
|
"FitResult",
|
88
|
+
"FitSettings",
|
89
|
+
"GlobalScipyMinimizer",
|
39
90
|
"InitialGuess",
|
91
|
+
"JointFitResult",
|
40
92
|
"LOGGER",
|
93
|
+
"LocalScipyMinimizer",
|
41
94
|
"LossFn",
|
42
95
|
"MinResult",
|
43
|
-
"
|
44
|
-
"
|
96
|
+
"Minimizer",
|
97
|
+
"MixedSettings",
|
98
|
+
"ResFn",
|
45
99
|
"ResidualFn",
|
46
|
-
"
|
47
|
-
"TimeSeriesResidualFn",
|
100
|
+
"ResidualProtocol",
|
48
101
|
"carousel_protocol_time_course",
|
49
102
|
"carousel_steady_state",
|
50
103
|
"carousel_time_course",
|
104
|
+
"cosine_similarity",
|
105
|
+
"ensemble_protocol_time_course",
|
106
|
+
"ensemble_steady_state",
|
107
|
+
"ensemble_time_course",
|
108
|
+
"joint_mixed",
|
109
|
+
"joint_protocol_time_course",
|
110
|
+
"joint_steady_state",
|
111
|
+
"joint_time_course",
|
112
|
+
"mae",
|
113
|
+
"mean",
|
114
|
+
"mean_absolute_percentage",
|
115
|
+
"mean_squared",
|
116
|
+
"mean_squared_logarithmic",
|
51
117
|
"protocol_time_course",
|
118
|
+
"protocol_time_course_residual",
|
52
119
|
"rmse",
|
53
120
|
"steady_state",
|
121
|
+
"steady_state_residual",
|
54
122
|
"time_course",
|
123
|
+
"time_course_residual",
|
124
|
+
]
|
125
|
+
|
126
|
+
type InitialGuess = dict[str, float]
|
127
|
+
|
128
|
+
type Bounds = dict[str, tuple[float | None, float | None]]
|
129
|
+
type ResFn = Callable[[Array], float]
|
130
|
+
type LossFn = Callable[
|
131
|
+
[
|
132
|
+
pd.DataFrame | pd.Series,
|
133
|
+
pd.DataFrame | pd.Series,
|
134
|
+
],
|
135
|
+
float,
|
136
|
+
]
|
137
|
+
|
138
|
+
|
139
|
+
type Minimizer = Callable[
|
140
|
+
[
|
141
|
+
ResidualFn,
|
142
|
+
InitialGuess,
|
143
|
+
Bounds,
|
144
|
+
],
|
145
|
+
MinResult | None,
|
55
146
|
]
|
56
147
|
|
57
148
|
|
149
|
+
class ResidualProtocol(Protocol):
|
150
|
+
"""Protocol for steady state residual functions.
|
151
|
+
|
152
|
+
This is the user-facing variant, for stuff like
|
153
|
+
- `fit.steady_state`
|
154
|
+
- `fit.time_course`
|
155
|
+
- `fit.protocol_time_course`
|
156
|
+
|
157
|
+
The settings are later partially applied to yield ResidualFn
|
158
|
+
"""
|
159
|
+
|
160
|
+
def __call__(
|
161
|
+
self,
|
162
|
+
updates: dict[str, float],
|
163
|
+
settings: _Settings,
|
164
|
+
) -> float:
|
165
|
+
"""Calculate residual error between model steady state and experimental data."""
|
166
|
+
...
|
167
|
+
|
168
|
+
|
169
|
+
class ResidualFn(Protocol):
|
170
|
+
"""Protocol for steady state residual functions.
|
171
|
+
|
172
|
+
This is the internal version, which is produced by partial
|
173
|
+
application of `settings` of `ResidualProtocol`
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __call__(
|
177
|
+
self,
|
178
|
+
updates: dict[str, float],
|
179
|
+
) -> float:
|
180
|
+
"""Calculate residual error between model steady state and experimental data."""
|
181
|
+
...
|
182
|
+
|
183
|
+
|
184
|
+
@dataclass
|
185
|
+
class FitSettings:
|
186
|
+
"""Settings for a fit."""
|
187
|
+
|
188
|
+
model: Model
|
189
|
+
data: pd.Series | pd.DataFrame
|
190
|
+
y0: dict[str, float] | None = None
|
191
|
+
integrator: IntegratorType | None = None
|
192
|
+
loss_fn: LossFn | None = None
|
193
|
+
protocol: pd.DataFrame | None = None
|
194
|
+
|
195
|
+
|
196
|
+
@dataclass
|
197
|
+
class MixedSettings:
|
198
|
+
"""Settings for a fit."""
|
199
|
+
|
200
|
+
model: Model
|
201
|
+
data: pd.Series | pd.DataFrame
|
202
|
+
residual_fn: ResidualFn
|
203
|
+
y0: dict[str, float] | None = None
|
204
|
+
integrator: IntegratorType | None = None
|
205
|
+
loss_fn: LossFn | None = None
|
206
|
+
protocol: pd.DataFrame | None = None
|
207
|
+
|
208
|
+
|
209
|
+
@dataclass
|
210
|
+
class _Settings:
|
211
|
+
"""Non user-facing version of FitSettings."""
|
212
|
+
|
213
|
+
model: Model
|
214
|
+
data: pd.Series | pd.DataFrame
|
215
|
+
y0: dict[str, float] | None
|
216
|
+
integrator: IntegratorType | None
|
217
|
+
loss_fn: LossFn
|
218
|
+
p_names: list[str]
|
219
|
+
v_names: list[str]
|
220
|
+
protocol: pd.DataFrame | None = None
|
221
|
+
residual_fn: ResidualFn | None = None
|
222
|
+
|
223
|
+
|
58
224
|
@dataclass
|
59
225
|
class MinResult:
|
60
226
|
"""Result of a minimization operation."""
|
@@ -81,7 +247,7 @@ class FitResult:
|
|
81
247
|
|
82
248
|
|
83
249
|
@dataclass
|
84
|
-
class
|
250
|
+
class EnsembleFitResult:
|
85
251
|
"""Result of a carousel fit operation."""
|
86
252
|
|
87
253
|
fits: list[FitResult]
|
@@ -95,24 +261,37 @@ class CarouselFit:
|
|
95
261
|
return min(self.fits, key=lambda x: x.loss)
|
96
262
|
|
97
263
|
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
[
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
264
|
+
@dataclass
|
265
|
+
class JointFitResult:
|
266
|
+
"""Result of joint fit operation."""
|
267
|
+
|
268
|
+
best_pars: dict[str, float]
|
269
|
+
loss: float
|
270
|
+
|
271
|
+
def __repr__(self) -> str:
|
272
|
+
"""Return default representation."""
|
273
|
+
return pformat(self)
|
274
|
+
|
275
|
+
|
276
|
+
###############################################################################
|
277
|
+
# loss fns
|
278
|
+
###############################################################################
|
279
|
+
|
280
|
+
|
281
|
+
def mean(
|
282
|
+
y_pred: pd.DataFrame | pd.Series,
|
283
|
+
y_true: pd.DataFrame | pd.Series,
|
284
|
+
) -> float:
|
285
|
+
"""Calculate root mean square error between model and data."""
|
286
|
+
return cast(float, np.mean(y_pred - y_true))
|
287
|
+
|
288
|
+
|
289
|
+
def mean_squared(
|
290
|
+
y_pred: pd.DataFrame | pd.Series,
|
291
|
+
y_true: pd.DataFrame | pd.Series,
|
292
|
+
) -> float:
|
293
|
+
"""Calculate mean square error between model and data."""
|
294
|
+
return cast(float, np.mean(np.square(y_pred - y_true)))
|
116
295
|
|
117
296
|
|
118
297
|
def rmse(
|
@@ -123,126 +302,204 @@ def rmse(
|
|
123
302
|
return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
|
124
303
|
|
125
304
|
|
126
|
-
|
127
|
-
|
305
|
+
def mae(
|
306
|
+
y_pred: pd.DataFrame | pd.Series,
|
307
|
+
y_true: pd.DataFrame | pd.Series,
|
308
|
+
) -> float:
|
309
|
+
"""Calculate mean absolute error."""
|
310
|
+
return cast(float, np.mean(np.abs(y_true - y_pred)))
|
128
311
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
312
|
+
|
313
|
+
def mean_absolute_percentage(
|
314
|
+
y_pred: pd.DataFrame | pd.Series,
|
315
|
+
y_true: pd.DataFrame | pd.Series,
|
316
|
+
) -> float:
|
317
|
+
"""Calculate mean absolute error."""
|
318
|
+
return cast(float, 100 * np.mean(np.abs((y_true - y_pred) / y_pred)))
|
319
|
+
|
320
|
+
|
321
|
+
def mean_squared_logarithmic(
|
322
|
+
y_pred: pd.DataFrame | pd.Series,
|
323
|
+
y_true: pd.DataFrame | pd.Series,
|
324
|
+
) -> float:
|
325
|
+
"""Calculate root mean square error between model and data."""
|
326
|
+
return cast(float, np.mean(np.square(np.log(y_pred + 1) - np.log(y_true + 1))))
|
327
|
+
|
328
|
+
|
329
|
+
def cosine_similarity(
|
330
|
+
y_pred: pd.DataFrame | pd.Series,
|
331
|
+
y_true: pd.DataFrame | pd.Series,
|
332
|
+
) -> float:
|
333
|
+
"""Calculate root mean square error between model and data."""
|
334
|
+
norm = np.linalg.norm
|
335
|
+
return cast(float, -np.sum(norm(y_pred, 2) * norm(y_true, 2)))
|
336
|
+
|
337
|
+
|
338
|
+
###############################################################################
|
339
|
+
# Minimizers
|
340
|
+
###############################################################################
|
142
341
|
|
143
342
|
|
144
|
-
|
145
|
-
|
343
|
+
@dataclass
|
344
|
+
class LocalScipyMinimizer:
|
345
|
+
"""Local multivariate minimization using scipy.optimize.
|
346
|
+
|
347
|
+
See Also
|
348
|
+
--------
|
349
|
+
https://docs.scipy.org/doc/scipy/reference/optimize.html#local-multivariate-optimization
|
350
|
+
|
351
|
+
"""
|
352
|
+
|
353
|
+
tol: float = 1e-6
|
354
|
+
method: Literal[
|
355
|
+
"Nelder-Mead",
|
356
|
+
"Powell",
|
357
|
+
"CG",
|
358
|
+
"BFGS",
|
359
|
+
"Newton-CG",
|
360
|
+
"L-BFGS-B",
|
361
|
+
"TNC",
|
362
|
+
"COBYLA",
|
363
|
+
"COBYQA",
|
364
|
+
"SLSQP",
|
365
|
+
"trust-constr",
|
366
|
+
"dogleg",
|
367
|
+
"trust-ncg",
|
368
|
+
"trust-exact",
|
369
|
+
"trust-krylov",
|
370
|
+
] = "L-BFGS-B"
|
146
371
|
|
147
372
|
def __call__(
|
148
373
|
self,
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
374
|
+
residual_fn: ResidualFn,
|
375
|
+
p0: dict[str, float],
|
376
|
+
bounds: Bounds,
|
377
|
+
) -> MinResult | None:
|
378
|
+
"""Call minimzer."""
|
379
|
+
par_names = list(p0.keys())
|
380
|
+
|
381
|
+
res = minimize(
|
382
|
+
lambda par_values: residual_fn(_pack_updates(par_values, par_names)),
|
383
|
+
x0=list(p0.values()),
|
384
|
+
bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
|
385
|
+
method=self.method,
|
386
|
+
tol=self.tol,
|
387
|
+
)
|
388
|
+
if res.success:
|
389
|
+
return MinResult(
|
390
|
+
parameters=dict(
|
391
|
+
zip(
|
392
|
+
p0,
|
393
|
+
res.x,
|
394
|
+
strict=True,
|
395
|
+
),
|
396
|
+
),
|
397
|
+
residual=res.fun,
|
398
|
+
)
|
399
|
+
|
400
|
+
LOGGER.warning("Minimisation failed due to %s", res.message)
|
401
|
+
return None
|
402
|
+
|
403
|
+
|
404
|
+
@dataclass
|
405
|
+
class GlobalScipyMinimizer:
|
406
|
+
"""Global iate minimization using scipy.optimize.
|
407
|
+
|
408
|
+
See Also
|
409
|
+
--------
|
410
|
+
https://docs.scipy.org/doc/scipy/reference/optimize.html#global-optimization
|
160
411
|
|
412
|
+
"""
|
161
413
|
|
162
|
-
|
163
|
-
|
414
|
+
tol: float = 1e-6
|
415
|
+
method: Literal[
|
416
|
+
"basinhopping",
|
417
|
+
"differential_evolution",
|
418
|
+
"shgo",
|
419
|
+
"dual_annealing",
|
420
|
+
"direct",
|
421
|
+
] = "basinhopping"
|
164
422
|
|
165
423
|
def __call__(
|
166
424
|
self,
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
"""Calculate residual error between model time course and experimental data."""
|
178
|
-
...
|
179
|
-
|
425
|
+
residual_fn: ResidualFn,
|
426
|
+
p0: dict[str, float],
|
427
|
+
bounds: Bounds,
|
428
|
+
) -> MinResult | None:
|
429
|
+
"""Minimize residual fn."""
|
430
|
+
res: OptimizeResult
|
431
|
+
par_names = list(p0.keys())
|
432
|
+
res_fn: ResFn = lambda par_values: residual_fn( # noqa: E731
|
433
|
+
_pack_updates(par_values, par_names)
|
434
|
+
)
|
180
435
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
)
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
method
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
436
|
+
if self.method == "basinhopping":
|
437
|
+
res = basinhopping(
|
438
|
+
res_fn,
|
439
|
+
x0=list(p0.values()),
|
440
|
+
)
|
441
|
+
elif self.method == "differential_evolution":
|
442
|
+
res = differential_evolution(res_fn, bounds)
|
443
|
+
elif self.method == "shgo":
|
444
|
+
res = shgo(res_fn, bounds)
|
445
|
+
elif self.method == "dual_annealing":
|
446
|
+
res = dual_annealing(res_fn, bounds)
|
447
|
+
elif self.method == "direct":
|
448
|
+
res = direct(res_fn, bounds)
|
449
|
+
else:
|
450
|
+
msg = f"Unknown method {self.method}"
|
451
|
+
raise NotImplementedError(msg)
|
452
|
+
if res.success:
|
453
|
+
return MinResult(
|
454
|
+
parameters=dict(
|
455
|
+
zip(
|
456
|
+
p0,
|
457
|
+
res.x,
|
458
|
+
strict=True,
|
459
|
+
),
|
199
460
|
),
|
200
|
-
|
201
|
-
|
202
|
-
|
461
|
+
residual=res.fun,
|
462
|
+
)
|
463
|
+
|
464
|
+
LOGGER.warning("Minimisation failed.")
|
465
|
+
return None
|
203
466
|
|
204
|
-
LOGGER.warning("Minimisation failed.")
|
205
|
-
return None
|
206
467
|
|
468
|
+
###############################################################################
|
469
|
+
# Residual functions
|
470
|
+
###############################################################################
|
207
471
|
|
208
|
-
|
472
|
+
|
473
|
+
def _pack_updates(
|
209
474
|
par_values: Array,
|
210
|
-
# This will be filled out by partial
|
211
475
|
par_names: list[str],
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
476
|
+
) -> dict[str, float]:
|
477
|
+
return dict(
|
478
|
+
zip(
|
479
|
+
par_names,
|
480
|
+
par_values,
|
481
|
+
strict=True,
|
482
|
+
)
|
483
|
+
)
|
219
484
|
|
220
|
-
Args:
|
221
|
-
par_values: Parameter values to test
|
222
|
-
data: Experimental steady state data
|
223
|
-
model: Model instance to simulate
|
224
|
-
y0: Initial conditions
|
225
|
-
par_names: Names of parameters being fit
|
226
|
-
integrator: ODE integrator class to use
|
227
|
-
loss_fn: Loss function to use for residual calculation
|
228
485
|
|
229
|
-
|
230
|
-
|
486
|
+
def steady_state_residual(
|
487
|
+
updates: dict[str, float],
|
488
|
+
settings: _Settings,
|
489
|
+
) -> float:
|
490
|
+
"""Calculate residual error between model steady state and experimental data."""
|
491
|
+
model = settings.model
|
492
|
+
if (y0 := settings.y0) is not None:
|
493
|
+
model.update_variables(y0)
|
494
|
+
for p in settings.p_names:
|
495
|
+
model.update_parameter(p, updates[p])
|
496
|
+
for p in settings.v_names:
|
497
|
+
model.update_variable(p, updates[p])
|
231
498
|
|
232
|
-
"""
|
233
499
|
res = (
|
234
500
|
Simulator(
|
235
|
-
model
|
236
|
-
|
237
|
-
zip(
|
238
|
-
par_names,
|
239
|
-
par_values,
|
240
|
-
strict=True,
|
241
|
-
)
|
242
|
-
)
|
243
|
-
),
|
244
|
-
y0=y0,
|
245
|
-
integrator=integrator,
|
501
|
+
model,
|
502
|
+
integrator=settings.integrator,
|
246
503
|
)
|
247
504
|
.simulate_to_steady_state()
|
248
505
|
.get_result()
|
@@ -250,93 +507,67 @@ def _steady_state_residual(
|
|
250
507
|
if res is None:
|
251
508
|
return cast(float, np.inf)
|
252
509
|
|
253
|
-
return loss_fn(
|
254
|
-
res.get_combined().loc[:, cast(list, data.index)],
|
255
|
-
data,
|
510
|
+
return settings.loss_fn(
|
511
|
+
res.get_combined().loc[:, cast(list, settings.data.index)],
|
512
|
+
settings.data,
|
256
513
|
)
|
257
514
|
|
258
515
|
|
259
|
-
def
|
260
|
-
|
261
|
-
|
262
|
-
par_names: list[str],
|
263
|
-
data: pd.DataFrame,
|
264
|
-
model: Model,
|
265
|
-
y0: dict[str, float] | None,
|
266
|
-
integrator: IntegratorType,
|
267
|
-
loss_fn: LossFn,
|
516
|
+
def time_course_residual(
|
517
|
+
updates: dict[str, float],
|
518
|
+
settings: _Settings,
|
268
519
|
) -> float:
|
269
|
-
"""Calculate residual error between model time course and experimental data.
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
model
|
275
|
-
|
276
|
-
|
277
|
-
integrator: ODE integrator class to use
|
278
|
-
loss_fn: Loss function to use for residual calculation
|
520
|
+
"""Calculate residual error between model time course and experimental data."""
|
521
|
+
model = settings.model
|
522
|
+
if (y0 := settings.y0) is not None:
|
523
|
+
model.update_variables(y0)
|
524
|
+
for p in settings.p_names:
|
525
|
+
model.update_parameter(p, updates[p])
|
526
|
+
for p in settings.v_names:
|
527
|
+
model.update_variable(p, updates[p])
|
279
528
|
|
280
|
-
Returns:
|
281
|
-
float: Root mean square error between model and data
|
282
|
-
|
283
|
-
"""
|
284
529
|
res = (
|
285
530
|
Simulator(
|
286
|
-
model
|
287
|
-
|
288
|
-
integrator=integrator,
|
531
|
+
model,
|
532
|
+
integrator=settings.integrator,
|
289
533
|
)
|
290
|
-
.simulate_time_course(cast(list, data.index))
|
534
|
+
.simulate_time_course(cast(list, settings.data.index))
|
291
535
|
.get_result()
|
292
536
|
)
|
293
537
|
if res is None:
|
294
538
|
return cast(float, np.inf)
|
295
539
|
results_ss = res.get_combined()
|
296
540
|
|
297
|
-
return loss_fn(
|
298
|
-
results_ss.loc[:, cast(list, data.columns)],
|
299
|
-
data,
|
541
|
+
return settings.loss_fn(
|
542
|
+
results_ss.loc[:, cast(list, settings.data.columns)],
|
543
|
+
settings.data,
|
300
544
|
)
|
301
545
|
|
302
546
|
|
303
|
-
def
|
304
|
-
|
305
|
-
|
306
|
-
par_names: list[str],
|
307
|
-
data: pd.DataFrame,
|
308
|
-
model: Model,
|
309
|
-
y0: dict[str, float] | None,
|
310
|
-
integrator: IntegratorType,
|
311
|
-
loss_fn: LossFn,
|
312
|
-
protocol: pd.DataFrame,
|
547
|
+
def protocol_time_course_residual(
|
548
|
+
updates: dict[str, float],
|
549
|
+
settings: _Settings,
|
313
550
|
) -> float:
|
314
|
-
"""Calculate residual error between model time course and experimental data.
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
model
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
time_points_per_step: Number of time points per step in the protocol
|
551
|
+
"""Calculate residual error between model time course and experimental data."""
|
552
|
+
model = settings.model
|
553
|
+
if (y0 := settings.y0) is not None:
|
554
|
+
model.update_variables(y0)
|
555
|
+
for p in settings.p_names:
|
556
|
+
model.update_parameter(p, updates[p])
|
557
|
+
for p in settings.v_names:
|
558
|
+
model.update_variable(p, updates[p])
|
559
|
+
|
560
|
+
if (protocol := settings.protocol) is None:
|
561
|
+
raise ValueError
|
326
562
|
|
327
|
-
Returns:
|
328
|
-
float: Root mean square error between model and data
|
329
|
-
|
330
|
-
"""
|
331
563
|
res = (
|
332
564
|
Simulator(
|
333
|
-
model
|
334
|
-
|
335
|
-
integrator=integrator,
|
565
|
+
model,
|
566
|
+
integrator=settings.integrator,
|
336
567
|
)
|
337
568
|
.simulate_protocol_time_course(
|
338
569
|
protocol=protocol,
|
339
|
-
time_points=data.index,
|
570
|
+
time_points=settings.data.index,
|
340
571
|
)
|
341
572
|
.get_result()
|
342
573
|
)
|
@@ -344,87 +575,9 @@ def _protocol_time_course_residual(
|
|
344
575
|
return cast(float, np.inf)
|
345
576
|
results_ss = res.get_combined()
|
346
577
|
|
347
|
-
return loss_fn(
|
348
|
-
results_ss.loc[:, cast(list, data.columns)],
|
349
|
-
data,
|
350
|
-
)
|
351
|
-
|
352
|
-
|
353
|
-
def _carousel_steady_state_worker(
|
354
|
-
model: Model,
|
355
|
-
p0: dict[str, float],
|
356
|
-
data: pd.Series,
|
357
|
-
y0: dict[str, float] | None,
|
358
|
-
integrator: IntegratorType | None,
|
359
|
-
loss_fn: LossFn,
|
360
|
-
minimize_fn: MinimizeFn,
|
361
|
-
residual_fn: SteadyStateResidualFn,
|
362
|
-
bounds: Bounds | None,
|
363
|
-
) -> FitResult | None:
|
364
|
-
model_pars = model.get_parameter_values()
|
365
|
-
|
366
|
-
return steady_state(
|
367
|
-
model,
|
368
|
-
p0={k: v for k, v in p0.items() if k in model_pars},
|
369
|
-
y0=y0,
|
370
|
-
data=data,
|
371
|
-
minimize_fn=minimize_fn,
|
372
|
-
residual_fn=residual_fn,
|
373
|
-
integrator=integrator,
|
374
|
-
loss_fn=loss_fn,
|
375
|
-
bounds=bounds,
|
376
|
-
)
|
377
|
-
|
378
|
-
|
379
|
-
def _carousel_time_course_worker(
|
380
|
-
model: Model,
|
381
|
-
p0: dict[str, float],
|
382
|
-
data: pd.DataFrame,
|
383
|
-
y0: dict[str, float] | None,
|
384
|
-
integrator: IntegratorType | None,
|
385
|
-
loss_fn: LossFn,
|
386
|
-
minimize_fn: MinimizeFn,
|
387
|
-
residual_fn: TimeSeriesResidualFn,
|
388
|
-
bounds: Bounds | None,
|
389
|
-
) -> FitResult | None:
|
390
|
-
model_pars = model.get_parameter_values()
|
391
|
-
return time_course(
|
392
|
-
model,
|
393
|
-
p0={k: v for k, v in p0.items() if k in model_pars},
|
394
|
-
y0=y0,
|
395
|
-
data=data,
|
396
|
-
minimize_fn=minimize_fn,
|
397
|
-
residual_fn=residual_fn,
|
398
|
-
integrator=integrator,
|
399
|
-
loss_fn=loss_fn,
|
400
|
-
bounds=bounds,
|
401
|
-
)
|
402
|
-
|
403
|
-
|
404
|
-
def _carousel_protocol_worker(
|
405
|
-
model: Model,
|
406
|
-
p0: dict[str, float],
|
407
|
-
data: pd.DataFrame,
|
408
|
-
protocol: pd.DataFrame,
|
409
|
-
y0: dict[str, float] | None,
|
410
|
-
integrator: IntegratorType | None,
|
411
|
-
loss_fn: LossFn,
|
412
|
-
minimize_fn: MinimizeFn,
|
413
|
-
residual_fn: ProtocolResidualFn,
|
414
|
-
bounds: Bounds | None,
|
415
|
-
) -> FitResult | None:
|
416
|
-
model_pars = model.get_parameter_values()
|
417
|
-
return protocol_time_course(
|
418
|
-
model,
|
419
|
-
p0={k: v for k, v in p0.items() if k in model_pars},
|
420
|
-
y0=y0,
|
421
|
-
protocol=protocol,
|
422
|
-
data=data,
|
423
|
-
minimize_fn=minimize_fn,
|
424
|
-
residual_fn=residual_fn,
|
425
|
-
integrator=integrator,
|
426
|
-
loss_fn=loss_fn,
|
427
|
-
bounds=bounds,
|
578
|
+
return settings.loss_fn(
|
579
|
+
results_ss.loc[:, cast(list, settings.data.columns)],
|
580
|
+
settings.data,
|
428
581
|
)
|
429
582
|
|
430
583
|
|
@@ -433,12 +586,13 @@ def steady_state(
|
|
433
586
|
*,
|
434
587
|
p0: dict[str, float],
|
435
588
|
data: pd.Series,
|
589
|
+
minimizer: Minimizer,
|
436
590
|
y0: dict[str, float] | None = None,
|
437
|
-
|
438
|
-
residual_fn: SteadyStateResidualFn = _steady_state_residual,
|
591
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
439
592
|
integrator: IntegratorType | None = None,
|
440
593
|
loss_fn: LossFn = rmse,
|
441
594
|
bounds: Bounds | None = None,
|
595
|
+
as_deepcopy: bool = True,
|
442
596
|
) -> FitResult | None:
|
443
597
|
"""Fit model parameters to steady-state experimental data.
|
444
598
|
|
@@ -449,13 +603,14 @@ def steady_state(
|
|
449
603
|
Args:
|
450
604
|
model: Model instance to fit
|
451
605
|
data: Experimental steady state data as pandas Series
|
452
|
-
p0: Initial
|
606
|
+
p0: Initial guesses as {name: value}
|
453
607
|
y0: Initial conditions as {species_name: value}
|
454
|
-
|
608
|
+
minimizer: Function to minimize fitting error
|
455
609
|
residual_fn: Function to calculate fitting error
|
456
610
|
integrator: ODE integrator class
|
457
611
|
loss_fn: Loss function to use for residual calculation
|
458
612
|
bounds: Mapping of bounds per parameter
|
613
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
459
614
|
|
460
615
|
Returns:
|
461
616
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -464,31 +619,30 @@ def steady_state(
|
|
464
619
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
465
620
|
|
466
621
|
"""
|
467
|
-
|
622
|
+
if as_deepcopy:
|
623
|
+
model = deepcopy(model)
|
468
624
|
|
469
|
-
|
470
|
-
|
625
|
+
p_names = model.get_parameter_names()
|
626
|
+
v_names = model.get_variable_names()
|
471
627
|
|
472
|
-
fn =
|
473
|
-
|
474
|
-
|
475
|
-
residual_fn,
|
476
|
-
data=data,
|
628
|
+
fn: ResidualFn = partial(
|
629
|
+
residual_fn,
|
630
|
+
settings=_Settings(
|
477
631
|
model=model,
|
632
|
+
data=data,
|
478
633
|
y0=y0,
|
479
|
-
par_names=par_names,
|
480
634
|
integrator=integrator,
|
481
635
|
loss_fn=loss_fn,
|
636
|
+
p_names=[i for i in p0 if i in p_names],
|
637
|
+
v_names=[i for i in p0 if i in v_names],
|
482
638
|
),
|
483
639
|
)
|
484
|
-
min_result =
|
485
|
-
# Restore original model
|
486
|
-
model.update_parameters(p_orig)
|
640
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
487
641
|
if min_result is None:
|
488
|
-
return
|
642
|
+
return None
|
489
643
|
|
490
644
|
return FitResult(
|
491
|
-
model=
|
645
|
+
model=model,
|
492
646
|
best_pars=min_result.parameters,
|
493
647
|
loss=min_result.residual,
|
494
648
|
)
|
@@ -499,12 +653,13 @@ def time_course(
|
|
499
653
|
*,
|
500
654
|
p0: dict[str, float],
|
501
655
|
data: pd.DataFrame,
|
656
|
+
minimizer: Minimizer,
|
502
657
|
y0: dict[str, float] | None = None,
|
503
|
-
|
504
|
-
residual_fn: TimeSeriesResidualFn = _time_course_residual,
|
658
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
505
659
|
integrator: IntegratorType | None = None,
|
506
660
|
loss_fn: LossFn = rmse,
|
507
661
|
bounds: Bounds | None = None,
|
662
|
+
as_deepcopy: bool = True,
|
508
663
|
) -> FitResult | None:
|
509
664
|
"""Fit model parameters to time course of experimental data.
|
510
665
|
|
@@ -515,13 +670,14 @@ def time_course(
|
|
515
670
|
Args:
|
516
671
|
model: Model instance to fit
|
517
672
|
data: Experimental time course data
|
518
|
-
p0: Initial
|
673
|
+
p0: Initial guesses as {parameter_name: value}
|
519
674
|
y0: Initial conditions as {species_name: value}
|
520
|
-
|
675
|
+
minimizer: Function to minimize fitting error
|
521
676
|
residual_fn: Function to calculate fitting error
|
522
677
|
integrator: ODE integrator class
|
523
678
|
loss_fn: Loss function to use for residual calculation
|
524
679
|
bounds: Mapping of bounds per parameter
|
680
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
525
681
|
|
526
682
|
Returns:
|
527
683
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -530,30 +686,30 @@ def time_course(
|
|
530
686
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
531
687
|
|
532
688
|
"""
|
533
|
-
|
534
|
-
|
689
|
+
if as_deepcopy:
|
690
|
+
model = deepcopy(model)
|
691
|
+
p_names = model.get_parameter_names()
|
692
|
+
v_names = model.get_variable_names()
|
535
693
|
|
536
|
-
fn =
|
537
|
-
|
538
|
-
|
539
|
-
residual_fn,
|
540
|
-
data=data,
|
694
|
+
fn: ResidualFn = partial(
|
695
|
+
residual_fn,
|
696
|
+
settings=_Settings(
|
541
697
|
model=model,
|
698
|
+
data=data,
|
542
699
|
y0=y0,
|
543
|
-
par_names=par_names,
|
544
700
|
integrator=integrator,
|
545
701
|
loss_fn=loss_fn,
|
702
|
+
p_names=[i for i in p0 if i in p_names],
|
703
|
+
v_names=[i for i in p0 if i in v_names],
|
546
704
|
),
|
547
705
|
)
|
548
706
|
|
549
|
-
min_result =
|
550
|
-
# Restore original model
|
551
|
-
model.update_parameters(p_orig)
|
707
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
552
708
|
if min_result is None:
|
553
|
-
return
|
709
|
+
return None
|
554
710
|
|
555
711
|
return FitResult(
|
556
|
-
model=
|
712
|
+
model=model,
|
557
713
|
best_pars=min_result.parameters,
|
558
714
|
loss=min_result.residual,
|
559
715
|
)
|
@@ -565,12 +721,13 @@ def protocol_time_course(
|
|
565
721
|
p0: dict[str, float],
|
566
722
|
data: pd.DataFrame,
|
567
723
|
protocol: pd.DataFrame,
|
724
|
+
minimizer: Minimizer,
|
568
725
|
y0: dict[str, float] | None = None,
|
569
|
-
|
570
|
-
residual_fn: ProtocolResidualFn = _protocol_time_course_residual,
|
726
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
571
727
|
integrator: IntegratorType | None = None,
|
572
728
|
loss_fn: LossFn = rmse,
|
573
729
|
bounds: Bounds | None = None,
|
730
|
+
as_deepcopy: bool = True,
|
574
731
|
) -> FitResult | None:
|
575
732
|
"""Fit model parameters to time course of experimental data.
|
576
733
|
|
@@ -586,12 +743,13 @@ def protocol_time_course(
|
|
586
743
|
data: Experimental time course data
|
587
744
|
protocol: Experimental protocol
|
588
745
|
y0: Initial conditions as {species_name: value}
|
589
|
-
|
746
|
+
minimizer: Function to minimize fitting error
|
590
747
|
residual_fn: Function to calculate fitting error
|
591
748
|
integrator: ODE integrator class
|
592
749
|
loss_fn: Loss function to use for residual calculation
|
593
750
|
time_points_per_step: Number of time points per step in the protocol
|
594
751
|
bounds: Mapping of bounds per parameter
|
752
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
595
753
|
|
596
754
|
Returns:
|
597
755
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -600,65 +758,73 @@ def protocol_time_course(
|
|
600
758
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
601
759
|
|
602
760
|
"""
|
603
|
-
|
604
|
-
|
761
|
+
if as_deepcopy:
|
762
|
+
model = deepcopy(model)
|
763
|
+
p_names = model.get_parameter_names()
|
764
|
+
v_names = model.get_variable_names()
|
605
765
|
|
606
|
-
fn =
|
607
|
-
|
608
|
-
|
609
|
-
residual_fn,
|
610
|
-
data=data,
|
766
|
+
fn: ResidualFn = partial(
|
767
|
+
residual_fn,
|
768
|
+
settings=_Settings(
|
611
769
|
model=model,
|
770
|
+
data=data,
|
612
771
|
y0=y0,
|
613
|
-
par_names=par_names,
|
614
772
|
integrator=integrator,
|
615
773
|
loss_fn=loss_fn,
|
774
|
+
p_names=[i for i in p0 if i in p_names],
|
775
|
+
v_names=[i for i in p0 if i in v_names],
|
616
776
|
protocol=protocol,
|
617
777
|
),
|
618
778
|
)
|
619
779
|
|
620
|
-
min_result =
|
621
|
-
# Restore original model
|
622
|
-
model.update_parameters(p_orig)
|
780
|
+
min_result = minimizer(fn, p0, {} if bounds is None else bounds)
|
623
781
|
if min_result is None:
|
624
|
-
return
|
782
|
+
return None
|
625
783
|
|
626
784
|
return FitResult(
|
627
|
-
model=
|
785
|
+
model=model,
|
628
786
|
best_pars=min_result.parameters,
|
629
787
|
loss=min_result.residual,
|
630
788
|
)
|
631
789
|
|
632
790
|
|
633
|
-
|
634
|
-
|
791
|
+
###############################################################################
|
792
|
+
# Ensemble / carousel
|
793
|
+
# This is multi-model, single data fitting, where the models share parameters
|
794
|
+
###############################################################################
|
795
|
+
|
796
|
+
|
797
|
+
def ensemble_steady_state(
|
798
|
+
ensemble: list[Model],
|
635
799
|
*,
|
636
800
|
p0: dict[str, float],
|
637
801
|
data: pd.Series,
|
802
|
+
minimizer: Minimizer,
|
638
803
|
y0: dict[str, float] | None = None,
|
639
|
-
|
640
|
-
residual_fn: SteadyStateResidualFn = _steady_state_residual,
|
804
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
641
805
|
integrator: IntegratorType | None = None,
|
642
806
|
loss_fn: LossFn = rmse,
|
643
807
|
bounds: Bounds | None = None,
|
644
|
-
|
645
|
-
|
808
|
+
as_deepcopy: bool = True,
|
809
|
+
) -> EnsembleFitResult:
|
810
|
+
"""Fit model ensemble parameters to steady-state experimental data.
|
646
811
|
|
647
812
|
Examples:
|
648
813
|
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
649
814
|
|
650
815
|
Args:
|
651
|
-
|
816
|
+
ensemble: Ensemble to fit
|
652
817
|
p0: Initial parameter guesses as {parameter_name: value}
|
653
818
|
data: Experimental time course data
|
654
819
|
protocol: Experimental protocol
|
655
820
|
y0: Initial conditions as {species_name: value}
|
656
|
-
|
821
|
+
minimizer: Function to minimize fitting error
|
657
822
|
residual_fn: Function to calculate fitting error
|
658
823
|
integrator: ODE integrator class
|
659
824
|
loss_fn: Loss function to use for residual calculation
|
660
825
|
time_points_per_step: Number of time points per step in the protocol
|
661
826
|
bounds: Mapping of bounds per parameter
|
827
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
662
828
|
|
663
829
|
Returns:
|
664
830
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -667,40 +833,95 @@ def carousel_steady_state(
|
|
667
833
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
668
834
|
|
669
835
|
"""
|
670
|
-
return
|
836
|
+
return EnsembleFitResult(
|
671
837
|
[
|
672
838
|
fit
|
673
839
|
for i in parallel.parallelise(
|
674
840
|
partial(
|
675
|
-
|
841
|
+
steady_state,
|
676
842
|
p0=p0,
|
677
843
|
data=data,
|
678
844
|
y0=y0,
|
679
845
|
integrator=integrator,
|
680
846
|
loss_fn=loss_fn,
|
681
|
-
|
847
|
+
minimizer=minimizer,
|
682
848
|
residual_fn=residual_fn,
|
683
849
|
bounds=bounds,
|
850
|
+
as_deepcopy=as_deepcopy,
|
684
851
|
),
|
685
|
-
inputs=list(enumerate(
|
852
|
+
inputs=list(enumerate(ensemble)),
|
686
853
|
)
|
687
854
|
if (fit := i[1]) is not None
|
688
855
|
]
|
689
856
|
)
|
690
857
|
|
691
858
|
|
692
|
-
def
|
859
|
+
def carousel_steady_state(
|
693
860
|
carousel: Carousel,
|
694
861
|
*,
|
695
862
|
p0: dict[str, float],
|
863
|
+
data: pd.Series,
|
864
|
+
minimizer: Minimizer,
|
865
|
+
y0: dict[str, float] | None = None,
|
866
|
+
residual_fn: ResidualProtocol = steady_state_residual,
|
867
|
+
integrator: IntegratorType | None = None,
|
868
|
+
loss_fn: LossFn = rmse,
|
869
|
+
bounds: Bounds | None = None,
|
870
|
+
as_deepcopy: bool = True,
|
871
|
+
) -> EnsembleFitResult:
|
872
|
+
"""Fit model parameters to steady-state experimental data over a carousel.
|
873
|
+
|
874
|
+
Examples:
|
875
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
876
|
+
|
877
|
+
Args:
|
878
|
+
carousel: Model carousel to fit
|
879
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
880
|
+
data: Experimental time course data
|
881
|
+
protocol: Experimental protocol
|
882
|
+
y0: Initial conditions as {species_name: value}
|
883
|
+
minimizer: Function to minimize fitting error
|
884
|
+
residual_fn: Function to calculate fitting error
|
885
|
+
integrator: ODE integrator class
|
886
|
+
loss_fn: Loss function to use for residual calculation
|
887
|
+
time_points_per_step: Number of time points per step in the protocol
|
888
|
+
bounds: Mapping of bounds per parameter
|
889
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
890
|
+
|
891
|
+
Returns:
|
892
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
893
|
+
|
894
|
+
Note:
|
895
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
896
|
+
|
897
|
+
"""
|
898
|
+
return ensemble_steady_state(
|
899
|
+
carousel.variants,
|
900
|
+
p0=p0,
|
901
|
+
data=data,
|
902
|
+
minimizer=minimizer,
|
903
|
+
y0=y0,
|
904
|
+
residual_fn=residual_fn,
|
905
|
+
integrator=integrator,
|
906
|
+
loss_fn=loss_fn,
|
907
|
+
bounds=bounds,
|
908
|
+
as_deepcopy=as_deepcopy,
|
909
|
+
)
|
910
|
+
|
911
|
+
|
912
|
+
def ensemble_time_course(
|
913
|
+
ensemble: list[Model],
|
914
|
+
*,
|
915
|
+
p0: dict[str, float],
|
696
916
|
data: pd.DataFrame,
|
917
|
+
minimizer: Minimizer,
|
697
918
|
y0: dict[str, float] | None = None,
|
698
|
-
|
699
|
-
residual_fn: TimeSeriesResidualFn = _time_course_residual,
|
919
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
700
920
|
integrator: IntegratorType | None = None,
|
701
921
|
loss_fn: LossFn = rmse,
|
702
922
|
bounds: Bounds | None = None,
|
703
|
-
|
923
|
+
as_deepcopy: bool = True,
|
924
|
+
) -> EnsembleFitResult:
|
704
925
|
"""Fit model parameters to time course of experimental data over a carousel.
|
705
926
|
|
706
927
|
Time points are taken from the data.
|
@@ -709,17 +930,18 @@ def carousel_time_course(
|
|
709
930
|
>>> carousel_time_course(carousel, p0=p0, data=data)
|
710
931
|
|
711
932
|
Args:
|
712
|
-
|
933
|
+
ensemble: Model ensemble to fit
|
713
934
|
p0: Initial parameter guesses as {parameter_name: value}
|
714
935
|
data: Experimental time course data
|
715
936
|
protocol: Experimental protocol
|
716
937
|
y0: Initial conditions as {species_name: value}
|
717
|
-
|
938
|
+
minimizer: Function to minimize fitting error
|
718
939
|
residual_fn: Function to calculate fitting error
|
719
940
|
integrator: ODE integrator class
|
720
941
|
loss_fn: Loss function to use for residual calculation
|
721
942
|
time_points_per_step: Number of time points per step in the protocol
|
722
943
|
bounds: Mapping of bounds per parameter
|
944
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
723
945
|
|
724
946
|
Returns:
|
725
947
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -728,41 +950,98 @@ def carousel_time_course(
|
|
728
950
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
729
951
|
|
730
952
|
"""
|
731
|
-
return
|
953
|
+
return EnsembleFitResult(
|
732
954
|
[
|
733
955
|
fit
|
734
956
|
for i in parallel.parallelise(
|
735
957
|
partial(
|
736
|
-
|
958
|
+
time_course,
|
737
959
|
p0=p0,
|
738
960
|
data=data,
|
739
961
|
y0=y0,
|
740
962
|
integrator=integrator,
|
741
963
|
loss_fn=loss_fn,
|
742
|
-
|
964
|
+
minimizer=minimizer,
|
743
965
|
residual_fn=residual_fn,
|
744
966
|
bounds=bounds,
|
967
|
+
as_deepcopy=as_deepcopy,
|
745
968
|
),
|
746
|
-
inputs=list(enumerate(
|
969
|
+
inputs=list(enumerate(ensemble)),
|
747
970
|
)
|
748
971
|
if (fit := i[1]) is not None
|
749
972
|
]
|
750
973
|
)
|
751
974
|
|
752
975
|
|
753
|
-
def
|
976
|
+
def carousel_time_course(
|
754
977
|
carousel: Carousel,
|
755
978
|
*,
|
756
979
|
p0: dict[str, float],
|
757
980
|
data: pd.DataFrame,
|
981
|
+
minimizer: Minimizer,
|
982
|
+
y0: dict[str, float] | None = None,
|
983
|
+
residual_fn: ResidualProtocol = time_course_residual,
|
984
|
+
integrator: IntegratorType | None = None,
|
985
|
+
loss_fn: LossFn = rmse,
|
986
|
+
bounds: Bounds | None = None,
|
987
|
+
as_deepcopy: bool = True,
|
988
|
+
) -> EnsembleFitResult:
|
989
|
+
"""Fit model parameters to time course of experimental data over a carousel.
|
990
|
+
|
991
|
+
Time points are taken from the data.
|
992
|
+
|
993
|
+
Examples:
|
994
|
+
>>> carousel_time_course(carousel, p0=p0, data=data)
|
995
|
+
|
996
|
+
Args:
|
997
|
+
carousel: Model carousel to fit
|
998
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
999
|
+
data: Experimental time course data
|
1000
|
+
protocol: Experimental protocol
|
1001
|
+
y0: Initial conditions as {species_name: value}
|
1002
|
+
minimizer: Function to minimize fitting error
|
1003
|
+
residual_fn: Function to calculate fitting error
|
1004
|
+
integrator: ODE integrator class
|
1005
|
+
loss_fn: Loss function to use for residual calculation
|
1006
|
+
time_points_per_step: Number of time points per step in the protocol
|
1007
|
+
bounds: Mapping of bounds per parameter
|
1008
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
1009
|
+
|
1010
|
+
Returns:
|
1011
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
1012
|
+
|
1013
|
+
Note:
|
1014
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
1015
|
+
|
1016
|
+
"""
|
1017
|
+
return ensemble_time_course(
|
1018
|
+
carousel.variants,
|
1019
|
+
p0=p0,
|
1020
|
+
data=data,
|
1021
|
+
minimizer=minimizer,
|
1022
|
+
y0=y0,
|
1023
|
+
residual_fn=residual_fn,
|
1024
|
+
integrator=integrator,
|
1025
|
+
loss_fn=loss_fn,
|
1026
|
+
bounds=bounds,
|
1027
|
+
as_deepcopy=as_deepcopy,
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
|
1031
|
+
def ensemble_protocol_time_course(
|
1032
|
+
ensemble: list[Model],
|
1033
|
+
*,
|
1034
|
+
p0: dict[str, float],
|
1035
|
+
data: pd.DataFrame,
|
1036
|
+
minimizer: Minimizer,
|
758
1037
|
protocol: pd.DataFrame,
|
759
1038
|
y0: dict[str, float] | None = None,
|
760
|
-
|
761
|
-
residual_fn: ProtocolResidualFn = _protocol_time_course_residual,
|
1039
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
762
1040
|
integrator: IntegratorType | None = None,
|
763
1041
|
loss_fn: LossFn = rmse,
|
764
1042
|
bounds: Bounds | None = None,
|
765
|
-
|
1043
|
+
as_deepcopy: bool = True,
|
1044
|
+
) -> EnsembleFitResult:
|
766
1045
|
"""Fit model parameters to time course of experimental data over a protocol.
|
767
1046
|
|
768
1047
|
Time points of protocol time course are taken from the data.
|
@@ -771,17 +1050,18 @@ def carousel_protocol_time_course(
|
|
771
1050
|
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
772
1051
|
|
773
1052
|
Args:
|
774
|
-
|
775
|
-
p0:
|
1053
|
+
ensemble: Model ensemble: value}
|
1054
|
+
p0: initial parameter guess
|
776
1055
|
data: Experimental time course data
|
777
1056
|
protocol: Experimental protocol
|
778
1057
|
y0: Initial conditions as {species_name: value}
|
779
|
-
|
1058
|
+
minimizer: Function to minimize fitting error
|
780
1059
|
residual_fn: Function to calculate fitting error
|
781
1060
|
integrator: ODE integrator class
|
782
1061
|
loss_fn: Loss function to use for residual calculation
|
783
1062
|
time_points_per_step: Number of time points per step in the protocol
|
784
1063
|
bounds: Mapping of bounds per parameter
|
1064
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
785
1065
|
|
786
1066
|
Returns:
|
787
1067
|
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
@@ -790,24 +1070,345 @@ def carousel_protocol_time_course(
|
|
790
1070
|
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
791
1071
|
|
792
1072
|
"""
|
793
|
-
return
|
1073
|
+
return EnsembleFitResult(
|
794
1074
|
[
|
795
1075
|
fit
|
796
1076
|
for i in parallel.parallelise(
|
797
1077
|
partial(
|
798
|
-
|
1078
|
+
protocol_time_course,
|
799
1079
|
p0=p0,
|
800
1080
|
data=data,
|
801
1081
|
protocol=protocol,
|
802
1082
|
y0=y0,
|
803
1083
|
integrator=integrator,
|
804
1084
|
loss_fn=loss_fn,
|
805
|
-
|
1085
|
+
minimizer=minimizer,
|
806
1086
|
residual_fn=residual_fn,
|
807
1087
|
bounds=bounds,
|
1088
|
+
as_deepcopy=as_deepcopy,
|
808
1089
|
),
|
809
|
-
inputs=list(enumerate(
|
1090
|
+
inputs=list(enumerate(ensemble)),
|
810
1091
|
)
|
811
1092
|
if (fit := i[1]) is not None
|
812
1093
|
]
|
813
1094
|
)
|
1095
|
+
|
1096
|
+
|
1097
|
+
def carousel_protocol_time_course(
|
1098
|
+
carousel: Carousel,
|
1099
|
+
*,
|
1100
|
+
p0: dict[str, float],
|
1101
|
+
data: pd.DataFrame,
|
1102
|
+
minimizer: Minimizer,
|
1103
|
+
protocol: pd.DataFrame,
|
1104
|
+
y0: dict[str, float] | None = None,
|
1105
|
+
residual_fn: ResidualProtocol = protocol_time_course_residual,
|
1106
|
+
integrator: IntegratorType | None = None,
|
1107
|
+
loss_fn: LossFn = rmse,
|
1108
|
+
bounds: Bounds | None = None,
|
1109
|
+
as_deepcopy: bool = True,
|
1110
|
+
) -> EnsembleFitResult:
|
1111
|
+
"""Fit model parameters to time course of experimental data over a protocol.
|
1112
|
+
|
1113
|
+
Time points of protocol time course are taken from the data.
|
1114
|
+
|
1115
|
+
Examples:
|
1116
|
+
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
1117
|
+
|
1118
|
+
Args:
|
1119
|
+
carousel: Model carousel to fit
|
1120
|
+
p0: Initial parameter guesses as {parameter_name: value}
|
1121
|
+
data: Experimental time course data
|
1122
|
+
protocol: Experimental protocol
|
1123
|
+
y0: Initial conditions as {species_name: value}
|
1124
|
+
minimizer: Function to minimize fitting error
|
1125
|
+
residual_fn: Function to calculate fitting error
|
1126
|
+
integrator: ODE integrator class
|
1127
|
+
loss_fn: Loss function to use for residual calculation
|
1128
|
+
time_points_per_step: Number of time points per step in the protocol
|
1129
|
+
bounds: Mapping of bounds per parameter
|
1130
|
+
as_deepcopy: Whether to copy the model to avoid overwriting the state
|
1131
|
+
|
1132
|
+
Returns:
|
1133
|
+
dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
|
1134
|
+
|
1135
|
+
Note:
|
1136
|
+
Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
|
1137
|
+
|
1138
|
+
"""
|
1139
|
+
return ensemble_protocol_time_course(
|
1140
|
+
carousel.variants,
|
1141
|
+
p0=p0,
|
1142
|
+
data=data,
|
1143
|
+
minimizer=minimizer,
|
1144
|
+
protocol=protocol,
|
1145
|
+
y0=y0,
|
1146
|
+
residual_fn=residual_fn,
|
1147
|
+
integrator=integrator,
|
1148
|
+
loss_fn=loss_fn,
|
1149
|
+
bounds=bounds,
|
1150
|
+
as_deepcopy=as_deepcopy,
|
1151
|
+
)
|
1152
|
+
|
1153
|
+
|
1154
|
+
###############################################################################
|
1155
|
+
# Joint fitting
|
1156
|
+
# This is multi-model, multi-data fitting, where the models share some parameters
|
1157
|
+
###############################################################################
|
1158
|
+
|
1159
|
+
|
1160
|
+
def _unpacked[T1, T2, Tout](inp: tuple[T1, T2], fn: Callable[[T1, T2], Tout]) -> Tout:
|
1161
|
+
return fn(*inp)
|
1162
|
+
|
1163
|
+
|
1164
|
+
def _sum_of_residuals(
|
1165
|
+
updates: dict[str, float],
|
1166
|
+
residual_fn: ResidualProtocol,
|
1167
|
+
fits: list[_Settings],
|
1168
|
+
pool: pebble.ProcessPool,
|
1169
|
+
) -> float:
|
1170
|
+
future = pool.map(
|
1171
|
+
partial(_unpacked, fn=residual_fn),
|
1172
|
+
[(updates, i) for i in fits],
|
1173
|
+
timeout=None,
|
1174
|
+
)
|
1175
|
+
error = 0.0
|
1176
|
+
it = future.result()
|
1177
|
+
while True:
|
1178
|
+
try:
|
1179
|
+
error += next(it)
|
1180
|
+
except StopIteration:
|
1181
|
+
break
|
1182
|
+
except TimeoutError:
|
1183
|
+
return np.inf
|
1184
|
+
return error
|
1185
|
+
|
1186
|
+
|
1187
|
+
def joint_steady_state(
|
1188
|
+
to_fit: list[FitSettings],
|
1189
|
+
*,
|
1190
|
+
p0: dict[str, float],
|
1191
|
+
minimizer: Minimizer,
|
1192
|
+
y0: dict[str, float] | None = None,
|
1193
|
+
integrator: IntegratorType | None = None,
|
1194
|
+
loss_fn: LossFn = rmse,
|
1195
|
+
bounds: Bounds | None = None,
|
1196
|
+
max_workers: int | None = None,
|
1197
|
+
as_deepcopy: bool = True,
|
1198
|
+
) -> JointFitResult | None:
|
1199
|
+
"""Multi-model, multi-data fitting."""
|
1200
|
+
full_settings = []
|
1201
|
+
for i in to_fit:
|
1202
|
+
p_names = i.model.get_parameter_names()
|
1203
|
+
v_names = i.model.get_variable_names()
|
1204
|
+
full_settings.append(
|
1205
|
+
_Settings(
|
1206
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1207
|
+
data=i.data,
|
1208
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1209
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1210
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1211
|
+
p_names=[j for j in p0 if j in p_names],
|
1212
|
+
v_names=[j for j in p0 if j in v_names],
|
1213
|
+
)
|
1214
|
+
)
|
1215
|
+
|
1216
|
+
with pebble.ProcessPool(
|
1217
|
+
max_workers=(
|
1218
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1219
|
+
)
|
1220
|
+
) as pool:
|
1221
|
+
min_result = minimizer(
|
1222
|
+
partial(
|
1223
|
+
_sum_of_residuals,
|
1224
|
+
residual_fn=steady_state_residual,
|
1225
|
+
fits=full_settings,
|
1226
|
+
pool=pool,
|
1227
|
+
),
|
1228
|
+
p0,
|
1229
|
+
{} if bounds is None else bounds,
|
1230
|
+
)
|
1231
|
+
if min_result is None:
|
1232
|
+
return None
|
1233
|
+
|
1234
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1235
|
+
|
1236
|
+
|
1237
|
+
def joint_time_course(
|
1238
|
+
to_fit: list[FitSettings],
|
1239
|
+
*,
|
1240
|
+
p0: dict[str, float],
|
1241
|
+
minimizer: Minimizer,
|
1242
|
+
y0: dict[str, float] | None = None,
|
1243
|
+
integrator: IntegratorType | None = None,
|
1244
|
+
loss_fn: LossFn = rmse,
|
1245
|
+
bounds: Bounds | None = None,
|
1246
|
+
max_workers: int | None = None,
|
1247
|
+
as_deepcopy: bool = True,
|
1248
|
+
) -> JointFitResult | None:
|
1249
|
+
"""Multi-model, multi-data fitting."""
|
1250
|
+
full_settings = []
|
1251
|
+
for i in to_fit:
|
1252
|
+
p_names = i.model.get_parameter_names()
|
1253
|
+
v_names = i.model.get_variable_names()
|
1254
|
+
full_settings.append(
|
1255
|
+
_Settings(
|
1256
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1257
|
+
data=i.data,
|
1258
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1259
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1260
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1261
|
+
p_names=[j for j in p0 if j in p_names],
|
1262
|
+
v_names=[j for j in p0 if j in v_names],
|
1263
|
+
)
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
with pebble.ProcessPool(
|
1267
|
+
max_workers=(
|
1268
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1269
|
+
)
|
1270
|
+
) as pool:
|
1271
|
+
min_result = minimizer(
|
1272
|
+
partial(
|
1273
|
+
_sum_of_residuals,
|
1274
|
+
residual_fn=time_course_residual,
|
1275
|
+
fits=full_settings,
|
1276
|
+
pool=pool,
|
1277
|
+
),
|
1278
|
+
p0,
|
1279
|
+
{} if bounds is None else bounds,
|
1280
|
+
)
|
1281
|
+
if min_result is None:
|
1282
|
+
return None
|
1283
|
+
|
1284
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1285
|
+
|
1286
|
+
|
1287
|
+
def joint_protocol_time_course(
|
1288
|
+
to_fit: list[FitSettings],
|
1289
|
+
*,
|
1290
|
+
p0: dict[str, float],
|
1291
|
+
minimizer: Minimizer,
|
1292
|
+
y0: dict[str, float] | None = None,
|
1293
|
+
integrator: IntegratorType | None = None,
|
1294
|
+
loss_fn: LossFn = rmse,
|
1295
|
+
bounds: Bounds | None = None,
|
1296
|
+
max_workers: int | None = None,
|
1297
|
+
as_deepcopy: bool = True,
|
1298
|
+
) -> JointFitResult | None:
|
1299
|
+
"""Multi-model, multi-data fitting."""
|
1300
|
+
full_settings = []
|
1301
|
+
for i in to_fit:
|
1302
|
+
p_names = i.model.get_parameter_names()
|
1303
|
+
v_names = i.model.get_variable_names()
|
1304
|
+
full_settings.append(
|
1305
|
+
_Settings(
|
1306
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1307
|
+
data=i.data,
|
1308
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1309
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1310
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1311
|
+
p_names=[j for j in p0 if j in p_names],
|
1312
|
+
v_names=[j for j in p0 if j in v_names],
|
1313
|
+
)
|
1314
|
+
)
|
1315
|
+
|
1316
|
+
with pebble.ProcessPool(
|
1317
|
+
max_workers=(
|
1318
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1319
|
+
)
|
1320
|
+
) as pool:
|
1321
|
+
min_result = minimizer(
|
1322
|
+
partial(
|
1323
|
+
_sum_of_residuals,
|
1324
|
+
residual_fn=protocol_time_course_residual,
|
1325
|
+
fits=full_settings,
|
1326
|
+
pool=pool,
|
1327
|
+
),
|
1328
|
+
p0,
|
1329
|
+
{} if bounds is None else bounds,
|
1330
|
+
)
|
1331
|
+
if min_result is None:
|
1332
|
+
return None
|
1333
|
+
|
1334
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|
1335
|
+
|
1336
|
+
|
1337
|
+
###############################################################################
|
1338
|
+
# Joint fitting
|
1339
|
+
# This is multi-model, multi-data, multi-simulation fitting
|
1340
|
+
# The models share some parameters here, everything else can be changed though
|
1341
|
+
###############################################################################
|
1342
|
+
|
1343
|
+
|
1344
|
+
def _execute(inp: tuple[dict[str, float], ResidualProtocol, _Settings]) -> float:
|
1345
|
+
updates, residual_fn, settings = inp
|
1346
|
+
return residual_fn(updates, settings)
|
1347
|
+
|
1348
|
+
|
1349
|
+
def _mixed_sum_of_residuals(
|
1350
|
+
updates: dict[str, float],
|
1351
|
+
fits: list[_Settings],
|
1352
|
+
pool: pebble.ProcessPool,
|
1353
|
+
) -> float:
|
1354
|
+
future = pool.map(_execute, [(updates, i.residual_fn, i) for i in fits])
|
1355
|
+
error = 0.0
|
1356
|
+
it = future.result()
|
1357
|
+
while True:
|
1358
|
+
try:
|
1359
|
+
error += next(it)
|
1360
|
+
except StopIteration:
|
1361
|
+
break
|
1362
|
+
except TimeoutError:
|
1363
|
+
return np.inf
|
1364
|
+
return error
|
1365
|
+
|
1366
|
+
|
1367
|
+
def joint_mixed(
|
1368
|
+
to_fit: list[MixedSettings],
|
1369
|
+
*,
|
1370
|
+
p0: dict[str, float],
|
1371
|
+
minimizer: Minimizer,
|
1372
|
+
y0: dict[str, float] | None = None,
|
1373
|
+
integrator: IntegratorType | None = None,
|
1374
|
+
loss_fn: LossFn = rmse,
|
1375
|
+
bounds: Bounds | None = None,
|
1376
|
+
max_workers: int | None = None,
|
1377
|
+
as_deepcopy: bool = True,
|
1378
|
+
) -> JointFitResult | None:
|
1379
|
+
"""Multi-model, multi-data, multi-simulation fitting."""
|
1380
|
+
full_settings = []
|
1381
|
+
for i in to_fit:
|
1382
|
+
p_names = i.model.get_parameter_names()
|
1383
|
+
v_names = i.model.get_variable_names()
|
1384
|
+
full_settings.append(
|
1385
|
+
_Settings(
|
1386
|
+
model=deepcopy(i.model) if as_deepcopy else i.model,
|
1387
|
+
data=i.data,
|
1388
|
+
y0=i.y0 if i.y0 is not None else y0,
|
1389
|
+
integrator=i.integrator if i.integrator is not None else integrator,
|
1390
|
+
loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
|
1391
|
+
p_names=[j for j in p0 if j in p_names],
|
1392
|
+
v_names=[j for j in p0 if j in v_names],
|
1393
|
+
residual_fn=i.residual_fn,
|
1394
|
+
)
|
1395
|
+
)
|
1396
|
+
|
1397
|
+
with pebble.ProcessPool(
|
1398
|
+
max_workers=(
|
1399
|
+
multiprocessing.cpu_count() if max_workers is None else max_workers
|
1400
|
+
)
|
1401
|
+
) as pool:
|
1402
|
+
min_result = minimizer(
|
1403
|
+
partial(
|
1404
|
+
_mixed_sum_of_residuals,
|
1405
|
+
fits=full_settings,
|
1406
|
+
pool=pool,
|
1407
|
+
),
|
1408
|
+
p0,
|
1409
|
+
{} if bounds is None else bounds,
|
1410
|
+
)
|
1411
|
+
if min_result is None:
|
1412
|
+
return None
|
1413
|
+
|
1414
|
+
return JointFitResult(min_result.parameters, loss=min_result.residual)
|