mxlpy 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/fit.py CHANGED
@@ -10,44 +10,80 @@ Functions:
10
10
 
11
11
  from __future__ import annotations
12
12
 
13
+ import logging
14
+ from copy import deepcopy
15
+ from dataclasses import dataclass
13
16
  from functools import partial
14
17
  from typing import TYPE_CHECKING, Protocol
15
18
 
16
19
  import numpy as np
17
20
  from scipy.optimize import minimize
18
21
 
19
- from mxlpy.integrators import DefaultIntegrator
22
+ from mxlpy import parallel
20
23
  from mxlpy.simulator import Simulator
21
- from mxlpy.types import (
22
- Array,
23
- ArrayLike,
24
- Callable,
25
- IntegratorType,
26
- cast,
27
- )
24
+ from mxlpy.types import Array, ArrayLike, Callable, IntegratorType, cast
25
+
26
+ if TYPE_CHECKING:
27
+ import pandas as pd
28
+
29
+ from mxlpy.carousel import Carousel
30
+ from mxlpy.model import Model
31
+
32
+ LOGGER = logging.getLogger(__name__)
28
33
 
29
34
  __all__ = [
35
+ "CarouselFit",
36
+ "FitResult",
30
37
  "InitialGuess",
38
+ "LOGGER",
31
39
  "LossFn",
40
+ "MinResult",
32
41
  "MinimizeFn",
33
42
  "ProtocolResidualFn",
34
43
  "ResidualFn",
35
44
  "SteadyStateResidualFn",
36
45
  "TimeSeriesResidualFn",
46
+ "carousel_steady_state",
47
+ "carousel_time_course",
48
+ "carousel_time_course_over_protocol",
37
49
  "rmse",
38
50
  "steady_state",
39
51
  "time_course",
40
52
  "time_course_over_protocol",
41
53
  ]
42
54
 
43
- if TYPE_CHECKING:
44
- import pandas as pd
45
55
 
46
- from mxlpy.model import Model
56
+ @dataclass
57
+ class MinResult:
58
+ """Result of a minimization operation."""
59
+
60
+ parameters: dict[str, float]
61
+ residual: float
62
+
63
+
64
+ @dataclass
65
+ class FitResult:
66
+ """Result of a fit operation."""
67
+
68
+ model: Model
69
+ best_pars: dict[str, float]
70
+ loss: float
71
+
72
+
73
+ @dataclass
74
+ class CarouselFit:
75
+ """Result of a carousel fit operation."""
76
+
77
+ fits: list[FitResult]
78
+
79
+ def get_best_fit(self) -> FitResult:
80
+ """Get the best fit from the carousel."""
81
+ return min(self.fits, key=lambda x: x.loss)
82
+
47
83
 
48
84
  type InitialGuess = dict[str, float]
49
85
  type ResidualFn = Callable[[Array], float]
50
- type MinimizeFn = Callable[[ResidualFn, InitialGuess], dict[str, float]]
86
+ type MinimizeFn = Callable[[ResidualFn, InitialGuess], MinResult | None]
51
87
  type LossFn = Callable[
52
88
  [
53
89
  pd.DataFrame | pd.Series,
@@ -75,7 +111,7 @@ class SteadyStateResidualFn(Protocol):
75
111
  par_names: list[str],
76
112
  data: pd.Series,
77
113
  model: Model,
78
- y0: dict[str, float],
114
+ y0: dict[str, float] | None,
79
115
  integrator: IntegratorType,
80
116
  loss_fn: LossFn,
81
117
  ) -> float:
@@ -93,7 +129,7 @@ class TimeSeriesResidualFn(Protocol):
93
129
  par_names: list[str],
94
130
  data: pd.DataFrame,
95
131
  model: Model,
96
- y0: dict[str, float],
132
+ y0: dict[str, float] | None,
97
133
  integrator: IntegratorType,
98
134
  loss_fn: LossFn,
99
135
  ) -> float:
