plotjs 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plotjs/javascript.py ADDED
@@ -0,0 +1,23 @@
1
+ def from_file(javascript_file: str) -> str:
2
+ """
3
+ Get raw javascript from a javascript file.
4
+ This function just reads the js from a given
5
+ file.
6
+
7
+ Args:
8
+ javascript_file: Path to a js file.
9
+
10
+ Returns:
11
+ A string of raw javascript.
12
+
13
+ Examples:
14
+ ```python
15
+ from plotjs import javascript
16
+
17
+ javascript.from_file("path/to/script.js")
18
+ ```
19
+ """
20
+ with open(javascript_file, "r") as f:
21
+ js: str = f.read()
22
+
23
+ return js
plotjs/main.py ADDED
@@ -0,0 +1,359 @@
1
+ import os
2
+ import io
3
+ import random
4
+ import uuid
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from jinja2 import Environment, FileSystemLoader
10
+ from narwhals.typing import SeriesT
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib.figure import Figure
13
+ from matplotlib.axes import Axes
14
+
15
+ from plotjs.utils import _vector_to_list, _get_and_sanitize_js
16
+ from plotjs import css, javascript
17
+
18
+ MAIN_DIR: str = Path(__file__).parent
19
+ TEMPLATE_DIR: str = MAIN_DIR / "static"
20
+ CSS_PATH: str = os.path.join(TEMPLATE_DIR, "default.css")
21
+ JS_PARSER_PATH: str = os.path.join(TEMPLATE_DIR, "plotparser.js")
22
+
23
+ env: Environment = Environment(loader=FileSystemLoader(str(TEMPLATE_DIR)))
24
+
25
+
26
+ class PlotJS:
27
+ """
28
+ Class to convert static matplotlib plots to interactive charts.
29
+
30
+ Attributes:
31
+ - additional_css: All the CSS added via `add_css()`
32
+ - additional_javascript: All the JavaScript added via `add_javascript()`
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ fig: Figure | None = None,
38
+ **savefig_kws: dict,
39
+ ):
40
+ """
41
+ Initiate an `PlotJS` instance to convert matplotlib
42
+ figures to interactive charts.
43
+
44
+ Args:
45
+ fig: An optional matplotlib figure. If None, uses `plt.gcf()`.
46
+ savefig_kws: Additional keyword arguments passed to `plt.savefig()`.
47
+ """
48
+ if fig is None:
49
+ fig: Figure = plt.gcf()
50
+ buf: io.StringIO = io.StringIO()
51
+
52
+ # temporary change svg hashsalt and id for reproductibility
53
+ # https://matplotlib.org/stable/users/explain/customizing.html#the-default-matplotlibrc-file
54
+ plt.rcParams["svg.hashsalt"] = "svg-hashsalt"
55
+ plt.rcParams["svg.id"] = "svg-id"
56
+ fig.savefig(buf, format="svg", **savefig_kws)
57
+ plt.rcParams["svg.hashsalt"] = None
58
+ plt.rcParams["svg.id"] = None
59
+
60
+ buf.seek(0)
61
+ self._svg_content = buf.getvalue()
62
+
63
+ self._axes: list[Axes] = fig.get_axes()
64
+
65
+ self.additional_css = ""
66
+ self.additional_javascript = ""
67
+ self._hover_nearest = False
68
+ self._template = env.get_template("template.html")
69
+
70
+ with open(CSS_PATH) as f:
71
+ self._default_css = f.read()
72
+
73
+ self._js_parser = _get_and_sanitize_js(
74
+ file_path=JS_PARSER_PATH,
75
+ after_pattern=r"class PlotSVGParser.*",
76
+ )
77
+
78
+ rnd = random.Random(22022001)
79
+ self._uuid = uuid.UUID(int=rnd.getrandbits(128))
80
+
81
+ def add_tooltip(
82
+ self,
83
+ *,
84
+ labels: list | tuple | np.ndarray | SeriesT | None = None,
85
+ groups: list | tuple | np.ndarray | SeriesT | None = None,
86
+ tooltip_x_shift: int = 0,
87
+ tooltip_y_shift: int = 0,
88
+ hover_nearest: bool = False,
89
+ ax: Axes | None = None,
90
+ ) -> "PlotJS":
91
+ """
92
+ Add a tooltip to the interactive plot. You can set either
93
+ just `labels`, just `groups`, both or none.
94
+
95
+ Args:
96
+ labels: An iterable containing the labels for the tooltip.
97
+ It corresponds to the text that will appear on hover.
98
+ groups: An iterable containing the group for tooltip. It
99
+ corresponds to how to 'group' the tooltip. The easiest
100
+ way to understand this argument is to check the examples
101
+ below. Also note that the use of this argument is required
102
+ to 'connect' the legend with plot elements.
103
+ tooltip_x_shift: Number of pixels to shift the tooltip from
104
+ the cursor, on the x axis.
105
+ tooltip_y_shift: Number of pixels to shift the tooltip from
106
+ the cursor, on the y axis.
107
+ hover_nearest: When `True`, hover the nearest plot element.
108
+ ax: A matplotlib Axes. If `None` (default), uses first Axes.
109
+
110
+ Returns:
111
+ self: Returns the instance to allow method chaining.
112
+
113
+ Examples:
114
+ ```python
115
+ PlotJS(...).add_tooltip(
116
+ labels=["S&P500", "CAC40", "Sunflower"],
117
+ )
118
+ ```
119
+
120
+ ```python
121
+ PlotJS(...).add_tooltip(
122
+ labels=["S&P500", "CAC40", "Sunflower"],
123
+ columns=["S&P500", "CAC40", "Sunflower"],
124
+ )
125
+ ```
126
+
127
+ ```python
128
+ PlotJS(...).add_tooltip(
129
+ labels=["S&P500", "CAC40", "Sunflower"],
130
+ hover_nearest=True,
131
+ )
132
+ ```
133
+ """
134
+ self._tooltip_x_shift = tooltip_x_shift
135
+ self._tooltip_y_shift = tooltip_y_shift
136
+
137
+ if ax is None:
138
+ ax: Axes = self._axes[0]
139
+ self._legend_handles, self._legend_handles_labels = (
140
+ ax.get_legend_handles_labels()
141
+ )
142
+
143
+ if labels is None:
144
+ self._tooltip_labels = []
145
+ else:
146
+ self._tooltip_labels = _vector_to_list(labels)
147
+ self._tooltip_labels.extend(self._legend_handles_labels)
148
+ if groups is None:
149
+ self._tooltip_groups = list(range(len(self._tooltip_labels)))
150
+ else:
151
+ self._tooltip_groups = _vector_to_list(groups)
152
+ self._tooltip_groups.extend(self._legend_handles_labels)
153
+
154
+ if not hasattr(self, "_axes_tooltip"):
155
+ self._axes_tooltip: dict = dict()
156
+ axe_idx: int = self._axes.index(ax) + 1
157
+ axe_tooltip: dict[str, dict] = {
158
+ f"axes_{axe_idx}": {
159
+ "tooltip_labels": self._tooltip_labels,
160
+ "tooltip_groups": self._tooltip_groups,
161
+ "hover_nearest": "true" if hover_nearest else "false", # js boolean
162
+ }
163
+ }
164
+ self._axes_tooltip.update(axe_tooltip)
165
+
166
+ return self
167
+
168
+ def add_css(
169
+ self,
170
+ from_string: Optional[str] = None,
171
+ *,
172
+ from_dict: Optional[dict] = None,
173
+ from_file: Optional[str] = None,
174
+ ) -> "PlotJS":
175
+ """
176
+ Add CSS to the final HTML output. This function allows you to override
177
+ default styles or add custom CSS rules.
178
+
179
+ See the [CSS guide](../guides/css/index.md) for more info on how to work with CSS.
180
+
181
+ Args:
182
+ from_string: CSS rules to apply, as a string.
183
+ from_dict: CSS rules to apply, as a dictionnary.
184
+ from_file: CSS rules to apply from a CSS file.
185
+
186
+ Returns:
187
+ self: Returns the instance to allow method chaining.
188
+
189
+ Examples:
190
+ ```python
191
+ PlotJS(...).add_css('.tooltip {"color": "red";}')
192
+ ```
193
+
194
+ ```python
195
+
196
+ PlotJS(...).add_css(from_file="path/to/styles.css")
197
+ ```
198
+
199
+ ```python
200
+ from plotjs import css
201
+
202
+ PlotJS(...).add_css(from_dict={".tooltip": {"color": "red";}})
203
+ ```
204
+
205
+ ```python
206
+ from plotjs import css
207
+
208
+ PlotJS(...).add_css(
209
+ from_dict={".tooltip": {"color": "red";}},
210
+ ).add_css(
211
+ from_dict={".tooltip": {"background": "blue";}},
212
+ )
213
+ ```
214
+ """
215
+ if from_string:
216
+ self.additional_css += from_string
217
+ elif from_dict:
218
+ self.additional_css += css.from_dict(from_dict)
219
+ elif from_file:
220
+ self.additional_css += css.from_file(from_file)
221
+ else:
222
+ raise ValueError(
223
+ "Must provide at least one of: `from_string`, `from_dict`, `from_file`."
224
+ )
225
+ return self
226
+
227
+ def add_javascript(
228
+ self, from_string: Optional[str] = None, *, from_file: Optional[str] = None
229
+ ) -> "PlotJS":
230
+ """
231
+ Add custom JavaScript to the final HTML output. This function allows
232
+ users to enhance interactivity, define custom behaviors, or extend
233
+ the existing chart logic.
234
+
235
+ Args:
236
+ from_string: JavaScript code to include, as a string.
237
+ from_file: JavaScript code to apply from a JS file.
238
+
239
+ Returns:
240
+ self: Returns the instance to allow method chaining.
241
+
242
+ Examples:
243
+ ```python
244
+ PlotJS(...).add_javascript("console.log('Custom JS loaded!');")
245
+ ```
246
+
247
+ ```python
248
+ from plotjs import javascript
249
+
250
+ custom_js = javascript.from_file("script.js")
251
+ PlotJS(...).add_javascript(custom_js)
252
+ ```
253
+ """
254
+ if from_string:
255
+ self.additional_javascript += from_string
256
+ elif from_file:
257
+ self.additional_javascript += javascript.from_file(from_file)
258
+ else:
259
+ raise ValueError(
260
+ "Must provide at least one of: `from_string`, `from_file`."
261
+ )
262
+ return self
263
+
264
+ def save(
265
+ self,
266
+ file_path: str,
267
+ favicon_path: str = "https://github.com/JosephBARBIERDARNAL/static/blob/main/python-libs/plotjs/favicon.ico?raw=true",
268
+ document_title: str = "Made with plotjs",
269
+ ) -> "PlotJS":
270
+ """
271
+ Save the interactive matplotlib plots to an HTML file.
272
+
273
+ Args:
274
+ file_path: Where to save the HTML file. If the ".html"
275
+ extension is missing, it's added.
276
+ favicon_path: Path to a favicon file, remote or local.
277
+ The default is the logo of plotjs.
278
+ document_title: String used for the page title (the title
279
+ tag inside the head of the html document).
280
+
281
+ Returns:
282
+ The instance itself to allow method chaining.
283
+
284
+ Examples:
285
+ ```python
286
+ PlotJS(...).save("index.html")
287
+ ```
288
+
289
+ ```python
290
+ PlotJS(...).save("path/to/my_chart.html")
291
+ ```
292
+ """
293
+ self._favicon_path = favicon_path
294
+ self._document_title = document_title
295
+
296
+ self._set_html()
297
+
298
+ if not file_path.endswith(".html"):
299
+ file_path += ".html"
300
+ with open(file_path, "w") as f:
301
+ f.write(self.html)
302
+
303
+ return self
304
+
305
+ def as_html(self) -> str:
306
+ """
307
+ Retrieve the interactive plot as an HTML string.
308
+ This can be useful to display the plot in
309
+ environment such as marimo, or do advanced customization.
310
+
311
+ Returns:
312
+ A string with all the HTML of the plot.
313
+
314
+ Examples:
315
+ ```python
316
+ import marimo as mo
317
+ from plotjs import PlotJS, data
318
+
319
+ df = data.load_iris()
320
+
321
+ html_plot = (
322
+ PlotJS(fig=fig)
323
+ .add_tooltip(labels=df["species"])
324
+ .as_html()
325
+ )
326
+
327
+ # display in marimo
328
+ mo.iframe(html_plot)
329
+ ```
330
+ """
331
+ self._set_html()
332
+ return self.html
333
+
334
+ def _set_plot_data_json(self) -> None:
335
+ if not hasattr(self, "_tooltip_labels"):
336
+ self.add_tooltip()
337
+
338
+ self.plot_data_json = {
339
+ "tooltip_labels": self._tooltip_labels,
340
+ "tooltip_groups": self._tooltip_groups,
341
+ "tooltip_x_shift": self._tooltip_x_shift,
342
+ "tooltip_y_shift": self._tooltip_y_shift,
343
+ "hover_nearest": self._hover_nearest,
344
+ "axes": self._axes_tooltip,
345
+ }
346
+
347
+ def _set_html(self) -> None:
348
+ self._set_plot_data_json()
349
+ self.html: str = self._template.render(
350
+ uuid=str(self._uuid),
351
+ default_css=self._default_css,
352
+ js_parser=self._js_parser,
353
+ additional_css=self.additional_css,
354
+ additional_javascript=self.additional_javascript,
355
+ svg=self._svg_content,
356
+ plot_data_json=self.plot_data_json,
357
+ favicon_path=self._favicon_path,
358
+ document_title=self._document_title,
359
+ )
@@ -0,0 +1,40 @@
1
+ :root {
2
+ --default-opacity: 1;
3
+ --default-not-hovered-opacity: 0.2;
4
+ --default-transition: opacity 0.1s ease;
5
+ }
6
+
7
+ svg {
8
+ width: 100%;
9
+ height: auto;
10
+ }
11
+
12
+ .tooltip {
13
+ position: absolute;
14
+ background: #001d3d;
15
+ padding: 8px 12px;
16
+ border-radius: 6px;
17
+ color: #ffffff;
18
+ font-size: 14px;
19
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
20
+ pointer-events: none;
21
+ display: none;
22
+ font-family: "Helvetica Neue", "Arial", sans-serif;
23
+ }
24
+
25
+ .plot-element {
26
+ opacity: var(--default-opacity);
27
+ transition: var(--default-transition);
28
+ }
29
+
30
+ .plot-element:hover {
31
+ opacity: var(--default-opacity);
32
+ }
33
+
34
+ .plot-element.not-hovered {
35
+ opacity: var(--default-not-hovered-opacity);
36
+ }
37
+
38
+ .plot-element.hovered {
39
+ opacity: var(--default-opacity);
40
+ }
@@ -0,0 +1,229 @@
1
+ import * as d3 from "d3-selection";
2
+
3
+ /**
4
+ * Core utility for parsing and interacting with matplotlib-generated SVG outputs.
5
+ * Provides methods to query common plot elements (bars, points, lines, areas),
6
+ * and to attach interactive hover tooltips.
7
+ *
8
+ * Example usage:
9
+ * ```js
10
+ * const parser = new PlotSVGParser(svg, tooltip, xShift, yShift);
11
+ * const points = parser.findPoints(svg, "axes_1", tooltipGroups);
12
+ * parser.setHoverEffect(points, "axes_1", tooltipLabels, tooltipGroups, "block", true);
13
+ * ```
14
+ */
15
+ export default class PlotSVGParser {
16
+ /**
17
+ * Create a new parser bound to an SVG figure.
18
+ *
19
+ * @param {d3.Selection} svg - D3 selection of the target SVG element (e.g. the entire plot).
20
+ * @param {d3.Selection} tooltip - D3 selection of the tooltip container (e.g. a div).
21
+ * @param {number} tooltip_x_shift - Horizontal offset for tooltip positioning.
22
+ * @param {number} tooltip_y_shift - Vertical offset for tooltip positioning.
23
+ */
24
+ constructor(svg, tooltip, tooltip_x_shift, tooltip_y_shift) {
25
+ this.svg = svg;
26
+ this.tooltip = tooltip;
27
+ this.tooltip_x_shift = tooltip_x_shift;
28
+ this.tooltip_y_shift = tooltip_y_shift;
29
+ }
30
+
31
+ /**
32
+ * Find bar elements (`patch` groups with clipping) inside a given axes.
33
+ *
34
+ * @param {d3.Selection} svg - D3 selection of the SVG element.
35
+ * @param {string} axes_class - ID of the axes group (e.g. "axes_1").
36
+ * @returns {d3.Selection} D3 selection of bar elements.
37
+ */
38
+ findBars(svg, axes_class) {
39
+ // select all #patch within the specific axes
40
+ const bars = svg
41
+ .selectAll(`g#${axes_class} g[id^="patch"]`)
42
+ .filter(function () {
43
+ const path = d3.select(this).select("path");
44
+ // that have a clip-path attribute
45
+ const clip = path.attr("clip-path");
46
+ // starting with "url("
47
+ return clip && clip.startsWith("url(");
48
+ });
49
+
50
+ bars.attr("class", "bar plot-element");
51
+
52
+ console.log(`Found ${bars.size()} "bar" element`);
53
+ return bars;
54
+ }
55
+
56
+ /**
57
+ * Find scatter plot points inside a given axes.
58
+ * Handles both `<use>` and `<path>` fallback cases,
59
+ * and assigns `data-group` attributes based on tooltip groups.
60
+ *
61
+ * @param {d3.Selection} svg - D3 selection of the SVG element.
62
+ * @param {string} axes_class - ID of the axes group (e.g. "axes_1").
63
+ * @param {string[]} tooltip_groups - Group identifiers for tooltips, parallel to points.
64
+ * @returns {d3.Selection} D3 selection of point elements.
65
+ */
66
+ findPoints(svg, axes_class, tooltip_groups) {
67
+ let points = svg.selectAll(
68
+ `g#${axes_class} g[id^="PathCollection"] g[clip-path] use`
69
+ );
70
+
71
+ if (points.empty()) {
72
+ // fallback: no <use> found → grab <path> instead
73
+ points = svg.selectAll(`g#${axes_class} g[id^="PathCollection"] path`);
74
+ }
75
+
76
+ points.each(function (_, i) {
77
+ d3.select(this).attr("data-group", tooltip_groups[i]);
78
+ });
79
+ points.attr("class", "point plot-element");
80
+
81
+ console.log(`Found ${points.size()} "point" element`);
82
+ return points;
83
+ }
84
+
85
+ /**
86
+ * Find line elements (`line2d` paths) inside a given axes,
87
+ * excluding axis grid lines.
88
+ *
89
+ * @param {d3.Selection} svg - D3 selection of the SVG element.
90
+ * @param {string} axes_class - ID of the axes group.
91
+ * @returns {d3.Selection} D3 selection of line elements.
92
+ */
93
+ findLines(svg, axes_class) {
94
+ // select all <path> of Line2D elements within the specific axes
95
+ const lines = svg
96
+ .selectAll(`g#${axes_class} g[id^="line2d"] path`)
97
+ .filter(function () {
98
+ return !this.closest('g[id^="matplotlib.axis"]');
99
+ });
100
+
101
+ lines.attr("class", "line plot-element");
102
+
103
+ console.log(`Found ${lines.size()} "line" element`);
104
+ return lines;
105
+ }
106
+
107
+ /**
108
+ * Find filled area elements (`FillBetweenPolyCollection` paths) inside a given axes.
109
+ *
110
+ * @param {d3.Selection} svg - D3 selection of the SVG element.
111
+ * @param {string} axes_class - ID of the axes group.
112
+ * @returns {d3.Selection} D3 selection of area elements.
113
+ */
114
+ findAreas(svg, axes_class) {
115
+ // select all <path> of FillBetweenPolyCollection elements within the specific axes
116
+ const areas = svg.selectAll(
117
+ `g#${axes_class} g[id^="FillBetweenPolyCollection"] path`
118
+ );
119
+ areas.attr("class", "area plot-element");
120
+
121
+ console.log(`Found ${areas.size()} "area" element`);
122
+ return areas;
123
+ }
124
+
125
+ /**
126
+ * Compute the nearest element to the mouse cursor from a set of elements.
127
+ * Uses bounding box centers for distance.
128
+ * This function is used when the `hover_nearest` argument is true.
129
+ *
130
+ * @param {number} mouseX - X coordinate of the mouse relative to SVG.
131
+ * @param {number} mouseY - Y coordinate of the mouse relative to SVG.
132
+ * @param {d3.Selection} elements - Selection of candidate elements.
133
+ * @returns {Element|null} The nearest DOM element or `null`.
134
+ */
135
+ nearestElementFromMouse(mouseX, mouseY, elements) {
136
+ let nearestElem = null;
137
+ let minDist = Infinity;
138
+
139
+ elements.each(function (_, i) {
140
+ const bbox = this.getBBox();
141
+ const cx = bbox.x + bbox.width / 2;
142
+ const cy = bbox.y + bbox.height / 2;
143
+ const dist = Math.hypot(mouseX - cx, mouseY - cy);
144
+ if (dist < minDist) {
145
+ minDist = dist;
146
+ nearestElem = this;
147
+ }
148
+ });
149
+
150
+ return nearestElem;
151
+ }
152
+
153
+ /**
154
+ * Attach hover interaction and tooltip display to plot elements.
155
+ * Can highlight nearest element (if enabled) or hovered element directly.
156
+ *
157
+ * @param {d3.Selection} plot_element - Selection of plot elements (points, lines, etc.).
158
+ * @param {string} axes_class - ID of the axes group.
159
+ * @param {string[]} tooltip_labels - Tooltip labels for each element.
160
+ * @param {string[]} tooltip_groups - Group identifiers for each element.
161
+ * @param {"block"|"none"} show_tooltip - Whether to display tooltips.
162
+ * @param {boolean} hover_nearest - If true, highlight nearest element instead of hovered one.
163
+ */
164
+ setHoverEffect(
165
+ plot_element,
166
+ axes_class,
167
+ tooltip_labels,
168
+ tooltip_groups,
169
+ show_tooltip,
170
+ hover_nearest
171
+ ) {
172
+ const self = this;
173
+ const axesGroup = this.svg.select(`g#${axes_class}`);
174
+ const getHoveredIndex = hover_nearest
175
+ ? (event) => {
176
+ const [mouseX, mouseY] = d3.pointer(event);
177
+ const allElements = axesGroup.selectAll(".plot-element");
178
+ const nearestElem = self.nearestElementFromMouse(
179
+ mouseX,
180
+ mouseY,
181
+ allElements
182
+ );
183
+ return nearestElem ? allElements.nodes().indexOf(nearestElem) : null;
184
+ }
185
+ : (event) => plot_element.nodes().indexOf(event.currentTarget);
186
+
187
+ const mousemoveHandler = (event) => {
188
+ const hoveredIndex = getHoveredIndex(event);
189
+ const allElements = axesGroup.selectAll(".plot-element");
190
+
191
+ allElements.classed("hovered", false).classed("not-hovered", false);
192
+
193
+ if (hoveredIndex !== null) {
194
+ const hoveredGroup = tooltip_groups[hoveredIndex];
195
+
196
+ allElements
197
+ .filter((_, j) => tooltip_groups[j] === hoveredGroup)
198
+ .classed("hovered", true);
199
+
200
+ allElements
201
+ .filter((_, j) => tooltip_groups[j] !== hoveredGroup)
202
+ .classed("not-hovered", true);
203
+
204
+ self.tooltip
205
+ .style("display", show_tooltip)
206
+ .style("left", event.pageX + self.tooltip_x_shift + "px")
207
+ .style("top", event.pageY + self.tooltip_y_shift + "px")
208
+ .html(tooltip_labels[hoveredIndex]);
209
+ } else {
210
+ self.tooltip.style("display", "none");
211
+ }
212
+ };
213
+
214
+ if (hover_nearest) {
215
+ axesGroup.on("mousemove", mousemoveHandler).on("mouseout", () => {
216
+ axesGroup
217
+ .selectAll(".plot-element")
218
+ .classed("hovered", false)
219
+ .classed("not-hovered", false);
220
+ self.tooltip.style("display", "none");
221
+ });
222
+ } else {
223
+ plot_element.on("mouseover", mousemoveHandler).on("mouseout", () => {
224
+ plot_element.classed("hovered", false).classed("not-hovered", false);
225
+ self.tooltip.style("display", "none");
226
+ });
227
+ }
228
+ }
229
+ }