mxlpy 0.24.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/fit.py CHANGED
@@ -1,60 +1,226 @@
1
- """Parameter Fitting Module for Metabolic Models.
1
+ """Fitting routines.
2
+
3
+ Single model, single data routines
4
+ ----------------------------------
5
+ - `steady_state`
6
+ - `time_course`
7
+ - `protocol_time_course`
8
+
9
+ Multiple model, single data routines
10
+ ------------------------------------
11
+ - `ensemble_steady_state`
12
+ - `ensemble_time_course`
13
+ - `ensemble_protocol_time_course`
14
+
15
+ A carousel is a special case of an ensemble, where the general
16
+ structure (e.g. stoichiometries) is the same, while the reactions kinetics
17
+ can vary
18
+ - `carousel_steady_state`
19
+ - `carousel_time_course`
20
+ - `carousel_protocol_time_course`
21
+
22
+ Multiple model, multiple data
23
+ -----------------------------
24
+ - `joint_steady_state`
25
+ - `joint_time_course`
26
+ - `joint_protocol_time_course`
27
+
28
+ Multiple model, multiple data, multiple methods
29
+ -----------------------------------------------
30
+ Here we also allow to run different methods (e.g. steady-state vs time courses)
31
+ for each combination of model:data.
32
+
33
+ - `joint_mixed`
34
+
35
+ Minimizers
36
+ ----------
37
+ - LocalScipyMinimizer, including common methods such as Nelder-Mead or L-BFGS-B
38
+ - GlobalScipyMinimizer, including common methods such as basin hopping or dual annealing
39
+
40
+ Loss functions
41
+ --------------
42
+ - rmse
2
43
 
