expr-codegen 0.12.1__tar.gz → 0.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. expr_codegen-0.13.0/.gitignore +160 -0
  2. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/PKG-INFO +22 -20
  3. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/README.md +12 -8
  4. expr_codegen-0.13.0/expr_codegen/_version.py +1 -0
  5. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/codes.py +6 -5
  6. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/model.py +27 -55
  7. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/code.py +12 -1
  8. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/printer.py +2 -0
  9. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/ta.py +3 -0
  10. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/template.py.j2 +8 -10
  11. {expr_codegen-0.12.1/expr_codegen/polars_over → expr_codegen-0.13.0/expr_codegen/polars}/code.py +13 -2
  12. {expr_codegen-0.12.1/expr_codegen/polars_over → expr_codegen-0.13.0/expr_codegen/polars}/template.py.j2 +14 -10
  13. expr_codegen-0.13.0/expr_codegen/sql/code.py +106 -0
  14. {expr_codegen-0.12.1/expr_codegen/polars_over → expr_codegen-0.13.0/expr_codegen/sql}/printer.py +47 -7
  15. expr_codegen-0.13.0/expr_codegen/sql/template.sql.j2 +29 -0
  16. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/tool.py +32 -21
  17. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/pyproject.toml +11 -19
  18. expr_codegen-0.12.1/expr_codegen/_version.py +0 -1
  19. expr_codegen-0.12.1/expr_codegen/polars_group/code.py +0 -115
  20. expr_codegen-0.12.1/expr_codegen/polars_group/template.py.j2 +0 -83
  21. expr_codegen-0.12.1/expr_codegen.egg-info/PKG-INFO +0 -319
  22. expr_codegen-0.12.1/expr_codegen.egg-info/SOURCES.txt +0 -31
  23. expr_codegen-0.12.1/expr_codegen.egg-info/dependency_links.txt +0 -1
  24. expr_codegen-0.12.1/expr_codegen.egg-info/requires.txt +0 -11
  25. expr_codegen-0.12.1/expr_codegen.egg-info/top_level.txt +0 -1
  26. expr_codegen-0.12.1/setup.cfg +0 -4
  27. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/LICENSE +0 -0
  28. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/__init__.py +0 -0
  29. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/dag.py +0 -0
  30. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/expr.py +0 -0
  31. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/latex/__init__.py +0 -0
  32. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/latex/printer.py +0 -0
  33. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/__init__.py +0 -0
  34. {expr_codegen-0.12.1 → expr_codegen-0.13.0}/expr_codegen/pandas/helper.py +0 -0
  35. {expr_codegen-0.12.1/expr_codegen/polars_group → expr_codegen-0.13.0/expr_codegen/polars}/__init__.py +0 -0
  36. {expr_codegen-0.12.1/expr_codegen/polars_group → expr_codegen-0.13.0/expr_codegen/polars}/printer.py +0 -0
  37. {expr_codegen-0.12.1/expr_codegen/polars_over → expr_codegen-0.13.0/expr_codegen/sql}/__init__.py +0 -0
@@ -0,0 +1,160 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expr_codegen
3
- Version: 0.12.1
3
+ Version: 0.13.0
4
4
  Summary: symbol expression to polars expression tool
5
5
  Author-email: wukan <wu-kan@163.com>
6
6
  License: BSD 3-Clause License
@@ -31,24 +31,22 @@ License: BSD 3-Clause License
31
31
  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32
32
  OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33
33
  OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
-
35
- Keywords: polars,expression,talib
34
+ License-File: LICENSE
35
+ Keywords: expression,polars,talib
36
36
  Classifier: Development Status :: 4 - Beta
37
37
  Classifier: Programming Language :: Python
38
38
  Requires-Python: >=3.9
39
- Description-Content-Type: text/markdown
40
- License-File: LICENSE
39
+ Requires-Dist: ast-comments
41
40
  Requires-Dist: black
42
- Requires-Dist: Jinja2
43
- Requires-Dist: networkx
41
+ Requires-Dist: jinja2
44
42
  Requires-Dist: loguru
43
+ Requires-Dist: networkx
45
44
  Requires-Dist: sympy
46
- Requires-Dist: ast-comments
47
45
  Provides-Extra: streamlit
48
- Requires-Dist: streamlit; extra == "streamlit"
49
- Requires-Dist: streamlit-ace; extra == "streamlit"
50
- Requires-Dist: more_itertools; extra == "streamlit"
51
- Dynamic: license-file
46
+ Requires-Dist: more-itertools; extra == 'streamlit'
47
+ Requires-Dist: streamlit; extra == 'streamlit'
48
+ Requires-Dist: streamlit-ace; extra == 'streamlit'
49
+ Description-Content-Type: text/markdown
52
50
 
