bbstrader 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bbstrader might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  from abc import ABCMeta, abstractmethod
2
2
  from queue import Queue
3
+ from typing import Any, Union
3
4
 
4
5
  from loguru import logger
5
6
 
@@ -39,7 +40,7 @@ class ExecutionHandler(metaclass=ABCMeta):
39
40
  """
40
41
 
41
42
  @abstractmethod
42
- def execute_order(self, event: OrderEvent):
43
+ def execute_order(self, event: OrderEvent) -> None:
43
44
  """
44
45
  Takes an Order event and executes it, producing
45
46
  a Fill event that gets placed onto the Events queue.
@@ -47,7 +48,7 @@ class ExecutionHandler(metaclass=ABCMeta):
47
48
  Args:
48
49
  event (OrderEvent): Contains an Event object with order information.
49
50
  """
50
- pass
51
+ raise NotImplementedError("Should implement execute_order()")
51
52
 
52
53
 
53
54
  class SimExecutionHandler(ExecutionHandler):
@@ -61,7 +62,12 @@ class SimExecutionHandler(ExecutionHandler):
61
62
  handler.
62
63
  """
63
64
 
64
- def __init__(self, events: Queue, data: DataHandler, **kwargs):
65
+ def __init__(
66
+ self,
67
+ events: "Queue[Union[FillEvent, OrderEvent]]",
68
+ data: DataHandler,
69
+ **kwargs: Any,
70
+ ) -> None:
65
71
  """
66
72
  Initialises the handler, setting the event queues
67
73
  up internally.
@@ -75,7 +81,7 @@ class SimExecutionHandler(ExecutionHandler):
75
81
  self.commissions = kwargs.get("commission")
76
82
  self.exchange = kwargs.get("exchange", "ARCA")
77
83
 
78
- def execute_order(self, event: OrderEvent):
84
+ def execute_order(self, event: OrderEvent) -> None:
79
85
  """
80
86
  Simply converts Order objects into Fill objects naively,
81
87
  i.e. without any latency, slippage or fill ratio problems.
@@ -86,7 +92,7 @@ class SimExecutionHandler(ExecutionHandler):
86
92
  if event.type == Events.ORDER:
87
93
  dtime = self.bardata.get_latest_bar_datetime(event.symbol)
88
94
  fill_event = FillEvent(
89
- timeindex=dtime,
95
+ timeindex=dtime, # type: ignore
90
96
  symbol=event.symbol,
91
97
  exchange=self.exchange,
92
98
  quantity=event.quantity,
@@ -96,9 +102,10 @@ class SimExecutionHandler(ExecutionHandler):
96
102
  order=event.signal,
97
103
  )
98
104
  self.events.put(fill_event)
105
+ price = event.price or 0.0
99
106
  self.logger.info(
100
107
  f"{event.direction} ORDER FILLED: SYMBOL={event.symbol}, "
101
- f"QUANTITY={event.quantity}, PRICE @{round(event.price, 5)} EXCHANGE={fill_event.exchange}",
108
+ f"QUANTITY={event.quantity}, PRICE @{round(price, 5)} EXCHANGE={fill_event.exchange}",
102
109
  custom_time=fill_event.timeindex,
103
110
  )
104
111
 
@@ -128,7 +135,12 @@ class MT5ExecutionHandler(ExecutionHandler):
128
135
  This class only works with `bbstrader.metatrader.data.MT5DataHandler` class.
129
136
  """
130
137
 
131
- def __init__(self, events: Queue, data: DataHandler, **kwargs):
138
+ def __init__(
139
+ self,
140
+ events: "Queue[Union[FillEvent, OrderEvent]]",
141
+ data: DataHandler,
142
+ **kwargs: Any,
143
+ ) -> None:
132
144
  """
133
145
  Initialises the handler, setting the event queues up internally.
134
146
 
@@ -142,14 +154,16 @@ class MT5ExecutionHandler(ExecutionHandler):
142
154
  self.exchange = kwargs.get("exchange", "MT5")
143
155
  self.__account = Account(**kwargs)
144
156
 
