stratpy-lib 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stratpy_lib-0.1.0/LICENSE +21 -0
- stratpy_lib-0.1.0/PKG-INFO +23 -0
- stratpy_lib-0.1.0/README.md +2 -0
- stratpy_lib-0.1.0/pyproject.toml +35 -0
- stratpy_lib-0.1.0/setup.cfg +4 -0
- stratpy_lib-0.1.0/src/stratpy/__init__.py +37 -0
- stratpy_lib-0.1.0/src/stratpy/backtest.py +129 -0
- stratpy_lib-0.1.0/src/stratpy/data.py +76 -0
- stratpy_lib-0.1.0/src/stratpy/indicators.py +179 -0
- stratpy_lib-0.1.0/src/stratpy/strategies.py +161 -0
- stratpy_lib-0.1.0/src/stratpy/utils.py +32 -0
- stratpy_lib-0.1.0/src/stratpy_lib.egg-info/PKG-INFO +23 -0
- stratpy_lib-0.1.0/src/stratpy_lib.egg-info/SOURCES.txt +15 -0
- stratpy_lib-0.1.0/src/stratpy_lib.egg-info/dependency_links.txt +1 -0
- stratpy_lib-0.1.0/src/stratpy_lib.egg-info/requires.txt +9 -0
- stratpy_lib-0.1.0/src/stratpy_lib.egg-info/top_level.txt +1 -0
- stratpy_lib-0.1.0/tests/test_indicators.py +77 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 quantbirrd
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: stratpy-lib
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A modular python library for building algorithmic trading strategies.
|
|
5
|
+
Author-email: Albert Akinola <albert.akinola@outlook.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.8
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: pandas
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: yfinance
|
|
15
|
+
Requires-Dist: matplotlib
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
|
18
|
+
Requires-Dist: build; extra == "dev"
|
|
19
|
+
Requires-Dist: twine; extra == "dev"
|
|
20
|
+
Dynamic: license-file
|
|
21
|
+
|
|
22
|
+
# stratpy
|
|
23
|
+
A modular Python library for building and testing algorithmic trading strategies quickly, without writing tons of code from scratch
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
[project]
|
|
7
|
+
name = "stratpy-lib"
|
|
8
|
+
version = "0.1.0"
|
|
9
|
+
authors = [
|
|
10
|
+
{ name="Albert Akinola", email="albert.akinola@outlook.com" },
|
|
11
|
+
]
|
|
12
|
+
description = "A modular python library for building algorithmic trading strategies."
|
|
13
|
+
readme = "README.md"
|
|
14
|
+
requires-python = ">=3.8"
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
]
|
|
20
|
+
dependencies = [
|
|
21
|
+
"pandas",
|
|
22
|
+
"numpy",
|
|
23
|
+
"yfinance",
|
|
24
|
+
"matplotlib"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.optional-dependencies]
|
|
28
|
+
dev = [
|
|
29
|
+
"pytest",
|
|
30
|
+
"build",
|
|
31
|
+
"twine"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[tool.setuptools.packages.find]
|
|
35
|
+
where = ["src"]
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Stratpy: A modular Python library for building and testing algorithmic trading strategies.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.1.0"
|
|
6
|
+
|
|
7
|
+
# Import data functions
|
|
8
|
+
from .data import data, clean
|
|
9
|
+
|
|
10
|
+
# Import indicators
|
|
11
|
+
from .indicators import sma, ema, macd, rsi, bb, atr, vwap
|
|
12
|
+
|
|
13
|
+
# Import strategies (classes and legacy wrappers)
|
|
14
|
+
from .strategies import BaseStrategy, MACrossover, RSIReversion, BollingerBreakout, mac_strategy, rsi_reversion
|
|
15
|
+
|
|
16
|
+
# Import backtesting engine
|
|
17
|
+
from .backtest import runstrat
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"data",
|
|
21
|
+
"clean",
|
|
22
|
+
"sma",
|
|
23
|
+
"ema",
|
|
24
|
+
"macd",
|
|
25
|
+
"rsi",
|
|
26
|
+
"bb",
|
|
27
|
+
"atr",
|
|
28
|
+
"vwap",
|
|
29
|
+
"BaseStrategy",
|
|
30
|
+
"MACrossover",
|
|
31
|
+
"RSIReversion",
|
|
32
|
+
"BollingerBreakout",
|
|
33
|
+
"mac_strategy",
|
|
34
|
+
"rsi_reversion",
|
|
35
|
+
"runstrat",
|
|
36
|
+
"__version__",
|
|
37
|
+
]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import matplotlib.dates as mdates
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from .utils import validate_dataframe
|
|
7
|
+
|
|
8
|
+
def runstrat(df: pd.DataFrame, column: str = 'Close', plot_path: Optional[str] = None) -> pd.DataFrame:
|
|
9
|
+
"""
|
|
10
|
+
Simulates the strategy execution, prints performance metrics (including Sharpe Ratio
|
|
11
|
+
and Maximum Drawdown), and outputs a polished Matplotlib comparison graph.
|
|
12
|
+
|
|
13
|
+
Parameters:
|
|
14
|
+
df (pd.DataFrame): DataFrame with historical price data and generated 'Signal' column.
|
|
15
|
+
column (str): The column containing asset prices. Defaults to 'Close'.
|
|
16
|
+
plot_path (Optional[str]): If provided, saves the performance plot to this file path.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
pd.DataFrame: A copy of the input DataFrame containing returns and cumulative metrics.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If 'Signal' or the specified price column is missing.
|
|
23
|
+
"""
|
|
24
|
+
# Validate that both the price and Signal columns are present
|
|
25
|
+
validate_dataframe(df, [column, 'Signal'], "runstrat")
|
|
26
|
+
|
|
27
|
+
df = df.copy()
|
|
28
|
+
|
|
29
|
+
# 1. Calculate daily percentage changes of the underlying market
|
|
30
|
+
df['Market_Returns'] = df[column].pct_change()
|
|
31
|
+
|
|
32
|
+
# 2. Calculate strategy returns (shift Signal by 1 to prevent look-ahead bias)
|
|
33
|
+
df['Strategy_Returns'] = df['Signal'].shift(1) * df['Market_Returns']
|
|
34
|
+
|
|
35
|
+
# 3. Calculate cumulative returns starting at 1.0 (0% gain)
|
|
36
|
+
df['Cumulative_Market'] = (1 + df['Market_Returns'].fillna(0)).cumprod()
|
|
37
|
+
df['Cumulative_Strategy'] = (1 + df['Strategy_Returns'].fillna(0)).cumprod()
|
|
38
|
+
|
|
39
|
+
# 4. Extract total returns (in percent)
|
|
40
|
+
total_market_return = (df['Cumulative_Market'].iloc[-1] - 1) * 100
|
|
41
|
+
total_strategy_return = (df['Cumulative_Strategy'].iloc[-1] - 1) * 100
|
|
42
|
+
|
|
43
|
+
# 5. Calculate Annualized Sharpe Ratio (assuming risk-free rate of 0)
|
|
44
|
+
# Using 252 standard trading days to annualize the daily standard deviation
|
|
45
|
+
mkt_daily_ret = df['Market_Returns'].dropna()
|
|
46
|
+
strat_daily_ret = df['Strategy_Returns'].dropna()
|
|
47
|
+
|
|
48
|
+
mkt_std = mkt_daily_ret.std()
|
|
49
|
+
strat_std = strat_daily_ret.std()
|
|
50
|
+
|
|
51
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
52
|
+
market_sharpe = 0.0 if mkt_std == 0 or np.isnan(mkt_std) else (mkt_daily_ret.mean() / mkt_std) * np.sqrt(252)
|
|
53
|
+
strat_sharpe = 0.0 if strat_std == 0 or np.isnan(strat_std) else (strat_daily_ret.mean() / strat_std) * np.sqrt(252)
|
|
54
|
+
|
|
55
|
+
# 6. Calculate Maximum Drawdown
|
|
56
|
+
cum_market = df['Cumulative_Market']
|
|
57
|
+
running_max_mkt = cum_market.cummax()
|
|
58
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
59
|
+
drawdown_mkt = np.where(running_max_mkt <= 0, 0.0, (cum_market / running_max_mkt) - 1.0)
|
|
60
|
+
max_dd_mkt = drawdown_mkt.min()
|
|
61
|
+
|
|
62
|
+
cum_strat = df['Cumulative_Strategy']
|
|
63
|
+
running_max_strat = cum_strat.cummax()
|
|
64
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
65
|
+
drawdown_strat = np.where(running_max_strat <= 0, 0.0, (cum_strat / running_max_strat) - 1.0)
|
|
66
|
+
max_dd_strat = drawdown_strat.min()
|
|
67
|
+
|
|
68
|
+
# --- Format & Print ASCII Output Table ---
|
|
69
|
+
mkt_ret_str = f"{total_market_return:+.2f}%"
|
|
70
|
+
strat_ret_str = f"{total_strategy_return:+.2f}%"
|
|
71
|
+
mkt_sharpe_str = f"{market_sharpe:.2f}"
|
|
72
|
+
strat_sharpe_str = f"{strat_sharpe:.2f}"
|
|
73
|
+
mkt_dd_str = f"{max_dd_mkt * 100:+.2f}%"
|
|
74
|
+
strat_dd_str = f"{max_dd_strat * 100:+.2f}%"
|
|
75
|
+
|
|
76
|
+
print("\n" + "="*50)
|
|
77
|
+
print(" STRATPY BACKTEST RESULTS")
|
|
78
|
+
print("="*50)
|
|
79
|
+
print(f"{'Metric':<20}{'Market Buy & Hold':<18}{'Strategy':<12}")
|
|
80
|
+
print("-"*50)
|
|
81
|
+
print(f"{'Total Return':<20}{mkt_ret_str:<18}{strat_ret_str:<12}")
|
|
82
|
+
print(f"{'Sharpe Ratio':<20}{mkt_sharpe_str:<18}{strat_sharpe_str:<12}")
|
|
83
|
+
print(f"{'Max Drawdown':<20}{mkt_dd_str:<18}{strat_dd_str:<12}")
|
|
84
|
+
print("="*50 + "\n")
|
|
85
|
+
|
|
86
|
+
# --- Plot the Results ---
|
|
87
|
+
fig, ax = plt.subplots(figsize=(12, 6), dpi=100)
|
|
88
|
+
|
|
89
|
+
# Set a clean face color
|
|
90
|
+
ax.set_facecolor('#f8fafc')
|
|
91
|
+
|
|
92
|
+
# Plot Market (semi-transparent gray/slate)
|
|
93
|
+
ax.plot(df.index, df['Cumulative_Market'], label='Market Buy & Hold', color='#94a3b8', alpha=0.8, linewidth=1.5)
|
|
94
|
+
|
|
95
|
+
# Plot Strategy (bold, premium steel blue)
|
|
96
|
+
ax.plot(df.index, df['Cumulative_Strategy'], label='Stratpy Strategy', color='#0f766e', linewidth=2.5)
|
|
97
|
+
|
|
98
|
+
# Set labels and title
|
|
99
|
+
ax.set_title('Stratpy: Cumulative Performance vs. Market Buy & Hold', fontsize=13, fontweight='bold', pad=15, color='#1e293b')
|
|
100
|
+
ax.set_xlabel('Date', fontsize=10, labelpad=10, color='#334155')
|
|
101
|
+
ax.set_ylabel('Growth of $1.00 (Cumulative Return)', fontsize=10, labelpad=10, color='#334155')
|
|
102
|
+
|
|
103
|
+
# Format dates beautifully if DataFrame index is DatetimeIndex
|
|
104
|
+
if isinstance(df.index, pd.DatetimeIndex):
|
|
105
|
+
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
|
|
106
|
+
fig.autofmt_xdate(rotation=30)
|
|
107
|
+
|
|
108
|
+
# Add subtle dashed grids
|
|
109
|
+
ax.grid(True, linestyle='--', alpha=0.5, color='#cbd5e1')
|
|
110
|
+
|
|
111
|
+
# Remove top and right borders for a cleaner modern aesthetic
|
|
112
|
+
for spine in ['top', 'right']:
|
|
113
|
+
ax.spines[spine].set_visible(False)
|
|
114
|
+
ax.spines['left'].set_color('#cbd5e1')
|
|
115
|
+
ax.spines['bottom'].set_color('#cbd5e1')
|
|
116
|
+
|
|
117
|
+
# Configure premium looking legend
|
|
118
|
+
ax.legend(loc='upper left', frameon=True, facecolor='white', edgecolor='#e2e8f0', fontsize=9)
|
|
119
|
+
|
|
120
|
+
plt.tight_layout()
|
|
121
|
+
|
|
122
|
+
# Save image if path is provided
|
|
123
|
+
if plot_path:
|
|
124
|
+
plt.savefig(plot_path, dpi=150)
|
|
125
|
+
print(f"Stratpy: Performance plot saved successfully to: {plot_path}")
|
|
126
|
+
|
|
127
|
+
plt.show()
|
|
128
|
+
|
|
129
|
+
return df
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import yfinance as yf
|
|
3
|
+
import os
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
def data(ticker: str, period: str = '1y', interval: str = '1d') -> pd.DataFrame:
|
|
7
|
+
"""
|
|
8
|
+
Pulls historical market data from Yahoo Finance.
|
|
9
|
+
|
|
10
|
+
Parameters:
|
|
11
|
+
ticker (str): The stock or crypto ticker symbol (e.g., 'AAPL', 'BTC-USD').
|
|
12
|
+
period (str): The time period to pull (e.g., '1mo', '1y', 'max').
|
|
13
|
+
interval (str): The timeframe of the candles (e.g., '1d', '1h', '15m').
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
pd.DataFrame: A formatted pandas DataFrame containing Open, High, Low, Close, Volume.
|
|
17
|
+
|
|
18
|
+
Raises:
|
|
19
|
+
ValueError: If ticker is empty or if no data was found for the ticker symbol.
|
|
20
|
+
"""
|
|
21
|
+
if not ticker or not isinstance(ticker, str):
|
|
22
|
+
raise ValueError("Stratpy Error: Ticker symbol must be a non-empty string.")
|
|
23
|
+
|
|
24
|
+
print(f"Stratpy: Fetching {period} of {interval} data for {ticker}...")
|
|
25
|
+
|
|
26
|
+
stock = yf.Ticker(ticker)
|
|
27
|
+
df = stock.history(period=period, interval=interval)
|
|
28
|
+
|
|
29
|
+
# Catch errors if the user types an invalid ticker symbol
|
|
30
|
+
if df.empty:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"No data found for ticker '{ticker}'. Please check the symbol, period, "
|
|
33
|
+
f"and interval, and try again."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
return df
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def clean(file_path: Union[str, os.PathLike]) -> pd.DataFrame:
|
|
40
|
+
"""
|
|
41
|
+
Loads a local CSV file, handles missing data, and sets the Date index.
|
|
42
|
+
|
|
43
|
+
Parameters:
|
|
44
|
+
file_path (str or PathLike): The relative or absolute path to the CSV file.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
pd.DataFrame: A cleaned pandas DataFrame ready for Stratpy indicators.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
FileNotFoundError: If the file path does not point to a valid file.
|
|
51
|
+
ValueError: If the file path is empty or invalid.
|
|
52
|
+
"""
|
|
53
|
+
if not file_path:
|
|
54
|
+
raise ValueError("Stratpy Error: File path must not be empty.")
|
|
55
|
+
|
|
56
|
+
# Check if the file actually exists to prevent traceback errors
|
|
57
|
+
if not os.path.exists(file_path):
|
|
58
|
+
raise FileNotFoundError(f"Stratpy Error: The file '{file_path}' was not found.")
|
|
59
|
+
|
|
60
|
+
print(f"Stratpy: Cleaning {file_path}...")
|
|
61
|
+
df = pd.read_csv(file_path)
|
|
62
|
+
|
|
63
|
+
# Forward fill missing values (carry the last known price forward), then drop any remaining NaNs
|
|
64
|
+
df = df.copy()
|
|
65
|
+
df = df.ffill().dropna()
|
|
66
|
+
|
|
67
|
+
# Automatically find the date column, format it, and set it as the index
|
|
68
|
+
date_columns = ['Date', 'date', 'Datetime', 'datetime', 'Timestamp', 'timestamp']
|
|
69
|
+
for col in date_columns:
|
|
70
|
+
if col in df.columns:
|
|
71
|
+
df[col] = pd.to_datetime(df[col])
|
|
72
|
+
df.set_index(col, inplace=True)
|
|
73
|
+
break # Stop searching once we find the date column
|
|
74
|
+
|
|
75
|
+
return df
|
|
76
|
+
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from .utils import validate_dataframe
|
|
4
|
+
|
|
5
|
+
def sma(df: pd.DataFrame, window: int = 20, column: str = 'Close') -> pd.DataFrame:
|
|
6
|
+
"""
|
|
7
|
+
Calculates the Simple Moving Average (SMA) of a specified column.
|
|
8
|
+
|
|
9
|
+
Parameters:
|
|
10
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
11
|
+
window (int): The number of periods to look back for the average.
|
|
12
|
+
column (str): The column to calculate the SMA on. Defaults to 'Close'.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
pd.DataFrame: A new DataFrame copy with the added SMA column.
|
|
16
|
+
"""
|
|
17
|
+
validate_dataframe(df, [column], "sma")
|
|
18
|
+
df = df.copy()
|
|
19
|
+
df[f'SMA_{window}'] = df[column].rolling(window=window).mean()
|
|
20
|
+
return df
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def ema(df: pd.DataFrame, window: int = 20, column: str = 'Close') -> pd.DataFrame:
|
|
24
|
+
"""
|
|
25
|
+
Calculates the Exponential Moving Average (EMA) of a specified column.
|
|
26
|
+
|
|
27
|
+
Parameters:
|
|
28
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
29
|
+
window (int): The decay/span period for the exponential average.
|
|
30
|
+
column (str): The column to calculate the EMA on. Defaults to 'Close'.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
pd.DataFrame: A new DataFrame copy with the added EMA column.
|
|
34
|
+
"""
|
|
35
|
+
validate_dataframe(df, [column], "ema")
|
|
36
|
+
df = df.copy()
|
|
37
|
+
df[f'EMA_{window}'] = df[column].ewm(span=window, adjust=False).mean()
|
|
38
|
+
return df
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def macd(df: pd.DataFrame, fast: int = 12, slow: int = 26, signal: int = 9, column: str = 'Close') -> pd.DataFrame:
|
|
42
|
+
"""
|
|
43
|
+
Calculates the Moving Average Convergence Divergence (MACD) indicators.
|
|
44
|
+
|
|
45
|
+
Parameters:
|
|
46
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
47
|
+
fast (int): The span for the fast EMA. Defaults to 12.
|
|
48
|
+
slow (int): The span for the slow EMA. Defaults to 26.
|
|
49
|
+
signal (int): The span for the signal line EMA. Defaults to 9.
|
|
50
|
+
column (str): The column to run MACD on. Defaults to 'Close'.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
pd.DataFrame: A new DataFrame copy containing 'MACD', 'MACD_Signal', and 'MACD_Hist'.
|
|
54
|
+
"""
|
|
55
|
+
validate_dataframe(df, [column], "macd")
|
|
56
|
+
df = df.copy()
|
|
57
|
+
|
|
58
|
+
ema_fast = df[column].ewm(span=fast, adjust=False).mean()
|
|
59
|
+
ema_slow = df[column].ewm(span=slow, adjust=False).mean()
|
|
60
|
+
|
|
61
|
+
df['MACD'] = ema_fast - ema_slow
|
|
62
|
+
df['MACD_Signal'] = df['MACD'].ewm(span=signal, adjust=False).mean()
|
|
63
|
+
df['MACD_Hist'] = df['MACD'] - df['MACD_Signal']
|
|
64
|
+
return df
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def rsi(df: pd.DataFrame, window: int = 14, column: str = 'Close') -> pd.DataFrame:
|
|
68
|
+
"""
|
|
69
|
+
Calculates the Relative Strength Index (RSI) using standard pandas/numpy vectorization.
|
|
70
|
+
Safely handles division by zero if all values in the window are identical.
|
|
71
|
+
|
|
72
|
+
Parameters:
|
|
73
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
74
|
+
window (int): The number of periods to look back for gain/loss averages.
|
|
75
|
+
column (str): The column to run RSI on. Defaults to 'Close'.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
pd.DataFrame: A new DataFrame copy with the added RSI column.
|
|
79
|
+
"""
|
|
80
|
+
validate_dataframe(df, [column], "rsi")
|
|
81
|
+
df = df.copy()
|
|
82
|
+
delta = df[column].diff()
|
|
83
|
+
|
|
84
|
+
# Capture positive gains and negative losses separately
|
|
85
|
+
gain = (delta.where(delta > 0, 0.0)).rolling(window=window).mean()
|
|
86
|
+
loss = (-delta.where(delta < 0, 0.0)).rolling(window=window).mean()
|
|
87
|
+
|
|
88
|
+
# Calculate RSI using total change to avoid division by zero (when loss is 0)
|
|
89
|
+
# Formula: RSI = 100 * gain / (gain + loss)
|
|
90
|
+
# If total change is 0 (price stayed flat), RSI is defined as 50.0.
|
|
91
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
92
|
+
total_change = gain + loss
|
|
93
|
+
df[f'RSI_{window}'] = np.where(
|
|
94
|
+
total_change == 0,
|
|
95
|
+
50.0,
|
|
96
|
+
100.0 * gain / total_change
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return df
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def bb(df: pd.DataFrame, window: int = 20, num_std: int = 2, column: str = 'Close') -> pd.DataFrame:
|
|
103
|
+
"""
|
|
104
|
+
Calculates Bollinger Bands (Middle, Upper, and Lower).
|
|
105
|
+
|
|
106
|
+
Parameters:
|
|
107
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
108
|
+
window (int): The number of periods for the middle moving average.
|
|
109
|
+
num_std (int): The number of standard deviations to apply to upper and lower bands.
|
|
110
|
+
column (str): The column to run Bollinger Bands on. Defaults to 'Close'.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
pd.DataFrame: A new DataFrame copy with 'BB_Mid', 'BB_Upper', and 'BB_Lower'.
|
|
114
|
+
"""
|
|
115
|
+
validate_dataframe(df, [column], "bb")
|
|
116
|
+
df = df.copy()
|
|
117
|
+
|
|
118
|
+
rolling_mean = df[column].rolling(window=window).mean()
|
|
119
|
+
rolling_std = df[column].rolling(window=window).std()
|
|
120
|
+
|
|
121
|
+
df['BB_Mid'] = rolling_mean
|
|
122
|
+
df['BB_Upper'] = rolling_mean + (rolling_std * num_std)
|
|
123
|
+
df['BB_Lower'] = rolling_mean - (rolling_std * num_std)
|
|
124
|
+
return df
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def atr(df: pd.DataFrame, window: int = 14) -> pd.DataFrame:
|
|
128
|
+
"""
|
|
129
|
+
Calculates the Average True Range (ATR), a measure of volatility.
|
|
130
|
+
|
|
131
|
+
Parameters:
|
|
132
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
133
|
+
window (int): The period to smooth the True Range over.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
pd.DataFrame: A new DataFrame copy with the added ATR column.
|
|
137
|
+
"""
|
|
138
|
+
validate_dataframe(df, ['High', 'Low', 'Close'], "atr")
|
|
139
|
+
df = df.copy()
|
|
140
|
+
|
|
141
|
+
high_low = df['High'] - df['Low']
|
|
142
|
+
high_close = np.abs(df['High'] - df['Close'].shift())
|
|
143
|
+
low_close = np.abs(df['Low'] - df['Close'].shift())
|
|
144
|
+
|
|
145
|
+
# Calculate True Range as the maximum of three metrics
|
|
146
|
+
ranges = pd.concat([high_low, high_close, low_close], axis=1)
|
|
147
|
+
true_range = np.max(ranges, axis=1)
|
|
148
|
+
|
|
149
|
+
df[f'ATR_{window}'] = true_range.rolling(window=window).mean()
|
|
150
|
+
return df
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def vwap(df: pd.DataFrame) -> pd.DataFrame:
|
|
154
|
+
"""
|
|
155
|
+
Calculates the Volume Weighted Average Price (VWAP).
|
|
156
|
+
Safely handles division by zero if total volume is zero.
|
|
157
|
+
|
|
158
|
+
Parameters:
|
|
159
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
pd.DataFrame: A new DataFrame copy with the added 'VWAP' column.
|
|
163
|
+
"""
|
|
164
|
+
validate_dataframe(df, ['High', 'Low', 'Close', 'Volume'], "vwap")
|
|
165
|
+
df = df.copy()
|
|
166
|
+
|
|
167
|
+
typical_price = (df['High'] + df['Low'] + df['Close']) / 3
|
|
168
|
+
cum_volume = df['Volume'].cumsum()
|
|
169
|
+
|
|
170
|
+
# Safely divide typical price sum by cumulative volume
|
|
171
|
+
with np.errstate(divide='ignore', invalid='ignore'):
|
|
172
|
+
df['VWAP'] = np.where(
|
|
173
|
+
cum_volume == 0,
|
|
174
|
+
np.nan,
|
|
175
|
+
(typical_price * df['Volume']).cumsum() / cum_volume
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return df
|
|
179
|
+
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from .utils import validate_dataframe
|
|
5
|
+
from .indicators import bb, rsi
|
|
6
|
+
|
|
7
|
+
class BaseStrategy(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Abstract Base Class for all trading strategies.
|
|
10
|
+
|
|
11
|
+
All strategy implementations must define the `generate_signals` method.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
16
|
+
"""
|
|
17
|
+
Calculates indicators and generates buy (1), sell/short (-1), or hold (0) signals.
|
|
18
|
+
|
|
19
|
+
Parameters:
|
|
20
|
+
df (pd.DataFrame): DataFrame containing historical market data.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
pd.DataFrame: A copy of the DataFrame with the 'Signal' column.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
28
|
+
"""
|
|
29
|
+
Allows the strategy instance to be invoked directly on a DataFrame.
|
|
30
|
+
For example:
|
|
31
|
+
strategy = MACrossover()
|
|
32
|
+
df = strategy(df)
|
|
33
|
+
"""
|
|
34
|
+
return self.generate_signals(df)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MACrossover(BaseStrategy):
|
|
38
|
+
"""
|
|
39
|
+
Moving Average Crossover Strategy.
|
|
40
|
+
|
|
41
|
+
Buys (Signal = 1) when the short moving average crosses ABOVE the long moving average.
|
|
42
|
+
Sells/Shorts (Signal = -1) when the short moving average crosses BELOW the long moving average.
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self, short_window: int = 20, long_window: int = 50, column: str = 'Close'):
|
|
45
|
+
self.short_window = short_window
|
|
46
|
+
self.long_window = long_window
|
|
47
|
+
self.column = column
|
|
48
|
+
|
|
49
|
+
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
50
|
+
validate_dataframe(df, [self.column], self.__class__.__name__)
|
|
51
|
+
df = df.copy()
|
|
52
|
+
|
|
53
|
+
# Calculate moving averages
|
|
54
|
+
sma_short = df[self.column].rolling(window=self.short_window).mean()
|
|
55
|
+
sma_long = df[self.column].rolling(window=self.long_window).mean()
|
|
56
|
+
|
|
57
|
+
# Generate raw crossover signals
|
|
58
|
+
df['Signal'] = np.where(sma_short > sma_long, 1, -1)
|
|
59
|
+
|
|
60
|
+
# Set signal to 0 during the initial warm-up period (where averages are NaN)
|
|
61
|
+
df.loc[sma_long.isna() | sma_short.isna(), 'Signal'] = 0
|
|
62
|
+
|
|
63
|
+
return df
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class RSIReversion(BaseStrategy):
|
|
67
|
+
"""
|
|
68
|
+
RSI Mean Reversion Strategy.
|
|
69
|
+
|
|
70
|
+
Buys (Signal = 1) when RSI falls below the lower bound (oversold).
|
|
71
|
+
Sells/Shorts (Signal = -1) when RSI rises above the upper bound (overbought).
|
|
72
|
+
"""
|
|
73
|
+
def __init__(self, lower_bound: int = 30, upper_bound: int = 70, column: str = 'Close', window: int = 14):
|
|
74
|
+
self.lower_bound = lower_bound
|
|
75
|
+
self.upper_bound = upper_bound
|
|
76
|
+
self.column = column
|
|
77
|
+
self.window = window
|
|
78
|
+
|
|
79
|
+
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
80
|
+
validate_dataframe(df, [self.column], self.__class__.__name__)
|
|
81
|
+
df = df.copy()
|
|
82
|
+
|
|
83
|
+
# Check if the RSI column is present, if not compute it
|
|
84
|
+
rsi_col = f'RSI_{self.window}'
|
|
85
|
+
if rsi_col not in df.columns:
|
|
86
|
+
df = rsi(df, window=self.window, column=self.column)
|
|
87
|
+
|
|
88
|
+
df['Signal'] = 0
|
|
89
|
+
df.loc[df[rsi_col] < self.lower_bound, 'Signal'] = 1
|
|
90
|
+
df.loc[df[rsi_col] > self.upper_bound, 'Signal'] = -1
|
|
91
|
+
|
|
92
|
+
# Forward fill the signals so we carry the active position forward
|
|
93
|
+
df['Signal'] = df['Signal'].replace(0, np.nan).ffill().fillna(0).astype(int)
|
|
94
|
+
|
|
95
|
+
# Set signal to 0 during the initial warm-up period (where RSI is NaN)
|
|
96
|
+
df.loc[df[rsi_col].isna(), 'Signal'] = 0
|
|
97
|
+
|
|
98
|
+
return df
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class BollingerBreakout(BaseStrategy):
|
|
102
|
+
"""
|
|
103
|
+
Bollinger Bands Breakout Strategy.
|
|
104
|
+
|
|
105
|
+
Buys (Signal = 1) when the price closes ABOVE the upper Bollinger Band.
|
|
106
|
+
Sells/Shorts (Signal = -1) when the price closes BELOW the lower Bollinger Band.
|
|
107
|
+
"""
|
|
108
|
+
def __init__(self, window: int = 20, num_std: int = 2, column: str = 'Close'):
|
|
109
|
+
self.window = window
|
|
110
|
+
self.num_std = num_std
|
|
111
|
+
self.column = column
|
|
112
|
+
|
|
113
|
+
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
114
|
+
validate_dataframe(df, [self.column], self.__class__.__name__)
|
|
115
|
+
df = df.copy()
|
|
116
|
+
|
|
117
|
+
# Check if Bollinger Bands are present, if not compute them
|
|
118
|
+
if 'BB_Upper' not in df.columns or 'BB_Lower' not in df.columns:
|
|
119
|
+
df = bb(df, window=self.window, num_std=self.num_std, column=self.column)
|
|
120
|
+
|
|
121
|
+
df['Signal'] = 0
|
|
122
|
+
df.loc[df[self.column] > df['BB_Upper'], 'Signal'] = 1
|
|
123
|
+
df.loc[df[self.column] < df['BB_Lower'], 'Signal'] = -1
|
|
124
|
+
|
|
125
|
+
# Forward fill the signals
|
|
126
|
+
df['Signal'] = df['Signal'].replace(0, np.nan).ffill().fillna(0).astype(int)
|
|
127
|
+
|
|
128
|
+
# Set signal to 0 during the initial warm-up period (where bands are NaN)
|
|
129
|
+
df.loc[df['BB_Upper'].isna() | df['BB_Lower'].isna(), 'Signal'] = 0
|
|
130
|
+
|
|
131
|
+
return df
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# --- Legacy Functional Wrappers (Backward Compatibility) ---
|
|
135
|
+
|
|
136
|
+
def mac_strategy(df: pd.DataFrame, short_window: int = 20, long_window: int = 50, column: str = 'Close') -> pd.DataFrame:
|
|
137
|
+
"""
|
|
138
|
+
Legacy wrapper for Moving Average Crossover Strategy.
|
|
139
|
+
"""
|
|
140
|
+
return MACrossover(short_window=short_window, long_window=long_window, column=column).generate_signals(df)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def rsi_reversion(df: pd.DataFrame, rsi_col: str, lower_bound: int = 30, upper_bound: int = 70) -> pd.DataFrame:
|
|
144
|
+
"""
|
|
145
|
+
Legacy wrapper for RSI Mean Reversion Strategy.
|
|
146
|
+
Assumes that the RSI values have already been pre-calculated in `rsi_col`.
|
|
147
|
+
"""
|
|
148
|
+
validate_dataframe(df, [rsi_col], "rsi_reversion")
|
|
149
|
+
df = df.copy()
|
|
150
|
+
|
|
151
|
+
df['Signal'] = 0
|
|
152
|
+
df.loc[df[rsi_col] < lower_bound, 'Signal'] = 1
|
|
153
|
+
df.loc[df[rsi_col] > upper_bound, 'Signal'] = -1
|
|
154
|
+
|
|
155
|
+
# Forward fill signals
|
|
156
|
+
df['Signal'] = df['Signal'].replace(0, np.nan).ffill().fillna(0).astype(int)
|
|
157
|
+
|
|
158
|
+
# Ensure warm-up values are clean
|
|
159
|
+
df.loc[df[rsi_col].isna(), 'Signal'] = 0
|
|
160
|
+
|
|
161
|
+
return df
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
class MissingColumnsError(ValueError):
|
|
5
|
+
"""
|
|
6
|
+
Custom exception raised when a DataFrame is missing one or more required columns
|
|
7
|
+
necessary for running an indicator or strategy.
|
|
8
|
+
"""
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
def validate_dataframe(df: pd.DataFrame, required_columns: List[str], caller_name: str) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Validates that the input is a pandas DataFrame and contains all required columns.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
df (pd.DataFrame): The DataFrame to validate.
|
|
17
|
+
required_columns (List[str]): List of column names that must be present.
|
|
18
|
+
caller_name (str): The name of the function or class performing the validation (for error messages).
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
TypeError: If the input is not a pandas DataFrame.
|
|
22
|
+
MissingColumnsError: If any of the required columns are missing from the DataFrame.
|
|
23
|
+
"""
|
|
24
|
+
if not isinstance(df, pd.DataFrame):
|
|
25
|
+
raise TypeError(f"Stratpy Error: Input to '{caller_name}' must be a pandas DataFrame, got {type(df).__name__}.")
|
|
26
|
+
|
|
27
|
+
missing_cols = [col for col in required_columns if col not in df.columns]
|
|
28
|
+
if missing_cols:
|
|
29
|
+
raise MissingColumnsError(
|
|
30
|
+
f"Stratpy Error: '{caller_name}' requires the following column(s): {missing_cols}. "
|
|
31
|
+
f"Please verify your DataFrame contains these columns. Available columns: {list(df.columns)}"
|
|
32
|
+
)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: stratpy-lib
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A modular python library for building algorithmic trading strategies.
|
|
5
|
+
Author-email: Albert Akinola <albert.akinola@outlook.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
8
|
+
Classifier: Operating System :: OS Independent
|
|
9
|
+
Requires-Python: >=3.8
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: pandas
|
|
13
|
+
Requires-Dist: numpy
|
|
14
|
+
Requires-Dist: yfinance
|
|
15
|
+
Requires-Dist: matplotlib
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest; extra == "dev"
|
|
18
|
+
Requires-Dist: build; extra == "dev"
|
|
19
|
+
Requires-Dist: twine; extra == "dev"
|
|
20
|
+
Dynamic: license-file
|
|
21
|
+
|
|
22
|
+
# stratpy
|
|
23
|
+
A modular Python library for building and testing algorithmic trading strategies quickly, without writing tons of code from scratch
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/stratpy/__init__.py
|
|
5
|
+
src/stratpy/backtest.py
|
|
6
|
+
src/stratpy/data.py
|
|
7
|
+
src/stratpy/indicators.py
|
|
8
|
+
src/stratpy/strategies.py
|
|
9
|
+
src/stratpy/utils.py
|
|
10
|
+
src/stratpy_lib.egg-info/PKG-INFO
|
|
11
|
+
src/stratpy_lib.egg-info/SOURCES.txt
|
|
12
|
+
src/stratpy_lib.egg-info/dependency_links.txt
|
|
13
|
+
src/stratpy_lib.egg-info/requires.txt
|
|
14
|
+
src/stratpy_lib.egg-info/top_level.txt
|
|
15
|
+
tests/test_indicators.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
stratpy
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import unittest
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
# Ensure the local 'src' directory is prioritized in the Python path
|
|
8
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
|
9
|
+
|
|
10
|
+
import stratpy as sp
|
|
11
|
+
from stratpy.utils import MissingColumnsError
|
|
12
|
+
|
|
13
|
+
class TestIndicators(unittest.TestCase):
|
|
14
|
+
|
|
15
|
+
def setUp(self):
|
|
16
|
+
# Create a basic sample DataFrame for testing
|
|
17
|
+
dates = pd.date_range(start="2026-01-01", periods=10, freq="D")
|
|
18
|
+
self.df = pd.DataFrame({
|
|
19
|
+
"Open": [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0],
|
|
20
|
+
"High": [102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0],
|
|
21
|
+
"Low": [98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0],
|
|
22
|
+
"Close": [101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0],
|
|
23
|
+
"Volume": [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
|
|
24
|
+
}, index=dates)
|
|
25
|
+
|
|
26
|
+
def test_sma_calculation(self):
|
|
27
|
+
# Test SMA with window 3
|
|
28
|
+
result = sp.sma(self.df, window=3, column='Close')
|
|
29
|
+
self.assertIn("SMA_3", result.columns)
|
|
30
|
+
# The first two should be NaN
|
|
31
|
+
self.assertTrue(np.isnan(result["SMA_3"].iloc[0]))
|
|
32
|
+
self.assertTrue(np.isnan(result["SMA_3"].iloc[1]))
|
|
33
|
+
# The third value should be average of 101, 102, 103 = 102
|
|
34
|
+
self.assertAlmostEqual(result["SMA_3"].iloc[2], 102.0)
|
|
35
|
+
|
|
36
|
+
def test_missing_column_validation(self):
|
|
37
|
+
# Create a DataFrame missing 'Close'
|
|
38
|
+
incomplete_df = self.df.drop(columns=["Close"])
|
|
39
|
+
|
|
40
|
+
# Calling an indicator requiring 'Close' should raise MissingColumnsError
|
|
41
|
+
with self.assertRaises(MissingColumnsError) as context:
|
|
42
|
+
sp.sma(incomplete_df, window=3, column='Close')
|
|
43
|
+
self.assertIn("requires the following column(s): ['Close']", str(context.exception))
|
|
44
|
+
|
|
45
|
+
# ATR requires High, Low, Close. Drop 'High'
|
|
46
|
+
incomplete_df_atr = self.df.drop(columns=["High"])
|
|
47
|
+
with self.assertRaises(MissingColumnsError) as context:
|
|
48
|
+
sp.atr(incomplete_df_atr, window=3)
|
|
49
|
+
self.assertIn("requires the following column(s): ['High']", str(context.exception))
|
|
50
|
+
|
|
51
|
+
def test_rsi_division_by_zero_flat_price(self):
|
|
52
|
+
# Create a DataFrame where price does not change at all (RSI gain/loss will be 0)
|
|
53
|
+
flat_dates = pd.date_range(start="2026-01-01", periods=10, freq="D")
|
|
54
|
+
flat_df = pd.DataFrame({
|
|
55
|
+
"Close": [100.0] * 10
|
|
56
|
+
}, index=flat_dates)
|
|
57
|
+
|
|
58
|
+
# RSI should run and return 50.0 (safely handling division by zero) rather than crashing
|
|
59
|
+
result = sp.rsi(flat_df, window=3, column='Close')
|
|
60
|
+
self.assertIn("RSI_3", result.columns)
|
|
61
|
+
# The values after the rolling warm-up should be 50.0
|
|
62
|
+
# Wait, since gain and loss are 0, total_change is 0, so result is 50.0
|
|
63
|
+
self.assertAlmostEqual(result["RSI_3"].iloc[2], 50.0)
|
|
64
|
+
|
|
65
|
+
def test_vwap_division_by_zero_zero_volume(self):
|
|
66
|
+
# Create a DataFrame where volume is zero
|
|
67
|
+
zero_vol_df = self.df.copy()
|
|
68
|
+
zero_vol_df["Volume"] = 0
|
|
69
|
+
|
|
70
|
+
# VWAP should run and return NaN (or handles zero volume safely) rather than crashing or throwing ZeroDivisionError
|
|
71
|
+
result = sp.vwap(zero_vol_df)
|
|
72
|
+
self.assertIn("VWAP", result.columns)
|
|
73
|
+
self.assertTrue(np.isnan(result["VWAP"].iloc[0]))
|
|
74
|
+
self.assertTrue(np.isnan(result["VWAP"].iloc[-1]))
|
|
75
|
+
|
|
76
|
+
if __name__ == '__main__':
|
|
77
|
+
unittest.main()
|