53
51
  # expr_codegen 表达式转译器
54
52
 
@@ -81,6 +79,8 @@ https://exprcodegen.streamlit.app
81
79
  import sys
82
80
  from io import StringIO
83
81
 
82
+ import polars as pl
83
+
84
84
  from expr_codegen import codegen_exec
85
85
 
86
86
 
@@ -109,18 +109,20 @@ def _code_block_2():
109
109
  CPV = cs_zscore(_corr) + cs_zscore(_beta)
110
110
 
111
111
 
112
- code = StringIO()
112
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=sys.stdout) # 打印代码
113
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file="output.py") # 保存到文件
114
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by') # 只执行,不保存代码
113
115
 
114
- df = None # 替换成真实的polars数据
115
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file=sys.stdout) # 打印代码
116
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file="output.py") # 保存到文件
117
- df = codegen_exec(df, _code_block_1, _code_block_2) # 只执行,不保存代码
118
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file=code) # 保存到字符串
116
+ code = StringIO()
117
+ codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=code) # 保存到字符串
119
118
  code.seek(0)
120
119
  code.read() # 读取代码
121
120
 
122
- df = codegen_exec(df.lazy(), _code_block_1, _code_block_2).collect() # Lazy CPU
123
- df = codegen_exec(df.lazy(), _code_block_1, _code_block_2).collect(engine="gpu") # Lazy GPU
121
+ # TODO 替换成合适的数据
122
+ df = pl.DataFrame()
123
+ df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect() # Lazy CPU
124
+ df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect(engine="gpu") # Lazy GPU
125
+
124
126
  ```
125
127
 
126
128
  ## 目录结构
@@ -29,6 +29,8 @@ https://exprcodegen.streamlit.app
29
29
  import sys
30
30
  from io import StringIO
31
31
 
32
+ import polars as pl
33
+
32
34
  from expr_codegen import codegen_exec
33
35
 
34
36
 
@@ -57,18 +59,20 @@ def _code_block_2():
57
59
  CPV = cs_zscore(_corr) + cs_zscore(_beta)
58
60
 
59
61
 
60
- code = StringIO()
62
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=sys.stdout) # 打印代码
63
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file="output.py") # 保存到文件
64
+ code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by') # 只执行,不保存代码
61
65
 
62
- df = None # 替换成真实的polars数据
63
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file=sys.stdout) # 打印代码
64
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file="output.py") # 保存到文件
65
- df = codegen_exec(df, _code_block_1, _code_block_2) # 只执行,不保存代码
66
- df = codegen_exec(df, _code_block_1, _code_block_2, output_file=code) # 保存到字符串
66
+ code = StringIO()
67
+ codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=code) # 保存到字符串
67
68
  code.seek(0)
68
69
  code.read() # 读取代码
69
70
 
70
- df = codegen_exec(df.lazy(), _code_block_1, _code_block_2).collect() # Lazy CPU
71
- df = codegen_exec(df.lazy(), _code_block_1, _code_block_2).collect(engine="gpu") # Lazy GPU
71
+ # TODO 替换成合适的数据
72
+ df = pl.DataFrame()
73
+ df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect() # Lazy CPU
74
+ df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect(engine="gpu") # Lazy GPU
75
+
72
76
  ```
73
77
 
74
78
  ## 目录结构
@@ -0,0 +1 @@
1
+ __version__ = "0.13.0"
@@ -381,15 +381,16 @@ def sources_to_asts(*sources, convert_xor: bool):
381
381
  if isinstance(node, ast.Assign):
382
382
  assigns.append(node)
383
383
  continue
384
- if isinstance(node, ast_comments.Comment):
385
- # 添加注释
386
- if node.inline and isinstance(assigns[-1], ast.Assign):
387
- assigns.append(node)
388
- continue
389
384
  # TODO 是否要把其它语句也加入?是否有安全问题?
390
385
  if isinstance(node, (ast.Import, ast.ImportFrom)):
391
386
  raw.append(node)
392
387
  continue
388
+ if isinstance(node, ast_comments.Comment):
389
+ # 添加注释
390
+ if node.inline and isinstance(tree.body[i - 1], ast.Assign):
391
+ assigns.append(node)
392
+ continue
393
+
393
394
  return raw_to_code(raw), assigns_to_list(assigns), t.funcs_new, t.args_new, t.targets_new
394
395
 
395
396
 
@@ -1,10 +1,10 @@
1
1
  from functools import reduce
2
- from itertools import product
2
+ from itertools import product, permutations
3
3
 
