quantification 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantification/__init__.py +3 -2
- quantification/api/__init__.py +3 -0
- quantification/api/akshare/__init__.py +1 -0
- quantification/api/akshare/akshare.py +17 -0
- quantification/api/akshare/delegate/__init__.py +6 -0
- quantification/api/akshare/delegate/macro_china_fdi.py +46 -0
- quantification/api/akshare/delegate/macro_china_lpr.py +43 -0
- quantification/api/akshare/delegate/macro_china_qyspjg.py +51 -0
- quantification/api/akshare/delegate/macro_china_shrzgm.py +47 -0
- quantification/api/akshare/delegate/macro_cnbs.py +47 -0
- quantification/api/akshare/delegate/stock_zh_a_hist.py +77 -0
- quantification/api/akshare/setting.py +5 -0
- quantification/api/api.py +11 -0
- quantification/api/api.pyi +21 -0
- quantification/api/tushare/__init__.py +1 -0
- quantification/api/tushare/delegate/__init__.py +7 -0
- quantification/api/tushare/delegate/balancesheet.py +66 -0
- quantification/api/tushare/delegate/cashflow.py +29 -0
- quantification/api/tushare/delegate/common.py +64 -0
- quantification/api/tushare/delegate/daily_basic.py +81 -0
- quantification/api/tushare/delegate/fina_indicator.py +20 -0
- quantification/api/tushare/delegate/income.py +34 -0
- quantification/api/tushare/delegate/index_daily.py +61 -0
- quantification/api/tushare/delegate/pro_bar.py +80 -0
- quantification/api/tushare/setting.py +5 -0
- quantification/api/tushare/tushare.py +17 -0
- quantification/core/__init__.py +9 -0
- quantification/core/asset/__init__.py +6 -0
- quantification/core/asset/base_asset.py +96 -0
- quantification/core/asset/base_broker.py +42 -0
- quantification/core/asset/broker.py +108 -0
- quantification/core/asset/cash.py +75 -0
- quantification/core/asset/stock.py +268 -0
- quantification/core/cache.py +93 -0
- quantification/core/configure.py +15 -0
- quantification/core/data/__init__.py +5 -0
- quantification/core/data/base_api.py +109 -0
- quantification/core/data/base_delegate.py +73 -0
- quantification/core/data/field.py +213 -0
- quantification/core/data/panel.py +42 -0
- quantification/core/env.py +25 -0
- quantification/core/logger.py +94 -0
- quantification/core/strategy/__init__.py +3 -0
- quantification/core/strategy/base_strategy.py +66 -0
- quantification/core/strategy/base_trigger.py +69 -0
- quantification/core/strategy/base_use.py +69 -0
- quantification/core/trader/__init__.py +7 -0
- quantification/core/trader/base_order.py +45 -0
- quantification/core/trader/base_stage.py +16 -0
- quantification/core/trader/base_trader.py +173 -0
- quantification/core/trader/collector.py +47 -0
- quantification/core/trader/order.py +23 -0
- quantification/core/trader/portfolio.py +72 -0
- quantification/core/trader/query.py +29 -0
- quantification/core/trader/report.py +76 -0
- quantification/core/util.py +181 -0
- quantification/default/__init__.py +5 -0
- quantification/default/stage/__init__.py +1 -0
- quantification/default/stage/cn_stock.py +23 -0
- quantification/default/strategy/__init__.py +1 -0
- quantification/default/strategy/simple/__init__.py +1 -0
- quantification/default/strategy/simple/strategy.py +8 -0
- quantification/default/trader/__init__.py +2 -0
- quantification/default/trader/a_factor/__init__.py +1 -0
- quantification/default/trader/a_factor/trader.py +27 -0
- quantification/default/trader/simple/__init__.py +1 -0
- quantification/default/trader/simple/trader.py +8 -0
- quantification/default/trigger/__init__.py +1 -0
- quantification/default/trigger/trigger.py +63 -0
- quantification/default/use/__init__.py +1 -0
- quantification/default/use/factors/__init__.py +2 -0
- quantification/default/use/factors/factor.py +205 -0
- quantification/default/use/factors/use.py +38 -0
- quantification-0.1.1.dist-info/METADATA +19 -0
- quantification-0.1.1.dist-info/RECORD +76 -0
- {quantification-0.1.0.dist-info → quantification-0.1.1.dist-info}/WHEEL +1 -1
- quantification-0.1.0.dist-info/METADATA +0 -13
- quantification-0.1.0.dist-info/RECORD +0 -4
@@ -0,0 +1,268 @@
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
2
|
+
from typing import overload, Type
|
3
|
+
from datetime import date
|
4
|
+
from itertools import chain
|
5
|
+
|
6
|
+
import akshare as ak
|
7
|
+
|
8
|
+
from .base_asset import BaseAsset
|
9
|
+
|
10
|
+
from ..cache import cache_query
|
11
|
+
|
12
|
+
|
13
|
+
class BaseStockExchangeMeta(ABCMeta):
|
14
|
+
def __repr__(self: "BaseStockExchange"):
|
15
|
+
return self.name()
|
16
|
+
|
17
|
+
__str__ = __repr__
|
18
|
+
|
19
|
+
|
20
|
+
class BaseStockExchange(metaclass=BaseStockExchangeMeta):
|
21
|
+
@classmethod
|
22
|
+
@abstractmethod
|
23
|
+
def name(cls) -> str:
|
24
|
+
raise NotImplementedError
|
25
|
+
|
26
|
+
@classmethod
|
27
|
+
@abstractmethod
|
28
|
+
def code(cls) -> str:
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
|
32
|
+
# 上海证券交易所
|
33
|
+
class SSE(BaseStockExchange):
|
34
|
+
@classmethod
|
35
|
+
def name(cls) -> str:
|
36
|
+
return "上海证券交易所"
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def code(cls) -> str:
|
40
|
+
return "SH"
|
41
|
+
|
42
|
+
|
43
|
+
# 深圳证券交易所
|
44
|
+
class SZSE(BaseStockExchange):
|
45
|
+
@classmethod
|
46
|
+
def name(cls) -> str:
|
47
|
+
return "深圳证券交易所"
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def code(cls) -> str:
|
51
|
+
return "SZ"
|
52
|
+
|
53
|
+
|
54
|
+
# 北京证券交易所
|
55
|
+
class BSE(BaseStockExchange):
|
56
|
+
@classmethod
|
57
|
+
def name(cls) -> str:
|
58
|
+
return "北京证券交易所"
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def code(cls) -> str:
|
62
|
+
return "BJ"
|
63
|
+
|
64
|
+
|
65
|
+
# 香港交易所
|
66
|
+
class HKEX(BaseStockExchange):
|
67
|
+
@classmethod
|
68
|
+
def name(cls) -> str:
|
69
|
+
return "香港交易所"
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def code(cls) -> str:
|
73
|
+
return "HK"
|
74
|
+
|
75
|
+
|
76
|
+
# 纽约证券交易所
|
77
|
+
class NYSE(BaseStockExchange):
|
78
|
+
@classmethod
|
79
|
+
def name(cls) -> str:
|
80
|
+
return "纽约证券交易所"
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def code(cls) -> str:
|
84
|
+
return "N"
|
85
|
+
|
86
|
+
|
87
|
+
# 纳斯达克
|
88
|
+
class NASDAQ(BaseStockExchange):
|
89
|
+
@classmethod
|
90
|
+
def name(cls) -> str:
|
91
|
+
return "纳斯达克"
|
92
|
+
|
93
|
+
@classmethod
|
94
|
+
def code(cls) -> str:
|
95
|
+
return "O"
|
96
|
+
|
97
|
+
|
98
|
+
# 伦敦证券交易所
|
99
|
+
class LSE(BaseStockExchange):
|
100
|
+
@classmethod
|
101
|
+
def name(cls) -> str:
|
102
|
+
return "伦敦证券交易所"
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def code(cls) -> str:
|
106
|
+
return "L"
|
107
|
+
|
108
|
+
|
109
|
+
class StockExchange:
|
110
|
+
SSE = SSE
|
111
|
+
SZSE = SZSE
|
112
|
+
BSE = BSE
|
113
|
+
HKEX = HKEX
|
114
|
+
NYSE = NYSE
|
115
|
+
NASDAQ = NASDAQ
|
116
|
+
LSE = LSE
|
117
|
+
|
118
|
+
|
119
|
+
def predict_exchange(symbol: str) -> type[BaseStockExchange]:
|
120
|
+
if "." not in symbol:
|
121
|
+
api = cache_query(update=False)(ak.stock_zh_a_spot)
|
122
|
+
code_list = api()["代码"].tolist()
|
123
|
+
for code in code_list:
|
124
|
+
if code[2:] != symbol: continue
|
125
|
+
|
126
|
+
match code[:2]:
|
127
|
+
case "sh":
|
128
|
+
return StockExchange.SSE
|
129
|
+
case "sz":
|
130
|
+
return StockExchange.SZSE
|
131
|
+
case "bj":
|
132
|
+
return StockExchange.BSE
|
133
|
+
|
134
|
+
raise ValueError(f"无法根据{symbol}判断交易所")
|
135
|
+
|
136
|
+
_, exchange = symbol.split(".")
|
137
|
+
match exchange.lower():
|
138
|
+
case "sh":
|
139
|
+
return StockExchange.SSE
|
140
|
+
case "sz":
|
141
|
+
return StockExchange.SZSE
|
142
|
+
case "bj":
|
143
|
+
return StockExchange.BSE
|
144
|
+
case _:
|
145
|
+
raise ValueError(f"无法根据{symbol}判断交易所")
|
146
|
+
|
147
|
+
|
148
|
+
stock_family: dict[str:type["Stock"]] = {}
|
149
|
+
SharePosition = dict[date, int]
|
150
|
+
|
151
|
+
|
152
|
+
class Stock(BaseAsset):
|
153
|
+
symbol: str = None
|
154
|
+
exchange: type[BaseStockExchange] = None
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def type(cls, *args, **kwargs):
|
158
|
+
return "Stock"
|
159
|
+
|
160
|
+
@classmethod
|
161
|
+
def name(cls, *args, **kwargs):
|
162
|
+
return f"股票{cls.symbol}.{cls.exchange.code()}" if cls.symbol else "股票"
|
163
|
+
|
164
|
+
@property
|
165
|
+
def amount(self, *args, **kwargs):
|
166
|
+
return sum(self.share_position.values())
|
167
|
+
|
168
|
+
@property
|
169
|
+
def extra(self, *args, **kwargs):
|
170
|
+
return {"position": {k.isoformat(): int(v) for k, v in self.share_position.items()}}
|
171
|
+
|
172
|
+
@property
|
173
|
+
def is_empty(self):
|
174
|
+
return self.amount == 0
|
175
|
+
|
176
|
+
@property
|
177
|
+
def copy(self):
|
178
|
+
return Stock[self.symbol](self.share_position)
|
179
|
+
|
180
|
+
def available(self, day: date):
|
181
|
+
available_share = 0
|
182
|
+
|
183
|
+
for brought_day, share in self.share_position.items():
|
184
|
+
if brought_day < day:
|
185
|
+
available_share += share
|
186
|
+
|
187
|
+
return Stock[self.symbol](available_share)
|
188
|
+
|
189
|
+
def __add__(self, other: "Stock"):
|
190
|
+
assert self == other, f"只有同种股票可以相加减, {self.__class__.__name__} != {other.__class__.__name__}"
|
191
|
+
|
192
|
+
return Stock[self.symbol](self.add_shares(other.share_position))
|
193
|
+
|
194
|
+
def __sub__(self, other: "Stock"):
|
195
|
+
assert self == other, f"只有同种股票可以相加减, {self.__class__.__name__} != {other.__class__.__name__}"
|
196
|
+
|
197
|
+
day = next(iter(other.share_position.keys()))
|
198
|
+
available = self.available(day).amount
|
199
|
+
assert available >= other.amount, f"可出售股份不足, {available} < {other.amount}"
|
200
|
+
|
201
|
+
return Stock[self.symbol](self.sub_shares(other.amount))
|
202
|
+
|
203
|
+
def __eq__(self, other):
|
204
|
+
return isinstance(other, Stock) and self.symbol == other.symbol
|
205
|
+
|
206
|
+
def __repr__(self):
|
207
|
+
return f"{self.name()}({self.amount}股)"
|
208
|
+
|
209
|
+
__str__ = __repr__
|
210
|
+
|
211
|
+
def __init__(self, share: int | SharePosition):
|
212
|
+
assert self.symbol is not None, "未指定股票代码, 无法实例化, 是否应该使用Stock[symbol](...) ?"
|
213
|
+
|
214
|
+
if isinstance(share, int):
|
215
|
+
assert share >= 0, f"股票份数不能为负, 实际为{share=}"
|
216
|
+
self.share_position: SharePosition = {self.env.date: share}
|
217
|
+
else:
|
218
|
+
assert all([v >= 0 for v in share.values()]), f"股票份数不能为负, 实际为{share.values()=}"
|
219
|
+
self.share_position = share
|
220
|
+
|
221
|
+
def add_shares(self, share_position: SharePosition) -> SharePosition:
|
222
|
+
return {
|
223
|
+
k: self.share_position.get(k, 0) + share_position.get(k, 0)
|
224
|
+
for k in chain(self.share_position, share_position)
|
225
|
+
}
|
226
|
+
|
227
|
+
@overload
|
228
|
+
def __class_getitem__(cls, symbol: str):
|
229
|
+
...
|
230
|
+
|
231
|
+
@overload
|
232
|
+
def __class_getitem__(cls, symbol_and_exchange: tuple[str, Type[BaseStockExchange]]):
|
233
|
+
...
|
234
|
+
|
235
|
+
def __class_getitem__(cls, arg: str | tuple[str, Type[BaseStockExchange]]):
|
236
|
+
if isinstance(arg, str):
|
237
|
+
symbol = arg
|
238
|
+
exchange = predict_exchange(arg)
|
239
|
+
else:
|
240
|
+
symbol, exchange = arg
|
241
|
+
|
242
|
+
if not stock_family.get(f"{symbol}.{exchange}"):
|
243
|
+
stock_family[symbol] = type(
|
244
|
+
f"Stock{symbol}",
|
245
|
+
(Stock,),
|
246
|
+
{"symbol": symbol, "exchange": exchange},
|
247
|
+
)
|
248
|
+
|
249
|
+
return stock_family[symbol]
|
250
|
+
|
251
|
+
def sub_shares(self, share: int) -> SharePosition:
|
252
|
+
remaining_share = share
|
253
|
+
new_share_position = {}
|
254
|
+
sorted_dates = sorted(self.share_position.keys())
|
255
|
+
for sorted_date in sorted_dates:
|
256
|
+
current_share = self.share_position[sorted_date]
|
257
|
+
if remaining_share <= 0:
|
258
|
+
new_share_position[sorted_date] = current_share
|
259
|
+
continue
|
260
|
+
if current_share <= remaining_share:
|
261
|
+
remaining_share -= current_share
|
262
|
+
else:
|
263
|
+
new_share_position[sorted_date] = current_share - remaining_share
|
264
|
+
remaining_share = 0
|
265
|
+
return new_share_position
|
266
|
+
|
267
|
+
|
268
|
+
__all__ = ["Stock", "StockExchange"]
|
@@ -0,0 +1,93 @@
|
|
1
|
+
import time
|
2
|
+
import pickle
|
3
|
+
import hashlib
|
4
|
+
import inspect
|
5
|
+
from pathlib import Path
|
6
|
+
from functools import partial
|
7
|
+
|
8
|
+
from .configure import config
|
9
|
+
|
10
|
+
CACHE_DIR = Path(config.cache_dir)
|
11
|
+
CACHE_QUERY = CACHE_DIR / 'query'
|
12
|
+
|
13
|
+
CACHE_QUERY.mkdir(parents=True, exist_ok=True)
|
14
|
+
|
15
|
+
|
16
|
+
def flush(cache_file, func, *args, **kwargs):
|
17
|
+
result = func(*args, **kwargs)
|
18
|
+
with open(cache_file, 'wb') as file:
|
19
|
+
pickle.dump(result, file)
|
20
|
+
return result
|
21
|
+
|
22
|
+
|
23
|
+
def get_or_flush(cache_file, func, *args, **kwargs):
|
24
|
+
try:
|
25
|
+
with open(cache_file, 'rb') as f:
|
26
|
+
return pickle.load(f)
|
27
|
+
except (pickle.UnpicklingError, EOFError, FileNotFoundError):
|
28
|
+
return flush(cache_file, func, *args, **kwargs)
|
29
|
+
|
30
|
+
|
31
|
+
def cache_query(update=None, expire_seconds=86400):
|
32
|
+
"""缓存装饰器, 通过参数控制缓存更新逻辑
|
33
|
+
|
34
|
+
Args:
|
35
|
+
update (bool|None):
|
36
|
+
- True: 强制更新缓存
|
37
|
+
- False: 只要缓存存在就使用(忽略有效期)
|
38
|
+
- None: 根据 `expire_seconds` 判断是否更新
|
39
|
+
expire_seconds (int): 缓存有效期(秒), 仅在 update=None 时生效
|
40
|
+
"""
|
41
|
+
|
42
|
+
def decorator(func):
|
43
|
+
def wrapper(*args, **kwargs):
|
44
|
+
|
45
|
+
# 绑定参数生成唯一缓存键
|
46
|
+
sig = inspect.signature(func)
|
47
|
+
bound_args = sig.bind(*args, **kwargs)
|
48
|
+
bound_args.apply_defaults()
|
49
|
+
args_dict = bound_args.arguments
|
50
|
+
sorted_args = sorted(args_dict.items(), key=lambda x: x[0])
|
51
|
+
name = func.__name__ if not isinstance(func, partial) else func.func.__name__
|
52
|
+
key_data = (name, sorted_args)
|
53
|
+
|
54
|
+
# 计算哈希作为文件名
|
55
|
+
hash_key = hashlib.sha256(pickle.dumps(key_data)).hexdigest()
|
56
|
+
cache_file = CACHE_QUERY / f"{hash_key}.pkl"
|
57
|
+
cache_exists = cache_file.exists()
|
58
|
+
|
59
|
+
match update:
|
60
|
+
case True:
|
61
|
+
return flush(cache_file, func, *args, **kwargs)
|
62
|
+
|
63
|
+
case False:
|
64
|
+
if not cache_exists:
|
65
|
+
return flush(cache_file, func, *args, **kwargs)
|
66
|
+
|
67
|
+
return get_or_flush(cache_file, func, *args, **kwargs)
|
68
|
+
|
69
|
+
case None:
|
70
|
+
if not cache_exists:
|
71
|
+
return flush(cache_file, func, *args, **kwargs)
|
72
|
+
|
73
|
+
if expire_seconds is not None:
|
74
|
+
current_time = time.time()
|
75
|
+
mtime = cache_file.stat().st_mtime
|
76
|
+
cache_valid = (current_time - mtime) <= expire_seconds
|
77
|
+
else:
|
78
|
+
cache_valid = True # 无有效期要求
|
79
|
+
|
80
|
+
if not cache_valid:
|
81
|
+
return flush(cache_file, func, *args, **kwargs)
|
82
|
+
|
83
|
+
return get_or_flush(cache_file, func, *args, **kwargs)
|
84
|
+
|
85
|
+
case _:
|
86
|
+
raise ValueError("Invalid update parameter value")
|
87
|
+
|
88
|
+
return wrapper
|
89
|
+
|
90
|
+
return decorator
|
91
|
+
|
92
|
+
|
93
|
+
__all__ = ['cache_query']
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
3
|
+
|
4
|
+
|
5
|
+
class Config(BaseSettings):
|
6
|
+
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="allow")
|
7
|
+
|
8
|
+
log_level: str = Field(default="INFO", description="日志等级")
|
9
|
+
cache_dir: str = Field(default="./cache", description="缓存文件夹位置")
|
10
|
+
adjust: str = Field(description='复权方式')
|
11
|
+
|
12
|
+
|
13
|
+
config = Config()
|
14
|
+
|
15
|
+
__all__ = ["Config", "config"]
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import TypeVar, Generic
|
2
|
+
from datetime import date
|
3
|
+
from itertools import product
|
4
|
+
|
5
|
+
from pydantic import BaseModel
|
6
|
+
|
7
|
+
from .field import Field
|
8
|
+
from .panel import DataPanel
|
9
|
+
from .base_delegate import BaseDelegate
|
10
|
+
|
11
|
+
from ..logger import logger
|
12
|
+
from ..configure import config, Config
|
13
|
+
|
14
|
+
SettingType = TypeVar('SettingType', bound=BaseModel)
|
15
|
+
|
16
|
+
|
17
|
+
class BaseAPI(Generic[SettingType]):
|
18
|
+
setting_class: type[BaseModel]
|
19
|
+
delegate_classes: list[type[BaseDelegate[SettingType]]]
|
20
|
+
|
21
|
+
def __init_subclass__(cls, **kwargs):
|
22
|
+
assert hasattr(cls, "setting_class"), \
|
23
|
+
f"{cls.__name__}必须实现类属性setting_class"
|
24
|
+
|
25
|
+
assert issubclass(cls.setting_class, BaseModel), \
|
26
|
+
f"{cls.__name__}类属性setting_class必须为BaseModel子类, 实际为{cls.setting_class}"
|
27
|
+
|
28
|
+
assert hasattr(cls, "delegate_classes"), \
|
29
|
+
f"{cls.__name__}必须实现类属性delegate_classes"
|
30
|
+
|
31
|
+
assert isinstance(cls.delegate_classes, list), \
|
32
|
+
f"{cls.__name__}类属性delegate_classes必须为列表, 实际为{type(cls.delegate_classes)}"
|
33
|
+
|
34
|
+
for delegate_class in cls.delegate_classes:
|
35
|
+
assert issubclass(delegate_class, BaseDelegate), \
|
36
|
+
f"{cls.__name__}类属性delegate_classes的元素必须为BaseDelegate子类, 实际为{delegate_class}"
|
37
|
+
|
38
|
+
def __init__(self):
|
39
|
+
self.config = config
|
40
|
+
self.setting: SettingType = self.setting_class.model_validate(config.model_dump())
|
41
|
+
self.delegates = self.initialize_delegates()
|
42
|
+
|
43
|
+
def initialize_delegates(self):
|
44
|
+
return [i(self.config, self.setting) for i in self.delegate_classes]
|
45
|
+
|
46
|
+
def query(self, start_date: date, end_date: date, fields: list[Field], **kwargs):
|
47
|
+
fields_delegation: dict[BaseDelegate[SettingType], list[Field]] = {}
|
48
|
+
|
49
|
+
picked = []
|
50
|
+
for delegate, field in product(self.delegates, fields):
|
51
|
+
if field in picked:
|
52
|
+
continue
|
53
|
+
|
54
|
+
if not delegate.has_field(field, **kwargs):
|
55
|
+
continue
|
56
|
+
|
57
|
+
if not fields_delegation.get(delegate):
|
58
|
+
fields_delegation[delegate] = []
|
59
|
+
|
60
|
+
fields_delegation[delegate].append(field)
|
61
|
+
picked.append(field)
|
62
|
+
|
63
|
+
omitted = set(fields) - set(picked)
|
64
|
+
if omitted:
|
65
|
+
logger.warning(f"无法查询的字段: {omitted}")
|
66
|
+
|
67
|
+
res = DataPanel()
|
68
|
+
|
69
|
+
for delegate, fields in fields_delegation.items():
|
70
|
+
panel = delegate.execute(start_date, end_date, fields, **kwargs)
|
71
|
+
res <<= panel
|
72
|
+
|
73
|
+
return res
|
74
|
+
|
75
|
+
def __repr__(self):
|
76
|
+
return f"<API {self.__class__.__name__}>"
|
77
|
+
|
78
|
+
__str__ = __repr__
|
79
|
+
|
80
|
+
|
81
|
+
class BaseCombinedAPI(BaseAPI):
|
82
|
+
setting_class = Config
|
83
|
+
delegate_classes = []
|
84
|
+
|
85
|
+
api_classes: list[type[BaseAPI]]
|
86
|
+
|
87
|
+
def __init_subclass__(cls, **kwargs):
|
88
|
+
assert hasattr(cls, "api_classes"), \
|
89
|
+
f"{cls.__name__}必须实现类属性api_classes"
|
90
|
+
|
91
|
+
for api_class in cls.api_classes:
|
92
|
+
assert issubclass(api_class, BaseAPI), \
|
93
|
+
f"{cls.__name__}类属性delegate_classes元素必须为BaseAPI子类, 实际为{api_class}"
|
94
|
+
|
95
|
+
def initialize_delegates(self):
|
96
|
+
delegates = []
|
97
|
+
|
98
|
+
for api_class in self.api_classes:
|
99
|
+
delegates += api_class().delegates
|
100
|
+
|
101
|
+
return delegates
|
102
|
+
|
103
|
+
def __repr__(self):
|
104
|
+
return f"<CombinedAPI {'|'.join(map(lambda x: x.__name__, self.api_classes))}>"
|
105
|
+
|
106
|
+
__str__ = __repr__
|
107
|
+
|
108
|
+
|
109
|
+
__all__ = ["BaseAPI", "BaseCombinedAPI"]
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import TypeVar, Generic, Callable
|
3
|
+
from datetime import date
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
from .field import Field
|
9
|
+
from .panel import DataPanel
|
10
|
+
|
11
|
+
from ..configure import Config
|
12
|
+
|
13
|
+
SettingType = TypeVar('SettingType', bound=BaseModel)
|
14
|
+
|
15
|
+
|
16
|
+
class BaseDelegate(ABC, Generic[SettingType]):
|
17
|
+
pair: list[tuple[Field, str]]
|
18
|
+
field2str: dict[Field, str]
|
19
|
+
str2field: dict[str, Field]
|
20
|
+
|
21
|
+
DATE_FIELD = "日期"
|
22
|
+
|
23
|
+
def __init_subclass__(cls, **kwargs):
|
24
|
+
if ABC in cls.__bases__: return
|
25
|
+
|
26
|
+
assert hasattr(cls, 'pair'), f'{cls.__name__}必须实现类属性pair'
|
27
|
+
|
28
|
+
cls.field2str = {k: v for k, v in cls.pair}
|
29
|
+
cls.str2field = {v: k for k, v in cls.pair}
|
30
|
+
|
31
|
+
def __init__(self, config: Config, setting: SettingType) -> None:
|
32
|
+
self.config = config
|
33
|
+
self.setting = setting
|
34
|
+
|
35
|
+
def rename_columns(self, data: pd.DataFrame, date_field: str) -> pd.DataFrame:
|
36
|
+
data = data.rename(columns=self.str2field)
|
37
|
+
data = data.rename(columns={date_field: self.DATE_FIELD})
|
38
|
+
|
39
|
+
return data
|
40
|
+
|
41
|
+
def use_date_index(self, data: pd.DataFrame, formatter: Callable = None) -> pd.DataFrame:
|
42
|
+
if formatter:
|
43
|
+
data[self.DATE_FIELD] = data[self.DATE_FIELD].apply(formatter)
|
44
|
+
|
45
|
+
data[self.DATE_FIELD] = pd.to_datetime(data[self.DATE_FIELD])
|
46
|
+
data = data.set_index(self.DATE_FIELD)
|
47
|
+
data = data.sort_index()
|
48
|
+
|
49
|
+
return data
|
50
|
+
|
51
|
+
def execute(self, start_date: date, end_date: date, fields: list[Field], **kwargs) -> DataPanel:
|
52
|
+
data = self.query(start_date, end_date, fields, **kwargs)
|
53
|
+
mask = self.mask(data, start_date, end_date, fields, **kwargs)
|
54
|
+
data = data[fields]
|
55
|
+
mask = mask[fields]
|
56
|
+
data = data[start_date:end_date]
|
57
|
+
mask = mask[start_date:end_date]
|
58
|
+
return DataPanel(data, mask)
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def has_field(self, field: Field, **kwargs):
|
62
|
+
raise NotImplementedError
|
63
|
+
|
64
|
+
@abstractmethod
|
65
|
+
def query(self, start_date: date, end_date: date, fields: list[Field], **kwargs) -> pd.DataFrame:
|
66
|
+
raise NotImplementedError
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def mask(self, data: pd.DataFrame, start_date: date, end_date: date, fields: list[Field], **kwargs) -> pd.DataFrame:
|
70
|
+
raise NotImplementedError
|
71
|
+
|
72
|
+
|
73
|
+
__all__ = ["BaseDelegate"]
|