lidb 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lidb/__init__.py +30 -0
- lidb/database.py +234 -0
- lidb/dataset.py +442 -0
- lidb/init.py +42 -0
- lidb/parse.py +107 -0
- lidb/qdf/__init__.py +34 -0
- lidb/qdf/errors.py +65 -0
- lidb/qdf/expr.py +370 -0
- lidb/qdf/lazy.py +174 -0
- lidb/qdf/lazy2.py +161 -0
- lidb/qdf/qdf.py +161 -0
- lidb/qdf/udf/__init__.py +14 -0
- lidb/qdf/udf/base_udf.py +146 -0
- lidb/qdf/udf/cs_udf.py +115 -0
- lidb/qdf/udf/d_udf.py +183 -0
- lidb/qdf/udf/itd_udf.py +209 -0
- lidb/qdf/udf/ts_udf.py +182 -0
- lidb/svc/__init__.py +6 -0
- lidb/svc/data.py +138 -0
- lidb/table.py +129 -0
- lidb-1.2.0.dist-info/METADATA +18 -0
- lidb-1.2.0.dist-info/RECORD +24 -0
- lidb-1.2.0.dist-info/WHEEL +5 -0
- lidb-1.2.0.dist-info/top_level.txt +1 -0
lidb/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright (c) ZhangYundi.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
from .init import (
|
|
5
|
+
NAME,
|
|
6
|
+
DB_PATH,
|
|
7
|
+
CONFIG_PATH,
|
|
8
|
+
get_settings,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from .database import (
|
|
12
|
+
sql,
|
|
13
|
+
put,
|
|
14
|
+
has,
|
|
15
|
+
tb_path,
|
|
16
|
+
read_mysql,
|
|
17
|
+
write_mysql,
|
|
18
|
+
execute_mysql,
|
|
19
|
+
read_ck,
|
|
20
|
+
scan,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
from .table import Table, TableMode
|
|
24
|
+
from .dataset import Dataset, DataLoader
|
|
25
|
+
from .qdf import from_polars, Expr
|
|
26
|
+
from .svc import DataService, D
|
|
27
|
+
|
|
28
|
+
from .parse import parse_hive_partition_structure
|
|
29
|
+
|
|
30
|
+
__version__ = "1.2.0"
|
lidb/database.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
---------------------------------------------
|
|
4
|
+
Copyright (c) 2025 ZhangYundi
|
|
5
|
+
Licensed under the MIT License.
|
|
6
|
+
Created on 2024/7/1 09:44
|
|
7
|
+
Email: yundi.xxii@outlook.com
|
|
8
|
+
---------------------------------------------
|
|
9
|
+
"""
|
|
10
|
+
import re
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
import pymysql
|
|
15
|
+
|
|
16
|
+
from .parse import extract_table_names_from_sql
|
|
17
|
+
from .init import DB_PATH, logger, get_settings
|
|
18
|
+
import urllib
|
|
19
|
+
import polars as pl
|
|
20
|
+
|
|
21
|
+
# ======================== 本地数据库 catdb ========================
|
|
22
|
+
def tb_path(tb_name: str) -> Path:
|
|
23
|
+
"""
|
|
24
|
+
返回指定表名 完整的本地路径
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
tb_name: str
|
|
28
|
+
表名,路径写法: a/b/c
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
pathlib.Path
|
|
32
|
+
full_abs_path: pathlib.Path
|
|
33
|
+
完整的本地绝对路径 $DB_PATH/a/b/c
|
|
34
|
+
"""
|
|
35
|
+
return Path(DB_PATH, tb_name)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def put(df, tb_name: str, partitions: list[str] | None = None):
|
|
39
|
+
"""
|
|
40
|
+
将一个DataFrame写入到指定名称的表格目录中,支持分区存储。
|
|
41
|
+
|
|
42
|
+
该函数负责将给定的DataFrame(df)根据提供的表名(tb_name)写入到本地文件系统中。
|
|
43
|
+
如果指定了分区(partitions),则会按照这些分区列将数据分割存储。如果目录不存在,会自动创建目录。
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
df: polars.DataFrame
|
|
48
|
+
tb_name: str
|
|
49
|
+
表的名称,用于确定存储数据的目录
|
|
50
|
+
partitions: list[str] | None
|
|
51
|
+
指定用于分区的列名列表。如果未提供,则不进行分区。
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
if df is None:
|
|
58
|
+
logger.warning(f"put failed: input data is None.")
|
|
59
|
+
return
|
|
60
|
+
if df.is_empty():
|
|
61
|
+
logger.warning(f"put failed: input data is empty.")
|
|
62
|
+
return
|
|
63
|
+
tbpath = tb_path(tb_name)
|
|
64
|
+
if not tbpath.exists():
|
|
65
|
+
tbpath.mkdir(parents=True, exist_ok=True)
|
|
66
|
+
if partitions is not None:
|
|
67
|
+
df.write_parquet(tbpath, partition_by=partitions)
|
|
68
|
+
else:
|
|
69
|
+
df.write_parquet(tbpath / "data.parquet")
|
|
70
|
+
|
|
71
|
+
def has(tb_name: str) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
判定给定的表名是否存在
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
tb_name: str
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
return tb_path(tb_name).exists()
|
|
83
|
+
|
|
84
|
+
def sql(query: str, ):
|
|
85
|
+
"""
|
|
86
|
+
sql 查询,从本地paquet文件中查询数据
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
query: str
|
|
91
|
+
sql查询语句
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
import polars as pl
|
|
97
|
+
|
|
98
|
+
tbs = extract_table_names_from_sql(query)
|
|
99
|
+
convertor = dict()
|
|
100
|
+
for tb in tbs:
|
|
101
|
+
db_path = tb_path(tb)
|
|
102
|
+
format_tb = f"read_parquet('{db_path}/**/*.parquet')"
|
|
103
|
+
convertor[tb] = format_tb
|
|
104
|
+
pattern = re.compile("|".join(re.escape(k) for k in convertor.keys()))
|
|
105
|
+
new_query = pattern.sub(lambda m: convertor[m.group(0)], query)
|
|
106
|
+
return pl.sql(new_query)
|
|
107
|
+
|
|
108
|
+
def scan(tb: str,) -> pl.LazyFrame:
|
|
109
|
+
"""polars.scan_parquet"""
|
|
110
|
+
tb = tb_path(tb)
|
|
111
|
+
return pl.scan_parquet(tb)
|
|
112
|
+
|
|
113
|
+
def read_mysql(query: str, db_conf: str = "DATABASES.mysql"):
|
|
114
|
+
"""
|
|
115
|
+
从MySQL数据库中读取数据。
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
query: str
|
|
119
|
+
查询语句
|
|
120
|
+
db_conf: str
|
|
121
|
+
对应的配置 $DB_PATH/conf/settings.toml
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
polars.DataFrame
|
|
125
|
+
"""
|
|
126
|
+
import polars as pl
|
|
127
|
+
try:
|
|
128
|
+
db_setting = get_settings().get(db_conf, {})
|
|
129
|
+
required_keys = ['user', 'password', 'url', 'db']
|
|
130
|
+
missing_keys = [key for key in required_keys if key not in db_setting]
|
|
131
|
+
if missing_keys:
|
|
132
|
+
raise KeyError(f"Missing required keys in database config: {missing_keys}")
|
|
133
|
+
|
|
134
|
+
user = urllib.parse.quote_plus(db_setting['user'])
|
|
135
|
+
password = urllib.parse.quote_plus(db_setting['password'])
|
|
136
|
+
uri = f"mysql://{user}:{password}@{db_setting['url']}/{db_setting['db']}"
|
|
137
|
+
return pl.read_database_uri(query, uri)
|
|
138
|
+
|
|
139
|
+
except KeyError as e:
|
|
140
|
+
raise RuntimeError("Database configuration error: missing required fields.") from e
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise RuntimeError(f"Failed to execute MySQL query: {e}") from e
|
|
143
|
+
|
|
144
|
+
def write_mysql(df: pl.DataFrame,
|
|
145
|
+
remote_tb: str,
|
|
146
|
+
db_conf: str,
|
|
147
|
+
if_table_exists: Literal["append", "replace", "fail"]="append"):
|
|
148
|
+
"""将 polars.DataFrame 写入mysql"""
|
|
149
|
+
try:
|
|
150
|
+
db_setting = get_settings().get(db_conf, {})
|
|
151
|
+
required_keys = ['user', 'password', 'url', 'db']
|
|
152
|
+
missing_keys = [key for key in required_keys if key not in db_setting]
|
|
153
|
+
if missing_keys:
|
|
154
|
+
raise KeyError(f"Missing required keys in database config: {missing_keys}")
|
|
155
|
+
|
|
156
|
+
user = urllib.parse.quote_plus(db_setting['user'])
|
|
157
|
+
password = urllib.parse.quote_plus(db_setting['password'])
|
|
158
|
+
uri = f"mysql+pymysql://{user}:{password}@{db_setting['url']}/{db_setting['db']}"
|
|
159
|
+
return df.write_database(remote_tb,
|
|
160
|
+
connection=uri,
|
|
161
|
+
if_table_exists=if_table_exists)
|
|
162
|
+
|
|
163
|
+
except KeyError as e:
|
|
164
|
+
raise RuntimeError("Database configuration error: missing required fields.") from e
|
|
165
|
+
except Exception as e:
|
|
166
|
+
raise RuntimeError(f"Failed to write MySQL: {e}") from e
|
|
167
|
+
|
|
168
|
+
def execute_mysql(sql: str, db_conf: str):
|
|
169
|
+
"""执行mysql语句"""
|
|
170
|
+
try:
|
|
171
|
+
db_setting = get_settings().get(db_conf, {})
|
|
172
|
+
required_keys = ['user', 'password', 'url', 'db']
|
|
173
|
+
missing_keys = [key for key in required_keys if key not in db_setting]
|
|
174
|
+
if missing_keys:
|
|
175
|
+
raise KeyError(f"Missing required keys in database config: {missing_keys}")
|
|
176
|
+
|
|
177
|
+
user = urllib.parse.quote_plus(db_setting['user'])
|
|
178
|
+
password = urllib.parse.quote_plus(db_setting['password'])
|
|
179
|
+
url = urllib.parse.quote_plus(db_setting["url"])
|
|
180
|
+
host, port = url.split(":")
|
|
181
|
+
|
|
182
|
+
except KeyError as e:
|
|
183
|
+
raise RuntimeError("Database configuration error: missing required fields.") from e
|
|
184
|
+
except Exception as e:
|
|
185
|
+
raise RuntimeError(f"Failed to parse config: {e}") from e
|
|
186
|
+
|
|
187
|
+
connection = pymysql.connect(
|
|
188
|
+
host=host,
|
|
189
|
+
port=port,
|
|
190
|
+
user=user,
|
|
191
|
+
password=password,
|
|
192
|
+
database=db_setting['db'] # or extract from connection string
|
|
193
|
+
)
|
|
194
|
+
try:
|
|
195
|
+
with connection.cursor() as cursor:
|
|
196
|
+
cursor.execute(sql)
|
|
197
|
+
connection.commit()
|
|
198
|
+
except Exception as e:
|
|
199
|
+
raise RuntimeError(f"Failed to execute MySQL: {e}") from e
|
|
200
|
+
finally:
|
|
201
|
+
connection.close()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def read_ck(query: str, db_conf: str = "DATABASES.ck"):
|
|
205
|
+
"""
|
|
206
|
+
从Clickhouse集群读取数据。
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
query: str
|
|
210
|
+
查询语句
|
|
211
|
+
db_conf: str
|
|
212
|
+
对应的配置 $DB_PATH/conf/settings.toml
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
polars.DataFrame
|
|
216
|
+
"""
|
|
217
|
+
import clickhouse_df
|
|
218
|
+
try:
|
|
219
|
+
db_setting = get_settings().get(db_conf, {})
|
|
220
|
+
required_keys = ['user', 'password', 'urls']
|
|
221
|
+
missing_keys = [key for key in required_keys if key not in db_setting]
|
|
222
|
+
if missing_keys:
|
|
223
|
+
raise KeyError(f"Missing required keys in database config: {missing_keys}")
|
|
224
|
+
|
|
225
|
+
user = urllib.parse.quote_plus(db_setting['user'])
|
|
226
|
+
password = urllib.parse.quote_plus(db_setting['password'])
|
|
227
|
+
|
|
228
|
+
with clickhouse_df.connect(db_setting['urls'], user=user, password=password):
|
|
229
|
+
return clickhouse_df.to_polars(query)
|
|
230
|
+
|
|
231
|
+
except KeyError as e:
|
|
232
|
+
raise RuntimeError("Database configuration error: missing required fields.") from e
|
|
233
|
+
except Exception as e:
|
|
234
|
+
raise RuntimeError(f"Failed to execute ClickHouse query: {e}") from e
|
lidb/dataset.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
# Copyright (c) ZhangYundi.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
# Created on 2025/10/27 14:13
|
|
4
|
+
# Description:
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import Callable, Literal
|
|
12
|
+
|
|
13
|
+
import logair
|
|
14
|
+
import polars as pl
|
|
15
|
+
import polars.selectors as cs
|
|
16
|
+
import xcals
|
|
17
|
+
import ygo
|
|
18
|
+
|
|
19
|
+
from .database import put, tb_path, scan, DB_PATH
|
|
20
|
+
from .parse import parse_hive_partition_structure
|
|
21
|
+
from .qdf import QDF, from_polars
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InstrumentType(Enum):
|
|
25
|
+
STOCK = "Stock" # 股票
|
|
26
|
+
ETF = "ETF" #
|
|
27
|
+
CB = "ConvertibleBond" # 可转债
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def complete_data(fn, date, save_path, partitions):
|
|
31
|
+
logger = logair.get_logger(__name__)
|
|
32
|
+
try:
|
|
33
|
+
data = fn(date=date)
|
|
34
|
+
if data is None:
|
|
35
|
+
# 保存数据的逻辑在fn中实现了
|
|
36
|
+
return
|
|
37
|
+
# 剔除以 `_` 开头的列
|
|
38
|
+
data = data.select(~cs.starts_with("_"))
|
|
39
|
+
if not isinstance(data, (pl.DataFrame, pl.LazyFrame)):
|
|
40
|
+
logger.error(f"{save_path}: Result of dataset.fn must be polars.DataFrame or polars.LazyFrame.")
|
|
41
|
+
return
|
|
42
|
+
if isinstance(data, pl.LazyFrame):
|
|
43
|
+
data = data.collect()
|
|
44
|
+
cols = data.columns
|
|
45
|
+
if "date" not in cols:
|
|
46
|
+
data = data.with_columns(pl.lit(date).alias("date")).select("date", *cols)
|
|
47
|
+
|
|
48
|
+
put(data, save_path, partitions=partitions)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
logger.error(f"{save_path}: Error when complete data for {date}")
|
|
51
|
+
logger.warning(e)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Dataset:
|
|
55
|
+
|
|
56
|
+
def __init__(self,
|
|
57
|
+
fn: Callable[..., pl.DataFrame],
|
|
58
|
+
tb: str,
|
|
59
|
+
update_time: str = "",
|
|
60
|
+
partitions: list[str] = None,
|
|
61
|
+
by_asset: bool = True,
|
|
62
|
+
by_time: bool = False):
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
fn: str
|
|
68
|
+
数据集计算函数
|
|
69
|
+
tb: str
|
|
70
|
+
数据集保存表格
|
|
71
|
+
update_time: str
|
|
72
|
+
更新时间: 默认没有-实时更新,也就是可以取到当天值
|
|
73
|
+
partitions: list[str]
|
|
74
|
+
分区
|
|
75
|
+
by_asset: bool
|
|
76
|
+
是否按照标的进行分区,默认 True
|
|
77
|
+
by_time: bool
|
|
78
|
+
是否按照标的进行分区,默认 False
|
|
79
|
+
"""
|
|
80
|
+
self.fn = fn
|
|
81
|
+
self.fn_params_sig = ygo.fn_signature_params(fn)
|
|
82
|
+
self._by_asset = by_asset
|
|
83
|
+
self._by_time = by_time
|
|
84
|
+
self._append_partitions = ["asset", "date"] if by_asset else ["date", ]
|
|
85
|
+
if by_time:
|
|
86
|
+
self._append_partitions.append("time")
|
|
87
|
+
if partitions is not None:
|
|
88
|
+
partitions = [k for k in partitions if k not in self._append_partitions]
|
|
89
|
+
partitions = [*partitions, *self._append_partitions]
|
|
90
|
+
else:
|
|
91
|
+
partitions = self._append_partitions
|
|
92
|
+
self.partitions = partitions
|
|
93
|
+
self._type_asset = "asset" in self.fn_params_sig
|
|
94
|
+
self.update_time = update_time
|
|
95
|
+
|
|
96
|
+
self.tb = tb
|
|
97
|
+
self.save_path = tb_path(tb)
|
|
98
|
+
fn_params = ygo.fn_params(self.fn)
|
|
99
|
+
self.fn_params = {k: v for (k, v) in fn_params}
|
|
100
|
+
self.constraints = dict()
|
|
101
|
+
for k in self.partitions[:-len(self._append_partitions)]:
|
|
102
|
+
if k in self.fn_params:
|
|
103
|
+
v = self.fn_params[k]
|
|
104
|
+
if isinstance(v, (list, tuple)) and not isinstance(v, str):
|
|
105
|
+
v = sorted(v)
|
|
106
|
+
self.constraints[k] = v
|
|
107
|
+
self.save_path = self.save_path / f"{k}={v}"
|
|
108
|
+
|
|
109
|
+
def is_empty(self, path) -> bool:
|
|
110
|
+
return not any(path.rglob("*.parquet"))
|
|
111
|
+
|
|
112
|
+
def __call__(self, *fn_args, **fn_kwargs):
|
|
113
|
+
# self.fn =
|
|
114
|
+
fn = partial(self.fn, *fn_args, **fn_kwargs)
|
|
115
|
+
ds = Dataset(fn=fn,
|
|
116
|
+
tb=self.tb,
|
|
117
|
+
partitions=self.partitions,
|
|
118
|
+
by_asset=self._by_asset,
|
|
119
|
+
by_time=self._by_time,
|
|
120
|
+
update_time=self.update_time)
|
|
121
|
+
return ds
|
|
122
|
+
|
|
123
|
+
def get_value(self, date, eager: bool = True, **constraints):
|
|
124
|
+
"""
|
|
125
|
+
取值: 不保证未来数据
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
date: str
|
|
129
|
+
取值日期
|
|
130
|
+
eager: bool
|
|
131
|
+
constraints: dict
|
|
132
|
+
取值的过滤条件
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
_constraints = {k: v for k, v in constraints.items() if k in self.partitions}
|
|
139
|
+
_limits = {k: v for k, v in constraints.items() if k not in self.partitions}
|
|
140
|
+
search_path = self.save_path
|
|
141
|
+
for k, v in _constraints.items():
|
|
142
|
+
if isinstance(v, (list, tuple)) and not isinstance(v, str):
|
|
143
|
+
v = sorted(v)
|
|
144
|
+
search_path = search_path / f"{k}={v}"
|
|
145
|
+
search_path = search_path / f"date={date}"
|
|
146
|
+
|
|
147
|
+
if not self.is_empty(search_path):
|
|
148
|
+
lf = scan(search_path).cast({"date": pl.Utf8})
|
|
149
|
+
schema = lf.collect_schema()
|
|
150
|
+
_limits = {k: v for k, v in constraints.items() if schema.get(k) is not None}
|
|
151
|
+
lf = lf.filter(date=date, **_limits)
|
|
152
|
+
if not eager:
|
|
153
|
+
return lf
|
|
154
|
+
data = lf.collect()
|
|
155
|
+
if not data.is_empty():
|
|
156
|
+
return data
|
|
157
|
+
fn = self.fn
|
|
158
|
+
save_path = self.save_path
|
|
159
|
+
|
|
160
|
+
if self._type_asset:
|
|
161
|
+
if "asset" in _constraints:
|
|
162
|
+
fn = ygo.delay(self.fn)(asset=_constraints["asset"])
|
|
163
|
+
if len(self.constraints) < len(self.partitions) - len(self._append_partitions):
|
|
164
|
+
# 如果分区指定的字段没有在Dataset定义中指定,需要在get_value中指定
|
|
165
|
+
params = dict()
|
|
166
|
+
for k in self.partitions[:-len(self._append_partitions)]:
|
|
167
|
+
if k not in self.constraints:
|
|
168
|
+
v = constraints[k]
|
|
169
|
+
params[k] = v
|
|
170
|
+
save_path = save_path / f"{k}={v}"
|
|
171
|
+
fn = ygo.delay(self.fn)(**params)
|
|
172
|
+
logger = logair.get_logger(__name__)
|
|
173
|
+
|
|
174
|
+
today = xcals.today()
|
|
175
|
+
now = xcals.now()
|
|
176
|
+
if (date > today) or (date == today and now < self.update_time):
|
|
177
|
+
logger.warning(f"{self.tb}: {date} is not ready, waiting for {self.update_time}")
|
|
178
|
+
return
|
|
179
|
+
complete_data(fn, date, save_path, self._append_partitions)
|
|
180
|
+
|
|
181
|
+
lf = scan(search_path).cast({"date": pl.Utf8})
|
|
182
|
+
schema = lf.collect_schema()
|
|
183
|
+
_limits = {k: v for k, v in constraints.items() if schema.get(k) is not None}
|
|
184
|
+
lf = lf.filter(date=date, **_limits)
|
|
185
|
+
if not eager:
|
|
186
|
+
return lf
|
|
187
|
+
return lf.collect()
|
|
188
|
+
|
|
189
|
+
def get_pit(self, date: str, query_time: str, eager: bool = True, **contraints):
|
|
190
|
+
"""取值:如果取值时间早于更新时间,则返回上一天的值"""
|
|
191
|
+
if not self.update_time:
|
|
192
|
+
return self.get_value(date, **contraints)
|
|
193
|
+
val_date = date
|
|
194
|
+
if query_time < self.update_time:
|
|
195
|
+
val_date = xcals.shift_tradeday(date, -1)
|
|
196
|
+
return self.get_value(val_date, eager=eager, **contraints).with_columns(date=pl.lit(date), )
|
|
197
|
+
|
|
198
|
+
def get_history(self,
|
|
199
|
+
dateList: list[str],
|
|
200
|
+
n_jobs: int = 5,
|
|
201
|
+
backend: Literal["threading", "multiprocessing", "loky"] = "loky",
|
|
202
|
+
eager: bool = True,
|
|
203
|
+
rep_asset: str = "000001", # 默认 000001
|
|
204
|
+
**constraints):
|
|
205
|
+
"""获取历史值: 不保证未来数据"""
|
|
206
|
+
_constraints = {k: v for k, v in constraints.items() if k in self.partitions}
|
|
207
|
+
search_path = self.save_path
|
|
208
|
+
for k, v in _constraints.items():
|
|
209
|
+
if isinstance(v, (list, tuple)) and not isinstance(v, str):
|
|
210
|
+
v = sorted(v)
|
|
211
|
+
search_path = search_path / f"{k}={v}"
|
|
212
|
+
if self.is_empty(search_path):
|
|
213
|
+
# 需要补全全部数据
|
|
214
|
+
missing_dates = dateList
|
|
215
|
+
else:
|
|
216
|
+
if not self._type_asset:
|
|
217
|
+
_search_path = self.save_path
|
|
218
|
+
for k, v in _constraints.items():
|
|
219
|
+
if k != "asset":
|
|
220
|
+
_search_path = _search_path / f"{k}={v}"
|
|
221
|
+
else:
|
|
222
|
+
_search_path = _search_path / f"asset={rep_asset}"
|
|
223
|
+
hive_info = parse_hive_partition_structure(_search_path)
|
|
224
|
+
else:
|
|
225
|
+
hive_info = parse_hive_partition_structure(search_path)
|
|
226
|
+
exist_dates = hive_info["date"].to_list()
|
|
227
|
+
missing_dates = set(dateList).difference(set(exist_dates))
|
|
228
|
+
missing_dates = sorted(list(missing_dates))
|
|
229
|
+
if missing_dates:
|
|
230
|
+
fn = self.fn
|
|
231
|
+
save_path = self.save_path
|
|
232
|
+
|
|
233
|
+
if self._type_asset:
|
|
234
|
+
if "asset" in _constraints:
|
|
235
|
+
fn = ygo.delay(self.fn)(asset=_constraints["asset"])
|
|
236
|
+
|
|
237
|
+
if len(self.constraints) < len(self.partitions) - len(self._append_partitions):
|
|
238
|
+
params = dict()
|
|
239
|
+
for k in self.partitions[:-len(self._append_partitions)]:
|
|
240
|
+
if k not in self.constraints:
|
|
241
|
+
v = constraints[k]
|
|
242
|
+
params[k] = v
|
|
243
|
+
save_path = save_path / f"{k}={v}"
|
|
244
|
+
fn = ygo.delay(self.fn)(**params)
|
|
245
|
+
|
|
246
|
+
with ygo.pool(n_jobs=n_jobs, backend=backend) as go:
|
|
247
|
+
info_path = self.save_path
|
|
248
|
+
try:
|
|
249
|
+
info_path = info_path.relative_to(DB_PATH)
|
|
250
|
+
except:
|
|
251
|
+
pass
|
|
252
|
+
for date in missing_dates:
|
|
253
|
+
go.submit(complete_data, job_name=f"Completing {info_path}")(
|
|
254
|
+
fn=fn,
|
|
255
|
+
date=date,
|
|
256
|
+
save_path=save_path,
|
|
257
|
+
partitions=self._append_partitions,
|
|
258
|
+
)
|
|
259
|
+
go.do()
|
|
260
|
+
data = scan(search_path, ).cast({"date": pl.Utf8}).filter(pl.col("date").is_in(dateList), **constraints)
|
|
261
|
+
data = data.sort("date")
|
|
262
|
+
if eager:
|
|
263
|
+
return data.collect()
|
|
264
|
+
return data
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def loader(data_name: str,
|
|
268
|
+
ds: Dataset,
|
|
269
|
+
date_list: list[str],
|
|
270
|
+
prev_date_list: list[str],
|
|
271
|
+
prev_date_mapping: dict[str, str],
|
|
272
|
+
time: str,
|
|
273
|
+
**constraints) -> pl.LazyFrame:
|
|
274
|
+
if time < ds.update_time:
|
|
275
|
+
if len(prev_date_list) > 1:
|
|
276
|
+
lf = ds.get_history(prev_date_list, eager=False, **constraints)
|
|
277
|
+
else:
|
|
278
|
+
lf = ds.get_value(prev_date_list[0], eager=False, **constraints)
|
|
279
|
+
else:
|
|
280
|
+
if len(date_list) > 1:
|
|
281
|
+
lf = ds.get_history(date_list, eager=False, **constraints)
|
|
282
|
+
else:
|
|
283
|
+
lf = ds.get_value(date_list[0], eager=False, **constraints)
|
|
284
|
+
schema = lf.collect_schema()
|
|
285
|
+
include_time = schema.get("time") is not None
|
|
286
|
+
if include_time:
|
|
287
|
+
lf = lf.filter(time=time)
|
|
288
|
+
else:
|
|
289
|
+
lf = lf.with_columns(time=pl.lit(time))
|
|
290
|
+
if time < ds.update_time:
|
|
291
|
+
lf = lf.with_columns(date=pl.col("date").replace(prev_date_mapping))
|
|
292
|
+
return data_name, lf
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def load_ds(ds_conf: dict[str, list[Dataset]],
|
|
296
|
+
beg_date: str,
|
|
297
|
+
end_date: str,
|
|
298
|
+
time: str,
|
|
299
|
+
n_jobs: int = 7,
|
|
300
|
+
backend: Literal["threading", "multiprocessing", "loky"] = "threading",
|
|
301
|
+
eager: bool = False,
|
|
302
|
+
**constraints) -> dict[str, pl.DataFrame | pl.LazyFrame]:
|
|
303
|
+
"""
|
|
304
|
+
加载数据集
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
ds_conf: dict[str, list[Dataset]]
|
|
308
|
+
数据集配置: key-data_name, value-list[Dataset]
|
|
309
|
+
beg_date: str
|
|
310
|
+
开始日期
|
|
311
|
+
end_date: str
|
|
312
|
+
结束日期
|
|
313
|
+
time: str
|
|
314
|
+
取值时间
|
|
315
|
+
n_jobs: int
|
|
316
|
+
并发数量
|
|
317
|
+
backend: str
|
|
318
|
+
eager: bool
|
|
319
|
+
是否返回 DataFrame
|
|
320
|
+
- True: 返回DataFrame
|
|
321
|
+
- False: 返回LazyFrame
|
|
322
|
+
constraints
|
|
323
|
+
限制条件,比如 asset='000001'
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
dict[str, polars.DataFrame | polars.LazyFrame]
|
|
327
|
+
- key: data_name
|
|
328
|
+
- value: polars.DataFrame
|
|
329
|
+
|
|
330
|
+
"""
|
|
331
|
+
if beg_date > end_date:
|
|
332
|
+
raise ValueError("beg_date must be less than end_date")
|
|
333
|
+
date_list = xcals.get_tradingdays(beg_date, end_date)
|
|
334
|
+
beg_date, end_date = date_list[0], date_list[-1]
|
|
335
|
+
prev_date_list = xcals.get_tradingdays(xcals.shift_tradeday(beg_date, -1), xcals.shift_tradeday(end_date, -1))
|
|
336
|
+
prev_date_mapping = {prev_date: date_list[i] for i, prev_date in enumerate(prev_date_list)}
|
|
337
|
+
results = defaultdict(list)
|
|
338
|
+
with ygo.pool(n_jobs=n_jobs, backend=backend) as go:
|
|
339
|
+
for data_name, ds_list in ds_conf.items():
|
|
340
|
+
for ds in ds_list:
|
|
341
|
+
go.submit(loader,
|
|
342
|
+
job_name="Loading",
|
|
343
|
+
postfix=data_name)(data_name=data_name,
|
|
344
|
+
ds=ds,
|
|
345
|
+
date_list=date_list,
|
|
346
|
+
prev_date_list=prev_date_list,
|
|
347
|
+
prev_date_mapping=prev_date_mapping,
|
|
348
|
+
time=time,
|
|
349
|
+
**constraints)
|
|
350
|
+
for name, lf in go.do():
|
|
351
|
+
results[name].append(lf)
|
|
352
|
+
index = ("date", "time", "asset")
|
|
353
|
+
LFs = {
|
|
354
|
+
name: (pl.concat(lfList, how="align")
|
|
355
|
+
.sort(index)
|
|
356
|
+
.select(*index,
|
|
357
|
+
cs.exclude(index))
|
|
358
|
+
)
|
|
359
|
+
for name, lfList in results.items()}
|
|
360
|
+
if not eager:
|
|
361
|
+
return LFs
|
|
362
|
+
return {
|
|
363
|
+
name: lf.collect()
|
|
364
|
+
for name, lf in LFs.items()
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class DataLoader:
|
|
369
|
+
|
|
370
|
+
def __init__(self, name: str):
|
|
371
|
+
self._name = name
|
|
372
|
+
self._lf: pl.LazyFrame = None
|
|
373
|
+
self._df: pl.DataFrame = None
|
|
374
|
+
self._index: tuple[str] = ("date", "time", "asset")
|
|
375
|
+
self._db: QDF = None
|
|
376
|
+
self._one: pl.DataFrame = None
|
|
377
|
+
|
|
378
|
+
def get(self,
|
|
379
|
+
ds_list: list[Dataset],
|
|
380
|
+
beg_date: str,
|
|
381
|
+
end_date: str,
|
|
382
|
+
n_jobs: int = 11,
|
|
383
|
+
backend: Literal["threading", "multiprocessing", "loky"] = "threading",
|
|
384
|
+
**constraints):
|
|
385
|
+
"""
|
|
386
|
+
添加数据集
|
|
387
|
+
Parameters
|
|
388
|
+
----------
|
|
389
|
+
ds_list: list[Dataset]
|
|
390
|
+
beg_date: str
|
|
391
|
+
end_date: str
|
|
392
|
+
n_jobs: int
|
|
393
|
+
backend: str
|
|
394
|
+
constraints
|
|
395
|
+
|
|
396
|
+
Returns
|
|
397
|
+
-------
|
|
398
|
+
|
|
399
|
+
"""
|
|
400
|
+
lf = load_ds(ds_conf={self._name: ds_list},
|
|
401
|
+
beg_date=beg_date,
|
|
402
|
+
end_date=end_date,
|
|
403
|
+
n_jobs=n_jobs,
|
|
404
|
+
backend=backend,
|
|
405
|
+
eager=False,
|
|
406
|
+
**constraints)
|
|
407
|
+
self._lf = lf.get(self._name)
|
|
408
|
+
self._df = None
|
|
409
|
+
self._db = from_polars(self._lf, self._index, align=True)
|
|
410
|
+
dateList = xcals.get_tradingdays(beg_date, end_date)
|
|
411
|
+
_data_name = f"{self._name}(one_day)"
|
|
412
|
+
self._one = load_ds(ds_conf={_data_name: ds_list},
|
|
413
|
+
beg_date=dateList[0],
|
|
414
|
+
end_date=dateList[0],
|
|
415
|
+
n_jobs=n_jobs,
|
|
416
|
+
backend=backend,
|
|
417
|
+
eager=False,
|
|
418
|
+
**constraints).get(_data_name).collect()
|
|
419
|
+
|
|
420
|
+
@property
|
|
421
|
+
def name(self) -> str:
|
|
422
|
+
return self._name
|
|
423
|
+
|
|
424
|
+
@property
|
|
425
|
+
def one_day(self) -> pl.DataFrame:
|
|
426
|
+
return self._one
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def schema(self) -> pl.Schema:
|
|
430
|
+
return self._one.schema
|
|
431
|
+
|
|
432
|
+
@property
|
|
433
|
+
def columns(self) -> list[str]:
|
|
434
|
+
return self._one.columns
|
|
435
|
+
|
|
436
|
+
def collect(self) -> pl.DataFrame:
|
|
437
|
+
if self._df is None:
|
|
438
|
+
self._df = self._lf.collect()
|
|
439
|
+
return self._df
|
|
440
|
+
|
|
441
|
+
def sql(self, *exprs: str) -> pl.DataFrame:
|
|
442
|
+
return self._db.sql(*exprs)
|