akquant 0.1.0__cp39-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of akquant might be problematic. Click here for more details.
- akquant/__init__.py +81 -0
- akquant/__pycache__/__init__.cpython-312.pyc +0 -0
- akquant/__pycache__/__init__.cpython-313.pyc +0 -0
- akquant/__pycache__/analyzer.cpython-312.pyc +0 -0
- akquant/__pycache__/analyzer.cpython-313.pyc +0 -0
- akquant/__pycache__/backtest.cpython-312.pyc +0 -0
- akquant/__pycache__/backtest.cpython-313.pyc +0 -0
- akquant/__pycache__/config.cpython-312.pyc +0 -0
- akquant/__pycache__/config.cpython-313.pyc +0 -0
- akquant/__pycache__/data.cpython-312.pyc +0 -0
- akquant/__pycache__/data.cpython-313.pyc +0 -0
- akquant/__pycache__/indicator.cpython-312.pyc +0 -0
- akquant/__pycache__/indicator.cpython-313.pyc +0 -0
- akquant/__pycache__/log.cpython-312.pyc +0 -0
- akquant/__pycache__/log.cpython-313.pyc +0 -0
- akquant/__pycache__/sizer.cpython-312.pyc +0 -0
- akquant/__pycache__/sizer.cpython-313.pyc +0 -0
- akquant/__pycache__/strategy.cpython-312.pyc +0 -0
- akquant/__pycache__/strategy.cpython-313.pyc +0 -0
- akquant/__pycache__/utils.cpython-312.pyc +0 -0
- akquant/__pycache__/utils.cpython-313.pyc +0 -0
- akquant/akquant.pyd +0 -0
- akquant/akquant.pyi +518 -0
- akquant/backtest.py +414 -0
- akquant/config.py +36 -0
- akquant/data.py +122 -0
- akquant/indicator.py +56 -0
- akquant/log.py +135 -0
- akquant/sizer.py +82 -0
- akquant/strategy.py +516 -0
- akquant/utils.py +167 -0
- akquant-0.1.0.dist-info/METADATA +149 -0
- akquant-0.1.0.dist-info/RECORD +35 -0
- akquant-0.1.0.dist-info/WHEEL +4 -0
- akquant-0.1.0.dist-info/licenses/LICENSE +21 -0
akquant/backtest.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import Union, List, Optional, Type, Callable, Dict, Any
|
|
3
|
+
from .akquant import Engine, ExecutionMode, Instrument, AssetType, DataFeed, Bar
|
|
4
|
+
from .strategy import Strategy
|
|
5
|
+
from .data import DataLoader
|
|
6
|
+
from .log import get_logger, register_logger
|
|
7
|
+
from .utils import prepare_dataframe, df_to_arrays
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FunctionalStrategy(Strategy):
|
|
11
|
+
"""
|
|
12
|
+
内部策略包装器,用于支持函数式 API (Zipline 风格)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, initialize: Callable, on_bar: Callable, context: Dict[str, Any] = None
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._initialize = initialize
|
|
20
|
+
self._on_bar_func = on_bar
|
|
21
|
+
self._context = context or {}
|
|
22
|
+
|
|
23
|
+
# 将 context 注入到 self 中,模拟 Zipline 的 context 对象
|
|
24
|
+
# 用户可以通过 self.xxx 访问 context 属性
|
|
25
|
+
for k, v in self._context.items():
|
|
26
|
+
setattr(self, k, v)
|
|
27
|
+
|
|
28
|
+
# 调用初始化函数
|
|
29
|
+
if self._initialize:
|
|
30
|
+
self._initialize(self)
|
|
31
|
+
|
|
32
|
+
def on_bar(self, bar: Bar):
|
|
33
|
+
if self._on_bar_func:
|
|
34
|
+
self._on_bar_func(self, bar)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def run_backtest(
|
|
38
|
+
data: Union[pd.DataFrame, List[Bar]],
|
|
39
|
+
strategy: Union[Type[Strategy], Strategy, Callable[[Any, Bar], None]],
|
|
40
|
+
symbol: str = "BENCHMARK",
|
|
41
|
+
cash: float = 1_000_000.0,
|
|
42
|
+
commission: float = 0.0003,
|
|
43
|
+
stamp_tax: float = 0.0005,
|
|
44
|
+
transfer_fee: float = 0.00001,
|
|
45
|
+
min_commission: float = 5.0,
|
|
46
|
+
execution_mode: Union[ExecutionMode, str] = ExecutionMode.NextOpen,
|
|
47
|
+
timezone: str = "Asia/Shanghai",
|
|
48
|
+
initialize: Optional[Callable[[Any], None]] = None,
|
|
49
|
+
context: Optional[Dict[str, Any]] = None,
|
|
50
|
+
history_depth: int = 0,
|
|
51
|
+
lot_size: Union[int, Dict[str, int], None] = None,
|
|
52
|
+
show_progress: bool = True,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
简化版回测入口函数
|
|
57
|
+
|
|
58
|
+
:param data: 回测数据,可以是 Pandas DataFrame 或 Bar 列表
|
|
59
|
+
:param strategy: 策略类、策略实例或 on_bar 回调函数
|
|
60
|
+
:param symbol: 标的代码
|
|
61
|
+
:param cash: 初始资金
|
|
62
|
+
:param commission: 佣金率
|
|
63
|
+
:param stamp_tax: 印花税率 (仅卖出)
|
|
64
|
+
:param transfer_fee: 过户费率
|
|
65
|
+
:param min_commission: 最低佣金
|
|
66
|
+
:param execution_mode: 执行模式 (ExecutionMode.NextOpen 或 "next_open")
|
|
67
|
+
:param timezone: 时区名称
|
|
68
|
+
:param initialize: 初始化回调函数 (仅当 strategy 为函数时使用)
|
|
69
|
+
:param context: 初始上下文数据 (仅当 strategy 为函数时使用)
|
|
70
|
+
:param history_depth: 自动维护历史数据的长度 (0 表示禁用)
|
|
71
|
+
:param lot_size: 最小交易单位。如果是 int,则应用于所有标的;如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
|
|
72
|
+
:param show_progress: 是否显示进度条 (默认 True)
|
|
73
|
+
:return: 回测结果 Result 对象
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
# 1. 确保日志已初始化
|
|
77
|
+
# 如果用户没有配置过日志,这里会提供一个默认配置
|
|
78
|
+
logger = get_logger()
|
|
79
|
+
if not logger.handlers:
|
|
80
|
+
register_logger(console=True, level="INFO")
|
|
81
|
+
logger = get_logger()
|
|
82
|
+
|
|
83
|
+
# 2. 准备数据
|
|
84
|
+
feed = DataFeed()
|
|
85
|
+
symbols = []
|
|
86
|
+
|
|
87
|
+
# Normalize symbol to list
|
|
88
|
+
if isinstance(symbol, str):
|
|
89
|
+
symbols = [symbol]
|
|
90
|
+
elif isinstance(symbol, list):
|
|
91
|
+
symbols = symbol
|
|
92
|
+
else:
|
|
93
|
+
# If symbol not provided, try to infer from Dict keys or use default
|
|
94
|
+
symbols = ["BENCHMARK"]
|
|
95
|
+
|
|
96
|
+
if isinstance(data, pd.DataFrame):
|
|
97
|
+
# Single DataFrame -> Single Symbol (use first symbol)
|
|
98
|
+
target_symbol = symbols[0] if symbols else "BENCHMARK"
|
|
99
|
+
df = prepare_dataframe(data)
|
|
100
|
+
# Fast Path: Avoid creating Bar objects in Python
|
|
101
|
+
arrays = df_to_arrays(df, symbol=target_symbol)
|
|
102
|
+
feed.add_arrays(*arrays)
|
|
103
|
+
feed.sort()
|
|
104
|
+
|
|
105
|
+
if target_symbol not in symbols:
|
|
106
|
+
symbols = [target_symbol]
|
|
107
|
+
|
|
108
|
+
elif isinstance(data, dict):
|
|
109
|
+
# Dict[str, DataFrame] -> Multi Symbol
|
|
110
|
+
symbols = list(data.keys())
|
|
111
|
+
for sym, df in data.items():
|
|
112
|
+
df_prep = prepare_dataframe(df)
|
|
113
|
+
# Fast Path
|
|
114
|
+
arrays = df_to_arrays(df_prep, symbol=sym)
|
|
115
|
+
feed.add_arrays(*arrays)
|
|
116
|
+
feed.sort()
|
|
117
|
+
|
|
118
|
+
elif isinstance(data, list):
|
|
119
|
+
# List[Bar]
|
|
120
|
+
if data:
|
|
121
|
+
data.sort(key=lambda b: b.timestamp)
|
|
122
|
+
feed.add_bars(data)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("data must be a DataFrame, Dict[str, DataFrame], or List[Bar]")
|
|
125
|
+
|
|
126
|
+
# 3. 设置引擎
|
|
127
|
+
engine = Engine()
|
|
128
|
+
engine.set_timezone_name(timezone)
|
|
129
|
+
engine.set_cash(cash)
|
|
130
|
+
|
|
131
|
+
# ... (ExecutionMode logic)
|
|
132
|
+
if isinstance(execution_mode, str):
|
|
133
|
+
mode_map = {
|
|
134
|
+
"next_open": ExecutionMode.NextOpen,
|
|
135
|
+
"current_close": ExecutionMode.CurrentClose,
|
|
136
|
+
}
|
|
137
|
+
mode = mode_map.get(execution_mode.lower())
|
|
138
|
+
if not mode:
|
|
139
|
+
logger.warning(
|
|
140
|
+
f"Unknown execution mode '{execution_mode}', defaulting to NextOpen"
|
|
141
|
+
)
|
|
142
|
+
mode = ExecutionMode.NextOpen
|
|
143
|
+
engine.set_execution_mode(mode)
|
|
144
|
+
else:
|
|
145
|
+
engine.set_execution_mode(execution_mode)
|
|
146
|
+
|
|
147
|
+
engine.set_t_plus_one(False) # 默认 T+0,可配置
|
|
148
|
+
engine.set_force_session_continuous(True)
|
|
149
|
+
engine.set_stock_fee_rules(commission, stamp_tax, transfer_fee, min_commission)
|
|
150
|
+
|
|
151
|
+
# Configure other asset fees if provided
|
|
152
|
+
if "fund_commission" in kwargs:
|
|
153
|
+
engine.set_fund_fee_rules(
|
|
154
|
+
kwargs["fund_commission"],
|
|
155
|
+
kwargs.get("fund_transfer_fee", 0.0),
|
|
156
|
+
kwargs.get("fund_min_commission", 0.0),
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if "option_commission" in kwargs:
|
|
160
|
+
engine.set_option_fee_rules(kwargs["option_commission"])
|
|
161
|
+
|
|
162
|
+
# 4. 添加标的
|
|
163
|
+
multiplier = kwargs.get("multiplier", 1.0)
|
|
164
|
+
margin_ratio = kwargs.get("margin_ratio", 1.0)
|
|
165
|
+
tick_size = kwargs.get("tick_size", 0.01)
|
|
166
|
+
asset_type = kwargs.get("asset_type", AssetType.Stock)
|
|
167
|
+
|
|
168
|
+
# Option specific fields
|
|
169
|
+
option_type = kwargs.get("option_type", None)
|
|
170
|
+
strike_price = kwargs.get("strike_price", None)
|
|
171
|
+
expiry_date = kwargs.get("expiry_date", None)
|
|
172
|
+
# lot_size is handled separately via argument
|
|
173
|
+
|
|
174
|
+
for sym in symbols:
|
|
175
|
+
# Determine lot_size for this symbol
|
|
176
|
+
current_lot_size = None
|
|
177
|
+
if isinstance(lot_size, int):
|
|
178
|
+
current_lot_size = lot_size
|
|
179
|
+
elif isinstance(lot_size, dict):
|
|
180
|
+
current_lot_size = lot_size.get(sym)
|
|
181
|
+
|
|
182
|
+
instr = Instrument(
|
|
183
|
+
sym,
|
|
184
|
+
asset_type,
|
|
185
|
+
multiplier,
|
|
186
|
+
margin_ratio,
|
|
187
|
+
tick_size,
|
|
188
|
+
option_type,
|
|
189
|
+
strike_price,
|
|
190
|
+
expiry_date,
|
|
191
|
+
current_lot_size,
|
|
192
|
+
)
|
|
193
|
+
engine.add_instrument(instr)
|
|
194
|
+
|
|
195
|
+
# 5. 添加数据
|
|
196
|
+
engine.add_data(feed)
|
|
197
|
+
|
|
198
|
+
# ... (Rest is same)
|
|
199
|
+
|
|
200
|
+
# 6. 准备策略实例
|
|
201
|
+
strategy_instance = None
|
|
202
|
+
|
|
203
|
+
if isinstance(strategy, type) and issubclass(strategy, Strategy):
|
|
204
|
+
# 如果是策略类,实例化它
|
|
205
|
+
# 尝试传递 kwargs 给构造函数,如果失败则无参数构造
|
|
206
|
+
try:
|
|
207
|
+
strategy_instance = strategy(**kwargs)
|
|
208
|
+
except TypeError:
|
|
209
|
+
strategy_instance = strategy()
|
|
210
|
+
elif isinstance(strategy, Strategy):
|
|
211
|
+
# 如果已经是实例
|
|
212
|
+
strategy_instance = strategy
|
|
213
|
+
elif callable(strategy):
|
|
214
|
+
# 如果是函数,假设是 on_bar 回调 (Zipline 风格)
|
|
215
|
+
# 需要配合 initialize 使用
|
|
216
|
+
strategy_instance = FunctionalStrategy(initialize, strategy, context)
|
|
217
|
+
else:
|
|
218
|
+
raise ValueError("Invalid strategy type")
|
|
219
|
+
|
|
220
|
+
# 7. 运行回测
|
|
221
|
+
logger.info("Running backtest via run_backtest()...")
|
|
222
|
+
|
|
223
|
+
# 注入 context 到策略实例
|
|
224
|
+
if context and hasattr(strategy_instance, "_context"):
|
|
225
|
+
# 如果是 FunctionalStrategy
|
|
226
|
+
# 已经在 __init__ 中注入了
|
|
227
|
+
pass
|
|
228
|
+
elif context and strategy_instance:
|
|
229
|
+
# 如果是普通 Strategy,尝试注入属性
|
|
230
|
+
for k, v in context.items():
|
|
231
|
+
setattr(strategy_instance, k, v)
|
|
232
|
+
|
|
233
|
+
# 设置自动历史数据维护
|
|
234
|
+
if history_depth > 0:
|
|
235
|
+
strategy_instance.set_history_depth(history_depth)
|
|
236
|
+
|
|
237
|
+
engine.run(strategy_instance, show_progress)
|
|
238
|
+
|
|
239
|
+
return engine.get_results()
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def plot_result(
|
|
243
|
+
result: Any, show: bool = True, filename: str = None, benchmark: pd.Series = None
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
绘制回测结果 (权益曲线、回撤、日收益率)
|
|
247
|
+
|
|
248
|
+
:param result: BacktestResult 对象
|
|
249
|
+
:param show: 是否调用 plt.show()
|
|
250
|
+
:param filename: 保存图片的文件名
|
|
251
|
+
:param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
|
|
252
|
+
"""
|
|
253
|
+
try:
|
|
254
|
+
import matplotlib.pyplot as plt
|
|
255
|
+
import matplotlib.dates as mdates
|
|
256
|
+
from matplotlib.gridspec import GridSpec
|
|
257
|
+
from datetime import datetime
|
|
258
|
+
import numpy as np
|
|
259
|
+
except ImportError:
|
|
260
|
+
print(
|
|
261
|
+
"Error: matplotlib is required for plotting. Please install it via 'pip install matplotlib'."
|
|
262
|
+
)
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
# Extract data
|
|
266
|
+
equity_curve = result.equity_curve # List[Tuple[int, float]]
|
|
267
|
+
|
|
268
|
+
if not equity_curve:
|
|
269
|
+
print("No equity curve data to plot.")
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
# Check if timestamp is in nanoseconds (e.g. > 1e11)
|
|
273
|
+
# 1e11 seconds is roughly year 5138, so valid seconds are < 1e11
|
|
274
|
+
# 1e18 nanoseconds is roughly year 2001
|
|
275
|
+
first_ts = equity_curve[0][0]
|
|
276
|
+
scale = 1.0
|
|
277
|
+
if first_ts > 1e11:
|
|
278
|
+
scale = 1e-9
|
|
279
|
+
|
|
280
|
+
times = [datetime.fromtimestamp(t * scale) for t, _ in equity_curve]
|
|
281
|
+
equity = [e for _, e in equity_curve]
|
|
282
|
+
|
|
283
|
+
# Convert to DataFrame for easier calculation
|
|
284
|
+
df = pd.DataFrame({"equity": equity}, index=times)
|
|
285
|
+
df.index.name = "Date"
|
|
286
|
+
df["returns"] = df["equity"].pct_change().fillna(0)
|
|
287
|
+
|
|
288
|
+
# Calculate Drawdown
|
|
289
|
+
rolling_max = df["equity"].cummax()
|
|
290
|
+
drawdown = (df["equity"] - rolling_max) / rolling_max
|
|
291
|
+
|
|
292
|
+
# Create figure with GridSpec
|
|
293
|
+
fig = plt.figure(figsize=(14, 10))
|
|
294
|
+
# 3 rows: Equity (3), Drawdown (1), Daily Returns (1)
|
|
295
|
+
gs = GridSpec(3, 1, height_ratios=[3, 1, 1], hspace=0.05)
|
|
296
|
+
|
|
297
|
+
# 1. Equity Curve
|
|
298
|
+
ax1 = fig.add_subplot(gs[0])
|
|
299
|
+
ax1.plot(df.index, df["equity"], label="Strategy", color="#1f77b4", linewidth=1.5)
|
|
300
|
+
|
|
301
|
+
if benchmark is not None:
|
|
302
|
+
# Align benchmark to strategy dates
|
|
303
|
+
try:
|
|
304
|
+
# Ensure benchmark has DatetimeIndex
|
|
305
|
+
if not isinstance(benchmark.index, pd.DatetimeIndex):
|
|
306
|
+
benchmark.index = pd.to_datetime(benchmark.index)
|
|
307
|
+
|
|
308
|
+
# Reindex benchmark to match strategy dates (forward fill for missing days if any, but strict alignment preferred)
|
|
309
|
+
# Normalize dates to start of day for alignment if needed, but here we assume close match
|
|
310
|
+
# For simplicity, we just plot what overlaps
|
|
311
|
+
|
|
312
|
+
# Calculate cumulative return of benchmark
|
|
313
|
+
bench_cum = (1 + benchmark).cumprod()
|
|
314
|
+
|
|
315
|
+
# Rebase benchmark to match initial strategy equity
|
|
316
|
+
initial_equity = df["equity"].iloc[0]
|
|
317
|
+
if not bench_cum.empty:
|
|
318
|
+
# Align start
|
|
319
|
+
# Find the closest date in benchmark to start date
|
|
320
|
+
start_date = df.index[0]
|
|
321
|
+
if start_date in bench_cum.index:
|
|
322
|
+
base_val = bench_cum.loc[start_date]
|
|
323
|
+
else:
|
|
324
|
+
# Fallback: use first available
|
|
325
|
+
base_val = bench_cum.iloc[0]
|
|
326
|
+
|
|
327
|
+
bench_scaled = (bench_cum / base_val) * initial_equity
|
|
328
|
+
|
|
329
|
+
# Filter to strategy range
|
|
330
|
+
bench_plot = bench_scaled[df.index[0] : df.index[-1]]
|
|
331
|
+
ax1.plot(
|
|
332
|
+
bench_plot.index,
|
|
333
|
+
bench_plot,
|
|
334
|
+
label="Benchmark",
|
|
335
|
+
color="gray",
|
|
336
|
+
linestyle="--",
|
|
337
|
+
alpha=0.7,
|
|
338
|
+
)
|
|
339
|
+
except Exception as e:
|
|
340
|
+
print(f"Warning: Failed to plot benchmark: {e}")
|
|
341
|
+
|
|
342
|
+
ax1.set_title("Strategy Performance Analysis", fontsize=14, fontweight="bold")
|
|
343
|
+
ax1.set_ylabel("Equity", fontsize=10)
|
|
344
|
+
ax1.grid(True, linestyle="--", alpha=0.3)
|
|
345
|
+
ax1.legend(loc="upper left", frameon=True, fancybox=True, framealpha=0.8)
|
|
346
|
+
|
|
347
|
+
# Add Metrics Text Box
|
|
348
|
+
metrics = result.metrics
|
|
349
|
+
trade_metrics = result.trade_metrics
|
|
350
|
+
|
|
351
|
+
metrics_text = [
|
|
352
|
+
f"Total Return: {metrics.total_return_pct:>8.2f}%",
|
|
353
|
+
f"Annualized: {metrics.annualized_return:>8.2%}",
|
|
354
|
+
f"Sharpe Ratio: {metrics.sharpe_ratio:>8.2f}",
|
|
355
|
+
f"Max Drawdown: {metrics.max_drawdown_pct:>8.2f}%",
|
|
356
|
+
f"Win Rate: {metrics.win_rate:>8.2%}",
|
|
357
|
+
]
|
|
358
|
+
|
|
359
|
+
if hasattr(trade_metrics, "total_closed_trades"):
|
|
360
|
+
metrics_text.append(f"Trades: {trade_metrics.total_closed_trades:>8d}")
|
|
361
|
+
|
|
362
|
+
text_str = "\n".join(metrics_text)
|
|
363
|
+
|
|
364
|
+
props = dict(boxstyle="round", facecolor="white", alpha=0.8, edgecolor="lightgray")
|
|
365
|
+
ax1.text(
|
|
366
|
+
0.02,
|
|
367
|
+
0.05,
|
|
368
|
+
text_str,
|
|
369
|
+
transform=ax1.transAxes,
|
|
370
|
+
fontsize=9,
|
|
371
|
+
verticalalignment="bottom",
|
|
372
|
+
fontfamily="monospace",
|
|
373
|
+
bbox=props,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# 2. Drawdown
|
|
377
|
+
ax2 = fig.add_subplot(gs[1], sharex=ax1)
|
|
378
|
+
ax2.fill_between(
|
|
379
|
+
df.index, drawdown, 0, color="#d62728", alpha=0.3, label="Drawdown"
|
|
380
|
+
)
|
|
381
|
+
ax2.plot(df.index, drawdown, color="#d62728", linewidth=0.8, alpha=0.8)
|
|
382
|
+
ax2.set_ylabel("Drawdown", fontsize=10)
|
|
383
|
+
ax2.grid(True, linestyle="--", alpha=0.3)
|
|
384
|
+
# ax2.legend(loc='lower right', fontsize=8)
|
|
385
|
+
|
|
386
|
+
# 3. Daily Returns
|
|
387
|
+
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
|
388
|
+
ax3.bar(
|
|
389
|
+
df.index,
|
|
390
|
+
df["returns"],
|
|
391
|
+
color="gray",
|
|
392
|
+
alpha=0.5,
|
|
393
|
+
label="Daily Returns",
|
|
394
|
+
width=1.0 if len(df) < 100 else 0.8,
|
|
395
|
+
)
|
|
396
|
+
# Highlight extreme returns? No, keep simple.
|
|
397
|
+
ax3.set_ylabel("Returns", fontsize=10)
|
|
398
|
+
ax3.grid(True, linestyle="--", alpha=0.3)
|
|
399
|
+
|
|
400
|
+
# Format X axis
|
|
401
|
+
ax3.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
|
|
402
|
+
plt.setp(ax1.get_xticklabels(), visible=False)
|
|
403
|
+
plt.setp(ax2.get_xticklabels(), visible=False)
|
|
404
|
+
plt.xticks(rotation=0)
|
|
405
|
+
|
|
406
|
+
# Adjust margins
|
|
407
|
+
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.08, right=0.95)
|
|
408
|
+
|
|
409
|
+
if filename:
|
|
410
|
+
plt.savefig(filename, dpi=100, bbox_inches="tight")
|
|
411
|
+
print(f"Plot saved to {filename}")
|
|
412
|
+
|
|
413
|
+
if show:
|
|
414
|
+
plt.show()
|
akquant/config.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class StrategyConfig:
|
|
7
|
+
"""
|
|
8
|
+
Global configuration for strategies and backtesting.
|
|
9
|
+
Inspired by PyBroker's configuration system.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
# Capital Management
|
|
13
|
+
initial_cash: float = 100000.0
|
|
14
|
+
|
|
15
|
+
# Fees & Commission
|
|
16
|
+
fee_mode: str = "per_order" # 'per_order', 'per_share', 'percent'
|
|
17
|
+
fee_amount: float = 0.0 # Fixed amount or percentage
|
|
18
|
+
|
|
19
|
+
# Execution
|
|
20
|
+
enable_fractional_shares: bool = False
|
|
21
|
+
round_fill_price: bool = True
|
|
22
|
+
|
|
23
|
+
# Position Sizing Constraints
|
|
24
|
+
max_long_positions: Optional[int] = None
|
|
25
|
+
max_short_positions: Optional[int] = None
|
|
26
|
+
|
|
27
|
+
# Bootstrap Metrics
|
|
28
|
+
bootstrap_samples: int = 1000
|
|
29
|
+
bootstrap_sample_size: Optional[int] = None
|
|
30
|
+
|
|
31
|
+
# Other
|
|
32
|
+
exit_on_last_bar: bool = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Global instance
|
|
36
|
+
strategy_config = StrategyConfig()
|
akquant/data.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import hashlib
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import akshare as ak
|
|
9
|
+
except ImportError:
|
|
10
|
+
ak = None
|
|
11
|
+
|
|
12
|
+
from .utils import load_akshare_bar
|
|
13
|
+
from .akquant import Bar
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DataLoader:
|
|
19
|
+
"""
|
|
20
|
+
Data Loader with caching capabilities, inspired by PyBroker.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, cache_dir: Optional[str] = None):
|
|
24
|
+
"""
|
|
25
|
+
Initialize DataLoader.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
cache_dir (str, optional): Directory to store cache files.
|
|
29
|
+
Defaults to ~/.akquant/cache.
|
|
30
|
+
"""
|
|
31
|
+
if cache_dir:
|
|
32
|
+
self.cache_dir = Path(cache_dir)
|
|
33
|
+
else:
|
|
34
|
+
self.cache_dir = Path.home() / ".akquant" / "cache"
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
if not self.cache_dir.exists():
|
|
38
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
except PermissionError:
|
|
40
|
+
logger.warning(
|
|
41
|
+
f"Permission denied for {self.cache_dir}, falling back to local .akquant_cache"
|
|
42
|
+
)
|
|
43
|
+
self.cache_dir = Path.cwd() / ".akquant_cache"
|
|
44
|
+
if not self.cache_dir.exists():
|
|
45
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
|
|
47
|
+
def _get_cache_path(self, key: str) -> Path:
|
|
48
|
+
"""Generate cache file path based on a unique key."""
|
|
49
|
+
# Use a hash of the key to avoid filesystem issues with long/invalid filenames
|
|
50
|
+
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
51
|
+
return self.cache_dir / f"{hashed_key}.pkl"
|
|
52
|
+
|
|
53
|
+
def load_akshare(
|
|
54
|
+
self,
|
|
55
|
+
symbol: str,
|
|
56
|
+
start_date: str,
|
|
57
|
+
end_date: str,
|
|
58
|
+
adjust: str = "qfq",
|
|
59
|
+
period: str = "daily",
|
|
60
|
+
use_cache: bool = True,
|
|
61
|
+
) -> pd.DataFrame:
|
|
62
|
+
"""
|
|
63
|
+
Load A-share history data from AKShare with caching.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
symbol (str): Stock symbol (e.g., "600000").
|
|
67
|
+
start_date (str): Start date (YYYYMMDD).
|
|
68
|
+
end_date (str): End date (YYYYMMDD).
|
|
69
|
+
adjust (str): Adjustment factor ("qfq", "hfq", ""). Default "qfq".
|
|
70
|
+
period (str): Period ("daily", "weekly", "monthly"). Default "daily".
|
|
71
|
+
use_cache (bool): Whether to use cache. Default True.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
pd.DataFrame: Historical data.
|
|
75
|
+
"""
|
|
76
|
+
if ak is None:
|
|
77
|
+
raise ImportError(
|
|
78
|
+
"akshare is not installed. Please run `pip install akshare`."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
cache_key = f"akshare_stock_zh_a_hist_{symbol}_{start_date}_{end_date}_{adjust}_{period}"
|
|
82
|
+
cache_path = self._get_cache_path(cache_key)
|
|
83
|
+
|
|
84
|
+
if use_cache and cache_path.exists():
|
|
85
|
+
logger.info(f"Loading cached data for {symbol} from {cache_path}")
|
|
86
|
+
try:
|
|
87
|
+
df = pd.read_pickle(cache_path)
|
|
88
|
+
return df
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.warning(f"Failed to load cache: {e}. Reloading from source.")
|
|
91
|
+
|
|
92
|
+
logger.info(f"Fetching data for {symbol} from AKShare...")
|
|
93
|
+
try:
|
|
94
|
+
df = ak.stock_zh_a_hist(
|
|
95
|
+
symbol=symbol,
|
|
96
|
+
period=period,
|
|
97
|
+
start_date=start_date,
|
|
98
|
+
end_date=end_date,
|
|
99
|
+
adjust=adjust,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Basic validation
|
|
103
|
+
if df.empty:
|
|
104
|
+
logger.warning(f"No data found for {symbol}")
|
|
105
|
+
return df
|
|
106
|
+
|
|
107
|
+
# Cache the result
|
|
108
|
+
if use_cache:
|
|
109
|
+
df.to_pickle(cache_path)
|
|
110
|
+
logger.info(f"Data cached to {cache_path}")
|
|
111
|
+
|
|
112
|
+
return df
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.error(f"Error fetching data from AKShare: {e}")
|
|
115
|
+
raise
|
|
116
|
+
|
|
117
|
+
def df_to_bars(self, df: pd.DataFrame, symbol: Optional[str] = None) -> List[Bar]:
|
|
118
|
+
"""
|
|
119
|
+
Convert DataFrame to list of Bar objects.
|
|
120
|
+
Wrapper around utils.load_akshare_bar.
|
|
121
|
+
"""
|
|
122
|
+
return load_akshare_bar(df, symbol)
|
akquant/indicator.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Callable, Dict
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Indicator:
|
|
6
|
+
"""
|
|
7
|
+
Helper class for defining and calculating indicators.
|
|
8
|
+
Inspired by PyBroker's indicator system.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, name: str, fn: Callable, **kwargs):
|
|
12
|
+
self.name = name
|
|
13
|
+
self.fn = fn
|
|
14
|
+
self.kwargs = kwargs
|
|
15
|
+
self._data: Dict[str, pd.Series] = {} # symbol -> series
|
|
16
|
+
|
|
17
|
+
def __call__(self, df: pd.DataFrame, symbol: str) -> pd.Series:
|
|
18
|
+
"""Calculate indicator on a DataFrame"""
|
|
19
|
+
if symbol in self._data:
|
|
20
|
+
return self._data[symbol]
|
|
21
|
+
|
|
22
|
+
# Assume fn takes a series/df and returns a series
|
|
23
|
+
# If kwargs contains column names, extract them
|
|
24
|
+
# This is a simplified version of PyBroker's powerful DSL
|
|
25
|
+
try:
|
|
26
|
+
result = self.fn(df, **self.kwargs)
|
|
27
|
+
except Exception:
|
|
28
|
+
# Try passing column if specified in kwargs
|
|
29
|
+
# e.g. rolling_mean(df['close'], window=5)
|
|
30
|
+
# This part is tricky to generalize without a full DSL,
|
|
31
|
+
# so we start simple: user passes a lambda or function that takes df
|
|
32
|
+
result = self.fn(df)
|
|
33
|
+
|
|
34
|
+
self._data[symbol] = result
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class IndicatorSet:
|
|
39
|
+
"""
|
|
40
|
+
Collection of indicators for easy management.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
self._indicators: Dict[str, Indicator] = {}
|
|
45
|
+
|
|
46
|
+
def add(self, name: str, fn: Callable, **kwargs):
|
|
47
|
+
self._indicators[name] = Indicator(name, fn, **kwargs)
|
|
48
|
+
|
|
49
|
+
def get(self, name: str) -> Indicator:
|
|
50
|
+
return self._indicators[name]
|
|
51
|
+
|
|
52
|
+
def calculate_all(self, df: pd.DataFrame, symbol: str) -> Dict[str, pd.Series]:
|
|
53
|
+
results = {}
|
|
54
|
+
for name, ind in self._indicators.items():
|
|
55
|
+
results[name] = ind(df, symbol)
|
|
56
|
+
return results
|