@@ -111,7 +147,7 @@ class ProtocolResidualFn(Protocol):
111
147
  par_names: list[str],
112
148
  data: pd.DataFrame,
113
149
  model: Model,
114
- y0: dict[str, float],
150
+ y0: dict[str, float] | None,
115
151
  integrator: IntegratorType,
116
152
  loss_fn: LossFn,
117
153
  protocol: pd.DataFrame,
@@ -124,22 +160,27 @@ class ProtocolResidualFn(Protocol):
124
160
  def _default_minimize_fn(
125
161
  residual_fn: ResidualFn,
126
162
  p0: dict[str, float],
127
- ) -> dict[str, float]:
163
+ ) -> MinResult | None:
128
164
  res = minimize(
129
165
  residual_fn,
130
166
  x0=list(p0.values()),
131
- bounds=[(1e-12, 1e6) for _ in range(len(p0))],
167
+ bounds=[(0, None) for _ in range(len(p0))],
132
168
  method="L-BFGS-B",
133
169
  )
134
170
  if res.success:
135
- return dict(
136
- zip(
137
- p0,
138
- res.x,
139
- strict=True,
140
- )
171
+ return MinResult(
172
+ parameters=dict(
173
+ zip(
174
+ p0,
175
+ res.x,
176
+ strict=True,
177
+ ),
178
+ ),
179
+ residual=res.fun,
141
180
  )
142
- return dict(zip(p0, np.full(len(p0), np.nan, dtype=float), strict=True))
181
+
182
+ LOGGER.warning("Minimisation failed.")
183
+ return None
143
184
 
144
185
 
145
186
  def _steady_state_residual(
@@ -290,14 +331,15 @@ def _protocol_residual(
290
331
 
291
332
  def steady_state(
292
333
  model: Model,
334
+ *,
293
335
  p0: dict[str, float],
294
336
  data: pd.Series,
295
337
  y0: dict[str, float] | None = None,
296
338
  minimize_fn: MinimizeFn = _default_minimize_fn,
297
339
  residual_fn: SteadyStateResidualFn = _steady_state_residual,
298
- integrator: IntegratorType = DefaultIntegrator,
340
+ integrator: IntegratorType | None = None,
299
341
  loss_fn: LossFn = rmse,
300
- ) -> dict[str, float]:
342
+ ) -> FitResult | None:
301
343
  """Fit model parameters to steady-state experimental data.
302
344
 
303
345
  Examples:
@@ -338,23 +380,30 @@ def steady_state(
338
380
  loss_fn=loss_fn,
339
381
  ),
340
382
  )
341
- res = minimize_fn(fn, p0)
342
-
343
- # Restore
383
+ min_result = minimize_fn(fn, p0)
384
+ # Restore original model
344
385
  model.update_parameters(p_orig)
345
- return res
386
+ if min_result is None:
387
+ return min_result
388
+
389
+ return FitResult(
390
+ model=deepcopy(model).update_parameters(min_result.parameters),
391
+ best_pars=min_result.parameters,
392
+ loss=min_result.residual,
393
+ )
346
394
 
347
395
 
348
396
  def time_course(
349
397
  model: Model,
398
+ *,
350
399
  p0: dict[str, float],
351
400
  data: pd.DataFrame,
352
401
  y0: dict[str, float] | None = None,
353
402
  minimize_fn: MinimizeFn = _default_minimize_fn,
354
403
  residual_fn: TimeSeriesResidualFn = _time_course_residual,
355
- integrator: IntegratorType = DefaultIntegrator,
404
+ integrator: IntegratorType | None = None,
356
405
  loss_fn: LossFn = rmse,
357
- ) -> dict[str, float]:
406
+ ) -> FitResult | None:
358
407
  """Fit model parameters to time course of experimental data.
359
408
 
360
409
  Examples:
@@ -393,23 +442,33 @@ def time_course(
393
442
  loss_fn=loss_fn,
394
443
  ),
395
444
  )
396
- res = minimize_fn(fn, p0)
445
+
446
+ min_result = minimize_fn(fn, p0)
447
+ # Restore original model
397
448
  model.update_parameters(p_orig)
398
- return res
449
+ if min_result is None:
450
+ return min_result
451
+
452
+ return FitResult(
453
+ model=deepcopy(model).update_parameters(min_result.parameters),
454
+ best_pars=min_result.parameters,
455
+ loss=min_result.residual,
456
+ )
399
457
 
400
458
 
401
459
  def time_course_over_protocol(
402
460
  model: Model,
461
+ *,
403
462
  p0: dict[str, float],
404
463
  data: pd.DataFrame,
405
464
  protocol: pd.DataFrame,
406
465
  y0: dict[str, float] | None = None,
407
466
  minimize_fn: MinimizeFn = _default_minimize_fn,
408
467
  residual_fn: ProtocolResidualFn = _protocol_residual,
409
- integrator: IntegratorType = DefaultIntegrator,
468
+ integrator: IntegratorType | None = None,
410
469
  loss_fn: LossFn = rmse,
411
470
  time_points_per_step: int = 10,
412
- ) -> dict[str, float]:
471
+ ) -> FitResult | None:
413
472
  """Fit model parameters to time course of experimental data.
414
473
 
415
474
  Examples:
@@ -452,6 +511,188 @@ def time_course_over_protocol(
452
511
  time_points_per_step=time_points_per_step,
453
512
  ),
454
513
  )
455
- res = minimize_fn(fn, p0)
514
+
515
+ min_result = minimize_fn(fn, p0)
516
+ # Restore original model
456
517
  model.update_parameters(p_orig)
457
- return res
518
+ if min_result is None:
519
+ return min_result
520
+
521
+ return FitResult(
522
+ model=deepcopy(model).update_parameters(min_result.parameters),
523
+ best_pars=min_result.parameters,
524
+ loss=min_result.residual,
525
+ )
526
+
527
+
528
+ def _carousel_steady_state_worker(
529
+ model: Model,
530
+ p0: dict[str, float],
531
+ data: pd.Series,
532
+ y0: dict[str, float] | None,
533
+ integrator: IntegratorType | None,
534
+ loss_fn: LossFn,
535
+ minimize_fn: MinimizeFn,
536
+ residual_fn: SteadyStateResidualFn,
537
+ ) -> FitResult | None:
538
+ model_pars = model.parameters
539
+
540
+ return steady_state(
541
+ model,
542
+ p0={k: v for k, v in p0.items() if k in model_pars},
543
+ y0=y0,
544
+ data=data,
545
+ minimize_fn=minimize_fn,
546
+ residual_fn=residual_fn,
547
+ integrator=integrator,
548
+ loss_fn=loss_fn,
549
+ )
550
+
551
+
552
+ def _carousel_time_course_worker(
553
+ model: Model,
554
+ p0: dict[str, float],
555
+ data: pd.DataFrame,
556
+ y0: dict[str, float] | None,
557
+ integrator: IntegratorType | None,
558
+ loss_fn: LossFn,
559
+ minimize_fn: MinimizeFn,
560
+ residual_fn: TimeSeriesResidualFn,
561
+ ) -> FitResult | None:
562
+ model_pars = model.parameters
563
+ return time_course(
564
+ model,
565
+ p0={k: v for k, v in p0.items() if k in model_pars},
566
+ y0=y0,
567
+ data=data,
568
+ minimize_fn=minimize_fn,
569
+ residual_fn=residual_fn,
570
+ integrator=integrator,
571
+ loss_fn=loss_fn,
572
+ )
573
+
574
+
575
+ def _carousel_protocol_worker(
576
+ model: Model,
577
+ p0: dict[str, float],
578
+ data: pd.DataFrame,
579
+ protocol: pd.DataFrame,
580
+ y0: dict[str, float] | None,
581
+ integrator: IntegratorType | None,
582
+ loss_fn: LossFn,
583
+ minimize_fn: MinimizeFn,
584
+ residual_fn: ProtocolResidualFn,
585
+ ) -> FitResult | None:
586
+ model_pars = model.parameters
587
+ return time_course_over_protocol(
588
+ model,
589
+ p0={k: v for k, v in p0.items() if k in model_pars},
590
+ y0=y0,
591
+ protocol=protocol,
592
+ data=data,
593
+ minimize_fn=minimize_fn,
594
+ residual_fn=residual_fn,
595
+ integrator=integrator,
596
+ loss_fn=loss_fn,
597
+ )
598
+
599
+
600
+ def carousel_steady_state(
601
+ carousel: Carousel,
602
+ *,
603
+ p0: dict[str, float],
604
+ data: pd.Series,
605
+ y0: dict[str, float] | None = None,
606
+ minimize_fn: MinimizeFn = _default_minimize_fn,
607
+ residual_fn: SteadyStateResidualFn = _steady_state_residual,
608
+ integrator: IntegratorType | None = None,
609
+ loss_fn: LossFn = rmse,
610
+ ) -> CarouselFit:
611
+ """Fit model parameters to steady-state experimental data over a carousel."""
612
+ return CarouselFit(
613
+ [
614
+ fit
615
+ for i in parallel.parallelise(
616
+ partial(
617
+ _carousel_steady_state_worker,
618
+ p0=p0,
619
+ data=data,
620
+ y0=y0,
621
+ integrator=integrator,
622
+ loss_fn=loss_fn,
623
+ minimize_fn=minimize_fn,
624
+ residual_fn=residual_fn,
625
+ ),
626
+ inputs=list(enumerate(carousel.variants)),
627
+ )
628
+ if (fit := i[1]) is not None
629
+ ]
630
+ )
631
+
632
+
633
+ def carousel_time_course(
634
+ carousel: Carousel,
635
+ *,
636
+ p0: dict[str, float],
637
+ data: pd.DataFrame,
638
+ y0: dict[str, float] | None = None,
639
+ minimize_fn: MinimizeFn = _default_minimize_fn,
640
+ residual_fn: TimeSeriesResidualFn = _time_course_residual,
641
+ integrator: IntegratorType | None = None,
642
+ loss_fn: LossFn = rmse,
643
+ ) -> CarouselFit:
644
+ """Fit model parameters to time course of experimental data over a carousel."""
645
+ return CarouselFit(
646
+ [
647
+ fit
648
+ for i in parallel.parallelise(
649
+ partial(
650
+ _carousel_time_course_worker,
651
+ p0=p0,
652
+ data=data,
653
+ y0=y0,
654
+ integrator=integrator,
655
+ loss_fn=loss_fn,
656
+ minimize_fn=minimize_fn,
657
+ residual_fn=residual_fn,
658
+ ),
659
+ inputs=list(enumerate(carousel.variants)),
660
+ )
661
+ if (fit := i[1]) is not None
662
+ ]
663
+ )
664
+
665
+
666
+ def carousel_time_course_over_protocol(
667
+ carousel: Carousel,
668
+ *,
669
+ p0: dict[str, float],
670
+ data: pd.DataFrame,
671
+ protocol: pd.DataFrame,
672
+ y0: dict[str, float] | None = None,
673
+ minimize_fn: MinimizeFn = _default_minimize_fn,
674
+ residual_fn: ProtocolResidualFn = _protocol_residual,
675
+ integrator: IntegratorType | None = None,
676
+ loss_fn: LossFn = rmse,
677
+ ) -> CarouselFit:
678
+ """Fit model parameters to time course of experimental data over a protocol."""
679
+ return CarouselFit(
680
+ [
681
+ fit
682
+ for i in parallel.parallelise(
683
+ partial(
684
+ _carousel_protocol_worker,
685
+ p0=p0,
686
+ data=data,
687
+ protocol=protocol,
688
+ y0=y0,
689
+ integrator=integrator,
690
+ loss_fn=loss_fn,
691
+ minimize_fn=minimize_fn,
692
+ residual_fn=residual_fn,
693
+ ),
694
+ inputs=list(enumerate(carousel.variants)),
695
+ )
696
+ if (fit := i[1]) is not None
697
+ ]
698
+ )
mxlpy/fns.py CHANGED
@@ -2,11 +2,6 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING
6
-
7
- if TYPE_CHECKING:
8
- from mxlpy.types import Float
9
-
10
5
  __all__ = [
11
6
  "add",
12
7
  "constant",
@@ -36,7 +31,7 @@ __all__ = [
36
31
  ###############################################################################
37
32
 
38
33
 
39
- def constant(x: Float) -> Float:
34
+ def constant(x: float) -> float:
40
35
  """Return a constant value regardless of other model components.
