pyfcstm 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyfcstm/__init__.py +0 -0
- pyfcstm/__main__.py +4 -0
- pyfcstm/config/__init__.py +0 -0
- pyfcstm/config/meta.py +20 -0
- pyfcstm/dsl/__init__.py +6 -0
- pyfcstm/dsl/error.py +226 -0
- pyfcstm/dsl/grammar/Grammar.g4 +190 -0
- pyfcstm/dsl/grammar/Grammar.interp +168 -0
- pyfcstm/dsl/grammar/Grammar.tokens +118 -0
- pyfcstm/dsl/grammar/GrammarLexer.interp +214 -0
- pyfcstm/dsl/grammar/GrammarLexer.py +523 -0
- pyfcstm/dsl/grammar/GrammarLexer.tokens +118 -0
- pyfcstm/dsl/grammar/GrammarListener.py +521 -0
- pyfcstm/dsl/grammar/GrammarParser.py +4373 -0
- pyfcstm/dsl/grammar/__init__.py +3 -0
- pyfcstm/dsl/listener.py +440 -0
- pyfcstm/dsl/node.py +1581 -0
- pyfcstm/dsl/parse.py +155 -0
- pyfcstm/entry/__init__.py +1 -0
- pyfcstm/entry/base.py +126 -0
- pyfcstm/entry/cli.py +12 -0
- pyfcstm/entry/dispatch.py +46 -0
- pyfcstm/entry/generate.py +83 -0
- pyfcstm/entry/plantuml.py +67 -0
- pyfcstm/model/__init__.py +3 -0
- pyfcstm/model/base.py +51 -0
- pyfcstm/model/expr.py +764 -0
- pyfcstm/model/model.py +1392 -0
- pyfcstm/render/__init__.py +3 -0
- pyfcstm/render/env.py +36 -0
- pyfcstm/render/expr.py +180 -0
- pyfcstm/render/func.py +77 -0
- pyfcstm/render/render.py +279 -0
- pyfcstm/utils/__init__.py +6 -0
- pyfcstm/utils/binary.py +38 -0
- pyfcstm/utils/doc.py +64 -0
- pyfcstm/utils/jinja2.py +121 -0
- pyfcstm/utils/json.py +125 -0
- pyfcstm/utils/text.py +91 -0
- pyfcstm/utils/validate.py +102 -0
- pyfcstm-0.0.1.dist-info/LICENSE +165 -0
- pyfcstm-0.0.1.dist-info/METADATA +205 -0
- pyfcstm-0.0.1.dist-info/RECORD +46 -0
- pyfcstm-0.0.1.dist-info/WHEEL +5 -0
- pyfcstm-0.0.1.dist-info/entry_points.txt +2 -0
- pyfcstm-0.0.1.dist-info/top_level.txt +1 -0
pyfcstm/model/expr.py
ADDED
@@ -0,0 +1,764 @@
|
|
1
|
+
"""
|
2
|
+
Expression handling module for mathematical expressions and operations.
|
3
|
+
|
4
|
+
This module provides a comprehensive set of classes for representing and evaluating
|
5
|
+
mathematical expressions. It includes support for basic data types (integers, floats, booleans),
|
6
|
+
various operators (unary, binary, conditional), mathematical functions, and variables.
|
7
|
+
|
8
|
+
The expression system allows for:
|
9
|
+
|
10
|
+
- Building complex mathematical expressions
|
11
|
+
- Evaluating expressions with variable substitution
|
12
|
+
- Converting expressions to AST nodes
|
13
|
+
- Analyzing expressions to extract variables
|
14
|
+
- Handling operator precedence correctly
|
15
|
+
|
16
|
+
All expression classes inherit from the base `Expr` class and implement the AstExportable interface.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import math
|
20
|
+
import operator
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from typing import Iterator
|
23
|
+
|
24
|
+
from .base import AstExportable
|
25
|
+
from ..dsl import node as dsl_nodes
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'Expr',
|
29
|
+
'Integer',
|
30
|
+
'Float',
|
31
|
+
'Boolean',
|
32
|
+
'Op',
|
33
|
+
'UnaryOp',
|
34
|
+
'BinaryOp',
|
35
|
+
'ConditionalOp',
|
36
|
+
'UFunc',
|
37
|
+
'Variable',
|
38
|
+
'parse_expr_node_to_expr',
|
39
|
+
]
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass
|
43
|
+
class Expr(AstExportable):
|
44
|
+
"""
|
45
|
+
Base class for all expressions.
|
46
|
+
|
47
|
+
This abstract class defines the common interface for all expression types.
|
48
|
+
It provides methods for traversing the expression tree, evaluating expressions,
|
49
|
+
and converting expressions to AST nodes.
|
50
|
+
|
51
|
+
:rtype: Expr
|
52
|
+
"""
|
53
|
+
|
54
|
+
def _iter_subs(self) -> Iterator['Expr']:
|
55
|
+
"""
|
56
|
+
Iterate over direct sub-expressions of this expression.
|
57
|
+
|
58
|
+
:return: Iterator over sub-expressions
|
59
|
+
:rtype: Iterator[Expr]
|
60
|
+
"""
|
61
|
+
yield from []
|
62
|
+
|
63
|
+
def _iter_all_subs(self):
|
64
|
+
"""
|
65
|
+
Recursively iterate over all sub-expressions including this expression.
|
66
|
+
|
67
|
+
:return: Iterator over all sub-expressions
|
68
|
+
:rtype: Iterator[Expr]
|
69
|
+
"""
|
70
|
+
yield self
|
71
|
+
for sub in self._iter_subs():
|
72
|
+
yield from sub._iter_all_subs()
|
73
|
+
|
74
|
+
def list_variables(self):
|
75
|
+
"""
|
76
|
+
List all unique variables used in this expression.
|
77
|
+
|
78
|
+
:return: List of unique Variable objects
|
79
|
+
:rtype: list[Variable]
|
80
|
+
"""
|
81
|
+
vs, retval = set(), []
|
82
|
+
for item in self._iter_all_subs():
|
83
|
+
if isinstance(item, Variable) and item.name not in vs:
|
84
|
+
retval.append(item)
|
85
|
+
vs.add(item.name)
|
86
|
+
return retval
|
87
|
+
|
88
|
+
def _call(self, **kwargs):
|
89
|
+
"""
|
90
|
+
Internal method to evaluate the expression with given variable values.
|
91
|
+
|
92
|
+
:param kwargs: Variable name to value mapping
|
93
|
+
:return: Result of the expression evaluation
|
94
|
+
:raises NotImplementedError: Must be implemented by subclasses
|
95
|
+
"""
|
96
|
+
raise NotImplementedError # pragma: no cover
|
97
|
+
|
98
|
+
def __call__(self, **kwargs):
|
99
|
+
"""
|
100
|
+
Evaluate the expression with given variable values.
|
101
|
+
|
102
|
+
:param kwargs: Variable name to value mapping
|
103
|
+
:return: Result of the expression evaluation
|
104
|
+
"""
|
105
|
+
return self._call(**kwargs)
|
106
|
+
|
107
|
+
def __str__(self):
|
108
|
+
"""
|
109
|
+
Get string representation of the expression.
|
110
|
+
|
111
|
+
:return: String representation
|
112
|
+
:rtype: str
|
113
|
+
"""
|
114
|
+
return str(self.to_ast_node())
|
115
|
+
|
116
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
117
|
+
"""
|
118
|
+
Convert this expression to an AST node.
|
119
|
+
|
120
|
+
:return: AST node representation
|
121
|
+
:rtype: dsl_nodes.Expr
|
122
|
+
:raises NotImplementedError: Must be implemented by subclasses
|
123
|
+
"""
|
124
|
+
raise NotImplementedError # pragma: no cover
|
125
|
+
|
126
|
+
|
127
|
+
@dataclass
|
128
|
+
class Integer(Expr):
|
129
|
+
"""
|
130
|
+
Integer literal expression.
|
131
|
+
|
132
|
+
:param value: Integer value
|
133
|
+
:type value: int
|
134
|
+
"""
|
135
|
+
value: int
|
136
|
+
|
137
|
+
def _call(self, **kwargs):
|
138
|
+
"""
|
139
|
+
Return the integer value.
|
140
|
+
|
141
|
+
:param kwargs: Ignored
|
142
|
+
:return: Integer value
|
143
|
+
:rtype: int
|
144
|
+
"""
|
145
|
+
return self.value
|
146
|
+
|
147
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
148
|
+
"""
|
149
|
+
Convert to an Integer AST node.
|
150
|
+
|
151
|
+
:return: Integer AST node
|
152
|
+
:rtype: dsl_nodes.Integer
|
153
|
+
"""
|
154
|
+
return dsl_nodes.Integer(raw=str(int(self.value)))
|
155
|
+
|
156
|
+
|
157
|
+
@dataclass
|
158
|
+
class Float(Expr):
|
159
|
+
"""
|
160
|
+
Floating point literal expression.
|
161
|
+
|
162
|
+
:param value: Float value
|
163
|
+
:type value: float
|
164
|
+
"""
|
165
|
+
value: float
|
166
|
+
|
167
|
+
def _call(self, **kwargs):
|
168
|
+
"""
|
169
|
+
Return the float value.
|
170
|
+
|
171
|
+
:param kwargs: Ignored
|
172
|
+
:return: Float value
|
173
|
+
:rtype: float
|
174
|
+
"""
|
175
|
+
return self.value
|
176
|
+
|
177
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
178
|
+
"""
|
179
|
+
Convert to a Float AST node or a Constant node for special values.
|
180
|
+
|
181
|
+
Recognizes mathematical constants like pi, e, and tau.
|
182
|
+
|
183
|
+
:return: Float or Constant AST node
|
184
|
+
:rtype: Union[dsl_nodes.Float, dsl_nodes.Constant]
|
185
|
+
"""
|
186
|
+
const_name = None
|
187
|
+
if abs(self.value - math.pi) < 1e-10:
|
188
|
+
const_name = 'pi'
|
189
|
+
elif abs(self.value - math.e) < 1e-10:
|
190
|
+
const_name = 'E'
|
191
|
+
elif abs(self.value - math.tau) < 1e-10:
|
192
|
+
const_name = 'tau'
|
193
|
+
|
194
|
+
if const_name is None:
|
195
|
+
return dsl_nodes.Float(raw=str(float(self.value)))
|
196
|
+
else:
|
197
|
+
return dsl_nodes.Constant(raw=const_name)
|
198
|
+
|
199
|
+
|
200
|
+
@dataclass
|
201
|
+
class Boolean(Expr):
|
202
|
+
"""
|
203
|
+
Boolean literal expression.
|
204
|
+
|
205
|
+
:param value: Boolean value
|
206
|
+
:type value: bool
|
207
|
+
"""
|
208
|
+
value: bool
|
209
|
+
|
210
|
+
def __post_init__(self):
|
211
|
+
"""
|
212
|
+
Ensure the value is a boolean.
|
213
|
+
"""
|
214
|
+
self.value = bool(self.value)
|
215
|
+
|
216
|
+
def _call(self, **kwargs):
|
217
|
+
"""
|
218
|
+
Return the boolean value.
|
219
|
+
|
220
|
+
:param kwargs: Ignored
|
221
|
+
:return: Boolean value
|
222
|
+
:rtype: bool
|
223
|
+
"""
|
224
|
+
return self.value
|
225
|
+
|
226
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
227
|
+
"""
|
228
|
+
Convert to a Boolean AST node.
|
229
|
+
|
230
|
+
:return: Boolean AST node
|
231
|
+
:rtype: dsl_nodes.Boolean
|
232
|
+
"""
|
233
|
+
return dsl_nodes.Boolean(raw=str(self.value).lower())
|
234
|
+
|
235
|
+
|
236
|
+
_OP_PRECEDENCE = {
|
237
|
+
# Parentheses (highest precedence)
|
238
|
+
"()": 100,
|
239
|
+
|
240
|
+
# Function calls
|
241
|
+
"function_call": 90,
|
242
|
+
|
243
|
+
# Unary operators
|
244
|
+
"unary+": 80,
|
245
|
+
"unary-": 80,
|
246
|
+
"!": 80,
|
247
|
+
"not": 80,
|
248
|
+
|
249
|
+
# Exponentiation (right associative)
|
250
|
+
"**": 70,
|
251
|
+
|
252
|
+
# Multiplicative operators
|
253
|
+
"*": 60,
|
254
|
+
"/": 60,
|
255
|
+
"%": 60,
|
256
|
+
|
257
|
+
# Additive operators
|
258
|
+
"+": 50,
|
259
|
+
"-": 50,
|
260
|
+
|
261
|
+
# Bitwise shift operators
|
262
|
+
"<<": 40,
|
263
|
+
">>": 40,
|
264
|
+
|
265
|
+
# Bitwise AND
|
266
|
+
"&": 35,
|
267
|
+
|
268
|
+
# Bitwise XOR
|
269
|
+
"^": 30,
|
270
|
+
|
271
|
+
# Bitwise OR
|
272
|
+
"|": 25,
|
273
|
+
|
274
|
+
# Comparison operators
|
275
|
+
"<": 20,
|
276
|
+
">": 20,
|
277
|
+
"<=": 20,
|
278
|
+
">=": 20,
|
279
|
+
"==": 20,
|
280
|
+
"!=": 20,
|
281
|
+
|
282
|
+
# Logical operators
|
283
|
+
"&&": 15,
|
284
|
+
"and": 15,
|
285
|
+
"||": 10,
|
286
|
+
"or": 10,
|
287
|
+
|
288
|
+
# Conditional/ternary operator (C-style)
|
289
|
+
"?:": 5
|
290
|
+
}
|
291
|
+
|
292
|
+
_OP_FUNCTIONS = {
|
293
|
+
# Unary operators
|
294
|
+
"unary+": operator.pos,
|
295
|
+
"unary-": operator.neg,
|
296
|
+
"!": lambda x: not bool(x),
|
297
|
+
"not": lambda x: not bool(x),
|
298
|
+
|
299
|
+
# Binary operators
|
300
|
+
"**": operator.pow,
|
301
|
+
"*": operator.mul,
|
302
|
+
"/": operator.truediv,
|
303
|
+
"%": operator.mod,
|
304
|
+
"+": operator.add,
|
305
|
+
"-": operator.sub,
|
306
|
+
"<<": operator.lshift,
|
307
|
+
">>": operator.rshift,
|
308
|
+
"&": operator.and_,
|
309
|
+
"^": operator.xor,
|
310
|
+
"|": operator.or_,
|
311
|
+
"<": operator.lt,
|
312
|
+
">": operator.gt,
|
313
|
+
"<=": operator.le,
|
314
|
+
">=": operator.ge,
|
315
|
+
"==": operator.eq,
|
316
|
+
"!=": operator.ne,
|
317
|
+
"&&": lambda x, y: bool(x) and bool(y),
|
318
|
+
"and": lambda x, y: bool(x) and bool(y),
|
319
|
+
"||": lambda x, y: bool(x) or bool(y),
|
320
|
+
"or": lambda x, y: bool(x) or bool(y),
|
321
|
+
|
322
|
+
# Ternary operator
|
323
|
+
"?:": lambda condition, true_value, false_value: true_value if condition else false_value
|
324
|
+
}
|
325
|
+
|
326
|
+
|
327
|
+
@dataclass
|
328
|
+
class Op(Expr):
|
329
|
+
"""
|
330
|
+
Base class for all operator expressions.
|
331
|
+
|
332
|
+
This abstract class provides common functionality for operator expressions.
|
333
|
+
"""
|
334
|
+
|
335
|
+
@property
|
336
|
+
def op_mark(self):
|
337
|
+
"""
|
338
|
+
Get the operator mark for precedence lookup.
|
339
|
+
|
340
|
+
:return: Operator mark
|
341
|
+
:rtype: str
|
342
|
+
:raises NotImplementedError: Must be implemented by subclasses
|
343
|
+
"""
|
344
|
+
raise NotImplementedError # pragma: no cover
|
345
|
+
|
346
|
+
|
347
|
+
@dataclass
|
348
|
+
class BinaryOp(Op):
|
349
|
+
"""
|
350
|
+
Binary operator expression.
|
351
|
+
|
352
|
+
:param x: Left operand
|
353
|
+
:type x: Expr
|
354
|
+
:param op: Operator symbol
|
355
|
+
:type op: str
|
356
|
+
:param y: Right operand
|
357
|
+
:type y: Expr
|
358
|
+
"""
|
359
|
+
__aliases__ = {
|
360
|
+
'and': '&&',
|
361
|
+
'or': '||',
|
362
|
+
}
|
363
|
+
|
364
|
+
x: Expr
|
365
|
+
op: str
|
366
|
+
y: Expr
|
367
|
+
|
368
|
+
def __post_init__(self):
|
369
|
+
"""
|
370
|
+
Normalize operator aliases.
|
371
|
+
"""
|
372
|
+
self.op = self.__aliases__.get(self.op, self.op)
|
373
|
+
|
374
|
+
@property
|
375
|
+
def op_mark(self):
|
376
|
+
"""
|
377
|
+
Get the operator mark for precedence lookup.
|
378
|
+
|
379
|
+
:return: Operator mark
|
380
|
+
:rtype: str
|
381
|
+
"""
|
382
|
+
return self.op
|
383
|
+
|
384
|
+
def _iter_subs(self):
|
385
|
+
"""
|
386
|
+
Iterate over operands.
|
387
|
+
|
388
|
+
:return: Iterator over operands
|
389
|
+
:rtype: Iterator[Expr]
|
390
|
+
"""
|
391
|
+
yield self.x
|
392
|
+
yield self.y
|
393
|
+
|
394
|
+
def _call(self, **kwargs):
|
395
|
+
"""
|
396
|
+
Evaluate the binary operation.
|
397
|
+
|
398
|
+
:param kwargs: Variable name to value mapping
|
399
|
+
:return: Result of the operation
|
400
|
+
"""
|
401
|
+
return _OP_FUNCTIONS[self.op_mark](self.x._call(**kwargs), self.y._call(**kwargs))
|
402
|
+
|
403
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
404
|
+
"""
|
405
|
+
Convert to a BinaryOp AST node.
|
406
|
+
|
407
|
+
Handles operator precedence by adding parentheses where needed.
|
408
|
+
|
409
|
+
:return: BinaryOp AST node
|
410
|
+
:rtype: dsl_nodes.BinaryOp
|
411
|
+
"""
|
412
|
+
my_pre = _OP_PRECEDENCE[self.op_mark]
|
413
|
+
|
414
|
+
left_need_paren = False
|
415
|
+
if isinstance(self.x, Op):
|
416
|
+
left_pre = _OP_PRECEDENCE[self.x.op_mark]
|
417
|
+
if left_pre < my_pre:
|
418
|
+
left_need_paren = True
|
419
|
+
|
420
|
+
right_need_paren = False
|
421
|
+
if isinstance(self.y, Op):
|
422
|
+
right_pre = _OP_PRECEDENCE[self.y.op_mark]
|
423
|
+
if right_pre <= my_pre:
|
424
|
+
right_need_paren = True
|
425
|
+
|
426
|
+
left_term = self.x.to_ast_node()
|
427
|
+
if left_need_paren:
|
428
|
+
left_term = dsl_nodes.Paren(left_term)
|
429
|
+
right_term = self.y.to_ast_node()
|
430
|
+
if right_need_paren:
|
431
|
+
right_term = dsl_nodes.Paren(right_term)
|
432
|
+
|
433
|
+
return dsl_nodes.BinaryOp(
|
434
|
+
expr1=left_term,
|
435
|
+
op=self.op,
|
436
|
+
expr2=right_term,
|
437
|
+
)
|
438
|
+
|
439
|
+
|
440
|
+
@dataclass
|
441
|
+
class UnaryOp(Op):
|
442
|
+
"""
|
443
|
+
Unary operator expression.
|
444
|
+
|
445
|
+
:param op: Operator symbol
|
446
|
+
:type op: str
|
447
|
+
:param x: Operand
|
448
|
+
:type x: Expr
|
449
|
+
"""
|
450
|
+
__aliases__ = {
|
451
|
+
'not': '!',
|
452
|
+
}
|
453
|
+
|
454
|
+
op: str
|
455
|
+
x: Expr
|
456
|
+
|
457
|
+
def __post_init__(self):
|
458
|
+
"""
|
459
|
+
Normalize operator aliases.
|
460
|
+
"""
|
461
|
+
self.op = self.__aliases__.get(self.op, self.op)
|
462
|
+
|
463
|
+
@property
|
464
|
+
def op_mark(self):
|
465
|
+
"""
|
466
|
+
Get the operator mark for precedence lookup.
|
467
|
+
|
468
|
+
:return: Operator mark
|
469
|
+
:rtype: str
|
470
|
+
"""
|
471
|
+
return f'unary{self.op}' if self.op in {'+', '-'} else self.op
|
472
|
+
|
473
|
+
def _iter_subs(self):
|
474
|
+
"""
|
475
|
+
Iterate over operands.
|
476
|
+
|
477
|
+
:return: Iterator over operands
|
478
|
+
:rtype: Iterator[Expr]
|
479
|
+
"""
|
480
|
+
yield self.x
|
481
|
+
|
482
|
+
def _call(self, **kwargs):
|
483
|
+
"""
|
484
|
+
Evaluate the unary operation.
|
485
|
+
|
486
|
+
:param kwargs: Variable name to value mapping
|
487
|
+
:return: Result of the operation
|
488
|
+
"""
|
489
|
+
return _OP_FUNCTIONS[self.op_mark](self.x._call(**kwargs))
|
490
|
+
|
491
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
492
|
+
"""
|
493
|
+
Convert to a UnaryOp AST node.
|
494
|
+
|
495
|
+
Handles operator precedence by adding parentheses where needed.
|
496
|
+
|
497
|
+
:return: UnaryOp AST node
|
498
|
+
:rtype: dsl_nodes.UnaryOp
|
499
|
+
"""
|
500
|
+
my_pre = _OP_PRECEDENCE[self.op_mark]
|
501
|
+
x_node = self.x.to_ast_node()
|
502
|
+
if isinstance(self.x, Op):
|
503
|
+
value_pre = _OP_PRECEDENCE[self.x.op_mark]
|
504
|
+
if value_pre <= my_pre:
|
505
|
+
x_node = dsl_nodes.Paren(expr=x_node)
|
506
|
+
return dsl_nodes.UnaryOp(op=self.op, expr=x_node)
|
507
|
+
|
508
|
+
|
509
|
+
_MATH_FUNCTIONS = {
|
510
|
+
# Trigonometric functions
|
511
|
+
"sin": math.sin,
|
512
|
+
"cos": math.cos,
|
513
|
+
"tan": math.tan,
|
514
|
+
"asin": math.asin,
|
515
|
+
"acos": math.acos,
|
516
|
+
"atan": math.atan,
|
517
|
+
|
518
|
+
# Hyperbolic functions
|
519
|
+
"sinh": math.sinh,
|
520
|
+
"cosh": math.cosh,
|
521
|
+
"tanh": math.tanh,
|
522
|
+
"asinh": math.asinh,
|
523
|
+
"acosh": math.acosh,
|
524
|
+
"atanh": math.atanh,
|
525
|
+
|
526
|
+
# Root and power functions
|
527
|
+
"sqrt": math.sqrt,
|
528
|
+
"cbrt": lambda x: math.pow(x, 1 / 3), # Cube root implementation
|
529
|
+
"exp": math.exp,
|
530
|
+
|
531
|
+
# Logarithmic functions
|
532
|
+
"log": math.log, # Natural logarithm (base e)
|
533
|
+
"log10": math.log10,
|
534
|
+
"log2": math.log2,
|
535
|
+
"log1p": math.log1p, # log(1+x)
|
536
|
+
|
537
|
+
# Rounding and absolute value functions
|
538
|
+
"abs": abs, # Python's built-in abs function
|
539
|
+
"ceil": math.ceil,
|
540
|
+
"floor": math.floor,
|
541
|
+
"round": round, # Python's built-in round function
|
542
|
+
"trunc": math.trunc,
|
543
|
+
|
544
|
+
# Sign function
|
545
|
+
"sign": lambda x: 0 if x == 0 else (1 if x > 0 else -1) # Returns the sign of x
|
546
|
+
}
|
547
|
+
|
548
|
+
|
549
|
+
@dataclass
|
550
|
+
class UFunc(Expr):
|
551
|
+
"""
|
552
|
+
Mathematical function expression.
|
553
|
+
|
554
|
+
Represents calls to mathematical functions like sin, cos, sqrt, etc.
|
555
|
+
|
556
|
+
:param func: Function name
|
557
|
+
:type func: str
|
558
|
+
:param x: Function argument
|
559
|
+
:type x: Expr
|
560
|
+
"""
|
561
|
+
func: str
|
562
|
+
x: Expr
|
563
|
+
|
564
|
+
def _iter_subs(self):
|
565
|
+
"""
|
566
|
+
Iterate over function arguments.
|
567
|
+
|
568
|
+
:return: Iterator over arguments
|
569
|
+
:rtype: Iterator[Expr]
|
570
|
+
"""
|
571
|
+
yield self.x
|
572
|
+
|
573
|
+
def _call(self, **kwargs):
|
574
|
+
"""
|
575
|
+
Evaluate the function.
|
576
|
+
|
577
|
+
:param kwargs: Variable name to value mapping
|
578
|
+
:return: Result of the function call
|
579
|
+
"""
|
580
|
+
return _MATH_FUNCTIONS[self.func](self.x._call(**kwargs))
|
581
|
+
|
582
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
583
|
+
"""
|
584
|
+
Convert to a UFunc AST node.
|
585
|
+
|
586
|
+
:return: UFunc AST node
|
587
|
+
:rtype: dsl_nodes.UFunc
|
588
|
+
"""
|
589
|
+
return dsl_nodes.UFunc(func=self.func, expr=self.x.to_ast_node())
|
590
|
+
|
591
|
+
|
592
|
+
@dataclass
|
593
|
+
class ConditionalOp(Op):
|
594
|
+
"""
|
595
|
+
Conditional (ternary) operator expression.
|
596
|
+
|
597
|
+
:param cond: Condition expression
|
598
|
+
:type cond: Expr
|
599
|
+
:param if_true: Expression to evaluate if condition is true
|
600
|
+
:type if_true: Expr
|
601
|
+
:param if_false: Expression to evaluate if condition is false
|
602
|
+
:type if_false: Expr
|
603
|
+
"""
|
604
|
+
cond: Expr
|
605
|
+
if_true: Expr
|
606
|
+
if_false: Expr
|
607
|
+
|
608
|
+
@property
|
609
|
+
def op_mark(self):
|
610
|
+
"""
|
611
|
+
Get the operator mark for precedence lookup.
|
612
|
+
|
613
|
+
:return: Operator mark
|
614
|
+
:rtype: str
|
615
|
+
"""
|
616
|
+
return '?:'
|
617
|
+
|
618
|
+
def _iter_subs(self):
|
619
|
+
"""
|
620
|
+
Iterate over sub-expressions.
|
621
|
+
|
622
|
+
:return: Iterator over sub-expressions
|
623
|
+
:rtype: Iterator[Expr]
|
624
|
+
"""
|
625
|
+
yield self.cond
|
626
|
+
yield self.if_true
|
627
|
+
yield self.if_false
|
628
|
+
|
629
|
+
def _call(self, **kwargs):
|
630
|
+
"""
|
631
|
+
Evaluate the conditional operation.
|
632
|
+
|
633
|
+
:param kwargs: Variable name to value mapping
|
634
|
+
:return: Result of either if_true or if_false based on condition
|
635
|
+
"""
|
636
|
+
cond_value = self.cond._call(**kwargs)
|
637
|
+
if cond_value:
|
638
|
+
return self.if_true._call(**kwargs)
|
639
|
+
else:
|
640
|
+
return self.if_false._call(**kwargs)
|
641
|
+
|
642
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
643
|
+
"""
|
644
|
+
Convert to a ConditionalOp AST node.
|
645
|
+
|
646
|
+
Handles operator precedence by adding parentheses where needed.
|
647
|
+
|
648
|
+
:return: ConditionalOp AST node
|
649
|
+
:rtype: dsl_nodes.ConditionalOp
|
650
|
+
"""
|
651
|
+
my_pre = _OP_PRECEDENCE[self.op_mark]
|
652
|
+
|
653
|
+
true_need_paren = False
|
654
|
+
if isinstance(self.if_true, Op):
|
655
|
+
true_pre = _OP_PRECEDENCE[self.if_true.op_mark]
|
656
|
+
if true_pre <= my_pre:
|
657
|
+
true_need_paren = True
|
658
|
+
|
659
|
+
false_need_paren = False
|
660
|
+
if isinstance(self.if_false, Op):
|
661
|
+
false_pre = _OP_PRECEDENCE[self.if_false.op_mark]
|
662
|
+
if false_pre <= my_pre:
|
663
|
+
false_need_paren = True
|
664
|
+
|
665
|
+
cond_term = self.cond.to_ast_node()
|
666
|
+
true_term = self.if_true.to_ast_node()
|
667
|
+
if true_need_paren:
|
668
|
+
true_term = dsl_nodes.Paren(true_term)
|
669
|
+
false_term = self.if_false.to_ast_node()
|
670
|
+
if false_need_paren:
|
671
|
+
false_term = dsl_nodes.Paren(false_term)
|
672
|
+
|
673
|
+
return dsl_nodes.ConditionalOp(
|
674
|
+
cond=cond_term,
|
675
|
+
value_true=true_term,
|
676
|
+
value_false=false_term,
|
677
|
+
)
|
678
|
+
|
679
|
+
|
680
|
+
@dataclass
|
681
|
+
class Variable(Expr):
|
682
|
+
"""
|
683
|
+
Variable reference expression.
|
684
|
+
|
685
|
+
:param name: Variable name
|
686
|
+
:type name: str
|
687
|
+
"""
|
688
|
+
name: str
|
689
|
+
|
690
|
+
def _call(self, **kwargs):
|
691
|
+
"""
|
692
|
+
Lookup the variable value from kwargs.
|
693
|
+
|
694
|
+
:param kwargs: Variable name to value mapping
|
695
|
+
:return: Variable value
|
696
|
+
:raises KeyError: If variable name is not found in kwargs
|
697
|
+
"""
|
698
|
+
return kwargs[self.name]
|
699
|
+
|
700
|
+
def to_ast_node(self) -> dsl_nodes.Expr:
|
701
|
+
"""
|
702
|
+
Convert to a Name AST node.
|
703
|
+
|
704
|
+
:return: Name AST node
|
705
|
+
:rtype: dsl_nodes.Name
|
706
|
+
"""
|
707
|
+
return dsl_nodes.Name(name=self.name)
|
708
|
+
|
709
|
+
|
710
|
+
def parse_expr_node_to_expr(node: dsl_nodes.Expr) -> Expr:
|
711
|
+
"""
|
712
|
+
Parse an AST expression node into an Expr object.
|
713
|
+
|
714
|
+
This function converts DSL expression nodes into the corresponding expression objects.
|
715
|
+
|
716
|
+
:param node: AST expression node
|
717
|
+
:type node: dsl_nodes.Expr
|
718
|
+
:return: Corresponding expression object
|
719
|
+
:rtype: Expr
|
720
|
+
:raises TypeError: If the node type is not recognized
|
721
|
+
|
722
|
+
Example::
|
723
|
+
|
724
|
+
>>> ast_node = dsl_nodes.Integer(raw="42")
|
725
|
+
>>> expr = parse_expr_node_to_expr(ast_node)
|
726
|
+
>>> isinstance(expr, Integer)
|
727
|
+
True
|
728
|
+
>>> expr.value
|
729
|
+
42
|
730
|
+
"""
|
731
|
+
if isinstance(node, dsl_nodes.Name):
|
732
|
+
return Variable(name=node.name)
|
733
|
+
elif isinstance(node, (dsl_nodes.Integer, dsl_nodes.HexInt)):
|
734
|
+
return Integer(value=node.value)
|
735
|
+
elif isinstance(node, (dsl_nodes.Constant, dsl_nodes.Float)):
|
736
|
+
return Float(value=node.value)
|
737
|
+
elif isinstance(node, dsl_nodes.Boolean):
|
738
|
+
return Boolean(value=node.value)
|
739
|
+
elif isinstance(node, dsl_nodes.Paren):
|
740
|
+
return parse_expr_node_to_expr(node.expr)
|
741
|
+
elif isinstance(node, dsl_nodes.UnaryOp):
|
742
|
+
return UnaryOp(
|
743
|
+
op=node.op,
|
744
|
+
x=parse_expr_node_to_expr(node.expr),
|
745
|
+
)
|
746
|
+
elif isinstance(node, dsl_nodes.BinaryOp):
|
747
|
+
return BinaryOp(
|
748
|
+
x=parse_expr_node_to_expr(node.expr1),
|
749
|
+
op=node.op,
|
750
|
+
y=parse_expr_node_to_expr(node.expr2),
|
751
|
+
)
|
752
|
+
elif isinstance(node, dsl_nodes.ConditionalOp):
|
753
|
+
return ConditionalOp(
|
754
|
+
cond=parse_expr_node_to_expr(node.cond),
|
755
|
+
if_true=parse_expr_node_to_expr(node.value_true),
|
756
|
+
if_false=parse_expr_node_to_expr(node.value_false),
|
757
|
+
)
|
758
|
+
elif isinstance(node, dsl_nodes.UFunc):
|
759
|
+
return UFunc(
|
760
|
+
func=node.func,
|
761
|
+
x=parse_expr_node_to_expr(node.expr),
|
762
|
+
)
|
763
|
+
else:
|
764
|
+
raise TypeError(f'Unknown node type - {node!r}.') # pragma: no cover
|