bbstrader 0.2.991__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bbstrader might be problematic. Click here for more details.

@@ -33,12 +33,6 @@ __all__ = [
33
33
  "create_trade_instance",
34
34
  ]
35
35
 
36
- FILLING_TYPE = [
37
- Mt5.ORDER_FILLING_IOC,
38
- Mt5.ORDER_FILLING_RETURN,
39
- Mt5.ORDER_FILLING_BOC,
40
- ]
41
-
42
36
  log.add(
43
37
  f"{BBSTRADER_DIR}/logs/trade.log",
44
38
  enqueue=True,
@@ -50,6 +44,13 @@ global LOGGER
50
44
  LOGGER = log
51
45
 
52
46
 
47
+ FILLING_TYPE = [
48
+ Mt5.ORDER_FILLING_IOC,
49
+ Mt5.ORDER_FILLING_RETURN,
50
+ Mt5.ORDER_FILLING_BOC,
51
+ ]
52
+
53
+
53
54
  class TradeAction(Enum):
54
55
  """
55
56
  An enumeration class for trade actions.
@@ -136,6 +137,15 @@ class TradeSignal:
136
137
  f"price={self.price}, stoplimit={self.stoplimit}), comment='{self.comment}'"
137
138
  )
138
139
 
140
+ class TradingMode(Enum):
141
+ BACKTEST = "BACKTEST"
142
+ LIVE = "LIVE"
143
+
144
+ def isbacktest(self) -> bool:
145
+ return self == TradingMode.BACKTEST
146
+ def islive(self) -> bool:
147
+ return self == TradingMode.LIVE
148
+
139
149
 
140
150
  Buys = Literal["BMKT", "BLMT", "BSTP", "BSTPLMT"]
141
151
  Sells = Literal["SMKT", "SLMT", "SSTP", "SSTPLMT"]
@@ -152,6 +162,7 @@ Orders = Literal[
152
162
 
153
163
  EXPERT_ID = 98181105
154
164
 
165
+
155
166
  class Trade(RiskManagement):
156
167
  """
157
168
  Extends the `RiskManagement` class to include specific trading operations,
@@ -834,7 +845,7 @@ class Trade(RiskManagement):
834
845
  action (str): (`'BMKT'`, `'SMKT'`) for Market orders
835
846
  or (`'BLMT', 'SLMT', 'BSTP', 'SSTP', 'BSTPLMT', 'SSTPLMT'`) for pending orders
836
847
  price (float): The price at which to open an order
837
- stoplimit (float): A price a pending Limit order is set at
848
+ stoplimit (float): A price a pending Limit order is set at
838
849
  when the price reaches the 'price' value (this condition is mandatory).
839
850
  The pending order is not passed to the trading system until that moment
840
851
  id (int): The strategy id or expert Id
@@ -873,30 +884,30 @@ class Trade(RiskManagement):
873
884
  @property
874
885
  def orders(self):
875
886
  """Return all opened order's tickets"""
876
- if len(self.opened_orders) != 0:
877
- return self.opened_orders
878
- return None
887
+ current_orders = self.get_current_orders() or []
888
+ opened_orders = set(current_orders + self.opened_orders)
889
+ return list(opened_orders) if len(opened_orders) != 0 else None
879
890
 
880
891
  @property
881
892
  def positions(self):
882
893
  """Return all opened position's tickets"""
883
- if len(self.opened_positions) != 0:
884
- return self.opened_positions
885
- return None
894
+ current_positions = self.get_current_positions() or []
895
+ opened_positions = set(current_positions + self.opened_positions)
896
+ return list(opened_positions) if len(opened_positions) != 0 else None
886
897
 
887
898
  @property
888
899
  def buypos(self):
889
900
  """Return all buy opened position's tickets"""
890
- if len(self.buy_positions) != 0:
891
- return self.buy_positions
892
- return None
901
+ buy_positions = self.get_current_buys() or []
902
+ buy_positions = set(buy_positions + self.buy_positions)
903
+ return list(buy_positions) if len(buy_positions) != 0 else None
893
904
 
894
905
  @property
895
906
  def sellpos(self):
896
907
  """Return all sell opened position's tickets"""
897
- if len(self.sell_positions) != 0:
898
- return self.sell_positions
899
- return None
908
+ sell_positions = self.get_current_sells() or []
909
+ sell_positions = set(sell_positions + self.sell_positions)
910
+ return list(sell_positions) if len(sell_positions) != 0 else None
900
911
 
901
912
  @property
902
913
  def bepos(self):
@@ -1214,7 +1225,7 @@ class Trade(RiskManagement):
1214
1225
  Sets the break-even level for a given trading position.
1215
1226
 
1216
1227
  Args:
1217
- position (TradePosition): The trading position for which the break-even is to be set.
1228
+ position (TradePosition): The trading position for which the break-even is to be set.
1218
1229
  This is the value return by `mt5.positions_get()`.
1219
1230
  be (int): The break-even level in points.
1220
1231
  level (float): The break-even level in price, if set to None , it will be calated automaticaly.
@@ -1446,7 +1457,7 @@ class Trade(RiskManagement):
1446
1457
  Args:
1447
1458
  ticket (int): Order ticket to modify (e.g TradeOrder.ticket)
1448
1459
  price (float): The price at which to modify the order
1449
- stoplimit (float): A price a pending Limit order is set at
1460
+ stoplimit (float): A price a pending Limit order is set at
1450
1461
  when the price reaches the 'price' value (this condition is mandatory).
1451
1462
  The pending order is not passed to the trading system until that moment
1452
1463
  sl (float): The stop loss in points
@@ -1597,7 +1608,7 @@ class Trade(RiskManagement):
1597
1608
  ):
