bbstrader 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bbstrader might be problematic. Click here for more details.

bbstrader/tseries.py CHANGED
@@ -6,7 +6,6 @@ tasks such as cointegration testing, volatility modeling,
6
6
  and filter-based estimation to assist in trading strategy development,
7
7
  market analysis, and financial data exploration.
8
8
  """
9
-
10
9
  import numpy as np
11
10
  import pandas as pd
12
11
  import yfinance as yf
@@ -19,10 +18,13 @@ import statsmodels.tsa.stattools as ts
19
18
  from numpy import cumsum, log, polyfit, sqrt, std, subtract
20
19
  from numpy.random import randn
21
20
  from hurst import compute_Hc
21
+ from scipy.optimize import minimize
22
22
  from filterpy.kalman import KalmanFilter
23
23
  from statsmodels.tsa.vector_ar.vecm import coint_johansen
24
+ from statsmodels.graphics.tsaplots import plot_acf
24
25
  from itertools import combinations
25
26
  from typing import Union, List, Tuple
27
+ from statsmodels.stats.diagnostic import acorr_ljungbox
26
28
  import pprint
27
29
  import warnings
28
30
  warnings.filterwarnings("ignore")
@@ -41,9 +43,13 @@ __all__ = [
41
43
  "run_cadf_test",
42
44
  "run_hurst_test",
43
45
  "run_coint_test",
44
- "run_kalman_filter"
46
+ "run_kalman_filter",
47
+ "ArimaGarchModel",
48
+ "KalmanFilterModel",
49
+ "OrnsteinUhlenbeckModel"
45
50
  ]
46
51
 
52
+
47
53
  def load_and_prepare_data(df: pd.DataFrame):
48
54
  """
49
55
  Prepares financial time series data for analysis.
@@ -74,7 +80,7 @@ def load_and_prepare_data(df: pd.DataFrame):
74
80
  return data
75
81
 
76
82
 
77
- def fit_best_arima(window_data: Union[pd.Series , np.ndarray]):
83
+ def fit_best_arima(window_data: Union[pd.Series, np.ndarray]):
78
84
  """
79
85
  Identifies and fits the best `ARIMA` model
80
86
  based on the Akaike Information Criterion `(AIC)`.
@@ -91,6 +97,11 @@ def fit_best_arima(window_data: Union[pd.Series , np.ndarray]):
91
97
  Returns:
92
98
  ARIMA result object: The fitted `ARIMA` model with the lowest `AIC`.
93
99
  """
100
+ if isinstance(window_data, pd.Series):
101
+ window_data = window_data.values
102
+
103
+ window_data = window_data[~(np.isnan(window_data) | np.isinf(window_data))]
104
+ # Fit ARIMA model with best parameters
94
105
  model = pm.auto_arima(
95
106
  window_data,
96
107
  start_p=1,
@@ -101,15 +112,25 @@ def fit_best_arima(window_data: Union[pd.Series , np.ndarray]):
101
112
  stepwise=True
102
113
  )
103
114
  final_order = model.order
104
- import warnings
105
- from statsmodels.tools.sm_exceptions import ConvergenceWarning
106
- warnings.filterwarnings("ignore", category=ConvergenceWarning)
107
- best_arima_model = ARIMA(
108
- window_data, order=final_order, missing='drop').fit()
109
- return best_arima_model
110
-
111
-
112
- def fit_garch(window_data: Union[pd.Series , np.ndarray]):
115
+ from arch.utility.exceptions import ConvergenceWarning as ArchConvergenceWarning
116
+ from statsmodels.tools.sm_exceptions import ConvergenceWarning as StatsConvergenceWarning
117
+ warnings.filterwarnings("ignore", category=StatsConvergenceWarning)
118
+ warnings.filterwarnings("ignore", category=ArchConvergenceWarning)
119
+ try:
120
+ best_arima_model = ARIMA(
121
+ window_data + 1e-5, order=final_order, missing='drop').fit()
122
+ return best_arima_model
123
+ except np.linalg.LinAlgError:
124
+ # Catch specific linear algebra errors
125
+ print("LinAlgError occurred, skipping this data point.")
126
+ return None
127
+ except Exception as e:
128
+ # Catch any other unexpected errors and log them
129
+ print(f"An error occurred: {e}")
130
+ return None
131
+
132
+
133
+ def fit_garch(window_data: Union[pd.Series, np.ndarray]):
113
134
  """
114
135
  Fits an `ARIMA` model to the data to get residuals,
115
136
  then fits a `GARCH(1,1)` model on these residuals.
@@ -126,6 +147,8 @@ def fit_garch(window_data: Union[pd.Series , np.ndarray]):
126
147
  object and the `GARCH` result object.
127
148
  """
