auto-chart-patterns 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,169 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict
3
+ from abc import abstractmethod
4
+ from .line import Pivot
5
+ from .zigzag import Zigzag
6
+ import numpy as np
7
+
8
+ import logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @dataclass
12
+ class ChartPatternProperties:
13
+ offset: int = 0
14
+ min_periods_lapsed: int = 21
15
+ max_live_patterns: int = 50
16
+ avoid_overlap: bool = True # whether to avoid overlapping patterns
17
+ allowed_patterns: List[bool] = None
18
+ allowed_last_pivot_directions: List[int] = None
19
+
20
+ @dataclass
21
+ class ChartPattern:
22
+ """Base class for chart patterns"""
23
+ pivots: List[Pivot]
24
+ pattern_type: int = 0
25
+ pattern_name: str = ""
26
+ extra_props: Dict = None # for adding extra properties, not serialized
27
+
28
+ def dict(self):
29
+ return {
30
+ "pivots": [p.dict() for p in self.pivots],
31
+ "pattern_type": self.pattern_type,
32
+ "pattern_name": self.pattern_name
33
+ }
34
+
35
+ def __eq__(self, other: 'ChartPattern') -> bool:
36
+ other_pivot_indexes = [p.point.time for p in other.pivots]
37
+ self_pivot_indexes = [p.point.time for p in self.pivots]
38
+ return self.pattern_type == other.pattern_type and \
39
+ self_pivot_indexes == other_pivot_indexes
40
+
41
+ @classmethod
42
+ def from_dict(cls, dict: dict):
43
+ self = cls(pivots=[Pivot.from_dict(p) for p in dict["pivots"]],
44
+ pattern_type=dict["pattern_type"],
45
+ pattern_name=dict["pattern_name"])
46
+ return self
47
+
48
+ @abstractmethod
49
+ def get_pattern_name_by_id(self, pattern_type: int) -> str:
50
+ """Get pattern name from pattern type ID
51
+
52
+ Args:
53
+ pattern_type: Pattern type identifier
54
+
55
+ Returns:
56
+ str: Name of the pattern
57
+ """
58
+ pass
59
+
60
+ def process_pattern(self, properties: ChartPatternProperties,
61
+ patterns: List['ChartPattern']) -> bool:
62
+ """
63
+ Process a new pattern: validate it, check if it's allowed, and manage pattern list
64
+
65
+ Args:
66
+ properties: Scan properties
67
+ patterns: List of existing patterns
68
+ max_live_patterns: Maximum number of patterns to keep
69
+
70
+ Returns:
71
+ bool: True if pattern was successfully processed and added
72
+ """
73
+ # Log warning if invalid pattern type detected
74
+ if self.pattern_type == 0:
75
+ return False
76
+
77
+ # Get last direction from the last pivot
78
+ last_dir = self.pivots[-1].direction
79
+
80
+ # Get allowed last pivot direction for this pattern type
81
+ allowed_last_pivot_direction = 0
82
+ if properties.allowed_last_pivot_directions is not None:
83
+ if self.pattern_type < len(properties.allowed_last_pivot_directions):
84
+ allowed_last_pivot_direction = properties.allowed_last_pivot_directions[self.pattern_type]
85
+
86
+ # Check if pattern type is allowed
87
+ pattern_allowed = True
88
+ if properties.allowed_patterns is not None:
89
+ if self.pattern_type > len(properties.allowed_patterns):
90
+ pattern_allowed = False
91
+ else:
92
+ pattern_allowed = (self.pattern_type > 0 and
93
+ properties.allowed_patterns[self.pattern_type-1])
94
+
95
+ # Check if direction is allowed
96
+ direction_allowed = (allowed_last_pivot_direction == 0 or
97
+ allowed_last_pivot_direction == last_dir)
98
+
99
+ if pattern_allowed and direction_allowed:
100
+ # Check for existing pattern with same pivots
101
+ existing_pattern = False
102
+ replacing_patterns = []
103
+
104
+ for idx, existing in enumerate(patterns):
105
+ # Check if pivots match
106
+ existing_indexes = set([p.point.time for p in existing.pivots])
107
+ self_indexes = set([p.point.time for p in self.pivots])
108
+ # check if the indexes of self.pivots are a subset of existing.pivots
109
+ if self_indexes == existing_indexes:
110
+ existing_pattern = True
111
+ break
112
+ elif self_indexes.issubset(existing_indexes) and properties.avoid_overlap:
113
+ existing_pattern = True
114
+ break
115
+ elif existing_indexes.issubset(self_indexes) and properties.avoid_overlap:
116
+ replacing_patterns.append(idx)
117
+
118
+ if not existing_pattern:
119
+ for idx in replacing_patterns:
120
+ patterns.pop(idx)
121
+
122
+ # Set pattern name
123
+ self.pattern_name = self.get_pattern_name_by_id(self.pattern_type)
124
+
125
+ # Add new pattern and manage list size
126
+ patterns.append(self)
127
+ while len(patterns) > properties.max_live_patterns:
128
+ patterns.pop(0)
129
+
130
+ return True
131
+
132
+ return False
133
+
134
+ def get_pivots_from_zigzag(zigzag: Zigzag, pivots: List[Pivot], offset: int, min_pivots: int) -> int:
135
+ for i in range(min_pivots):
136
+ pivot = zigzag.get_pivot(i + offset)
137
+ if pivot is None:
138
+ return i
139
+ pivots.insert(0, pivot.deep_copy())
140
+ return i+1
141
+
142
+ def is_same_height(pivot1: Pivot, pivot2: Pivot, ref_pivots: List[Pivot], flat_ratio: float) -> bool:
143
+ # check if two pivots are approximately flat with a list of reference pivots
144
+ # use the first and last pivots in the list as reference points
145
+ if np.sign(pivot1.direction) != np.sign(pivot2.direction):
146
+ raise ValueError("Pivots must have the same direction")
147
+
148
+ # use the reference pivots to calculate the height ratio
149
+ if pivot1.direction > 0:
150
+ ref_prices = np.min([p.point.price for p in ref_pivots])
151
+ else:
152
+ ref_prices = np.max([p.point.price for p in ref_pivots])
153
+
154
+ diff1 = pivot1.point.price - ref_prices
155
+ diff2 = pivot2.point.price - ref_prices
156
+ if diff2 == 0:
157
+ return False
158
+
159
+ ratio = diff1 / diff2
160
+ fit_pct = 1 - flat_ratio
161
+ if ratio < 1:
162
+ same_height = ratio >= fit_pct
163
+ else:
164
+ same_height = ratio <= 1 / fit_pct
165
+ logger.debug(f"Pivot {pivot1.point.index} ({pivot1.point.price:.4f}) "
166
+ f"and {pivot2.point.index} ({pivot2.point.price:.4f}), "
167
+ f"ref_prices: {ref_prices:.4f}, ratio: {ratio:.4f}")
168
+ return same_height
169
+
@@ -0,0 +1,84 @@
1
+ from dataclasses import dataclass
2
+ import pandas as pd
3
+ @dataclass
4
+ class Point:
5
+ time: str
6
+ index: int
7
+ price: float
8
+
9
+ def dict(self):
10
+ return {
11
+ "time": str(self.time),
12
+ "price": float(self.price)
13
+ }
14
+
15
+ @classmethod
16
+ def from_dict(cls, dict: dict):
17
+ self = cls(time=pd.to_datetime(dict["time"]),
18
+ index=0,
19
+ price=dict["price"])
20
+ return self
21
+
22
+ def copy(self):
23
+ return Point(self.time, self.index, self.price)
24
+
25
+ @dataclass
26
+ class Pivot:
27
+ point: Point
28
+ direction: int # 1 for high, -1 for low
29
+ diff: float = 0.0 # price difference between the pivot and the previous pivot
30
+ index_diff: int = 0 # index difference between the pivot and the previous pivot
31
+ cross_diff: float = 0.0 # price difference between the pivot and the previous pivot of the same direction
32
+
33
+ def dict(self):
34
+ return {
35
+ "point": self.point.dict(),
36
+ "direction": self.direction
37
+ }
38
+
39
+ @classmethod
40
+ def from_dict(cls, dict: dict):
41
+ self = cls(point=Point.from_dict(dict["point"]),
42
+ direction=dict["direction"])
43
+ return self
44
+
45
+ def deep_copy(self):
46
+ return Pivot(
47
+ point=self.point.copy(),
48
+ direction=self.direction,
49
+ cross_diff=self.cross_diff,
50
+ diff=self.diff,
51
+ index_diff=self.index_diff
52
+ )
53
+
54
+ @dataclass
55
+ class Line:
56
+ def dict(self):
57
+ return {
58
+ "p1": self.p1.dict(),
59
+ "p2": self.p2.dict()
60
+ }
61
+
62
+ @classmethod
63
+ def from_dict(cls, dict):
64
+ return cls(p1=Point.from_dict(dict["p1"]), p2=Point.from_dict(dict["p2"]))
65
+
66
+ def __init__(self, p1: Point, p2: Point):
67
+ self.p1 = p1
68
+ self.p2 = p2
69
+
70
+ def get_price(self, index: int) -> float:
71
+ """Calculate price at given index using linear interpolation"""
72
+ if self.p2.index == self.p1.index:
73
+ return self.p1.price
74
+
75
+ slope = (self.p2.price - self.p1.price) / (self.p2.index - self.p1.index)
76
+ return self.p1.price + slope * (index - self.p1.index)
77
+
78
+ def get_slope(self) -> float:
79
+ if self.p2.index == self.p1.index:
80
+ return 0.0
81
+ return (self.p2.price - self.p1.price) / (self.p2.index - self.p1.index)
82
+
83
+ def copy(self):
84
+ return Line(self.p1.copy(), self.p2.copy())
@@ -0,0 +1,202 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ from .chart_pattern import ChartPattern, ChartPatternProperties, get_pivots_from_zigzag, \
4
+ is_same_height
5
+ from .line import Pivot, Line, Point
6
+ from .zigzag import Zigzag
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @dataclass
12
+ class ReversalPatternProperties(ChartPatternProperties):
13
+ # minimum number of days to form a pattern
14
+ min_periods_lapsed: int = 15
15
+ # maximum allowed ratio between aligned horizontal pivots
16
+ flat_ratio: float = 0.15
17
+ # maximum allowed ratio between the two side peaks and the middle peak
18
+ peak_symmetry_ratio: float = 0.5
19
+
20
+ class ReversalPattern(ChartPattern):
21
+ def __init__(self, pivots: List[Pivot], support_line: Line):
22
+ self.pivots = pivots
23
+ self.pivots_count = len(pivots)
24
+ self.support_line = support_line
25
+ self.extra_props = {}
26
+
27
+ @classmethod
28
+ def from_dict(cls, dict):
29
+ self = cls(pivots=[Pivot.from_dict(p) for p in dict["pivots"]],
30
+ support_line=Line.from_dict(dict["support_line"]))
31
+ self.pattern_type = dict["pattern_type"]
32
+ self.pattern_name = dict["pattern_name"]
33
+ return self
34
+
35
+ def dict(self):
36
+ obj = super().dict()
37
+ obj["support_line"] = self.support_line.dict()
38
+ return obj
39
+
40
+ def get_pattern_name_by_id(self, id: int) -> str:
41
+ pattern_names = {
42
+ 1: "Double Tops",
43
+ 2: "Double Bottoms",
44
+ 3: "Triple Tops",
45
+ 4: "Triple Bottoms",
46
+ 5: "Head and Shoulders",
47
+ 6: "Inverted Head and Shoulders",
48
+ }
49
+ return pattern_names[id]
50
+
51
+ def resolve(self, properties: ReversalPatternProperties) -> 'ReversalPattern':
52
+ self.pattern_type = 0
53
+ if self.pivots_count == 5:
54
+ if self.pivots[0].direction < 0:
55
+ self.pattern_type = 1 # Double Tops
56
+ else:
57
+ self.pattern_type = 2 # Double Bottoms
58
+ elif self.pivots_count == 7:
59
+ if is_same_height(self.pivots[1], self.pivots[5], self.pivots, properties.flat_ratio):
60
+ # check if three pivots are approximately flat
61
+ if is_same_height(self.pivots[1], self.pivots[3], self.pivots, properties.flat_ratio) and \
62
+ is_same_height(self.pivots[3], self.pivots[5], self.pivots, properties.flat_ratio):
63
+ # 3 pivots are approximately flat, we have a triple top or bottom
64
+ logger.debug(f"Pivots: {self.pivots[1].point.index}, {self.pivots[3].point.index}, "
65
+ f"{self.pivots[5].point.index} are flat")
66
+ if self.pivots[0].direction < 0:
67
+ self.pattern_type = 3 # Triple Tops
68
+ else:
69
+ self.pattern_type = 4 # Triple Bottoms
70
+ # check if the side peaks are lower than the middle peak
71
+ elif self.pivots[0].direction < 0 and self.pivots[3].cross_diff > 0 and \
72
+ self.pivots[5].cross_diff < 0:
73
+ self.pattern_type = 5 # Head and Shoulders
74
+ elif self.pivots[0].direction > 0 and self.pivots[3].cross_diff < 0 and \
75
+ self.pivots[5].cross_diff > 0:
76
+ self.pattern_type = 6 # Inverted Head and Shoulders
77
+ else:
78
+ raise ValueError("Invalid number of pivots")
79
+ return self
80
+
81
+ def check_peak_symmetry(diff1: int, diff2: int, threshold: float) -> bool:
82
+ # check the symmetry of the side peaks and the middle peak
83
+ ratio = float(diff1) / float(diff2)
84
+ fit_pct = 1 - threshold
85
+ if ratio < 1:
86
+ valid = ratio >= fit_pct
87
+ else:
88
+ valid = ratio <= 1 / fit_pct
89
+ return valid
90
+
91
+ def inspect_five_pivot_pattern(pivots: List[Pivot], properties: ReversalPatternProperties) -> bool:
92
+ # check tops or bottoms are approximately flat
93
+ if is_same_height(pivots[1], pivots[3], pivots, properties.flat_ratio):
94
+ if pivots[0].direction > 0:
95
+ # may be a double bottom, check the sandle point price
96
+ if pivots[2].point.price < pivots[0].point.price or \
97
+ pivots[2].point.price < pivots[4].point.price:
98
+ return True
99
+ else:
100
+ # may be a double top, check the sandle point price
101
+ if pivots[2].point.price > pivots[0].point.price or \
102
+ pivots[2].point.price > pivots[4].point.price:
103
+ return True
104
+ return False
105
+
106
+ def inspect_seven_pivot_pattern(pivots: List[Pivot], properties: ReversalPatternProperties) -> bool:
107
+ # check the double sandle points price range and flat ratio
108
+ if pivots[0].direction > 0:
109
+ if pivots[2].point.price >= pivots[0].point.price or \
110
+ pivots[4].point.price >= pivots[0].point.price:
111
+ return False
112
+ else:
113
+ if pivots[2].point.price <= pivots[0].point.price or \
114
+ pivots[4].point.price <= pivots[0].point.price:
115
+ return False
116
+ # check the symmetry of the side peaks and the middle peak
117
+ return check_peak_symmetry(pivots[3].point.index - pivots[1].point.index,
118
+ pivots[5].point.index - pivots[3].point.index,
119
+ properties.peak_symmetry_ratio)
120
+
121
+ def find_cross_point(line: Line, start_index: int, end_index: int, zigzag: Zigzag) -> Optional[Point]:
122
+ if start_index > end_index:
123
+ return None
124
+ for i in range(start_index, end_index):
125
+ current = zigzag.get_df_data_by_index(i)
126
+ high = current['high']
127
+ low = current['low']
128
+ price = line.get_price(i)
129
+ if high >= price and low <= price:
130
+ return Point(zigzag.get_df_data_by_index(i).name, i, price)
131
+ return None
132
+
133
+ def get_support_line(pivots: List[Pivot], start_index: int, end_index: int, zigzag: Zigzag) -> Optional[Line]:
134
+ if len(pivots) > 2:
135
+ raise ValueError("At most two points are required to form a line")
136
+ if len(pivots) == 1:
137
+ line = Line(pivots[0].point, pivots[0].point)
138
+ cross_point2 = find_cross_point(line, pivots[0].point.index+1, end_index, zigzag)
139
+ else:
140
+ line = Line(pivots[0].point, pivots[1].point)
141
+ cross_point2 = find_cross_point(line, pivots[1].point.index+1, end_index, zigzag)
142
+
143
+ cross_point1 = find_cross_point(line, start_index, pivots[0].point.index, zigzag)
144
+ if cross_point1 is None:
145
+ # the line is not crossing the chart on the left side
146
+ return None
147
+ # the cross point on the right side can be none as the chart is still trending
148
+ if cross_point2 is None:
149
+ cross_point2 = Point(zigzag.get_df_data_by_index(end_index).name,
150
+ end_index, line.get_price(end_index))
151
+ return Line(cross_point1, cross_point2)
152
+
153
+ def find_reversal_patterns(zigzag: Zigzag, offset: int, properties: ReversalPatternProperties,
154
+ patterns: List[ReversalPattern]) -> bool:
155
+ """
156
+ Find reversal patterns using zigzag pivots
157
+
158
+ Args:
159
+ zigzag: Zigzag instance
160
+ offset: Offset to start searching for pivots
161
+ properties: Reversal pattern properties
162
+ patterns: List to store found patterns
163
+
164
+ Returns:
165
+ List[ReversalPattern]: Found patterns
166
+ """
167
+ found_7_pattern = False
168
+ found_5_pattern = False
169
+ pivots = []
170
+ pivots_count = get_pivots_from_zigzag(zigzag, pivots, offset, 7)
171
+ if pivots_count == 7:
172
+ if inspect_seven_pivot_pattern(pivots, properties):
173
+ # we may have a triple top or bottom or head and shoulders
174
+ support_line = get_support_line(
175
+ [pivots[2], pivots[4]], pivots[0].point.index, pivots[6].point.index, zigzag)
176
+
177
+ index_delta = pivots[-1].point.index - pivots[0].point.index + 1
178
+ if support_line is not None and index_delta >= properties.min_periods_lapsed:
179
+ pattern = ReversalPattern(pivots, support_line).resolve(properties)
180
+ found_7_pattern = pattern.process_pattern(properties, patterns)
181
+
182
+ # continue to inspect 5 point pattern
183
+ if pivots_count >= 5:
184
+ for i in range(0, pivots_count - 5 + 1):
185
+ pivots = []
186
+ get_pivots_from_zigzag(zigzag, pivots, offset + i, 5) # check the last 5 pivots as the pivots are in reverse order
187
+ if inspect_five_pivot_pattern(pivots, properties):
188
+ # use the sandle point to form a support line
189
+ support_line = get_support_line(
190
+ [pivots[2]], pivots[0].point.index, pivots[4].point.index, zigzag)
191
+
192
+ index_delta = pivots[-1].point.index - pivots[0].point.index + 1
193
+ if support_line is not None and index_delta >= properties.min_periods_lapsed:
194
+ pattern = ReversalPattern(pivots, support_line).resolve(properties)
195
+ found = pattern.process_pattern(properties, patterns)
196
+
197
+ if found:
198
+ found_5_pattern = True
199
+
200
+ return found_7_pattern or found_5_pattern
201
+
202
+
@@ -0,0 +1,146 @@
1
+ from dataclasses import dataclass
2
+ from auto_chart_patterns.chart_pattern import ChartPatternProperties
3
+ from auto_chart_patterns.line import Line, Point
4
+ from auto_chart_patterns.zigzag import window_peaks
5
+ from typing import List
6
+ import pandas as pd
7
+ import logging
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+ @dataclass
12
+ class RsiDivergenceProperties(ChartPatternProperties):
13
+ min_periods_lapsed: int = 5 # minimum number of days to form a pattern
14
+ min_change_pct: float = 0.005 # minimum change percentage
15
+
16
+ class RsiDivergencePattern:
17
+ def __init__(self, points: List[Point], divergence_line: Line, is_high_pivots: bool):
18
+ self.points = points
19
+ self.divergence_line = divergence_line
20
+ self.is_high_pivots = is_high_pivots
21
+
22
+ def get_pattern_name_by_id(self, id: int) -> str:
23
+ pattern_names = {
24
+ 1: "Bullish",
25
+ 2: "Bearish",
26
+ 3: "Hidden Bullish",
27
+ 4: "Hidden Bearish",
28
+ }
29
+ return pattern_names[id]
30
+
31
+ def get_change_direction(self, value1: float, value2: float,
32
+ properties: RsiDivergenceProperties) -> int:
33
+ change_pct = (value2 - value1) / value1
34
+ if change_pct > properties.min_change_pct:
35
+ return 1
36
+ elif change_pct < -properties.min_change_pct:
37
+ return -1
38
+ return 0
39
+
40
+ def resolve(self, properties: RsiDivergenceProperties) -> 'RsiDivergencePattern':
41
+ if len(self.points) != 2:
42
+ raise ValueError("Rsi Divergence must have 2 points")
43
+ self.pattern_type = 0
44
+
45
+ # makes prices always greater than the rsi values
46
+ price_change_dir = self.get_change_direction(self.points[0].price,
47
+ self.points[1].price, properties)
48
+ rsi_change_dir = self.get_change_direction(self.divergence_line.p1.price,
49
+ self.divergence_line.p2.price, properties)
50
+
51
+ log.debug(f"points: {self.points[0].index}, {self.points[1].index}, "
52
+ f"rsi: {self.divergence_line.p1.price}, {self.divergence_line.p2.price}, "
53
+ f"price_change_dir: {price_change_dir}, rsi_change_dir: {rsi_change_dir}")
54
+
55
+ if price_change_dir == 1 and rsi_change_dir == -1:
56
+ if self.is_high_pivots:
57
+ # higher high but lower RSI
58
+ self.pattern_type = 2 # bearish
59
+ else:
60
+ # higher low but lower RSI
61
+ self.pattern_type = 3 # hidden bullish
62
+ elif price_change_dir == -1 and rsi_change_dir == 1:
63
+ if self.is_high_pivots:
64
+ # lower high but higher RSI
65
+ self.pattern_type = 4 # hidden bearish
66
+ else:
67
+ # lower low but higher RSI
68
+ self.pattern_type = 1 # bullish
69
+
70
+ if self.pattern_type != 0:
71
+ self.pattern_name = self.get_pattern_name_by_id(self.pattern_type)
72
+ return self
73
+
74
+ def calc_rsi(prices: pd.DataFrame, period: int = 14) -> pd.Series:
75
+ """Calculate RSI"""
76
+ series = prices["close"]
77
+ ewm = dict(alpha=1.0 / period, min_periods=period, adjust=True, ignore_na=True)
78
+ diff = series.diff()
79
+ ups = diff.clip(lower=0).ewm(**ewm).mean()
80
+ downs = diff.clip(upper=0).abs().ewm(**ewm).mean()
81
+
82
+ return 100.0 - (100.0 / (1.0 + ups / downs))
83
+
84
+ def handle_rsi_pivots(rsi_pivots: pd.DataFrame, is_high_pivots: bool,
85
+ properties: RsiDivergenceProperties, patterns: List[RsiDivergencePattern]):
86
+ if is_high_pivots:
87
+ rsi_col = 'rsi_high'
88
+ price_col = 'high'
89
+ else:
90
+ rsi_col = 'rsi_low'
91
+ price_col = 'low'
92
+
93
+ for i in range(len(rsi_pivots)-1):
94
+ current_row = rsi_pivots.iloc[i]
95
+ next_row = rsi_pivots.iloc[i+1]
96
+ current_index = current_row['row_number'].astype(int)
97
+ next_index = next_row['row_number'].astype(int)
98
+ if next_index - current_index + 1 < properties.min_periods_lapsed:
99
+ continue
100
+
101
+ point1 = Point(current_row.name, current_index,
102
+ current_row[rsi_col])
103
+ point2 = Point(next_row.name, next_index,
104
+ next_row[rsi_col])
105
+ divergence_line = Line(point1, point2)
106
+ price_points = [Point(current_row.name, current_index,
107
+ current_row[price_col]),
108
+ Point(next_row.name, next_index,
109
+ next_row[price_col])]
110
+ pattern = RsiDivergencePattern(price_points, divergence_line, is_high_pivots).resolve(properties)
111
+ if pattern.pattern_type != 0:
112
+ patterns.append(pattern)
113
+
114
+ def find_rsi_divergences(backcandles: int, forwardcandles: int,
115
+ properties: RsiDivergenceProperties,
116
+ patterns: List[RsiDivergencePattern], df: pd.DataFrame):
117
+ """
118
+ Find RSI divergences using zigzag pivots
119
+
120
+ Args:
121
+ backcandles: Number of backcandles
122
+ forwardcandles: Number of forwardcandles
123
+ properties: RSI divergence properties
124
+ patterns: List to store found patterns
125
+ df: DataFrame with prices
126
+ """
127
+ # calculate rsi
128
+ rsi = calc_rsi(df)
129
+ # get rsi peaks
130
+ rsi_highs, rsi_lows = window_peaks(rsi, backcandles, forwardcandles)
131
+ rsi_high_pivots = rsi.where(rsi == rsi_highs)
132
+ rsi_low_pivots = rsi.where(rsi == rsi_lows)
133
+ # add row number
134
+ df['row_number'] = pd.RangeIndex(len(df))
135
+
136
+ # Merge for highs - including RSI values
137
+ rsi_pivots= pd.merge(
138
+ # Convert Series to DataFrame with column name
139
+ pd.DataFrame({'rsi_high': rsi_high_pivots, 'rsi_low': rsi_low_pivots}),
140
+ df[['row_number', 'high', 'low']],
141
+ left_index=True,
142
+ right_index=True,
143
+ how='inner'
144
+ )
145
+ handle_rsi_pivots(rsi_pivots[['rsi_high', 'high','row_number']].dropna(), True, properties, patterns)
146
+ handle_rsi_pivots(rsi_pivots[['rsi_low', 'low','row_number']].dropna(), False, properties, patterns)
@@ -0,0 +1,446 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple, Optional
3
+ import numpy as np
4
+ from .line import Point, Pivot, Line
5
+ from .zigzag import Zigzag
6
+ from .chart_pattern import ChartPattern, ChartPatternProperties, get_pivots_from_zigzag, \
7
+ is_same_height
8
+
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ @dataclass
13
+ class TrendLineProperties(ChartPatternProperties):
14
+ number_of_pivots: int = 5 # minimum number of pivots to form a pattern
15
+ flat_ratio: float = 0.2 # maximum allowed flat ratio between flat trend lines
16
+ align_ratio: float = 0.4 # maximum allowed ratio between aligned diagonal pivots
17
+ flag_ratio: float = 1.5 # minimum allowed flag/pennant ratio between flag pole and flag width
18
+ flag_pole_span: int = 15 # maximum periods a flag/pennant pole can span
19
+ flag_span: int = 15 # maximum periods a flag/pennant can span
20
+ max_candle_body_crosses: int = 1 # maximum allowed candle body crosses for a valid trend line
21
+
22
+ class TrendLinePattern(ChartPattern):
23
+ def __init__(self, pivots: List[Pivot], trend_line1: Line, trend_line2: Line):
24
+ self.pivots = pivots
25
+ self.trend_line1 = trend_line1
26
+ self.trend_line2 = trend_line2
27
+ self.extra_props = {}
28
+
29
+ @classmethod
30
+ def from_dict(cls, dict):
31
+ self = cls(pivots=[Pivot.from_dict(p) for p in dict["pivots"]],
32
+ trend_line1=Line.from_dict(dict["trend_line1"]),
33
+ trend_line2=Line.from_dict(dict["trend_line2"]))
34
+ self.pattern_type = dict["pattern_type"]
35
+ self.pattern_name = dict["pattern_name"]
36
+ return self
37
+
38
+ def dict(self):
39
+ obj = super().dict()
40
+ obj["trend_line1"] = self.trend_line1.dict()
41
+ obj["trend_line2"] = self.trend_line2.dict()
42
+ return obj
43
+
44
+ def get_pattern_name_by_id(self, id: int) -> str:
45
+ pattern_names = {
46
+ 1: "Ascending Channel",
47
+ 2: "Descending Channel",
48
+ 3: "Ranging Channel",
49
+ 4: "Rising Wedge (Expanding)",
50
+ 5: "Falling Wedge (Expanding)",
51
+ 6: "Diverging Triangle",
52
+ 7: "Ascending Triangle (Expanding)",
53
+ 8: "Descending Triangle (Expanding)",
54
+ 9: "Rising Wedge (Contracting)",
55
+ 10: "Falling Wedge (Contracting)",
56
+ 11: "Converging Triangle",
57
+ 12: "Descending Triangle (Contracting)",
58
+ 13: "Ascending Triangle (Contracting)",
59
+ 14: "Bull Pennant",
60
+ 15: "Bear Pennant",
61
+ 16: "Bull Flag",
62
+ 17: "Bear Flag"
63
+ }
64
+ return pattern_names.get(id, "Error")
65
+
66
+ def resolve(self, properties: TrendLineProperties) -> 'TrendLinePattern':
67
+ """
68
+ Resolve pattern by updating trend lines, pivot points, and ratios
69
+
70
+ Args:
71
+ properties: ScanProperties object containing pattern parameters
72
+
73
+ Returns:
74
+ self: Returns the pattern object for method chaining
75
+ """
76
+ # Get first and last indices/times from pivots
77
+ first_index = self.pivots[0].point.index
78
+ last_index = self.pivots[-1].point.index
79
+ first_time = self.pivots[0].point.time
80
+ last_time = self.pivots[-1].point.time
81
+
82
+ # Update trend line 1 endpoints
83
+ self.trend_line1.p1 = Point(
84
+ time=first_time,
85
+ index=first_index,
86
+ price=self.trend_line1.get_price(first_index)
87
+ )
88
+ self.trend_line1.p2 = Point(
89
+ time=last_time,
90
+ index=last_index,
91
+ price=self.trend_line1.get_price(last_index)
92
+ )
93
+
94
+ # Update trend line 2 endpoints
95
+ self.trend_line2.p1 = Point(
96
+ time=first_time,
97
+ index=first_index,
98
+ price=self.trend_line2.get_price(first_index)
99
+ )
100
+ self.trend_line2.p2 = Point(
101
+ time=last_time,
102
+ index=last_index,
103
+ price=self.trend_line2.get_price(last_index)
104
+ )
105
+
106
+ # Update pivot points to match trend lines
107
+ for i, pivot in enumerate(self.pivots):
108
+ current_trend_line = self.trend_line2 if i % 2 == 1 else self.trend_line1
109
+
110
+ # Update pivot price to match trend line
111
+ pivot.point.price = current_trend_line.get_price(pivot.point.index)
112
+
113
+ # Resolve pattern name/type
114
+ self.resolve_pattern_name(properties)
115
+ return self
116
+
117
+ def resolve_pattern_name(self, properties: TrendLineProperties) -> 'TrendLinePattern':
118
+ """Determine the pattern type based on trend lines and angles"""
119
+ t1p1 = self.trend_line1.p1.price
120
+ t1p2 = self.trend_line1.p2.price
121
+ t2p1 = self.trend_line2.p1.price
122
+ t2p2 = self.trend_line2.p2.price
123
+
124
+ # Calculate angles between trend lines
125
+ upper_angle = ((t1p2 - min(t2p1, t2p2)) / (t1p1 - min(t2p1, t2p2))
126
+ if t1p1 > t2p1 else
127
+ (t2p2 - min(t1p1, t1p2)) / (t2p1 - min(t1p1, t1p2)))
128
+
129
+ lower_angle = ((t2p2 - max(t1p1, t1p2)) / (t2p1 - max(t1p1, t1p2))
130
+ if t1p1 > t2p1 else
131
+ (t1p2 - max(t2p1, t2p2)) / (t1p1 - max(t2p1, t2p2)))
132
+
133
+ # Determine line directions
134
+ upper_line_dir = (1 if upper_angle > 1 + properties.flat_ratio else
135
+ -1 if upper_angle < 1 - properties.flat_ratio else 0)
136
+
137
+ lower_line_dir = (-1 if lower_angle > 1 + properties.flat_ratio else
138
+ 1 if lower_angle < 1 - properties.flat_ratio else 0)
139
+
140
+ # Calculate differences and ratios
141
+ start_diff = abs(t1p1 - t2p1)
142
+ end_diff = abs(t1p2 - t2p2)
143
+ min_diff = min(start_diff, end_diff)
144
+ bar_diff = self.trend_line1.p2.index - self.trend_line2.p1.index
145
+ price_diff = abs(start_diff - end_diff) / bar_diff if bar_diff != 0 else 0
146
+
147
+ probable_converging_bars = min_diff / price_diff if price_diff != 0 else float('inf')
148
+
149
+ is_expanding = end_diff > start_diff
150
+ is_contracting = start_diff > end_diff
151
+
152
+ is_channel = (probable_converging_bars > 2 * bar_diff or
153
+ (not is_expanding and not is_contracting) or
154
+ (upper_line_dir == 0 and lower_line_dir == 0))
155
+
156
+ invalid = np.sign(t1p1 - t2p1) != np.sign(t1p2 - t2p2)
157
+
158
+ # Determine pattern type
159
+ if invalid:
160
+ self.pattern_type = 0
161
+ elif is_channel:
162
+ if upper_line_dir > 0 and lower_line_dir > 0:
163
+ self.pattern_type = 1 # Ascending Channel
164
+ elif upper_line_dir < 0 and lower_line_dir < 0:
165
+ self.pattern_type = 2 # Descending Channel
166
+ else:
167
+ self.pattern_type = 3 # Ranging Channel
168
+ elif is_expanding:
169
+ if upper_line_dir > 0 and lower_line_dir > 0:
170
+ self.pattern_type = 4 # Rising Wedge (Expanding)
171
+ elif upper_line_dir < 0 and lower_line_dir < 0:
172
+ self.pattern_type = 5 # Falling Wedge (Expanding)
173
+ elif upper_line_dir > 0 and lower_line_dir < 0:
174
+ self.pattern_type = 6 # Diverging Triangle
175
+ elif upper_line_dir > 0 and lower_line_dir == 0:
176
+ self.pattern_type = 7 # Ascending Triangle (Expanding)
177
+ elif upper_line_dir == 0 and lower_line_dir < 0:
178
+ self.pattern_type = 8 # Descending Triangle (Expanding)
179
+ elif is_contracting:
180
+ if upper_line_dir > 0 and lower_line_dir > 0:
181
+ self.pattern_type = 9 # Rising Wedge (Contracting)
182
+ elif upper_line_dir < 0 and lower_line_dir < 0:
183
+ self.pattern_type = 10 # Falling Wedge (Contracting)
184
+ elif upper_line_dir < 0 and lower_line_dir > 0:
185
+ self.pattern_type = 11 # Converging Triangle
186
+ elif lower_line_dir == 0:
187
+ self.pattern_type = 12 if upper_line_dir < 0 else 1 # Descending Triangle (Contracting)
188
+ elif upper_line_dir == 0:
189
+ self.pattern_type = 13 if lower_line_dir > 0 else 2 # Ascending Triangle (Contracting)
190
+
191
+ if properties.number_of_pivots == 4:
192
+ # check flag ratio and difference
193
+ flag_pole_height = abs(self.pivots[0].diff)
194
+ flag_pole_span = self.pivots[0].index_diff
195
+ flag_span = self.pivots[-1].point.index - self.pivots[0].point.index
196
+ flag_size = max(abs(self.trend_line1.p1.price - self.trend_line2.p1.price),
197
+ abs(self.trend_line1.p2.price - self.trend_line2.p2.price))
198
+ # flag size must be smaller than its pole
199
+ if flag_size * properties.flag_ratio < flag_pole_height and \
200
+ flag_pole_span <= properties.flag_pole_span and \
201
+ flag_span <= properties.flag_span:
202
+ if self.pattern_type == 1 or self.pattern_type == 2 or self.pattern_type == 3:
203
+ # channel patterns
204
+ if self.pivots[0].direction > 0:
205
+ self.pattern_type = 16 # Bull Flag
206
+ elif self.pivots[0].direction < 0:
207
+ self.pattern_type = 17 # Bear Flag
208
+ else:
209
+ self.pattern_type = 0
210
+ elif self.pattern_type == 9 or self.pattern_type == 10 or \
211
+ self.pattern_type == 11 or self.pattern_type == 12 or \
212
+ self.pattern_type == 13:
213
+ # pennant patterns
214
+ if self.pivots[0].direction > 0:
215
+ self.pattern_type = 14 # Bull Pennant
216
+ else:
217
+ self.pattern_type = 15 # Bear Pennant
218
+ else:
219
+ self.pattern_type = 0
220
+ else:
221
+ # invalidate other pattern types
222
+ self.pattern_type = 0
223
+
224
+ return self
225
+
226
+ def is_aligned(pivots: List[Pivot], ref_pivots: List[Pivot], align_ratio: float,
227
+ flat_ratio: float) -> bool:
228
+ if len(pivots) > 3:
229
+ raise ValueError("Pivots can't be more than 3")
230
+ if len(pivots) < 3:
231
+ return True
232
+
233
+ first = pivots[0]
234
+ second = pivots[1]
235
+ third = pivots[2]
236
+ if is_same_height(first, second, ref_pivots, flat_ratio) and \
237
+ is_same_height(second, third, ref_pivots, flat_ratio):
238
+ logger.debug(f"Pivots: {first.point.index}, {second.point.index}, {third.point.index} "
239
+ f"are aligned as a horizontal line")
240
+ return True
241
+
242
+ # check the ratio of the price differences to the bar differences
243
+ if third.cross_diff == 0:
244
+ # the first and third pivots are the same height, but they are not aligned
245
+ # with the second pivot
246
+ return False
247
+ price_ratio = second.cross_diff / third.cross_diff
248
+ bar_ratio = float(second.point.index - first.point.index) / \
249
+ float(third.point.index - second.point.index)
250
+ ratio = price_ratio / bar_ratio
251
+ fit_pct = 1 - align_ratio
252
+ if ratio < 1:
253
+ aligned = ratio >= fit_pct
254
+ else:
255
+ aligned = ratio <= 1 / fit_pct
256
+ logger.debug(f"Pivots: {first.point.index}, {second.point.index}, {third.point.index} "
257
+ f"price ratio: {price_ratio:.4f}, bar ratio: {bar_ratio:.4f}, ratio: {ratio:.4f}")
258
+ return aligned
259
+
260
+ def check_if_line_cross_candle_body(line: Line, direction: float, zigzag: Zigzag,
261
+ line_end_index: Optional[int] = None) -> int:
262
+ """
263
+ Check if a line crosses the candle body
264
+ """
265
+ if line_end_index is not None and line_end_index > line.p2.index:
266
+ end_index = line_end_index
267
+ else:
268
+ end_index = line.p2.index
269
+
270
+ crosses = 0
271
+ for i in range(line.p1.index + 1, end_index):
272
+ bar_data = zigzag.get_df_data_by_index(i)
273
+ if direction > 0 and line.get_price(i) < max(bar_data['open'], bar_data['close']):
274
+ crosses += 1
275
+ elif direction < 0 and line.get_price(i) > min(bar_data['open'], bar_data['close']):
276
+ crosses += 1
277
+ return crosses
278
+
279
+ def inspect_line_by_point(line: Line, point_bar: int, direction: float,
280
+ zigzag: Zigzag) -> Tuple[bool, float]:
281
+ """
282
+ Inspect a single line against price data from a pandas DataFrame
283
+
284
+ Args:
285
+ line: Line object to inspect
286
+ point_bar: Index of the point to inspect
287
+ direction: Direction of the trend (1 for up, -1 for down)
288
+ zigzag: Zigzag calculator instance
289
+
290
+ Returns:
291
+ Tuple of (valid: bool, diff: float)
292
+ """
293
+ # Get price data from DataFrame
294
+ bar_data = zigzag.get_df_data_by_index(point_bar)
295
+
296
+ # Determine prices based on direction
297
+ line_price = line.get_price(point_bar)
298
+ line_price_diff = abs(line.p1.price - line.p2.price)
299
+ if direction > 0:
300
+ # upper line
301
+ body_high_price = max(bar_data['open'], bar_data['close'])
302
+ if line_price < body_high_price:
303
+ # invalid if line is crossing the candle body
304
+ return False, float('inf') # make the difference as large as possible
305
+ elif line_price > bar_data['high']:
306
+ # line is above the candle wick
307
+ diff = line_price - bar_data['high']
308
+ return diff < line_price_diff, diff
309
+ else:
310
+ # line is crossing the candle wick
311
+ return True, 0
312
+ else:
313
+ # lower line
314
+ body_low_price = min(bar_data['open'], bar_data['close'])
315
+ if line_price > body_low_price:
316
+ # invalid if line is crossing the candle body
317
+ return False, float('inf') # make the difference as large as possible
318
+ elif line_price < bar_data['low']:
319
+ # line is below the candle wick
320
+ diff = bar_data['low'] - line_price
321
+ return diff < line_price_diff, diff
322
+ else:
323
+ # line is crossing the candle wick
324
+ return True, 0
325
+
326
+ def inspect_pivots(pivots: List[Pivot], direction: float, properties: TrendLineProperties,
327
+ last_pivot: Pivot, zigzag: Zigzag) -> Tuple[bool, Line]:
328
+ """
329
+ Inspect multiple pivots to find the best trend line using DataFrame price data
330
+
331
+ Args:
332
+ pivots: List of pivots to create trend lines
333
+ direction: Direction of the trend
334
+ properties: TrendLineProperties object containing pattern parameters
335
+ last_pivot: The last pivot to inspect
336
+ zigzag: Zigzag calculator instance
337
+
338
+ Returns:
339
+ Tuple of (valid: bool, best_trend_line: Line)
340
+ """
341
+ if len(pivots) == 3:
342
+ # Create three possible trend lines
343
+ trend_line1 = Line(pivots[0].point, pivots[2].point) # First to last
344
+ # check if the line consisting of the first and last points crosses the candle body
345
+ if check_if_line_cross_candle_body(trend_line1, direction, zigzag) > \
346
+ properties.max_candle_body_crosses:
347
+ return False, None
348
+ # inspect line by middle point
349
+ valid1, diff1 = inspect_line_by_point(trend_line1, pivots[1].point.index,
350
+ direction, zigzag)
351
+ if valid1 and diff1 == 0:
352
+ # prefer the line connecting the first and last points
353
+ return True, trend_line1
354
+
355
+ trend_line2 = Line(pivots[0].point, pivots[1].point) # First to middle
356
+ valid2, diff2 = inspect_line_by_point(trend_line2, pivots[2].point.index,
357
+ direction, zigzag)
358
+
359
+ trend_line3 = Line(pivots[1].point, pivots[2].point) # Middle to last
360
+ valid3, diff3 = inspect_line_by_point(trend_line3, pivots[0].point.index,
361
+ direction, zigzag)
362
+
363
+ if not valid1 and not valid2 and not valid3:
364
+ return False, None
365
+
366
+ # Find the best line
367
+ if valid1:
368
+ trendline = trend_line1
369
+ elif valid2 and diff2 < diff1:
370
+ trendline = trend_line2
371
+ elif valid3 and diff3 < min(diff1, diff2):
372
+ trendline = trend_line3
373
+
374
+ return True, trendline
375
+ else:
376
+ # For 2 points, simply create one trend line
377
+ trend_line = Line(pivots[0].point, pivots[1].point)
378
+ valid = check_if_line_cross_candle_body(
379
+ trend_line, direction, zigzag, last_pivot.point.index) <= \
380
+ properties.max_candle_body_crosses
381
+ return valid, trend_line
382
+
383
+ def find_trend_lines(zigzag: Zigzag, offset: int, properties: TrendLineProperties,
384
+ patterns: List[TrendLinePattern]) -> bool:
385
+ """
386
+ Find patterns using DataFrame price data
387
+
388
+ Args:
389
+ zigzag: ZigZag calculator instance
390
+ offset: Offset to start searching for pivots
391
+ properties: Scan properties
392
+ patterns: List to store found patterns
393
+
394
+ Returns:
395
+ int: Index of the pivot that was used to find the pattern
396
+ """
397
+ # Get pivots
398
+ if properties.number_of_pivots < 4 or properties.number_of_pivots > 6:
399
+ raise ValueError("Number of pivots must be between 4 and 6")
400
+
401
+ pivots = []
402
+ min_pivots = get_pivots_from_zigzag(zigzag, pivots, offset, properties.number_of_pivots)
403
+ if min_pivots != properties.number_of_pivots:
404
+ return False
405
+
406
+ # Validate pattern
407
+ # Create point arrays for trend lines
408
+ trend_pivots1 = ([pivots[0], pivots[2]]
409
+ if properties.number_of_pivots == 4
410
+ else [pivots[0], pivots[2], pivots[4]])
411
+ trend_pivots2 = ([pivots[1], pivots[3], pivots[5]]
412
+ if properties.number_of_pivots == 6
413
+ else [pivots[1], pivots[3]])
414
+
415
+ if not is_aligned(trend_pivots1, pivots, properties.align_ratio,
416
+ properties.flat_ratio) or \
417
+ not is_aligned(trend_pivots2, pivots, properties.align_ratio,
418
+ properties.flat_ratio):
419
+ return False
420
+
421
+ # Validate trend lines using DataFrame
422
+ valid1, trend_line1 = inspect_pivots(trend_pivots1,
423
+ np.sign(trend_pivots1[0].direction),
424
+ properties, pivots[-1], zigzag)
425
+ valid2, trend_line2 = inspect_pivots(trend_pivots2,
426
+ np.sign(trend_pivots2[0].direction),
427
+ properties, pivots[-1], zigzag)
428
+
429
+ if valid1 and valid2:
430
+ index_delta = pivots[-1].point.index - pivots[0].point.index + 1
431
+ if index_delta < properties.min_periods_lapsed and \
432
+ properties.number_of_pivots >= 5:
433
+ # only consider patterns with enough time lapsed
434
+ return False
435
+
436
+ # Create pattern
437
+ pattern = TrendLinePattern(
438
+ pivots=pivots,
439
+ trend_line1=trend_line1,
440
+ trend_line2=trend_line2,
441
+ ).resolve(properties)
442
+
443
+ # Process pattern (resolve type, check if allowed, etc.)
444
+ return pattern.process_pattern(properties, patterns)
445
+ else:
446
+ return False
@@ -0,0 +1,228 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from .line import Pivot, Point
7
+
8
+ import logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+ @dataclass
12
+ class Zigzag:
13
+ def __init__(self, backcandles: int = 5, forwardcandles: int = 5,
14
+ pivot_limit: int = 20, offset: int = 0, level: int = 0):
15
+ self.backcandles = backcandles
16
+ self.forwardcandles = forwardcandles
17
+ self.pivot_limit = pivot_limit
18
+ self.offset = offset
19
+ self.level = level
20
+ self.zigzag_pivots: List[Pivot] = []
21
+ self.df = None
22
+
23
+ def update_pivot_properties(self, pivot: Pivot) -> 'Zigzag':
24
+ """
25
+ Update the properties of the pivot
26
+ """
27
+ if len(self.zigzag_pivots) > 1:
28
+ dir = np.sign(pivot.direction)
29
+ value = pivot.point.price
30
+ last_pivot = self.zigzag_pivots[1]
31
+ last_value = last_pivot.point.price
32
+ if last_pivot.point.index == pivot.point.index:
33
+ raise ValueError(f"Last pivot index {last_pivot.point.index} "
34
+ f"is the same as current pivot index {pivot.point.index}")
35
+ pivot.index_diff = pivot.point.index - last_pivot.point.index
36
+
37
+ # Calculate difference between last and current pivot
38
+ pivot.diff = value - last_value
39
+ if len(self.zigzag_pivots) > 2:
40
+ llast_pivot = self.zigzag_pivots[2]
41
+ llast_value = llast_pivot.point.price
42
+ # Calculate slope between last and current pivot
43
+ pivot.cross_diff = value - llast_value
44
+ # Determine if trend is strong (2) or weak (1)
45
+ new_dir = dir * 2 if dir * value > dir * llast_value else dir
46
+ pivot.direction = int(new_dir)
47
+
48
+ def add_new_pivot(self, pivot: Pivot) -> 'Zigzag':
49
+ """
50
+ Add a new pivot to the zigzag
51
+
52
+ Args:
53
+ pivot: Pivot object to add
54
+
55
+ Returns:
56
+ self: Returns zigzag object for method chaining
57
+
58
+ Raises:
59
+ ValueError: If direction mismatch with last pivot
60
+ """
61
+ if len(self.zigzag_pivots) >= 1:
62
+ # Check direction mismatch
63
+ if np.sign(self.zigzag_pivots[0].direction) == np.sign(pivot.direction):
64
+ raise ValueError('Direction mismatch')
65
+
66
+ # Insert at beginning and maintain max size
67
+ self.zigzag_pivots.insert(0, pivot)
68
+ self.update_pivot_properties(pivot)
69
+
70
+ if len(self.zigzag_pivots) > self.pivot_limit:
71
+ logger.warning(f"Warning: pivots exceeded limit {self.pivot_limit}, "
72
+ f"popping pivot {self.zigzag_pivots[-1].point.index}")
73
+ self.zigzag_pivots.pop()
74
+
75
+ return self
76
+
77
+ def calculate(self, df: pd.DataFrame, offset: Optional[int] = None) -> 'Zigzag':
78
+ """
79
+ Calculate zigzag pivots from DataFrame
80
+
81
+ Args:
82
+ df: DataFrame with 'high' and 'low' columns
83
+ offset: Offset to apply to the dataframe index
84
+ Returns:
85
+ self: Returns zigzag object for method chaining
86
+ """
87
+ if offset is not None:
88
+ self.offset = offset
89
+
90
+ # rescale the dataframe using the max and low prices in the range
91
+ if df.get('high') is None or df.get('low') is None:
92
+ raise ValueError("High and low prices not found in dataframe")
93
+
94
+ self.zigzag_pivots = []
95
+
96
+ highs, lows = window_peaks(df, self.backcandles, self.forwardcandles)
97
+
98
+ # Calculate pivot highs
99
+ pivot_highs = df['high'].where((df['high'] == highs))
100
+
101
+ # Calculate pivot lows
102
+ pivot_lows = df['low'].where((df['low'] == lows))
103
+
104
+ # Process pivot points into zigzag
105
+ last_pivot_price = None
106
+ last_pivot_direction = 0
107
+
108
+ for i in range(len(df)):
109
+ if not (pd.isna(pivot_highs.iloc[i]) and pd.isna(pivot_lows.iloc[i])):
110
+ current_index = i + self.offset
111
+ current_time = df.index[i]
112
+ take_high = True
113
+ if not pd.isna(pivot_highs.iloc[i]) and not pd.isna(pivot_lows.iloc[i]):
114
+ # both high and low pivot, take the more extreme one
115
+ if last_pivot_price is not None:
116
+ assert last_pivot_direction != 0
117
+ if last_pivot_direction == 1:
118
+ if pivot_highs.iloc[i] <= last_pivot_price:
119
+ # the current pivot high is lower than the last pivot high, take low instead
120
+ take_high = False
121
+ else:
122
+ if pivot_lows.iloc[i] < last_pivot_price:
123
+ # the current pivot low is lower than the last pivot low, take low instead
124
+ take_high = False
125
+ elif pd.isna(pivot_highs.iloc[i]):
126
+ take_high = False
127
+
128
+ if take_high:
129
+ current_price = pivot_highs.iloc[i]
130
+ current_direction = 1 # bullish
131
+ else:
132
+ current_price = pivot_lows.iloc[i]
133
+ current_direction = -1 # bearish
134
+
135
+ # Create and add pivot if valid
136
+ if last_pivot_price is None or last_pivot_direction != current_direction:
137
+ new_pivot = Pivot(
138
+ point=Point(
139
+ price=current_price,
140
+ index=current_index,
141
+ time=current_time
142
+ ),
143
+ direction=current_direction
144
+ )
145
+
146
+ self.add_new_pivot(new_pivot)
147
+ last_pivot_price = current_price
148
+ last_pivot_direction = current_direction
149
+
150
+ # Update last pivot if same direction but more extreme
151
+ elif ((current_direction == 1 and current_price > last_pivot_price) or
152
+ (current_direction == -1 and current_price < last_pivot_price)):
153
+ # Update the last pivot
154
+ last_pivot = self.zigzag_pivots[0]
155
+ last_pivot.point.price = current_price
156
+ last_pivot.point.index = current_index
157
+ last_pivot.point.time = current_time
158
+ self.update_pivot_properties(last_pivot)
159
+ last_pivot_price = current_price
160
+
161
+ # record the dataframe
162
+ self.df = df
163
+
164
+ return self
165
+
166
+ def get_pivot_by_index(self, index: int) -> Optional[Pivot]:
167
+ """Get pivot at specific index"""
168
+ for i in range(len(self.zigzag_pivots)):
169
+ current_pivot = self.zigzag_pivots[len(self.zigzag_pivots) - i - 1]
170
+ if current_pivot.point.index == index:
171
+ return current_pivot
172
+ return None
173
+
174
+ def get_pivot(self, offset: int) -> Optional[Pivot]:
175
+ """Get pivot at specific index"""
176
+ if 0 <= offset < len(self.zigzag_pivots):
177
+ return self.zigzag_pivots[offset]
178
+ return None
179
+
180
+ def get_last_pivot(self) -> Optional[Pivot]:
181
+ """Get the most recent pivot"""
182
+ return self.zigzag_pivots[0] if self.zigzag_pivots else None
183
+
184
+ def get_df_data_by_index(self, index: int) -> pd.Series:
185
+ """Get the dataframe data at a specific index"""
186
+ if self.df is not None:
187
+ if index < self.offset or index - self.offset >= len(self.df):
188
+ raise ValueError(f"Index {index} is out of bounds")
189
+ return self.df.iloc[index - self.offset]
190
+ raise ValueError("DataFrame not calculated")
191
+
192
+ def window_peaks(data, before: int, after: int) -> tuple[pd.Series, pd.Series]:
193
+ """
194
+ Faster version using numpy's stride tricks
195
+
196
+ Args:
197
+ df: DataFrame with 'high' and 'low' columns
198
+ before: Number of bars before the current bar
199
+ after: Number of bars after the current bar
200
+
201
+ Returns:
202
+ pd.Series: Series of highs and lows
203
+ """
204
+ if isinstance(data, pd.DataFrame):
205
+ values_high = data["high"].values
206
+ elif isinstance(data, pd.Series):
207
+ values_high = data.values
208
+ else:
209
+ raise ValueError("Unsupported dataframe type")
210
+
211
+ if isinstance(data, pd.DataFrame):
212
+ values_low = data["low"].values
213
+ elif isinstance(data, pd.Series):
214
+ values_low = data.values
215
+ result_high = np.zeros(len(values_high))
216
+ result_low = np.zeros(len(values_low))
217
+
218
+ # Handle edges with padding
219
+ padded_high = np.pad(values_high, (before, after), mode='edge')
220
+ padded_low = np.pad(values_low, (before, after), mode='edge')
221
+
222
+ # Create rolling window view
223
+ windows_high = np.lib.stride_tricks.sliding_window_view(padded_high, before + after + 1)
224
+ windows_low = np.lib.stride_tricks.sliding_window_view(padded_low, before + after + 1)
225
+ result_high = np.max(windows_high, axis=1)
226
+ result_low = np.min(windows_low, axis=1)
227
+
228
+ return pd.Series(result_high, index=data.index), pd.Series(result_low, index=data.index)
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.3
2
+ Name: auto-chart-patterns
3
+ Version: 0.1.0
4
+ Summary: Automatically identify chart patterns from OHLC data
5
+ Project-URL: Homepage, https://github.com/FanM/auto-chart-patterns
6
+ Project-URL: Bug Tracker, https://github.com/FanM/auto-chart-patterns/issues
7
+ Author-email: Fan Mao <maofan@xsmail.com>
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.9
12
+ Requires-Dist: numpy>=1.19.0
13
+ Requires-Dist: pandas>=1.0.0
14
+ Description-Content-Type: text/markdown
15
+
16
+ # auto-chart-patterns
17
+ Automatically identify chart patterns from OHLC data
@@ -0,0 +1,11 @@
1
+ auto_chart_patterns/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ auto_chart_patterns/chart_pattern.py,sha256=Lt2Et8bqR6xsenCu0lLD2me0RF4UAUNPsRLSmtWVH30,6276
3
+ auto_chart_patterns/line.py,sha256=6Vftr-Ok3_UESoRUqjr8JDKbxyDjQztcn-_CkFHm5Fo,2383
4
+ auto_chart_patterns/reversal_patterns.py,sha256=8LX6Adz6xOsn9-3cx3dHbIW7DenDxwt4V2QSA7zp3ww,9017
5
+ auto_chart_patterns/rsi_div_patterns.py,sha256=gkDYcHysypfLaPSEpg9-yldxYqKy-kDjj9g2RFGhwWo,5876
6
+ auto_chart_patterns/trendline_patterns.py,sha256=Dmnmx0dXcsBN1kUizyiKEtwbOVp_Qb5sQJIF_L5LmuM,18811
7
+ auto_chart_patterns/zigzag.py,sha256=i9hjKyY9omeIU5DL1Mzpm08L2YM9k9bSJH6J04r6kRA,9107
8
+ auto_chart_patterns-0.1.0.dist-info/METADATA,sha256=jZ4weSPNIqxe0RDgK-1Z6JNlmf876GRXnPhpjvHzU9c,654
9
+ auto_chart_patterns-0.1.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
10
+ auto_chart_patterns-0.1.0.dist-info/licenses/LICENSE,sha256=1YoFL1mg5DVYt9jUfKKxbE98PX8YjvrPoIGhHx4RV8w,1061
11
+ auto_chart_patterns-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.26.3
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 FanM
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.