1598
1609
  """
1599
1610
  Args:
1600
- order_type (str): Type of orders to close
1611
+ order_type (str): Type of orders to close
1601
1612
  ('all', 'buy_stops', 'sell_stops', 'buy_limits', 'sell_limits', 'buy_stop_limits', 'sell_stop_limits')
1602
1613
  id (int): The unique ID of the Expert or Strategy
1603
1614
  comment (str): Comment for the closing position
@@ -70,27 +70,31 @@ class TimeFrame(Enum):
70
70
  Rrepresent a time frame object
71
71
  """
72
72
 
73
- M1 = "1m"
74
- M2 = "2m"
75
- M3 = "3m"
76
- M4 = "4m"
77
- M5 = "5m"
78
- M6 = "6m"
79
- M10 = "10m"
80
- M12 = "12m"
81
- M15 = "15m"
82
- M20 = "20m"
83
- M30 = "30m"
84
- H1 = "1h"
85
- H2 = "2h"
86
- H3 = "3h"
87
- H4 = "4h"
88
- H6 = "6h"
89
- H8 = "8h"
90
- H12 = "12h"
91
- D1 = "D1"
92
- W1 = "W1"
93
- MN1 = "MN1"
73
+ M1 = TIMEFRAMES["1m"]
74
+ M2 = TIMEFRAMES["2m"]
75
+ M3 = TIMEFRAMES["3m"]
76
+ M4 = TIMEFRAMES["4m"]
77
+ M5 = TIMEFRAMES["5m"]
78
+ M6 = TIMEFRAMES["6m"]
79
+ M10 = TIMEFRAMES["10m"]
80
+ M12 = TIMEFRAMES["12m"]
81
+ M15 = TIMEFRAMES["15m"]
82
+ M20 = TIMEFRAMES["20m"]
83
+ M30 = TIMEFRAMES["30m"]
84
+ H1 = TIMEFRAMES["1h"]
85
+ H2 = TIMEFRAMES["2h"]
86
+ H3 = TIMEFRAMES["3h"]
87
+ H4 = TIMEFRAMES["4h"]
88
+ H6 = TIMEFRAMES["6h"]
89
+ H8 = TIMEFRAMES["8h"]
90
+ H12 = TIMEFRAMES["12h"]
91
+ D1 = TIMEFRAMES["D1"]
92
+ W1 = TIMEFRAMES["W1"]
93
+ MN1 = TIMEFRAMES["MN1"]
94
+
95
+ def __str__(self):
96
+ """Return the string representation of the time frame."""
97
+ return self.name
94
98
 