4
4
  import networkx as nx
5
5
  from sympy import symbols
6
6
 
7
- from expr_codegen.dag import zero_indegree, hierarchy_pos, remove_paths_by_zero_outdegree, zero_outdegree
7
+ from expr_codegen.dag import zero_indegree, hierarchy_pos, remove_paths_by_zero_outdegree
8
8
  from expr_codegen.expr import CL, get_symbols, get_children, get_key, is_simple_expr
9
9
 
10
10
  _RESERVED_WORD_ = {'_NONE_', '_TRUE_', '_FALSE_'}
@@ -122,6 +122,18 @@ class ListDictList:
122
122
  return l3
123
123
 
124
124
 
125
+ def score1(row) -> int:
126
+ # 首尾相连打分加1
127
+ lst = [None] + [key for r in row for key in dict(r).keys()]
128
+ return sum([x == y for x, y in zip(lst[:-1], lst[1:])])
129
+
130
+
131
+ def score2(row) -> float:
132
+ # 最后一个ts越靠前,打分越高
133
+ lst = ['ts'] + [key[0] for r in row for key in dict(r).keys()]
134
+ return lst[::-1].index('ts') / len(lst)
135
+
136
+
125
137
  def chain_create(nested_list):
126
138
  """接龙。多个列表,头尾相连
127
139
 
@@ -131,63 +143,22 @@ def chain_create(nested_list):
131
143
  alpha_031 = ((cs_rank(cs_rank(cs_rank(ts_decay_linear((-1 * cs_rank(cs_rank(ts_delta(CLOSE, 10)))), 10))))))
132
144
 
133
145
  """
134
- # 两两取交集,交集为{}时,添加一个{None},防止product时出错
135
- neighbor_inter = [set(x) & set(y) or {None} for x, y in zip(nested_list[:-1], nested_list[1:])]
146
+ perms = []
147
+ for d in nested_list:
148
+ # 每一层生成排列
149
+ perms.append(permutations(d.items()))
136
150
 
137
- # 查找最小数字,表示两两不重复
138
- last_min = float('inf')
139
- # 最小不重复的一行记录
151
+ last_score = float('-inf')
140
152
  last_row = None
141
- last_rows = set()
142
- for row in product(*neighbor_inter):
143
- # 判断两两是否重复,重复为1,反之为0
144
- result = sum([x == y for x, y in zip(row[:-1], row[1:])])
145
- if last_min > result:
146
- last_min = result
153
+ # 生成笛卡尔积
154
+ for row in product(*perms):
155
+ result = score1(row) + score2(row)
156
+ # print(result, row)
157
+ if result > last_score:
158
+ last_score = result
147
159
  last_row = row
148
- if result == 0:
149
- last_rows.add(last_row)
150
- last_min = float('inf')
151
- continue
152
- last_rows.add(last_row)
153
- last_rows = list(last_rows)
154
-
155
- # last_rows中有多个满足条件的,优先保证最后一组ts在最前,ts后可提前filter减少计算量
156
- last_row = last_rows[0]
157
- for row in last_rows:
158
- if len(row) == 0:
159
- # 一行表达式
160
- continue
161
- if row[-1] is None:
162
- continue
163
- if row[-1][0] == 'ts':
164
- last_row = row
165
- break
166
-
167
- # 如何移动才是难点 如果两个连续 ts/ts,那么如何移动
168
-
169
- # 调整后的第0列
170
- head = [None] + list(last_row)
171
- # 调整后的第-1列
172
- tail = list(last_row) + [None]
173
-
174
- # 调整新列表
175
- arr = []
176
- for ll, hh, tt in zip(nested_list, head, tail):
177
- d = []
178
- for k, v in ll.items():
179
- if len(d) == 0:
180
- d.append((k, v))
181
- continue
182
- if k == hh:
183
- d.insert(0, (k, v))
184
- elif k == tt:
185
- d.append((k, v))
186
- else:
187
- d.insert(1, (k, v))
188
- arr.append(dict(d))
189
160
 
190
- return arr
161
+ return [dict(ro) for ro in last_row]
191
162
 
192
163
 
193
164
  # ==========================
@@ -425,6 +396,7 @@ def dag_end(G):
425
396
 
426
397
  exprs_ldl.append(key, (node, expr, symbols, comment))
427
398
 
399
+ # 第0层是CLOSE等基础因子,剔除
428
400
  exprs_ldl._list = exprs_ldl.values()[1:]
429
401
 
430
402
  return exprs_ldl, G
@@ -38,6 +38,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
38
38
  filename,
39
39
  date='date', asset='asset',
