mxlpy 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +5 -1
- mxlpy/compare.py +2 -6
- mxlpy/experimental/diff.py +1 -1
- mxlpy/fit.py +195 -99
- mxlpy/identify.py +14 -9
- mxlpy/integrators/int_scipy.py +3 -0
- mxlpy/label_map.py +5 -5
- mxlpy/linear_label_map.py +3 -1
- mxlpy/mc.py +3 -0
- mxlpy/mca.py +2 -2
- mxlpy/meta/__init__.py +5 -3
- mxlpy/meta/codegen_latex.py +44 -30
- mxlpy/meta/codegen_model.py +174 -0
- mxlpy/meta/{codegen_modebase.py → codegen_mxlpy.py} +35 -29
- mxlpy/meta/source_tools.py +408 -167
- mxlpy/meta/sympy_tools.py +117 -0
- mxlpy/model.py +528 -224
- mxlpy/report.py +153 -90
- mxlpy/sbml/_export.py +11 -8
- mxlpy/sbml/_import.py +7 -7
- mxlpy/scan.py +1 -1
- mxlpy/simulator.py +238 -57
- mxlpy/symbolic/symbolic_model.py +29 -17
- mxlpy/types.py +45 -20
- mxlpy/units.py +128 -0
- {mxlpy-0.21.0.dist-info → mxlpy-0.22.0.dist-info}/METADATA +1 -1
- {mxlpy-0.21.0.dist-info → mxlpy-0.22.0.dist-info}/RECORD +29 -27
- mxlpy/meta/codegen_py.py +0 -115
- {mxlpy-0.21.0.dist-info → mxlpy-0.22.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.21.0.dist-info → mxlpy-0.22.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/meta/source_tools.py
CHANGED
@@ -3,28 +3,166 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import ast
|
6
|
+
import importlib
|
6
7
|
import inspect
|
8
|
+
import logging
|
9
|
+
import math
|
7
10
|
import textwrap
|
8
11
|
from dataclasses import dataclass
|
12
|
+
from types import ModuleType
|
9
13
|
from typing import TYPE_CHECKING, Any, cast
|
10
14
|
|
11
15
|
import dill
|
16
|
+
import numpy as np
|
12
17
|
import sympy
|
13
|
-
from sympy.printing.pycode import pycode
|
14
18
|
|
15
19
|
if TYPE_CHECKING:
|
16
20
|
from collections.abc import Callable
|
17
|
-
from types import ModuleType
|
18
21
|
|
19
22
|
__all__ = [
|
20
23
|
"Context",
|
24
|
+
"KNOWN_CONSTANTS",
|
25
|
+
"KNOWN_FNS",
|
26
|
+
"PARSE_ERROR",
|
21
27
|
"fn_to_sympy",
|
22
28
|
"get_fn_ast",
|
23
29
|
"get_fn_source",
|
24
|
-
"sympy_to_fn",
|
25
|
-
"sympy_to_inline",
|
26
30
|
]
|
27
31
|
|
32
|
+
_LOGGER = logging.getLogger(__name__)
|
33
|
+
PARSE_ERROR = sympy.Symbol("ERROR")
|
34
|
+
|
35
|
+
KNOWN_CONSTANTS: dict[float, sympy.Float] = {
|
36
|
+
math.e: sympy.E,
|
37
|
+
math.pi: sympy.pi,
|
38
|
+
math.nan: sympy.nan,
|
39
|
+
math.tau: sympy.pi * 2,
|
40
|
+
math.inf: sympy.oo,
|
41
|
+
# numpy
|
42
|
+
np.e: sympy.E,
|
43
|
+
np.pi: sympy.pi,
|
44
|
+
np.nan: sympy.nan,
|
45
|
+
np.inf: sympy.oo,
|
46
|
+
}
|
47
|
+
|
48
|
+
KNOWN_FNS: dict[Callable, sympy.Expr] = {
|
49
|
+
# built-ins
|
50
|
+
abs: sympy.Abs, # type: ignore
|
51
|
+
min: sympy.Min,
|
52
|
+
max: sympy.Max,
|
53
|
+
pow: sympy.Pow,
|
54
|
+
# round: sympy
|
55
|
+
# divmod
|
56
|
+
# math module
|
57
|
+
math.acos: sympy.acos,
|
58
|
+
math.acosh: sympy.acosh,
|
59
|
+
math.asin: sympy.asin,
|
60
|
+
math.asinh: sympy.asinh,
|
61
|
+
math.atan: sympy.atan,
|
62
|
+
math.atan2: sympy.atan2,
|
63
|
+
math.atanh: sympy.atanh,
|
64
|
+
math.cbrt: sympy.cbrt,
|
65
|
+
math.ceil: sympy.ceiling,
|
66
|
+
# math.comb: sympy.comb,
|
67
|
+
# math.copysign: sympy.copysign,
|
68
|
+
math.cos: sympy.cos,
|
69
|
+
math.cosh: sympy.cosh,
|
70
|
+
# math.degrees: sympy.degrees,
|
71
|
+
# math.dist: sympy.dist,
|
72
|
+
math.erf: sympy.erf,
|
73
|
+
math.erfc: sympy.erfc,
|
74
|
+
math.exp: sympy.exp,
|
75
|
+
# math.exp2: sympy.exp2,
|
76
|
+
# math.expm1: sympy.expm1,
|
77
|
+
# math.fabs: sympy.fabs,
|
78
|
+
math.factorial: sympy.factorial,
|
79
|
+
math.floor: sympy.floor,
|
80
|
+
# math.fmod: sympy.fmod,
|
81
|
+
# math.frexp: sympy.frexp,
|
82
|
+
# math.fsum: sympy.fsum,
|
83
|
+
math.gamma: sympy.gamma,
|
84
|
+
math.gcd: sympy.gcd,
|
85
|
+
# math.hypot: sympy.hypot,
|
86
|
+
# math.isclose: sympy.isclose,
|
87
|
+
# math.isfinite: sympy.isfinite,
|
88
|
+
# math.isinf: sympy.isinf,
|
89
|
+
# math.isnan: sympy.isnan,
|
90
|
+
# math.isqrt: sympy.isqrt,
|
91
|
+
math.lcm: sympy.lcm,
|
92
|
+
# math.ldexp: sympy.ldexp,
|
93
|
+
# math.lgamma: sympy.lgamma,
|
94
|
+
math.log: sympy.log,
|
95
|
+
# math.log10: sympy.log10,
|
96
|
+
# math.log1p: sympy.log1p,
|
97
|
+
# math.log2: sympy.log2,
|
98
|
+
# math.modf: sympy.modf,
|
99
|
+
# math.nextafter: sympy.nextafter,
|
100
|
+
# math.perm: sympy.perm,
|
101
|
+
math.pow: sympy.Pow,
|
102
|
+
math.prod: sympy.prod,
|
103
|
+
math.radians: sympy.rad,
|
104
|
+
math.remainder: sympy.rem,
|
105
|
+
math.sin: sympy.sin,
|
106
|
+
math.sinh: sympy.sinh,
|
107
|
+
math.sqrt: sympy.sqrt,
|
108
|
+
# math.sumprod: sympy.sumprod,
|
109
|
+
math.tan: sympy.tan,
|
110
|
+
math.tanh: sympy.tanh,
|
111
|
+
math.trunc: sympy.trunc,
|
112
|
+
# math.ulp: sympy.ulp,
|
113
|
+
# numpy
|
114
|
+
np.exp: sympy.exp,
|
115
|
+
np.abs: sympy.Abs,
|
116
|
+
np.acos: sympy.acos,
|
117
|
+
np.acosh: sympy.acosh,
|
118
|
+
np.asin: sympy.asin,
|
119
|
+
np.asinh: sympy.asinh,
|
120
|
+
np.atan: sympy.atan,
|
121
|
+
np.atanh: sympy.atanh,
|
122
|
+
np.atan2: sympy.atan2,
|
123
|
+
np.pow: sympy.Pow,
|
124
|
+
np.absolute: sympy.Abs,
|
125
|
+
np.add: sympy.Add,
|
126
|
+
np.arccos: sympy.acos,
|
127
|
+
np.arccosh: sympy.acosh,
|
128
|
+
np.arcsin: sympy.asin,
|
129
|
+
np.arcsinh: sympy.asinh,
|
130
|
+
np.arctan2: sympy.atan2,
|
131
|
+
np.arctan: sympy.atan,
|
132
|
+
np.arctanh: sympy.atanh,
|
133
|
+
np.cbrt: sympy.cbrt,
|
134
|
+
np.ceil: sympy.ceiling,
|
135
|
+
np.conjugate: sympy.conjugate,
|
136
|
+
np.cos: sympy.cos,
|
137
|
+
np.cosh: sympy.cosh,
|
138
|
+
np.exp: sympy.exp,
|
139
|
+
np.floor: sympy.floor,
|
140
|
+
np.gcd: sympy.gcd,
|
141
|
+
np.greater: sympy.GreaterThan,
|
142
|
+
np.greater_equal: sympy.Ge,
|
143
|
+
np.invert: sympy.invert,
|
144
|
+
np.lcm: sympy.lcm,
|
145
|
+
np.less: sympy.LessThan,
|
146
|
+
np.less_equal: sympy.Le,
|
147
|
+
np.log: sympy.log,
|
148
|
+
np.maximum: sympy.maximum,
|
149
|
+
np.minimum: sympy.minimum,
|
150
|
+
np.mod: sympy.Mod,
|
151
|
+
np.positive: sympy.Abs,
|
152
|
+
np.power: sympy.Pow,
|
153
|
+
np.sign: sympy.sign,
|
154
|
+
np.sin: sympy.sin,
|
155
|
+
np.sinh: sympy.sinh,
|
156
|
+
np.sqrt: sympy.sqrt,
|
157
|
+
# np.square: sympy.square,
|
158
|
+
# np.subtract: sympy., # Add(x, -1 * y)
|
159
|
+
np.tan: sympy.tan,
|
160
|
+
np.tanh: sympy.tanh,
|
161
|
+
# np.true_divide: sympy.true_divide,
|
162
|
+
np.trunc: sympy.trunc,
|
163
|
+
# np.vecdot: sympy.vecdot,
|
164
|
+
}
|
165
|
+
|
28
166
|
|
29
167
|
@dataclass
|
30
168
|
class Context:
|
@@ -33,6 +171,9 @@ class Context:
|
|
33
171
|
symbols: dict[str, sympy.Symbol | sympy.Expr]
|
34
172
|
caller: Callable
|
35
173
|
parent_module: ModuleType | None
|
174
|
+
origin: str
|
175
|
+
modules: dict[str, ModuleType]
|
176
|
+
fns: dict[str, Callable]
|
36
177
|
|
37
178
|
def updated(
|
38
179
|
self,
|
@@ -47,9 +188,23 @@ class Context:
|
|
47
188
|
parent_module=self.parent_module
|
48
189
|
if parent_module is None
|
49
190
|
else parent_module,
|
191
|
+
origin=self.origin,
|
192
|
+
modules=self.modules,
|
193
|
+
fns=self.fns,
|
50
194
|
)
|
51
195
|
|
52
196
|
|
197
|
+
def _find_root(value: ast.Attribute | ast.Name, levels: list) -> list[str]:
|
198
|
+
if isinstance(value, ast.Attribute):
|
199
|
+
return _find_root(
|
200
|
+
cast(ast.Attribute, value.value),
|
201
|
+
[value.attr, *levels],
|
202
|
+
)
|
203
|
+
|
204
|
+
root = str(value.id)
|
205
|
+
return [root, *levels]
|
206
|
+
|
207
|
+
|
53
208
|
def get_fn_source(fn: Callable) -> str:
|
54
209
|
"""Get the string representation of a function.
|
55
210
|
|
@@ -110,121 +265,80 @@ def get_fn_ast(fn: Callable) -> ast.FunctionDef:
|
|
110
265
|
return fn_def
|
111
266
|
|
112
267
|
|
113
|
-
def sympy_to_inline(expr: sympy.Expr) -> str:
|
114
|
-
"""Convert a sympy expression to inline Python code.
|
115
|
-
|
116
|
-
Parameters
|
117
|
-
----------
|
118
|
-
expr
|
119
|
-
The sympy expression to convert
|
120
|
-
|
121
|
-
Returns
|
122
|
-
-------
|
123
|
-
str
|
124
|
-
Python code string for the expression
|
125
|
-
|
126
|
-
Examples
|
127
|
-
--------
|
128
|
-
>>> import sympy
|
129
|
-
>>> x = sympy.Symbol('x')
|
130
|
-
>>> expr = x**2 + 2*x + 1
|
131
|
-
>>> sympy_to_inline(expr)
|
132
|
-
'x**2 + 2*x + 1'
|
133
|
-
|
134
|
-
"""
|
135
|
-
return cast(str, pycode(expr, fully_qualified_modules=True))
|
136
|
-
|
137
|
-
|
138
|
-
def sympy_to_fn(
|
139
|
-
*,
|
140
|
-
fn_name: str,
|
141
|
-
args: list[str],
|
142
|
-
expr: sympy.Expr,
|
143
|
-
) -> str:
|
144
|
-
"""Convert a sympy expression to a python function.
|
145
|
-
|
146
|
-
Parameters
|
147
|
-
----------
|
148
|
-
fn_name
|
149
|
-
Name of the function to generate
|
150
|
-
args
|
151
|
-
List of argument names for the function
|
152
|
-
expr
|
153
|
-
Sympy expression to convert to a function body
|
154
|
-
|
155
|
-
Returns
|
156
|
-
-------
|
157
|
-
str
|
158
|
-
String representation of the generated function
|
159
|
-
|
160
|
-
Examples
|
161
|
-
--------
|
162
|
-
>>> import sympy
|
163
|
-
>>> x, y = sympy.symbols('x y')
|
164
|
-
>>> expr = x**2 + y
|
165
|
-
>>> print(sympy_to_fn(fn_name="square_plus_y", args=["x", "y"], expr=expr))
|
166
|
-
def square_plus_y(x: float, y: float) -> float:
|
167
|
-
return x**2 + y
|
168
|
-
|
169
|
-
"""
|
170
|
-
fn_args = ", ".join(f"{i}: float" for i in args)
|
171
|
-
|
172
|
-
return f"""def {fn_name}({fn_args}) -> float:
|
173
|
-
return {pycode(expr)}
|
174
|
-
"""
|
175
|
-
|
176
|
-
|
177
268
|
def fn_to_sympy(
|
178
269
|
fn: Callable,
|
270
|
+
origin: str,
|
179
271
|
model_args: list[sympy.Symbol | sympy.Expr] | None = None,
|
180
|
-
) -> sympy.Expr:
|
272
|
+
) -> sympy.Expr | None:
|
181
273
|
"""Convert a python function to a sympy expression.
|
182
274
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
model_args
|
188
|
-
Optional list of sympy symbols to substitute for function arguments
|
275
|
+
Args:
|
276
|
+
fn: The function to convert
|
277
|
+
origin: Name of the original caller. Used for error messages.
|
278
|
+
model_args: Optional list of sympy symbols to substitute for function arguments
|
189
279
|
|
190
|
-
Returns
|
191
|
-
-------
|
192
|
-
sympy.Expr
|
280
|
+
Returns:
|
193
281
|
Sympy expression equivalent to the function
|
194
282
|
|
195
|
-
Examples
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
y**2
|
283
|
+
Examples:
|
284
|
+
>>> def square_fn(x):
|
285
|
+
... return x**2
|
286
|
+
>>> import sympy
|
287
|
+
>>> fn_to_sympy(square_fn)
|
288
|
+
x**2
|
289
|
+
>>> # With model_args
|
290
|
+
>>> y = sympy.Symbol('y')
|
291
|
+
>>> fn_to_sympy(square_fn, [y])
|
292
|
+
y**2
|
206
293
|
|
207
294
|
"""
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
295
|
+
try:
|
296
|
+
fn_def = get_fn_ast(fn)
|
297
|
+
fn_args = [str(arg.arg) for arg in fn_def.args.args]
|
298
|
+
|
299
|
+
sympy_expr = _handle_fn_body(
|
300
|
+
fn_def.body,
|
301
|
+
ctx=Context(
|
302
|
+
symbols={name: sympy.Symbol(name) for name in fn_args},
|
303
|
+
caller=fn,
|
304
|
+
parent_module=inspect.getmodule(fn),
|
305
|
+
origin=origin,
|
306
|
+
modules={},
|
307
|
+
fns={},
|
308
|
+
),
|
309
|
+
)
|
310
|
+
if sympy_expr is None:
|
311
|
+
return None
|
312
|
+
# FIXME: we shouldn't end up here, where does this come from?
|
313
|
+
if isinstance(sympy_expr, float):
|
314
|
+
return sympy.Float(sympy_expr)
|
315
|
+
if model_args is not None:
|
316
|
+
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
317
|
+
return cast(sympy.Expr, sympy_expr)
|
318
|
+
|
319
|
+
except (TypeError, ValueError, NotImplementedError) as e:
|
320
|
+
msg = f"Failed parsing function of {origin}"
|
321
|
+
_LOGGER.warning(msg)
|
322
|
+
_LOGGER.debug("", exc_info=e)
|
323
|
+
return None
|
221
324
|
|
222
325
|
|
223
326
|
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
224
|
-
|
327
|
+
value = ctx.symbols.get(node.id)
|
328
|
+
if value is None:
|
329
|
+
global_variables = dict(
|
330
|
+
inspect.getmembers(
|
331
|
+
ctx.parent_module,
|
332
|
+
predicate=lambda x: isinstance(x, float),
|
333
|
+
)
|
334
|
+
)
|
335
|
+
value = sympy.Float(global_variables[node.id])
|
336
|
+
return value
|
225
337
|
|
226
338
|
|
227
|
-
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
339
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
340
|
+
if isinstance(node, float):
|
341
|
+
return sympy.Float(node)
|
228
342
|
if isinstance(node, ast.UnaryOp):
|
229
343
|
return _handle_unaryop(node, ctx)
|
230
344
|
if isinstance(node, ast.BinOp):
|
@@ -233,6 +347,11 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
|
233
347
|
return _handle_name(node, ctx)
|
234
348
|
if isinstance(node, ast.Constant):
|
235
349
|
return node.value
|
350
|
+
if isinstance(node, ast.Call):
|
351
|
+
return _handle_call(node, ctx=ctx)
|
352
|
+
if isinstance(node, ast.Attribute):
|
353
|
+
return _handle_attribute(node, ctx=ctx)
|
354
|
+
|
236
355
|
if isinstance(node, ast.Compare):
|
237
356
|
# Handle chained comparisons like 1 < a < 2
|
238
357
|
left = cast(Any, _handle_expr(node.left, ctx))
|
@@ -263,8 +382,6 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
|
263
382
|
for comp in comparisons[1:]:
|
264
383
|
result = sympy.And(result, comp)
|
265
384
|
return cast(sympy.Expr, result)
|
266
|
-
if isinstance(node, ast.Call):
|
267
|
-
return _handle_call(node, ctx)
|
268
385
|
|
269
386
|
# Handle conditional expressions (ternary operators)
|
270
387
|
if isinstance(node, ast.IfExp):
|
@@ -277,7 +394,7 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
|
|
277
394
|
raise NotImplementedError(msg)
|
278
395
|
|
279
396
|
|
280
|
-
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
397
|
+
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
281
398
|
pieces = []
|
282
399
|
remaining_body = list(body)
|
283
400
|
|
@@ -333,7 +450,10 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
|
333
450
|
target_elements, value_elements, strict=True
|
334
451
|
):
|
335
452
|
if isinstance(target, ast.Name):
|
336
|
-
|
453
|
+
expr = _handle_expr(value_expr, ctx)
|
454
|
+
if expr is None:
|
455
|
+
return None
|
456
|
+
ctx.symbols[target.id] = expr
|
337
457
|
else:
|
338
458
|
# Handle potential iterable unpacking
|
339
459
|
value = _handle_expr(node.value, ctx)
|
@@ -344,8 +464,33 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
|
344
464
|
raise TypeError(msg)
|
345
465
|
target_name = target.id
|
346
466
|
value = _handle_expr(node.value, ctx)
|
467
|
+
if value is None:
|
468
|
+
return None
|
347
469
|
ctx.symbols[target_name] = value
|
348
470
|
|
471
|
+
elif isinstance(node, ast.Import):
|
472
|
+
for alias in node.names:
|
473
|
+
name = alias.name
|
474
|
+
ctx.modules[name] = importlib.import_module(name)
|
475
|
+
|
476
|
+
elif isinstance(node, ast.ImportFrom):
|
477
|
+
package = cast(str, node.module)
|
478
|
+
module = importlib.import_module(package)
|
479
|
+
contents = dict(inspect.getmembers(module))
|
480
|
+
for alias in node.names:
|
481
|
+
name = alias.name
|
482
|
+
el = contents[name]
|
483
|
+
if isinstance(el, float):
|
484
|
+
ctx.symbols[name] = sympy.Float(el)
|
485
|
+
elif callable(el):
|
486
|
+
ctx.fns[name] = el
|
487
|
+
elif isinstance(el, ModuleType):
|
488
|
+
ctx.modules[name] = el
|
489
|
+
else:
|
490
|
+
_LOGGER.debug("Skipping import %s", node)
|
491
|
+
else:
|
492
|
+
_LOGGER.debug("Skipping node of type %s", type(node))
|
493
|
+
|
349
494
|
# If we have pieces to combine into a Piecewise
|
350
495
|
if pieces:
|
351
496
|
return sympy.Piecewise(*pieces)
|
@@ -364,13 +509,14 @@ def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
|
364
509
|
left = _handle_expr(node.operand, ctx)
|
365
510
|
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
366
511
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
512
|
+
match node.op:
|
513
|
+
case ast.UAdd():
|
514
|
+
return +left
|
515
|
+
case ast.USub():
|
516
|
+
return -left
|
517
|
+
case _:
|
518
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
519
|
+
raise NotImplementedError(msg)
|
374
520
|
|
375
521
|
|
376
522
|
def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
@@ -380,63 +526,158 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
|
380
526
|
right = _handle_expr(node.right, ctx)
|
381
527
|
right = cast(Any, right) # stupid sympy types don't allow ops on symbols
|
382
528
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
529
|
+
match node.op:
|
530
|
+
case ast.Add():
|
531
|
+
return left + right
|
532
|
+
case ast.Sub():
|
533
|
+
return left - right
|
534
|
+
case ast.Mult():
|
535
|
+
return left * right
|
536
|
+
case ast.Div():
|
537
|
+
return left / right
|
538
|
+
case ast.Pow():
|
539
|
+
return left**right
|
540
|
+
case ast.Mod():
|
541
|
+
return left % right
|
542
|
+
case ast.FloorDiv():
|
543
|
+
return left // right
|
544
|
+
case _:
|
545
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
546
|
+
raise NotImplementedError(msg)
|
547
|
+
|
548
|
+
|
549
|
+
# FIXME: check if target isn't an object or class
|
550
|
+
def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
551
|
+
"""Handle an attribute.
|
552
|
+
|
553
|
+
Structures to expect:
|
554
|
+
Attribute(Name(id), attr) | direct
|
555
|
+
Attribute(Attribute(Name(id)), attr) | single layer of nesting
|
556
|
+
Attribute(Attribute(...), attr) | arbitrary nesting
|
557
|
+
|
558
|
+
Targets to expect:
|
559
|
+
- modules (both absolute and relative import)
|
560
|
+
- import a; a.attr
|
561
|
+
- import a; a.b.attr
|
562
|
+
- from a import b; b.attr
|
563
|
+
- objects, e.g. Parameters().a
|
564
|
+
- classes, e.g. Parameters.a
|
565
|
+
|
566
|
+
Watch out for relative imports and the different ways they can be called
|
567
|
+
import a
|
568
|
+
from a import b
|
569
|
+
from a.b import c
|
570
|
+
|
571
|
+
a.attr
|
572
|
+
b.attr
|
573
|
+
c.attr
|
574
|
+
a.b.attr
|
575
|
+
b.c.attr
|
576
|
+
a.b.c.attr
|
577
|
+
"""
|
578
|
+
name = str(node.attr)
|
579
|
+
module: ModuleType | None = None
|
580
|
+
modules = (
|
581
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
582
|
+
| ctx.modules
|
583
|
+
)
|
584
|
+
match node.value:
|
585
|
+
case ast.Name(l1):
|
586
|
+
module_name = l1
|
587
|
+
module = modules.get(module_name)
|
588
|
+
case ast.Attribute():
|
589
|
+
levels = _find_root(node.value, [])
|
590
|
+
module_name = ".".join(levels)
|
591
|
+
for level in levels[:-1]:
|
592
|
+
modules.update(
|
593
|
+
dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
|
594
|
+
)
|
595
|
+
module = modules.get(levels[-1])
|
596
|
+
case _:
|
597
|
+
raise NotImplementedError
|
598
|
+
|
599
|
+
# Fall-back to absolute import
|
600
|
+
if module is None:
|
601
|
+
module = importlib.import_module(module_name)
|
602
|
+
|
603
|
+
element = dict(
|
604
|
+
inspect.getmembers(
|
605
|
+
module,
|
606
|
+
predicate=lambda x: isinstance(x, float),
|
607
|
+
)
|
608
|
+
).get(name)
|
400
609
|
|
610
|
+
if element is None:
|
611
|
+
return None
|
401
612
|
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
fn_name = str(callee.id)
|
406
|
-
fns = dict(inspect.getmembers(ctx.parent_module, predicate=callable))
|
613
|
+
if (value := KNOWN_CONSTANTS.get(element)) is not None:
|
614
|
+
return value
|
615
|
+
return sympy.Float(element)
|
407
616
|
|
408
|
-
return fn_to_sympy(
|
409
|
-
fns[fn_name],
|
410
|
-
model_args=[_handle_expr(i, ctx) for i in node.args],
|
411
|
-
)
|
412
617
|
|
413
|
-
|
414
|
-
|
415
|
-
|
618
|
+
# FIXME: check if target isn't an object or class
|
619
|
+
def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
620
|
+
"""Handle call expression.
|
416
621
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
ctx=ctx.updated(parent_module=imports[module_name.id]),
|
422
|
-
)
|
622
|
+
Variants
|
623
|
+
- mass_action(x, k1)
|
624
|
+
- fns.mass_action(x, k1)
|
625
|
+
- mxlpy.fns.mass_action(x, k1)
|
423
626
|
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
ctx
|
627
|
+
In future think about?
|
628
|
+
- object.call
|
629
|
+
- Class.call
|
630
|
+
"""
|
631
|
+
model_args: list[sympy.Expr] = []
|
632
|
+
for i in node.args:
|
633
|
+
if (expr := _handle_expr(i, ctx)) is None:
|
634
|
+
return None
|
635
|
+
model_args.append(expr)
|
636
|
+
|
637
|
+
match node.func:
|
638
|
+
case ast.Name(id):
|
639
|
+
fn_name = str(id)
|
640
|
+
fns = (
|
641
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=callable))
|
642
|
+
| ctx.fns
|
643
|
+
)
|
644
|
+
py_fn = fns.get(fn_name)
|
645
|
+
|
646
|
+
# FIXME: use _handle_attribute for this
|
647
|
+
case ast.Attribute(attr=fn_name):
|
648
|
+
module: ModuleType | None = None
|
649
|
+
modules = (
|
650
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
651
|
+
| ctx.modules
|
439
652
|
)
|
440
653
|
|
441
|
-
|
442
|
-
|
654
|
+
levels = _find_root(node.func, [])
|
655
|
+
module_name = ".".join(levels[:-1])
|
656
|
+
|
657
|
+
_LOGGER.debug("Searching for module %s", module_name)
|
658
|
+
for level in levels[:-1]:
|
659
|
+
modules.update(
|
660
|
+
dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
|
661
|
+
)
|
662
|
+
module = modules.get(levels[-2])
|
663
|
+
|
664
|
+
# Fall-back to absolute import
|
665
|
+
if module is None:
|
666
|
+
module = importlib.import_module(module_name)
|
667
|
+
|
668
|
+
fns = dict(inspect.getmembers(module, predicate=callable))
|
669
|
+
py_fn = fns.get(fn_name)
|
670
|
+
case _:
|
671
|
+
raise NotImplementedError
|
672
|
+
|
673
|
+
if py_fn is None:
|
674
|
+
return None
|
675
|
+
|
676
|
+
if (fn := KNOWN_FNS.get(py_fn)) is not None:
|
677
|
+
return fn
|
678
|
+
|
679
|
+
return fn_to_sympy(
|
680
|
+
py_fn,
|
681
|
+
origin=ctx.origin,
|
682
|
+
model_args=model_args,
|
683
|
+
)
|