mxlpy 0.21.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,28 +3,165 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import ast
6
+ import importlib
6
7
  import inspect
8
+ import logging
9
+ import math
7
10
  import textwrap
8
11
  from dataclasses import dataclass
12
+ from types import ModuleType
9
13
  from typing import TYPE_CHECKING, Any, cast
10
14
 
11
15
  import dill
16
+ import numpy as np
12
17
  import sympy
13
- from sympy.printing.pycode import pycode
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from collections.abc import Callable
17
- from types import ModuleType
18
21
 
19
22
  __all__ = [
20
23
  "Context",
24
+ "KNOWN_CONSTANTS",
25
+ "KNOWN_FNS",
26
+ "PARSE_ERROR",
21
27
  "fn_to_sympy",
22
28
  "get_fn_ast",
23
29
  "get_fn_source",
24
- "sympy_to_fn",
25
- "sympy_to_inline",
26
30
  ]
27
31
 
32
+ _LOGGER = logging.getLogger(__name__)
33
+ PARSE_ERROR = sympy.Symbol("ERROR")
34
+
35
+ KNOWN_CONSTANTS: dict[float, sympy.Float] = {
36
+ math.e: sympy.E,
37
+ math.pi: sympy.pi,
38
+ math.nan: sympy.nan,
39
+ math.tau: sympy.pi * 2,
40
+ math.inf: sympy.oo,
41
+ # numpy
42
+ np.e: sympy.E,
43
+ np.pi: sympy.pi,
44
+ np.nan: sympy.nan,
45
+ np.inf: sympy.oo,
46
+ }
47
+
48
+ KNOWN_FNS: dict[Callable, sympy.Expr] = {
49
+ # built-ins
50
+ abs: sympy.Abs, # type: ignore
51
+ min: sympy.Min,
52
+ max: sympy.Max,
53
+ pow: sympy.Pow,
54
+ # round: sympy
55
+ # divmod
56
+ # math module
57
+ math.acos: sympy.acos,
58
+ math.acosh: sympy.acosh,
59
+ math.asin: sympy.asin,
60
+ math.asinh: sympy.asinh,
61
+ math.atan: sympy.atan,
62
+ math.atan2: sympy.atan2,
63
+ math.atanh: sympy.atanh,
64
+ math.cbrt: sympy.cbrt,
65
+ math.ceil: sympy.ceiling,
66
+ # math.comb: sympy.comb,
67
+ # math.copysign: sympy.copysign,
68
+ math.cos: sympy.cos,
69
+ math.cosh: sympy.cosh,
70
+ # math.degrees: sympy.degrees,
71
+ # math.dist: sympy.dist,
72
+ math.erf: sympy.erf,
73
+ math.erfc: sympy.erfc,
74
+ math.exp: sympy.exp,
75
+ # math.exp2: sympy.exp2,
76
+ # math.expm1: sympy.expm1,
77
+ # math.fabs: sympy.fabs,
78
+ math.factorial: sympy.factorial,
79
+ math.floor: sympy.floor,
80
+ # math.fmod: sympy.fmod,
81
+ # math.frexp: sympy.frexp,
82
+ # math.fsum: sympy.fsum,
83
+ math.gamma: sympy.gamma,
84
+ math.gcd: sympy.gcd,
85
+ # math.hypot: sympy.hypot,
86
+ # math.isclose: sympy.isclose,
87
+ # math.isfinite: sympy.isfinite,
88
+ # math.isinf: sympy.isinf,
89
+ # math.isnan: sympy.isnan,
90
+ # math.isqrt: sympy.isqrt,
91
+ math.lcm: sympy.lcm,
92
+ # math.ldexp: sympy.ldexp,
93
+ # math.lgamma: sympy.lgamma,
94
+ math.log: sympy.log,
95
+ # math.log10: sympy.log10,
96
+ # math.log1p: sympy.log1p,
97
+ # math.log2: sympy.log2,
98
+ # math.modf: sympy.modf,
99
+ # math.nextafter: sympy.nextafter,
100
+ # math.perm: sympy.perm,
101
+ math.pow: sympy.Pow,
102
+ math.prod: sympy.prod,
103
+ math.radians: sympy.rad,
104
+ math.remainder: sympy.rem,
105
+ math.sin: sympy.sin,
106
+ math.sinh: sympy.sinh,
107
+ math.sqrt: sympy.sqrt,
108
+ # math.sumprod: sympy.sumprod,
109
+ math.tan: sympy.tan,
110
+ math.tanh: sympy.tanh,
111
+ math.trunc: sympy.trunc,
112
+ # math.ulp: sympy.ulp,
113
+ # numpy
114
+ np.abs: sympy.Abs,
115
+ np.acos: sympy.acos,
116
+ np.acosh: sympy.acosh,
117
+ np.asin: sympy.asin,
118
+ np.asinh: sympy.asinh,
119
+ np.atan: sympy.atan,
120
+ np.atanh: sympy.atanh,
121
+ np.atan2: sympy.atan2,
122
+ np.pow: sympy.Pow,
123
+ np.absolute: sympy.Abs,
124
+ np.add: sympy.Add,
125
+ np.arccos: sympy.acos,
126
+ np.arccosh: sympy.acosh,
127
+ np.arcsin: sympy.asin,
128
+ np.arcsinh: sympy.asinh,
129
+ np.arctan2: sympy.atan2,
130
+ np.arctan: sympy.atan,
131
+ np.arctanh: sympy.atanh,
132
+ np.cbrt: sympy.cbrt,
133
+ np.ceil: sympy.ceiling,
134
+ np.conjugate: sympy.conjugate,
135
+ np.cos: sympy.cos,
136
+ np.cosh: sympy.cosh,
137
+ np.exp: sympy.exp,
138
+ np.floor: sympy.floor,
139
+ np.gcd: sympy.gcd,
140
+ np.greater: sympy.GreaterThan,
141
+ np.greater_equal: sympy.Ge,
142
+ np.invert: sympy.invert,
143
+ np.lcm: sympy.lcm,
144
+ np.less: sympy.LessThan,
145
+ np.less_equal: sympy.Le,
146
+ np.log: sympy.log,
147
+ np.maximum: sympy.maximum,
148
+ np.minimum: sympy.minimum,
149
+ np.mod: sympy.Mod,
150
+ np.positive: sympy.Abs,
151
+ np.power: sympy.Pow,
152
+ np.sign: sympy.sign,
153
+ np.sin: sympy.sin,
154
+ np.sinh: sympy.sinh,
155
+ np.sqrt: sympy.sqrt,
156
+ # np.square: sympy.square,
157
+ # np.subtract: sympy., # Add(x, -1 * y)
158
+ np.tan: sympy.tan,
159
+ np.tanh: sympy.tanh,
160
+ # np.true_divide: sympy.true_divide,
161
+ np.trunc: sympy.trunc,
162
+ # np.vecdot: sympy.vecdot,
163
+ }
164
+
28
165
 
