expr-codegen 0.16.2__tar.gz → 0.16.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/PKG-INFO +2 -2
  2. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/README.md +1 -1
  3. expr_codegen-0.16.4/expr_codegen/_version.py +1 -0
  4. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/codes.py +2 -2
  5. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/model.py +4 -2
  6. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/polars/code.py +11 -6
  7. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/polars/template.py.j2 +1 -0
  8. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/rust/code.py +11 -7
  9. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/sql/code.py +10 -5
  10. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/tool.py +31 -21
  11. expr_codegen-0.16.2/expr_codegen/_version.py +0 -1
  12. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/.gitignore +0 -0
  13. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/LICENSE +0 -0
  14. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/__init__.py +0 -0
  15. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/dag.py +0 -0
  16. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/expr.py +0 -0
  17. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/latex/__init__.py +0 -0
  18. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/latex/printer.py +0 -0
  19. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/__init__.py +0 -0
  20. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/code.py +0 -0
  21. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/helper.py +0 -0
  22. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/printer.py +0 -0
  23. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/ta.py +0 -0
  24. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/pandas/template.py.j2 +0 -0
  25. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/polars/__init__.py +0 -0
  26. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/polars/printer.py +0 -0
  27. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/rust/__init__.py +0 -0
  28. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/rust/printer.py +0 -0
  29. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/rust/template.rs.j2 +0 -0
  30. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/sql/__init__.py +0 -0
  31. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/sql/printer.py +0 -0
  32. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/expr_codegen/sql/template.sql.j2 +0 -0
  33. {expr_codegen-0.16.2 → expr_codegen-0.16.4}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expr_codegen
3
- Version: 0.16.2
3
+ Version: 0.16.4
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -276,7 +276,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
276
276
  ```
277
277
  11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
278
278
  12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
279
- 13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
279
+ 13. 虽然`sympy`的限制不支持关键字参数,但`codegen_exec`底层会试着将关键字参数转成位置参数使用
280
280
 
281
281
  ## 下划线开头的变量
282
282
 
@@ -226,7 +226,7 @@ X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_),
226
226
  ```
227
227
  11. 单行注释支持参数输入,如:`# --over_null`、`# --over_null=order_by`、`# --over_null=partition_by`
228
228
  12. 代码块中,对`import`、`def`、`class`三种语法,自动原样插入到生成的代码中
229
- 13. 由于`sympy`的限制不支持关键字参数,但如果`from polars_ta.prefix.wq import *`,然后`codegen_exec(function_mapping=globals())`,底层会试着将关键字参数转成位置参数使用
229
+ 13. 虽然`sympy`的限制不支持关键字参数,但`codegen_exec`底层会试着将关键字参数转成位置参数使用
230
230
 
231
231
  ## 下划线开头的变量
232
232
 
@@ -0,0 +1 @@
1
+ __version__ = "0.16.4"
@@ -489,9 +489,9 @@ def _add_default_type(globals_):
489
489
  return globals_
490
490
 
491
491
 
492
- def sources_to_exprs(globals_, *sources, convert_xor: bool, function_mapping):
492
+ def sources_to_exprs(globals_, *sources, convert_xor: bool):
493
493
  """将源代码转换成表达式"""
494
-
494
+ function_mapping = {k: v for k, v in globals_.items() if inspect.isfunction(v)}
495
495
  globals_ = _add_default_type(globals_)
496
496
 
497
497
  raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor, function_mapping=function_mapping)
@@ -365,12 +365,14 @@ def dag_start(exprs_list, func, func_kwargs, date, asset):
365
365
  return G
366
366
 
367
367
 
368
- def dag_middle(G, exprs_names, skip_columns, func, func_kwargs, date, asset):
368
+ def dag_middle(G, exprs_names, skip_columns, func, func_kwargs, date, asset, skip_simplify):
369
369
  """删除几个没有必要的节点"""
370
370
  # 以下划线开头的节点,不保留
371
371
  keep_nodes = [k for k in exprs_names if not k.startswith('_')]
372
372
 
373
- G = merge_nodes_1(G, keep_nodes, *keep_nodes)
373
+ if not skip_simplify:
374
+ # ts_rank(-RET - -RET, 20),防止替换成ts_rank(0, 20)
375
+ G = merge_nodes_1(G, keep_nodes, *keep_nodes)
374
376
  G = merge_nodes_2(G, keep_nodes, *keep_nodes)
375
377
 
376
378
  # 移除0出度的节点,但保留部分
