bbstrader 0.2.93__py3-none-any.whl → 0.2.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bbstrader might be problematic. Click here for more details.

Files changed (35) hide show
  1. bbstrader/__ini__.py +20 -20
  2. bbstrader/__main__.py +50 -50
  3. bbstrader/btengine/__init__.py +54 -54
  4. bbstrader/btengine/scripts.py +157 -157
  5. bbstrader/compat.py +19 -19
  6. bbstrader/config.py +137 -137
  7. bbstrader/core/data.py +22 -22
  8. bbstrader/core/utils.py +146 -146
  9. bbstrader/metatrader/__init__.py +6 -6
  10. bbstrader/metatrader/account.py +1516 -1516
  11. bbstrader/metatrader/copier.py +750 -745
  12. bbstrader/metatrader/rates.py +584 -584
  13. bbstrader/metatrader/risk.py +749 -748
  14. bbstrader/metatrader/scripts.py +81 -81
  15. bbstrader/metatrader/trade.py +1836 -1836
  16. bbstrader/metatrader/utils.py +645 -645
  17. bbstrader/models/__init__.py +10 -10
  18. bbstrader/models/factors.py +312 -312
  19. bbstrader/models/ml.py +1272 -1272
  20. bbstrader/models/optimization.py +182 -182
  21. bbstrader/models/portfolio.py +223 -223
  22. bbstrader/models/risk.py +398 -398
  23. bbstrader/trading/__init__.py +11 -11
  24. bbstrader/trading/execution.py +846 -846
  25. bbstrader/trading/script.py +155 -155
  26. bbstrader/trading/scripts.py +69 -69
  27. bbstrader/trading/strategies.py +860 -860
  28. bbstrader/tseries.py +1842 -1842
  29. {bbstrader-0.2.93.dist-info → bbstrader-0.2.95.dist-info}/LICENSE +21 -21
  30. {bbstrader-0.2.93.dist-info → bbstrader-0.2.95.dist-info}/METADATA +188 -187
  31. bbstrader-0.2.95.dist-info/RECORD +44 -0
  32. bbstrader-0.2.93.dist-info/RECORD +0 -44
  33. {bbstrader-0.2.93.dist-info → bbstrader-0.2.95.dist-info}/WHEEL +0 -0
  34. {bbstrader-0.2.93.dist-info → bbstrader-0.2.95.dist-info}/entry_points.txt +0 -0
  35. {bbstrader-0.2.93.dist-info → bbstrader-0.2.95.dist-info}/top_level.txt +0 -0
