maidr 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maidr/__init__.py +2 -1
- maidr/core/plot/candlestick.py +87 -203
- maidr/core/plot/maidr_plot.py +7 -2
- maidr/core/plot/maidr_plot_factory.py +12 -11
- maidr/core/plot/mplfinance_barplot.py +115 -0
- maidr/core/plot/mplfinance_lineplot.py +146 -0
- maidr/patch/__init__.py +15 -0
- maidr/patch/mplfinance.py +213 -0
- maidr/util/__init__.py +3 -0
- maidr/util/mplfinance_utils.py +409 -0
- maidr/util/plot_detection.py +136 -0
- {maidr-1.3.0.dist-info → maidr-1.4.0.dist-info}/METADATA +1 -1
- {maidr-1.3.0.dist-info → maidr-1.4.0.dist-info}/RECORD +15 -10
- {maidr-1.3.0.dist-info → maidr-1.4.0.dist-info}/LICENSE +0 -0
- {maidr-1.3.0.dist-info → maidr-1.4.0.dist-info}/WHEEL +0 -0
maidr/__init__.py
CHANGED
maidr/core/plot/candlestick.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import matplotlib.dates as mdates
|
|
2
2
|
import numpy as np
|
|
3
3
|
from matplotlib.axes import Axes
|
|
4
|
-
from matplotlib.collections import LineCollection, PatchCollection
|
|
5
4
|
from matplotlib.patches import Rectangle
|
|
6
5
|
|
|
7
6
|
from maidr.core.enum.plot_type import PlotType
|
|
8
7
|
from maidr.core.plot.maidr_plot import MaidrPlot
|
|
8
|
+
from maidr.core.enum.maidr_key import MaidrKey
|
|
9
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class CandlestickPlot(MaidrPlot):
|
|
@@ -27,213 +28,96 @@ class CandlestickPlot(MaidrPlot):
|
|
|
27
28
|
raise ValueError("Axes list cannot be empty.")
|
|
28
29
|
super().__init__(axes[0], PlotType.CANDLESTICK)
|
|
29
30
|
|
|
31
|
+
# Store custom collections passed from mplfinance patch
|
|
32
|
+
self._maidr_wick_collection = kwargs.get("_maidr_wick_collection", None)
|
|
33
|
+
self._maidr_body_collection = kwargs.get("_maidr_body_collection", None)
|
|
34
|
+
self._maidr_date_nums = kwargs.get("_maidr_date_nums", None)
|
|
35
|
+
|
|
36
|
+
# Store the GID for proper selector generation
|
|
37
|
+
self._maidr_gid = None
|
|
38
|
+
if self._maidr_body_collection:
|
|
39
|
+
self._maidr_gid = self._maidr_body_collection.get_gid()
|
|
40
|
+
elif self._maidr_wick_collection:
|
|
41
|
+
self._maidr_gid = self._maidr_wick_collection.get_gid()
|
|
42
|
+
|
|
30
43
|
def _extract_plot_data(self) -> list[dict]:
|
|
31
|
-
"""
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
>>> data = plot_instance._extract_plot_data()
|
|
51
|
-
>>> print(data[0])
|
|
52
|
-
{
|
|
53
|
-
'value': '2021-01-01',
|
|
54
|
-
'open': 100.0,
|
|
55
|
-
'high': 100.9,
|
|
56
|
-
'low': 99.27,
|
|
57
|
-
'close': 100.75,
|
|
58
|
-
'volume': 171914,
|
|
59
|
-
}
|
|
60
|
-
"""
|
|
44
|
+
"""Extract candlestick data from the plot."""
|
|
45
|
+
|
|
46
|
+
# Get the custom collections from kwargs
|
|
47
|
+
body_collection = self._maidr_body_collection
|
|
48
|
+
wick_collection = self._maidr_wick_collection
|
|
49
|
+
|
|
50
|
+
if body_collection and wick_collection:
|
|
51
|
+
# Store the GID from the body collection for highlighting
|
|
52
|
+
self._maidr_gid = body_collection.get_gid()
|
|
53
|
+
|
|
54
|
+
# Use the original collections for highlighting
|
|
55
|
+
self._elements = [body_collection, wick_collection]
|
|
56
|
+
|
|
57
|
+
# Use the utility class to extract data
|
|
58
|
+
return MplfinanceDataExtractor.extract_candlestick_data(
|
|
59
|
+
body_collection, wick_collection, self._maidr_date_nums
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Fallback to original detection method
|
|
61
63
|
if not self.axes:
|
|
62
64
|
return []
|
|
63
65
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
body_rectangles
|
|
90
|
-
): # If we found rectangles this way, assume this is the primary body collection
|
|
91
|
-
break
|
|
92
|
-
except Exception:
|
|
93
|
-
# Could fail if collection is not iterable in the expected way or patches are not Rectangles
|
|
94
|
-
pass
|
|
95
|
-
|
|
96
|
-
if not body_rectangles:
|
|
97
|
-
for patch in ax_ohlc.patches:
|
|
98
|
-
if isinstance(patch, Rectangle):
|
|
99
|
-
body_rectangles.append(patch)
|
|
100
|
-
|
|
101
|
-
if not body_rectangles:
|
|
102
|
-
pass
|
|
103
|
-
|
|
104
|
-
ax_for_wicks: Axes | None = None
|
|
105
|
-
if len(self.axes) > 1:
|
|
106
|
-
ax_for_wicks = self.axes[1]
|
|
107
|
-
|
|
108
|
-
if ax_for_wicks:
|
|
109
|
-
# Attempt 1: Find wicks in ax_for_wicks.collections (as a LineCollection)
|
|
110
|
-
for collection in ax_for_wicks.collections:
|
|
111
|
-
if isinstance(collection, LineCollection):
|
|
112
|
-
segments = collection.get_segments()
|
|
113
|
-
# Check if the collection contains segments and the first segment looks like a vertical line
|
|
114
|
-
if segments is not None and len(segments) > 0:
|
|
115
|
-
first_segment = segments[0]
|
|
116
|
-
if (
|
|
117
|
-
len(first_segment) == 2 # Segment consists of two points
|
|
118
|
-
and len(first_segment[0]) == 2 # First point has (x, y)
|
|
119
|
-
and len(first_segment[1]) == 2 # Second point has (x, y)
|
|
120
|
-
and np.isclose(
|
|
121
|
-
first_segment[0][0], first_segment[1][0]
|
|
122
|
-
) # X-coordinates are close (vertical)
|
|
123
|
-
):
|
|
124
|
-
wick_collection = collection
|
|
125
|
-
break # Found a suitable LineCollection
|
|
126
|
-
|
|
127
|
-
# Attempt 2: If no LineCollection found, try to find wicks from individual Line2D objects in ax_for_wicks.get_lines()
|
|
128
|
-
if not wick_collection and hasattr(ax_for_wicks, "get_lines"):
|
|
129
|
-
potential_wick_segments = []
|
|
130
|
-
for line in ax_for_wicks.get_lines(): # Iterate over Line2D objects
|
|
131
|
-
x_data, y_data = line.get_data()
|
|
132
|
-
# A wick is typically a vertical line defined by two points.
|
|
133
|
-
if len(x_data) == 2 and len(y_data) == 2:
|
|
134
|
-
if np.isclose(x_data[0], x_data[1]): # Check for verticality
|
|
135
|
-
# Create a segment in the format [[x1, y1], [x2, y2]]
|
|
136
|
-
segment = [
|
|
137
|
-
[x_data[0], y_data[0]],
|
|
138
|
-
[x_data[1], y_data[1]],
|
|
139
|
-
]
|
|
140
|
-
potential_wick_segments.append(segment)
|
|
141
|
-
|
|
142
|
-
if potential_wick_segments:
|
|
143
|
-
# If wick segments were found from individual lines,
|
|
144
|
-
# create a new LineCollection to hold them.
|
|
145
|
-
# This allows the downstream processing logic
|
|
146
|
-
# for wicks to remain consistent.
|
|
147
|
-
# Basic properties are set; color/linestyle
|
|
148
|
-
# are defaults and may not match
|
|
149
|
-
# the original plot's styling if that
|
|
150
|
-
# were relevant for segment extraction.
|
|
151
|
-
wick_collection = LineCollection(
|
|
152
|
-
potential_wick_segments,
|
|
153
|
-
colors="k", # Default color for the temporary collection
|
|
154
|
-
linestyles="solid", # Default linestyle
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
# Process wicks into a map: x_coordinate -> (low_price, high_price)
|
|
158
|
-
wick_segments_map: dict[float, tuple[float, float]] = {}
|
|
159
|
-
if wick_collection:
|
|
160
|
-
for seg in wick_collection.get_segments():
|
|
161
|
-
if len(seg) == 2 and len(seg[0]) == 2 and len(seg[1]) == 2:
|
|
162
|
-
# Ensure x-coordinates are (nearly) identical for a vertical wick line
|
|
163
|
-
if np.isclose(seg[0][0], seg[1][0]):
|
|
164
|
-
x_coord = seg[0][0] # Matplotlib date number
|
|
165
|
-
low_price = min(seg[0][1], seg[1][1])
|
|
166
|
-
high_price = max(seg[0][1], seg[1][1])
|
|
167
|
-
wick_segments_map[x_coord] = (low_price, high_price)
|
|
168
|
-
|
|
169
|
-
body_rectangles.sort(key=lambda r: r.get_x())
|
|
170
|
-
|
|
171
|
-
for rect in body_rectangles:
|
|
172
|
-
x_left = rect.get_x()
|
|
173
|
-
width = rect.get_width()
|
|
174
|
-
x_center_num = x_left + width / 2.0
|
|
175
|
-
|
|
176
|
-
try:
|
|
177
|
-
date_dt = mdates.num2date(x_center_num)
|
|
178
|
-
date_str = date_dt.strftime("%Y-%m-%d")
|
|
179
|
-
except ValueError:
|
|
180
|
-
date_str = f"raw_date_{x_center_num:.2f}"
|
|
181
|
-
|
|
182
|
-
y_bottom = rect.get_y()
|
|
183
|
-
height = rect.get_height()
|
|
184
|
-
face_color = rect.get_facecolor() # RGBA tuple
|
|
185
|
-
|
|
186
|
-
# Infer open and close prices
|
|
187
|
-
# Heuristic: Green component > Red component for an "up" candle (close > open)
|
|
188
|
-
# This assumes standard green for up, red for down.
|
|
189
|
-
# A more robust method would involve knowing the exact up/down colors used.
|
|
190
|
-
is_up_candle = (
|
|
191
|
-
face_color[1] > face_color[0]
|
|
192
|
-
) # Compare Green and Red components
|
|
193
|
-
|
|
194
|
-
if is_up_candle: # Typically green: price went up
|
|
195
|
-
open_price = y_bottom
|
|
196
|
-
close_price = y_bottom + height
|
|
197
|
-
else: # Typically red: price went down (or other color)
|
|
198
|
-
close_price = y_bottom
|
|
199
|
-
open_price = y_bottom + height
|
|
200
|
-
|
|
201
|
-
matched_wick_data = None
|
|
202
|
-
closest_wick_x = None
|
|
203
|
-
min_diff = float("inf")
|
|
204
|
-
|
|
205
|
-
for wick_x, prices in wick_segments_map.items():
|
|
206
|
-
diff = abs(wick_x - x_center_num)
|
|
207
|
-
if diff < min_diff:
|
|
208
|
-
min_diff = diff
|
|
209
|
-
closest_wick_x = wick_x
|
|
210
|
-
|
|
211
|
-
# Tolerance for matching wick x-coordinate (e.g., 10% of candle width)
|
|
212
|
-
if closest_wick_x is not None and min_diff < (width * 0.1):
|
|
213
|
-
matched_wick_data = wick_segments_map[closest_wick_x]
|
|
214
|
-
|
|
215
|
-
if matched_wick_data:
|
|
216
|
-
low_price, high_price = matched_wick_data
|
|
217
|
-
# Ensure high >= max(open,close) and low <= min(open,close)
|
|
218
|
-
high_price = max(high_price, open_price, close_price)
|
|
219
|
-
low_price = min(low_price, open_price, close_price)
|
|
220
|
-
else:
|
|
221
|
-
# Fallback if no wick found: high is max(open,close), low is min(open,close)
|
|
222
|
-
high_price = max(open_price, close_price)
|
|
223
|
-
low_price = min(open_price, close_price)
|
|
224
|
-
|
|
225
|
-
plot_data.append(
|
|
226
|
-
{
|
|
227
|
-
"value": date_str,
|
|
228
|
-
"open": open_price,
|
|
229
|
-
"high": high_price,
|
|
230
|
-
"low": low_price,
|
|
231
|
-
"close": close_price,
|
|
232
|
-
"volume": 0,
|
|
233
|
-
}
|
|
66
|
+
ax_ohlc = self.axes[0]
|
|
67
|
+
candles = []
|
|
68
|
+
|
|
69
|
+
# Look for Rectangle patches (original_flavor candlestick)
|
|
70
|
+
body_rectangles = []
|
|
71
|
+
for patch in ax_ohlc.patches:
|
|
72
|
+
if isinstance(patch, Rectangle):
|
|
73
|
+
body_rectangles.append(patch)
|
|
74
|
+
|
|
75
|
+
if body_rectangles:
|
|
76
|
+
# Set elements for highlighting
|
|
77
|
+
self._elements = body_rectangles
|
|
78
|
+
|
|
79
|
+
# Generate a GID for highlighting if none exists
|
|
80
|
+
if not self._maidr_gid:
|
|
81
|
+
import uuid
|
|
82
|
+
|
|
83
|
+
self._maidr_gid = f"maidr-{uuid.uuid4()}"
|
|
84
|
+
# Set GID on all rectangles
|
|
85
|
+
for rect in body_rectangles:
|
|
86
|
+
rect.set_gid(self._maidr_gid)
|
|
87
|
+
|
|
88
|
+
# Use the utility class to extract data
|
|
89
|
+
return MplfinanceDataExtractor.extract_rectangle_candlestick_data(
|
|
90
|
+
body_rectangles, self._maidr_date_nums
|
|
234
91
|
)
|
|
235
|
-
|
|
236
|
-
return
|
|
92
|
+
|
|
93
|
+
return []
|
|
237
94
|
|
|
238
95
|
def _extract_axes_data(self) -> dict:
|
|
239
96
|
return {}
|
|
97
|
+
|
|
98
|
+
def _get_selector(self) -> str:
|
|
99
|
+
"""Return the CSS selector for highlighting candlestick elements in the SVG output."""
|
|
100
|
+
# Use the stored GID if available, otherwise fall back to generic selector
|
|
101
|
+
if self._maidr_gid:
|
|
102
|
+
# Use the full GID as the id attribute (since that's what's in the SVG)
|
|
103
|
+
selector = (
|
|
104
|
+
f"g[id='{self._maidr_gid}'] > path, g[id='{self._maidr_gid}'] > rect"
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
selector = "g[maidr='true'] > path, g[maidr='true'] > rect"
|
|
108
|
+
return selector
|
|
109
|
+
|
|
110
|
+
def render(self) -> dict:
|
|
111
|
+
"""Initialize the MAIDR schema dictionary with basic plot information."""
|
|
112
|
+
maidr_schema = {
|
|
113
|
+
MaidrKey.TYPE: self.type,
|
|
114
|
+
MaidrKey.TITLE: self.ax.get_title(),
|
|
115
|
+
MaidrKey.AXES: self._extract_axes_data(),
|
|
116
|
+
MaidrKey.DATA: self._extract_plot_data(),
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# Include selector only if the plot supports highlighting.
|
|
120
|
+
if self._support_highlighting:
|
|
121
|
+
maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
|
|
122
|
+
|
|
123
|
+
return maidr_schema
|
maidr/core/plot/maidr_plot.py
CHANGED
|
@@ -42,8 +42,13 @@ class MaidrPlot(ABC):
|
|
|
42
42
|
self._support_highlighting = True
|
|
43
43
|
self._elements = []
|
|
44
44
|
ss = self.ax.get_subplotspec()
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
# Handle cases where subplotspec is None (dynamically created axes)
|
|
46
|
+
if ss is not None:
|
|
47
|
+
self.row_index = ss.rowspan.start
|
|
48
|
+
self.col_index = ss.colspan.start
|
|
49
|
+
else:
|
|
50
|
+
self.row_index = 0
|
|
51
|
+
self.col_index = 0
|
|
47
52
|
|
|
48
53
|
# MAIDR data
|
|
49
54
|
self.type = plot_type
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from matplotlib.axes import Axes
|
|
4
|
-
|
|
5
4
|
from maidr.core.enum import PlotType
|
|
6
5
|
from maidr.core.plot.barplot import BarPlot
|
|
7
6
|
from maidr.core.plot.boxplot import BoxPlot
|
|
@@ -13,6 +12,9 @@ from maidr.core.plot.lineplot import MultiLinePlot
|
|
|
13
12
|
from maidr.core.plot.maidr_plot import MaidrPlot
|
|
14
13
|
from maidr.core.plot.scatterplot import ScatterPlot
|
|
15
14
|
from maidr.core.plot.regplot import SmoothPlot
|
|
15
|
+
from maidr.core.plot.mplfinance_barplot import MplfinanceBarPlot
|
|
16
|
+
from maidr.core.plot.mplfinance_lineplot import MplfinanceLinePlot
|
|
17
|
+
from maidr.util.plot_detection import PlotDetectionUtils
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class MaidrPlotFactory:
|
|
@@ -38,17 +40,13 @@ class MaidrPlotFactory:
|
|
|
38
40
|
single_ax = ax
|
|
39
41
|
|
|
40
42
|
if plot_type == PlotType.CANDLESTICK:
|
|
41
|
-
|
|
42
|
-
# If ax is a list of lists, flatten it
|
|
43
|
-
if ax and isinstance(ax[0], list):
|
|
44
|
-
axes = ax[0] # Take the first inner list
|
|
45
|
-
else:
|
|
46
|
-
axes = ax # Use the list as-is
|
|
47
|
-
else:
|
|
48
|
-
axes = [ax] # Wrap single axes in list
|
|
43
|
+
axes = PlotDetectionUtils.get_candlestick_axes(ax)
|
|
49
44
|
return CandlestickPlot(axes, **kwargs)
|
|
50
45
|
elif PlotType.BAR == plot_type or PlotType.COUNT == plot_type:
|
|
51
|
-
|
|
46
|
+
if PlotDetectionUtils.is_mplfinance_bar_plot(**kwargs):
|
|
47
|
+
return MplfinanceBarPlot(single_ax, **kwargs)
|
|
48
|
+
else:
|
|
49
|
+
return BarPlot(single_ax)
|
|
52
50
|
elif PlotType.BOX == plot_type:
|
|
53
51
|
return BoxPlot(single_ax, **kwargs)
|
|
54
52
|
elif PlotType.HEAT == plot_type:
|
|
@@ -56,7 +54,10 @@ class MaidrPlotFactory:
|
|
|
56
54
|
elif PlotType.HIST == plot_type:
|
|
57
55
|
return HistPlot(single_ax)
|
|
58
56
|
elif PlotType.LINE == plot_type:
|
|
59
|
-
|
|
57
|
+
if PlotDetectionUtils.is_mplfinance_line_plot(single_ax, **kwargs):
|
|
58
|
+
return MplfinanceLinePlot(single_ax, **kwargs)
|
|
59
|
+
else:
|
|
60
|
+
return MultiLinePlot(single_ax, **kwargs)
|
|
60
61
|
elif PlotType.SCATTER == plot_type:
|
|
61
62
|
return ScatterPlot(single_ax)
|
|
62
63
|
elif PlotType.DODGED == plot_type or PlotType.STACKED == plot_type:
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from matplotlib.patches import Rectangle
|
|
5
|
+
|
|
6
|
+
from maidr.core.enum import PlotType
|
|
7
|
+
from maidr.core.plot import MaidrPlot
|
|
8
|
+
from maidr.exception import ExtractionError
|
|
9
|
+
from maidr.util.mixin import (
|
|
10
|
+
ContainerExtractorMixin,
|
|
11
|
+
DictMergerMixin,
|
|
12
|
+
LevelExtractorMixin,
|
|
13
|
+
)
|
|
14
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MplfinanceBarPlot(
|
|
18
|
+
MaidrPlot, ContainerExtractorMixin, LevelExtractorMixin, DictMergerMixin
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Specialized bar plot class for mplfinance volume bars.
|
|
22
|
+
|
|
23
|
+
This class handles the extraction and processing of volume data from mplfinance
|
|
24
|
+
plots, including proper date conversion and data validation.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, ax: Axes, **kwargs) -> None:
|
|
28
|
+
super().__init__(ax, PlotType.BAR)
|
|
29
|
+
# Store custom patches passed from mplfinance patch
|
|
30
|
+
self._custom_patches = kwargs.get("_maidr_patches", None)
|
|
31
|
+
# Store date numbers for volume bars (from mplfinance)
|
|
32
|
+
self._maidr_date_nums = kwargs.get("_maidr_date_nums", None)
|
|
33
|
+
|
|
34
|
+
def _extract_plot_data(self) -> list:
|
|
35
|
+
"""Extract data from mplfinance volume patches."""
|
|
36
|
+
if self._custom_patches:
|
|
37
|
+
return self._extract_volume_patches_data(self._custom_patches)
|
|
38
|
+
|
|
39
|
+
# Fallback to original bar plot logic if no custom patches
|
|
40
|
+
plot = self.extract_container(
|
|
41
|
+
self.ax, ContainerExtractorMixin, include_all=True
|
|
42
|
+
)
|
|
43
|
+
data = self._extract_bar_container_data(plot)
|
|
44
|
+
levels = self.extract_level(self.ax)
|
|
45
|
+
formatted_data = []
|
|
46
|
+
combined_data = list(
|
|
47
|
+
zip(levels, data) if plot[0].orientation == "vertical" else zip(data, levels) # type: ignore
|
|
48
|
+
)
|
|
49
|
+
if combined_data: # type: ignore
|
|
50
|
+
for x, y in combined_data: # type: ignore
|
|
51
|
+
formatted_data.append({"x": x, "y": y})
|
|
52
|
+
return formatted_data
|
|
53
|
+
if len(formatted_data) == 0:
|
|
54
|
+
raise ExtractionError(self.type, plot)
|
|
55
|
+
if data is None:
|
|
56
|
+
raise ExtractionError(self.type, plot)
|
|
57
|
+
|
|
58
|
+
return data
|
|
59
|
+
|
|
60
|
+
def _extract_volume_patches_data(self, volume_patches: list[Rectangle]) -> list:
|
|
61
|
+
"""
|
|
62
|
+
Extract data from volume Rectangle patches (used by mplfinance volume bars).
|
|
63
|
+
"""
|
|
64
|
+
if not volume_patches:
|
|
65
|
+
return []
|
|
66
|
+
|
|
67
|
+
# Sort patches by x-coordinate to maintain order
|
|
68
|
+
sorted_patches = sorted(volume_patches, key=lambda p: p.get_x())
|
|
69
|
+
|
|
70
|
+
# Set elements for highlighting (use the patches directly)
|
|
71
|
+
self._elements = sorted_patches
|
|
72
|
+
|
|
73
|
+
# Use the utility class to extract data
|
|
74
|
+
return MplfinanceDataExtractor.extract_volume_data(
|
|
75
|
+
sorted_patches, self._maidr_date_nums
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _extract_bar_container_data(self, plot: list | None) -> list | None:
|
|
79
|
+
"""Fallback method for regular bar containers."""
|
|
80
|
+
if plot is None:
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
# Since v0.13, Seaborn has transitioned from using `list[Patch]` to
|
|
84
|
+
# `list[BarContainers] for plotting bar plots.
|
|
85
|
+
# So, extract data correspondingly based on the level.
|
|
86
|
+
# Flatten all the `list[BarContainer]` to `list[Patch]`.
|
|
87
|
+
patches = [patch for container in plot for patch in container.patches]
|
|
88
|
+
level = self.extract_level(self.ax)
|
|
89
|
+
if level is None or len(level) == 0: # type: ignore
|
|
90
|
+
level = ["" for _ in range(len(patches))] # type: ignore
|
|
91
|
+
|
|
92
|
+
if len(patches) != len(level):
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
self._elements.extend(patches)
|
|
96
|
+
|
|
97
|
+
return [float(patch.get_height()) for patch in patches]
|
|
98
|
+
|
|
99
|
+
def _extract_axes_data(self) -> dict:
|
|
100
|
+
"""Extract axes data with mplfinance-specific cleaning."""
|
|
101
|
+
ax_data = super()._extract_axes_data()
|
|
102
|
+
|
|
103
|
+
# For mplfinance volume plots, clean up the y-axis label
|
|
104
|
+
if self._custom_patches:
|
|
105
|
+
y_label = ax_data.get("y")
|
|
106
|
+
if y_label:
|
|
107
|
+
ax_data["y"] = MplfinanceDataExtractor.clean_axis_label(y_label)
|
|
108
|
+
|
|
109
|
+
return ax_data
|
|
110
|
+
|
|
111
|
+
def _get_selector(self) -> str:
|
|
112
|
+
"""Return the CSS selector for highlighting bar elements in the SVG output."""
|
|
113
|
+
# Use the standard working selector that gets replaced with UUID by Maidr class
|
|
114
|
+
# This works for both original bar plots and mplfinance volume bars
|
|
115
|
+
return "g[maidr='true'] > path"
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from matplotlib.lines import Line2D
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from maidr.core.enum.maidr_key import MaidrKey
|
|
8
|
+
from maidr.core.enum.plot_type import PlotType
|
|
9
|
+
from maidr.core.plot.maidr_plot import MaidrPlot
|
|
10
|
+
from maidr.exception.extraction_error import ExtractionError
|
|
11
|
+
from maidr.util.mixin.extractor_mixin import LineExtractorMixin
|
|
12
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
13
|
+
import uuid
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MplfinanceLinePlot(MaidrPlot, LineExtractorMixin):
|
|
17
|
+
"""
|
|
18
|
+
Specialized line plot class for mplfinance moving averages.
|
|
19
|
+
|
|
20
|
+
This class handles the extraction and processing of moving average data from mplfinance
|
|
21
|
+
plots, including proper date conversion, NaN filtering, and moving average period detection.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, ax: Axes, **kwargs):
|
|
25
|
+
super().__init__(ax, PlotType.LINE)
|
|
26
|
+
|
|
27
|
+
def _get_selector(self) -> Union[str, List[str]]:
|
|
28
|
+
"""Return selectors for all lines that have data."""
|
|
29
|
+
all_lines = self.ax.get_lines()
|
|
30
|
+
if not all_lines:
|
|
31
|
+
return ["g[maidr='true'] > path"]
|
|
32
|
+
|
|
33
|
+
selectors = []
|
|
34
|
+
for line in all_lines:
|
|
35
|
+
# Only create selectors for lines that have data
|
|
36
|
+
xydata = line.get_xydata()
|
|
37
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
38
|
+
continue
|
|
39
|
+
gid = line.get_gid()
|
|
40
|
+
if gid:
|
|
41
|
+
selectors.append(f"g[id='{gid}'] path")
|
|
42
|
+
else:
|
|
43
|
+
selectors.append("g[maidr='true'] > path")
|
|
44
|
+
|
|
45
|
+
if not selectors:
|
|
46
|
+
return ["g[maidr='true'] > path"]
|
|
47
|
+
|
|
48
|
+
return selectors
|
|
49
|
+
|
|
50
|
+
def _extract_plot_data(self) -> Union[List[List[dict]], None]:
|
|
51
|
+
"""Extract data from mplfinance moving average lines."""
|
|
52
|
+
data = self._extract_line_data()
|
|
53
|
+
|
|
54
|
+
if data is None:
|
|
55
|
+
raise ExtractionError(self.type, None)
|
|
56
|
+
|
|
57
|
+
return data
|
|
58
|
+
|
|
59
|
+
def _extract_line_data(self) -> Union[List[List[dict]], None]:
|
|
60
|
+
"""
|
|
61
|
+
Extract data from all line objects and return as separate arrays.
|
|
62
|
+
|
|
63
|
+
This method handles mplfinance-specific logic including:
|
|
64
|
+
- Date conversion from matplotlib date numbers
|
|
65
|
+
- NaN filtering for moving averages
|
|
66
|
+
- Moving average period detection
|
|
67
|
+
- Proper data validation
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
list[list[dict]] | None
|
|
72
|
+
List of lists, where each inner list contains dictionaries with x,y coordinates
|
|
73
|
+
and line identifiers for one line, or None if the plot data is invalid.
|
|
74
|
+
"""
|
|
75
|
+
all_lines = self.ax.get_lines()
|
|
76
|
+
if not all_lines:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
all_lines_data = []
|
|
80
|
+
|
|
81
|
+
for line_idx, line in enumerate(all_lines):
|
|
82
|
+
xydata = line.get_xydata()
|
|
83
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
self._elements.append(line)
|
|
87
|
+
|
|
88
|
+
# Assign unique GID to each line if not already set
|
|
89
|
+
if line.get_gid() is None:
|
|
90
|
+
unique_gid = f"maidr-{uuid.uuid4()}"
|
|
91
|
+
line.set_gid(unique_gid)
|
|
92
|
+
|
|
93
|
+
label: str = line.get_label() # type: ignore
|
|
94
|
+
line_data = []
|
|
95
|
+
|
|
96
|
+
# Check if this line has date numbers from mplfinance
|
|
97
|
+
date_nums = getattr(line, "_maidr_date_nums", None)
|
|
98
|
+
|
|
99
|
+
# Convert xydata to list of points
|
|
100
|
+
for i, (x, y) in enumerate(line.get_xydata()): # type: ignore
|
|
101
|
+
# Skip points with NaN or inf values to prevent JSON parsing errors
|
|
102
|
+
if np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y):
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Handle x-value conversion - could be string (date) or numeric
|
|
106
|
+
if isinstance(x, str):
|
|
107
|
+
x_value = x # Keep string as-is (for dates)
|
|
108
|
+
else:
|
|
109
|
+
# Check if we have date numbers from mplfinance
|
|
110
|
+
if date_nums is not None and i < len(date_nums):
|
|
111
|
+
# Use the date number to convert to date string
|
|
112
|
+
date_num = float(date_nums[i])
|
|
113
|
+
x_value = self._convert_x_to_date(date_num)
|
|
114
|
+
else:
|
|
115
|
+
x_value = float(x) # Convert numeric to float
|
|
116
|
+
|
|
117
|
+
point_data = {
|
|
118
|
+
MaidrKey.X: x_value,
|
|
119
|
+
MaidrKey.Y: float(y),
|
|
120
|
+
MaidrKey.FILL: (label if not label.startswith("_child") else ""),
|
|
121
|
+
}
|
|
122
|
+
line_data.append(point_data)
|
|
123
|
+
|
|
124
|
+
if line_data:
|
|
125
|
+
all_lines_data.append(line_data)
|
|
126
|
+
|
|
127
|
+
return all_lines_data if all_lines_data else None
|
|
128
|
+
|
|
129
|
+
def _convert_x_to_date(self, x_value: float) -> str:
|
|
130
|
+
"""
|
|
131
|
+
Convert x-coordinate to date string for mplfinance plots.
|
|
132
|
+
|
|
133
|
+
This method uses the MplfinanceDataExtractor utility to convert
|
|
134
|
+
matplotlib date numbers to proper date strings.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
x_value : float
|
|
139
|
+
The x-coordinate value (matplotlib date number)
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
str
|
|
144
|
+
Date string in YYYY-MM-DD format
|
|
145
|
+
"""
|
|
146
|
+
return MplfinanceDataExtractor._convert_date_num_to_string(x_value)
|
maidr/patch/__init__.py
CHANGED