mxlpy 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +11 -2
- mxlpy/carousel.py +6 -4
- mxlpy/compare.py +2 -2
- mxlpy/integrators/__init__.py +4 -0
- mxlpy/integrators/int_assimulo.py +3 -3
- mxlpy/integrators/int_diffrax.py +119 -0
- mxlpy/integrators/int_scipy.py +12 -6
- mxlpy/label_map.py +1 -2
- mxlpy/mc.py +22 -22
- mxlpy/mca.py +8 -4
- mxlpy/meta/codegen_model.py +2 -1
- mxlpy/meta/codegen_mxlpy.py +194 -58
- mxlpy/meta/source_tools.py +124 -80
- mxlpy/meta/sympy_tools.py +5 -5
- mxlpy/model.py +288 -91
- mxlpy/plot.py +16 -14
- mxlpy/sbml/_export.py +13 -5
- mxlpy/sbml/_import.py +68 -547
- mxlpy/scan.py +38 -242
- mxlpy/simulator.py +4 -359
- mxlpy/types.py +655 -83
- mxlpy/units.py +5 -0
- {mxlpy-0.22.0.dist-info → mxlpy-0.23.0.dist-info}/METADATA +4 -1
- {mxlpy-0.22.0.dist-info → mxlpy-0.23.0.dist-info}/RECORD +26 -27
- mxlpy/sbml/_mathml.py +0 -692
- mxlpy/sbml/_unit_conversion.py +0 -74
- {mxlpy-0.22.0.dist-info → mxlpy-0.23.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.22.0.dist-info → mxlpy-0.23.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/meta/codegen_mxlpy.py
CHANGED
@@ -3,94 +3,225 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import logging
|
6
|
-
from
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from typing import TYPE_CHECKING, cast
|
7
8
|
|
8
|
-
|
9
|
-
|
9
|
+
import sympy
|
10
|
+
|
11
|
+
from mxlpy.meta.sympy_tools import (
|
12
|
+
fn_to_sympy,
|
13
|
+
list_of_symbols,
|
14
|
+
sympy_to_inline_py,
|
15
|
+
sympy_to_python_fn,
|
16
|
+
)
|
17
|
+
from mxlpy.types import Derived, InitialAssignment
|
18
|
+
from mxlpy.units import Quantity
|
10
19
|
|
11
20
|
if TYPE_CHECKING:
|
12
|
-
import
|
21
|
+
from collections.abc import Callable
|
13
22
|
|
14
23
|
from mxlpy.model import Model
|
15
24
|
|
16
25
|
__all__ = [
|
26
|
+
"SymbolicFn",
|
27
|
+
"SymbolicParameter",
|
28
|
+
"SymbolicReaction",
|
29
|
+
"SymbolicRepr",
|
30
|
+
"SymbolicVariable",
|
17
31
|
"generate_mxlpy_code",
|
32
|
+
"generate_mxlpy_code_from_symbolic_repr",
|
18
33
|
]
|
19
34
|
|
20
35
|
_LOGGER = logging.getLogger()
|
21
36
|
|
22
37
|
|
23
|
-
|
24
|
-
|
25
|
-
|
38
|
+
@dataclass
|
39
|
+
class SymbolicFn:
|
40
|
+
"""Container for symbolic fn."""
|
26
41
|
|
27
|
-
|
28
|
-
|
29
|
-
|
42
|
+
fn_name: str
|
43
|
+
expr: sympy.Expr
|
44
|
+
args: list[str]
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class SymbolicVariable:
|
49
|
+
"""Container for symbolic variable."""
|
50
|
+
|
51
|
+
value: sympy.Float | SymbolicFn # initial assignment
|
52
|
+
unit: Quantity | None
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class SymbolicParameter:
|
57
|
+
"""Container for symbolic par."""
|
58
|
+
|
59
|
+
value: sympy.Float | SymbolicFn # initial assignment
|
60
|
+
unit: Quantity | None
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class SymbolicReaction:
|
65
|
+
"""Container for symbolic rxn."""
|
66
|
+
|
67
|
+
fn: SymbolicFn
|
68
|
+
stoichiometry: dict[str, sympy.Float | str | SymbolicFn]
|
69
|
+
|
70
|
+
|
71
|
+
@dataclass
|
72
|
+
class SymbolicRepr:
|
73
|
+
"""Container for symbolic model."""
|
74
|
+
|
75
|
+
variables: dict[str, SymbolicVariable] = field(default_factory=dict)
|
76
|
+
parameters: dict[str, SymbolicParameter] = field(default_factory=dict)
|
77
|
+
derived: dict[str, SymbolicFn] = field(default_factory=dict)
|
78
|
+
reactions: dict[str, SymbolicReaction] = field(default_factory=dict)
|
79
|
+
|
80
|
+
|
81
|
+
def _fn_to_symbolic_repr(k: str, fn: Callable, model_args: list[str]) -> SymbolicFn:
|
82
|
+
fn_name = fn.__name__
|
83
|
+
args = cast(list, list_of_symbols(model_args))
|
84
|
+
if (expr := fn_to_sympy(fn, origin=k, model_args=args)) is None:
|
85
|
+
msg = f"Unable to parse fn for '{k}'"
|
86
|
+
raise ValueError(msg)
|
87
|
+
return SymbolicFn(fn_name=fn_name, expr=expr, args=model_args)
|
88
|
+
|
89
|
+
|
90
|
+
def _to_symbolic_repr(model: Model) -> SymbolicRepr:
|
91
|
+
sym = SymbolicRepr()
|
92
|
+
|
93
|
+
for k, variable in model.get_raw_variables().items():
|
94
|
+
sym.variables[k] = SymbolicVariable(
|
95
|
+
value=_fn_to_symbolic_repr(k, val.fn, val.args)
|
96
|
+
if isinstance(val := variable.initial_value, InitialAssignment)
|
97
|
+
else sympy.Float(val),
|
98
|
+
unit=cast(Quantity, variable.unit),
|
99
|
+
)
|
100
|
+
|
101
|
+
for k, parameter in model.get_raw_parameters().items():
|
102
|
+
sym.parameters[k] = SymbolicParameter(
|
103
|
+
value=_fn_to_symbolic_repr(k, val.fn, val.args)
|
104
|
+
if isinstance(val := parameter.value, InitialAssignment)
|
105
|
+
else sympy.Float(val),
|
106
|
+
unit=cast(Quantity, parameter.unit),
|
107
|
+
)
|
30
108
|
|
31
|
-
# Derived
|
32
|
-
derived_source = []
|
33
109
|
for k, der in model.get_raw_derived().items():
|
34
|
-
|
35
|
-
fn_name = fn.__name__
|
36
|
-
if (
|
37
|
-
expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(der.args))
|
38
|
-
) is None:
|
39
|
-
msg = f"Unable to parse fn for derived value '{k}'"
|
40
|
-
raise ValueError(msg)
|
110
|
+
sym.derived[k] = _fn_to_symbolic_repr(k, der.fn, der.args)
|
41
111
|
|
42
|
-
|
112
|
+
for k, rxn in model.get_raw_reactions().items():
|
113
|
+
sym.reactions[k] = SymbolicReaction(
|
114
|
+
fn=_fn_to_symbolic_repr(k, rxn.fn, rxn.args),
|
115
|
+
stoichiometry={
|
116
|
+
k: _fn_to_symbolic_repr(k, v.fn, v.args)
|
117
|
+
if isinstance(v, Derived)
|
118
|
+
else sympy.Float(v)
|
119
|
+
for k, v in rxn.stoichiometry.items()
|
120
|
+
},
|
121
|
+
)
|
43
122
|
|
123
|
+
if len(model._surrogates) > 0: # noqa: SLF001
|
124
|
+
msg = "Generating code for Surrogates not yet supported."
|
125
|
+
_LOGGER.warning(msg)
|
126
|
+
return sym
|
127
|
+
|
128
|
+
|
129
|
+
def _codegen_variable(
|
130
|
+
k: str, var: SymbolicVariable, functions: dict[str, tuple[sympy.Expr, list[str]]]
|
131
|
+
) -> str:
|
132
|
+
if isinstance(init := var.value, SymbolicFn):
|
133
|
+
fn_name = f"init_{init.fn_name}"
|
134
|
+
functions[fn_name] = (init.expr, init.args)
|
135
|
+
return f""" .add_variable(
|
136
|
+
{k!r},
|
137
|
+
initial_value=InitialAssignment(fn={fn_name}, args={init.args!r}),
|
138
|
+
)"""
|
139
|
+
|
140
|
+
value = sympy_to_inline_py(init)
|
141
|
+
if (unit := var.unit) is not None:
|
142
|
+
return f" .add_variable({k!r}, value={value}, unit={sympy_to_inline_py(unit)})"
|
143
|
+
return f" .add_variable({k!r}, initial_value={value})"
|
144
|
+
|
145
|
+
|
146
|
+
def _codegen_parameter(
|
147
|
+
k: str, par: SymbolicParameter, functions: dict[str, tuple[sympy.Expr, list[str]]]
|
148
|
+
) -> str:
|
149
|
+
if isinstance(init := par.value, SymbolicFn):
|
150
|
+
fn_name = f"init_{init.fn_name}"
|
151
|
+
functions[fn_name] = (init.expr, init.args)
|
152
|
+
return f""" .add_parameter(
|
153
|
+
{k!r},
|
154
|
+
value=InitialAssignment(fn={fn_name}, args={init.args!r}),
|
155
|
+
)"""
|
156
|
+
|
157
|
+
value = sympy_to_inline_py(init)
|
158
|
+
if (unit := par.unit) is not None:
|
159
|
+
return f" .add_parameter({k!r}, value={value}, unit={sympy_to_inline_py(unit)})"
|
160
|
+
return f" .add_parameter({k!r}, value={value})"
|
161
|
+
|
162
|
+
|
163
|
+
def generate_mxlpy_code_from_symbolic_repr(
|
164
|
+
model: SymbolicRepr, imports: list[str] | None = None
|
165
|
+
) -> str:
|
166
|
+
"""Generate MxlPy source code from symbolic representation.
|
167
|
+
|
168
|
+
This is both used by MxlPy internally to codegen an existing model again and by the
|
169
|
+
SBML import to generate the file.
|
170
|
+
"""
|
171
|
+
imports = [] if imports is None else imports
|
172
|
+
|
173
|
+
functions: dict[str, tuple[sympy.Expr, list[str]]] = {}
|
174
|
+
|
175
|
+
# Variables
|
176
|
+
variable_source = []
|
177
|
+
for k, var in model.variables.items():
|
178
|
+
variable_source.append(_codegen_variable(k, var, functions=functions))
|
179
|
+
|
180
|
+
# Parameters
|
181
|
+
parameter_source = []
|
182
|
+
for k, par in model.parameters.items():
|
183
|
+
parameter_source.append(_codegen_parameter(k, par, functions=functions))
|
184
|
+
|
185
|
+
# Derived
|
186
|
+
derived_source = []
|
187
|
+
for k, fn in model.derived.items():
|
188
|
+
functions[fn.fn_name] = (fn.expr, fn.args)
|
44
189
|
derived_source.append(
|
45
190
|
f""" .add_derived(
|
46
|
-
|
47
|
-
fn={fn_name},
|
48
|
-
args={
|
191
|
+
{k!r},
|
192
|
+
fn={fn.fn_name},
|
193
|
+
args={fn.args},
|
49
194
|
)"""
|
50
195
|
)
|
51
196
|
|
52
197
|
# Reactions
|
53
198
|
reactions_source = []
|
54
|
-
for k, rxn in model.
|
199
|
+
for k, rxn in model.reactions.items():
|
55
200
|
fn = rxn.fn
|
56
|
-
fn_name = fn.
|
57
|
-
|
58
|
-
expr := fn_to_sympy(fn, origin=k, model_args=list_of_symbols(rxn.args))
|
59
|
-
) is None:
|
60
|
-
msg = f"Unable to parse fn for reaction '{k}'"
|
61
|
-
raise ValueError(msg)
|
62
|
-
|
63
|
-
functions[fn_name] = (expr, rxn.args)
|
201
|
+
functions[fn.fn_name] = (fn.expr, fn.args)
|
202
|
+
|
64
203
|
stoichiometry: list[str] = []
|
65
204
|
for var, stoich in rxn.stoichiometry.items():
|
66
|
-
if isinstance(stoich,
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
)
|
71
|
-
) is None:
|
72
|
-
msg = f"Unable to parse fn for stoichiometry '{var}'"
|
73
|
-
raise ValueError(msg)
|
74
|
-
functions[fn_name] = (expr, rxn.args)
|
75
|
-
args = ", ".join(f'"{k}"' for k in stoich.args)
|
76
|
-
stoich = ( # noqa: PLW2901
|
77
|
-
f"""Derived(fn={fn.__name__}, args=[{args}])"""
|
205
|
+
if isinstance(stoich, SymbolicFn):
|
206
|
+
fn_name = f"{k}_stoich_{stoich.fn_name}"
|
207
|
+
functions[fn_name] = (stoich.expr, stoich.args)
|
208
|
+
stoichiometry.append(
|
209
|
+
f""""{var}": Derived(fn={fn_name}, args={stoich.args!r})"""
|
78
210
|
)
|
79
|
-
|
80
|
-
|
211
|
+
elif isinstance(stoich, str):
|
212
|
+
stoichiometry.append(f""""{var}": {stoich!r}""")
|
213
|
+
else:
|
214
|
+
stoichiometry.append(f""""{var}": {sympy_to_inline_py(stoich)}""")
|
81
215
|
reactions_source.append(
|
82
216
|
f""" .add_reaction(
|
83
217
|
"{k}",
|
84
|
-
fn={fn_name},
|
85
|
-
args={
|
218
|
+
fn={fn.fn_name},
|
219
|
+
args={fn.args},
|
86
220
|
stoichiometry={{{",".join(stoichiometry)}}},
|
87
221
|
)"""
|
88
222
|
)
|
89
223
|
|
90
224
|
# Surrogates
|
91
|
-
if len(model._surrogates) > 0: # noqa: SLF001
|
92
|
-
msg = "Generating code for Surrogates not yet supported."
|
93
|
-
_LOGGER.warning(msg)
|
94
225
|
|
95
226
|
# Combine all the sources
|
96
227
|
functions_source = "\n\n".join(
|
@@ -98,21 +229,26 @@ def generate_mxlpy_code(model: Model) -> str:
|
|
98
229
|
for name, (expr, args) in functions.items()
|
99
230
|
)
|
100
231
|
source = [
|
101
|
-
|
232
|
+
*imports,
|
233
|
+
"from mxlpy import Model, Derived, InitialAssignment\n",
|
102
234
|
functions_source,
|
235
|
+
"",
|
103
236
|
"def create_model() -> Model:",
|
104
237
|
" return (",
|
105
238
|
" Model()",
|
106
239
|
]
|
107
|
-
if len(
|
108
|
-
source.append(
|
109
|
-
if len(
|
110
|
-
source.append(
|
240
|
+
if len(variable_source) > 0:
|
241
|
+
source.append("\n".join(variable_source))
|
242
|
+
if len(parameter_source) > 0:
|
243
|
+
source.append("\n".join(parameter_source))
|
111
244
|
if len(derived_source) > 0:
|
112
245
|
source.append("\n".join(derived_source))
|
113
246
|
if len(reactions_source) > 0:
|
114
247
|
source.append("\n".join(reactions_source))
|
115
|
-
|
116
248
|
source.append(" )")
|
117
|
-
|
118
249
|
return "\n".join(source)
|
250
|
+
|
251
|
+
|
252
|
+
def generate_mxlpy_code(model: Model) -> str:
|
253
|
+
"""Generate a mxlpy model from a model."""
|
254
|
+
return generate_mxlpy_code_from_symbolic_repr(_to_symbolic_repr(model))
|
mxlpy/meta/source_tools.py
CHANGED
@@ -111,7 +111,6 @@ KNOWN_FNS: dict[Callable, sympy.Expr] = {
|
|
111
111
|
math.trunc: sympy.trunc,
|
112
112
|
# math.ulp: sympy.ulp,
|
113
113
|
# numpy
|
114
|
-
np.exp: sympy.exp,
|
115
114
|
np.abs: sympy.Abs,
|
116
115
|
np.acos: sympy.acos,
|
117
116
|
np.acosh: sympy.acosh,
|
@@ -309,10 +308,10 @@ def fn_to_sympy(
|
|
309
308
|
)
|
310
309
|
if sympy_expr is None:
|
311
310
|
return None
|
312
|
-
#
|
311
|
+
# Evaluated fns and floats from attributes
|
313
312
|
if isinstance(sympy_expr, float):
|
314
313
|
return sympy.Float(sympy_expr)
|
315
|
-
if model_args is not None:
|
314
|
+
if model_args is not None and len(model_args):
|
316
315
|
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
317
316
|
return cast(sympy.Expr, sympy_expr)
|
318
317
|
|
@@ -323,77 +322,6 @@ def fn_to_sympy(
|
|
323
322
|
return None
|
324
323
|
|
325
324
|
|
326
|
-
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
327
|
-
value = ctx.symbols.get(node.id)
|
328
|
-
if value is None:
|
329
|
-
global_variables = dict(
|
330
|
-
inspect.getmembers(
|
331
|
-
ctx.parent_module,
|
332
|
-
predicate=lambda x: isinstance(x, float),
|
333
|
-
)
|
334
|
-
)
|
335
|
-
value = sympy.Float(global_variables[node.id])
|
336
|
-
return value
|
337
|
-
|
338
|
-
|
339
|
-
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
340
|
-
if isinstance(node, float):
|
341
|
-
return sympy.Float(node)
|
342
|
-
if isinstance(node, ast.UnaryOp):
|
343
|
-
return _handle_unaryop(node, ctx)
|
344
|
-
if isinstance(node, ast.BinOp):
|
345
|
-
return _handle_binop(node, ctx)
|
346
|
-
if isinstance(node, ast.Name):
|
347
|
-
return _handle_name(node, ctx)
|
348
|
-
if isinstance(node, ast.Constant):
|
349
|
-
return node.value
|
350
|
-
if isinstance(node, ast.Call):
|
351
|
-
return _handle_call(node, ctx=ctx)
|
352
|
-
if isinstance(node, ast.Attribute):
|
353
|
-
return _handle_attribute(node, ctx=ctx)
|
354
|
-
|
355
|
-
if isinstance(node, ast.Compare):
|
356
|
-
# Handle chained comparisons like 1 < a < 2
|
357
|
-
left = cast(Any, _handle_expr(node.left, ctx))
|
358
|
-
comparisons = []
|
359
|
-
|
360
|
-
# Build all individual comparisons from the chain
|
361
|
-
prev_value = left
|
362
|
-
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
363
|
-
right = cast(Any, _handle_expr(comparator, ctx))
|
364
|
-
|
365
|
-
if isinstance(op, ast.Gt):
|
366
|
-
comparisons.append(prev_value > right)
|
367
|
-
elif isinstance(op, ast.GtE):
|
368
|
-
comparisons.append(prev_value >= right)
|
369
|
-
elif isinstance(op, ast.Lt):
|
370
|
-
comparisons.append(prev_value < right)
|
371
|
-
elif isinstance(op, ast.LtE):
|
372
|
-
comparisons.append(prev_value <= right)
|
373
|
-
elif isinstance(op, ast.Eq):
|
374
|
-
comparisons.append(prev_value == right)
|
375
|
-
elif isinstance(op, ast.NotEq):
|
376
|
-
comparisons.append(prev_value != right)
|
377
|
-
|
378
|
-
prev_value = right
|
379
|
-
|
380
|
-
# Combine all comparisons with logical AND
|
381
|
-
result = comparisons[0]
|
382
|
-
for comp in comparisons[1:]:
|
383
|
-
result = sympy.And(result, comp)
|
384
|
-
return cast(sympy.Expr, result)
|
385
|
-
|
386
|
-
# Handle conditional expressions (ternary operators)
|
387
|
-
if isinstance(node, ast.IfExp):
|
388
|
-
condition = _handle_expr(node.test, ctx)
|
389
|
-
if_true = _handle_expr(node.body, ctx)
|
390
|
-
if_false = _handle_expr(node.orelse, ctx)
|
391
|
-
return sympy.Piecewise((if_true, condition), (if_false, True))
|
392
|
-
|
393
|
-
msg = f"Expression type {type(node).__name__} not implemented"
|
394
|
-
raise NotImplementedError(msg)
|
395
|
-
|
396
|
-
|
397
325
|
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
398
326
|
pieces = []
|
399
327
|
remaining_body = list(body)
|
@@ -505,6 +433,81 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
|
505
433
|
raise ValueError(msg)
|
506
434
|
|
507
435
|
|
436
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
437
|
+
"""Key dispatch function."""
|
438
|
+
if isinstance(node, float):
|
439
|
+
return sympy.Float(node)
|
440
|
+
if isinstance(node, ast.UnaryOp):
|
441
|
+
return _handle_unaryop(node, ctx)
|
442
|
+
if isinstance(node, ast.BinOp):
|
443
|
+
return _handle_binop(node, ctx)
|
444
|
+
if isinstance(node, ast.Name):
|
445
|
+
return _handle_name(node, ctx)
|
446
|
+
if isinstance(node, ast.Constant):
|
447
|
+
if isinstance(val := node.value, (float, int)):
|
448
|
+
return sympy.Float(val)
|
449
|
+
msg = "Can only use float values"
|
450
|
+
raise NotImplementedError(msg)
|
451
|
+
if isinstance(node, ast.Call):
|
452
|
+
return _handle_call(node, ctx=ctx)
|
453
|
+
if isinstance(node, ast.Attribute):
|
454
|
+
return _handle_attribute(node, ctx=ctx)
|
455
|
+
|
456
|
+
if isinstance(node, ast.Compare):
|
457
|
+
# Handle chained comparisons like 1 < a < 2
|
458
|
+
left = cast(Any, _handle_expr(node.left, ctx))
|
459
|
+
comparisons = []
|
460
|
+
|
461
|
+
# Build all individual comparisons from the chain
|
462
|
+
prev_value = left
|
463
|
+
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
464
|
+
right = cast(Any, _handle_expr(comparator, ctx))
|
465
|
+
|
466
|
+
if isinstance(op, ast.Gt):
|
467
|
+
comparisons.append(prev_value > right)
|
468
|
+
elif isinstance(op, ast.GtE):
|
469
|
+
comparisons.append(prev_value >= right)
|
470
|
+
elif isinstance(op, ast.Lt):
|
471
|
+
comparisons.append(prev_value < right)
|
472
|
+
elif isinstance(op, ast.LtE):
|
473
|
+
comparisons.append(prev_value <= right)
|
474
|
+
elif isinstance(op, ast.Eq):
|
475
|
+
comparisons.append(prev_value == right)
|
476
|
+
elif isinstance(op, ast.NotEq):
|
477
|
+
comparisons.append(prev_value != right)
|
478
|
+
|
479
|
+
prev_value = right
|
480
|
+
|
481
|
+
# Combine all comparisons with logical AND
|
482
|
+
result = comparisons[0]
|
483
|
+
for comp in comparisons[1:]:
|
484
|
+
result = sympy.And(result, comp)
|
485
|
+
return cast(sympy.Expr, result)
|
486
|
+
|
487
|
+
# Handle conditional expressions (ternary operators)
|
488
|
+
if isinstance(node, ast.IfExp):
|
489
|
+
condition = _handle_expr(node.test, ctx)
|
490
|
+
if_true = _handle_expr(node.body, ctx)
|
491
|
+
if_false = _handle_expr(node.orelse, ctx)
|
492
|
+
return sympy.Piecewise((if_true, condition), (if_false, True))
|
493
|
+
|
494
|
+
msg = f"Expression type {type(node).__name__} not implemented"
|
495
|
+
raise NotImplementedError(msg)
|
496
|
+
|
497
|
+
|
498
|
+
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
499
|
+
value = ctx.symbols.get(node.id)
|
500
|
+
if value is None:
|
501
|
+
global_variables = dict(
|
502
|
+
inspect.getmembers(
|
503
|
+
ctx.parent_module,
|
504
|
+
predicate=lambda x: isinstance(x, float),
|
505
|
+
)
|
506
|
+
)
|
507
|
+
value = sympy.Float(global_variables[node.id])
|
508
|
+
return value
|
509
|
+
|
510
|
+
|
508
511
|
def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
509
512
|
left = _handle_expr(node.operand, ctx)
|
510
513
|
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
@@ -546,6 +549,27 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
|
546
549
|
raise NotImplementedError(msg)
|
547
550
|
|
548
551
|
|
552
|
+
def _get_inner_object(obj: object, levels: list[str]) -> sympy.Float | None:
|
553
|
+
# Check if object is instantiated, otherwise instantiate first
|
554
|
+
if isinstance(obj, type):
|
555
|
+
obj = obj()
|
556
|
+
|
557
|
+
for level in levels:
|
558
|
+
_LOGGER.debug("obj %s, level %s", obj, level)
|
559
|
+
obj = getattr(obj, level, None)
|
560
|
+
|
561
|
+
if obj is None:
|
562
|
+
return None
|
563
|
+
|
564
|
+
if isinstance(obj, float):
|
565
|
+
if (value := KNOWN_CONSTANTS.get(obj)) is not None:
|
566
|
+
return value
|
567
|
+
return sympy.Float(obj)
|
568
|
+
|
569
|
+
_LOGGER.debug("Inner object not float: %s", obj)
|
570
|
+
return None
|
571
|
+
|
572
|
+
|
549
573
|
# FIXME: check if target isn't an object or class
|
550
574
|
def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
551
575
|
"""Handle an attribute.
|
@@ -581,17 +605,36 @@ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
|
581
605
|
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
582
606
|
| ctx.modules
|
583
607
|
)
|
608
|
+
variables = vars(ctx.parent_module)
|
609
|
+
|
584
610
|
match node.value:
|
585
611
|
case ast.Name(l1):
|
586
612
|
module_name = l1
|
587
613
|
module = modules.get(module_name)
|
614
|
+
if module is None and (var := variables.get(l1)) is not None:
|
615
|
+
return _get_inner_object(var, [node.attr])
|
588
616
|
case ast.Attribute():
|
589
|
-
levels = _find_root(node.value, [])
|
617
|
+
levels = _find_root(node.value, levels=[])
|
618
|
+
_LOGGER.debug("Attribute levels %s", levels)
|
590
619
|
module_name = ".".join(levels)
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
620
|
+
|
621
|
+
for idx, level in enumerate(levels[:-1]):
|
622
|
+
if (module := modules.get(level)) is not None:
|
623
|
+
modules.update(
|
624
|
+
dict(
|
625
|
+
inspect.getmembers(
|
626
|
+
module,
|
627
|
+
predicate=inspect.ismodule,
|
628
|
+
)
|
629
|
+
)
|
630
|
+
)
|
631
|
+
elif (var := variables.get(level)) is not None:
|
632
|
+
_LOGGER.debug("var %s", var)
|
633
|
+
return _get_inner_object(var, levels[(idx + 1) :] + [node.attr])
|
634
|
+
|
635
|
+
else:
|
636
|
+
_LOGGER.debug("No target found")
|
637
|
+
|
595
638
|
module = modules.get(levels[-1])
|
596
639
|
case _:
|
597
640
|
raise NotImplementedError
|
@@ -633,6 +676,7 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
|
633
676
|
if (expr := _handle_expr(i, ctx)) is None:
|
634
677
|
return None
|
635
678
|
model_args.append(expr)
|
679
|
+
_LOGGER.debug("Fn args: %s", model_args)
|
636
680
|
|
637
681
|
match node.func:
|
638
682
|
case ast.Name(id):
|
@@ -674,7 +718,7 @@ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
|
674
718
|
return None
|
675
719
|
|
676
720
|
if (fn := KNOWN_FNS.get(py_fn)) is not None:
|
677
|
-
return fn
|
721
|
+
return sympy.Float(fn(*model_args)) # type: ignore
|
678
722
|
|
679
723
|
return fn_to_sympy(
|
680
724
|
py_fn,
|
mxlpy/meta/sympy_tools.py
CHANGED
@@ -50,12 +50,12 @@ def sympy_to_inline_py(expr: sympy.Expr) -> str:
|
|
50
50
|
'x**2 + 2*x + 1'
|
51
51
|
|
52
52
|
"""
|
53
|
-
return cast(str, pycode(expr, fully_qualified_modules=True))
|
53
|
+
return cast(str, pycode(expr, fully_qualified_modules=True, full_prec=False))
|
54
54
|
|
55
55
|
|
56
56
|
def sympy_to_inline_rust(expr: sympy.Expr) -> str:
|
57
57
|
"""Create rust code from sympy expression."""
|
58
|
-
return cast(str, rust_code(expr))
|
58
|
+
return cast(str, rust_code(expr, full_prec=False))
|
59
59
|
|
60
60
|
|
61
61
|
def sympy_to_python_fn(
|
@@ -93,8 +93,8 @@ def sympy_to_python_fn(
|
|
93
93
|
fn_args = ", ".join(f"{i}: float" for i in args)
|
94
94
|
|
95
95
|
return f"""def {fn_name}({fn_args}) -> float:
|
96
|
-
return {pycode(expr)}
|
97
|
-
"""
|
96
|
+
return {pycode(expr, fully_qualified_modules=True, full_prec=False)}
|
97
|
+
""".replace("math.factorial", "scipy.special.factorial")
|
98
98
|
|
99
99
|
|
100
100
|
def stoichiometries_to_sympy(
|
@@ -114,4 +114,4 @@ def stoichiometries_to_sympy(
|
|
114
114
|
expr = expr + sympy_fn * sympy.Symbol(rxn_name) # type: ignore
|
115
115
|
else:
|
116
116
|
expr = expr + rxn_stoich * sympy.Symbol(rxn_name) # type: ignore
|
117
|
-
return expr
|
117
|
+
return expr.subs(1.0, 1) # type: ignore
|