expr-codegen 0.10.15__tar.gz → 0.11.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/PKG-INFO +34 -3
  2. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/README.md +32 -2
  3. expr_codegen-0.11.0/expr_codegen/_version.py +1 -0
  4. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/codes.py +29 -8
  5. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/expr.py +44 -39
  6. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/model.py +12 -7
  7. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/code.py +2 -2
  8. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/template.py.j2 +2 -2
  9. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_group/code.py +2 -2
  10. {expr_codegen-0.10.15/expr_codegen/polars_over → expr_codegen-0.11.0/expr_codegen/polars_group}/template.py.j2 +2 -2
  11. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_over/code.py +10 -4
  12. {expr_codegen-0.10.15/expr_codegen/polars_group → expr_codegen-0.11.0/expr_codegen/polars_over}/template.py.j2 +2 -2
  13. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/tool.py +44 -35
  14. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen.egg-info/PKG-INFO +34 -3
  15. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen.egg-info/requires.txt +1 -0
  16. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/pyproject.toml +1 -0
  17. expr_codegen-0.10.15/expr_codegen/_version.py +0 -1
  18. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/LICENSE +0 -0
  19. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/__init__.py +0 -0
  20. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/dag.py +0 -0
  21. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/latex/__init__.py +0 -0
  22. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/latex/printer.py +0 -0
  23. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/__init__.py +0 -0
  24. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/helper.py +0 -0
  25. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/printer.py +0 -0
  26. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/pandas/ta.py +0 -0
  27. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_group/__init__.py +0 -0
  28. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_group/printer.py +0 -0
  29. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_over/__init__.py +0 -0
  30. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen/polars_over/printer.py +0 -0
  31. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen.egg-info/SOURCES.txt +0 -0
  32. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen.egg-info/dependency_links.txt +0 -0
  33. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/expr_codegen.egg-info/top_level.txt +0 -0
  34. {expr_codegen-0.10.15 → expr_codegen-0.11.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expr_codegen
3
- Version: 0.10.15
3
+ Version: 0.11.0
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -43,6 +43,7 @@ Requires-Dist: Jinja2
43
43
  Requires-Dist: networkx
44
44
  Requires-Dist: loguru
45
45
  Requires-Dist: sympy
46
+ Requires-Dist: ast-comments
46
47
  Provides-Extra: streamlit
47
48
  Requires-Dist: streamlit; extra == "streamlit"
48
49
  Requires-Dist: streamlit-ace; extra == "streamlit"
@@ -211,6 +212,33 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
211
212
  2. `over_null='order_by'`。分到一个区域,`null`排在前面
212
213
  3. `over_null=None`。不处理,直接调用,速度更快。如果确信不会中段产生`null`建议使用此参数
213
214
 
215
+ `codegen_exec(over_null='partition_by')`为全局使用`partition_by`。但遇到`ts_count_nulls`这类`null`
216
+ 函数就得使用`over_null=None`,所以本工具还新添了注释功能来指定单行表达式参数
217
+
218
+ 1. `# --over_null partition_by`。单行`over_null='partition_by'`
219
+ 2. `# --over_null=order_by`。单行`over_null='order_by'`
220
+ 3. `# --over_null`。单行`over_null=None`
221
+ 4. `# `。取`codegen_exec`参数传入的`over_null`值
222
+
223
+ 注意:
224
+
225
+ 1. `# --over_null`传参注释只能写在单行表达式的后面,不能独立成一行,否则会被忽略
226
+ 2. `# --over_null # --over_null=order_by`多个`#`时,只取第一个有效
227
+ 3. 只对最外层`ts`函数有效。如果`ts`函数不在外层,需要人工提炼。例如:
228
+ ```python
229
+ X1 = cs_rank(ts_mean(CLOSE, 3)) # --over_null=order_by # 应用在cs_rank上,没有意义
230
+ X2 = ts_rank(ts_mean(CLOSE, 3), 5) # --over_null=order_by # 本以为应用在ts_rank(ts_mean)上,但由于出现了公共ts_mean,其实是应用在ts_rank(_x_0)上
231
+ ```
232
+
233
+ 需写成
234
+
235
+ ```python
236
+ _x_0 = ts_mean(CLOSE, 3) # --over_null=order_by
237
+ X1 = cs_rank(_x_0)
238
+ X2 = ts_rank(_x_0, 5)
239
+ ```
240
+ 4. 由于很容易搞错,强烈建议生成`output_file`,检查生成的代码是否正确。
241
+
214
242
  ## `expr_codegen`局限性
215
243
 
216
244
  1. `DAG`只能增加列无法删除。增加列时,遇到同名列会覆盖
@@ -220,7 +248,8 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
220
248
 
221
249
  ## 特别语法
222
250
 
223
- 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`,最后转成`if_else(C,T,F)`。支持与`if else`混用
251
+ 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`
252
+ ,最后转成`if_else(C,T,F)`。支持与`if else`混用
224
253
  2. `(A<B)*-1`,底层将转换成`int_(A<B)*-1`
225
254
  3. 为防止`A==B`被`sympy`替换成`False`,底层会换成`Eq(A,B)`
226
255
  4. `A^B`的含义与`convert_xor`参数有关,`convert_xor=True`底层会转换成`Pow(A,B)`,反之为`Xor(A,B)`。默认为`False`,用`**`表示乘方
@@ -230,6 +259,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
230
259
  8. 支持`~A`,底层会转换成`Not(A)`
231
260
  9. `gp_`开头的函数都会返回对应的`cs_`函数。如`gp_func(A,B,C)`会替换成`cs_func(B,C)`,其中`A`用在了`groupby([date, A])`
232
261
  10. 支持`A,B,C=MACD()`元组解包,在底层会替换成
262
+ 11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
233
263
 
234
264
  ```python
235
265
  _x_0 = MACD()
@@ -242,7 +272,8 @@ C = unpack(_x_0, 2)
242
272
 
243
273
  1. 输出的数据,所有以`_`开头的列,最后会被自动删除。所以需要保留的变量一定不要以`_`开头
244
274
  2. 为减少重复计算,自动添加了了中间变量,以`_x_`开头,如`_x_0`,`_x_1`等。最后会被自动删除
245
- 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
275
+ 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`
276
+ 开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
246
277
  1. 同一变量名,重复使用。本质是不同的变量
247
278
  2. 循环赋值,但`DAG`不支持有环。`=`号左右的同名变量其实是不同变量
248
279
 
@@ -160,6 +160,33 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
160
160
  2. `over_null='order_by'`。分到一个区域,`null`排在前面
161
161
  3. `over_null=None`。不处理,直接调用,速度更快。如果确信不会中段产生`null`建议使用此参数
162
162
 
163
+ `codegen_exec(over_null='partition_by')`为全局使用`partition_by`。但遇到`ts_count_nulls`这类`null`
164
+ 函数就得使用`over_null=None`,所以本工具还新添了注释功能来指定单行表达式参数
165
+
166
+ 1. `# --over_null partition_by`。单行`over_null='partition_by'`
167
+ 2. `# --over_null=order_by`。单行`over_null='order_by'`
168
+ 3. `# --over_null`。单行`over_null=None`
169
+ 4. `# `。取`codegen_exec`参数传入的`over_null`值
170
+
171
+ 注意:
172
+
173
+ 1. `# --over_null`传参注释只能写在单行表达式的后面,不能独立成一行,否则会被忽略
174
+ 2. `# --over_null # --over_null=order_by`多个`#`时,只取第一个有效
175
+ 3. 只对最外层`ts`函数有效。如果`ts`函数不在外层,需要人工提炼。例如:
176
+ ```python
177
+ X1 = cs_rank(ts_mean(CLOSE, 3)) # --over_null=order_by # 应用在cs_rank上,没有意义
178
+ X2 = ts_rank(ts_mean(CLOSE, 3), 5) # --over_null=order_by # 本以为应用在ts_rank(ts_mean)上,但由于出现了公共ts_mean,其实是应用在ts_rank(_x_0)上
179
+ ```
180
+
181
+ 需写成
182
+
183
+ ```python
184
+ _x_0 = ts_mean(CLOSE, 3) # --over_null=order_by
185
+ X1 = cs_rank(_x_0)
186
+ X2 = ts_rank(_x_0, 5)
187
+ ```
188
+ 4. 由于很容易搞错,强烈建议生成`output_file`,检查生成的代码是否正确。
189
+
163
190
  ## `expr_codegen`局限性
164
191
 
165
192
  1. `DAG`只能增加列无法删除。增加列时,遇到同名列会覆盖
@@ -169,7 +196,8 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
169
196
 
170
197
  ## 特别语法
171
198
 
172
- 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`,最后转成`if_else(C,T,F)`。支持与`if else`混用
199
+ 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`
200
+ ,最后转成`if_else(C,T,F)`。支持与`if else`混用
173
201
  2. `(A<B)*-1`,底层将转换成`int_(A<B)*-1`
174
202
  3. 为防止`A==B`被`sympy`替换成`False`,底层会换成`Eq(A,B)`
175
203
  4. `A^B`的含义与`convert_xor`参数有关,`convert_xor=True`底层会转换成`Pow(A,B)`,反之为`Xor(A,B)`。默认为`False`,用`**`表示乘方
@@ -179,6 +207,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
179
207
  8. 支持`~A`,底层会转换成`Not(A)`
180
208
  9. `gp_`开头的函数都会返回对应的`cs_`函数。如`gp_func(A,B,C)`会替换成`cs_func(B,C)`,其中`A`用在了`groupby([date, A])`
181
209
  10. 支持`A,B,C=MACD()`元组解包,在底层会替换成
210
+ 11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
182
211
 
183
212
  ```python
184
213
  _x_0 = MACD()
@@ -191,7 +220,8 @@ C = unpack(_x_0, 2)
191
220
 
192
221
  1. 输出的数据,所有以`_`开头的列,最后会被自动删除。所以需要保留的变量一定不要以`_`开头
193
222
  2. 为减少重复计算,自动添加了了中间变量,以`_x_`开头,如`_x_0`,`_x_1`等。最后会被自动删除
194
- 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
223
+ 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`
224
+ 开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
195
225
  1. 同一变量名,重复使用。本质是不同的变量
196
226
  2. 循环赋值,但`DAG`不支持有环。`=`号左右的同名变量其实是不同变量
197
227
 
@@ -0,0 +1 @@
1
+ __version__ = "0.11.0"
@@ -2,10 +2,11 @@ import ast
2
2
  import re
3
3
  from ast import expr
4
4
 
5
+ import ast_comments
5
6
  from black import Mode, format_str
6
7
  from sympy import Add, Mul, Pow, Eq, Not, Xor
7
8
 
8
- from expr_codegen.expr import register_symbols, dict_to_exprs
9
+ from expr_codegen.expr import register_symbols, list_to_exprs
9
10
 
10
11
 
11
12
  class SyntaxTransformer(ast.NodeTransformer):
@@ -108,7 +109,8 @@ class SyntaxTransformer(ast.NodeTransformer):
108
109
  def visit_Subscript(self, node):
109
110
  if isinstance(node.slice, ast.Constant) and node.slice.value == 0:
110
111
  node = node.value
111
- elif isinstance(node.slice, ast.UnaryOp) and isinstance(node.slice.operand, ast.Constant) and node.slice.operand.value == 0:
112
+ elif isinstance(node.slice, ast.UnaryOp) and isinstance(node.slice.operand,
113
+ ast.Constant) and node.slice.operand.value == 0:
112
114
  node = node.value
113
115
  else:
114
116
  node = ast.Call(
@@ -328,6 +330,21 @@ def assigns_to_dict(assigns):
328
330
  return {ast.unparse(a.targets): ast.unparse(a.value) for a in assigns}
329
331
 
330
332
 
333
+ def assigns_to_list(assigns):
334
+ """赋值表达式转成列表"""
335
+ outputs = []
336
+ for i, a in enumerate(assigns):
337
+ comment = "#"
338
+ if i + 1 < len(assigns):
339
+ b = assigns[i + 1]
340
+ if isinstance(b, ast_comments.Comment):
341
+ # comment = ast_comments.unparse(b)
342
+ comment = b.value
343
+ if isinstance(a, ast.Assign):
344
+ outputs.append((ast.unparse(a.targets), ast.unparse(a.value), comment))
345
+ return outputs
346
+
347
+
331
348
  def raw_to_code(raw):
332
349
  """导入语句转字符列表"""
333
350
  return '\n'.join([ast.unparse(a) for a in raw])
@@ -338,7 +355,7 @@ def sources_to_asts(*sources, convert_xor: bool):
338
355
 
339
356
  def _source_to_asts(source):
340
357
  """源代码"""
341
- tree = ast.parse(source_replace(source))
358
+ tree = ast_comments.parse(source_replace(source))
342
359
 
343
360
  if isinstance(tree.body[0], ast.FunctionDef):
344
361
  body = tree.body[0].body
@@ -347,7 +364,7 @@ def sources_to_asts(*sources, convert_xor: bool):
347
364
 
348
365
  return body
349
366
 
350
- tree = ast.parse("")
367
+ tree = ast_comments.parse("")
351
368
  for arg in sources:
352
369
  tree.body.extend(_source_to_asts(arg))
353
370
 
@@ -359,16 +376,21 @@ def sources_to_asts(*sources, convert_xor: bool):
359
376
  raw = []
360
377
  assigns = []
361
378
 
362
- for node in tree.body:
379
+ for i, node in enumerate(tree.body):
363
380
  # 特殊处理的节点
364
381
  if isinstance(node, ast.Assign):
365
382
  assigns.append(node)
366
383
  continue
384
+ if isinstance(node, ast_comments.Comment):
385
+ # 添加注释
386
+ if node.inline and isinstance(assigns[-1], ast.Assign):
387
+ assigns.append(node)
388
+ continue
367
389
  # TODO 是否要把其它语句也加入?是否有安全问题?
368
390
  if isinstance(node, (ast.Import, ast.ImportFrom)):
369
391
  raw.append(node)
370
392
  continue
371
- return raw_to_code(raw), assigns_to_dict(assigns), t.funcs_new, t.args_new, t.targets_new
393
+ return raw_to_code(raw), assigns_to_list(assigns), t.funcs_new, t.args_new, t.targets_new
372
394
 
373
395
 
374
396
  def _add_default_type(globals_):
@@ -394,5 +416,4 @@ def sources_to_exprs(globals_, *sources, convert_xor: bool):
394
416
  register_symbols(funcs_new, globals_, is_function=True)
395
417
  register_symbols(args_new, globals_, is_function=False)
396
418
  register_symbols(targets_new, globals_, is_function=False)
397
- exprs_dict = dict_to_exprs(assigns, globals_)
398
- return raw, exprs_dict
419
+ return raw, list_to_exprs(assigns, globals_)
@@ -46,9 +46,8 @@ def register_symbols(syms, globals_, is_function: bool):
46
46
  return globals_
47
47
 
48
48
 
49
- def dict_to_exprs(exprs_src, globals_):
50
- exprs_src = {k: sympify(v, globals_, evaluate=False) for k, v in exprs_src.items()}
51
- return exprs_src
49
+ def list_to_exprs(exprs_src, globals_):
50
+ return [(a, sympify(b, globals_, evaluate=False), c) for a, b, c in exprs_src]
52
51
 
53
52
 
54
53
  def append_node(node, output_exprs):
@@ -97,6 +96,12 @@ def get_symbols(expr, syms=None, return_str=True):
97
96
  syms.append(arg.name)
98
97
  else:
99
98
  syms.append(arg)
99
+ elif arg.is_Number:
100
+ # alpha_001 = log(1)+1
101
+ if return_str:
102
+ syms.append(str(arg))
103
+ else:
104
+ syms.append(arg)
100
105
  else:
101
106
  get_symbols(arg, syms, return_str)
102
107
  return syms
@@ -284,15 +289,15 @@ def get_key(children):
284
289
  def replace_exprs(exprs):
285
290
  """使用替换的方式简化表达式"""
286
291
  # Alpha101中大量ts_sum(x, 10)/10, 转成ts_mean(x, 10)
287
- exprs = {k: _replace__ts_sum__to__ts_mean(v) for k, v in exprs.items()}
292
+ exprs = [(a, _replace__ts_sum__to__ts_mean(b), c) for a, b, c in exprs]
288
293
  # alpha_031中大量cs_rank(cs_rank(x)) 转成cs_rank(x)
289
- exprs = {k: _replace__repeat(v) for k, v in exprs.items()}
294
+ exprs = [(a, _replace__repeat(b), c) for a, b, c in exprs]
290
295
  # 1.0*VWAP转VWAP
291
- exprs = {k: _replace__one_mul(v) for k, v in exprs.items()}
296
+ exprs = [(a, _replace__one_mul(b), c) for a, b, c in exprs]
292
297
  # 将部分参数为1的ts函数进行简化
293
- exprs = {k: _replace__ts_xxx_1(v) for k, v in exprs.items()}
298
+ exprs = [(a, _replace__ts_xxx_1(b), c) for a, b, c in exprs]
294
299
  # ts_delay转成ts_delta
295
- exprs = {k: _replace__ts_delay__to__ts_delta(v) for k, v in exprs.items()}
300
+ exprs = [(a, _replace__ts_delay__to__ts_delta(b), c) for a, b, c in exprs]
296
301
 
297
302
  return exprs
298
303
 
@@ -435,34 +440,34 @@ def _replace__ts_delay__to__ts_delta(e):
435
440
  e = e.xreplace({node: replacement})
436
441
  return e
437
442
 
438
-
439
- def is_meaningless(e):
440
- if _meaningless__ts_xxx_1(e):
441
- return True
442
- if _meaningless__xx_xx(e):
443
- return True
444
- return False
445
-
446
-
447
- def _meaningless__ts_xxx_1(e):
448
- """ts_xxx部分函数如果参数为1,可直接丢弃"""
449
- for node in preorder_traversal(e):
450
- if len(node.args) >= 2:
451
- if node.args[-1] == 1:
452
- node_name = get_node_name(node)
453
- if node_name in ('ts_delay', 'ts_delta', 'max_', 'min_'):
454
- return False
455
- else:
456
- # 其它算子,参数1都认为无意义
457
- return True
458
-
459
- return False
460
-
461
-
462
- def _meaningless__xx_xx(e):
463
- """部分函数如果两参数完全一样,可直接丢弃"""
464
- for node in preorder_traversal(e):
465
- if len(node.args) >= 2:
466
- if node.args[0] == node.args[1]:
467
- return True
468
- return False
443
+ # def is_meaningless(e):
444
+ # if _meaningless__ts_xxx_1(e):
445
+ # return True
446
+ # if _meaningless__xx_xx(e):
447
+ # return True
448
+ # return False
449
+ #
450
+ #
451
+ # def _meaningless__ts_xxx_1(e):
452
+ # """ts_xxx部分函数如果参数为1,可直接丢弃"""
453
+ # for node in preorder_traversal(e):
454
+ # if len(node.args) >= 2:
455
+ # node_name = get_node_name(node)
456
+ # if node_name in ('ts_delay', 'ts_delta'):
457
+ # if not node.args[1].is_Integer:
458
+ # return True
459
+ # if node_name.startswith('ts_'):
460
+ # if not node.args[-1].is_Number:
461
+ # return True
462
+ # if node.args[-1] <= 1:
463
+ # return True
464
+ # return False
465
+ #
466
+ #
467
+ # def _meaningless__xx_xx(e):
468
+ # """部分函数如果两参数完全一样,可直接丢弃"""
469
+ # for node in preorder_traversal(e):
470
+ # if len(node.args) >= 2:
471
+ # if node.args[0] == node.args[1]:
472
+ # return True
473
+ # return False
@@ -4,7 +4,7 @@ from itertools import product
4
4
  import networkx as nx
5
5
  from sympy import symbols
6
6
 
7
- from expr_codegen.dag import zero_indegree, hierarchy_pos, remove_paths_by_zero_outdegree
7
+ from expr_codegen.dag import zero_indegree, hierarchy_pos, remove_paths_by_zero_outdegree, zero_outdegree
8
8
  from expr_codegen.expr import CL, get_symbols, get_children, get_key, is_simple_expr
9
9
 
10
10
  _RESERVED_WORD_ = {'_NONE_', '_TRUE_', '_FALSE_'}
@@ -196,15 +196,15 @@ def create_dag_exprs(exprs):
196
196
  # 创建有向无环图
197
197
  G = nx.DiGraph()
198
198
 
199
- for symbol, expr in exprs.items():
199
+ for symbol, expr, comment in exprs:
200
200
  # if symbol.name == 'GP_0':
201
201
  # test = 1
202
202
  if expr.is_Symbol:
203
- G.add_node(symbol.name, symbol=symbol, expr=expr)
203
+ G.add_node(symbol.name, symbol=symbol, expr=expr, comment=comment)
204
204
  G.add_edge(expr.name, symbol.name)
205
205
  else:
206
206
  # 添加中间节点
207
- G.add_node(symbol.name, symbol=symbol, expr=expr)
207
+ G.add_node(symbol.name, symbol=symbol, expr=expr, comment=comment)
208
208
  syms = get_symbols(expr, return_str=True)
209
209
  for sym in syms:
210
210
  # 由于边的原因,这里会主动生成一些源节点
@@ -221,6 +221,10 @@ def create_dag_exprs(exprs):
221
221
  s = symbols(node)
222
222
  G.nodes[node]['symbol'] = s
223
223
  G.nodes[node]['expr'] = s
224
+ G.nodes[node]['comment'] = "#"
225
+ #
226
+ # for node in zero_outdegree(G):
227
+ # print(11, G.nodes[node]['comment'])
224
228
  return G
225
229
 
226
230
 
@@ -380,9 +384,9 @@ def skip_expr_node(G: nx.DiGraph, node, keep_nodes):
380
384
  return G
381
385
 
382
386
 
383
- def dag_start(exprs_dict, func, func_kwargs, date, asset):
387
+ def dag_start(exprs_list, func, func_kwargs, date, asset):
384
388
  """初始生成DAG"""
385
- G = create_dag_exprs(exprs_dict)
389
+ G = create_dag_exprs(exprs_list)
386
390
  G = init_dag_exprs(G, func, func_kwargs, date, asset)
387
391
 
388
392
  # 分层输出
@@ -413,11 +417,12 @@ def dag_end(G):
413
417
  for node in generation:
414
418
  key = G.nodes[node]['key']
415
419
  expr = G.nodes[node]['expr']
420
+ comment = G.nodes[node]['comment']
416
421
  symbols = G.nodes[node]['symbols']
417
422
  # 这几个特殊的不算成字段名
418
423
  symbols = list(set(symbols) - _RESERVED_WORD_)
419
424
 
420
- exprs_ldl.append(key, (node, expr, symbols))
425
+ exprs_ldl.append(key, (node, expr, symbols, comment))
421
426
 
422
427
  exprs_ldl._list = exprs_ldl.values()[1:]
423
428
 
@@ -67,9 +67,9 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
67
67
  func_code.append(f" # " + '=' * 40)
68
68
  exprs_dst.append(f"#" + '=' * 40 + func_name)
69
69
  else:
70
- va, ex, sym = kv
70
+ va, ex, sym, comment = kv
71
71
  func_code.append(f" # {va} = {ex}\n g[{va}] = {p.doprint(ex)}")
72
- exprs_dst.append(f"{va} = {ex}")
72
+ exprs_dst.append(f"{va} = {ex} {comment}")
73
73
  if va not in syms_dst:
74
74
  syms_out.append(va)
75
75
 
@@ -42,8 +42,8 @@ def {{ key }}(df: pd.DataFrame) -> pd.DataFrame:
42
42
  """
43
43
 
44
44
  """
45
- {%-for key, value in exprs_src.items() %}
46
- {{ key }} = {{ value-}}
45
+ {%-for a,b,c in exprs_src %}
46
+ {{ a }} = {{ b}} {{c-}}
47
47
  {% endfor %}
48
48
  """
49
49
 
@@ -70,7 +70,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
70
70
  func_code.append(f" df = df.with_columns(")
71
71
  exprs_dst.append(f"#" + '=' * 40 + func_name)
72
72
  else:
73
- va, ex, sym = kv
73
+ va, ex, sym, comment = kv
74
74
  s1 = str(ex)
75
75
  s2 = p.doprint(ex)
76
76
  if s1 != s2:
@@ -78,7 +78,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
78
78
  func_code.append(f"# {va} = {s1}")
79
79
 
80
80
  func_code.append(f"{va}={s2},")
81
- exprs_dst.append(f"{va} = {s1}")
81
+ exprs_dst.append(f"{va} = {s1} {comment}")
82
82
  if va not in syms_dst:
83
83
  syms_out.append(va)
84
84
  func_code.append(f" )")
@@ -51,8 +51,8 @@ def {{ key }}(df: DataFrame) -> DataFrame:
51
51
  """
52
52
 
53
53
  """
54
- {%-for key, value in exprs_src.items() %}
55
- {{ key }} = {{ value-}}
54
+ {%-for a,b,c in exprs_src %}
55
+ {{ a }} = {{ b}} {{c-}}
56
56
  {% endfor %}
57
57
  """
58
58
 
@@ -1,3 +1,4 @@
1
+ import argparse
1
2
  import os
2
3
  from typing import Sequence, Dict, Literal
3
4
 
@@ -43,6 +44,9 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
43
44
  over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
44
45
  **kwargs):
45
46
  """基于模板的代码生成"""
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--over_null", type=str, nargs="?", default=over_null)
49
+
46
50
  # 打印Polars风格代码
47
51
  p = PolarsStrPrinter()
48
52
 
@@ -71,7 +75,9 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
71
75
  func_code.append(f" df = df.with_columns(")
72
76
  exprs_dst.append(f"#" + '=' * 40 + func_name)
73
77
  else:
74
- va, ex, sym = kv
78
+ va, ex, sym, comment = kv
79
+ # 多个#时,只取第一个#后的参数
80
+ args, argv = parser.parse_known_args(args=comment.split("#")[1].split(" "))
75
81
  s1 = str(ex)
76
82
  s2 = p.doprint(ex)
77
83
  if s1 != s2:
@@ -84,9 +90,9 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
84
90
  _sym = f"pl.all_horizontal({','.join(_sym)})"
85
91
  else:
86
92
  _sym = ','.join(_sym)
87
- if over_null == 'partition_by':
93
+ if args.over_null == 'partition_by':
88
94
  func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),")