29
166
  @dataclass
30
167
  class Context:
@@ -33,6 +170,9 @@ class Context:
33
170
  symbols: dict[str, sympy.Symbol | sympy.Expr]
34
171
  caller: Callable
35
172
  parent_module: ModuleType | None
173
+ origin: str
174
+ modules: dict[str, ModuleType]
175
+ fns: dict[str, Callable]
36
176
 
37
177
  def updated(
38
178
  self,
@@ -47,8 +187,22 @@ class Context:
47
187
  parent_module=self.parent_module
48
188
  if parent_module is None
49
189
  else parent_module,
190
+ origin=self.origin,
191
+ modules=self.modules,
192
+ fns=self.fns,
193
+ )
194
+
195
+
196
+ def _find_root(value: ast.Attribute | ast.Name, levels: list) -> list[str]:
197
+ if isinstance(value, ast.Attribute):
198
+ return _find_root(
199
+ cast(ast.Attribute, value.value),
200
+ [value.attr, *levels],
50
201
  )
51
202
 
203
+ root = str(value.id)
204
+ return [root, *levels]
205
+
52
206
 
53
207
  def get_fn_source(fn: Callable) -> str:
54
208
  """Get the string representation of a function.
@@ -110,174 +264,65 @@ def get_fn_ast(fn: Callable) -> ast.FunctionDef:
110
264
  return fn_def
111
265
 
112
266
 
113
- def sympy_to_inline(expr: sympy.Expr) -> str:
114
- """Convert a sympy expression to inline Python code.
115
-
116
- Parameters
117
- ----------
118
- expr
119
- The sympy expression to convert
120
-
121
- Returns
122
- -------
123
- str
124
- Python code string for the expression
125
-
126
- Examples
127
- --------
128
- >>> import sympy
129
- >>> x = sympy.Symbol('x')
130
- >>> expr = x**2 + 2*x + 1
131
- >>> sympy_to_inline(expr)
132
- 'x**2 + 2*x + 1'
133
-
134
- """
135
- return cast(str, pycode(expr, fully_qualified_modules=True))
136
-
137
-
138
- def sympy_to_fn(
139
- *,
140
- fn_name: str,
141
- args: list[str],
142
- expr: sympy.Expr,
143
- ) -> str:
144
- """Convert a sympy expression to a python function.
145
-
146
- Parameters
147
- ----------
148
- fn_name
149
- Name of the function to generate
150
- args
151
- List of argument names for the function
152
- expr
153
- Sympy expression to convert to a function body
154
-
155
- Returns
156
- -------
157
- str
158
- String representation of the generated function
159
-
160
- Examples
161
- --------
162
- >>> import sympy
163
- >>> x, y = sympy.symbols('x y')
164
- >>> expr = x**2 + y
165
- >>> print(sympy_to_fn(fn_name="square_plus_y", args=["x", "y"], expr=expr))
166
- def square_plus_y(x: float, y: float) -> float:
167
- return x**2 + y
168
-
169
- """
170
- fn_args = ", ".join(f"{i}: float" for i in args)
171
-
172
- return f"""def {fn_name}({fn_args}) -> float:
173
- return {pycode(expr)}
174
- """
175
-
176
-
177
267
  def fn_to_sympy(
178
268
  fn: Callable,
269
+ origin: str,
179
270
  model_args: list[sympy.Symbol | sympy.Expr] | None = None,
180
- ) -> sympy.Expr:
271
+ ) -> sympy.Expr | None:
181
272
  """Convert a python function to a sympy expression.
