expr-codegen 0.15.2__tar.gz → 0.16.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/PKG-INFO +2 -1
  2. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/README.md +1 -0
  3. expr_codegen-0.16.1/expr_codegen/_version.py +1 -0
  4. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/codes.py +63 -3
  5. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/code.py +9 -4
  6. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/code.py +9 -4
  7. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/template.py.j2 +2 -3
  8. expr_codegen-0.16.1/expr_codegen/rust/code.py +150 -0
  9. expr_codegen-0.16.1/expr_codegen/rust/printer.py +115 -0
  10. expr_codegen-0.16.1/expr_codegen/rust/template.rs.j2 +125 -0
  11. expr_codegen-0.16.1/expr_codegen/sql/__init__.py +0 -0
  12. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/tool.py +17 -7
  13. expr_codegen-0.15.2/expr_codegen/_version.py +0 -1
  14. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/.gitignore +0 -0
  15. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/LICENSE +0 -0
  16. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/__init__.py +0 -0
  17. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/dag.py +0 -0
  18. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/expr.py +0 -0
  19. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/latex/__init__.py +0 -0
  20. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/latex/printer.py +0 -0
  21. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/model.py +0 -0
  22. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/__init__.py +0 -0
  23. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/helper.py +0 -0
  24. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/printer.py +0 -0
  25. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/ta.py +0 -0
  26. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/pandas/template.py.j2 +0 -0
  27. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/__init__.py +0 -0
  28. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/polars/printer.py +0 -0
  29. {expr_codegen-0.15.2/expr_codegen/sql → expr_codegen-0.16.1/expr_codegen/rust}/__init__.py +0 -0
  30. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/code.py +0 -0
  31. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/printer.py +0 -0
  32. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/expr_codegen/sql/template.sql.j2 +0 -0
  33. {expr_codegen-0.15.2 → expr_codegen-0.16.1}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expr_codegen
3
- Version: 0.15.2
3
+ Version: 0.16.1
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -276,6 +276,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
276
276
  ```
277
277
  11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
278
278
  12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
279
+ 13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
279
280
 
280
281
  ## 下划线开头的变量
281
282
 
@@ -226,6 +226,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
226
226
  ```
227
227
  11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
228
228
  12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
229
+ 13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
229
230
 
230
231
  ## 下划线开头的变量
231
232
 
@@ -0,0 +1 @@
1
+ __version__ = "0.16.1"
@@ -1,4 +1,5 @@
1
1
  import ast
2
+ import inspect
2
3
  import re
3
4
  from ast import expr
4
5
 
@@ -333,6 +334,63 @@ class RenameTransformer(ast.NodeTransformer):
333
334
  return node
334
335
 
335
336
 
337
+ class KeywordToPositionalTransformer(ast.NodeTransformer):
338
+ def __init__(self, function_mapping):
339
+ self.function_mapping = function_mapping # 函数名到实际函数的映射
340
+
341
+ def visit_Call(self, node):
342
+ if isinstance(node.func, ast.Name) and node.keywords:
343
+ func_name = node.func.id
344
+ if func_name in self.function_mapping:
345
+ return self.transform_call(node, self.function_mapping[func_name])
346
+ return node
347
+
348
+ def transform_call(self, node, target_func):
349
+ try:
350
+ # 获取函数参数签名
351
+ sig = inspect.signature(target_func)
352
+ param_names = list(sig.parameters.keys())
353
+
354
+ # 构建参数映射
355
+ arg_mapping = {}
356
+ # 处理现有位置参数
357
+ for i, arg in enumerate(node.args):
358
+ if i < len(param_names):
359
+ arg_mapping[param_names[i]] = arg
360
+
361
+ # 处理关键字参数
362
+ for keyword in node.keywords:
363
+ if keyword.arg in param_names:
364
+ arg_mapping[keyword.arg] = keyword.value
365
+
366
+ # 按参数顺序构建新的位置参数列表
367
+ new_args = []
368
+ for param_name in param_names:
369
+ if param_name in arg_mapping:
370
+ new_args.append(arg_mapping[param_name])
371
+ else:
372
+ # 对于没有提供的参数,需要处理默认值
373
+ param = sig.parameters[param_name]
374
+ if param.default != inspect.Parameter.empty:
375
+ # 使用默认值
376
+ new_args.append(ast.Constant(value=param.default))
377
+ else:
378
+ # 必需参数缺失,保持原样或报错
379
+ return node
380
+
381
+ # 创建新的调用节点
382
+ new_node = ast.Call(
383
+ func=node.func,
384
+ args=new_args,
385
+ keywords=[]
386
+ )
387
+ return new_node
388
+
389
+ except Exception as e:
390
+ # 转换失败时返回原节点
391
+ return node
392
+
393
+
336
394
  def source_replace(source: str) -> str:
337
395
  # 三元表达式转换成 错误版if( )else,一定得在Transformer中修正
338
396
  num = 1
@@ -374,7 +432,7 @@ def raw_to_code(raw):
374
432
  return '\n'.join([ast_comments.unparse(a) for a in raw])
375
433
 
376
434
 
377
- def sources_to_asts(*sources, convert_xor: bool):
435
+ def sources_to_asts(*sources, convert_xor: bool, function_mapping):
378
436
  """输入多份源代码"""
379
437
 
380
438
  def _source_to_asts(source):
@@ -394,6 +452,8 @@ def sources_to_asts(*sources, convert_xor: bool):
394
452
 
395
453
  t1 = SyntaxTransformer(convert_xor)
396
454
  t1.visit(tree)
455
+ t2 = KeywordToPositionalTransformer(function_mapping)
456
+ t2.visit(tree)
397
457
  t = RenameTransformer({}, {})
398
458
  t.visit(tree)
399
459
 
@@ -429,12 +489,12 @@ def _add_default_type(globals_):
429
489
  return globals_
430
490
 
431
491
 
432
- def sources_to_exprs(globals_, *sources, convert_xor: bool):
492
+ def sources_to_exprs(globals_, *sources, convert_xor: bool, function_mapping):
433
493
  """将源代码转换成表达式"""
434
494
 
435
495
  globals_ = _add_default_type(globals_)
436
496
 
437
- raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor)
497
+ raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor, function_mapping=function_mapping)
438
498
  # 支持OPEN[1]转ts_delay(OPEN,1)
439
499
  funcs_new.add('ts_delay')
440
500
 
@@ -13,18 +13,23 @@ def get_groupby_from_tuple(tup, func_name, drop_cols):
13
13
  """从传入的元组中生成分组运行代码"""
14
14
  prefix2, *_ = tup
15
15
 
16
+ if len(drop_cols)>0:
17
+ drop_str = f'.drop(columns={drop_cols})'
18
+ else:
19
+ drop_str = ""
20
+
16
21
  if prefix2 == TS:
17
22
  # 组内需要按时间进行排序,需要维持顺序
18
23
  prefix2, asset = tup
19
- return f'df = df.groupby(by=[_ASSET_], group_keys=False).apply({func_name}).drop(columns={drop_cols})'
24
+ return f'df = df.groupby(by=[_ASSET_], group_keys=False).apply({func_name}){drop_str}'
20
25
  if prefix2 == CS:
21
26
  prefix2, date = tup
22
- return f'df = df.groupby(by=[_DATE_], group_keys=False).apply({func_name}).drop(columns={drop_cols})'
27
+ return f'df = df.groupby(by=[_DATE_], group_keys=False).apply({func_name}){drop_str}'
23
28
  if prefix2 == GP:
24
29
  prefix2, date, group = tup
25
- return f'df = df.groupby(by=[_DATE_, "{group}"], group_keys=False).apply({func_name}).drop(columns={drop_cols})'
30
+ return f'df = df.groupby(by=[_DATE_, "{group}"], group_keys=False).apply({func_name}){drop_str}'
26
31
 