89
- elif over_null == 'order_by':
95
+ elif args.over_null == 'order_by':
90
96
  func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),")
91
97
  else:
92
98
  func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
@@ -96,7 +102,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
96
102
  func_code.append(f"{va}=({s2}).over(_DATE_, '{k[2]}'),")
97
103
  else:
98
104
  func_code.append(f"{va}={s2},")
99
- exprs_dst.append(f"{va} = {s1}")
105
+ exprs_dst.append(f"{va} = {s1} {comment}")
100
106
  if va not in syms_dst:
101
107
  syms_out.append(va)
102
108
  func_code.append(f" )")
@@ -51,8 +51,8 @@ def {{ key }}(df: DataFrame) -> DataFrame:
51
51
  """
52
52
 
53
53
  """
54
- {%-for key, value in exprs_src.items() %}
55
- {{ key }} = {{ value-}}
54
+ {%-for a,b,c in exprs_src %}
55
+ {{ a }} = {{ b}} {{c-}}
56
56
  {% endfor %}
57
57
  """
58
58
 
@@ -61,7 +61,7 @@ class ExprTool:
61
61
  def __init__(self):
62
62
  self.get_current_func = get_current_by_prefix
63
63
  self.get_current_func_kwargs = {}
64
- self.exprs_dict = {}
64
+ self.exprs_list = {}
65
65
  self.exprs_names = []
66
66
  self.globals_ = {}
67
67
 
@@ -92,7 +92,7 @@ class ExprTool:
92
92
  # print(exprs)
93
93
  return exprs, syms
94
94
 
95
- def merge(self, date, asset, **kwargs):
95
+ def merge(self, date, asset, args):
96
96
  """合并多个表达式
