mxlpy 0.23.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/__init__.py CHANGED
@@ -45,21 +45,22 @@ from . import (
45
45
  compare,
46
46
  distributions,
47
47
  experimental,
48
- fit,
49
48
  fns,
50
49
  mc,
51
50
  mca,
52
51
  plot,
53
52
  report,
54
53
  sbml,
54
+ scan,
55
55
  units,
56
56
  )
57
+ from .fit import global_ as fit_global
58
+ from .fit import local_ as fit_local
57
59
  from .integrators import DefaultIntegrator, Diffrax, Scipy
58
60
  from .label_map import LabelMapper
59
61
  from .linear_label_map import LinearLabelMapper
60
62
  from .mc import Cache
61
63
  from .model import Model
62
- from .scan import steady_state, time_course, time_course_over_protocol
63
64
  from .simulator import Simulator
64
65
  from .symbolic import SymbolicModel, to_symbolic_model
65
66
  from .types import (
@@ -106,7 +107,8 @@ __all__ = [
106
107
  "compare",
107
108
  "distributions",
108
109
  "experimental",
109
- "fit",
110
+ "fit_global",
111
+ "fit_local",
110
112
  "fns",
111
113
  "make_protocol",
112
114
  "mc",
@@ -116,11 +118,9 @@ __all__ = [
116
118
  "plot",
117
119
  "report",
118
120
  "sbml",
119
- "steady_state",
121
+ "scan",
120
122
  "surrogates",
121
123
  "symbolic",
122
- "time_course",
123
- "time_course_over_protocol",
124
124
  "to_symbolic_model",
125
125
  "units",
126
126
  "unwrap",
mxlpy/carousel.py CHANGED
@@ -9,6 +9,7 @@ from functools import partial
9
9
  from typing import TYPE_CHECKING
10
10
 
11
11
  import pandas as pd
12
+ from wadler_lindig import pformat
12
13
 
13
14
  from mxlpy import parallel, scan
14
15
 
@@ -25,6 +26,10 @@ if TYPE_CHECKING:
25
26
  class ReactionTemplate:
26
27
  """Template for a reaction in a model."""
27
28
 
29
+ def __repr__(self) -> str:
30
+ """Return default representation."""
31
+ return pformat(self)
32
+
28
33
  fn: RateFn
29
34
  args: list[str]
30
35
  additional_parameters: dict[str, float] = field(default_factory=dict)
@@ -37,6 +42,10 @@ class CarouselSteadyState:
37
42
  carousel: list[Model]
38
43
  results: list[Result]
39
44
 
45
+ def __repr__(self) -> str:
46
+ """Return default representation."""
47
+ return pformat(self)
48
+
40
49
  def get_variables_by_model(self) -> pd.DataFrame:
41
50
  """Get the variables of the time course results, indexed by model."""
42
51
  return pd.DataFrame(
@@ -51,6 +60,10 @@ class CarouselTimeCourse:
51
60
  carousel: list[Model]
52
61
  results: list[Result]
53
62
 
63
+ def __repr__(self) -> str:
64
+ """Return default representation."""
65
+ return pformat(self)
66
+
54
67
  def get_variables_by_model(self) -> pd.DataFrame:
55
68
  """Get the variables of the time course results, indexed by model."""
56
69
  return pd.concat({i: r.variables for i, r in enumerate(self.results)})
@@ -76,6 +89,10 @@ class Carousel:
76
89
 
77
90
  variants: list[Model]
78
91
 
92
+ def __repr__(self) -> str:
93
+ """Return default representation."""
94
+ return pformat(self)
95
+
79
96
  def __init__(
80
97
  self,
81
98
  model: Model,
@@ -115,7 +132,7 @@ class Carousel:
115
132
  results=results,
116
133
  )
117
134
 
118
- def protocol_time_course(
135
+ def protocol(
119
136
  self,
120
137
  protocol: pd.DataFrame,
121
138
  *,
@@ -141,6 +158,34 @@ class Carousel:
141
158
  results=results,
142
159
  )
143
160
 
161
+ def protocol_time_course(
162
+ self,
163
+ protocol: pd.DataFrame,
164
+ time_points: Array,
165
+ *,
166
+ y0: dict[str, float] | None = None,
167
+ integrator: IntegratorType | None = None,
168
+ ) -> CarouselTimeCourse:
169
+ """Simulate the carousel of models over a protocol time course."""
170
+ results = [
171
+ i[1]
172
+ for i in parallel.parallelise(
173
+ partial(
174
+ scan._protocol_time_course_worker, # noqa: SLF001
175
+ protocol=protocol,
176
+ integrator=integrator,
177
+ time_points=time_points,
178
+ y0=y0,
179
+ ),
180
+ list(enumerate(self.variants)),
181
+ )
182
+ ]
183
+
184
+ return CarouselTimeCourse(
185
+ carousel=self.variants,
186
+ results=results,
187
+ )
188
+
144
189
  def steady_state(
145
190
  self,
146
191
  *,
mxlpy/compare.py CHANGED
@@ -6,6 +6,7 @@ from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, cast
7
7
 
8
8
  import pandas as pd
9
+ from wadler_lindig import pformat
9
10
 
10
11
  from mxlpy import plot
11
12
  from mxlpy.simulator import Simulator
@@ -32,6 +33,10 @@ class SteadyStateComparison:
32
33
  res1: Result
33
34
  res2: Result
34
35
 
36
+ def __repr__(self) -> str:
37
+ """Return default representation."""
38
+ return pformat(self)
39
+
35
40
  @property
36
41
  def variables(self) -> pd.DataFrame:
37
42
  """Compare the steady state variables."""
@@ -93,6 +98,10 @@ class TimeCourseComparison:
93
98
  res1: Result
94
99
  res2: Result
95
100
 
101
+ def __repr__(self) -> str:
102
+ """Return default representation."""
103
+ return pformat(self)
104
+
96
105
  # @property
97
106
  # def variables(self) -> pd.DataFrame:
98
107
  # """Compare the steady state variables."""
@@ -144,6 +153,10 @@ class ProtocolComparison:
144
153
  res2: Result
145
154
  protocol: pd.DataFrame
146
155
 
156
+ def __repr__(self) -> str:
157
+ """Return default representation."""
158
+ return pformat(self)
159
+
147
160
  # @property
148
161
  # def variables(self) -> pd.DataFrame:
149
162
  # """Compare the steady state variables."""
@@ -5,6 +5,8 @@ from __future__ import annotations
5
5
  from dataclasses import dataclass, field
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from wadler_lindig import pformat
9
+
8
10
  from mxlpy.types import Derived
9
11
 
10
12
  if TYPE_CHECKING:
@@ -28,6 +30,10 @@ class DerivedDiff:
28
30
  args1: list[str] = field(default_factory=list)
29
31
  args2: list[str] = field(default_factory=list)
30
32
 
33
+ def __repr__(self) -> str:
34
+ """Return default representation."""
35
+ return pformat(self)
36
+
31
37
 
32
38
  @dataclass
33
39
  class ReactionDiff:
@@ -38,6 +44,10 @@ class ReactionDiff:
38
44
  stoichiometry1: dict[str, float | Derived] = field(default_factory=dict)
39
45
  stoichiometry2: dict[str, float | Derived] = field(default_factory=dict)
40
46
 
47
+ def __repr__(self) -> str:
48
+ """Return default representation."""
49
+ return pformat(self)
50
+
41
51
 
42
52
  @dataclass
43
53
  class ModelDiff:
@@ -56,6 +66,10 @@ class ModelDiff:
56
66
  different_readouts: dict[str, DerivedDiff] = field(default_factory=dict)
57
67
  different_derived: dict[str, DerivedDiff] = field(default_factory=dict)
58
68
 
69
+ def __repr__(self) -> str:
70
+ """Return default representation."""
71
+ return pformat(self)
72
+
59
73
  def __str__(self) -> str:
60
74
  """Return a human-readable string representation of the diff."""
61
75
  content = ["Model Diff", "----------"]
mxlpy/fit/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Fitting routines."""
2
+
3
+ from . import common, global_, local_
4
+
5
+ __all__ = [
6
+ "common",
7
+ "global_",
8
+ "local_",
9
+ ]
mxlpy/fit/common.py ADDED
@@ -0,0 +1,298 @@
1
+ """Common types and utilities between local and global fitting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Protocol
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from wadler_lindig import pformat
12
+
13
+ from mxlpy.model import Model
14
+ from mxlpy.simulator import Simulator
15
+ from mxlpy.types import Array, ArrayLike, Callable, IntegratorType, cast
16
+
17
+ if TYPE_CHECKING:
18
+ import pandas as pd
19
+
20
+ from mxlpy.model import Model
21
+
22
+ LOGGER = logging.getLogger(__name__)
23
+
24
+ type InitialGuess = dict[str, float]
25
+
26
+ type Bounds = dict[str, tuple[float | None, float | None]]
27
+ type ResidualFn = Callable[[Array], float]
28
+ type LossFn = Callable[
29
+ [
30
+ pd.DataFrame | pd.Series,
31
+ pd.DataFrame | pd.Series,
32
+ ],
33
+ float,
34
+ ]
35
+
36
+
37
+ __all__ = [
38
+ "Bounds",
39
+ "CarouselFit",
40
+ "FitResult",
41
+ "InitialGuess",
42
+ "LOGGER",
43
+ "LossFn",
44
+ "MinResult",
45
+ "ProtocolResidualFn",
46
+ "ResidualFn",
47
+ "SteadyStateResidualFn",
48
+ "TimeSeriesResidualFn",
49
+ "rmse",
50
+ ]
51
+
52
+
53
+ @dataclass
54
+ class MinResult:
55
+ """Result of a minimization operation."""
56
+
57
+ parameters: dict[str, float]
58
+ residual: float
59
+
60
+ def __repr__(self) -> str:
61
+ """Return default representation."""
62
+ return pformat(self)
63
+
64
+
65
+ @dataclass
66
+ class FitResult:
67
+ """Result of a fit operation."""
68
+
69
+ model: Model
70
+ best_pars: dict[str, float]
71
+ loss: float
72
+
73
+ def __repr__(self) -> str:
74
+ """Return default representation."""
75
+ return pformat(self)
76
+
77
+
78
+ @dataclass
79
+ class CarouselFit:
80
+ """Result of a carousel fit operation."""
81
+
82
+ fits: list[FitResult]
83
+
84
+ def __repr__(self) -> str:
85
+ """Return default representation."""
86
+ return pformat(self)
87
+
88
+ def get_best_fit(self) -> FitResult:
89
+ """Get the best fit from the carousel."""
90
+ return min(self.fits, key=lambda x: x.loss)
91
+
92
+
93
+ def rmse(
94
+ y_pred: pd.DataFrame | pd.Series,
95
+ y_true: pd.DataFrame | pd.Series,
96
+ ) -> float:
97
+ """Calculate root mean square error between model and data."""
98
+ return cast(float, np.sqrt(np.mean(np.square(y_pred - y_true))))
99
+
100
+
101
+ class SteadyStateResidualFn(Protocol):
102
+ """Protocol for steady state residual functions."""
103
+
104
+ def __call__(
105
+ self,
106
+ par_values: Array,
107
+ # This will be filled out by partial
108
+ par_names: list[str],
109
+ data: pd.Series,
110
+ model: Model,
111
+ y0: dict[str, float] | None,
112
+ integrator: IntegratorType | None,
113
+ loss_fn: LossFn,
114
+ ) -> float:
115
+ """Calculate residual error between model steady state and experimental data."""
116
+ ...
117
+
118
+
119
+ class TimeSeriesResidualFn(Protocol):
120
+ """Protocol for time series residual functions."""
121
+
122
+ def __call__(
123
+ self,
124
+ par_values: Array,
125
+ # This will be filled out by partial
126
+ par_names: list[str],
127
+ data: pd.DataFrame,
128
+ model: Model,
129
+ y0: dict[str, float] | None,
130
+ integrator: IntegratorType | None,
131
+ loss_fn: LossFn,
132
+ ) -> float:
133
+ """Calculate residual error between model time course and experimental data."""
134
+ ...
135
+
136
+
137
+ class ProtocolResidualFn(Protocol):
138
+ """Protocol for time series residual functions."""
139
+
140
+ def __call__(
141
+ self,
142
+ par_values: Array,
143
+ # This will be filled out by partial
144
+ par_names: list[str],
145
+ data: pd.DataFrame,
146
+ model: Model,
147
+ y0: dict[str, float] | None,
148
+ integrator: IntegratorType | None,
149
+ loss_fn: LossFn,
150
+ protocol: pd.DataFrame,
151
+ ) -> float:
152
+ """Calculate residual error between model time course and experimental data."""
153
+ ...
154
+
155
+
156
+ def _steady_state_residual(
157
+ par_values: Array,
158
+ # This will be filled out by partial
159
+ par_names: list[str],
160
+ data: pd.Series,
161
+ model: Model,
162
+ y0: dict[str, float] | None,
163
+ integrator: IntegratorType | None,
164
+ loss_fn: LossFn,
165
+ ) -> float:
166
+ """Calculate residual error between model steady state and experimental data.
167
+
168
+ Args:
169
+ par_values: Parameter values to test
170
+ data: Experimental steady state data
171
+ model: Model instance to simulate
172
+ y0: Initial conditions
173
+ par_names: Names of parameters being fit
174
+ integrator: ODE integrator class to use
175
+ loss_fn: Loss function to use for residual calculation
176
+
177
+ Returns:
178
+ float: Root mean square error between model and data
179
+
180
+ """
181
+ res = (
182
+ Simulator(
183
+ model.update_parameters(
184
+ dict(
185
+ zip(
186
+ par_names,
187
+ par_values,
188
+ strict=True,
189
+ )
190
+ )
191
+ ),
192
+ y0=y0,
193
+ integrator=integrator,
194
+ )
195
+ .simulate_to_steady_state()
196
+ .get_result()
197
+ )
198
+ if res is None:
199
+ return cast(float, np.inf)
200
+
201
+ return loss_fn(
202
+ res.get_combined().loc[:, cast(list, data.index)],
203
+ data,
204
+ )
205
+
206
+
207
+ def _time_course_residual(
208
+ par_values: ArrayLike,
209
+ # This will be filled out by partial
210
+ par_names: list[str],
211
+ data: pd.DataFrame,
212
+ model: Model,
213
+ y0: dict[str, float] | None,
214
+ integrator: IntegratorType | None,
215
+ loss_fn: LossFn,
216
+ ) -> float:
217
+ """Calculate residual error between model time course and experimental data.
218
+
219
+ Args:
220
+ par_values: Parameter values to test
221
+ data: Experimental time course data
222
+ model: Model instance to simulate
223
+ y0: Initial conditions
224
+ par_names: Names of parameters being fit
225
+ integrator: ODE integrator class to use
226
+ loss_fn: Loss function to use for residual calculation
227
+
228
+ Returns:
229
+ float: Root mean square error between model and data
230
+
231
+ """
232
+ res = (
233
+ Simulator(
234
+ model.update_parameters(dict(zip(par_names, par_values, strict=True))),
235
+ y0=y0,
236
+ integrator=integrator,
237
+ )
238
+ .simulate_time_course(cast(list, data.index))
239
+ .get_result()
240
+ )
241
+ if res is None:
242
+ return cast(float, np.inf)
243
+ results_ss = res.get_combined()
244
+
245
+ return loss_fn(
246
+ results_ss.loc[:, cast(list, data.columns)],
247
+ data,
248
+ )
249
+
250
+
251
+ def _protocol_time_course_residual(
252
+ par_values: ArrayLike,
253
+ # This will be filled out by partial
254
+ par_names: list[str],
255
+ data: pd.DataFrame,
256
+ model: Model,
257
+ y0: dict[str, float] | None,
258
+ integrator: IntegratorType | None,
259
+ loss_fn: LossFn,
260
+ protocol: pd.DataFrame,
261
+ ) -> float:
262
+ """Calculate residual error between model time course and experimental data.
263
+
264
+ Args:
265
+ par_values: Parameter values to test
266
+ data: Experimental time course data
267
+ model: Model instance to simulate
268
+ y0: Initial conditions
269
+ par_names: Names of parameters being fit
270
+ integrator: ODE integrator class to use
271
+ loss_fn: Loss function to use for residual calculation
272
+ protocol: Experimental protocol
273
+ time_points_per_step: Number of time points per step in the protocol
274
+
275
+ Returns:
276
+ float: Root mean square error between model and data
277
+
278
+ """
279
+ res = (
280
+ Simulator(
281
+ model.update_parameters(dict(zip(par_names, par_values, strict=True))),
282
+ y0=y0,
283
+ integrator=integrator,
284
+ )
285
+ .simulate_protocol_time_course(
286
+ protocol=protocol,
287
+ time_points=data.index,
288
+ )
289
+ .get_result()
290
+ )
291
+ if res is None:
292
+ return cast(float, np.inf)
293
+ results_ss = res.get_combined()
294
+
295
+ return loss_fn(
296
+ results_ss.loc[:, cast(list, data.columns)],
297
+ data,
298
+ )