akquant 0.1.4__cp310-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of akquant might be problematic. Click here for more details.
- akquant/__init__.py +98 -0
- akquant/akquant.pyd +0 -0
- akquant/akquant.pyi +683 -0
- akquant/backtest.py +659 -0
- akquant/config.py +65 -0
- akquant/data.py +136 -0
- akquant/indicator.py +81 -0
- akquant/log.py +135 -0
- akquant/ml/__init__.py +3 -0
- akquant/ml/model.py +234 -0
- akquant/py.typed +0 -0
- akquant/risk.py +40 -0
- akquant/sizer.py +96 -0
- akquant/strategy.py +824 -0
- akquant/utils.py +386 -0
- akquant-0.1.4.dist-info/METADATA +219 -0
- akquant-0.1.4.dist-info/RECORD +19 -0
- akquant-0.1.4.dist-info/WHEEL +4 -0
- akquant-0.1.4.dist-info/licenses/LICENSE +21 -0
akquant/backtest.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
1
|
+
from functools import cached_property
|
|
2
|
+
from typing import (
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
Dict,
|
|
6
|
+
List,
|
|
7
|
+
Optional,
|
|
8
|
+
Type,
|
|
9
|
+
Union,
|
|
10
|
+
cast,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from .akquant import (
|
|
16
|
+
AssetType,
|
|
17
|
+
Bar,
|
|
18
|
+
DataFeed,
|
|
19
|
+
Engine,
|
|
20
|
+
ExecutionMode,
|
|
21
|
+
Instrument,
|
|
22
|
+
)
|
|
23
|
+
from .akquant import (
|
|
24
|
+
BacktestResult as RustBacktestResult,
|
|
25
|
+
)
|
|
26
|
+
from .config import BacktestConfig
|
|
27
|
+
from .data import ParquetDataCatalog
|
|
28
|
+
from .log import get_logger, register_logger
|
|
29
|
+
from .risk import apply_risk_config
|
|
30
|
+
from .strategy import Strategy
|
|
31
|
+
from .utils import df_to_arrays, prepare_dataframe
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BacktestResult:
|
|
35
|
+
"""
|
|
36
|
+
Backtest Result Wrapper.
|
|
37
|
+
|
|
38
|
+
Wraps the underlying Rust BacktestResult to provide Python-friendly properties
|
|
39
|
+
like DataFrames.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, raw_result: RustBacktestResult, timezone: str = "Asia/Shanghai"):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the BacktestResult wrapper.
|
|
45
|
+
|
|
46
|
+
:param raw_result: The raw Rust BacktestResult object.
|
|
47
|
+
:param timezone: The timezone string for datetime conversion.
|
|
48
|
+
"""
|
|
49
|
+
self._raw = raw_result
|
|
50
|
+
self._timezone = timezone
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def daily_positions_df(self) -> pd.DataFrame:
|
|
54
|
+
"""
|
|
55
|
+
Get daily positions as a Pandas DataFrame.
|
|
56
|
+
|
|
57
|
+
Index: Datetime (Timezone-aware)
|
|
58
|
+
Columns: Symbols
|
|
59
|
+
Values: Quantity.
|
|
60
|
+
"""
|
|
61
|
+
if not self._raw.daily_positions:
|
|
62
|
+
return pd.DataFrame()
|
|
63
|
+
|
|
64
|
+
# Unzip the list of tuples [(ts, {sym: qty}), ...]
|
|
65
|
+
timestamps, positions = zip(*self._raw.daily_positions)
|
|
66
|
+
|
|
67
|
+
df = pd.DataFrame(list(positions), index=timestamps)
|
|
68
|
+
|
|
69
|
+
# Convert nanosecond timestamp to datetime with timezone
|
|
70
|
+
df.index = pd.to_datetime(df.index, unit="ns", utc=True).tz_convert(
|
|
71
|
+
self._timezone
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Sort index just in case
|
|
75
|
+
df = df.sort_index()
|
|
76
|
+
|
|
77
|
+
# Fill missing values with 0 (assuming 0 position if not present in map)
|
|
78
|
+
df = df.fillna(0.0)
|
|
79
|
+
|
|
80
|
+
return df
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def metrics_df(self) -> pd.DataFrame:
|
|
84
|
+
"""Get performance metrics as a Pandas DataFrame."""
|
|
85
|
+
metrics = self._raw.metrics
|
|
86
|
+
|
|
87
|
+
# Manually construct dictionary from known fields since PyO3 objects
|
|
88
|
+
# might not expose __dict__ directly in a clean way or might have extra fields.
|
|
89
|
+
# We use the fields defined in PerformanceMetrics (see akquant.pyi)
|
|
90
|
+
data = {
|
|
91
|
+
"total_return": metrics.total_return,
|
|
92
|
+
"annualized_return": metrics.annualized_return,
|
|
93
|
+
"max_drawdown": metrics.max_drawdown,
|
|
94
|
+
"max_drawdown_pct": metrics.max_drawdown_pct,
|
|
95
|
+
"sharpe_ratio": metrics.sharpe_ratio,
|
|
96
|
+
"sortino_ratio": metrics.sortino_ratio,
|
|
97
|
+
"volatility": metrics.volatility,
|
|
98
|
+
"ulcer_index": metrics.ulcer_index,
|
|
99
|
+
"upi": metrics.upi,
|
|
100
|
+
"equity_r2": metrics.equity_r2,
|
|
101
|
+
"std_error": metrics.std_error,
|
|
102
|
+
"win_rate": metrics.win_rate,
|
|
103
|
+
"initial_market_value": metrics.initial_market_value,
|
|
104
|
+
"end_market_value": metrics.end_market_value,
|
|
105
|
+
"total_return_pct": metrics.total_return_pct,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
# Return as a DataFrame with one row
|
|
109
|
+
return pd.DataFrame([data], index=["Backtest"])
|
|
110
|
+
|
|
111
|
+
@cached_property
|
|
112
|
+
def trades_df(self) -> pd.DataFrame:
|
|
113
|
+
"""Get closed trades as a Pandas DataFrame."""
|
|
114
|
+
if not self._raw.trades:
|
|
115
|
+
return pd.DataFrame()
|
|
116
|
+
|
|
117
|
+
data = []
|
|
118
|
+
for t in self._raw.trades:
|
|
119
|
+
data.append(
|
|
120
|
+
{
|
|
121
|
+
"symbol": t.symbol,
|
|
122
|
+
"entry_time": t.entry_time,
|
|
123
|
+
"exit_time": t.exit_time,
|
|
124
|
+
"entry_price": t.entry_price,
|
|
125
|
+
"exit_price": t.exit_price,
|
|
126
|
+
"quantity": t.quantity,
|
|
127
|
+
"direction": t.direction,
|
|
128
|
+
"pnl": t.pnl,
|
|
129
|
+
"net_pnl": t.net_pnl,
|
|
130
|
+
"return_pct": t.return_pct,
|
|
131
|
+
"commission": t.commission,
|
|
132
|
+
"duration_bars": t.duration_bars,
|
|
133
|
+
}
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
df = pd.DataFrame(data)
|
|
137
|
+
|
|
138
|
+
# Convert timestamps
|
|
139
|
+
df["entry_time"] = pd.to_datetime(
|
|
140
|
+
df["entry_time"], unit="ns", utc=True
|
|
141
|
+
).dt.tz_convert(self._timezone)
|
|
142
|
+
df["exit_time"] = pd.to_datetime(
|
|
143
|
+
df["exit_time"], unit="ns", utc=True
|
|
144
|
+
).dt.tz_convert(self._timezone)
|
|
145
|
+
|
|
146
|
+
return df
|
|
147
|
+
|
|
148
|
+
def __getattr__(self, name: str) -> Any:
|
|
149
|
+
"""Delegate attribute access to the raw result."""
|
|
150
|
+
return getattr(self._raw, name)
|
|
151
|
+
|
|
152
|
+
def __repr__(self) -> str:
|
|
153
|
+
"""Return the string representation of the raw result."""
|
|
154
|
+
return repr(self._raw)
|
|
155
|
+
|
|
156
|
+
def __dir__(self) -> List[str]:
|
|
157
|
+
"""Return the list of attributes including raw result attributes."""
|
|
158
|
+
return list(
|
|
159
|
+
set(dir(self._raw) + list(self.__dict__.keys()) + ["daily_positions_df"])
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class FunctionalStrategy(Strategy):
|
|
164
|
+
"""内部策略包装器,用于支持函数式 API (Zipline 风格)."""
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
initialize: Optional[Callable[[Any], None]],
|
|
169
|
+
on_bar: Optional[Callable[[Any, Bar], None]],
|
|
170
|
+
context: Optional[Dict[str, Any]] = None,
|
|
171
|
+
):
|
|
172
|
+
"""Initialize the FunctionalStrategy."""
|
|
173
|
+
super().__init__()
|
|
174
|
+
self._initialize = initialize
|
|
175
|
+
self._on_bar_func = on_bar
|
|
176
|
+
self._context = context or {}
|
|
177
|
+
|
|
178
|
+
# 将 context 注入到 self 中,模拟 Zipline 的 context 对象
|
|
179
|
+
# 用户可以通过 self.xxx 访问 context 属性
|
|
180
|
+
for k, v in self._context.items():
|
|
181
|
+
setattr(self, k, v)
|
|
182
|
+
|
|
183
|
+
# 调用初始化函数
|
|
184
|
+
if self._initialize is not None:
|
|
185
|
+
self._initialize(self)
|
|
186
|
+
|
|
187
|
+
def on_bar(self, bar: Bar) -> None:
|
|
188
|
+
"""Delegate on_bar event to the user-provided function."""
|
|
189
|
+
if self._on_bar_func is not None:
|
|
190
|
+
self._on_bar_func(self, bar)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def run_backtest(
|
|
194
|
+
data: Optional[Union[pd.DataFrame, Dict[str, pd.DataFrame], List[Bar]]] = None,
|
|
195
|
+
strategy: Union[Type[Strategy], Strategy, Callable[[Any, Bar], None], None] = None,
|
|
196
|
+
symbol: Union[str, List[str]] = "BENCHMARK",
|
|
197
|
+
cash: float = 1_000_000.0,
|
|
198
|
+
commission: float = 0.0003,
|
|
199
|
+
stamp_tax: float = 0.0005,
|
|
200
|
+
transfer_fee: float = 0.00001,
|
|
201
|
+
min_commission: float = 5.0,
|
|
202
|
+
execution_mode: Union[ExecutionMode, str] = ExecutionMode.NextOpen,
|
|
203
|
+
timezone: str = "Asia/Shanghai",
|
|
204
|
+
initialize: Optional[Callable[[Any], None]] = None,
|
|
205
|
+
context: Optional[Dict[str, Any]] = None,
|
|
206
|
+
history_depth: int = 0,
|
|
207
|
+
lot_size: Union[int, Dict[str, int], None] = None,
|
|
208
|
+
show_progress: bool = True,
|
|
209
|
+
config: Optional[BacktestConfig] = None,
|
|
210
|
+
**kwargs: Any,
|
|
211
|
+
) -> BacktestResult:
|
|
212
|
+
"""
|
|
213
|
+
简化版回测入口函数.
|
|
214
|
+
|
|
215
|
+
:param data: 回测数据,可以是 Pandas DataFrame 或 Bar 列表.
|
|
216
|
+
可选(如果配置了config或策略订阅).
|
|
217
|
+
:param strategy: 策略类、策略实例或 on_bar 回调函数
|
|
218
|
+
:param symbol: 标的代码
|
|
219
|
+
:param cash: 初始资金
|
|
220
|
+
:param commission: 佣金率
|
|
221
|
+
:param stamp_tax: 印花税率 (仅卖出)
|
|
222
|
+
:param transfer_fee: 过户费率
|
|
223
|
+
:param min_commission: 最低佣金
|
|
224
|
+
:param execution_mode: 执行模式 (ExecutionMode.NextOpen 或 "next_open")
|
|
225
|
+
:param timezone: 时区名称
|
|
226
|
+
:param initialize: 初始化回调函数 (仅当 strategy 为函数时使用)
|
|
227
|
+
:param context: 初始上下文数据 (仅当 strategy 为函数时使用)
|
|
228
|
+
:param history_depth: 自动维护历史数据的长度 (0 表示禁用)
|
|
229
|
+
:param lot_size: 最小交易单位。如果是 int,则应用于所有标的;
|
|
230
|
+
如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
|
|
231
|
+
:param show_progress: 是否显示进度条 (默认 True)
|
|
232
|
+
:param config: BacktestConfig 配置对象 (可选)
|
|
233
|
+
:return: 回测结果 Result 对象
|
|
234
|
+
"""
|
|
235
|
+
# 1. 确保日志已初始化
|
|
236
|
+
logger = get_logger()
|
|
237
|
+
if not logger.handlers:
|
|
238
|
+
register_logger(console=True, level="INFO")
|
|
239
|
+
logger = get_logger()
|
|
240
|
+
|
|
241
|
+
# 1.5 处理 Config 覆盖
|
|
242
|
+
if config:
|
|
243
|
+
if config.start_date:
|
|
244
|
+
kwargs["start_date"] = config.start_date
|
|
245
|
+
if config.end_date:
|
|
246
|
+
kwargs["end_date"] = config.end_date
|
|
247
|
+
if config.timezone:
|
|
248
|
+
timezone = config.timezone
|
|
249
|
+
if config.show_progress is not None:
|
|
250
|
+
show_progress = config.show_progress
|
|
251
|
+
if config.history_depth is not None:
|
|
252
|
+
history_depth = config.history_depth
|
|
253
|
+
|
|
254
|
+
if config.strategy_config:
|
|
255
|
+
cash = config.strategy_config.initial_cash
|
|
256
|
+
# Fee handling could be more complex, simplifying here
|
|
257
|
+
commission = config.strategy_config.fee_amount or commission
|
|
258
|
+
|
|
259
|
+
# Risk Config injection handled later
|
|
260
|
+
|
|
261
|
+
# 2. 实例化策略 (提前实例化以获取订阅信息)
|
|
262
|
+
strategy_instance = None
|
|
263
|
+
|
|
264
|
+
if isinstance(strategy, type) and issubclass(strategy, Strategy):
|
|
265
|
+
try:
|
|
266
|
+
strategy_instance = strategy(**kwargs)
|
|
267
|
+
except TypeError:
|
|
268
|
+
strategy_instance = strategy()
|
|
269
|
+
elif isinstance(strategy, Strategy):
|
|
270
|
+
strategy_instance = strategy
|
|
271
|
+
elif callable(strategy):
|
|
272
|
+
strategy_instance = FunctionalStrategy(
|
|
273
|
+
initialize, cast(Callable[[Any, Bar], None], strategy), context
|
|
274
|
+
)
|
|
275
|
+
elif strategy is None:
|
|
276
|
+
raise ValueError("Strategy must be provided.")
|
|
277
|
+
else:
|
|
278
|
+
raise ValueError("Invalid strategy type")
|
|
279
|
+
|
|
280
|
+
# 注入 context
|
|
281
|
+
if context and hasattr(strategy_instance, "_context"):
|
|
282
|
+
pass
|
|
283
|
+
elif context and strategy_instance:
|
|
284
|
+
for k, v in context.items():
|
|
285
|
+
setattr(strategy_instance, k, v)
|
|
286
|
+
|
|
287
|
+
# 注入 Config 中的 Risk Config
|
|
288
|
+
if config and config.strategy_config and config.strategy_config.risk:
|
|
289
|
+
# 如果策略支持 set_risk_config (假设我们添加它,或者直接注入属性)
|
|
290
|
+
if hasattr(strategy_instance, "risk_config"):
|
|
291
|
+
strategy_instance.risk_config = config.strategy_config.risk # type: ignore
|
|
292
|
+
|
|
293
|
+
# 调用 on_start 获取订阅
|
|
294
|
+
if hasattr(strategy_instance, "on_start"):
|
|
295
|
+
strategy_instance.on_start()
|
|
296
|
+
|
|
297
|
+
# 3. 准备数据源和 Symbol
|
|
298
|
+
feed = DataFeed()
|
|
299
|
+
symbols = []
|
|
300
|
+
data_map_for_indicators = {}
|
|
301
|
+
|
|
302
|
+
# Normalize symbol arg to list
|
|
303
|
+
if isinstance(symbol, str):
|
|
304
|
+
symbols = [symbol]
|
|
305
|
+
elif isinstance(symbol, list):
|
|
306
|
+
symbols = symbol
|
|
307
|
+
else:
|
|
308
|
+
symbols = ["BENCHMARK"]
|
|
309
|
+
|
|
310
|
+
# Merge with Config instruments
|
|
311
|
+
if config and config.instruments:
|
|
312
|
+
for s in config.instruments:
|
|
313
|
+
if s not in symbols:
|
|
314
|
+
symbols.append(s)
|
|
315
|
+
|
|
316
|
+
# Merge with Strategy subscriptions
|
|
317
|
+
if hasattr(strategy_instance, "_subscriptions"):
|
|
318
|
+
for s in strategy_instance._subscriptions:
|
|
319
|
+
if s not in symbols:
|
|
320
|
+
symbols.append(s)
|
|
321
|
+
|
|
322
|
+
# Determine Data Loading Strategy
|
|
323
|
+
if data is not None:
|
|
324
|
+
# Use provided data
|
|
325
|
+
if isinstance(data, pd.DataFrame):
|
|
326
|
+
target_symbol = symbols[0] if symbols else "BENCHMARK"
|
|
327
|
+
df = prepare_dataframe(data)
|
|
328
|
+
data_map_for_indicators[target_symbol] = df
|
|
329
|
+
arrays = df_to_arrays(df, symbol=target_symbol)
|
|
330
|
+
feed.add_arrays(*arrays) # type: ignore
|
|
331
|
+
feed.sort()
|
|
332
|
+
if target_symbol not in symbols:
|
|
333
|
+
symbols = [target_symbol]
|
|
334
|
+
elif isinstance(data, dict):
|
|
335
|
+
for sym, df in data.items():
|
|
336
|
+
df_prep = prepare_dataframe(df)
|
|
337
|
+
data_map_for_indicators[sym] = df_prep
|
|
338
|
+
arrays = df_to_arrays(df_prep, symbol=sym)
|
|
339
|
+
feed.add_arrays(*arrays) # type: ignore
|
|
340
|
+
if sym not in symbols:
|
|
341
|
+
symbols.append(sym)
|
|
342
|
+
feed.sort()
|
|
343
|
+
elif isinstance(data, list):
|
|
344
|
+
if data:
|
|
345
|
+
data.sort(key=lambda b: b.timestamp)
|
|
346
|
+
feed.add_bars(data)
|
|
347
|
+
else:
|
|
348
|
+
# Load from Catalog / Akshare
|
|
349
|
+
if not symbols:
|
|
350
|
+
logger.warning("No symbols specified and no data provided.")
|
|
351
|
+
|
|
352
|
+
catalog = ParquetDataCatalog()
|
|
353
|
+
start_date = kwargs.get("start_date")
|
|
354
|
+
end_date = kwargs.get("end_date")
|
|
355
|
+
|
|
356
|
+
loaded_count = 0
|
|
357
|
+
for sym in symbols:
|
|
358
|
+
# Try Catalog
|
|
359
|
+
df = catalog.read(sym, start_date=start_date, end_date=end_date)
|
|
360
|
+
if df.empty:
|
|
361
|
+
logger.warning(f"Data not found in catalog for {sym}")
|
|
362
|
+
continue
|
|
363
|
+
|
|
364
|
+
if not df.empty:
|
|
365
|
+
df = prepare_dataframe(df)
|
|
366
|
+
data_map_for_indicators[sym] = df
|
|
367
|
+
arrays = df_to_arrays(df, symbol=sym)
|
|
368
|
+
feed.add_arrays(*arrays) # type: ignore
|
|
369
|
+
loaded_count += 1
|
|
370
|
+
|
|
371
|
+
if loaded_count > 0:
|
|
372
|
+
feed.sort()
|
|
373
|
+
else:
|
|
374
|
+
if symbols:
|
|
375
|
+
logger.warning("Failed to load data for all requested symbols.")
|
|
376
|
+
|
|
377
|
+
# 4. 设置引擎
|
|
378
|
+
engine = Engine()
|
|
379
|
+
# engine.set_timezone_name(timezone)
|
|
380
|
+
offset_delta = pd.Timestamp.now(tz=timezone).utcoffset()
|
|
381
|
+
if offset_delta is None:
|
|
382
|
+
raise ValueError(f"Invalid timezone: {timezone}")
|
|
383
|
+
offset = int(offset_delta.total_seconds())
|
|
384
|
+
engine.set_timezone(offset)
|
|
385
|
+
engine.set_cash(cash)
|
|
386
|
+
|
|
387
|
+
# ... (ExecutionMode logic)
|
|
388
|
+
if isinstance(execution_mode, str):
|
|
389
|
+
mode_map = {
|
|
390
|
+
"next_open": ExecutionMode.NextOpen,
|
|
391
|
+
"current_close": ExecutionMode.CurrentClose,
|
|
392
|
+
}
|
|
393
|
+
mode = mode_map.get(execution_mode.lower())
|
|
394
|
+
if not mode:
|
|
395
|
+
logger.warning(
|
|
396
|
+
f"Unknown execution mode '{execution_mode}', defaulting to NextOpen"
|
|
397
|
+
)
|
|
398
|
+
mode = ExecutionMode.NextOpen
|
|
399
|
+
engine.set_execution_mode(mode)
|
|
400
|
+
else:
|
|
401
|
+
engine.set_execution_mode(execution_mode)
|
|
402
|
+
|
|
403
|
+
engine.set_t_plus_one(False) # 默认 T+0,可配置
|
|
404
|
+
engine.set_force_session_continuous(True)
|
|
405
|
+
engine.set_stock_fee_rules(commission, stamp_tax, transfer_fee, min_commission)
|
|
406
|
+
|
|
407
|
+
# Configure other asset fees if provided
|
|
408
|
+
if "fund_commission" in kwargs:
|
|
409
|
+
engine.set_fund_fee_rules(
|
|
410
|
+
kwargs["fund_commission"],
|
|
411
|
+
kwargs.get("fund_transfer_fee", 0.0),
|
|
412
|
+
kwargs.get("fund_min_commission", 0.0),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
if "option_commission" in kwargs:
|
|
416
|
+
engine.set_option_fee_rules(kwargs["option_commission"])
|
|
417
|
+
|
|
418
|
+
# Apply Risk Config
|
|
419
|
+
if config and config.strategy_config:
|
|
420
|
+
apply_risk_config(engine, config.strategy_config.risk)
|
|
421
|
+
|
|
422
|
+
# 5. 添加标的
|
|
423
|
+
multiplier = kwargs.get("multiplier", 1.0)
|
|
424
|
+
margin_ratio = kwargs.get("margin_ratio", 1.0)
|
|
425
|
+
tick_size = kwargs.get("tick_size", 0.01)
|
|
426
|
+
asset_type = kwargs.get("asset_type", AssetType.Stock)
|
|
427
|
+
|
|
428
|
+
# Option specific fields
|
|
429
|
+
option_type = kwargs.get("option_type", None)
|
|
430
|
+
strike_price = kwargs.get("strike_price", None)
|
|
431
|
+
expiry_date = kwargs.get("expiry_date", None)
|
|
432
|
+
|
|
433
|
+
for sym in symbols:
|
|
434
|
+
# Determine lot_size for this symbol
|
|
435
|
+
current_lot_size = None
|
|
436
|
+
if isinstance(lot_size, int):
|
|
437
|
+
current_lot_size = lot_size
|
|
438
|
+
elif isinstance(lot_size, dict):
|
|
439
|
+
current_lot_size = lot_size.get(sym)
|
|
440
|
+
|
|
441
|
+
instr = Instrument(
|
|
442
|
+
sym,
|
|
443
|
+
asset_type,
|
|
444
|
+
multiplier,
|
|
445
|
+
margin_ratio,
|
|
446
|
+
tick_size,
|
|
447
|
+
option_type,
|
|
448
|
+
strike_price,
|
|
449
|
+
expiry_date,
|
|
450
|
+
current_lot_size,
|
|
451
|
+
)
|
|
452
|
+
engine.add_instrument(instr)
|
|
453
|
+
|
|
454
|
+
# 6. 添加数据
|
|
455
|
+
engine.add_data(feed)
|
|
456
|
+
|
|
457
|
+
# 7. 运行回测
|
|
458
|
+
logger.info("Running backtest via run_backtest()...")
|
|
459
|
+
|
|
460
|
+
# 设置自动历史数据维护
|
|
461
|
+
if history_depth > 0:
|
|
462
|
+
strategy_instance.set_history_depth(history_depth)
|
|
463
|
+
|
|
464
|
+
# 7.5 Prepare Indicators (Vectorized Pre-calculation)
|
|
465
|
+
if hasattr(strategy_instance, "_prepare_indicators") and data_map_for_indicators:
|
|
466
|
+
strategy_instance._prepare_indicators(data_map_for_indicators)
|
|
467
|
+
|
|
468
|
+
engine.run(strategy_instance, show_progress)
|
|
469
|
+
|
|
470
|
+
return BacktestResult(engine.get_results(), timezone=timezone)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def plot_result(
|
|
474
|
+
result: Any,
|
|
475
|
+
show: bool = True,
|
|
476
|
+
filename: Optional[str] = None,
|
|
477
|
+
benchmark: Optional[pd.Series] = None,
|
|
478
|
+
) -> None:
|
|
479
|
+
"""
|
|
480
|
+
绘制回测结果 (权益曲线、回撤、日收益率).
|
|
481
|
+
|
|
482
|
+
:param result: BacktestResult 对象
|
|
483
|
+
:param show: 是否调用 plt.show()
|
|
484
|
+
:param filename: 保存图片的文件名
|
|
485
|
+
:param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
|
|
486
|
+
"""
|
|
487
|
+
try:
|
|
488
|
+
from datetime import datetime
|
|
489
|
+
|
|
490
|
+
import matplotlib.dates as mdates
|
|
491
|
+
import matplotlib.pyplot as plt
|
|
492
|
+
from matplotlib.gridspec import GridSpec
|
|
493
|
+
except ImportError:
|
|
494
|
+
print(
|
|
495
|
+
"Error: matplotlib is required for plotting. "
|
|
496
|
+
"Please install it via 'pip install matplotlib'."
|
|
497
|
+
)
|
|
498
|
+
return
|
|
499
|
+
|
|
500
|
+
# Extract data
|
|
501
|
+
equity_curve = result.equity_curve # List[Tuple[int, float]]
|
|
502
|
+
|
|
503
|
+
if not equity_curve:
|
|
504
|
+
print("No equity curve data to plot.")
|
|
505
|
+
return
|
|
506
|
+
|
|
507
|
+
# Check if timestamp is in nanoseconds (e.g. > 1e11)
|
|
508
|
+
# 1e11 seconds is roughly year 5138, so valid seconds are < 1e11
|
|
509
|
+
# 1e18 nanoseconds is roughly year 2001
|
|
510
|
+
first_ts = equity_curve[0][0]
|
|
511
|
+
scale = 1.0
|
|
512
|
+
if first_ts > 1e11:
|
|
513
|
+
scale = 1e-9
|
|
514
|
+
|
|
515
|
+
from datetime import timezone
|
|
516
|
+
|
|
517
|
+
# Use UTC to avoid local timezone issues and align with benchmark data
|
|
518
|
+
times = [
|
|
519
|
+
datetime.fromtimestamp(t * scale, tz=timezone.utc).replace(tzinfo=None)
|
|
520
|
+
for t, _ in equity_curve
|
|
521
|
+
]
|
|
522
|
+
equity = [e for _, e in equity_curve]
|
|
523
|
+
|
|
524
|
+
# Convert to DataFrame for easier calculation
|
|
525
|
+
df = pd.DataFrame({"equity": equity}, index=times)
|
|
526
|
+
df.index.name = "Date"
|
|
527
|
+
df["returns"] = df["equity"].pct_change().fillna(0)
|
|
528
|
+
|
|
529
|
+
# Calculate Drawdown
|
|
530
|
+
rolling_max = df["equity"].cummax()
|
|
531
|
+
drawdown = (df["equity"] - rolling_max) / rolling_max
|
|
532
|
+
|
|
533
|
+
# Create figure with GridSpec
|
|
534
|
+
fig = plt.figure(figsize=(14, 10))
|
|
535
|
+
# 3 rows: Equity (3), Drawdown (1), Daily Returns (1)
|
|
536
|
+
gs = GridSpec(3, 1, height_ratios=[3, 1, 1], hspace=0.05)
|
|
537
|
+
|
|
538
|
+
# 1. Equity Curve
|
|
539
|
+
ax1 = fig.add_subplot(gs[0])
|
|
540
|
+
ax1.plot(df.index, df["equity"], label="Strategy", color="#1f77b4", linewidth=1.5)
|
|
541
|
+
|
|
542
|
+
if benchmark is not None:
|
|
543
|
+
# Align benchmark to strategy dates
|
|
544
|
+
try:
|
|
545
|
+
# Ensure benchmark has DatetimeIndex
|
|
546
|
+
if not isinstance(benchmark.index, pd.DatetimeIndex):
|
|
547
|
+
benchmark.index = pd.to_datetime(benchmark.index)
|
|
548
|
+
|
|
549
|
+
# Normalize timezones: ensure benchmark is tz-naive UTC
|
|
550
|
+
if benchmark.index.tz is not None:
|
|
551
|
+
benchmark.index = benchmark.index.tz_convert("UTC").tz_localize(None)
|
|
552
|
+
|
|
553
|
+
# Reindex benchmark to match strategy dates (forward fill for missing days)
|
|
554
|
+
# Normalize dates to start of day for alignment if needed
|
|
555
|
+
# For simplicity, we just plot what overlaps
|
|
556
|
+
|
|
557
|
+
# Calculate cumulative return of benchmark
|
|
558
|
+
bench_cum = (1 + benchmark).cumprod()
|
|
559
|
+
|
|
560
|
+
# Rebase benchmark to match initial strategy equity
|
|
561
|
+
initial_equity = df["equity"].iloc[0]
|
|
562
|
+
if not bench_cum.empty:
|
|
563
|
+
# Align start
|
|
564
|
+
# Find the closest date in benchmark to start date
|
|
565
|
+
start_date = df.index[0]
|
|
566
|
+
if start_date in bench_cum.index:
|
|
567
|
+
base_val = bench_cum.loc[start_date]
|
|
568
|
+
else:
|
|
569
|
+
# Fallback: use first available
|
|
570
|
+
base_val = bench_cum.iloc[0]
|
|
571
|
+
|
|
572
|
+
bench_scaled = (bench_cum / base_val) * initial_equity
|
|
573
|
+
|
|
574
|
+
# Filter to strategy range
|
|
575
|
+
bench_plot = bench_scaled[df.index[0] : df.index[-1]] # type: ignore
|
|
576
|
+
ax1.plot(
|
|
577
|
+
bench_plot.index,
|
|
578
|
+
bench_plot,
|
|
579
|
+
label="Benchmark",
|
|
580
|
+
color="gray",
|
|
581
|
+
linestyle="--",
|
|
582
|
+
alpha=0.7,
|
|
583
|
+
)
|
|
584
|
+
except Exception as e:
|
|
585
|
+
print(f"Warning: Failed to plot benchmark: {e}")
|
|
586
|
+
|
|
587
|
+
ax1.set_title("Strategy Performance Analysis", fontsize=14, fontweight="bold")
|
|
588
|
+
ax1.set_ylabel("Equity", fontsize=10)
|
|
589
|
+
ax1.grid(True, linestyle="--", alpha=0.3)
|
|
590
|
+
ax1.legend(loc="upper left", frameon=True, fancybox=True, framealpha=0.8)
|
|
591
|
+
|
|
592
|
+
# Add Metrics Text Box
|
|
593
|
+
metrics = result.metrics
|
|
594
|
+
trade_metrics = result.trade_metrics
|
|
595
|
+
|
|
596
|
+
metrics_text = [
|
|
597
|
+
f"Total Return: {metrics.total_return_pct:>8.2f}%",
|
|
598
|
+
f"Annualized: {metrics.annualized_return:>8.2%}",
|
|
599
|
+
f"Sharpe Ratio: {metrics.sharpe_ratio:>8.2f}",
|
|
600
|
+
f"Max Drawdown: {metrics.max_drawdown_pct:>8.2f}%",
|
|
601
|
+
f"Win Rate: {metrics.win_rate:>8.2%}",
|
|
602
|
+
]
|
|
603
|
+
|
|
604
|
+
if hasattr(trade_metrics, "total_closed_trades"):
|
|
605
|
+
metrics_text.append(f"Trades: {trade_metrics.total_closed_trades:>8d}")
|
|
606
|
+
|
|
607
|
+
text_str = "\n".join(metrics_text)
|
|
608
|
+
|
|
609
|
+
props = dict(boxstyle="round", facecolor="white", alpha=0.8, edgecolor="lightgray")
|
|
610
|
+
ax1.text(
|
|
611
|
+
0.02,
|
|
612
|
+
0.05,
|
|
613
|
+
text_str,
|
|
614
|
+
transform=ax1.transAxes,
|
|
615
|
+
fontsize=9,
|
|
616
|
+
verticalalignment="bottom",
|
|
617
|
+
fontfamily="monospace",
|
|
618
|
+
bbox=props,
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
# 2. Drawdown
|
|
622
|
+
ax2 = fig.add_subplot(gs[1], sharex=ax1)
|
|
623
|
+
ax2.fill_between(
|
|
624
|
+
df.index, drawdown, 0, color="#d62728", alpha=0.3, label="Drawdown"
|
|
625
|
+
)
|
|
626
|
+
ax2.plot(df.index, drawdown, color="#d62728", linewidth=0.8, alpha=0.8)
|
|
627
|
+
ax2.set_ylabel("Drawdown", fontsize=10)
|
|
628
|
+
ax2.grid(True, linestyle="--", alpha=0.3)
|
|
629
|
+
# ax2.legend(loc='lower right', fontsize=8)
|
|
630
|
+
|
|
631
|
+
# 3. Daily Returns
|
|
632
|
+
ax3 = fig.add_subplot(gs[2], sharex=ax1)
|
|
633
|
+
ax3.bar(
|
|
634
|
+
df.index,
|
|
635
|
+
df["returns"],
|
|
636
|
+
color="gray",
|
|
637
|
+
alpha=0.5,
|
|
638
|
+
label="Daily Returns",
|
|
639
|
+
width=1.0 if len(df) < 100 else 0.8,
|
|
640
|
+
)
|
|
641
|
+
# Highlight extreme returns? No, keep simple.
|
|
642
|
+
ax3.set_ylabel("Returns", fontsize=10)
|
|
643
|
+
ax3.grid(True, linestyle="--", alpha=0.3)
|
|
644
|
+
|
|
645
|
+
# Format X axis
|
|
646
|
+
ax3.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
|
|
647
|
+
plt.setp(ax1.get_xticklabels(), visible=False)
|
|
648
|
+
plt.setp(ax2.get_xticklabels(), visible=False)
|
|
649
|
+
plt.xticks(rotation=0)
|
|
650
|
+
|
|
651
|
+
# Adjust margins
|
|
652
|
+
plt.subplots_adjust(top=0.95, bottom=0.05, left=0.08, right=0.95)
|
|
653
|
+
|
|
654
|
+
if filename:
|
|
655
|
+
plt.savefig(filename, dpi=100, bbox_inches="tight")
|
|
656
|
+
print(f"Plot saved to {filename}")
|
|
657
|
+
|
|
658
|
+
if show:
|
|
659
|
+
plt.show()
|