41
36
 
42
37
  Parameters
@@ -58,7 +53,7 @@ def constant(x: Float) -> Float:
58
53
  return x
59
54
 
60
55
 
61
- def neg(x: Float) -> Float:
56
+ def neg(x: float) -> float:
62
57
  """Calculate the negation of a value.
63
58
 
64
59
  Parameters
@@ -82,7 +77,7 @@ def neg(x: Float) -> Float:
82
77
  return -x
83
78
 
84
79
 
85
- def minus(x: Float, y: Float) -> Float:
80
+ def minus(x: float, y: float) -> float:
86
81
  """Calculate the difference between two values.
87
82
 
88
83
  Parameters
@@ -108,7 +103,7 @@ def minus(x: Float, y: Float) -> Float:
108
103
  return x - y
109
104
 
110
105
 
111
- def mul(x: Float, y: Float) -> Float:
106
+ def mul(x: float, y: float) -> float:
112
107
  """Calculate the product of two values.
113
108
 
114
109
  Parameters
@@ -134,7 +129,7 @@ def mul(x: Float, y: Float) -> Float:
134
129
  return x * y
135
130
 
136
131
 
137
- def div(x: Float, y: Float) -> Float:
132
+ def div(x: float, y: float) -> float:
138
133
  """Calculate the quotient of two values.