145
- def _calculate_lot(self, symbol, quantity, price):
157
+ def _calculate_lot(
158
+ self, symbol: str, quantity: Union[int, float], price: Union[int, float]
159
+ ) -> float:
146
160
  symbol_type = self.__account.get_symbol_type(symbol)
147
161
  symbol_info = self.__account.get_symbol_info(symbol)
148
162
  contract_size = symbol_info.trade_contract_size
149
163
 
150
164
  lot = (quantity * price) / (contract_size * price)
151
165
  if contract_size == 1:
152
- lot = quantity
166
+ lot = float(quantity)
153
167
  if (
154
168
  symbol_type
155
169
  in (SymbolType.COMMODITIES, SymbolType.FUTURES, SymbolType.CRYPTO)
@@ -157,18 +171,24 @@ class MT5ExecutionHandler(ExecutionHandler):
157
171
  ):
158
172
  lot = quantity / contract_size
159
173
  if symbol_type == SymbolType.FOREX:
160
- lot = quantity * price / contract_size
174
+ lot = float(quantity * price / contract_size)
161
175
  return self._check_lot(symbol, lot)
162
176
 
163
- def _check_lot(self, symbol, lot):
177
+ def _check_lot(self, symbol: str, lot: float) -> float:
164
178
  symbol_info = self.__account.get_symbol_info(symbol)
165
179
  if lot < symbol_info.volume_min:
166
- return symbol_info.volume_min
180
+ return float(symbol_info.volume_min)
167
181
  elif lot > symbol_info.volume_max:
168
- return symbol_info.volume_max
182
+ return float(symbol_info.volume_max)
169
183
  return round(lot, 2)
170
184
 
171
- def _estimate_total_fees(self, symbol, lot, qty, price):
185
+ def _estimate_total_fees(
186
+ self,
187
+ symbol: str,
188
+ lot: float,
189
+ qty: Union[int, float],
190
+ price: Union[int, float],
191
+ ) -> float:
172
192
  symbol_type = self.__account.get_symbol_type(symbol)
173
193
  if symbol_type in (SymbolType.STOCKS, SymbolType.ETFs):
174
194
  return self._estimate_stock_commission(symbol, qty, price)
@@ -185,7 +205,9 @@ class MT5ExecutionHandler(ExecutionHandler):
185
205
  else:
186
206
  return 0.0
187
207
 
188
- def _estimate_stock_commission(self, symbol, qty, price):
208
+ def _estimate_stock_commission(
209
+ self, symbol: str, qty: Union[int, float], price: Union[int, float]
210
+ ) -> float:
189
211
  # https://admiralmarkets.com/start-trading/contract-specifications?regulator=jsc
190
212
  min_com = 1.0
191
213
  min_aud = 8.0
@@ -220,22 +242,22 @@ class MT5ExecutionHandler(ExecutionHandler):
220
242
  else:
221
243
  return max(min_com, qty * price * eu_asia_cm)
222
244
 
223
- def _estimate_forex_commission(self, lot):
245
+ def _estimate_forex_commission(self, lot: float) -> float:
224
246
  return 3.0 * lot
225
247
 
226
- def _estimate_commodity_commission(self, lot):
248
+ def _estimate_commodity_commission(self, lot: float) -> float:
227
249
  return 3.0 * lot
228
250
 
229
- def _estimate_index_commission(self, lot):
251
+ def _estimate_index_commission(self, lot: float) -> float:
230
252
  return 0.25 * lot
231
253
 
232
- def _estimate_futures_commission(self):
254
+ def _estimate_futures_commission(self) -> float:
233
255
  return 0.0
234
256
 
235
- def _estimate_crypto_commission(self):
257
+ def _estimate_crypto_commission(self) -> float:
236
258
  return 0.0
237
259
 