95
99
 
96
100
  class TerminalInfo(NamedTuple):
@@ -263,6 +267,23 @@ class SymbolInfo(NamedTuple):
263
267
  path: str
264
268
 
265
269
 
270
+ class SymbolType(Enum):
271
+ """
272
+ Represents the type of a symbol.
273
+ """
274
+
275
+ FOREX = "FOREX" # Forex currency pairs
276
+ FUTURES = "FUTURES" # Futures contracts
277
+ STOCKS = "STOCKS" # Stocks and shares
278
+ BONDS = "BONDS" # Bonds
279
+ CRYPTO = "CRYPTO" # Cryptocurrencies
280
+ ETFs = "ETFs" # Exchange-Traded Funds
281
+ INDICES = "INDICES" # Market indices
282
+ COMMODITIES = "COMMODITIES" # Commodities
283
+ OPTIONS = "OPTIONS" # Options contracts
284
+ unknown = "UNKNOWN" # Unknown or unsupported type
285
+
286
+
266
287
  class TickInfo(NamedTuple):
267
288
  """
268
289
  Represents the last tick for the specified financial instrument.
@@ -465,10 +486,12 @@ class MT5TerminalError(Exception):
465
486
  self.message = message
466
487
 
467
488
  def __str__(self) -> str:
468
- if self.message is None:
469
- return f"{self.__class__.__name__}"
470
- else:
471
- return f"{self.__class__.__name__}, {self.message}"
489
+ # if self.message is None:
490
+ # return f"{self.__class__.__name__}"
491
+ # else:
492
+ # return f"{self.__class__.__name__}, {self.message}"
493
+ msg_str = str(self.message) if self.message is not None else ""
494
+ return f"{self.code} - {self.__class__.__name__}: {msg_str}"
472
495
 
473
496
 
474
497
  class GenericFail(MT5TerminalError):
@@ -561,6 +584,21 @@ class InternalFailTimeout(InternalFailError):
561
584
  super().__init__(MT5.RES_E_INTERNAL_FAIL_TIMEOUT, message)
562
585
 
563
586
 
587
+ RES_E_FAIL = 1 # Generic error
588
+ RES_E_INVALID_PARAMS = 2 # Invalid parameters
589
+ RES_E_NOT_FOUND = 3 # Not found
590
+ RES_E_INVALID_VERSION = 4 # Invalid version
591
+ RES_E_AUTH_FAILED = 5 # Authorization failed
592
+ RES_E_UNSUPPORTED = 6 # Unsupported method
593
+ RES_E_AUTO_TRADING_DISABLED = 7 # Autotrading disabled
594
+
595
+ # Actual internal error codes from MetaTrader5
596
+ RES_E_INTERNAL_FAIL_CONNECT = -10000
597
+ RES_E_INTERNAL_FAIL_INIT = -10001
598
+ RES_E_INTERNAL_FAIL_SEND = -10006
599
+ RES_E_INTERNAL_FAIL_RECEIVE = -10007
600
+ RES_E_INTERNAL_FAIL_TIMEOUT = -10008
601
+
564
602
  # Dictionary to map error codes to exception classes
565
603
  _ERROR_CODE_TO_EXCEPTION_ = {
566
604
  MT5.RES_E_FAIL: GenericFail,
@@ -575,6 +613,18 @@ _ERROR_CODE_TO_EXCEPTION_ = {
575
613
  MT5.RES_E_INTERNAL_FAIL_INIT: InternalFailInit,
576
614
  MT5.RES_E_INTERNAL_FAIL_CONNECT: InternalFailConnect,
577
615
  MT5.RES_E_INTERNAL_FAIL_TIMEOUT: InternalFailTimeout,
616
+ RES_E_FAIL: GenericFail,
617
+ RES_E_INVALID_PARAMS: InvalidParams,
618
+ RES_E_NOT_FOUND: HistoryNotFound,
619
+ RES_E_INVALID_VERSION: InvalidVersion,
620
+ RES_E_AUTH_FAILED: AuthFailed,
621
+ RES_E_UNSUPPORTED: UnsupportedMethod,
622
+ RES_E_AUTO_TRADING_DISABLED: AutoTradingDisabled,
623
+ RES_E_INTERNAL_FAIL_SEND: InternalFailSend,
624
+ RES_E_INTERNAL_FAIL_RECEIVE: InternalFailReceive,
625
+ RES_E_INTERNAL_FAIL_INIT: InternalFailInit,
626
+ RES_E_INTERNAL_FAIL_CONNECT: InternalFailConnect,
627
+ RES_E_INTERNAL_FAIL_TIMEOUT: InternalFailTimeout,
578
628
  }
579
629
 
580
630
 
@@ -588,7 +638,10 @@ def raise_mt5_error(message: Optional[str] = None):
588
638
  MT5TerminalError: A specific exception based on the error code.
589
639
  """