3
- This module provides functions foru fitting model parameters to experimental data,
4
- including both steadyd-state and time-series data fitting capabilities.e
5
-
6
- Functions:
7
- fit_steady_state: Fits parameters to steady-state experimental data
8
- fit_time_course: Fits parameters to time-series experimental data
9
44
  """
10
45
 
11
46
  from __future__ import annotations
12
47
 
13
48
  import logging
49
+ import multiprocessing
50
+ from collections.abc import Callable
14
51
  from copy import deepcopy
15
52
  from dataclasses import dataclass
16
53
  from functools import partial
17
- from typing import TYPE_CHECKING, Protocol
54
+ from typing import TYPE_CHECKING, Literal, Protocol
18
55
 
19
56
  import numpy as np
20
- from scipy.optimize import minimize
57
+ import pandas as pd
58
+ import pebble
59
+ from scipy.optimize import (
60
+ basinhopping,
61
+ differential_evolution,
62
+ direct,
63
+ dual_annealing,
64
+ minimize,
65
+ shgo,
66
+ )
21
67
  from wadler_lindig import pformat
22
68
 
23
69
  from mxlpy import parallel
70
+ from mxlpy.model import Model
24
71
  from mxlpy.simulator import Simulator
25
- from mxlpy.types import Array, ArrayLike, Callable, IntegratorType, cast
72
+ from mxlpy.types import Array, IntegratorType, cast
26
73
 
27
74
  if TYPE_CHECKING:
28
75
  import pandas as pd
76
+ from scipy.optimize._optimize import OptimizeResult
29
77
 
30
78
  from mxlpy.carousel import Carousel
31
79
  from mxlpy.model import Model
32
80
 
33
81
  LOGGER = logging.getLogger(__name__)
34
82
 
83
+
35
84
  __all__ = [
36
85
  "Bounds",
37
- "CarouselFit",
86
+ "EnsembleFitResult",
38
87
  "FitResult",
88
+ "FitSettings",
89
+ "GlobalScipyMinimizer",
39
90
  "InitialGuess",
91
+ "JointFitResult",
40
92
  "LOGGER",
93
+ "LocalScipyMinimizer",
41
94
  "LossFn",
42
95
  "MinResult",
43
- "MinimizeFn",
44
- "ProtocolResidualFn",
96
+ "Minimizer",
97
+ "MixedSettings",
98
+ "ResFn",
45
99
  "ResidualFn",
46
- "SteadyStateResidualFn",
47
- "TimeSeriesResidualFn",
100
+ "ResidualProtocol",
48
101
  "carousel_protocol_time_course",
49
102
  "carousel_steady_state",
50
103
  "carousel_time_course",
104
+ "cosine_similarity",
105
+ "ensemble_protocol_time_course",
106
+ "ensemble_steady_state",
107
+ "ensemble_time_course",
108
+ "joint_mixed",
109
+ "joint_protocol_time_course",
110
+ "joint_steady_state",
111
+ "joint_time_course",
112
+ "mae",
113
+ "mean",
114
+ "mean_absolute_percentage",
115
+ "mean_squared",
116
+ "mean_squared_logarithmic",
51
117
  "protocol_time_course",
118
+ "protocol_time_course_residual",
52
119
  "rmse",
53
120
  "steady_state",
121
+ "steady_state_residual",
54
122
  "time_course",
123
+ "time_course_residual",
124
+ ]
125
+
126
+ type InitialGuess = dict[str, float]
127
+
128
+ type Bounds = dict[str, tuple[float | None, float | None]]
129
+ type ResFn = Callable[[Array], float]
130
+ type LossFn = Callable[
131
+ [
132
+ pd.DataFrame | pd.Series,
133
+ pd.DataFrame | pd.Series,
134
+ ],
135
+ float,
136
+ ]
137
+
138
+
139
+ type Minimizer = Callable[
140
+ [
141
+ ResidualFn,
142
+ InitialGuess,
143
+ Bounds,
144
+ ],
145
+ MinResult | None,
55
146
  ]
56
147
 
57
148
 
149
+ class ResidualProtocol(Protocol):
150
+ """Protocol for steady state residual functions.
151
+
152
+ This is the user-facing variant, for stuff like
153
+ - `fit.steady_state`
154
+ - `fit.time_course`
155
+ - `fit.protocol_time_course`
156
+
157
+ The settings are later partially applied to yield ResidualFn
158
+ """
159
+
160
+ def __call__(
161
+ self,
162
+ updates: dict[str, float],
163
+ settings: _Settings,
164
+ ) -> float:
165
+ """Calculate residual error between model steady state and experimental data."""
166
+ ...
167
+
168
+
169
+ class ResidualFn(Protocol):
170
+ """Protocol for steady state residual functions.
171
+
172
+ This is the internal version, which is produced by partial
173
+ application of `settings` of `ResidualProtocol`
174
+ """
175
+
176
+ def __call__(
177
+ self,
178
+ updates: dict[str, float],
179
+ ) -> float:
180
+ """Calculate residual error between model steady state and experimental data."""
181
+ ...
182
+
183
+
184
+ @dataclass
185
+ class FitSettings:
186
+ """Settings for a fit."""
187
+
188
+ model: Model
189
+ data: pd.Series | pd.DataFrame
190
+ y0: dict[str, float] | None = None
191
+ integrator: IntegratorType | None = None
192
+ loss_fn: LossFn | None = None
193
+ protocol: pd.DataFrame | None = None
194
+
195
+
196
+ @dataclass
197
+ class MixedSettings:
198
+ """Settings for a fit."""
199
+
200
+ model: Model
201
+ data: pd.Series | pd.DataFrame
202
+ residual_fn: ResidualFn
203
+ y0: dict[str, float] | None = None
204
+ integrator: IntegratorType | None = None
205
+ loss_fn: LossFn | None = None
206
+ protocol: pd.DataFrame | None = None
207
+
208
+
209
+ @dataclass
210
+ class _Settings:
211
+ """Non user-facing version of FitSettings."""
212
+
213
+ model: Model
214
+ data: pd.Series | pd.DataFrame
215
+ y0: dict[str, float] | None
216
+ integrator: IntegratorType | None
217
+ loss_fn: LossFn
218
+ p_names: list[str]
219
+ v_names: list[str]
220
+ protocol: pd.DataFrame | None = None
221
+ residual_fn: ResidualFn | None = None
222
+
223
+
58
224
  @dataclass
59
225
  class MinResult:
60
226
  """Result of a minimization operation."""
@@ -81,7 +247,7 @@ class FitResult:
81
247
 
82
248
 
83
249
  @dataclass
84
- class CarouselFit:
250
+ class EnsembleFitResult:
85
251
  """Result of a carousel fit operation."""
86
252
 
87
253
  fits: list[FitResult]
@@ -95,24 +261,37 @@ class CarouselFit:
95
261
  return min(self.fits, key=lambda x: x.loss)
96
262
 
97
263
 
98
- type InitialGuess = dict[str, float]
99
- type ResidualFn = Callable[[Array], float]
100
- type Bounds = dict[str, tuple[float | None, float | None]]
101
- type MinimizeFn = Callable[
102
- [
103
- ResidualFn,
104
- InitialGuess,
105
- Bounds,
106
- ],
107
- MinResult | None,
108
- ]
109
- type LossFn = Callable[
110
- [
111
- pd.DataFrame | pd.Series,
112
- pd.DataFrame | pd.Series,
113
- ],
114
- float,
115
- ]
264
+ @dataclass
265
+ class JointFitResult:
266
+ """Result of joint fit operation."""
267
+
268
+ best_pars: dict[str, float]
269
+ loss: float
270
+
271
+ def __repr__(self) -> str:
272
+ """Return default representation."""
273
+ return pformat(self)
274
+
275
+
276
+ ###############################################################################
277
+ # loss fns
278
+ ###############################################################################
279
+
280
+
281
+ def mean(
282
+ y_pred: pd.DataFrame | pd.Series,
283
+ y_true: pd.DataFrame | pd.Series,
284
+ ) -> float:
285
+ """Calculate root mean square error between model and data."""
286
+ return cast(float, np.mean(y_pred - y_true))
287
+
288
+
289
+ def mean_squared(
290
+ y_pred: pd.DataFrame | pd.Series,
291
+ y_true: pd.DataFrame | pd.Series,
292
+ ) -> float:
293
+ """Calculate mean square error between model and data."""
294
+ return cast(float, np.mean(np.square(y_pred - y_true)))
116
295
 
117
296
 
118
297
  def rmse(
@@ -123,126 +302,204 @@ def rmse(
123
302
  return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
124
303
 
125
304
 
126
- class SteadyStateResidualFn(Protocol):
127
- """Protocol for steady state residual functions."""
305
+ def mae(
306
+ y_pred: pd.DataFrame | pd.Series,
307
+ y_true: pd.DataFrame | pd.Series,
308
+ ) -> float:
309
+ """Calculate mean absolute error."""
310
+ return cast(float, np.mean(np.abs(y_true - y_pred)))
128
311
 
129
- def __call__(
130
- self,
131
- par_values: Array,
132
- # This will be filled out by partial
133
- par_names: list[str],
134
- data: pd.Series,
135
- model: Model,
136
- y0: dict[str, float] | None,
137
- integrator: IntegratorType,
138
- loss_fn: LossFn,
139
- ) -> float:
140
- """Calculate residual error between model steady state and experimental data."""
141
- ...
312
+
313
+ def mean_absolute_percentage(
314
+ y_pred: pd.DataFrame | pd.Series,
315
+ y_true: pd.DataFrame | pd.Series,
316
+ ) -> float:
317
+ """Calculate mean absolute error."""
318
+ return cast(float, 100 * np.mean(np.abs((y_true - y_pred) / y_pred)))
319
+
320
+
321
+ def mean_squared_logarithmic(
322
+ y_pred: pd.DataFrame | pd.Series,
323
+ y_true: pd.DataFrame | pd.Series,
324
+ ) -> float:
325
+ """Calculate root mean square error between model and data."""
326
+ return cast(float, np.mean(np.square(np.log(y_pred + 1) - np.log(y_true + 1))))
327
+
328
+
329
+ def cosine_similarity(
330
+ y_pred: pd.DataFrame | pd.Series,
331
+ y_true: pd.DataFrame | pd.Series,
332
+ ) -> float:
333
+ """Calculate root mean square error between model and data."""
334
+ norm = np.linalg.norm
335
+ return cast(float, -np.sum(norm(y_pred, 2) * norm(y_true, 2)))
336
+
337
+
338
+ ###############################################################################
339
+ # Minimizers
340
+ ###############################################################################
142
341
 
143
342
 
144
- class TimeSeriesResidualFn(Protocol):
145
- """Protocol for time series residual functions."""
343
+ @dataclass
344
+ class LocalScipyMinimizer:
345
+ """Local multivariate minimization using scipy.optimize.
346
+
347
+ See Also
348
+ --------
349
+ https://docs.scipy.org/doc/scipy/reference/optimize.html#local-multivariate-optimization
350
+
351
+ """
352
+
353
+ tol: float = 1e-6
354
+ method: Literal[
355
+ "Nelder-Mead",
356
+ "Powell",
357
+ "CG",
358
+ "BFGS",
359
+ "Newton-CG",
360
+ "L-BFGS-B",
361
+ "TNC",
362
+ "COBYLA",
363
+ "COBYQA",
364
+ "SLSQP",
365
+ "trust-constr",
366
+ "dogleg",
367
+ "trust-ncg",
368
+ "trust-exact",
369
+ "trust-krylov",
370
+ ] = "L-BFGS-B"
146
371
 
147
372
  def __call__(
148
373
  self,
149
- par_values: Array,
150
- # This will be filled out by partial
151
- par_names: list[str],
152
- data: pd.DataFrame,
153
- model: Model,
154
- y0: dict[str, float] | None,
155
- integrator: IntegratorType,
156
- loss_fn: LossFn,
157
- ) -> float:
158
- """Calculate residual error between model time course and experimental data."""
159
- ...
374
+ residual_fn: ResidualFn,
375
+ p0: dict[str, float],
376
+ bounds: Bounds,
377
+ ) -> MinResult | None:
378
+ """Call minimzer."""
379
+ par_names = list(p0.keys())
380
+
381
+ res = minimize(
382
+ lambda par_values: residual_fn(_pack_updates(par_values, par_names)),
383
+ x0=list(p0.values()),
384
+ bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
385
+ method=self.method,
386
+ tol=self.tol,
387
+ )
388
+ if res.success:
389
+ return MinResult(
390
+ parameters=dict(
391
+ zip(
392
+ p0,
393
+ res.x,
394
+ strict=True,
395
+ ),
396
+ ),
397
+ residual=res.fun,
398
+ )
399
+
400
+ LOGGER.warning("Minimisation failed due to %s", res.message)
401
+ return None
402
+
403
+
404
+ @dataclass
405
+ class GlobalScipyMinimizer:
406
+ """Global iate minimization using scipy.optimize.
407
+
408
+ See Also
409
+ --------
410
+ https://docs.scipy.org/doc/scipy/reference/optimize.html#global-optimization
160
411
 