97
97
 
98
98
  1. 先抽取分割子公式
@@ -100,28 +100,31 @@ class ExprTool:
100
100
 
101
101
  Parameters
102
102
  ----------
103
- kwargs
104
- 表达式字典
103
+ args
104
+ 表达式列表
105
105
 
106
106
  Returns
107
107
  -------
108
108
  表达式列表
109
109
  """
110
110
  # 抽取前先化简
111
- kwargs = {k: simplify2(v) for k, v in kwargs.items()}
111
+ args = [(a, simplify2(b), c) for a, b, c in args]
112
112
 
113
- exprs_syms = [self.extract(v, date, asset) for v in kwargs.values()]
113
+ # 保留了注释信息
114
+ exprs_syms = [(self.extract(b, date, asset), c) for a, b, c in args]
114
115
  exprs = []
115
116
  syms = []
116
- for e, s in exprs_syms:
117
- exprs.extend(e)
117
+ for (e, s), c in exprs_syms:
118
118
  syms.extend(s)
119
+ for _ in e:
120
+ # 抽取的表达式添加注释
121
+ exprs.append((_, c))
119
122
 
120
123
  syms = sorted(set(syms), key=syms.index)
121
124
  # 如果目标有重复表达式,这里会混乱
122
125
  exprs = sorted(set(exprs), key=exprs.index)
123
126
  # 这里不能合并简化与未简化的表达式,会导致cse时失败,需要简化表达式合并
124
- exprs = exprs + list(kwargs.values())
127
+ exprs = exprs + [(b, c) for a, b, c in args]
125
128
 
126
129
  # print(exprs)
127
130
  syms = [str(s) for s in syms]
@@ -130,18 +133,18 @@ class ExprTool:
130
133
  def reduce(self, repl, redu):
131
134
  """减少中间变量数量,有利用减少内存占用"""
132
135
 
133
- exprs_dict = {}
136
+ exprs_list = []
134
137
 
135
138
  # cse前简化一次,cse后不再简化
136
139
  # (~开盘涨停 & 昨收涨停) | (~收盘涨停 & 最高涨停)
137
- for variable, expr in repl:
138
- exprs_dict[variable] = expr
139
- for variable, expr in redu:
140
- exprs_dict[variable] = expr
140
+ for a, b in repl:
141
+ exprs_list.append((a, b, "#"))
142
+ for a, b, c in redu:
143
+ exprs_list.append((a, b, c))
141
144
 
142
- return exprs_dict
145
+ return exprs_list
143
146
 
144
- def cse(self, exprs, symbols_repl=None, symbols_redu=None):
147
+ def cse(self, exprs, symbols_repl=None, exprs_src=None):
145
148
  """多个子公式+长公式,提取公共公式
