expr-codegen 0.8.0__tar.gz → 0.8.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/PKG-INFO +1 -1
- expr_codegen-0.8.2/expr_codegen/_version.py +1 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/codes.py +61 -7
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/model.py +2 -1
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/template.py.j2 +3 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/printer.py +17 -1
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/template.py.j2 +3 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/tool.py +11 -3
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/PKG-INFO +1 -1
- expr_codegen-0.8.0/expr_codegen/_version.py +0 -1
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/LICENSE +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/README.md +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/__init__.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/dag.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/expr.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/latex/__init__.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/latex/printer.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/__init__.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/code.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/printer.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/__init__.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/code.py +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/SOURCES.txt +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/dependency_links.txt +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/requires.txt +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/top_level.txt +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/pyproject.toml +0 -0
- {expr_codegen-0.8.0 → expr_codegen-0.8.2}/setup.cfg +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.8.2"
|
|
@@ -21,7 +21,9 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
21
21
|
|
|
22
22
|
# 映射
|
|
23
23
|
funcs_map = {}
|
|
24
|
-
|
|
24
|
+
# 由于None等常量无法在sympy中正确处理,只能改成Symbol变量
|
|
25
|
+
# !!!一定要在drop_symbols时排除
|
|
26
|
+
args_map = {'True': "_TRUE_", 'False': "_FALSE_", 'None': "_NONE_"}
|
|
25
27
|
targets_map = {} # 只对非下划线开头的生效
|
|
26
28
|
|
|
27
29
|
def config_map(self, funcs_map, args_map, targets_map):
|
|
@@ -35,11 +37,18 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
35
37
|
node.func.id = self.funcs_map.get(node.func.id, node.func.id)
|
|
36
38
|
self.funcs_new.add(node.func.id)
|
|
37
39
|
# 提取参数名
|
|
38
|
-
for arg in node.args:
|
|
40
|
+
for i, arg in enumerate(node.args):
|
|
39
41
|
if isinstance(arg, ast.Name):
|
|
40
42
|
self.args_old.add(arg.id)
|
|
41
43
|
arg.id = self.args_map.get(arg.id, arg.id)
|
|
42
44
|
self.args_new.add(arg.id)
|
|
45
|
+
if isinstance(arg, ast.Constant):
|
|
46
|
+
old_arg_value = str(arg.value)
|
|
47
|
+
if old_arg_value in self.args_map:
|
|
48
|
+
new_arg_value = self.args_map.get(old_arg_value, old_arg_value)
|
|
49
|
+
self.args_old.add(old_arg_value)
|
|
50
|
+
node.args[i] = ast.Name(new_arg_value, ctx=ast.Load())
|
|
51
|
+
self.args_new.add(new_arg_value)
|
|
43
52
|
|
|
44
53
|
self.generic_visit(node)
|
|
45
54
|
return node
|
|
@@ -60,24 +69,41 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
60
69
|
# 记录修改的变量名,之后会使用到
|
|
61
70
|
self.args_map[old_target_id] = new_target_id
|
|
62
71
|
|
|
72
|
+
if isinstance(target, ast.Constant):
|
|
73
|
+
old_target_value = str(target.value)
|
|
74
|
+
if old_target_value in self.args_map:
|
|
75
|
+
new_target_value = self.args_map.get(old_target_value, old_target_value)
|
|
76
|
+
self.args_old.add(old_target_value)
|
|
77
|
+
target = ast.Name(new_target_value, ctx=ast.Load())
|
|
78
|
+
self.args_new.add(new_target_value)
|
|
79
|
+
|
|
80
|
+
return target
|
|
81
|
+
|
|
63
82
|
def visit_Assign(self, node):
|
|
64
83
|
# 调整位置,支持循环赋值
|
|
65
84
|
# _A = _A+1 调整成 _A_001 = _A_000 + 1
|
|
66
85
|
self.generic_visit(node)
|
|
67
86
|
|
|
68
87
|
# 提取输出变量名
|
|
69
|
-
for target in node.targets:
|
|
88
|
+
for i, target in enumerate(node.targets):
|
|
70
89
|
if isinstance(target, ast.Tuple):
|
|
71
|
-
for t in target.elts:
|
|
72
|
-
self.__visit_Assign(t)
|
|
90
|
+
for j, t in enumerate(target.elts):
|
|
91
|
+
target.elts[j] = self.__visit_Assign(t)
|
|
73
92
|
else:
|
|
74
|
-
self.__visit_Assign(target)
|
|
93
|
+
node.targets[i] = self.__visit_Assign(target)
|
|
75
94
|
|
|
76
95
|
# 处理 alpha=close 这种情况
|
|
77
96
|
if isinstance(node.value, ast.Name):
|
|
78
97
|
self.args_old.add(node.value.id)
|
|
79
98
|
node.value.id = self.args_map.get(node.value.id, node.value.id)
|
|
80
99
|
self.args_new.add(node.value.id)
|
|
100
|
+
if isinstance(node.value, ast.Constant):
|
|
101
|
+
old_node_value = str(node.value.value)
|
|
102
|
+
if old_node_value in self.args_map:
|
|
103
|
+
new_node_value = self.args_map.get(old_node_value, old_node_value)
|
|
104
|
+
self.args_old.add(old_node_value)
|
|
105
|
+
node.value = ast.Name(new_node_value, ctx=ast.Load())
|
|
106
|
+
self.args_new.add(new_node_value)
|
|
81
107
|
|
|
82
108
|
return node
|
|
83
109
|
|
|
@@ -87,11 +113,18 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
87
113
|
self.args_old.add(node.left.id)
|
|
88
114
|
node.left.id = self.args_map.get(node.left.id, node.left.id)
|
|
89
115
|
self.args_new.add(node.left.id)
|
|
90
|
-
for com in node.comparators:
|
|
116
|
+
for i, com in enumerate(node.comparators):
|
|
91
117
|
if isinstance(com, ast.Name):
|
|
92
118
|
self.args_old.add(com.id)
|
|
93
119
|
com.id = self.args_map.get(com.id, com.id)
|
|
94
120
|
self.args_new.add(com.id)
|
|
121
|
+
if isinstance(com, ast.Constant):
|
|
122
|
+
old_com_value = str(com.value)
|
|
123
|
+
if old_com_value in self.args_map:
|
|
124
|
+
new_com_value = self.args_map.get(old_com_value, old_com_value)
|
|
125
|
+
self.args_old.add(old_com_value)
|
|
126
|
+
node.comparators[i] = ast.Name(new_com_value, ctx=ast.Load())
|
|
127
|
+
self.args_new.add(new_com_value)
|
|
95
128
|
|
|
96
129
|
# OPEN==CLOSE,要转成Eq
|
|
97
130
|
if isinstance(node.ops[0], ast.Eq):
|
|
@@ -146,6 +179,20 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
146
179
|
self.args_old.add(node.right.id)
|
|
147
180
|
node.right.id = self.args_map.get(node.right.id, node.right.id)
|
|
148
181
|
self.args_new.add(node.right.id)
|
|
182
|
+
if isinstance(node.left, ast.Constant):
|
|
183
|
+
old_node_value = str(node.left.value)
|
|
184
|
+
if old_node_value in self.args_map:
|
|
185
|
+
new_node_value = self.args_map.get(old_node_value, old_node_value)
|
|
186
|
+
self.args_old.add(old_node_value)
|
|
187
|
+
node.left = ast.Name(new_node_value, ctx=ast.Load())
|
|
188
|
+
self.args_new.add(new_node_value)
|
|
189
|
+
if isinstance(node.right, ast.Constant):
|
|
190
|
+
old_node_value = str(node.right.value)
|
|
191
|
+
if old_node_value in self.args_map:
|
|
192
|
+
new_node_value = self.args_map.get(old_node_value, old_node_value)
|
|
193
|
+
self.args_old.add(old_node_value)
|
|
194
|
+
node.right = ast.Name(new_node_value, ctx=ast.Load())
|
|
195
|
+
self.args_new.add(new_node_value)
|
|
149
196
|
|
|
150
197
|
self.generic_visit(node)
|
|
151
198
|
return node
|
|
@@ -156,6 +203,13 @@ class SympyTransformer(ast.NodeTransformer):
|
|
|
156
203
|
self.args_old.add(node.operand.id)
|
|
157
204
|
node.operand.id = self.args_map.get(node.operand.id, node.operand.id)
|
|
158
205
|
self.args_new.add(node.operand.id)
|
|
206
|
+
if isinstance(node.operand, ast.Constant):
|
|
207
|
+
old_operand_value = str(node.operand.value)
|
|
208
|
+
if old_operand_value in self.args_map:
|
|
209
|
+
new_operand_value = self.args_map.get(old_operand_value, old_operand_value)
|
|
210
|
+
self.args_old.add(old_operand_value)
|
|
211
|
+
node.operand = ast.Name(new_operand_value, ctx=ast.Load())
|
|
212
|
+
self.args_new.add(new_operand_value)
|
|
159
213
|
|
|
160
214
|
self.generic_visit(node)
|
|
161
215
|
return node
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from sympy import Basic, Function, StrPrinter
|
|
2
|
-
from sympy.printing.precedence import precedence
|
|
2
|
+
from sympy.printing.precedence import precedence, PRECEDENCE
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
# TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略
|
|
@@ -55,6 +55,22 @@ class PolarsStrPrinter(StrPrinter):
|
|
|
55
55
|
PREC = precedence(expr)
|
|
56
56
|
return "%s==%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
|
|
57
57
|
|
|
58
|
+
def _print_Or(self, expr):
|
|
59
|
+
PREC = PRECEDENCE["Mul"]
|
|
60
|
+
return "%s | %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
|
|
61
|
+
|
|
62
|
+
def _print_Xor(self, expr):
|
|
63
|
+
PREC = PRECEDENCE["Mul"]
|
|
64
|
+
return "%s ^ %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
|
|
65
|
+
|
|
66
|
+
def _print_And(self, expr):
|
|
67
|
+
PREC = PRECEDENCE["Mul"]
|
|
68
|
+
return "%s & %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
|
|
69
|
+
|
|
70
|
+
def _print_Not(self, expr):
|
|
71
|
+
PREC = PRECEDENCE["Mul"]
|
|
72
|
+
return "~%s" % self.parenthesize(expr.args[0], PREC)
|
|
73
|
+
|
|
58
74
|
def _print_gp_rank(self, expr):
|
|
59
75
|
return "cs_rank(%s)" % self._print(expr.args[1])
|
|
60
76
|
|
|
@@ -11,6 +11,14 @@ from expr_codegen.expr import get_current_by_prefix, get_children, replace_exprs
|
|
|
11
11
|
from expr_codegen.model import dag_start, dag_end, dag_middle
|
|
12
12
|
|
|
13
13
|
|
|
14
|
+
def simplify2(expr):
|
|
15
|
+
try:
|
|
16
|
+
expr = simplify(expr)
|
|
17
|
+
except AttributeError as e:
|
|
18
|
+
print(f'{expr} ,表达式无法简化, {e}')
|
|
19
|
+
return expr
|
|
20
|
+
|
|
21
|
+
|
|
14
22
|
class ExprTool:
|
|
15
23
|
|
|
16
24
|
def __init__(self):
|
|
@@ -38,7 +46,7 @@ class ExprTool:
|
|
|
38
46
|
|
|
39
47
|
"""
|
|
40
48
|
# 抽取前先化简
|
|
41
|
-
expr =
|
|
49
|
+
expr = simplify2(expr)
|
|
42
50
|
|
|
43
51
|
exprs = []
|
|
44
52
|
syms = []
|
|
@@ -87,9 +95,9 @@ class ExprTool:
|
|
|
87
95
|
|
|
88
96
|
# 不做改动,直接生成
|
|
89
97
|
for variable, expr in repl:
|
|
90
|
-
exprs_dict[variable] =
|
|
98
|
+
exprs_dict[variable] = simplify2(expr)
|
|
91
99
|
for variable, expr in redu:
|
|
92
|
-
exprs_dict[variable] =
|
|
100
|
+
exprs_dict[variable] = simplify2(expr)
|
|
93
101
|
|
|
94
102
|
return exprs_dict
|
|
95
103
|
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.8.0"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|