mxlpy 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/__init__.py CHANGED
@@ -42,6 +42,7 @@ from typing import TYPE_CHECKING
42
42
  import pandas as pd
43
43
 
44
44
  from . import (
45
+ compare,
45
46
  distributions,
46
47
  experimental,
47
48
  fit,
@@ -57,11 +58,7 @@ from .label_map import LabelMapper
57
58
  from .linear_label_map import LinearLabelMapper
58
59
  from .mc import Cache
59
60
  from .model import Model
60
- from .scan import (
61
- steady_state,
62
- time_course,
63
- time_course_over_protocol,
64
- )
61
+ from .scan import steady_state, time_course, time_course_over_protocol
65
62
  from .simulator import Simulator
66
63
  from .symbolic import SymbolicModel, to_symbolic_model
67
64
  from .types import Derived, IntegratorProtocol, unwrap
@@ -72,11 +69,7 @@ with contextlib.suppress(ImportError):
72
69
  if TYPE_CHECKING:
73
70
  from mxlpy.types import ArrayLike
74
71
 
75
- from . import (
76
- nn,
77
- npe,
78
- surrogates,
79
- )
72
+ from . import nn, npe, surrogates
80
73
  else:
81
74
  from lazy_import import lazy_module
82
75
 
@@ -98,6 +91,7 @@ __all__ = [
98
91
  "Simulator",
99
92
  "SymbolicModel",
100
93
  "cartesian_product",
94
+ "compare",
101
95
  "distributions",
102
96
  "experimental",
103
97
  "fit",
mxlpy/carousel.py ADDED
@@ -0,0 +1,166 @@
1
+ """Reaction carousel."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools as it
6
+ from copy import deepcopy
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING
10
+
11
+ import pandas as pd
12
+
13
+ from mxlpy import parallel, scan
14
+
15
+ __all__ = ["Carousel", "CarouselSteadyState", "CarouselTimeCourse", "ReactionTemplate"]
16
+
17
+ if TYPE_CHECKING:
18
+ from collections.abc import Iterable, Mapping
19
+
20
+ from mxlpy import Model
21
+ from mxlpy.types import Array, IntegratorType, RateFn
22
+
23
+
24
+ @dataclass
25
+ class ReactionTemplate:
26
+ """Template for a reaction in a model."""
27
+
28
+ fn: RateFn
29
+ args: list[str]
30
+ additional_parameters: dict[str, float] = field(default_factory=dict)
31
+
32
+
33
+ @dataclass
34
+ class CarouselSteadyState:
35
+ """Time course of a carousel simulation."""
36
+
37
+ carousel: list[Model]
38
+ results: list[scan.TimePoint]
39
+
40
+ def get_variables_by_model(self) -> pd.DataFrame:
41
+ """Get the variables of the time course results, indexed by model."""
42
+ return pd.DataFrame({i: r.variables for i, r in enumerate(self.results)}).T
43
+
44
+
45
+ @dataclass
46
+ class CarouselTimeCourse:
47
+ """Time course of a carousel simulation."""
48
+
49
+ carousel: list[Model]
50
+ results: list[scan.TimeCourse]
51
+
52
+ def get_variables_by_model(self) -> pd.DataFrame:
53
+ """Get the variables of the time course results, indexed by model."""
54
+ return pd.concat({i: r.variables for i, r in enumerate(self.results)})
55
+
56
+
57
+ def _dict_product[T1, T2](d: Mapping[T1, Iterable[T2]]) -> Iterable[dict[T1, T2]]:
58
+ yield from (dict(zip(d.keys(), x, strict=True)) for x in it.product(*d.values()))
59
+
60
+
61
+ def _make_reaction_carousel(
62
+ model: Model, rxns: dict[str, list[ReactionTemplate]]
63
+ ) -> Iterable[Model]:
64
+ for d in _dict_product(rxns):
65
+ new = deepcopy(model)
66
+ for rxn, template in d.items():
67
+ new.add_parameters(template.additional_parameters)
68
+ new.update_reaction(name=rxn, fn=template.fn, args=template.args)
69
+ yield new
70
+
71
+
72
+ class Carousel:
73
+ """A carousel of models with different reaction templates."""
74
+
75
+ variants: list[Model]
76
+
77
+ def __init__(
78
+ self,
79
+ model: Model,
80
+ variants: dict[str, list[ReactionTemplate]],
81
+ ) -> None:
82
+ """Initialize the carousel with a model and reaction templates."""
83
+ self.variants = list(
84
+ _make_reaction_carousel(
85
+ model=model,
86
+ rxns=variants,
87
+ )
88
+ )
89
+
90
+ def time_course(
91
+ self,
92
+ time_points: Array,
93
+ *,
94
+ y0: dict[str, float] | None = None,
95
+ integrator: IntegratorType | None = None,
96
+ ) -> CarouselTimeCourse:
97
+ """Simulate the carousel of models over a time course."""
98
+ results = [
99
+ i[1]
100
+ for i in parallel.parallelise(
101
+ partial(
102
+ scan._time_course_worker, # noqa: SLF001
103
+ time_points=time_points,
104
+ integrator=integrator,
105
+ y0=y0,
106
+ ),
107
+ list(enumerate(self.variants)),
108
+ )
109
+ ]
110
+
111
+ return CarouselTimeCourse(
112
+ carousel=self.variants,
113
+ results=results,
114
+ )
115
+
116
+ def protocol_time_course(
117
+ self,
118
+ protocol: pd.DataFrame,
119
+ *,
120
+ y0: dict[str, float] | None = None,
121
+ integrator: IntegratorType | None = None,
122
+ ) -> CarouselTimeCourse:
123
+ """Simulate the carousel of models over a protocol time course."""
124
+ results = [
125
+ i[1]
126
+ for i in parallel.parallelise(
127
+ partial(
128
+ scan._protocol_worker, # noqa: SLF001
129
+ protocol=protocol,
130
+ integrator=integrator,
131
+ y0=y0,
132
+ ),
133
+ list(enumerate(self.variants)),
134
+ )
135
+ ]
136
+
137
+ return CarouselTimeCourse(
138
+ carousel=self.variants,
139
+ results=results,
140
+ )
141
+
142
+ def steady_state(
143
+ self,
144
+ *,
145
+ y0: dict[str, float] | None = None,
146
+ integrator: IntegratorType | None = None,
147
+ rel_norm: bool = False,
148
+ ) -> CarouselSteadyState:
149
+ """Simulate the carousel of models over a time course."""
150
+ results = [
151
+ i[1]
152
+ for i in parallel.parallelise(
153
+ partial(
154
+ scan._steady_state_worker, # noqa: SLF001
155
+ integrator=integrator,
156
+ rel_norm=rel_norm,
157
+ y0=y0,
158
+ ),
159
+ list(enumerate(self.variants)),
160
+ )
161
+ ]
162
+
163
+ return CarouselSteadyState(
164
+ carousel=self.variants,
165
+ results=results,
166
+ )
mxlpy/compare.py ADDED
@@ -0,0 +1,240 @@
1
+ """Docstring."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, cast
7
+
8
+ import pandas as pd
9
+
10
+ from mxlpy import plot
11
+ from mxlpy.simulator import Result, Simulator
12
+ from mxlpy.types import unwrap
13
+
14
+ if TYPE_CHECKING:
15
+ from mxlpy.model import Model
16
+ from mxlpy.types import ArrayLike
17
+
18
+ __all__ = [
19
+ "ProtocolComparison",
20
+ "SteadyStateComparison",
21
+ "TimeCourseComparison",
22
+ "protocol_time_courses",
23
+ "steady_states",
24
+ "time_courses",
25
+ ]
26
+
27
+
28
+ @dataclass
29
+ class SteadyStateComparison:
30
+ """Compare two steady states."""
31
+
32
+ res1: Result
33
+ res2: Result
34
+
35
+ @property
36
+ def variables(self) -> pd.DataFrame:
37
+ """Compare the steady state variables."""
38
+ ss1 = self.res1.get_variables().iloc[-1]
39
+ ss2 = self.res2.get_variables().iloc[-1]
40
+ diff = ss2 - ss1
41
+ return pd.DataFrame(
42
+ {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
43
+ )
44
+
45
+ @property
46
+ def fluxes(self) -> pd.DataFrame:
47
+ """Compare the steady state fluxes."""
48
+ ss1 = self.res1.get_fluxes().iloc[-1]
49
+ ss2 = self.res2.get_fluxes().iloc[-1]
50
+ diff = ss2 - ss1
51
+ return pd.DataFrame(
52
+ {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
53
+ )
54
+
55
+ @property
56
+ def all(self) -> pd.DataFrame:
57
+ """Compare both steady-state variables and fluxes."""
58
+ ss1 = self.res1.get_combined().iloc[-1]
59
+ ss2 = self.res2.get_combined().iloc[-1]
60
+ diff = ss2 - ss1
61
+ return pd.DataFrame(
62
+ {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
63
+ )
64
+
65
+ def plot_variables(self, title: str = "Variables") -> plot.FigAxs:
66
+ """Plot the relative difference of steady-state variables."""
67
+ fig, axs = plot.bars_autogrouped(self.variables["rel_diff"], ylabel="")
68
+ plot.grid_labels(axs, ylabel="Relative difference")
69
+ fig.suptitle(title)
70
+ return fig, axs
71
+
72
+ def plot_fluxes(self, title: str = "Fluxes") -> plot.FigAxs:
73
+ """Plot the relative difference of steady-state fluxes."""
74
+ fig, axs = plot.bars_autogrouped(self.fluxes["rel_diff"], ylabel="")
75
+ plot.grid_labels(axs, ylabel="Relative difference")
76
+ fig.suptitle(title)
77
+ return fig, axs
78
+
79
+ def plot_all(self, title: str = "Variables and Fluxes") -> plot.FigAxs:
80
+ """Plot the relative difference of steady-state variables and fluxes."""
81
+ combined = self.all
82
+
83
+ fig, axs = plot.bars_autogrouped(combined["rel_diff"], ylabel="")
84
+ plot.grid_labels(axs, ylabel="Relative difference")
85
+ fig.suptitle(title)
86
+ return fig, axs
87
+
88
+
89
+ @dataclass
90
+ class TimeCourseComparison:
91
+ """Compare two time courses."""
92
+
93
+ res1: Result
94
+ res2: Result
95
+
96
+ # @property
97
+ # def variables(self) -> pd.DataFrame:
98
+ # """Compare the steady state variables."""
99
+ # ss1 = self.res1.get_variables()
100
+ # ss2 = self.res2.get_variables()
101
+ # diff = ss2 - ss1
102
+ # return pd.DataFrame(
103
+ # {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
104
+ # )
105
+
106
+ # @property
107
+ # def fluxes(self) -> pd.DataFrame:
108
+ # """Compare the steady state fluxes."""
109
+ # ss1 = self.res1.get_fluxes()
110
+ # ss2 = self.res2.get_fluxes()
111
+ # diff = ss2 - ss1
112
+ # return pd.DataFrame(
113
+ # {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
114
+ # )
115
+
116
+ def plot_variables_relative_difference(self) -> plot.FigAxs:
117
+ """Plot the relative difference of time course variables."""
118
+ c1 = self.res1.variables
119
+ c2 = self.res2.variables
120
+
121
+ rel_diff = ((c2.loc[:, cast(list[str], c1.columns)] - c1) / c1).fillna(0)
122
+ fig, axs = plot.line_autogrouped(rel_diff, ylabel="")
123
+ plot.grid_labels(axs, ylabel="Relative difference")
124
+ fig.suptitle("Variables")
125
+ return fig, axs
126
+
127
+ def plot_fluxes_relative_difference(self) -> plot.FigAxs:
128
+ """Plot the relative difference of time course fluxes."""
129
+ v1 = self.res1.fluxes
130
+ v2 = self.res2.fluxes
131
+
132
+ rel_diff = ((v2.loc[:, cast(list[str], v1.columns)] - v1) / v1).fillna(0)
133
+ fig, axs = plot.line_autogrouped(rel_diff, ylabel="")
134
+ plot.grid_labels(axs, ylabel="Relative difference")
135
+ fig.suptitle("Fluxes")
136
+ return fig, axs
137
+
138
+
139
+ @dataclass
140
+ class ProtocolComparison:
141
+ """Compare two protocol time courses."""
142
+
143
+ res1: Result
144
+ res2: Result
145
+ protocol: pd.DataFrame
146
+
147
+ # @property
148
+ # def variables(self) -> pd.DataFrame:
149
+ # """Compare the steady state variables."""
150
+ # ss1 = self.res1.get_variables()
151
+ # ss2 = self.res2.get_variables()
152
+ # diff = ss2 - ss1
153
+ # return pd.DataFrame(
154
+ # {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
155
+ # )
156
+
157
+ # @property
158
+ # def fluxes(self) -> pd.DataFrame:
159
+ # """Compare the steady state fluxes."""
160
+ # ss1 = self.res1.get_fluxes()
161
+ # ss2 = self.res2.get_fluxes()
162
+ # diff = ss2 - ss1
163
+ # return pd.DataFrame(
164
+ # {"m1": ss1, "m2": ss2, "diff": diff, "rel_diff": diff / ss1}
165
+ # )
166
+
167
+ def plot_variables_relative_difference(
168
+ self,
169
+ shade_protocol_variable: str | None = None,
170
+ ) -> plot.FigAxs:
171
+ """Plot the relative difference of time course variables."""
172
+ c1 = self.res1.variables
173
+ c2 = self.res2.variables
174
+
175
+ rel_diff = ((c2.loc[:, cast(list[str], c1.columns)] - c1) / c1).fillna(0)
176
+ fig, axs = plot.line_autogrouped(rel_diff, ylabel="")
177
+ plot.grid_labels(axs, ylabel="Relative difference")
178
+ fig.suptitle("Variables")
179
+
180
+ if shade_protocol_variable is not None:
181
+ protocol = self.protocol[shade_protocol_variable]
182
+ for ax in axs:
183
+ plot.shade_protocol(protocol=protocol, ax=ax)
184
+ return fig, axs
185
+
186
+ def plot_fluxes_relative_difference(
187
+ self,
188
+ shade_protocol_variable: str | None = None,
189
+ ) -> plot.FigAxs:
190
+ """Plot the relative difference of time course fluxes."""
191
+ v1 = self.res1.fluxes
192
+ v2 = self.res2.fluxes
193
+
194
+ rel_diff = ((v2.loc[:, cast(list[str], v1.columns)] - v1) / v1).fillna(0)
195
+ fig, axs = plot.line_autogrouped(rel_diff, ylabel="")
196
+ plot.grid_labels(axs, ylabel="Relative difference")
197
+ fig.suptitle("Fluxes")
198
+
199
+ if shade_protocol_variable is not None:
200
+ protocol = self.protocol[shade_protocol_variable]
201
+ for ax in axs:
202
+ plot.shade_protocol(protocol=protocol, ax=ax)
203
+ return fig, axs
204
+
205
+
206
+ def steady_states(m1: Model, m2: Model) -> SteadyStateComparison:
207
+ """Compare the steady states of two models."""
208
+ return SteadyStateComparison(
209
+ res1=unwrap(Simulator(m1).simulate_to_steady_state().get_result()),
210
+ res2=unwrap(Simulator(m2).simulate_to_steady_state().get_result()),
211
+ )
212
+
213
+
214
+ def time_courses(m1: Model, m2: Model, time_points: ArrayLike) -> TimeCourseComparison:
215
+ """Compare the time courses of two models."""
216
+ return TimeCourseComparison(
217
+ res1=unwrap(
218
+ Simulator(m1).simulate_time_course(time_points=time_points).get_result()
219
+ ),
220
+ res2=unwrap(
221
+ Simulator(m2).simulate_time_course(time_points=time_points).get_result()
222
+ ),
223
+ )
224
+
225
+
226
+ def protocol_time_courses(
227
+ m1: Model,
228
+ m2: Model,
229
+ protocol: pd.DataFrame,
230
+ ) -> ProtocolComparison:
231
+ """Compare the time courses of two models."""
232
+ return ProtocolComparison(
233
+ res1=unwrap(
234
+ Simulator(m1).simulate_over_protocol(protocol=protocol).get_result()
235
+ ),
236
+ res2=unwrap(
237
+ Simulator(m2).simulate_over_protocol(protocol=protocol).get_result()
238
+ ),
239
+ protocol=protocol,
240
+ )
@@ -1,12 +1,24 @@
1
1
  """Diffing utilities for comparing models."""
2
2
 
3
- from collections.abc import Mapping
3
+ from __future__ import annotations
4
+
4
5
  from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING
5
7
 
6
- from mxlpy.model import Model
7
8
  from mxlpy.types import Derived
8
9
 
9
- __all__ = ["DerivedDiff", "ModelDiff", "ReactionDiff", "model_diff", "soft_eq"]
10
+ if TYPE_CHECKING:
11
+ from collections.abc import Mapping
12
+
13
+ from mxlpy.model import Model
14
+
15
+ __all__ = [
16
+ "DerivedDiff",
17
+ "ModelDiff",
18
+ "ReactionDiff",
19
+ "model_diff",
20
+ "soft_eq",
21
+ ]
10
22
 
11
23
 
12
24
  @dataclass