139
134
 
140
135
  Parameters
@@ -160,7 +155,7 @@ def div(x: Float, y: Float) -> Float:
160
155
  return x / y
161
156
 
162
157
 
163
- def one_div(x: Float) -> Float:
158
+ def one_div(x: float) -> float:
164
159
  """Calculate the reciprocal of a value.
165
160
 
166
161
  Parameters
@@ -184,7 +179,7 @@ def one_div(x: Float) -> Float:
184
179
  return 1.0 / x
185
180
 
186
181
 
187
- def neg_div(x: Float, y: Float) -> Float:
182
+ def neg_div(x: float, y: float) -> float:
188
183
  """Calculate the negative quotient of two values.
189
184
 
190
185
  Parameters
@@ -210,7 +205,7 @@ def neg_div(x: Float, y: Float) -> Float:
210
205
  return -x / y
211
206
 
212
207
 
213
- def twice(x: Float) -> Float:
208
+ def twice(x: float) -> float:
214
209
  """Calculate twice the value.
215
210
 
216
211
  Parameters
@@ -234,7 +229,7 @@ def twice(x: Float) -> Float:
234
229
  return x * 2
235
230
 
236
231
 
237
- def add(x: Float, y: Float) -> Float:
232
+ def add(x: float, y: float) -> float:
238
233
  """Calculate the sum of two values.
