expr-codegen 0.8.0__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/PKG-INFO +1 -1
  2. expr_codegen-0.8.2/expr_codegen/_version.py +1 -0
  3. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/codes.py +61 -7
  4. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/model.py +2 -1
  5. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/template.py.j2 +3 -0
  6. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/printer.py +17 -1
  7. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/template.py.j2 +3 -0
  8. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/tool.py +11 -3
  9. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/PKG-INFO +1 -1
  10. expr_codegen-0.8.0/expr_codegen/_version.py +0 -1
  11. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/LICENSE +0 -0
  12. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/README.md +0 -0
  13. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/__init__.py +0 -0
  14. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/dag.py +0 -0
  15. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/expr.py +0 -0
  16. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/latex/__init__.py +0 -0
  17. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/latex/printer.py +0 -0
  18. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/__init__.py +0 -0
  19. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/code.py +0 -0
  20. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/pandas/printer.py +0 -0
  21. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/__init__.py +0 -0
  22. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen/polars/code.py +0 -0
  23. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/SOURCES.txt +0 -0
  24. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/dependency_links.txt +0 -0
  25. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/requires.txt +0 -0
  26. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/expr_codegen.egg-info/top_level.txt +0 -0
  27. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/pyproject.toml +0 -0
  28. {expr_codegen-0.8.0 → expr_codegen-0.8.2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: expr_codegen
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -0,0 +1 @@
1
+ __version__ = "0.8.2"
@@ -21,7 +21,9 @@ class SympyTransformer(ast.NodeTransformer):
21
21
 
22
22
  # 映射
23
23
  funcs_map = {}
24
- args_map = {}
24
+ # 由于None等常量无法在sympy中正确处理,只能改成Symbol变量
25
+ # !!!一定要在drop_symbols时排除
26
+ args_map = {'True': "_TRUE_", 'False': "_FALSE_", 'None': "_NONE_"}
25
27
  targets_map = {} # 只对非下划线开头的生效
26
28
 
27
29
  def config_map(self, funcs_map, args_map, targets_map):
@@ -35,11 +37,18 @@ class SympyTransformer(ast.NodeTransformer):
35
37
  node.func.id = self.funcs_map.get(node.func.id, node.func.id)
36
38
  self.funcs_new.add(node.func.id)
37
39
  # 提取参数名
38
- for arg in node.args:
40
+ for i, arg in enumerate(node.args):
39
41
  if isinstance(arg, ast.Name):
40
42
  self.args_old.add(arg.id)
41
43
  arg.id = self.args_map.get(arg.id, arg.id)
42
44
  self.args_new.add(arg.id)
45
+ if isinstance(arg, ast.Constant):
46
+ old_arg_value = str(arg.value)
47
+ if old_arg_value in self.args_map:
48
+ new_arg_value = self.args_map.get(old_arg_value, old_arg_value)
49
+ self.args_old.add(old_arg_value)
50
+ node.args[i] = ast.Name(new_arg_value, ctx=ast.Load())
51
+ self.args_new.add(new_arg_value)
43
52
 
44
53
  self.generic_visit(node)
45
54
  return node
@@ -60,24 +69,41 @@ class SympyTransformer(ast.NodeTransformer):
60
69
  # 记录修改的变量名,之后会使用到
61
70
  self.args_map[old_target_id] = new_target_id
62
71
 
72
+ if isinstance(target, ast.Constant):
73
+ old_target_value = str(target.value)
74
+ if old_target_value in self.args_map:
75
+ new_target_value = self.args_map.get(old_target_value, old_target_value)
76
+ self.args_old.add(old_target_value)
77
+ target = ast.Name(new_target_value, ctx=ast.Load())
78
+ self.args_new.add(new_target_value)
79
+
80
+ return target
81
+
63
82
  def visit_Assign(self, node):
64
83
  # 调整位置,支持循环赋值
65
84
  # _A = _A+1 调整成 _A_001 = _A_000 + 1
66
85
  self.generic_visit(node)
67
86
 
68
87
  # 提取输出变量名
69
- for target in node.targets:
88
+ for i, target in enumerate(node.targets):
70
89
  if isinstance(target, ast.Tuple):
71
- for t in target.elts:
72
- self.__visit_Assign(t)
90
+ for j, t in enumerate(target.elts):
91
+ target.elts[j] = self.__visit_Assign(t)
73
92
  else:
74
- self.__visit_Assign(target)
93
+ node.targets[i] = self.__visit_Assign(target)
75
94
 
76
95
  # 处理 alpha=close 这种情况
77
96
  if isinstance(node.value, ast.Name):
78
97
  self.args_old.add(node.value.id)
79
98
  node.value.id = self.args_map.get(node.value.id, node.value.id)
80
99
  self.args_new.add(node.value.id)
100
+ if isinstance(node.value, ast.Constant):
101
+ old_node_value = str(node.value.value)
102
+ if old_node_value in self.args_map:
103
+ new_node_value = self.args_map.get(old_node_value, old_node_value)
104
+ self.args_old.add(old_node_value)
105
+ node.value = ast.Name(new_node_value, ctx=ast.Load())
106
+ self.args_new.add(new_node_value)
81
107
 
82
108
  return node
83
109
 
@@ -87,11 +113,18 @@ class SympyTransformer(ast.NodeTransformer):
87
113
  self.args_old.add(node.left.id)
88
114
  node.left.id = self.args_map.get(node.left.id, node.left.id)
89
115
  self.args_new.add(node.left.id)
90
- for com in node.comparators:
116
+ for i, com in enumerate(node.comparators):
91
117
  if isinstance(com, ast.Name):
92
118
  self.args_old.add(com.id)
93
119
  com.id = self.args_map.get(com.id, com.id)
94
120
  self.args_new.add(com.id)
121
+ if isinstance(com, ast.Constant):
122
+ old_com_value = str(com.value)
123
+ if old_com_value in self.args_map:
124
+ new_com_value = self.args_map.get(old_com_value, old_com_value)
125
+ self.args_old.add(old_com_value)
126
+ node.comparators[i] = ast.Name(new_com_value, ctx=ast.Load())
127
+ self.args_new.add(new_com_value)
95
128
 
96
129
  # OPEN==CLOSE,要转成Eq
97
130
  if isinstance(node.ops[0], ast.Eq):
@@ -146,6 +179,20 @@ class SympyTransformer(ast.NodeTransformer):
146
179
  self.args_old.add(node.right.id)
147
180
  node.right.id = self.args_map.get(node.right.id, node.right.id)
148
181
  self.args_new.add(node.right.id)
182
+ if isinstance(node.left, ast.Constant):
183
+ old_node_value = str(node.left.value)
184
+ if old_node_value in self.args_map:
185
+ new_node_value = self.args_map.get(old_node_value, old_node_value)
186
+ self.args_old.add(old_node_value)
187
+ node.left = ast.Name(new_node_value, ctx=ast.Load())
188
+ self.args_new.add(new_node_value)
189
+ if isinstance(node.right, ast.Constant):
190
+ old_node_value = str(node.right.value)
191
+ if old_node_value in self.args_map:
192
+ new_node_value = self.args_map.get(old_node_value, old_node_value)
193
+ self.args_old.add(old_node_value)
194
+ node.right = ast.Name(new_node_value, ctx=ast.Load())
195
+ self.args_new.add(new_node_value)
149
196
 
150
197
  self.generic_visit(node)
151
198
  return node
@@ -156,6 +203,13 @@ class SympyTransformer(ast.NodeTransformer):
156
203
  self.args_old.add(node.operand.id)
157
204
  node.operand.id = self.args_map.get(node.operand.id, node.operand.id)
158
205
  self.args_new.add(node.operand.id)
206
+ if isinstance(node.operand, ast.Constant):
207
+ old_operand_value = str(node.operand.value)
208
+ if old_operand_value in self.args_map:
209
+ new_operand_value = self.args_map.get(old_operand_value, old_operand_value)
210
+ self.args_old.add(old_operand_value)
211
+ node.operand = ast.Name(new_operand_value, ctx=ast.Load())
212
+ self.args_new.add(new_operand_value)
159
213
 
160
214
  self.generic_visit(node)
161
215
  return node
@@ -109,7 +109,8 @@ class ListDictList:
109
109
  l2 = [set()]
110
110
  s = set()
111
111
  for i in reversed(l1):
112
- s = s | i
112
+ # 这三变量需要排除
113
+ s = s | i - {'_NONE_', '_TRUE_', '_FALSE_'}
113
114
  l2.append(s)
114
115
  l2 = list(reversed(l2))
115
116
 
@@ -16,6 +16,9 @@ from loguru import logger # noqa
16
16
 
17
17
  _DATE_ = '{{ date }}'
18
18
  _ASSET_ = '{{ asset }}'
19
+ _NONE_ = None
20
+ _TRUE_ = True
21
+ _FALSE_ = False
19
22
 
20
23
  {%-for row in extra_codes %}
21
24
  {{ row-}}
@@ -1,5 +1,5 @@
1
1
  from sympy import Basic, Function, StrPrinter
2
- from sympy.printing.precedence import precedence
2
+ from sympy.printing.precedence import precedence, PRECEDENCE
3
3
 
4
4
 
5
5
  # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略
@@ -55,6 +55,22 @@ class PolarsStrPrinter(StrPrinter):
55
55
  PREC = precedence(expr)
56
56
  return "%s==%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
57
57
 
58
+ def _print_Or(self, expr):
59
+ PREC = PRECEDENCE["Mul"]
60
+ return "%s | %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
61
+
62
+ def _print_Xor(self, expr):
63
+ PREC = PRECEDENCE["Mul"]
64
+ return "%s ^ %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
65
+
66
+ def _print_And(self, expr):
67
+ PREC = PRECEDENCE["Mul"]
68
+ return "%s & %s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC))
69
+
70
+ def _print_Not(self, expr):
71
+ PREC = PRECEDENCE["Mul"]
72
+ return "~%s" % self.parenthesize(expr.args[0], PREC)
73
+
58
74
  def _print_gp_rank(self, expr):
59
75
  return "cs_rank(%s)" % self._print(expr.args[1])
60
76
 
@@ -26,6 +26,9 @@ from polars_ta.prefix.cdl import * # noqa
26
26
 
27
27
  _DATE_ = '{{ date }}'
28
28
  _ASSET_ = '{{ asset }}'
29
+ _NONE_ = None
30
+ _TRUE_ = True
31
+ _FALSE_ = False
29
32
 
30
33
  {%-for row in extra_codes %}
31
34
  {{ row-}}
@@ -11,6 +11,14 @@ from expr_codegen.expr import get_current_by_prefix, get_children, replace_exprs
11
11
  from expr_codegen.model import dag_start, dag_end, dag_middle
12
12
 
13
13
 
14
+ def simplify2(expr):
15
+ try:
16
+ expr = simplify(expr)
17
+ except AttributeError as e:
18
+ print(f'{expr} ,表达式无法简化, {e}')
19
+ return expr
20
+
21
+
14
22
  class ExprTool:
15
23
 
16
24
  def __init__(self):
@@ -38,7 +46,7 @@ class ExprTool:
38
46
 
39
47
  """
40
48
  # 抽取前先化简
41
- expr = simplify(expr)
49
+ expr = simplify2(expr)
42
50
 
43
51
  exprs = []
44
52
  syms = []
@@ -87,9 +95,9 @@ class ExprTool:
87
95
 
88
96
  # 不做改动,直接生成
89
97
  for variable, expr in repl:
90
- exprs_dict[variable] = simplify(expr)
98
+ exprs_dict[variable] = simplify2(expr)
91
99
  for variable, expr in redu:
92
- exprs_dict[variable] = simplify(expr)
100
+ exprs_dict[variable] = simplify2(expr)
93
101
 
94
102
  return exprs_dict
95
103
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: expr_codegen
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -1 +0,0 @@
1
- __version__ = "0.8.0"
File without changes
File without changes
File without changes