238
- def execute_order(self, event: OrderEvent):
260
+ def execute_order(self, event: OrderEvent) -> None:
239
261
  """
240
262
  Executes an Order event by converting it into a Fill event.
241
263
 
@@ -247,12 +269,14 @@ class MT5ExecutionHandler(ExecutionHandler):
247
269
  direction = event.direction
248
270
  quantity = event.quantity
249
271
  price = event.price
272
+ if price is None:
273
+ price = self.bardata.get_latest_bar_value(symbol, "close")
250
274
  lot = self._calculate_lot(symbol, quantity, price)
251
275
  fees = self._estimate_total_fees(symbol, lot, quantity, price)
252
276
  dtime = self.bardata.get_latest_bar_datetime(symbol)
253
277
  commission = self.commissions or fees
254
278
  fill_event = FillEvent(
255
- timeindex=dtime,
279
+ timeindex=dtime, # type: ignore
256
280
  symbol=symbol,
257
281
  exchange=self.exchange,
258
282
  quantity=quantity,
@@ -262,11 +286,14 @@ class MT5ExecutionHandler(ExecutionHandler):
262
286
  order=event.signal,
263
287
  )
264
288
  self.events.put(fill_event)
289
+ log_price = event.price or 0.0
265
290
  self.logger.info(
266
291
  f"{direction} ORDER FILLED: SYMBOL={symbol}, QUANTITY={quantity}, "
267
- f"PRICE @{round(event.price, 5)} EXCHANGE={fill_event.exchange}",
292
+ f"PRICE @{round(log_price, 5)} EXCHANGE={fill_event.exchange}",
268
293
  custom_time=fill_event.timeindex,
269
294
  )
270
295
 
271
296
 
272
- class IBExecutionHandler(ExecutionHandler): ...
297
+ class IBExecutionHandler(ExecutionHandler):
298
+ def execute_order(self, event: OrderEvent) -> None:
299
+ raise NotImplementedError("Should implement execute_order()")
@@ -1,4 +1,4 @@
1
- from typing import Dict, List
1
+ from typing import Dict, List, Optional, Tuple
2
2
  import warnings
3
3
 
4
4
  import matplotlib.pyplot as plt
@@ -24,8 +24,12 @@ __all__ = [
24
24
  "get_perfbased_weights",
25
25
  ]
26
26
 
27
+
27
28
  def get_asset_performances(
28
- portfolio: pd.DataFrame, assets: List[str], plot=True, strategy=""
29
+ portfolio: pd.DataFrame,
30
+ assets: List[str],
31
+ plot: bool = True,
32
+ strategy: str = "",
29
33
  ) -> pd.Series:
30
34
  """
31
35
  Calculate the performance of the assets in the portfolio.
