mxlpy 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,28 +3,166 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import ast
6
+ import importlib
6
7
  import inspect
8
+ import logging
9
+ import math
7
10
  import textwrap
8
11
  from dataclasses import dataclass
12
+ from types import ModuleType
9
13
  from typing import TYPE_CHECKING, Any, cast
10
14
 
11
15
  import dill
16
+ import numpy as np
12
17
  import sympy
13
- from sympy.printing.pycode import pycode
14
18
 
15
19
  if TYPE_CHECKING:
16
20
  from collections.abc import Callable
17
- from types import ModuleType
18
21
 
19
22
  __all__ = [
20
23
  "Context",
24
+ "KNOWN_CONSTANTS",
25
+ "KNOWN_FNS",
26
+ "PARSE_ERROR",
21
27
  "fn_to_sympy",
22
28
  "get_fn_ast",
23
29
  "get_fn_source",
24
- "sympy_to_fn",
25
- "sympy_to_inline",
26
30
  ]
27
31
 
32
+ _LOGGER = logging.getLogger(__name__)
33
+ PARSE_ERROR = sympy.Symbol("ERROR")
34
+
35
+ KNOWN_CONSTANTS: dict[float, sympy.Float] = {
36
+ math.e: sympy.E,
37
+ math.pi: sympy.pi,
38
+ math.nan: sympy.nan,
39
+ math.tau: sympy.pi * 2,
40
+ math.inf: sympy.oo,
41
+ # numpy
42
+ np.e: sympy.E,
43
+ np.pi: sympy.pi,
44
+ np.nan: sympy.nan,
45
+ np.inf: sympy.oo,
46
+ }
47
+
48
+ KNOWN_FNS: dict[Callable, sympy.Expr] = {
49
+ # built-ins
50
+ abs: sympy.Abs, # type: ignore
51
+ min: sympy.Min,
52
+ max: sympy.Max,
53
+ pow: sympy.Pow,
54
+ # round: sympy
55
+ # divmod
56
+ # math module
57
+ math.acos: sympy.acos,
58
+ math.acosh: sympy.acosh,
59
+ math.asin: sympy.asin,
60
+ math.asinh: sympy.asinh,
61
+ math.atan: sympy.atan,
62
+ math.atan2: sympy.atan2,
63
+ math.atanh: sympy.atanh,
64
+ math.cbrt: sympy.cbrt,
65
+ math.ceil: sympy.ceiling,
66
+ # math.comb: sympy.comb,
67
+ # math.copysign: sympy.copysign,
68
+ math.cos: sympy.cos,
69
+ math.cosh: sympy.cosh,
70
+ # math.degrees: sympy.degrees,
71
+ # math.dist: sympy.dist,
72
+ math.erf: sympy.erf,
73
+ math.erfc: sympy.erfc,
74
+ math.exp: sympy.exp,
75
+ # math.exp2: sympy.exp2,
76
+ # math.expm1: sympy.expm1,
77
+ # math.fabs: sympy.fabs,
78
+ math.factorial: sympy.factorial,
79
+ math.floor: sympy.floor,
80
+ # math.fmod: sympy.fmod,
81
+ # math.frexp: sympy.frexp,
82
+ # math.fsum: sympy.fsum,
83
+ math.gamma: sympy.gamma,
84
+ math.gcd: sympy.gcd,
85
+ # math.hypot: sympy.hypot,
86
+ # math.isclose: sympy.isclose,
87
+ # math.isfinite: sympy.isfinite,
88
+ # math.isinf: sympy.isinf,
89
+ # math.isnan: sympy.isnan,
90
+ # math.isqrt: sympy.isqrt,
91
+ math.lcm: sympy.lcm,
92
+ # math.ldexp: sympy.ldexp,
93
+ # math.lgamma: sympy.lgamma,
94
+ math.log: sympy.log,
95
+ # math.log10: sympy.log10,
96
+ # math.log1p: sympy.log1p,
97
+ # math.log2: sympy.log2,
98
+ # math.modf: sympy.modf,
99
+ # math.nextafter: sympy.nextafter,
100
+ # math.perm: sympy.perm,
101
+ math.pow: sympy.Pow,
102
+ math.prod: sympy.prod,
103
+ math.radians: sympy.rad,
104
+ math.remainder: sympy.rem,
105
+ math.sin: sympy.sin,
106
+ math.sinh: sympy.sinh,
107
+ math.sqrt: sympy.sqrt,
108
+ # math.sumprod: sympy.sumprod,
109
+ math.tan: sympy.tan,
110
+ math.tanh: sympy.tanh,
111
+ math.trunc: sympy.trunc,
112
+ # math.ulp: sympy.ulp,
113
+ # numpy
114
+ np.exp: sympy.exp,
115
+ np.abs: sympy.Abs,
116
+ np.acos: sympy.acos,
117
+ np.acosh: sympy.acosh,
118
+ np.asin: sympy.asin,
119
+ np.asinh: sympy.asinh,
120
+ np.atan: sympy.atan,
121
+ np.atanh: sympy.atanh,
122
+ np.atan2: sympy.atan2,
123
+ np.pow: sympy.Pow,
124
+ np.absolute: sympy.Abs,
125
+ np.add: sympy.Add,
126
+ np.arccos: sympy.acos,
127
+ np.arccosh: sympy.acosh,
128
+ np.arcsin: sympy.asin,
129
+ np.arcsinh: sympy.asinh,
130
+ np.arctan2: sympy.atan2,
131
+ np.arctan: sympy.atan,
132
+ np.arctanh: sympy.atanh,
133
+ np.cbrt: sympy.cbrt,
134
+ np.ceil: sympy.ceiling,
135
+ np.conjugate: sympy.conjugate,
136
+ np.cos: sympy.cos,
137
+ np.cosh: sympy.cosh,
138
+ np.exp: sympy.exp,
139
+ np.floor: sympy.floor,
140
+ np.gcd: sympy.gcd,
141
+ np.greater: sympy.GreaterThan,
142
+ np.greater_equal: sympy.Ge,
143
+ np.invert: sympy.invert,
144
+ np.lcm: sympy.lcm,
145
+ np.less: sympy.LessThan,
146
+ np.less_equal: sympy.Le,
147
+ np.log: sympy.log,
148
+ np.maximum: sympy.maximum,
149
+ np.minimum: sympy.minimum,
150
+ np.mod: sympy.Mod,
151
+ np.positive: sympy.Abs,
152
+ np.power: sympy.Pow,
153
+ np.sign: sympy.sign,
154
+ np.sin: sympy.sin,
155
+ np.sinh: sympy.sinh,
156
+ np.sqrt: sympy.sqrt,
157
+ # np.square: sympy.square,
158
+ # np.subtract: sympy., # Add(x, -1 * y)
159
+ np.tan: sympy.tan,
160
+ np.tanh: sympy.tanh,
161
+ # np.true_divide: sympy.true_divide,
162
+ np.trunc: sympy.trunc,
163
+ # np.vecdot: sympy.vecdot,
164
+ }
165
+
28
166
 
