maidr 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maidr/__init__.py +2 -1
- maidr/core/maidr.py +74 -7
- maidr/core/plot/candlestick.py +112 -204
- maidr/core/plot/maidr_plot.py +7 -2
- maidr/core/plot/maidr_plot_factory.py +12 -11
- maidr/core/plot/mplfinance_barplot.py +139 -0
- maidr/core/plot/mplfinance_lineplot.py +273 -0
- maidr/patch/__init__.py +15 -0
- maidr/patch/mplfinance.py +215 -0
- maidr/util/__init__.py +3 -0
- maidr/util/mplfinance_utils.py +409 -0
- maidr/util/plot_detection.py +136 -0
- {maidr-1.3.0.dist-info → maidr-1.4.1.dist-info}/METADATA +1 -1
- {maidr-1.3.0.dist-info → maidr-1.4.1.dist-info}/RECORD +16 -11
- {maidr-1.3.0.dist-info → maidr-1.4.1.dist-info}/LICENSE +0 -0
- {maidr-1.3.0.dist-info → maidr-1.4.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from matplotlib.patches import Rectangle
|
|
5
|
+
|
|
6
|
+
from maidr.core.enum import PlotType
|
|
7
|
+
from maidr.core.plot import MaidrPlot
|
|
8
|
+
from maidr.exception import ExtractionError
|
|
9
|
+
from maidr.util.mixin import (
|
|
10
|
+
ContainerExtractorMixin,
|
|
11
|
+
DictMergerMixin,
|
|
12
|
+
LevelExtractorMixin,
|
|
13
|
+
)
|
|
14
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MplfinanceBarPlot(
|
|
18
|
+
MaidrPlot, ContainerExtractorMixin, LevelExtractorMixin, DictMergerMixin
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Specialized bar plot class for mplfinance volume bars.
|
|
22
|
+
|
|
23
|
+
This class handles the extraction and processing of volume data from mplfinance
|
|
24
|
+
plots, including proper date conversion and data validation.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, ax: Axes, **kwargs) -> None:
|
|
28
|
+
super().__init__(ax, PlotType.BAR)
|
|
29
|
+
# Store custom patches passed from mplfinance patch
|
|
30
|
+
self._custom_patches = kwargs.get("_maidr_patches", None)
|
|
31
|
+
# Store date numbers for volume bars (from mplfinance)
|
|
32
|
+
self._maidr_date_nums = kwargs.get("_maidr_date_nums", None)
|
|
33
|
+
# Store custom title
|
|
34
|
+
|
|
35
|
+
def set_title(self, title: str) -> None:
|
|
36
|
+
"""Set a custom title for this volume bar plot."""
|
|
37
|
+
self._title = title
|
|
38
|
+
|
|
39
|
+
def _extract_plot_data(self) -> list:
|
|
40
|
+
"""Extract data from mplfinance volume patches."""
|
|
41
|
+
if self._custom_patches:
|
|
42
|
+
return self._extract_volume_patches_data(self._custom_patches)
|
|
43
|
+
|
|
44
|
+
# Fallback to original bar plot logic if no custom patches
|
|
45
|
+
plot = self.extract_container(
|
|
46
|
+
self.ax, ContainerExtractorMixin, include_all=True
|
|
47
|
+
)
|
|
48
|
+
data = self._extract_bar_container_data(plot)
|
|
49
|
+
levels = self.extract_level(self.ax)
|
|
50
|
+
formatted_data = []
|
|
51
|
+
combined_data = list(
|
|
52
|
+
zip(levels, data) if plot[0].orientation == "vertical" else zip(data, levels) # type: ignore
|
|
53
|
+
)
|
|
54
|
+
if combined_data: # type: ignore
|
|
55
|
+
for x, y in combined_data: # type: ignore
|
|
56
|
+
formatted_data.append({"x": x, "y": y})
|
|
57
|
+
return formatted_data
|
|
58
|
+
if len(formatted_data) == 0:
|
|
59
|
+
raise ExtractionError(self.type, plot)
|
|
60
|
+
if data is None:
|
|
61
|
+
raise ExtractionError(self.type, plot)
|
|
62
|
+
|
|
63
|
+
return data
|
|
64
|
+
|
|
65
|
+
def _extract_volume_patches_data(self, volume_patches: list[Rectangle]) -> list:
|
|
66
|
+
"""
|
|
67
|
+
Extract data from volume Rectangle patches (used by mplfinance volume bars).
|
|
68
|
+
"""
|
|
69
|
+
if not volume_patches:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
# Sort patches by x-coordinate to maintain order
|
|
73
|
+
sorted_patches = sorted(volume_patches, key=lambda p: p.get_x())
|
|
74
|
+
|
|
75
|
+
# Set elements for highlighting (use the patches directly)
|
|
76
|
+
self._elements = sorted_patches
|
|
77
|
+
|
|
78
|
+
# Use the utility class to extract data
|
|
79
|
+
return MplfinanceDataExtractor.extract_volume_data(
|
|
80
|
+
sorted_patches, self._maidr_date_nums
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _extract_bar_container_data(self, plot: list | None) -> list | None:
|
|
84
|
+
"""Fallback method for regular bar containers."""
|
|
85
|
+
if plot is None:
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
# Since v0.13, Seaborn has transitioned from using `list[Patch]` to
|
|
89
|
+
# `list[BarContainers] for plotting bar plots.
|
|
90
|
+
# So, extract data correspondingly based on the level.
|
|
91
|
+
# Flatten all the `list[BarContainer]` to `list[Patch]`.
|
|
92
|
+
patches = [patch for container in plot for patch in container.patches]
|
|
93
|
+
level = self.extract_level(self.ax)
|
|
94
|
+
if level is None or len(level) == 0: # type: ignore
|
|
95
|
+
level = ["" for _ in range(len(patches))] # type: ignore
|
|
96
|
+
|
|
97
|
+
if len(patches) != len(level):
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
self._elements.extend(patches)
|
|
101
|
+
|
|
102
|
+
return [float(patch.get_height()) for patch in patches]
|
|
103
|
+
|
|
104
|
+
def _extract_axes_data(self) -> dict:
|
|
105
|
+
"""Extract axes data with mplfinance-specific cleaning."""
|
|
106
|
+
ax_data = super()._extract_axes_data()
|
|
107
|
+
|
|
108
|
+
# For mplfinance volume plots, clean up the y-axis label
|
|
109
|
+
if self._custom_patches:
|
|
110
|
+
y_label = ax_data.get("y")
|
|
111
|
+
if y_label:
|
|
112
|
+
ax_data["y"] = MplfinanceDataExtractor.clean_axis_label(y_label)
|
|
113
|
+
|
|
114
|
+
return ax_data
|
|
115
|
+
|
|
116
|
+
def _get_selector(self) -> str:
|
|
117
|
+
"""Return the CSS selector for highlighting bar elements in the SVG output."""
|
|
118
|
+
# Use the standard working selector that gets replaced with UUID by Maidr class
|
|
119
|
+
# This works for both original bar plots and mplfinance volume bars
|
|
120
|
+
return "g[maidr='true'] > path"
|
|
121
|
+
|
|
122
|
+
def render(self) -> dict:
|
|
123
|
+
"""Initialize the MAIDR schema dictionary with basic plot information."""
|
|
124
|
+
from maidr.core.enum.maidr_key import MaidrKey
|
|
125
|
+
|
|
126
|
+
title = "Volume Bar Plot"
|
|
127
|
+
|
|
128
|
+
maidr_schema = {
|
|
129
|
+
MaidrKey.TYPE: self.type,
|
|
130
|
+
MaidrKey.TITLE: title,
|
|
131
|
+
MaidrKey.AXES: self._extract_axes_data(),
|
|
132
|
+
MaidrKey.DATA: self._extract_plot_data(),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# Include selector only if the plot supports highlighting.
|
|
136
|
+
if self._support_highlighting:
|
|
137
|
+
maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
|
|
138
|
+
|
|
139
|
+
return maidr_schema
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from matplotlib.lines import Line2D
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from maidr.core.enum.maidr_key import MaidrKey
|
|
8
|
+
from maidr.core.enum.plot_type import PlotType
|
|
9
|
+
from maidr.core.plot.maidr_plot import MaidrPlot
|
|
10
|
+
from maidr.exception.extraction_error import ExtractionError
|
|
11
|
+
from maidr.util.mixin.extractor_mixin import LineExtractorMixin
|
|
12
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
13
|
+
import uuid
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MplfinanceLinePlot(MaidrPlot, LineExtractorMixin):
|
|
17
|
+
"""
|
|
18
|
+
Specialized line plot class for mplfinance moving averages.
|
|
19
|
+
|
|
20
|
+
This class handles the extraction and processing of moving average data from mplfinance
|
|
21
|
+
plots, including proper date conversion, NaN filtering, and moving average period detection.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, ax: Axes, **kwargs):
|
|
25
|
+
super().__init__(ax, PlotType.LINE)
|
|
26
|
+
self._line_titles = [] # Store line titles separately
|
|
27
|
+
|
|
28
|
+
def _get_selector(self) -> Union[str, List[str]]:
|
|
29
|
+
"""Return selectors for all lines that have data."""
|
|
30
|
+
all_lines = self.ax.get_lines()
|
|
31
|
+
if not all_lines:
|
|
32
|
+
return ["g[maidr='true'] > path"]
|
|
33
|
+
|
|
34
|
+
selectors = []
|
|
35
|
+
for line in all_lines:
|
|
36
|
+
# Only create selectors for lines that have data
|
|
37
|
+
xydata = line.get_xydata()
|
|
38
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
39
|
+
continue
|
|
40
|
+
gid = line.get_gid()
|
|
41
|
+
if gid:
|
|
42
|
+
selectors.append(f"g[id='{gid}'] path")
|
|
43
|
+
else:
|
|
44
|
+
selectors.append("g[maidr='true'] > path")
|
|
45
|
+
|
|
46
|
+
if not selectors:
|
|
47
|
+
return ["g[maidr='true'] > path"]
|
|
48
|
+
|
|
49
|
+
return selectors
|
|
50
|
+
|
|
51
|
+
def _extract_axes_data(self) -> dict:
|
|
52
|
+
"""
|
|
53
|
+
Extract axis labels for the plot.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
dict
|
|
58
|
+
Dictionary containing x and y axis labels with custom y-label for moving averages.
|
|
59
|
+
"""
|
|
60
|
+
x_labels = self.ax.get_xlabel()
|
|
61
|
+
if not x_labels:
|
|
62
|
+
x_labels = self.extract_shared_xlabel(self.ax)
|
|
63
|
+
if not x_labels:
|
|
64
|
+
x_labels = "Date"
|
|
65
|
+
|
|
66
|
+
# Get the period from the first line for y-axis label
|
|
67
|
+
ma_period = self._extract_moving_average_period()
|
|
68
|
+
y_label = (
|
|
69
|
+
f"{ma_period}-day mav price ($)"
|
|
70
|
+
if ma_period
|
|
71
|
+
else "Moving Average Price ($)"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return {MaidrKey.X: x_labels, MaidrKey.Y: y_label}
|
|
75
|
+
|
|
76
|
+
def _extract_moving_average_periods(self) -> List[str]:
|
|
77
|
+
"""
|
|
78
|
+
Extract all moving average periods from the _maidr_ma_period attributes set by the mplfinance patch.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
List[str]
|
|
83
|
+
List of moving average periods (e.g., ["3", "6", "30"]).
|
|
84
|
+
"""
|
|
85
|
+
all_lines = self.ax.get_lines()
|
|
86
|
+
periods = []
|
|
87
|
+
for line in all_lines:
|
|
88
|
+
# Get the period that was stored by the mplfinance patch
|
|
89
|
+
ma_period = getattr(line, "_maidr_ma_period", None)
|
|
90
|
+
if ma_period is not None:
|
|
91
|
+
periods.append(str(ma_period))
|
|
92
|
+
|
|
93
|
+
# Remove duplicates and sort
|
|
94
|
+
periods = sorted(list(set(periods)))
|
|
95
|
+
|
|
96
|
+
return periods
|
|
97
|
+
|
|
98
|
+
def _extract_moving_average_period(self) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Extract the moving average period from the _maidr_ma_period attribute set by the mplfinance patch.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
str
|
|
105
|
+
The moving average period (e.g., "3", "6", "30") or empty string if no period found.
|
|
106
|
+
"""
|
|
107
|
+
periods = self._extract_moving_average_periods()
|
|
108
|
+
return periods[0] if periods else ""
|
|
109
|
+
|
|
110
|
+
def _extract_plot_data(self) -> Union[List[List[dict]], None]:
|
|
111
|
+
"""Extract data from mplfinance moving average lines."""
|
|
112
|
+
data = self._extract_line_data()
|
|
113
|
+
|
|
114
|
+
if data is None:
|
|
115
|
+
raise ExtractionError(self.type, None)
|
|
116
|
+
|
|
117
|
+
return data
|
|
118
|
+
|
|
119
|
+
def _extract_line_data(self) -> Union[List[List[dict]], None]:
|
|
120
|
+
"""
|
|
121
|
+
Extract data from all line objects and return as separate arrays.
|
|
122
|
+
|
|
123
|
+
This method handles mplfinance-specific logic including:
|
|
124
|
+
- Date conversion from matplotlib date numbers
|
|
125
|
+
- NaN filtering for moving averages
|
|
126
|
+
- Moving average period detection
|
|
127
|
+
- Proper data validation
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
list[list[dict]] | None
|
|
132
|
+
List of lists, where each inner list contains dictionaries with x,y coordinates
|
|
133
|
+
and line identifiers for one line, or None if the plot data is invalid.
|
|
134
|
+
"""
|
|
135
|
+
all_lines = self.ax.get_lines()
|
|
136
|
+
if not all_lines:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
all_lines_data = []
|
|
140
|
+
|
|
141
|
+
for line_idx, line in enumerate(all_lines):
|
|
142
|
+
xydata = line.get_xydata()
|
|
143
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
self._elements.append(line)
|
|
147
|
+
|
|
148
|
+
# Assign unique GID to each line if not already set
|
|
149
|
+
if line.get_gid() is None:
|
|
150
|
+
unique_gid = f"maidr-{uuid.uuid4()}"
|
|
151
|
+
line.set_gid(unique_gid)
|
|
152
|
+
|
|
153
|
+
label: str = line.get_label() # type: ignore
|
|
154
|
+
|
|
155
|
+
# Get the period for this specific line
|
|
156
|
+
ma_period = getattr(line, "_maidr_ma_period", None)
|
|
157
|
+
|
|
158
|
+
# Create title for this line
|
|
159
|
+
line_title = (
|
|
160
|
+
f"{ma_period}-Day Moving Average Line Plot"
|
|
161
|
+
if ma_period
|
|
162
|
+
else "Moving Average Line Plot"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
line_data = []
|
|
166
|
+
|
|
167
|
+
# Check if this line has date numbers from mplfinance
|
|
168
|
+
date_nums = getattr(line, "_maidr_date_nums", None)
|
|
169
|
+
|
|
170
|
+
# Convert xydata to list of points
|
|
171
|
+
for i, (x, y) in enumerate(line.get_xydata()): # type: ignore
|
|
172
|
+
# Skip points with NaN or inf values to prevent JSON parsing errors
|
|
173
|
+
if np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y):
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
# Handle x-value conversion - could be string (date) or numeric
|
|
177
|
+
if isinstance(x, str):
|
|
178
|
+
x_value = x # Keep string as-is (for dates)
|
|
179
|
+
else:
|
|
180
|
+
# Check if we have date numbers from mplfinance
|
|
181
|
+
if date_nums is not None and i < len(date_nums):
|
|
182
|
+
# Use the date number to convert to date string
|
|
183
|
+
date_num = float(date_nums[i])
|
|
184
|
+
x_value = self._convert_x_to_date(date_num)
|
|
185
|
+
else:
|
|
186
|
+
x_value = float(x) # Convert numeric to float
|
|
187
|
+
|
|
188
|
+
point_data = {
|
|
189
|
+
MaidrKey.X: x_value,
|
|
190
|
+
MaidrKey.Y: float(y),
|
|
191
|
+
}
|
|
192
|
+
line_data.append(point_data)
|
|
193
|
+
|
|
194
|
+
if line_data:
|
|
195
|
+
# Create line data with title, axes, and points structure
|
|
196
|
+
line_with_metadata = {
|
|
197
|
+
"title": line_title,
|
|
198
|
+
"axes": {
|
|
199
|
+
"x": "Date",
|
|
200
|
+
"y": f"{ma_period}-day mav price ($)"
|
|
201
|
+
if ma_period
|
|
202
|
+
else "Moving Average Price ($)",
|
|
203
|
+
},
|
|
204
|
+
"points": line_data,
|
|
205
|
+
}
|
|
206
|
+
all_lines_data.append(line_with_metadata)
|
|
207
|
+
|
|
208
|
+
return all_lines_data if all_lines_data else None
|
|
209
|
+
|
|
210
|
+
def _convert_x_to_date(self, x_value: float) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Convert x-coordinate to date string for mplfinance plots.
|
|
213
|
+
|
|
214
|
+
This method uses the MplfinanceDataExtractor utility to convert
|
|
215
|
+
matplotlib date numbers to proper date strings.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
x_value : float
|
|
220
|
+
The x-coordinate value (matplotlib date number)
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
str
|
|
225
|
+
Date string in YYYY-MM-DD format
|
|
226
|
+
"""
|
|
227
|
+
return MplfinanceDataExtractor._convert_date_num_to_string(x_value)
|
|
228
|
+
|
|
229
|
+
def _extract_line_titles(self) -> List[str]:
|
|
230
|
+
"""
|
|
231
|
+
Extract titles for all moving average lines.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
List[str]
|
|
236
|
+
List of titles for each line.
|
|
237
|
+
"""
|
|
238
|
+
all_lines = self.ax.get_lines()
|
|
239
|
+
titles = []
|
|
240
|
+
|
|
241
|
+
for line in all_lines:
|
|
242
|
+
ma_period = getattr(line, "_maidr_ma_period", None)
|
|
243
|
+
title = (
|
|
244
|
+
f"{ma_period}-Day Moving Average Line Plot"
|
|
245
|
+
if ma_period
|
|
246
|
+
else "Moving Average Line Plot"
|
|
247
|
+
)
|
|
248
|
+
titles.append(title)
|
|
249
|
+
|
|
250
|
+
return titles
|
|
251
|
+
|
|
252
|
+
def render(self) -> dict:
|
|
253
|
+
"""Initialize the MAIDR schema dictionary with basic plot information."""
|
|
254
|
+
# Use the first line's period for the main title
|
|
255
|
+
ma_period = self._extract_moving_average_period()
|
|
256
|
+
title = (
|
|
257
|
+
f"{ma_period}-Day Moving Averages Line Plot"
|
|
258
|
+
if ma_period
|
|
259
|
+
else "Moving Averages Line Plot"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
maidr_schema = {
|
|
263
|
+
MaidrKey.TYPE: self.type,
|
|
264
|
+
MaidrKey.TITLE: title,
|
|
265
|
+
MaidrKey.AXES: self._extract_axes_data(),
|
|
266
|
+
MaidrKey.DATA: self._extract_plot_data(),
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
# Include selector only if the plot supports highlighting.
|
|
270
|
+
if self._support_highlighting:
|
|
271
|
+
maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
|
|
272
|
+
|
|
273
|
+
return maidr_schema
|
maidr/patch/__init__.py
CHANGED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import wrapt
|
|
3
|
+
import mplfinance as mpf
|
|
4
|
+
import numpy as np
|
|
5
|
+
from matplotlib.collections import LineCollection, PolyCollection
|
|
6
|
+
from matplotlib.patches import Rectangle
|
|
7
|
+
from matplotlib.lines import Line2D
|
|
8
|
+
from maidr.core.enum import PlotType
|
|
9
|
+
from maidr.patch.common import common
|
|
10
|
+
from maidr.core.context_manager import ContextManager
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mplfinance_plot_patch(wrapped, instance, args, kwargs):
|
|
14
|
+
"""
|
|
15
|
+
Enhanced patch function for `mplfinance.plot` that registers separate layers:
|
|
16
|
+
- CANDLESTICK: For OHLC data (candle bodies and wicks)
|
|
17
|
+
- BAR: For volume data (volume bars)
|
|
18
|
+
- LINE: For moving averages (lines)
|
|
19
|
+
|
|
20
|
+
This function intercepts calls to `mplfinance.plot`, identifies the resulting
|
|
21
|
+
candlestick, volume, and moving average components, and registers them with
|
|
22
|
+
maidr using the common patching mechanism.
|
|
23
|
+
"""
|
|
24
|
+
# Ensure `returnfig=True` to capture the figure and axes objects.
|
|
25
|
+
original_returnfig = kwargs.get("returnfig", False)
|
|
26
|
+
kwargs["returnfig"] = True
|
|
27
|
+
|
|
28
|
+
with ContextManager.set_internal_context():
|
|
29
|
+
result = wrapped(*args, **kwargs)
|
|
30
|
+
|
|
31
|
+
# Validate that we received the expected figure and axes tuple
|
|
32
|
+
if not (isinstance(result, tuple) and len(result) >= 2):
|
|
33
|
+
return result if original_returnfig else None
|
|
34
|
+
|
|
35
|
+
fig, axes = result[0], result[1]
|
|
36
|
+
ax_list = axes if isinstance(axes, list) else [axes]
|
|
37
|
+
|
|
38
|
+
# Enhanced axis identification using content-based detection
|
|
39
|
+
price_ax = None
|
|
40
|
+
volume_ax = None
|
|
41
|
+
|
|
42
|
+
# Identify axes by their content rather than just labels
|
|
43
|
+
for ax in ax_list:
|
|
44
|
+
# Price axis has candlestick collections (LineCollection for wicks, PolyCollection for bodies)
|
|
45
|
+
if any(isinstance(c, (LineCollection, PolyCollection)) for c in ax.collections):
|
|
46
|
+
price_ax = ax
|
|
47
|
+
# Volume axis has rectangle patches for volume bars
|
|
48
|
+
elif any(isinstance(p, Rectangle) for p in ax.patches):
|
|
49
|
+
volume_ax = ax
|
|
50
|
+
# Fallback: use y-label if content-based detection fails
|
|
51
|
+
elif price_ax is None and "price" in ax.get_ylabel().lower():
|
|
52
|
+
price_ax = ax
|
|
53
|
+
elif volume_ax is None and "volume" in ax.get_ylabel().lower():
|
|
54
|
+
volume_ax = ax
|
|
55
|
+
|
|
56
|
+
# Try to extract date numbers from the data
|
|
57
|
+
date_nums = None
|
|
58
|
+
data = None
|
|
59
|
+
if len(args) > 0:
|
|
60
|
+
data = args[0]
|
|
61
|
+
elif "data" in kwargs:
|
|
62
|
+
data = kwargs["data"]
|
|
63
|
+
|
|
64
|
+
if data is not None:
|
|
65
|
+
if hasattr(data, "Date_num"):
|
|
66
|
+
date_nums = list(data["Date_num"])
|
|
67
|
+
elif hasattr(data, "index"):
|
|
68
|
+
# fallback: use index if it's a DatetimeIndex
|
|
69
|
+
try:
|
|
70
|
+
import matplotlib.dates as mdates
|
|
71
|
+
|
|
72
|
+
date_nums = [mdates.date2num(d) for d in data.index]
|
|
73
|
+
except Exception:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
# Process and register the Candlestick plot
|
|
77
|
+
if price_ax:
|
|
78
|
+
wick_collection = next(
|
|
79
|
+
(c for c in price_ax.collections if isinstance(c, LineCollection)), None
|
|
80
|
+
)
|
|
81
|
+
body_collection = next(
|
|
82
|
+
(c for c in price_ax.collections if isinstance(c, PolyCollection)), None
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if wick_collection and body_collection:
|
|
86
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
87
|
+
wick_collection.set_gid(gid)
|
|
88
|
+
body_collection.set_gid(gid)
|
|
89
|
+
|
|
90
|
+
candlestick_kwargs = dict(
|
|
91
|
+
kwargs,
|
|
92
|
+
_maidr_wick_collection=wick_collection,
|
|
93
|
+
_maidr_body_collection=body_collection,
|
|
94
|
+
_maidr_date_nums=date_nums,
|
|
95
|
+
)
|
|
96
|
+
common(
|
|
97
|
+
PlotType.CANDLESTICK,
|
|
98
|
+
lambda *a, **k: price_ax,
|
|
99
|
+
instance,
|
|
100
|
+
args,
|
|
101
|
+
candlestick_kwargs,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Process and register the Volume plot
|
|
105
|
+
if volume_ax:
|
|
106
|
+
volume_patches = [p for p in volume_ax.patches if isinstance(p, Rectangle)]
|
|
107
|
+
|
|
108
|
+
if not volume_patches:
|
|
109
|
+
# Search in shared axes for volume patches
|
|
110
|
+
for twin_ax in volume_ax.get_shared_x_axes().get_siblings(volume_ax):
|
|
111
|
+
if twin_ax is not volume_ax:
|
|
112
|
+
volume_patches.extend(
|
|
113
|
+
[p for p in twin_ax.patches if isinstance(p, Rectangle)]
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if volume_patches:
|
|
117
|
+
# Set GID for volume patches for highlighting
|
|
118
|
+
for patch in volume_patches:
|
|
119
|
+
if patch.get_gid() is None:
|
|
120
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
121
|
+
patch.set_gid(gid)
|
|
122
|
+
|
|
123
|
+
bar_kwargs = dict(
|
|
124
|
+
kwargs,
|
|
125
|
+
_maidr_patches=volume_patches,
|
|
126
|
+
_maidr_date_nums=date_nums,
|
|
127
|
+
)
|
|
128
|
+
common(PlotType.BAR, lambda *a, **k: volume_ax, instance, args, bar_kwargs)
|
|
129
|
+
|
|
130
|
+
# Process and register Moving Averages as LINE plots
|
|
131
|
+
if price_ax:
|
|
132
|
+
# Find moving average lines (Line2D objects)
|
|
133
|
+
ma_lines = [line for line in price_ax.get_lines() if isinstance(line, Line2D)]
|
|
134
|
+
|
|
135
|
+
# Track processed lines to avoid duplicates
|
|
136
|
+
processed_lines = set()
|
|
137
|
+
valid_lines = []
|
|
138
|
+
|
|
139
|
+
for line in ma_lines:
|
|
140
|
+
# Try to identify the moving average period based on NaN count
|
|
141
|
+
xydata = line.get_xydata()
|
|
142
|
+
|
|
143
|
+
if xydata is not None:
|
|
144
|
+
xydata_array = np.asarray(xydata)
|
|
145
|
+
nan_count = np.sum(
|
|
146
|
+
np.isnan(xydata_array[:, 1])
|
|
147
|
+
) # Count NaN in y-values
|
|
148
|
+
|
|
149
|
+
# Map NaN count to likely moving average period
|
|
150
|
+
estimated_period = nan_count + 1
|
|
151
|
+
|
|
152
|
+
# Store the period directly on the line for easy access
|
|
153
|
+
setattr(line, "_maidr_ma_period", estimated_period)
|
|
154
|
+
|
|
155
|
+
# Create a better label for the line
|
|
156
|
+
label = str(line.get_label())
|
|
157
|
+
if label.startswith("_child"):
|
|
158
|
+
new_label = f"Moving Average {estimated_period} days"
|
|
159
|
+
line.set_label(new_label)
|
|
160
|
+
else:
|
|
161
|
+
# If it's not a _child label, still add the period info
|
|
162
|
+
new_label = f"{label}_MA{estimated_period}"
|
|
163
|
+
line.set_label(new_label)
|
|
164
|
+
|
|
165
|
+
# Create a unique identifier for this line based on its data
|
|
166
|
+
if xydata is not None:
|
|
167
|
+
xydata_array = np.asarray(xydata)
|
|
168
|
+
if xydata_array.size > 0:
|
|
169
|
+
# Use shape and first few values to create a unique identifier
|
|
170
|
+
first_values = (
|
|
171
|
+
xydata_array[:3].flatten()
|
|
172
|
+
if xydata_array.size >= 6
|
|
173
|
+
else xydata_array.flatten()
|
|
174
|
+
)
|
|
175
|
+
data_hash = hash(f"{xydata_array.shape}_{str(first_values)}")
|
|
176
|
+
line_id = f"{line.get_label()}_{data_hash}"
|
|
177
|
+
else:
|
|
178
|
+
line_id = f"{line.get_label()}"
|
|
179
|
+
else:
|
|
180
|
+
line_id = f"{line.get_label()}"
|
|
181
|
+
|
|
182
|
+
if line_id in processed_lines:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
processed_lines.add(line_id)
|
|
186
|
+
|
|
187
|
+
# Validate that the line has valid data
|
|
188
|
+
if xydata is None or xydata_array.size == 0:
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
# Store date numbers on the line for the line plot class to use
|
|
192
|
+
if date_nums is not None:
|
|
193
|
+
setattr(line, "_maidr_date_nums", date_nums)
|
|
194
|
+
|
|
195
|
+
# Ensure GID is set for highlighting
|
|
196
|
+
if line.get_gid() is None:
|
|
197
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
198
|
+
line.set_gid(gid)
|
|
199
|
+
|
|
200
|
+
# Add to valid lines list
|
|
201
|
+
valid_lines.append(line)
|
|
202
|
+
|
|
203
|
+
# Register all valid lines as a single LINE plot
|
|
204
|
+
if valid_lines:
|
|
205
|
+
line_kwargs = dict(kwargs)
|
|
206
|
+
common(PlotType.LINE, lambda *a, **k: price_ax, instance, args, line_kwargs)
|
|
207
|
+
|
|
208
|
+
if not original_returnfig:
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# Apply the patch to mplfinance.plot
|
|
215
|
+
wrapt.wrap_function_wrapper(mpf, "plot", mplfinance_plot_patch)
|
maidr/util/__init__.py
CHANGED