akquant 0.1.0__cp39-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of akquant might be problematic. Click here for more details.

akquant/backtest.py ADDED
@@ -0,0 +1,414 @@
1
+ import pandas as pd
2
+ from typing import Union, List, Optional, Type, Callable, Dict, Any
3
+ from .akquant import Engine, ExecutionMode, Instrument, AssetType, DataFeed, Bar
4
+ from .strategy import Strategy
5
+ from .data import DataLoader
6
+ from .log import get_logger, register_logger
7
+ from .utils import prepare_dataframe, df_to_arrays
8
+
9
+
10
+ class FunctionalStrategy(Strategy):
11
+ """
12
+ 内部策略包装器,用于支持函数式 API (Zipline 风格)
13
+ """
14
+
15
+ def __init__(
16
+ self, initialize: Callable, on_bar: Callable, context: Dict[str, Any] = None
17
+ ):
18
+ super().__init__()
19
+ self._initialize = initialize
20
+ self._on_bar_func = on_bar
21
+ self._context = context or {}
22
+
23
+ # 将 context 注入到 self 中,模拟 Zipline 的 context 对象
24
+ # 用户可以通过 self.xxx 访问 context 属性
25
+ for k, v in self._context.items():
26
+ setattr(self, k, v)
27
+
28
+ # 调用初始化函数
29
+ if self._initialize:
30
+ self._initialize(self)
31
+
32
+ def on_bar(self, bar: Bar):
33
+ if self._on_bar_func:
34
+ self._on_bar_func(self, bar)
35
+
36
+
37
+ def run_backtest(
38
+ data: Union[pd.DataFrame, List[Bar]],
39
+ strategy: Union[Type[Strategy], Strategy, Callable[[Any, Bar], None]],
40
+ symbol: str = "BENCHMARK",
41
+ cash: float = 1_000_000.0,
42
+ commission: float = 0.0003,
43
+ stamp_tax: float = 0.0005,
44
+ transfer_fee: float = 0.00001,
45
+ min_commission: float = 5.0,
46
+ execution_mode: Union[ExecutionMode, str] = ExecutionMode.NextOpen,
47
+ timezone: str = "Asia/Shanghai",
48
+ initialize: Optional[Callable[[Any], None]] = None,
49
+ context: Optional[Dict[str, Any]] = None,
50
+ history_depth: int = 0,
51
+ lot_size: Union[int, Dict[str, int], None] = None,
52
+ show_progress: bool = True,
53
+ **kwargs,
54
+ ) -> Any:
55
+ """
56
+ 简化版回测入口函数
57
+
58
+ :param data: 回测数据,可以是 Pandas DataFrame 或 Bar 列表
59
+ :param strategy: 策略类、策略实例或 on_bar 回调函数
60
+ :param symbol: 标的代码
61
+ :param cash: 初始资金
62
+ :param commission: 佣金率
63
+ :param stamp_tax: 印花税率 (仅卖出)
64
+ :param transfer_fee: 过户费率
65
+ :param min_commission: 最低佣金
66
+ :param execution_mode: 执行模式 (ExecutionMode.NextOpen 或 "next_open")
67
+ :param timezone: 时区名称
68
+ :param initialize: 初始化回调函数 (仅当 strategy 为函数时使用)
69
+ :param context: 初始上下文数据 (仅当 strategy 为函数时使用)
70
+ :param history_depth: 自动维护历史数据的长度 (0 表示禁用)
71
+ :param lot_size: 最小交易单位。如果是 int,则应用于所有标的;如果是 Dict[str, int],则按代码匹配;如果不传(None),默认为 1。
72
+ :param show_progress: 是否显示进度条 (默认 True)
73
+ :return: 回测结果 Result 对象
74
+ """
75
+
76
+ # 1. 确保日志已初始化
77
+ # 如果用户没有配置过日志,这里会提供一个默认配置
78
+ logger = get_logger()
79
+ if not logger.handlers:
80
+ register_logger(console=True, level="INFO")
81
+ logger = get_logger()
82
+
83
+ # 2. 准备数据
84
+ feed = DataFeed()
85
+ symbols = []
86
+
87
+ # Normalize symbol to list
88
+ if isinstance(symbol, str):
89
+ symbols = [symbol]
90
+ elif isinstance(symbol, list):
91
+ symbols = symbol
92
+ else:
93
+ # If symbol not provided, try to infer from Dict keys or use default
94
+ symbols = ["BENCHMARK"]
95
+
96
+ if isinstance(data, pd.DataFrame):
97
+ # Single DataFrame -> Single Symbol (use first symbol)
98
+ target_symbol = symbols[0] if symbols else "BENCHMARK"
99
+ df = prepare_dataframe(data)
100
+ # Fast Path: Avoid creating Bar objects in Python
101
+ arrays = df_to_arrays(df, symbol=target_symbol)
102
+ feed.add_arrays(*arrays)
103
+ feed.sort()
104
+
105
+ if target_symbol not in symbols:
106
+ symbols = [target_symbol]
107
+
108
+ elif isinstance(data, dict):
109
+ # Dict[str, DataFrame] -> Multi Symbol
110
+ symbols = list(data.keys())
111
+ for sym, df in data.items():
112
+ df_prep = prepare_dataframe(df)
113
+ # Fast Path
114
+ arrays = df_to_arrays(df_prep, symbol=sym)
115
+ feed.add_arrays(*arrays)
116
+ feed.sort()
117
+
118
+ elif isinstance(data, list):
119
+ # List[Bar]
120
+ if data:
121
+ data.sort(key=lambda b: b.timestamp)
122
+ feed.add_bars(data)
123
+ else:
124
+ raise ValueError("data must be a DataFrame, Dict[str, DataFrame], or List[Bar]")
125
+
126
+ # 3. 设置引擎
127
+ engine = Engine()
128
+ engine.set_timezone_name(timezone)
129
+ engine.set_cash(cash)
130
+
131
+ # ... (ExecutionMode logic)
132
+ if isinstance(execution_mode, str):
133
+ mode_map = {
134
+ "next_open": ExecutionMode.NextOpen,
135
+ "current_close": ExecutionMode.CurrentClose,
136
+ }
137
+ mode = mode_map.get(execution_mode.lower())
138
+ if not mode:
139
+ logger.warning(
140
+ f"Unknown execution mode '{execution_mode}', defaulting to NextOpen"
141
+ )
142
+ mode = ExecutionMode.NextOpen
143
+ engine.set_execution_mode(mode)
144
+ else:
145
+ engine.set_execution_mode(execution_mode)
146
+
147
+ engine.set_t_plus_one(False) # 默认 T+0,可配置
148
+ engine.set_force_session_continuous(True)
149
+ engine.set_stock_fee_rules(commission, stamp_tax, transfer_fee, min_commission)
150
+
151
+ # Configure other asset fees if provided
152
+ if "fund_commission" in kwargs:
153
+ engine.set_fund_fee_rules(
154
+ kwargs["fund_commission"],
155
+ kwargs.get("fund_transfer_fee", 0.0),
156
+ kwargs.get("fund_min_commission", 0.0),
157
+ )
158
+
159
+ if "option_commission" in kwargs:
160
+ engine.set_option_fee_rules(kwargs["option_commission"])
161
+
162
+ # 4. 添加标的
163
+ multiplier = kwargs.get("multiplier", 1.0)
164
+ margin_ratio = kwargs.get("margin_ratio", 1.0)
165
+ tick_size = kwargs.get("tick_size", 0.01)
166
+ asset_type = kwargs.get("asset_type", AssetType.Stock)
167
+
168
+ # Option specific fields
169
+ option_type = kwargs.get("option_type", None)
170
+ strike_price = kwargs.get("strike_price", None)
171
+ expiry_date = kwargs.get("expiry_date", None)
172
+ # lot_size is handled separately via argument
173
+
174
+ for sym in symbols:
175
+ # Determine lot_size for this symbol
176
+ current_lot_size = None
177
+ if isinstance(lot_size, int):
178
+ current_lot_size = lot_size
179
+ elif isinstance(lot_size, dict):
180
+ current_lot_size = lot_size.get(sym)
181
+
182
+ instr = Instrument(
183
+ sym,
184
+ asset_type,
185
+ multiplier,
186
+ margin_ratio,
187
+ tick_size,
188
+ option_type,
189
+ strike_price,
190
+ expiry_date,
191
+ current_lot_size,
192
+ )
193
+ engine.add_instrument(instr)
194
+
195
+ # 5. 添加数据
196
+ engine.add_data(feed)
197
+
198
+ # ... (Rest is same)
199
+
200
+ # 6. 准备策略实例
201
+ strategy_instance = None
202
+
203
+ if isinstance(strategy, type) and issubclass(strategy, Strategy):
204
+ # 如果是策略类,实例化它
205
+ # 尝试传递 kwargs 给构造函数,如果失败则无参数构造
206
+ try:
207
+ strategy_instance = strategy(**kwargs)
208
+ except TypeError:
209
+ strategy_instance = strategy()
210
+ elif isinstance(strategy, Strategy):
211
+ # 如果已经是实例
212
+ strategy_instance = strategy
213
+ elif callable(strategy):
214
+ # 如果是函数,假设是 on_bar 回调 (Zipline 风格)
215
+ # 需要配合 initialize 使用
216
+ strategy_instance = FunctionalStrategy(initialize, strategy, context)
217
+ else:
218
+ raise ValueError("Invalid strategy type")
219
+
220
+ # 7. 运行回测
221
+ logger.info("Running backtest via run_backtest()...")
222
+
223
+ # 注入 context 到策略实例
224
+ if context and hasattr(strategy_instance, "_context"):
225
+ # 如果是 FunctionalStrategy
226
+ # 已经在 __init__ 中注入了
227
+ pass
228
+ elif context and strategy_instance:
229
+ # 如果是普通 Strategy,尝试注入属性
230
+ for k, v in context.items():
231
+ setattr(strategy_instance, k, v)
232
+
233
+ # 设置自动历史数据维护
234
+ if history_depth > 0:
235
+ strategy_instance.set_history_depth(history_depth)
236
+
237
+ engine.run(strategy_instance, show_progress)
238
+
239
+ return engine.get_results()
240
+
241
+
242
+ def plot_result(
243
+ result: Any, show: bool = True, filename: str = None, benchmark: pd.Series = None
244
+ ):
245
+ """
246
+ 绘制回测结果 (权益曲线、回撤、日收益率)
247
+
248
+ :param result: BacktestResult 对象
249
+ :param show: 是否调用 plt.show()
250
+ :param filename: 保存图片的文件名
251
+ :param benchmark: 基准收益率序列 (可选, Series with DatetimeIndex)
252
+ """
253
+ try:
254
+ import matplotlib.pyplot as plt
255
+ import matplotlib.dates as mdates
256
+ from matplotlib.gridspec import GridSpec
257
+ from datetime import datetime
258
+ import numpy as np
259
+ except ImportError:
260
+ print(
261
+ "Error: matplotlib is required for plotting. Please install it via 'pip install matplotlib'."
262
+ )
263
+ return
264
+
265
+ # Extract data
266
+ equity_curve = result.equity_curve # List[Tuple[int, float]]
267
+
268
+ if not equity_curve:
269
+ print("No equity curve data to plot.")
270
+ return
271
+
272
+ # Check if timestamp is in nanoseconds (e.g. > 1e11)
273
+ # 1e11 seconds is roughly year 5138, so valid seconds are < 1e11
274
+ # 1e18 nanoseconds is roughly year 2001
275
+ first_ts = equity_curve[0][0]
276
+ scale = 1.0
277
+ if first_ts > 1e11:
278
+ scale = 1e-9
279
+
280
+ times = [datetime.fromtimestamp(t * scale) for t, _ in equity_curve]
281
+ equity = [e for _, e in equity_curve]
282
+
283
+ # Convert to DataFrame for easier calculation
284
+ df = pd.DataFrame({"equity": equity}, index=times)
285
+ df.index.name = "Date"
286
+ df["returns"] = df["equity"].pct_change().fillna(0)
287
+
288
+ # Calculate Drawdown
289
+ rolling_max = df["equity"].cummax()
290
+ drawdown = (df["equity"] - rolling_max) / rolling_max
291
+
292
+ # Create figure with GridSpec
293
+ fig = plt.figure(figsize=(14, 10))
294
+ # 3 rows: Equity (3), Drawdown (1), Daily Returns (1)
295
+ gs = GridSpec(3, 1, height_ratios=[3, 1, 1], hspace=0.05)
296
+
297
+ # 1. Equity Curve
298
+ ax1 = fig.add_subplot(gs[0])
299
+ ax1.plot(df.index, df["equity"], label="Strategy", color="#1f77b4", linewidth=1.5)
300
+
301
+ if benchmark is not None:
302
+ # Align benchmark to strategy dates
303
+ try:
304
+ # Ensure benchmark has DatetimeIndex
305
+ if not isinstance(benchmark.index, pd.DatetimeIndex):
306
+ benchmark.index = pd.to_datetime(benchmark.index)
307
+
308
+ # Reindex benchmark to match strategy dates (forward fill for missing days if any, but strict alignment preferred)
309
+ # Normalize dates to start of day for alignment if needed, but here we assume close match
310
+ # For simplicity, we just plot what overlaps
311
+
312
+ # Calculate cumulative return of benchmark
313
+ bench_cum = (1 + benchmark).cumprod()
314
+
315
+ # Rebase benchmark to match initial strategy equity
316
+ initial_equity = df["equity"].iloc[0]
317
+ if not bench_cum.empty:
318
+ # Align start
319
+ # Find the closest date in benchmark to start date
320
+ start_date = df.index[0]
321
+ if start_date in bench_cum.index:
322
+ base_val = bench_cum.loc[start_date]
323
+ else:
324
+ # Fallback: use first available
325
+ base_val = bench_cum.iloc[0]
326
+
327
+ bench_scaled = (bench_cum / base_val) * initial_equity
328
+
329
+ # Filter to strategy range
330
+ bench_plot = bench_scaled[df.index[0] : df.index[-1]]
331
+ ax1.plot(
332
+ bench_plot.index,
333
+ bench_plot,
334
+ label="Benchmark",
335
+ color="gray",
336
+ linestyle="--",
337
+ alpha=0.7,
338
+ )
339
+ except Exception as e:
340
+ print(f"Warning: Failed to plot benchmark: {e}")
341
+
342
+ ax1.set_title("Strategy Performance Analysis", fontsize=14, fontweight="bold")
343
+ ax1.set_ylabel("Equity", fontsize=10)
344
+ ax1.grid(True, linestyle="--", alpha=0.3)
345
+ ax1.legend(loc="upper left", frameon=True, fancybox=True, framealpha=0.8)
346
+
347
+ # Add Metrics Text Box
348
+ metrics = result.metrics
349
+ trade_metrics = result.trade_metrics
350
+
351
+ metrics_text = [
352
+ f"Total Return: {metrics.total_return_pct:>8.2f}%",
353
+ f"Annualized: {metrics.annualized_return:>8.2%}",
354
+ f"Sharpe Ratio: {metrics.sharpe_ratio:>8.2f}",
355
+ f"Max Drawdown: {metrics.max_drawdown_pct:>8.2f}%",
356
+ f"Win Rate: {metrics.win_rate:>8.2%}",
357
+ ]
358
+
359
+ if hasattr(trade_metrics, "total_closed_trades"):
360
+ metrics_text.append(f"Trades: {trade_metrics.total_closed_trades:>8d}")
361
+
362
+ text_str = "\n".join(metrics_text)
363
+
364
+ props = dict(boxstyle="round", facecolor="white", alpha=0.8, edgecolor="lightgray")
365
+ ax1.text(
366
+ 0.02,
367
+ 0.05,
368
+ text_str,
369
+ transform=ax1.transAxes,
370
+ fontsize=9,
371
+ verticalalignment="bottom",
372
+ fontfamily="monospace",
373
+ bbox=props,
374
+ )
375
+
376
+ # 2. Drawdown
377
+ ax2 = fig.add_subplot(gs[1], sharex=ax1)
378
+ ax2.fill_between(
379
+ df.index, drawdown, 0, color="#d62728", alpha=0.3, label="Drawdown"
380
+ )
381
+ ax2.plot(df.index, drawdown, color="#d62728", linewidth=0.8, alpha=0.8)
382
+ ax2.set_ylabel("Drawdown", fontsize=10)
383
+ ax2.grid(True, linestyle="--", alpha=0.3)
384
+ # ax2.legend(loc='lower right', fontsize=8)
385
+
386
+ # 3. Daily Returns
387
+ ax3 = fig.add_subplot(gs[2], sharex=ax1)
388
+ ax3.bar(
389
+ df.index,
390
+ df["returns"],
391
+ color="gray",
392
+ alpha=0.5,
393
+ label="Daily Returns",
394
+ width=1.0 if len(df) < 100 else 0.8,
395
+ )
396
+ # Highlight extreme returns? No, keep simple.
397
+ ax3.set_ylabel("Returns", fontsize=10)
398
+ ax3.grid(True, linestyle="--", alpha=0.3)
399
+
400
+ # Format X axis
401
+ ax3.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
402
+ plt.setp(ax1.get_xticklabels(), visible=False)
403
+ plt.setp(ax2.get_xticklabels(), visible=False)
404
+ plt.xticks(rotation=0)
405
+
406
+ # Adjust margins
407
+ plt.subplots_adjust(top=0.95, bottom=0.05, left=0.08, right=0.95)
408
+
409
+ if filename:
410
+ plt.savefig(filename, dpi=100, bbox_inches="tight")
411
+ print(f"Plot saved to {filename}")
412
+
413
+ if show:
414
+ plt.show()
akquant/config.py ADDED
@@ -0,0 +1,36 @@
1
+ from typing import Optional
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class StrategyConfig:
7
+ """
8
+ Global configuration for strategies and backtesting.
9
+ Inspired by PyBroker's configuration system.
10
+ """
11
+
12
+ # Capital Management
13
+ initial_cash: float = 100000.0
14
+
15
+ # Fees & Commission
16
+ fee_mode: str = "per_order" # 'per_order', 'per_share', 'percent'
17
+ fee_amount: float = 0.0 # Fixed amount or percentage
18
+
19
+ # Execution
20
+ enable_fractional_shares: bool = False
21
+ round_fill_price: bool = True
22
+
23
+ # Position Sizing Constraints
24
+ max_long_positions: Optional[int] = None
25
+ max_short_positions: Optional[int] = None
26
+
27
+ # Bootstrap Metrics
28
+ bootstrap_samples: int = 1000
29
+ bootstrap_sample_size: Optional[int] = None
30
+
31
+ # Other
32
+ exit_on_last_bar: bool = True
33
+
34
+
35
+ # Global instance
36
+ strategy_config = StrategyConfig()
akquant/data.py ADDED
@@ -0,0 +1,122 @@
1
+ import pandas as pd
2
+ import hashlib
3
+ from typing import Optional, List
4
+ from pathlib import Path
5
+ import logging
6
+
7
+ try:
8
+ import akshare as ak
9
+ except ImportError:
10
+ ak = None
11
+
12
+ from .utils import load_akshare_bar
13
+ from .akquant import Bar
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class DataLoader:
19
+ """
20
+ Data Loader with caching capabilities, inspired by PyBroker.
21
+ """
22
+
23
+ def __init__(self, cache_dir: Optional[str] = None):
24
+ """
25
+ Initialize DataLoader.
26
+
27
+ Args:
28
+ cache_dir (str, optional): Directory to store cache files.
29
+ Defaults to ~/.akquant/cache.
30
+ """
31
+ if cache_dir:
32
+ self.cache_dir = Path(cache_dir)
33
+ else:
34
+ self.cache_dir = Path.home() / ".akquant" / "cache"
35
+
36
+ try:
37
+ if not self.cache_dir.exists():
38
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
39
+ except PermissionError:
40
+ logger.warning(
41
+ f"Permission denied for {self.cache_dir}, falling back to local .akquant_cache"
42
+ )
43
+ self.cache_dir = Path.cwd() / ".akquant_cache"
44
+ if not self.cache_dir.exists():
45
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ def _get_cache_path(self, key: str) -> Path:
48
+ """Generate cache file path based on a unique key."""
49
+ # Use a hash of the key to avoid filesystem issues with long/invalid filenames
50
+ hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
51
+ return self.cache_dir / f"{hashed_key}.pkl"
52
+
53
+ def load_akshare(
54
+ self,
55
+ symbol: str,
56
+ start_date: str,
57
+ end_date: str,
58
+ adjust: str = "qfq",
59
+ period: str = "daily",
60
+ use_cache: bool = True,
61
+ ) -> pd.DataFrame:
62
+ """
63
+ Load A-share history data from AKShare with caching.
64
+
65
+ Args:
66
+ symbol (str): Stock symbol (e.g., "600000").
67
+ start_date (str): Start date (YYYYMMDD).
68
+ end_date (str): End date (YYYYMMDD).
69
+ adjust (str): Adjustment factor ("qfq", "hfq", ""). Default "qfq".
70
+ period (str): Period ("daily", "weekly", "monthly"). Default "daily".
71
+ use_cache (bool): Whether to use cache. Default True.
72
+
73
+ Returns:
74
+ pd.DataFrame: Historical data.
75
+ """
76
+ if ak is None:
77
+ raise ImportError(
78
+ "akshare is not installed. Please run `pip install akshare`."
79
+ )
80
+
81
+ cache_key = f"akshare_stock_zh_a_hist_{symbol}_{start_date}_{end_date}_{adjust}_{period}"
82
+ cache_path = self._get_cache_path(cache_key)
83
+
84
+ if use_cache and cache_path.exists():
85
+ logger.info(f"Loading cached data for {symbol} from {cache_path}")
86
+ try:
87
+ df = pd.read_pickle(cache_path)
88
+ return df
89
+ except Exception as e:
90
+ logger.warning(f"Failed to load cache: {e}. Reloading from source.")
91
+
92
+ logger.info(f"Fetching data for {symbol} from AKShare...")
93
+ try:
94
+ df = ak.stock_zh_a_hist(
95
+ symbol=symbol,
96
+ period=period,
97
+ start_date=start_date,
98
+ end_date=end_date,
99
+ adjust=adjust,
100
+ )
101
+
102
+ # Basic validation
103
+ if df.empty:
104
+ logger.warning(f"No data found for {symbol}")
105
+ return df
106
+
107
+ # Cache the result
108
+ if use_cache:
109
+ df.to_pickle(cache_path)
110
+ logger.info(f"Data cached to {cache_path}")
111
+
112
+ return df
113
+ except Exception as e:
114
+ logger.error(f"Error fetching data from AKShare: {e}")
115
+ raise
116
+
117
+ def df_to_bars(self, df: pd.DataFrame, symbol: Optional[str] = None) -> List[Bar]:
118
+ """
119
+ Convert DataFrame to list of Bar objects.
120
+ Wrapper around utils.load_akshare_bar.
121
+ """
122
+ return load_akshare_bar(df, symbol)
akquant/indicator.py ADDED
@@ -0,0 +1,56 @@
1
+ from typing import Callable, Dict
2
+ import pandas as pd
3
+
4
+
5
+ class Indicator:
6
+ """
7
+ Helper class for defining and calculating indicators.
8
+ Inspired by PyBroker's indicator system.
9
+ """
10
+
11
+ def __init__(self, name: str, fn: Callable, **kwargs):
12
+ self.name = name
13
+ self.fn = fn
14
+ self.kwargs = kwargs
15
+ self._data: Dict[str, pd.Series] = {} # symbol -> series
16
+
17
+ def __call__(self, df: pd.DataFrame, symbol: str) -> pd.Series:
18
+ """Calculate indicator on a DataFrame"""
19
+ if symbol in self._data:
20
+ return self._data[symbol]
21
+
22
+ # Assume fn takes a series/df and returns a series
23
+ # If kwargs contains column names, extract them
24
+ # This is a simplified version of PyBroker's powerful DSL
25
+ try:
26
+ result = self.fn(df, **self.kwargs)
27
+ except Exception:
28
+ # Try passing column if specified in kwargs
29
+ # e.g. rolling_mean(df['close'], window=5)
30
+ # This part is tricky to generalize without a full DSL,
31
+ # so we start simple: user passes a lambda or function that takes df
32
+ result = self.fn(df)
33
+
34
+ self._data[symbol] = result
35
+ return result
36
+
37
+
38
+ class IndicatorSet:
39
+ """
40
+ Collection of indicators for easy management.
41
+ """
42
+
43
+ def __init__(self):
44
+ self._indicators: Dict[str, Indicator] = {}
45
+
46
+ def add(self, name: str, fn: Callable, **kwargs):
47
+ self._indicators[name] = Indicator(name, fn, **kwargs)
48
+
49
+ def get(self, name: str) -> Indicator:
50
+ return self._indicators[name]
51
+
52
+ def calculate_all(self, df: pd.DataFrame, symbol: str) -> Dict[str, pd.Series]:
53
+ results = {}
54
+ for name, ind in self._indicators.items():
55
+ results[name] = ind(df, symbol)
56
+ return results