expr-codegen 0.15.2__tar.gz → 0.16.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/PKG-INFO +2 -1
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/README.md +1 -0
- expr_codegen-0.16.1/expr_codegen/_version.py +1 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/codes.py +63 -3
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/code.py +9 -4
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/code.py +9 -4
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/template.py.j2 +2 -3
- expr_codegen-0.16.1/expr_codegen/rust/code.py +150 -0
- expr_codegen-0.16.1/expr_codegen/rust/printer.py +115 -0
- expr_codegen-0.16.1/expr_codegen/rust/template.rs.j2 +125 -0
- expr_codegen-0.16.1/expr_codegen/sql/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/tool.py +17 -7
- expr_codegen-0.15.2/expr_codegen/_version.py +0 -1
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/.gitignore +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/LICENSE +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/dag.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/expr.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/latex/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/latex/printer.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/model.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/helper.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/printer.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/ta.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/template.py.j2 +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/printer.py +0 -0
- {expr_codegen-0.15.2/expr_codegen/sql → expr_codegen-0.16.1/expr_codegen/rust}/__init__.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/code.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/printer.py +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/template.sql.j2 +0 -0
- {expr_codegen-0.15.2 → expr_codegen-0.16.1}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: expr_codegen
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.16.1
|
|
4
4
|
Summary: symbol expression to polars expression tool
|
|
5
5
|
Author-email: wukan <wu-kan@163.com>
|
|
6
6
|
License: BSD 3-Clause License
|
|
@@ -276,6 +276,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
|
|
|
276
276
|
```
|
|
277
277
|
11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
|
|
278
278
|
12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
|
|
279
|
+
13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
|
|
279
280
|
|
|
280
281
|
## 下划线开头的变量
|
|
281
282
|
|
|
@@ -226,6 +226,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
|
|
|
226
226
|
```
|
|
227
227
|
11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
|
|
228
228
|
12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
|
|
229
|
+
13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
|
|
229
230
|
|
|
230
231
|
## 下划线开头的变量
|
|
231
232
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.16.1"
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import ast
|
|
2
|
+
import inspect
|
|
2
3
|
import re
|
|
3
4
|
from ast import expr
|
|
4
5
|
|
|
@@ -333,6 +334,63 @@ class RenameTransformer(ast.NodeTransformer):
|
|
|
333
334
|
return node
|
|
334
335
|
|
|
335
336
|
|
|
337
|
+
class KeywordToPositionalTransformer(ast.NodeTransformer):
|
|
338
|
+
def __init__(self, function_mapping):
|
|
339
|
+
self.function_mapping = function_mapping # 函数名到实际函数的映射
|
|
340
|
+
|
|
341
|
+
def visit_Call(self, node):
|
|
342
|
+
if isinstance(node.func, ast.Name) and node.keywords:
|
|
343
|
+
func_name = node.func.id
|
|
344
|
+
if func_name in self.function_mapping:
|
|
345
|
+
return self.transform_call(node, self.function_mapping[func_name])
|
|
346
|
+
return node
|
|
347
|
+
|
|
348
|
+
def transform_call(self, node, target_func):
|
|
349
|
+
try:
|
|
350
|
+
# 获取函数参数签名
|
|
351
|
+
sig = inspect.signature(target_func)
|
|
352
|
+
param_names = list(sig.parameters.keys())
|
|
353
|
+
|
|
354
|
+
# 构建参数映射
|
|
355
|
+
arg_mapping = {}
|
|
356
|
+
# 处理现有位置参数
|
|
357
|
+
for i, arg in enumerate(node.args):
|
|
358
|
+
if i < len(param_names):
|
|
359
|
+
arg_mapping[param_names[i]] = arg
|
|
360
|
+
|
|
361
|
+
# 处理关键字参数
|
|
362
|
+
for keyword in node.keywords:
|
|
363
|
+
if keyword.arg in param_names:
|
|
364
|
+
arg_mapping[keyword.arg] = keyword.value
|
|
365
|
+
|
|
366
|
+
# 按参数顺序构建新的位置参数列表
|
|
367
|
+
new_args = []
|
|
368
|
+
for param_name in param_names:
|
|
369
|
+
if param_name in arg_mapping:
|
|
370
|
+
new_args.append(arg_mapping[param_name])
|
|
371
|
+
else:
|
|
372
|
+
# 对于没有提供的参数,需要处理默认值
|
|
373
|
+
param = sig.parameters[param_name]
|
|
374
|
+
if param.default != inspect.Parameter.empty:
|
|
375
|
+
# 使用默认值
|
|
376
|
+
new_args.append(ast.Constant(value=param.default))
|
|
377
|
+
else:
|
|
378
|
+
# 必需参数缺失,保持原样或报错
|
|
379
|
+
return node
|
|
380
|
+
|
|
381
|
+
# 创建新的调用节点
|
|
382
|
+
new_node = ast.Call(
|
|
383
|
+
func=node.func,
|
|
384
|
+
args=new_args,
|
|
385
|
+
keywords=[]
|
|
386
|
+
)
|
|
387
|
+
return new_node
|
|
388
|
+
|
|
389
|
+
except Exception as e:
|
|
390
|
+
# 转换失败时返回原节点
|
|
391
|
+
return node
|
|
392
|
+
|
|
393
|
+
|
|
336
394
|
def source_replace(source: str) -> str:
|
|
337
395
|
# 三元表达式转换成 错误版if( )else,一定得在Transformer中修正
|
|
338
396
|
num = 1
|
|
@@ -374,7 +432,7 @@ def raw_to_code(raw):
|
|
|
374
432
|
return '\n'.join([ast_comments.unparse(a) for a in raw])
|
|
375
433
|
|
|
376
434
|
|
|
377
|
-
def sources_to_asts(*sources, convert_xor: bool):
|
|
435
|
+
def sources_to_asts(*sources, convert_xor: bool, function_mapping):
|
|
378
436
|
"""输入多份源代码"""
|
|
379
437
|
|
|
380
438
|
def _source_to_asts(source):
|
|
@@ -394,6 +452,8 @@ def sources_to_asts(*sources, convert_xor: bool):
|
|
|
394
452
|
|
|
395
453
|
t1 = SyntaxTransformer(convert_xor)
|
|
396
454
|
t1.visit(tree)
|
|
455
|
+
t2 = KeywordToPositionalTransformer(function_mapping)
|
|
456
|
+
t2.visit(tree)
|
|
397
457
|
t = RenameTransformer({}, {})
|
|
398
458
|
t.visit(tree)
|
|
399
459
|
|
|
@@ -429,12 +489,12 @@ def _add_default_type(globals_):
|
|
|
429
489
|
return globals_
|
|
430
490
|
|
|
431
491
|
|
|
432
|
-
def sources_to_exprs(globals_, *sources, convert_xor: bool):
|
|
492
|
+
def sources_to_exprs(globals_, *sources, convert_xor: bool, function_mapping):
|
|
433
493
|
"""将源代码转换成表达式"""
|
|
434
494
|
|
|
435
495
|
globals_ = _add_default_type(globals_)
|
|
436
496
|
|
|
437
|
-
raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor)
|
|
497
|
+
raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor, function_mapping=function_mapping)
|
|
438
498
|
# 支持OPEN[1]转ts_delay(OPEN,1)
|
|
439
499
|
funcs_new.add('ts_delay')
|
|
440
500
|
|
|
@@ -13,18 +13,23 @@ def get_groupby_from_tuple(tup, func_name, drop_cols):
|
|
|
13
13
|
"""从传入的元组中生成分组运行代码"""
|
|
14
14
|
prefix2, *_ = tup
|
|
15
15
|
|
|
16
|
+
if len(drop_cols)>0:
|
|
17
|
+
drop_str = f'.drop(columns={drop_cols})'
|
|
18
|
+
else:
|
|
19
|
+
drop_str = ""
|
|
20
|
+
|
|
16
21
|
if prefix2 == TS:
|
|
17
22
|
# 组内需要按时间进行排序,需要维持顺序
|
|
18
23
|
prefix2, asset = tup
|
|
19
|
-
return f'df = df.groupby(by=[_ASSET_], group_keys=False).apply({func_name})
|
|
24
|
+
return f'df = df.groupby(by=[_ASSET_], group_keys=False).apply({func_name}){drop_str}'
|
|
20
25
|
if prefix2 == CS:
|
|
21
26
|
prefix2, date = tup
|
|
22
|
-
return f'df = df.groupby(by=[_DATE_], group_keys=False).apply({func_name})
|
|
27
|
+
return f'df = df.groupby(by=[_DATE_], group_keys=False).apply({func_name}){drop_str}'
|
|
23
28
|
if prefix2 == GP:
|
|
24
29
|
prefix2, date, group = tup
|
|
25
|
-
return f'df = df.groupby(by=[_DATE_, "{group}"], group_keys=False).apply({func_name})
|
|
30
|
+
return f'df = df.groupby(by=[_DATE_, "{group}"], group_keys=False).apply({func_name}){drop_str}'
|
|
26
31
|
|
|
27
|
-
return f'df = {func_name}(df)
|
|
32
|
+
return f'df = {func_name}(df){drop_str}'
|
|
28
33
|
|
|
29
34
|
|
|
30
35
|
def symbols_to_code(syms):
|
|
@@ -14,18 +14,23 @@ def get_groupby_from_tuple(tup, func_name, drop_cols):
|
|
|
14
14
|
"""从传入的元组中生成分组运行代码"""
|
|
15
15
|
prefix2, *_ = tup
|
|
16
16
|
|
|
17
|
+
if len(drop_cols)>0:
|
|
18
|
+
drop_str = f'.drop(*{drop_cols})'
|
|
19
|
+
else:
|
|
20
|
+
drop_str = ""
|
|
21
|
+
|
|
17
22
|
if prefix2 == TS:
|
|
18
23
|
# 组内需要按时间进行排序,需要维持顺序
|
|
19
24
|
prefix2, asset = tup
|
|
20
|
-
return f'df = {func_name}(df.sort(_ASSET_, _DATE_))
|
|
25
|
+
return f'df = {func_name}(df.sort(_ASSET_, _DATE_)){drop_str}'
|
|
21
26
|
if prefix2 == CS:
|
|
22
27
|
prefix2, date = tup
|
|
23
|
-
return f'df = {func_name}(df.sort(_DATE_))
|
|
28
|
+
return f'df = {func_name}(df.sort(_DATE_)){drop_str}'
|
|
24
29
|
if prefix2 == GP:
|
|
25
30
|
prefix2, date, group = tup
|
|
26
|
-
return f'df = {func_name}(df.sort(_DATE_, "{group}"))
|
|
31
|
+
return f'df = {func_name}(df.sort(_DATE_, "{group}")){drop_str}'
|
|
27
32
|
|
|
28
|
-
return f'df = {func_name}(df)
|
|
33
|
+
return f'df = {func_name}(df){drop_str}'
|
|
29
34
|
|
|
30
35
|
|
|
31
36
|
def symbols_to_code(syms):
|
|
@@ -63,12 +63,11 @@ def {{ key }}(df: DataFrame) -> DataFrame:
|
|
|
63
63
|
|
|
64
64
|
|
|
65
65
|
def _filter_last(df: DataFrame, ge_date_idx: int) -> DataFrame:
|
|
66
|
-
"""过滤数据,只取最后几天。实盘时可用于减少计算量
|
|
67
|
-
"""
|
|
66
|
+
"""过滤数据,只取最后几天。实盘时可用于减少计算量"""
|
|
68
67
|
if ge_date_idx == 0:
|
|
69
68
|
return df
|
|
70
69
|
else:
|
|
71
|
-
return df.filter(pl.col(_DATE_) >=
|
|
70
|
+
return df.filter(pl.col(_DATE_) >= pl.col(_DATE_).unique().sort().slice(ge_date_idx, 1).first())
|
|
72
71
|
|
|
73
72
|
|
|
74
73
|
def main(df: DataFrame, ge_date_idx: int) -> DataFrame:
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import Sequence, Literal
|
|
5
|
+
|
|
6
|
+
import jinja2
|
|
7
|
+
from jinja2 import FileSystemLoader, TemplateNotFound
|
|
8
|
+
|
|
9
|
+
from expr_codegen.expr import TS, CS, GP
|
|
10
|
+
from expr_codegen.model import ListDictList
|
|
11
|
+
from expr_codegen.rust.printer import RustStrPrinter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_groupby_from_tuple(tup, func_name, drop_cols):
|
|
15
|
+
"""从传入的元组中生成分组运行代码"""
|
|
16
|
+
prefix2, *_ = tup
|
|
17
|
+
if len(drop_cols)>0:
|
|
18
|
+
drop_cols = [f'"{c}".into()' for c in drop_cols]
|
|
19
|
+
drop_str = f'.drop(Selector::ByName {{ names: Arc::new([{','.join(drop_cols)}]), strict: true }})'
|
|
20
|
+
else:
|
|
21
|
+
drop_str = ""
|
|
22
|
+
|
|
23
|
+
if prefix2 == TS:
|
|
24
|
+
# 组内需要按时间进行排序,需要维持顺序
|
|
25
|
+
prefix2, asset = tup
|
|
26
|
+
return f'df = {func_name}(df.sort([_ASSET_, _DATE_], SortMultipleOptions::default())){drop_str};'
|
|
27
|
+
if prefix2 == CS:
|
|
28
|
+
prefix2, date = tup
|
|
29
|
+
return f'df = {func_name}(df.sort([_DATE_], SortMultipleOptions::default())){drop_str};'
|
|
30
|
+
if prefix2 == GP:
|
|
31
|
+
prefix2, date, group = tup
|
|
32
|
+
return f'df = {func_name}(df.sort([_DATE_, "{group}"], SortMultipleOptions::default())){drop_str};'
|
|
33
|
+
|
|
34
|
+
return f'df = {func_name}(df){drop_str};'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# def symbols_to_code(syms):
|
|
38
|
+
# a = [f"{s}" for s in syms]
|
|
39
|
+
# b = [f"'{s}'" for s in syms]
|
|
40
|
+
# return f"""_ = [{','.join(b)}]
|
|
41
|
+
# [{','.join(a)}] = [pl.col(i) for i in _]"""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
|
|
45
|
+
filename,
|
|
46
|
+
date='date', asset='asset',
|
|
47
|
+
extra_codes: Sequence[str] = (),
|
|
48
|
+
over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
|
|
49
|
+
ge_date_idx: int = 0,
|
|
50
|
+
**kwargs):
|
|
51
|
+
"""基于模板的代码生成"""
|
|
52
|
+
if filename is None:
|
|
53
|
+
filename = 'template.rs.j2'
|
|
54
|
+
|
|
55
|
+
parser = argparse.ArgumentParser()
|
|
56
|
+
parser.add_argument("--over_null", type=str, nargs="?", default=over_null)
|
|
57
|
+
|
|
58
|
+
# 打印Polars风格代码
|
|
59
|
+
p = RustStrPrinter()
|
|
60
|
+
|
|
61
|
+
# polars风格代码
|
|
62
|
+
funcs = {}
|
|
63
|
+
# 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前
|
|
64
|
+
groupbys = {'sort': ''}
|
|
65
|
+
# 处理过后的表达式
|
|
66
|
+
exprs_dst = []
|
|
67
|
+
syms_out = []
|
|
68
|
+
ts_func_name = None
|
|
69
|
+
drop_symbols = exprs_ldl.drop_symbols()
|
|
70
|
+
j = -1
|
|
71
|
+
for i, row in enumerate(exprs_ldl.values()):
|
|
72
|
+
for k, vv in row.items():
|
|
73
|
+
j += 1
|
|
74
|
+
if len(vv) == 0:
|
|
75
|
+
continue
|
|
76
|
+
# 函数名
|
|
77
|
+
func_name = f'func_{i}_{"_".join(k)}'
|
|
78
|
+
func_code = []
|
|
79
|
+
for kv in vv:
|
|
80
|
+
if kv is None:
|
|
81
|
+
func_code.append(f" ]);")
|
|
82
|
+
func_code.append(f"// " + '=' * 40)
|
|
83
|
+
func_code.append(f" df = df.with_columns([")
|
|
84
|
+
exprs_dst.append(f"#" + '=' * 40 + func_name)
|
|
85
|
+
else:
|
|
86
|
+
va, ex, sym, comment = kv
|
|
87
|
+
# 多个#时,只取第一个#后的参数
|
|
88
|
+
args, argv = parser.parse_known_args(args=comment.split("#")[1].split(" "))
|
|
89
|
+
s1 = str(ex)
|
|
90
|
+
s2 = p.doprint(ex)
|
|
91
|
+
if s1 != s2:
|
|
92
|
+
# 不想等,打印注释,显示会更直观察
|
|
93
|
+
func_code.append(f"// {va} = {s1}")
|
|
94
|
+
if k[0] == TS:
|
|
95
|
+
ts_func_name = func_name
|
|
96
|
+
# https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629
|
|
97
|
+
_sym = [f'col("{s}").is_not_null()' for s in set(sym)]
|
|
98
|
+
if len(_sym) > 1:
|
|
99
|
+
_sym = f"all_horizontal([{','.join(_sym)}]).unwrap()"
|
|
100
|
+
else:
|
|
101
|
+
_sym = ','.join(_sym)
|
|
102
|
+
if args.over_null == 'partition_by':
|
|
103
|
+
func_code.append(f'({s2}).over_with_options(Some([{_sym}, col(_ASSET_)]), Some(([col(_DATE_), lit(1)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
|
|
104
|
+
elif args.over_null == 'order_by':
|
|
105
|
+
func_code.append(f'({s2}).over_with_options(Some([col(_ASSET_), lit(1)]), Some(([{_sym}, col(_DATE_)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
|
|
106
|
+
else:
|
|
107
|
+
func_code.append(f'({s2}).over_with_options(Some([_ASSET_]), Some(([_DATE_], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
|
|
108
|
+
elif k[0] == CS:
|
|
109
|
+
func_code.append(f'({s2}).over([_DATE_]).alias("{va}"),')
|
|
110
|
+
elif k[0] == GP:
|
|
111
|
+
func_code.append(f'({s2}).over([_DATE_, "{k[2]}"]).alias("{va}"),')
|
|
112
|
+
else:
|
|
113
|
+
func_code.append(f'({s2}).alias("{va}"),')
|
|
114
|
+
exprs_dst.append(f"{va} = {s1} {comment}")
|
|
115
|
+
if va not in syms_dst:
|
|
116
|
+
syms_out.append(va)
|
|
117
|
+
func_code.append(f" ]);")
|
|
118
|
+
func_code = func_code[1:]
|
|
119
|
+
|
|
120
|
+
# polars风格代码列表
|
|
121
|
+
funcs[func_name] = '\n'.join(func_code)
|
|
122
|
+
# 只有下划线开头再删除
|
|
123
|
+
ds = [x for x in drop_symbols[j] if x.startswith('_')]
|
|
124
|
+
# 分组应用代码
|
|
125
|
+
groupbys[func_name] = get_groupby_from_tuple(k, func_name, ds)
|
|
126
|
+
|
|
127
|
+
# syms1 = symbols_to_code(syms_dst)
|
|
128
|
+
# syms2 = symbols_to_code(syms_out)
|
|
129
|
+
# filter_last处理
|
|
130
|
+
_groupbys = {'sort': groupbys['sort']}
|
|
131
|
+
if ts_func_name is None:
|
|
132
|
+
_groupbys['_filter_last'] = "df = _filter_last(df, ge_date_idx);"
|
|
133
|
+
for k, v in groupbys.items():
|
|
134
|
+
_groupbys[k] = v
|
|
135
|
+
if k == ts_func_name:
|
|
136
|
+
_groupbys[k + '_filter_last'] = "df = _filter_last(df, ge_date_idx);"
|
|
137
|
+
groupbys = _groupbys
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__)))
|
|
141
|
+
template = env.get_template(filename)
|
|
142
|
+
except TemplateNotFound:
|
|
143
|
+
env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename)))
|
|
144
|
+
template = env.get_template(os.path.basename(filename))
|
|
145
|
+
|
|
146
|
+
return template.render(funcs=funcs, groupbys=groupbys,
|
|
147
|
+
exprs_src=exprs_src, exprs_dst=exprs_dst,
|
|
148
|
+
# syms1=syms1, syms2=syms2,
|
|
149
|
+
date=date, asset=asset,
|
|
150
|
+
extra_codes=extra_codes)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
from sympy import Basic, Function, StrPrinter
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略
|
|
7
|
+
|
|
8
|
+
class RustStrPrinter(StrPrinter):
|
|
9
|
+
def _print(self, expr, **kwargs) -> str:
|
|
10
|
+
"""Internal dispatcher
|
|
11
|
+
|
|
12
|
+
Tries the following concepts to print an expression:
|
|
13
|
+
1. Let the object print itself if it knows how.
|
|
14
|
+
2. Take the best fitting method defined in the printer.
|
|
15
|
+
3. As fall-back use the emptyPrinter method for the printer.
|
|
16
|
+
"""
|
|
17
|
+
self._print_level += 1
|
|
18
|
+
try:
|
|
19
|
+
# If the printer defines a name for a printing method
|
|
20
|
+
# (Printer.printmethod) and the object knows for itself how it
|
|
21
|
+
# should be printed, use that method.
|
|
22
|
+
if self.printmethod and hasattr(expr, self.printmethod):
|
|
23
|
+
if not (isinstance(expr, type) and issubclass(expr, Basic)):
|
|
24
|
+
return getattr(expr, self.printmethod)(self, **kwargs)
|
|
25
|
+
|
|
26
|
+
# See if the class of expr is known, or if one of its super
|
|
27
|
+
# classes is known, and use that print function
|
|
28
|
+
# Exception: ignore the subclasses of Undefined, so that, e.g.,
|
|
29
|
+
# Function('gamma') does not get dispatched to _print_gamma
|
|
30
|
+
classes = type(expr).__mro__
|
|
31
|
+
# if AppliedUndef in classes:
|
|
32
|
+
# classes = classes[classes.index(AppliedUndef):]
|
|
33
|
+
# if UndefinedFunction in classes:
|
|
34
|
+
# classes = classes[classes.index(UndefinedFunction):]
|
|
35
|
+
# Another exception: if someone subclasses a known function, e.g.,
|
|
36
|
+
# gamma, and changes the name, then ignore _print_gamma
|
|
37
|
+
if Function in classes:
|
|
38
|
+
i = classes.index(Function)
|
|
39
|
+
classes = tuple(c for c in classes[:i] if \
|
|
40
|
+
c.__name__ == classes[0].__name__ or \
|
|
41
|
+
c.__name__.endswith("Base")) + classes[i:]
|
|
42
|
+
for cls in classes:
|
|
43
|
+
printmethodname = '_print_' + cls.__name__
|
|
44
|
+
|
|
45
|
+
# 所有以gp_开头的函数都转换成cs_开头
|
|
46
|
+
if printmethodname.startswith('_print_gp_'):
|
|
47
|
+
printmethodname = "_print_gp_"
|
|
48
|
+
|
|
49
|
+
printmethod = getattr(self, printmethodname, None)
|
|
50
|
+
if printmethod is not None:
|
|
51
|
+
return printmethod(expr, **kwargs)
|
|
52
|
+
# Unknown object, fall back to the emptyPrinter.
|
|
53
|
+
return self.emptyPrinter(expr)
|
|
54
|
+
finally:
|
|
55
|
+
self._print_level -= 1
|
|
56
|
+
|
|
57
|
+
def _print_Symbol(self, expr):
|
|
58
|
+
if expr.name in ('_NONE_', '_TRUE_', '_FALSE_'):
|
|
59
|
+
return expr.name
|
|
60
|
+
return f'col("{expr.name}")'
|
|
61
|
+
|
|
62
|
+
def _print_Equality(self, expr):
|
|
63
|
+
new_args = [f"eq({self._print(arg)})" for arg in expr.args]
|
|
64
|
+
return ".".join(new_args)[2:]
|
|
65
|
+
|
|
66
|
+
def _print_Or(self, expr):
|
|
67
|
+
new_args = [f"or({self._print(arg)})" for arg in expr.args]
|
|
68
|
+
return ".".join(new_args)[2:]
|
|
69
|
+
|
|
70
|
+
def _print_Xor(self, expr):
|
|
71
|
+
new_args = [f"xor({self._print(arg)})" for arg in expr.args]
|
|
72
|
+
return ".".join(new_args)[3:]
|
|
73
|
+
|
|
74
|
+
def _print_And(self, expr):
|
|
75
|
+
new_args = [f"and({self._print(arg)})" for arg in expr.args]
|
|
76
|
+
return ".".join(new_args)[3:]
|
|
77
|
+
|
|
78
|
+
def _print_Not(self, expr):
|
|
79
|
+
return "(%s).not()" % self._print(expr.args[0])
|
|
80
|
+
|
|
81
|
+
def _print_gp_(self, expr):
|
|
82
|
+
"""gp_函数都转换成cs_函数,但要丢弃第一个参数"""
|
|
83
|
+
new_args = [self._print(arg) for arg in expr.args[1:]]
|
|
84
|
+
func_name = expr.func.__name__[3:]
|
|
85
|
+
return "cs_%s(%s)" % (func_name, ",".join(new_args))
|
|
86
|
+
|
|
87
|
+
def _print_Integer(self, expr):
|
|
88
|
+
caller_frame = inspect.stack()[2]
|
|
89
|
+
caller_name = caller_frame.function
|
|
90
|
+
if caller_name in ("_print_Pow", "_print_Add", "_print_Mul", "_print_Relational"):
|
|
91
|
+
return "lit(%s)" % super()._print_Integer(expr)
|
|
92
|
+
else:
|
|
93
|
+
return super()._print_Integer(expr)
|
|
94
|
+
|
|
95
|
+
def _print_Float(self, expr):
|
|
96
|
+
caller_frame = inspect.stack()[2]
|
|
97
|
+
caller_name = caller_frame.function
|
|
98
|
+
if caller_name in ("_print_Pow", "_print_Add", "_print_Mul", "_print_Relational"):
|
|
99
|
+
return "lit(%s)" % super()._print_Float(expr)
|
|
100
|
+
else:
|
|
101
|
+
return super()._print_Float(expr)
|
|
102
|
+
|
|
103
|
+
def _print_Relational(self, expr):
|
|
104
|
+
|
|
105
|
+
charmap = {
|
|
106
|
+
"<": "lt",
|
|
107
|
+
">": "gt",
|
|
108
|
+
">=": "gt_eq",
|
|
109
|
+
"<=": "lt_eq",
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if expr.rel_op in charmap:
|
|
113
|
+
return '(%s).%s(%s)' % (self._print(expr.lhs), charmap[expr.rel_op], self._print(expr.rhs))
|
|
114
|
+
|
|
115
|
+
return super()._print_Relational(expr)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
// this code is auto generated by the expr_codegen
|
|
2
|
+
// https://github.com/wukan1986/expr_codegen
|
|
3
|
+
// 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request
|
|
4
|
+
use polars::prelude::*;
|
|
5
|
+
|
|
6
|
+
pub fn ts_delay(x: Expr, n: i16) -> Expr {
|
|
7
|
+
x.shift(lit(n))
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
pub fn ts_mean(x: Expr, d: usize) -> Expr {
|
|
11
|
+
x.rolling_mean(RollingOptionsFixedWindow {
|
|
12
|
+
window_size: d,
|
|
13
|
+
min_periods: d,
|
|
14
|
+
weights: None,
|
|
15
|
+
center: false,
|
|
16
|
+
fn_params: None,
|
|
17
|
+
})
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
pub fn ts_sum(x: Expr, d: usize) -> Expr {
|
|
21
|
+
x.rolling_sum(RollingOptionsFixedWindow {
|
|
22
|
+
window_size: d,
|
|
23
|
+
min_periods: d,
|
|
24
|
+
weights: None,
|
|
25
|
+
center: false,
|
|
26
|
+
fn_params: None,
|
|
27
|
+
})
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
pub fn cs_zscore(x: Expr, ddof: u8) -> Expr {
|
|
31
|
+
(x.clone() - x.clone().mean()) / x.clone().std(ddof)
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
const _DATE_: &str = "date";
|
|
35
|
+
const _ASSET_: &str = "asset";
|
|
36
|
+
const _NONE_: Option<i32> = None;
|
|
37
|
+
const _TRUE_: bool = true;
|
|
38
|
+
const _FALSE_: bool = false;
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
{%-for row in extra_codes %}
|
|
43
|
+
{{ row-}}
|
|
44
|
+
{% endfor %}
|
|
45
|
+
|
|
46
|
+
{% for key, value in funcs.items() %}
|
|
47
|
+
|
|
48
|
+
fn {{ key }}(mut df: LazyFrame) -> LazyFrame{
|
|
49
|
+
{{ value }}
|
|
50
|
+
df}
|
|
51
|
+
|
|
52
|
+
{% endfor %}
|
|
53
|
+
|
|
54
|
+
/*
|
|
55
|
+
{%-for row in exprs_dst %}
|
|
56
|
+
{{ row-}}
|
|
57
|
+
{% endfor %}
|
|
58
|
+
*/
|
|
59
|
+
|
|
60
|
+
/*
|
|
61
|
+
{%-for a,b,c in exprs_src %}
|
|
62
|
+
{{ a }} = {{ b}} {{c-}}
|
|
63
|
+
{% endfor %}
|
|
64
|
+
*/
|
|
65
|
+
|
|
66
|
+
pub fn main_(mut df: LazyFrame, ge_date_idx: i64) -> LazyFrame {
|
|
67
|
+
{% for key, value in groupbys.items() %}
|
|
68
|
+
{{ value-}}
|
|
69
|
+
{% endfor %}
|
|
70
|
+
|
|
71
|
+
df
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
pub fn _filter_last(df: LazyFrame, ge_date_idx: i64) -> LazyFrame {
|
|
75
|
+
if ge_date_idx == 0 {
|
|
76
|
+
df
|
|
77
|
+
} else {
|
|
78
|
+
let date_expr = col(_DATE_)
|
|
79
|
+
.unique()
|
|
80
|
+
.sort(SortOptions::default())
|
|
81
|
+
.slice(ge_date_idx, 1)
|
|
82
|
+
.first();
|
|
83
|
+
|
|
84
|
+
df.filter(col(_DATE_).gt_eq(date_expr))
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
89
|
+
// 目前生成的代码还需要调整才能投入使用,还有更多的函数需要补充
|
|
90
|
+
let mut df = df! (
|
|
91
|
+
"date" => [1, 2, 3, 4, 5, 1, 2, 3, 4],
|
|
92
|
+
"asset" => [1, 2, 3, 1, 2, 3, 1, 2, 3],
|
|
93
|
+
"OPEN" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
94
|
+
"HIGH" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
95
|
+
"LOW" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
96
|
+
"CLOSE" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
97
|
+
)?;
|
|
98
|
+
|
|
99
|
+
df = main_(df.lazy(), 0)
|
|
100
|
+
.drop(Selector::Matches("^_.*$".into()))
|
|
101
|
+
.collect()?;
|
|
102
|
+
println!("{:?}", df);
|
|
103
|
+
|
|
104
|
+
Ok(())
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
/*
|
|
108
|
+
# Cargo.toml
|
|
109
|
+
# https://docs.pola.rs/user-guide/installation/
|
|
110
|
+
|
|
111
|
+
[package]
|
|
112
|
+
name = "expr_codegen_rs"
|
|
113
|
+
version = "0.1.0"
|
|
114
|
+
edition = "2024"
|
|
115
|
+
|
|
116
|
+
[dependencies]
|
|
117
|
+
polars = { version = "0.51.0", features = [
|
|
118
|
+
"lazy",
|
|
119
|
+
"round_series",
|
|
120
|
+
"strings",
|
|
121
|
+
"regex",
|
|
122
|
+
"rolling_window",
|
|
123
|
+
] }
|
|
124
|
+
|
|
125
|
+
*/
|
|
File without changes
|
|
@@ -2,7 +2,7 @@ import inspect
|
|
|
2
2
|
import pathlib
|
|
3
3
|
from functools import lru_cache
|
|
4
4
|
from io import TextIOBase
|
|
5
|
-
from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable
|
|
5
|
+
from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable, Dict
|
|
6
6
|
|
|
7
7
|
import polars as pl
|
|
8
8
|
from black import Mode, format_str
|
|
@@ -201,7 +201,7 @@ class ExprTool:
|
|
|
201
201
|
asset)
|
|
202
202
|
return dag_end(G)
|
|
203
203
|
|
|
204
|
-
def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql'] = 'polars',
|
|
204
|
+
def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
|
|
205
205
|
template_file: Optional[str] = None,
|
|
206
206
|
replace: bool = True, regroup: bool = False, format: bool = True,
|
|
207
207
|
date='date', asset='asset',
|
|
@@ -245,7 +245,7 @@ class ExprTool:
|
|
|
245
245
|
代码字符串
|
|
246
246
|
|
|
247
247
|
"""
|
|
248
|
-
assert style in ('pandas', 'polars', 'sql')
|
|
248
|
+
assert style in ('pandas', 'polars', 'sql', 'rust')
|
|
249
249
|
|
|
250
250
|
if replace:
|
|
251
251
|
exprs_src = replace_exprs(exprs_src)
|
|
@@ -269,6 +269,9 @@ class ExprTool:
|
|
|
269
269
|
elif style == 'sql':
|
|
270
270
|
from expr_codegen.sql.code import codegen
|
|
271
271
|
format = False
|
|
272
|
+
elif style == 'rust':
|
|
273
|
+
from expr_codegen.rust.code import codegen
|
|
274
|
+
format = False
|
|
272
275
|
else:
|
|
273
276
|
raise ValueError(f'unknown style {style}')
|
|
274
277
|
|
|
@@ -290,13 +293,13 @@ class ExprTool:
|
|
|
290
293
|
|
|
291
294
|
return codes, G
|
|
292
295
|
|
|
293
|
-
@lru_cache(maxsize=64)
|
|
296
|
+
# @lru_cache(maxsize=64)
|
|
294
297
|
def _get_code(self,
|
|
295
298
|
source: str, *more_sources: str,
|
|
296
299
|
extra_codes: str,
|
|
297
300
|
output_file: str,
|
|
298
301
|
convert_xor: bool,
|
|
299
|
-
style: Literal['pandas', 'polars', 'sql'] = 'polars',
|
|
302
|
+
style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
|
|
300
303
|
template_file: Optional[str] = None,
|
|
301
304
|
date: str = 'date', asset: str = 'asset',
|
|
302
305
|
over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
|
|
@@ -304,9 +307,10 @@ class ExprTool:
|
|
|
304
307
|
ge_date_idx: int = 0,
|
|
305
308
|
skip_simplify: bool = False,
|
|
306
309
|
skip_columns: Iterable[str] = (),
|
|
310
|
+
function_mapping={},
|
|
307
311
|
**kwargs) -> str:
|
|
308
312
|
"""通过字符串生成代码, 加了缓存,多次调用不重复生成"""
|
|
309
|
-
raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
|
|
313
|
+
raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor, function_mapping=function_mapping)
|
|
310
314
|
|
|
311
315
|
# 生成代码
|
|
312
316
|
code, G = _TOOL_.all(exprs_list, style=style, template_file=template_file,
|
|
@@ -380,13 +384,14 @@ def codegen_exec(df: Union[DataFrame, None],
|
|
|
380
384
|
output_file: Union[str, TextIOBase, None] = None,
|
|
381
385
|
run_file: Union[bool, str] = False,
|
|
382
386
|
convert_xor: bool = False,
|
|
383
|
-
style: Literal['pandas', 'polars', 'sql'] = 'polars',
|
|
387
|
+
style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
|
|
384
388
|
template_file: Optional[str] = None,
|
|
385
389
|
date: str = 'date', asset: str = 'asset',
|
|
386
390
|
table_name: str = 'self',
|
|
387
391
|
ge_date_idx: int = 0,
|
|
388
392
|
skip_simplify: bool = False,
|
|
389
393
|
skip_columns: Iterable[str] = (),
|
|
394
|
+
function_mapping: Dict = {},
|
|
390
395
|
**kwargs) -> Union[DataFrame, str]:
|
|
391
396
|
"""快速转换源代码并执行
|
|
392
397
|
|
|
@@ -437,6 +442,8 @@ def codegen_exec(df: Union[DataFrame, None],
|
|
|
437
442
|
已经存在的列不参与计算。可用于加快计算速度。只在计算耗时久时再用,否则没有必要
|
|
438
443
|
例如:在研发阶段,第一次计算100个因子,第二次,只改动了其中的5个,所以只要将这5个从df.columns中排除即可。
|
|
439
444
|
注意:生成的源代码有差异。
|
|
445
|
+
function_mapping:
|
|
446
|
+
传入函数定义,可直接传`globals()`。用于将所有的关键字参数转换成位置参数
|
|
440
447
|
|
|
441
448
|
Returns
|
|
442
449
|
-------
|
|
@@ -494,12 +501,15 @@ def codegen_exec(df: Union[DataFrame, None],
|
|
|
494
501
|
ge_date_idx=ge_date_idx,
|
|
495
502
|
skip_simplify=skip_simplify,
|
|
496
503
|
skip_columns=skip_columns,
|
|
504
|
+
function_mapping=function_mapping,
|
|
497
505
|
**kwargs
|
|
498
506
|
)
|
|
499
507
|
|
|
500
508
|
if df is None:
|
|
501
509
|
# 如果df为空,直接返回代码
|
|
502
510
|
return code
|
|
511
|
+
elif style == 'rust':
|
|
512
|
+
return code
|
|
503
513
|
elif style == 'sql':
|
|
504
514
|
with pl.SQLContext(frames={table_name: df}) as ctx:
|
|
505
515
|
return ctx.execute(code, eager=isinstance(df, _pl_DataFrame))
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.15.2"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|