239
234
 
240
235
  Parameters
@@ -260,7 +255,7 @@ def add(x: Float, y: Float) -> Float:
260
255
  return x + y
261
256
 
262
257
 
263
- def proportional(x: Float, y: Float) -> Float:
258
+ def proportional(x: float, y: float) -> float:
264
259
  """Calculate the product of two values.
265
260
 
266
261
  Common in mass-action kinetics where x represents a rate constant
@@ -295,9 +290,9 @@ def proportional(x: Float, y: Float) -> Float:
295
290
 
296
291
 
297
292
  def moiety_1s(
298
- x: Float,
299
- x_total: Float,
300
- ) -> Float:
293
+ x: float,
294
+ x_total: float,
295
+ ) -> float:
301
296
  """Calculate conservation relationship for one substrate.
302
297
 
303
298
  Used for creating derived variables that represent moiety conservation,
@@ -328,10 +323,10 @@ def moiety_1s(
328
323
 
329
324
 
330
325
  def moiety_2s(
331
- x1: Float,
332
- x2: Float,
333
- x_total: Float,
334
- ) -> Float:
326
+ x1: float,
327
+ x2: float,
328
+ x_total: float,
329
+ ) -> float:
335
330
  """Calculate conservation relationship for two substrates.
336
331
 
337
332
  Used for creating derived variables that represent moiety conservation
@@ -369,7 +364,7 @@ def moiety_2s(
369
364
  ###############################################################################
370
365
 
371
366
 
372
- def mass_action_1s(s1: Float, k: Float) -> Float:
367
+ def mass_action_1s(s1: float, k: float) -> float:
373
368
  """Calculate irreversible mass action reaction rate with one substrate.