@@ -48,12 +52,14 @@ def get_asset_performances(
48
52
  asset_returns.fillna(0, inplace=True)
49
53
  asset_cum_returns = (1.0 + asset_returns).cumprod()
50
54
  if plot:
51
- asset_cum_returns.plot(figsize=(12, 6), title=f"{strategy} Strategy Assets Performance")
55
+ asset_cum_returns.plot(
56
+ figsize=(12, 6), title=f"{strategy} Strategy Assets Performance"
57
+ )
52
58
  plt.show()
53
59
  return asset_cum_returns.iloc[-1] - 1
54
60
 
55
61
 
56
- def get_perfbased_weights(performances) -> Dict[str, float]:
62
+ def get_perfbased_weights(performances: pd.Series) -> Dict[str, float]:
57
63
  """
58
64
  Calculate the weights of the assets based on their performances.
59
65
 
@@ -71,7 +77,7 @@ def get_perfbased_weights(performances) -> Dict[str, float]:
71
77
  return weights
72
78
 
73
79
 
74
- def create_sharpe_ratio(returns, periods=252) -> float:
80
+ def create_sharpe_ratio(returns: pd.Series, periods: int = 252) -> float:
75
81
  """
76
82
  Create the Sharpe ratio for the strategy, based on a
77
83
  benchmark of zero (i.e. no risk-free rate information).
@@ -89,7 +95,7 @@ def create_sharpe_ratio(returns, periods=252) -> float:
89
95
  # Define a function to calculate the Sortino Ratio
90
96
 
91
97
 
92
- def create_sortino_ratio(returns, periods=252) -> float:
98
+ def create_sortino_ratio(returns: pd.Series, periods: int = 252) -> float:
93
99
  """
94
100
  Create the Sortino ratio for the strategy, based on a
95
101
  benchmark of zero (i.e. no risk-free rate information).
@@ -104,7 +110,7 @@ def create_sortino_ratio(returns, periods=252) -> float:
104
110
  return qs.stats.sortino(returns, periods=periods)
105
111
 
106
112
 
107
- def create_drawdowns(pnl):
113
+ def create_drawdowns(pnl: pd.Series) -> Tuple[pd.Series, float, float]:
108
114
  """
109
115
  Calculate the largest peak-to-trough drawdown of the PnL curve
110
116
  as well as the duration of the drawdown. Requires that the
@@ -135,7 +141,7 @@ def create_drawdowns(pnl):
135
141
  return drawdown, drawdown.max(), duration.max()
136
142
 
137
143
 
138
- def plot_performance(df, title):
144
+ def plot_performance(df: pd.DataFrame, title: str) -> None:
139
145
  """
140
146
  Plot the performance of the strategy:
141
147
  - (Portfolio value, %)
@@ -188,7 +194,7 @@ def plot_performance(df, title):
188
194
  plt.show()
189
195
 
190
196
 
191
- def plot_returns_and_dd(df: pd.DataFrame, benchmark: str, title):
197
+ def plot_returns_and_dd(df: pd.DataFrame, benchmark: str, title: str) -> None:
192
198
  """
193
199
  Plot the returns and drawdowns of the strategy
194
200
  compared to a benchmark.
@@ -271,7 +277,7 @@ def plot_returns_and_dd(df: pd.DataFrame, benchmark: str, title):
271
277
  plt.show()
272
278
 
273
279
 
274
- def plot_monthly_yearly_returns(df: pd.DataFrame, title):
280
+ def plot_monthly_yearly_returns(df: pd.DataFrame, title: str) -> None:
275
281
  """
276
282
  Plot the monthly and yearly returns of the strategy.
277
283
 
@@ -306,7 +312,7 @@ def plot_monthly_yearly_returns(df: pd.DataFrame, title):
306
312
  # Prepare monthly returns DataFrame
307
313
  monthly_returns_df = monthly_returns.unstack(level=-1) * 100
308
314
  monthly_returns_df.columns = monthly_returns_df.columns.map(
309
- lambda x: pd.to_datetime(x, format="%m").strftime("%b")
315
+ lambda x: pd.to_datetime(str(x), format="%m").strftime("%b")
310
316
  )
311
317
 
312
318
  # Calculate and prepare yearly returns DataFrame
@@ -371,7 +377,12 @@ def plot_monthly_yearly_returns(df: pd.DataFrame, title):
371
377
  plt.show()
372
378
 
373
379
 
374
- def show_qs_stats(returns, benchmark, strategy_name, save_dir=None):
380
+ def show_qs_stats(
381
+ returns: pd.Series,
382
+ benchmark: str,
383
+ strategy_name: str,
384
+ save_dir: Optional[str] = None,
385
+ ) -> None:
375
386
  """
376
387
  Generate the full quantstats report for the strategy.
377
388
 
@@ -1,12 +1,19 @@
1
1
  from datetime import datetime
2
2
  from pathlib import Path
3
3
  from queue import Queue
4
+ from typing import Any, Dict, List, Optional, Union
4
5
 
5
6
  import pandas as pd
6
7
  import quantstats as qs
7
8
 
8
9
  from bbstrader.btengine.data import DataHandler
9
- from bbstrader.btengine.event import Events, FillEvent, MarketEvent, OrderEvent, SignalEvent
10
+ from bbstrader.btengine.event import (
11
+ Events,
12
+ FillEvent,
13
+ MarketEvent,
14
+ OrderEvent,
15
+ SignalEvent,
16
+ )
10
17
  from bbstrader.btengine.performance import (
11
18
  create_drawdowns,
12
19
  create_sharpe_ratio,
@@ -22,7 +29,7 @@ __all__ = [
22
29
  ]
23
30
 
24
31
 
25
- class Portfolio(object):
32
+ class Portfolio:
26
33
  """
27
34
  This describes a `Portfolio()` object that keeps track of the positions
28
35
  within a portfolio and generates orders of a fixed quantity of stock based on signals.
@@ -72,11 +79,11 @@ class Portfolio(object):
72
79
  def __init__(
73
80
  self,
74
81
  bars: DataHandler,
75
- events: Queue,
82
+ events: "Queue[Union[OrderEvent, FillEvent, SignalEvent]]",
76
83
  start_date: datetime,
77
- initial_capital=100000.0,
78
- **kwargs,
79
- ):
84
+ initial_capital: float = 100000.0,
85
+ **kwargs: Any,
86
+ ) -> None:
80
87
  """
81
88
  Initialises the portfolio with bars and an event queue.
82
89
  Also includes a starting datetime index and initial capital
@@ -99,7 +106,7 @@ class Portfolio(object):
99
106
  """
100
107
  self.bars = bars
101
108
  self.events = events
102
- self.symbol_list = self.bars.symbol_list
109
+ self.symbol_list = self.bars.symbols
103
110
  self.start_date = start_date
104
111
  self.initial_capital = initial_capital
105
112
  self._leverage = kwargs.get("leverage", 1)
@@ -119,15 +126,15 @@ class Portfolio(object):
119
126
  else:
120
127
  self.tf = self._tf_mapping()[self.timeframe]
121
128
 
122
- self.all_positions = self.construct_all_positions()
123
- self.current_positions = dict(
129
+ self.all_positions: List[Dict[str, Any]] = self.construct_all_positions()
130
+ self.current_positions: Dict[str, Any] = dict(
124
131
  (k, v) for k, v in [(s, 0) for s in self.symbol_list]
125
132
  )
126
- self.all_holdings = self.construct_all_holdings()
127
- self.current_holdings = self.construct_current_holdings()
128
- self.equity_curve = None
133
+ self.all_holdings: List[Dict[str, Any]] = self.construct_all_holdings()
134
+ self.current_holdings: Dict[str, Any] = self.construct_current_holdings()
135
+ self.equity_curve: Optional[pd.DataFrame] = None
129
136
 
130
- def _tf_mapping(self):
137
+ def _tf_mapping(self) -> Dict[str, int]:
131
138
  """
132
139
  Returns a dictionary mapping the time frames
133
140
  to the number of bars in a year.
@@ -154,12 +161,12 @@ class Portfolio(object):
154
161
  480,
155
162
  720,
156
163
  ]:
157
- key = f"{minutes//60}h" if minutes >= 60 else f"{minutes}m"
164
+ key = f"{minutes // 60}h" if minutes >= 60 else f"{minutes}m"
158
165
  time_frame_mapping[key] = int(252 * (60 / minutes) * th)
159
166
  time_frame_mapping["D1"] = 252
160
167
  return time_frame_mapping
161
168
 
162
- def construct_all_positions(self):
169
+ def construct_all_positions(self) -> List[Dict[str, Any]]:
163
170
  """
164
171
  Constructs the positions list using the start_date
165
172
  to determine when the time index will begin.
@@ -168,7 +175,7 @@ class Portfolio(object):
168
175
  d["Datetime"] = self.start_date
169
176
  return [d]
170
177
 
171
- def construct_all_holdings(self):
178
+ def construct_all_holdings(self) -> List[Dict[str, Any]]:
172
179
  """
173
180
  Constructs the holdings list using the start_date
174
181
  to determine when the time index will begin.
@@ -180,7 +187,7 @@ class Portfolio(object):
180
187
  d["Total"] = self.initial_capital
181
188
  return [d]
182
189
 
183
- def construct_current_holdings(self):
190
+ def construct_current_holdings(self) -> Dict[str, float]:
184
191
  """
185
192
  This constructs the dictionary which will hold the instantaneous
186
193
  value of the portfolio across all symbols.
@@ -202,7 +209,7 @@ class Portfolio(object):
202
209
  except (AttributeError, KeyError, ValueError):
203
210
  return 0.0
204
211
 
205
- def update_timeindex(self, event: MarketEvent):
212
+ def update_timeindex(self, event: MarketEvent) -> None:
206
213
  """
207
214
  Adds a new record to the positions matrix for the current
208
215
  market data bar. This reflects the PREVIOUS bar, i.e. all
@@ -236,7 +243,7 @@ class Portfolio(object):
236
243
  # Append the current holdings
237
244
  self.all_holdings.append(dh)
238
245
 
239
- def update_positions_from_fill(self, fill: FillEvent):
246
+ def update_positions_from_fill(self, fill: FillEvent) -> None:
240
247
  """
241
248
  Takes a Fill object and updates the position matrix to
242
249
  reflect the new position.
@@ -254,7 +261,7 @@ class Portfolio(object):
254
261
  # Update positions list with new quantities
255
262
  self.current_positions[fill.symbol] += fill_dir * fill.quantity
256
263
 
257
- def update_holdings_from_fill(self, fill: FillEvent):
264
+ def update_holdings_from_fill(self, fill: FillEvent) -> None:
258
265
  """
259
266
  Takes a Fill object and updates the holdings matrix to
260
267
  reflect the holdings value.
@@ -277,7 +284,7 @@ class Portfolio(object):
277
284
  self.current_holdings["Cash"] -= cost + fill.commission
278
285
  self.current_holdings["Total"] -= cost + fill.commission
279
286
 
280
- def update_fill(self, event: FillEvent):
287
+ def update_fill(self, event: FillEvent) -> None:
281
288
  """
282
289
  Updates the portfolio current positions and holdings
283
290
  from a FillEvent.
@@ -286,7 +293,7 @@ class Portfolio(object):
286
293
  self.update_positions_from_fill(event)
287
294
  self.update_holdings_from_fill(event)
288
295
 
289
- def generate_order(self, signal: SignalEvent):
296
+ def generate_order(self, signal: SignalEvent) -> Optional[OrderEvent]:
290
297
  """
291
298
  Turns a SignalEvent into an OrderEvent.
292
299
 
@@ -304,7 +311,7 @@ class Portfolio(object):
304
311
  strength = signal.strength
305
312
  price = signal.price or self._get_price(symbol)
306
313
  cur_quantity = self.current_positions[symbol]
307
- mkt_quantity = round(quantity * strength, 2)
314
+ mkt_quantity = round(float(quantity) * float(strength), 2)
308
315
  new_quantity = mkt_quantity * self._leverage
309
316
 
310
317
  if direction in ["LONG", "SHORT", "EXIT"]:
@@ -332,7 +339,7 @@ class Portfolio(object):
332
339
 
333
340
  return order
334
341
 
335
- def update_signal(self, event: SignalEvent):
342
+ def update_signal(self, event: SignalEvent) -> None:
336
343
  """
337
344
  Acts on a SignalEvent to generate new orders
338
345
  based on the portfolio logic.
@@ -341,7 +348,7 @@ class Portfolio(object):
341
348
  order_event = self.generate_order(event)
342
349
  self.events.put(order_event)
343
350
 
344
- def create_equity_curve_dataframe(self):
351
+ def create_equity_curve_dataframe(self) -> None:
345
352
  """
346
353
  Creates a pandas DataFrame from the all_holdings
347
354
  list of dictionaries.
@@ -353,13 +360,16 @@ class Portfolio(object):
353
360
  curve["Equity Curve"] = (1.0 + curve["Returns"]).cumprod()
354
361
  self.equity_curve = curve
355
362
 
356
- def output_summary_stats(self):
363
+ def output_summary_stats(self) -> List[Any]:
357
364
  """
358
365
  Creates a list of summary statistics for the portfolio.
359
366
  """
360
- total_return = self.equity_curve["Equity Curve"].iloc[-1]
361
- returns = self.equity_curve["Returns"]
362
- pnl = self.equity_curve["Equity Curve"]
367
+ if self.equity_curve is None:
368
+ self.create_equity_curve_dataframe()
369
+
370
+ total_return = self.equity_curve["Equity Curve"].iloc[-1] # type: ignore
371
+ returns = self.equity_curve["Returns"] # type: ignore
372
+ pnl = self.equity_curve["Equity Curve"] # type: ignore
363
373
 
364
374
  sharpe_ratio = create_sharpe_ratio(returns, periods=self.tf)
365
375
  sortino_ratio = create_sortino_ratio(returns, periods=self.tf)
@@ -370,7 +380,7 @@ class Portfolio(object):
370
380
  self.equity_curve["Drawdown"] = drawdown
371
381
 
372
382
  stats = [
373
- ("Total Return", f"{(total_return-1.0) * 100.0:.2f}%"),
383
+ ("Total Return", f"{(total_return - 1.0) * 100.0:.2f}%"),
374
384
  ("Sharpe Ratio", f"{sharpe_ratio:.2f}"),
375
385
  ("Sortino Ratio", f"{sortino_ratio:.2f}"),
376
386
  ("Max Drawdown", f"{max_dd * 100.0:.2f}%"),
@@ -3,6 +3,8 @@ import json
3
3
  import os
4
4
  import sys
5
5
  from datetime import datetime
6
+ from types import ModuleType
7
+ from typing import Any, Dict, List, Type
6
8
 
7
9
  from bbstrader.btengine.backtest import run_backtest
8
10
  from bbstrader.btengine.data import (
@@ -18,12 +20,13 @@ from bbstrader.btengine.execution import (
18
20
  MT5ExecutionHandler,
19
21
  SimExecutionHandler,
20
22
  )
23
+ from bbstrader.btengine.strategy import MT5Strategy, Strategy
21
24
  from bbstrader.core.utils import load_class, load_module
22
25
 
23
26
  BACKTEST_PATH = os.path.expanduser("~/.bbstrader/backtest/backtest.py")
24
27
  CONFIG_PATH = os.path.expanduser("~/.bbstrader/backtest/backtest.json")
25
28
 
26
- DATA_HANDLER_MAP = {
29
+ DATA_HANDLER_MAP: Dict[str, Type[DataHandler]] = {
27
30
  "csv": CSVDataHandler,
28
31
  "mt5": MT5DataHandler,
29
32
  "yf": YFDataHandler,
@@ -31,27 +34,25 @@ DATA_HANDLER_MAP = {
31
34
  "fmp": FMPDataHandler,
32
35
  }
33
36
 
34
- EXECUTION_HANDLER_MAP = {
37
+ EXECUTION_HANDLER_MAP: Dict[str, Type[ExecutionHandler]] = {
35
38
  "sim": SimExecutionHandler,
36
39
  "mt5": MT5ExecutionHandler,
37
40
  }
38
41
 
39
42
 
40
- def load_exc_handler(module, handler_name):
41
- return load_class(module, handler_name, ExecutionHandler)
43
+ def load_exc_handler(module: ModuleType, handler_name: str) -> Type[ExecutionHandler]:
44
+ return load_class(module, handler_name, ExecutionHandler) # type: ignore
42
45
 
43
46
 
44
- def load_data_handler(module, handler_name):
45
- return load_class(module, handler_name, DataHandler)
47
+ def load_data_handler(module: ModuleType, handler_name: str) -> Type[DataHandler]:
48
+ return load_class(module, handler_name, DataHandler) # type: ignore
46
49
 
47
50
 
48
- def load_strategy(module, strategy_name):
49
- from bbstrader.btengine.strategy import MT5Strategy, Strategy
51
+ def load_strategy(module: ModuleType, strategy_name: str) -> Type[Strategy]:
52
+ return load_class(module, strategy_name, (Strategy, MT5Strategy)) # type: ignore
50
53
 
51
- return load_class(module, strategy_name, (Strategy, MT5Strategy))
52
54
 
53
-
54
- def load_config(config_path, strategy_name):
55
+ def load_config(config_path: str, strategy_name: str) -> Dict[str, Any]:
55
56
  if not os.path.exists(config_path):
56
57
  raise FileNotFoundError(
57
58
  f"Configuration file {config_path} not found. Please create it."
@@ -101,7 +102,7 @@ def load_config(config_path, strategy_name):
101
102
  return config
102
103
 
103
104
 
104
- def backtest(unknown):
105
+ def backtest(unknown: List[str]) -> None:
105
106
  HELP_MSG = """
106
107
  Usage:
107
108
  python -m bbstrader --run backtest [options]