maidr 1.2.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,146 @@
1
+ from typing import List, Union
2
+
3
+ from matplotlib.axes import Axes
4
+ from matplotlib.lines import Line2D
5
+ import numpy as np
6
+
7
+ from maidr.core.enum.maidr_key import MaidrKey
8
+ from maidr.core.enum.plot_type import PlotType
9
+ from maidr.core.plot.maidr_plot import MaidrPlot
10
+ from maidr.exception.extraction_error import ExtractionError
11
+ from maidr.util.mixin.extractor_mixin import LineExtractorMixin
12
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
13
+ import uuid
14
+
15
+
16
+ class MplfinanceLinePlot(MaidrPlot, LineExtractorMixin):
17
+ """
18
+ Specialized line plot class for mplfinance moving averages.
19
+
20
+ This class handles the extraction and processing of moving average data from mplfinance
21
+ plots, including proper date conversion, NaN filtering, and moving average period detection.
22
+ """
23
+
24
+ def __init__(self, ax: Axes, **kwargs):
25
+ super().__init__(ax, PlotType.LINE)
26
+
27
+ def _get_selector(self) -> Union[str, List[str]]:
28
+ """Return selectors for all lines that have data."""
29
+ all_lines = self.ax.get_lines()
30
+ if not all_lines:
31
+ return ["g[maidr='true'] > path"]
32
+
33
+ selectors = []
34
+ for line in all_lines:
35
+ # Only create selectors for lines that have data
36
+ xydata = line.get_xydata()
37
+ if xydata is None or not xydata.size: # type: ignore
38
+ continue
39
+ gid = line.get_gid()
40
+ if gid:
41
+ selectors.append(f"g[id='{gid}'] path")
42
+ else:
43
+ selectors.append("g[maidr='true'] > path")
44
+
45
+ if not selectors:
46
+ return ["g[maidr='true'] > path"]
47
+
48
+ return selectors
49
+
50
+ def _extract_plot_data(self) -> Union[List[List[dict]], None]:
51
+ """Extract data from mplfinance moving average lines."""
52
+ data = self._extract_line_data()
53
+
54
+ if data is None:
55
+ raise ExtractionError(self.type, None)
56
+
57
+ return data
58
+
59
+ def _extract_line_data(self) -> Union[List[List[dict]], None]:
60
+ """
61
+ Extract data from all line objects and return as separate arrays.
62
+
63
+ This method handles mplfinance-specific logic including:
64
+ - Date conversion from matplotlib date numbers
65
+ - NaN filtering for moving averages
66
+ - Moving average period detection
67
+ - Proper data validation
68
+
69
+ Returns
70
+ -------
71
+ list[list[dict]] | None
72
+ List of lists, where each inner list contains dictionaries with x,y coordinates
73
+ and line identifiers for one line, or None if the plot data is invalid.
74
+ """
75
+ all_lines = self.ax.get_lines()
76
+ if not all_lines:
77
+ return None
78
+
79
+ all_lines_data = []
80
+
81
+ for line_idx, line in enumerate(all_lines):
82
+ xydata = line.get_xydata()
83
+ if xydata is None or not xydata.size: # type: ignore
84
+ continue
85
+
86
+ self._elements.append(line)
87
+
88
+ # Assign unique GID to each line if not already set
89
+ if line.get_gid() is None:
90
+ unique_gid = f"maidr-{uuid.uuid4()}"
91
+ line.set_gid(unique_gid)
92
+
93
+ label: str = line.get_label() # type: ignore
94
+ line_data = []
95
+
96
+ # Check if this line has date numbers from mplfinance
97
+ date_nums = getattr(line, "_maidr_date_nums", None)
98
+
99
+ # Convert xydata to list of points
100
+ for i, (x, y) in enumerate(line.get_xydata()): # type: ignore
101
+ # Skip points with NaN or inf values to prevent JSON parsing errors
102
+ if np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y):
103
+ continue
104
+
105
+ # Handle x-value conversion - could be string (date) or numeric
106
+ if isinstance(x, str):
107
+ x_value = x # Keep string as-is (for dates)
108
+ else:
109
+ # Check if we have date numbers from mplfinance
110
+ if date_nums is not None and i < len(date_nums):
111
+ # Use the date number to convert to date string
112
+ date_num = float(date_nums[i])
113
+ x_value = self._convert_x_to_date(date_num)
114
+ else:
115
+ x_value = float(x) # Convert numeric to float
116
+
117
+ point_data = {
118
+ MaidrKey.X: x_value,
119
+ MaidrKey.Y: float(y),
120
+ MaidrKey.FILL: (label if not label.startswith("_child") else ""),
121
+ }
122
+ line_data.append(point_data)
123
+
124
+ if line_data:
125
+ all_lines_data.append(line_data)
126
+
127
+ return all_lines_data if all_lines_data else None
128
+
129
+ def _convert_x_to_date(self, x_value: float) -> str:
130
+ """
131
+ Convert x-coordinate to date string for mplfinance plots.
132
+
133
+ This method uses the MplfinanceDataExtractor utility to convert
134
+ matplotlib date numbers to proper date strings.
135
+
136
+ Parameters
137
+ ----------
138
+ x_value : float
139
+ The x-coordinate value (matplotlib date number)
140
+
141
+ Returns
142
+ -------
143
+ str
144
+ Date string in YYYY-MM-DD format
145
+ """
146
+ return MplfinanceDataExtractor._convert_date_num_to_string(x_value)
maidr/patch/__init__.py CHANGED
@@ -0,0 +1,15 @@
1
+ # Import all patches to ensure they are applied
2
+ from . import (
3
+ barplot,
4
+ boxplot,
5
+ candlestick,
6
+ clear,
7
+ heatmap,
8
+ highlight,
9
+ histogram,
10
+ lineplot,
11
+ scatterplot,
12
+ regplot,
13
+ kdeplot,
14
+ mplfinance,
15
+ )
@@ -0,0 +1,213 @@
1
+ import uuid
2
+ import wrapt
3
+ import mplfinance as mpf
4
+ import numpy as np
5
+ from matplotlib.collections import LineCollection, PolyCollection
6
+ from matplotlib.patches import Rectangle
7
+ from matplotlib.lines import Line2D
8
+ from maidr.core.enum import PlotType
9
+ from maidr.patch.common import common
10
+ from maidr.core.context_manager import ContextManager
11
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
12
+
13
+
14
+ def mplfinance_plot_patch(wrapped, instance, args, kwargs):
15
+ """
16
+ Enhanced patch function for `mplfinance.plot` that registers separate layers:
17
+ - CANDLESTICK: For OHLC data (candle bodies and wicks)
18
+ - BAR: For volume data (volume bars)
19
+ - LINE: For moving averages (lines)
20
+
21
+ This function intercepts calls to `mplfinance.plot`, identifies the resulting
22
+ candlestick, volume, and moving average components, and registers them with
23
+ maidr using the common patching mechanism.
24
+ """
25
+ # Ensure `returnfig=True` to capture the figure and axes objects.
26
+ original_returnfig = kwargs.get("returnfig", False)
27
+ kwargs["returnfig"] = True
28
+
29
+ with ContextManager.set_internal_context():
30
+ result = wrapped(*args, **kwargs)
31
+
32
+ # Validate that we received the expected figure and axes tuple
33
+ if not (isinstance(result, tuple) and len(result) >= 2):
34
+ return result if original_returnfig else None
35
+
36
+ fig, axes = result[0], result[1]
37
+ ax_list = axes if isinstance(axes, list) else [axes]
38
+
39
+ # Enhanced axis identification using content-based detection
40
+ price_ax = None
41
+ volume_ax = None
42
+
43
+ # Identify axes by their content rather than just labels
44
+ for ax in ax_list:
45
+ # Price axis has candlestick collections (LineCollection for wicks, PolyCollection for bodies)
46
+ if any(isinstance(c, (LineCollection, PolyCollection)) for c in ax.collections):
47
+ price_ax = ax
48
+ # Volume axis has rectangle patches for volume bars
49
+ elif any(isinstance(p, Rectangle) for p in ax.patches):
50
+ volume_ax = ax
51
+ # Fallback: use y-label if content-based detection fails
52
+ elif price_ax is None and "price" in ax.get_ylabel().lower():
53
+ price_ax = ax
54
+ elif volume_ax is None and "volume" in ax.get_ylabel().lower():
55
+ volume_ax = ax
56
+
57
+ # Try to extract date numbers from the data
58
+ date_nums = None
59
+ data = None
60
+ if len(args) > 0:
61
+ data = args[0]
62
+ elif "data" in kwargs:
63
+ data = kwargs["data"]
64
+
65
+ if data is not None:
66
+ if hasattr(data, "Date_num"):
67
+ date_nums = list(data["Date_num"])
68
+ elif hasattr(data, "index"):
69
+ # fallback: use index if it's a DatetimeIndex
70
+ try:
71
+ import matplotlib.dates as mdates
72
+
73
+ date_nums = [mdates.date2num(d) for d in data.index]
74
+ except Exception:
75
+ pass
76
+
77
+ # Process and register the Candlestick plot
78
+ if price_ax:
79
+ wick_collection = next(
80
+ (c for c in price_ax.collections if isinstance(c, LineCollection)), None
81
+ )
82
+ body_collection = next(
83
+ (c for c in price_ax.collections if isinstance(c, PolyCollection)), None
84
+ )
85
+
86
+ if wick_collection and body_collection:
87
+ gid = f"maidr-{uuid.uuid4()}"
88
+ wick_collection.set_gid(gid)
89
+ body_collection.set_gid(gid)
90
+
91
+ candlestick_kwargs = dict(
92
+ kwargs,
93
+ _maidr_wick_collection=wick_collection,
94
+ _maidr_body_collection=body_collection,
95
+ _maidr_date_nums=date_nums,
96
+ )
97
+ common(
98
+ PlotType.CANDLESTICK,
99
+ lambda *a, **k: price_ax,
100
+ instance,
101
+ args,
102
+ candlestick_kwargs,
103
+ )
104
+
105
+ # Process and register the Volume plot
106
+ if volume_ax:
107
+ volume_patches = [p for p in volume_ax.patches if isinstance(p, Rectangle)]
108
+
109
+ if not volume_patches:
110
+ # Search in shared axes for volume patches
111
+ for twin_ax in volume_ax.get_shared_x_axes().get_siblings(volume_ax):
112
+ if twin_ax is not volume_ax:
113
+ volume_patches.extend(
114
+ [p for p in twin_ax.patches if isinstance(p, Rectangle)]
115
+ )
116
+
117
+ if volume_patches:
118
+ # Set GID for volume patches for highlighting
119
+ for patch in volume_patches:
120
+ if patch.get_gid() is None:
121
+ gid = f"maidr-{uuid.uuid4()}"
122
+ patch.set_gid(gid)
123
+
124
+ bar_kwargs = dict(
125
+ kwargs,
126
+ _maidr_patches=volume_patches,
127
+ _maidr_date_nums=date_nums,
128
+ )
129
+ common(PlotType.BAR, lambda *a, **k: volume_ax, instance, args, bar_kwargs)
130
+
131
+ # Process and register Moving Averages as LINE plots
132
+ if price_ax:
133
+ # Find moving average lines (Line2D objects)
134
+ ma_lines = [line for line in price_ax.get_lines() if isinstance(line, Line2D)]
135
+
136
+ # Track processed lines to avoid duplicates
137
+ processed_lines = set()
138
+ valid_lines = []
139
+
140
+ for line in ma_lines:
141
+ # Try to identify the moving average period based on NaN count
142
+ xydata = line.get_xydata()
143
+
144
+ if xydata is not None:
145
+ xydata_array = np.asarray(xydata)
146
+ nan_count = np.sum(
147
+ np.isnan(xydata_array[:, 1])
148
+ ) # Count NaN in y-values
149
+
150
+ # Map NaN count to likely moving average period
151
+ estimated_period = nan_count + 1
152
+
153
+ # Create a better label for the line
154
+ label = str(line.get_label())
155
+ if label.startswith("_child"):
156
+ new_label = f"Moving Average {estimated_period} days"
157
+ line.set_label(new_label)
158
+ else:
159
+ # If it's not a _child label, still add the period info
160
+ new_label = f"{label}_MA{estimated_period}"
161
+ line.set_label(new_label)
162
+
163
+ # Create a unique identifier for this line based on its data
164
+ if xydata is not None:
165
+ xydata_array = np.asarray(xydata)
166
+ if xydata_array.size > 0:
167
+ # Use shape and first few values to create a unique identifier
168
+ first_values = (
169
+ xydata_array[:3].flatten()
170
+ if xydata_array.size >= 6
171
+ else xydata_array.flatten()
172
+ )
173
+ data_hash = hash(f"{xydata_array.shape}_{str(first_values)}")
174
+ line_id = f"{line.get_label()}_{data_hash}"
175
+ else:
176
+ line_id = f"{line.get_label()}"
177
+ else:
178
+ line_id = f"{line.get_label()}"
179
+
180
+ if line_id in processed_lines:
181
+ continue
182
+
183
+ processed_lines.add(line_id)
184
+
185
+ # Validate that the line has valid data
186
+ if xydata is None or xydata_array.size == 0:
187
+ continue
188
+
189
+ # Store date numbers on the line for the line plot class to use
190
+ if date_nums is not None:
191
+ setattr(line, "_maidr_date_nums", date_nums)
192
+
193
+ # Ensure GID is set for highlighting
194
+ if line.get_gid() is None:
195
+ gid = f"maidr-{uuid.uuid4()}"
196
+ line.set_gid(gid)
197
+
198
+ # Add to valid lines list
199
+ valid_lines.append(line)
200
+
201
+ # Register all valid lines as a single LINE plot
202
+ if valid_lines:
203
+ line_kwargs = dict(kwargs)
204
+ common(PlotType.LINE, lambda *a, **k: price_ax, instance, args, line_kwargs)
205
+
206
+ if not original_returnfig:
207
+ return None
208
+
209
+ return result
210
+
211
+
212
+ # Apply the patch to mplfinance.plot
213
+ wrapt.wrap_function_wrapper(mpf, "plot", mplfinance_plot_patch)
maidr/util/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ from .plot_detection import PlotDetectionUtils
2
+
3
+ __all__ = ["PlotDetectionUtils"]
maidr/util/environment.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ import sys
3
4
 