374
369
 
375
370
  Rate = k * [S]
@@ -398,7 +393,7 @@ def mass_action_1s(s1: Float, k: Float) -> Float:
398
393
  return k * s1
399
394
 
400
395
 
401
- def mass_action_1s_1p(s1: Float, p1: Float, kf: Float, kr: Float) -> Float:
396
+ def mass_action_1s_1p(s1: float, p1: float, kf: float, kr: float) -> float:
402
397
  """Calculate reversible mass action reaction rate with one substrate and one product.
403
398
 
404
399
  Rate = kf * [S] - kr * [P]
@@ -432,7 +427,7 @@ def mass_action_1s_1p(s1: Float, p1: Float, kf: Float, kr: Float) -> Float:
432
427
  return kf * s1 - kr * p1
433
428
 
434
429
 
435
- def mass_action_2s(s1: Float, s2: Float, k: Float) -> Float:
430
+ def mass_action_2s(s1: float, s2: float, k: float) -> float:
436
431
  """Calculate irreversible mass action reaction rate with two substrates.
437
432
 
438
433
  Rate = k * [S1] * [S2]
@@ -463,7 +458,7 @@ def mass_action_2s(s1: Float, s2: Float, k: Float) -> Float:
463
458
  return k * s1 * s2
464
459
 
465
460
 
466
- def mass_action_2s_1p(s1: Float, s2: Float, p1: Float, kf: Float, kr: Float) -> Float:
461
+ def mass_action_2s_1p(s1: float, s2: float, p1: float, kf: float, kr: float) -> float:
467
462
  """Calculate reversible mass action reaction rate with two substrates and one product.
468
463
 
469
464
  Rate = kf * [S1] * [S2] - kr * [P]
@@ -505,7 +500,7 @@ def mass_action_2s_1p(s1: Float, s2: Float, p1: Float, kf: Float, kr: Float) ->
505
500
  ###############################################################################
506
501
 
507
502
 
508
- def michaelis_menten_1s(s1: Float, vmax: Float, km1: Float) -> Float:
503
+ def michaelis_menten_1s(s1: float, vmax: float, km1: float) -> float:
509
504
  """Calculate irreversible Michaelis-Menten reaction rate for one substrate.
510
505
 
511
506
  Rate = Vmax * [S] / (Km + [S])