@@ -14,7 +14,7 @@ def get_groupby_from_tuple(tup, func_name, drop_cols):
14
14
  """从传入的元组中生成分组运行代码"""
15
15
  prefix2, *_ = tup
16
16
 
17
- if len(drop_cols)>0:
17
+ if len(drop_cols) > 0:
18
18
  drop_str = f'.drop(*{drop_cols})'
19
19
  else:
20
20
  drop_str = ""
@@ -98,12 +98,17 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
98
98
  _sym = f"pl.all_horizontal({','.join(_sym)})"
99
99
  else:
100
100
  _sym = ','.join(_sym)
101
- if args.over_null == 'partition_by':
102
- func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),")
103
- elif args.over_null == 'order_by':
104
- func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),")
105
- else:
101
+
102
+ if len(_sym) == 0:
106
103
  func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
104
+ else:
105
+ if args.over_null == 'partition_by':
106
+ func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),")
107
+ elif args.over_null == 'order_by':
108
+ func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),")
109
+ else:
110
+ func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),")
111
+
107
112
  elif k[0] == CS:
108
113
  func_code.append(f"{va}=({s2}).over(_DATE_),")
109
114
  elif k[0] == GP:
@@ -19,6 +19,7 @@ from polars_ta.prefix.ta import * # noqa
19
19
  from polars_ta.prefix.wq import * # noqa
20
20
  from polars_ta.prefix.cdl import * # noqa
21
21
  from polars_ta.prefix.vec import * # noqa
22
+ from polars_ta.utils.functions import apply_const_to_expr # noqa
22
23
 
23
24
  DataFrame = TypeVar('DataFrame', _pl_LazyFrame, _pl_DataFrame)
24
25
  # ===================================
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import json
3
2
  import os
4
3
  from typing import Sequence, Literal
5
4
 
@@ -14,7 +13,7 @@ from expr_codegen.rust.printer import RustStrPrinter
14
13
  def get_groupby_from_tuple(tup, func_name, drop_cols):
15
14
  """从传入的元组中生成分组运行代码"""
16
15
  prefix2, *_ = tup
17
- if len(drop_cols)>0:
16
+ if len(drop_cols) > 0:
18
17
  drop_cols = [f'"{c}".into()' for c in drop_cols]
19
18
  drop_str = f'.drop(Selector::ByName {{ names: Arc::new([{','.join(drop_cols)}]), strict: true }})'
20
19
  else:
@@ -99,12 +98,17 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
99
98
  _sym = f"all_horizontal([{','.join(_sym)}]).unwrap()"
100
99
  else:
101
100
  _sym = ','.join(_sym)