@@ -1,860 +1,860 @@
1
- """
2
- Strategies module for trading strategies backtesting and execution.
3
- """
4
-
5
- from datetime import datetime
6
- from queue import Queue
7
- from typing import Dict, List, Literal, Optional, Union
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import yfinance as yf
12
-
13
- from bbstrader.btengine.backtest import BacktestEngine
14
- from bbstrader.btengine.data import DataHandler, MT5DataHandler, YFDataHandler
15
- from bbstrader.btengine.event import SignalEvent
16
- from bbstrader.btengine.execution import MT5ExecutionHandler, SimExecutionHandler
17
- from bbstrader.btengine.strategy import Strategy
18
- from bbstrader.metatrader.account import Account
19
- from bbstrader.metatrader.rates import Rates
20
- from bbstrader.models.risk import build_hmm_models
21
- from bbstrader.tseries import ArimaGarchModel, KalmanFilterModel
22
-
23
- __all__ = [
24
- "SMAStrategy",
25
- "ArimaGarchStrategy",
26
- "KalmanFilterStrategy",
27
- "StockIndexSTBOTrading",
28
- "test_strategy",
29
- "get_quantities",
30
- ]
31
-
32
-
33
- def get_quantities(quantities, symbol_list):
34
- if isinstance(quantities, dict):
35
- return quantities
36
- elif isinstance(quantities, int):
37
- return {symbol: quantities for symbol in symbol_list}
38
-
39
-
40
- class SMAStrategy(Strategy):
41
- """
42
- Carries out a basic Moving Average Crossover strategy bactesting with a
43
- short/long simple weighted moving average. Default short/long
44
- windows are 50/200 periods respectively and uses Hiden Markov Model
45
- as risk Managment system for filteering signals.
46
-
47
- The trading strategy for this class is exceedingly simple and is used to bettter
48
- understood. The important issue is the risk management aspect (the Hmm model)
49
-
50
- The Long-term trend following strategy is of the classic moving average crossover type.
51
- The rules are simple:
52
- - At every bar calculate the 50-day and 200-day simple moving averages (SMA)
53
- - If the 50-day SMA exceeds the 200-day SMA and the strategy is not invested, then go long
54
- - If the 200-day SMA exceeds the 50-day SMA and the strategy is invested, then close the position
55
- """
56
-
57
- def __init__(
58
- self,
59
- bars: DataHandler = None,
60
- events: Queue = None,
61
- symbol_list: List[str] = None,
62
- mode: Literal["backtest", "live"] = "backtest",
63
- **kwargs,
64
- ):
65
- """
66
- Args:
67
- bars (DataHandler): A data handler object that provides market data.
68
- events (Queue): An event queue object where generated signals are placed.
69
- symbol_list (List[str]): A list of symbols to consider for trading.
70
- mode (Literal['backtest', 'live']): The mode of operation for the strategy.
71
- short_window (int, optional): The period for the short moving average.
72
- long_window (int, optional): The period for the long moving average.
73
- time_frame (str, optional): The time frame for the data.
74
- session_duration (float, optional): The duration of the trading session.
75
- risk_window (int, optional): The window size for the risk model.
76
- quantities (int, dict | optional): The default quantities of each asset to trade.
77
- """
78
- self.bars = bars
79
- self.events = events
80
- self.symbol_list = symbol_list or self.bars.symbol_list
81
- self.mode = mode
82
-
83
- self.kwargs = kwargs
84
- self.short_window = kwargs.get("short_window", 50)
85
- self.long_window = kwargs.get("long_window", 200)
86
- self.tf = kwargs.get("time_frame", "D1")
87
- self.qty = get_quantities(kwargs.get("quantities", 100), self.symbol_list)
88
- self.sd = kwargs.get("session_duration", 23.0)
89
- self.risk_models = build_hmm_models(self.symbol_list, **self.kwargs)
90
- self.risk_window = kwargs.get("risk_window", self.long_window)
91
- self.bought = self._calculate_initial_bought()
92
-
93
- def _calculate_initial_bought(self):
94
- bought = {}
95
- for s in self.symbol_list:
96
- bought[s] = "OUT"
97
- return bought
98
-
99
- def get_backtest_data(self):
100
- symbol_data = {symbol: None for symbol in self.symbol_list}
101
- for s in self.symbol_list:
102
- bar_date = self.bars.get_latest_bar_datetime(s)
103
- bars = self.bars.get_latest_bars_values(s, "adj_close", N=self.long_window)
104
- returns_val = self.bars.get_latest_bars_values(
105
- s, "returns", N=self.risk_window
106
- )
107
- if len(bars) >= self.long_window and len(returns_val) >= self.risk_window:
108
- regime = self.risk_models[s].which_trade_allowed(returns_val)
109
-
110
- short_sma = np.mean(bars[-self.short_window :])
111
- long_sma = np.mean(bars[-self.long_window :])
112
-
113
- symbol_data[s] = (short_sma, long_sma, regime, bar_date)
114
- return symbol_data
115
-
116
- def create_backtest_signals(self):
117
- signals = {symbol: None for symbol in self.symbol_list}
118
- symbol_data = self.get_backtest_data()
119
- for s, data in symbol_data.items():
120
- signal = None
121
- if data is not None:
122
- price = self.bars.get_latest_bar_value(s, "adj_close")
123
- short_sma, long_sma, regime, bar_date = data
124
- dt = bar_date
125
- if regime == "LONG":
126
- # Bulliqh regime
127
- if short_sma < long_sma and self.bought[s] == "LONG":
128
- print(f"EXIT: {bar_date}")
129
- signal = SignalEvent(1, s, dt, "EXIT", price=price)
130
- self.bought[s] = "OUT"
131
-
132
- elif short_sma > long_sma and self.bought[s] == "OUT":
133
- print(f"LONG: {bar_date}")
134
- signal = SignalEvent(
135
- 1, s, dt, "LONG", quantity=self.qty[s], price=price
136
- )
137
- self.bought[s] = "LONG"
138
-
139
- elif regime == "SHORT":
140
- # Bearish regime
141
- if short_sma > long_sma and self.bought[s] == "SHORT":
142
- print(f"EXIT: {bar_date}")
143
- signal = SignalEvent(1, s, dt, "EXIT", price=price)
144
- self.bought[s] = "OUT"
145
-
146
- elif short_sma < long_sma and self.bought[s] == "OUT":
147
- print(f"SHORT: {bar_date}")
148
- signal = SignalEvent(
149
- 1, s, dt, "SHORT", quantity=self.qty[s], price=price
150
- )
151
- self.bought[s] = "SHORT"
152
- signals[s] = signal
153
- return signals
154
-
155
- def get_live_data(self):
156
- symbol_data = {symbol: None for symbol in self.symbol_list}
157
- for symbol in self.symbol_list:
158
- sig_rate = Rates(symbol, self.tf, 0, self.risk_window + 2, **self.kwargs)
159
- hmm_data = sig_rate.returns.values
160
- prices = sig_rate.close.values
161
- current_regime = self.risk_models[symbol].which_trade_allowed(hmm_data)
162
- assert len(prices) >= self.long_window and len(hmm_data) >= self.risk_window
163
- short_sma = np.mean(prices[-self.short_window :])
164
- long_sma = np.mean(prices[-self.long_window :])
165
- short_sma, long_sma, current_regime
166
- symbol_data[symbol] = (short_sma, long_sma, current_regime)
167
- return symbol_data
168
-
169
- def create_live_signals(self):
170
- signals = {symbol: None for symbol in self.symbol_list}
171
- symbol_data = self.get_live_data()
172
- for symbol, data in symbol_data.items():
173
- signal = None
174
- short_sma, long_sma, regime = data
175
- if regime == "LONG":
176
- if short_sma > long_sma:
177
- signal = "LONG"
178
- elif regime == "SHORT":
179
- if short_sma < long_sma:
180
- signal = "SHORT"
181
- signals[symbol] = signal
182
- return signals
183
-
184
- def calculate_signals(self, event=None):
185
- if self.mode == "backtest" and event is not None:
186
- if event.type == "MARKET":
187
- signals = self.create_backtest_signals()
188
- for signal in signals.values():
189
- if signal is not None:
190
- self.events.put(signal)
191
- elif self.mode == "live":
192
- signals = self.create_live_signals()
193
- return signals
194
-
195
-
196
- class ArimaGarchStrategy(Strategy):
197
- """
198
- The `ArimaGarchStrategy` class extends the `Strategy`
199
- class to implement a backtesting framework for trading strategies based on
200
- ARIMA-GARCH models, incorporating a Hidden Markov Model (HMM) for risk management.
201
-
202
- Features
203
- ========
204
- - **ARIMA-GARCH Model**: Utilizes ARIMA for time series forecasting and GARCH for volatility forecasting, aimed at predicting market movements.
205
-
206
- - **HMM Risk Management**: Employs a Hidden Markov Model to manage risks, determining safe trading regimes.
207
-
208
- - **Event-Driven Backtesting**: Capable of simulating real-time trading conditions by processing market data and signals sequentially.
209
-
210
- - **Live Trading**: Supports real-time trading by generating signals based on live ARIMA-GARCH predictions and HMM risk management.
211
-
212
- Key Methods
213
- ===========
214
- - `get_backtest_data()`: Retrieves historical data for backtesting.
215
- - `create_backtest_signal()`: Generates trading signals based on ARIMA-GARCH predictions and HMM risk management.
216
- - `get_live_data()`: Retrieves live data for real-time trading.
217
- - `create_live_signals()`: Generates trading signals based on live ARIMA-GARCH predictions and HMM risk management.
218
- - `calculate_signals()`: Determines the trading signals based on the mode of operation (backtest or live).
219
-
220
- """
221
-
222
- def __init__(
223
- self,
224
- bars: DataHandler = None,
225
- events: Queue = None,
226
- symbol_list: List[str] = None,
227
- mode: Literal["backtest", "live"] = "backtest",
228
- **kwargs,
229
- ):
230
- """
231
- Args:
232
- `bars`: A data handler object that provides market data.
233
- `events`: An event queue object where generated signals are placed.
234
- `symbol_list`: A list of symbols to consider for trading.
235
- `mode`: The mode of operation for the strategy.
236
- `arima_window`: The window size for rolling prediction in backtesting.
237
- `time_frame`: The time frame for the data.
238
- `quantities`: Quantity of each assets to trade.
239
- `hmm_window`: Lookback period for HMM.
240
- """
241
- self.bars = bars
242
- self.events = events
243
- self.symbol_list = symbol_list or self.bars.symbol_list
244
- self.mode = mode
245
-
246
- self.qty = get_quantities(kwargs.get("quantities", 100), self.symbol_list)
247
- self.arima_window = kwargs.get("arima_window", 252)
248
- self.tf = kwargs.get("time_frame", "D1")
249
- self.sd = kwargs.get("session_duration", 23.0)
250
- self.risk_window = kwargs.get("hmm_window", 50)
251
- self.risk_models = build_hmm_models(self.symbol_list, **kwargs)
252
- self.arima_models = self._build_arch_models(**kwargs)
253
-
254
- self.long_market = {s: False for s in self.symbol_list}
255
- self.short_market = {s: False for s in self.symbol_list}
256
-
257
- def _build_arch_models(self, **kwargs) -> Dict[str, ArimaGarchModel]:
258
- arch_models = {symbol: None for symbol in self.symbol_list}
259
- for symbol in self.symbol_list:
260
- try:
261
- rates = Rates(symbol, self.tf, 0)
262
- data = rates.get_rates_from_pos()
263
- assert data is not None, f"No data for {symbol}"
264
- except AssertionError:
265
- data = yf.download(symbol, start=kwargs.get("yf_start"))
266
- arch = ArimaGarchModel(symbol, data, k=self.arima_window)
267
- arch_models[symbol] = arch
268
- return arch_models
269
-
270
- def get_backtest_data(self):
271
- symbol_data = {symbol: None for symbol in self.symbol_list}
272
- for symbol in self.symbol_list:
273
- M = self.arima_window
274
- N = self.risk_window
275
- dt = self.bars.get_latest_bar_datetime(symbol)
276
- bars = self.bars.get_latest_bars_values(
277
- symbol, "close", N=self.arima_window
278
- )
279
- returns = self.bars.get_latest_bars_values(
280
- symbol, "returns", N=self.risk_window
281
- )
282
- df = pd.DataFrame()
283
- df["Close"] = bars[-M:]
284
- df = df.dropna()
285
- arch_returns = self.arima_models[symbol].load_and_prepare_data(df)
286
- data = arch_returns["diff_log_return"].iloc[-self.arima_window :]
287
- if len(data) >= M and len(returns) >= N:
288
- symbol_data[symbol] = (data, returns[-N:], dt)
289
- return symbol_data
290
-
291
- def create_backtest_signal(self):
292
- signals = {symbol: None for symbol in self.symbol_list}
293
- for symbol in self.symbol_list:
294
- symbol_data = self.get_backtest_data()[symbol]
295
- if symbol_data is not None:
296
- data, returns, dt = symbol_data
297
- signal = None
298
- prediction = self.arima_models[symbol].get_prediction(data)
299
- regime = self.risk_models[symbol].which_trade_allowed(returns)
300
- price = self.bars.get_latest_bar_value(symbol, "adj_close")
301
-
302
- # If we are short the market, check for an exit
303
- if prediction > 0 and self.short_market[symbol]:
304
- signal = SignalEvent(1, symbol, dt, "EXIT", price=price)
305
- print(f"{dt}: EXIT SHORT")
306
- self.short_market[symbol] = False
307
-
308
- # If we are long the market, check for an exit
309
- elif prediction < 0 and self.long_market[symbol]:
310
- signal = SignalEvent(1, symbol, dt, "EXIT", price=price)
311
- print(f"{dt}: EXIT LONG")
312
- self.long_market[symbol] = False
313
-
314
- if regime == "LONG":
315
- # If we are not in the market, go long
316
- if prediction > 0 and not self.long_market[symbol]:
317
- signal = SignalEvent(
318
- 1,
319
- symbol,
320
- dt,
321
- "LONG",
322
- quantity=self.qty[symbol],
323
- price=price,
324
- )
325
- print(f"{dt}: LONG")
326
- self.long_market[symbol] = True
327
-
328
- elif regime == "SHORT":
329
- # If we are not in the market, go short
330
- if prediction < 0 and not self.short_market[symbol]:
331
- signal = SignalEvent(
332
- 1,
333
- symbol,
334
- dt,
335
- "SHORT",
336
- quantity=self.qty[symbol],
337
- price=price,
338
- )
339
- print(f"{dt}: SHORT")
340
- self.short_market[symbol] = True
341
- signals[symbol] = signal
342
- return signals
343
-
344
- def get_live_data(self):
345
- symbol_data = {symbol: None for symbol in self.symbol_list}
346
- for symbol in self.symbol_list:
347
- arch_data = Rates(symbol, self.tf, 0, self.arima_window)
348
- rates = arch_data.get_rates_from_pos()
349
- arch_returns = self.arima_models[symbol].load_and_prepare_data(rates)
350
- window_data = arch_returns["diff_log_return"].iloc[-self.arima_window :]
351
- hmm_returns = arch_data.returns.values[-self.risk_window :]
352
- symbol_data[symbol] = (window_data, hmm_returns)
353
- return symbol_data
354
-
355
- def create_live_signals(self):
356
- signals = {symbol: None for symbol in self.symbol_list}
357
- data = self.get_live_data()
358
- for symbol in self.symbol_list:
359
- symbol_data = data[symbol]
360
- if symbol_data is not None:
361
- window_data, hmm_returns = symbol_data
362
- prediction = self.arima_models[symbol].get_prediction(window_data)
363
- regime = self.risk_models[symbol].which_trade_allowed(hmm_returns)
364
- if regime == "LONG":
365
- if prediction > 0:
366
- signals[symbol] = "LONG"
367
- elif regime == "SHORT":
368
- if prediction < 0:
369
- signals[symbol] = "SHORT"
370
- return signals
371
-
372
- def calculate_signals(self, event=None):
373
- if self.mode == "backtest" and event is not None:
374
- if event.type == "MARKET":
375
- signals = self.create_backtest_signal()
376
- for signal in signals.values():
377
- if signal is not None:
378
- self.events.put(signal)
379
- elif self.mode == "live":
380
- return self.create_live_signals()
381
-
382
-
383
- class KalmanFilterStrategy(Strategy):
384
- """
385
- The `KalmanFilterStrategy` class implements a backtesting framework for a
386
- [pairs trading](https://en.wikipedia.org/wiki/Pairs_trade) strategy using
387
- Kalman Filter for signals and Hidden Markov Models (HMM) for risk management.
388
- This document outlines the structure and usage of the `KalmanFilterStrategy`,
389
- including initialization parameters, main functions, and an example of how to run a backtest.
390
- """
391
-
392
- def __init__(
393
- self,
394
- bars: DataHandler = None,
395
- events: Queue = None,
396
- symbol_list: List[str] = None,
397
- mode: Literal["backtest", "live"] = "backtest",
398
- **kwargs,
399
- ):
400
- """
401
- Args:
402
- `bars`: `DataHandler` for market data handling.
403
- `events`: A queue for managing events.
404
- `symbol_list`: List of ticker symbols for the pairs trading strategy.
405
- `mode`: Mode of operation for the strategy.
406
- kwargs : Additional keyword arguments including
407
- - `quantity`: Quantity of assets to trade. Default is 100.
408
- - `hmm_window`: Window size for calculating returns for the HMM. Default is 50.
409
- - `hmm_tiker`: Ticker symbol used by the HMM for risk management.
410
- - `time_frame`: Time frame for the data. Default is 'D1'.
411
- - `session_duration`: Duration of the trading session. Default is 6.5.
412
- """
413
- self.bars = bars
414
- self.events_queue = events
415
- self.symbol_list = symbol_list or self.bars.symbol_list
416
- self.mode = mode
417
-
418
- self.hmm_tiker = kwargs.get("hmm_tiker")
419
- self._assert_tikers()
420
- self.account = Account(**kwargs)
421
- self.hmm_window = kwargs.get("hmm_window", 50)
422
- self.qty = kwargs.get("quantity", 100)
423
- self.tf = kwargs.get("time_frame", "D1")
424
- self.sd = kwargs.get("session_duration", 6.5)
425
-
426
- self.risk_model = build_hmm_models(self.symbol_list, **kwargs)
427
- self.kl_model = KalmanFilterModel(self.tickers, **kwargs)
428
-
429
- self.long_market = False
430
- self.short_market = False
431
-
432
- def _assert_tikers(self):
433
- if self.symbol_list is None or len(self.symbol_list) != 2:
434
- raise ValueError("A list of 2 Tickers must be provide for this strategy")
435
- self.tickers = self.symbol_list
436
- if self.hmm_tiker is None:
437
- raise ValueError(
438
- "You need to provide a ticker used by the HMM for risk management"
439
- )
440
-
441
- def calculate_btxy(self, etqt, regime, dt):
442
- # Make sure there is no position open
443
- if etqt is None:
444
- return
445
- et, sqrt_Qt = etqt
446
- theta = self.kl_model.theta
447
- p1 = self.bars.get_latest_bar_value(self.tickers[1], "adj_close")
448
- p0 = self.bars.get_latest_bar_value(self.tickers[0], "adj_close")
449
- if et >= -sqrt_Qt and self.long_market:
450
- print("CLOSING LONG: %s" % dt)
451
- y_signal = SignalEvent(1, self.tickers[1], dt, "EXIT", price=p1)
452
- x_signal = SignalEvent(1, self.tickers[0], dt, "EXIT", price=p0)
453
- self.events_queue.put(y_signal)
454
- self.events_queue.put(x_signal)
455
- self.long_market = False
456
-
457
- elif et <= sqrt_Qt and self.short_market:
458
- print("CLOSING SHORT: %s" % dt)
459
- y_signal = SignalEvent(1, self.tickers[1], dt, "EXIT", price=p1)
460
- x_signal = SignalEvent(1, self.tickers[0], dt, "EXIT", price=p0)
461
- self.events_queue.put(y_signal)
462
- self.events_queue.put(x_signal)
463
- self.short_market = False
464
-
465
- # Long Entry
466
- if regime == "LONG":
467
- if et <= -sqrt_Qt and not self.long_market:
468
- print("LONG: %s" % dt)
469
- y_signal = SignalEvent(
470
- 1, self.tickers[1], dt, "LONG", self.qty, 1.0, price=p1
471
- )
472
- x_signal = SignalEvent(
473
- 1, self.tickers[0], dt, "SHORT", self.qty, theta[0], price=p0
474
- )
475
- self.events_queue.put(y_signal)
476
- self.events_queue.put(x_signal)
477
- self.long_market = True
478
-
479
- # Short Entry
480
- elif regime == "SHORT":
481
- if et >= sqrt_Qt and not self.short_market:
482
- print("SHORT: %s" % dt)
483
- y_signal = SignalEvent(
484
- 1, self.tickers[1], dt, "SHORT", self.qty, 1.0, price=p1
485
- )
486
- x_signal = SignalEvent(
487
- 1, self.tickers[0], "LONG", self.qty, theta[0], price=p0
488
- )
489
- self.events_queue.put(y_signal)
490
- self.events_queue.put(x_signal)
491
- self.short_market = True
492
-
493
- def calculate_livexy(self):
494
- signals = {symbol: None for symbol in self.symbol_list}
495
- p0_price = self.account.get_tick_info(self.tickers[0]).ask
496
- p1_price = self.account.get_tick_info(self.tickers[1]).ask
497
- prices = np.array([p0_price, p1_price])
498
- et_std = self.kl_model.calculate_etqt(prices)
499
- if et_std is not None:
500
- et, std = et_std
501
- y_signal = None
502
- x_signal = None
503
-
504
- if et >= -std or et <= std:
505
- y_signal = "EXIT"
506
- x_signal = "EXIT"
507
-
508
- if et <= -std:
509
- y_signal = "LONG"
510
- x_signal = "SHORT"
511
-
512
- if et >= std:
513
- y_signal = "SHORT"
514
- x_signal = "LONG"
515
-
516
- signals[self.tickers[0]] = x_signal
517
- signals[self.tickers[1]] = y_signal
518
- return signals
519
-
520
- def calculate_backtest_signals(self):
521
- p0, p1 = self.tickers[0], self.tickers[1]
522
- dt = self.bars.get_latest_bar_datetime(p0)
523
- x = self.bars.get_latest_bar_value(p0, "close")
524
- y = self.bars.get_latest_bar_value(p1, "close")
525
- returns = self.bars.get_latest_bars_values(
526
- self.hmm_tiker, "returns", N=self.hmm_window
527
- )
528
- latest_prices = np.array([-1.0, -1.0])
529
- if len(returns) >= self.hmm_window:
530
- latest_prices[0] = x
531
- latest_prices[1] = y
532
- et_qt = self.kl_model.calculate_etqt(latest_prices)
533
- regime = self.risk_model[self.hmm_tiker].which_trade_allowed(returns)
534
- self.calculate_btxy(et_qt, regime, dt)
535
-
536
- def calculate_live_signals(self):
537
- # Data Retrieval
538
- signals = {symbol: None for symbol in self.symbol_list}
539
- initial_signals = self.calculate_livexy()
540
- hmm_data = Rates(self.hmm_ticker, self.tf, 0, self.hmm_window)
541
- returns = hmm_data.returns.values
542
- current_regime = self.risk_model[self.hmm_tiker].which_trade_allowed(returns)
543
- for symbol in self.symbol_list:
544
- if symbol in initial_signals:
545
- signal = initial_signals[symbol]
546
- if signal == "LONG" and current_regime == "LONG":
547
- signals[symbol] = "LONG"
548
- elif signal == "SHORT" and current_regime == "SHORT":
549
- signals[symbol] = "SHORT"
550
- return signals
551
-
552
- def calculate_signals(self, event=None):
553
- """
554
- Calculate the Kalman Filter strategy.
555
- """
556
- if self.mode == "backtest" and event is not None:
557
- if event.type == "MARKET":
558
- self.calculate_backtest_signals()
559
- elif self.mode == "live":
560
- return self.calculate_live_signals()
561
-
562
-
563
- class StockIndexSTBOTrading(Strategy):
564
- """
565
- The StockIndexSTBOTrading class implements a stock index Contract for Difference (CFD)
566
- Buy-Only trading strategy. This strategy is based on the assumption that stock markets
567
- typically follow a long-term uptrend. The strategy is designed to capitalize on market
568
- corrections and price dips, where stocks or indices temporarily drop but are expected
569
- to recover. It operates in two modes: backtest and live, and it is particularly
570
- tailored to index trading.
571
- """
572
-
573
- def __init__(
574
- self,
575
- bars: DataHandler = None,
576
- events: Queue = None,
577
- symbol_list: List[str] = None,
578
- mode: Literal["backtest", "live"] = "backtest",
579
- **kwargs,
580
- ):
581
- """
582
- Args:
583
- `bars`: `DataHandler` for market data handling.
584
- `events`: A queue for managing events.
585
- `symbol_list`: List of ticker symbols for the pairs trading strategy.
586
- `mode`: Mode of operation for the strategy.
587
- kwargs : Additional keyword arguments including
588
- - rr (float, default: 3.0): The risk-reward ratio used to determine exit points.
589
- - epsilon (float, default: 0.1): The percentage threshold for price changes when considering new highs or lows.
590
- - expected_returns (dict): Expected return percentages for each symbol in the symbol list.
591
- - quantities (int, default: 100): The number of units to trade.
592
- - max_trades (dict): The maximum number of trades allowed per symbol.
593
- - logger: A logger object for tracking operations.
594
- - expert_id (int, default: 5134): Unique identifier for trade positions created by this strategy.
595
- """
596
- self.bars = bars
597
- self.events = events
598
- self.symbol_list = symbol_list or self.bars.symbol_list
599
- self.mode = mode
600
-
601
- self.account = Account()
602
-
603
- self.rr = kwargs.get("rr", 3.0)
604
- self.epsilon = kwargs.get("epsilon", 0.1)
605
- self._initialize(**kwargs)
606
- self.logger = kwargs.get("logger")
607
- self.ID = kwargs.get("expert_id", 5134)
608
-
609
- def _initialize(self, **kwargs):
610
- symbols = self.symbol_list.copy()
611
- returns = kwargs.get("expected_returns")
612
- quantities = kwargs.get("quantities", 100)
613
- max_trades = kwargs.get("max_trades")
614
-
615
- self.expeted_return = {index: returns[index] for index in symbols}
616
- self.max_trades = {index: max_trades[index] for index in symbols}
617
- self.last_price = {index: None for index in symbols}
618
- self.heightest_price = {index: None for index in symbols}
619
- self.lowerst_price = {index: None for index in symbols}
620
-
621
- if self.mode == "backtest":
622
- self.qty = get_quantities(quantities, symbols)
623
- self.num_buys = {index: 0 for index in symbols}
624
- self.buy_prices = {index: [] for index in symbols}
625
-
626
- def _calculate_pct_change(self, current_price, lh_price):
627
- return ((current_price - lh_price) / lh_price) * 100
628
-
629
- def calculate_live_signals(self):
630
- signals = {index: None for index in self.symbol_list}
631
- for index in self.symbol_list:
632
- current_price = self.account.get_tick_info(index).ask
633
- if self.last_price[index] is None:
634
- self.last_price[index] = current_price
635
- self.heightest_price[index] = current_price
636
- self.lowerst_price[index] = current_price
637
- continue
638
- else:
639
- if (
640
- self._calculate_pct_change(
641
- current_price, self.heightest_price[index]
642
- )
643
- >= self.epsilon
644
- ):
645
- self.heightest_price[index] = current_price
646
- elif (
647
- self._calculate_pct_change(current_price, self.lowerst_price[index])
648
- <= -self.epsilon
649
- ):
650
- self.lowerst_price[index] = current_price
651
-
652
- down_change = self._calculate_pct_change(
653
- current_price, self.heightest_price[index]
654
- )
655
-
656
- if down_change <= -(self.expeted_return[index] / self.rr):
657
- signals[index] = "LONG"
658
-
659
- positions = self.account.get_positions(symbol=index)
660
- if positions is not None:
661
- buy_prices = [
662
- position.price_open
663
- for position in positions
664
- if position.type == 0 and position.magic == self.ID
665
- ]
666
- if len(buy_prices) == 0:
667
- continue
668
- avg_price = sum(buy_prices) / len(buy_prices)
669
- if (
670
- self._calculate_pct_change(current_price, avg_price)
671
- >= (self.expeted_return[index])
672
- ):
673
- signals[index] = "EXIT"
674
- self.logger.info(
675
- f"SYMBOL={index} - Hp={self.heightest_price[index]} - "
676
- f"Lp={self.lowerst_price[index]} - Cp={current_price} - %chg={round(down_change, 2)}"
677
- )
678
- return signals
679
-
680
- def calculate_backtest_signals(self):
681
- for index in self.symbol_list.copy():
682
- dt = self.bars.get_latest_bar_datetime(index)
683
- last_price = self.bars.get_latest_bars_values(index, "close", N=1)
684
-
685
- current_price = last_price[-1]
686
- if self.last_price[index] is None:
687
- self.last_price[index] = current_price
688
- self.heightest_price[index] = current_price
689
- self.lowerst_price[index] = current_price
690
- continue
691
- else:
692
- if (
693
- self._calculate_pct_change(
694
- current_price, self.heightest_price[index]
695
- )
696
- >= self.epsilon
697
- ):
698
- self.heightest_price[index] = current_price
699
- elif (
700
- self._calculate_pct_change(current_price, self.lowerst_price[index])
701
- <= -self.epsilon
702
- ):
703
- self.lowerst_price[index] = current_price
704
-
705
- down_change = self._calculate_pct_change(
706
- current_price, self.heightest_price[index]
707
- )
708
-
709
- if (
710
- down_change <= -(self.expeted_return[index] / self.rr)
711
- and self.num_buys[index] <= self.max_trades[index]
712
- ):
713
- signal = SignalEvent(
714
- 100,
715
- index,
716
- dt,
717
- "LONG",
718
- quantity=self.qty[index],
719
- price=current_price,
720
- )
721
- self.events.put(signal)
722
- self.num_buys[index] += 1
723
- self.buy_prices[index].append(current_price)
724
-
725
- elif self.num_buys[index] > 0:
726
- av_price = sum(self.buy_prices[index]) / len(self.buy_prices[index])
727
- qty = self.qty[index] * self.num_buys[index]
728
- if (
729
- self._calculate_pct_change(current_price, av_price)
730
- >= (self.expeted_return[index])
731
- ):
732
- signal = SignalEvent(
733
- 100, index, dt, "EXIT", quantity=qty, price=current_price
734
- )
735
- self.events.put(signal)
736
- self.num_buys[index] = 0
737
- self.buy_prices[index] = []
738
-
739
- def calculate_signals(self, event=None) -> Dict[str, Union[str, None]]:
740
- if self.mode == "backtest" and event is not None:
741
- if event.type == "MARKET":
742
- self.calculate_backtest_signals()
743
- elif self.mode == "live":
744
- return self.calculate_live_signals()
745
-
746
-
747
- def _run_backtest(strategy_name: str, capital: float, symbol_list: list, kwargs: dict):
748
- """
749
- Executes a backtest of the specified strategy
750
- integrating a Hidden Markov Model (HMM) for risk management.
751
- """
752
- kwargs["strategy_name"] = strategy_name
753
- engine = BacktestEngine(
754
- symbol_list,
755
- capital,
756
- 0.0,
757
- datetime.strptime(kwargs["yf_start"], "%Y-%m-%d"),
758
- kwargs.get("data_handler", YFDataHandler),
759
- kwargs.get("exc_handler", SimExecutionHandler),
760
- kwargs.pop("backtester_class"),
761
- **kwargs,
762
- )
763
- engine.simulate_trading()
764
-
765
-
766
- def _run_arch_backtest(capital: float = 100000.0, quantity: int = 1000):
767
- hmm_data = yf.download("^GSPC", start="1990-01-01", end="2009-12-31")
768
- kwargs = {
769
- "quantity": quantity,
770
- "yf_start": "2010-01-04",
771
- "hmm_data": hmm_data,
772
- "backtester_class": ArimaGarchStrategy,
773
- "data_handler": YFDataHandler,
774
- }
775
- _run_backtest("ARIMA+GARCH & HMM", capital, ["^GSPC"], kwargs)
776
-
777
-
778
- def _run_kf_backtest(capital: float = 100000.0, quantity: int = 2000):
779
- symbol_list = ["IEI", "TLT"]
780
- tlt = yf.download("TLT", end="2008-07-09")
781
- iei = yf.download("IEI", end="2008-07-09")
782
- kwargs = {
783
- "quantity": quantity,
784
- "yf_start": "2009-08-03",
785
- "hmm_data": {"IEI": iei, "TLT": tlt},
786
- "hmm_tiker": "TLT",
787
- "session_duration": 6.5,
788
- "backtester_class": KalmanFilterStrategy,
789
- "data_handler": YFDataHandler,
790
- }
791
- _run_backtest("Kalman Filter & HMM", capital, symbol_list, kwargs)
792
-
793
-
794
- def _run_sma_backtest(capital: float = 100000.0, quantity: int = 1):
795
- spx_data = yf.download("^GSPC", start="1990-01-01", end="2009-12-31")
796
- kwargs = {
797
- "quantities": quantity,
798
- "hmm_end": "2009-12-31",
799
- "yf_start": "2010-01-04",
800
- "hmm_data": spx_data,
801
- "mt5_start": datetime(2010, 1, 1),
802
- "mt5_end": datetime(2023, 1, 1),
803
- "backtester_class": SMAStrategy,
804
- "data_handler": MT5DataHandler,
805
- "exc_handler": MT5ExecutionHandler,
806
- }
807
- _run_backtest("SMA & HMM", capital, ["[SP500]"], kwargs)
808
-
809
-
810
- def _run_sistbo_backtest(capital: float = 100000.0, quantity: int = None):
811
- ndx = "[NQ100]"
812
- spx = "[SP500]"
813
- dji = "[DJI30]"
814
- dax = "GERMANY40"
815
-
816
- symbol_list = [spx, dax, dji, ndx]
817
- start = datetime(2010, 6, 1, 2, 0, 0)
818
- quantity = {ndx: 15, spx: 30, dji: 5, dax: 10}
819
- kwargs = {
820
- "expected_returns": {ndx: 1.5, spx: 1.5, dji: 1.0, dax: 1.0},
821
- "quantities": quantity,
822
- "max_trades": {ndx: 3, spx: 3, dji: 3, dax: 3},
823
- "mt5_start": start,
824
- "yf_start": start.strftime("%Y-%m-%d"),
825
- "time_frame": "15m",
826
- "backtester_class": StockIndexSTBOTrading,
827
- "data_handler": MT5DataHandler,
828
- "exc_handler": MT5ExecutionHandler,
829
- }
830
- _run_backtest("Stock Index Short Term Buy Only ", capital, symbol_list, kwargs)
831
-
832
-
833
- _BACKTESTS = {
834
- "sma": _run_sma_backtest,
835
- "klf": _run_kf_backtest,
836
- "arch": _run_arch_backtest,
837
- "sistbo": _run_sistbo_backtest,
838
- }
839
-
840
-
841
- def test_strategy(
842
- strategy: Literal["sma", "klf", "arch", "sistbo"] = "sma",
843
- quantity: Optional[int] = 100,
844
- ):
845
- """
846
- Executes a backtest of the specified strategy
847
-
848
- Args:
849
- strategy : The strategy to use in test mode. Default is `sma`.
850
- - `sma` Execute `SMAStrategy`, for more detail see this class documentation.
851
- - `klf` Execute `KalmanFilterStrategy`, for more detail see this class documentation.
852
- - `arch` Execute `ArimaGarchStrategy`, for more detail see this class documentation.
853
- - `sistbo` Execute `StockIndexSTBOTrading`, for more detail see this class documentation.
854
- quantity : The quantity of assets to be used in the test backtest. Default is 1000.
855
-
856
- """
857
- if strategy in _BACKTESTS:
858
- _BACKTESTS[strategy](quantity=quantity)
859
- else:
860
- raise ValueError(f"Unknown strategy: {strategy}")
1
+ """
2
+ Strategies module for trading strategies backtesting and execution.
3
+ """
4
+
5
+ from datetime import datetime
6
+ from queue import Queue
7
+ from typing import Dict, List, Literal, Optional, Union
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import yfinance as yf
12
+
13
+ from bbstrader.btengine.backtest import BacktestEngine
14
+ from bbstrader.btengine.data import DataHandler, MT5DataHandler, YFDataHandler
15
+ from bbstrader.btengine.event import SignalEvent
16
+ from bbstrader.btengine.execution import MT5ExecutionHandler, SimExecutionHandler
17
+ from bbstrader.btengine.strategy import Strategy
18
+ from bbstrader.metatrader.account import Account
19
+ from bbstrader.metatrader.rates import Rates
20
+ from bbstrader.models.risk import build_hmm_models
21
+ from bbstrader.tseries import ArimaGarchModel, KalmanFilterModel
22
+
23
+ __all__ = [
24
+ "SMAStrategy",
25
+ "ArimaGarchStrategy",
26
+ "KalmanFilterStrategy",
27
+ "StockIndexSTBOTrading",
28
+ "test_strategy",
29
+ "get_quantities",
30
+ ]
31
+
32
+
33
+ def get_quantities(quantities, symbol_list):
34
+ if isinstance(quantities, dict):
35
+ return quantities
36
+ elif isinstance(quantities, int):
37
+ return {symbol: quantities for symbol in symbol_list}
38
+
39
+
40
+ class SMAStrategy(Strategy):
41
+ """
42
+ Carries out a basic Moving Average Crossover strategy bactesting with a
43
+ short/long simple weighted moving average. Default short/long
44
+ windows are 50/200 periods respectively and uses Hiden Markov Model
45
+ as risk Managment system for filteering signals.
46
+
47
+ The trading strategy for this class is exceedingly simple and is used to bettter
48
+ understood. The important issue is the risk management aspect (the Hmm model)
49
+
50
+ The Long-term trend following strategy is of the classic moving average crossover type.
51
+ The rules are simple:
52
+ - At every bar calculate the 50-day and 200-day simple moving averages (SMA)
53
+ - If the 50-day SMA exceeds the 200-day SMA and the strategy is not invested, then go long
54
+ - If the 200-day SMA exceeds the 50-day SMA and the strategy is invested, then close the position
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ bars: DataHandler = None,
60
+ events: Queue = None,
61
+ symbol_list: List[str] = None,
62
+ mode: Literal["backtest", "live"] = "backtest",
63
+ **kwargs,
64
+ ):
65
+ """
66
+ Args:
67
+ bars (DataHandler): A data handler object that provides market data.
68
+ events (Queue): An event queue object where generated signals are placed.
69
+ symbol_list (List[str]): A list of symbols to consider for trading.
70
+ mode (Literal['backtest', 'live']): The mode of operation for the strategy.
71
+ short_window (int, optional): The period for the short moving average.
72
+ long_window (int, optional): The period for the long moving average.
73
+ time_frame (str, optional): The time frame for the data.
74
+ session_duration (float, optional): The duration of the trading session.
75
+ risk_window (int, optional): The window size for the risk model.
76
+ quantities (int, dict | optional): The default quantities of each asset to trade.
77
+ """
78
+ self.bars = bars
79
+ self.events = events
80
+ self.symbol_list = symbol_list or self.bars.symbol_list
81
+ self.mode = mode
82
+
83
+ self.kwargs = kwargs
84
+ self.short_window = kwargs.get("short_window", 50)
85
+ self.long_window = kwargs.get("long_window", 200)
86
+ self.tf = kwargs.get("time_frame", "D1")
87
+ self.qty = get_quantities(kwargs.get("quantities", 100), self.symbol_list)
88
+ self.sd = kwargs.get("session_duration", 23.0)
89
+ self.risk_models = build_hmm_models(self.symbol_list, **self.kwargs)
90
+ self.risk_window = kwargs.get("risk_window", self.long_window)
91
+ self.bought = self._calculate_initial_bought()
92
+
93
+ def _calculate_initial_bought(self):
94
+ bought = {}
95
+ for s in self.symbol_list:
96
+ bought[s] = "OUT"
97
+ return bought
98
+
99
+ def get_backtest_data(self):
100
+ symbol_data = {symbol: None for symbol in self.symbol_list}
101
+ for s in self.symbol_list:
102
+ bar_date = self.bars.get_latest_bar_datetime(s)
103
+ bars = self.bars.get_latest_bars_values(s, "adj_close", N=self.long_window)
104
+ returns_val = self.bars.get_latest_bars_values(
105
+ s, "returns", N=self.risk_window
106
+ )
107
+ if len(bars) >= self.long_window and len(returns_val) >= self.risk_window:
108
+ regime = self.risk_models[s].which_trade_allowed(returns_val)
109
+
110
+ short_sma = np.mean(bars[-self.short_window :])
111
+ long_sma = np.mean(bars[-self.long_window :])
112
+
113
+ symbol_data[s] = (short_sma, long_sma, regime, bar_date)
114
+ return symbol_data
115
+
116
+ def create_backtest_signals(self):
117
+ signals = {symbol: None for symbol in self.symbol_list}
118
+ symbol_data = self.get_backtest_data()
119
+ for s, data in symbol_data.items():
120
+ signal = None
121
+ if data is not None:
122
+ price = self.bars.get_latest_bar_value(s, "adj_close")
123
+ short_sma, long_sma, regime, bar_date = data
124
+ dt = bar_date
125
+ if regime == "LONG":
126
+ # Bulliqh regime
127
+ if short_sma < long_sma and self.bought[s] == "LONG":
128
+ print(f"EXIT: {bar_date}")
129
+ signal = SignalEvent(1, s, dt, "EXIT", price=price)
130
+ self.bought[s] = "OUT"
131
+
132
+ elif short_sma > long_sma and self.bought[s] == "OUT":
133
+ print(f"LONG: {bar_date}")
134
+ signal = SignalEvent(
135
+ 1, s, dt, "LONG", quantity=self.qty[s], price=price
136
+ )
137
+ self.bought[s] = "LONG"
138
+
139
+ elif regime == "SHORT":
140
+ # Bearish regime
141
+ if short_sma > long_sma and self.bought[s] == "SHORT":
142
+ print(f"EXIT: {bar_date}")
143
+ signal = SignalEvent(1, s, dt, "EXIT", price=price)
144
+ self.bought[s] = "OUT"
145
+
146
+ elif short_sma < long_sma and self.bought[s] == "OUT":
147
+ print(f"SHORT: {bar_date}")
148
+ signal = SignalEvent(
149
+ 1, s, dt, "SHORT", quantity=self.qty[s], price=price
150
+ )
151
+ self.bought[s] = "SHORT"
152
+ signals[s] = signal
153
+ return signals
154
+
155
+ def get_live_data(self):
156
+ symbol_data = {symbol: None for symbol in self.symbol_list}
157
+ for symbol in self.symbol_list:
158
+ sig_rate = Rates(symbol, self.tf, 0, self.risk_window + 2, **self.kwargs)
159
+ hmm_data = sig_rate.returns.values
160
+ prices = sig_rate.close.values
161
+ current_regime = self.risk_models[symbol].which_trade_allowed(hmm_data)
162
+ assert len(prices) >= self.long_window and len(hmm_data) >= self.risk_window
163
+ short_sma = np.mean(prices[-self.short_window :])
164
+ long_sma = np.mean(prices[-self.long_window :])
165
+ short_sma, long_sma, current_regime
166
+ symbol_data[symbol] = (short_sma, long_sma, current_regime)
167
+ return symbol_data
168
+
169
+ def create_live_signals(self):
170
+ signals = {symbol: None for symbol in self.symbol_list}
171
+ symbol_data = self.get_live_data()
172
+ for symbol, data in symbol_data.items():
173
+ signal = None
174
+ short_sma, long_sma, regime = data
175
+ if regime == "LONG":
176
+ if short_sma > long_sma:
177
+ signal = "LONG"
178
+ elif regime == "SHORT":
179
+ if short_sma < long_sma:
180
+ signal = "SHORT"
181
+ signals[symbol] = signal
182
+ return signals
183
+
184
+ def calculate_signals(self, event=None):
185
+ if self.mode == "backtest" and event is not None:
186
+ if event.type == "MARKET":
187
+ signals = self.create_backtest_signals()
188
+ for signal in signals.values():
189
+ if signal is not None:
190
+ self.events.put(signal)
191
+ elif self.mode == "live":
192
+ signals = self.create_live_signals()
193
+ return signals
194
+
195
+
196
+ class ArimaGarchStrategy(Strategy):
197
+ """
198
+ The `ArimaGarchStrategy` class extends the `Strategy`
199
+ class to implement a backtesting framework for trading strategies based on
200
+ ARIMA-GARCH models, incorporating a Hidden Markov Model (HMM) for risk management.
201
+
202
+ Features
203
+ ========
204
+ - **ARIMA-GARCH Model**: Utilizes ARIMA for time series forecasting and GARCH for volatility forecasting, aimed at predicting market movements.
205
+
206
+ - **HMM Risk Management**: Employs a Hidden Markov Model to manage risks, determining safe trading regimes.
207
+
208
+ - **Event-Driven Backtesting**: Capable of simulating real-time trading conditions by processing market data and signals sequentially.
209
+
210
+ - **Live Trading**: Supports real-time trading by generating signals based on live ARIMA-GARCH predictions and HMM risk management.
211
+
212
+ Key Methods
213
+ ===========
214
+ - `get_backtest_data()`: Retrieves historical data for backtesting.
215
+ - `create_backtest_signal()`: Generates trading signals based on ARIMA-GARCH predictions and HMM risk management.
216
+ - `get_live_data()`: Retrieves live data for real-time trading.
217
+ - `create_live_signals()`: Generates trading signals based on live ARIMA-GARCH predictions and HMM risk management.
218
+ - `calculate_signals()`: Determines the trading signals based on the mode of operation (backtest or live).
219
+
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ bars: DataHandler = None,
225
+ events: Queue = None,
226
+ symbol_list: List[str] = None,
227
+ mode: Literal["backtest", "live"] = "backtest",
228
+ **kwargs,
229
+ ):
230
+ """
231
+ Args:
232
+ `bars`: A data handler object that provides market data.
233
+ `events`: An event queue object where generated signals are placed.
234
+ `symbol_list`: A list of symbols to consider for trading.
235
+ `mode`: The mode of operation for the strategy.
236
+ `arima_window`: The window size for rolling prediction in backtesting.
237
+ `time_frame`: The time frame for the data.
238
+ `quantities`: Quantity of each assets to trade.
239
+ `hmm_window`: Lookback period for HMM.
240
+ """
241
+ self.bars = bars
242
+ self.events = events
243
+ self.symbol_list = symbol_list or self.bars.symbol_list
244
+ self.mode = mode
245
+
246
+ self.qty = get_quantities(kwargs.get("quantities", 100), self.symbol_list)
247
+ self.arima_window = kwargs.get("arima_window", 252)
248
+ self.tf = kwargs.get("time_frame", "D1")
249
+ self.sd = kwargs.get("session_duration", 23.0)
250
+ self.risk_window = kwargs.get("hmm_window", 50)
251
+ self.risk_models = build_hmm_models(self.symbol_list, **kwargs)
252
+ self.arima_models = self._build_arch_models(**kwargs)
253
+
254
+ self.long_market = {s: False for s in self.symbol_list}
255
+ self.short_market = {s: False for s in self.symbol_list}
256
+
257
+ def _build_arch_models(self, **kwargs) -> Dict[str, ArimaGarchModel]:
258
+ arch_models = {symbol: None for symbol in self.symbol_list}
259
+ for symbol in self.symbol_list:
260
+ try:
261
+ rates = Rates(symbol, self.tf, 0)
262
+ data = rates.get_rates_from_pos()
263
+ assert data is not None, f"No data for {symbol}"
264
+ except AssertionError:
265
+ data = yf.download(symbol, start=kwargs.get("yf_start"))
266
+ arch = ArimaGarchModel(symbol, data, k=self.arima_window)
267
+ arch_models[symbol] = arch
268
+ return arch_models
269
+
270
+ def get_backtest_data(self):
271
+ symbol_data = {symbol: None for symbol in self.symbol_list}
272
+ for symbol in self.symbol_list:
273
+ M = self.arima_window
274
+ N = self.risk_window
275
+ dt = self.bars.get_latest_bar_datetime(symbol)
276
+ bars = self.bars.get_latest_bars_values(
277
+ symbol, "close", N=self.arima_window
278
+ )
279
+ returns = self.bars.get_latest_bars_values(
280
+ symbol, "returns", N=self.risk_window
281
+ )
282
+ df = pd.DataFrame()
283
+ df["Close"] = bars[-M:]
284
+ df = df.dropna()
285
+ arch_returns = self.arima_models[symbol].load_and_prepare_data(df)
286
+ data = arch_returns["diff_log_return"].iloc[-self.arima_window :]
287
+ if len(data) >= M and len(returns) >= N:
288
+ symbol_data[symbol] = (data, returns[-N:], dt)
289
+ return symbol_data
290
+
291
+ def create_backtest_signal(self):
292
+ signals = {symbol: None for symbol in self.symbol_list}
293
+ for symbol in self.symbol_list:
294
+ symbol_data = self.get_backtest_data()[symbol]
295
+ if symbol_data is not None:
296
+ data, returns, dt = symbol_data
297
+ signal = None
298
+ prediction = self.arima_models[symbol].get_prediction(data)
299
+ regime = self.risk_models[symbol].which_trade_allowed(returns)
300
+ price = self.bars.get_latest_bar_value(symbol, "adj_close")
301
+
302
+ # If we are short the market, check for an exit
303
+ if prediction > 0 and self.short_market[symbol]:
304
+ signal = SignalEvent(1, symbol, dt, "EXIT", price=price)
305
+ print(f"{dt}: EXIT SHORT")
306
+ self.short_market[symbol] = False
307
+
308
+ # If we are long the market, check for an exit
309
+ elif prediction < 0 and self.long_market[symbol]:
310
+ signal = SignalEvent(1, symbol, dt, "EXIT", price=price)
311
+ print(f"{dt}: EXIT LONG")
312
+ self.long_market[symbol] = False
313
+
314
+ if regime == "LONG":
315
+ # If we are not in the market, go long
316
+ if prediction > 0 and not self.long_market[symbol]:
317
+ signal = SignalEvent(
318
+ 1,
319
+ symbol,
320
+ dt,
321
+ "LONG",
322
+ quantity=self.qty[symbol],
323
+ price=price,
324
+ )
325
+ print(f"{dt}: LONG")
326
+ self.long_market[symbol] = True
327
+
328
+ elif regime == "SHORT":
329
+ # If we are not in the market, go short
330
+ if prediction < 0 and not self.short_market[symbol]:
331
+ signal = SignalEvent(
332
+ 1,
333
+ symbol,
334
+ dt,
335
+ "SHORT",
336
+ quantity=self.qty[symbol],
337
+ price=price,
338
+ )
339
+ print(f"{dt}: SHORT")
340
+ self.short_market[symbol] = True
341
+ signals[symbol] = signal
342
+ return signals
343
+
344
+ def get_live_data(self):
345
+ symbol_data = {symbol: None for symbol in self.symbol_list}
346
+ for symbol in self.symbol_list:
347
+ arch_data = Rates(symbol, self.tf, 0, self.arima_window)
348
+ rates = arch_data.get_rates_from_pos()
349
+ arch_returns = self.arima_models[symbol].load_and_prepare_data(rates)
350
+ window_data = arch_returns["diff_log_return"].iloc[-self.arima_window :]
351
+ hmm_returns = arch_data.returns.values[-self.risk_window :]
352
+ symbol_data[symbol] = (window_data, hmm_returns)
353
+ return symbol_data
354
+
355
+ def create_live_signals(self):
356
+ signals = {symbol: None for symbol in self.symbol_list}
357
+ data = self.get_live_data()
358
+ for symbol in self.symbol_list:
359
+ symbol_data = data[symbol]
360
+ if symbol_data is not None:
361
+ window_data, hmm_returns = symbol_data
362
+ prediction = self.arima_models[symbol].get_prediction(window_data)
363
+ regime = self.risk_models[symbol].which_trade_allowed(hmm_returns)
364
+ if regime == "LONG":
365
+ if prediction > 0:
366
+ signals[symbol] = "LONG"
367
+ elif regime == "SHORT":
368
+ if prediction < 0:
369
+ signals[symbol] = "SHORT"
370
+ return signals
371
+
372
+ def calculate_signals(self, event=None):
373
+ if self.mode == "backtest" and event is not None:
374
+ if event.type == "MARKET":
375
+ signals = self.create_backtest_signal()
376
+ for signal in signals.values():
377
+ if signal is not None:
378
+ self.events.put(signal)
379
+ elif self.mode == "live":
380
+ return self.create_live_signals()
381
+
382
+
383
+ class KalmanFilterStrategy(Strategy):
384
+ """
385
+ The `KalmanFilterStrategy` class implements a backtesting framework for a
386
+ [pairs trading](https://en.wikipedia.org/wiki/Pairs_trade) strategy using
387
+ Kalman Filter for signals and Hidden Markov Models (HMM) for risk management.
388
+ This document outlines the structure and usage of the `KalmanFilterStrategy`,
389
+ including initialization parameters, main functions, and an example of how to run a backtest.
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ bars: DataHandler = None,
395
+ events: Queue = None,
396
+ symbol_list: List[str] = None,
397
+ mode: Literal["backtest", "live"] = "backtest",
398
+ **kwargs,
399
+ ):
400
+ """
401
+ Args:
402
+ `bars`: `DataHandler` for market data handling.
403
+ `events`: A queue for managing events.
404
+ `symbol_list`: List of ticker symbols for the pairs trading strategy.
405
+ `mode`: Mode of operation for the strategy.
406
+ kwargs : Additional keyword arguments including
407
+ - `quantity`: Quantity of assets to trade. Default is 100.
408
+ - `hmm_window`: Window size for calculating returns for the HMM. Default is 50.
409
+ - `hmm_tiker`: Ticker symbol used by the HMM for risk management.
410
+ - `time_frame`: Time frame for the data. Default is 'D1'.
411
+ - `session_duration`: Duration of the trading session. Default is 6.5.
412
+ """
413
+ self.bars = bars
414
+ self.events_queue = events
415
+ self.symbol_list = symbol_list or self.bars.symbol_list
416
+ self.mode = mode
417
+
418
+ self.hmm_tiker = kwargs.get("hmm_tiker")
419
+ self._assert_tikers()
420
+ self.account = Account(**kwargs)
421
+ self.hmm_window = kwargs.get("hmm_window", 50)
422
+ self.qty = kwargs.get("quantity", 100)
423
+ self.tf = kwargs.get("time_frame", "D1")
424
+ self.sd = kwargs.get("session_duration", 6.5)
425
+
426
+ self.risk_model = build_hmm_models(self.symbol_list, **kwargs)
427
+ self.kl_model = KalmanFilterModel(self.tickers, **kwargs)
428
+
429
+ self.long_market = False
430
+ self.short_market = False
431
+
432
+ def _assert_tikers(self):
433
+ if self.symbol_list is None or len(self.symbol_list) != 2:
434
+ raise ValueError("A list of 2 Tickers must be provide for this strategy")
435
+ self.tickers = self.symbol_list
436
+ if self.hmm_tiker is None:
437
+ raise ValueError(
438
+ "You need to provide a ticker used by the HMM for risk management"
439
+ )
440
+
441
+ def calculate_btxy(self, etqt, regime, dt):
442
+ # Make sure there is no position open
443
+ if etqt is None:
444
+ return
445
+ et, sqrt_Qt = etqt
446
+ theta = self.kl_model.theta
447
+ p1 = self.bars.get_latest_bar_value(self.tickers[1], "adj_close")
448
+ p0 = self.bars.get_latest_bar_value(self.tickers[0], "adj_close")
449
+ if et >= -sqrt_Qt and self.long_market:
450
+ print("CLOSING LONG: %s" % dt)
451
+ y_signal = SignalEvent(1, self.tickers[1], dt, "EXIT", price=p1)
452
+ x_signal = SignalEvent(1, self.tickers[0], dt, "EXIT", price=p0)
453
+ self.events_queue.put(y_signal)
454
+ self.events_queue.put(x_signal)
455
+ self.long_market = False
456
+
457
+ elif et <= sqrt_Qt and self.short_market:
458
+ print("CLOSING SHORT: %s" % dt)
459
+ y_signal = SignalEvent(1, self.tickers[1], dt, "EXIT", price=p1)
460
+ x_signal = SignalEvent(1, self.tickers[0], dt, "EXIT", price=p0)
461
+ self.events_queue.put(y_signal)
462
+ self.events_queue.put(x_signal)
463
+ self.short_market = False
464
+
465
+ # Long Entry
466
+ if regime == "LONG":
467
+ if et <= -sqrt_Qt and not self.long_market:
468
+ print("LONG: %s" % dt)
469
+ y_signal = SignalEvent(
470
+ 1, self.tickers[1], dt, "LONG", self.qty, 1.0, price=p1
471
+ )
472
+ x_signal = SignalEvent(
473
+ 1, self.tickers[0], dt, "SHORT", self.qty, theta[0], price=p0
474
+ )
475
+ self.events_queue.put(y_signal)
476
+ self.events_queue.put(x_signal)
477
+ self.long_market = True
478
+
479
+ # Short Entry
480
+ elif regime == "SHORT":
481
+ if et >= sqrt_Qt and not self.short_market:
482
+ print("SHORT: %s" % dt)
483
+ y_signal = SignalEvent(
484
+ 1, self.tickers[1], dt, "SHORT", self.qty, 1.0, price=p1
485
+ )
486
+ x_signal = SignalEvent(
487
+ 1, self.tickers[0], "LONG", self.qty, theta[0], price=p0
488
+ )
489
+ self.events_queue.put(y_signal)
490
+ self.events_queue.put(x_signal)
491
+ self.short_market = True
492
+
493
+ def calculate_livexy(self):
494
+ signals = {symbol: None for symbol in self.symbol_list}
495
+ p0_price = self.account.get_tick_info(self.tickers[0]).ask
496
+ p1_price = self.account.get_tick_info(self.tickers[1]).ask
497
+ prices = np.array([p0_price, p1_price])
498
+ et_std = self.kl_model.calculate_etqt(prices)
499
+ if et_std is not None:
500
+ et, std = et_std
501
+ y_signal = None
502
+ x_signal = None
503
+
504
+ if et >= -std or et <= std:
505
+ y_signal = "EXIT"
506
+ x_signal = "EXIT"
507
+
508
+ if et <= -std:
509
+ y_signal = "LONG"
510
+ x_signal = "SHORT"
511
+
512
+ if et >= std:
513
+ y_signal = "SHORT"
514
+ x_signal = "LONG"
515
+
516
+ signals[self.tickers[0]] = x_signal
517
+ signals[self.tickers[1]] = y_signal
518
+ return signals
519
+
520
+ def calculate_backtest_signals(self):
521
+ p0, p1 = self.tickers[0], self.tickers[1]
522
+ dt = self.bars.get_latest_bar_datetime(p0)
523
+ x = self.bars.get_latest_bar_value(p0, "close")
524
+ y = self.bars.get_latest_bar_value(p1, "close")
525
+ returns = self.bars.get_latest_bars_values(
526
+ self.hmm_tiker, "returns", N=self.hmm_window
527
+ )
528
+ latest_prices = np.array([-1.0, -1.0])
529
+ if len(returns) >= self.hmm_window:
530
+ latest_prices[0] = x
531
+ latest_prices[1] = y
532
+ et_qt = self.kl_model.calculate_etqt(latest_prices)
533
+ regime = self.risk_model[self.hmm_tiker].which_trade_allowed(returns)
534
+ self.calculate_btxy(et_qt, regime, dt)
535
+
536
+ def calculate_live_signals(self):
537
+ # Data Retrieval
538
+ signals = {symbol: None for symbol in self.symbol_list}
539
+ initial_signals = self.calculate_livexy()
540
+ hmm_data = Rates(self.hmm_ticker, self.tf, 0, self.hmm_window)
541
+ returns = hmm_data.returns.values
542
+ current_regime = self.risk_model[self.hmm_tiker].which_trade_allowed(returns)
543
+ for symbol in self.symbol_list:
544
+ if symbol in initial_signals:
545
+ signal = initial_signals[symbol]
546
+ if signal == "LONG" and current_regime == "LONG":
547
+ signals[symbol] = "LONG"
548
+ elif signal == "SHORT" and current_regime == "SHORT":
549
+ signals[symbol] = "SHORT"
550
+ return signals
551
+
552
+ def calculate_signals(self, event=None):
553
+ """
554
+ Calculate the Kalman Filter strategy.
555
+ """
556
+ if self.mode == "backtest" and event is not None:
557
+ if event.type == "MARKET":
558
+ self.calculate_backtest_signals()
559
+ elif self.mode == "live":
560
+ return self.calculate_live_signals()
561
+
562
+
563
+ class StockIndexSTBOTrading(Strategy):
564
+ """
565
+ The StockIndexSTBOTrading class implements a stock index Contract for Difference (CFD)
566
+ Buy-Only trading strategy. This strategy is based on the assumption that stock markets
567
+ typically follow a long-term uptrend. The strategy is designed to capitalize on market
568
+ corrections and price dips, where stocks or indices temporarily drop but are expected
569
+ to recover. It operates in two modes: backtest and live, and it is particularly
570
+ tailored to index trading.
571
+ """
572
+
573
+ def __init__(
574
+ self,
575
+ bars: DataHandler = None,
576
+ events: Queue = None,
577
+ symbol_list: List[str] = None,
578
+ mode: Literal["backtest", "live"] = "backtest",
579
+ **kwargs,
580
+ ):
581
+ """
582
+ Args:
583
+ `bars`: `DataHandler` for market data handling.
584
+ `events`: A queue for managing events.
585
+ `symbol_list`: List of ticker symbols for the pairs trading strategy.
586
+ `mode`: Mode of operation for the strategy.
587
+ kwargs : Additional keyword arguments including
588
+ - rr (float, default: 3.0): The risk-reward ratio used to determine exit points.
589
+ - epsilon (float, default: 0.1): The percentage threshold for price changes when considering new highs or lows.
590
+ - expected_returns (dict): Expected return percentages for each symbol in the symbol list.
591
+ - quantities (int, default: 100): The number of units to trade.
592
+ - max_trades (dict): The maximum number of trades allowed per symbol.
593
+ - logger: A logger object for tracking operations.
594
+ - expert_id (int, default: 5134): Unique identifier for trade positions created by this strategy.
595
+ """
596
+ self.bars = bars
597
+ self.events = events
598
+ self.symbol_list = symbol_list or self.bars.symbol_list
599
+ self.mode = mode
600
+
601
+ self.account = Account()
602
+
603
+ self.rr = kwargs.get("rr", 3.0)
604
+ self.epsilon = kwargs.get("epsilon", 0.1)
605
+ self._initialize(**kwargs)
606
+ self.logger = kwargs.get("logger")
607
+ self.ID = kwargs.get("expert_id", 5134)
608
+
609
+ def _initialize(self, **kwargs):
610
+ symbols = self.symbol_list.copy()
611
+ returns = kwargs.get("expected_returns")
612
+ quantities = kwargs.get("quantities", 100)
613
+ max_trades = kwargs.get("max_trades")
614
+
615
+ self.expeted_return = {index: returns[index] for index in symbols}
616
+ self.max_trades = {index: max_trades[index] for index in symbols}
617
+ self.last_price = {index: None for index in symbols}
618
+ self.heightest_price = {index: None for index in symbols}
619
+ self.lowerst_price = {index: None for index in symbols}
620
+
621
+ if self.mode == "backtest":
622
+ self.qty = get_quantities(quantities, symbols)
623
+ self.num_buys = {index: 0 for index in symbols}
624
+ self.buy_prices = {index: [] for index in symbols}
625
+
626
+ def _calculate_pct_change(self, current_price, lh_price):
627
+ return ((current_price - lh_price) / lh_price) * 100
628
+
629
+ def calculate_live_signals(self):
630
+ signals = {index: None for index in self.symbol_list}
631
+ for index in self.symbol_list:
632
+ current_price = self.account.get_tick_info(index).ask
633
+ if self.last_price[index] is None:
634
+ self.last_price[index] = current_price
635
+ self.heightest_price[index] = current_price
636
+ self.lowerst_price[index] = current_price
637
+ continue
638
+ else:
639
+ if (
640
+ self._calculate_pct_change(
641
+ current_price, self.heightest_price[index]
642
+ )
643
+ >= self.epsilon
644
+ ):
645
+ self.heightest_price[index] = current_price
646
+ elif (
647
+ self._calculate_pct_change(current_price, self.lowerst_price[index])
648
+ <= -self.epsilon
649
+ ):
650
+ self.lowerst_price[index] = current_price
651
+
652
+ down_change = self._calculate_pct_change(
653
+ current_price, self.heightest_price[index]
654
+ )
655
+
656
+ if down_change <= -(self.expeted_return[index] / self.rr):
657
+ signals[index] = "LONG"
658
+
659
+ positions = self.account.get_positions(symbol=index)
660
+ if positions is not None:
661
+ buy_prices = [
662
+ position.price_open
663
+ for position in positions
664
+ if position.type == 0 and position.magic == self.ID
665
+ ]
666
+ if len(buy_prices) == 0:
667
+ continue
668
+ avg_price = sum(buy_prices) / len(buy_prices)
669
+ if (
670
+ self._calculate_pct_change(current_price, avg_price)
671
+ >= (self.expeted_return[index])
672
+ ):
673
+ signals[index] = "EXIT"
674
+ self.logger.info(
675
+ f"SYMBOL={index} - Hp={self.heightest_price[index]} - "
676
+ f"Lp={self.lowerst_price[index]} - Cp={current_price} - %chg={round(down_change, 2)}"
677
+ )
678
+ return signals
679
+
680
+ def calculate_backtest_signals(self):
681
+ for index in self.symbol_list.copy():
682
+ dt = self.bars.get_latest_bar_datetime(index)
683
+ last_price = self.bars.get_latest_bars_values(index, "close", N=1)
684
+
685
+ current_price = last_price[-1]
686
+ if self.last_price[index] is None:
687
+ self.last_price[index] = current_price
688
+ self.heightest_price[index] = current_price
689
+ self.lowerst_price[index] = current_price
690
+ continue
691
+ else:
692
+ if (
693
+ self._calculate_pct_change(
694
+ current_price, self.heightest_price[index]
695
+ )
696
+ >= self.epsilon
697
+ ):
698
+ self.heightest_price[index] = current_price
699
+ elif (
700
+ self._calculate_pct_change(current_price, self.lowerst_price[index])
701
+ <= -self.epsilon
702
+ ):
703
+ self.lowerst_price[index] = current_price
704
+
705
+ down_change = self._calculate_pct_change(
706
+ current_price, self.heightest_price[index]
707
+ )
708
+
709
+ if (
710
+ down_change <= -(self.expeted_return[index] / self.rr)
711
+ and self.num_buys[index] <= self.max_trades[index]
712
+ ):
713
+ signal = SignalEvent(
714
+ 100,
715
+ index,
716
+ dt,
717
+ "LONG",
718
+ quantity=self.qty[index],
719
+ price=current_price,
720
+ )
721
+ self.events.put(signal)
722
+ self.num_buys[index] += 1
723
+ self.buy_prices[index].append(current_price)
724
+
725
+ elif self.num_buys[index] > 0:
726
+ av_price = sum(self.buy_prices[index]) / len(self.buy_prices[index])
727
+ qty = self.qty[index] * self.num_buys[index]
728
+ if (
729
+ self._calculate_pct_change(current_price, av_price)
730
+ >= (self.expeted_return[index])
731
+ ):
732
+ signal = SignalEvent(
733
+ 100, index, dt, "EXIT", quantity=qty, price=current_price
734
+ )
735
+ self.events.put(signal)
736
+ self.num_buys[index] = 0
737
+ self.buy_prices[index] = []
738
+
739
+ def calculate_signals(self, event=None) -> Dict[str, Union[str, None]]:
740
+ if self.mode == "backtest" and event is not None:
741
+ if event.type == "MARKET":
742
+ self.calculate_backtest_signals()
743
+ elif self.mode == "live":
744
+ return self.calculate_live_signals()
745
+
746
+
747
+ def _run_backtest(strategy_name: str, capital: float, symbol_list: list, kwargs: dict):
748
+ """
749
+ Executes a backtest of the specified strategy
750
+ integrating a Hidden Markov Model (HMM) for risk management.
751
+ """
752
+ kwargs["strategy_name"] = strategy_name
753
+ engine = BacktestEngine(
754
+ symbol_list,
755
+ capital,
756
+ 0.0,
757
+ datetime.strptime(kwargs["yf_start"], "%Y-%m-%d"),
758
+ kwargs.get("data_handler", YFDataHandler),
759
+ kwargs.get("exc_handler", SimExecutionHandler),
760
+ kwargs.pop("backtester_class"),
761
+ **kwargs,
762
+ )
763
+ engine.simulate_trading()
764
+
765
+
766
+ def _run_arch_backtest(capital: float = 100000.0, quantity: int = 1000):
767
+ hmm_data = yf.download("^GSPC", start="1990-01-01", end="2009-12-31")
768
+ kwargs = {
769
+ "quantity": quantity,
770
+ "yf_start": "2010-01-04",
771
+ "hmm_data": hmm_data,
772
+ "backtester_class": ArimaGarchStrategy,
773
+ "data_handler": YFDataHandler,
774
+ }
775
+ _run_backtest("ARIMA+GARCH & HMM", capital, ["^GSPC"], kwargs)
776
+
777
+
778
+ def _run_kf_backtest(capital: float = 100000.0, quantity: int = 2000):
779
+ symbol_list = ["IEI", "TLT"]
780
+ tlt = yf.download("TLT", end="2008-07-09")
781
+ iei = yf.download("IEI", end="2008-07-09")
782
+ kwargs = {
783
+ "quantity": quantity,
784
+ "yf_start": "2009-08-03",
785
+ "hmm_data": {"IEI": iei, "TLT": tlt},
786
+ "hmm_tiker": "TLT",
787
+ "session_duration": 6.5,
788
+ "backtester_class": KalmanFilterStrategy,
789
+ "data_handler": YFDataHandler,
790
+ }
791
+ _run_backtest("Kalman Filter & HMM", capital, symbol_list, kwargs)
792
+
793
+
794
+ def _run_sma_backtest(capital: float = 100000.0, quantity: int = 1):
795
+ spx_data = yf.download("^GSPC", start="1990-01-01", end="2009-12-31")
796
+ kwargs = {
797
+ "quantities": quantity,
798
+ "hmm_end": "2009-12-31",
799
+ "yf_start": "2010-01-04",
800
+ "hmm_data": spx_data,
801
+ "mt5_start": datetime(2010, 1, 1),
802
+ "mt5_end": datetime(2023, 1, 1),
803
+ "backtester_class": SMAStrategy,
804
+ "data_handler": MT5DataHandler,
805
+ "exc_handler": MT5ExecutionHandler,
806
+ }
807
+ _run_backtest("SMA & HMM", capital, ["[SP500]"], kwargs)
808
+
809
+
810
+ def _run_sistbo_backtest(capital: float = 100000.0, quantity: int = None):
811
+ ndx = "[NQ100]"
812
+ spx = "[SP500]"
813
+ dji = "[DJI30]"
814
+ dax = "GERMANY40"
815
+
816
+ symbol_list = [spx, dax, dji, ndx]
817
+ start = datetime(2010, 6, 1, 2, 0, 0)
818
+ quantity = {ndx: 15, spx: 30, dji: 5, dax: 10}
819
+ kwargs = {
820
+ "expected_returns": {ndx: 1.5, spx: 1.5, dji: 1.0, dax: 1.0},
821
+ "quantities": quantity,
822
+ "max_trades": {ndx: 3, spx: 3, dji: 3, dax: 3},
823
+ "mt5_start": start,
824
+ "yf_start": start.strftime("%Y-%m-%d"),
825
+ "time_frame": "15m",
826
+ "backtester_class": StockIndexSTBOTrading,
827
+ "data_handler": MT5DataHandler,
828
+ "exc_handler": MT5ExecutionHandler,
829
+ }
830
+ _run_backtest("Stock Index Short Term Buy Only ", capital, symbol_list, kwargs)
831
+
832
+
833
+ _BACKTESTS = {
834
+ "sma": _run_sma_backtest,
835
+ "klf": _run_kf_backtest,
836
+ "arch": _run_arch_backtest,
837
+ "sistbo": _run_sistbo_backtest,
838
+ }
839
+
840
+
841
+ def test_strategy(
842
+ strategy: Literal["sma", "klf", "arch", "sistbo"] = "sma",
843
+ quantity: Optional[int] = 100,
844
+ ):
845
+ """
846
+ Executes a backtest of the specified strategy
847
+
848
+ Args:
849
+ strategy : The strategy to use in test mode. Default is `sma`.
850
+ - `sma` Execute `SMAStrategy`, for more detail see this class documentation.
851
+ - `klf` Execute `KalmanFilterStrategy`, for more detail see this class documentation.
852
+ - `arch` Execute `ArimaGarchStrategy`, for more detail see this class documentation.
853
+ - `sistbo` Execute `StockIndexSTBOTrading`, for more detail see this class documentation.
854
+ quantity : The quantity of assets to be used in the test backtest. Default is 1000.
855
+
856
+ """
857
+ if strategy in _BACKTESTS:
858
+ _BACKTESTS[strategy](quantity=quantity)
859
+ else:
860
+ raise ValueError(f"Unknown strategy: {strategy}")