akquant 0.1.4__cp310-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of akquant might be problematic. Click here for more details.

akquant/backtest.py ADDED
@@ -0,0 +1,659 @@
1
+ from functools import cached_property
2
+ from typing import (
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ List,
7
+ Optional,
8
+ Type,
9
+ Union,
10
+ cast,
11
+ )
12
+
13
+ import pandas as pd
14
+
15
+ from .akquant import (
16
+ AssetType,
17
+ Bar,
18
+ DataFeed,
19
+ Engine,
20
+ ExecutionMode,
21
+ Instrument,
22
+ )
23
+ from .akquant import (
24
+ BacktestResult as RustBacktestResult,
25
+ )
26
+ from .config import BacktestConfig
27
+ from .data import ParquetDataCatalog
28
+ from .log import get_logger, register_logger
29
+ from .risk import apply_risk_config
30
+ from .strategy import Strategy
31
+ from .utils import df_to_arrays, prepare_dataframe
32
+
33
+
34
+ class BacktestResult:
35
+ """
36
+ Backtest Result Wrapper.
37
+
38
+ Wraps the underlying Rust BacktestResult to provide Python-friendly properties
39
+ like DataFrames.
40
+ """
41
+
42
+ def __init__(self, raw_result: RustBacktestResult, timezone: str = "Asia/Shanghai"):
43
+ """
44
+ Initialize the BacktestResult wrapper.
45
+
46
+ :param raw_result: The raw Rust BacktestResult object.
47
+ :param timezone: The timezone string for datetime conversion.
48
+ """
49
+ self._raw = raw_result
50
+ self._timezone = timezone
51
+
52
+ @property
53
+ def daily_positions_df(self) -> pd.DataFrame:
54
+ """
55
+ Get daily positions as a Pandas DataFrame.
56
+
57
+ Index: Datetime (Timezone-aware)
58
+ Columns: Symbols
59
+ Values: Quantity.
60
+ """
61
+ if not self._raw.daily_positions:
62
+ return pd.DataFrame()
63
+
64
+ # Unzip the list of tuples [(ts, {sym: qty}), ...]
65
+ timestamps, positions = zip(*self._raw.daily_positions)
66
+
67
+ df = pd.DataFrame(list(positions), index=timestamps)
68
+
69
+ # Convert nanosecond timestamp to datetime with timezone
70
+ df.index = pd.to_datetime(df.index, unit="ns", utc=True).tz_convert(
71
+ self._timezone
72
+ )
73
+
74
+ # Sort index just in case
75
+ df = df.sort_index()
76
+
77
+ # Fill missing values with 0 (assuming 0 position if not present in map)
78
+ df = df.fillna(0.0)
79
+
80
+ return df
81
+
82
+ @property
83
+ def metrics_df(self) -> pd.DataFrame:
84
+ """Get performance metrics as a Pandas DataFrame."""
85
+ metrics = self._raw.metrics
86
+
87
+ # Manually construct dictionary from known fields since PyO3 objects
88
+ # might not expose __dict__ directly in a clean way or might have extra fields.
89
+ # We use the fields defined in PerformanceMetrics (see akquant.pyi)
90
+ data = {
91
+ "total_return": metrics.total_return,
92
+ "annualized_return": metrics.annualized_return,
93
+ "max_drawdown": metrics.max_drawdown,
94
+ "max_drawdown_pct": metrics.max_drawdown_pct,
95
+ "sharpe_ratio": metrics.sharpe_ratio,
96
+ "sortino_ratio": metrics.sortino_ratio,
97
+ "volatility": metrics.volatility,
98
+ "ulcer_index": metrics.ulcer_index,
99
+ "upi": metrics.upi,
100
+ "equity_r2": metrics.equity_r2,
101
+ "std_error": metrics.std_error,
102
+ "win_rate": metrics.win_rate,
103
+ "initial_market_value": metrics.initial_market_value,
104
+ "end_market_value": metrics.end_market_value,
105
+ "total_return_pct": metrics.total_return_pct,
106
+ }
107
+
108
+ # Return as a DataFrame with one row
109
+ return pd.DataFrame([data], index=["Backtest"])
110
+
111
+ @cached_property
112
+ def trades_df(self) -> pd.DataFrame:
113
+ """Get closed trades as a Pandas DataFrame."""
114
+ if not self._raw.trades:
115
+ return pd.DataFrame()
116
+
117
+ data = []
118
+ for t in self._raw.trades:
119
+ data.append(
120
+ {
121
+ "symbol": t.symbol,
122
+ "entry_time": t.entry_time,
123
+ "exit_time": t.exit_time,
124
+ "entry_price": t.entry_price,
125
+ "exit_price": t.exit_price,
126
+ "quantity": t.quantity,
127
+ "direction": t.direction,
128
+ "pnl": t.pnl,
129
+ "net_pnl": t.net_pnl,
130
+ "return_pct": t.return_pct,
131
+ "commission": t.commission,
132
+ "duration_bars": t.duration_bars,
133
+ }
134
+ )
135
+
136
+ df = pd.DataFrame(data)
137
+
138
+ # Convert timestamps
139
+ df["entry_time"] = pd.to_datetime(
140
+ df["entry_time"], unit="ns", utc=True
141
+ ).dt.tz_convert(self._timezone)
142
+ df["exit_time"] = pd.to_datetime(
143
+ df["exit_time"], unit="ns", utc=True
144
+ ).dt.tz_convert(self._timezone)
145
+
146
+ return df
147
+
148
+ def __getattr__(self, name: str) -> Any:
149
+ """Delegate attribute access to the raw result."""
150
+ return getattr(self._raw, name)
151
+
152
+ def __repr__(self) -> str:
153
+ """Return the string representation of the raw result."""
154
+ return repr(self._raw)
155
+
156
+ def __dir__(self) -> List[str]:
157
+ """Return the list of attributes including raw result attributes."""
158
+ return list(
159
+ set(dir(self._raw) + list(self.__dict__.keys()) + ["daily_positions_df"])
160
+ )
161
+
162
+
163
+ class FunctionalStrategy(Strategy):
164
+ """内部策略包装器,用于支持函数式 API (Zipline 风格)."""
165
+
166
+ def __init__(
167
+ self,
168
+ initialize: Optional[Callable[[Any], None]],
169
+ on_bar: Optional[Callable[[Any, Bar], None]],
170
+ context: Optional[Dict[str, Any]] = None,
171
+ ):
172
+ """Initialize the FunctionalStrategy."""
173
+ super().__init__()
174
+ self._initialize = initialize
175
+ self._on_bar_func = on_bar
176
+ self._context = context or {}
177
+
178
+ # 将 context 注入到 self 中,模拟 Zipline 的 context 对象
179
+ # 用户可以通过 self.xxx 访问 context 属性
180
+ for k, v in self._context.items():
181
+ setattr(self, k, v)
182
+
183
+ # 调用初始化函数
184
+ if self._initialize is not None:
185
+ self._initialize(self)
186
+
187
+ def on_bar(self, bar: Bar) -> None:
188
+ """Delegate on_bar event to the user-provided function."""
189
+ if self._on_bar_func is not None:
190
+ self._on_bar_func(self, bar)
191
+
192
+
193
+ def run_backtest(
194
+ data: Optional[Union[pd.DataFrame, Dict[str, pd.DataFrame], List[Bar]]] = None,
195
+ strategy: Union[Type[Strategy], Strategy, Callable[[Any, Bar], None], None] = None,
196
+ symbol: Union[str, List[str]] = "BENCHMARK",
197
+ cash: float = 1_000_000.0,
198
+ commission: float = 0.0003,
199
+ stamp_tax: float = 0.0005,
200
+ transfer_fee: float = 0.00001,
201
+ min_commission: float = 5.0,
202
+ execution_mode: Union[ExecutionMode, str] = ExecutionMode.NextOpen,
203
+ timezone: str = "Asia/Shanghai",
204
+ initialize: Optional[Callable[[Any], None]] = None,
205
+ context: Optional[Dict[str, Any]] = None,
206
+ history_depth: int = 0,
207
+ lot_size: Union[int, Dict[str, int], None] = None,
208
+ show_progress: bool = True,
209
+ config: Optional[BacktestConfig] = None,
210
+ **kwargs: Any,
211
+ ) -> BacktestResult:
212
+ """
213
+ 简化版回测入口函数.
214
+
215
+ :param data: 回测数据,可以是 Pandas DataFrame 或 Bar 列表.
216
+ 可选(如果配置了config或策略订阅).
217
+ :param strategy: 策略类、策略实例或 on_bar 回调函数
218
+ :param symbol: 标的代码
219
+ :param cash: 初始资金
220
+ :param commission: 佣金率
221
+ :param stamp_tax: 印花税率 (仅卖出)
222
+ :param transfer_fee: 过户费率
223
+ :param min_commission: 最低佣金
224
+ :param execution_mode: 执行模式 (ExecutionMode.NextOpen 或 "next_open")
225
+ :param timezone: 时区名称
226
+ :param initialize: 初始化回调函数 (仅当 strategy 为函数时使用)
227
+ :param context: 初始上下文数据 (仅当 strategy 为函数时使用)
228
+ :param history_depth: 自动维护历史数据的长度 (0 表示禁用)
229
+ :param lot_size: 最小交易单位。如果是 int,则应用于所有标的;
230
+ 如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
231
+ :param show_progress: 是否显示进度条 (默认 True)
232
+ :param config: BacktestConfig 配置对象 (可选)
233
+ :return: 回测结果 Result 对象
234
+ """
235
+ # 1. 确保日志已初始化
236
+ logger = get_logger()
237
+ if not logger.handlers:
238
+ register_logger(console=True, level="INFO")
239
+ logger = get_logger()
240
+
241
+ # 1.5 处理 Config 覆盖
242
+ if config:
243
+ if config.start_date:
244
+ kwargs["start_date"] = config.start_date
245
+ if config.end_date:
246
+ kwargs["end_date"] = config.end_date
247
+ if config.timezone:
248
+ timezone = config.timezone
249
+ if config.show_progress is not None:
250
+ show_progress = config.show_progress
251
+ if config.history_depth is not None:
252
+ history_depth = config.history_depth
253
+
254
+ if config.strategy_config:
255
+ cash = config.strategy_config.initial_cash
256
+ # Fee handling could be more complex, simplifying here
257
+ commission = config.strategy_config.fee_amount or commission
258
+
259
+ # Risk Config injection handled later
260
+
261
+ # 2. 实例化策略 (提前实例化以获取订阅信息)
262
+ strategy_instance = None
263
+
264
+ if isinstance(strategy, type) and issubclass(strategy, Strategy):
265
+ try:
266
+ strategy_instance = strategy(**kwargs)
267
+ except TypeError:
268
+ strategy_instance = strategy()
269
+ elif isinstance(strategy, Strategy):
270
+ strategy_instance = strategy
271
+ elif callable(strategy):
272
+ strategy_instance = FunctionalStrategy(
273
+ initialize, cast(Callable[[Any, Bar], None], strategy), context
274
+ )
275
+ elif strategy is None:
276
+ raise ValueError("Strategy must be provided.")
277
+ else:
278
+ raise ValueError("Invalid strategy type")
279
+
280
+ # 注入 context
281
+ if context and hasattr(strategy_instance, "_context"):
282
+ pass
283
+ elif context and strategy_instance:
284
+ for k, v in context.items():
285
+ setattr(strategy_instance, k, v)
286
+
287
+ # 注入 Config 中的 Risk Config
288
+ if config and config.strategy_config and config.strategy_config.risk:
289
+ # 如果策略支持 set_risk_config (假设我们添加它,或者直接注入属性)
290
+ if hasattr(strategy_instance, "risk_config"):
291
+ strategy_instance.risk_config = config.strategy_config.risk # type: ignore
292
+
293
+ # 调用 on_start 获取订阅
294
+ if hasattr(strategy_instance, "on_start"):
295
+ strategy_instance.on_start()
296
+
297
+ # 3. 准备数据源和 Symbol
298
+ feed = DataFeed()
299
+ symbols = []
300
+ data_map_for_indicators = {}
301
+
302
+ # Normalize symbol arg to list
303
+ if isinstance(symbol, str):
304
+ symbols = [symbol]
305
+ elif isinstance(symbol, list):
306
+ symbols = symbol
307
+ else:
308
+ symbols = ["BENCHMARK"]
309
+
310
+ # Merge with Config instruments
311
+ if config and config.instruments:
312
+ for s in config.instruments:
313
+ if s not in symbols:
314
+ symbols.append(s)
315
+
316
+ # Merge with Strategy subscriptions
317
+ if hasattr(strategy_instance, "_subscriptions"):
318
+ for s in strategy_instance._subscriptions:
319
+ if s not in symbols:
320
+ symbols.append(s)
321
+
322
+ # Determine Data Loading Strategy
323
+ if data is not None:
324
+ # Use provided data
325
+ if isinstance(data, pd.DataFrame):
326
+ target_symbol = symbols[0] if symbols else "BENCHMARK"
327
+ df = prepare_dataframe(data)
328
+ data_map_for_indicators[target_symbol] = df
329
+ arrays = df_to_arrays(df, symbol=target_symbol)
330
+ feed.add_arrays(*arrays) # type: ignore
331
+ feed.sort()
332
+ if target_symbol not in symbols:
333
+ symbols = [target_symbol]
334
+ elif isinstance(data, dict):
335
+ for sym, df in data.items():
336
+ df_prep = prepare_dataframe(df)
337
+ data_map_for_indicators[sym] = df_prep
338
+ arrays = df_to_arrays(df_prep, symbol=sym)
339
+ feed.add_arrays(*arrays) # type: ignore
340
+ if sym not in symbols:
341
+ symbols.append(sym)
342
+ feed.sort()
343
+ elif isinstance(data, list):
344
+ if data:
345
+ data.sort(key=lambda b: b.timestamp)
346
+ feed.add_bars(data)
347
+ else:
348
+ # Load from Catalog / Akshare
349
+ if not symbols:
350
+ logger.warning("No symbols specified and no data provided.")
351
+
352
+ catalog = ParquetDataCatalog()
353
+ start_date = kwargs.get("start_date")
354
+ end_date = kwargs.get("end_date")
355
+
356
+ loaded_count = 0
357
+ for sym in symbols:
358
+ # Try Catalog
359
+ df = catalog.read(sym, start_date=start_date, end_date=end_date)
360
+ if df.empty:
361
+ logger.warning(f"Data not found in catalog for {sym}")
362
+ continue
363
+
364
+ if not df.empty:
365
+ df = prepare_dataframe(df)
366
+ data_map_for_indicators[sym] = df
367
+ arrays = df_to_arrays(df, symbol=sym)
368
+ feed.add_arrays(*arrays) # type: ignore
369
+ loaded_count += 1
370
+
371
+ if loaded_count > 0:
372
+ feed.sort()
373
+ else:
374
+ if symbols:
375
+ logger.warning("Failed to load data for all requested symbols.")
376
+
377
+ # 4. 设置引擎
378
+ engine = Engine()
379
+ # engine.set_timezone_name(timezone)
380
+ offset_delta = pd.Timestamp.now(tz=timezone).utcoffset()
381
+ if offset_delta is None:
382
+ raise ValueError(f"Invalid timezone: {timezone}")
383
+ offset = int(offset_delta.total_seconds())
384
+ engine.set_timezone(offset)
385
+ engine.set_cash(cash)
386
+
387
+ # ... (ExecutionMode logic)
388
+ if isinstance(execution_mode, str):
389
+ mode_map = {
390
+ "next_open": ExecutionMode.NextOpen,
391
+ "current_close": ExecutionMode.CurrentClose,
392
+ }
393
+ mode = mode_map.get(execution_mode.lower())
394
+ if not mode:
395
+ logger.warning(
396
+ f"Unknown execution mode '{execution_mode}', defaulting to NextOpen"
397
+ )
398
+ mode = ExecutionMode.NextOpen
399
+ engine.set_execution_mode(mode)
400
+ else:
401
+ engine.set_execution_mode(execution_mode)
402
+
403
+ engine.set_t_plus_one(False) # 默认 T+0,可配置
404
+ engine.set_force_session_continuous(True)
405
+ engine.set_stock_fee_rules(commission, stamp_tax, transfer_fee, min_commission)
406
+
407
+ # Configure other asset fees if provided
408
+ if "fund_commission" in kwargs:
409
+ engine.set_fund_fee_rules(
410
+ kwargs["fund_commission"],
411
+ kwargs.get("fund_transfer_fee", 0.0),
412
+ kwargs.get("fund_min_commission", 0.0),
413
+ )
414
+
415
+ if "option_commission" in kwargs:
416
+ engine.set_option_fee_rules(kwargs["option_commission"])
417
+
418
+ # Apply Risk Config
419
+ if config and config.strategy_config:
420
+ apply_risk_config(engine, config.strategy_config.risk)
421
+
422
+ # 5. 添加标的
423
+ multiplier = kwargs.get("multiplier", 1.0)
424
+ margin_ratio = kwargs.get("margin_ratio", 1.0)
425
+ tick_size = kwargs.get("tick_size", 0.01)
426
+ asset_type = kwargs.get("asset_type", AssetType.Stock)
427
+
428
+ # Option specific fields
429
+ option_type = kwargs.get("option_type", None)
430
+ strike_price = kwargs.get("strike_price", None)
431
+ expiry_date = kwargs.get("expiry_date", None)
432
+
433
+ for sym in symbols:
434
+ # Determine lot_size for this symbol
435
+ current_lot_size = None
436
+ if isinstance(lot_size, int):
437
+ current_lot_size = lot_size
438
+ elif isinstance(lot_size, dict):
439
+ current_lot_size = lot_size.get(sym)
440
+
441
+ instr = Instrument(
442
+ sym,
443
+ asset_type,
444
+ multiplier,
445
+ margin_ratio,
446
+ tick_size,
447
+ option_type,
448
+ strike_price,
449
+ expiry_date,
450
+ current_lot_size,
451
+ )
452
+ engine.add_instrument(instr)
453
+
454
+ # 6. 添加数据
455
+ engine.add_data(feed)
456
+
457
+ # 7. 运行回测
458
+ logger.info("Running backtest via run_backtest()...")
459
+
460
+ # 设置自动历史数据维护
461
+ if history_depth > 0:
462
+ strategy_instance.set_history_depth(history_depth)
463
+
464
+ # 7.5 Prepare Indicators (Vectorized Pre-calculation)
465
+ if hasattr(strategy_instance, "_prepare_indicators") and data_map_for_indicators:
466
+ strategy_instance._prepare_indicators(data_map_for_indicators)
467
+
468
+ engine.run(strategy_instance, show_progress)
469
+
470
+ return BacktestResult(engine.get_results(), timezone=timezone)
471
+
472
+
473
+ def plot_result(
474
+ result: Any,
475
+ show: bool = True,
476
+ filename: Optional[str] = None,
477
+ benchmark: Optional[pd.Series] = None,
478
+ ) -> None:
479
+ """
480
+ 绘制回测结果 (权益曲线、回撤、日收益率).
481
+
482
+ :param result: BacktestResult 对象
483
+ :param show: 是否调用 plt.show()
484
+ :param filename: 保存图片的文件名
485
+ :param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
486
+ """
487
+ try:
488
+ from datetime import datetime
489
+
490
+ import matplotlib.dates as mdates
491
+ import matplotlib.pyplot as plt
492
+ from matplotlib.gridspec import GridSpec
493
+ except ImportError:
494
+ print(
495
+ "Error: matplotlib is required for plotting. "
496
+ "Please install it via 'pip install matplotlib'."
497
+ )
498
+ return
499
+
500
+ # Extract data
501
+ equity_curve = result.equity_curve # List[Tuple[int, float]]
502
+
503
+ if not equity_curve:
504
+ print("No equity curve data to plot.")
505
+ return
506
+
507
+ # Check if timestamp is in nanoseconds (e.g. > 1e11)
508
+ # 1e11 seconds is roughly year 5138, so valid seconds are < 1e11
509
+ # 1e18 nanoseconds is roughly year 2001
510
+ first_ts = equity_curve[0][0]
511
+ scale = 1.0
512
+ if first_ts > 1e11:
513
+ scale = 1e-9
514
+
515
+ from datetime import timezone
516
+
517
+ # Use UTC to avoid local timezone issues and align with benchmark data
518
+ times = [
519
+ datetime.fromtimestamp(t * scale, tz=timezone.utc).replace(tzinfo=None)
520
+ for t, _ in equity_curve
521
+ ]
522
+ equity = [e for _, e in equity_curve]
523
+
524
+ # Convert to DataFrame for easier calculation
525
+ df = pd.DataFrame({"equity": equity}, index=times)
526
+ df.index.name = "Date"
527
+ df["returns"] = df["equity"].pct_change().fillna(0)
528
+
529
+ # Calculate Drawdown
530
+ rolling_max = df["equity"].cummax()
531
+ drawdown = (df["equity"] - rolling_max) / rolling_max
532
+
533
+ # Create figure with GridSpec
534
+ fig = plt.figure(figsize=(14, 10))
535
+ # 3 rows: Equity (3), Drawdown (1), Daily Returns (1)
536
+ gs = GridSpec(3, 1, height_ratios=[3, 1, 1], hspace=0.05)
537
+
538
+ # 1. Equity Curve
539
+ ax1 = fig.add_subplot(gs[0])
540
+ ax1.plot(df.index, df["equity"], label="Strategy", color="#1f77b4", linewidth=1.5)
541
+
542
+ if benchmark is not None:
543
+ # Align benchmark to strategy dates
544
+ try:
545
+ # Ensure benchmark has DatetimeIndex
546
+ if not isinstance(benchmark.index, pd.DatetimeIndex):
547
+ benchmark.index = pd.to_datetime(benchmark.index)
548
+
549
+ # Normalize timezones: ensure benchmark is tz-naive UTC
550
+ if benchmark.index.tz is not None:
551
+ benchmark.index = benchmark.index.tz_convert("UTC").tz_localize(None)
552
+
553
+ # Reindex benchmark to match strategy dates (forward fill for missing days)
554
+ # Normalize dates to start of day for alignment if needed
555
+ # For simplicity, we just plot what overlaps
556
+
557
+ # Calculate cumulative return of benchmark
558
+ bench_cum = (1 + benchmark).cumprod()
559
+
560
+ # Rebase benchmark to match initial strategy equity
561
+ initial_equity = df["equity"].iloc[0]
562
+ if not bench_cum.empty:
563
+ # Align start
564
+ # Find the closest date in benchmark to start date
565
+ start_date = df.index[0]
566
+ if start_date in bench_cum.index:
567
+ base_val = bench_cum.loc[start_date]
568
+ else:
569
+ # Fallback: use first available
570
+ base_val = bench_cum.iloc[0]
571
+
572
+ bench_scaled = (bench_cum / base_val) * initial_equity
573
+
574
+ # Filter to strategy range
575
+ bench_plot = bench_scaled[df.index[0] : df.index[-1]] # type: ignore
576
+ ax1.plot(
577
+ bench_plot.index,
578
+ bench_plot,
579
+ label="Benchmark",
580
+ color="gray",
581
+ linestyle="--",
582
+ alpha=0.7,
583
+ )
584
+ except Exception as e:
585
+ print(f"Warning: Failed to plot benchmark: {e}")
586
+
587
+ ax1.set_title("Strategy Performance Analysis", fontsize=14, fontweight="bold")
588
+ ax1.set_ylabel("Equity", fontsize=10)
589
+ ax1.grid(True, linestyle="--", alpha=0.3)
590
+ ax1.legend(loc="upper left", frameon=True, fancybox=True, framealpha=0.8)
591
+
592
+ # Add Metrics Text Box
593
+ metrics = result.metrics
594
+ trade_metrics = result.trade_metrics
595
+
596
+ metrics_text = [
597
+ f"Total Return: {metrics.total_return_pct:>8.2f}%",
598
+ f"Annualized: {metrics.annualized_return:>8.2%}",
599
+ f"Sharpe Ratio: {metrics.sharpe_ratio:>8.2f}",
600
+ f"Max Drawdown: {metrics.max_drawdown_pct:>8.2f}%",
601
+ f"Win Rate: {metrics.win_rate:>8.2%}",
602
+ ]
603
+
604
+ if hasattr(trade_metrics, "total_closed_trades"):
605
+ metrics_text.append(f"Trades: {trade_metrics.total_closed_trades:>8d}")
606
+
607
+ text_str = "\n".join(metrics_text)
608
+
609
+ props = dict(boxstyle="round", facecolor="white", alpha=0.8, edgecolor="lightgray")
610
+ ax1.text(
611
+ 0.02,
612
+ 0.05,
613
+ text_str,
614
+ transform=ax1.transAxes,
615
+ fontsize=9,
616
+ verticalalignment="bottom",
617
+ fontfamily="monospace",
618
+ bbox=props,
619
+ )
620
+
621
+ # 2. Drawdown
622
+ ax2 = fig.add_subplot(gs[1], sharex=ax1)
623
+ ax2.fill_between(
624
+ df.index, drawdown, 0, color="#d62728", alpha=0.3, label="Drawdown"
625
+ )
626
+ ax2.plot(df.index, drawdown, color="#d62728", linewidth=0.8, alpha=0.8)
627
+ ax2.set_ylabel("Drawdown", fontsize=10)
628
+ ax2.grid(True, linestyle="--", alpha=0.3)
629
+ # ax2.legend(loc='lower right', fontsize=8)
630
+
631
+ # 3. Daily Returns
632
+ ax3 = fig.add_subplot(gs[2], sharex=ax1)
633
+ ax3.bar(
634
+ df.index,
635
+ df["returns"],
636
+ color="gray",
637
+ alpha=0.5,
638
+ label="Daily Returns",
639
+ width=1.0 if len(df) < 100 else 0.8,
640
+ )
641
+ # Highlight extreme returns? No, keep simple.
642
+ ax3.set_ylabel("Returns", fontsize=10)
643
+ ax3.grid(True, linestyle="--", alpha=0.3)
644
+
645
+ # Format X axis
646
+ ax3.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
647
+ plt.setp(ax1.get_xticklabels(), visible=False)
648
+ plt.setp(ax2.get_xticklabels(), visible=False)
649
+ plt.xticks(rotation=0)
650
+
651
+ # Adjust margins
652
+ plt.subplots_adjust(top=0.95, bottom=0.05, left=0.08, right=0.95)
653
+
654
+ if filename:
655
+ plt.savefig(filename, dpi=100, bbox_inches="tight")
656
+ print(f"Plot saved to {filename}")
657
+
658
+ if show:
659
+ plt.show()