mxlpy 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/__init__.py CHANGED
@@ -45,7 +45,6 @@ from . import (
45
45
  compare,
46
46
  distributions,
47
47
  experimental,
48
- fit,
49
48
  fns,
50
49
  mc,
51
50
  mca,
@@ -55,6 +54,8 @@ from . import (
55
54
  scan,
56
55
  units,
57
56
  )
57
+ from .fit import global_ as fit_global
58
+ from .fit import local_ as fit_local
58
59
  from .integrators import DefaultIntegrator, Diffrax, Scipy
59
60
  from .label_map import LabelMapper
60
61
  from .linear_label_map import LinearLabelMapper
@@ -106,7 +107,8 @@ __all__ = [
106
107
  "compare",
107
108
  "distributions",
108
109
  "experimental",
109
- "fit",
110
+ "fit_global",
111
+ "fit_local",
110
112
  "fns",
111
113
  "make_protocol",
112
114
  "mc",
mxlpy/fit/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Fitting routines."""
2
+
3
+ from . import common, global_, local_
4
+
5
+ __all__ = [
6
+ "common",
7
+ "global_",
8
+ "local_",
9
+ ]
mxlpy/fit/common.py ADDED
@@ -0,0 +1,298 @@
1
+ """Common types and utilities between local and global fitting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Protocol
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from wadler_lindig import pformat
12
+
13
+ from mxlpy.model import Model
14
+ from mxlpy.simulator import Simulator
15
+ from mxlpy.types import Array, ArrayLike, Callable, IntegratorType, cast
16
+
17
+ if TYPE_CHECKING:
18
+ import pandas as pd
19
+
20
+ from mxlpy.model import Model
21
+
22
+ LOGGER = logging.getLogger(__name__)
23
+
24
+ type InitialGuess = dict[str, float]
25
+
26
+ type Bounds = dict[str, tuple[float | None, float | None]]
27
+ type ResidualFn = Callable[[Array], float]
28
+ type LossFn = Callable[
29
+ [
30
+ pd.DataFrame | pd.Series,
31
+ pd.DataFrame | pd.Series,
32
+ ],
33
+ float,
34
+ ]
35
+
36
+
37
+ __all__ = [
38
+ "Bounds",
39
+ "CarouselFit",
40
+ "FitResult",
41
+ "InitialGuess",
42
+ "LOGGER",
43
+ "LossFn",
44
+ "MinResult",
45
+ "ProtocolResidualFn",
46
+ "ResidualFn",
47
+ "SteadyStateResidualFn",
48
+ "TimeSeriesResidualFn",
49
+ "rmse",
50
+ ]
51
+
52
+
53
+ @dataclass
54
+ class MinResult:
55
+ """Result of a minimization operation."""
56
+
57
+ parameters: dict[str, float]
58
+ residual: float
59
+
60
+ def __repr__(self) -> str:
61
+ """Return default representation."""
62
+ return pformat(self)
63
+
64
+
65
+ @dataclass
66
+ class FitResult:
67
+ """Result of a fit operation."""
68
+
69
+ model: Model
70
+ best_pars: dict[str, float]
71
+ loss: float
72
+
73
+ def __repr__(self) -> str:
74
+ """Return default representation."""
75
+ return pformat(self)
76
+
77
+
78
+ @dataclass
79
+ class CarouselFit:
80
+ """Result of a carousel fit operation."""
81
+
82
+ fits: list[FitResult]
83
+
84
+ def __repr__(self) -> str:
85
+ """Return default representation."""
86
+ return pformat(self)
87
+
88
+ def get_best_fit(self) -> FitResult:
89
+ """Get the best fit from the carousel."""
90
+ return min(self.fits, key=lambda x: x.loss)
91
+
92
+
93
+ def rmse(
94
+ y_pred: pd.DataFrame | pd.Series,
95
+ y_true: pd.DataFrame | pd.Series,
96
+ ) -> float:
97
+ """Calculate root mean square error between model and data."""
98
+ return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
99
+
100
+
101
+ class SteadyStateResidualFn(Protocol):
102
+ """Protocol for steady state residual functions."""
103
+
104
+ def __call__(
105
+ self,
106
+ par_values: Array,
107
+ # This will be filled out by partial
108
+ par_names: list[str],
109
+ data: pd.Series,
110
+ model: Model,
111
+ y0: dict[str, float] | None,
112
+ integrator: IntegratorType | None,
113
+ loss_fn: LossFn,
114
+ ) -> float:
115
+ """Calculate residual error between model steady state and experimental data."""
116
+ ...
117
+
118
+
119
+ class TimeSeriesResidualFn(Protocol):
120
+ """Protocol for time series residual functions."""
121
+
122
+ def __call__(
123
+ self,
124
+ par_values: Array,
125
+ # This will be filled out by partial
126
+ par_names: list[str],
127
+ data: pd.DataFrame,
128
+ model: Model,
129
+ y0: dict[str, float] | None,
130
+ integrator: IntegratorType | None,
131
+ loss_fn: LossFn,
132
+ ) -> float:
133
+ """Calculate residual error between model time course and experimental data."""
134
+ ...
135
+
136
+
137
+ class ProtocolResidualFn(Protocol):
138
+ """Protocol for time series residual functions."""
139
+
140
+ def __call__(
141
+ self,
142
+ par_values: Array,
143
+ # This will be filled out by partial
144
+ par_names: list[str],
145
+ data: pd.DataFrame,
146
+ model: Model,
147
+ y0: dict[str, float] | None,
148
+ integrator: IntegratorType | None,
149
+ loss_fn: LossFn,
150
+ protocol: pd.DataFrame,
151
+ ) -> float:
152
+ """Calculate residual error between model time course and experimental data."""
153
+ ...
154
+
155
+
156
+ def _steady_state_residual(
157
+ par_values: Array,
158
+ # This will be filled out by partial
159
+ par_names: list[str],
160
+ data: pd.Series,
161
+ model: Model,
162
+ y0: dict[str, float] | None,
163
+ integrator: IntegratorType | None,
164
+ loss_fn: LossFn,
165
+ ) -> float:
166
+ """Calculate residual error between model steady state and experimental data.
167
+
168
+ Args:
169
+ par_values: Parameter values to test
170
+ data: Experimental steady state data
171
+ model: Model instance to simulate
172
+ y0: Initial conditions
173
+ par_names: Names of parameters being fit
174
+ integrator: ODE integrator class to use
175
+ loss_fn: Loss function to use for residual calculation
176
+
177
+ Returns:
178
+ float: Root mean square error between model and data
179
+
180
+ """
181
+ res = (
182
+ Simulator(
183
+ model.update_parameters(
184
+ dict(
185
+ zip(
186
+ par_names,
187
+ par_values,
188
+ strict=True,
189
+ )
190
+ )
191
+ ),
192
+ y0=y0,
193
+ integrator=integrator,
194
+ )
195
+ .simulate_to_steady_state()
196
+ .get_result()
197
+ )
198
+ if res is None:
199
+ return cast(float, np.inf)
200
+
201
+ return loss_fn(
202
+ res.get_combined().loc[:, cast(list, data.index)],
203
+ data,
204
+ )
205
+
206
+
207
+ def _time_course_residual(
208
+ par_values: ArrayLike,
209
+ # This will be filled out by partial
210
+ par_names: list[str],
211
+ data: pd.DataFrame,
212
+ model: Model,
213
+ y0: dict[str, float] | None,
214
+ integrator: IntegratorType | None,
215
+ loss_fn: LossFn,
216
+ ) -> float:
217
+ """Calculate residual error between model time course and experimental data.
218
+
219
+ Args:
220
+ par_values: Parameter values to test
221
+ data: Experimental time course data
222
+ model: Model instance to simulate
223
+ y0: Initial conditions
224
+ par_names: Names of parameters being fit
225
+ integrator: ODE integrator class to use
226
+ loss_fn: Loss function to use for residual calculation
227
+
228
+ Returns:
229
+ float: Root mean square error between model and data
230
+
231
+ """
232
+ res = (
233
+ Simulator(
234
+ model.update_parameters(dict(zip(par_names, par_values, strict=True))),
235
+ y0=y0,
236
+ integrator=integrator,
237
+ )
238
+ .simulate_time_course(cast(list, data.index))
239
+ .get_result()
240
+ )
241
+ if res is None:
242
+ return cast(float, np.inf)
243
+ results_ss = res.get_combined()
244
+
245
+ return loss_fn(
246
+ results_ss.loc[:, cast(list, data.columns)],
247
+ data,
248
+ )
249
+
250
+
251
+ def _protocol_time_course_residual(
252
+ par_values: ArrayLike,
253
+ # This will be filled out by partial
254
+ par_names: list[str],
255
+ data: pd.DataFrame,
256
+ model: Model,
257
+ y0: dict[str, float] | None,
258
+ integrator: IntegratorType | None,
259
+ loss_fn: LossFn,
260
+ protocol: pd.DataFrame,
261
+ ) -> float:
262
+ """Calculate residual error between model time course and experimental data.
263
+
264
+ Args:
265
+ par_values: Parameter values to test
266
+ data: Experimental time course data
267
+ model: Model instance to simulate
268
+ y0: Initial conditions
269
+ par_names: Names of parameters being fit
270
+ integrator: ODE integrator class to use
271
+ loss_fn: Loss function to use for residual calculation
272
+ protocol: Experimental protocol
273
+ time_points_per_step: Number of time points per step in the protocol
274
+
275
+ Returns:
276
+ float: Root mean square error between model and data
277
+
278
+ """
279
+ res = (
280
+ Simulator(
281
+ model.update_parameters(dict(zip(par_names, par_values, strict=True))),
282
+ y0=y0,
283
+ integrator=integrator,
284
+ )
285
+ .simulate_protocol_time_course(
286
+ protocol=protocol,
287
+ time_points=data.index,
288
+ )
289
+ .get_result()
290
+ )
291
+ if res is None:
292
+ return cast(float, np.inf)
293
+ results_ss = res.get_combined()
294
+
295
+ return loss_fn(
296
+ results_ss.loc[:, cast(list, data.columns)],
297
+ data,
298
+ )