maidr 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,213 @@
1
+ import uuid
2
+ import wrapt
3
+ import mplfinance as mpf
4
+ import numpy as np
5
+ from matplotlib.collections import LineCollection, PolyCollection
6
+ from matplotlib.patches import Rectangle
7
+ from matplotlib.lines import Line2D
8
+ from maidr.core.enum import PlotType
9
+ from maidr.patch.common import common
10
+ from maidr.core.context_manager import ContextManager
11
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
12
+
13
+
14
+ def mplfinance_plot_patch(wrapped, instance, args, kwargs):
15
+ """
16
+ Enhanced patch function for `mplfinance.plot` that registers separate layers:
17
+ - CANDLESTICK: For OHLC data (candle bodies and wicks)
18
+ - BAR: For volume data (volume bars)
19
+ - LINE: For moving averages (lines)
20
+
21
+ This function intercepts calls to `mplfinance.plot`, identifies the resulting
22
+ candlestick, volume, and moving average components, and registers them with
23
+ maidr using the common patching mechanism.
24
+ """
25
+ # Ensure `returnfig=True` to capture the figure and axes objects.
26
+ original_returnfig = kwargs.get("returnfig", False)
27
+ kwargs["returnfig"] = True
28
+
29
+ with ContextManager.set_internal_context():
30
+ result = wrapped(*args, **kwargs)
31
+
32
+ # Validate that we received the expected figure and axes tuple
33
+ if not (isinstance(result, tuple) and len(result) >= 2):
34
+ return result if original_returnfig else None
35
+
36
+ fig, axes = result[0], result[1]
37
+ ax_list = axes if isinstance(axes, list) else [axes]
38
+
39
+ # Enhanced axis identification using content-based detection
40
+ price_ax = None
41
+ volume_ax = None
42
+
43
+ # Identify axes by their content rather than just labels
44
+ for ax in ax_list:
45
+ # Price axis has candlestick collections (LineCollection for wicks, PolyCollection for bodies)
46
+ if any(isinstance(c, (LineCollection, PolyCollection)) for c in ax.collections):
47
+ price_ax = ax
48
+ # Volume axis has rectangle patches for volume bars
49
+ elif any(isinstance(p, Rectangle) for p in ax.patches):
50
+ volume_ax = ax
51
+ # Fallback: use y-label if content-based detection fails
52
+ elif price_ax is None and "price" in ax.get_ylabel().lower():
53
+ price_ax = ax
54
+ elif volume_ax is None and "volume" in ax.get_ylabel().lower():
55
+ volume_ax = ax
56
+
57
+ # Try to extract date numbers from the data
58
+ date_nums = None
59
+ data = None
60
+ if len(args) > 0:
61
+ data = args[0]
62
+ elif "data" in kwargs:
63
+ data = kwargs["data"]
64
+
65
+ if data is not None:
66
+ if hasattr(data, "Date_num"):
67
+ date_nums = list(data["Date_num"])
68
+ elif hasattr(data, "index"):
69
+ # fallback: use index if it's a DatetimeIndex
70
+ try:
71
+ import matplotlib.dates as mdates
72
+
73
+ date_nums = [mdates.date2num(d) for d in data.index]
74
+ except Exception:
75
+ pass
76
+
77
+ # Process and register the Candlestick plot
78
+ if price_ax:
79
+ wick_collection = next(
80
+ (c for c in price_ax.collections if isinstance(c, LineCollection)), None
81
+ )
82
+ body_collection = next(
83
+ (c for c in price_ax.collections if isinstance(c, PolyCollection)), None
84
+ )
85
+
86
+ if wick_collection and body_collection:
87
+ gid = f"maidr-{uuid.uuid4()}"
88
+ wick_collection.set_gid(gid)
89
+ body_collection.set_gid(gid)
90
+
91
+ candlestick_kwargs = dict(
92
+ kwargs,
93
+ _maidr_wick_collection=wick_collection,
94
+ _maidr_body_collection=body_collection,
95
+ _maidr_date_nums=date_nums,
96
+ )
97
+ common(
98
+ PlotType.CANDLESTICK,
99
+ lambda *a, **k: price_ax,
100
+ instance,
101
+ args,
102
+ candlestick_kwargs,
103
+ )
104
+
105
+ # Process and register the Volume plot
106
+ if volume_ax:
107
+ volume_patches = [p for p in volume_ax.patches if isinstance(p, Rectangle)]
108
+
109
+ if not volume_patches:
110
+ # Search in shared axes for volume patches
111
+ for twin_ax in volume_ax.get_shared_x_axes().get_siblings(volume_ax):
112
+ if twin_ax is not volume_ax:
113
+ volume_patches.extend(
114
+ [p for p in twin_ax.patches if isinstance(p, Rectangle)]
115
+ )
116
+
117
+ if volume_patches:
118
+ # Set GID for volume patches for highlighting
119
+ for patch in volume_patches:
120
+ if patch.get_gid() is None:
121
+ gid = f"maidr-{uuid.uuid4()}"
122
+ patch.set_gid(gid)
123
+
124
+ bar_kwargs = dict(
125
+ kwargs,
126
+ _maidr_patches=volume_patches,
127
+ _maidr_date_nums=date_nums,
128
+ )
129
+ common(PlotType.BAR, lambda *a, **k: volume_ax, instance, args, bar_kwargs)
130
+
131
+ # Process and register Moving Averages as LINE plots
132
+ if price_ax:
133
+ # Find moving average lines (Line2D objects)
134
+ ma_lines = [line for line in price_ax.get_lines() if isinstance(line, Line2D)]
135
+
136
+ # Track processed lines to avoid duplicates
137
+ processed_lines = set()
138
+ valid_lines = []
139
+
140
+ for line in ma_lines:
141
+ # Try to identify the moving average period based on NaN count
142
+ xydata = line.get_xydata()
143
+
144
+ if xydata is not None:
145
+ xydata_array = np.asarray(xydata)
146
+ nan_count = np.sum(
147
+ np.isnan(xydata_array[:, 1])
148
+ ) # Count NaN in y-values
149
+
150
+ # Map NaN count to likely moving average period
151
+ estimated_period = nan_count + 1
152
+
153
+ # Create a better label for the line
154
+ label = str(line.get_label())
155
+ if label.startswith("_child"):
156
+ new_label = f"Moving Average {estimated_period} days"
157
+ line.set_label(new_label)
158
+ else:
159
+ # If it's not a _child label, still add the period info
160
+ new_label = f"{label}_MA{estimated_period}"
161
+ line.set_label(new_label)
162
+
163
+ # Create a unique identifier for this line based on its data
164
+ if xydata is not None:
165
+ xydata_array = np.asarray(xydata)
166
+ if xydata_array.size > 0:
167
+ # Use shape and first few values to create a unique identifier
168
+ first_values = (
169
+ xydata_array[:3].flatten()
170
+ if xydata_array.size >= 6
171
+ else xydata_array.flatten()
172
+ )
173
+ data_hash = hash(f"{xydata_array.shape}_{str(first_values)}")
174
+ line_id = f"{line.get_label()}_{data_hash}"
175
+ else:
176
+ line_id = f"{line.get_label()}"
177
+ else:
178
+ line_id = f"{line.get_label()}"
179
+
180
+ if line_id in processed_lines:
181
+ continue
182
+
183
+ processed_lines.add(line_id)
184
+
185
+ # Validate that the line has valid data
186
+ if xydata is None or xydata_array.size == 0:
187
+ continue
188
+
189
+ # Store date numbers on the line for the line plot class to use
190
+ if date_nums is not None:
191
+ setattr(line, "_maidr_date_nums", date_nums)
192
+
193
+ # Ensure GID is set for highlighting
194
+ if line.get_gid() is None:
195
+ gid = f"maidr-{uuid.uuid4()}"
196
+ line.set_gid(gid)
197
+
198
+ # Add to valid lines list
199
+ valid_lines.append(line)
200
+
201
+ # Register all valid lines as a single LINE plot
202
+ if valid_lines:
203
+ line_kwargs = dict(kwargs)
204
+ common(PlotType.LINE, lambda *a, **k: price_ax, instance, args, line_kwargs)
205
+
206
+ if not original_returnfig:
207
+ return None
208
+
209
+ return result
210
+
211
+
212
+ # Apply the patch to mplfinance.plot
213
+ wrapt.wrap_function_wrapper(mpf, "plot", mplfinance_plot_patch)
maidr/util/__init__.py CHANGED
@@ -0,0 +1,3 @@
1
+ from .plot_detection import PlotDetectionUtils
2
+
3
+ __all__ = ["PlotDetectionUtils"]
@@ -0,0 +1,409 @@
1
+ """
2
+ Utility functions for handling mplfinance-specific data extraction and processing.
3
+ """
4
+
5
+ import matplotlib.dates as mdates
6
+ import numpy as np
7
+ from matplotlib.patches import Rectangle
8
+ from typing import List, Optional, Tuple, Any
9
+
10
+
11
+ class MplfinanceDataExtractor:
12
+ """
13
+ Utility class for extracting and processing mplfinance-specific data.
14
+
15
+ This class handles the conversion of mplfinance plot elements (patches, collections)
16
+ into standardized data formats that can be used by the core plot classes.
17
+ """
18
+
19
+ @staticmethod
20
+ def extract_volume_data(
21
+ volume_patches: List[Rectangle], date_nums: Optional[List[float]] = None
22
+ ) -> List[dict]:
23
+ """
24
+ Extract volume data from mplfinance Rectangle patches.
25
+
26
+ Parameters
27
+ ----------
28
+ volume_patches : List[Rectangle]
29
+ List of Rectangle patches representing volume bars
30
+ date_nums : Optional[List[float]], default=None
31
+ List of matplotlib date numbers corresponding to the patches
32
+
33
+ Returns
34
+ -------
35
+ List[dict]
36
+ List of dictionaries with 'x' and 'y' keys for volume data
37
+ """
38
+ if not volume_patches:
39
+ return []
40
+
41
+ formatted_data = []
42
+
43
+ # Sort patches by x-coordinate to maintain order
44
+ sorted_patches = sorted(volume_patches, key=lambda p: p.get_x())
45
+
46
+ for i, patch in enumerate(sorted_patches):
47
+ height = patch.get_height()
48
+
49
+ # Use date mapping if available, otherwise use index
50
+ if date_nums is not None and i < len(date_nums):
51
+ date_num = date_nums[i]
52
+ x_label = MplfinanceDataExtractor._convert_date_num_to_string(date_num)
53
+ else:
54
+ x_label = f"date_{i:03d}"
55
+
56
+ formatted_data.append({"x": str(x_label), "y": float(height)})
57
+
58
+ return formatted_data
59
+
60
+ @staticmethod
61
+ def extract_candlestick_data(
62
+ body_collection: Any,
63
+ wick_collection: Any,
64
+ date_nums: Optional[List[float]] = None,
65
+ ) -> List[dict]:
66
+ """
67
+ Extract candlestick data from mplfinance collections.
68
+
69
+ Parameters
70
+ ----------
71
+ body_collection : Any
72
+ PolyCollection containing candlestick bodies
73
+ wick_collection : Any
74
+ LineCollection containing candlestick wicks
75
+ date_nums : Optional[List[float]], default=None
76
+ List of matplotlib date numbers corresponding to the candles
77
+
78
+ Returns
79
+ -------
80
+ List[dict]
81
+ List of dictionaries with OHLC data
82
+ """
83
+ if not body_collection or not hasattr(body_collection, "get_paths"):
84
+ return []
85
+
86
+ candles = []
87
+ paths = body_collection.get_paths()
88
+ face_colors = body_collection.get_facecolor()
89
+
90
+ for i, path in enumerate(paths):
91
+ if len(path.vertices) >= 4:
92
+ # Extract rectangle coordinates from the path
93
+ vertices = path.vertices
94
+ x_coords = vertices[:, 0]
95
+ y_coords = vertices[:, 1]
96
+
97
+ x_min, x_max = x_coords.min(), x_coords.max()
98
+ y_min, y_max = y_coords.min(), y_coords.max()
99
+
100
+ # Use date mapping if available
101
+ if date_nums is not None and i < len(date_nums):
102
+ date_num = date_nums[i]
103
+ date_str = MplfinanceDataExtractor._convert_date_num_to_string(
104
+ date_num
105
+ )
106
+ else:
107
+ x_center = (x_min + x_max) / 2
108
+ date_str = MplfinanceDataExtractor._convert_date_num_to_string(
109
+ x_center
110
+ )
111
+
112
+ # Determine if this is an up or down candle based on color
113
+ is_up = MplfinanceDataExtractor._is_up_candle(face_colors, i)
114
+
115
+ # Extract OHLC values
116
+ (
117
+ open_val,
118
+ close_val,
119
+ ) = MplfinanceDataExtractor._extract_ohlc_from_rectangle(
120
+ y_min, y_max, is_up
121
+ )
122
+
123
+ # Estimate high and low (these would normally come from wick data)
124
+ high_val = y_max + (y_max - y_min) * 0.1
125
+ low_val = y_min - (y_max - y_min) * 0.1
126
+
127
+ candle_data = {
128
+ "value": date_str,
129
+ "open": round(open_val, 2),
130
+ "high": round(high_val, 2),
131
+ "low": round(low_val, 2),
132
+ "close": round(close_val, 2),
133
+ "volume": 0, # Volume is handled separately
134
+ }
135
+ candles.append(candle_data)
136
+
137
+ return candles
138
+
139
+ @staticmethod
140
+ def extract_rectangle_candlestick_data(
141
+ body_rectangles: List[Rectangle], date_nums: Optional[List[float]] = None
142
+ ) -> List[dict]:
143
+ """
144
+ Extract candlestick data from Rectangle patches (original_flavor).
145
+
146
+ Parameters
147
+ ----------
148
+ body_rectangles : List[Rectangle]
149
+ List of Rectangle patches representing candlestick bodies
150
+ date_nums : Optional[List[float]], default=None
151
+ List of matplotlib date numbers corresponding to the candles
152
+
153
+ Returns
154
+ -------
155
+ List[dict]
156
+ List of dictionaries with OHLC data
157
+ """
158
+ if not body_rectangles:
159
+ return []
160
+
161
+ candles = []
162
+
163
+ # Sort rectangles by x-coordinate
164
+ body_rectangles.sort(key=lambda r: r.get_x())
165
+
166
+ for i, rect in enumerate(body_rectangles):
167
+ x_left = rect.get_x()
168
+ width = rect.get_width()
169
+ x_center_num = x_left + width / 2.0
170
+
171
+ # Convert x coordinate to date
172
+ if date_nums is not None and i < len(date_nums):
173
+ date_str = MplfinanceDataExtractor._convert_date_num_to_string(
174
+ date_nums[i]
175
+ )
176
+ else:
177
+ date_str = MplfinanceDataExtractor._convert_date_num_to_string(
178
+ x_center_num
179
+ )
180
+
181
+ y_bottom = rect.get_y()
182
+ height = rect.get_height()
183
+ face_color = rect.get_facecolor()
184
+
185
+ # Determine if this is an up or down candle based on color
186
+ is_up_candle = MplfinanceDataExtractor._is_up_candle_from_color(face_color)
187
+
188
+ # Extract OHLC values from rectangle
189
+ (
190
+ open_price,
191
+ close_price,
192
+ ) = MplfinanceDataExtractor._extract_ohlc_from_rectangle(
193
+ y_bottom, y_bottom + height, is_up_candle
194
+ )
195
+
196
+ # Estimate high and low
197
+ high_price = max(open_price, close_price) + height * 0.1
198
+ low_price = min(open_price, close_price) - height * 0.1
199
+
200
+ # Ensure all values are valid numbers
201
+ open_price = float(open_price) if not np.isnan(open_price) else 0.0
202
+ high_price = float(high_price) if not np.isnan(high_price) else 0.0
203
+ low_price = float(low_price) if not np.isnan(low_price) else 0.0
204
+ close_price = float(close_price) if not np.isnan(close_price) else 0.0
205
+
206
+ candle_data = {
207
+ "value": date_str,
208
+ "open": round(open_price, 2),
209
+ "high": round(high_price, 2),
210
+ "low": round(low_price, 2),
211
+ "close": round(close_price, 2),
212
+ "volume": 0,
213
+ }
214
+ candles.append(candle_data)
215
+
216
+ return candles
217
+
218
+ @staticmethod
219
+ def clean_axis_label(label: str) -> str:
220
+ """
221
+ Clean up axis labels by removing LaTeX formatting.
222
+
223
+ Parameters
224
+ ----------
225
+ label : str
226
+ The original axis label
227
+
228
+ Returns
229
+ -------
230
+ str
231
+ Cleaned axis label
232
+ """
233
+ if not label or not isinstance(label, str):
234
+ return label
235
+
236
+ import re
237
+
238
+ # Removes LaTeX-like scientific notation, e.g., "$10^{6}$"
239
+ cleaned_label = re.sub(r"\s*\$.*?\$", "", label).strip()
240
+ return cleaned_label if cleaned_label else label
241
+
242
+ @staticmethod
243
+ def _convert_date_num_to_string(date_num: float) -> str:
244
+ """
245
+ Convert matplotlib date number to date string.
246
+
247
+ Parameters
248
+ ----------
249
+ date_num : float
250
+ Matplotlib date number
251
+
252
+ Returns
253
+ -------
254
+ str
255
+ Date string in YYYY-MM-DD format or fallback index
256
+ """
257
+ try:
258
+ # Check if this looks like a matplotlib date number (typically > 700000)
259
+ if date_num > 700000:
260
+ date_dt = mdates.num2date(date_num)
261
+ if hasattr(date_dt, "replace"):
262
+ date_dt = date_dt.replace(tzinfo=None)
263
+ return date_dt.strftime("%Y-%m-%d")
264
+ elif date_num > 1000:
265
+ # Try converting as if it's a pandas timestamp
266
+ try:
267
+ import pandas as pd
268
+
269
+ date_dt = pd.to_datetime(date_num, unit="D")
270
+ return date_dt.strftime("%Y-%m-%d")
271
+ except:
272
+ pass
273
+ except (ValueError, TypeError, OverflowError):
274
+ pass
275
+
276
+ # Fallback to index-based date string
277
+ return f"date_{int(date_num):03d}"
278
+
279
+ @staticmethod
280
+ def convert_x_to_date(x_center_num: float, axes: Optional[List] = None) -> str:
281
+ """
282
+ Convert matplotlib x-coordinate to date string.
283
+
284
+ Parameters
285
+ ----------
286
+ x_center_num : float
287
+ X-coordinate value to convert
288
+ axes : Optional[List], optional
289
+ List of matplotlib axes to help with date conversion
290
+
291
+ Returns
292
+ -------
293
+ str
294
+ Date string in YYYY-MM-DD format or fallback
295
+ """
296
+ # First, try to get the actual dates from the axes x-axis data
297
+ if axes and len(axes) > 0:
298
+ ax = axes[0]
299
+ try:
300
+ # Get x-axis ticks which might contain the actual dates
301
+ x_ticks = ax.get_xticks()
302
+
303
+ # If we have x-axis ticks and they look like dates (large numbers), use them
304
+ if len(x_ticks) > 0 and x_ticks[0] > 1000:
305
+ # Find the closest tick to our x_center_num
306
+ tick_index = int(round(x_center_num))
307
+ if 0 <= tick_index < len(x_ticks):
308
+ actual_date_num = x_ticks[tick_index]
309
+
310
+ # Convert the actual date number
311
+ if actual_date_num > 700000:
312
+ date_dt = mdates.num2date(actual_date_num)
313
+ if hasattr(date_dt, "replace"):
314
+ date_dt = date_dt.replace(tzinfo=None)
315
+ date_str = date_dt.strftime("%Y-%m-%d")
316
+ return date_str
317
+ except Exception:
318
+ pass
319
+
320
+ # Use the utility class for date conversion
321
+ return MplfinanceDataExtractor._convert_date_num_to_string(x_center_num)
322
+
323
+ @staticmethod
324
+ def _is_up_candle(face_colors: Any, index: int) -> bool:
325
+ """
326
+ Determine if a candle is up based on face color.
327
+
328
+ Parameters
329
+ ----------
330
+ face_colors : Any
331
+ Face colors from the collection
332
+ index : int
333
+ Index of the candle
334
+
335
+ Returns
336
+ -------
337
+ bool
338
+ True if up candle, False if down candle
339
+ """
340
+ is_up = True # Default to up candle
341
+ if hasattr(face_colors, "__len__") and len(face_colors) > index:
342
+ color = (
343
+ face_colors[index]
344
+ if hasattr(face_colors[index], "__len__")
345
+ else face_colors
346
+ )
347
+ if isinstance(color, (list, tuple, np.ndarray)):
348
+ if len(color) >= 3:
349
+ # Dark colors typically indicate down candles
350
+ if color[0] < 0.5 and color[1] < 0.5 and color[2] < 0.5:
351
+ is_up = False
352
+ return is_up
353
+
354
+ @staticmethod
355
+ def _is_up_candle_from_color(face_color: Any) -> bool:
356
+ """
357
+ Determine if a candle is up based on face color (for Rectangle patches).
358
+
359
+ Parameters
360
+ ----------
361
+ face_color : Any
362
+ Face color of the rectangle
363
+
364
+ Returns
365
+ -------
366
+ bool
367
+ True if up candle, False if down candle
368
+ """
369
+ try:
370
+ if (
371
+ isinstance(face_color, (list, tuple, np.ndarray))
372
+ and len(face_color) >= 3
373
+ ):
374
+ # Green colors typically indicate up candles
375
+ if face_color[1] > face_color[0]:
376
+ return True
377
+ else:
378
+ return False
379
+ except (TypeError, IndexError):
380
+ pass
381
+ return True # Default to up candle
382
+
383
+ @staticmethod
384
+ def _extract_ohlc_from_rectangle(
385
+ y_min: float, y_max: float, is_up: bool
386
+ ) -> Tuple[float, float]:
387
+ """
388
+ Extract open and close values from rectangle bounds.
389
+
390
+ Parameters
391
+ ----------
392
+ y_min : float
393
+ Minimum y value of rectangle
394
+ y_max : float
395
+ Maximum y value of rectangle
396
+ is_up : bool
397
+ Whether this is an up candle
398
+
399
+ Returns
400
+ -------
401
+ Tuple[float, float]
402
+ (open_price, close_price)
403
+ """
404
+ if is_up:
405
+ # Up candle: open at bottom, close at top
406
+ return y_min, y_max
407
+ else:
408
+ # Down candle: open at top, close at bottom
409
+ return y_max, y_min