590
640
  error = _ERROR_CODE_TO_EXCEPTION_.get(MT5.last_error()[0])
591
- raise Exception(f"{error(None)} {message or MT5.last_error()[1]}")
641
+ if error is not None:
642
+ raise Exception(f"{error(None)} {message or MT5.last_error()[1]}")
643
+ else:
644
+ raise Exception(f"{message or MT5.last_error()[1]}")
592
645
 
593
646
 
594
647
  _ORDER_FILLING_TYPE_ = "https://www.mql5.com/en/docs/constants/tradingconstants/orderproperties#enum_order_type_filling"
@@ -28,8 +28,10 @@ def _download_and_process_data(source, tickers, start, end, tf, path, **kwargs):
28
28
  end=end,
29
29
  progress=False,
30
30
  multi_level_index=False,
31
+ auto_adjust=True,
31
32
  )
32
- data = data.drop(columns=["Adj Close"], axis=1)
33
+ if "Adj Close" in data.columns:
34
+ data = data.drop(columns=["Adj Close"], axis=1)
33
35
  elif source == "mt5":
34
36
  start, end = pd.Timestamp(start), pd.Timestamp(end)
35
37
  data = download_historical_data(
bbstrader/models/ml.py CHANGED
@@ -250,12 +250,13 @@ class LightGBModel(object):
250
250
  data = pd.concat(data)
251
251
  data = (
252
252
  data.rename(columns={s: s.lower().replace(" ", "_") for s in data.columns})
253
- .drop(columns=["adj_close"])
254
253
  .set_index("symbol", append=True)
255
254
  .swaplevel()
256
255
  .sort_index()
257
256
  .dropna()
258
257
  )
258
+ if "adj_close" in data.columns:
259
+ data = data.drop(columns=["adj_close"])
259
260
  return data
260
261
 
261
262
  def download_metadata(self, tickers):
bbstrader/models/nlp.py CHANGED
@@ -331,8 +331,18 @@ FINANCIAL_LEXICON = {
331
331
 
332
332
  class TopicModeler(object):
333
333
  def __init__(self):
334
- self.nlp = spacy.load("en_core_web_sm")
335
- self.nlp.disable_pipes("ner")
334
+ nltk.download("punkt", quiet=True)
335
+ nltk.download("stopwords", quiet=True)
336
+
337
+ try:
338
+ self.nlp = spacy.load("en_core_web_sm")
339
+ self.nlp.disable_pipes("ner")
340
+ except OSError:
341
+ raise RuntimeError(
342
+ "The SpaCy model 'en_core_web_sm' is not installed.\n"
343
+ "Please install it by running:\n"
344
+ " python -m spacy download en_core_web_sm"
345
+ )
336
346
 
337
347
  def preprocess_texts(self, texts: list[str]):
338
348
  def clean_doc(Doc):
@@ -9,7 +9,7 @@ from loguru import logger as log
9
9
  from bbstrader.btengine.strategy import MT5Strategy, Strategy
10
10
  from bbstrader.config import BBSTRADER_DIR
11
11
  from bbstrader.metatrader.account import Account, check_mt5_connection
12
- from bbstrader.metatrader.trade import Trade, TradeAction
12
+ from bbstrader.metatrader.trade import Trade, TradeAction, TradingMode
13
13
  from bbstrader.trading.utils import send_message
14
14
 
15
15
  try:
@@ -174,6 +174,7 @@ class Mt5ExecutionEngine:
174
174
  strategy_cls: Strategy,
175
175
  /,
176
176
  mm: bool = True,
177
+ auto_trade: bool = True,
177
178
  optimizer: str = "equal",
178
179
  trail: bool = True,
179
180
  stop_trail: Optional[int] = None,
@@ -187,7 +188,7 @@ class Mt5ExecutionEngine:
187
188
  closing_pnl: Optional[float] = None,
188
189
  trading_days: Optional[List[str]] = None,
189
190
  comment: Optional[str] = None,
190
- **kwargs
191
+ **kwargs,
191
192
  ):
192
193
  """
193
194
  Args:
@@ -197,6 +198,9 @@ class Mt5ExecutionEngine:
197
198
  mm : Enable Money Management. Defaults to True.
198
199
  optimizer : Risk management optimizer. Defaults to 'equal'.
199
200
  See `bbstrader.models.optimization` module for more information.
201
+ auto_trade : If set to true, when signal are generated by the strategy class,
202
+ the Execution engine will automaticaly open position in other whise it will prompt
203
+ the user for confimation.
200
204
  show_positions_orders : Print open positions and orders. Defaults to False.
201
205
  iter_time : Interval to check for signals and `mm`. Defaults to 5.
202
206
  use_trade_time : Open trades after the time is completed. Defaults to True.
@@ -239,6 +243,7 @@ class Mt5ExecutionEngine:
239
243
  self.trades_instances = trades_instances
240
244
  self.strategy_cls = strategy_cls
241
245
  self.mm = mm
246
+ self.auto_trade = auto_trade
242
247
  self.optimizer = optimizer
243
248
  self.trail = trail
244
249
  self.stop_trail = stop_trail
@@ -266,7 +271,7 @@ class Mt5ExecutionEngine:
266
271
  def __repr__(self):
267
272
  trades = self.trades_instances.keys()
268
273
  strategy = self.strategy_cls.__name__
269
- return f"Mt5ExecutionEngine(Symbols={list(trades)}, Strategy={strategy})"
274
+ return f"{self.__class__.__name__}(Symbols={list(trades)}, Strategy={strategy})"
270
275
 
271
276
  def _initialize_engine(self, **kwargs):
272
277
  global logger
@@ -300,14 +305,14 @@ class Mt5ExecutionEngine:
300
305
  )
301
306
  return
302
307
 
303
- def _print_exc(self, msg, e: Exception):
308
+ def _print_exc(self, msg: str, e: Exception):
304
309
  if isinstance(e, KeyboardInterrupt):
305
310
  logger.info("Stopping the Execution Engine ...")
306
311
  quit()
307
312
  if self.debug_mode:
308
313
  raise ValueError(msg).with_traceback(e.__traceback__)
309
314
  else:
310
- logger.error(msg)
315
+ logger.error(f"{msg, repr(e)}")
311
316
 
312
317
  def _max_trades(self, mtrades):
313
318
  max_trades = {
@@ -331,7 +336,7 @@ class Mt5ExecutionEngine:
331
336
  try:
332
337
  check_mt5_connection(**kwargs)
333
338
  strategy: MT5Strategy = self.strategy_cls(
334
- symbol_list=self.symbols, mode="live", **kwargs
339
+ symbol_list=self.symbols, mode=TradingMode.LIVE, **kwargs
335
340
  )
336
341
  except Exception as e:
337
342
  self._print_exc(
@@ -357,24 +362,24 @@ class Mt5ExecutionEngine:
357
362
  }
358
363
 
359
364
  info = (
360
- "SIGNAL = {signal}, SYMBOL={symbol}, STRATEGY={strategy}, "
365
+ "SIGNAL={signal}, SYMBOL={symbol}, STRATEGY={strategy}, "
361
366
  "TIMEFRAME={timeframe}, ACCOUNT={account}"
362
367
  ).format(**common_data)
363
368
 
364
369
  sigmsg = (
365
- "SIGNAL = {signal},\n"
366
- "SYMBOL = {symbol},\n"
367
- "TYPE = {symbol_type},\n"
368
- "DESCRIPTION = {description},\n"
369
- "PRICE = {price},\n"
370
- "STOPLIMIT = {stoplimit},\n"
371
- "STRATEGY = {strategy},\n"
372
- "TIMEFRAME = {timeframe},\n"
373
- "BROKER = {broker},\n"
374
- "TIMESTAMP = {timestamp}"
370
+ "SIGNAL={signal},\n"
371
+ "SYMBOL={symbol},\n"
372
+ "TYPE={symbol_type},\n"
373
+ "DESCRIPTION={description},\n"
374
+ "PRICE={price},\n"
375
+ "STOPLIMIT={stoplimit},\n"
376
+ "STRATEGY={strategy},\n"
377
+ "TIMEFRAME={timeframe},\n"
378
+ "BROKER={broker},\n"
379
+ "TIMESTAMP={timestamp}"
375
380
  ).format(
376
381
  **common_data,
377
- symbol_type=account.get_symbol_type(symbol),
382
+ symbol_type=account.get_symbol_type(symbol).value,
378
383
  description=symbol_info.description,
379
384
  price=price,
380
385
  stoplimit=stoplimit,
@@ -382,7 +387,7 @@ class Mt5ExecutionEngine:
382
387
  timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
383
388
  )
384
389
 
385
- msg_template = "SYMBOL = {symbol}, STRATEGY = {strategy}, ACCOUNT = {account}"
390
+ msg_template = "SYMBOL={symbol}, STRATEGY={strategy}, ACCOUNT={account}"
386
391
  msg = f"Sending {signal} Order ... " + msg_template.format(**common_data)
387
392
  tfmsg = "Time Frame Not completed !!! " + msg_template.format(**common_data)
388
393
  riskmsg = "Risk not allowed !!! " + msg_template.format(**common_data)
@@ -621,6 +626,12 @@ class Mt5ExecutionEngine:
621
626
  ):
622
627
  if self.notify:
623
628
  self._send_notification(sigmsg, symbol)
629
+ if not self.auto_trade:
630
+ auto_trade = input(
631
+ f"{sigmsg} \n\n Please enter Y/Yes to accept this order or N/No to reject it :"
632
+ )
633
+ if not auto_trade.upper().startswith("Y"):
634
+ return
624
635
  if not self._check_retcode(trade, "BMKT"):
625
636
  logger.info(msg)
626
637
  trade.open_buy_position(
@@ -638,6 +649,12 @@ class Mt5ExecutionEngine:
638
649
  ):
639
650
  if self.notify:
640
651
  self._send_notification(sigmsg, symbol)
652
+ if not self.auto_trade:
653
+ auto_trade = input(
654
+ f"{sigmsg} \n\n Please enter Y/Yes to accept this order or N/No to reject it :"
655
+ )
656
+ if not auto_trade.upper().startswith("Y"):
657
+ return
641
658
  if not self._check_retcode(trade, "SMKT"):
642
659
  logger.info(msg)
643
660
  trade.open_sell_position(
@@ -31,6 +31,7 @@ from bbstrader.btengine.execution import MT5ExecutionHandler, SimExecutionHandle
31
31
  from bbstrader.btengine.strategy import Strategy
32
32
  from bbstrader.metatrader.account import Account
33
33
  from bbstrader.metatrader.rates import Rates
34
+ from bbstrader.metatrader.trade import TradingMode
34
35
  from bbstrader.models.risk import build_hmm_models
35
36
  from bbstrader.tseries import ArimaGarchModel, KalmanFilterModel
36
37
 
@@ -73,7 +74,7 @@ class SMAStrategy(Strategy):
73
74
  bars: DataHandler = None,
74
75
  events: Queue = None,
75
76
  symbol_list: List[str] = None,
76
- mode: Literal["backtest", "live"] = "backtest",
77
+ mode: TradingMode = TradingMode.BACKTEST,
77
78
  **kwargs,
78
79
  ):
79
80
  """
@@ -81,7 +82,7 @@ class SMAStrategy(Strategy):
81
82
  bars (DataHandler): A data handler object that provides market data.
82
83
  events (Queue): An event queue object where generated signals are placed.
83
84
  symbol_list (List[str]): A list of symbols to consider for trading.
84
- mode (Literal['backtest', 'live']): The mode of operation for the strategy.
85
+ mode TradingMode: The mode of operation for the strategy.
85
86
  short_window (int, optional): The period for the short moving average.
86
87
  long_window (int, optional): The period for the long moving average.
87
88
  time_frame (str, optional): The time frame for the data.
@@ -196,13 +197,13 @@ class SMAStrategy(Strategy):
196
197
  return signals
197
198
 
198
199
  def calculate_signals(self, event=None):
199
- if self.mode == "backtest" and event is not None:
200
+ if self.mode == TradingMode.BACKTEST and event is not None:
200
201
  if event.type == Events.MARKET:
201
202
  signals = self.create_backtest_signals()
202
203
  for signal in signals.values():
203
204
  if signal is not None:
204
205
  self.events.put(signal)
205
- elif self.mode == "live":
206
+ elif self.mode == TradingMode.LIVE:
206
207
  signals = self.create_live_signals()
207
208
  return signals
208
209
 
@@ -238,7 +239,7 @@ class ArimaGarchStrategy(Strategy):
238
239
  bars: DataHandler = None,
239
240
  events: Queue = None,
240
241
  symbol_list: List[str] = None,
241
- mode: Literal["backtest", "live"] = "backtest",
242
+ mode: TradingMode = TradingMode.BACKTEST,
242
243
  **kwargs,
243
244
  ):