27
- return f'df = {func_name}(df).drop(columns={drop_cols})'
32
+ return f'df = {func_name}(df){drop_str}'
28
33
 
29
34
 
30
35
  def symbols_to_code(syms):
@@ -14,18 +14,23 @@ def get_groupby_from_tuple(tup, func_name, drop_cols):
14
14
  """从传入的元组中生成分组运行代码"""
15
15
  prefix2, *_ = tup
16
16
 
17
+ if len(drop_cols)>0:
18
+ drop_str = f'.drop(*{drop_cols})'
19
+ else:
20
+ drop_str = ""
21
+
17
22
  if prefix2 == TS:
18
23
  # 组内需要按时间进行排序,需要维持顺序
19
24
  prefix2, asset = tup
20
- return f'df = {func_name}(df.sort(_ASSET_, _DATE_)).drop(*{drop_cols})'
25
+ return f'df = {func_name}(df.sort(_ASSET_, _DATE_)){drop_str}'
21
26
  if prefix2 == CS:
22
27
  prefix2, date = tup
23
- return f'df = {func_name}(df.sort(_DATE_)).drop(*{drop_cols})'
28
+ return f'df = {func_name}(df.sort(_DATE_)){drop_str}'
24
29
  if prefix2 == GP:
25
30
  prefix2, date, group = tup
26
- return f'df = {func_name}(df.sort(_DATE_, "{group}")).drop(*{drop_cols})'
31
+ return f'df = {func_name}(df.sort(_DATE_, "{group}")){drop_str}'
27
32
 
28
- return f'df = {func_name}(df).drop(*{drop_cols})'
33
+ return f'df = {func_name}(df){drop_str}'
29
34
 
30
35
 
31
36
  def symbols_to_code(syms):
@@ -63,12 +63,11 @@ def {{ key }}(df: DataFrame) -> DataFrame:
63
63
 
64
64
 
65
65
  def _filter_last(df: DataFrame, ge_date_idx: int) -> DataFrame:
66
- """过滤数据,只取最后几天。实盘时可用于减少计算量
67
- """
66
+ """过滤数据,只取最后几天。实盘时可用于减少计算量"""
68
67
  if ge_date_idx == 0:
69
68
  return df
70
69
  else:
71
- return df.filter(pl.col(_DATE_) >= df.select(pl.col(_DATE_).unique().sort())[ge_date_idx, 0])
70
+ return df.filter(pl.col(_DATE_) >= pl.col(_DATE_).unique().sort().slice(ge_date_idx, 1).first())
72
71
 
73
72
 
74
73
  def main(df: DataFrame, ge_date_idx: int) -> DataFrame:
@@ -0,0 +1,150 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ from typing import Sequence, Literal
5
+
6
+ import jinja2
7
+ from jinja2 import FileSystemLoader, TemplateNotFound
8
+
9
+ from expr_codegen.expr import TS, CS, GP
10
+ from expr_codegen.model import ListDictList
11
+ from expr_codegen.rust.printer import RustStrPrinter
12
+
13
+
14
+ def get_groupby_from_tuple(tup, func_name, drop_cols):
15
+ """从传入的元组中生成分组运行代码"""
16
+ prefix2, *_ = tup
17
+ if len(drop_cols)>0:
18
+ drop_cols = [f'"{c}".into()' for c in drop_cols]
19
+ drop_str = f'.drop(Selector::ByName {{ names: Arc::new([{','.join(drop_cols)}]), strict: true }})'
20
+ else:
21
+ drop_str = ""
22
+
23
+ if prefix2 == TS:
24
+ # 组内需要按时间进行排序,需要维持顺序
25
+ prefix2, asset = tup
26
+ return f'df = {func_name}(df.sort([_ASSET_, _DATE_], SortMultipleOptions::default())){drop_str};'
27
+ if prefix2 == CS:
28
+ prefix2, date = tup
29
+ return f'df = {func_name}(df.sort([_DATE_], SortMultipleOptions::default())){drop_str};'
30
+ if prefix2 == GP:
31
+ prefix2, date, group = tup
32
+ return f'df = {func_name}(df.sort([_DATE_, "{group}"], SortMultipleOptions::default())){drop_str};'
33
+
34
+ return f'df = {func_name}(df){drop_str};'
35
+
36
+
37
+ # def symbols_to_code(syms):
38
+ # a = [f"{s}" for s in syms]
39
+ # b = [f"'{s}'" for s in syms]
40
+ # return f"""_ = [{','.join(b)}]
41
+ # [{','.join(a)}] = [pl.col(i) for i in _]"""
42
+
43
+
44
+ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
45
+ filename,
46
+ date='date', asset='asset',
47
+ extra_codes: Sequence[str] = (),
48
+ over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
49
+ ge_date_idx: int = 0,
50
+ **kwargs):
51
+ """基于模板的代码生成"""
52
+ if filename is None:
53
+ filename = 'template.rs.j2'
54
+
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument("--over_null", type=str, nargs="?", default=over_null)
57
+
58
+ # 打印Polars风格代码
59
+ p = RustStrPrinter()
60
+
61
+ # polars风格代码
62
+ funcs = {}
63
+ # 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前
64
+ groupbys = {'sort': ''}
65
+ # 处理过后的表达式
66
+ exprs_dst = []
67
+ syms_out = []
68
+ ts_func_name = None
69
+ drop_symbols = exprs_ldl.drop_symbols()
70
+ j = -1
71
+ for i, row in enumerate(exprs_ldl.values()):
72
+ for k, vv in row.items():
73
+ j += 1
74
+ if len(vv) == 0:
75
+ continue
76
+ # 函数名
77
+ func_name = f'func_{i}_{"_".join(k)}'
78
+ func_code = []
79
+ for kv in vv:
80
+ if kv is None:
81
+ func_code.append(f" ]);")
82
+ func_code.append(f"// " + '=' * 40)
83
+ func_code.append(f" df = df.with_columns([")
84
+ exprs_dst.append(f"#" + '=' * 40 + func_name)
85
+ else:
86
+ va, ex, sym, comment = kv
87
+ # 多个#时,只取第一个#后的参数
88
+ args, argv = parser.parse_known_args(args=comment.split("#")[1].split(" "))
89
+ s1 = str(ex)
90
+ s2 = p.doprint(ex)
91
+ if s1 != s2:
92
+ # 不想等,打印注释,显示会更直观察
93
+ func_code.append(f"// {va} = {s1}")
94
+ if k[0] == TS:
95
+ ts_func_name = func_name
96
+ # https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629
97
+ _sym = [f'col("{s}").is_not_null()' for s in set(sym)]
98
+ if len(_sym) > 1:
99
+ _sym = f"all_horizontal([{','.join(_sym)}]).unwrap()"
100
+ else:
101
+ _sym = ','.join(_sym)
102
+ if args.over_null == 'partition_by':
103
+ func_code.append(f'({s2}).over_with_options(Some([{_sym}, col(_ASSET_)]), Some(([col(_DATE_), lit(1)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
104
+ elif args.over_null == 'order_by':
105
+ func_code.append(f'({s2}).over_with_options(Some([col(_ASSET_), lit(1)]), Some(([{_sym}, col(_DATE_)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
106
+ else:
107
+ func_code.append(f'({s2}).over_with_options(Some([_ASSET_]), Some(([_DATE_], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
108
+ elif k[0] == CS:
109
+ func_code.append(f'({s2}).over([_DATE_]).alias("{va}"),')
110
+ elif k[0] == GP:
111
+ func_code.append(f'({s2}).over([_DATE_, "{k[2]}"]).alias("{va}"),')
112
+ else:
113
+ func_code.append(f'({s2}).alias("{va}"),')
114
+ exprs_dst.append(f"{va} = {s1} {comment}")
115
+ if va not in syms_dst:
116
+ syms_out.append(va)
117
+ func_code.append(f" ]);")
118
+ func_code = func_code[1:]
119
+
120
+ # polars风格代码列表
121
+ funcs[func_name] = '\n'.join(func_code)
122
+ # 只有下划线开头再删除
123
+ ds = [x for x in drop_symbols[j] if x.startswith('_')]
124
+ # 分组应用代码
125
+ groupbys[func_name] = get_groupby_from_tuple(k, func_name, ds)
126
+
127
+ # syms1 = symbols_to_code(syms_dst)
128
+ # syms2 = symbols_to_code(syms_out)
129
+ # filter_last处理
130
+ _groupbys = {'sort': groupbys['sort']}
131
+ if ts_func_name is None:
132
+ _groupbys['_filter_last'] = "df = _filter_last(df, ge_date_idx);"
133
+ for k, v in groupbys.items():
134
+ _groupbys[k] = v
135
+ if k == ts_func_name:
136
+ _groupbys[k + '_filter_last'] = "df = _filter_last(df, ge_date_idx);"
137
+ groupbys = _groupbys
138
+
139
+ try:
140
+ env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__)))
141
+ template = env.get_template(filename)
142
+ except TemplateNotFound:
143
+ env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename)))
144
+ template = env.get_template(os.path.basename(filename))
145
+
146
+ return template.render(funcs=funcs, groupbys=groupbys,
147
+ exprs_src=exprs_src, exprs_dst=exprs_dst,
148
+ # syms1=syms1, syms2=syms2,
149
+ date=date, asset=asset,
150
+ extra_codes=extra_codes)
@@ -0,0 +1,115 @@
1
+ import inspect
2
+
3
+ from sympy import Basic, Function, StrPrinter
4
+
5
+
6
+ # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略
7
+
8
+ class RustStrPrinter(StrPrinter):
9
+ def _print(self, expr, **kwargs) -> str:
10
+ """Internal dispatcher
11
+
12
+ Tries the following concepts to print an expression:
13
+ 1. Let the object print itself if it knows how.
14
+ 2. Take the best fitting method defined in the printer.
15
+ 3. As fall-back use the emptyPrinter method for the printer.
16
+ """
17
+ self._print_level += 1
18
+ try:
19
+ # If the printer defines a name for a printing method
20
+ # (Printer.printmethod) and the object knows for itself how it
21
+ # should be printed, use that method.
22
+ if self.printmethod and hasattr(expr, self.printmethod):
23
+ if not (isinstance(expr, type) and issubclass(expr, Basic)):
24
+ return getattr(expr, self.printmethod)(self, **kwargs)
25
+
26
+ # See if the class of expr is known, or if one of its super
27
+ # classes is known, and use that print function
28
+ # Exception: ignore the subclasses of Undefined, so that, e.g.,
29
+ # Function('gamma') does not get dispatched to _print_gamma
30
+ classes = type(expr).__mro__
31
+ # if AppliedUndef in classes:
32
+ # classes = classes[classes.index(AppliedUndef):]
33
+ # if UndefinedFunction in classes:
34
+ # classes = classes[classes.index(UndefinedFunction):]
35
+ # Another exception: if someone subclasses a known function, e.g.,
36
+ # gamma, and changes the name, then ignore _print_gamma
37
+ if Function in classes:
38
+ i = classes.index(Function)
39
+ classes = tuple(c for c in classes[:i] if \
40
+ c.__name__ == classes[0].__name__ or \
41
+ c.__name__.endswith("Base")) + classes[i:]
42
+ for cls in classes:
43
+ printmethodname = '_print_' + cls.__name__
44
+
45
+ # 所有以gp_开头的函数都转换成cs_开头
46
+ if printmethodname.startswith('_print_gp_'):
47
+ printmethodname = "_print_gp_"
48
+
49
+ printmethod = getattr(self, printmethodname, None)
50
+ if printmethod is not None:
51
+ return printmethod(expr, **kwargs)
52
+ # Unknown object, fall back to the emptyPrinter.
53
+ return self.emptyPrinter(expr)
54
+ finally:
55
+ self._print_level -= 1
56
+
57
+ def _print_Symbol(self, expr):
58
+ if expr.name in ('_NONE_', '_TRUE_', '_FALSE_'):
59
+ return expr.name
60
+ return f'col("{expr.name}")'
61
+
62
+ def _print_Equality(self, expr):
63
+ new_args = [f"eq({self._print(arg)})" for arg in expr.args]
64
+ return ".".join(new_args)[2:]
65
+
66
+ def _print_Or(self, expr):
67
+ new_args = [f"or({self._print(arg)})" for arg in expr.args]
68
+ return ".".join(new_args)[2:]
69
+
70
+ def _print_Xor(self, expr):
71
+ new_args = [f"xor({self._print(arg)})" for arg in expr.args]
72
+ return ".".join(new_args)[3:]
73
+
74
+ def _print_And(self, expr):
75
+ new_args = [f"and({self._print(arg)})" for arg in expr.args]
76
+ return ".".join(new_args)[3:]
77
+
78
+ def _print_Not(self, expr):
79
+ return "(%s).not()" % self._print(expr.args[0])
80
+
81
+ def _print_gp_(self, expr):
82
+ """gp_函数都转换成cs_函数,但要丢弃第一个参数"""
83
+ new_args = [self._print(arg) for arg in expr.args[1:]]
84
+ func_name = expr.func.__name__[3:]
85
+ return "cs_%s(%s)" % (func_name, ",".join(new_args))
86
+
87
+ def _print_Integer(self, expr):
88
+ caller_frame = inspect.stack()[2]
89
+ caller_name = caller_frame.function
90
+ if caller_name in ("_print_Pow", "_print_Add", "_print_Mul", "_print_Relational"):
91
+ return "lit(%s)" % super()._print_Integer(expr)
92
+ else:
93
+ return super()._print_Integer(expr)
94
+
95
+ def _print_Float(self, expr):
96
+ caller_frame = inspect.stack()[2]
97
+ caller_name = caller_frame.function
98
+ if caller_name in ("_print_Pow", "_print_Add", "_print_Mul", "_print_Relational"):
99
+ return "lit(%s)" % super()._print_Float(expr)
100
+ else:
101
+ return super()._print_Float(expr)
102
+
103
+ def _print_Relational(self, expr):
104
+
105
+ charmap = {
106
+ "<": "lt",
107
+ ">": "gt",
108
+ ">=": "gt_eq",
109
+ "<=": "lt_eq",
110
+ }
111
+
112
+ if expr.rel_op in charmap:
113
+ return '(%s).%s(%s)' % (self._print(expr.lhs), charmap[expr.rel_op], self._print(expr.rhs))
114
+
115
+ return super()._print_Relational(expr)
@@ -0,0 +1,125 @@
1
+ // this code is auto generated by the expr_codegen
2
+ // https://github.com/wukan1986/expr_codegen
3
+ // 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request
4
+ use polars::prelude::*;
5
+
6
+ pub fn ts_delay(x: Expr, n: i16) -> Expr {
7
+ x.shift(lit(n))
8
+ }
9
+
10
+ pub fn ts_mean(x: Expr, d: usize) -> Expr {
11
+ x.rolling_mean(RollingOptionsFixedWindow {
12
+ window_size: d,
13
+ min_periods: d,
14
+ weights: None,
15
+ center: false,
16
+ fn_params: None,
17
+ })
18
+ }
19
+
20
+ pub fn ts_sum(x: Expr, d: usize) -> Expr {
21
+ x.rolling_sum(RollingOptionsFixedWindow {
22
+ window_size: d,
23
+ min_periods: d,
24
+ weights: None,
25
+ center: false,
26
+ fn_params: None,
27
+ })
28
+ }
29
+
30
+ pub fn cs_zscore(x: Expr, ddof: u8) -> Expr {
31
+ (x.clone() - x.clone().mean()) / x.clone().std(ddof)
32
+ }
33
+
34
+ const _DATE_: &str = "date";
35
+ const _ASSET_: &str = "asset";
36
+ const _NONE_: Option<i32> = None;
37
+ const _TRUE_: bool = true;
38
+ const _FALSE_: bool = false;
39
+
40
+
41
+
42
+ {%-for row in extra_codes %}
43
+ {{ row-}}
44
+ {% endfor %}
45
+
46
+ {% for key, value in funcs.items() %}
47
+
48
+ fn {{ key }}(mut df: LazyFrame) -> LazyFrame{
49
+ {{ value }}
50
+ df}
51
+
52
+ {% endfor %}
53
+
54
+ /*
55
+ {%-for row in exprs_dst %}
56
+ {{ row-}}
57
+ {% endfor %}
58
+ */
59
+
60
+ /*
61
+ {%-for a,b,c in exprs_src %}
62
+ {{ a }} = {{ b}} {{c-}}
63
+ {% endfor %}
64
+ */
65
+
66
+ pub fn main_(mut df: LazyFrame, ge_date_idx: i64) -> LazyFrame {
67
+ {% for key, value in groupbys.items() %}
68
+ {{ value-}}
69
+ {% endfor %}
70
+
71
+ df
72
+ }
73
+
74
+ pub fn _filter_last(df: LazyFrame, ge_date_idx: i64) -> LazyFrame {
75
+ if ge_date_idx == 0 {
76
+ df
77
+ } else {
78
+ let date_expr = col(_DATE_)
79
+ .unique()
80
+ .sort(SortOptions::default())
81
+ .slice(ge_date_idx, 1)
82
+ .first();
83
+
84
+ df.filter(col(_DATE_).gt_eq(date_expr))
85
+ }
86
+ }
87
+
88
+ fn main() -> Result<(), Box<dyn std::error::Error>> {
89
+ // 目前生成的代码还需要调整才能投入使用,还有更多的函数需要补充
90
+ let mut df = df! (
91
+ "date" => [1, 2, 3, 4, 5, 1, 2, 3, 4],
92
+ "asset" => [1, 2, 3, 1, 2, 3, 1, 2, 3],
93
+ "OPEN" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
94
+ "HIGH" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
95
+ "LOW" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
96
+ "CLOSE" => [1, 2, 3, 4, 5, 6, 7, 8, 9],
97
+ )?;
98
+
99
+ df = main_(df.lazy(), 0)
100
+ .drop(Selector::Matches("^_.*$".into()))
101
+ .collect()?;
102
+ println!("{:?}", df);
103
+
104
+ Ok(())
105
+ }
106
+
107
+ /*
108
+ # Cargo.toml
109
+ # https://docs.pola.rs/user-guide/installation/
110
+
111
+ [package]
112
+ name = "expr_codegen_rs"
113
+ version = "0.1.0"
114
+ edition = "2024"
115
+
116
+ [dependencies]
117
+ polars = { version = "0.51.0", features = [
118
+ "lazy",
119
+ "round_series",
120
+ "strings",
121
+ "regex",
122
+ "rolling_window",
123
+ ] }
124
+
125
+ */
File without changes
@@ -2,7 +2,7 @@ import inspect
2
2
  import pathlib