146
149
 
147
150
  Parameters
@@ -150,7 +153,7 @@ class ExprTool:
150
153
  表达式列表
151
154
  symbols_repl
152
155
  中间字段名迭代器
153
- symbols_redu
156
+ exprs_src
154
157
  最终字段名列表
155
158
 
156
159
  Returns
@@ -163,34 +166,38 @@ class ExprTool:
163
166
  表达式
164
167
 
165
168
  """
166
- self.exprs_names = list(symbols_redu)
169
+ self.exprs_names = [a for a, b, c in exprs_src]
170
+ # 包含了注释信息
171
+ _exprs = [a for a, b in exprs]
167
172
 
168
- repl, redu = cse(exprs, symbols_repl, optimizations="basic")
169
- outputs_len = len(symbols_redu)
173
+ # 注意:对于表达式右边相同,左边不同的情况,会当成一个处理
174
+ repl, redu = cse(_exprs, symbols_repl, optimizations="basic")
175
+ outputs_len = len(exprs_src)
170
176
 
171
177
  new_redu = []
172
- symbols_redu = iter(symbols_redu)
178
+ symbols_redu = iter(exprs_src)
173
179
  for expr in redu[-outputs_len:]:
174
180
  # 可能部分表达式只在之前出现过,后面完全用不到如,ts_rank(ts_decay_linear(x_147, 11.4157), 6.72611)
175
181
  variable = next(symbols_redu)
176
- variable = symbols(variable)
177
- new_redu.append((variable, expr))
182
+ a = symbols(variable[0])
183
+ new_redu.append((a, expr, variable[2]))
178
184
 
179
- self.exprs_dict = self.reduce(repl, new_redu)
185
+ self.exprs_list = self.reduce(repl, new_redu)
180
186
 
181
187
  # with open("exprs.pickle", "wb") as file:
182
188
  # pickle.dump(exprs_dict, file)
183
189
 
184
- return self.exprs_dict
190
+ return self.exprs_list
185
191
 
186
192
  def dag(self, merge: bool, date, asset):
187
193
  """生成DAG"""
188
- G = dag_start(self.exprs_dict, self.get_current_func, self.get_current_func_kwargs, date, asset)
194
+ G = dag_start(self.exprs_list, self.get_current_func, self.get_current_func_kwargs, date, asset)
189
195
  if merge:
190
196
  G = dag_middle(G, self.exprs_names, self.get_current_func, self.get_current_func_kwargs, date, asset)
191
197
  return dag_end(G)
192
198
 
193
- def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
199
+ def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over',
200
+ template_file: str = 'template.py.j2',
194
201
  replace: bool = True, regroup: bool = False, format: bool = True,
195
202
  date='date', asset='asset',
196
203
  alias: Dict[str, str] = {},
@@ -200,8 +207,8 @@ class ExprTool:
200
207
 
201
208
  Parameters
202
209
  ----------
203
- exprs_src: dict
204
- 表达式字典
210
+ exprs_src: list
211
+ 表达式列表
205
212
  style: str
206
213
  代码风格。可选值 ('polars_group', 'polars_over', 'pandas')
207
214
  template_file: str
@@ -232,11 +239,11 @@ class ExprTool:
232
239
  exprs_src = replace_exprs(exprs_src)
233
240
 
234
241
  # 子表达式在前,原表式在最后
235
- exprs_dst, syms_dst = self.merge(date, asset, **exprs_src)
242
+ exprs_dst, syms_dst = self.merge(date, asset, exprs_src)
236
243
  syms_dst = list(set(syms_dst) - _RESERVED_WORD_)
237
244
 
238
245
  # 提取公共表达式
239
- self.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), symbols_redu=exprs_src.keys())
246
+ self.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), exprs_src=exprs_src)
240
247
  # 有向无环图流转
241
248
  exprs_ldl, G = self.dag(True, date, asset)
242
249
 
@@ -272,14 +279,15 @@ class ExprTool:
272
279
  extra_codes: str,
273
280
  output_file: str,
274
281
  convert_xor: bool,
275
- style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
282
+ style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over',
283
+ template_file: str = 'template.py.j2',
276
284
  date: str = 'date', asset: str = 'asset',
277
285
  **kwargs) -> str:
278
286
  """通过字符串生成代码, 加了缓存,多次调用不重复生成"""
279
- raw, exprs_dict = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
287
+ raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
280
288
 
281
289
  # 生成代码
282
- code, G = _TOOL_.all(exprs_dict, style=style, template_file=template_file,
290
+ code, G = _TOOL_.all(exprs_list, style=style, template_file=template_file,
283
291
  replace=True, regroup=True, format=True,
284
292
  date=date, asset=asset,
285
293
  # 复制了需要使用的函数,还复制了最原始的表达式
@@ -333,6 +341,7 @@ _TOOL_ = ExprTool()
333
341
 
334
342
  def codegen_exec(df: Optional[DataFrame],
335
343
  *codes,
344
+ over_null: Literal['partition_by', 'order_by', None],
336
345
  extra_codes: str = r'CS_SW_L1 = r"^sw_l1_\d+$"',
337
346
  output_file: Union[str, TextIOBase, None] = None,
338
347
  run_file: Union[bool, str] = False,
@@ -340,7 +349,7 @@ def codegen_exec(df: Optional[DataFrame],
340
349
  style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over',
341
350
  template_file: str = 'template.py.j2',
342
351
  date: str = 'date', asset: str = 'asset',
343
- over_null: Literal['partition_by', 'order_by', None] = 'partition_by',
352
+
344
353
  **kwargs) -> Optional[DataFrame]:
345
354
  """快速转换源代码并执行