244
245
  """
@@ -384,13 +385,13 @@ class ArimaGarchStrategy(Strategy):
384
385
  return signals
385
386
 
386
387
  def calculate_signals(self, event=None):
387
- if self.mode == "backtest" and event is not None:
388
+ if self.mode == TradingMode.BACKTEST and event is not None:
388
389
  if event.type == Events.MARKET:
389
390
  signals = self.create_backtest_signal()
390
391
  for signal in signals.values():
391
392
  if signal is not None:
392
393
  self.events.put(signal)
393
- elif self.mode == "live":
394
+ elif self.mode == TradingMode.LIVE:
394
395
  return self.create_live_signals()
395
396
 
396
397
 
@@ -408,7 +409,7 @@ class KalmanFilterStrategy(Strategy):
408
409
  bars: DataHandler = None,
409
410
  events: Queue = None,
410
411
  symbol_list: List[str] = None,
411
- mode: Literal["backtest", "live"] = "backtest",
412
+ mode: TradingMode = TradingMode.BACKTEST,
412
413
  **kwargs,
413
414
  ):
414
415
  """
@@ -567,10 +568,10 @@ class KalmanFilterStrategy(Strategy):
567
568
  """
568
569
  Calculate the Kalman Filter strategy.
569
570
  """
570
- if self.mode == "backtest" and event is not None:
571
+ if self.mode == TradingMode.BACKTEST and event is not None:
571
572
  if event.type == Events.MARKET:
572
573
  self.calculate_backtest_signals()
573
- elif self.mode == "live":
574
+ elif self.mode == TradingMode.LIVE:
574
575
  return self.calculate_live_signals()
575
576
 
576
577
 
@@ -589,7 +590,7 @@ class StockIndexSTBOTrading(Strategy):
589
590
  bars: DataHandler = None,
590
591
  events: Queue = None,
591
592
  symbol_list: List[str] = None,
592
- mode: Literal["backtest", "live"] = "backtest",
593
+ mode: TradingMode = TradingMode.BACKTEST,
593
594
  **kwargs,
594
595
  ):
595
596
  """
@@ -632,7 +633,7 @@ class StockIndexSTBOTrading(Strategy):
632
633
  self.heightest_price = {index: None for index in symbols}
633
634
  self.lowerst_price = {index: None for index in symbols}
634
635
 
635
- if self.mode == "backtest":
636
+ if self.mode == TradingMode.BACKTEST:
636
637
  self.qty = get_quantities(quantities, symbols)
637
638
  self.num_buys = {index: 0 for index in symbols}
638
639
  self.buy_prices = {index: [] for index in symbols}
@@ -751,10 +752,10 @@ class StockIndexSTBOTrading(Strategy):
751
752
  self.buy_prices[index] = []
752
753
 
753
754
  def calculate_signals(self, event=None) -> Dict[str, Union[str, None]]:
754
- if self.mode == "backtest" and event is not None:
755
+ if self.mode == TradingMode.BACKTEST and event is not None:
755
756
  if event.type == Events.MARKET:
756
757
  self.calculate_backtest_signals()
757
- elif self.mode == "live":
758
+ elif self.mode == TradingMode.LIVE:
758
759
  return self.calculate_live_signals()
759
760
 
760
761
 
bbstrader/tseries.py CHANGED
@@ -514,9 +514,8 @@ def get_corr(tickers: Union[List[str], Tuple[str, ...]], start: str, end: str) -
514
514
  >>> get_corr(['AAPL', 'MSFT', 'GOOG'], '2023-01-01', '2023-12-31')
515
515
  """
