mxlpy 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/__init__.py CHANGED
@@ -52,17 +52,24 @@ from . import (
52
52
  plot,
53
53
  report,
54
54
  sbml,
55
+ scan,
55
56
  units,
56
57
  )
57
- from .integrators import DefaultIntegrator, Scipy
58
+ from .integrators import DefaultIntegrator, Diffrax, Scipy
58
59
  from .label_map import LabelMapper
59
60
  from .linear_label_map import LinearLabelMapper
60
61
  from .mc import Cache
61
62
  from .model import Model
62
- from .scan import steady_state, time_course, time_course_over_protocol
63
63
  from .simulator import Simulator
64
64
  from .symbolic import SymbolicModel, to_symbolic_model
65
- from .types import Derived, IntegratorProtocol, Parameter, Variable, unwrap
65
+ from .types import (
66
+ Derived,
67
+ InitialAssignment,
68
+ IntegratorProtocol,
69
+ Parameter,
70
+ Variable,
71
+ unwrap,
72
+ )
66
73
 
67
74
  with contextlib.suppress(ImportError):
68
75
  from .integrators import Assimulo
@@ -84,6 +91,8 @@ __all__ = [
84
91
  "Cache",
85
92
  "DefaultIntegrator",
86
93
  "Derived",
94
+ "Diffrax",
95
+ "InitialAssignment",
87
96
  "IntegratorProtocol",
88
97
  "LabelMapper",
89
98
  "LinearLabelMapper",
@@ -107,11 +116,9 @@ __all__ = [
107
116
  "plot",
108
117
  "report",
109
118
  "sbml",
110
- "steady_state",
119
+ "scan",
111
120
  "surrogates",
112
121
  "symbolic",
113
- "time_course",
114
- "time_course_over_protocol",
115
122
  "to_symbolic_model",
116
123
  "units",
117
124
  "unwrap",
mxlpy/carousel.py CHANGED
@@ -9,6 +9,7 @@ from functools import partial
9
9
  from typing import TYPE_CHECKING
10
10
 
11
11
  import pandas as pd
12
+ from wadler_lindig import pformat
12
13
 
13
14
  from mxlpy import parallel, scan
14
15
 
@@ -18,13 +19,17 @@ if TYPE_CHECKING:
18
19
  from collections.abc import Iterable, Mapping
19
20
 
20
21
  from mxlpy import Model
21
- from mxlpy.types import Array, IntegratorType, RateFn
22
+ from mxlpy.types import Array, IntegratorType, RateFn, Result
22
23
 
23
24
 
24
25
  @dataclass
25
26
  class ReactionTemplate:
26
27
  """Template for a reaction in a model."""
27
28
 
29
+ def __repr__(self) -> str:
30
+ """Return default representation."""
31
+ return pformat(self)
32
+
28
33
  fn: RateFn
29
34
  args: list[str]
30
35
  additional_parameters: dict[str, float] = field(default_factory=dict)
@@ -35,11 +40,17 @@ class CarouselSteadyState:
35
40
  """Time course of a carousel simulation."""
36
41
 
37
42
  carousel: list[Model]
38
- results: list[scan.TimePoint]
43
+ results: list[Result]
44
+
45
+ def __repr__(self) -> str:
46
+ """Return default representation."""
47
+ return pformat(self)
39
48
 
40
49
  def get_variables_by_model(self) -> pd.DataFrame:
41
50
  """Get the variables of the time course results, indexed by model."""
42
- return pd.DataFrame({i: r.variables for i, r in enumerate(self.results)}).T
51
+ return pd.DataFrame(
52
+ {i: r.variables.iloc[-1] for i, r in enumerate(self.results)}
53
+ ).T
43
54
 
44
55
 
45
56
  @dataclass
@@ -47,7 +58,11 @@ class CarouselTimeCourse:
47
58
  """Time course of a carousel simulation."""
48
59
 
49
60
  carousel: list[Model]
50
- results: list[scan.TimeCourse]
61
+ results: list[Result]
62
+
63
+ def __repr__(self) -> str:
64
+ """Return default representation."""
65
+ return pformat(self)
51
66
 
52
67
  def get_variables_by_model(self) -> pd.DataFrame:
53
68
  """Get the variables of the time course results, indexed by model."""
@@ -74,6 +89,10 @@ class Carousel:
74
89
 
75
90
  variants: list[Model]
76
91
 
92
+ def __repr__(self) -> str:
93
+ """Return default representation."""
94
+ return pformat(self)
95
+
77
96
  def __init__(
78
97
  self,
79
98
  model: Model,
@@ -113,7 +132,7 @@ class Carousel:
113
132
  results=results,
114
133
  )
115
134
 
116
- def protocol_time_course(
135
+ def protocol(
117
136
  self,
118
137
  protocol: pd.DataFrame,
119
138
  *,
@@ -139,6 +158,34 @@ class Carousel:
139
158
  results=results,
140
159
  )
141
160
 
161
+ def protocol_time_course(
162
+ self,
163
+ protocol: pd.DataFrame,
164
+ time_points: Array,
165
+ *,
166
+ y0: dict[str, float] | None = None,
167
+ integrator: IntegratorType | None = None,
168
+ ) -> CarouselTimeCourse:
169
+ """Simulate the carousel of models over a protocol time course."""
170
+ results = [
171
+ i[1]
172
+ for i in parallel.parallelise(
173
+ partial(
174
+ scan._protocol_time_course_worker, # noqa: SLF001
175
+ protocol=protocol,
176
+ integrator=integrator,
177
+ time_points=time_points,
178
+ y0=y0,
179
+ ),
180
+ list(enumerate(self.variants)),
181
+ )
182
+ ]
183
+
184
+ return CarouselTimeCourse(
185
+ carousel=self.variants,
186
+ results=results,
187
+ )
188
+
142
189
  def steady_state(
143
190
  self,
144
191
  *,
mxlpy/compare.py CHANGED
@@ -6,10 +6,11 @@ from dataclasses import dataclass
6
6
  from typing import TYPE_CHECKING, cast
7
7
 
8
8
  import pandas as pd
9
+ from wadler_lindig import pformat
9
10
 
10
11
  from mxlpy import plot
11
- from mxlpy.simulator import Result, Simulator
12
- from mxlpy.types import unwrap
12
+ from mxlpy.simulator import Simulator
13
+ from mxlpy.types import Result, unwrap
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from mxlpy.model import Model
@@ -32,6 +33,10 @@ class SteadyStateComparison:
32
33
  res1: Result
33
34
  res2: Result
34
35
 
36
+ def __repr__(self) -> str:
37
+ """Return default representation."""
38
+ return pformat(self)
39
+
35
40
  @property
36
41
  def variables(self) -> pd.DataFrame:
37
42
  """Compare the steady state variables."""
@@ -93,6 +98,10 @@ class TimeCourseComparison:
93
98
  res1: Result
94
99
  res2: Result
95
100
 
101
+ def __repr__(self) -> str:
102
+ """Return default representation."""
103
+ return pformat(self)
104
+
96
105
  # @property
97
106
  # def variables(self) -> pd.DataFrame:
98
107
  # """Compare the steady state variables."""
@@ -144,6 +153,10 @@ class ProtocolComparison:
144
153
  res2: Result
145
154
  protocol: pd.DataFrame
146
155
 
156
+ def __repr__(self) -> str:
157
+ """Return default representation."""
158
+ return pformat(self)
159
+
147
160
  # @property
148
161
  # def variables(self) -> pd.DataFrame:
149
162
  # """Compare the steady state variables."""
@@ -5,6 +5,8 @@ from __future__ import annotations
5
5
  from dataclasses import dataclass, field
6
6
  from typing import TYPE_CHECKING
7
7
 
8
+ from wadler_lindig import pformat
9
+
8
10
  from mxlpy.types import Derived
9
11
 
10
12
  if TYPE_CHECKING:
@@ -28,6 +30,10 @@ class DerivedDiff:
28
30
  args1: list[str] = field(default_factory=list)
29
31
  args2: list[str] = field(default_factory=list)
30
32
 
33
+ def __repr__(self) -> str:
34
+ """Return default representation."""
35
+ return pformat(self)
36
+
31
37
 
32
38
  @dataclass
33
39
  class ReactionDiff:
@@ -38,6 +44,10 @@ class ReactionDiff:
38
44
  stoichiometry1: dict[str, float | Derived] = field(default_factory=dict)
39
45
  stoichiometry2: dict[str, float | Derived] = field(default_factory=dict)
40
46
 
47
+ def __repr__(self) -> str:
48
+ """Return default representation."""
49
+ return pformat(self)
50
+
41
51
 
42
52
  @dataclass
43
53
  class ModelDiff:
@@ -56,6 +66,10 @@ class ModelDiff:
56
66
  different_readouts: dict[str, DerivedDiff] = field(default_factory=dict)
57
67
  different_derived: dict[str, DerivedDiff] = field(default_factory=dict)
58
68
 
69
+ def __repr__(self) -> str:
70
+ """Return default representation."""
71
+ return pformat(self)
72
+
59
73
  def __str__(self) -> str:
60
74
  """Return a human-readable string representation of the diff."""
61
75
  content = ["Model Diff", "----------"]
mxlpy/fit.py CHANGED
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Protocol
18
18
 
19
19
  import numpy as np
20
20
  from scipy.optimize import minimize
21
+ from wadler_lindig import pformat
21
22
 
22
23
  from mxlpy import parallel
23
24
  from mxlpy.simulator import Simulator
@@ -61,6 +62,10 @@ class MinResult:
61
62
  parameters: dict[str, float]
62
63
  residual: float
63
64
 
65
+ def __repr__(self) -> str:
66
+ """Return default representation."""
67
+ return pformat(self)
68
+
64
69
 
65
70
  @dataclass
66
71
  class FitResult:
@@ -70,6 +75,10 @@ class FitResult:
70
75
  best_pars: dict[str, float]
71
76
  loss: float
72
77
 
78
+ def __repr__(self) -> str:
79
+ """Return default representation."""
80
+ return pformat(self)
81
+
73
82
 
74
83
  @dataclass
75
84
  class CarouselFit:
@@ -77,6 +86,10 @@ class CarouselFit:
77
86
 
78
87
  fits: list[FitResult]
79
88
 
89
+ def __repr__(self) -> str:
90
+ """Return default representation."""
91
+ return pformat(self)
92
+
80
93
  def get_best_fit(self) -> FitResult:
81
94
  """Get the best fit from the carousel."""
82
95
  return min(self.fits, key=lambda x: x.loss)
@@ -561,6 +574,8 @@ def protocol_time_course(
561
574
  ) -> FitResult | None:
562
575
  """Fit model parameters to time course of experimental data.
563
576
 
577
+ Time points of protocol time course are taken from the data.
578
+
564
579
  Examples:
565
580
  >>> time_course(model, p0, data)
566
581
  {'k1': 0.1, 'k2': 0.2}
@@ -688,8 +703,10 @@ def carousel_time_course(
688
703
  ) -> CarouselFit:
689
704
  """Fit model parameters to time course of experimental data over a carousel.
690
705
 
706
+ Time points are taken from the data.
707
+
691
708
  Examples:
692
- >>> carousel_steady_state(carousel, p0=p0, data=data)
709
+ >>> carousel_time_course(carousel, p0=p0, data=data)
693
710
 
694
711
  Args:
695
712
  carousel: Model carousel to fit
@@ -748,6 +765,8 @@ def carousel_protocol_time_course(
748
765
  ) -> CarouselFit:
749
766
  """Fit model parameters to time course of experimental data over a protocol.
750
767
 
768
+ Time points of protocol time course are taken from the data.
769
+
751
770
  Examples:
752
771
  >>> carousel_steady_state(carousel, p0=p0, data=data)
753
772
 
@@ -6,6 +6,7 @@ It includes support for both Assimulo and Scipy integrators, with Assimulo being
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from .int_diffrax import Diffrax
9
10
  from .int_scipy import Scipy
10
11
 
11
12
  try:
@@ -16,5 +17,8 @@ except ImportError:
16
17
  DefaultIntegrator = Scipy
17
18
 
18
19
  __all__ = [
20
+ "Assimulo",
19
21
  "DefaultIntegrator",
22
+ "Diffrax",
23
+ "Scipy",
20
24
  ]
@@ -17,7 +17,7 @@ with contextlib.redirect_stderr(open(os.devnull, "w")): # noqa: PTH123
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import Callable
19
19
 
20
- from mxlpy.types import Array, ArrayLike
20
+ from mxlpy.types import Array, ArrayLike, Rhs
21
21
 
22
22
 
23
23
  __all__ = [
@@ -43,8 +43,8 @@ class Assimulo:
43
43
 
44
44
  """
45
45
 
46
- rhs: Callable
47
- y0: ArrayLike
46
+ rhs: Rhs
47
+ y0: tuple[float, ...]
48
48
  jacobian: Callable | None = None
49
49
  atol: float = 1e-8
50
50
  rtol: float = 1e-8
@@ -0,0 +1,119 @@
1
+ """Diffrax integrator for solving ODEs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING
7
+
8
+ import numpy as np
9
+ from diffrax import (
10
+ AbstractSolver,
11
+ AbstractStepSizeController,
12
+ Kvaerno5,
13
+ ODETerm,
14
+ PIDController,
15
+ SaveAt,
16
+ diffeqsolve,
17
+ )
18
+
19
+ __all__ = ["Diffrax"]
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable
23
+
24
+ from mxlpy.types import Array, Rhs
25
+
26
+
27
+ @dataclass
28
+ class Diffrax:
29
+ """Diffrax integrator for solving ODEs."""
30
+
31
+ rhs: Rhs
32
+ y0: tuple[float, ...]
33
+ jac: Callable | None = None
34
+ solver: AbstractSolver = field(default=Kvaerno5())
35
+ stepsize_controller: AbstractStepSizeController = field(
36
+ default=PIDController(rtol=1e-8, atol=1e-8)
37
+ )
38
+ t0: float = 0.0
39
+
40
+ def __post_init__(self) -> None:
41
+ """Create copy of initial state.
42
+
43
+ This method creates a copy of the initial state `y0` and stores it in the `_y0_orig` attribute.
44
+ This is useful for preserving the original initial state for future reference or reset operations.
45
+
46
+ """
47
+ self._y0_orig = self.y0
48
+
49
+ def reset(self) -> None:
50
+ """Reset the integrator."""
51
+ self.t0 = 0
52
+ self.y0 = self._y0_orig
53
+
54
+ def integrate_time_course(
55
+ self, *, time_points: Array
56
+ ) -> tuple[Array | None, Array | None]:
57
+ """Integrate the ODE system over a time course.
58
+
59
+ Args:
60
+ time_points: Time points for the integration.
61
+
62
+ Returns:
63
+ tuple[Array, Array]: Tuple containing the time points and the integrated values.
64
+
65
+ """
66
+ if time_points[0] != self.t0:
67
+ time_points = np.insert(time_points, 0, self.t0)
68
+
69
+ res = diffeqsolve(
70
+ ODETerm(lambda t, y, _: self.rhs(t, y)), # type: ignore
71
+ solver=self.solver,
72
+ t0=time_points[0],
73
+ t1=time_points[-1],
74
+ dt0=None,
75
+ y0=self.y0,
76
+ max_steps=None,
77
+ saveat=SaveAt(ts=time_points), # type: ignore
78
+ stepsize_controller=self.stepsize_controller,
79
+ )
80
+
81
+ t = np.atleast_1d(np.array(res.ts, dtype=float))
82
+ y = np.atleast_2d(np.array(res.ys, dtype=float).T)
83
+
84
+ self.t0 = t[-1]
85
+ self.y0 = y[-1]
86
+ return t, y
87
+
88
+ def integrate(
89
+ self,
90
+ *,
91
+ t_end: float,
92
+ steps: int | None = None,
93
+ ) -> tuple[Array | None, Array | None]:
94
+ """Integrate the ODE system over a time course."""
95
+ steps = 100 if steps is None else steps
96
+
97
+ return self.integrate_time_course(
98
+ time_points=np.linspace(self.t0, t_end, steps, dtype=float)
99
+ )
100
+
101
+ def integrate_to_steady_state(
102
+ self,
103
+ *,
104
+ tolerance: float,
105
+ rel_norm: bool,
106
+ t_max: float = 1_000_000_000,
107
+ ) -> tuple[float | None, Array | None]:
108
+ """Integrate the ODE system to steady state.
109
+
110
+ Args:
111
+ tolerance: Tolerance for determining steady state.
112
+ rel_norm: Whether to use relative normalization.
113
+ t_max: Maximum time point for the integration (default: 1,000,000,000).
114
+
115
+ Returns:
116
+ tuple[float | None, Array | None]: Tuple containing the final time point and the integrated values at steady state.
117
+
118
+ """
119
+ raise NotImplementedError
@@ -14,6 +14,8 @@ from mxlpy.types import Array, ArrayLike
14
14
  if TYPE_CHECKING:
15
15
  from collections.abc import Callable
16
16
 
17
+ from mxlpy.types import Rhs
18
+
17
19
 
18
20
  __all__ = [
19
21
  "Scipy",
@@ -40,13 +42,13 @@ class Scipy:
40
42
 
41
43
  """
42
44
 
43
- rhs: Callable
44
- y0: ArrayLike
45
+ rhs: Rhs
46
+ y0: tuple[float, ...]
45
47
  jacobian: Callable | None = None
46
48
  atol: float = 1e-8
47
49
  rtol: float = 1e-8
48
50
  t0: float = 0.0
49
- _y0_orig: ArrayLike = field(default_factory=list)
51
+ _y0_orig: tuple[float, ...] = field(default_factory=tuple)
50
52
 
51
53
  def __post_init__(self) -> None:
52
54
  """Create copy of initial state.
@@ -55,12 +57,12 @@ class Scipy:
55
57
  This is useful for preserving the original initial state for future reference or reset operations.
56
58
 
57
59
  """
58
- self._y0_orig = self.y0.copy()
60
+ self._y0_orig = self.y0
59
61
 
60
62
  def reset(self) -> None:
61
63
  """Reset the integrator."""
62
64
  self.t0 = 0
63
- self.y0 = self._y0_orig.copy()
65
+ self.y0 = self._y0_orig
64
66
 
65
67
  def integrate(
66
68
  self,
@@ -143,9 +145,13 @@ class Scipy:
143
145
 
144
146
  """
145
147
  self.reset()
146
- integ = spi.ode(self.rhs, jac=self.jacobian)
148
+
149
+ # If rhs returns a tuple, we get weird errors, so we need
150
+ # to wrap this in a list for some reason
151
+ integ = spi.ode(lambda t, x: list(self.rhs(t, x)), jac=self.jacobian)
147
152
  integ.set_integrator(name="lsoda")
148
153
  integ.set_initial_value(self.y0)
154
+
149
155
  t = self.t0 + step_size
150
156
  y1 = copy.deepcopy(self.y0)
151
157
  for _ in range(max_steps):
mxlpy/label_map.py CHANGED
@@ -27,7 +27,6 @@ from mxlpy.model import Model
27
27
  if TYPE_CHECKING:
28
28
  from collections.abc import Callable, Mapping
29
29
 
30
- from mxlpy.types import Derived
31
30
 
32
31
  __all__ = [
33
32
  "LabelMapper",
@@ -556,7 +555,7 @@ class LabelMapper:
556
555
  for name, dp in self.model.get_derived_parameters().items():
557
556
  m.add_derived(name, fn=dp.fn, args=dp.args)
558
557
 
559
- variables: dict[str, float | Derived] = {}
558
+ variables: dict[str, float] = {}
560
559
  for k, v in self.model.get_initial_conditions().items():
561
560
  if (isos := isotopomers.get(k)) is None:
562
561
  variables[k] = v