akquant 0.1.4__cp310-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of akquant might be problematic. Click here for more details.

akquant/strategy.py ADDED
@@ -0,0 +1,824 @@
1
+ from collections import defaultdict, deque
2
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from .akquant import (
8
+ Bar,
9
+ ExecutionMode,
10
+ OrderStatus,
11
+ StrategyContext,
12
+ Tick,
13
+ TimeInForce,
14
+ )
15
+ from .sizer import FixedSize, Sizer
16
+ from .utils import parse_duration_to_bars
17
+
18
+ if TYPE_CHECKING:
19
+ from .indicator import Indicator
20
+ from .ml.model import QuantModel
21
+
22
+
23
+ class Strategy:
24
+ """
25
+ 策略基类 (Base Strategy Class).
26
+
27
+ 采用类似 NautilusTrader 的事件驱动设计
28
+ """
29
+
30
+ ctx: Optional[StrategyContext]
31
+ execution_mode: Optional[ExecutionMode]
32
+ sizer: Sizer
33
+ current_bar: Optional[Bar]
34
+ current_tick: Optional[Tick]
35
+ _history_depth: int
36
+ _bars_history: "defaultdict[str, deque[Bar]]"
37
+ _indicators: List["Indicator"]
38
+ _subscriptions: List[str]
39
+ _last_prices: Dict[str, float]
40
+ _rolling_train_window: int
41
+ _rolling_step: int
42
+ _bar_count: int
43
+ _model_configured: bool
44
+ model: Optional["QuantModel"]
45
+
46
+ def __new__(cls, *args: Any, **kwargs: Any) -> "Strategy":
47
+ """Create a new Strategy instance."""
48
+ instance = super().__new__(cls)
49
+ instance.ctx = None
50
+ instance.execution_mode = None
51
+ instance.sizer = FixedSize(100)
52
+ instance.current_bar = None
53
+ instance.current_tick = None
54
+ instance._indicators = []
55
+ instance._subscriptions = []
56
+ instance._last_prices = {}
57
+
58
+ # 历史数据配置
59
+ instance._history_depth = 0
60
+
61
+ # 滚动训练配置
62
+ instance._rolling_train_window = 0
63
+ instance._rolling_step = 0
64
+ instance._bar_count = 0
65
+ instance._model_configured = False
66
+
67
+ # 初始化通常在 __init__ 中的属性,允许子类省略 super().__init__()
68
+ instance.model = None
69
+
70
+ return instance
71
+
72
+ def __init__(self) -> None:
73
+ """初始化."""
74
+ pass
75
+
76
+ def set_history_depth(self, depth: int) -> None:
77
+ """
78
+ 设置历史数据回溯长度.
79
+
80
+ :param depth: 保留的 Bar 数量 (0 表示不保留)
81
+ """
82
+ self._history_depth = depth
83
+
84
+ def set_rolling_window(self, train_window: int, step: int) -> None:
85
+ """
86
+ 设置滚动训练窗口参数.
87
+
88
+ :param train_window: 训练数据长度 (Bars)
89
+ :param step: 滚动步长 (每隔多少个 Bar 触发一次训练)
90
+ """
91
+ self._rolling_train_window = train_window
92
+ self._rolling_step = step
93
+ # 自动调整 history_depth 以满足训练窗口需求
94
+ if self._history_depth < train_window:
95
+ self._history_depth = train_window
96
+
97
+ def get_history(
98
+ self, count: int, symbol: Optional[str] = None, field: str = "close"
99
+ ) -> np.ndarray:
100
+ """
101
+ 获取历史数据 (类似 Zipline data.history).
102
+
103
+ :param count: 获取的数据长度 (必须 <= history_depth)
104
+ :param symbol: 标的代码 (默认当前 Bar 的 symbol)
105
+ :param field: 字段名 (open, high, low, close, volume)
106
+ :return: Numpy 数组
107
+ """
108
+ if self._history_depth == 0:
109
+ raise RuntimeError(
110
+ "History tracking is not enabled. Call set_history_depth() first."
111
+ )
112
+
113
+ if self.ctx is None:
114
+ raise RuntimeError("Context not ready")
115
+
116
+ symbol = self._resolve_symbol(symbol)
117
+
118
+ # Call Rust implementation
119
+ arr = self.ctx.history(symbol, field.lower(), count)
120
+
121
+ if arr is None:
122
+ return cast(np.ndarray, np.full(count, np.nan))
123
+
124
+ if len(arr) < count:
125
+ # Pad with NaN at the beginning
126
+ padding = np.full(count - len(arr), np.nan)
127
+ return cast(np.ndarray, np.concatenate((padding, arr)))
128
+
129
+ return cast(np.ndarray, arr)
130
+
131
+ def get_history_df(self, count: int, symbol: Optional[str] = None) -> pd.DataFrame:
132
+ """
133
+ 获取历史数据 DataFrame (Open, High, Low, Close, Volume).
134
+
135
+ :param count: 数据长度
136
+ :param symbol: 标的代码
137
+ :return: pd.DataFrame
138
+ """
139
+ symbol = self._resolve_symbol(symbol)
140
+
141
+ data = {
142
+ "open": self.get_history(count, symbol, "open"),
143
+ "high": self.get_history(count, symbol, "high"),
144
+ "low": self.get_history(count, symbol, "low"),
145
+ "close": self.get_history(count, symbol, "close"),
146
+ "volume": self.get_history(count, symbol, "volume"),
147
+ }
148
+ return pd.DataFrame(data)
149
+
150
+ def get_rolling_data(
151
+ self, length: Optional[int] = None, symbol: Optional[str] = None
152
+ ) -> tuple[pd.DataFrame, Optional[pd.Series]]:
153
+ """
154
+ 获取滚动训练数据.
155
+
156
+ :param length: 数据长度 (默认使用 set_rolling_window 设置的 train_window)
157
+ :param symbol: 标的代码
158
+ :return: (X, y) 默认为 (DataFrame, None)
159
+ """
160
+ if length is None:
161
+ length = self._rolling_train_window
162
+
163
+ if length <= 0:
164
+ raise ValueError("Invalid rolling window length")
165
+
166
+ df = self.get_history_df(length, symbol)
167
+
168
+ # 默认返回 raw DataFrame 作为 X,y 为 None
169
+ # 用户可以在策略中重写此方法或自行处理数据
170
+ return df, None
171
+
172
+ def on_train_signal(self, context: Any) -> None:
173
+ """
174
+ 滚动训练信号回调.
175
+
176
+ 默认实现:如果配置了 self.model,则自动执行数据准备和训练.
177
+
178
+ :param context: 策略上下文 (通常是 self)
179
+ """
180
+ if self.model:
181
+ try:
182
+ X_df, _ = self.get_rolling_data()
183
+ X, y = self.prepare_features(X_df)
184
+ self.model.fit(X, y)
185
+ except NotImplementedError:
186
+ # User didn't implement prepare_features, assuming manual handling
187
+ pass
188
+ except Exception as e:
189
+ print(f"Auto-training failed at bar {self._bar_count}: {e}")
190
+
191
+ def prepare_features(self, df: pd.DataFrame) -> Tuple[Any, Any]:
192
+ """
193
+ Prepare features and labels for ML model.
194
+
195
+ Must be implemented by user if using auto-training.
196
+
197
+ :param df: Raw dataframe from get_rolling_data
198
+ :return: (X, y)
199
+ """
200
+ raise NotImplementedError(
201
+ "You must implement prepare_features(self, df) for auto-training"
202
+ )
203
+
204
+ def _auto_configure_model(self) -> None:
205
+ """Apply model validation configuration if present."""
206
+ if self._model_configured:
207
+ return
208
+
209
+ if self.model and self.model.validation_config:
210
+ cfg = self.model.validation_config
211
+
212
+ try:
213
+ train_window = parse_duration_to_bars(cfg.train_window, cfg.frequency)
214
+ step = parse_duration_to_bars(cfg.rolling_step, cfg.frequency)
215
+
216
+ # Update settings
217
+ self.set_rolling_window(train_window, step)
218
+ except Exception as e:
219
+ print(f"Failed to configure model validation: {e}")
220
+
221
+ self._model_configured = True
222
+
223
+ def set_sizer(self, sizer: Sizer) -> None:
224
+ """设置仓位管理器."""
225
+ self.sizer = sizer
226
+
227
+ def register_indicator(self, name: str, indicator: "Indicator") -> None:
228
+ """
229
+ Register an indicator.
230
+
231
+ This allows accessing the indicator via self.name and ensures it is
232
+ calculated before the backtest starts.
233
+ """
234
+ self._indicators.append(indicator)
235
+ setattr(self, name, indicator)
236
+
237
+ def subscribe(self, instrument_id: str) -> None:
238
+ """
239
+ Subscribe to market data for an instrument.
240
+
241
+ :param instrument_id: The instrument identifier (e.g., '600000').
242
+ """
243
+ if instrument_id not in self._subscriptions:
244
+ self._subscriptions.append(instrument_id)
245
+
246
+ def on_start(self) -> None:
247
+ """策略启动时调用."""
248
+ pass
249
+
250
+ def _prepare_indicators(self, data: Dict[str, pd.DataFrame]) -> None:
251
+ """Pre-calculate indicators."""
252
+ if not self._indicators:
253
+ return
254
+
255
+ for ind in self._indicators:
256
+ for sym, df in data.items():
257
+ # Calculate and cache inside indicator
258
+ ind(df, sym)
259
+
260
+ def _on_bar_event(self, bar: Bar, ctx: StrategyContext) -> None:
261
+ """引擎调用的 Bar 回调 (Internal)."""
262
+ self.ctx = ctx
263
+
264
+ # Lazy configuration
265
+ if not self._model_configured:
266
+ self._auto_configure_model()
267
+
268
+ self.current_bar = bar
269
+ self._last_prices[bar.symbol] = bar.close
270
+
271
+ # 检查滚动训练信号
272
+ if self._rolling_step > 0:
273
+ self._bar_count += 1
274
+ if self._bar_count % self._rolling_step == 0:
275
+ # 触发训练信号,传入 self 作为 context
276
+ self.on_train_signal(self)
277
+
278
+ self.on_bar(bar)
279
+
280
+ def _on_tick_event(self, tick: Tick, ctx: StrategyContext) -> None:
281
+ """引擎调用的 Tick 回调 (Internal)."""
282
+ self.ctx = ctx
283
+ self.current_tick = tick
284
+ self._last_prices[tick.symbol] = tick.price
285
+ self.on_tick(tick)
286
+
287
+ def _on_timer_event(self, payload: str, ctx: StrategyContext) -> None:
288
+ """引擎调用的 Timer 回调 (Internal)."""
289
+ self.ctx = ctx
290
+ self.on_timer(payload)
291
+
292
+ def on_bar(self, bar: Bar) -> None:
293
+ """
294
+ 策略逻辑入口 (Bar 数据).
295
+
296
+ 用户应重写此方法.
297
+ """
298
+ pass
299
+
300
+ def on_tick(self, tick: Tick) -> None:
301
+ """
302
+ 策略逻辑入口 (Tick 数据).
303
+
304
+ 用户应重写此方法.
305
+ """
306
+ pass
307
+
308
+ def on_timer(self, payload: str) -> None:
309
+ """
310
+ 策略逻辑入口 (Timer 事件).
311
+
312
+ Args:
313
+ payload: 定时器携带的数据
314
+ """
315
+ pass
316
+
317
+ def _resolve_symbol(self, symbol: Optional[str] = None) -> str:
318
+ if symbol is None:
319
+ if self.current_bar:
320
+ symbol = self.current_bar.symbol
321
+ elif self.current_tick:
322
+ symbol = self.current_tick.symbol
323
+ else:
324
+ raise ValueError("Symbol must be provided")
325
+ return symbol
326
+
327
+ def get_open_orders(self, symbol: Optional[str] = None) -> list[Any]:
328
+ """
329
+ 获取当前未完成的订单.
330
+
331
+ Args:
332
+ symbol: 标的代码 (如果为 None,返回所有标的订单)
333
+
334
+ Returns:
335
+ List[Order]: 订单列表
336
+ """
337
+ if self.ctx is None:
338
+ return []
339
+
340
+ orders = [
341
+ o
342
+ for o in self.ctx.active_orders
343
+ if o.status in (OrderStatus.New, OrderStatus.Submitted)
344
+ ]
345
+ if symbol:
346
+ return [o for o in orders if o.symbol == symbol]
347
+ return orders
348
+
349
+ def buy(
350
+ self,
351
+ symbol: Optional[str] = None,
352
+ quantity: Optional[float] = None,
353
+ price: Optional[float] = None,
354
+ time_in_force: Optional[TimeInForce] = None,
355
+ trigger_price: Optional[float] = None,
356
+ ) -> None:
357
+ """
358
+ 买入下单.
359
+
360
+ Args:
361
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
362
+ quantity: 数量 (如果不填, 使用 Sizer 计算)
363
+ price: 限价 (None 为市价)
364
+ time_in_force: 订单有效期
365
+ trigger_price: 触发价 (止损/止盈)
366
+ """
367
+ if self.ctx is None:
368
+ raise RuntimeError("Context not ready")
369
+
370
+ # 1. Determine Symbol
371
+ symbol = self._resolve_symbol(symbol)
372
+
373
+ # 2. Determine Reference Price for Sizing
374
+ ref_price = price
375
+ if ref_price is None:
376
+ ref_price = self._last_prices.get(symbol, 0.0)
377
+
378
+ # 3. Determine Quantity via Sizer
379
+ if quantity is None:
380
+ quantity = self.sizer.get_size(ref_price, self.ctx.cash, self.ctx, symbol)
381
+
382
+ # 4. Execute Buy
383
+ if quantity > 0:
384
+ self.ctx.buy(symbol, quantity, price, time_in_force, trigger_price)
385
+
386
+ def sell(
387
+ self,
388
+ symbol: Optional[str] = None,
389
+ quantity: Optional[float] = None,
390
+ price: Optional[float] = None,
391
+ time_in_force: Optional[TimeInForce] = None,
392
+ trigger_price: Optional[float] = None,
393
+ ) -> None:
394
+ """
395
+ 卖出下单.
396
+
397
+ Args:
398
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
399
+ quantity: 数量 (如果不填, 默认卖出当前标的所有持仓)
400
+ price: 限价 (None 为市价)
401
+ time_in_force: 订单有效期
402
+ trigger_price: 触发价 (止损/止盈)
403
+ """
404
+ if self.ctx is None:
405
+ raise RuntimeError("Context not ready")
406
+
407
+ # 1. Determine Symbol
408
+ symbol = self._resolve_symbol(symbol)
409
+
410
+ # 2. Determine Quantity (Default to Close Position if None)
411
+ if quantity is None:
412
+ # Default to closing the entire position for this symbol
413
+ pos = self.ctx.get_position(symbol)
414
+ if pos > 0:
415
+ quantity = pos
416
+ else:
417
+ # If no position, maybe use Sizer?
418
+ # For now, if no position and no quantity, we can't sell.
419
+ return
420
+
421
+ # 3. Execute Sell
422
+ if quantity > 0:
423
+ self.ctx.sell(symbol, quantity, price, time_in_force, trigger_price)
424
+
425
+ def stop_buy(
426
+ self,
427
+ symbol: Optional[str] = None,
428
+ trigger_price: float = 0.0,
429
+ quantity: Optional[float] = None,
430
+ price: Optional[float] = None,
431
+ time_in_force: Optional[TimeInForce] = None,
432
+ ) -> None:
433
+ """
434
+ 发送止损买入单 (Stop Buy Order).
435
+
436
+ 当市价上涨突破 trigger_price 时触发买入.
437
+ - 如果 price 为 None, 触发后转为市价单 (Stop Market).
438
+ - 如果 price 不为 None, 触发后转为限价单 (Stop Limit).
439
+ """
440
+ self.buy(symbol, quantity, price, time_in_force, trigger_price=trigger_price)
441
+
442
+ def stop_sell(
443
+ self,
444
+ symbol: Optional[str] = None,
445
+ trigger_price: float = 0.0,
446
+ quantity: Optional[float] = None,
447
+ price: Optional[float] = None,
448
+ time_in_force: Optional[TimeInForce] = None,
449
+ ) -> None:
450
+ """
451
+ 发送止损卖出单 (Stop Sell Order).
452
+
453
+ 当市价下跌跌破 trigger_price 时触发卖出.
454
+ - 如果 price 为 None, 触发后转为市价单 (Stop Market).
455
+ - 如果 price 不为 None, 触发后转为限价单 (Stop Limit).
456
+ """
457
+ self.sell(symbol, quantity, price, time_in_force, trigger_price=trigger_price)
458
+
459
+ def get_portfolio_value(self) -> float:
460
+ """计算当前投资组合总价值 (现金 + 持仓市值)."""
461
+ if self.ctx is None:
462
+ return 0.0
463
+
464
+ total_value = float(self.ctx.cash)
465
+
466
+ for symbol, qty in self.ctx.positions.items():
467
+ if qty == 0:
468
+ continue
469
+
470
+ # 使用最新价格计算市值
471
+ price = self._last_prices.get(symbol, 0.0)
472
+ # 如果没有最新价格,尝试使用当前 bar/tick
473
+ if price == 0.0:
474
+ if self.current_bar and self.current_bar.symbol == symbol:
475
+ price = self.current_bar.close
476
+ elif self.current_tick and self.current_tick.symbol == symbol:
477
+ price = self.current_tick.price
478
+
479
+ total_value += float(qty) * price
480
+
481
+ return total_value
482
+
483
+ def order_target(
484
+ self,
485
+ target: float,
486
+ symbol: Optional[str] = None,
487
+ price: Optional[float] = None,
488
+ **kwargs: Any,
489
+ ) -> None:
490
+ """
491
+ 调整仓位到目标数量.
492
+
493
+ :param target: 目标持仓数量 (例如 100, -100)
494
+ :param symbol: 标的代码
495
+ :param price: 限价 (可选)
496
+ :param kwargs: 其他下单参数
497
+ """
498
+ symbol = self._resolve_symbol(symbol)
499
+
500
+ current_qty = 0.0
501
+ if self.ctx:
502
+ current_qty = float(self.ctx.get_position(symbol))
503
+
504
+ delta_qty = target - current_qty
505
+
506
+ if delta_qty > 0:
507
+ self.buy(symbol, delta_qty, price, **kwargs)
508
+ elif delta_qty < 0:
509
+ self.sell(symbol, abs(delta_qty), price, **kwargs)
510
+
511
+ def order_target_value(
512
+ self,
513
+ target_value: float,
514
+ symbol: Optional[str] = None,
515
+ price: Optional[float] = None,
516
+ **kwargs: Any,
517
+ ) -> None:
518
+ """
519
+ 调整仓位到目标价值.
520
+
521
+ :param target_value: 目标持仓市值
522
+ :param symbol: 标的代码
523
+ :param price: 限价 (可选)
524
+ :param kwargs: 其他下单参数
525
+ """
526
+ symbol = self._resolve_symbol(symbol)
527
+
528
+ # 获取当前价格
529
+ current_price = self._last_prices.get(symbol, 0.0)
530
+ if current_price == 0.0:
531
+ if self.current_bar and self.current_bar.symbol == symbol:
532
+ current_price = self.current_bar.close
533
+ elif self.current_tick and self.current_tick.symbol == symbol:
534
+ current_price = self.current_tick.price
535
+ else:
536
+ # 无法获取价格,无法计算数量
537
+ print(
538
+ f"Warning: Cannot determine price for {symbol}, "
539
+ "skipping order_target_value"
540
+ )
541
+ return
542
+
543
+ # 获取当前持仓
544
+ current_qty = 0.0
545
+ if self.ctx:
546
+ current_qty = float(self.ctx.get_position(symbol))
547
+
548
+ # 计算目标数量
549
+ target_qty = target_value / current_price
550
+ delta_qty = target_qty - current_qty
551
+
552
+ # 下单
553
+ if delta_qty > 0:
554
+ self.buy(symbol, delta_qty, price, **kwargs)
555
+ elif delta_qty < 0:
556
+ self.sell(symbol, abs(delta_qty), price, **kwargs)
557
+
558
+ def order_target_percent(
559
+ self,
560
+ target_percent: float,
561
+ symbol: Optional[str] = None,
562
+ price: Optional[float] = None,
563
+ **kwargs: Any,
564
+ ) -> None:
565
+ """
566
+ 调整仓位到目标百分比.
567
+
568
+ :param target_percent: 目标持仓比例 (0.5 = 50%)
569
+ :param symbol: 标的代码
570
+ :param price: 限价 (可选)
571
+ :param kwargs: 其他下单参数
572
+ """
573
+ portfolio_value = self.get_portfolio_value()
574
+ target_value = portfolio_value * target_percent
575
+ self.order_target_value(target_value, symbol, price, **kwargs)
576
+
577
+ def cancel_order(self, order_or_id: Any) -> None:
578
+ """
579
+ 取消订单.
580
+
581
+ Args:
582
+ order_or_id: 订单对象或订单 ID
583
+ """
584
+ if self.ctx is None:
585
+ raise RuntimeError("Context not ready")
586
+
587
+ order_id = order_or_id
588
+ if hasattr(order_or_id, "id"):
589
+ order_id = order_or_id.id
590
+
591
+ self.ctx.cancel_order(order_id)
592
+
593
+ def cancel_all_orders(self, symbol: Optional[str] = None) -> None:
594
+ """
595
+ 取消所有未完成订单.
596
+
597
+ Args:
598
+ symbol: 标的代码 (如果为 None, 取消所有标的订单)
599
+ """
600
+ for order in self.get_open_orders(symbol):
601
+ self.cancel_order(order)
602
+
603
+ def buy_all(self, symbol: Optional[str] = None) -> None:
604
+ """
605
+ 全仓买入 (Buy All).
606
+
607
+ 使用当前所有可用资金买入.
608
+
609
+ Args:
610
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
611
+ """
612
+ if self.ctx is None:
613
+ raise RuntimeError("Context not ready")
614
+
615
+ symbol = self._resolve_symbol(symbol)
616
+
617
+ # 获取参考价格
618
+ price = 0.0
619
+ if self.current_bar and self.current_bar.symbol == symbol:
620
+ price = self.current_bar.close
621
+ elif self.current_tick and self.current_tick.symbol == symbol:
622
+ price = self.current_tick.price
623
+
624
+ if price <= 0:
625
+ # 无法获取价格,无法计算数量
626
+ # 这里可以选择记录日志或抛出警告,暂时直接返回
627
+ return
628
+
629
+ cash = self.ctx.cash
630
+ # 计算最大可买数量 (向下取整)
631
+ # 注意:这里未扣除预估手续费,如果资金刚好卡在边界,可能会因为手续费导致拒单
632
+ # 建议引擎层或用户预留 buffer,或者在这里 * 0.99
633
+ quantity = int(cash / price)
634
+
635
+ if quantity > 0:
636
+ self.buy(symbol=symbol, quantity=quantity)
637
+
638
+ def close_position(self, symbol: Optional[str] = None) -> None:
639
+ """
640
+ 平仓 (Close Position).
641
+
642
+ 卖出/买入以抵消当前持仓.
643
+
644
+ Args:
645
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
646
+ """
647
+ symbol = self._resolve_symbol(symbol)
648
+ position = self.get_position(symbol)
649
+
650
+ if position > 0:
651
+ self.sell(symbol=symbol, quantity=position)
652
+ elif position < 0:
653
+ self.buy(symbol=symbol, quantity=abs(position))
654
+
655
+ def short(
656
+ self,
657
+ symbol: Optional[str] = None,
658
+ quantity: Optional[float] = None,
659
+ price: Optional[float] = None,
660
+ time_in_force: Optional[TimeInForce] = None,
661
+ trigger_price: Optional[float] = None,
662
+ ) -> None:
663
+ """
664
+ 卖出开空 (Short Sell).
665
+
666
+ Args:
667
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
668
+ quantity: 数量 (如果不填, 使用 Sizer 计算)
669
+ price: 限价 (None 为市价)
670
+ time_in_force: 订单有效期
671
+ trigger_price: 触发价 (止损/止盈)
672
+ """
673
+ if self.ctx is None:
674
+ raise RuntimeError("Context not ready")
675
+
676
+ # 1. Determine Symbol
677
+ symbol = self._resolve_symbol(symbol)
678
+
679
+ # 2. Determine Reference Price for Sizing
680
+ ref_price = price
681
+ if ref_price is None:
682
+ if self.current_bar:
683
+ ref_price = self.current_bar.close
684
+ elif self.current_tick:
685
+ ref_price = self.current_tick.price
686
+ else:
687
+ ref_price = 0.0
688
+
689
+ # 3. Determine Quantity via Sizer
690
+ if quantity is None:
691
+ quantity = self.sizer.get_size(ref_price, self.ctx.cash, self.ctx, symbol)
692
+
693
+ # 4. Execute Sell (Short)
694
+ if quantity > 0:
695
+ self.ctx.sell(symbol, quantity, price, time_in_force, trigger_price)
696
+
697
+ def cover(
698
+ self,
699
+ symbol: Optional[str] = None,
700
+ quantity: Optional[float] = None,
701
+ price: Optional[float] = None,
702
+ time_in_force: Optional[TimeInForce] = None,
703
+ trigger_price: Optional[float] = None,
704
+ ) -> None:
705
+ """
706
+ 买入平空 (Buy to Cover).
707
+
708
+ Args:
709
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar/Tick 的 symbol)
710
+ quantity: 数量 (如果不填, 默认平掉当前标的所有空头持仓)
711
+ price: 限价 (None 为市价)
712
+ time_in_force: 订单有效期
713
+ trigger_price: 触发价 (止损/止盈)
714
+ """
715
+ if self.ctx is None:
716
+ raise RuntimeError("Context not ready")
717
+
718
+ # 1. Determine Symbol
719
+ symbol = self._resolve_symbol(symbol)
720
+
721
+ # 2. Determine Quantity (Default to Close Short Position if None)
722
+ if quantity is None:
723
+ pos = self.ctx.get_position(symbol)
724
+ if pos < 0:
725
+ quantity = abs(pos)
726
+ else:
727
+ # No short position to cover
728
+ return
729
+
730
+ # 3. Execute Buy (Cover)
731
+ if quantity > 0:
732
+ self.ctx.buy(symbol, quantity, price, time_in_force, trigger_price)
733
+
734
+ def schedule(self, timestamp: int, payload: str) -> None:
735
+ """
736
+ 注册定时事件.
737
+
738
+ Args:
739
+ timestamp: 触发时间戳 (Unix 纳秒)
740
+ payload: 事件携带的数据
741
+ """
742
+ if self.ctx is None:
743
+ raise RuntimeError("Context not ready")
744
+ self.ctx.schedule(timestamp, payload)
745
+
746
+ def get_position(self, symbol: Optional[str] = None) -> float:
747
+ """获取当前持仓数量."""
748
+ if self.ctx is None:
749
+ return 0.0
750
+
751
+ if symbol is None:
752
+ if self.current_bar:
753
+ symbol = self.current_bar.symbol
754
+ elif self.current_tick:
755
+ symbol = self.current_tick.symbol
756
+ else:
757
+ return 0.0
758
+ return self.ctx.get_position(symbol)
759
+
760
+ def get_cash(self) -> float:
761
+ """获取现金."""
762
+ if self.ctx is None:
763
+ return 0.0
764
+ return self.ctx.cash
765
+
766
+
767
+ class VectorizedStrategy(Strategy):
768
+ """
769
+ 向量化策略基类 (Vectorized Strategy Base Class).
770
+
771
+ 支持预计算指标的高速回测模式.
772
+ 用户应在回测前使用 Pandas/Numpy 计算好所有指标,
773
+ 然后通过本类提供的高速游标访问机制在 on_bar 中读取.
774
+ """
775
+
776
+ def __init__(self, precalculated_data: Dict[str, Dict[str, np.ndarray]]) -> None:
777
+ """
778
+ Initialize VectorizedStrategy.
779
+
780
+ :param precalculated_data: 预计算数据字典
781
+ Structure: {symbol: {indicator_name: numpy_array}}
782
+ """
783
+ super().__init__()
784
+ self.precalc = precalculated_data
785
+ # 游标管理: {symbol: index}
786
+ self.cursors: defaultdict[str, int] = defaultdict(int)
787
+
788
+ # 默认禁用 Python 侧历史数据缓存以提升性能
789
+ self.set_history_depth(0)
790
+
791
+ def _on_bar_event(self, bar: Bar, ctx: StrategyContext) -> None:
792
+ """Wrap the user on_bar handler internally."""
793
+ # 1. Call standard setup (ctx, current_bar, history)
794
+ # Note: We copy logic from Strategy._on_bar_event to avoid double calling on_bar
795
+ # if we just called super()._on_bar_event(bar, ctx).
796
+ # Actually Strategy._on_bar_event calls self.on_bar(bar).
797
+
798
+ self.ctx = ctx
799
+ self.current_bar = bar
800
+
801
+ # 2. Call User Strategy
802
+ self.on_bar(bar)
803
+
804
+ # 3. Increment Cursor
805
+ self.cursors[bar.symbol] += 1
806
+
807
+ def get_value(self, name: str, symbol: Optional[str] = None) -> float:
808
+ """
809
+ 获取当前 Bar 对应的预计算指标值.
810
+
811
+ Args:
812
+ name: 指标名称
813
+ symbol: 标的代码 (如果不填, 默认使用当前 Bar 的 symbol)
814
+
815
+ Returns:
816
+ 指标值 (float). 如果不存在或越界,返回 nan.
817
+ """
818
+ symbol = self._resolve_symbol(symbol)
819
+ idx = self.cursors[symbol]
820
+
821
+ try:
822
+ return float(self.precalc[symbol][name][idx])
823
+ except (KeyError, IndexError):
824
+ return float("nan")