182
273
 
183
- Parameters
184
- ----------
185
- fn
186
- The function to convert
187
- model_args
188
- Optional list of sympy symbols to substitute for function arguments
274
+ Args:
275
+ fn: The function to convert
276
+ origin: Name of the original caller. Used for error messages.
277
+ model_args: Optional list of sympy symbols to substitute for function arguments
189
278
 
190
- Returns
191
- -------
192
- sympy.Expr
279
+ Returns:
193
280
  Sympy expression equivalent to the function
194
281
 
195
- Examples
196
- --------
197
- >>> def square_fn(x):
198
- ... return x**2
199
- >>> import sympy
200
- >>> fn_to_sympy(square_fn)
201
- x**2
202
- >>> # With model_args
203
- >>> y = sympy.Symbol('y')
204
- >>> fn_to_sympy(square_fn, [y])
205
- y**2
282
+ Examples:
283
+ >>> def square_fn(x):
284
+ ... return x**2
285
+ >>> import sympy
286
+ >>> fn_to_sympy(square_fn)
287
+ x**2
288
+ >>> # With model_args
289
+ >>> y = sympy.Symbol('y')
290
+ >>> fn_to_sympy(square_fn, [y])
291
+ y**2
206
292
 
207
293
  """
208
- fn_def = get_fn_ast(fn)
209
- fn_args = [str(arg.arg) for arg in fn_def.args.args]
210
- sympy_expr = _handle_fn_body(
211
- fn_def.body,
212
- ctx=Context(
213
- symbols={name: sympy.Symbol(name) for name in fn_args},
214
- caller=fn,
215
- parent_module=inspect.getmodule(fn),
216
- ),
217
- )
218
- if model_args is not None:
219
- sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
220
- return cast(sympy.Expr, sympy_expr)
221
-
222
-
223
- def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
224
- return ctx.symbols[node.id]
225
-
226
-
227
- def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
228
- if isinstance(node, ast.UnaryOp):
229
- return _handle_unaryop(node, ctx)
230
- if isinstance(node, ast.BinOp):
231
- return _handle_binop(node, ctx)
232
- if isinstance(node, ast.Name):
233
- return _handle_name(node, ctx)
234
- if isinstance(node, ast.Constant):
235
- return node.value
236
- if isinstance(node, ast.Compare):
237
- # Handle chained comparisons like 1 < a < 2
238
- left = cast(Any, _handle_expr(node.left, ctx))
239
- comparisons = []
240
-
241
- # Build all individual comparisons from the chain
242
- prev_value = left
243
- for op, comparator in zip(node.ops, node.comparators, strict=True):
244
- right = cast(Any, _handle_expr(comparator, ctx))
245
-
246
- if isinstance(op, ast.Gt):
247
- comparisons.append(prev_value > right)
248
- elif isinstance(op, ast.GtE):
249
- comparisons.append(prev_value >= right)
250
- elif isinstance(op, ast.Lt):
251
- comparisons.append(prev_value < right)
252
- elif isinstance(op, ast.LtE):
253
- comparisons.append(prev_value <= right)
254
- elif isinstance(op, ast.Eq):
255
- comparisons.append(prev_value == right)
256
- elif isinstance(op, ast.NotEq):
257
- comparisons.append(prev_value != right)
258
-
259
- prev_value = right
260
-
261
- # Combine all comparisons with logical AND
262
- result = comparisons[0]
263
- for comp in comparisons[1:]:
264
- result = sympy.And(result, comp)
265
- return cast(sympy.Expr, result)
266
- if isinstance(node, ast.Call):
267
- return _handle_call(node, ctx)
268
-
269
- # Handle conditional expressions (ternary operators)
270
- if isinstance(node, ast.IfExp):
271
- condition = _handle_expr(node.test, ctx)
272
- if_true = _handle_expr(node.body, ctx)
273
- if_false = _handle_expr(node.orelse, ctx)
274
- return sympy.Piecewise((if_true, condition), (if_false, True))
275
-
276
- msg = f"Expression type {type(node).__name__} not implemented"
277
- raise NotImplementedError(msg)
278
-
279
-
280
- def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
294
+ try:
295
+ fn_def = get_fn_ast(fn)
296
+ fn_args = [str(arg.arg) for arg in fn_def.args.args]
297
+
298
+ sympy_expr = _handle_fn_body(
299
+ fn_def.body,
300
+ ctx=Context(
301
+ symbols={name: sympy.Symbol(name) for name in fn_args},
302
+ caller=fn,
303
+ parent_module=inspect.getmodule(fn),
304
+ origin=origin,
305
+ modules={},
306
+ fns={},
307
+ ),
308
+ )
309
+ if sympy_expr is None:
310
+ return None
311
+ # Evaluated fns and floats from attributes
312
+ if isinstance(sympy_expr, float):
313
+ return sympy.Float(sympy_expr)
314
+ if model_args is not None and len(model_args):
315
+ sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
316
+ return cast(sympy.Expr, sympy_expr)
317
+
318
+ except (TypeError, ValueError, NotImplementedError) as e:
319
+ msg = f"Failed parsing function of {origin}"
320
+ _LOGGER.warning(msg)
321
+ _LOGGER.debug("", exc_info=e)
322
+ return None
323
+
324
+
325
+ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
281
326
  pieces = []
282
327
  remaining_body = list(body)
283
328
 
@@ -333,7 +378,10 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
333
378
  target_elements, value_elements, strict=True
334
379
  ):
335
380
  if isinstance(target, ast.Name):
336
- ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
381
+ expr = _handle_expr(value_expr, ctx)
382
+ if expr is None:
383
+ return None
384
+ ctx.symbols[target.id] = expr
337
385
  else:
338
386
  # Handle potential iterable unpacking
339
387
  value = _handle_expr(node.value, ctx)
@@ -344,8 +392,33 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
344
392
  raise TypeError(msg)
345
393
  target_name = target.id
346
394
  value = _handle_expr(node.value, ctx)
395
+ if value is None:
396
+ return None
347
397
  ctx.symbols[target_name] = value
348
398
 
399
+ elif isinstance(node, ast.Import):
400
+ for alias in node.names:
401
+ name = alias.name
402
+ ctx.modules[name] = importlib.import_module(name)
403
+
404
+ elif isinstance(node, ast.ImportFrom):
405
+ package = cast(str, node.module)
406
+ module = importlib.import_module(package)
407
+ contents = dict(inspect.getmembers(module))
408
+ for alias in node.names:
409
+ name = alias.name
410
+ el = contents[name]
411
+ if isinstance(el, float):
412
+ ctx.symbols[name] = sympy.Float(el)
413
+ elif callable(el):
414
+ ctx.fns[name] = el
415
+ elif isinstance(el, ModuleType):
416
+ ctx.modules[name] = el
417
+ else:
418
+ _LOGGER.debug("Skipping import %s", node)
419
+ else:
420
+ _LOGGER.debug("Skipping node of type %s", type(node))
421
+
349
422
  # If we have pieces to combine into a Piecewise
350
423
  if pieces:
351
424
  return sympy.Piecewise(*pieces)
@@ -360,17 +433,93 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
360
433
  raise ValueError(msg)
361
434
 
362
435
 
436
+ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
437
+ """Key dispatch function."""
438
+ if isinstance(node, float):
439
+ return sympy.Float(node)
440
+ if isinstance(node, ast.UnaryOp):
441
+ return _handle_unaryop(node, ctx)
442
+ if isinstance(node, ast.BinOp):
443
+ return _handle_binop(node, ctx)
444
+ if isinstance(node, ast.Name):
445
+ return _handle_name(node, ctx)
446
+ if isinstance(node, ast.Constant):
447
+ if isinstance(val := node.value, (float, int)):
448
+ return sympy.Float(val)
449
+ msg = "Can only use float values"
450
+ raise NotImplementedError(msg)
451
+ if isinstance(node, ast.Call):
452
+ return _handle_call(node, ctx=ctx)
453
+ if isinstance(node, ast.Attribute):
454
+ return _handle_attribute(node, ctx=ctx)
455
+
456
+ if isinstance(node, ast.Compare):
457
+ # Handle chained comparisons like 1 < a < 2
458
+ left = cast(Any, _handle_expr(node.left, ctx))
459
+ comparisons = []
460
+
461
+ # Build all individual comparisons from the chain
462
+ prev_value = left
463
+ for op, comparator in zip(node.ops, node.comparators, strict=True):
464
+ right = cast(Any, _handle_expr(comparator, ctx))
465
+
466
+ if isinstance(op, ast.Gt):
467
+ comparisons.append(prev_value > right)
468
+ elif isinstance(op, ast.GtE):
469
+ comparisons.append(prev_value >= right)
470
+ elif isinstance(op, ast.Lt):
471
+ comparisons.append(prev_value < right)
472
+ elif isinstance(op, ast.LtE):
473
+ comparisons.append(prev_value <= right)
474
+ elif isinstance(op, ast.Eq):
475
+ comparisons.append(prev_value == right)
476
+ elif isinstance(op, ast.NotEq):
477
+ comparisons.append(prev_value != right)
478
+
479
+ prev_value = right
480
+
481
+ # Combine all comparisons with logical AND
482
+ result = comparisons[0]
483
+ for comp in comparisons[1:]:
484
+ result = sympy.And(result, comp)
485
+ return cast(sympy.Expr, result)
486
+
487
+ # Handle conditional expressions (ternary operators)
488
+ if isinstance(node, ast.IfExp):
489
+ condition = _handle_expr(node.test, ctx)
490
+ if_true = _handle_expr(node.body, ctx)
491
+ if_false = _handle_expr(node.orelse, ctx)
492
+ return sympy.Piecewise((if_true, condition), (if_false, True))
493
+
494
+ msg = f"Expression type {type(node).__name__} not implemented"
495
+ raise NotImplementedError(msg)
496
+
497
+
498
+ def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
499
+ value = ctx.symbols.get(node.id)
500
+ if value is None:
501
+ global_variables = dict(
502
+ inspect.getmembers(
503
+ ctx.parent_module,
504
+ predicate=lambda x: isinstance(x, float),
505
+ )
506
+ )
507
+ value = sympy.Float(global_variables[node.id])
508
+ return value
509
+
510
+
363
511
  def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
364
512
  left = _handle_expr(node.operand, ctx)
365
513
  left = cast(Any, left) # stupid sympy types don't allow ops on symbols
366
514
 
367
- if isinstance(node.op, ast.UAdd):
368
- return +left
369
- if isinstance(node.op, ast.USub):
370
- return -left
371
-
372
- msg = f"Operation {type(node.op).__name__} not implemented"
373
- raise NotImplementedError(msg)
515
+ match node.op:
516
+ case ast.UAdd():
517
+ return +left
518
+ case ast.USub():
519
+ return -left
520
+ case _:
521
+ msg = f"Operation {type(node.op).__name__} not implemented"
522
+ raise NotImplementedError(msg)
374
523
 
375
524
 
376
525
  def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
@@ -380,63 +529,199 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
380
529
  right = _handle_expr(node.right, ctx)
381
530
  right = cast(Any, right) # stupid sympy types don't allow ops on symbols
382
531
 
383
- if isinstance(node.op, ast.Add):
384
- return left + right
385
- if isinstance(node.op, ast.Sub):
386
- return left - right
387
- if isinstance(node.op, ast.Mult):
388
- return left * right
389
- if isinstance(node.op, ast.Div):
390
- return left / right
391
- if isinstance(node.op, ast.Pow):
392
- return left**right
393
- if isinstance(node.op, ast.Mod):
394
- return left % right
395
- if isinstance(node.op, ast.FloorDiv):
396
- return left // right
397
-
398
- msg = f"Operation {type(node.op).__name__} not implemented"
399
- raise NotImplementedError(msg)
532
+ match node.op:
533
+ case ast.Add():
534
+ return left + right
535
+ case ast.Sub():
536
+ return left - right
537
+ case ast.Mult():
538
+ return left * right
539
+ case ast.Div():
540
+ return left / right
541
+ case ast.Pow():
542
+ return left**right
543
+ case ast.Mod():
544
+ return left % right
545
+ case ast.FloorDiv():
546
+ return left // right
547
+ case _:
548
+ msg = f"Operation {type(node.op).__name__} not implemented"
549
+ raise NotImplementedError(msg)
550
+
551
+
552
+ def _get_inner_object(obj: object, levels: list[str]) -> sympy.Float | None:
553
+ # Check if object is instantiated, otherwise instantiate first
554
+ if isinstance(obj, type):
555
+ obj = obj()
556
+
557
+ for level in levels:
558
+ _LOGGER.debug("obj %s, level %s", obj, level)
559
+ obj = getattr(obj, level, None)
560
+
561
+ if obj is None:
562
+ return None
563
+
564
+ if isinstance(obj, float):
565
+ if (value := KNOWN_CONSTANTS.get(obj)) is not None:
566
+ return value
567
+ return sympy.Float(obj)
568
+
569
+ _LOGGER.debug("Inner object not float: %s", obj)
570
+ return None
571
+
572
+
573
+ # FIXME: check if target isn't an object or class
574
+ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
575
+ """Handle an attribute.
576
+
577
+ Structures to expect:
578
+ Attribute(Name(id), attr) | direct
579
+ Attribute(Attribute(Name(id)), attr) | single layer of nesting
580
+ Attribute(Attribute(...), attr) | arbitrary nesting
581
+
582
+ Targets to expect:
583
+ - modules (both absolute and relative import)
584
+ - import a; a.attr
585
+ - import a; a.b.attr
586
+ - from a import b; b.attr
587
+ - objects, e.g. Parameters().a
588
+ - classes, e.g. Parameters.a
589
+
590
+ Watch out for relative imports and the different ways they can be called
591
+ import a
592
+ from a import b
593
+ from a.b import c
594
+
595
+ a.attr
596
+ b.attr
597
+ c.attr
598
+ a.b.attr
599
+ b.c.attr
600
+ a.b.c.attr
601
+ """
602
+ name = str(node.attr)
603
+ module: ModuleType | None = None
604
+ modules = (
605
+ dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
606
+ | ctx.modules
607
+ )
608
+ variables = vars(ctx.parent_module)
609
+
610
+ match node.value:
611
+ case ast.Name(l1):
612
+ module_name = l1
613
+ module = modules.get(module_name)
614
+ if module is None and (var := variables.get(l1)) is not None:
615
+ return _get_inner_object(var, [node.attr])
616
+ case ast.Attribute():
617
+ levels = _find_root(node.value, levels=[])
618
+ _LOGGER.debug("Attribute levels %s", levels)
619
+ module_name = ".".join(levels)
620
+
621
+ for idx, level in enumerate(levels[:-1]):
622
+ if (module := modules.get(level)) is not None:
623
+ modules.update(
624
+ dict(
625
+ inspect.getmembers(
626
+ module,
627
+ predicate=inspect.ismodule,
628
+ )
629
+ )
630
+ )
631
+ elif (var := variables.get(level)) is not None:
632
+ _LOGGER.debug("var %s", var)
633
+ return _get_inner_object(var, levels[(idx + 1) :] + [node.attr])
400
634
 
635
+ else:
636
+ _LOGGER.debug("No target found")
637
+
638
+ module = modules.get(levels[-1])
639
+ case _:
640
+ raise NotImplementedError
401
641
 
402
- def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
403
- # direct call, e.g. mass_action(x, k1)
404
- if isinstance(callee := node.func, ast.Name):
405
- fn_name = str(callee.id)
406
- fns = dict(inspect.getmembers(ctx.parent_module, predicate=callable))
642
+ # Fall-back to absolute import
643
+ if module is None:
644
+ module = importlib.import_module(module_name)
407
645
 
408
- return fn_to_sympy(
409
- fns[fn_name],
410
- model_args=[_handle_expr(i, ctx) for i in node.args],
646
+ element = dict(
647
+ inspect.getmembers(
648
+ module,
649
+ predicate=lambda x: isinstance(x, float),
411
650
  )
651
+ ).get(name)
412
652
 
413
- # search for fn in other namespace
414
- if isinstance(attr := node.func, ast.Attribute):
415
- imports = dict(inspect.getmembers(ctx.parent_module, inspect.ismodule))
653
+ if element is None:
654
+ return None
416
655
 
417
- # Single level, e.g. fns.mass_action(x, k1)
418
- if isinstance(module_name := attr.value, ast.Name):
419
- return _handle_call(
420
- ast.Call(func=ast.Name(attr.attr), args=node.args, keywords=[]),
421
- ctx=ctx.updated(parent_module=imports[module_name.id]),
422
- )
656
+ if (value := KNOWN_CONSTANTS.get(element)) is not None:
657
+ return value
658
+ return sympy.Float(element)
659
+
660
+
661
+ # FIXME: check if target isn't an object or class
662
+ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
663
+ """Handle call expression.
423
664
 
