maidr 1.2.2__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maidr/__init__.py +2 -1
- maidr/core/maidr.py +12 -6
- maidr/core/plot/candlestick.py +87 -203
- maidr/core/plot/maidr_plot.py +7 -2
- maidr/core/plot/maidr_plot_factory.py +12 -11
- maidr/core/plot/mplfinance_barplot.py +115 -0
- maidr/core/plot/mplfinance_lineplot.py +146 -0
- maidr/patch/__init__.py +15 -0
- maidr/patch/mplfinance.py +213 -0
- maidr/util/__init__.py +3 -0
- maidr/util/environment.py +26 -7
- maidr/util/mplfinance_utils.py +409 -0
- maidr/util/plot_detection.py +136 -0
- {maidr-1.2.2.dist-info → maidr-1.4.0.dist-info}/METADATA +1 -4
- {maidr-1.2.2.dist-info → maidr-1.4.0.dist-info}/RECORD +17 -12
- {maidr-1.2.2.dist-info → maidr-1.4.0.dist-info}/LICENSE +0 -0
- {maidr-1.2.2.dist-info → maidr-1.4.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
from matplotlib.axes import Axes
|
|
4
|
+
from matplotlib.lines import Line2D
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from maidr.core.enum.maidr_key import MaidrKey
|
|
8
|
+
from maidr.core.enum.plot_type import PlotType
|
|
9
|
+
from maidr.core.plot.maidr_plot import MaidrPlot
|
|
10
|
+
from maidr.exception.extraction_error import ExtractionError
|
|
11
|
+
from maidr.util.mixin.extractor_mixin import LineExtractorMixin
|
|
12
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
13
|
+
import uuid
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MplfinanceLinePlot(MaidrPlot, LineExtractorMixin):
|
|
17
|
+
"""
|
|
18
|
+
Specialized line plot class for mplfinance moving averages.
|
|
19
|
+
|
|
20
|
+
This class handles the extraction and processing of moving average data from mplfinance
|
|
21
|
+
plots, including proper date conversion, NaN filtering, and moving average period detection.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, ax: Axes, **kwargs):
|
|
25
|
+
super().__init__(ax, PlotType.LINE)
|
|
26
|
+
|
|
27
|
+
def _get_selector(self) -> Union[str, List[str]]:
|
|
28
|
+
"""Return selectors for all lines that have data."""
|
|
29
|
+
all_lines = self.ax.get_lines()
|
|
30
|
+
if not all_lines:
|
|
31
|
+
return ["g[maidr='true'] > path"]
|
|
32
|
+
|
|
33
|
+
selectors = []
|
|
34
|
+
for line in all_lines:
|
|
35
|
+
# Only create selectors for lines that have data
|
|
36
|
+
xydata = line.get_xydata()
|
|
37
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
38
|
+
continue
|
|
39
|
+
gid = line.get_gid()
|
|
40
|
+
if gid:
|
|
41
|
+
selectors.append(f"g[id='{gid}'] path")
|
|
42
|
+
else:
|
|
43
|
+
selectors.append("g[maidr='true'] > path")
|
|
44
|
+
|
|
45
|
+
if not selectors:
|
|
46
|
+
return ["g[maidr='true'] > path"]
|
|
47
|
+
|
|
48
|
+
return selectors
|
|
49
|
+
|
|
50
|
+
def _extract_plot_data(self) -> Union[List[List[dict]], None]:
|
|
51
|
+
"""Extract data from mplfinance moving average lines."""
|
|
52
|
+
data = self._extract_line_data()
|
|
53
|
+
|
|
54
|
+
if data is None:
|
|
55
|
+
raise ExtractionError(self.type, None)
|
|
56
|
+
|
|
57
|
+
return data
|
|
58
|
+
|
|
59
|
+
def _extract_line_data(self) -> Union[List[List[dict]], None]:
|
|
60
|
+
"""
|
|
61
|
+
Extract data from all line objects and return as separate arrays.
|
|
62
|
+
|
|
63
|
+
This method handles mplfinance-specific logic including:
|
|
64
|
+
- Date conversion from matplotlib date numbers
|
|
65
|
+
- NaN filtering for moving averages
|
|
66
|
+
- Moving average period detection
|
|
67
|
+
- Proper data validation
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
list[list[dict]] | None
|
|
72
|
+
List of lists, where each inner list contains dictionaries with x,y coordinates
|
|
73
|
+
and line identifiers for one line, or None if the plot data is invalid.
|
|
74
|
+
"""
|
|
75
|
+
all_lines = self.ax.get_lines()
|
|
76
|
+
if not all_lines:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
all_lines_data = []
|
|
80
|
+
|
|
81
|
+
for line_idx, line in enumerate(all_lines):
|
|
82
|
+
xydata = line.get_xydata()
|
|
83
|
+
if xydata is None or not xydata.size: # type: ignore
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
self._elements.append(line)
|
|
87
|
+
|
|
88
|
+
# Assign unique GID to each line if not already set
|
|
89
|
+
if line.get_gid() is None:
|
|
90
|
+
unique_gid = f"maidr-{uuid.uuid4()}"
|
|
91
|
+
line.set_gid(unique_gid)
|
|
92
|
+
|
|
93
|
+
label: str = line.get_label() # type: ignore
|
|
94
|
+
line_data = []
|
|
95
|
+
|
|
96
|
+
# Check if this line has date numbers from mplfinance
|
|
97
|
+
date_nums = getattr(line, "_maidr_date_nums", None)
|
|
98
|
+
|
|
99
|
+
# Convert xydata to list of points
|
|
100
|
+
for i, (x, y) in enumerate(line.get_xydata()): # type: ignore
|
|
101
|
+
# Skip points with NaN or inf values to prevent JSON parsing errors
|
|
102
|
+
if np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y):
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Handle x-value conversion - could be string (date) or numeric
|
|
106
|
+
if isinstance(x, str):
|
|
107
|
+
x_value = x # Keep string as-is (for dates)
|
|
108
|
+
else:
|
|
109
|
+
# Check if we have date numbers from mplfinance
|
|
110
|
+
if date_nums is not None and i < len(date_nums):
|
|
111
|
+
# Use the date number to convert to date string
|
|
112
|
+
date_num = float(date_nums[i])
|
|
113
|
+
x_value = self._convert_x_to_date(date_num)
|
|
114
|
+
else:
|
|
115
|
+
x_value = float(x) # Convert numeric to float
|
|
116
|
+
|
|
117
|
+
point_data = {
|
|
118
|
+
MaidrKey.X: x_value,
|
|
119
|
+
MaidrKey.Y: float(y),
|
|
120
|
+
MaidrKey.FILL: (label if not label.startswith("_child") else ""),
|
|
121
|
+
}
|
|
122
|
+
line_data.append(point_data)
|
|
123
|
+
|
|
124
|
+
if line_data:
|
|
125
|
+
all_lines_data.append(line_data)
|
|
126
|
+
|
|
127
|
+
return all_lines_data if all_lines_data else None
|
|
128
|
+
|
|
129
|
+
def _convert_x_to_date(self, x_value: float) -> str:
|
|
130
|
+
"""
|
|
131
|
+
Convert x-coordinate to date string for mplfinance plots.
|
|
132
|
+
|
|
133
|
+
This method uses the MplfinanceDataExtractor utility to convert
|
|
134
|
+
matplotlib date numbers to proper date strings.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
x_value : float
|
|
139
|
+
The x-coordinate value (matplotlib date number)
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
str
|
|
144
|
+
Date string in YYYY-MM-DD format
|
|
145
|
+
"""
|
|
146
|
+
return MplfinanceDataExtractor._convert_date_num_to_string(x_value)
|
maidr/patch/__init__.py
CHANGED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import wrapt
|
|
3
|
+
import mplfinance as mpf
|
|
4
|
+
import numpy as np
|
|
5
|
+
from matplotlib.collections import LineCollection, PolyCollection
|
|
6
|
+
from matplotlib.patches import Rectangle
|
|
7
|
+
from matplotlib.lines import Line2D
|
|
8
|
+
from maidr.core.enum import PlotType
|
|
9
|
+
from maidr.patch.common import common
|
|
10
|
+
from maidr.core.context_manager import ContextManager
|
|
11
|
+
from maidr.util.mplfinance_utils import MplfinanceDataExtractor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def mplfinance_plot_patch(wrapped, instance, args, kwargs):
|
|
15
|
+
"""
|
|
16
|
+
Enhanced patch function for `mplfinance.plot` that registers separate layers:
|
|
17
|
+
- CANDLESTICK: For OHLC data (candle bodies and wicks)
|
|
18
|
+
- BAR: For volume data (volume bars)
|
|
19
|
+
- LINE: For moving averages (lines)
|
|
20
|
+
|
|
21
|
+
This function intercepts calls to `mplfinance.plot`, identifies the resulting
|
|
22
|
+
candlestick, volume, and moving average components, and registers them with
|
|
23
|
+
maidr using the common patching mechanism.
|
|
24
|
+
"""
|
|
25
|
+
# Ensure `returnfig=True` to capture the figure and axes objects.
|
|
26
|
+
original_returnfig = kwargs.get("returnfig", False)
|
|
27
|
+
kwargs["returnfig"] = True
|
|
28
|
+
|
|
29
|
+
with ContextManager.set_internal_context():
|
|
30
|
+
result = wrapped(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
# Validate that we received the expected figure and axes tuple
|
|
33
|
+
if not (isinstance(result, tuple) and len(result) >= 2):
|
|
34
|
+
return result if original_returnfig else None
|
|
35
|
+
|
|
36
|
+
fig, axes = result[0], result[1]
|
|
37
|
+
ax_list = axes if isinstance(axes, list) else [axes]
|
|
38
|
+
|
|
39
|
+
# Enhanced axis identification using content-based detection
|
|
40
|
+
price_ax = None
|
|
41
|
+
volume_ax = None
|
|
42
|
+
|
|
43
|
+
# Identify axes by their content rather than just labels
|
|
44
|
+
for ax in ax_list:
|
|
45
|
+
# Price axis has candlestick collections (LineCollection for wicks, PolyCollection for bodies)
|
|
46
|
+
if any(isinstance(c, (LineCollection, PolyCollection)) for c in ax.collections):
|
|
47
|
+
price_ax = ax
|
|
48
|
+
# Volume axis has rectangle patches for volume bars
|
|
49
|
+
elif any(isinstance(p, Rectangle) for p in ax.patches):
|
|
50
|
+
volume_ax = ax
|
|
51
|
+
# Fallback: use y-label if content-based detection fails
|
|
52
|
+
elif price_ax is None and "price" in ax.get_ylabel().lower():
|
|
53
|
+
price_ax = ax
|
|
54
|
+
elif volume_ax is None and "volume" in ax.get_ylabel().lower():
|
|
55
|
+
volume_ax = ax
|
|
56
|
+
|
|
57
|
+
# Try to extract date numbers from the data
|
|
58
|
+
date_nums = None
|
|
59
|
+
data = None
|
|
60
|
+
if len(args) > 0:
|
|
61
|
+
data = args[0]
|
|
62
|
+
elif "data" in kwargs:
|
|
63
|
+
data = kwargs["data"]
|
|
64
|
+
|
|
65
|
+
if data is not None:
|
|
66
|
+
if hasattr(data, "Date_num"):
|
|
67
|
+
date_nums = list(data["Date_num"])
|
|
68
|
+
elif hasattr(data, "index"):
|
|
69
|
+
# fallback: use index if it's a DatetimeIndex
|
|
70
|
+
try:
|
|
71
|
+
import matplotlib.dates as mdates
|
|
72
|
+
|
|
73
|
+
date_nums = [mdates.date2num(d) for d in data.index]
|
|
74
|
+
except Exception:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
# Process and register the Candlestick plot
|
|
78
|
+
if price_ax:
|
|
79
|
+
wick_collection = next(
|
|
80
|
+
(c for c in price_ax.collections if isinstance(c, LineCollection)), None
|
|
81
|
+
)
|
|
82
|
+
body_collection = next(
|
|
83
|
+
(c for c in price_ax.collections if isinstance(c, PolyCollection)), None
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if wick_collection and body_collection:
|
|
87
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
88
|
+
wick_collection.set_gid(gid)
|
|
89
|
+
body_collection.set_gid(gid)
|
|
90
|
+
|
|
91
|
+
candlestick_kwargs = dict(
|
|
92
|
+
kwargs,
|
|
93
|
+
_maidr_wick_collection=wick_collection,
|
|
94
|
+
_maidr_body_collection=body_collection,
|
|
95
|
+
_maidr_date_nums=date_nums,
|
|
96
|
+
)
|
|
97
|
+
common(
|
|
98
|
+
PlotType.CANDLESTICK,
|
|
99
|
+
lambda *a, **k: price_ax,
|
|
100
|
+
instance,
|
|
101
|
+
args,
|
|
102
|
+
candlestick_kwargs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Process and register the Volume plot
|
|
106
|
+
if volume_ax:
|
|
107
|
+
volume_patches = [p for p in volume_ax.patches if isinstance(p, Rectangle)]
|
|
108
|
+
|
|
109
|
+
if not volume_patches:
|
|
110
|
+
# Search in shared axes for volume patches
|
|
111
|
+
for twin_ax in volume_ax.get_shared_x_axes().get_siblings(volume_ax):
|
|
112
|
+
if twin_ax is not volume_ax:
|
|
113
|
+
volume_patches.extend(
|
|
114
|
+
[p for p in twin_ax.patches if isinstance(p, Rectangle)]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if volume_patches:
|
|
118
|
+
# Set GID for volume patches for highlighting
|
|
119
|
+
for patch in volume_patches:
|
|
120
|
+
if patch.get_gid() is None:
|
|
121
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
122
|
+
patch.set_gid(gid)
|
|
123
|
+
|
|
124
|
+
bar_kwargs = dict(
|
|
125
|
+
kwargs,
|
|
126
|
+
_maidr_patches=volume_patches,
|
|
127
|
+
_maidr_date_nums=date_nums,
|
|
128
|
+
)
|
|
129
|
+
common(PlotType.BAR, lambda *a, **k: volume_ax, instance, args, bar_kwargs)
|
|
130
|
+
|
|
131
|
+
# Process and register Moving Averages as LINE plots
|
|
132
|
+
if price_ax:
|
|
133
|
+
# Find moving average lines (Line2D objects)
|
|
134
|
+
ma_lines = [line for line in price_ax.get_lines() if isinstance(line, Line2D)]
|
|
135
|
+
|
|
136
|
+
# Track processed lines to avoid duplicates
|
|
137
|
+
processed_lines = set()
|
|
138
|
+
valid_lines = []
|
|
139
|
+
|
|
140
|
+
for line in ma_lines:
|
|
141
|
+
# Try to identify the moving average period based on NaN count
|
|
142
|
+
xydata = line.get_xydata()
|
|
143
|
+
|
|
144
|
+
if xydata is not None:
|
|
145
|
+
xydata_array = np.asarray(xydata)
|
|
146
|
+
nan_count = np.sum(
|
|
147
|
+
np.isnan(xydata_array[:, 1])
|
|
148
|
+
) # Count NaN in y-values
|
|
149
|
+
|
|
150
|
+
# Map NaN count to likely moving average period
|
|
151
|
+
estimated_period = nan_count + 1
|
|
152
|
+
|
|
153
|
+
# Create a better label for the line
|
|
154
|
+
label = str(line.get_label())
|
|
155
|
+
if label.startswith("_child"):
|
|
156
|
+
new_label = f"Moving Average {estimated_period} days"
|
|
157
|
+
line.set_label(new_label)
|
|
158
|
+
else:
|
|
159
|
+
# If it's not a _child label, still add the period info
|
|
160
|
+
new_label = f"{label}_MA{estimated_period}"
|
|
161
|
+
line.set_label(new_label)
|
|
162
|
+
|
|
163
|
+
# Create a unique identifier for this line based on its data
|
|
164
|
+
if xydata is not None:
|
|
165
|
+
xydata_array = np.asarray(xydata)
|
|
166
|
+
if xydata_array.size > 0:
|
|
167
|
+
# Use shape and first few values to create a unique identifier
|
|
168
|
+
first_values = (
|
|
169
|
+
xydata_array[:3].flatten()
|
|
170
|
+
if xydata_array.size >= 6
|
|
171
|
+
else xydata_array.flatten()
|
|
172
|
+
)
|
|
173
|
+
data_hash = hash(f"{xydata_array.shape}_{str(first_values)}")
|
|
174
|
+
line_id = f"{line.get_label()}_{data_hash}"
|
|
175
|
+
else:
|
|
176
|
+
line_id = f"{line.get_label()}"
|
|
177
|
+
else:
|
|
178
|
+
line_id = f"{line.get_label()}"
|
|
179
|
+
|
|
180
|
+
if line_id in processed_lines:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
processed_lines.add(line_id)
|
|
184
|
+
|
|
185
|
+
# Validate that the line has valid data
|
|
186
|
+
if xydata is None or xydata_array.size == 0:
|
|
187
|
+
continue
|
|
188
|
+
|
|
189
|
+
# Store date numbers on the line for the line plot class to use
|
|
190
|
+
if date_nums is not None:
|
|
191
|
+
setattr(line, "_maidr_date_nums", date_nums)
|
|
192
|
+
|
|
193
|
+
# Ensure GID is set for highlighting
|
|
194
|
+
if line.get_gid() is None:
|
|
195
|
+
gid = f"maidr-{uuid.uuid4()}"
|
|
196
|
+
line.set_gid(gid)
|
|
197
|
+
|
|
198
|
+
# Add to valid lines list
|
|
199
|
+
valid_lines.append(line)
|
|
200
|
+
|
|
201
|
+
# Register all valid lines as a single LINE plot
|
|
202
|
+
if valid_lines:
|
|
203
|
+
line_kwargs = dict(kwargs)
|
|
204
|
+
common(PlotType.LINE, lambda *a, **k: price_ax, instance, args, line_kwargs)
|
|
205
|
+
|
|
206
|
+
if not original_returnfig:
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# Apply the patch to mplfinance.plot
|
|
213
|
+
wrapt.wrap_function_wrapper(mpf, "plot", mplfinance_plot_patch)
|
maidr/util/__init__.py
CHANGED
maidr/util/environment.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class Environment:
|
|
@@ -24,10 +25,19 @@ class Environment:
|
|
|
24
25
|
try:
|
|
25
26
|
from IPython import get_ipython # type: ignore
|
|
26
27
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
ipy = get_ipython()
|
|
29
|
+
if ipy is not None:
|
|
30
|
+
# Check for Pyodide/JupyterLite specific indicators
|
|
31
|
+
ipy_str = str(ipy).lower()
|
|
32
|
+
if "pyodide" in ipy_str or "jupyterlite" in ipy_str:
|
|
33
|
+
return True
|
|
34
|
+
# Check for other notebook indicators
|
|
35
|
+
if "ipykernel" in str(ipy) or "google.colab" in str(ipy):
|
|
36
|
+
return True
|
|
37
|
+
# Check for Pyodide platform
|
|
38
|
+
if sys.platform == "emscripten":
|
|
39
|
+
return True
|
|
40
|
+
return False
|
|
31
41
|
except ImportError:
|
|
32
42
|
return False
|
|
33
43
|
|
|
@@ -61,10 +71,19 @@ class Environment:
|
|
|
61
71
|
ipy = ( # pyright: ignore[reportUnknownVariableType]
|
|
62
72
|
IPython.get_ipython() # pyright: ignore[reportUnknownMemberType, reportPrivateImportUsage]
|
|
63
73
|
)
|
|
64
|
-
|
|
74
|
+
if ipy is not None:
|
|
75
|
+
# Check for Pyodide/JupyterLite
|
|
76
|
+
ipy_str = str(ipy).lower()
|
|
77
|
+
if "pyodide" in ipy_str or "jupyterlite" in ipy_str:
|
|
78
|
+
return "ipython"
|
|
79
|
+
# Check for Pyodide platform
|
|
80
|
+
if sys.platform == "emscripten":
|
|
81
|
+
return "ipython"
|
|
82
|
+
return "ipython"
|
|
83
|
+
else:
|
|
84
|
+
return "browser"
|
|
65
85
|
except ImportError:
|
|
66
|
-
|
|
67
|
-
return renderer
|
|
86
|
+
return "browser"
|
|
68
87
|
|
|
69
88
|
@staticmethod
|
|
70
89
|
def initialize_llm_secrets(unique_id: str) -> str:
|