@@ -549,12 +544,12 @@ def michaelis_menten_1s(s1: Float, vmax: Float, km1: Float) -> Float:
549
544
 
550
545
 
551
546
  def michaelis_menten_2s(
552
- s1: Float,
553
- s2: Float,
554
- vmax: Float,
555
- km1: Float,
556
- km2: Float,
557
- ) -> Float:
547
+ s1: float,
548
+ s2: float,
549
+ vmax: float,
550
+ km1: float,
551
+ km2: float,
552
+ ) -> float:
558
553
  """Calculate Michaelis-Menten reaction rate (ping-pong) for two substrates.
559
554
 
560
555
  Rate = Vmax * [S1] * [S2] / ([S1]*[S2] + km1*[S2] + km2*[S1])
@@ -594,14 +589,14 @@ def michaelis_menten_2s(
594
589
 
595
590
 
596
591
  def michaelis_menten_3s(
597
- s1: Float,
598
- s2: Float,
599
- s3: Float,
600
- vmax: Float,
601
- km1: Float,
602
- km2: Float,
603
- km3: Float,
604
- ) -> Float:
592
+ s1: float,
593
+ s2: float,
594
+ s3: float,
595
+ vmax: float,
596
+ km1: float,
597
+ km2: float,
598
+ km3: float,
599
+ ) -> float:
605
600
  """Calculate Michaelis-Menten reaction rate (ping-pong) for three substrates.
606
601
 
607
602
  Rate = Vmax * [S1] * [S2] * [S3] / ([S1]*[S2] + km1*[S2]*[S3] + km2*[S1]*[S3] + km3*[S1]*[S2])
@@ -649,7 +644,7 @@ def michaelis_menten_3s(
649
644
  ###############################################################################
650
645
 
651
646
 
652
- def diffusion_1s_1p(inside: Float, outside: Float, k: Float) -> Float:
647
+ def diffusion_1s_1p(inside: float, outside: float, k: float) -> float:
653
648
  """Calculate diffusion rate between two compartments.
654
649
 
655
650
  Rate = k * ([outside] - [inside])
mxlpy/identify.py CHANGED
@@ -1,6 +1,9 @@
1
1
  """Numerical parameter identification estimations."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from functools import partial
6
+ from typing import TYPE_CHECKING
4
7
 
5
8
  import numpy as np
6
9
  import pandas as pd
@@ -8,11 +11,15 @@ from tqdm import tqdm
8
11
 
9
12
  from mxlpy import fit
10
13
  from mxlpy.distributions import LogNormal, sample
11
- from mxlpy.model import Model
12
14
  from mxlpy.parallel import parallelise
13
- from mxlpy.types import Array
14
15
 
15
- __all__ = ["profile_likelihood"]
16
+ if TYPE_CHECKING:
17
+ from mxlpy.model import Model
18
+ from mxlpy.types import Array
19
+
20
+ __all__ = [
21
+ "profile_likelihood",
22
+ ]
16
23
 
17
24
 
18
25
  def _mc_fit_time_course_worker(
@@ -21,16 +28,15 @@ def _mc_fit_time_course_worker(
21
28
  data: pd.DataFrame,
22
29
  loss_fn: fit.LossFn,
23
30
  ) -> float:
24
- p_fit = fit.time_course(model=model, p0=p0.to_dict(), data=data)
25
- return fit._time_course_residual( # noqa: SLF001
26
- par_values=list(p_fit.values()),
27
- par_names=list(p_fit.keys()),
28
- data=data,
31
+ fit_result = fit.time_course(
29
32
  model=model,
30
- y0=None,
31
- integrator=fit.DefaultIntegrator,
33
+ p0=p0.to_dict(),
34
+ data=data,
32
35
  loss_fn=loss_fn,
33
36
  )
37
+ if fit_result is None:
38
+ return np.inf
39
+ return fit_result.loss
34
40
 
35
41
 
36
42
  def profile_likelihood(