mathjs-to-func 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,90 @@
1
+ """Public API for mathjs-to-func."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Callable, Iterable, Mapping
10
+
11
+ from .compiler import CompilationResult, compile_to_callable
12
+ from .errors import (
13
+ CircularDependencyError,
14
+ ExpressionError,
15
+ InputValidationError,
16
+ InvalidNodeError,
17
+ MissingTargetError,
18
+ UnknownIdentifierError,
19
+ )
20
+
21
+ __all__ = [
22
+ "CircularDependencyError",
23
+ "ExpressionError",
24
+ "InputValidationError",
25
+ "InvalidNodeError",
26
+ "MissingTargetError",
27
+ "UnknownIdentifierError",
28
+ "build_evaluator",
29
+ ]
30
+
31
+
32
+ def _extract_payload(
33
+ expressions: Mapping[str, Any] | None,
34
+ inputs: Iterable[str] | None,
35
+ target: str | None,
36
+ payload: Mapping[str, Any] | None,
37
+ ) -> tuple[Mapping[str, Any], Iterable[str], str]:
38
+ if payload is not None:
39
+ if expressions is not None or inputs is not None or target is not None:
40
+ raise ExpressionError(
41
+ "payload cannot be combined with direct arguments",
42
+ expression=None,
43
+ )
44
+ try:
45
+ expressions = payload["expressions"]
46
+ inputs = payload["inputs"]
47
+ target = payload["target"]
48
+ except KeyError as exc:
49
+ missing = exc.args[0]
50
+ raise ExpressionError(
51
+ f"Payload missing required key: {missing}",
52
+ expression=None,
53
+ ) from exc
54
+ if expressions is None or inputs is None or target is None:
55
+ raise ExpressionError("Expressions, inputs, and target are required")
56
+ return expressions, inputs, target
57
+
58
+
59
+ def build_evaluator(
60
+ expressions: Mapping[str, Any] | None = None,
61
+ inputs: Iterable[str] | None = None,
62
+ target: str | None = None,
63
+ *,
64
+ payload: Mapping[str, Any] | None = None,
65
+ include_source: bool = False,
66
+ ) -> Callable[[Mapping[str, Any]], Any]:
67
+ """Compile math.js expressions into a reusable callable.
68
+
69
+ Parameters may be supplied directly or via ``payload`` containing the keys
70
+ ``expressions``, ``inputs``, and ``target``. The returned function expects a
71
+ single mapping argument containing the input values and returns the computed
72
+ target value.
73
+ """
74
+ expressions, inputs, target = _extract_payload(expressions, inputs, target, payload)
75
+
76
+ result: CompilationResult = compile_to_callable(
77
+ expressions=expressions,
78
+ inputs=inputs,
79
+ target=target,
80
+ )
81
+
82
+ func = result.function
83
+ func.__mathjs_required_inputs__ = result.required_inputs
84
+ func.__mathjs_evaluation_order__ = result.evaluation_order
85
+
86
+ if include_source:
87
+ source = ast.unparse(result.module_ast)
88
+ func.__mathjs_source__ = source
89
+
90
+ return func
@@ -0,0 +1,418 @@
1
+ """Convert serialized math.js AST nodes into Python AST nodes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import ast
6
+ import math
7
+ import re
8
+ from collections.abc import Iterable as AbcIterable
9
+ from collections.abc import Mapping
10
+ from collections.abc import Mapping as AbcMapping
11
+ from typing import Any
12
+
13
+ from .errors import InvalidNodeError
14
+
15
+ IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
16
+
17
+
18
+ def ensure_identifier(name: str, *, expression: str | None) -> str:
19
+ """Validate and return a safe identifier for generated expressions."""
20
+ if not IDENTIFIER_PATTERN.match(name):
21
+ raise InvalidNodeError(
22
+ f"Unsupported identifier name: {name!r}",
23
+ expression=expression,
24
+ node=None,
25
+ )
26
+ return name
27
+
28
+
29
+ def _extract_type(node: Mapping[str, Any]) -> str:
30
+ node_type = node.get("type") or node.get("mathjs")
31
+ if not isinstance(node_type, str):
32
+ raise InvalidNodeError(
33
+ "Node is missing 'type' field",
34
+ expression=None,
35
+ node=node,
36
+ )
37
+ return node_type
38
+
39
+
40
+ class MathJsAstVisitor[T]:
41
+ """Generic visitor for math.js AST nodes."""
42
+
43
+ def __init__(self, *, expression_name: str) -> None:
44
+ self.expression_name = expression_name
45
+
46
+ def visit(self, node: Mapping[str, Any]) -> T:
47
+ node_type = _extract_type(node)
48
+ method = getattr(self, f"visit_{node_type}", None)
49
+ if method is None:
50
+ raise InvalidNodeError(
51
+ f"Unsupported node type {node_type!r}",
52
+ expression=self.expression_name,
53
+ node=node,
54
+ )
55
+ return method(node)
56
+
57
+ def _ensure_mapping(
58
+ self,
59
+ value: Any,
60
+ *,
61
+ node: Mapping[str, Any],
62
+ message: str,
63
+ ) -> Mapping[str, Any]:
64
+ if not isinstance(value, AbcMapping):
65
+ raise InvalidNodeError(
66
+ message,
67
+ expression=self.expression_name,
68
+ node=node,
69
+ )
70
+ return value
71
+
72
+ def _ensure_iterable(
73
+ self,
74
+ value: Any,
75
+ *,
76
+ node: Mapping[str, Any],
77
+ message: str,
78
+ ) -> list[Any]:
79
+ if not isinstance(value, AbcIterable):
80
+ raise InvalidNodeError(
81
+ message,
82
+ expression=self.expression_name,
83
+ node=node,
84
+ )
85
+ return list(value)
86
+
87
+
88
+ def _to_number(value: Any, *, expression: str | None) -> float | int:
89
+ if isinstance(value, (int, float)):
90
+ if isinstance(value, float) and not math.isfinite(value):
91
+ raise InvalidNodeError(
92
+ "Non-finite literal encountered",
93
+ expression=expression,
94
+ node=None,
95
+ )
96
+ return value
97
+ if isinstance(value, str):
98
+ lowered = value.lower()
99
+ if "." in value or "e" in lowered:
100
+ try:
101
+ parsed = float(value)
102
+ except ValueError as exc:
103
+ raise InvalidNodeError(
104
+ f"Invalid numeric literal: {value!r}",
105
+ expression=expression,
106
+ node=None,
107
+ ) from exc
108
+ if not math.isfinite(parsed):
109
+ raise InvalidNodeError(
110
+ "Non-finite literal encountered",
111
+ expression=expression,
112
+ node=None,
113
+ )
114
+ return parsed
115
+ try:
116
+ return int(value, 10)
117
+ except ValueError as exc:
118
+ raise InvalidNodeError(
119
+ f"Invalid numeric literal: {value!r}",
120
+ expression=expression,
121
+ node=None,
122
+ ) from exc
123
+ raise InvalidNodeError(
124
+ f"Unsupported literal type: {type(value).__name__}",
125
+ expression=expression,
126
+ node=None,
127
+ )
128
+
129
+
130
+ class MathJsAstBuilder(MathJsAstVisitor[ast.expr]):
131
+ """Translate math.js AST nodes into Python AST expressions."""
132
+
133
+ def __init__(
134
+ self,
135
+ *,
136
+ expression_name: str,
137
+ helper_names: Mapping[str, str],
138
+ ) -> None:
139
+ super().__init__(expression_name=expression_name)
140
+ self.helper_names = helper_names
141
+
142
+ def build(self, node: Mapping[str, Any]) -> ast.expr:
143
+ return self.visit(node)
144
+
145
+ def visit_ConstantNode(self, node: Mapping[str, Any]) -> ast.expr:
146
+ value_type = node.get("valueType")
147
+ value = node.get("value")
148
+ if value_type in {None, "number"}:
149
+ number = _to_number(value, expression=self.expression_name)
150
+ return ast.Constant(value=number)
151
+ if value_type == "boolean":
152
+ parsed = value.lower() == "true" if isinstance(value, str) else bool(value)
153
+ return ast.Constant(value=parsed)
154
+ if value_type == "null":
155
+ return ast.Constant(value=None)
156
+ raise InvalidNodeError(
157
+ f"Unsupported constant value type: {value_type!r}",
158
+ expression=self.expression_name,
159
+ node=node,
160
+ )
161
+
162
+ def visit_SymbolNode(self, node: Mapping[str, Any]) -> ast.expr:
163
+ name = node.get("name")
164
+ if not isinstance(name, str):
165
+ raise InvalidNodeError(
166
+ "SymbolNode missing name",
167
+ expression=self.expression_name,
168
+ node=node,
169
+ )
170
+ safe_name = ensure_identifier(name, expression=self.expression_name)
171
+ return ast.Name(id=safe_name, ctx=ast.Load())
172
+
173
+ def visit_ParenthesisNode(self, node: Mapping[str, Any]) -> ast.expr:
174
+ content = node.get("content") or node.get("expr")
175
+ child = self._ensure_mapping(
176
+ content,
177
+ node=node,
178
+ message="ParenthesisNode missing child content",
179
+ )
180
+ return self.visit(child)
181
+
182
+ def visit_OperatorNode(self, node: Mapping[str, Any]) -> ast.expr:
183
+ args_list = self._ensure_iterable(
184
+ node.get("args"),
185
+ node=node,
186
+ message="OperatorNode missing args",
187
+ )
188
+ fn = node.get("fn")
189
+ if len(args_list) == 1:
190
+ child = self._ensure_mapping(
191
+ args_list[0],
192
+ node=node,
193
+ message="OperatorNode child must be object",
194
+ )
195
+ return self._visit_unary_operator(fn, child)
196
+ if len(args_list) == 2:
197
+ left_node = self._ensure_mapping(
198
+ args_list[0],
199
+ node=node,
200
+ message="OperatorNode children must be objects",
201
+ )
202
+ right_node = self._ensure_mapping(
203
+ args_list[1],
204
+ node=node,
205
+ message="OperatorNode children must be objects",
206
+ )
207
+ return self._visit_binary_operator(fn, left_node, right_node)
208
+ raise InvalidNodeError(
209
+ "OperatorNode args must be unary or binary",
210
+ expression=self.expression_name,
211
+ node=node,
212
+ )
213
+
214
+ def _visit_unary_operator(
215
+ self,
216
+ fn: Any,
217
+ child: Mapping[str, Any],
218
+ ) -> ast.expr:
219
+ if fn not in {"unaryMinus", "unaryPlus"}:
220
+ raise InvalidNodeError(
221
+ f"Unsupported unary operator: {fn!r}",
222
+ expression=self.expression_name,
223
+ node=None,
224
+ )
225
+ operand = self.visit(child)
226
+ op = ast.USub() if fn == "unaryMinus" else ast.UAdd()
227
+ return ast.UnaryOp(op=op, operand=operand)
228
+
229
+ def _visit_binary_operator(
230
+ self,
231
+ fn: Any,
232
+ left_node: Mapping[str, Any],
233
+ right_node: Mapping[str, Any],
234
+ ) -> ast.expr:
235
+ left = self.visit(left_node)
236
+ right = self.visit(right_node)
237
+ match fn:
238
+ case "add":
239
+ op = ast.Add()
240
+ case "subtract":
241
+ op = ast.Sub()
242
+ case "multiply":
243
+ op = ast.Mult()
244
+ case "divide":
245
+ op = ast.Div()
246
+ case "pow":
247
+ op = ast.Pow()
248
+ case "mod":
249
+ op = ast.Mod()
250
+ case _:
251
+ raise InvalidNodeError(
252
+ f"Unsupported binary operator: {fn!r}",
253
+ expression=self.expression_name,
254
+ node=None,
255
+ )
256
+ return ast.BinOp(left=left, op=op, right=right)
257
+
258
+ def visit_FunctionNode(self, node: Mapping[str, Any]) -> ast.expr:
259
+ raw_fn = node.get("fn")
260
+ fn_name = raw_fn.get("name") if isinstance(raw_fn, AbcMapping) else raw_fn
261
+ if not isinstance(fn_name, str):
262
+ raise InvalidNodeError(
263
+ "FunctionNode missing function name",
264
+ expression=self.expression_name,
265
+ node=node,
266
+ )
267
+ normalized = fn_name.strip()
268
+ helper_name = self.helper_names.get(normalized)
269
+ if helper_name is None:
270
+ raise InvalidNodeError(
271
+ f"Unsupported function {normalized!r}",
272
+ expression=self.expression_name,
273
+ node=node,
274
+ )
275
+ args = node.get("args") or []
276
+ args_list = self._ensure_iterable(
277
+ args,
278
+ node=node,
279
+ message="FunctionNode args must be iterable",
280
+ )
281
+ call_args = []
282
+ for arg in args_list:
283
+ child = self._ensure_mapping(
284
+ arg,
285
+ node=node,
286
+ message="FunctionNode argument must be object",
287
+ )
288
+ call_args.append(self.visit(child))
289
+
290
+ if normalized == "ifnull" and len(call_args) != 2:
291
+ raise InvalidNodeError(
292
+ "ifnull expects exactly two arguments",
293
+ expression=self.expression_name,
294
+ node=node,
295
+ )
296
+
297
+ if normalized in {"min", "max", "sum"} and not call_args:
298
+ raise InvalidNodeError(
299
+ f"{normalized} requires at least one argument",
300
+ expression=self.expression_name,
301
+ node=node,
302
+ )
303
+ return ast.Call(
304
+ func=ast.Name(id=helper_name, ctx=ast.Load()),
305
+ args=call_args,
306
+ keywords=[],
307
+ )
308
+
309
+ def visit_ArrayNode(self, node: Mapping[str, Any]) -> ast.expr:
310
+ items = node.get("items")
311
+ items_list = self._ensure_iterable(
312
+ items,
313
+ node=node,
314
+ message="ArrayNode items must be iterable",
315
+ )
316
+ elts: list[ast.expr] = []
317
+ for item in items_list:
318
+ element = self._ensure_mapping(
319
+ item,
320
+ node=node,
321
+ message="ArrayNode element must be object",
322
+ )
323
+ elts.append(self.visit(element))
324
+ return ast.List(elts=elts, ctx=ast.Load())
325
+
326
+
327
+ class SymbolDependencyCollector(MathJsAstVisitor[set[str]]):
328
+ """Collect symbol dependencies from math.js AST nodes."""
329
+
330
+ def collect(self, node: Mapping[str, Any]) -> set[str]:
331
+ return self.visit(node)
332
+
333
+ def visit_ConstantNode(self, node: Mapping[str, Any]) -> set[str]:
334
+ return set()
335
+
336
+ def visit_SymbolNode(self, node: Mapping[str, Any]) -> set[str]:
337
+ name = node.get("name")
338
+ if not isinstance(name, str):
339
+ raise InvalidNodeError(
340
+ "SymbolNode missing name",
341
+ expression=self.expression_name,
342
+ node=node,
343
+ )
344
+ ensure_identifier(name, expression=self.expression_name)
345
+ return {name}
346
+
347
+ def visit_ParenthesisNode(self, node: Mapping[str, Any]) -> set[str]:
348
+ content = node.get("content") or node.get("expr")
349
+ child = self._ensure_mapping(
350
+ content,
351
+ node=node,
352
+ message="ParenthesisNode missing child content",
353
+ )
354
+ return self.visit(child)
355
+
356
+ def visit_OperatorNode(self, node: Mapping[str, Any]) -> set[str]:
357
+ args_list = self._ensure_iterable(
358
+ node.get("args"),
359
+ node=node,
360
+ message="OperatorNode missing args",
361
+ )
362
+ result: set[str] = set()
363
+ for child in args_list:
364
+ mapping = self._ensure_mapping(
365
+ child,
366
+ node=node,
367
+ message="OperatorNode argument must be object",
368
+ )
369
+ result.update(self.visit(mapping))
370
+ return result
371
+
372
+ def visit_FunctionNode(self, node: Mapping[str, Any]) -> set[str]:
373
+ result: set[str] = set()
374
+ raw_fn = node.get("fn")
375
+ if isinstance(raw_fn, AbcMapping):
376
+ fn_type = raw_fn.get("type") or raw_fn.get("mathjs")
377
+ if isinstance(fn_type, str):
378
+ result.update(self.visit(raw_fn))
379
+ args = node.get("args") or []
380
+ args_list = self._ensure_iterable(
381
+ args,
382
+ node=node,
383
+ message="FunctionNode args must be iterable",
384
+ )
385
+ for arg in args_list:
386
+ child = self._ensure_mapping(
387
+ arg,
388
+ node=node,
389
+ message="FunctionNode argument must be object",
390
+ )
391
+ result.update(self.visit(child))
392
+ return result
393
+
394
+ def visit_ArrayNode(self, node: Mapping[str, Any]) -> set[str]:
395
+ items = node.get("items")
396
+ items_list = self._ensure_iterable(
397
+ items,
398
+ node=node,
399
+ message="ArrayNode items must be iterable",
400
+ )
401
+ result: set[str] = set()
402
+ for item in items_list:
403
+ element = self._ensure_mapping(
404
+ item,
405
+ node=node,
406
+ message="ArrayNode element must be object",
407
+ )
408
+ result.update(self.visit(element))
409
+ return result
410
+
411
+
412
+ __all__ = [
413
+ "IDENTIFIER_PATTERN",
414
+ "MathJsAstBuilder",
415
+ "MathJsAstVisitor",
416
+ "SymbolDependencyCollector",
417
+ "ensure_identifier",
418
+ ]