40
40
  extra_codes: Sequence[str] = (),
41
+ filter_last: bool = False,
41
42
  **kwargs):
42
43
  """基于模板的代码生成"""
43
44
  if filename is None:
@@ -53,7 +54,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
53
54
  # 处理过后的表达式
54
55
  exprs_dst = []
55
56
  syms_out = []
56
-
57
+ ts_func_name = None
57
58
  drop_symbols = exprs_ldl.drop_symbols()
58
59
  j = -1
59
60
  for i, row in enumerate(exprs_ldl.values()):
@@ -78,6 +79,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
78
79
  if len(groupbys['sort']) == 0:
79
80
  groupbys['sort'] = f'df = df.sort_values(by=[_ASSET_, _DATE_]).reset_index(drop=True)'
80
81
  if k[0] == TS:
82
+ ts_func_name = func_name
81
83
  # 时序需要排序
82
84
  func_code = [f' g.df = df.sort_values(by=[_DATE_])'] + func_code
83
85
  else:
@@ -93,6 +95,15 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
93
95
 
94
96
  syms1 = symbols_to_code(syms_dst)
95
97
  syms2 = symbols_to_code(syms_out)
98
+ if filter_last:
99
+ _groupbys = {'sort': groupbys['sort']}
100
+ if ts_func_name is None:
101
+ _groupbys['_filter_last'] = "df = filter_last(df.sort_values(by=[_DATE_]))"
102
+ for k, v in groupbys.items():
103
+ _groupbys[k] = v
104
+ if k == ts_func_name:
105
+ _groupbys[k + '_filter_last'] = "df = filter_last(df)"
106
+ groupbys = _groupbys
96
107
 
97
108
  try:
98
109
  env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__)))
@@ -54,6 +54,8 @@ class PandasStrPrinter(StrPrinter):
54
54
  self._print_level -= 1
55
55
 
56
56
  def _print_Symbol(self, expr):
57
+ if expr.name in ('_NONE_', '_TRUE_', '_FALSE_'):
58
+ return expr.name
57
59
  return f"g[{expr.name}]"
58
60
 
59
61
  def _print_Equality(self, expr):
@@ -73,6 +73,9 @@ def ts_delay(x: pd.Series, d: int = 1) -> pd.Series:
73
73
  def ts_delta(x: pd.Series, d: int = 1) -> pd.Series:
74
74
  return x.diff(d)
75
75
 
76
+ def ts_returns(x: pd.Series, d: int = 1) -> pd.Series:
77
+ return x.pct_change(d)
78
+
76
79
 
77
80
  def ts_max(x: pd.Series, d: int = 5) -> pd.Series:
78
81
  return x.rolling(d).max()
@@ -22,6 +22,7 @@ _FALSE_ = False
22
22
 
23
23
  g = GlobalVariable()
24
24
 
25
+
25
26
  def unpack(x: Tuple, idx: int = 0) -> pd.Series:
26
27
  return x[idx]
27
28
 
@@ -30,9 +31,11 @@ def unpack(x: Tuple, idx: int = 0) -> pd.Series:
30
31
  {% endfor %}
31
32
 
32
33
  {% for key, value in funcs.items() %}
34
+
33
35
  def {{ key }}(df: pd.DataFrame) -> pd.DataFrame:
34
36
  {{ value }}
35
37
  return g.df
38
+
36
39
  {% endfor %}
37
40
 
38
41
  """
@@ -48,8 +51,12 @@ def {{ key }}(df: pd.DataFrame) -> pd.DataFrame:
48
51
  """
49
52
 
50
53
 
54
+ def filter_last(df: pd.DataFrame) -> pd.DataFrame:
55
+ """过滤数据,只取最后一天。实盘时可用于减少计算量"""
56
+ return df[df[_DATE_] >= df[_DATE_].iloc[-1]]
57
+
58
+
51
59
  def main(df: pd.DataFrame) -> pd.DataFrame:
52
- # logger.info("start...")
53
60
  {% for key, value in groupbys.items() %}
54
61
  {{ value-}}
55
62
  {% endfor %}
@@ -57,13 +64,4 @@ def main(df: pd.DataFrame) -> pd.DataFrame:
57
64
  # drop intermediate columns
58
65
  df = df.drop(columns=list(filter(lambda x: x.startswith("_"), df.columns)))
59
66
 
60
- # logger.info('done')
61
-
62
- # save
63
- # df.to_parquet('output.parquet', compression='zstd')
64
-
65
67
  return df
66
-
67
- # if __name__ in ("__main__", "builtins"):
68
- # # TODO: 数据加载或外部传入
69
- # df_output = main(df_input)
@@ -7,7 +7,7 @@ from jinja2 import FileSystemLoader, TemplateNotFound
7
7
 
