mxlpy 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mxlpy/__init__.py +15 -2
- mxlpy/carousel.py +6 -4
- mxlpy/compare.py +4 -8
- mxlpy/experimental/diff.py +1 -1
- mxlpy/fit.py +195 -99
- mxlpy/identify.py +14 -9
- mxlpy/integrators/__init__.py +4 -0
- mxlpy/integrators/int_assimulo.py +3 -3
- mxlpy/integrators/int_diffrax.py +119 -0
- mxlpy/integrators/int_scipy.py +15 -6
- mxlpy/label_map.py +6 -7
- mxlpy/linear_label_map.py +3 -1
- mxlpy/mc.py +25 -22
- mxlpy/mca.py +10 -6
- mxlpy/meta/__init__.py +5 -3
- mxlpy/meta/codegen_latex.py +44 -30
- mxlpy/meta/codegen_model.py +175 -0
- mxlpy/meta/codegen_mxlpy.py +254 -0
- mxlpy/meta/source_tools.py +506 -221
- mxlpy/meta/sympy_tools.py +117 -0
- mxlpy/model.py +758 -257
- mxlpy/plot.py +16 -14
- mxlpy/report.py +153 -90
- mxlpy/sbml/_export.py +22 -11
- mxlpy/sbml/_import.py +68 -547
- mxlpy/scan.py +39 -243
- mxlpy/simulator.py +109 -283
- mxlpy/symbolic/symbolic_model.py +29 -17
- mxlpy/types.py +694 -97
- mxlpy/units.py +133 -0
- {mxlpy-0.21.0.dist-info → mxlpy-0.23.0.dist-info}/METADATA +4 -1
- mxlpy-0.23.0.dist-info/RECORD +57 -0
- mxlpy/meta/codegen_modebase.py +0 -112
- mxlpy/meta/codegen_py.py +0 -115
- mxlpy/sbml/_mathml.py +0 -692
- mxlpy/sbml/_unit_conversion.py +0 -74
- mxlpy-0.21.0.dist-info/RECORD +0 -56
- {mxlpy-0.21.0.dist-info → mxlpy-0.23.0.dist-info}/WHEEL +0 -0
- {mxlpy-0.21.0.dist-info → mxlpy-0.23.0.dist-info}/licenses/LICENSE +0 -0
mxlpy/meta/source_tools.py
CHANGED
@@ -3,28 +3,165 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import ast
|
6
|
+
import importlib
|
6
7
|
import inspect
|
8
|
+
import logging
|
9
|
+
import math
|
7
10
|
import textwrap
|
8
11
|
from dataclasses import dataclass
|
12
|
+
from types import ModuleType
|
9
13
|
from typing import TYPE_CHECKING, Any, cast
|
10
14
|
|
11
15
|
import dill
|
16
|
+
import numpy as np
|
12
17
|
import sympy
|
13
|
-
from sympy.printing.pycode import pycode
|
14
18
|
|
15
19
|
if TYPE_CHECKING:
|
16
20
|
from collections.abc import Callable
|
17
|
-
from types import ModuleType
|
18
21
|
|
19
22
|
__all__ = [
|
20
23
|
"Context",
|
24
|
+
"KNOWN_CONSTANTS",
|
25
|
+
"KNOWN_FNS",
|
26
|
+
"PARSE_ERROR",
|
21
27
|
"fn_to_sympy",
|
22
28
|
"get_fn_ast",
|
23
29
|
"get_fn_source",
|
24
|
-
"sympy_to_fn",
|
25
|
-
"sympy_to_inline",
|
26
30
|
]
|
27
31
|
|
32
|
+
_LOGGER = logging.getLogger(__name__)
|
33
|
+
PARSE_ERROR = sympy.Symbol("ERROR")
|
34
|
+
|
35
|
+
KNOWN_CONSTANTS: dict[float, sympy.Float] = {
|
36
|
+
math.e: sympy.E,
|
37
|
+
math.pi: sympy.pi,
|
38
|
+
math.nan: sympy.nan,
|
39
|
+
math.tau: sympy.pi * 2,
|
40
|
+
math.inf: sympy.oo,
|
41
|
+
# numpy
|
42
|
+
np.e: sympy.E,
|
43
|
+
np.pi: sympy.pi,
|
44
|
+
np.nan: sympy.nan,
|
45
|
+
np.inf: sympy.oo,
|
46
|
+
}
|
47
|
+
|
48
|
+
KNOWN_FNS: dict[Callable, sympy.Expr] = {
|
49
|
+
# built-ins
|
50
|
+
abs: sympy.Abs, # type: ignore
|
51
|
+
min: sympy.Min,
|
52
|
+
max: sympy.Max,
|
53
|
+
pow: sympy.Pow,
|
54
|
+
# round: sympy
|
55
|
+
# divmod
|
56
|
+
# math module
|
57
|
+
math.acos: sympy.acos,
|
58
|
+
math.acosh: sympy.acosh,
|
59
|
+
math.asin: sympy.asin,
|
60
|
+
math.asinh: sympy.asinh,
|
61
|
+
math.atan: sympy.atan,
|
62
|
+
math.atan2: sympy.atan2,
|
63
|
+
math.atanh: sympy.atanh,
|
64
|
+
math.cbrt: sympy.cbrt,
|
65
|
+
math.ceil: sympy.ceiling,
|
66
|
+
# math.comb: sympy.comb,
|
67
|
+
# math.copysign: sympy.copysign,
|
68
|
+
math.cos: sympy.cos,
|
69
|
+
math.cosh: sympy.cosh,
|
70
|
+
# math.degrees: sympy.degrees,
|
71
|
+
# math.dist: sympy.dist,
|
72
|
+
math.erf: sympy.erf,
|
73
|
+
math.erfc: sympy.erfc,
|
74
|
+
math.exp: sympy.exp,
|
75
|
+
# math.exp2: sympy.exp2,
|
76
|
+
# math.expm1: sympy.expm1,
|
77
|
+
# math.fabs: sympy.fabs,
|
78
|
+
math.factorial: sympy.factorial,
|
79
|
+
math.floor: sympy.floor,
|
80
|
+
# math.fmod: sympy.fmod,
|
81
|
+
# math.frexp: sympy.frexp,
|
82
|
+
# math.fsum: sympy.fsum,
|
83
|
+
math.gamma: sympy.gamma,
|
84
|
+
math.gcd: sympy.gcd,
|
85
|
+
# math.hypot: sympy.hypot,
|
86
|
+
# math.isclose: sympy.isclose,
|
87
|
+
# math.isfinite: sympy.isfinite,
|
88
|
+
# math.isinf: sympy.isinf,
|
89
|
+
# math.isnan: sympy.isnan,
|
90
|
+
# math.isqrt: sympy.isqrt,
|
91
|
+
math.lcm: sympy.lcm,
|
92
|
+
# math.ldexp: sympy.ldexp,
|
93
|
+
# math.lgamma: sympy.lgamma,
|
94
|
+
math.log: sympy.log,
|
95
|
+
# math.log10: sympy.log10,
|
96
|
+
# math.log1p: sympy.log1p,
|
97
|
+
# math.log2: sympy.log2,
|
98
|
+
# math.modf: sympy.modf,
|
99
|
+
# math.nextafter: sympy.nextafter,
|
100
|
+
# math.perm: sympy.perm,
|
101
|
+
math.pow: sympy.Pow,
|
102
|
+
math.prod: sympy.prod,
|
103
|
+
math.radians: sympy.rad,
|
104
|
+
math.remainder: sympy.rem,
|
105
|
+
math.sin: sympy.sin,
|
106
|
+
math.sinh: sympy.sinh,
|
107
|
+
math.sqrt: sympy.sqrt,
|
108
|
+
# math.sumprod: sympy.sumprod,
|
109
|
+
math.tan: sympy.tan,
|
110
|
+
math.tanh: sympy.tanh,
|
111
|
+
math.trunc: sympy.trunc,
|
112
|
+
# math.ulp: sympy.ulp,
|
113
|
+
# numpy
|
114
|
+
np.abs: sympy.Abs,
|
115
|
+
np.acos: sympy.acos,
|
116
|
+
np.acosh: sympy.acosh,
|
117
|
+
np.asin: sympy.asin,
|
118
|
+
np.asinh: sympy.asinh,
|
119
|
+
np.atan: sympy.atan,
|
120
|
+
np.atanh: sympy.atanh,
|
121
|
+
np.atan2: sympy.atan2,
|
122
|
+
np.pow: sympy.Pow,
|
123
|
+
np.absolute: sympy.Abs,
|
124
|
+
np.add: sympy.Add,
|
125
|
+
np.arccos: sympy.acos,
|
126
|
+
np.arccosh: sympy.acosh,
|
127
|
+
np.arcsin: sympy.asin,
|
128
|
+
np.arcsinh: sympy.asinh,
|
129
|
+
np.arctan2: sympy.atan2,
|
130
|
+
np.arctan: sympy.atan,
|
131
|
+
np.arctanh: sympy.atanh,
|
132
|
+
np.cbrt: sympy.cbrt,
|
133
|
+
np.ceil: sympy.ceiling,
|
134
|
+
np.conjugate: sympy.conjugate,
|
135
|
+
np.cos: sympy.cos,
|
136
|
+
np.cosh: sympy.cosh,
|
137
|
+
np.exp: sympy.exp,
|
138
|
+
np.floor: sympy.floor,
|
139
|
+
np.gcd: sympy.gcd,
|
140
|
+
np.greater: sympy.GreaterThan,
|
141
|
+
np.greater_equal: sympy.Ge,
|
142
|
+
np.invert: sympy.invert,
|
143
|
+
np.lcm: sympy.lcm,
|
144
|
+
np.less: sympy.LessThan,
|
145
|
+
np.less_equal: sympy.Le,
|
146
|
+
np.log: sympy.log,
|
147
|
+
np.maximum: sympy.maximum,
|
148
|
+
np.minimum: sympy.minimum,
|
149
|
+
np.mod: sympy.Mod,
|
150
|
+
np.positive: sympy.Abs,
|
151
|
+
np.power: sympy.Pow,
|
152
|
+
np.sign: sympy.sign,
|
153
|
+
np.sin: sympy.sin,
|
154
|
+
np.sinh: sympy.sinh,
|
155
|
+
np.sqrt: sympy.sqrt,
|
156
|
+
# np.square: sympy.square,
|
157
|
+
# np.subtract: sympy., # Add(x, -1 * y)
|
158
|
+
np.tan: sympy.tan,
|
159
|
+
np.tanh: sympy.tanh,
|
160
|
+
# np.true_divide: sympy.true_divide,
|
161
|
+
np.trunc: sympy.trunc,
|
162
|
+
# np.vecdot: sympy.vecdot,
|
163
|
+
}
|
164
|
+
|
28
165
|
|
29
166
|
@dataclass
|
30
167
|
class Context:
|
@@ -33,6 +170,9 @@ class Context:
|
|
33
170
|
symbols: dict[str, sympy.Symbol | sympy.Expr]
|
34
171
|
caller: Callable
|
35
172
|
parent_module: ModuleType | None
|
173
|
+
origin: str
|
174
|
+
modules: dict[str, ModuleType]
|
175
|
+
fns: dict[str, Callable]
|
36
176
|
|
37
177
|
def updated(
|
38
178
|
self,
|
@@ -47,8 +187,22 @@ class Context:
|
|
47
187
|
parent_module=self.parent_module
|
48
188
|
if parent_module is None
|
49
189
|
else parent_module,
|
190
|
+
origin=self.origin,
|
191
|
+
modules=self.modules,
|
192
|
+
fns=self.fns,
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
def _find_root(value: ast.Attribute | ast.Name, levels: list) -> list[str]:
|
197
|
+
if isinstance(value, ast.Attribute):
|
198
|
+
return _find_root(
|
199
|
+
cast(ast.Attribute, value.value),
|
200
|
+
[value.attr, *levels],
|
50
201
|
)
|
51
202
|
|
203
|
+
root = str(value.id)
|
204
|
+
return [root, *levels]
|
205
|
+
|
52
206
|
|
53
207
|
def get_fn_source(fn: Callable) -> str:
|
54
208
|
"""Get the string representation of a function.
|
@@ -110,174 +264,65 @@ def get_fn_ast(fn: Callable) -> ast.FunctionDef:
|
|
110
264
|
return fn_def
|
111
265
|
|
112
266
|
|
113
|
-
def sympy_to_inline(expr: sympy.Expr) -> str:
|
114
|
-
"""Convert a sympy expression to inline Python code.
|
115
|
-
|
116
|
-
Parameters
|
117
|
-
----------
|
118
|
-
expr
|
119
|
-
The sympy expression to convert
|
120
|
-
|
121
|
-
Returns
|
122
|
-
-------
|
123
|
-
str
|
124
|
-
Python code string for the expression
|
125
|
-
|
126
|
-
Examples
|
127
|
-
--------
|
128
|
-
>>> import sympy
|
129
|
-
>>> x = sympy.Symbol('x')
|
130
|
-
>>> expr = x**2 + 2*x + 1
|
131
|
-
>>> sympy_to_inline(expr)
|
132
|
-
'x**2 + 2*x + 1'
|
133
|
-
|
134
|
-
"""
|
135
|
-
return cast(str, pycode(expr, fully_qualified_modules=True))
|
136
|
-
|
137
|
-
|
138
|
-
def sympy_to_fn(
|
139
|
-
*,
|
140
|
-
fn_name: str,
|
141
|
-
args: list[str],
|
142
|
-
expr: sympy.Expr,
|
143
|
-
) -> str:
|
144
|
-
"""Convert a sympy expression to a python function.
|
145
|
-
|
146
|
-
Parameters
|
147
|
-
----------
|
148
|
-
fn_name
|
149
|
-
Name of the function to generate
|
150
|
-
args
|
151
|
-
List of argument names for the function
|
152
|
-
expr
|
153
|
-
Sympy expression to convert to a function body
|
154
|
-
|
155
|
-
Returns
|
156
|
-
-------
|
157
|
-
str
|
158
|
-
String representation of the generated function
|
159
|
-
|
160
|
-
Examples
|
161
|
-
--------
|
162
|
-
>>> import sympy
|
163
|
-
>>> x, y = sympy.symbols('x y')
|
164
|
-
>>> expr = x**2 + y
|
165
|
-
>>> print(sympy_to_fn(fn_name="square_plus_y", args=["x", "y"], expr=expr))
|
166
|
-
def square_plus_y(x: float, y: float) -> float:
|
167
|
-
return x**2 + y
|
168
|
-
|
169
|
-
"""
|
170
|
-
fn_args = ", ".join(f"{i}: float" for i in args)
|
171
|
-
|
172
|
-
return f"""def {fn_name}({fn_args}) -> float:
|
173
|
-
return {pycode(expr)}
|
174
|
-
"""
|
175
|
-
|
176
|
-
|
177
267
|
def fn_to_sympy(
|
178
268
|
fn: Callable,
|
269
|
+
origin: str,
|
179
270
|
model_args: list[sympy.Symbol | sympy.Expr] | None = None,
|
180
|
-
) -> sympy.Expr:
|
271
|
+
) -> sympy.Expr | None:
|
181
272
|
"""Convert a python function to a sympy expression.
|
182
273
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
model_args
|
188
|
-
Optional list of sympy symbols to substitute for function arguments
|
274
|
+
Args:
|
275
|
+
fn: The function to convert
|
276
|
+
origin: Name of the original caller. Used for error messages.
|
277
|
+
model_args: Optional list of sympy symbols to substitute for function arguments
|
189
278
|
|
190
|
-
Returns
|
191
|
-
-------
|
192
|
-
sympy.Expr
|
279
|
+
Returns:
|
193
280
|
Sympy expression equivalent to the function
|
194
281
|
|
195
|
-
Examples
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
y**2
|
282
|
+
Examples:
|
283
|
+
>>> def square_fn(x):
|
284
|
+
... return x**2
|
285
|
+
>>> import sympy
|
286
|
+
>>> fn_to_sympy(square_fn)
|
287
|
+
x**2
|
288
|
+
>>> # With model_args
|
289
|
+
>>> y = sympy.Symbol('y')
|
290
|
+
>>> fn_to_sympy(square_fn, [y])
|
291
|
+
y**2
|
206
292
|
|
207
293
|
"""
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
# Build all individual comparisons from the chain
|
242
|
-
prev_value = left
|
243
|
-
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
244
|
-
right = cast(Any, _handle_expr(comparator, ctx))
|
245
|
-
|
246
|
-
if isinstance(op, ast.Gt):
|
247
|
-
comparisons.append(prev_value > right)
|
248
|
-
elif isinstance(op, ast.GtE):
|
249
|
-
comparisons.append(prev_value >= right)
|
250
|
-
elif isinstance(op, ast.Lt):
|
251
|
-
comparisons.append(prev_value < right)
|
252
|
-
elif isinstance(op, ast.LtE):
|
253
|
-
comparisons.append(prev_value <= right)
|
254
|
-
elif isinstance(op, ast.Eq):
|
255
|
-
comparisons.append(prev_value == right)
|
256
|
-
elif isinstance(op, ast.NotEq):
|
257
|
-
comparisons.append(prev_value != right)
|
258
|
-
|
259
|
-
prev_value = right
|
260
|
-
|
261
|
-
# Combine all comparisons with logical AND
|
262
|
-
result = comparisons[0]
|
263
|
-
for comp in comparisons[1:]:
|
264
|
-
result = sympy.And(result, comp)
|
265
|
-
return cast(sympy.Expr, result)
|
266
|
-
if isinstance(node, ast.Call):
|
267
|
-
return _handle_call(node, ctx)
|
268
|
-
|
269
|
-
# Handle conditional expressions (ternary operators)
|
270
|
-
if isinstance(node, ast.IfExp):
|
271
|
-
condition = _handle_expr(node.test, ctx)
|
272
|
-
if_true = _handle_expr(node.body, ctx)
|
273
|
-
if_false = _handle_expr(node.orelse, ctx)
|
274
|
-
return sympy.Piecewise((if_true, condition), (if_false, True))
|
275
|
-
|
276
|
-
msg = f"Expression type {type(node).__name__} not implemented"
|
277
|
-
raise NotImplementedError(msg)
|
278
|
-
|
279
|
-
|
280
|
-
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
294
|
+
try:
|
295
|
+
fn_def = get_fn_ast(fn)
|
296
|
+
fn_args = [str(arg.arg) for arg in fn_def.args.args]
|
297
|
+
|
298
|
+
sympy_expr = _handle_fn_body(
|
299
|
+
fn_def.body,
|
300
|
+
ctx=Context(
|
301
|
+
symbols={name: sympy.Symbol(name) for name in fn_args},
|
302
|
+
caller=fn,
|
303
|
+
parent_module=inspect.getmodule(fn),
|
304
|
+
origin=origin,
|
305
|
+
modules={},
|
306
|
+
fns={},
|
307
|
+
),
|
308
|
+
)
|
309
|
+
if sympy_expr is None:
|
310
|
+
return None
|
311
|
+
# Evaluated fns and floats from attributes
|
312
|
+
if isinstance(sympy_expr, float):
|
313
|
+
return sympy.Float(sympy_expr)
|
314
|
+
if model_args is not None and len(model_args):
|
315
|
+
sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
|
316
|
+
return cast(sympy.Expr, sympy_expr)
|
317
|
+
|
318
|
+
except (TypeError, ValueError, NotImplementedError) as e:
|
319
|
+
msg = f"Failed parsing function of {origin}"
|
320
|
+
_LOGGER.warning(msg)
|
321
|
+
_LOGGER.debug("", exc_info=e)
|
322
|
+
return None
|
323
|
+
|
324
|
+
|
325
|
+
def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
|
281
326
|
pieces = []
|
282
327
|
remaining_body = list(body)
|
283
328
|
|
@@ -333,7 +378,10 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
|
333
378
|
target_elements, value_elements, strict=True
|
334
379
|
):
|
335
380
|
if isinstance(target, ast.Name):
|
336
|
-
|
381
|
+
expr = _handle_expr(value_expr, ctx)
|
382
|
+
if expr is None:
|
383
|
+
return None
|
384
|
+
ctx.symbols[target.id] = expr
|
337
385
|
else:
|
338
386
|
# Handle potential iterable unpacking
|
339
387
|
value = _handle_expr(node.value, ctx)
|
@@ -344,8 +392,33 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
|
344
392
|
raise TypeError(msg)
|
345
393
|
target_name = target.id
|
346
394
|
value = _handle_expr(node.value, ctx)
|
395
|
+
if value is None:
|
396
|
+
return None
|
347
397
|
ctx.symbols[target_name] = value
|
348
398
|
|
399
|
+
elif isinstance(node, ast.Import):
|
400
|
+
for alias in node.names:
|
401
|
+
name = alias.name
|
402
|
+
ctx.modules[name] = importlib.import_module(name)
|
403
|
+
|
404
|
+
elif isinstance(node, ast.ImportFrom):
|
405
|
+
package = cast(str, node.module)
|
406
|
+
module = importlib.import_module(package)
|
407
|
+
contents = dict(inspect.getmembers(module))
|
408
|
+
for alias in node.names:
|
409
|
+
name = alias.name
|
410
|
+
el = contents[name]
|
411
|
+
if isinstance(el, float):
|
412
|
+
ctx.symbols[name] = sympy.Float(el)
|
413
|
+
elif callable(el):
|
414
|
+
ctx.fns[name] = el
|
415
|
+
elif isinstance(el, ModuleType):
|
416
|
+
ctx.modules[name] = el
|
417
|
+
else:
|
418
|
+
_LOGGER.debug("Skipping import %s", node)
|
419
|
+
else:
|
420
|
+
_LOGGER.debug("Skipping node of type %s", type(node))
|
421
|
+
|
349
422
|
# If we have pieces to combine into a Piecewise
|
350
423
|
if pieces:
|
351
424
|
return sympy.Piecewise(*pieces)
|
@@ -360,17 +433,93 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
|
|
360
433
|
raise ValueError(msg)
|
361
434
|
|
362
435
|
|
436
|
+
def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
|
437
|
+
"""Key dispatch function."""
|
438
|
+
if isinstance(node, float):
|
439
|
+
return sympy.Float(node)
|
440
|
+
if isinstance(node, ast.UnaryOp):
|
441
|
+
return _handle_unaryop(node, ctx)
|
442
|
+
if isinstance(node, ast.BinOp):
|
443
|
+
return _handle_binop(node, ctx)
|
444
|
+
if isinstance(node, ast.Name):
|
445
|
+
return _handle_name(node, ctx)
|
446
|
+
if isinstance(node, ast.Constant):
|
447
|
+
if isinstance(val := node.value, (float, int)):
|
448
|
+
return sympy.Float(val)
|
449
|
+
msg = "Can only use float values"
|
450
|
+
raise NotImplementedError(msg)
|
451
|
+
if isinstance(node, ast.Call):
|
452
|
+
return _handle_call(node, ctx=ctx)
|
453
|
+
if isinstance(node, ast.Attribute):
|
454
|
+
return _handle_attribute(node, ctx=ctx)
|
455
|
+
|
456
|
+
if isinstance(node, ast.Compare):
|
457
|
+
# Handle chained comparisons like 1 < a < 2
|
458
|
+
left = cast(Any, _handle_expr(node.left, ctx))
|
459
|
+
comparisons = []
|
460
|
+
|
461
|
+
# Build all individual comparisons from the chain
|
462
|
+
prev_value = left
|
463
|
+
for op, comparator in zip(node.ops, node.comparators, strict=True):
|
464
|
+
right = cast(Any, _handle_expr(comparator, ctx))
|
465
|
+
|
466
|
+
if isinstance(op, ast.Gt):
|
467
|
+
comparisons.append(prev_value > right)
|
468
|
+
elif isinstance(op, ast.GtE):
|
469
|
+
comparisons.append(prev_value >= right)
|
470
|
+
elif isinstance(op, ast.Lt):
|
471
|
+
comparisons.append(prev_value < right)
|
472
|
+
elif isinstance(op, ast.LtE):
|
473
|
+
comparisons.append(prev_value <= right)
|
474
|
+
elif isinstance(op, ast.Eq):
|
475
|
+
comparisons.append(prev_value == right)
|
476
|
+
elif isinstance(op, ast.NotEq):
|
477
|
+
comparisons.append(prev_value != right)
|
478
|
+
|
479
|
+
prev_value = right
|
480
|
+
|
481
|
+
# Combine all comparisons with logical AND
|
482
|
+
result = comparisons[0]
|
483
|
+
for comp in comparisons[1:]:
|
484
|
+
result = sympy.And(result, comp)
|
485
|
+
return cast(sympy.Expr, result)
|
486
|
+
|
487
|
+
# Handle conditional expressions (ternary operators)
|
488
|
+
if isinstance(node, ast.IfExp):
|
489
|
+
condition = _handle_expr(node.test, ctx)
|
490
|
+
if_true = _handle_expr(node.body, ctx)
|
491
|
+
if_false = _handle_expr(node.orelse, ctx)
|
492
|
+
return sympy.Piecewise((if_true, condition), (if_false, True))
|
493
|
+
|
494
|
+
msg = f"Expression type {type(node).__name__} not implemented"
|
495
|
+
raise NotImplementedError(msg)
|
496
|
+
|
497
|
+
|
498
|
+
def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
|
499
|
+
value = ctx.symbols.get(node.id)
|
500
|
+
if value is None:
|
501
|
+
global_variables = dict(
|
502
|
+
inspect.getmembers(
|
503
|
+
ctx.parent_module,
|
504
|
+
predicate=lambda x: isinstance(x, float),
|
505
|
+
)
|
506
|
+
)
|
507
|
+
value = sympy.Float(global_variables[node.id])
|
508
|
+
return value
|
509
|
+
|
510
|
+
|
363
511
|
def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
|
364
512
|
left = _handle_expr(node.operand, ctx)
|
365
513
|
left = cast(Any, left) # stupid sympy types don't allow ops on symbols
|
366
514
|
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
515
|
+
match node.op:
|
516
|
+
case ast.UAdd():
|
517
|
+
return +left
|
518
|
+
case ast.USub():
|
519
|
+
return -left
|
520
|
+
case _:
|
521
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
522
|
+
raise NotImplementedError(msg)
|
374
523
|
|
375
524
|
|
376
525
|
def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
@@ -380,63 +529,199 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
|
|
380
529
|
right = _handle_expr(node.right, ctx)
|
381
530
|
right = cast(Any, right) # stupid sympy types don't allow ops on symbols
|
382
531
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
532
|
+
match node.op:
|
533
|
+
case ast.Add():
|
534
|
+
return left + right
|
535
|
+
case ast.Sub():
|
536
|
+
return left - right
|
537
|
+
case ast.Mult():
|
538
|
+
return left * right
|
539
|
+
case ast.Div():
|
540
|
+
return left / right
|
541
|
+
case ast.Pow():
|
542
|
+
return left**right
|
543
|
+
case ast.Mod():
|
544
|
+
return left % right
|
545
|
+
case ast.FloorDiv():
|
546
|
+
return left // right
|
547
|
+
case _:
|
548
|
+
msg = f"Operation {type(node.op).__name__} not implemented"
|
549
|
+
raise NotImplementedError(msg)
|
550
|
+
|
551
|
+
|
552
|
+
def _get_inner_object(obj: object, levels: list[str]) -> sympy.Float | None:
|
553
|
+
# Check if object is instantiated, otherwise instantiate first
|
554
|
+
if isinstance(obj, type):
|
555
|
+
obj = obj()
|
556
|
+
|
557
|
+
for level in levels:
|
558
|
+
_LOGGER.debug("obj %s, level %s", obj, level)
|
559
|
+
obj = getattr(obj, level, None)
|
560
|
+
|
561
|
+
if obj is None:
|
562
|
+
return None
|
563
|
+
|
564
|
+
if isinstance(obj, float):
|
565
|
+
if (value := KNOWN_CONSTANTS.get(obj)) is not None:
|
566
|
+
return value
|
567
|
+
return sympy.Float(obj)
|
568
|
+
|
569
|
+
_LOGGER.debug("Inner object not float: %s", obj)
|
570
|
+
return None
|
571
|
+
|
572
|
+
|
573
|
+
# FIXME: check if target isn't an object or class
|
574
|
+
def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
|
575
|
+
"""Handle an attribute.
|
576
|
+
|
577
|
+
Structures to expect:
|
578
|
+
Attribute(Name(id), attr) | direct
|
579
|
+
Attribute(Attribute(Name(id)), attr) | single layer of nesting
|
580
|
+
Attribute(Attribute(...), attr) | arbitrary nesting
|
581
|
+
|
582
|
+
Targets to expect:
|
583
|
+
- modules (both absolute and relative import)
|
584
|
+
- import a; a.attr
|
585
|
+
- import a; a.b.attr
|
586
|
+
- from a import b; b.attr
|
587
|
+
- objects, e.g. Parameters().a
|
588
|
+
- classes, e.g. Parameters.a
|
589
|
+
|
590
|
+
Watch out for relative imports and the different ways they can be called
|
591
|
+
import a
|
592
|
+
from a import b
|
593
|
+
from a.b import c
|
594
|
+
|
595
|
+
a.attr
|
596
|
+
b.attr
|
597
|
+
c.attr
|
598
|
+
a.b.attr
|
599
|
+
b.c.attr
|
600
|
+
a.b.c.attr
|
601
|
+
"""
|
602
|
+
name = str(node.attr)
|
603
|
+
module: ModuleType | None = None
|
604
|
+
modules = (
|
605
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
606
|
+
| ctx.modules
|
607
|
+
)
|
608
|
+
variables = vars(ctx.parent_module)
|
609
|
+
|
610
|
+
match node.value:
|
611
|
+
case ast.Name(l1):
|
612
|
+
module_name = l1
|
613
|
+
module = modules.get(module_name)
|
614
|
+
if module is None and (var := variables.get(l1)) is not None:
|
615
|
+
return _get_inner_object(var, [node.attr])
|
616
|
+
case ast.Attribute():
|
617
|
+
levels = _find_root(node.value, levels=[])
|
618
|
+
_LOGGER.debug("Attribute levels %s", levels)
|
619
|
+
module_name = ".".join(levels)
|
620
|
+
|
621
|
+
for idx, level in enumerate(levels[:-1]):
|
622
|
+
if (module := modules.get(level)) is not None:
|
623
|
+
modules.update(
|
624
|
+
dict(
|
625
|
+
inspect.getmembers(
|
626
|
+
module,
|
627
|
+
predicate=inspect.ismodule,
|
628
|
+
)
|
629
|
+
)
|
630
|
+
)
|
631
|
+
elif (var := variables.get(level)) is not None:
|
632
|
+
_LOGGER.debug("var %s", var)
|
633
|
+
return _get_inner_object(var, levels[(idx + 1) :] + [node.attr])
|
400
634
|
|
635
|
+
else:
|
636
|
+
_LOGGER.debug("No target found")
|
637
|
+
|
638
|
+
module = modules.get(levels[-1])
|
639
|
+
case _:
|
640
|
+
raise NotImplementedError
|
401
641
|
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
fn_name = str(callee.id)
|
406
|
-
fns = dict(inspect.getmembers(ctx.parent_module, predicate=callable))
|
642
|
+
# Fall-back to absolute import
|
643
|
+
if module is None:
|
644
|
+
module = importlib.import_module(module_name)
|
407
645
|
|
408
|
-
|
409
|
-
|
410
|
-
|
646
|
+
element = dict(
|
647
|
+
inspect.getmembers(
|
648
|
+
module,
|
649
|
+
predicate=lambda x: isinstance(x, float),
|
411
650
|
)
|
651
|
+
).get(name)
|
412
652
|
|
413
|
-
|
414
|
-
|
415
|
-
imports = dict(inspect.getmembers(ctx.parent_module, inspect.ismodule))
|
653
|
+
if element is None:
|
654
|
+
return None
|
416
655
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
656
|
+
if (value := KNOWN_CONSTANTS.get(element)) is not None:
|
657
|
+
return value
|
658
|
+
return sympy.Float(element)
|
659
|
+
|
660
|
+
|
661
|
+
# FIXME: check if target isn't an object or class
|
662
|
+
def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
|
663
|
+
"""Handle call expression.
|
423
664
|
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
665
|
+
Variants
|
666
|
+
- mass_action(x, k1)
|
667
|
+
- fns.mass_action(x, k1)
|
668
|
+
- mxlpy.fns.mass_action(x, k1)
|
669
|
+
|
670
|
+
In future think about?
|
671
|
+
- object.call
|
672
|
+
- Class.call
|
673
|
+
"""
|
674
|
+
model_args: list[sympy.Expr] = []
|
675
|
+
for i in node.args:
|
676
|
+
if (expr := _handle_expr(i, ctx)) is None:
|
677
|
+
return None
|
678
|
+
model_args.append(expr)
|
679
|
+
_LOGGER.debug("Fn args: %s", model_args)
|
680
|
+
|
681
|
+
match node.func:
|
682
|
+
case ast.Name(id):
|
683
|
+
fn_name = str(id)
|
684
|
+
fns = (
|
685
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=callable))
|
686
|
+
| ctx.fns
|
687
|
+
)
|
688
|
+
py_fn = fns.get(fn_name)
|
689
|
+
|
690
|
+
# FIXME: use _handle_attribute for this
|
691
|
+
case ast.Attribute(attr=fn_name):
|
692
|
+
module: ModuleType | None = None
|
693
|
+
modules = (
|
694
|
+
dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
|
695
|
+
| ctx.modules
|
439
696
|
)
|
440
697
|
|
441
|
-
|
442
|
-
|
698
|
+
levels = _find_root(node.func, [])
|
699
|
+
module_name = ".".join(levels[:-1])
|
700
|
+
|
701
|
+
_LOGGER.debug("Searching for module %s", module_name)
|
702
|
+
for level in levels[:-1]:
|
703
|
+
modules.update(
|
704
|
+
dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
|
705
|
+
)
|
706
|
+
module = modules.get(levels[-2])
|
707
|
+
|
708
|
+
# Fall-back to absolute import
|
709
|
+
if module is None:
|
710
|
+
module = importlib.import_module(module_name)
|
711
|
+
|
712
|
+
fns = dict(inspect.getmembers(module, predicate=callable))
|
713
|
+
py_fn = fns.get(fn_name)
|
714
|
+
case _:
|
715
|
+
raise NotImplementedError
|
716
|
+
|
717
|
+
if py_fn is None:
|
718
|
+
return None
|
719
|
+
|
720
|
+
if (fn := KNOWN_FNS.get(py_fn)) is not None:
|
721
|
+
return sympy.Float(fn(*model_args)) # type: ignore
|
722
|
+
|
723
|
+
return fn_to_sympy(
|
724
|
+
py_fn,
|
725
|
+
origin=ctx.origin,
|
726
|
+
model_args=model_args,
|
727
|
+
)
|