346
355
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expr_codegen
3
- Version: 0.10.15
3
+ Version: 0.11.0
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -43,6 +43,7 @@ Requires-Dist: Jinja2
43
43
  Requires-Dist: networkx
44
44
  Requires-Dist: loguru
45
45
  Requires-Dist: sympy
46
+ Requires-Dist: ast-comments
46
47
  Provides-Extra: streamlit
47
48
  Requires-Dist: streamlit; extra == "streamlit"
48
49
  Requires-Dist: streamlit-ace; extra == "streamlit"
@@ -211,6 +212,33 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
211
212
  2. `over_null='order_by'`。分到一个区域,`null`排在前面
212
213
  3. `over_null=None`。不处理,直接调用,速度更快。如果确信不会中段产生`null`建议使用此参数
213
214
 
215
+ `codegen_exec(over_null='partition_by')`为全局使用`partition_by`。但遇到`ts_count_nulls`这类`null`
216
+ 函数就得使用`over_null=None`,所以本工具还新添了注释功能来指定单行表达式参数
217
+
218
+ 1. `# --over_null partition_by`。单行`over_null='partition_by'`
219
+ 2. `# --over_null=order_by`。单行`over_null='order_by'`
220
+ 3. `# --over_null`。单行`over_null=None`
221
+ 4. `# `。取`codegen_exec`参数传入的`over_null`值
222
+
223
+ 注意:
224
+
225
+ 1. `# --over_null`传参注释只能写在单行表达式的后面,不能独立成一行,否则会被忽略
226
+ 2. `# --over_null # --over_null=order_by`多个`#`时,只取第一个有效
227
+ 3. 只对最外层`ts`函数有效。如果`ts`函数不在外层,需要人工提炼。例如:
228
+ ```python
229
+ X1 = cs_rank(ts_mean(CLOSE, 3)) # --over_null=order_by # 应用在cs_rank上,没有意义
230
+ X2 = ts_rank(ts_mean(CLOSE, 3), 5) # --over_null=order_by # 本以为应用在ts_rank(ts_mean)上,但由于出现了公共ts_mean,其实是应用在ts_rank(_x_0)上
231
+ ```
232
+
233
+ 需写成
234
+
235
+ ```python
236
+ _x_0 = ts_mean(CLOSE, 3) # --over_null=order_by
237
+ X1 = cs_rank(_x_0)
238
+ X2 = ts_rank(_x_0, 5)
239
+ ```
240
+ 4. 由于很容易搞错,强烈建议生成`output_file`,检查生成的代码是否正确。
241
+
214
242
  ## `expr_codegen`局限性