128
149
  arima_result = fit_best_arima(window_data)
150
+ if arima_result is None:
151
+ return None, None
129
152
  resid = np.asarray(arima_result.resid)
130
153
  resid = resid[~(np.isnan(resid) | np.isinf(resid))]
131
154
  garch_model = arch_model(resid, p=1, q=1, rescale=False)
@@ -148,6 +171,8 @@ def predict_next_return(arima_result, garch_result):
148
171
  Returns:
149
172
  float: The predicted next return, adjusted for predicted volatility.
150
173
  """
174
+ if arima_result is None or garch_result is None:
175
+ return 0
151
176
  # Predict next value with ARIMA
152
177
  arima_pred = arima_result.forecast(steps=1)
153
178
  # Predict next volatility with GARCH
@@ -155,11 +180,13 @@ def predict_next_return(arima_result, garch_result):
155
180
  next_volatility = garch_pred.variance.iloc[-1, 0]
156
181
 
157
182
  # Combine predictions (return + volatility)
158
- next_return = arima_pred.values[0] + next_volatility
159
- return next_return
160
-
183
+ if not isinstance(arima_pred, np.ndarray):
184
+ pred = arima_pred.values[0]
185
+ else:
186
+ pred = arima_pred[0]
187
+ return pred + next_volatility
161
188
 
162
- def get_prediction(window_data: Union[pd.Series , np.ndarray]):
189
+ def get_prediction(window_data: Union[pd.Series, np.ndarray]):
163
190
  """
164
191
  Orchestrator function to get the next period's return prediction.
165
192
 
@@ -182,7 +209,7 @@ def get_prediction(window_data: Union[pd.Series , np.ndarray]):
182
209
  # *********************************************
183
210
  # STATS TEST (Cointegration , Mean Reverting)*
184
211
  # *********************************************
185
- def get_corr(tickers: Union[List[str] , Tuple[str, ...]], start: str, end: str) -> None:
212
+ def get_corr(tickers: Union[List[str], Tuple[str, ...]], start: str, end: str) -> None:
186
213
  """
187
214
  Calculates and prints the correlation matrix of the adjusted closing prices
188
215
  for a given list of stock tickers within a specified date range.
@@ -275,7 +302,7 @@ def plot_residuals(df: pd.DataFrame):
275
302
  plt.show()
276
303
 
277
304
 
278
- def run_cadf_test(pair: Union[List[str] , Tuple[str, ...]], start: str, end: str) -> None:
305
+ def run_cadf_test(pair: Union[List[str], Tuple[str, ...]], start: str, end: str) -> None:
279
306
  """
280
307
  Performs the Cointegration Augmented Dickey-Fuller (CADF) test on a pair of stock tickers
281
308
  over a specified date range to check for cointegration.
@@ -560,7 +587,9 @@ def draw_slope_intercept_changes(prices, state_means):
560
587
  plt.show()
561
588
 
562
589
 