29
167
  @dataclass
30
168
  class Context:
@@ -33,6 +171,9 @@ class Context:
33
171
  symbols: dict[str, sympy.Symbol | sympy.Expr]
34
172
  caller: Callable
35
173
  parent_module: ModuleType | None
174
+ origin: str
175
+ modules: dict[str, ModuleType]
176
+ fns: dict[str, Callable]
36
177
 
37
178
  def updated(
38
179
  self,
@@ -47,9 +188,23 @@ class Context:
47
188
  parent_module=self.parent_module
48
189
  if parent_module is None
49
190
  else parent_module,
191
+ origin=self.origin,
192
+ modules=self.modules,
193
+ fns=self.fns,
50
194
  )
51
195
 
52
196
 
197
+ def _find_root(value: ast.Attribute | ast.Name, levels: list) -> list[str]:
198
+ if isinstance(value, ast.Attribute):
199
+ return _find_root(
200
+ cast(ast.Attribute, value.value),
201
+ [value.attr, *levels],
202
+ )
203
+
204
+ root = str(value.id)
205
+ return [root, *levels]
206
+
207
+
53
208
  def get_fn_source(fn: Callable) -> str:
54
209
  """Get the string representation of a function.
55
210
 
@@ -110,121 +265,80 @@ def get_fn_ast(fn: Callable) -> ast.FunctionDef:
110
265
  return fn_def
111
266
 
112
267
 
113
- def sympy_to_inline(expr: sympy.Expr) -> str:
114
- """Convert a sympy expression to inline Python code.
115
-
116
- Parameters
117
- ----------
118
- expr
119
- The sympy expression to convert
120
-
121
- Returns
122
- -------
123
- str
124
- Python code string for the expression
125
-
126
- Examples
127
- --------
128
- >>> import sympy
129
- >>> x = sympy.Symbol('x')
130
- >>> expr = x**2 + 2*x + 1
131
- >>> sympy_to_inline(expr)
132
- 'x**2 + 2*x + 1'
133
-
134
- """
135
- return cast(str, pycode(expr, fully_qualified_modules=True))
136
-
137
-
138
- def sympy_to_fn(
139
- *,
140
- fn_name: str,
141
- args: list[str],
142
- expr: sympy.Expr,
143
- ) -> str:
144
- """Convert a sympy expression to a python function.
145
-
146
- Parameters
147
- ----------
148
- fn_name
149
- Name of the function to generate
150
- args
151
- List of argument names for the function
152
- expr
153
- Sympy expression to convert to a function body
154
-
155
- Returns
156
- -------
157
- str
158
- String representation of the generated function
159
-
160
- Examples
161
- --------
162
- >>> import sympy
163
- >>> x, y = sympy.symbols('x y')
164
- >>> expr = x**2 + y
165
- >>> print(sympy_to_fn(fn_name="square_plus_y", args=["x", "y"], expr=expr))
166
- def square_plus_y(x: float, y: float) -> float:
167
- return x**2 + y
168
-
169
- """
170
- fn_args = ", ".join(f"{i}: float" for i in args)
171
-
172
- return f"""def {fn_name}({fn_args}) -> float:
173
- return {pycode(expr)}
174
- """
175
-
176
-
177
268
  def fn_to_sympy(
178
269
  fn: Callable,
270
+ origin: str,
179
271
  model_args: list[sympy.Symbol | sympy.Expr] | None = None,
180
- ) -> sympy.Expr:
272
+ ) -> sympy.Expr | None:
181
273
  """Convert a python function to a sympy expression.
