mxlpy 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +2 -4
- mxlpy/carousel.py +46 -1
- mxlpy/compare.py +13 -0
- mxlpy/experimental/diff.py +14 -0
- mxlpy/fit.py +20 -1
- mxlpy/mc.py +63 -2
- mxlpy/meta/__init__.py +6 -1
- mxlpy/meta/codegen_latex.py +9 -0
- mxlpy/meta/codegen_model.py +53 -11
- mxlpy/meta/codegen_mxlpy.py +21 -0
- mxlpy/meta/source_tools.py +5 -0
- mxlpy/meta/sympy_tools.py +7 -1
- mxlpy/model.py +26 -5
- mxlpy/plot.py +44 -15
- mxlpy/sbml/_data.py +34 -0
- mxlpy/sbml/_export.py +4 -3
- mxlpy/scan.py +125 -7
- mxlpy/simulator.py +5 -0
- mxlpy/types.py +73 -2
- {mxlpy-0.23.0.dist-info → mxlpy-0.24.0.dist-info}/METADATA +5 -2
- {mxlpy-0.23.0.dist-info → mxlpy-0.24.0.dist-info}/RECORD +23 -23
- {mxlpy-0.23.0.dist-info → mxlpy-0.24.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.23.0.dist-info → mxlpy-0.24.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/__init__.py
CHANGED
@@ -52,6 +52,7 @@ from . import (
|
|
52
52
|
plot,
|
53
53
|
report,
|
54
54
|
sbml,
|
55
|
+
scan,
|
55
56
|
units,
|
56
57
|
)
|
57
58
|
from .integrators import DefaultIntegrator, Diffrax, Scipy
|
@@ -59,7 +60,6 @@ from .label_map import LabelMapper
|
|
59
60
|
from .linear_label_map import LinearLabelMapper
|
60
61
|
from .mc import Cache
|
61
62
|
from .model import Model
|
62
|
-
from .scan import steady_state, time_course, time_course_over_protocol
|
63
63
|
from .simulator import Simulator
|
64
64
|
from .symbolic import SymbolicModel, to_symbolic_model
|
65
65
|
from .types import (
|
@@ -116,11 +116,9 @@ __all__ = [
|
|
116
116
|
"plot",
|
117
117
|
"report",
|
118
118
|
"sbml",
|
119
|
-
"
|
119
|
+
"scan",
|
120
120
|
"surrogates",
|
121
121
|
"symbolic",
|
122
|
-
"time_course",
|
123
|
-
"time_course_over_protocol",
|
124
122
|
"to_symbolic_model",
|
125
123
|
"units",
|
126
124
|
"unwrap",
|
mxlpy/carousel.py
CHANGED
@@ -9,6 +9,7 @@ from functools import partial
|
|
9
9
|
from typing import TYPE_CHECKING
|
10
10
|
|
11
11
|
import pandas as pd
|
12
|
+
from wadler_lindig import pformat
|
12
13
|
|
13
14
|
from mxlpy import parallel, scan
|
14
15
|
|
@@ -25,6 +26,10 @@ if TYPE_CHECKING:
|
|
25
26
|
class ReactionTemplate:
|
26
27
|
"""Template for a reaction in a model."""
|
27
28
|
|
29
|
+
def __repr__(self) -> str:
|
30
|
+
"""Return default representation."""
|
31
|
+
return pformat(self)
|
32
|
+
|
28
33
|
fn: RateFn
|
29
34
|
args: list[str]
|
30
35
|
additional_parameters: dict[str, float] = field(default_factory=dict)
|
@@ -37,6 +42,10 @@ class CarouselSteadyState:
|
|
37
42
|
carousel: list[Model]
|
38
43
|
results: list[Result]
|
39
44
|
|
45
|
+
def __repr__(self) -> str:
|
46
|
+
"""Return default representation."""
|
47
|
+
return pformat(self)
|
48
|
+
|
40
49
|
def get_variables_by_model(self) -> pd.DataFrame:
|
41
50
|
"""Get the variables of the time course results, indexed by model."""
|
42
51
|
return pd.DataFrame(
|
@@ -51,6 +60,10 @@ class CarouselTimeCourse:
|
|
51
60
|
carousel: list[Model]
|
52
61
|
results: list[Result]
|
53
62
|
|
63
|
+
def __repr__(self) -> str:
|
64
|
+
"""Return default representation."""
|
65
|
+
return pformat(self)
|
66
|
+
|
54
67
|
def get_variables_by_model(self) -> pd.DataFrame:
|
55
68
|
"""Get the variables of the time course results, indexed by model."""
|
56
69
|
return pd.concat({i: r.variables for i, r in enumerate(self.results)})
|
@@ -76,6 +89,10 @@ class Carousel:
|
|
76
89
|
|
77
90
|
variants: list[Model]
|
78
91
|
|
92
|
+
def __repr__(self) -> str:
|
93
|
+
"""Return default representation."""
|
94
|
+
return pformat(self)
|
95
|
+
|
79
96
|
def __init__(
|
80
97
|
self,
|
81
98
|
model: Model,
|
@@ -115,7 +132,7 @@ class Carousel:
|
|
115
132
|
results=results,
|
116
133
|
)
|
117
134
|
|
118
|
-
def
|
135
|
+
def protocol(
|
119
136
|
self,
|
120
137
|
protocol: pd.DataFrame,
|
121
138
|
*,
|
@@ -141,6 +158,34 @@ class Carousel:
|
|
141
158
|
results=results,
|
142
159
|
)
|
143
160
|
|
161
|
+
def protocol_time_course(
|
162
|
+
self,
|
163
|
+
protocol: pd.DataFrame,
|
164
|
+
time_points: Array,
|
165
|
+
*,
|
166
|
+
y0: dict[str, float] | None = None,
|
167
|
+
integrator: IntegratorType | None = None,
|
168
|
+
) -> CarouselTimeCourse:
|
169
|
+
"""Simulate the carousel of models over a protocol time course."""
|
170
|
+
results = [
|
171
|
+
i[1]
|
172
|
+
for i in parallel.parallelise(
|
173
|
+
partial(
|
174
|
+
scan._protocol_time_course_worker, # noqa: SLF001
|
175
|
+
protocol=protocol,
|
176
|
+
integrator=integrator,
|
177
|
+
time_points=time_points,
|
178
|
+
y0=y0,
|
179
|
+
),
|
180
|
+
list(enumerate(self.variants)),
|
181
|
+
)
|
182
|
+
]
|
183
|
+
|
184
|
+
return CarouselTimeCourse(
|
185
|
+
carousel=self.variants,
|
186
|
+
results=results,
|
187
|
+
)
|
188
|
+
|
144
189
|
def steady_state(
|
145
190
|
self,
|
146
191
|
*,
|
mxlpy/compare.py
CHANGED
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING, cast
|
7
7
|
|
8
8
|
import pandas as pd
|
9
|
+
from wadler_lindig import pformat
|
9
10
|
|
10
11
|
from mxlpy import plot
|
11
12
|
from mxlpy.simulator import Simulator
|
@@ -32,6 +33,10 @@ class SteadyStateComparison:
|
|
32
33
|
res1: Result
|
33
34
|
res2: Result
|
34
35
|
|
36
|
+
def __repr__(self) -> str:
|
37
|
+
"""Return default representation."""
|
38
|
+
return pformat(self)
|
39
|
+
|
35
40
|
@property
|
36
41
|
def variables(self) -> pd.DataFrame:
|
37
42
|
"""Compare the steady state variables."""
|
@@ -93,6 +98,10 @@ class TimeCourseComparison:
|
|
93
98
|
res1: Result
|
94
99
|
res2: Result
|
95
100
|
|
101
|
+
def __repr__(self) -> str:
|
102
|
+
"""Return default representation."""
|
103
|
+
return pformat(self)
|
104
|
+
|
96
105
|
# @property
|
97
106
|
# def variables(self) -> pd.DataFrame:
|
98
107
|
# """Compare the steady state variables."""
|
@@ -144,6 +153,10 @@ class ProtocolComparison:
|
|
144
153
|
res2: Result
|
145
154
|
protocol: pd.DataFrame
|
146
155
|
|
156
|
+
def __repr__(self) -> str:
|
157
|
+
"""Return default representation."""
|
158
|
+
return pformat(self)
|
159
|
+
|
147
160
|
# @property
|
148
161
|
# def variables(self) -> pd.DataFrame:
|
149
162
|
# """Compare the steady state variables."""
|
mxlpy/experimental/diff.py
CHANGED
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|
5
5
|
from dataclasses import dataclass, field
|
6
6
|
from typing import TYPE_CHECKING
|
7
7
|
|
8
|
+
from wadler_lindig import pformat
|
9
|
+
|
8
10
|
from mxlpy.types import Derived
|
9
11
|
|
10
12
|
if TYPE_CHECKING:
|
@@ -28,6 +30,10 @@ class DerivedDiff:
|
|
28
30
|
args1: list[str] = field(default_factory=list)
|
29
31
|
args2: list[str] = field(default_factory=list)
|
30
32
|
|
33
|
+
def __repr__(self) -> str:
|
34
|
+
"""Return default representation."""
|
35
|
+
return pformat(self)
|
36
|
+
|
31
37
|
|
32
38
|
@dataclass
|
33
39
|
class ReactionDiff:
|
@@ -38,6 +44,10 @@ class ReactionDiff:
|
|
38
44
|
stoichiometry1: dict[str, float | Derived] = field(default_factory=dict)
|
39
45
|
stoichiometry2: dict[str, float | Derived] = field(default_factory=dict)
|
40
46
|
|
47
|
+
def __repr__(self) -> str:
|
48
|
+
"""Return default representation."""
|
49
|
+
return pformat(self)
|
50
|
+
|
41
51
|
|
42
52
|
@dataclass
|
43
53
|
class ModelDiff:
|
@@ -56,6 +66,10 @@ class ModelDiff:
|
|
56
66
|
different_readouts: dict[str, DerivedDiff] = field(default_factory=dict)
|
57
67
|
different_derived: dict[str, DerivedDiff] = field(default_factory=dict)
|
58
68
|
|
69
|
+
def __repr__(self) -> str:
|
70
|
+
"""Return default representation."""
|
71
|
+
return pformat(self)
|
72
|
+
|
59
73
|
def __str__(self) -> str:
|
60
74
|
"""Return a human-readable string representation of the diff."""
|
61
75
|
content = ["Model Diff", "----------"]
|
mxlpy/fit.py
CHANGED
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Protocol
|
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
from scipy.optimize import minimize
|
21
|
+
from wadler_lindig import pformat
|
21
22
|
|
22
23
|
from mxlpy import parallel
|
23
24
|
from mxlpy.simulator import Simulator
|
@@ -61,6 +62,10 @@ class MinResult:
|
|
61
62
|
parameters: dict[str, float]
|
62
63
|
residual: float
|
63
64
|
|
65
|
+
def __repr__(self) -> str:
|
66
|
+
"""Return default representation."""
|
67
|
+
return pformat(self)
|
68
|
+
|
64
69
|
|
65
70
|
@dataclass
|
66
71
|
class FitResult:
|
@@ -70,6 +75,10 @@ class FitResult:
|
|
70
75
|
best_pars: dict[str, float]
|
71
76
|
loss: float
|
72
77
|
|
78
|
+
def __repr__(self) -> str:
|
79
|
+
"""Return default representation."""
|
80
|
+
return pformat(self)
|
81
|
+
|
73
82
|
|
74
83
|
@dataclass
|
75
84
|
class CarouselFit:
|
@@ -77,6 +86,10 @@ class CarouselFit:
|
|
77
86
|
|
78
87
|
fits: list[FitResult]
|
79
88
|
|
89
|
+
def __repr__(self) -> str:
|
90
|
+
"""Return default representation."""
|
91
|
+
return pformat(self)
|
92
|
+
|
80
93
|
def get_best_fit(self) -> FitResult:
|
81
94
|
"""Get the best fit from the carousel."""
|
82
95
|
return min(self.fits, key=lambda x: x.loss)
|
@@ -561,6 +574,8 @@ def protocol_time_course(
|
|
561
574
|
) -> FitResult | None:
|
562
575
|
"""Fit model parameters to time course of experimental data.
|
563
576
|
|
577
|
+
Time points of protocol time course are taken from the data.
|
578
|
+
|
564
579
|
Examples:
|
565
580
|
>>> time_course(model, p0, data)
|
566
581
|
{'k1': 0.1, 'k2': 0.2}
|
@@ -688,8 +703,10 @@ def carousel_time_course(
|
|
688
703
|
) -> CarouselFit:
|
689
704
|
"""Fit model parameters to time course of experimental data over a carousel.
|
690
705
|
|
706
|
+
Time points are taken from the data.
|
707
|
+
|
691
708
|
Examples:
|
692
|
-
>>>
|
709
|
+
>>> carousel_time_course(carousel, p0=p0, data=data)
|
693
710
|
|
694
711
|
Args:
|
695
712
|
carousel: Model carousel to fit
|
@@ -748,6 +765,8 @@ def carousel_protocol_time_course(
|
|
748
765
|
) -> CarouselFit:
|
749
766
|
"""Fit model parameters to time course of experimental data over a protocol.
|
750
767
|
|
768
|
+
Time points of protocol time course are taken from the data.
|
769
|
+
|
751
770
|
Examples:
|
752
771
|
>>> carousel_steady_state(carousel, p0=p0, data=data)
|
753
772
|
|
mxlpy/mc.py
CHANGED
@@ -24,9 +24,11 @@ import pandas as pd
|
|
24
24
|
from mxlpy import mca, scan
|
25
25
|
from mxlpy.parallel import Cache, parallelise
|
26
26
|
from mxlpy.scan import (
|
27
|
+
ProtocolTimeCourseWorker,
|
27
28
|
ProtocolWorker,
|
28
29
|
SteadyStateWorker,
|
29
30
|
TimeCourseWorker,
|
31
|
+
_protocol_time_course_worker,
|
30
32
|
_protocol_worker,
|
31
33
|
_steady_state_worker,
|
32
34
|
_time_course_worker,
|
@@ -49,11 +51,12 @@ if TYPE_CHECKING:
|
|
49
51
|
__all__ = [
|
50
52
|
"ParameterScanWorker",
|
51
53
|
"parameter_elasticities",
|
54
|
+
"protocol",
|
55
|
+
"protocol_time_course",
|
52
56
|
"response_coefficients",
|
53
57
|
"scan_steady_state",
|
54
58
|
"steady_state",
|
55
59
|
"time_course",
|
56
|
-
"time_course_over_protocol",
|
57
60
|
"variable_elasticities",
|
58
61
|
]
|
59
62
|
|
@@ -229,7 +232,7 @@ def time_course(
|
|
229
232
|
)
|
230
233
|
|
231
234
|
|
232
|
-
def
|
235
|
+
def protocol(
|
233
236
|
model: Model,
|
234
237
|
*,
|
235
238
|
protocol: pd.DataFrame,
|
@@ -287,6 +290,64 @@ def time_course_over_protocol(
|
|
287
290
|
)
|
288
291
|
|
289
292
|
|
293
|
+
def protocol_time_course(
|
294
|
+
model: Model,
|
295
|
+
*,
|
296
|
+
protocol: pd.DataFrame,
|
297
|
+
time_points: Array,
|
298
|
+
mc_to_scan: pd.DataFrame,
|
299
|
+
y0: dict[str, float] | None = None,
|
300
|
+
max_workers: int | None = None,
|
301
|
+
cache: Cache | None = None,
|
302
|
+
worker: ProtocolTimeCourseWorker = _protocol_time_course_worker,
|
303
|
+
integrator: IntegratorType | None = None,
|
304
|
+
) -> ProtocolScan:
|
305
|
+
"""MC time course.
|
306
|
+
|
307
|
+
Examples:
|
308
|
+
>>> protocol_time_course(model, protocol, time_points, mc_to_scan)
|
309
|
+
p t x y
|
310
|
+
0 0.0 0.1 0.00
|
311
|
+
1.0 0.2 0.01
|
312
|
+
2.0 0.3 0.02
|
313
|
+
3.0 0.4 0.03
|
314
|
+
... ... ...
|
315
|
+
1 0.0 0.1 0.00
|
316
|
+
1.0 0.2 0.01
|
317
|
+
2.0 0.3 0.02
|
318
|
+
3.0 0.4 0.03
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
tuple[concentrations, fluxes] using pandas multiindex
|
322
|
+
Both dataframes are of shape (#time_points * #mc_to_scan, #variables)
|
323
|
+
|
324
|
+
"""
|
325
|
+
if y0 is not None:
|
326
|
+
model.update_variables(y0)
|
327
|
+
|
328
|
+
res = parallelise(
|
329
|
+
partial(
|
330
|
+
_update_parameters_and_initial_conditions,
|
331
|
+
fn=partial(
|
332
|
+
worker,
|
333
|
+
protocol=protocol,
|
334
|
+
time_points=time_points,
|
335
|
+
integrator=integrator,
|
336
|
+
y0=None,
|
337
|
+
),
|
338
|
+
model=model,
|
339
|
+
),
|
340
|
+
inputs=list(mc_to_scan.iterrows()),
|
341
|
+
max_workers=max_workers,
|
342
|
+
cache=cache,
|
343
|
+
)
|
344
|
+
return ProtocolScan(
|
345
|
+
to_scan=mc_to_scan,
|
346
|
+
protocol=protocol,
|
347
|
+
raw_results=dict(res),
|
348
|
+
)
|
349
|
+
|
350
|
+
|
290
351
|
def scan_steady_state(
|
291
352
|
model: Model,
|
292
353
|
*,
|
mxlpy/meta/__init__.py
CHANGED
@@ -3,13 +3,18 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
from .codegen_latex import generate_latex_code, to_tex_export
|
6
|
-
from .codegen_model import
|
6
|
+
from .codegen_model import (
|
7
|
+
generate_model_code_py,
|
8
|
+
generate_model_code_rs,
|
9
|
+
generate_model_code_ts,
|
10
|
+
)
|
7
11
|
from .codegen_mxlpy import generate_mxlpy_code
|
8
12
|
|
9
13
|
__all__ = [
|
10
14
|
"generate_latex_code",
|
11
15
|
"generate_model_code_py",
|
12
16
|
"generate_model_code_rs",
|
17
|
+
"generate_model_code_ts",
|
13
18
|
"generate_mxlpy_code",
|
14
19
|
"to_tex_export",
|
15
20
|
]
|
mxlpy/meta/codegen_latex.py
CHANGED
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|
6
6
|
from typing import TYPE_CHECKING
|
7
7
|
|
8
8
|
import sympy
|
9
|
+
from wadler_lindig import pformat
|
9
10
|
|
10
11
|
from mxlpy.meta.sympy_tools import fn_to_sympy, list_of_symbols
|
11
12
|
from mxlpy.types import Derived, RateFn
|
@@ -358,6 +359,10 @@ class TexReaction:
|
|
358
359
|
fn: RateFn
|
359
360
|
args: list[str]
|
360
361
|
|
362
|
+
def __repr__(self) -> str:
|
363
|
+
"""Return default representation."""
|
364
|
+
return pformat(self)
|
365
|
+
|
361
366
|
|
362
367
|
@dataclass
|
363
368
|
class TexExport:
|
@@ -397,6 +402,10 @@ class TexExport:
|
|
397
402
|
reactions: dict[str, TexReaction]
|
398
403
|
diff_eqs: dict[str, Mapping[str, float | Derived]]
|
399
404
|
|
405
|
+
def __repr__(self) -> str:
|
406
|
+
"""Return default representation."""
|
407
|
+
return pformat(self)
|
408
|
+
|
400
409
|
@staticmethod
|
401
410
|
def _diff_parameters(
|
402
411
|
p1: dict[str, float],
|
mxlpy/meta/codegen_model.py
CHANGED
@@ -9,6 +9,7 @@ from mxlpy.meta.sympy_tools import (
|
|
9
9
|
fn_to_sympy,
|
10
10
|
list_of_symbols,
|
11
11
|
stoichiometries_to_sympy,
|
12
|
+
sympy_to_inline_js,
|
12
13
|
sympy_to_inline_py,
|
13
14
|
sympy_to_inline_rust,
|
14
15
|
)
|
@@ -23,6 +24,7 @@ if TYPE_CHECKING:
|
|
23
24
|
__all__ = [
|
24
25
|
"generate_model_code_py",
|
25
26
|
"generate_model_code_rs",
|
27
|
+
"generate_model_code_ts",
|
26
28
|
]
|
27
29
|
|
28
30
|
_LOGGER = logging.getLogger(__name__)
|
@@ -37,6 +39,7 @@ def _generate_model_code(
|
|
37
39
|
assignment_template: str,
|
38
40
|
sympy_inline_fn: Callable[[sympy.Expr], str],
|
39
41
|
return_template: str,
|
42
|
+
custom_fns: dict[str, sympy.Expr],
|
40
43
|
imports: list[str] | None = None,
|
41
44
|
end: str | None = None,
|
42
45
|
free_parameters: list[str] | None = None,
|
@@ -70,11 +73,13 @@ def _generate_model_code(
|
|
70
73
|
|
71
74
|
# Derived
|
72
75
|
for name, derived in model.get_raw_derived().items():
|
73
|
-
expr =
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
76
|
+
expr = custom_fns.get(name)
|
77
|
+
if expr is None:
|
78
|
+
expr = fn_to_sympy(
|
79
|
+
derived.fn,
|
80
|
+
origin=name,
|
81
|
+
model_args=list_of_symbols(derived.args),
|
82
|
+
)
|
78
83
|
if expr is None:
|
79
84
|
msg = f"Unable to parse fn for derived value '{name}'"
|
80
85
|
raise ValueError(msg)
|
@@ -82,11 +87,16 @@ def _generate_model_code(
|
|
82
87
|
|
83
88
|
# Reactions
|
84
89
|
for name, rxn in model.get_raw_reactions().items():
|
85
|
-
expr =
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
+
expr = custom_fns.get(name)
|
91
|
+
if expr is None:
|
92
|
+
try:
|
93
|
+
expr = fn_to_sympy(
|
94
|
+
rxn.fn,
|
95
|
+
origin=name,
|
96
|
+
model_args=list_of_symbols(rxn.args),
|
97
|
+
)
|
98
|
+
except KeyError:
|
99
|
+
_LOGGER.warning("Failed to parse %s", name)
|
90
100
|
if expr is None:
|
91
101
|
msg = f"Unable to parse fn for reaction value '{name}'"
|
92
102
|
raise ValueError(msg)
|
@@ -123,6 +133,7 @@ def _generate_model_code(
|
|
123
133
|
|
124
134
|
def generate_model_code_py(
|
125
135
|
model: Model,
|
136
|
+
custom_fns: dict[str, sympy.Expr] | None = None,
|
126
137
|
free_parameters: list[str] | None = None,
|
127
138
|
) -> str:
|
128
139
|
"""Transform the model into a python function, inlining the function calls."""
|
@@ -137,21 +148,51 @@ def generate_model_code_py(
|
|
137
148
|
return _generate_model_code(
|
138
149
|
model,
|
139
150
|
imports=[
|
151
|
+
"import math\n",
|
140
152
|
"from collections.abc import Iterable\n",
|
141
153
|
],
|
142
154
|
sized=False,
|
143
155
|
model_fn=model_fn,
|
144
156
|
variables_template=" {} = variables",
|
145
|
-
assignment_template=" {k} = {v}",
|
157
|
+
assignment_template=" {k}: float = {v}",
|
146
158
|
sympy_inline_fn=sympy_to_inline_py,
|
147
159
|
return_template=" return {}",
|
148
160
|
end=None,
|
149
161
|
free_parameters=free_parameters,
|
162
|
+
custom_fns={} if custom_fns is None else custom_fns,
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
def generate_model_code_ts(
|
167
|
+
model: Model,
|
168
|
+
custom_fns: dict[str, sympy.Expr] | None = None,
|
169
|
+
free_parameters: list[str] | None = None,
|
170
|
+
) -> str:
|
171
|
+
"""Transform the model into a typescript function, inlining the function calls."""
|
172
|
+
if free_parameters is None:
|
173
|
+
model_fn = "function model(time: number, variables: number[]) {"
|
174
|
+
else:
|
175
|
+
args = ", ".join(f"{k}: number" for k in free_parameters)
|
176
|
+
model_fn = f"function model(time: number, variables: number[], {args}) {{"
|
177
|
+
|
178
|
+
return _generate_model_code(
|
179
|
+
model,
|
180
|
+
imports=[],
|
181
|
+
sized=False,
|
182
|
+
model_fn=model_fn,
|
183
|
+
variables_template=" let [{}] = variables;",
|
184
|
+
assignment_template=" let {k}: number = {v};",
|
185
|
+
sympy_inline_fn=sympy_to_inline_js,
|
186
|
+
return_template=" return [{}];",
|
187
|
+
end="};",
|
188
|
+
free_parameters=free_parameters,
|
189
|
+
custom_fns={} if custom_fns is None else custom_fns,
|
150
190
|
)
|
151
191
|
|
152
192
|
|
153
193
|
def generate_model_code_rs(
|
154
194
|
model: Model,
|
195
|
+
custom_fns: dict[str, sympy.Expr] | None = None,
|
155
196
|
free_parameters: list[str] | None = None,
|
156
197
|
) -> str:
|
157
198
|
"""Transform the model into a rust function, inlining the function calls."""
|
@@ -172,4 +213,5 @@ def generate_model_code_rs(
|
|
172
213
|
return_template=" return [{}]",
|
173
214
|
end="}",
|
174
215
|
free_parameters=free_parameters,
|
216
|
+
custom_fns={} if custom_fns is None else custom_fns,
|
175
217
|
)
|
mxlpy/meta/codegen_mxlpy.py
CHANGED
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|
7
7
|
from typing import TYPE_CHECKING, cast
|
8
8
|
|
9
9
|
import sympy
|
10
|
+
from wadler_lindig import pformat
|
10
11
|
|
11
12
|
from mxlpy.meta.sympy_tools import (
|
12
13
|
fn_to_sympy,
|
@@ -43,6 +44,10 @@ class SymbolicFn:
|
|
43
44
|
expr: sympy.Expr
|
44
45
|
args: list[str]
|
45
46
|
|
47
|
+
def __repr__(self) -> str:
|
48
|
+
"""Return default representation."""
|
49
|
+
return pformat(self)
|
50
|
+
|
46
51
|
|
47
52
|
@dataclass
|
48
53
|
class SymbolicVariable:
|
@@ -51,6 +56,10 @@ class SymbolicVariable:
|
|
51
56
|
value: sympy.Float | SymbolicFn # initial assignment
|
52
57
|
unit: Quantity | None
|
53
58
|
|
59
|
+
def __repr__(self) -> str:
|
60
|
+
"""Return default representation."""
|
61
|
+
return pformat(self)
|
62
|
+
|
54
63
|
|
55
64
|
@dataclass
|
56
65
|
class SymbolicParameter:
|
@@ -59,6 +68,10 @@ class SymbolicParameter:
|
|
59
68
|
value: sympy.Float | SymbolicFn # initial assignment
|
60
69
|
unit: Quantity | None
|
61
70
|
|
71
|
+
def __repr__(self) -> str:
|
72
|
+
"""Return default representation."""
|
73
|
+
return pformat(self)
|
74
|
+
|
62
75
|
|
63
76
|
@dataclass
|
64
77
|
class SymbolicReaction:
|
@@ -67,6 +80,10 @@ class SymbolicReaction:
|
|
67
80
|
fn: SymbolicFn
|
68
81
|
stoichiometry: dict[str, sympy.Float | str | SymbolicFn]
|
69
82
|
|
83
|
+
def __repr__(self) -> str:
|
84
|
+
"""Return default representation."""
|
85
|
+
return pformat(self)
|
86
|
+
|
70
87
|
|
71
88
|
@dataclass
|
72
89
|
class SymbolicRepr:
|
@@ -77,6 +94,10 @@ class SymbolicRepr:
|
|
77
94
|
derived: dict[str, SymbolicFn] = field(default_factory=dict)
|
78
95
|
reactions: dict[str, SymbolicReaction] = field(default_factory=dict)
|
79
96
|
|
97
|
+
def __repr__(self) -> str:
|
98
|
+
"""Return default representation."""
|
99
|
+
return pformat(self)
|
100
|
+
|
80
101
|
|
81
102
|
def _fn_to_symbolic_repr(k: str, fn: Callable, model_args: list[str]) -> SymbolicFn:
|
82
103
|
fn_name = fn.__name__
|
mxlpy/meta/source_tools.py
CHANGED
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, cast
|
|
15
15
|
import dill
|
16
16
|
import numpy as np
|
17
17
|
import sympy
|
18
|
+
from wadler_lindig import pformat
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
20
21
|
from collections.abc import Callable
|
@@ -174,6 +175,10 @@ class Context:
|
|
174
175
|
modules: dict[str, ModuleType]
|
175
176
|
fns: dict[str, Callable]
|
176
177
|
|
178
|
+
def __repr__(self) -> str:
|
179
|
+
"""Return default representation."""
|
180
|
+
return pformat(self)
|
181
|
+
|
177
182
|
def updated(
|
178
183
|
self,
|
179
184
|
symbols: dict[str, sympy.Symbol | sympy.Expr] | None = None,
|
mxlpy/meta/sympy_tools.py
CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
5
5
|
from typing import TYPE_CHECKING, cast
|
6
6
|
|
7
7
|
import sympy
|
8
|
-
from sympy.printing import rust_code
|
8
|
+
from sympy.printing import jscode, rust_code
|
9
9
|
from sympy.printing.pycode import pycode
|
10
10
|
|
11
11
|
from mxlpy.meta.source_tools import fn_to_sympy
|
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|
17
17
|
__all__ = [
|
18
18
|
"list_of_symbols",
|
19
19
|
"stoichiometries_to_sympy",
|
20
|
+
"sympy_to_inline_js",
|
20
21
|
"sympy_to_inline_py",
|
21
22
|
"sympy_to_inline_rust",
|
22
23
|
"sympy_to_python_fn",
|
@@ -53,6 +54,11 @@ def sympy_to_inline_py(expr: sympy.Expr) -> str:
|
|
53
54
|
return cast(str, pycode(expr, fully_qualified_modules=True, full_prec=False))
|
54
55
|
|
55
56
|
|
57
|
+
def sympy_to_inline_js(expr: sympy.Expr) -> str:
|
58
|
+
"""Create rust code from sympy expression."""
|
59
|
+
return cast(str, jscode(expr, full_prec=False))
|
60
|
+
|
61
|
+
|
56
62
|
def sympy_to_inline_rust(expr: sympy.Expr) -> str:
|
57
63
|
"""Create rust code from sympy expression."""
|
58
64
|
return cast(str, rust_code(expr, full_prec=False))
|