412
+ """
161
413
 
162
- class ProtocolResidualFn(Protocol):
163
- """Protocol for time series residual functions."""
414
+ tol: float = 1e-6
415
+ method: Literal[
416
+ "basinhopping",
417
+ "differential_evolution",
418
+ "shgo",
419
+ "dual_annealing",
420
+ "direct",
421
+ ] = "basinhopping"
164
422
 
165
423
  def __call__(
166
424
  self,
167
- par_values: Array,
168
- # This will be filled out by partial
169
- par_names: list[str],
170
- data: pd.DataFrame,
171
- model: Model,
172
- y0: dict[str, float] | None,
173
- integrator: IntegratorType,
174
- loss_fn: LossFn,
175
- protocol: pd.DataFrame,
176
- ) -> float:
177
- """Calculate residual error between model time course and experimental data."""
178
- ...
179
-
425
+ residual_fn: ResidualFn,
426
+ p0: dict[str, float],
427
+ bounds: Bounds,
428
+ ) -> MinResult | None:
429
+ """Minimize residual fn."""
430
+ res: OptimizeResult
431
+ par_names = list(p0.keys())
432
+ res_fn: ResFn = lambda par_values: residual_fn( # noqa: E731
433
+ _pack_updates(par_values, par_names)
434
+ )
180
435
 
181
- def _default_minimize_fn(
182
- residual_fn: ResidualFn,
183
- p0: dict[str, float],
184
- bounds: Bounds,
185
- ) -> MinResult | None:
186
- res = minimize(
187
- residual_fn,
188
- x0=list(p0.values()),
189
- bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
190
- method="L-BFGS-B",
191
- )
192
- if res.success:
193
- return MinResult(
194
- parameters=dict(
195
- zip(
196
- p0,
197
- res.x,
198
- strict=True,
436
+ if self.method == "basinhopping":
437
+ res = basinhopping(
438
+ res_fn,
439
+ x0=list(p0.values()),
440
+ )
441
+ elif self.method == "differential_evolution":
442
+ res = differential_evolution(res_fn, bounds)
443
+ elif self.method == "shgo":
444
+ res = shgo(res_fn, bounds)
445
+ elif self.method == "dual_annealing":
446
+ res = dual_annealing(res_fn, bounds)
447
+ elif self.method == "direct":
448
+ res = direct(res_fn, bounds)
449
+ else:
450
+ msg = f"Unknown method {self.method}"
451
+ raise NotImplementedError(msg)
452
+ if res.success:
453
+ return MinResult(
454
+ parameters=dict(
455
+ zip(
456
+ p0,
457
+ res.x,
458
+ strict=True,
459
+ ),
199
460
  ),
200
- ),
201
- residual=res.fun,
202
- )
461
+ residual=res.fun,
462
+ )
463
+
464
+ LOGGER.warning("Minimisation failed.")
465
+ return None
203
466
 
204
- LOGGER.warning("Minimisation failed.")
205
- return None
206
467
 
468
+ ###############################################################################
469
+ # Residual functions
470
+ ###############################################################################
207
471
 
208
- def _steady_state_residual(
472
+
473
+ def _pack_updates(
209
474
  par_values: Array,
210
- # This will be filled out by partial
211
475
  par_names: list[str],
212
- data: pd.Series,
213
- model: Model,
214
- y0: dict[str, float] | None,
215
- integrator: IntegratorType,
216
- loss_fn: LossFn,
217
- ) -> float:
218
- """Calculate residual error between model steady state and experimental data.
476
+ ) -> dict[str, float]:
477
+ return dict(
478
+ zip(
479
+ par_names,
480
+ par_values,
481
+ strict=True,
482
+ )
483
+ )
219
484
 
220
- Args:
221
- par_values: Parameter values to test
222
- data: Experimental steady state data
223
- model: Model instance to simulate
224
- y0: Initial conditions
225
- par_names: Names of parameters being fit
226
- integrator: ODE integrator class to use
227
- loss_fn: Loss function to use for residual calculation
228
485
 
229
- Returns:
230
- float: Root mean square error between model and data
486
+ def steady_state_residual(
487
+ updates: dict[str, float],
488
+ settings: _Settings,
489
+ ) -> float:
490
+ """Calculate residual error between model steady state and experimental data."""
491
+ model = settings.model
492
+ if (y0 := settings.y0) is not None:
493
+ model.update_variables(y0)
494
+ for p in settings.p_names:
495
+ model.update_parameter(p, updates[p])
496
+ for p in settings.v_names:
497
+ model.update_variable(p, updates[p])
231
498
 
232
- """
233
499
  res = (
234
500
  Simulator(
235
- model.update_parameters(
236
- dict(
237
- zip(
238
- par_names,
239
- par_values,
240
- strict=True,
241
- )
242
- )
243
- ),
244
- y0=y0,
245
- integrator=integrator,
501
+ model,
502
+ integrator=settings.integrator,
246
503
  )
247
504
  .simulate_to_steady_state()
248
505
  .get_result()
@@ -250,93 +507,67 @@ def _steady_state_residual(
250
507
  if res is None:
251
508
  return cast(float, np.inf)
252
509
 
253
- return loss_fn(
254
- res.get_combined().loc[:, cast(list, data.index)],
255
- data,
510
+ return settings.loss_fn(
511
+ res.get_combined().loc[:, cast(list, settings.data.index)],
512
+ settings.data,
256
513
  )
257
514
 
258
515
 
259
- def _time_course_residual(
260
- par_values: ArrayLike,
261
- # This will be filled out by partial
262
- par_names: list[str],
263
- data: pd.DataFrame,
264
- model: Model,
265
- y0: dict[str, float] | None,
266
- integrator: IntegratorType,
267
- loss_fn: LossFn,
516
+ def time_course_residual(
517
+ updates: dict[str, float],
518
+ settings: _Settings,
268
519
  ) -> float:
269
- """Calculate residual error between model time course and experimental data.
270
-
271
- Args:
272
- par_values: Parameter values to test
273
- data: Experimental time course data
274
- model: Model instance to simulate
275
- y0: Initial conditions
276
- par_names: Names of parameters being fit
277
- integrator: ODE integrator class to use
278
- loss_fn: Loss function to use for residual calculation
520
+ """Calculate residual error between model time course and experimental data."""
521
+ model = settings.model
522
+ if (y0 := settings.y0) is not None:
523
+ model.update_variables(y0)
524
+ for p in settings.p_names:
525
+ model.update_parameter(p, updates[p])
526
+ for p in settings.v_names:
527
+ model.update_variable(p, updates[p])
279
528
 