563
- def run_kalman_filter(etfs: Union[List[str] , Tuple[str, ...]], start: str, end: str) -> None:
590
+ def run_kalman_filter(
591
+ etfs: Union[List[str], Tuple[str, ...]],
592
+ start: str, end: str) -> None:
564
593
  """
565
594
  Applies a Kalman filter to a pair of ETF adjusted closing prices within a specified date range
566
595
  to estimate the slope and intercept over time.
@@ -590,3 +619,558 @@ def run_kalman_filter(etfs: Union[List[str] , Tuple[str, ...]], start: str, end:
590
619
  draw_date_coloured_scatterplot(etfs, prices)
591
620
  state_means, state_covs = calc_slope_intercept_kalman(etfs, prices)
592
621
  draw_slope_intercept_changes(prices, state_means)
622
+
623
+
624
+ class ArimaGarchModel():
625
+ """
626
+ This class implements a time serie model
627
+ that combines `ARIMA (AutoRegressive Integrated Moving Average)`
628
+ and `GARCH (Generalized Autoregressive Conditional Heteroskedasticity)` models
629
+ to predict future returns based on historical price data.
630
+
631
+ The model is implemented in the following steps:
632
+ 1. Data Preparation: Load and prepare the historical price data.
633
+ 2. Modeling: Fit the ARIMA model to the data and then fit the GARCH model to the residuals.
634
+ 3. Prediction: Predict the next return using the ARIMA model and the next volatility using the GARCH model.
635
+ 4. Trading Strategy: Execute the trading strategy based on the predictions.
636
+ 5. Vectorized Backtesting: Backtest the trading strategy using the historical data.
637
+
638
+ Exemple:
639
+ >>> import yfinance as yf
640
+ >>> from bbstrader.strategies import ArimaGarchModel
641
+ >>> from bbstrader.tseries import load_and_prepare_data
642
+
643
+ >>> if __name__ == '__main__':
644
+ >>> # ARCH SPY Vectorize Backtest
645
+ >>> k = 252
646
+ >>> data = yf.download("SPY", start="2004-01-02", end="2015-12-31")
647
+ >>> arch = ArimaGarchModel("SPY", data, k=k)
648
+ >>> df = load_and_prepare_data(data)
649
+ >>> arch.show_arima_garch_results(df['diff_log_return'].values[-k:])
650
+ >>> arch.backtest_strategy()
651
+ """
652
+
653
+ def __init__(self, symbol, data, k: int = 252):
654
+ """
655
+ Initializes the ArimaGarchStrategy class.
656
+
657
+ Args:
658
+ symbol (str): The ticker symbol for the financial instrument.
659
+ data (pd.DataFrame): `The raw dataset containing at least the 'Close' prices`.
660
+ k (int): The window size for rolling prediction in backtesting.
661
+ """
662
+ self.symbol = symbol
663
+ self.data = self.load_and_prepare_data(data)
664
+ self.k = k
665
+
666
+ # Step 1: Data Preparation
667
+ def load_and_prepare_data(self, df):
668
+ """
669
+ Prepares the dataset by calculating logarithmic returns
670
+ and differencing if necessary.
671
+
672
+ Args:
673
+ df (pd.DataFrame): `The raw dataset containing at least the 'Close' prices`.
674
+
675
+ Returns:
676
+ pd.DataFrame: The dataset with additional columns
677
+ for log returns and differenced log returns.
678
+ """
679
+ return load_and_prepare_data(df)
680
+
681
+ # Step 2: Modeling (ARIMA + GARCH)
682
+ def fit_best_arima(self, window_data):
683
+ """
684
+ Fits the ARIMA model to the provided window of data,
685
+ selecting the best model based on AIC.
686
+
687
+ Args:
688
+ window_data (np.array): The dataset for a specific window period.
689
+
690
+ Returns:
691
+ ARIMA model: The best fitted ARIMA model based on AIC.
692
+ """
693
+ return fit_best_arima(window_data)
694
+
695
+ def fit_garch(self, window_data):
696
+ """
697
+ Fits the GARCH model to the residuals of the best ARIMA model.
698
+
699
+ Args:
700
+ window_data (np.array): The dataset for a specific window period.
701
+
702
+ Returns:
703
+ tuple: Contains the ARIMA result and GARCH result.
704
+ """
705
+ return fit_garch(window_data)
706
+
707
+ def show_arima_garch_results(self, window_data, acf=True, test_resid=True):
708
+ """
709
+ Displays the ARIMA and GARCH model results, including plotting
710
+ ACF of residuals and conducting , Box-Pierce and Ljung-Box tests.
711
+
712
+ Args:
713
+ window_data (np.array): The dataset for a specific window period.
714
+ acf (bool, optional): If True, plot the ACF of residuals. Defaults to True.
715
+
716
+ test_resid (bool, optional):
717
+ If True, conduct Box-Pierce and Ljung-Box tests on residuals. Defaults to True.
718
+ """
719
+ arima_result = self.fit_best_arima(window_data)
720
+ resid = np.asarray(arima_result.resid)
721
+ resid = resid[~(np.isnan(resid) | np.isinf(resid))]
722
+ garch_model = arch_model(resid, p=1, q=1, rescale=False)
723
+ garch_result = garch_model.fit(disp='off')
724
+ residuals = garch_result.resid
725
+
726
+ # TODO : Plot the ACF of the residuals
727
+ if acf:
728
+ fig = plt.figure(figsize=(12, 8))
729
+ # Plot the ACF of ARIMA residuals
730
+ ax1 = fig.add_subplot(211, ylabel='ACF')
731
+ plot_acf(resid, alpha=0.05, ax=ax1, title='ACF of ARIMA Residuals')
732
+ ax1.set_xlabel('Lags')
733
+ ax1.grid(True)
734
+
735
+ # Plot the ACF of GARCH residuals on the same axes
736
+ ax2 = fig.add_subplot(212, ylabel='ACF')
737
+ plot_acf(residuals, alpha=0.05, ax=ax2,
738
+ title='ACF of GARCH Residuals')
739
+ ax2.set_xlabel('Lags')
740
+ ax2.grid(True)
741
+
742
+ # Plot the figure
743
+ plt.tight_layout()
744
+ plt.show()
745
+
746
+ # TODO : Conduct Box-Pierce and Ljung-Box Tests of the residuals
747
+ if test_resid:
748
+ print(arima_result.summary())
749
+ print(garch_result.summary())
750
+ bp_test = acorr_ljungbox(resid, return_df=True)
751
+ print("Box-Pierce and Ljung-Box Tests Results for ARIMA:\n", bp_test)
752
+
753
+ # Step 3: Prediction
754
+ def predict_next_return(self, arima_result, garch_result):
755
+ """
756
+ Predicts the next return using the ARIMA model
757
+ and the next volatility using the GARCH model.
758
+
759
+ Args:
760
+ arima_result (ARIMA model): The ARIMA model result.
761
+ garch_result (GARCH model): The GARCH model result.
762
+
763
+ Returns:
764
+ float: The predicted next return.
765
+ """
766
+ return predict_next_return(arima_result, garch_result)
767
+
768
+ def get_prediction(self, window_data):
769
+ """
770
+ Generates a prediction for the next return based on a window of data.
771
+
772
+ Args:
773
+ window_data (np.array): The dataset for a specific window period.
774
+
775
+ Returns:
776
+ float: The predicted next return.
777
+ """
778
+ return get_prediction(window_data)
779
+
780
+ def calculate_signals(self, window_data):
781
+ """
782
+ Calculates the trading signal based on the prediction.
783
+
784
+ Args:
785
+ window_data (np.array): The dataset for a specific window period.
786
+
787
+ Returns:
788
+ str: The trading signal ('LONG', 'SHORT', or None).
789
+ """
790
+ prediction = self.get_prediction(window_data)
791
+ if prediction > 0:
792
+ signal = "LONG"
793
+ elif prediction < 0:
794
+ signal = "SHORT"
795
+ else:
796
+ signal = None
797
+ return signal
798
+
799
+ # Step 4: Trading Strategy
800
+
801
+ def execute_trading_strategy(self, predictions):
802
+ """
803
+ Executes the trading strategy based on a list
804
+ of predictions, determining positions to take.
805
+
806
+ Args:
807
+ predictions (list): A list of predicted returns.
808
+
809
+ Returns:
810
+ list: A list of positions (1 for 'LONG', -1 for 'SHORT', 0 for 'HOLD').
811
+ """
812
+ positions = [] # Long if 1, Short if -1
813
+ previous_position = 0 # Initial position
814
+ for prediction in predictions:
815
+ if prediction > 0:
816
+ current_position = 1 # Long
817
+ elif prediction < 0:
818
+ current_position = -1 # Short
819
+ else:
820
+ current_position = previous_position # Hold previous position
821
+ positions.append(current_position)
822
+ previous_position = current_position
823
+
824
+ return positions
825
+
826
+ # Step 5: Vectorized Backtesting
827
+ def generate_predictions(self):
828
+ """
829
+ Generator that yields predictions one by one.
830
+ """
831
+ data = self.data
832
+ window_size = self.k
833
+ for i in range(window_size, len(data)):
834
+ print(
835
+ f"Processing window {i - window_size + 1}/{len(data) - window_size}...")
836
+ window_data = data['diff_log_return'].iloc[i-window_size:i]
837
+ next_return = self.get_prediction(window_data)
838
+ yield next_return
839
+
840
+ def backtest_strategy(self):
841
+ """
842
+ Performs a backtest of the strategy over
843
+ the entire dataset, plotting cumulative returns.
844
+ """
845
+ data = self.data
846
+ window_size = self.k
847
+ print(
848
+ f"Starting backtesting for {self.symbol}\n"
849
+ f"Window size {window_size}.\n"
850
+ f"Total iterations: {len(data) - window_size}.\n")
851
+ predictions_generator = self.generate_predictions()
852
+
853
+ positions = self.execute_trading_strategy(predictions_generator)
854
+
855
+ strategy_returns = np.array(
856
+ positions[:-1]) * data['log_return'].iloc[window_size+1:].values
857
+ buy_and_hold = data['log_return'].iloc[window_size+1:].values
858
+ buy_and_hold_returns = np.cumsum(buy_and_hold)
859
+ cumulative_returns = np.cumsum(strategy_returns)
860
+ dates = data.index[window_size+1:]
861
+ self.plot_cumulative_returns(
862
+ cumulative_returns, buy_and_hold_returns, dates)
863
+
864
+ print("\nBacktesting completed !!")
865
+
866
+ # Function to plot the cumulative returns
867
+ def plot_cumulative_returns(self, strategy_returns, buy_and_hold_returns, dates):
868
+ """
869
+ Plots the cumulative returns of the ARIMA+GARCH strategy against
870
+ a buy-and-hold strategy.
871
+
872
+ Args:
873
+ strategy_returns (np.array): Cumulative returns from the strategy.
874
+ buy_and_hold_returns (np.array): Cumulative returns from a buy-and-hold strategy.
875
+ dates (pd.Index): The dates corresponding to the returns.
876
+ """
877
+ plt.figure(figsize=(14, 7))
878
+ plt.plot(dates, strategy_returns, label='ARIMA+GARCH ', color='blue')
879
+ plt.plot(dates, buy_and_hold_returns, label='Buy & Hold', color='red')
880
+ plt.xlabel('Time')
881
+ plt.ylabel('Cumulative Returns')
882
+ plt.title(f'ARIMA+GARCH Strategy vs. Buy & Hold on ({self.symbol})')
883
+ plt.legend()
884
+ plt.grid(True)
885
+ plt.show()
886
+
887
+
888
+ class OrnsteinUhlenbeck():
889
+ """
890
+ The Ornstein-Uhlenbeck process is a mathematical model
891
+ used to describe the behavior of a mean-reverting stochastic process.
892
+ We use it to model the price dynamics of an asset that tends
893
+ to revert to a long-term mean.
894
+
895
+ We Estimate the drift (θ), volatility (σ), and long-term mean (μ)
896
+ based on historical price data; then we Simulate the OU process
897
+ using the estimated parameters.
898
+
899
+ https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process
900
+ """
901
+
902
+ def __init__(
903
+ self, prices: np.ndarray,
904
+ returns: bool = True, timeframe: str = "D1"
905
+ ):
906
+ """
907
+ Initializes the OrnsteinUhlenbeck instance.
908
+
909
+ Args:
910
+ prices (np.ndarray) : Historical close prices.
911
+
912
+ retrurns (bool) : Use it to indicate weither
913
+ you want to simulate the returns or your raw data
914
+
915
+ timeframe (str) : The time frame for the Historical prices
916
+ (1m, 5m, 15m, 30m, 1h, 4h, D1)
917
+ """
918
+ self.prices = prices
919
+ if returns:
920
+ series = pd.Series(self.prices)
921
+ self.returns = series.pct_change().dropna().values
922
+ else:
923
+ self.returns = self.prices
924
+
925
+ time_frame_mapping = {
926
+ '1m': 1 / (24 * 60), # 1 minute intervals
927
+ '5m': 5 / (24 * 60), # 5 minute intervals
928
+ '15m': 15 / (24 * 60), # 15 minute intervals
929
+ '30m': 30 / (24 * 60), # 30 minute intervals
930
+ '1h': 1 / 24, # 1 hour intervals
931
+ '4h': 4 / 24, # 4 hour intervals
932
+ 'D1': 1, # Daily intervals
933
+ }
934
+ if timeframe not in time_frame_mapping:
935
+ raise ValueError("Unsupported time frame")
936
+ self.tf = time_frame_mapping[timeframe]
937
+
938
+ params = self.estimate_parameters()
939
+ self.mu_hat = params[0] # Mean (μ)
940
+ self.theta_hat = params[1] # Drift (θ)
941
+ self.sigma_hat = params[2] # Volatility (σ)
942
+ print(f'Estimated μ: {self.mu_hat}')
943
+ print(f'Estimated θ: {self.theta_hat}')
944
+ print(f'Estimated σ: {self.sigma_hat}')
945
+
946
+ def ornstein_uhlenbeck(self, mu, theta, sigma, dt, X0, n):
947
+ """
948
+ Simulates the Ornstein-Uhlenbeck process.
949
+
950
+ Args:
951
+ mu (float): Estimated long-term mean.
952
+ theta (float): Estimated drift.
953
+ sigma (float): Estimated volatility.
954
+ dt (float): Time step.
955
+ X0 (float): Initial value.
956
+ n (int): Number of time steps.
957
+
958
+ Returns:
959
+ np.ndarray : Simulated process.
960
+ """
961
+ x = np.zeros(n)
962
+ x[0] = X0
963
+ for t in range(1, n):
964
+ dW = np.random.normal(loc=0, scale=np.sqrt(dt))
965
+ # O-U process differential equation
966
+ x[t] = x[t-1] + (theta * (mu - x[t-1]) * dt) + (sigma * dW)
967
+ # dW is a Wiener process
968
+ # (theta * (mu - x[t-1]) * dt) represents the mean-reverting tendency
969
+ # (sigma * dW) represents the random volatility
970
+ return x
971
+
972
+ def estimate_parameters(self):
973
+ """
974
+ Estimates the mean-reverting parameters (μ, θ, σ)
975
+ using the negative log-likelihood.
976
+
977
+ Returns:
978
+ Tuple: Estimated μ, θ, and σ.
979
+ """
980
+ initial_guess = [0, 0.1, np.std(self.returns)]
981
+ result = minimize(
982
+ self._neg_log_likelihood, initial_guess, args=(self.returns,)
983
+ )
984
+ mu, theta, sigma = result.x
985
+ return mu, theta, sigma
986
+
987
+ def _neg_log_likelihood(self, params, returns):
988
+ """
989
+ Calculates the negative
990
+ log-likelihood for parameter estimation.
991
+
992
+ Args:
993
+ params (list): List of parameters [mu, theta, sigma].
994
+ returns (np.ndarray): Historical returns.
995
+
996
+ Returns:
997
+ float: Negative log-likelihood.
998
+ """
999
+ mu, theta, sigma = params
1000
+ dt = self.tf
1001
+ n = len(returns)
1002
+ ou_simulated = self.ornstein_uhlenbeck(
1003
+ mu, theta, sigma, dt, 0, n + 1
1004
+ )
1005
+ residuals = ou_simulated[1:n + 1] - returns
1006
+ neg_ll = 0.5 * np.sum(
1007
+ residuals**2
1008
+ ) / sigma**2 + 0.5 * n * np.log(2 * np.pi * sigma**2)
1009
+ return neg_ll
1010
+
1011
+ def simulate_process(self, rts=None, n=100, p=None):
1012
+ """
1013
+ Simulates the OU process multiple times .
1014
+
1015
+ Args:
1016
+ rts (np.ndarray): Historical returns.
1017
+ n (int): Number of simulations to perform.
1018
+ p (int): Number of time steps.
1019
+
1020
+ Returns:
1021
+ np.ndarray: 2D array representing simulated processes.
1022
+ """
1023
+ if rts is not None:
1024
+ returns = rts
1025
+ else:
1026
+ returns = self.returns
1027
+ if p is not None:
1028
+ T = p
1029
+ else:
1030
+ T = len(returns)
1031
+ dt = self.tf
1032
+
1033
+ dW_matrix = np.random.normal(
1034
+ loc=0, scale=np.sqrt(dt), size=(n, T)
1035
+ )
1036
+ simulations_matrix = np.zeros((n, T))
1037
+ simulations_matrix[:, 0] = returns[-1]
1038
+
1039
+ for t in range(1, T):
1040
+ simulations_matrix[:, t] = (
1041
+ simulations_matrix[:, t-1] +
1042
+ self.theta_hat * (
1043
+ self.mu_hat - simulations_matrix[:, t-1]) * dt +
1044
+ self.sigma_hat * dW_matrix[:, t]
1045
+ )
1046
+ return simulations_matrix
1047
+
1048
+
1049
+ class KalmanFilterModel():
1050
+ """
1051
+ Implements a Kalman Filter model a recursive algorithm used for estimating
1052
+ the state of a linear dynamic system from a series of noisy measurements.
1053
+ It's designed to process market data, estimate dynamic parameters such as
1054
+ the slope and intercept of price relationships,
1055
+ forecast error and standard deviation of the predictions
1056
+
1057
+ You can learn more here https://en.wikipedia.org/wiki/Kalman_filter
1058
+ """
1059
+
1060
+ def __init__(self, tickers: list | tuple, **kwargs):
1061
+ """
1062
+ Initializes the Kalman Filter strategy.
1063
+
1064
+ Args:
1065
+ tickers :
1066
+ A list or tuple of ticker symbols representing financial instruments.
1067
+
1068
+ kwargs : Keyword arguments for additional parameters,
1069
+ specifically `delta` and `vt`
1070
+ """
1071
+ self.tickers = tickers
1072
+ assert self.tickers is not None
1073
+ self.latest_prices = np.array([-1.0, -1.0])
1074
+ self.delta = kwargs.get("delta", 1e-4)
1075
+ self.wt = self.delta/(1-self.delta) * np.eye(2)
1076
+ self.vt = kwargs.get("vt", 1e-3)
1077
+ self.theta = np.zeros(2)
1078
+ self.P = np.zeros((2, 2))
1079
+ self.R = None
1080
+ self.kf = self._init_kalman()
1081
+
1082
+ def _init_kalman(self):
1083
+ """
1084
+ Initializes and returns a Kalman Filter configured
1085
+ for the trading strategy. The filter is set up with initial
1086
+ state and covariance, state transition matrix, process noise
1087
+ and measurement noise covariances.
1088
+ """
1089
+ kf = KalmanFilter(dim_x=2, dim_z=1)
1090
+ kf.x = np.zeros((2, 1)) # Initial state
1091
+ kf.P = self.P # Initial covariance
1092
+ kf.F = np.eye(2) # State transition matrix
1093
+ kf.Q = self.wt # Process noise covariance
1094
+ kf.R = 1. # Scalar measurement noise covariance
1095
+
1096
+ return kf
1097
+
1098
+ def calc_slope_intercep(self, prices: np.ndarray):
1099
+ """
1100
+ Calculates and returns the slope and intercept
1101
+ of the relationship between the provided prices using the Kalman Filter.
1102
+ This method updates the filter with the latest price and returns
1103
+ the estimated slope and intercept.
1104
+
1105
+ Args:
1106
+ prices : A numpy array of prices for two financial instruments.
1107
+
1108
+ Returns:
1109
+ A tuple containing the slope and intercept of the relationship
1110
+ """
1111
+ kf = self.kf
1112
+ kf.H = np.array([[prices[1], 1.0]])
1113
+ kf.predict()
1114
+ kf.update(prices[0])
1115
+ slope = kf.x.copy().flatten()[0]
1116
+ intercept = kf.x.copy().flatten()[1]
1117
+
1118
+ return slope, intercept
1119
+
1120
+ def calculate_etqt(self, prices: np.ndarray):
1121
+ """
1122
+ Calculates the forecast error and standard deviation of the predictions
1123
+ using the Kalman Filter.
1124
+
1125
+ Args:
1126
+ prices : A numpy array of prices for two financial instruments.
1127
+
1128
+ Returns:
1129
+ A tuple containing the forecast error and standard deviation of the predictions.
1130
+ """
1131
+
1132
+ self.latest_prices[0] = prices[0]
1133
+ self.latest_prices[1] = prices[1]
1134
+
1135
+ if all(self.latest_prices > -1.0):
1136
+ slope, intercept = self.calc_slope_intercep(self.latest_prices)
1137
+
1138
+ self.theta[0] = slope
1139
+ self.theta[1] = intercept
1140
+
1141
+ # Create the observation matrix of the latest prices
1142
+ # of Y and the intercept value (1.0) as well as the
1143
+ # scalar value of the latest price from X
1144
+ F = np.asarray([self.latest_prices[0], 1.0]).reshape((1, 2))
1145
+ y = self.latest_prices[1]
1146
+
1147
+ # The prior value of the states {\theta_t} is
1148
+ # distributed as a multivariate Gaussian with
1149
+ # mean a_t and variance-covariance {R_t}
1150
+ if self.R is not None:
1151
+ self.R = self.C + self.wt
1152
+ else:
1153
+ self.R = np.zeros((2, 2))
1154
+
1155
+ # Calculate the Kalman Filter update
1156
+ # ---------------------------------
1157
+ # Calculate prediction of new observation
1158
+ # as well as forecast error of that prediction
1159
+ yhat = F.dot(self.theta)
1160
+ et = y - yhat
1161
+
1162
+ # {Q_t} is the variance of the prediction of
1163
+ # observations and hence sqrt_Qt is the
1164
+ # standard deviation of the predictions
1165
+ Qt = F.dot(self.R).dot(F.T) + self.vt
1166
+ sqrt_Qt = np.sqrt(Qt)
1167
+
1168
+ # The posterior value of the states {\theta_t} is
1169
+ # distributed as a multivariate Gaussian with mean
1170
+ # {m_t} and variance-covariance {C_t}
1171
+ At = self.R.dot(F.T) / Qt
1172
+ self.theta = self.theta + At.flatten() * et
1173
+ self.C = self.R - At * F.dot(self.R)
1174
+ return (et, sqrt_Qt)
1175
+ else:
1176
+ return None