215
243
 
216
244
  1. `DAG`只能增加列无法删除。增加列时,遇到同名列会覆盖
@@ -220,7 +248,8 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
220
248
 
221
249
  ## 特别语法
222
250
 
223
- 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`,最后转成`if_else(C,T,F)`。支持与`if else`混用
251
+ 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F`
252
+ ,最后转成`if_else(C,T,F)`。支持与`if else`混用
224
253
  2. `(A<B)*-1`,底层将转换成`int_(A<B)*-1`
225
254
  3. 为防止`A==B`被`sympy`替换成`False`,底层会换成`Eq(A,B)`
226
255
  4. `A^B`的含义与`convert_xor`参数有关,`convert_xor=True`底层会转换成`Pow(A,B)`,反之为`Xor(A,B)`。默认为`False`,用`**`表示乘方
@@ -230,6 +259,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
230
259
  8. 支持`~A`,底层会转换成`Not(A)`
231
260
  9. `gp_`开头的函数都会返回对应的`cs_`函数。如`gp_func(A,B,C)`会替换成`cs_func(B,C)`,其中`A`用在了`groupby([date, A])`
232
261
  10. 支持`A,B,C=MACD()`元组解包,在底层会替换成
262
+ 11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
233
263
 
234
264
  ```python
235
265
  _x_0 = MACD()
@@ -242,7 +272,8 @@ C = unpack(_x_0, 2)
242
272
 
243
273
  1. 输出的数据,所有以`_`开头的列,最后会被自动删除。所以需要保留的变量一定不要以`_`开头
244
274
  2. 为减少重复计算,自动添加了了中间变量,以`_x_`开头,如`_x_0`,`_x_1`等。最后会被自动删除
245
- 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
275
+ 3. 单行表达式过长,或有重复计算,可以通过中间变量,将单行表达式改成多行。如果中间变量使用`_`
276
+ 开头,将会自动添加数字后缀,形成不同的变量,如`_A`会替换成`_A_0_`、`_A_1_`等。使用场景如下:
246
277
  1. 同一变量名,重复使用。本质是不同的变量
247
278
  2. 循环赋值,但`DAG`不支持有环。`=`号左右的同名变量其实是不同变量
248
279
 
@@ -3,6 +3,7 @@ Jinja2
3
3
  networkx
4
4
  loguru
5
5
  sympy
6
+ ast-comments
6
7
 
7
8
  [streamlit]
8
9
  streamlit
@@ -22,6 +22,7 @@ dependencies = [
22
22
  'networkx',
23
23
  'loguru',
24
24
  'sympy',
25
+ 'ast-comments',
25
26
  # 'polars_ta',
26
27
  ]
27
28
  dynamic = ["version"]
@@ -1 +0,0 @@
1
- __version__ = "0.10.15"
File without changes
File without changes