424
- # Multiple levels, e.g. mxlpy.fns.mass_action(x, k1)
425
- if isinstance(inner_attr := attr.value, ast.Attribute):
426
- if not isinstance(module_name := inner_attr.value, ast.Name):
427
- msg = f"Unknown target kind {module_name}"
428
- raise NotImplementedError(msg)
429
- return _handle_call(
430
- ast.Call(
431
- func=ast.Attribute(
432
- value=ast.Name(inner_attr.attr),
433
- attr=attr.attr,
434
- ),
435
- args=node.args,
436
- keywords=[],
437
- ),
438
- ctx=ctx.updated(parent_module=imports[module_name.id]),
665
+ Variants
666
+ - mass_action(x, k1)
667
+ - fns.mass_action(x, k1)
668
+ - mxlpy.fns.mass_action(x, k1)
669
+
670
+ In future think about?
671
+ - object.call
672
+ - Class.call
673
+ """
674
+ model_args: list[sympy.Expr] = []
675
+ for i in node.args:
676
+ if (expr := _handle_expr(i, ctx)) is None:
677
+ return None
678
+ model_args.append(expr)
679
+ _LOGGER.debug("Fn args: %s", model_args)
680
+
681
+ match node.func:
682
+ case ast.Name(id):
683
+ fn_name = str(id)
684
+ fns = (
685
+ dict(inspect.getmembers(ctx.parent_module, predicate=callable))
686
+ | ctx.fns
687
+ )
688
+ py_fn = fns.get(fn_name)
689
+
690
+ # FIXME: use _handle_attribute for this
691
+ case ast.Attribute(attr=fn_name):
692
+ module: ModuleType | None = None
693
+ modules = (
694
+ dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
695
+ | ctx.modules
439
696
  )
440
697
 
441
- msg = f"Unsupported function type {node.func}"
442
- raise NotImplementedError(msg)
698
+ levels = _find_root(node.func, [])
699
+ module_name = ".".join(levels[:-1])
700
+
701
+ _LOGGER.debug("Searching for module %s", module_name)
702
+ for level in levels[:-1]:
703
+ modules.update(
704
+ dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
705
+ )
706
+ module = modules.get(levels[-2])
707
+
708
+ # Fall-back to absolute import
709
+ if module is None:
710
+ module = importlib.import_module(module_name)
711
+
712
+ fns = dict(inspect.getmembers(module, predicate=callable))
713
+ py_fn = fns.get(fn_name)
714
+ case _:
715
+ raise NotImplementedError
716
+
717
+ if py_fn is None:
718
+ return None
719
+
720
+ if (fn := KNOWN_FNS.get(py_fn)) is not None:
721
+ return sympy.Float(fn(*model_args)) # type: ignore
722
+
723
+ return fn_to_sympy(
724
+ py_fn,
725
+ origin=ctx.origin,
726
+ model_args=model_args,
727
+ )