mxlpy 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/fit.py ADDED
@@ -0,0 +1,1414 @@
1
+ """Fitting routines.
2
+
3
+ Single model, single data routines
4
+ ----------------------------------
5
+ - `steady_state`
6
+ - `time_course`
7
+ - `protocol_time_course`
8
+
9
+ Multiple model, single data routines
10
+ ------------------------------------
11
+ - `ensemble_steady_state`
12
+ - `ensemble_time_course`
13
+ - `ensemble_protocol_time_course`
14
+
15
+ A carousel is a special case of an ensemble, where the general
16
+ structure (e.g. stoichiometries) is the same, while the reactions kinetics
17
+ can vary
18
+ - `carousel_steady_state`
19
+ - `carousel_time_course`
20
+ - `carousel_protocol_time_course`
21
+
22
+ Multiple model, multiple data
23
+ -----------------------------
24
+ - `joint_steady_state`
25
+ - `joint_time_course`
26
+ - `joint_protocol_time_course`
27
+
28
+ Multiple model, multiple data, multiple methods
29
+ -----------------------------------------------
30
+ Here we also allow to run different methods (e.g. steady-state vs time courses)
31
+ for each combination of model:data.
32
+
33
+ - `joint_mixed`
34
+
35
+ Minimizers
36
+ ----------
37
+ - LocalScipyMinimizer, including common methods such as Nelder-Mead or L-BFGS-B
38
+ - GlobalScipyMinimizer, including common methods such as basin hopping or dual annealing
39
+
40
+ Loss functions
41
+ --------------
42
+ - rmse
43
+
44
+ """
45
+
46
+ from __future__ import annotations
47
+
48
+ import logging
49
+ import multiprocessing
50
+ from collections.abc import Callable
51
+ from copy import deepcopy
52
+ from dataclasses import dataclass
53
+ from functools import partial
54
+ from typing import TYPE_CHECKING, Literal, Protocol
55
+
56
+ import numpy as np
57
+ import pandas as pd
58
+ import pebble
59
+ from scipy.optimize import (
60
+ basinhopping,
61
+ differential_evolution,
62
+ direct,
63
+ dual_annealing,
64
+ minimize,
65
+ shgo,
66
+ )
67
+ from wadler_lindig import pformat
68
+
69
+ from mxlpy import parallel
70
+ from mxlpy.model import Model
71
+ from mxlpy.simulator import Simulator
72
+ from mxlpy.types import Array, IntegratorType, cast
73
+
74
+ if TYPE_CHECKING:
75
+ import pandas as pd
76
+ from scipy.optimize._optimize import OptimizeResult
77
+
78
+ from mxlpy.carousel import Carousel
79
+ from mxlpy.model import Model
80
+
81
+ LOGGER = logging.getLogger(__name__)
82
+
83
+
84
+ __all__ = [
85
+ "Bounds",
86
+ "EnsembleFitResult",
87
+ "FitResult",
88
+ "FitSettings",
89
+ "GlobalScipyMinimizer",
90
+ "InitialGuess",
91
+ "JointFitResult",
92
+ "LOGGER",
93
+ "LocalScipyMinimizer",
94
+ "LossFn",
95
+ "MinResult",
96
+ "Minimizer",
97
+ "MixedSettings",
98
+ "ResFn",
99
+ "ResidualFn",
100
+ "ResidualProtocol",
101
+ "carousel_protocol_time_course",
102
+ "carousel_steady_state",
103
+ "carousel_time_course",
104
+ "cosine_similarity",
105
+ "ensemble_protocol_time_course",
106
+ "ensemble_steady_state",
107
+ "ensemble_time_course",
108
+ "joint_mixed",
109
+ "joint_protocol_time_course",
110
+ "joint_steady_state",
111
+ "joint_time_course",
112
+ "mae",
113
+ "mean",
114
+ "mean_absolute_percentage",
115
+ "mean_squared",
116
+ "mean_squared_logarithmic",
117
+ "protocol_time_course",
118
+ "protocol_time_course_residual",
119
+ "rmse",
120
+ "steady_state",
121
+ "steady_state_residual",
122
+ "time_course",
123
+ "time_course_residual",
124
+ ]
125
+
126
+ type InitialGuess = dict[str, float]
127
+
128
+ type Bounds = dict[str, tuple[float | None, float | None]]
129
+ type ResFn = Callable[[Array], float]
130
+ type LossFn = Callable[
131
+ [
132
+ pd.DataFrame | pd.Series,
133
+ pd.DataFrame | pd.Series,
134
+ ],
135
+ float,
136
+ ]
137
+
138
+
139
+ type Minimizer = Callable[
140
+ [
141
+ ResidualFn,
142
+ InitialGuess,
143
+ Bounds,
144
+ ],
145
+ MinResult | None,
146
+ ]
147
+
148
+
149
+ class ResidualProtocol(Protocol):
150
+ """Protocol for steady state residual functions.
151
+
152
+ This is the user-facing variant, for stuff like
153
+ - `fit.steady_state`
154
+ - `fit.time_course`
155
+ - `fit.protocol_time_course`
156
+
157
+ The settings are later partially applied to yield ResidualFn
158
+ """
159
+
160
+ def __call__(
161
+ self,
162
+ updates: dict[str, float],
163
+ settings: _Settings,
164
+ ) -> float:
165
+ """Calculate residual error between model steady state and experimental data."""
166
+ ...
167
+
168
+
169
+ class ResidualFn(Protocol):
170
+ """Protocol for steady state residual functions.
171
+
172
+ This is the internal version, which is produced by partial
173
+ application of `settings` of `ResidualProtocol`
174
+ """
175
+
176
+ def __call__(
177
+ self,
178
+ updates: dict[str, float],
179
+ ) -> float:
180
+ """Calculate residual error between model steady state and experimental data."""
181
+ ...
182
+
183
+
184
+ @dataclass
185
+ class FitSettings:
186
+ """Settings for a fit."""
187
+
188
+ model: Model
189
+ data: pd.Series | pd.DataFrame
190
+ y0: dict[str, float] | None = None
191
+ integrator: IntegratorType | None = None
192
+ loss_fn: LossFn | None = None
193
+ protocol: pd.DataFrame | None = None
194
+
195
+
196
+ @dataclass
197
+ class MixedSettings:
198
+ """Settings for a fit."""
199
+
200
+ model: Model
201
+ data: pd.Series | pd.DataFrame
202
+ residual_fn: ResidualFn
203
+ y0: dict[str, float] | None = None
204
+ integrator: IntegratorType | None = None
205
+ loss_fn: LossFn | None = None
206
+ protocol: pd.DataFrame | None = None
207
+
208
+
209
+ @dataclass
210
+ class _Settings:
211
+ """Non user-facing version of FitSettings."""
212
+
213
+ model: Model
214
+ data: pd.Series | pd.DataFrame
215
+ y0: dict[str, float] | None
216
+ integrator: IntegratorType | None
217
+ loss_fn: LossFn
218
+ p_names: list[str]
219
+ v_names: list[str]
220
+ protocol: pd.DataFrame | None = None
221
+ residual_fn: ResidualFn | None = None
222
+
223
+
224
+ @dataclass
225
+ class MinResult:
226
+ """Result of a minimization operation."""
227
+
228
+ parameters: dict[str, float]
229
+ residual: float
230
+
231
+ def __repr__(self) -> str:
232
+ """Return default representation."""
233
+ return pformat(self)
234
+
235
+
236
+ @dataclass
237
+ class FitResult:
238
+ """Result of a fit operation."""
239
+
240
+ model: Model
241
+ best_pars: dict[str, float]
242
+ loss: float
243
+
244
+ def __repr__(self) -> str:
245
+ """Return default representation."""
246
+ return pformat(self)
247
+
248
+
249
+ @dataclass
250
+ class EnsembleFitResult:
251
+ """Result of a carousel fit operation."""
252
+
253
+ fits: list[FitResult]
254
+
255
+ def __repr__(self) -> str:
256
+ """Return default representation."""
257
+ return pformat(self)
258
+
259
+ def get_best_fit(self) -> FitResult:
260
+ """Get the best fit from the carousel."""
261
+ return min(self.fits, key=lambda x: x.loss)
262
+
263
+
264
+ @dataclass
265
+ class JointFitResult:
266
+ """Result of joint fit operation."""
267
+
268
+ best_pars: dict[str, float]
269
+ loss: float
270
+
271
+ def __repr__(self) -> str:
272
+ """Return default representation."""
273
+ return pformat(self)
274
+
275
+
276
+ ###############################################################################
277
+ # loss fns
278
+ ###############################################################################
279
+
280
+
281
+ def mean(
282
+ y_pred: pd.DataFrame | pd.Series,
283
+ y_true: pd.DataFrame | pd.Series,
284
+ ) -> float:
285
+ """Calculate root mean square error between model and data."""
286
+ return cast(float, np.mean(y_pred - y_true))
287
+
288
+
289
+ def mean_squared(
290
+ y_pred: pd.DataFrame | pd.Series,
291
+ y_true: pd.DataFrame | pd.Series,
292
+ ) -> float:
293
+ """Calculate mean square error between model and data."""
294
+ return cast(float, np.mean(np.square(y_pred - y_true)))
295
+
296
+
297
+ def rmse(
298
+ y_pred: pd.DataFrame | pd.Series,
299
+ y_true: pd.DataFrame | pd.Series,
300
+ ) -> float:
301
+ """Calculate root mean square error between model and data."""
302
+ return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
303
+
304
+
305
+ def mae(
306
+ y_pred: pd.DataFrame | pd.Series,
307
+ y_true: pd.DataFrame | pd.Series,
308
+ ) -> float:
309
+ """Calculate mean absolute error."""
310
+ return cast(float, np.mean(np.abs(y_true - y_pred)))
311
+
312
+
313
+ def mean_absolute_percentage(
314
+ y_pred: pd.DataFrame | pd.Series,
315
+ y_true: pd.DataFrame | pd.Series,
316
+ ) -> float:
317
+ """Calculate mean absolute error."""
318
+ return cast(float, 100 * np.mean(np.abs((y_true - y_pred) / y_pred)))
319
+
320
+
321
+ def mean_squared_logarithmic(
322
+ y_pred: pd.DataFrame | pd.Series,
323
+ y_true: pd.DataFrame | pd.Series,
324
+ ) -> float:
325
+ """Calculate root mean square error between model and data."""
326
+ return cast(float, np.mean(np.square(np.log(y_pred + 1) - np.log(y_true + 1))))
327
+
328
+
329
+ def cosine_similarity(
330
+ y_pred: pd.DataFrame | pd.Series,
331
+ y_true: pd.DataFrame | pd.Series,
332
+ ) -> float:
333
+ """Calculate root mean square error between model and data."""
334
+ norm = np.linalg.norm
335
+ return cast(float, -np.sum(norm(y_pred, 2) * norm(y_true, 2)))
336
+
337
+
338
+ ###############################################################################
339
+ # Minimizers
340
+ ###############################################################################
341
+
342
+
343
+ @dataclass
344
+ class LocalScipyMinimizer:
345
+ """Local multivariate minimization using scipy.optimize.
346
+
347
+ See Also
348
+ --------
349
+ https://docs.scipy.org/doc/scipy/reference/optimize.html#local-multivariate-optimization
350
+
351
+ """
352
+
353
+ tol: float = 1e-6
354
+ method: Literal[
355
+ "Nelder-Mead",
356
+ "Powell",
357
+ "CG",
358
+ "BFGS",
359
+ "Newton-CG",
360
+ "L-BFGS-B",
361
+ "TNC",
362
+ "COBYLA",
363
+ "COBYQA",
364
+ "SLSQP",
365
+ "trust-constr",
366
+ "dogleg",
367
+ "trust-ncg",
368
+ "trust-exact",
369
+ "trust-krylov",
370
+ ] = "L-BFGS-B"
371
+
372
+ def __call__(
373
+ self,
374
+ residual_fn: ResidualFn,
375
+ p0: dict[str, float],
376
+ bounds: Bounds,
377
+ ) -> MinResult | None:
378
+ """Call minimzer."""
379
+ par_names = list(p0.keys())
380
+
381
+ res = minimize(
382
+ lambda par_values: residual_fn(_pack_updates(par_values, par_names)),
383
+ x0=list(p0.values()),
384
+ bounds=[bounds.get(name, (1e-6, 1e6)) for name in p0],
385
+ method=self.method,
386
+ tol=self.tol,
387
+ )
388
+ if res.success:
389
+ return MinResult(
390
+ parameters=dict(
391
+ zip(
392
+ p0,
393
+ res.x,
394
+ strict=True,
395
+ ),
396
+ ),
397
+ residual=res.fun,
398
+ )
399
+
400
+ LOGGER.warning("Minimisation failed due to %s", res.message)
401
+ return None
402
+
403
+
404
+ @dataclass
405
+ class GlobalScipyMinimizer:
406
+ """Global iate minimization using scipy.optimize.
407
+
408
+ See Also
409
+ --------
410
+ https://docs.scipy.org/doc/scipy/reference/optimize.html#global-optimization
411
+
412
+ """
413
+
414
+ tol: float = 1e-6
415
+ method: Literal[
416
+ "basinhopping",
417
+ "differential_evolution",
418
+ "shgo",
419
+ "dual_annealing",
420
+ "direct",
421
+ ] = "basinhopping"
422
+
423
+ def __call__(
424
+ self,
425
+ residual_fn: ResidualFn,
426
+ p0: dict[str, float],
427
+ bounds: Bounds,
428
+ ) -> MinResult | None:
429
+ """Minimize residual fn."""
430
+ res: OptimizeResult
431
+ par_names = list(p0.keys())
432
+ res_fn: ResFn = lambda par_values: residual_fn( # noqa: E731
433
+ _pack_updates(par_values, par_names)
434
+ )
435
+
436
+ if self.method == "basinhopping":
437
+ res = basinhopping(
438
+ res_fn,
439
+ x0=list(p0.values()),
440
+ )
441
+ elif self.method == "differential_evolution":
442
+ res = differential_evolution(res_fn, bounds)
443
+ elif self.method == "shgo":
444
+ res = shgo(res_fn, bounds)
445
+ elif self.method == "dual_annealing":
446
+ res = dual_annealing(res_fn, bounds)
447
+ elif self.method == "direct":
448
+ res = direct(res_fn, bounds)
449
+ else:
450
+ msg = f"Unknown method {self.method}"
451
+ raise NotImplementedError(msg)
452
+ if res.success:
453
+ return MinResult(
454
+ parameters=dict(
455
+ zip(
456
+ p0,
457
+ res.x,
458
+ strict=True,
459
+ ),
460
+ ),
461
+ residual=res.fun,
462
+ )
463
+
464
+ LOGGER.warning("Minimisation failed.")
465
+ return None
466
+
467
+
468
+ ###############################################################################
469
+ # Residual functions
470
+ ###############################################################################
471
+
472
+
473
+ def _pack_updates(
474
+ par_values: Array,
475
+ par_names: list[str],
476
+ ) -> dict[str, float]:
477
+ return dict(
478
+ zip(
479
+ par_names,
480
+ par_values,
481
+ strict=True,
482
+ )
483
+ )
484
+
485
+
486
+ def steady_state_residual(
487
+ updates: dict[str, float],
488
+ settings: _Settings,
489
+ ) -> float:
490
+ """Calculate residual error between model steady state and experimental data."""
491
+ model = settings.model
492
+ if (y0 := settings.y0) is not None:
493
+ model.update_variables(y0)
494
+ for p in settings.p_names:
495
+ model.update_parameter(p, updates[p])
496
+ for p in settings.v_names:
497
+ model.update_variable(p, updates[p])
498
+
499
+ res = (
500
+ Simulator(
501
+ model,
502
+ integrator=settings.integrator,
503
+ )
504
+ .simulate_to_steady_state()
505
+ .get_result()
506
+ )
507
+ if res is None:
508
+ return cast(float, np.inf)
509
+
510
+ return settings.loss_fn(
511
+ res.get_combined().loc[:, cast(list, settings.data.index)],
512
+ settings.data,
513
+ )
514
+
515
+
516
+ def time_course_residual(
517
+ updates: dict[str, float],
518
+ settings: _Settings,
519
+ ) -> float:
520
+ """Calculate residual error between model time course and experimental data."""
521
+ model = settings.model
522
+ if (y0 := settings.y0) is not None:
523
+ model.update_variables(y0)
524
+ for p in settings.p_names:
525
+ model.update_parameter(p, updates[p])
526
+ for p in settings.v_names:
527
+ model.update_variable(p, updates[p])
528
+
529
+ res = (
530
+ Simulator(
531
+ model,
532
+ integrator=settings.integrator,
533
+ )
534
+ .simulate_time_course(cast(list, settings.data.index))
535
+ .get_result()
536
+ )
537
+ if res is None:
538
+ return cast(float, np.inf)
539
+ results_ss = res.get_combined()
540
+
541
+ return settings.loss_fn(
542
+ results_ss.loc[:, cast(list, settings.data.columns)],
543
+ settings.data,
544
+ )
545
+
546
+
547
+ def protocol_time_course_residual(
548
+ updates: dict[str, float],
549
+ settings: _Settings,
550
+ ) -> float:
551
+ """Calculate residual error between model time course and experimental data."""
552
+ model = settings.model
553
+ if (y0 := settings.y0) is not None:
554
+ model.update_variables(y0)
555
+ for p in settings.p_names:
556
+ model.update_parameter(p, updates[p])
557
+ for p in settings.v_names:
558
+ model.update_variable(p, updates[p])
559
+
560
+ if (protocol := settings.protocol) is None:
561
+ raise ValueError
562
+
563
+ res = (
564
+ Simulator(
565
+ model,
566
+ integrator=settings.integrator,
567
+ )
568
+ .simulate_protocol_time_course(
569
+ protocol=protocol,
570
+ time_points=settings.data.index,
571
+ )
572
+ .get_result()
573
+ )
574
+ if res is None:
575
+ return cast(float, np.inf)
576
+ results_ss = res.get_combined()
577
+
578
+ return settings.loss_fn(
579
+ results_ss.loc[:, cast(list, settings.data.columns)],
580
+ settings.data,
581
+ )
582
+
583
+
584
+ def steady_state(
585
+ model: Model,
586
+ *,
587
+ p0: dict[str, float],
588
+ data: pd.Series,
589
+ minimizer: Minimizer,
590
+ y0: dict[str, float] | None = None,
591
+ residual_fn: ResidualProtocol = steady_state_residual,
592
+ integrator: IntegratorType | None = None,
593
+ loss_fn: LossFn = rmse,
594
+ bounds: Bounds | None = None,
595
+ as_deepcopy: bool = True,
596
+ ) -> FitResult | None:
597
+ """Fit model parameters to steady-state experimental data.
598
+
599
+ Examples:
600
+ >>> steady_state(model, p0, data)
601
+ {'k1': 0.1, 'k2': 0.2}
602
+
603
+ Args:
604
+ model: Model instance to fit
605
+ data: Experimental steady state data as pandas Series
606
+ p0: Initial guesses as {name: value}
607
+ y0: Initial conditions as {species_name: value}
608
+ minimizer: Function to minimize fitting error
609
+ residual_fn: Function to calculate fitting error
610
+ integrator: ODE integrator class
611
+ loss_fn: Loss function to use for residual calculation
612
+ bounds: Mapping of bounds per parameter
613
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
614
+
615
+ Returns:
616
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
617
+
618
+ Note:
619
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
620
+
621
+ """
622
+ if as_deepcopy:
623
+ model = deepcopy(model)
624
+
625
+ p_names = model.get_parameter_names()
626
+ v_names = model.get_variable_names()
627
+
628
+ fn: ResidualFn = partial(
629
+ residual_fn,
630
+ settings=_Settings(
631
+ model=model,
632
+ data=data,
633
+ y0=y0,
634
+ integrator=integrator,
635
+ loss_fn=loss_fn,
636
+ p_names=[i for i in p0 if i in p_names],
637
+ v_names=[i for i in p0 if i in v_names],
638
+ ),
639
+ )
640
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
641
+ if min_result is None:
642
+ return None
643
+
644
+ return FitResult(
645
+ model=model,
646
+ best_pars=min_result.parameters,
647
+ loss=min_result.residual,
648
+ )
649
+
650
+
651
+ def time_course(
652
+ model: Model,
653
+ *,
654
+ p0: dict[str, float],
655
+ data: pd.DataFrame,
656
+ minimizer: Minimizer,
657
+ y0: dict[str, float] | None = None,
658
+ residual_fn: ResidualProtocol = time_course_residual,
659
+ integrator: IntegratorType | None = None,
660
+ loss_fn: LossFn = rmse,
661
+ bounds: Bounds | None = None,
662
+ as_deepcopy: bool = True,
663
+ ) -> FitResult | None:
664
+ """Fit model parameters to time course of experimental data.
665
+
666
+ Examples:
667
+ >>> time_course(model, p0, data)
668
+ {'k1': 0.1, 'k2': 0.2}
669
+
670
+ Args:
671
+ model: Model instance to fit
672
+ data: Experimental time course data
673
+ p0: Initial guesses as {parameter_name: value}
674
+ y0: Initial conditions as {species_name: value}
675
+ minimizer: Function to minimize fitting error
676
+ residual_fn: Function to calculate fitting error
677
+ integrator: ODE integrator class
678
+ loss_fn: Loss function to use for residual calculation
679
+ bounds: Mapping of bounds per parameter
680
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
681
+
682
+ Returns:
683
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
684
+
685
+ Note:
686
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
687
+
688
+ """
689
+ if as_deepcopy:
690
+ model = deepcopy(model)
691
+ p_names = model.get_parameter_names()
692
+ v_names = model.get_variable_names()
693
+
694
+ fn: ResidualFn = partial(
695
+ residual_fn,
696
+ settings=_Settings(
697
+ model=model,
698
+ data=data,
699
+ y0=y0,
700
+ integrator=integrator,
701
+ loss_fn=loss_fn,
702
+ p_names=[i for i in p0 if i in p_names],
703
+ v_names=[i for i in p0 if i in v_names],
704
+ ),
705
+ )
706
+
707
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
708
+ if min_result is None:
709
+ return None
710
+
711
+ return FitResult(
712
+ model=model,
713
+ best_pars=min_result.parameters,
714
+ loss=min_result.residual,
715
+ )
716
+
717
+
718
+ def protocol_time_course(
719
+ model: Model,
720
+ *,
721
+ p0: dict[str, float],
722
+ data: pd.DataFrame,
723
+ protocol: pd.DataFrame,
724
+ minimizer: Minimizer,
725
+ y0: dict[str, float] | None = None,
726
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
727
+ integrator: IntegratorType | None = None,
728
+ loss_fn: LossFn = rmse,
729
+ bounds: Bounds | None = None,
730
+ as_deepcopy: bool = True,
731
+ ) -> FitResult | None:
732
+ """Fit model parameters to time course of experimental data.
733
+
734
+ Time points of protocol time course are taken from the data.
735
+
736
+ Examples:
737
+ >>> time_course(model, p0, data)
738
+ {'k1': 0.1, 'k2': 0.2}
739
+
740
+ Args:
741
+ model: Model instance to fit
742
+ p0: Initial parameter guesses as {parameter_name: value}
743
+ data: Experimental time course data
744
+ protocol: Experimental protocol
745
+ y0: Initial conditions as {species_name: value}
746
+ minimizer: Function to minimize fitting error
747
+ residual_fn: Function to calculate fitting error
748
+ integrator: ODE integrator class
749
+ loss_fn: Loss function to use for residual calculation
750
+ time_points_per_step: Number of time points per step in the protocol
751
+ bounds: Mapping of bounds per parameter
752
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
753
+
754
+ Returns:
755
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
756
+
757
+ Note:
758
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
759
+
760
+ """
761
+ if as_deepcopy:
762
+ model = deepcopy(model)
763
+ p_names = model.get_parameter_names()
764
+ v_names = model.get_variable_names()
765
+
766
+ fn: ResidualFn = partial(
767
+ residual_fn,
768
+ settings=_Settings(
769
+ model=model,
770
+ data=data,
771
+ y0=y0,
772
+ integrator=integrator,
773
+ loss_fn=loss_fn,
774
+ p_names=[i for i in p0 if i in p_names],
775
+ v_names=[i for i in p0 if i in v_names],
776
+ protocol=protocol,
777
+ ),
778
+ )
779
+
780
+ min_result = minimizer(fn, p0, {} if bounds is None else bounds)
781
+ if min_result is None:
782
+ return None
783
+
784
+ return FitResult(
785
+ model=model,
786
+ best_pars=min_result.parameters,
787
+ loss=min_result.residual,
788
+ )
789
+
790
+
791
+ ###############################################################################
792
+ # Ensemble / carousel
793
+ # This is multi-model, single data fitting, where the models share parameters
794
+ ###############################################################################
795
+
796
+
797
+ def ensemble_steady_state(
798
+ ensemble: list[Model],
799
+ *,
800
+ p0: dict[str, float],
801
+ data: pd.Series,
802
+ minimizer: Minimizer,
803
+ y0: dict[str, float] | None = None,
804
+ residual_fn: ResidualProtocol = steady_state_residual,
805
+ integrator: IntegratorType | None = None,
806
+ loss_fn: LossFn = rmse,
807
+ bounds: Bounds | None = None,
808
+ as_deepcopy: bool = True,
809
+ ) -> EnsembleFitResult:
810
+ """Fit model ensemble parameters to steady-state experimental data.
811
+
812
+ Examples:
813
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
814
+
815
+ Args:
816
+ ensemble: Ensemble to fit
817
+ p0: Initial parameter guesses as {parameter_name: value}
818
+ data: Experimental time course data
819
+ protocol: Experimental protocol
820
+ y0: Initial conditions as {species_name: value}
821
+ minimizer: Function to minimize fitting error
822
+ residual_fn: Function to calculate fitting error
823
+ integrator: ODE integrator class
824
+ loss_fn: Loss function to use for residual calculation
825
+ time_points_per_step: Number of time points per step in the protocol
826
+ bounds: Mapping of bounds per parameter
827
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
828
+
829
+ Returns:
830
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
831
+
832
+ Note:
833
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
834
+
835
+ """
836
+ return EnsembleFitResult(
837
+ [
838
+ fit
839
+ for i in parallel.parallelise(
840
+ partial(
841
+ steady_state,
842
+ p0=p0,
843
+ data=data,
844
+ y0=y0,
845
+ integrator=integrator,
846
+ loss_fn=loss_fn,
847
+ minimizer=minimizer,
848
+ residual_fn=residual_fn,
849
+ bounds=bounds,
850
+ as_deepcopy=as_deepcopy,
851
+ ),
852
+ inputs=list(enumerate(ensemble)),
853
+ )
854
+ if (fit := i[1]) is not None
855
+ ]
856
+ )
857
+
858
+
859
+ def carousel_steady_state(
860
+ carousel: Carousel,
861
+ *,
862
+ p0: dict[str, float],
863
+ data: pd.Series,
864
+ minimizer: Minimizer,
865
+ y0: dict[str, float] | None = None,
866
+ residual_fn: ResidualProtocol = steady_state_residual,
867
+ integrator: IntegratorType | None = None,
868
+ loss_fn: LossFn = rmse,
869
+ bounds: Bounds | None = None,
870
+ as_deepcopy: bool = True,
871
+ ) -> EnsembleFitResult:
872
+ """Fit model parameters to steady-state experimental data over a carousel.
873
+
874
+ Examples:
875
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
876
+
877
+ Args:
878
+ carousel: Model carousel to fit
879
+ p0: Initial parameter guesses as {parameter_name: value}
880
+ data: Experimental time course data
881
+ protocol: Experimental protocol
882
+ y0: Initial conditions as {species_name: value}
883
+ minimizer: Function to minimize fitting error
884
+ residual_fn: Function to calculate fitting error
885
+ integrator: ODE integrator class
886
+ loss_fn: Loss function to use for residual calculation
887
+ time_points_per_step: Number of time points per step in the protocol
888
+ bounds: Mapping of bounds per parameter
889
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
890
+
891
+ Returns:
892
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
893
+
894
+ Note:
895
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
896
+
897
+ """
898
+ return ensemble_steady_state(
899
+ carousel.variants,
900
+ p0=p0,
901
+ data=data,
902
+ minimizer=minimizer,
903
+ y0=y0,
904
+ residual_fn=residual_fn,
905
+ integrator=integrator,
906
+ loss_fn=loss_fn,
907
+ bounds=bounds,
908
+ as_deepcopy=as_deepcopy,
909
+ )
910
+
911
+
912
+ def ensemble_time_course(
913
+ ensemble: list[Model],
914
+ *,
915
+ p0: dict[str, float],
916
+ data: pd.DataFrame,
917
+ minimizer: Minimizer,
918
+ y0: dict[str, float] | None = None,
919
+ residual_fn: ResidualProtocol = time_course_residual,
920
+ integrator: IntegratorType | None = None,
921
+ loss_fn: LossFn = rmse,
922
+ bounds: Bounds | None = None,
923
+ as_deepcopy: bool = True,
924
+ ) -> EnsembleFitResult:
925
+ """Fit model parameters to time course of experimental data over a carousel.
926
+
927
+ Time points are taken from the data.
928
+
929
+ Examples:
930
+ >>> carousel_time_course(carousel, p0=p0, data=data)
931
+
932
+ Args:
933
+ ensemble: Model ensemble to fit
934
+ p0: Initial parameter guesses as {parameter_name: value}
935
+ data: Experimental time course data
936
+ protocol: Experimental protocol
937
+ y0: Initial conditions as {species_name: value}
938
+ minimizer: Function to minimize fitting error
939
+ residual_fn: Function to calculate fitting error
940
+ integrator: ODE integrator class
941
+ loss_fn: Loss function to use for residual calculation
942
+ time_points_per_step: Number of time points per step in the protocol
943
+ bounds: Mapping of bounds per parameter
944
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
945
+
946
+ Returns:
947
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
948
+
949
+ Note:
950
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
951
+
952
+ """
953
+ return EnsembleFitResult(
954
+ [
955
+ fit
956
+ for i in parallel.parallelise(
957
+ partial(
958
+ time_course,
959
+ p0=p0,
960
+ data=data,
961
+ y0=y0,
962
+ integrator=integrator,
963
+ loss_fn=loss_fn,
964
+ minimizer=minimizer,
965
+ residual_fn=residual_fn,
966
+ bounds=bounds,
967
+ as_deepcopy=as_deepcopy,
968
+ ),
969
+ inputs=list(enumerate(ensemble)),
970
+ )
971
+ if (fit := i[1]) is not None
972
+ ]
973
+ )
974
+
975
+
976
+ def carousel_time_course(
977
+ carousel: Carousel,
978
+ *,
979
+ p0: dict[str, float],
980
+ data: pd.DataFrame,
981
+ minimizer: Minimizer,
982
+ y0: dict[str, float] | None = None,
983
+ residual_fn: ResidualProtocol = time_course_residual,
984
+ integrator: IntegratorType | None = None,
985
+ loss_fn: LossFn = rmse,
986
+ bounds: Bounds | None = None,
987
+ as_deepcopy: bool = True,
988
+ ) -> EnsembleFitResult:
989
+ """Fit model parameters to time course of experimental data over a carousel.
990
+
991
+ Time points are taken from the data.
992
+
993
+ Examples:
994
+ >>> carousel_time_course(carousel, p0=p0, data=data)
995
+
996
+ Args:
997
+ carousel: Model carousel to fit
998
+ p0: Initial parameter guesses as {parameter_name: value}
999
+ data: Experimental time course data
1000
+ protocol: Experimental protocol
1001
+ y0: Initial conditions as {species_name: value}
1002
+ minimizer: Function to minimize fitting error
1003
+ residual_fn: Function to calculate fitting error
1004
+ integrator: ODE integrator class
1005
+ loss_fn: Loss function to use for residual calculation
1006
+ time_points_per_step: Number of time points per step in the protocol
1007
+ bounds: Mapping of bounds per parameter
1008
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
1009
+
1010
+ Returns:
1011
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
1012
+
1013
+ Note:
1014
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
1015
+
1016
+ """
1017
+ return ensemble_time_course(
1018
+ carousel.variants,
1019
+ p0=p0,
1020
+ data=data,
1021
+ minimizer=minimizer,
1022
+ y0=y0,
1023
+ residual_fn=residual_fn,
1024
+ integrator=integrator,
1025
+ loss_fn=loss_fn,
1026
+ bounds=bounds,
1027
+ as_deepcopy=as_deepcopy,
1028
+ )
1029
+
1030
+
1031
+ def ensemble_protocol_time_course(
1032
+ ensemble: list[Model],
1033
+ *,
1034
+ p0: dict[str, float],
1035
+ data: pd.DataFrame,
1036
+ minimizer: Minimizer,
1037
+ protocol: pd.DataFrame,
1038
+ y0: dict[str, float] | None = None,
1039
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
1040
+ integrator: IntegratorType | None = None,
1041
+ loss_fn: LossFn = rmse,
1042
+ bounds: Bounds | None = None,
1043
+ as_deepcopy: bool = True,
1044
+ ) -> EnsembleFitResult:
1045
+ """Fit model parameters to time course of experimental data over a protocol.
1046
+
1047
+ Time points of protocol time course are taken from the data.
1048
+
1049
+ Examples:
1050
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
1051
+
1052
+ Args:
1053
+ ensemble: Model ensemble: value}
1054
+ p0: initial parameter guess
1055
+ data: Experimental time course data
1056
+ protocol: Experimental protocol
1057
+ y0: Initial conditions as {species_name: value}
1058
+ minimizer: Function to minimize fitting error
1059
+ residual_fn: Function to calculate fitting error
1060
+ integrator: ODE integrator class
1061
+ loss_fn: Loss function to use for residual calculation
1062
+ time_points_per_step: Number of time points per step in the protocol
1063
+ bounds: Mapping of bounds per parameter
1064
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
1065
+
1066
+ Returns:
1067
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
1068
+
1069
+ Note:
1070
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
1071
+
1072
+ """
1073
+ return EnsembleFitResult(
1074
+ [
1075
+ fit
1076
+ for i in parallel.parallelise(
1077
+ partial(
1078
+ protocol_time_course,
1079
+ p0=p0,
1080
+ data=data,
1081
+ protocol=protocol,
1082
+ y0=y0,
1083
+ integrator=integrator,
1084
+ loss_fn=loss_fn,
1085
+ minimizer=minimizer,
1086
+ residual_fn=residual_fn,
1087
+ bounds=bounds,
1088
+ as_deepcopy=as_deepcopy,
1089
+ ),
1090
+ inputs=list(enumerate(ensemble)),
1091
+ )
1092
+ if (fit := i[1]) is not None
1093
+ ]
1094
+ )
1095
+
1096
+
1097
+ def carousel_protocol_time_course(
1098
+ carousel: Carousel,
1099
+ *,
1100
+ p0: dict[str, float],
1101
+ data: pd.DataFrame,
1102
+ minimizer: Minimizer,
1103
+ protocol: pd.DataFrame,
1104
+ y0: dict[str, float] | None = None,
1105
+ residual_fn: ResidualProtocol = protocol_time_course_residual,
1106
+ integrator: IntegratorType | None = None,
1107
+ loss_fn: LossFn = rmse,
1108
+ bounds: Bounds | None = None,
1109
+ as_deepcopy: bool = True,
1110
+ ) -> EnsembleFitResult:
1111
+ """Fit model parameters to time course of experimental data over a protocol.
1112
+
1113
+ Time points of protocol time course are taken from the data.
1114
+
1115
+ Examples:
1116
+ >>> carousel_steady_state(carousel, p0=p0, data=data)
1117
+
1118
+ Args:
1119
+ carousel: Model carousel to fit
1120
+ p0: Initial parameter guesses as {parameter_name: value}
1121
+ data: Experimental time course data
1122
+ protocol: Experimental protocol
1123
+ y0: Initial conditions as {species_name: value}
1124
+ minimizer: Function to minimize fitting error
1125
+ residual_fn: Function to calculate fitting error
1126
+ integrator: ODE integrator class
1127
+ loss_fn: Loss function to use for residual calculation
1128
+ time_points_per_step: Number of time points per step in the protocol
1129
+ bounds: Mapping of bounds per parameter
1130
+ as_deepcopy: Whether to copy the model to avoid overwriting the state
1131
+
1132
+ Returns:
1133
+ dict[str, float]: Fitted parameters as {parameter_name: fitted_value}
1134
+
1135
+ Note:
1136
+ Uses L-BFGS-B optimization with bounds [1e-6, 1e6] for all parameters
1137
+
1138
+ """
1139
+ return ensemble_protocol_time_course(
1140
+ carousel.variants,
1141
+ p0=p0,
1142
+ data=data,
1143
+ minimizer=minimizer,
1144
+ protocol=protocol,
1145
+ y0=y0,
1146
+ residual_fn=residual_fn,
1147
+ integrator=integrator,
1148
+ loss_fn=loss_fn,
1149
+ bounds=bounds,
1150
+ as_deepcopy=as_deepcopy,
1151
+ )
1152
+
1153
+
1154
+ ###############################################################################
1155
+ # Joint fitting
1156
+ # This is multi-model, multi-data fitting, where the models share some parameters
1157
+ ###############################################################################
1158
+
1159
+
1160
+ def _unpacked[T1, T2, Tout](inp: tuple[T1, T2], fn: Callable[[T1, T2], Tout]) -> Tout:
1161
+ return fn(*inp)
1162
+
1163
+
1164
+ def _sum_of_residuals(
1165
+ updates: dict[str, float],
1166
+ residual_fn: ResidualProtocol,
1167
+ fits: list[_Settings],
1168
+ pool: pebble.ProcessPool,
1169
+ ) -> float:
1170
+ future = pool.map(
1171
+ partial(_unpacked, fn=residual_fn),
1172
+ [(updates, i) for i in fits],
1173
+ timeout=None,
1174
+ )
1175
+ error = 0.0
1176
+ it = future.result()
1177
+ while True:
1178
+ try:
1179
+ error += next(it)
1180
+ except StopIteration:
1181
+ break
1182
+ except TimeoutError:
1183
+ return np.inf
1184
+ return error
1185
+
1186
+
1187
+ def joint_steady_state(
1188
+ to_fit: list[FitSettings],
1189
+ *,
1190
+ p0: dict[str, float],
1191
+ minimizer: Minimizer,
1192
+ y0: dict[str, float] | None = None,
1193
+ integrator: IntegratorType | None = None,
1194
+ loss_fn: LossFn = rmse,
1195
+ bounds: Bounds | None = None,
1196
+ max_workers: int | None = None,
1197
+ as_deepcopy: bool = True,
1198
+ ) -> JointFitResult | None:
1199
+ """Multi-model, multi-data fitting."""
1200
+ full_settings = []
1201
+ for i in to_fit:
1202
+ p_names = i.model.get_parameter_names()
1203
+ v_names = i.model.get_variable_names()
1204
+ full_settings.append(
1205
+ _Settings(
1206
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1207
+ data=i.data,
1208
+ y0=i.y0 if i.y0 is not None else y0,
1209
+ integrator=i.integrator if i.integrator is not None else integrator,
1210
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1211
+ p_names=[j for j in p0 if j in p_names],
1212
+ v_names=[j for j in p0 if j in v_names],
1213
+ )
1214
+ )
1215
+
1216
+ with pebble.ProcessPool(
1217
+ max_workers=(
1218
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1219
+ )
1220
+ ) as pool:
1221
+ min_result = minimizer(
1222
+ partial(
1223
+ _sum_of_residuals,
1224
+ residual_fn=steady_state_residual,
1225
+ fits=full_settings,
1226
+ pool=pool,
1227
+ ),
1228
+ p0,
1229
+ {} if bounds is None else bounds,
1230
+ )
1231
+ if min_result is None:
1232
+ return None
1233
+
1234
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1235
+
1236
+
1237
+ def joint_time_course(
1238
+ to_fit: list[FitSettings],
1239
+ *,
1240
+ p0: dict[str, float],
1241
+ minimizer: Minimizer,
1242
+ y0: dict[str, float] | None = None,
1243
+ integrator: IntegratorType | None = None,
1244
+ loss_fn: LossFn = rmse,
1245
+ bounds: Bounds | None = None,
1246
+ max_workers: int | None = None,
1247
+ as_deepcopy: bool = True,
1248
+ ) -> JointFitResult | None:
1249
+ """Multi-model, multi-data fitting."""
1250
+ full_settings = []
1251
+ for i in to_fit:
1252
+ p_names = i.model.get_parameter_names()
1253
+ v_names = i.model.get_variable_names()
1254
+ full_settings.append(
1255
+ _Settings(
1256
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1257
+ data=i.data,
1258
+ y0=i.y0 if i.y0 is not None else y0,
1259
+ integrator=i.integrator if i.integrator is not None else integrator,
1260
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1261
+ p_names=[j for j in p0 if j in p_names],
1262
+ v_names=[j for j in p0 if j in v_names],
1263
+ )
1264
+ )
1265
+
1266
+ with pebble.ProcessPool(
1267
+ max_workers=(
1268
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1269
+ )
1270
+ ) as pool:
1271
+ min_result = minimizer(
1272
+ partial(
1273
+ _sum_of_residuals,
1274
+ residual_fn=time_course_residual,
1275
+ fits=full_settings,
1276
+ pool=pool,
1277
+ ),
1278
+ p0,
1279
+ {} if bounds is None else bounds,
1280
+ )
1281
+ if min_result is None:
1282
+ return None
1283
+
1284
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1285
+
1286
+
1287
+ def joint_protocol_time_course(
1288
+ to_fit: list[FitSettings],
1289
+ *,
1290
+ p0: dict[str, float],
1291
+ minimizer: Minimizer,
1292
+ y0: dict[str, float] | None = None,
1293
+ integrator: IntegratorType | None = None,
1294
+ loss_fn: LossFn = rmse,
1295
+ bounds: Bounds | None = None,
1296
+ max_workers: int | None = None,
1297
+ as_deepcopy: bool = True,
1298
+ ) -> JointFitResult | None:
1299
+ """Multi-model, multi-data fitting."""
1300
+ full_settings = []
1301
+ for i in to_fit:
1302
+ p_names = i.model.get_parameter_names()
1303
+ v_names = i.model.get_variable_names()
1304
+ full_settings.append(
1305
+ _Settings(
1306
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1307
+ data=i.data,
1308
+ y0=i.y0 if i.y0 is not None else y0,
1309
+ integrator=i.integrator if i.integrator is not None else integrator,
1310
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1311
+ p_names=[j for j in p0 if j in p_names],
1312
+ v_names=[j for j in p0 if j in v_names],
1313
+ )
1314
+ )
1315
+
1316
+ with pebble.ProcessPool(
1317
+ max_workers=(
1318
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1319
+ )
1320
+ ) as pool:
1321
+ min_result = minimizer(
1322
+ partial(
1323
+ _sum_of_residuals,
1324
+ residual_fn=protocol_time_course_residual,
1325
+ fits=full_settings,
1326
+ pool=pool,
1327
+ ),
1328
+ p0,
1329
+ {} if bounds is None else bounds,
1330
+ )
1331
+ if min_result is None:
1332
+ return None
1333
+
1334
+ return JointFitResult(min_result.parameters, loss=min_result.residual)
1335
+
1336
+
1337
+ ###############################################################################
1338
+ # Joint fitting
1339
+ # This is multi-model, multi-data, multi-simulation fitting
1340
+ # The models share some parameters here, everything else can be changed though
1341
+ ###############################################################################
1342
+
1343
+
1344
+ def _execute(inp: tuple[dict[str, float], ResidualProtocol, _Settings]) -> float:
1345
+ updates, residual_fn, settings = inp
1346
+ return residual_fn(updates, settings)
1347
+
1348
+
1349
+ def _mixed_sum_of_residuals(
1350
+ updates: dict[str, float],
1351
+ fits: list[_Settings],
1352
+ pool: pebble.ProcessPool,
1353
+ ) -> float:
1354
+ future = pool.map(_execute, [(updates, i.residual_fn, i) for i in fits])
1355
+ error = 0.0
1356
+ it = future.result()
1357
+ while True:
1358
+ try:
1359
+ error += next(it)
1360
+ except StopIteration:
1361
+ break
1362
+ except TimeoutError:
1363
+ return np.inf
1364
+ return error
1365
+
1366
+
1367
+ def joint_mixed(
1368
+ to_fit: list[MixedSettings],
1369
+ *,
1370
+ p0: dict[str, float],
1371
+ minimizer: Minimizer,
1372
+ y0: dict[str, float] | None = None,
1373
+ integrator: IntegratorType | None = None,
1374
+ loss_fn: LossFn = rmse,
1375
+ bounds: Bounds | None = None,
1376
+ max_workers: int | None = None,
1377
+ as_deepcopy: bool = True,
1378
+ ) -> JointFitResult | None:
1379
+ """Multi-model, multi-data, multi-simulation fitting."""
1380
+ full_settings = []
1381
+ for i in to_fit:
1382
+ p_names = i.model.get_parameter_names()
1383
+ v_names = i.model.get_variable_names()
1384
+ full_settings.append(
1385
+ _Settings(
1386
+ model=deepcopy(i.model) if as_deepcopy else i.model,
1387
+ data=i.data,
1388
+ y0=i.y0 if i.y0 is not None else y0,
1389
+ integrator=i.integrator if i.integrator is not None else integrator,
1390
+ loss_fn=i.loss_fn if i.loss_fn is not None else loss_fn,
1391
+ p_names=[j for j in p0 if j in p_names],
1392
+ v_names=[j for j in p0 if j in v_names],
1393
+ residual_fn=i.residual_fn,
1394
+ )
1395
+ )
1396
+
1397
+ with pebble.ProcessPool(
1398
+ max_workers=(
1399
+ multiprocessing.cpu_count() if max_workers is None else max_workers
1400
+ )
1401
+ ) as pool:
1402
+ min_result = minimizer(
1403
+ partial(
1404
+ _mixed_sum_of_residuals,
1405
+ fits=full_settings,
1406
+ pool=pool,
1407
+ ),
1408
+ p0,
1409
+ {} if bounds is None else bounds,
1410
+ )
1411
+ if min_result is None:
1412
+ return None
1413
+
1414
+ return JointFitResult(min_result.parameters, loss=min_result.residual)