102
- if args.over_null == 'partition_by':
103
- func_code.append(f'({s2}).over_with_options(Some([{_sym}, col(_ASSET_)]), Some(([col(_DATE_), lit(1)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
104
- elif args.over_null == 'order_by':
105
- func_code.append(f'({s2}).over_with_options(Some([col(_ASSET_), lit(1)]), Some(([{_sym}, col(_DATE_)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
106
- else:
101
+
102
+ if len(_sym) == 0:
107
103
  func_code.append(f'({s2}).over_with_options(Some([_ASSET_]), Some(([_DATE_], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
104
+ else:
105
+ if args.over_null == 'partition_by':
106
+ func_code.append(f'({s2}).over_with_options(Some([{_sym}, col(_ASSET_)]), Some(([col(_DATE_), lit(1)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
107
+ elif args.over_null == 'order_by':
108
+ func_code.append(f'({s2}).over_with_options(Some([col(_ASSET_), lit(1)]), Some(([{_sym}, col(_DATE_)], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
109
+ else:
110
+ func_code.append(f'({s2}).over_with_options(Some([_ASSET_]), Some(([_DATE_], SortOptions::default())), WindowMapping::default()).unwrap().alias("{va}"),')
111
+
108
112
  elif k[0] == CS:
109
113
  func_code.append(f'({s2}).over([_DATE_]).alias("{va}"),')
110
114
  elif k[0] == GP:
@@ -65,12 +65,17 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
65
65
  _sym = f"({' AND '.join(_sym)})"
66
66
  else:
67
67
  _sym = ','.join(_sym)
68
- if args.over_null == 'partition_by':
69
- func_code.append(f"{s2} OVER(PARTITION BY {_sym},`{asset}` ORDER BY `{date}`) AS {va},")
70
- elif args.over_null == 'order_by':
71
- func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY {_sym},`{date}`) AS {va},")
72
- else:
68
+
69
+ if len(_sym) == 0:
73
70
  func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY `{date}`) AS {va},")
71
+ else:
72
+ if args.over_null == 'partition_by':
73
+ func_code.append(f"{s2} OVER(PARTITION BY {_sym},`{asset}` ORDER BY `{date}`) AS {va},")
74
+ elif args.over_null == 'order_by':
75
+ func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY {_sym},`{date}`) AS {va},")
76
+ else:
77
+ func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY `{date}`) AS {va},")
78
+
74
79
  elif k[0] == CS:
75
80
  func_code.append(f"{s2} OVER(PARTITION BY `{date}`) AS {va},")
76
81
  elif k[0] == GP:
@@ -1,8 +1,9 @@
1
1
  import inspect
2
+ import os
2
3
  import pathlib
3
4
  from functools import lru_cache
4
5
  from io import TextIOBase
5
- from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable, Dict
6
+ from typing import Sequence, Union, TypeVar, Optional, Literal, Iterable, Dict, Tuple
6
7
 
7
8
  import polars as pl
8
9
  from black import Mode, format_str
@@ -148,7 +149,7 @@ class ExprTool:
148
149
 
149
150
  return exprs_list
150
151
 
151
- def cse(self, exprs, symbols_repl=None, exprs_src=None):
152
+ def cse(self, exprs, symbols_repl=None, exprs_src=None, skip_simplify=False):
152
153
  """多个子公式+长公式,提取公共公式
153
154
 
154
155
  Parameters
@@ -175,7 +176,11 @@ class ExprTool:
175
176
  _exprs = [k for k, v in exprs]
176
177
 
177
178
  # 注意:对于表达式右边相同,左边不同的情况,会当成一个处理
178
- repl, redu = cse(_exprs, symbols_repl, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post), ])
179
+ if skip_simplify:
180
+ repl, redu = cse(_exprs, symbols_repl, optimizations=[])
181
+ else:
182
+ repl, redu = cse(_exprs, symbols_repl, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post), ])
183
+
179
184
  outputs_len = len(exprs_src)
180
185
 
181
186
  new_redu = []
@@ -193,12 +198,11 @@ class ExprTool:
193
198
 
194
199
  return self.exprs_list
195
200
 
196
- def dag(self, merge: bool, skip_columns, date, asset):
201
+ def dag(self, merge: bool, skip_columns, date, asset, skip_simplify):
197
202
  """生成DAG"""
198
203
  G = dag_start(self.exprs_list, self.get_current_func, self.get_current_func_kwargs, date, asset)
199
204
  if merge:
200
- G = dag_middle(G, self.exprs_names, skip_columns, self.get_current_func, self.get_current_func_kwargs, date,
201
- asset)
205
+ G = dag_middle(G, self.exprs_names, skip_columns, self.get_current_func, self.get_current_func_kwargs, date, asset, skip_simplify)
202
206
  return dag_end(G)
203
207
 
204
208
  def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql', 'rust'] = 'polars',
@@ -251,13 +255,13 @@ class ExprTool:
251
255
  exprs_src = replace_exprs(exprs_src)
252
256
 
253
257
  # 子表达式在前,原表式在最后
254
- exprs_dst, syms_dst = self.merge(date, asset, exprs_src, skip_simplify)
258
+ exprs_dst, syms_dst = self.merge(date, asset, exprs_src, skip_simplify=skip_simplify)
255
259
  syms_dst = list(set(syms_dst) - _RESERVED_WORD_)
256
260
 
257
261
  # 提取公共表达式
258
- self.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), exprs_src=exprs_src)
262
+ self.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), exprs_src=exprs_src, skip_simplify=skip_simplify)
259
263
  # 有向无环图流转
260
- exprs_ldl, G = self.dag(True, skip_columns, date, asset)
264
+ exprs_ldl, G = self.dag(True, skip_columns, date, asset, skip_simplify=skip_simplify)
261
265
 
262
266
  if regroup:
263
267
  exprs_ldl.optimize(merge=style != 'sql')
@@ -293,7 +297,7 @@ class ExprTool:
293
297
 
294
298
  return codes, G
295
299
 
296
- # @lru_cache(maxsize=64)
300
+ @lru_cache(maxsize=64)
297
301
  def _get_code(self,
298
302
  source: str, *more_sources: str,
299
303
  extra_codes: str,
@@ -307,10 +311,9 @@ class ExprTool:
307
311
  ge_date_idx: int = 0,
308
312
  skip_simplify: bool = False,
309
313
  skip_columns: Iterable[str] = (),
310
- function_mapping={},
311
314
  **kwargs) -> str:
312
315
  """通过字符串生成代码, 加了缓存,多次调用不重复生成"""
313
- raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor, function_mapping=function_mapping)
316
+ raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
314
317
 
315
318
  # 生成代码
316
319
  code, G = _TOOL_.all(exprs_list, style=style, template_file=template_file,
@@ -391,9 +394,8 @@ def codegen_exec(df: Union[DataFrame, None],
391
394
  ge_date_idx: int = 0,
392
395
  skip_simplify: bool = False,
393
396
  skip_columns: Iterable[str] = (),
394
- function_mapping: Dict = {},
395
397
  **kwargs) -> Union[DataFrame, str]:
396
- """快速转换源代码并执行
398
+ r"""快速转换源代码并执行
397
399
 
398
400
  Parameters
399
401
  ----------
@@ -403,7 +405,12 @@ def codegen_exec(df: Union[DataFrame, None],
403
405
  codes:
404
406
  函数体。此部分中的表达式会被翻译成目标代码
405
407
  extra_codes: str
406
- 额外代码。不做处理,会被直接复制到目标代码中
408
+ 额外代码。不做处理,会被直接复制到目标代码中。例如:
409
+
410
+ r'CS_SW_L1 = r"^sw_l1_\d+$"'
411
+
412
+ apply_const_to_expr()
413
+
407
414
  output_file: str| TextIOBase
408
415
  保存生成的目标代码到文件中
409
416
  run_file: bool or str
@@ -412,6 +419,7 @@ def codegen_exec(df: Union[DataFrame, None],
412
419
  - 如果是字符串,会自动从run_file中读取代码
413
420
  - 如果是模块名,会自动从模块中读取代码(可调试)
414
421
  - 注意:可能调用到其他目录下的同名模块,所以保存的文件名要有辨识度
422
+ - 如果文件不存在,先生成文件。第二次从生成的文件中生成
415
423
  convert_xor: bool
416
424
  ^ 转成异或还是乘方
417
425
  style: str
@@ -438,12 +446,13 @@ def codegen_exec(df: Union[DataFrame, None],
438
446
  -2 表示最近两天 >=date[-2]
439
447
  skip_simplify:bool
440
448
  遗传算法时很有可能出现OPEN/OPEN,可以跳过化简步骤
449
+ 1. 跳过cse前的simplify
450
+ 2. 跳过cse时的optimizations
451
+ 3. 跳过DAG中的部分merge步骤
441
452
  skip_columns:
442
453
  已经存在的列不参与计算。可用于加快计算速度。只在计算耗时久时再用,否则没有必要
443
454
  例如:在研发阶段,第一次计算100个因子,第二次,只改动了其中的5个,所以只要将这5个从df.columns中排除即可。
444
455
  注意:生成的源代码有差异。
445
- function_mapping:
446
- 传入函数定义,可直接传`globals()`。用于将所有的关键字参数转换成位置参数
447
456
 
448
457
  Returns
449
458
  -------
@@ -473,10 +482,12 @@ def codegen_exec(df: Union[DataFrame, None],
473
482
 
474
483
  if input_file is not None:
475
484
  if input_file.endswith('.py'):
476
- return _get_func_from_file_py(input_file)(df, ge_date_idx)
485
+ if os.path.exists(input_file):
486
+ return _get_func_from_file_py(input_file)(df, ge_date_idx)
477
487
  elif input_file.endswith('.sql'):
478
- with pl.SQLContext(frames={table_name: df}) as ctx:
479
- return ctx.execute(_get_code_from_file(input_file), eager=isinstance(df, _pl_DataFrame))
488
+ if os.path.exists(input_file):
489
+ with pl.SQLContext(frames={table_name: df}) as ctx:
490
+ return ctx.execute(_get_code_from_file(input_file), eager=isinstance(df, _pl_DataFrame))
480
491
  else:
481
492
  return _get_func_from_module(input_file)(df, ge_date_idx) # 可断点调试
482
493
  else:
@@ -501,7 +512,6 @@ def codegen_exec(df: Union[DataFrame, None],
501
512
  ge_date_idx=ge_date_idx,
502
513
  skip_simplify=skip_simplify,
503
514
  skip_columns=skip_columns,
504
- function_mapping=function_mapping,
505
515
  **kwargs
506
516
  )
507
517
 
@@ -1 +0,0 @@
1
- __version__ = "0.16.2"
File without changes
File without changes