182
274
 
183
- Parameters
184
- ----------
185
- fn
186
- The function to convert
187
- model_args
188
- Optional list of sympy symbols to substitute for function arguments
275
+ Args:
276
+ fn: The function to convert
277
+ origin: Name of the original caller. Used for error messages.
278
+ model_args: Optional list of sympy symbols to substitute for function arguments
189
279
 
190
- Returns
191
- -------
192
- sympy.Expr
280
+ Returns:
193
281
  Sympy expression equivalent to the function
194
282
 
195
- Examples
196
- --------
197
- >>> def square_fn(x):
198
- ... return x**2
199
- >>> import sympy
200
- >>> fn_to_sympy(square_fn)
201
- x**2
202
- >>> # With model_args
203
- >>> y = sympy.Symbol('y')
204
- >>> fn_to_sympy(square_fn, [y])
205
- y**2
283
+ Examples:
284
+ >>> def square_fn(x):
285
+ ... return x**2
286
+ >>> import sympy
287
+ >>> fn_to_sympy(square_fn)
288
+ x**2
289
+ >>> # With model_args
290
+ >>> y = sympy.Symbol('y')
291
+ >>> fn_to_sympy(square_fn, [y])
292
+ y**2
206
293
 
207
294
  """
208
- fn_def = get_fn_ast(fn)
209
- fn_args = [str(arg.arg) for arg in fn_def.args.args]
210
- sympy_expr = _handle_fn_body(
211
- fn_def.body,
212
- ctx=Context(
213
- symbols={name: sympy.Symbol(name) for name in fn_args},
214
- caller=fn,
215
- parent_module=inspect.getmodule(fn),
216
- ),
217
- )
218
- if model_args is not None:
219
- sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
220
- return cast(sympy.Expr, sympy_expr)
295
+ try:
296
+ fn_def = get_fn_ast(fn)
297
+ fn_args = [str(arg.arg) for arg in fn_def.args.args]
298
+
299
+ sympy_expr = _handle_fn_body(
300
+ fn_def.body,
301
+ ctx=Context(
302
+ symbols={name: sympy.Symbol(name) for name in fn_args},
303
+ caller=fn,
304
+ parent_module=inspect.getmodule(fn),
305
+ origin=origin,
306
+ modules={},
307
+ fns={},
308
+ ),
309
+ )
310
+ if sympy_expr is None:
311
+ return None
312
+ # FIXME: we shouldn't end up here, where does this come from?
313
+ if isinstance(sympy_expr, float):
314
+ return sympy.Float(sympy_expr)
315
+ if model_args is not None:
316
+ sympy_expr = sympy_expr.subs(dict(zip(fn_args, model_args, strict=True)))
317
+ return cast(sympy.Expr, sympy_expr)
318
+
319
+ except (TypeError, ValueError, NotImplementedError) as e:
320
+ msg = f"Failed parsing function of {origin}"
321
+ _LOGGER.warning(msg)
322
+ _LOGGER.debug("", exc_info=e)
323
+ return None
221
324
 
222
325
 
223
326
  def _handle_name(node: ast.Name, ctx: Context) -> sympy.Symbol | sympy.Expr:
224
- return ctx.symbols[node.id]
327
+ value = ctx.symbols.get(node.id)
328
+ if value is None:
329
+ global_variables = dict(
330
+ inspect.getmembers(
331
+ ctx.parent_module,
332
+ predicate=lambda x: isinstance(x, float),
333
+ )
334
+ )
335
+ value = sympy.Float(global_variables[node.id])
336
+ return value
225
337
 
226
338
 
227
- def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
339
+ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr | None:
340
+ if isinstance(node, float):
341
+ return sympy.Float(node)
228
342
  if isinstance(node, ast.UnaryOp):
229
343
  return _handle_unaryop(node, ctx)
230
344
  if isinstance(node, ast.BinOp):
@@ -233,6 +347,11 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
233
347
  return _handle_name(node, ctx)
234
348
  if isinstance(node, ast.Constant):
235
349
  return node.value
350
+ if isinstance(node, ast.Call):
351
+ return _handle_call(node, ctx=ctx)
352
+ if isinstance(node, ast.Attribute):
353
+ return _handle_attribute(node, ctx=ctx)
354
+
236
355
  if isinstance(node, ast.Compare):
237
356
  # Handle chained comparisons like 1 < a < 2
238
357
  left = cast(Any, _handle_expr(node.left, ctx))
@@ -263,8 +382,6 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
263
382
  for comp in comparisons[1:]:
264
383
  result = sympy.And(result, comp)
265
384
  return cast(sympy.Expr, result)
266
- if isinstance(node, ast.Call):
267
- return _handle_call(node, ctx)
268
385
 
269
386
  # Handle conditional expressions (ternary operators)
270
387
  if isinstance(node, ast.IfExp):
@@ -277,7 +394,7 @@ def _handle_expr(node: ast.expr, ctx: Context) -> sympy.Expr:
277
394
  raise NotImplementedError(msg)
278
395
 
279
396
 
280
- def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
397
+ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr | None:
281
398
  pieces = []
282
399
  remaining_body = list(body)
283
400
 
@@ -333,7 +450,10 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
333
450
  target_elements, value_elements, strict=True
334
451
  ):
335
452
  if isinstance(target, ast.Name):
336
- ctx.symbols[target.id] = _handle_expr(value_expr, ctx)
453
+ expr = _handle_expr(value_expr, ctx)
454
+ if expr is None:
455
+ return None
456
+ ctx.symbols[target.id] = expr
337
457
  else:
338
458
  # Handle potential iterable unpacking
339
459
  value = _handle_expr(node.value, ctx)
@@ -344,8 +464,33 @@ def _handle_fn_body(body: list[ast.stmt], ctx: Context) -> sympy.Expr:
344
464
  raise TypeError(msg)
345
465
  target_name = target.id
346
466
  value = _handle_expr(node.value, ctx)
467
+ if value is None:
468
+ return None
347
469
  ctx.symbols[target_name] = value
348
470
 
471
+ elif isinstance(node, ast.Import):
472
+ for alias in node.names:
473
+ name = alias.name
474
+ ctx.modules[name] = importlib.import_module(name)
475
+
476
+ elif isinstance(node, ast.ImportFrom):
477
+ package = cast(str, node.module)
478
+ module = importlib.import_module(package)
479
+ contents = dict(inspect.getmembers(module))
480
+ for alias in node.names:
481
+ name = alias.name
482
+ el = contents[name]
483
+ if isinstance(el, float):
484
+ ctx.symbols[name] = sympy.Float(el)
485
+ elif callable(el):
486
+ ctx.fns[name] = el
487
+ elif isinstance(el, ModuleType):
488
+ ctx.modules[name] = el
489
+ else:
490
+ _LOGGER.debug("Skipping import %s", node)
491
+ else:
492
+ _LOGGER.debug("Skipping node of type %s", type(node))
493
+
349
494
  # If we have pieces to combine into a Piecewise
350
495
  if pieces:
351
496
  return sympy.Piecewise(*pieces)
@@ -364,13 +509,14 @@ def _handle_unaryop(node: ast.UnaryOp, ctx: Context) -> sympy.Expr:
364
509
  left = _handle_expr(node.operand, ctx)
365
510
  left = cast(Any, left) # stupid sympy types don't allow ops on symbols
366
511
 
367
- if isinstance(node.op, ast.UAdd):
368
- return +left
369
- if isinstance(node.op, ast.USub):
370
- return -left
371
-
372
- msg = f"Operation {type(node.op).__name__} not implemented"
373
- raise NotImplementedError(msg)
512
+ match node.op:
513
+ case ast.UAdd():
514
+ return +left
515
+ case ast.USub():
516
+ return -left
517
+ case _:
518
+ msg = f"Operation {type(node.op).__name__} not implemented"
519
+ raise NotImplementedError(msg)
374
520
 
375
521
 
376
522
  def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
@@ -380,63 +526,158 @@ def _handle_binop(node: ast.BinOp, ctx: Context) -> sympy.Expr:
380
526
  right = _handle_expr(node.right, ctx)
381
527
  right = cast(Any, right) # stupid sympy types don't allow ops on symbols
382
528
 
383
- if isinstance(node.op, ast.Add):
384
- return left + right
385
- if isinstance(node.op, ast.Sub):
386
- return left - right
387
- if isinstance(node.op, ast.Mult):
388
- return left * right
389
- if isinstance(node.op, ast.Div):
390
- return left / right
391
- if isinstance(node.op, ast.Pow):
392
- return left**right
393
- if isinstance(node.op, ast.Mod):
394
- return left % right
395
- if isinstance(node.op, ast.FloorDiv):
396
- return left // right
397
-
398
- msg = f"Operation {type(node.op).__name__} not implemented"
399
- raise NotImplementedError(msg)
529
+ match node.op:
530
+ case ast.Add():
531
+ return left + right
532
+ case ast.Sub():
533
+ return left - right
534
+ case ast.Mult():
535
+ return left * right
536
+ case ast.Div():
537
+ return left / right
538
+ case ast.Pow():
539
+ return left**right
540
+ case ast.Mod():
541
+ return left % right
542
+ case ast.FloorDiv():
543
+ return left // right
544
+ case _:
545
+ msg = f"Operation {type(node.op).__name__} not implemented"
546
+ raise NotImplementedError(msg)
547
+
548
+
549
+ # FIXME: check if target isn't an object or class
550
+ def _handle_attribute(node: ast.Attribute, ctx: Context) -> sympy.Expr | None:
551
+ """Handle an attribute.
552
+
553
+ Structures to expect:
554
+ Attribute(Name(id), attr) | direct
555
+ Attribute(Attribute(Name(id)), attr) | single layer of nesting
556
+ Attribute(Attribute(...), attr) | arbitrary nesting
557
+
558
+ Targets to expect:
559
+ - modules (both absolute and relative import)
560
+ - import a; a.attr
561
+ - import a; a.b.attr
562
+ - from a import b; b.attr
563
+ - objects, e.g. Parameters().a
564
+ - classes, e.g. Parameters.a
565
+
566
+ Watch out for relative imports and the different ways they can be called
567
+ import a
568
+ from a import b
569
+ from a.b import c
570
+
571
+ a.attr
572
+ b.attr
573
+ c.attr
574
+ a.b.attr
575
+ b.c.attr
576
+ a.b.c.attr
577
+ """
578
+ name = str(node.attr)
579
+ module: ModuleType | None = None
580
+ modules = (
581
+ dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
582
+ | ctx.modules
583
+ )
584
+ match node.value:
585
+ case ast.Name(l1):
586
+ module_name = l1
587
+ module = modules.get(module_name)
588
+ case ast.Attribute():
589
+ levels = _find_root(node.value, [])
590
+ module_name = ".".join(levels)
591
+ for level in levels[:-1]:
592
+ modules.update(
593
+ dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
594
+ )
595
+ module = modules.get(levels[-1])
596
+ case _:
597
+ raise NotImplementedError
598
+
599
+ # Fall-back to absolute import
600
+ if module is None:
601
+ module = importlib.import_module(module_name)
602
+
603
+ element = dict(
604
+ inspect.getmembers(
605
+ module,
606
+ predicate=lambda x: isinstance(x, float),
607
+ )
608
+ ).get(name)
400
609
 
610
+ if element is None:
611
+ return None
401
612
 
402
- def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr:
403
- # direct call, e.g. mass_action(x, k1)
404
- if isinstance(callee := node.func, ast.Name):
405
- fn_name = str(callee.id)
406
- fns = dict(inspect.getmembers(ctx.parent_module, predicate=callable))
613
+ if (value := KNOWN_CONSTANTS.get(element)) is not None:
614
+ return value
615
+ return sympy.Float(element)
407
616
 
408
- return fn_to_sympy(
409
- fns[fn_name],
410
- model_args=[_handle_expr(i, ctx) for i in node.args],
411
- )
412
617
 
413
- # search for fn in other namespace
414
- if isinstance(attr := node.func, ast.Attribute):
415
- imports = dict(inspect.getmembers(ctx.parent_module, inspect.ismodule))
618
+ # FIXME: check if target isn't an object or class
619
+ def _handle_call(node: ast.Call, ctx: Context) -> sympy.Expr | None:
620
+ """Handle call expression.
416
621
 
