maidr 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ from matplotlib.axes import Axes
4
+ from matplotlib.patches import Rectangle
5
+
6
+ from maidr.core.enum import PlotType
7
+ from maidr.core.plot import MaidrPlot
8
+ from maidr.exception import ExtractionError
9
+ from maidr.util.mixin import (
10
+ ContainerExtractorMixin,
11
+ DictMergerMixin,
12
+ LevelExtractorMixin,
13
+ )
14
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
15
+
16
+
17
+ class MplfinanceBarPlot(
18
+ MaidrPlot, ContainerExtractorMixin, LevelExtractorMixin, DictMergerMixin
19
+ ):
20
+ """
21
+ Specialized bar plot class for mplfinance volume bars.
22
+
23
+ This class handles the extraction and processing of volume data from mplfinance
24
+ plots, including proper date conversion and data validation.
25
+ """
26
+
27
+ def __init__(self, ax: Axes, **kwargs) -> None:
28
+ super().__init__(ax, PlotType.BAR)
29
+ # Store custom patches passed from mplfinance patch
30
+ self._custom_patches = kwargs.get("_maidr_patches", None)
31
+ # Store date numbers for volume bars (from mplfinance)
32
+ self._maidr_date_nums = kwargs.get("_maidr_date_nums", None)
33
+ # Store custom title
34
+
35
+ def set_title(self, title: str) -> None:
36
+ """Set a custom title for this volume bar plot."""
37
+ self._title = title
38
+
39
+ def _extract_plot_data(self) -> list:
40
+ """Extract data from mplfinance volume patches."""
41
+ if self._custom_patches:
42
+ return self._extract_volume_patches_data(self._custom_patches)
43
+
44
+ # Fallback to original bar plot logic if no custom patches
45
+ plot = self.extract_container(
46
+ self.ax, ContainerExtractorMixin, include_all=True
47
+ )
48
+ data = self._extract_bar_container_data(plot)
49
+ levels = self.extract_level(self.ax)
50
+ formatted_data = []
51
+ combined_data = list(
52
+ zip(levels, data) if plot[0].orientation == "vertical" else zip(data, levels) # type: ignore
53
+ )
54
+ if combined_data: # type: ignore
55
+ for x, y in combined_data: # type: ignore
56
+ formatted_data.append({"x": x, "y": y})
57
+ return formatted_data
58
+ if len(formatted_data) == 0:
59
+ raise ExtractionError(self.type, plot)
60
+ if data is None:
61
+ raise ExtractionError(self.type, plot)
62
+
63
+ return data
64
+
65
+ def _extract_volume_patches_data(self, volume_patches: list[Rectangle]) -> list:
66
+ """
67
+ Extract data from volume Rectangle patches (used by mplfinance volume bars).
68
+ """
69
+ if not volume_patches:
70
+ return []
71
+
72
+ # Sort patches by x-coordinate to maintain order
73
+ sorted_patches = sorted(volume_patches, key=lambda p: p.get_x())
74
+
75
+ # Set elements for highlighting (use the patches directly)
76
+ self._elements = sorted_patches
77
+
78
+ # Use the utility class to extract data
79
+ return MplfinanceDataExtractor.extract_volume_data(
80
+ sorted_patches, self._maidr_date_nums
81
+ )
82
+
83
+ def _extract_bar_container_data(self, plot: list | None) -> list | None:
84
+ """Fallback method for regular bar containers."""
85
+ if plot is None:
86
+ return None
87
+
88
+ # Since v0.13, Seaborn has transitioned from using `list[Patch]` to
89
+ # `list[BarContainers] for plotting bar plots.
90
+ # So, extract data correspondingly based on the level.
91
+ # Flatten all the `list[BarContainer]` to `list[Patch]`.
92
+ patches = [patch for container in plot for patch in container.patches]
93
+ level = self.extract_level(self.ax)
94
+ if level is None or len(level) == 0: # type: ignore
95
+ level = ["" for _ in range(len(patches))] # type: ignore
96
+
97
+ if len(patches) != len(level):
98
+ return None
99
+
100
+ self._elements.extend(patches)
101
+
102
+ return [float(patch.get_height()) for patch in patches]
103
+
104
+ def _extract_axes_data(self) -> dict:
105
+ """Extract axes data with mplfinance-specific cleaning."""
106
+ ax_data = super()._extract_axes_data()
107
+
108
+ # For mplfinance volume plots, clean up the y-axis label
109
+ if self._custom_patches:
110
+ y_label = ax_data.get("y")
111
+ if y_label:
112
+ ax_data["y"] = MplfinanceDataExtractor.clean_axis_label(y_label)
113
+
114
+ return ax_data
115
+
116
+ def _get_selector(self) -> str:
117
+ """Return the CSS selector for highlighting bar elements in the SVG output."""
118
+ # Use the standard working selector that gets replaced with UUID by Maidr class
119
+ # This works for both original bar plots and mplfinance volume bars
120
+ return "g[maidr='true'] > path"
121
+
122
+ def render(self) -> dict:
123
+ """Initialize the MAIDR schema dictionary with basic plot information."""
124
+ from maidr.core.enum.maidr_key import MaidrKey
125
+
126
+ title = "Volume Bar Plot"
127
+
128
+ maidr_schema = {
129
+ MaidrKey.TYPE: self.type,
130
+ MaidrKey.TITLE: title,
131
+ MaidrKey.AXES: self._extract_axes_data(),
132
+ MaidrKey.DATA: self._extract_plot_data(),
133
+ }
134
+
135
+ # Include selector only if the plot supports highlighting.
136
+ if self._support_highlighting:
137
+ maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
138
+
139
+ return maidr_schema
@@ -0,0 +1,273 @@
1
+ from typing import List, Union
2
+
3
+ from matplotlib.axes import Axes
4
+ from matplotlib.lines import Line2D
5
+ import numpy as np
6
+
7
+ from maidr.core.enum.maidr_key import MaidrKey
8
+ from maidr.core.enum.plot_type import PlotType
9
+ from maidr.core.plot.maidr_plot import MaidrPlot
10
+ from maidr.exception.extraction_error import ExtractionError
11
+ from maidr.util.mixin.extractor_mixin import LineExtractorMixin
12
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
13
+ import uuid
14
+
15
+
16
+ class MplfinanceLinePlot(MaidrPlot, LineExtractorMixin):
17
+ """
18
+ Specialized line plot class for mplfinance moving averages.
19
+
20
+ This class handles the extraction and processing of moving average data from mplfinance
21
+ plots, including proper date conversion, NaN filtering, and moving average period detection.
22
+ """
23
+
24
+ def __init__(self, ax: Axes, **kwargs):
25
+ super().__init__(ax, PlotType.LINE)
26
+ self._line_titles = [] # Store line titles separately
27
+
28
+ def _get_selector(self) -> Union[str, List[str]]:
29
+ """Return selectors for all lines that have data."""
30
+ all_lines = self.ax.get_lines()
31
+ if not all_lines:
32
+ return ["g[maidr='true'] > path"]
33
+
34
+ selectors = []
35
+ for line in all_lines:
36
+ # Only create selectors for lines that have data
37
+ xydata = line.get_xydata()
38
+ if xydata is None or not xydata.size: # type: ignore
39
+ continue
40
+ gid = line.get_gid()
41
+ if gid:
42
+ selectors.append(f"g[id='{gid}'] path")
43
+ else:
44
+ selectors.append("g[maidr='true'] > path")
45
+
46
+ if not selectors:
47
+ return ["g[maidr='true'] > path"]
48
+
49
+ return selectors
50
+
51
+ def _extract_axes_data(self) -> dict:
52
+ """
53
+ Extract axis labels for the plot.
54
+
55
+ Returns
56
+ -------
57
+ dict
58
+ Dictionary containing x and y axis labels with custom y-label for moving averages.
59
+ """
60
+ x_labels = self.ax.get_xlabel()
61
+ if not x_labels:
62
+ x_labels = self.extract_shared_xlabel(self.ax)
63
+ if not x_labels:
64
+ x_labels = "Date"
65
+
66
+ # Get the period from the first line for y-axis label
67
+ ma_period = self._extract_moving_average_period()
68
+ y_label = (
69
+ f"{ma_period}-day mav price ($)"
70
+ if ma_period
71
+ else "Moving Average Price ($)"
72
+ )
73
+
74
+ return {MaidrKey.X: x_labels, MaidrKey.Y: y_label}
75
+
76
+ def _extract_moving_average_periods(self) -> List[str]:
77
+ """
78
+ Extract all moving average periods from the _maidr_ma_period attributes set by the mplfinance patch.
79
+
80
+ Returns
81
+ -------
82
+ List[str]
83
+ List of moving average periods (e.g., ["3", "6", "30"]).
84
+ """
85
+ all_lines = self.ax.get_lines()
86
+ periods = []
87
+ for line in all_lines:
88
+ # Get the period that was stored by the mplfinance patch
89
+ ma_period = getattr(line, "_maidr_ma_period", None)
90
+ if ma_period is not None:
91
+ periods.append(str(ma_period))
92
+
93
+ # Remove duplicates and sort
94
+ periods = sorted(list(set(periods)))
95
+
96
+ return periods
97
+
98
+ def _extract_moving_average_period(self) -> str:
99
+ """
100
+ Extract the moving average period from the _maidr_ma_period attribute set by the mplfinance patch.
101
+
102
+ Returns
103
+ -------
104
+ str
105
+ The moving average period (e.g., "3", "6", "30") or empty string if no period found.
106
+ """
107
+ periods = self._extract_moving_average_periods()
108
+ return periods[0] if periods else ""
109
+
110
+ def _extract_plot_data(self) -> Union[List[List[dict]], None]:
111
+ """Extract data from mplfinance moving average lines."""
112
+ data = self._extract_line_data()
113
+
114
+ if data is None:
115
+ raise ExtractionError(self.type, None)
116
+
117
+ return data
118
+
119
+ def _extract_line_data(self) -> Union[List[List[dict]], None]:
120
+ """
121
+ Extract data from all line objects and return as separate arrays.
122
+
123
+ This method handles mplfinance-specific logic including:
124
+ - Date conversion from matplotlib date numbers
125
+ - NaN filtering for moving averages
126
+ - Moving average period detection
127
+ - Proper data validation
128
+
129
+ Returns
130
+ -------
131
+ list[list[dict]] | None
132
+ List of lists, where each inner list contains dictionaries with x,y coordinates
133
+ and line identifiers for one line, or None if the plot data is invalid.
134
+ """
135
+ all_lines = self.ax.get_lines()
136
+ if not all_lines:
137
+ return None
138
+
139
+ all_lines_data = []
140
+
141
+ for line_idx, line in enumerate(all_lines):
142
+ xydata = line.get_xydata()
143
+ if xydata is None or not xydata.size: # type: ignore
144
+ continue
145
+
146
+ self._elements.append(line)
147
+
148
+ # Assign unique GID to each line if not already set
149
+ if line.get_gid() is None:
150
+ unique_gid = f"maidr-{uuid.uuid4()}"
151
+ line.set_gid(unique_gid)
152
+
153
+ label: str = line.get_label() # type: ignore
154
+
155
+ # Get the period for this specific line
156
+ ma_period = getattr(line, "_maidr_ma_period", None)
157
+
158
+ # Create title for this line
159
+ line_title = (
160
+ f"{ma_period}-Day Moving Average Line Plot"
161
+ if ma_period
162
+ else "Moving Average Line Plot"
163
+ )
164
+
165
+ line_data = []
166
+
167
+ # Check if this line has date numbers from mplfinance
168
+ date_nums = getattr(line, "_maidr_date_nums", None)
169
+
170
+ # Convert xydata to list of points
171
+ for i, (x, y) in enumerate(line.get_xydata()): # type: ignore
172
+ # Skip points with NaN or inf values to prevent JSON parsing errors
173
+ if np.isnan(x) or np.isnan(y) or np.isinf(x) or np.isinf(y):
174
+ continue
175
+
176
+ # Handle x-value conversion - could be string (date) or numeric
177
+ if isinstance(x, str):
178
+ x_value = x # Keep string as-is (for dates)
179
+ else:
180
+ # Check if we have date numbers from mplfinance
181
+ if date_nums is not None and i < len(date_nums):
182
+ # Use the date number to convert to date string
183
+ date_num = float(date_nums[i])
184
+ x_value = self._convert_x_to_date(date_num)
185
+ else:
186
+ x_value = float(x) # Convert numeric to float
187
+
188
+ point_data = {
189
+ MaidrKey.X: x_value,
190
+ MaidrKey.Y: float(y),
191
+ }
192
+ line_data.append(point_data)
193
+
194
+ if line_data:
195
+ # Create line data with title, axes, and points structure
196
+ line_with_metadata = {
197
+ "title": line_title,
198
+ "axes": {
199
+ "x": "Date",
200
+ "y": f"{ma_period}-day mav price ($)"
201
+ if ma_period
202
+ else "Moving Average Price ($)",
203
+ },
204
+ "points": line_data,
205
+ }
206
+ all_lines_data.append(line_with_metadata)
207
+
208
+ return all_lines_data if all_lines_data else None
209
+
210
+ def _convert_x_to_date(self, x_value: float) -> str:
211
+ """
212
+ Convert x-coordinate to date string for mplfinance plots.
213
+
214
+ This method uses the MplfinanceDataExtractor utility to convert
215
+ matplotlib date numbers to proper date strings.
216
+
217
+ Parameters
218
+ ----------
219
+ x_value : float
220
+ The x-coordinate value (matplotlib date number)
221
+
222
+ Returns
223
+ -------
224
+ str
225
+ Date string in YYYY-MM-DD format
226
+ """
227
+ return MplfinanceDataExtractor._convert_date_num_to_string(x_value)
228
+
229
+ def _extract_line_titles(self) -> List[str]:
230
+ """
231
+ Extract titles for all moving average lines.
232
+
233
+ Returns
234
+ -------
235
+ List[str]
236
+ List of titles for each line.
237
+ """
238
+ all_lines = self.ax.get_lines()
239
+ titles = []
240
+
241
+ for line in all_lines:
242
+ ma_period = getattr(line, "_maidr_ma_period", None)
243
+ title = (
244
+ f"{ma_period}-Day Moving Average Line Plot"
245
+ if ma_period
246
+ else "Moving Average Line Plot"
247
+ )
248
+ titles.append(title)
249
+
250
+ return titles
251
+
252
+ def render(self) -> dict:
253
+ """Initialize the MAIDR schema dictionary with basic plot information."""
254
+ # Use the first line's period for the main title
255
+ ma_period = self._extract_moving_average_period()
256
+ title = (
257
+ f"{ma_period}-Day Moving Averages Line Plot"
258
+ if ma_period
259
+ else "Moving Averages Line Plot"
260
+ )
261
+
262
+ maidr_schema = {
263
+ MaidrKey.TYPE: self.type,
264
+ MaidrKey.TITLE: title,
265
+ MaidrKey.AXES: self._extract_axes_data(),
266
+ MaidrKey.DATA: self._extract_plot_data(),
267
+ }
268
+
269
+ # Include selector only if the plot supports highlighting.
270
+ if self._support_highlighting:
271
+ maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
272
+
273
+ return maidr_schema
maidr/patch/__init__.py CHANGED
@@ -0,0 +1,15 @@
1
+ # Import all patches to ensure they are applied
2
+ from . import (
3
+ barplot,
4
+ boxplot,
5
+ candlestick,
6
+ clear,
7
+ heatmap,
8
+ highlight,
9
+ histogram,
10
+ lineplot,
11
+ scatterplot,
12
+ regplot,
13
+ kdeplot,
14
+ mplfinance,
15
+ )
@@ -0,0 +1,215 @@
1
+ import uuid
2
+ import wrapt
3
+ import mplfinance as mpf
4
+ import numpy as np
5
+ from matplotlib.collections import LineCollection, PolyCollection
6
+ from matplotlib.patches import Rectangle
7
+ from matplotlib.lines import Line2D
8
+ from maidr.core.enum import PlotType
9
+ from maidr.patch.common import common
10
+ from maidr.core.context_manager import ContextManager
11
+
12
+
13
+ def mplfinance_plot_patch(wrapped, instance, args, kwargs):
14
+ """
15
+ Enhanced patch function for `mplfinance.plot` that registers separate layers:
16
+ - CANDLESTICK: For OHLC data (candle bodies and wicks)
17
+ - BAR: For volume data (volume bars)
18
+ - LINE: For moving averages (lines)
19
+
20
+ This function intercepts calls to `mplfinance.plot`, identifies the resulting
21
+ candlestick, volume, and moving average components, and registers them with
22
+ maidr using the common patching mechanism.
23
+ """
24
+ # Ensure `returnfig=True` to capture the figure and axes objects.
25
+ original_returnfig = kwargs.get("returnfig", False)
26
+ kwargs["returnfig"] = True
27
+
28
+ with ContextManager.set_internal_context():
29
+ result = wrapped(*args, **kwargs)
30
+
31
+ # Validate that we received the expected figure and axes tuple
32
+ if not (isinstance(result, tuple) and len(result) >= 2):
33
+ return result if original_returnfig else None
34
+
35
+ fig, axes = result[0], result[1]
36
+ ax_list = axes if isinstance(axes, list) else [axes]
37
+
38
+ # Enhanced axis identification using content-based detection
39
+ price_ax = None
40
+ volume_ax = None
41
+
42
+ # Identify axes by their content rather than just labels
43
+ for ax in ax_list:
44
+ # Price axis has candlestick collections (LineCollection for wicks, PolyCollection for bodies)
45
+ if any(isinstance(c, (LineCollection, PolyCollection)) for c in ax.collections):
46
+ price_ax = ax
47
+ # Volume axis has rectangle patches for volume bars
48
+ elif any(isinstance(p, Rectangle) for p in ax.patches):
49
+ volume_ax = ax
50
+ # Fallback: use y-label if content-based detection fails
51
+ elif price_ax is None and "price" in ax.get_ylabel().lower():
52
+ price_ax = ax
53
+ elif volume_ax is None and "volume" in ax.get_ylabel().lower():
54
+ volume_ax = ax
55
+
56
+ # Try to extract date numbers from the data
57
+ date_nums = None
58
+ data = None
59
+ if len(args) > 0:
60
+ data = args[0]
61
+ elif "data" in kwargs:
62
+ data = kwargs["data"]
63
+
64
+ if data is not None:
65
+ if hasattr(data, "Date_num"):
66
+ date_nums = list(data["Date_num"])
67
+ elif hasattr(data, "index"):
68
+ # fallback: use index if it's a DatetimeIndex
69
+ try:
70
+ import matplotlib.dates as mdates
71
+
72
+ date_nums = [mdates.date2num(d) for d in data.index]
73
+ except Exception:
74
+ pass
75
+
76
+ # Process and register the Candlestick plot
77
+ if price_ax:
78
+ wick_collection = next(
79
+ (c for c in price_ax.collections if isinstance(c, LineCollection)), None
80
+ )
81
+ body_collection = next(
82
+ (c for c in price_ax.collections if isinstance(c, PolyCollection)), None
83
+ )
84
+
85
+ if wick_collection and body_collection:
86
+ gid = f"maidr-{uuid.uuid4()}"
87
+ wick_collection.set_gid(gid)
88
+ body_collection.set_gid(gid)
89
+
90
+ candlestick_kwargs = dict(
91
+ kwargs,
92
+ _maidr_wick_collection=wick_collection,
93
+ _maidr_body_collection=body_collection,
94
+ _maidr_date_nums=date_nums,
95
+ )
96
+ common(
97
+ PlotType.CANDLESTICK,
98
+ lambda *a, **k: price_ax,
99
+ instance,
100
+ args,
101
+ candlestick_kwargs,
102
+ )
103
+
104
+ # Process and register the Volume plot
105
+ if volume_ax:
106
+ volume_patches = [p for p in volume_ax.patches if isinstance(p, Rectangle)]
107
+
108
+ if not volume_patches:
109
+ # Search in shared axes for volume patches
110
+ for twin_ax in volume_ax.get_shared_x_axes().get_siblings(volume_ax):
111
+ if twin_ax is not volume_ax:
112
+ volume_patches.extend(
113
+ [p for p in twin_ax.patches if isinstance(p, Rectangle)]
114
+ )
115
+
116
+ if volume_patches:
117
+ # Set GID for volume patches for highlighting
118
+ for patch in volume_patches:
119
+ if patch.get_gid() is None:
120
+ gid = f"maidr-{uuid.uuid4()}"
121
+ patch.set_gid(gid)
122
+
123
+ bar_kwargs = dict(
124
+ kwargs,
125
+ _maidr_patches=volume_patches,
126
+ _maidr_date_nums=date_nums,
127
+ )
128
+ common(PlotType.BAR, lambda *a, **k: volume_ax, instance, args, bar_kwargs)
129
+
130
+ # Process and register Moving Averages as LINE plots
131
+ if price_ax:
132
+ # Find moving average lines (Line2D objects)
133
+ ma_lines = [line for line in price_ax.get_lines() if isinstance(line, Line2D)]
134
+
135
+ # Track processed lines to avoid duplicates
136
+ processed_lines = set()
137
+ valid_lines = []
138
+
139
+ for line in ma_lines:
140
+ # Try to identify the moving average period based on NaN count
141
+ xydata = line.get_xydata()
142
+
143
+ if xydata is not None:
144
+ xydata_array = np.asarray(xydata)
145
+ nan_count = np.sum(
146
+ np.isnan(xydata_array[:, 1])
147
+ ) # Count NaN in y-values
148
+
149
+ # Map NaN count to likely moving average period
150
+ estimated_period = nan_count + 1
151
+
152
+ # Store the period directly on the line for easy access
153
+ setattr(line, "_maidr_ma_period", estimated_period)
154
+
155
+ # Create a better label for the line
156
+ label = str(line.get_label())
157
+ if label.startswith("_child"):
158
+ new_label = f"Moving Average {estimated_period} days"
159
+ line.set_label(new_label)
160
+ else:
161
+ # If it's not a _child label, still add the period info
162
+ new_label = f"{label}_MA{estimated_period}"
163
+ line.set_label(new_label)
164
+
165
+ # Create a unique identifier for this line based on its data
166
+ if xydata is not None:
167
+ xydata_array = np.asarray(xydata)
168
+ if xydata_array.size > 0:
169
+ # Use shape and first few values to create a unique identifier
170
+ first_values = (
171
+ xydata_array[:3].flatten()
172
+ if xydata_array.size >= 6
173
+ else xydata_array.flatten()
174
+ )
175
+ data_hash = hash(f"{xydata_array.shape}_{str(first_values)}")
176
+ line_id = f"{line.get_label()}_{data_hash}"
177
+ else:
178
+ line_id = f"{line.get_label()}"
179
+ else:
180
+ line_id = f"{line.get_label()}"
181
+
182
+ if line_id in processed_lines:
183
+ continue
184
+
185
+ processed_lines.add(line_id)
186
+
187
+ # Validate that the line has valid data
188
+ if xydata is None or xydata_array.size == 0:
189
+ continue
190
+
191
+ # Store date numbers on the line for the line plot class to use
192
+ if date_nums is not None:
193
+ setattr(line, "_maidr_date_nums", date_nums)
194
+
195
+ # Ensure GID is set for highlighting
196
+ if line.get_gid() is None:
197
+ gid = f"maidr-{uuid.uuid4()}"
198
+ line.set_gid(gid)
199
+
200
+ # Add to valid lines list
201
+ valid_lines.append(line)
202
+
203
+ # Register all valid lines as a single LINE plot
204
+ if valid_lines:
205
+ line_kwargs = dict(kwargs)
206
+ common(PlotType.LINE, lambda *a, **k: price_ax, instance, args, line_kwargs)
207
+
208
+ if not original_returnfig:
209
+ return None
210
+
211
+ return result
212
+
213
+
214
+ # Apply the patch to mplfinance.plot
215
+ wrapt.wrap_function_wrapper(mpf, "plot", mplfinance_plot_patch)
maidr/util/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ from .plot_detection import PlotDetectionUtils
2
+
3
+ __all__ = ["PlotDetectionUtils"]