mxlpy 0.22.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +13 -6
- mxlpy/carousel.py +52 -5
- mxlpy/compare.py +15 -2
- mxlpy/experimental/diff.py +14 -0
- mxlpy/fit.py +20 -1
- mxlpy/integrators/__init__.py +4 -0
- mxlpy/integrators/int_assimulo.py +3 -3
- mxlpy/integrators/int_diffrax.py +119 -0
- mxlpy/integrators/int_scipy.py +12 -6
- mxlpy/label_map.py +1 -2
- mxlpy/mc.py +85 -24
- mxlpy/mca.py +8 -4
- mxlpy/meta/__init__.py +6 -1
- mxlpy/meta/codegen_latex.py +9 -0
- mxlpy/meta/codegen_model.py +55 -12
- mxlpy/meta/codegen_mxlpy.py +215 -58
- mxlpy/meta/source_tools.py +129 -80
- mxlpy/meta/sympy_tools.py +12 -6
- mxlpy/model.py +314 -96
- mxlpy/plot.py +60 -29
- mxlpy/sbml/_data.py +34 -0
- mxlpy/sbml/_export.py +17 -8
- mxlpy/sbml/_import.py +68 -547
- mxlpy/scan.py +163 -249
- mxlpy/simulator.py +9 -359
- mxlpy/types.py +723 -80
- mxlpy/units.py +5 -0
- {mxlpy-0.22.0.dist-info → mxlpy-0.24.0.dist-info}/METADATA +7 -1
- mxlpy-0.24.0.dist-info/RECORD +57 -0
- mxlpy/sbml/_mathml.py +0 -692
- mxlpy/sbml/_unit_conversion.py +0 -74
- mxlpy-0.22.0.dist-info/RECORD +0 -58
- {mxlpy-0.22.0.dist-info → mxlpy-0.24.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.22.0.dist-info → mxlpy-0.24.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/meta/codegen_mxlpy.py
CHANGED
@@ -3,94 +3,246 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import logging
|
6
|
-
from
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from typing import TYPE_CHECKING, cast
|
7
8
|
|
8
|
-
|
9
|
-
from
|
9
|
+
import sympy
|
10
|
+
from wadler_lindig import pformat
|
11
|
+
|
12
|
+
from mxlpy.meta.sympy_tools import (
|
13
|
+
fn_to_sympy,
|
14
|
+
list_of_symbols,
|
15
|
+
sympy_to_inline_py,
|
16
|
+
sympy_to_python_fn,
|
17
|
+
)
|
18
|
+
from mxlpy.types import Derived, InitialAssignment
|
19
|
+
from mxlpy.units import Quantity
|
10
20
|
|
11
21
|
if TYPE_CHECKING:
|
12
|
-
import
|
22
|
+
from collections.abc import Callable
|
13
23
|
|
14
24
|
from mxlpy.model import Model
|
15
25
|
|
16
26
|
__all__ = [
|
27
|
+
"SymbolicFn",
|
28
|
+
"SymbolicParameter",
|
29
|
+
"SymbolicReaction",
|
30
|
+
"SymbolicRepr",
|
31
|
+
"SymbolicVariable",
|
17
32
|
"generate_mxlpy_code",
|
33
|
+
"generate_mxlpy_code_from_symbolic_repr",
|
18
34
|
]
|
19
35
|
|
20
36
|
_LOGGER = logging.getLogger()
|
21
37
|
|
22
38
|
|
23
|
-
|
24
|
-
|
25
|
-
|
39
|
+
@dataclass
|
40
|
+
class SymbolicFn:
|
41
|
+
"""Container for symbolic fn."""
|
26
42
|
|
27
|
-
|
28
|
-
|
29
|
-
|
43
|
+
fn_name: str
|
44
|
+
expr: sympy.Expr
|
45
|
+
args: list[str]
|
46
|
+
|
47
|
+
def __repr__(self) -> str:
|
48
|
+
"""Return default representation."""
|
49
|
+
return pformat(self)
|
50
|
+
|
51
|
+
|
52
|
+
@dataclass
|
53
|
+
class SymbolicVariable:
|
54
|
+
"""Container for symbolic variable."""
|
55
|
+
|
56
|
+
value: sympy.Float | SymbolicFn # initial assignment
|
57
|
+
unit: Quantity | None
|
58
|
+
|
59
|
+
def __repr__(self) -> str:
|
60
|
+
"""Return default representation."""
|
61
|
+
return pformat(self)
|
62
|
+
|
63
|
+
|
64
|
+
@dataclass
|
65
|
+
class SymbolicParameter:
|
66
|
+
"""Container for symbolic par."""
|
67
|
+
|
68
|
+
value: sympy.Float | SymbolicFn # initial assignment
|
69
|
+
unit: Quantity | None
|
70
|
+
|
71
|
+
def __repr__(self) -> str:
|
72
|
+
"""Return default representation."""
|
73
|
+
return pformat(self)
|
74
|
+
|
75
|
+
|
76
|
+
@dataclass
|
77
|
+
class SymbolicReaction:
|
78
|
+
"""Container for symbolic rxn."""
|
79
|
+
|
80
|
+
fn: SymbolicFn
|
81
|
+
stoichiometry: dict[str, sympy.Float | str | SymbolicFn]
|
82
|
+
|
83
|
+
def __repr__(self) -> str:
|
84
|
+
"""Return default representation."""
|
85
|
+
return pformat(self)
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass
|
89
|
+
class SymbolicRepr:
|
90
|
+
"""Container for symbolic model."""
|
91
|
+
|
92
|
+
variables: dict[str, SymbolicVariable] = field(default_factory=dict)
|
93
|
+
parameters: dict[str, SymbolicParameter] = field(default_factory=dict)
|
94
|
+
derived: dict[str, SymbolicFn] = field(default_factory=dict)
|
95
|
+
reactions: dict[str, SymbolicReaction] = field(default_factory=dict)
|
96
|
+
|
97
|
+
def __repr__(self) -> str:
|
98
|
+
"""Return default representation."""
|
99
|
+
return pformat(self)
|
100
|
+
|
101
|
+
|
102
|
+
def _fn_to_symbolic_repr(k: str, fn: Callable, model_args: list[str]) -> SymbolicFn:
|
103
|
+
fn_name = fn.__name__
|
104
|
+
args = cast(list, list_of_symbols(model_args))
|
105
|
+
if (expr := fn_to_sympy(fn, origin=k, model_args=args)) is None:
|
106
|
+
msg = f"Unable to parse fn for '{k}'"
|
107
|
+
raise ValueError(msg)
|
108
|
+
return SymbolicFn(fn_name=fn_name, expr=expr, args=model_args)
|
109
|
+
|
110
|
+
|
111
|
+
def _to_symbolic_repr(model: Model) -> SymbolicRepr:
|
112
|
+
sym = SymbolicRepr()
|
113
|
+
|
114
|
+
for k, variable in model.get_raw_variables().items():
|
115
|
+
sym.variables[k] = SymbolicVariable(
|
116
|
+
value=_fn_to_symbolic_repr(k, val.fn, val.args)
|
117
|
+
if isinstance(val := variable.initial_value, InitialAssignment)
|
118
|
+
else sympy.Float(val),
|
119
|
+
unit=cast(Quantity, variable.unit),
|
120
|
+
)
|
121
|
+
|
122
|
+
for k, parameter in model.get_raw_parameters().items():
|
123
|
+
sym.parameters[k] = SymbolicParameter(
|
124
|
+
value=_fn_to_symbolic_repr(k, val.fn, val.args)
|
125
|
+
if isinstance(val := parameter.value, InitialAssignment)
|
126
|
+
else sympy.Float(val),
|
127
|
+
unit=cast(Quantity, parameter.unit),
|
128
|
+
)
|
30
129
|
|
31
|
-
# Derived
|
32
|
-
derived_source = []
|
33
130
|
for k, der in model.get_raw_derived().items():
|
34
|
-
|
35
|
-
fn_name = fn.__name__
|
36
|
-
if (
|
37
|
-
expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(der.args))
|
38
|
-
) is None:
|
39
|
-
msg = f"Unable to parse fn for derived value '{k}'"
|
40
|
-
raise ValueError(msg)
|
131
|
+
sym.derived[k] = _fn_to_symbolic_repr(k, der.fn, der.args)
|
41
132
|
|
42
|
-
|
133
|
+
for k, rxn in model.get_raw_reactions().items():
|
134
|
+
sym.reactions[k] = SymbolicReaction(
|
135
|
+
fn=_fn_to_symbolic_repr(k, rxn.fn, rxn.args),
|
136
|
+
stoichiometry={
|
137
|
+
k: _fn_to_symbolic_repr(k, v.fn, v.args)
|
138
|
+
if isinstance(v, Derived)
|
139
|
+
else sympy.Float(v)
|
140
|
+
for k, v in rxn.stoichiometry.items()
|
141
|
+
},
|
142
|
+
)
|
43
143
|
|
144
|
+
if len(model._surrogates) > 0: # noqa: SLF001
|
145
|
+
msg = "Generating code for Surrogates not yet supported."
|
146
|
+
_LOGGER.warning(msg)
|
147
|
+
return sym
|
148
|
+
|
149
|
+
|
150
|
+
def _codegen_variable(
|
151
|
+
k: str, var: SymbolicVariable, functions: dict[str, tuple[sympy.Expr, list[str]]]
|
152
|
+
) -> str:
|
153
|
+
if isinstance(init := var.value, SymbolicFn):
|
154
|
+
fn_name = f"init_{init.fn_name}"
|
155
|
+
functions[fn_name] = (init.expr, init.args)
|
156
|
+
return f""" .add_variable(
|
157
|
+
{k!r},
|
158
|
+
initial_value=InitialAssignment(fn={fn_name}, args={init.args!r}),
|
159
|
+
)"""
|
160
|
+
|
161
|
+
value = sympy_to_inline_py(init)
|
162
|
+
if (unit := var.unit) is not None:
|
163
|
+
return f" .add_variable({k!r}, value={value}, unit={sympy_to_inline_py(unit)})"
|
164
|
+
return f" .add_variable({k!r}, initial_value={value})"
|
165
|
+
|
166
|
+
|
167
|
+
def _codegen_parameter(
|
168
|
+
k: str, par: SymbolicParameter, functions: dict[str, tuple[sympy.Expr, list[str]]]
|
169
|
+
) -> str:
|
170
|
+
if isinstance(init := par.value, SymbolicFn):
|
171
|
+
fn_name = f"init_{init.fn_name}"
|
172
|
+
functions[fn_name] = (init.expr, init.args)
|
173
|
+
return f""" .add_parameter(
|
174
|
+
{k!r},
|
175
|
+
value=InitialAssignment(fn={fn_name}, args={init.args!r}),
|
176
|
+
)"""
|
177
|
+
|
178
|
+
value = sympy_to_inline_py(init)
|
179
|
+
if (unit := par.unit) is not None:
|
180
|
+
return f" .add_parameter({k!r}, value={value}, unit={sympy_to_inline_py(unit)})"
|
181
|
+
return f" .add_parameter({k!r}, value={value})"
|
182
|
+
|
183
|
+
|
184
|
+
def generate_mxlpy_code_from_symbolic_repr(
|
185
|
+
model: SymbolicRepr, imports: list[str] | None = None
|
186
|
+
) -> str:
|
187
|
+
"""Generate MxlPy source code from symbolic representation.
|
188
|
+
|
189
|
+
This is both used by MxlPy internally to codegen an existing model again and by the
|
190
|
+
SBML import to generate the file.
|
191
|
+
"""
|
192
|
+
imports = [] if imports is None else imports
|
193
|
+
|
194
|
+
functions: dict[str, tuple[sympy.Expr, list[str]]] = {}
|
195
|
+
|
196
|
+
# Variables
|
197
|
+
variable_source = []
|
198
|
+
for k, var in model.variables.items():
|
199
|
+
variable_source.append(_codegen_variable(k, var, functions=functions))
|
200
|
+
|
201
|
+
# Parameters
|
202
|
+
parameter_source = []
|
203
|
+
for k, par in model.parameters.items():
|
204
|
+
parameter_source.append(_codegen_parameter(k, par, functions=functions))
|
205
|
+
|
206
|
+
# Derived
|
207
|
+
derived_source = []
|
208
|
+
for k, fn in model.derived.items():
|
209
|
+
functions[fn.fn_name] = (fn.expr, fn.args)
|
44
210
|
derived_source.append(
|
45
211
|
f""" .add_derived(
|
46
|
-
|
47
|
-
fn={fn_name},
|
48
|
-
args={
|
212
|
+
{k!r},
|
213
|
+
fn={fn.fn_name},
|
214
|
+
args={fn.args},
|
49
215
|
)"""
|
50
216
|
)
|
51
217
|
|
52
218
|
# Reactions
|
53
219
|
reactions_source = []
|
54
|
-
for k, rxn in model.
|
220
|
+
for k, rxn in model.reactions.items():
|
55
221
|
fn = rxn.fn
|
56
|
-
fn_name = fn.
|
57
|
-
|
58
|
-
expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(rxn.args))
|
59
|
-
) is None:
|
60
|
-
msg = f"Unable to parse fn for reaction '{k}'"
|
61
|
-
raise ValueError(msg)
|
62
|
-
|
63
|
-
functions[fn_name] = (expr, rxn.args)
|
222
|
+
functions[fn.fn_name] = (fn.expr, fn.args)
|
223
|
+
|
64
224
|
stoichiometry: list[str] = []
|
65
225
|
for var, stoich in rxn.stoichiometry.items():
|
66
|
-
if isinstance(stoich,
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
)
|
71
|
-
) is None:
|
72
|
-
msg = f"Unable to parse fn for stoichiometry '{var}'"
|
73
|
-
raise ValueError(msg)
|
74
|
-
functions[fn_name] = (expr, rxn.args)
|
75
|
-
args = ", ".join(f'"{k}"' for k in stoich.args)
|
76
|
-
stoich = ( # noqa: PLW2901
|
77
|
-
f"""Derived(fn={fn.__name__}, args=[{args}])"""
|
226
|
+
if isinstance(stoich, SymbolicFn):
|
227
|
+
fn_name = f"{k}_stoich_{stoich.fn_name}"
|
228
|
+
functions[fn_name] = (stoich.expr, stoich.args)
|
229
|
+
stoichiometry.append(
|
230
|
+
f""""{var}": Derived(fn={fn_name}, args={stoich.args!r})"""
|
78
231
|
)
|
79
|
-
|
80
|
-
|
232
|
+
elif isinstance(stoich, str):
|
233
|
+
stoichiometry.append(f""""{var}": {stoich!r}""")
|
234
|
+
else:
|
235
|
+
stoichiometry.append(f""""{var}": {sympy_to_inline_py(stoich)}""")
|
81
236
|
reactions_source.append(
|
82
237
|
f""" .add_reaction(
|
83
238
|
"{k}",
|
84
|
-
fn={fn_name},
|
85
|
-
args={
|
239
|
+
fn={fn.fn_name},
|
240
|
+
args={fn.args},
|
86
241
|
stoichiometry={{{",".join(stoichiometry)}}},
|
87
242
|
)"""
|
88
243
|
)
|
89
244
|
|
90
245
|
# Surrogates
|
91
|
-
if len(model._surrogates) > 0: # noqa: SLF001
|
92
|
-
msg = "Generating code for Surrogates not yet supported."
|
93
|
-
_LOGGER.warning(msg)
|
94
246
|
|
95
247
|
# Combine all the sources
|
96
248
|
functions_source = "\n\n".join(
|
@@ -98,21 +250,26 @@ def generate_mxlpy_code(model: Model) -> str:
|
|
98
250
|
for name, (expr, args) in functions.items()
|
99
251
|
)
|
100
252
|
source = [
|
101
|
-
|
253
|
+
*imports,
|
254
|
+
"from mxlpy import Model, Derived, InitialAssignment\n",
|
102
255
|
functions_source,
|
256
|
+
"",
|
103
257
|
"def create_model() -> Model:",
|
104
258
|
" return (",
|
105
259
|
" Model()",
|
106
260
|
]
|
107
|
-
if len(
|
108
|
-
source.append(
|
109
|
-
if len(
|
110
|
-
source.append(
|
261
|
+
if len(variable_source) > 0:
|
262
|
+
source.append("\n".join(variable_source))
|
263
|
+
if len(parameter_source) > 0:
|
264
|
+
source.append("\n".join(parameter_source))
|
111
265
|
if len(derived_source) > 0:
|
112
266
|
source.append("\n".join(derived_source))
|
113
267
|
if len(reactions_source) > 0:
|
114
268
|
source.append("\n".join(reactions_source))
|
115
|
-
|
116
269
|
source.append(" )")
|
117
|
-
|
118
270
|
return "\n".join(source)
|
271
|
+
|
272
|
+
|
273
|
+
def generate_mxlpy_code(model: Model) -> str:
|
274
|
+
"""Generate a mxlpy model from a model."""
|
275
|
+
return generate_mxlpy_code_from_symbolic_repr(_to_symbolic_repr(model))
|
mxlpy/meta/source_tools.py
CHANGED
@@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, cast
|
|
15
15
|
import dill
|
16
16
|
import numpy as np
|
17
17
|
import sympy
|
18
|
+
from wadler_lindig import pformat
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
20
21
|
from collections.abc import Callable
|
@@ -111,7 +112,6 @@ KNOWN_FNS: dict[Callable, sympy.Expr] = {
|
|
111
112
|
math.trunc: sympy.trunc,
|
112
113
|
# math.ulp: sympy.ulp,
|
113
114
|
# numpy
|
114
|
-
np.exp: sympy.exp,
|
115
115
|
np.abs: sympy.Abs,
|
116
116
|
np.acos: sympy.acos,
|
117
117
|
np.acosh: sympy.acosh,
|
@@ -175,6 +175,10 @@ class Context:
|
|
175
175
|
modules: dict[str, ModuleType]
|
176
176
|
fns: dict[str, Callable]
|
177
177
|
|
178
|
+
def __repr__(self) -> str:
|
179
|
+
"""Return default representation."""
|
180
|
+
return pformat(self)
|
181
|
+
|
178
182
|
def updated(
|
179
183
|
self,
|
180
184
|
symbols: dict[str, sympy.Symbol | sympy.Expr] | None = None,
|
@@ -309,10 +313,10 @@ def fn_to_sympy(
|
|
309
313
|
)
|
310
314
|
if sympy_expr is None:
|
311
315
|
return None
|
312
|
-
#
|
316
|
+
# Evaluated fns and floats from attributes
|
313
317
|
if isinstance(sympy_expr, float):
|
314
318
|
return sympy.Float(sympy_expr)
|
315
|
-
if model_args is not None:
|
319
|
+
if model_args is not None and len(model_args):
|
316
320
|
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
317
321
|
return cast(sympy.Expr, sympy_expr)
|
318
322
|
|
@@ -323,77 +327,6 @@ def fn_to_sympy(
|
|
323
327
|
return None
|
324
328
|
|
325
329
|
|
326
|
-
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
327
|
-
value = ctx.symbols.get(node.id)
|
328
|
-
if value is None:
|
329
|
-
global_variables = dict(
|
330
|
-
inspect.getmembers(
|
331
|
-
ctx.parent_module,
|
332
|
-
predicate=lambda x: isinstance(x, float),
|
333
|
-
)
|
334
|
-
)
|
335
|
-
value = sympy.Float(global_variables[node.id])
|
336
|
-
return value
|
337
|
-
|
338
|
-
|
339
|
-
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
340
|
-
if isinstance(node, float):
|
341
|
-
return sympy.Float(node)
|
342
|
-
if isinstance(node, ast.UnaryOp):
|
343
|
-
return _handle_unaryop(node, ctx)
|
344
|
-
if isinstance(node, ast.BinOp):
|
345
|
-
return _handle_binop(node, ctx)
|
346
|
-
if isinstance(node, ast.Name):
|
347
|
-
return _handle_name(node, ctx)
|
348
|
-
if isinstance(node, ast.Constant):
|
349
|
-
return node.value
|
350
|
-
if isinstance(node, ast.Call):
|
351
|
-
return _handle_call(node, ctx=ctx)
|
352
|
-
if isinstance(node, ast.Attribute):
|
353
|
-
return _handle_attribute(node, ctx=ctx)
|
354
|
-
|
355
|
-
if isinstance(node, ast.Compare):
|
356
|
-
# Handle chained comparisons like 1 < a < 2
|
357
|
-
left = cast(Any, _handle_expr(node.left, ctx))
|
358
|
-
comparisons = []
|
359
|
-
|
360
|
-
# Build all individual comparisons from the chain
|
361
|
-
prev_value = left
|
362
|
-
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
363
|
-
right = cast(Any, _handle_expr(comparator, ctx))
|
364
|
-
|
365
|
-
if isinstance(op, ast.Gt):
|
366
|
-
comparisons.append(prev_value > right)
|
367
|
-
elif isinstance(op, ast.GtE):
|
368
|
-
comparisons.append(prev_value >= right)
|
369
|
-
elif isinstance(op, ast.Lt):
|
370
|
-
comparisons.append(prev_value < right)
|
371
|
-
elif isinstance(op, ast.LtE):
|
372
|
-
comparisons.append(prev_value <= right)
|
373
|
-
elif isinstance(op, ast.Eq):
|
374
|
-
comparisons.append(prev_value == right)
|
375
|
-
elif isinstance(op, ast.NotEq):
|
376
|
-
comparisons.append(prev_value != right)
|
377
|
-
|
378
|
-
prev_value = right
|
379
|
-
|
380
|
-
# Combine all comparisons with logical AND
|
381
|
-
result = comparisons[0]
|
382
|
-
for comp in comparisons[1:]:
|
383
|
-
result = sympy.And(result, comp)
|
384
|
-
return cast(sympy.Expr, result)
|
385
|
-
|
386
|
-
# Handle conditional expressions (ternary operators)
|
387
|
-
if isinstance(node, ast.IfExp):
|
388
|
-
condition = _handle_expr(node.test, ctx)
|
389
|
-
if_true = _handle_expr(node.body, ctx)
|
390
|
-
if_false = _handle_expr(node.orelse, ctx)
|
391
|
-
return sympy.Piecewise((if_true, condition), (if_false, True))
|
392
|
-
|
393
|
-
msg = f"Expression type {type(node).__name__} not implemented"
|
394
|
-
raise NotImplementedError(msg)
|
395
|
-
|
396
|
-
|
397
330
|
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
398
331
|
pieces = []
|
399
332
|
remaining_body = list(body)
|
@@ -505,6 +438,81 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
|
505
438
|
raise ValueError(msg)
|
506
439
|
|
507
440
|
|
441
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
442
|
+
"""Key dispatch function."""
|
443
|
+
if isinstance(node, float):
|
444
|
+
return sympy.Float(node)
|
445
|
+
if isinstance(node, ast.UnaryOp):
|
446
|
+
return _handle_unaryop(node, ctx)
|
447
|
+
if isinstance(node, ast.BinOp):
|
448
|
+
return _handle_binop(node, ctx)
|
449
|
+
if isinstance(node, ast.Name):
|
450
|
+
return _handle_name(node, ctx)
|
451
|
+
if isinstance(node, ast.Constant):
|
452
|
+
if isinstance(val := node.value, (float, int)):
|
453
|
+
return sympy.Float(val)
|
454
|
+
msg = "Can only use float values"
|
455
|
+
raise NotImplementedError(msg)
|
456
|
+
if isinstance(node, ast.Call):
|
457
|
+
return _handle_call(node, ctx=ctx)
|
458
|
+
if isinstance(node, ast.Attribute):
|
459
|
+
return _handle_attribute(node, ctx=ctx)
|
460
|
+
|
461
|
+
if isinstance(node, ast.Compare):
|
462
|
+
# Handle chained comparisons like 1 < a < 2
|
463
|
+
left = cast(Any, _handle_expr(node.left, ctx))
|
464
|
+
comparisons = []
|
465
|
+
|
466
|
+
# Build all individual comparisons from the chain
|
467
|
+
prev_value = left
|
468
|
+
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
469
|
+
right = cast(Any, _handle_expr(comparator, ctx))
|
470
|
+
|
471
|
+
if isinstance(op, ast.Gt):
|
472
|
+
comparisons.append(prev_value > right)
|
473
|
+
elif isinstance(op, ast.GtE):
|
474
|
+
comparisons.append(prev_value >= right)
|
475
|
+
elif isinstance(op, ast.Lt):
|
476
|
+
comparisons.append(prev_value < right)
|
477
|
+
elif isinstance(op, ast.LtE):
|
478
|
+
comparisons.append(prev_value <= right)
|
479
|
+
elif isinstance(op, ast.Eq):
|
480
|
+
comparisons.append(prev_value == right)
|
481
|
+
elif isinstance(op, ast.NotEq):
|
482
|
+
comparisons.append(prev_value != right)
|
483
|
+
|
484
|
+
prev_value = right
|
485
|
+
|
486
|
+
# Combine all comparisons with logical AND
|
487
|
+
result = comparisons[0]
|
488
|
+
for comp in comparisons[1:]:
|
489
|
+
result = sympy.And(result, comp)
|
490
|
+
return cast(sympy.Expr, result)
|
491
|
+
|
492
|
+
# Handle conditional expressions (ternary operators)
|
493
|
+
if isinstance(node, ast.IfExp):
|
494
|
+
condition = _handle_expr(node.test, ctx)
|
495
|
+
if_true = _handle_expr(node.body, ctx)
|
496
|
+
if_false = _handle_expr(node.orelse, ctx)
|
497
|
+
return sympy.Piecewise((if_true, condition), (if_false, True))
|
498
|
+
|
499
|
+
msg = f"Expression type {type(node).__name__} not implemented"
|
500
|
+
raise NotImplementedError(msg)
|
501
|
+
|
502
|
+
|
503
|
+
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
504
|
+
value = ctx.symbols.get(node.id)
|
505
|
+
if value is None:
|
506
|
+
global_variables = dict(
|
507
|
+
inspect.getmembers(
|
508
|
+
ctx.parent_module,
|
509
|
+
predicate=lambda x: isinstance(x, float),
|
510
|
+
)
|
511
|
+
)
|
512
|
+
value = sympy.Float(global_variables[node.id])
|
513
|
+
return value
|
514
|
+
|
515
|
+
|
508
516
|
def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
509
517
|
left = _handle_expr(node.operand, ctx)
|
510
518
|
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
@@ -546,6 +554,27 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
|
546
554
|
raise NotImplementedError(msg)
|
547
555
|
|
548
556
|
|
557
|
+
def _get_inner_object(obj: object, levels: list[str]) -> sympy.Float | None:
|
558
|
+
# Check if object is instantiated, otherwise instantiate first
|
559
|
+
if isinstance(obj, type):
|
560
|
+
obj = obj()
|
561
|
+
|
562
|
+
for level in levels:
|
563
|
+
_LOGGER.debug("obj %s, level %s", obj, level)
|
564
|
+
obj = getattr(obj, level, None)
|
565
|
+
|
566
|
+
if obj is None:
|
567
|
+
return None
|
568
|
+
|
569
|
+
if isinstance(obj, float):
|
570
|
+
if (value := KNOWN_CONSTANTS.get(obj)) is not None:
|
571
|
+
return value
|
572
|
+
return sympy.Float(obj)
|
573
|
+
|
574
|
+
_LOGGER.debug("Inner object not float: %s", obj)
|
575
|
+
return None
|
576
|
+
|
577
|
+
|
549
578
|
# FIXME: check if target isn't an object or class
|
550
579
|
def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
551
580
|
"""Handle an attribute.
|
@@ -581,17 +610,36 @@ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
|
581
610
|
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
582
611
|
| ctx.modules
|
583
612
|
)
|
613
|
+
variables = vars(ctx.parent_module)
|
614
|
+
|
584
615
|
match node.value:
|
585
616
|
case ast.Name(l1):
|
586
617
|
module_name = l1
|
587
618
|
module = modules.get(module_name)
|
619
|
+
if module is None and (var := variables.get(l1)) is not None:
|
620
|
+
return _get_inner_object(var, [node.attr])
|
588
621
|
case ast.Attribute():
|
589
|
-
levels = _find_root(node.value, [])
|
622
|
+
levels = _find_root(node.value, levels=[])
|
623
|
+
_LOGGER.debug("Attribute levels %s", levels)
|
590
624
|
module_name = ".".join(levels)
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
625
|
+
|
626
|
+
for idx, level in enumerate(levels[:-1]):
|
627
|
+
if (module := modules.get(level)) is not None:
|
628
|
+
modules.update(
|
629
|
+
dict(
|
630
|
+
inspect.getmembers(
|
631
|
+
module,
|
632
|
+
predicate=inspect.ismodule,
|
633
|
+
)
|
634
|
+
)
|
635
|
+
)
|
636
|
+
elif (var := variables.get(level)) is not None:
|
637
|
+
_LOGGER.debug("var %s", var)
|
638
|
+
return _get_inner_object(var, levels[(idx + 1) :] + [node.attr])
|
639
|
+
|
640
|
+
else:
|
641
|
+
_LOGGER.debug("No target found")
|
642
|
+
|
595
643
|
module = modules.get(levels[-1])
|
596
644
|
case _:
|
597
645
|
raise NotImplementedError
|
@@ -633,6 +681,7 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
|
633
681
|
if (expr := _handle_expr(i, ctx)) is None:
|
634
682
|
return None
|
635
683
|
model_args.append(expr)
|
684
|
+
_LOGGER.debug("Fn args: %s", model_args)
|
636
685
|
|
637
686
|
match node.func:
|
638
687
|
case ast.Name(id):
|
@@ -674,7 +723,7 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
|
674
723
|
return None
|
675
724
|
|
676
725
|
if (fn := KNOWN_FNS.get(py_fn)) is not None:
|
677
|
-
return fn
|
726
|
+
return sympy.Float(fn(*model_args)) # type: ignore
|
678
727
|
|
679
728
|
return fn_to_sympy(
|
680
729
|
py_fn,
|
mxlpy/meta/sympy_tools.py
CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
5
5
|
from typing import TYPE_CHECKING, cast
|
6
6
|
|
7
7
|
import sympy
|
8
|
-
from sympy.printing import rust_code
|
8
|
+
from sympy.printing import jscode, rust_code
|
9
9
|
from sympy.printing.pycode import pycode
|
10
10
|
|
11
11
|
from mxlpy.meta.source_tools import fn_to_sympy
|
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|
17
17
|
__all__ = [
|
18
18
|
"list_of_symbols",
|
19
19
|
"stoichiometries_to_sympy",
|
20
|
+
"sympy_to_inline_js",
|
20
21
|
"sympy_to_inline_py",
|
21
22
|
"sympy_to_inline_rust",
|
22
23
|
"sympy_to_python_fn",
|
@@ -50,12 +51,17 @@ def sympy_to_inline_py(expr: sympy.Expr) -> str:
|
|
50
51
|
'x**2 + 2*x + 1'
|
51
52
|
|
52
53
|
"""
|
53
|
-
return cast(str, pycode(expr, fully_qualified_modules=True))
|
54
|
+
return cast(str, pycode(expr, fully_qualified_modules=True, full_prec=False))
|
55
|
+
|
56
|
+
|
57
|
+
def sympy_to_inline_js(expr: sympy.Expr) -> str:
|
58
|
+
"""Create rust code from sympy expression."""
|
59
|
+
return cast(str, jscode(expr, full_prec=False))
|
54
60
|
|
55
61
|
|
56
62
|
def sympy_to_inline_rust(expr: sympy.Expr) -> str:
|
57
63
|
"""Create rust code from sympy expression."""
|
58
|
-
return cast(str, rust_code(expr))
|
64
|
+
return cast(str, rust_code(expr, full_prec=False))
|
59
65
|
|
60
66
|
|
61
67
|
def sympy_to_python_fn(
|
@@ -93,8 +99,8 @@ def sympy_to_python_fn(
|
|
93
99
|
fn_args = ", ".join(f"{i}: float" for i in args)
|
94
100
|
|
95
101
|
return f"""def {fn_name}({fn_args}) -> float:
|
96
|
-
return {pycode(expr)}
|
97
|
-
"""
|
102
|
+
return {pycode(expr, fully_qualified_modules=True, full_prec=False)}
|
103
|
+
""".replace("math.factorial", "scipy.special.factorial")
|
98
104
|
|
99
105
|
|
100
106
|
def stoichiometries_to_sympy(
|
@@ -114,4 +120,4 @@ def stoichiometries_to_sympy(
|
|
114
120
|
expr = expr + sympy_fn * sympy.Symbol(rxn_name) # type: ignore
|
115
121
|
else:
|
116
122
|
expr = expr + rxn_stoich * sympy.Symbol(rxn_name) # type: ignore
|
117
|
-
return expr
|
123
|
+
return expr.subs(1.0, 1) # type: ignore
|