8
8
  from expr_codegen.expr import TS, CS, GP
9
9
  from expr_codegen.model import ListDictList
10
- from expr_codegen.polars_over.printer import PolarsStrPrinter
10
+ from expr_codegen.polars.printer import PolarsStrPrinter
11
11
 
12
12
 
13
13
  def get_groupby_from_tuple(tup, func_name, drop_cols):
@@ -40,6 +40,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
40
40
  date='date', asset='asset',
41
41
  extra_codes: Sequence[str] = (),
42
42
  over_null: Literal['order_by', 'partition_by', None] = 'partition_by',
43
+ filter_last: bool = False,
43
44
  **kwargs):
44
45
  """基于模板的代码生成"""
45
46
  if filename is None:
@@ -58,7 +59,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
58
59
  # 处理过后的表达式
59
60
  exprs_dst = []
60
61
  syms_out = []
61
-
62
+ ts_func_name = None
62
63
  drop_symbols = exprs_ldl.drop_symbols()
63
64
  j = -1
64
65
  for i, row in enumerate(exprs_ldl.values()):
@@ -85,6 +86,7 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
85
86
  # 不想等,打印注释,显示会更直观察
86
87
  func_code.append(f"# {va} = {s1}")
87
88
  if k[0] == TS:
89
+ ts_func_name = func_name
88
90
  # https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629
89
91
  _sym = [f"{s}.is_not_null()" for s in set(sym)]
90
92
  if len(_sym) > 1:
@@ -118,6 +120,15 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
118
120
 
119
121
  syms1 = symbols_to_code(syms_dst)
120
122
  syms2 = symbols_to_code(syms_out)
123
+ if filter_last:
124
+ _groupbys = {'sort': groupbys['sort']}
125
+ if ts_func_name is None:
126
+ _groupbys['_filter_last'] = "df = filter_last(df.sort(_DATE_))"
127
+ for k, v in groupbys.items():
128
+ _groupbys[k] = v
129
+ if k == ts_func_name:
130
+ _groupbys[k + '_filter_last'] = "df = filter_last(df)"
131
+ groupbys = _groupbys
121
132
 
122
133
  try:
123
134
  env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__)))
@@ -8,6 +8,7 @@ import polars.selectors as cs # noqa
8
8
  # from loguru import logger # noqa
9
9
  from polars import DataFrame as _pl_DataFrame
10
10
  from polars import LazyFrame as _pl_LazyFrame
11
+
11
12
  # ===================================
12
13
  # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta
13
14
  # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的
@@ -31,6 +32,7 @@ _NONE_ = None
31
32
  _TRUE_ = True
32
33
  _FALSE_ = False
33
34
 
35
+
34
36
  def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr:
35
37
  return x.struct[idx]
36
38
 
@@ -39,9 +41,11 @@ def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr:
39
41
  {% endfor %}
40
42
 
41
43
  {% for key, value in funcs.items() %}
44
+
42
45
  def {{ key }}(df: DataFrame) -> DataFrame:
43
46
  {{ value }}
44
47
  return df
48
+
45
49
  {% endfor %}
46
50
 
47
51
  """
@@ -57,8 +61,17 @@ def {{ key }}(df: DataFrame) -> DataFrame:
57
61
  """
58
62
 
59
63
 
64
+ def filter_last(df: DataFrame) -> DataFrame:
65
+ """过滤数据,只取最后一天。实盘时可用于减少计算量
66
+ 前一个调用的ts,这里可以直接调用,可以认为已经排序好
67
+ `df = filter_last(df)`
68
+ 反之
69
+ `df = filter_last(df.sort(_DATE_))`
70
+ """
71
+ return df.filter(pl.col(_DATE_) >= df.select(pl.last(_DATE_))[0, 0])
72
+
73
+
60
74
  def main(df: DataFrame) -> DataFrame:
61
- # logger.info("start...")
62
75
  {% for key, value in groupbys.items() %}
63
76
  {{ value-}}
64
77
  {% endfor %}
@@ -69,15 +82,6 @@ def main(df: DataFrame) -> DataFrame:
69
82
 
70
83
  # shrink
71
84
  df = df.select(cs.all().shrink_dtype())
72
- # df = df.shrink_to_fit()
73
-
74
- # logger.info('done')
75
-
76
- # save
77
- # df.write_parquet('output.parquet')
78
85
 
79
86
  return df
80
87
 
81
- # if __name__ in ("__main__", "builtins"):
82
- # # TODO: 数据加载或外部传入
83
- # df_output = main(df_input)