280
- Returns:
281
- float: Root mean square error between model and data
282
-
283
- """
284
529
  res = (
285
530
  Simulator(
286
- model.update_parameters(dict(zip(par_names, par_values, strict=True))),
287
- y0=y0,
288
- integrator=integrator,
531
+ model,
532
+ integrator=settings.integrator,
289
533
  )
290
- .simulate_time_course(cast(list, data.index))
534
+ .simulate_time_course(cast(list, settings.data.index))
291
535
  .get_result()
292
536
  )
293
537
  if res is None:
294
538
  return cast(float, np.inf)
295
539
  results_ss = res.get_combined()
296
540
 
297
- return loss_fn(
298
- results_ss.loc[:, cast(list, data.columns)],
299
- data,
541
+ return settings.loss_fn(
542
+ results_ss.loc[:, cast(list, settings.data.columns)],
543
+ settings.data,
300
544
  )
301
545
 
302
546
 
303
- def _protocol_time_course_residual(
304
- par_values: ArrayLike,
305
- # This will be filled out by partial
306
- par_names: list[str],
307
- data: pd.DataFrame,
308
- model: Model,
309
- y0: dict[str, float] | None,
310
- integrator: IntegratorType,
311
- loss_fn: LossFn,
312
- protocol: pd.DataFrame,
547
+ def protocol_time_course_residual(
548
+ updates: dict[str, float],
549
+ settings: _Settings,
313
550
  ) -> float:
314
- """Calculate residual error between model time course and experimental data.
315
-
316
- Args:
317
- par_values: Parameter values to test
318
- data: Experimental time course data
319
- model: Model instance to simulate
320
- y0: Initial conditions
321
- par_names: Names of parameters being fit
322
- integrator: ODE integrator class to use
323
- loss_fn: Loss function to use for residual calculation
324
- protocol: Experimental protocol
325
- time_points_per_step: Number of time points per step in the protocol
551
+ """Calculate residual error between model time course and experimental data."""
552
+ model = settings.model
553
+ if (y0 := settings.y0) is not None:
554
+ model.update_variables(y0)
555
+ for p in settings.p_names:
556
+ model.update_parameter(p, updates[p])
557
+ for p in settings.v_names:
558
+ model.update_variable(p, updates[p])
559
+
560
+ if (protocol := settings.protocol) is None:
561
+ raise ValueError
326
562
 
327
- Returns:
328
- float: Root mean square error between model and data
329
-
330
- """
331
563
  res = (
332
564
  Simulator(
333
- model.update_parameters(dict(zip(par_names, par_values, strict=True))),
334
- y0=y0,
335
- integrator=integrator,
565
+ model,
566
+ integrator=settings.integrator,
336
567
  )
337
568
  .simulate_protocol_time_course(
338
569
  protocol=protocol,
339
- time_points=data.index,
570
+ time_points=settings.data.index,
340
571
  )
341
572
  .get_result()
342
573
  )
@@ -344,87 +575,9 @@ def _protocol_time_course_residual(
344
575
  return cast(float, np.inf)
345
576
  results_ss = res.get_combined()
346
577
 
347
- return loss_fn(
348
- results_ss.loc[:, cast(list, data.columns)],
349
- data,
350
- )
351
-
352
-
353
- def _carousel_steady_state_worker(
354
- model: Model,
355
- p0: dict[str, float],
356
- data: pd.Series,
357
- y0: dict[str, float] | None,
358
- integrator: IntegratorType | None,
359
- loss_fn: LossFn,
360
- minimize_fn: MinimizeFn,
361
- residual_fn: SteadyStateResidualFn,
362
- bounds: Bounds | None,
363
- ) -> FitResult | None:
364
- model_pars = model.get_parameter_values()
365
-
366
- return steady_state(
367
- model,
368
- p0={k: v for k, v in p0.items() if k in model_pars},
369
- y0=y0,
370
- data=data,
371
- minimize_fn=minimize_fn,
372
- residual_fn=residual_fn,
373
- integrator=integrator,
374
- loss_fn=loss_fn,
375
- bounds=bounds,
376
- )
377
-
378
-
379
- def _carousel_time_course_worker(
380
- model: Model,
381
- p0: dict[str, float],
382
- data: pd.DataFrame,
383
- y0: dict[str, float] | None,
384
- integrator: IntegratorType | None,
385
- loss_fn: LossFn,
386
- minimize_fn: MinimizeFn,
387
- residual_fn: TimeSeriesResidualFn,
388
- bounds: Bounds | None,
389
- ) -> FitResult | None:
390
- model_pars = model.get_parameter_values()
391
- return time_course(
392
- model,
393
- p0={k: v for k, v in p0.items() if k in model_pars},
394
- y0=y0,
395
- data=data,
396
- minimize_fn=minimize_fn,
397
- residual_fn=residual_fn,
398
- integrator=integrator,
399
- loss_fn=loss_fn,
400
- bounds=bounds,
401
- )
402
-
403
-
404
- def _carousel_protocol_worker(
405
- model: Model,
406
- p0: dict[str, float],
407
- data: pd.DataFrame,
408
- protocol: pd.DataFrame,
409
- y0: dict[str, float] | None,
410
- integrator: IntegratorType | None,
411
- loss_fn: LossFn,
412
- minimize_fn: MinimizeFn,
413
- residual_fn: ProtocolResidualFn,
414
- bounds: Bounds | None,
415
- ) -> FitResult | None:
416
- model_pars = model.get_parameter_values()
417
- return protocol_time_course(
418
- model,
419
- p0={k: v for k, v in p0.items() if k in model_pars},
420
- y0=y0,
421
- protocol=protocol,
422
- data=data,
423
- minimize_fn=minimize_fn,
424
- residual_fn=residual_fn,
425
- integrator=integrator,
426
- loss_fn=loss_fn,
427
- bounds=bounds,
578
+ return settings.loss_fn(
579
+ results_ss.loc[:, cast(list, settings.data.columns)],
580
+ settings.data,
428
581
  )
429
582
 
430
583
 
@@ -433,12 +586,13 @@ def steady_state(
433
586
  *,
434
587
  p0: dict[str, float],
435
588
  data: pd.Series,
589
+ minimizer: Minimizer,
436
590
  y0: dict[str, float] | None = None,
437
- minimize_fn: MinimizeFn = _default_minimize_fn,
438
- residual_fn: SteadyStateResidualFn = _steady_state_residual,
591
+ residual_fn: ResidualProtocol = steady_state_residual,
439
592
  integrator: IntegratorType | None = None,
440
593
  loss_fn: LossFn = rmse,
441
594
  bounds: Bounds | None = None,
595
+ as_deepcopy: bool = True,
442
596
  ) -> FitResult | None:
443
597
  """Fit model parameters to steady-state experimental data.
444
598
 