516
516
  # Download historical data
517
- data = yf.download(tickers, start=start, end=end, multi_level_index=False)[
518
- "Adj Close"
519
- ]
517
+ data = yf.download(tickers, start=start, end=end, multi_level_index=False, auto_adjust=True)
518
+ data = data["Adj Close"] if "Adj Close" in data.columns else data["Close"]
520
519
 
521
520
  # Calculate correlation matrix
522
521
  correlation_matrix = data.corr()
@@ -685,8 +684,8 @@ def run_cadf_test(
685
684
  auto_adjust=True,
686
685
  )
687
686
  df = pd.DataFrame(index=_p0.index)
688
- df[p0] = _p0["Adj Close"]
689
- df[p1] = _p1["Adj Close"]
687
+ df[p0] = _p0["Close"]
688
+ df[p1] = _p1["Close"]
690
689
  df = df.dropna()
691
690
 
692
691
  # Calculate optimal hedge ratio "beta"
@@ -784,7 +783,7 @@ def run_hurst_test(symbol: str, start: str, end: str):
784
783
  print(f"\nHurst(GBM): {_hurst(gbm)}")
785
784
  print(f"Hurst(MR): {_hurst(mr)}")
786
785
  print(f"Hurst(TR): {_hurst(tr)}")
787
- print(f"Hurst({symbol}): {hurst(data['Adj Close'])}\n")
786
+ print(f"Hurst({symbol}): {hurst(data['Close'])}\n")
788
787
 
789
788
 
790
789
  def test_cointegration(ticker1, ticker2, start, end):
@@ -796,7 +795,7 @@ def test_cointegration(ticker1, ticker2, start, end):
796
795
  progress=False,
797
796
  multi_level_index=False,
798
797
  auto_adjust=True,
799
- )["Adj Close"].dropna()
798
+ )["Close"].dropna()
800
799
 
801
800
  # Perform Johansen cointegration test
802
801
  result = coint_johansen(stock_data_pair, det_order=0, k_ar_diff=1)
@@ -947,8 +946,8 @@ def run_kalman_filter(
947
946
  )
948
947
 
949
948
  prices = pd.DataFrame(index=etf_df1.index)
950
- prices[etfs[0]] = etf_df1["Adj Close"]
951
- prices[etfs[1]] = etf_df2["Adj Close"]
949
+ prices[etfs[0]] = etf_df1["Close"]
950
+ prices[etfs[1]] = etf_df2["Close"]
952
951
 
953
952
  draw_date_coloured_scatterplot(etfs, prices)
954
953
  state_means, state_covs = calc_slope_intercept_kalman(etfs, prices)