maidr 1.3.0__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
maidr/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "1.3.0"
1
+ __version__ = "1.4.1"
2
2
 
3
3
  from .api import close, render, save_html, show, stacked
4
4
  from .core import Maidr
@@ -15,6 +15,7 @@ from .patch import (
15
15
  scatterplot,
16
16
  regplot,
17
17
  kdeplot,
18
+ mplfinance,
18
19
  )
19
20
 
20
21
  __all__ = [
maidr/core/maidr.py CHANGED
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from datetime import datetime
4
+ import urllib.request
5
+ import urllib.error
6
+ import json
4
7
 
5
8
  import io
6
9
  import json
@@ -22,6 +25,11 @@ from maidr.core.plot import MaidrPlot
22
25
  from maidr.util.environment import Environment
23
26
  from maidr.util.dedup_utils import deduplicate_smooth_and_line
24
27
 
28
+ # Module-level cache for version to avoid repeated API calls
29
+ _MAIDR_VERSION_CACHE: str | None = None
30
+ _MAIDR_VERSION_CACHE_TIME: float = 0.0
31
+ _MAIDR_CACHE_DURATION = 3600 # Cache for 1 hour
32
+
25
33
 
26
34
  class Maidr:
27
35
  """
@@ -34,6 +42,8 @@ class Maidr:
34
42
  The matplotlib figure associated with this instance.
35
43
  _plots : list[MaidrPlot]
36
44
  A list of MaidrPlot objects which hold additional plot-specific configurations.
45
+ _cached_version : str | None
46
+ Cached version of maidr from npm registry to avoid repeated API calls.
37
47
 
38
48
  Methods
39
49
  -------
@@ -271,20 +281,77 @@ class Maidr:
271
281
  """Generate a unique identifier string using UUID4."""
272
282
  return str(uuid.uuid4())
273
283
 
284
+ @staticmethod
285
+ def _get_latest_maidr_version() -> str:
286
+ """
287
+ Query the npm registry API to get the latest version of maidr with caching.
288
+
289
+ Returns
290
+ -------
291
+ str
292
+ The latest version of maidr from npm registry, or 'latest' as fallback.
293
+ """
294
+ import time
295
+
296
+ global _MAIDR_VERSION_CACHE, _MAIDR_VERSION_CACHE_TIME
297
+
298
+ # Check if version fetching is disabled via environment variable
299
+ if os.getenv("MAIDR_DISABLE_VERSION_FETCH", "").lower() in ("true", "1", "yes"):
300
+ return "latest"
301
+
302
+ current_time = time.time()
303
+
304
+ # Check if we have a valid cached version
305
+ if (
306
+ _MAIDR_VERSION_CACHE is not None
307
+ and current_time - _MAIDR_VERSION_CACHE_TIME < _MAIDR_CACHE_DURATION
308
+ ):
309
+ return _MAIDR_VERSION_CACHE
310
+
311
+ try:
312
+ # Query npm registry API for maidr package
313
+ with urllib.request.urlopen(
314
+ "https://registry.npmjs.org/maidr/latest", timeout=5 # 5 second timeout
315
+ ) as response:
316
+ if response.status == 200:
317
+ data = json.loads(response.read().decode("utf-8"))
318
+ version = data.get("version", "latest")
319
+
320
+ # Cache the successful result
321
+ _MAIDR_VERSION_CACHE = version
322
+ _MAIDR_VERSION_CACHE_TIME = current_time
323
+
324
+ return version
325
+
326
+ except Exception:
327
+ # Any error - just use latest
328
+ pass
329
+
330
+ # Fallback to 'latest' if API call fails
331
+ return "latest"
332
+
333
+ @staticmethod
334
+ def clear_version_cache() -> None:
335
+ """Clear the cached version to force a fresh API call on next request."""
336
+ global _MAIDR_VERSION_CACHE, _MAIDR_VERSION_CACHE_TIME
337
+ _MAIDR_VERSION_CACHE = None
338
+ _MAIDR_VERSION_CACHE_TIME = 0.0
339
+
274
340
  @staticmethod
275
341
  def _inject_plot(plot: HTML, maidr: str, maidr_id, use_iframe: bool = True) -> Tag:
276
342
  """Embed the plot and associated MAIDR scripts into the HTML structure."""
277
- # MAIDR_TS_CDN_URL = "http://localhost:8888/tree/maidr/core/maidr.js" # DEMO URL
278
- MAIDR_TS_CDN_URL = "https://cdn.jsdelivr.net/npm/maidr@latest/dist/maidr.js"
279
- # Append a query parameter (using TIMESTAMP) to bust the cache (so that the latest (non-cached) version is always loaded).
280
- TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
343
+ # Get the latest version from npm registry
344
+ latest_version = Maidr._get_latest_maidr_version()
345
+ MAIDR_TS_CDN_URL = (
346
+ f"https://cdn.jsdelivr.net/npm/maidr@{latest_version}/dist/maidr.js"
347
+ )
281
348
 
282
349
  script = f"""
283
- if (!document.querySelector('script[src="{MAIDR_TS_CDN_URL}?v={TIMESTAMP}"]'))
350
+ if (!document.querySelector('script[src="{MAIDR_TS_CDN_URL}"]'))
284
351
  {{
285
352
  var script = document.createElement('script');
286
353
  script.type = 'module';
287
- script.src = '{MAIDR_TS_CDN_URL}?v={TIMESTAMP}';
354
+ script.src = '{MAIDR_TS_CDN_URL}';
288
355
  script.addEventListener('load', function() {{
289
356
  window.main();
290
357
  }});
@@ -299,7 +366,7 @@ class Maidr:
299
366
  base_html = tags.div(
300
367
  tags.link(
301
368
  rel="stylesheet",
302
- href="https://cdn.jsdelivr.net/npm/maidr/dist/maidr_style.css",
369
+ href=f"https://cdn.jsdelivr.net/npm/maidr@{latest_version}/dist/maidr_style.css",
303
370
  ),
304
371
  tags.script(script, type="text/javascript"),
305
372
  tags.div(plot),
@@ -1,14 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import matplotlib.dates as mdates
2
4
  import numpy as np
3
5
  from matplotlib.axes import Axes
4
- from matplotlib.collections import LineCollection, PatchCollection
5
6
  from matplotlib.patches import Rectangle
6
7
 
7
- from maidr.core.enum.plot_type import PlotType
8
- from maidr.core.plot.maidr_plot import MaidrPlot
8
+ from maidr.core.enum import PlotType
9
+ from maidr.core.plot import MaidrPlot
10
+ from maidr.core.enum.maidr_key import MaidrKey
11
+ from maidr.util.mplfinance_utils import MplfinanceDataExtractor
9
12
 
10
13
 
11
14
  class CandlestickPlot(MaidrPlot):
15
+ """
16
+ Specialized candlestick plot class for mplfinance OHLC data.
17
+
18
+ This class handles the extraction and processing of candlestick data from mplfinance
19
+ plots, including proper date conversion and data validation.
20
+ """
21
+
12
22
  def __init__(self, axes: list[Axes], **kwargs) -> None:
13
23
  """
14
24
  Initialize the CandlestickPlot.
@@ -27,213 +37,111 @@ class CandlestickPlot(MaidrPlot):
27
37
  raise ValueError("Axes list cannot be empty.")
28
38
  super().__init__(axes[0], PlotType.CANDLESTICK)
29
39
 
40
+ # Store custom collections passed from mplfinance patch
41
+ self._maidr_wick_collection = kwargs.get("_maidr_wick_collection", None)
42
+ self._maidr_body_collection = kwargs.get("_maidr_body_collection", None)
43
+ self._maidr_date_nums = kwargs.get("_maidr_date_nums", None)
44
+
45
+ # Store the GID for proper selector generation
46
+ self._maidr_gid = None
47
+ if self._maidr_body_collection:
48
+ self._maidr_gid = self._maidr_body_collection.get_gid()
49
+ elif self._maidr_wick_collection:
50
+ self._maidr_gid = self._maidr_wick_collection.get_gid()
51
+
30
52
  def _extract_plot_data(self) -> list[dict]:
31
- """
32
- Extracts candlestick (OHLC) and volume data from the plot axes.
53
+ """Extract candlestick data from the plot."""
33
54
 
34
- This method assumes that the candlestick chart is structured with
35
- LineCollection for wicks and PatchCollection of Rectangles for bodies
36
- on the first axis (self.axes[0]). Volume data is expected as a
37
- PatchCollection of Rectangles on the second axis (self.axes[1]), if present.
38
- Open and close prices are inferred from the body rectangle's color.
55
+ # Get the custom collections from kwargs
56
+ body_collection = self._maidr_body_collection
57
+ wick_collection = self._maidr_wick_collection
39
58
 
40
- Returns
41
- -------
42
- list[dict]
43
- A list of dictionaries, where each dictionary represents a data point
44
- with 'value' (date string YYYY-MM-DD), 'open', 'high', 'low',
45
- 'close', and 'volume'. Fields that cannot be extracted will be None.
46
-
47
- Examples
48
- --------
49
- Assuming a plot has been generated and `plot_instance.axes` is populated:
50
- >>> data = plot_instance._extract_plot_data()
51
- >>> print(data[0])
52
- {
53
- 'value': '2021-01-01',
54
- 'open': 100.0,
55
- 'high': 100.9,
56
- 'low': 99.27,
57
- 'close': 100.75,
58
- 'volume': 171914,
59
- }
60
- """
59
+ if body_collection and wick_collection:
60
+ # Store the GID from the body collection for highlighting
61
+ self._maidr_gid = body_collection.get_gid()
62
+
63
+ # Use the original collections for highlighting
64
+ self._elements = [body_collection, wick_collection]
65
+
66
+ # Use the utility class to extract data
67
+ return MplfinanceDataExtractor.extract_candlestick_data(
68
+ body_collection, wick_collection, self._maidr_date_nums
69
+ )
70
+
71
+ # Fallback to original detection method
61
72
  if not self.axes:
62
73
  return []
63
74
 
64
- plot_data: list[dict] = []
65
- ax_ohlc: Axes = self.axes[0]
66
-
67
- body_rectangles: list[Rectangle] = []
68
- wick_collection: LineCollection | None = None
69
-
70
- # Find candlestick body Rectangles from the OHLC axis
71
- # Prefer PatchCollection containing Rectangles, fallback to individual Rectangles in ax.patches
72
- for collection in ax_ohlc.collections:
73
- if isinstance(collection, PatchCollection):
74
- # Check if the collection's patches are Rectangles
75
- try:
76
- # Iterating a PatchCollection yields its constituent Patch objects
77
- patches_are_rects = all(
78
- isinstance(p, Rectangle) for p in collection
79
- )
80
- if (
81
- patches_are_rects and len(collection.get_paths()) > 0
82
- ): # Ensure it has paths and they are Rectangles
83
- for (
84
- patch
85
- ) in collection: # Iterate to get actual Rectangle objects
86
- if isinstance(patch, Rectangle):
87
- body_rectangles.append(patch)
88
- if (
89
- body_rectangles
90
- ): # If we found rectangles this way, assume this is the primary body collection
91
- break
92
- except Exception:
93
- # Could fail if collection is not iterable in the expected way or patches are not Rectangles
94
- pass
95
-
96
- if not body_rectangles:
97
- for patch in ax_ohlc.patches:
98
- if isinstance(patch, Rectangle):
99
- body_rectangles.append(patch)
100
-
101
- if not body_rectangles:
102
- pass
103
-
104
- ax_for_wicks: Axes | None = None
105
- if len(self.axes) > 1:
106
- ax_for_wicks = self.axes[1]
107
-
108
- if ax_for_wicks:
109
- # Attempt 1: Find wicks in ax_for_wicks.collections (as a LineCollection)
110
- for collection in ax_for_wicks.collections:
111
- if isinstance(collection, LineCollection):
112
- segments = collection.get_segments()
113
- # Check if the collection contains segments and the first segment looks like a vertical line
114
- if segments is not None and len(segments) > 0:
115
- first_segment = segments[0]
116
- if (
117
- len(first_segment) == 2 # Segment consists of two points
118
- and len(first_segment[0]) == 2 # First point has (x, y)
119
- and len(first_segment[1]) == 2 # Second point has (x, y)
120
- and np.isclose(
121
- first_segment[0][0], first_segment[1][0]
122
- ) # X-coordinates are close (vertical)
123
- ):
124
- wick_collection = collection
125
- break # Found a suitable LineCollection
126
-
127
- # Attempt 2: If no LineCollection found, try to find wicks from individual Line2D objects in ax_for_wicks.get_lines()
128
- if not wick_collection and hasattr(ax_for_wicks, "get_lines"):
129
- potential_wick_segments = []
130
- for line in ax_for_wicks.get_lines(): # Iterate over Line2D objects
131
- x_data, y_data = line.get_data()
132
- # A wick is typically a vertical line defined by two points.
133
- if len(x_data) == 2 and len(y_data) == 2:
134
- if np.isclose(x_data[0], x_data[1]): # Check for verticality
135
- # Create a segment in the format [[x1, y1], [x2, y2]]
136
- segment = [
137
- [x_data[0], y_data[0]],
138
- [x_data[1], y_data[1]],
139
- ]
140
- potential_wick_segments.append(segment)
141
-
142
- if potential_wick_segments:
143
- # If wick segments were found from individual lines,
144
- # create a new LineCollection to hold them.
145
- # This allows the downstream processing logic
146
- # for wicks to remain consistent.
147
- # Basic properties are set; color/linestyle
148
- # are defaults and may not match
149
- # the original plot's styling if that
150
- # were relevant for segment extraction.
151
- wick_collection = LineCollection(
152
- potential_wick_segments,
153
- colors="k", # Default color for the temporary collection
154
- linestyles="solid", # Default linestyle
155
- )
156
-
157
- # Process wicks into a map: x_coordinate -> (low_price, high_price)
158
- wick_segments_map: dict[float, tuple[float, float]] = {}
159
- if wick_collection:
160
- for seg in wick_collection.get_segments():
161
- if len(seg) == 2 and len(seg[0]) == 2 and len(seg[1]) == 2:
162
- # Ensure x-coordinates are (nearly) identical for a vertical wick line
163
- if np.isclose(seg[0][0], seg[1][0]):
164
- x_coord = seg[0][0] # Matplotlib date number
165
- low_price = min(seg[0][1], seg[1][1])
166
- high_price = max(seg[0][1], seg[1][1])
167
- wick_segments_map[x_coord] = (low_price, high_price)
168
-
169
- body_rectangles.sort(key=lambda r: r.get_x())
170
-
171
- for rect in body_rectangles:
172
- x_left = rect.get_x()
173
- width = rect.get_width()
174
- x_center_num = x_left + width / 2.0
175
-
176
- try:
177
- date_dt = mdates.num2date(x_center_num)
178
- date_str = date_dt.strftime("%Y-%m-%d")
179
- except ValueError:
180
- date_str = f"raw_date_{x_center_num:.2f}"
181
-
182
- y_bottom = rect.get_y()
183
- height = rect.get_height()
184
- face_color = rect.get_facecolor() # RGBA tuple
185
-
186
- # Infer open and close prices
187
- # Heuristic: Green component > Red component for an "up" candle (close > open)
188
- # This assumes standard green for up, red for down.
189
- # A more robust method would involve knowing the exact up/down colors used.
190
- is_up_candle = (
191
- face_color[1] > face_color[0]
192
- ) # Compare Green and Red components
193
-
194
- if is_up_candle: # Typically green: price went up
195
- open_price = y_bottom
196
- close_price = y_bottom + height
197
- else: # Typically red: price went down (or other color)
198
- close_price = y_bottom
199
- open_price = y_bottom + height
200
-
201
- matched_wick_data = None
202
- closest_wick_x = None
203
- min_diff = float("inf")
204
-
205
- for wick_x, prices in wick_segments_map.items():
206
- diff = abs(wick_x - x_center_num)
207
- if diff < min_diff:
208
- min_diff = diff
209
- closest_wick_x = wick_x
210
-
211
- # Tolerance for matching wick x-coordinate (e.g., 10% of candle width)
212
- if closest_wick_x is not None and min_diff < (width * 0.1):
213
- matched_wick_data = wick_segments_map[closest_wick_x]
214
-
215
- if matched_wick_data:
216
- low_price, high_price = matched_wick_data
217
- # Ensure high >= max(open,close) and low <= min(open,close)
218
- high_price = max(high_price, open_price, close_price)
219
- low_price = min(low_price, open_price, close_price)
220
- else:
221
- # Fallback if no wick found: high is max(open,close), low is min(open,close)
222
- high_price = max(open_price, close_price)
223
- low_price = min(open_price, close_price)
224
-
225
- plot_data.append(
226
- {
227
- "value": date_str,
228
- "open": open_price,
229
- "high": high_price,
230
- "low": low_price,
231
- "close": close_price,
232
- "volume": 0,
233
- }
75
+ ax_ohlc = self.axes[0]
76
+ candles = []
77
+
78
+ # Look for Rectangle patches (original_flavor candlestick)
79
+ body_rectangles = []
80
+ for patch in ax_ohlc.patches:
81
+ if isinstance(patch, Rectangle):
82
+ body_rectangles.append(patch)
83
+
84
+ if body_rectangles:
85
+ # Set elements for highlighting
86
+ self._elements = body_rectangles
87
+
88
+ # Generate a GID for highlighting if none exists
89
+ if not self._maidr_gid:
90
+ import uuid
91
+
92
+ self._maidr_gid = f"maidr-{uuid.uuid4()}"
93
+ # Set GID on all rectangles
94
+ for rect in body_rectangles:
95
+ rect.set_gid(self._maidr_gid)
96
+
97
+ # Use the utility class to extract data
98
+ return MplfinanceDataExtractor.extract_rectangle_candlestick_data(
99
+ body_rectangles, self._maidr_date_nums
234
100
  )
235
- self._elements.extend(body_rectangles)
236
- return plot_data
101
+
102
+ return []
237
103
 
238
104
  def _extract_axes_data(self) -> dict:
239
- return {}
105
+ """
106
+ Extract the plot's axes data including labels.
107
+
108
+ Returns
109
+ -------
110
+ dict
111
+ Dictionary containing x and y axis labels.
112
+ """
113
+ x_labels = self.ax.get_xlabel()
114
+ if not x_labels:
115
+ x_labels = self.extract_shared_xlabel(self.ax)
116
+ if not x_labels:
117
+ x_labels = "X"
118
+ return {MaidrKey.X: x_labels, MaidrKey.Y: self.ax.get_ylabel()}
119
+
120
+ def _get_selector(self) -> str:
121
+ """Return the CSS selector for highlighting candlestick elements in the SVG output."""
122
+ # Use the stored GID if available, otherwise fall back to generic selector
123
+ if self._maidr_gid:
124
+ # Use the full GID as the id attribute (since that's what's in the SVG)
125
+ selector = (
126
+ f"g[id='{self._maidr_gid}'] > path, g[id='{self._maidr_gid}'] > rect"
127
+ )
128
+ else:
129
+ selector = "g[maidr='true'] > path, g[maidr='true'] > rect"
130
+ return selector
131
+
132
+ def render(self) -> dict:
133
+ """Initialize the MAIDR schema dictionary with basic plot information."""
134
+ title = "Candlestick Chart"
135
+
136
+ maidr_schema = {
137
+ MaidrKey.TYPE: self.type,
138
+ MaidrKey.TITLE: title,
139
+ MaidrKey.AXES: self._extract_axes_data(),
140
+ MaidrKey.DATA: self._extract_plot_data(),
141
+ }
142
+
143
+ # Include selector only if the plot supports highlighting.
144
+ if self._support_highlighting:
145
+ maidr_schema[MaidrKey.SELECTOR] = self._get_selector()
146
+
147
+ return maidr_schema
@@ -42,8 +42,13 @@ class MaidrPlot(ABC):
42
42
  self._support_highlighting = True
43
43
  self._elements = []
44
44
  ss = self.ax.get_subplotspec()
45
- self.row_index = ss.rowspan.start
46
- self.col_index = ss.colspan.start
45
+ # Handle cases where subplotspec is None (dynamically created axes)
46
+ if ss is not None:
47
+ self.row_index = ss.rowspan.start
48
+ self.col_index = ss.colspan.start
49
+ else:
50
+ self.row_index = 0
51
+ self.col_index = 0
47
52
 
48
53
  # MAIDR data
49
54
  self.type = plot_type
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from matplotlib.axes import Axes
4
-
5
4
  from maidr.core.enum import PlotType
6
5
  from maidr.core.plot.barplot import BarPlot
7
6
  from maidr.core.plot.boxplot import BoxPlot
@@ -13,6 +12,9 @@ from maidr.core.plot.lineplot import MultiLinePlot
13
12
  from maidr.core.plot.maidr_plot import MaidrPlot
14
13
  from maidr.core.plot.scatterplot import ScatterPlot
15
14
  from maidr.core.plot.regplot import SmoothPlot
15
+ from maidr.core.plot.mplfinance_barplot import MplfinanceBarPlot
16
+ from maidr.core.plot.mplfinance_lineplot import MplfinanceLinePlot
17
+ from maidr.util.plot_detection import PlotDetectionUtils
16
18
 
17
19
 
18
20
  class MaidrPlotFactory:
@@ -38,17 +40,13 @@ class MaidrPlotFactory:
38
40
  single_ax = ax
39
41
 
40
42
  if plot_type == PlotType.CANDLESTICK:
41
- if isinstance(ax, list):
42
- # If ax is a list of lists, flatten it
43
- if ax and isinstance(ax[0], list):
44
- axes = ax[0] # Take the first inner list
45
- else:
46
- axes = ax # Use the list as-is
47
- else:
48
- axes = [ax] # Wrap single axes in list
43
+ axes = PlotDetectionUtils.get_candlestick_axes(ax)
49
44
  return CandlestickPlot(axes, **kwargs)
50
45
  elif PlotType.BAR == plot_type or PlotType.COUNT == plot_type:
51
- return BarPlot(single_ax)
46
+ if PlotDetectionUtils.is_mplfinance_bar_plot(**kwargs):
47
+ return MplfinanceBarPlot(single_ax, **kwargs)
48
+ else:
49
+ return BarPlot(single_ax)
52
50
  elif PlotType.BOX == plot_type:
53
51
  return BoxPlot(single_ax, **kwargs)
54
52
  elif PlotType.HEAT == plot_type:
@@ -56,7 +54,10 @@ class MaidrPlotFactory:
56
54
  elif PlotType.HIST == plot_type:
57
55
  return HistPlot(single_ax)
58
56
  elif PlotType.LINE == plot_type:
59
- return MultiLinePlot(single_ax)
57
+ if PlotDetectionUtils.is_mplfinance_line_plot(single_ax, **kwargs):
58
+ return MplfinanceLinePlot(single_ax, **kwargs)
59
+ else:
60
+ return MultiLinePlot(single_ax, **kwargs)
60
61
  elif PlotType.SCATTER == plot_type:
61
62
  return ScatterPlot(single_ax)
62
63
  elif PlotType.DODGED == plot_type or PlotType.STACKED == plot_type: