mxlpy 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/__init__.py CHANGED
@@ -54,7 +54,7 @@ from . import (
54
54
  sbml,
55
55
  units,
56
56
  )
57
- from .integrators import DefaultIntegrator, Scipy
57
+ from .integrators import DefaultIntegrator, Diffrax, Scipy
58
58
  from .label_map import LabelMapper
59
59
  from .linear_label_map import LinearLabelMapper
60
60
  from .mc import Cache
@@ -62,7 +62,14 @@ from .model import Model
62
62
  from .scan import steady_state, time_course, time_course_over_protocol
63
63
  from .simulator import Simulator
64
64
  from .symbolic import SymbolicModel, to_symbolic_model
65
- from .types import Derived, IntegratorProtocol, Parameter, Variable, unwrap
65
+ from .types import (
66
+ Derived,
67
+ InitialAssignment,
68
+ IntegratorProtocol,
69
+ Parameter,
70
+ Variable,
71
+ unwrap,
72
+ )
66
73
 
67
74
  with contextlib.suppress(ImportError):
68
75
  from .integrators import Assimulo
@@ -84,6 +91,8 @@ __all__ = [
84
91
  "Cache",
85
92
  "DefaultIntegrator",
86
93
  "Derived",
94
+ "Diffrax",
95
+ "InitialAssignment",
87
96
  "IntegratorProtocol",
88
97
  "LabelMapper",
89
98
  "LinearLabelMapper",
mxlpy/carousel.py CHANGED
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
18
18
  from collections.abc import Iterable, Mapping
19
19
 
20
20
  from mxlpy import Model
21
- from mxlpy.types import Array, IntegratorType, RateFn
21
+ from mxlpy.types import Array, IntegratorType, RateFn, Result
22
22
 
23
23
 
24
24
  @dataclass
@@ -35,11 +35,13 @@ class CarouselSteadyState:
35
35
  """Time course of a carousel simulation."""
36
36
 
37
37
  carousel: list[Model]
38
- results: list[scan.TimePoint]
38
+ results: list[Result]
39
39
 
40
40
  def get_variables_by_model(self) -> pd.DataFrame:
41
41
  """Get the variables of the time course results, indexed by model."""
42
- return pd.DataFrame({i: r.variables for i, r in enumerate(self.results)}).T
42
+ return pd.DataFrame(
43
+ {i: r.variables.iloc[-1] for i, r in enumerate(self.results)}
44
+ ).T
43
45
 
44
46
 
45
47
  @dataclass
@@ -47,7 +49,7 @@ class CarouselTimeCourse:
47
49
  """Time course of a carousel simulation."""
48
50
 
49
51
  carousel: list[Model]
50
- results: list[scan.TimeCourse]
52
+ results: list[Result]
51
53
 
52
54
  def get_variables_by_model(self) -> pd.DataFrame:
53
55
  """Get the variables of the time course results, indexed by model."""
mxlpy/compare.py CHANGED
@@ -8,8 +8,8 @@ from typing import TYPE_CHECKING, cast
8
8
  import pandas as pd
9
9
 
10
10
  from mxlpy import plot
11
- from mxlpy.simulator import Result, Simulator
12
- from mxlpy.types import unwrap
11
+ from mxlpy.simulator import Simulator
12
+ from mxlpy.types import Result, unwrap
13
13
 
14
14
  if TYPE_CHECKING:
15
15
  from mxlpy.model import Model
@@ -6,6 +6,7 @@ It includes support for both Assimulo and Scipy integrators, with Assimulo being
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from .int_diffrax import Diffrax
9
10
  from .int_scipy import Scipy
10
11
 
11
12
  try:
@@ -16,5 +17,8 @@ except ImportError:
16
17
  DefaultIntegrator = Scipy
17
18
 
18
19
  __all__ = [
20
+ "Assimulo",
19
21
  "DefaultIntegrator",
22
+ "Diffrax",
23
+ "Scipy",
20
24
  ]
@@ -17,7 +17,7 @@ with contextlib.redirect_stderr(open(os.devnull, "w")): # noqa: PTH123
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import Callable
19
19
 
20
- from mxlpy.types import Array, ArrayLike
20
+ from mxlpy.types import Array, ArrayLike, Rhs
21
21
 
22
22
 
23
23
  __all__ = [
@@ -43,8 +43,8 @@ class Assimulo:
43
43
 
44
44
  """
45
45
 
46
- rhs: Callable
47
- y0: ArrayLike
46
+ rhs: Rhs
47
+ y0: tuple[float, ...]
48
48
  jacobian: Callable | None = None
49
49
  atol: float = 1e-8
50
50
  rtol: float = 1e-8
@@ -0,0 +1,119 @@
1
+ """Diffrax integrator for solving ODEs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING
7
+
8
+ import numpy as np
9
+ from diffrax import (
10
+ AbstractSolver,
11
+ AbstractStepSizeController,
12
+ Kvaerno5,
13
+ ODETerm,
14
+ PIDController,
15
+ SaveAt,
16
+ diffeqsolve,
17
+ )
18
+
19
+ __all__ = ["Diffrax"]
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable
23
+
24
+ from mxlpy.types import Array, Rhs
25
+
26
+
27
+ @dataclass
28
+ class Diffrax:
29
+ """Diffrax integrator for solving ODEs."""
30
+
31
+ rhs: Rhs
32
+ y0: tuple[float, ...]
33
+ jac: Callable | None = None
34
+ solver: AbstractSolver = field(default=Kvaerno5())
35
+ stepsize_controller: AbstractStepSizeController = field(
36
+ default=PIDController(rtol=1e-8, atol=1e-8)
37
+ )
38
+ t0: float = 0.0
39
+
40
+ def __post_init__(self) -> None:
41
+ """Create copy of initial state.
42
+
43
+ This method creates a copy of the initial state `y0` and stores it in the `_y0_orig` attribute.
44
+ This is useful for preserving the original initial state for future reference or reset operations.
45
+
46
+ """
47
+ self._y0_orig = self.y0
48
+
49
+ def reset(self) -> None:
50
+ """Reset the integrator."""
51
+ self.t0 = 0
52
+ self.y0 = self._y0_orig
53
+
54
+ def integrate_time_course(
55
+ self, *, time_points: Array
56
+ ) -> tuple[Array | None, Array | None]:
57
+ """Integrate the ODE system over a time course.
58
+
59
+ Args:
60
+ time_points: Time points for the integration.
61
+
62
+ Returns:
63
+ tuple[Array, Array]: Tuple containing the time points and the integrated values.
64
+
65
+ """
66
+ if time_points[0] != self.t0:
67
+ time_points = np.insert(time_points, 0, self.t0)
68
+
69
+ res = diffeqsolve(
70
+ ODETerm(lambda t, y, _: self.rhs(t, y)), # type: ignore
71
+ solver=self.solver,
72
+ t0=time_points[0],
73
+ t1=time_points[-1],
74
+ dt0=None,
75
+ y0=self.y0,
76
+ max_steps=None,
77
+ saveat=SaveAt(ts=time_points), # type: ignore
78
+ stepsize_controller=self.stepsize_controller,
79
+ )
80
+
81
+ t = np.atleast_1d(np.array(res.ts, dtype=float))
82
+ y = np.atleast_2d(np.array(res.ys, dtype=float).T)
83
+
84
+ self.t0 = t[-1]
85
+ self.y0 = y[-1]
86
+ return t, y
87
+
88
+ def integrate(
89
+ self,
90
+ *,
91
+ t_end: float,
92
+ steps: int | None = None,
93
+ ) -> tuple[Array | None, Array | None]:
94
+ """Integrate the ODE system over a time course."""
95
+ steps = 100 if steps is None else steps
96
+
97
+ return self.integrate_time_course(
98
+ time_points=np.linspace(self.t0, t_end, steps, dtype=float)
99
+ )
100
+
101
+ def integrate_to_steady_state(
102
+ self,
103
+ *,
104
+ tolerance: float,
105
+ rel_norm: bool,
106
+ t_max: float = 1_000_000_000,
107
+ ) -> tuple[float | None, Array | None]:
108
+ """Integrate the ODE system to steady state.
109
+
110
+ Args:
111
+ tolerance: Tolerance for determining steady state.
112
+ rel_norm: Whether to use relative normalization.
113
+ t_max: Maximum time point for the integration (default: 1,000,000,000).
114
+
115
+ Returns:
116
+ tuple[float | None, Array | None]: Tuple containing the final time point and the integrated values at steady state.
117
+
118
+ """
119
+ raise NotImplementedError
@@ -14,6 +14,8 @@ from mxlpy.types import Array, ArrayLike
14
14
  if TYPE_CHECKING:
15
15
  from collections.abc import Callable
16
16
 
17
+ from mxlpy.types import Rhs
18
+
17
19
 
18
20
  __all__ = [
19
21
  "Scipy",
@@ -40,13 +42,13 @@ class Scipy:
40
42
 
41
43
  """
42
44
 
43
- rhs: Callable
44
- y0: ArrayLike
45
+ rhs: Rhs
46
+ y0: tuple[float, ...]
45
47
  jacobian: Callable | None = None
46
48
  atol: float = 1e-8
47
49
  rtol: float = 1e-8
48
50
  t0: float = 0.0
49
- _y0_orig: ArrayLike = field(default_factory=list)
51
+ _y0_orig: tuple[float, ...] = field(default_factory=tuple)
50
52
 
51
53
  def __post_init__(self) -> None:
52
54
  """Create copy of initial state.
@@ -55,12 +57,12 @@ class Scipy:
55
57
  This is useful for preserving the original initial state for future reference or reset operations.
56
58
 
57
59
  """
58
- self._y0_orig = self.y0.copy()
60
+ self._y0_orig = self.y0
59
61
 
60
62
  def reset(self) -> None:
61
63
  """Reset the integrator."""
62
64
  self.t0 = 0
63
- self.y0 = self._y0_orig.copy()
65
+ self.y0 = self._y0_orig
64
66
 
65
67
  def integrate(
66
68
  self,
@@ -143,9 +145,13 @@ class Scipy:
143
145
 
144
146
  """
145
147
  self.reset()
146
- integ = spi.ode(self.rhs, jac=self.jacobian)
148
+
149
+ # If rhs returns a tuple, we get weird errors, so we need
150
+ # to wrap this in a list for some reason
151
+ integ = spi.ode(lambda t, x: list(self.rhs(t, x)), jac=self.jacobian)
147
152
  integ.set_integrator(name="lsoda")
148
153
  integ.set_initial_value(self.y0)
154
+
149
155
  t = self.t0 + step_size
150
156
  y1 = copy.deepcopy(self.y0)
151
157
  for _ in range(max_steps):
mxlpy/label_map.py CHANGED
@@ -27,7 +27,6 @@ from mxlpy.model import Model
27
27
  if TYPE_CHECKING:
28
28
  from collections.abc import Callable, Mapping
29
29
 
30
- from mxlpy.types import Derived
31
30
 
32
31
  __all__ = [
33
32
  "LabelMapper",
@@ -556,7 +555,7 @@ class LabelMapper:
556
555
  for name, dp in self.model.get_derived_parameters().items():
557
556
  m.add_derived(name, fn=dp.fn, args=dp.args)
558
557
 
559
- variables: dict[str, float | Derived] = {}
558
+ variables: dict[str, float] = {}
560
559
  for k, v in self.model.get_initial_conditions().items():
561
560
  if (isos := isotopomers.get(k)) is None:
562
561
  variables[k] = v
mxlpy/mc.py CHANGED
@@ -35,10 +35,10 @@ from mxlpy.scan import (
35
35
  from mxlpy.types import (
36
36
  IntegratorType,
37
37
  McSteadyStates,
38
- ProtocolByPars,
38
+ ProtocolScan,
39
39
  ResponseCoefficientsByPars,
40
- SteadyStates,
41
- TimeCourseByPars,
40
+ SteadyStateScan,
41
+ TimeCourseScan,
42
42
  )
43
43
 
44
44
  if TYPE_CHECKING:
@@ -69,7 +69,7 @@ class ParameterScanWorker(Protocol):
69
69
  y0: dict[str, float] | None,
70
70
  rel_norm: bool,
71
71
  integrator: IntegratorType,
72
- ) -> SteadyStates:
72
+ ) -> SteadyStateScan:
73
73
  """Call the worker function."""
74
74
  ...
75
75
 
@@ -81,7 +81,7 @@ def _parameter_scan_worker(
81
81
  y0: dict[str, float] | None,
82
82
  rel_norm: bool,
83
83
  integrator: IntegratorType,
84
- ) -> SteadyStates:
84
+ ) -> SteadyStateScan:
85
85
  """Worker function for parallel steady state scanning across parameter sets.
86
86
 
87
87
  This function executes a parameter scan for steady state solutions for a
@@ -125,7 +125,7 @@ def steady_state(
125
125
  rel_norm: bool = False,
126
126
  worker: SteadyStateWorker = _steady_state_worker,
127
127
  integrator: IntegratorType | None = None,
128
- ) -> SteadyStates:
128
+ ) -> SteadyStateScan:
129
129
  """Monte-carlo scan of steady states.
130
130
 
131
131
  Examples:
@@ -163,10 +163,14 @@ def steady_state(
163
163
  max_workers=max_workers,
164
164
  cache=cache,
165
165
  )
166
- return SteadyStates(
167
- variables=pd.concat({k: v.variables for k, v in res}, axis=1).T,
168
- fluxes=pd.concat({k: v.fluxes for k, v in res}, axis=1).T,
169
- parameters=mc_to_scan,
166
+ return SteadyStateScan(
167
+ raw_index=(
168
+ pd.Index(mc_to_scan.iloc[:, 0])
169
+ if mc_to_scan.shape[1] == 1
170
+ else pd.MultiIndex.from_frame(mc_to_scan)
171
+ ),
172
+ raw_results=[i[1] for i in res],
173
+ to_scan=mc_to_scan,
170
174
  )
171
175
 
172
176
 
@@ -180,7 +184,7 @@ def time_course(
180
184
  cache: Cache | None = None,
181
185
  worker: TimeCourseWorker = _time_course_worker,
182
186
  integrator: IntegratorType | None = None,
183
- ) -> TimeCourseByPars:
187
+ ) -> TimeCourseScan:
184
188
  """MC time course.
185
189
 
186
190
  Examples:
@@ -219,10 +223,9 @@ def time_course(
219
223
  cache=cache,
220
224
  )
221
225
 
222
- return TimeCourseByPars(
223
- parameters=mc_to_scan,
224
- variables=pd.concat({k: v.variables.T for k, v in res}, axis=1).T,
225
- fluxes=pd.concat({k: v.fluxes.T for k, v in res}, axis=1).T,
226
+ return TimeCourseScan(
227
+ to_scan=mc_to_scan,
228
+ raw_results=dict(res),
226
229
  )
227
230
 
228
231
 
@@ -237,7 +240,7 @@ def time_course_over_protocol(
237
240
  cache: Cache | None = None,
238
241
  worker: ProtocolWorker = _protocol_worker,
239
242
  integrator: IntegratorType | None = None,
240
- ) -> ProtocolByPars:
243
+ ) -> ProtocolScan:
241
244
  """MC time course.
242
245
 
243
246
  Examples:
@@ -277,13 +280,10 @@ def time_course_over_protocol(
277
280
  max_workers=max_workers,
278
281
  cache=cache,
279
282
  )
280
- concs = {k: v.variables.T for k, v in res}
281
- fluxes = {k: v.fluxes.T for k, v in res}
282
- return ProtocolByPars(
283
- variables=pd.concat(concs, axis=1).T,
284
- fluxes=pd.concat(fluxes, axis=1).T,
285
- parameters=mc_to_scan,
283
+ return ProtocolScan(
284
+ to_scan=mc_to_scan,
286
285
  protocol=protocol,
286
+ raw_results=dict(res),
287
287
  )
288
288
 
289
289
 
mxlpy/mca.py CHANGED
@@ -91,8 +91,12 @@ def _response_coefficient_worker(
91
91
  y0=None,
92
92
  )
93
93
 
94
- conc_resp = (upper.variables - lower.variables) / (2 * displacement * old)
95
- flux_resp = (upper.fluxes - lower.fluxes) / (2 * displacement * old)
94
+ conc_resp = (upper.variables.iloc[-1] - lower.variables.iloc[-1]) / (
95
+ 2 * displacement * old
96
+ )
97
+ flux_resp = (upper.fluxes.iloc[-1] - lower.fluxes.iloc[-1]) / (
98
+ 2 * displacement * old
99
+ )
96
100
  # Reset
97
101
  model.update_parameters({parameter: old})
98
102
  if normalized:
@@ -102,8 +106,8 @@ def _response_coefficient_worker(
102
106
  integrator=integrator,
103
107
  y0=None,
104
108
  )
105
- conc_resp *= old / norm.variables
106
- flux_resp *= old / norm.fluxes
109
+ conc_resp *= old / norm.variables.iloc[-1]
110
+ flux_resp *= old / norm.fluxes.iloc[-1]
107
111
  return conc_resp, flux_resp
108
112
 
109
113
 
@@ -110,7 +110,8 @@ def _generate_model_code(
110
110
  _LOGGER.warning(msg)
111
111
 
112
112
  # Return
113
- ret = ", ".join(f"d{i}dt" for i in diff_eqs) if len(diff_eqs) > 0 else "()"
113
+ ret_order = [i for i in variables if i in diff_eqs]
114
+ ret = ", ".join(f"d{i}dt" for i in ret_order) if len(diff_eqs) > 0 else "()"
114
115
  source.append(return_template.format(ret))
115
116
 
116
117
  if end is not None: