mxlpy 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +4 -2
- mxlpy/fit/__init__.py +9 -0
- mxlpy/fit/common.py +298 -0
- mxlpy/fit/global_.py +534 -0
- mxlpy/{fit.py → fit/local_.py} +98 -320
- mxlpy/identify.py +5 -4
- mxlpy/model.py +33 -9
- mxlpy/types.py +171 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.25.0.dist-info}/METADATA +8 -1
- {mxlpy-0.24.0.dist-info → mxlpy-0.25.0.dist-info}/RECORD +12 -9
- {mxlpy-0.24.0.dist-info → mxlpy-0.25.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.24.0.dist-info → mxlpy-0.25.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/__init__.py
CHANGED
@@ -45,7 +45,6 @@ from . import (
|
|
45
45
|
compare,
|
46
46
|
distributions,
|
47
47
|
experimental,
|
48
|
-
fit,
|
49
48
|
fns,
|
50
49
|
mc,
|
51
50
|
mca,
|
@@ -55,6 +54,8 @@ from . import (
|
|
55
54
|
scan,
|
56
55
|
units,
|
57
56
|
)
|
57
|
+
from .fit import global_ as fit_global
|
58
|
+
from .fit import local_ as fit_local
|
58
59
|
from .integrators import DefaultIntegrator, Diffrax, Scipy
|
59
60
|
from .label_map import LabelMapper
|
60
61
|
from .linear_label_map import LinearLabelMapper
|
@@ -106,7 +107,8 @@ __all__ = [
|
|
106
107
|
"compare",
|
107
108
|
"distributions",
|
108
109
|
"experimental",
|
109
|
-
"
|
110
|
+
"fit_global",
|
111
|
+
"fit_local",
|
110
112
|
"fns",
|
111
113
|
"make_protocol",
|
112
114
|
"mc",
|
mxlpy/fit/__init__.py
ADDED
mxlpy/fit/common.py
ADDED
@@ -0,0 +1,298 @@
|
|
1
|
+
"""Common types and utilities between local and global fitting."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import TYPE_CHECKING, Protocol
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import pandas as pd
|
11
|
+
from wadler_lindig import pformat
|
12
|
+
|
13
|
+
from mxlpy.model import Model
|
14
|
+
from mxlpy.simulator import Simulator
|
15
|
+
from mxlpy.types import Array, ArrayLike, Callable, IntegratorType, cast
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
import pandas as pd
|
19
|
+
|
20
|
+
from mxlpy.model import Model
|
21
|
+
|
22
|
+
LOGGER = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
type InitialGuess = dict[str, float]
|
25
|
+
|
26
|
+
type Bounds = dict[str, tuple[float | None, float | None]]
|
27
|
+
type ResidualFn = Callable[[Array], float]
|
28
|
+
type LossFn = Callable[
|
29
|
+
[
|
30
|
+
pd.DataFrame | pd.Series,
|
31
|
+
pd.DataFrame | pd.Series,
|
32
|
+
],
|
33
|
+
float,
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
__all__ = [
|
38
|
+
"Bounds",
|
39
|
+
"CarouselFit",
|
40
|
+
"FitResult",
|
41
|
+
"InitialGuess",
|
42
|
+
"LOGGER",
|
43
|
+
"LossFn",
|
44
|
+
"MinResult",
|
45
|
+
"ProtocolResidualFn",
|
46
|
+
"ResidualFn",
|
47
|
+
"SteadyStateResidualFn",
|
48
|
+
"TimeSeriesResidualFn",
|
49
|
+
"rmse",
|
50
|
+
]
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass
|
54
|
+
class MinResult:
|
55
|
+
"""Result of a minimization operation."""
|
56
|
+
|
57
|
+
parameters: dict[str, float]
|
58
|
+
residual: float
|
59
|
+
|
60
|
+
def __repr__(self) -> str:
|
61
|
+
"""Return default representation."""
|
62
|
+
return pformat(self)
|
63
|
+
|
64
|
+
|
65
|
+
@dataclass
|
66
|
+
class FitResult:
|
67
|
+
"""Result of a fit operation."""
|
68
|
+
|
69
|
+
model: Model
|
70
|
+
best_pars: dict[str, float]
|
71
|
+
loss: float
|
72
|
+
|
73
|
+
def __repr__(self) -> str:
|
74
|
+
"""Return default representation."""
|
75
|
+
return pformat(self)
|
76
|
+
|
77
|
+
|
78
|
+
@dataclass
|
79
|
+
class CarouselFit:
|
80
|
+
"""Result of a carousel fit operation."""
|
81
|
+
|
82
|
+
fits: list[FitResult]
|
83
|
+
|
84
|
+
def __repr__(self) -> str:
|
85
|
+
"""Return default representation."""
|
86
|
+
return pformat(self)
|
87
|
+
|
88
|
+
def get_best_fit(self) -> FitResult:
|
89
|
+
"""Get the best fit from the carousel."""
|
90
|
+
return min(self.fits, key=lambda x: x.loss)
|
91
|
+
|
92
|
+
|
93
|
+
def rmse(
|
94
|
+
y_pred: pd.DataFrame | pd.Series,
|
95
|
+
y_true: pd.DataFrame | pd.Series,
|
96
|
+
) -> float:
|
97
|
+
"""Calculate root mean square error between model and data."""
|
98
|
+
return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
|
99
|
+
|
100
|
+
|
101
|
+
class SteadyStateResidualFn(Protocol):
|
102
|
+
"""Protocol for steady state residual functions."""
|
103
|
+
|
104
|
+
def __call__(
|
105
|
+
self,
|
106
|
+
par_values: Array,
|
107
|
+
# This will be filled out by partial
|
108
|
+
par_names: list[str],
|
109
|
+
data: pd.Series,
|
110
|
+
model: Model,
|
111
|
+
y0: dict[str, float] | None,
|
112
|
+
integrator: IntegratorType | None,
|
113
|
+
loss_fn: LossFn,
|
114
|
+
) -> float:
|
115
|
+
"""Calculate residual error between model steady state and experimental data."""
|
116
|
+
...
|
117
|
+
|
118
|
+
|
119
|
+
class TimeSeriesResidualFn(Protocol):
|
120
|
+
"""Protocol for time series residual functions."""
|
121
|
+
|
122
|
+
def __call__(
|
123
|
+
self,
|
124
|
+
par_values: Array,
|
125
|
+
# This will be filled out by partial
|
126
|
+
par_names: list[str],
|
127
|
+
data: pd.DataFrame,
|
128
|
+
model: Model,
|
129
|
+
y0: dict[str, float] | None,
|
130
|
+
integrator: IntegratorType | None,
|
131
|
+
loss_fn: LossFn,
|
132
|
+
) -> float:
|
133
|
+
"""Calculate residual error between model time course and experimental data."""
|
134
|
+
...
|
135
|
+
|
136
|
+
|
137
|
+
class ProtocolResidualFn(Protocol):
|
138
|
+
"""Protocol for time series residual functions."""
|
139
|
+
|
140
|
+
def __call__(
|
141
|
+
self,
|
142
|
+
par_values: Array,
|
143
|
+
# This will be filled out by partial
|
144
|
+
par_names: list[str],
|
145
|
+
data: pd.DataFrame,
|
146
|
+
model: Model,
|
147
|
+
y0: dict[str, float] | None,
|
148
|
+
integrator: IntegratorType | None,
|
149
|
+
loss_fn: LossFn,
|
150
|
+
protocol: pd.DataFrame,
|
151
|
+
) -> float:
|
152
|
+
"""Calculate residual error between model time course and experimental data."""
|
153
|
+
...
|
154
|
+
|
155
|
+
|
156
|
+
def _steady_state_residual(
|
157
|
+
par_values: Array,
|
158
|
+
# This will be filled out by partial
|
159
|
+
par_names: list[str],
|
160
|
+
data: pd.Series,
|
161
|
+
model: Model,
|
162
|
+
y0: dict[str, float] | None,
|
163
|
+
integrator: IntegratorType | None,
|
164
|
+
loss_fn: LossFn,
|
165
|
+
) -> float:
|
166
|
+
"""Calculate residual error between model steady state and experimental data.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
par_values: Parameter values to test
|
170
|
+
data: Experimental steady state data
|
171
|
+
model: Model instance to simulate
|
172
|
+
y0: Initial conditions
|
173
|
+
par_names: Names of parameters being fit
|
174
|
+
integrator: ODE integrator class to use
|
175
|
+
loss_fn: Loss function to use for residual calculation
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
float: Root mean square error between model and data
|
179
|
+
|
180
|
+
"""
|
181
|
+
res = (
|
182
|
+
Simulator(
|
183
|
+
model.update_parameters(
|
184
|
+
dict(
|
185
|
+
zip(
|
186
|
+
par_names,
|
187
|
+
par_values,
|
188
|
+
strict=True,
|
189
|
+
)
|
190
|
+
)
|
191
|
+
),
|
192
|
+
y0=y0,
|
193
|
+
integrator=integrator,
|
194
|
+
)
|
195
|
+
.simulate_to_steady_state()
|
196
|
+
.get_result()
|
197
|
+
)
|
198
|
+
if res is None:
|
199
|
+
return cast(float, np.inf)
|
200
|
+
|
201
|
+
return loss_fn(
|
202
|
+
res.get_combined().loc[:, cast(list, data.index)],
|
203
|
+
data,
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
def _time_course_residual(
|
208
|
+
par_values: ArrayLike,
|
209
|
+
# This will be filled out by partial
|
210
|
+
par_names: list[str],
|
211
|
+
data: pd.DataFrame,
|
212
|
+
model: Model,
|
213
|
+
y0: dict[str, float] | None,
|
214
|
+
integrator: IntegratorType | None,
|
215
|
+
loss_fn: LossFn,
|
216
|
+
) -> float:
|
217
|
+
"""Calculate residual error between model time course and experimental data.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
par_values: Parameter values to test
|
221
|
+
data: Experimental time course data
|
222
|
+
model: Model instance to simulate
|
223
|
+
y0: Initial conditions
|
224
|
+
par_names: Names of parameters being fit
|
225
|
+
integrator: ODE integrator class to use
|
226
|
+
loss_fn: Loss function to use for residual calculation
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
float: Root mean square error between model and data
|
230
|
+
|
231
|
+
"""
|
232
|
+
res = (
|
233
|
+
Simulator(
|
234
|
+
model.update_parameters(dict(zip(par_names, par_values, strict=True))),
|
235
|
+
y0=y0,
|
236
|
+
integrator=integrator,
|
237
|
+
)
|
238
|
+
.simulate_time_course(cast(list, data.index))
|
239
|
+
.get_result()
|
240
|
+
)
|
241
|
+
if res is None:
|
242
|
+
return cast(float, np.inf)
|
243
|
+
results_ss = res.get_combined()
|
244
|
+
|
245
|
+
return loss_fn(
|
246
|
+
results_ss.loc[:, cast(list, data.columns)],
|
247
|
+
data,
|
248
|
+
)
|
249
|
+
|
250
|
+
|
251
|
+
def _protocol_time_course_residual(
|
252
|
+
par_values: ArrayLike,
|
253
|
+
# This will be filled out by partial
|
254
|
+
par_names: list[str],
|
255
|
+
data: pd.DataFrame,
|
256
|
+
model: Model,
|
257
|
+
y0: dict[str, float] | None,
|
258
|
+
integrator: IntegratorType | None,
|
259
|
+
loss_fn: LossFn,
|
260
|
+
protocol: pd.DataFrame,
|
261
|
+
) -> float:
|
262
|
+
"""Calculate residual error between model time course and experimental data.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
par_values: Parameter values to test
|
266
|
+
data: Experimental time course data
|
267
|
+
model: Model instance to simulate
|
268
|
+
y0: Initial conditions
|
269
|
+
par_names: Names of parameters being fit
|
270
|
+
integrator: ODE integrator class to use
|
271
|
+
loss_fn: Loss function to use for residual calculation
|
272
|
+
protocol: Experimental protocol
|
273
|
+
time_points_per_step: Number of time points per step in the protocol
|
274
|
+
|
275
|
+
Returns:
|
276
|
+
float: Root mean square error between model and data
|
277
|
+
|
278
|
+
"""
|
279
|
+
res = (
|
280
|
+
Simulator(
|
281
|
+
model.update_parameters(dict(zip(par_names, par_values, strict=True))),
|
282
|
+
y0=y0,
|
283
|
+
integrator=integrator,
|
284
|
+
)
|
285
|
+
.simulate_protocol_time_course(
|
286
|
+
protocol=protocol,
|
287
|
+
time_points=data.index,
|
288
|
+
)
|
289
|
+
.get_result()
|
290
|
+
)
|
291
|
+
if res is None:
|
292
|
+
return cast(float, np.inf)
|
293
|
+
results_ss = res.get_combined()
|
294
|
+
|
295
|
+
return loss_fn(
|
296
|
+
results_ss.loc[:, cast(list, data.columns)],
|
297
|
+
data,
|
298
|
+
)
|