@@ -449,13 +603,14 @@ def steady_state(
449
603
  Args:
450
604
  model: Model instance to fit
451
605
  data: Experimental steady state data as pandas Series
452
- p0: Initial parameter guesses as {parameter_name: value}
606
+ p0: Initial guesses as {name: value}
453
607
  y0: Initial conditions as {species_name: value}
454
- minimize_fn: Function to minimize fitting error
608
+ minimizer: Function to minimize fitting error
455
609
  residual_fn: Function to calculate fitting error
456
610
  integrator: ODE integrator class
457
611
  loss_fn: Loss function to use for residual calculation
458
612
  bounds: Mapping of bounds per parameter
613
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
459
614
 
460
615
  Returns:
461
616
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -464,31 +619,30 @@ def steady_state(
464
619
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
465
620
 
466
621
  """
467
- par_names = list(p0.keys())
622
+ if as_deepcopy:
623
+ model = deepcopy(model)
468
624
 
469
- # Copy to restore
470
- p_orig = model.get_parameter_values()
625
+ p_names = model.get_parameter_names()
626
+ v_names = model.get_variable_names()
471
627
 
472
- fn = cast(
473
- ResidualFn,
474
- partial(
475
- residual_fn,
476
- data=data,
628
+ fn: ResidualFn = partial(
629
+ residual_fn,
630
+ settings=_Settings(
477
631
  model=model,
632
+ data=data,
478
633
  y0=y0,
479
- par_names=par_names,
480
634
  integrator=integrator,
481
635
  loss_fn=loss_fn,
636
+ p_names=[i for i in p0 if i in p_names],
637
+ v_names=[i for i in p0 if i in v_names],
482
638
  ),
483
639
  )
484
- min_result = minimize_fn(fn, p0, {} if bounds is None else bounds)
485
- # Restore original model
486
- model.update_parameters(p_orig)
640
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
487
641
  if min_result is None:
488
- return min_result
642
+ return None
489
643
 
490
644
  return FitResult(
491
- model=deepcopy(model).update_parameters(min_result.parameters),
645
+ model=model,
492
646
  best_pars=min_result.parameters,
493
647
  loss=min_result.residual,
494
648
  )
@@ -499,12 +653,13 @@ def time_course(
499
653
  *,
500
654
  p0: dict[str, float],
501
655
  data: pd.DataFrame,
656
+ minimizer: Minimizer,
502
657
  y0: dict[str, float] | None = None,
503
- minimize_fn: MinimizeFn = _default_minimize_fn,
504
- residual_fn: TimeSeriesResidualFn = _time_course_residual,
658
+ residual_fn: ResidualProtocol = time_course_residual,
505
659
  integrator: IntegratorType | None = None,
506
660
  loss_fn: LossFn = rmse,
507
661
  bounds: Bounds | None = None,
662
+ as_deepcopy: bool = True,
508
663
  ) -> FitResult | None:
509
664
  """Fit model parameters to time course of experimental data.
510
665
 
@@ -515,13 +670,14 @@ def time_course(
515
670
  Args:
516
671
  model: Model instance to fit
517
672
  data: Experimental time course data
518
- p0: Initial parameter guesses as {parameter_name: value}
673
+ p0: Initial guesses as {parameter_name: value}
519
674
  y0: Initial conditions as {species_name: value}
520
- minimize_fn: Function to minimize fitting error
675
+ minimizer: Function to minimize fitting error
521
676
  residual_fn: Function to calculate fitting error
522
677
  integrator: ODE integrator class
523
678
  loss_fn: Loss function to use for residual calculation
524
679
  bounds: Mapping of bounds per parameter
680
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
525
681
 
526
682
  Returns:
527
683
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -530,30 +686,30 @@ def time_course(
530
686
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
531
687
 
532
688
  """
533
- par_names = list(p0.keys())
534
- p_orig = model.get_parameter_values()
689
+ if as_deepcopy:
690
+ model = deepcopy(model)
691
+ p_names = model.get_parameter_names()
692
+ v_names = model.get_variable_names()
535
693
 
536
- fn = cast(
537
- ResidualFn,
538
- partial(
539
- residual_fn,
540
- data=data,
694
+ fn: ResidualFn = partial(
695
+ residual_fn,
696
+ settings=_Settings(
541
697
  model=model,
698
+ data=data,
542
699
  y0=y0,
543
- par_names=par_names,
544
700
  integrator=integrator,
545
701
  loss_fn=loss_fn,
702
+ p_names=[i for i in p0 if i in p_names],
703
+ v_names=[i for i in p0 if i in v_names],
546
704
  ),
547
705
  )
548
706
 
549
- min_result = minimize_fn(fn, p0, {} if bounds is None else bounds)
550
- # Restore original model
551
- model.update_parameters(p_orig)
707
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
552
708
  if min_result is None:
553
- return min_result
709
+ return None
554
710
 
555
711
  return FitResult(
556
- model=deepcopy(model).update_parameters(min_result.parameters),
712
+ model=model,
557
713
  best_pars=min_result.parameters,
558
714
  loss=min_result.residual,
559
715
  )
@@ -565,12 +721,13 @@ def protocol_time_course(
565
721
  p0: dict[str, float],
566
722
  data: pd.DataFrame,
567
723
  protocol: pd.DataFrame,
724
+ minimizer: Minimizer,
568
725
  y0: dict[str, float] | None = None,
569
- minimize_fn: MinimizeFn = _default_minimize_fn,
570
- residual_fn: ProtocolResidualFn = _protocol_time_course_residual,
726
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
571
727
  integrator: IntegratorType | None = None,
572
728
  loss_fn: LossFn = rmse,
573
729
  bounds: Bounds | None = None,
730
+ as_deepcopy: bool = True,
574
731
  ) -> FitResult | None:
575
732
  """Fit model parameters to time course of experimental data.
576
733
 
@@ -586,12 +743,13 @@ def protocol_time_course(
586
743
  data: Experimental time course data
587
744
  protocol: Experimental protocol
588
745
  y0: Initial conditions as {species_name: value}
589
- minimize_fn: Function to minimize fitting error
746
+ minimizer: Function to minimize fitting error
590
747
  residual_fn: Function to calculate fitting error
591
748
  integrator: ODE integrator class
592
749
  loss_fn: Loss function to use for residual calculation
593
750
  time_points_per_step: Number of time points per step in the protocol
594
751
  bounds: Mapping of bounds per parameter
752
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
595
753
 
596
754
  Returns:
597
755
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -600,65 +758,73 @@ def protocol_time_course(
600
758
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
601
759
 
602
760
  """
603
- par_names = list(p0.keys())
604
- p_orig = model.get_parameter_values()
761
+ if as_deepcopy:
762
+ model = deepcopy(model)
763
+ p_names = model.get_parameter_names()
764
+ v_names = model.get_variable_names()
605
765
 
606
- fn = cast(
607
- ResidualFn,
608
- partial(
609
- residual_fn,
610
- data=data,
766
+ fn: ResidualFn = partial(
767
+ residual_fn,
768
+ settings=_Settings(
611
769
  model=model,
770
+ data=data,
612
771
  y0=y0,
613
- par_names=par_names,
614
772
  integrator=integrator,
615
773
  loss_fn=loss_fn,
774
+ p_names=[i for i in p0 if i in p_names],
775
+ v_names=[i for i in p0 if i in v_names],
616
776
  protocol=protocol,
617
777
  ),
618
778
  )
619
779
 
620
- min_result = minimize_fn(fn, p0, {} if bounds is None else bounds)
621
- # Restore original model
622
- model.update_parameters(p_orig)
780
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
623
781
  if min_result is None:
624
- return min_result
782
+ return None
625
783
 
626
784
  return FitResult(
627
- model=deepcopy(model).update_parameters(min_result.parameters),
785
+ model=model,
628
786
  best_pars=min_result.parameters,
629
787
  loss=min_result.residual,
630
788
  )
631
789
 
632
790
 
633
- def carousel_steady_state(
634
- carousel: Carousel,
791
+ ###############################################################################
792
+ # Ensemble / carousel
793
+ # This is multi-model, single data fitting, where the models share parameters
794
+ ###############################################################################
795
+
796
+
797
+ def ensemble_steady_state(
798
+ ensemble: list[Model],
635
799
  *,
636
800
  p0: dict[str, float],
637
801
  data: pd.Series,
802
+ minimizer: Minimizer,
638
803
  y0: dict[str, float] | None = None,
639
- minimize_fn: MinimizeFn = _default_minimize_fn,
640
- residual_fn: SteadyStateResidualFn = _steady_state_residual,
804
+ residual_fn: ResidualProtocol = steady_state_residual,
641
805
  integrator: IntegratorType | None = None,
642
806
  loss_fn: LossFn = rmse,
643
807
  bounds: Bounds | None = None,
644
- ) -> CarouselFit:
645
- """Fit model parameters to steady-state experimental data over a carousel.
808
+ as_deepcopy: bool = True,
809
+ ) -> EnsembleFitResult:
810
+ """Fit model ensemble parameters to steady-state experimental data.
646
811
 
647
812
  Examples:
648
813
  >>> carousel_steady_state(carousel, p0=p0, data=data)
649
814
 
650
815
  Args:
651
- carousel: Model carousel to fit
816
+ ensemble: Ensemble to fit
652
817
  p0: Initial parameter guesses as {parameter_name: value}
653
818
  data: Experimental time course data
654
819
  protocol: Experimental protocol
655
820
  y0: Initial conditions as {species_name: value}
656
- minimize_fn: Function to minimize fitting error
821
+ minimizer: Function to minimize fitting error
657
822
  residual_fn: Function to calculate fitting error
658
823
  integrator: ODE integrator class
659
824
  loss_fn: Loss function to use for residual calculation
660
825
  time_points_per_step: Number of time points per step in the protocol
661
826
  bounds: Mapping of bounds per parameter
827
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
662
828
 
663
829
  Returns:
664
830
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -667,40 +833,95 @@ def carousel_steady_state(
667
833
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
668
834
 
669
835
  """
670
- return CarouselFit(
836
+ return EnsembleFitResult(
671
837
  [
672
838
  fit
673
839
  for i in parallel.parallelise(
674
840
  partial(
675
- _carousel_steady_state_worker,
841
+ steady_state,
676
842
  p0=p0,
677
843
  data=data,
678
844
  y0=y0,
679
845
  integrator=integrator,
680
846
  loss_fn=loss_fn,
681
- minimize_fn=minimize_fn,
847
+ minimizer=minimizer,
682
848
  residual_fn=residual_fn,
683
849
  bounds=bounds,
850
+ as_deepcopy=as_deepcopy,
684
851
  ),
685
- inputs=list(enumerate(carousel.variants)),
852
+ inputs=list(enumerate(ensemble)),
686
853
  )
687
854
  if (fit := i[1]) is not None
688
855
  ]
689
856
  )
690
857
 
691
858
 
692
- def carousel_time_course(
859
+ def carousel_steady_state(
693
860
  carousel: Carousel,
694
861
  *,
695
862
  p0: dict[str, float],
863
+ data: pd.Series,
864
+ minimizer: Minimizer,
865
+ y0: dict[str, float] | None = None,
866
+ residual_fn: ResidualProtocol = steady_state_residual,
867
+ integrator: IntegratorType | None = None,
868
+ loss_fn: LossFn = rmse,
869
+ bounds: Bounds | None = None,
870
+ as_deepcopy: bool = True,
871
+ ) -> EnsembleFitResult:
872
+ """Fit model parameters to steady-state experimental data over a carousel.
873
+
874
+ Examples:
875
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
876
+
877
+ Args:
878
+ carousel: Model carousel to fit
879
+ p0: Initial parameter guesses as {parameter_name: value}
880
+ data: Experimental time course data
881
+ protocol: Experimental protocol
882
+ y0: Initial conditions as {species_name: value}
883
+ minimizer: Function to minimize fitting error
884
+ residual_fn: Function to calculate fitting error
885
+ integrator: ODE integrator class
886
+ loss_fn: Loss function to use for residual calculation
887
+ time_points_per_step: Number of time points per step in the protocol
888
+ bounds: Mapping of bounds per parameter
889
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
890
+
891
+ Returns:
892
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
893
+
894
+ Note:
895
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
896
+
897
+ """
898
+ return ensemble_steady_state(
899
+ carousel.variants,
900
+ p0=p0,
901
+ data=data,
902
+ minimizer=minimizer,
903
+ y0=y0,
904
+ residual_fn=residual_fn,
905
+ integrator=integrator,
906
+ loss_fn=loss_fn,
907
+ bounds=bounds,
908
+ as_deepcopy=as_deepcopy,
909
+ )
910
+
911
+
912
+ def ensemble_time_course(
913
+ ensemble: list[Model],
914
+ *,
915
+ p0: dict[str, float],
696
916
  data: pd.DataFrame,
917
+ minimizer: Minimizer,
697
918
  y0: dict[str, float] | None = None,
698
- minimize_fn: MinimizeFn = _default_minimize_fn,
699
- residual_fn: TimeSeriesResidualFn = _time_course_residual,
919
+ residual_fn: ResidualProtocol = time_course_residual,
700
920
  integrator: IntegratorType | None = None,
701
921
  loss_fn: LossFn = rmse,
702
922
  bounds: Bounds | None = None,
703
- ) -> CarouselFit:
923
+ as_deepcopy: bool = True,
924
+ ) -> EnsembleFitResult:
704
925
  """Fit model parameters to time course of experimental data over a carousel.
705
926
 
706
927
  Time points are taken from the data.
@@ -709,17 +930,18 @@ def carousel_time_course(
709
930
  >>> carousel_time_course(carousel, p0=p0, data=data)
710
931
 
711
932
  Args:
712
- carousel: Model carousel to fit
933
+ ensemble: Model ensemble to fit
713
934
  p0: Initial parameter guesses as {parameter_name: value}
714
935
  data: Experimental time course data
715
936
  protocol: Experimental protocol
716
937
  y0: Initial conditions as {species_name: value}
717
- minimize_fn: Function to minimize fitting error
938
+ minimizer: Function to minimize fitting error
718
939
  residual_fn: Function to calculate fitting error
719
940
  integrator: ODE integrator class
720
941
  loss_fn: Loss function to use for residual calculation
721
942
  time_points_per_step: Number of time points per step in the protocol
722
943
  bounds: Mapping of bounds per parameter
944
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
723
945
 
724
946
  Returns:
725
947
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -728,41 +950,98 @@ def carousel_time_course(
728
950
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
729
951
 
730
952
  """
731
- return CarouselFit(
953
+ return EnsembleFitResult(
732
954
  [
733
955
  fit
734
956
  for i in parallel.parallelise(
735
957
  partial(
736
- _carousel_time_course_worker,
958
+ time_course,
737
959
  p0=p0,
738
960
  data=data,
739
961
  y0=y0,
740
962
  integrator=integrator,
741
963
  loss_fn=loss_fn,
742
- minimize_fn=minimize_fn,
964
+ minimizer=minimizer,
743
965
  residual_fn=residual_fn,
744
966
  bounds=bounds,
967
+ as_deepcopy=as_deepcopy,
745
968
  ),
746
- inputs=list(enumerate(carousel.variants)),
969
+ inputs=list(enumerate(ensemble)),
747
970
  )
748
971
  if (fit := i[1]) is not None
749
972
  ]
750
973
  )
751
974
 
752
975
 
753
- def carousel_protocol_time_course(
976
+ def carousel_time_course(
754
977
  carousel: Carousel,
755
978
  *,
756
979
  p0: dict[str, float],
757
980
  data: pd.DataFrame,
981
+ minimizer: Minimizer,
982
+ y0: dict[str, float] | None = None,
983
+ residual_fn: ResidualProtocol = time_course_residual,
984
+ integrator: IntegratorType | None = None,
985
+ loss_fn: LossFn = rmse,
986
+ bounds: Bounds | None = None,
987
+ as_deepcopy: bool = True,
988
+ ) -> EnsembleFitResult:
989
+ """Fit model parameters to time course of experimental data over a carousel.
990
+
991
+ Time points are taken from the data.
992
+
993
+ Examples:
994
+ >>> carousel_time_course(carousel, p0=p0, data=data)
995
+
996
+ Args:
997
+ carousel: Model carousel to fit
998
+ p0: Initial parameter guesses as {parameter_name: value}
999
+ data: Experimental time course data
1000
+ protocol: Experimental protocol
1001
+ y0: Initial conditions as {species_name: value}
1002
+ minimizer: Function to minimize fitting error
1003
+ residual_fn: Function to calculate fitting error
1004
+ integrator: ODE integrator class
1005
+ loss_fn: Loss function to use for residual calculation
1006
+ time_points_per_step: Number of time points per step in the protocol
1007
+ bounds: Mapping of bounds per parameter
1008
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
1009
+
1010
+ Returns:
1011
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
1012
+
1013
+ Note:
1014
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
1015
+
1016
+ """
1017
+ return ensemble_time_course(
1018
+ carousel.variants,
1019
+ p0=p0,
1020
+ data=data,
1021
+ minimizer=minimizer,
1022
+ y0=y0,
1023
+ residual_fn=residual_fn,
1024
+ integrator=integrator,
1025
+ loss_fn=loss_fn,
1026
+ bounds=bounds,
1027
+ as_deepcopy=as_deepcopy,
1028
+ )
1029
+
1030
+
1031
+ def ensemble_protocol_time_course(
1032
+ ensemble: list[Model],
1033
+ *,
1034
+ p0: dict[str, float],
1035
+ data: pd.DataFrame,
1036
+ minimizer: Minimizer,
758
1037
  protocol: pd.DataFrame,
759
1038
  y0: dict[str, float] | None = None,
760
- minimize_fn: MinimizeFn = _default_minimize_fn,
761
- residual_fn: ProtocolResidualFn = _protocol_time_course_residual,
1039
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
762
1040
  integrator: IntegratorType | None = None,
763
1041
  loss_fn: LossFn = rmse,
764
1042
  bounds: Bounds | None = None,
765
- ) -> CarouselFit:
1043
+ as_deepcopy: bool = True,
1044
+ ) -> EnsembleFitResult:
766
1045
  """Fit model parameters to time course of experimental data over a protocol.
767
1046
 
768
1047
  Time points of protocol time course are taken from the data.
@@ -771,17 +1050,18 @@ def carousel_protocol_time_course(
771
1050
  >>> carousel_steady_state(carousel, p0=p0, data=data)
772
1051
 
773
1052
  Args:
774
- carousel: Model carousel to fit
775
- p0: Initial parameter guesses as {parameter_name: value}
1053
+ ensemble: Model ensemble: value}
1054
+ p0: initial parameter guess
776
1055
  data: Experimental time course data
777
1056
  protocol: Experimental protocol
778
1057
  y0: Initial conditions as {species_name: value}
779
- minimize_fn: Function to minimize fitting error
1058
+ minimizer: Function to minimize fitting error
780
1059
  residual_fn: Function to calculate fitting error
781
1060
  integrator: ODE integrator class
782
1061
  loss_fn: Loss function to use for residual calculation
783
1062
  time_points_per_step: Number of time points per step in the protocol
784
1063
  bounds: Mapping of bounds per parameter
1064
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
785
1065
 
786
1066
  Returns:
787
1067
  dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
@@ -790,24 +1070,345 @@ def carousel_protocol_time_course(
790
1070
  Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
791
1071
 
792
1072
  """
793
- return CarouselFit(
1073
+ return EnsembleFitResult(
794
1074
  [
795
1075
  fit
796
1076
  for i in parallel.parallelise(
797
1077
  partial(
798
- _carousel_protocol_worker,
1078
+ protocol_time_course,
799
1079
  p0=p0,
800
1080
  data=data,
801
1081
  protocol=protocol,
802
1082
  y0=y0,
803
1083
  integrator=integrator,
804
1084
  loss_fn=loss_fn,
805
- minimize_fn=minimize_fn,
1085
+ minimizer=minimizer,
806
1086
  residual_fn=residual_fn,
807
1087
  bounds=bounds,
1088
+ as_deepcopy=as_deepcopy,
808
1089
  ),
809
- inputs=list(enumerate(carousel.variants)),
1090
+ inputs=list(enumerate(ensemble)),
810
1091
  )
811
1092
  if (fit := i[1]) is not None
812
1093
  ]
813
1094
  )
1095
+
1096
+
1097
+ def carousel_protocol_time_course(
1098
+ carousel: Carousel,
1099
+ *,
1100
+ p0: dict[str, float],
1101
+ data: pd.DataFrame,
1102
+ minimizer: Minimizer,
1103
+ protocol: pd.DataFrame,
1104
+ y0: dict[str, float] | None = None,
1105
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
1106
+ integrator: IntegratorType | None = None,
1107
+ loss_fn: LossFn = rmse,
1108
+ bounds: Bounds | None = None,
1109
+ as_deepcopy: bool = True,
1110
+ ) -> EnsembleFitResult:
1111
+ """Fit model parameters to time course of experimental data over a protocol.
1112
+
1113
+ Time points of protocol time course are taken from the data.
1114
+
1115
+ Examples:
1116
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
1117
+
1118
+ Args:
1119
+ carousel: Model carousel to fit
1120
+ p0: Initial parameter guesses as {parameter_name: value}
1121
+ data: Experimental time course data
1122
+ protocol: Experimental protocol
1123
+ y0: Initial conditions as {species_name: value}
1124
+ minimizer: Function to minimize fitting error
1125
+ residual_fn: Function to calculate fitting error
1126
+ integrator: ODE integrator class
1127
+ loss_fn: Loss function to use for residual calculation
1128
+ time_points_per_step: Number of time points per step in the protocol
1129
+ bounds: Mapping of bounds per parameter
1130
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
1131
+
1132
+ Returns:
1133
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
1134
+
1135
+ Note:
1136
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
1137
+
1138
+ """
1139
+ return ensemble_protocol_time_course(
1140
+ carousel.variants,
1141
+ p0=p0,
1142
+ data=data,
1143
+ minimizer=minimizer,
1144
+ protocol=protocol,
1145
+ y0=y0,
1146
+ residual_fn=residual_fn,
1147
+ integrator=integrator,
1148
+ loss_fn=loss_fn,
1149
+ bounds=bounds,
1150
+ as_deepcopy=as_deepcopy,
1151
+ )
1152
+
1153
+
1154
+ ###############################################################################
1155
+ # Joint fitting
1156
+ # This is multi-model, multi-data fitting, where the models share some parameters
1157
+ ###############################################################################
1158
+
1159
+
1160
+ def _unpacked[T1, T2, Tout](inp: tuple[T1, T2], fn: Callable[[T1, T2], Tout]) -> Tout:
1161
+ return fn(*inp)
1162
+
1163
+
1164
+ def _sum_of_residuals(
1165
+ updates: dict[str, float],
1166
+ residual_fn: ResidualProtocol,
1167
+ fits: list[_Settings],
1168
+ pool: pebble.ProcessPool,
1169
+ ) -> float:
1170
+ future = pool.map(
1171
+ partial(_unpacked, fn=residual_fn),
1172
+ [(updates, i) for i in fits],
1173
+ timeout=None,
1174
+ )
1175
+ error = 0.0
1176
+ it = future.result()
1177
+ while True:
1178
+ try:
1179
+ error += next(it)
1180
+ except StopIteration:
1181
+ break
1182
+ except TimeoutError:
1183
+ return np.inf
1184
+ return error
1185
+
1186
+
1187
+ def joint_steady_state(
1188
+ to_fit: list[FitSettings],
1189
+ *,
1190
+ p0: dict[str, float],
1191
+ minimizer: Minimizer,
1192
+ y0: dict[str, float] | None = None,
1193
+ integrator: IntegratorType | None = None,
1194
+ loss_fn: LossFn = rmse,
1195
+ bounds: Bounds | None = None,
1196
+ max_workers: int | None = None,
1197
+ as_deepcopy: bool = True,
1198
+ ) -> JointFitResult | None:
1199
+ """Multi-model, multi-data fitting."""
1200
+ full_settings = []
1201
+ for i in to_fit:
1202
+ p_names = i.model.get_parameter_names()
1203
+ v_names = i.model.get_variable_names()
1204
+ full_settings.append(
1205
+ _Settings(
1206
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1207
+ data=i.data,
1208
+ y0=i.y0 if i.y0 is not None else y0,
1209
+ integrator=i.integrator if i.integrator is not None else integrator,
1210
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1211
+ p_names=[j for j in p0 if j in p_names],
1212
+ v_names=[j for j in p0 if j in v_names],
1213
+ )
1214
+ )
1215
+
1216
+ with pebble.ProcessPool(
1217
+ max_workers=(
1218
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1219
+ )
1220
+ ) as pool:
1221
+ min_result = minimizer(
1222
+ partial(
1223
+ _sum_of_residuals,
1224
+ residual_fn=steady_state_residual,
1225
+ fits=full_settings,
1226
+ pool=pool,
1227
+ ),
1228
+ p0,
1229
+ {} if bounds is None else bounds,
1230
+ )
1231
+ if min_result is None:
1232
+ return None
1233
+
1234
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1235
+
1236
+
1237
+ def joint_time_course(
1238
+ to_fit: list[FitSettings],
1239
+ *,
1240
+ p0: dict[str, float],
1241
+ minimizer: Minimizer,
1242
+ y0: dict[str, float] | None = None,
1243
+ integrator: IntegratorType | None = None,
1244
+ loss_fn: LossFn = rmse,
1245
+ bounds: Bounds | None = None,
1246
+ max_workers: int | None = None,
1247
+ as_deepcopy: bool = True,
1248
+ ) -> JointFitResult | None:
1249
+ """Multi-model, multi-data fitting."""
1250
+ full_settings = []
1251
+ for i in to_fit:
1252
+ p_names = i.model.get_parameter_names()
1253
+ v_names = i.model.get_variable_names()
1254
+ full_settings.append(
1255
+ _Settings(
1256
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1257
+ data=i.data,
1258
+ y0=i.y0 if i.y0 is not None else y0,
1259
+ integrator=i.integrator if i.integrator is not None else integrator,
1260
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1261
+ p_names=[j for j in p0 if j in p_names],
1262
+ v_names=[j for j in p0 if j in v_names],
1263
+ )
1264
+ )
1265
+
1266
+ with pebble.ProcessPool(
1267
+ max_workers=(
1268
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1269
+ )
1270
+ ) as pool:
1271
+ min_result = minimizer(
1272
+ partial(
1273
+ _sum_of_residuals,
1274
+ residual_fn=time_course_residual,
1275
+ fits=full_settings,
1276
+ pool=pool,
1277
+ ),
1278
+ p0,
1279
+ {} if bounds is None else bounds,
1280
+ )
1281
+ if min_result is None:
1282
+ return None
1283
+
1284
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1285
+
1286
+
1287
+ def joint_protocol_time_course(
1288
+ to_fit: list[FitSettings],
1289
+ *,
1290
+ p0: dict[str, float],
1291
+ minimizer: Minimizer,
1292
+ y0: dict[str, float] | None = None,
1293
+ integrator: IntegratorType | None = None,
1294
+ loss_fn: LossFn = rmse,
1295
+ bounds: Bounds | None = None,
1296
+ max_workers: int | None = None,
1297
+ as_deepcopy: bool = True,
1298
+ ) -> JointFitResult | None:
1299
+ """Multi-model, multi-data fitting."""
1300
+ full_settings = []
1301
+ for i in to_fit:
1302
+ p_names = i.model.get_parameter_names()
1303
+ v_names = i.model.get_variable_names()
1304
+ full_settings.append(
1305
+ _Settings(
1306
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1307
+ data=i.data,
1308
+ y0=i.y0 if i.y0 is not None else y0,
1309
+ integrator=i.integrator if i.integrator is not None else integrator,
1310
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1311
+ p_names=[j for j in p0 if j in p_names],
1312
+ v_names=[j for j in p0 if j in v_names],
1313
+ )
1314
+ )
1315
+
1316
+ with pebble.ProcessPool(
1317
+ max_workers=(
1318
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1319
+ )
1320
+ ) as pool:
1321
+ min_result = minimizer(
1322
+ partial(
1323
+ _sum_of_residuals,
1324
+ residual_fn=protocol_time_course_residual,
1325
+ fits=full_settings,
1326
+ pool=pool,
1327
+ ),
1328
+ p0,
1329
+ {} if bounds is None else bounds,
1330
+ )
1331
+ if min_result is None:
1332
+ return None
1333
+
1334
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1335
+
1336
+
1337
+ ###############################################################################
1338
+ # Joint fitting
1339
+ # This is multi-model, multi-data, multi-simulation fitting
1340
+ # The models share some parameters here, everything else can be changed though
1341
+ ###############################################################################
1342
+
1343
+
1344
+ def _execute(inp: tuple[dict[str, float], ResidualProtocol, _Settings]) -> float:
1345
+ updates, residual_fn, settings = inp
1346
+ return residual_fn(updates, settings)
1347
+
1348
+
1349
+ def _mixed_sum_of_residuals(
1350
+ updates: dict[str, float],
1351
+ fits: list[_Settings],
1352
+ pool: pebble.ProcessPool,
1353
+ ) -> float:
1354
+ future = pool.map(_execute, [(updates, i.residual_fn, i) for i in fits])
1355
+ error = 0.0
1356
+ it = future.result()
1357
+ while True:
1358
+ try:
1359
+ error += next(it)
1360
+ except StopIteration:
1361
+ break
1362
+ except TimeoutError:
1363
+ return np.inf
1364
+ return error
1365
+
1366
+
1367
+ def joint_mixed(
1368
+ to_fit: list[MixedSettings],
1369
+ *,
1370
+ p0: dict[str, float],
1371
+ minimizer: Minimizer,
1372
+ y0: dict[str, float] | None = None,
1373
+ integrator: IntegratorType | None = None,
1374
+ loss_fn: LossFn = rmse,
1375
+ bounds: Bounds | None = None,
1376
+ max_workers: int | None = None,
1377
+ as_deepcopy: bool = True,
1378
+ ) -> JointFitResult | None:
1379
+ """Multi-model, multi-data, multi-simulation fitting."""
1380
+ full_settings = []
1381
+ for i in to_fit:
1382
+ p_names = i.model.get_parameter_names()
1383
+ v_names = i.model.get_variable_names()
1384
+ full_settings.append(
1385
+ _Settings(
1386
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1387
+ data=i.data,
1388
+ y0=i.y0 if i.y0 is not None else y0,
1389
+ integrator=i.integrator if i.integrator is not None else integrator,
1390
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1391
+ p_names=[j for j in p0 if j in p_names],
1392
+ v_names=[j for j in p0 if j in v_names],
1393
+ residual_fn=i.residual_fn,
1394
+ )
1395
+ )
1396
+
1397
+ with pebble.ProcessPool(
1398
+ max_workers=(
1399
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1400
+ )
1401
+ ) as pool:
1402
+ min_result = minimizer(
1403
+ partial(
1404
+ _mixed_sum_of_residuals,
1405
+ fits=full_settings,
1406
+ pool=pool,
1407
+ ),
1408
+ p0,
1409
+ {} if bounds is None else bounds,
1410
+ )
1411
+ if min_result is None:
1412
+ return None
1413
+
1414
+ return JointFitResult(min_result.parameters, loss=min_result.residual)