3
3
  from functools import lru_cache
4
4
  from io import TextIOBase
5
- from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable
5
+ from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable, Dict
6
6
 
7
7
  import polars as pl
8
8
  from black import Mode, format_str
@@ -201,7 +201,7 @@ class ExprTool:
201
201
  asset)
202
202
  return dag_end(G)
203
203
 
204
- def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql'] = 'polars',
204
+ def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
205
205
  template_file: Optional[str] = None,
206
206
  replace: bool = True, regroup: bool = False, format: bool = True,
207
207
  date='date', asset='asset',
@@ -245,7 +245,7 @@ class ExprTool:
245
245
  代码字符串
246
246
 
247
247
  """
248
- assert style in ('pandas', 'polars', 'sql')
248
+ assert style in ('pandas', 'polars', 'sql', 'rust')
249
249
 
250
250
  if replace:
251
251
  exprs_src = replace_exprs(exprs_src)
@@ -269,6 +269,9 @@ class ExprTool:
269
269
  elif style == 'sql':
270
270
  from expr_codegen.sql.code import codegen
271
271
  format = False
272
+ elif style == 'rust':
273
+ from expr_codegen.rust.code import codegen
274
+ format = False
272
275
  else:
273
276
  raise ValueError(f'unknown style {style}')
274
277
 
@@ -290,13 +293,13 @@ class ExprTool:
290
293
 
291
294
  return codes, G
292
295
 
293
- @lru_cache(maxsize=64)
296
+ # @lru_cache(maxsize=64)
294
297
  def _get_code(self,
295
298
  source: str, *more_sources: str,
296
299
  extra_codes: str,
297
300
  output_file: str,
298
301
  convert_xor: bool,
299
- style: Literal['pandas', 'polars', 'sql'] = 'polars',
302
+ style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
300
303
  template_file: Optional[str] = None,
301
304
  date: str = 'date', asset: str = 'asset',
302
305
  over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
@@ -304,9 +307,10 @@ class ExprTool:
304
307
  ge_date_idx: int = 0,
305
308
  skip_simplify: bool = False,
306
309
  skip_columns: Iterable[str] = (),
310
+ function_mapping={},
307
311
  **kwargs) -> str:
308
312
  """通过字符串生成代码, 加了缓存,多次调用不重复生成"""
309
- raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
313
+ raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor, function_mapping=function_mapping)
310
314
 
311
315
  # 生成代码
312
316
  code, G = _TOOL_.all(exprs_list, style=style, template_file=template_file,
@@ -380,13 +384,14 @@ def codegen_exec(df: Union[DataFrame, None],
380
384
  output_file: Union[str, TextIOBase, None] = None,
381
385
  run_file: Union[bool, str] = False,
382
386
  convert_xor: bool = False,
383
- style: Literal['pandas', 'polars', 'sql'] = 'polars',
387
+ style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
384
388
  template_file: Optional[str] = None,
385
389
  date: str = 'date', asset: str = 'asset',
386
390
  table_name: str = 'self',
387
391
  ge_date_idx: int = 0,
388
392
  skip_simplify: bool = False,
389
393
  skip_columns: Iterable[str] = (),
394
+ function_mapping: Dict = {},
390
395
  **kwargs) -> Union[DataFrame, str]:
391
396
  """快速转换源代码并执行
392
397
 
@@ -437,6 +442,8 @@ def codegen_exec(df: Union[DataFrame, None],
437
442
  已经存在的列不参与计算。可用于加快计算速度。只在计算耗时久时再用,否则没有必要
438
443
  例如:在研发阶段,第一次计算100个因子,第二次,只改动了其中的5个,所以只要将这5个从df.columns中排除即可。
439
444
  注意:生成的源代码有差异。
445
+ function_mapping:
446
+ 传入函数定义,可直接传`globals()`。用于将所有的关键字参数转换成位置参数
440
447
 
441
448
  Returns
442
449
  -------
@@ -494,12 +501,15 @@ def codegen_exec(df: Union[DataFrame, None],
494
501
  ge_date_idx=ge_date_idx,
495
502
  skip_simplify=skip_simplify,
496
503
  skip_columns=skip_columns,
504
+ function_mapping=function_mapping,
497
505
  **kwargs
498
506
  )
499
507
 
500
508
  if df is None:
501
509
  # 如果df为空,直接返回代码
502
510
  return code
511
+ elif style == 'rust':
512
+ return code
503
513
  elif style == 'sql':
504
514
  with pl.SQLContext(frames={table_name: df}) as ctx:
505
515
  return ctx.execute(code, eager=isinstance(df, _pl_DataFrame))
@@ -1 +0,0 @@
1
- __version__ = "0.15.2"
File without changes
File without changes