417
- # Single level, e.g. fns.mass_action(x, k1)
418
- if isinstance(module_name := attr.value, ast.Name):
419
- return _handle_call(
420
- ast.Call(func=ast.Name(attr.attr), args=node.args, keywords=[]),
421
- ctx=ctx.updated(parent_module=imports[module_name.id]),
422
- )
622
+ Variants
623
+ - mass_action(x, k1)
624
+ - fns.mass_action(x, k1)
625
+ - mxlpy.fns.mass_action(x, k1)
423
626
 
424
- # Multiple levels, e.g. mxlpy.fns.mass_action(x, k1)
425
- if isinstance(inner_attr := attr.value, ast.Attribute):
426
- if not isinstance(module_name := inner_attr.value, ast.Name):
427
- msg = f"Unknown target kind {module_name}"
428
- raise NotImplementedError(msg)
429
- return _handle_call(
430
- ast.Call(
431
- func=ast.Attribute(
432
- value=ast.Name(inner_attr.attr),
433
- attr=attr.attr,
434
- ),
435
- args=node.args,
436
- keywords=[],
437
- ),
438
- ctx=ctx.updated(parent_module=imports[module_name.id]),
627
+ In future think about?
628
+ - object.call
629
+ - Class.call
630
+ """
631
+ model_args: list[sympy.Expr] = []
632
+ for i in node.args:
633
+ if (expr := _handle_expr(i, ctx)) is None:
634
+ return None
635
+ model_args.append(expr)
636
+
637
+ match node.func:
638
+ case ast.Name(id):
639
+ fn_name = str(id)
640
+ fns = (
641
+ dict(inspect.getmembers(ctx.parent_module, predicate=callable))
642
+ | ctx.fns
643
+ )
644
+ py_fn = fns.get(fn_name)
645
+
646
+ # FIXME: use _handle_attribute for this
647
+ case ast.Attribute(attr=fn_name):
648
+ module: ModuleType | None = None
649
+ modules = (
650
+ dict(inspect.getmembers(ctx.parent_module, predicate=inspect.ismodule))
651
+ | ctx.modules
439
652
  )
440
653
 
441
- msg = f"Unsupported function type {node.func}"
442
- raise NotImplementedError(msg)
654
+ levels = _find_root(node.func, [])
655
+ module_name = ".".join(levels[:-1])
656
+
657
+ _LOGGER.debug("Searching for module %s", module_name)
658
+ for level in levels[:-1]:
659
+ modules.update(
660
+ dict(inspect.getmembers(modules[level], predicate=inspect.ismodule))
661
+ )
662
+ module = modules.get(levels[-2])
663
+
664
+ # Fall-back to absolute import
665
+ if module is None:
666
+ module = importlib.import_module(module_name)
667
+
668
+ fns = dict(inspect.getmembers(module, predicate=callable))
669
+ py_fn = fns.get(fn_name)
670
+ case _:
671
+ raise NotImplementedError
672
+
673
+ if py_fn is None:
674
+ return None
675
+
676
+ if (fn := KNOWN_FNS.get(py_fn)) is not None:
677
+ return fn
678
+
679
+ return fn_to_sympy(
680
+ py_fn,
681
+ origin=ctx.origin,
682
+ model_args=model_args,
683
+ )