4
5
 
5
6
  class Environment:
@@ -24,10 +25,19 @@ class Environment:
24
25
  try:
25
26
  from IPython import get_ipython # type: ignore
26
27
 
27
- return get_ipython() is not None and (
28
- "ipykernel" in str(get_ipython())
29
- or "google.colab" in str(get_ipython())
30
- )
28
+ ipy = get_ipython()
29
+ if ipy is not None:
30
+ # Check for Pyodide/JupyterLite specific indicators
31
+ ipy_str = str(ipy).lower()
32
+ if "pyodide" in ipy_str or "jupyterlite" in ipy_str:
33
+ return True
34
+ # Check for other notebook indicators
35
+ if "ipykernel" in str(ipy) or "google.colab" in str(ipy):
36
+ return True
37
+ # Check for Pyodide platform
38
+ if sys.platform == "emscripten":
39
+ return True
40
+ return False
31
41
  except ImportError:
32
42
  return False
33
43
 
@@ -61,10 +71,19 @@ class Environment:
61
71
  ipy = ( # pyright: ignore[reportUnknownVariableType]
62
72
  IPython.get_ipython() # pyright: ignore[reportUnknownMemberType, reportPrivateImportUsage]
63
73
  )
64
- renderer = "ipython" if ipy else "browser"
74
+ if ipy is not None:
75
+ # Check for Pyodide/JupyterLite
76
+ ipy_str = str(ipy).lower()
77
+ if "pyodide" in ipy_str or "jupyterlite" in ipy_str:
78
+ return "ipython"
79
+ # Check for Pyodide platform
80
+ if sys.platform == "emscripten":
81
+ return "ipython"
82
+ return "ipython"
83
+ else:
84
+ return "browser"
65
85
  except ImportError:
66
- renderer = "browser"
67
- return renderer
86
+ return "browser"
68
87
 
69
88
  @staticmethod
70
89
  def initialize_llm_secrets(unique_id: str) -> str: