plot-saver 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,47 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ jobs:
8
+ build:
9
+ runs-on: ubuntu-latest
10
+
11
+ steps:
12
+ - name: Check out repository
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.11"
19
+
20
+ - name: Install build tooling
21
+ run: python -m pip install --upgrade build
22
+
23
+ - name: Build distributions
24
+ run: python -m build
25
+
26
+ - name: Upload distributions
27
+ uses: actions/upload-artifact@v4
28
+ with:
29
+ name: dist
30
+ path: dist/*
31
+
32
+ publish:
33
+ needs: build
34
+ runs-on: ubuntu-latest
35
+ environment: pypi
36
+ permissions:
37
+ id-token: write
38
+
39
+ steps:
40
+ - name: Download distributions
41
+ uses: actions/download-artifact@v4
42
+ with:
43
+ name: dist
44
+ path: dist
45
+
46
+ - name: Publish to PyPI
47
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,115 @@
1
+ Metadata-Version: 2.4
2
+ Name: plot-saver
3
+ Version: 0.1.0
4
+ Summary: Reusable anywidget plot saver configurable with a config.toml.
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: anywidget>=0.9.21
7
+ Requires-Dist: marimo>=0.21.1
8
+ Requires-Dist: traitlets>=5.14.3
9
+ Description-Content-Type: text/markdown
10
+
11
+ # plot-saver
12
+
13
+ `plot-saver` is a small reusable package for adding save buttons to Matplotlib figures in marimo notebooks.
14
+
15
+ It saves plots under a project results directory and can be configured with a project-level `config.toml`.
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ pip install plot-saver
21
+ ```
22
+
23
+ or with `uv`:
24
+
25
+ ```bash
26
+ uv add plot-saver
27
+ ```
28
+
29
+ ## Config
30
+
31
+ `plot-saver` looks for a `config.toml` by searching upward from the current working directory.
32
+
33
+ All configuration keys are optional. If you do not provide a `config.toml`, or if you only provide some keys, `plot-saver` falls back to its built-in defaults.
34
+
35
+ The default output format is `png`.
36
+
37
+ You can override the defaults with a section like this:
38
+
39
+ ```toml
40
+ [plot-saver]
41
+ format = "png"
42
+ default_label = "Save"
43
+ save_all_label = "Save all model plots"
44
+ empty_title = "No plots available"
45
+ empty_detail = "Render the notebook plots first."
46
+ saved_title = "Saved"
47
+ saved_many_title = "Saved {count} plots"
48
+ saved_error_title = "Saved with errors"
49
+ failed_title = "Could not save plots"
50
+ toast_detail_color = "#6b7280"
51
+
52
+ [plot-saver.theme]
53
+ radius = "8px"
54
+ padding_y = "0.45rem"
55
+ padding_x = "0.9rem"
56
+ font_size = "0.9rem"
57
+ light_border = "rgba(107, 114, 128, 0.35)"
58
+ light_background = "rgba(249, 250, 251, 0.95)"
59
+ light_text = "#111827"
60
+ light_hover_background = "rgba(243, 244, 246, 1)"
61
+ light_hover_border = "rgba(107, 114, 128, 0.55)"
62
+ light_disabled_background = "rgba(229, 231, 235, 0.9)"
63
+ light_disabled_border = "rgba(156, 163, 175, 0.35)"
64
+ dark_border = "rgba(156, 163, 175, 0.35)"
65
+ dark_background = "rgba(31, 41, 55, 0.95)"
66
+ dark_text = "#f3f4f6"
67
+ dark_hover_background = "rgba(55, 65, 81, 1)"
68
+ dark_hover_border = "rgba(209, 213, 219, 0.45)"
69
+ dark_disabled_background = "rgba(55, 65, 81, 0.8)"
70
+ dark_disabled_border = "rgba(107, 114, 128, 0.35)"
71
+ ```
72
+
73
+ ## Usage
74
+
75
+ ```python
76
+ from pathlib import Path
77
+
78
+ import marimo as mo
79
+ import matplotlib.pyplot as plt
80
+ from plot_saver import make_plot_saver
81
+
82
+ fig, ax = plt.subplots()
83
+ ax.plot([0, 1, 2], [1, 3, 2])
84
+
85
+ save_plot = make_plot_saver(
86
+ mo,
87
+ results_dir=Path("results"),
88
+ config_path=None,
89
+ task_name="2AFC",
90
+ model_id="example-model",
91
+ )
92
+
93
+ button = save_plot(fig, "Example figure")
94
+ save_all = save_plot.save_all_widget()
95
+ ```
96
+
97
+ Saved plots go to:
98
+
99
+ ```text
100
+ results/plots/<task_name>/<model_id>/
101
+ ```
102
+
103
+ `save_plot(fig, "Example figure")` registers that figure and returns an individual save button for it.
104
+
105
+ `save_plot.save_all_widget()` returns a single button that saves every figure previously registered with that `PlotSaver` instance. This is useful in notebooks where you render several figures for the same task and model and want one action to export all of them together.
106
+
107
+ ## API
108
+
109
+ Main entry points:
110
+
111
+ - `make_plot_saver(...)`
112
+ - `save_button(...)`
113
+ - `save_figure(...)`
114
+ - `get_plot_save_format(...)`
115
+ - `find_project_config_path(...)`
@@ -0,0 +1,105 @@
1
+ # plot-saver
2
+
3
+ `plot-saver` is a small reusable package for adding save buttons to Matplotlib figures in marimo notebooks.
4
+
5
+ It saves plots under a project results directory and can be configured with a project-level `config.toml`.
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ pip install plot-saver
11
+ ```
12
+
13
+ or with `uv`:
14
+
15
+ ```bash
16
+ uv add plot-saver
17
+ ```
18
+
19
+ ## Config
20
+
21
+ `plot-saver` looks for a `config.toml` by searching upward from the current working directory.
22
+
23
+ All configuration keys are optional. If you do not provide a `config.toml`, or if you only provide some keys, `plot-saver` falls back to its built-in defaults.
24
+
25
+ The default output format is `png`.
26
+
27
+ You can override the defaults with a section like this:
28
+
29
+ ```toml
30
+ [plot-saver]
31
+ format = "png"
32
+ default_label = "Save"
33
+ save_all_label = "Save all model plots"
34
+ empty_title = "No plots available"
35
+ empty_detail = "Render the notebook plots first."
36
+ saved_title = "Saved"
37
+ saved_many_title = "Saved {count} plots"
38
+ saved_error_title = "Saved with errors"
39
+ failed_title = "Could not save plots"
40
+ toast_detail_color = "#6b7280"
41
+
42
+ [plot-saver.theme]
43
+ radius = "8px"
44
+ padding_y = "0.45rem"
45
+ padding_x = "0.9rem"
46
+ font_size = "0.9rem"
47
+ light_border = "rgba(107, 114, 128, 0.35)"
48
+ light_background = "rgba(249, 250, 251, 0.95)"
49
+ light_text = "#111827"
50
+ light_hover_background = "rgba(243, 244, 246, 1)"
51
+ light_hover_border = "rgba(107, 114, 128, 0.55)"
52
+ light_disabled_background = "rgba(229, 231, 235, 0.9)"
53
+ light_disabled_border = "rgba(156, 163, 175, 0.35)"
54
+ dark_border = "rgba(156, 163, 175, 0.35)"
55
+ dark_background = "rgba(31, 41, 55, 0.95)"
56
+ dark_text = "#f3f4f6"
57
+ dark_hover_background = "rgba(55, 65, 81, 1)"
58
+ dark_hover_border = "rgba(209, 213, 219, 0.45)"
59
+ dark_disabled_background = "rgba(55, 65, 81, 0.8)"
60
+ dark_disabled_border = "rgba(107, 114, 128, 0.35)"
61
+ ```
62
+
63
+ ## Usage
64
+
65
+ ```python
66
+ from pathlib import Path
67
+
68
+ import marimo as mo
69
+ import matplotlib.pyplot as plt
70
+ from plot_saver import make_plot_saver
71
+
72
+ fig, ax = plt.subplots()
73
+ ax.plot([0, 1, 2], [1, 3, 2])
74
+
75
+ save_plot = make_plot_saver(
76
+ mo,
77
+ results_dir=Path("results"),
78
+ config_path=None,
79
+ task_name="2AFC",
80
+ model_id="example-model",
81
+ )
82
+
83
+ button = save_plot(fig, "Example figure")
84
+ save_all = save_plot.save_all_widget()
85
+ ```
86
+
87
+ Saved plots go to:
88
+
89
+ ```text
90
+ results/plots/<task_name>/<model_id>/
91
+ ```
92
+
93
+ `save_plot(fig, "Example figure")` registers that figure and returns an individual save button for it.
94
+
95
+ `save_plot.save_all_widget()` returns a single button that saves every figure previously registered with that `PlotSaver` instance. This is useful in notebooks where you render several figures for the same task and model and want one action to export all of them together.
96
+
97
+ ## API
98
+
99
+ Main entry points:
100
+
101
+ - `make_plot_saver(...)`
102
+ - `save_button(...)`
103
+ - `save_figure(...)`
104
+ - `get_plot_save_format(...)`
105
+ - `find_project_config_path(...)`
@@ -0,0 +1,18 @@
1
+ [project]
2
+ name = "plot-saver"
3
+ version = "0.1.0"
4
+ description = "Reusable anywidget plot saver configurable with a config.toml."
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "anywidget>=0.9.21",
9
+ "marimo>=0.21.1",
10
+ "traitlets>=5.14.3",
11
+ ]
12
+
13
+ [build-system]
14
+ requires = ["hatchling"]
15
+ build-backend = "hatchling.build"
16
+
17
+ [tool.hatch.build.targets.wheel]
18
+ packages = ["src/plot_saver"]
@@ -0,0 +1,25 @@
1
+ from .save_widget import (
2
+ PlotSaver,
3
+ SaveFigureAnyWidget,
4
+ build_plot_path,
5
+ find_project_config_path,
6
+ get_plot_save_format,
7
+ load_app_config,
8
+ make_plot_saver,
9
+ sanitize_stem,
10
+ save_button,
11
+ save_figure,
12
+ )
13
+
14
+ __all__ = [
15
+ "PlotSaver",
16
+ "SaveFigureAnyWidget",
17
+ "build_plot_path",
18
+ "find_project_config_path",
19
+ "get_plot_save_format",
20
+ "load_app_config",
21
+ "make_plot_saver",
22
+ "sanitize_stem",
23
+ "save_button",
24
+ "save_figure",
25
+ ]
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from functools import lru_cache
5
+ from weakref import WeakKeyDictionary
6
+
7
+ from marimo._plugins.ui._core.ui_element import UIElement
8
+ from marimo._plugins.ui._impl.from_anywidget import (
9
+ anywidget as MarimoAnyWidget,
10
+ get_anywidget_model_id,
11
+ )
12
+ from marimo._utils.code import hash_code
13
+ from marimo._utils.data_uri import build_data_url
14
+
15
+
16
+ @lru_cache(maxsize=32)
17
+ def _build_js_data_url(js: str) -> str:
18
+ encoded = base64.b64encode(js.encode("utf-8"))
19
+ return build_data_url("text/javascript", encoded)
20
+
21
+
22
+ def _resolve_js_url(js: str) -> str:
23
+ if not js:
24
+ return ""
25
+ if js.startswith(("data:", "http://", "https://")):
26
+ return js
27
+ return _build_js_data_url(js)
28
+
29
+
30
+ class _StableAnyWidget(MarimoAnyWidget):
31
+ def __init__(self, widget):
32
+ self.widget = widget
33
+ self._initialized = False
34
+
35
+ js = str(getattr(widget, "_esm", "") or "")
36
+ js_hash = hash_code(js)
37
+
38
+ _ = widget.comm
39
+ model_id = get_anywidget_model_id(widget)
40
+
41
+ UIElement.__init__(
42
+ self,
43
+ component_name="marimo-anywidget",
44
+ initial_value={"model_id": model_id},
45
+ label=None,
46
+ args={
47
+ "js-url": _resolve_js_url(js),
48
+ "js-hash": js_hash,
49
+ "model-id": model_id,
50
+ },
51
+ on_change=None,
52
+ )
53
+
54
+
55
+ _WIDGET_CACHE: WeakKeyDictionary[object, _StableAnyWidget] = WeakKeyDictionary()
56
+
57
+
58
+ def wrap_anywidget(widget):
59
+ cached = _WIDGET_CACHE.get(widget)
60
+ if cached is not None:
61
+ return cached
62
+
63
+ wrapped = _StableAnyWidget(widget)
64
+ _WIDGET_CACHE[widget] = wrapped
65
+ return wrapped
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
4
+ from importlib.resources import files
5
+ from pathlib import Path
6
+ import tomllib
7
+
8
+
9
+ PROJECT_CONFIG_NAME = "config.toml"
10
+ DEFAULT_CONFIG_RESOURCE = "resources/default_config.toml"
11
+
12
+
13
+ def _read_toml(path: Path) -> dict:
14
+ with path.open("rb") as f:
15
+ return tomllib.load(f)
16
+
17
+
18
+ def _merge_dicts(base: dict, override: dict) -> dict:
19
+ merged = deepcopy(base)
20
+ for key, value in override.items():
21
+ if isinstance(value, dict) and isinstance(merged.get(key), dict):
22
+ merged[key] = _merge_dicts(merged[key], value)
23
+ else:
24
+ merged[key] = deepcopy(value)
25
+ return merged
26
+
27
+
28
+ def find_project_config_path(start: str | Path | None = None) -> Path | None:
29
+ current = Path(start).expanduser().resolve() if start is not None else Path.cwd().resolve()
30
+ if current.is_file():
31
+ current = current.parent
32
+ for directory in (current, *current.parents):
33
+ candidate = directory / PROJECT_CONFIG_NAME
34
+ if candidate.exists():
35
+ return candidate
36
+ return None
37
+
38
+
39
+ def _load_default_config() -> dict:
40
+ default_config = files("plot_saver").joinpath(DEFAULT_CONFIG_RESOURCE)
41
+ with default_config.open("rb") as f:
42
+ return tomllib.load(f)
43
+
44
+
45
+ def load_app_config(config_path: str | Path | None = None) -> dict:
46
+ cfg = deepcopy(_load_default_config())
47
+ resolved = Path(config_path).expanduser().resolve() if config_path else find_project_config_path()
48
+ if resolved is not None and resolved.exists():
49
+ cfg = _merge_dicts(cfg, _read_toml(resolved))
50
+ return cfg
@@ -0,0 +1,66 @@
1
+ .save-figure-wrap {
2
+ display: inline-flex;
3
+ flex-direction: column;
4
+ align-items: flex-start;
5
+ gap: 0.35rem;
6
+ --save-figure-radius: 8px;
7
+ --save-figure-padding-y: 0.45rem;
8
+ --save-figure-padding-x: 0.9rem;
9
+ --save-figure-font-size: 0.9rem;
10
+ --save-figure-light-border: rgba(107, 114, 128, 0.35);
11
+ --save-figure-light-background: rgba(249, 250, 251, 0.95);
12
+ --save-figure-light-text: #111827;
13
+ --save-figure-light-hover-background: rgba(243, 244, 246, 1);
14
+ --save-figure-light-hover-border: rgba(107, 114, 128, 0.55);
15
+ --save-figure-light-disabled-background: rgba(229, 231, 235, 0.9);
16
+ --save-figure-light-disabled-border: rgba(156, 163, 175, 0.35);
17
+ --save-figure-dark-border: rgba(156, 163, 175, 0.35);
18
+ --save-figure-dark-background: rgba(31, 41, 55, 0.95);
19
+ --save-figure-dark-text: #f3f4f6;
20
+ --save-figure-dark-hover-background: rgba(55, 65, 81, 1);
21
+ --save-figure-dark-hover-border: rgba(209, 213, 219, 0.45);
22
+ --save-figure-dark-disabled-background: rgba(55, 65, 81, 0.8);
23
+ --save-figure-dark-disabled-border: rgba(107, 114, 128, 0.35);
24
+ }
25
+
26
+ .save-figure-btn {
27
+ border: 1px solid var(--save-figure-light-border);
28
+ background: var(--save-figure-light-background);
29
+ color: var(--save-figure-light-text);
30
+ border-radius: var(--save-figure-radius);
31
+ padding: var(--save-figure-padding-y) var(--save-figure-padding-x);
32
+ font-size: var(--save-figure-font-size);
33
+ line-height: 1;
34
+ cursor: pointer;
35
+ transition: background 120ms ease, border-color 120ms ease;
36
+ }
37
+
38
+ .save-figure-btn:hover {
39
+ background: var(--save-figure-light-hover-background);
40
+ border-color: var(--save-figure-light-hover-border);
41
+ }
42
+
43
+ .save-figure-btn:disabled {
44
+ cursor: not-allowed;
45
+ opacity: 0.55;
46
+ background: var(--save-figure-light-disabled-background);
47
+ border-color: var(--save-figure-light-disabled-border);
48
+ }
49
+
50
+ @media (prefers-color-scheme: dark) {
51
+ .save-figure-btn {
52
+ background: var(--save-figure-dark-background);
53
+ color: var(--save-figure-dark-text);
54
+ border-color: var(--save-figure-dark-border);
55
+ }
56
+
57
+ .save-figure-btn:hover {
58
+ background: var(--save-figure-dark-hover-background);
59
+ border-color: var(--save-figure-dark-hover-border);
60
+ }
61
+
62
+ .save-figure-btn:disabled {
63
+ background: var(--save-figure-dark-disabled-background);
64
+ border-color: var(--save-figure-dark-disabled-border);
65
+ }
66
+ }
@@ -0,0 +1,40 @@
1
+ function render({ model, el }) {
2
+ el.innerHTML = "";
3
+
4
+ const wrap = document.createElement("div");
5
+ wrap.className = "save-figure-wrap";
6
+
7
+ const button = document.createElement("button");
8
+ button.className = "save-figure-btn";
9
+
10
+ const applyTheme = () => {
11
+ const theme = model.get("theme_tokens") || {};
12
+ for (const [key, value] of Object.entries(theme)) {
13
+ wrap.style.setProperty(`--save-figure-${key}`, value);
14
+ }
15
+ };
16
+
17
+ const updateButton = () => {
18
+ button.textContent = model.get("label");
19
+ button.disabled = !!model.get("disabled");
20
+ };
21
+
22
+ button.addEventListener("click", () => {
23
+ if (button.disabled) return;
24
+ const clicks = model.get("clicks") || 0;
25
+ model.set("clicks", clicks + 1);
26
+ model.save_changes();
27
+ });
28
+
29
+ model.on("change:label", updateButton);
30
+ model.on("change:disabled", updateButton);
31
+ model.on("change:theme_tokens", applyTheme);
32
+
33
+ applyTheme();
34
+ updateButton();
35
+
36
+ wrap.appendChild(button);
37
+ el.appendChild(wrap);
38
+ }
39
+
40
+ export default { render };
@@ -0,0 +1,31 @@
1
+ [plot-saver]
2
+ format = "png"
3
+ default_label = "Save"
4
+ save_all_label = "Save all model plots"
5
+ empty_title = "No plots available"
6
+ empty_detail = "Render the notebook plots first."
7
+ saved_title = "Saved"
8
+ saved_many_title = "Saved {count} plots"
9
+ saved_error_title = "Saved with errors"
10
+ failed_title = "Could not save plots"
11
+ toast_detail_color = "#6b7280"
12
+
13
+ [plot-saver.theme]
14
+ radius = "8px"
15
+ padding_y = "0.45rem"
16
+ padding_x = "0.9rem"
17
+ font_size = "0.9rem"
18
+ light_border = "rgba(107, 114, 128, 0.35)"
19
+ light_background = "rgba(249, 250, 251, 0.95)"
20
+ light_text = "#111827"
21
+ light_hover_background = "rgba(243, 244, 246, 1)"
22
+ light_hover_border = "rgba(107, 114, 128, 0.55)"
23
+ light_disabled_background = "rgba(229, 231, 235, 0.9)"
24
+ light_disabled_border = "rgba(156, 163, 175, 0.35)"
25
+ dark_border = "rgba(156, 163, 175, 0.35)"
26
+ dark_background = "rgba(31, 41, 55, 0.95)"
27
+ dark_text = "#f3f4f6"
28
+ dark_hover_background = "rgba(55, 65, 81, 1)"
29
+ dark_hover_border = "rgba(209, 213, 219, 0.45)"
30
+ dark_disabled_background = "rgba(55, 65, 81, 0.8)"
31
+ dark_disabled_border = "rgba(107, 114, 128, 0.35)"
@@ -0,0 +1,331 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ import re
5
+ from typing import Any
6
+
7
+ import anywidget
8
+ import traitlets
9
+
10
+ from .anywidget_compat import wrap_anywidget
11
+ from .config import find_project_config_path, load_app_config
12
+
13
+
14
+ _ASSET_DIR = Path(__file__).parent
15
+
16
+
17
+ def _read_asset(name: str) -> str:
18
+ return (_ASSET_DIR / name).read_text(encoding="utf-8")
19
+
20
+
21
+ def _get_save_figure_config(config_path: Path | None) -> dict[str, str]:
22
+ cfg = load_app_config(config_path)
23
+ section = cfg.get("plot-saver", {})
24
+ if not isinstance(section, dict):
25
+ return {}
26
+ save_cfg = {str(key): str(value) for key, value in section.items() if not isinstance(value, dict)}
27
+ theme_section = section.get("theme", {})
28
+ if isinstance(theme_section, dict):
29
+ save_cfg.update({str(key): str(value) for key, value in theme_section.items()})
30
+ return save_cfg
31
+
32
+
33
+ def _save_figure_theme_tokens(save_cfg: dict[str, str]) -> dict[str, str]:
34
+ token_keys = (
35
+ "radius",
36
+ "padding_y",
37
+ "padding_x",
38
+ "font_size",
39
+ "light_border",
40
+ "light_background",
41
+ "light_text",
42
+ "light_hover_background",
43
+ "light_hover_border",
44
+ "light_disabled_background",
45
+ "light_disabled_border",
46
+ "dark_border",
47
+ "dark_background",
48
+ "dark_text",
49
+ "dark_hover_background",
50
+ "dark_hover_border",
51
+ "dark_disabled_background",
52
+ "dark_disabled_border",
53
+ )
54
+ return {
55
+ key.replace("_", "-"): value
56
+ for key in token_keys
57
+ if (value := save_cfg.get(key))
58
+ }
59
+
60
+
61
+ def _toast_detail_html(detail: str, save_cfg: dict[str, str]) -> str:
62
+ color = save_cfg.get("toast_detail_color", "#6b7280")
63
+ return f"<span style='color:{color}'>{detail}</span>"
64
+
65
+
66
+ class SaveFigureAnyWidget(anywidget.AnyWidget):
67
+ _esm = _read_asset("figure_save_widget.js")
68
+ _css = _read_asset("figure_save_widget.css")
69
+
70
+ clicks = traitlets.Int(0).tag(sync=True)
71
+ label = traitlets.Unicode("Save").tag(sync=True)
72
+ disabled = traitlets.Bool(False).tag(sync=True)
73
+ theme_tokens = traitlets.Dict(default_value={}).tag(sync=True)
74
+
75
+
76
+ def get_plot_save_format(config_path: Path | None = None) -> str:
77
+ cfg = load_app_config(config_path)
78
+ fmt = str(cfg.get("plot-saver", {}).get("format", "pdf")).lower().strip(". ")
79
+ if fmt not in {"pdf", "svg", "png"}:
80
+ fmt = "pdf"
81
+ return fmt
82
+
83
+
84
+ def sanitize_stem(stem: str) -> str:
85
+ stem = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-")
86
+ return stem or "figure"
87
+
88
+
89
+ def build_plot_path(results_dir: Path, task_name: str, model_id: str, stem: str, fmt: str) -> Path:
90
+ out_dir = results_dir / "plots" / task_name / model_id
91
+ out_dir.mkdir(parents=True, exist_ok=True)
92
+ return out_dir / f"{sanitize_stem(stem)}.{fmt}"
93
+
94
+
95
+ def _axis_for_location(fig: Any, location: tuple[int, int]):
96
+ row, col = (int(location[0]), int(location[1]))
97
+ for ax in getattr(fig, "axes", []):
98
+ get_subplotspec = getattr(ax, "get_subplotspec", None)
99
+ if get_subplotspec is None:
100
+ continue
101
+ spec = get_subplotspec()
102
+ if spec is None:
103
+ continue
104
+ if row in range(spec.rowspan.start, spec.rowspan.stop) and col in range(spec.colspan.start, spec.colspan.stop):
105
+ return ax
106
+ raise ValueError(f"No subplot found at location ({row}, {col}).")
107
+
108
+
109
+ def _save_axis(fig: Any, ax: Any, out_path: Path, fmt: str) -> Path:
110
+ if hasattr(fig, "canvas") and fig.canvas is not None:
111
+ fig.canvas.draw()
112
+ renderer = fig.canvas.get_renderer()
113
+ bbox = ax.get_tightbbox(renderer).transformed(fig.dpi_scale_trans.inverted())
114
+ save_kwargs = {"bbox_inches": bbox}
115
+ if fmt != "svg":
116
+ save_kwargs["dpi"] = 300
117
+ fig.savefig(out_path, **save_kwargs)
118
+ return out_path
119
+
120
+
121
+ def save_figure(
122
+ fig,
123
+ *,
124
+ results_dir: Path,
125
+ config_path: Path | None,
126
+ task_name: str,
127
+ model_id: str,
128
+ stem: str,
129
+ location: tuple[int, int] | None = None,
130
+ ) -> Path:
131
+ fmt = get_plot_save_format(config_path)
132
+ out_path = build_plot_path(results_dir, task_name, model_id, stem, fmt)
133
+ if location is not None:
134
+ ax = _axis_for_location(fig, location)
135
+ return _save_axis(fig, ax, out_path, fmt)
136
+ if hasattr(fig, "canvas") and fig.canvas is not None:
137
+ fig.canvas.draw()
138
+ save_kwargs = {"bbox_inches": "tight"}
139
+ if fmt != "svg":
140
+ save_kwargs["dpi"] = 300
141
+ fig.savefig(out_path, **save_kwargs)
142
+ return out_path
143
+
144
+
145
+ class PlotSaver:
146
+ def __init__(self, mo, *, results_dir: Path, config_path: Path | None, task_name: str, model_id: str):
147
+ self.mo = mo
148
+ self.results_dir = results_dir
149
+ self.config_path = config_path or find_project_config_path()
150
+ self.task_name = task_name
151
+ self.model_id = model_id
152
+ self.fmt = get_plot_save_format(self.config_path)
153
+ self.save_cfg = _get_save_figure_config(self.config_path)
154
+ self.theme_tokens = _save_figure_theme_tokens(self.save_cfg)
155
+ self._registry: dict[str, dict[str, object]] = {}
156
+ self._save_all = SaveFigureAnyWidget(
157
+ label=self.save_cfg.get("save_all_label", "Save all model plots"),
158
+ disabled=True,
159
+ theme_tokens=self.theme_tokens,
160
+ )
161
+ self._save_all.observe(self._handle_save_all_click, names="clicks")
162
+ self._save_all._save_observer = self._handle_save_all_click
163
+ self._save_all_ui = None
164
+
165
+ def _save_one(self, fig, *, stem: str, location: tuple[int, int] | None = None) -> Path:
166
+ return save_figure(
167
+ fig,
168
+ results_dir=self.results_dir,
169
+ config_path=self.config_path,
170
+ task_name=self.task_name,
171
+ model_id=self.model_id,
172
+ stem=stem,
173
+ location=location,
174
+ )
175
+
176
+ def _register(self, fig, *, name: str, stem: str, location: tuple[int, int] | None = None) -> None:
177
+ self._registry[stem] = {
178
+ "fig": fig,
179
+ "name": name,
180
+ "stem": stem,
181
+ "location": location,
182
+ }
183
+ self._save_all.disabled = not bool(self._registry)
184
+
185
+ def _saved_message(self, saved_paths: list[Path]) -> str:
186
+ if not saved_paths:
187
+ return "No files saved."
188
+ if len(saved_paths) == 1:
189
+ return saved_paths[0].name
190
+ return f"{saved_paths[0].name} + {len(saved_paths) - 1} more"
191
+
192
+ def save_all(self) -> tuple[list[Path], list[tuple[str, Exception]]]:
193
+ saved_paths: list[Path] = []
194
+ errors: list[tuple[str, Exception]] = []
195
+ for item in list(self._registry.values()):
196
+ try:
197
+ out_path = self._save_one(
198
+ item["fig"],
199
+ stem=str(item["stem"]),
200
+ location=item["location"],
201
+ )
202
+ saved_paths.append(out_path)
203
+ except Exception as exc:
204
+ errors.append((str(item["name"]), exc))
205
+ return saved_paths, errors
206
+
207
+ def _handle_save_all_click(self, change) -> None:
208
+ if int(change["new"]) <= int(change["old"]):
209
+ return
210
+ if not self._registry:
211
+ self.mo.status.toast(
212
+ self.save_cfg.get("empty_title", "No plots available"),
213
+ _toast_detail_html(
214
+ self.save_cfg.get("empty_detail", "Render the notebook plots first."),
215
+ self.save_cfg,
216
+ ),
217
+ kind="danger",
218
+ )
219
+ return
220
+
221
+ saved_paths, errors = self.save_all()
222
+ if errors:
223
+ detail = self._saved_message(saved_paths)
224
+ if saved_paths:
225
+ detail = f"{detail}; {len(errors)} failed"
226
+ else:
227
+ detail = f"{len(errors)} failed"
228
+ self.mo.status.toast(
229
+ self.save_cfg.get(
230
+ "saved_error_title" if saved_paths else "failed_title",
231
+ "Saved with errors" if saved_paths else "Could not save plots",
232
+ ),
233
+ _toast_detail_html(detail, self.save_cfg),
234
+ kind="danger",
235
+ )
236
+ return
237
+
238
+ count = len(saved_paths)
239
+ title = (
240
+ self.save_cfg.get("saved_title", "Saved")
241
+ if count == 1
242
+ else self.save_cfg.get("saved_many_title", "Saved {count} plots").format(count=count)
243
+ )
244
+ self.mo.status.toast(
245
+ title,
246
+ _toast_detail_html(self._saved_message(saved_paths), self.save_cfg),
247
+ )
248
+
249
+ def save_all_widget(self, label: str | None = None):
250
+ self._save_all.label = label or self.save_cfg.get("save_all_label", "Save all model plots")
251
+ if self._save_all_ui is None:
252
+ self._save_all_ui = wrap_anywidget(self._save_all)
253
+ return self._save_all_ui
254
+
255
+ def __call__(
256
+ self,
257
+ fig,
258
+ name: str,
259
+ *,
260
+ stem: str | None = None,
261
+ label: str | None = None,
262
+ location: tuple[int, int] | None = None,
263
+ ):
264
+ if location is not None:
265
+ row, col = (int(location[0]), int(location[1]))
266
+ default_stem = f"{sanitize_stem(name.lower())}_r{row}_c{col}"
267
+ else:
268
+ default_stem = sanitize_stem(name.lower())
269
+ resolved_stem = stem or default_stem
270
+ button_label = label or f"{self.save_cfg.get('default_label', 'Save')} .{self.fmt}"
271
+ self._register(fig, name=name, stem=resolved_stem, location=location)
272
+ widget = SaveFigureAnyWidget(label=button_label, theme_tokens=self.theme_tokens)
273
+
274
+ def _handle_click(change):
275
+ if int(change["new"]) <= int(change["old"]):
276
+ return
277
+ try:
278
+ out_path = self._save_one(fig, stem=resolved_stem, location=location)
279
+ self.mo.status.toast(
280
+ self.save_cfg.get("saved_title", "Saved"),
281
+ _toast_detail_html(out_path.name, self.save_cfg),
282
+ )
283
+ except Exception as exc:
284
+ self.mo.status.toast(
285
+ self.save_cfg.get("failed_title", "Could not save plots"),
286
+ _toast_detail_html(f"{type(exc).__name__}: {exc}", self.save_cfg),
287
+ kind="danger",
288
+ )
289
+
290
+ widget.observe(_handle_click, names="clicks")
291
+ widget._save_observer = _handle_click
292
+ return wrap_anywidget(widget)
293
+
294
+
295
+ def make_plot_saver(mo, *, results_dir: Path, config_path: Path | None, task_name: str, model_id: str):
296
+ return PlotSaver(
297
+ mo,
298
+ results_dir=results_dir,
299
+ config_path=config_path,
300
+ task_name=task_name,
301
+ model_id=model_id,
302
+ )
303
+
304
+
305
+ def save_button(
306
+ mo,
307
+ fig,
308
+ *,
309
+ results_dir: Path,
310
+ config_path: Path | None,
311
+ task_name: str,
312
+ model_id: str,
313
+ stem: str,
314
+ label: str | None = None,
315
+ location: tuple[int, int] | None = None,
316
+ ):
317
+ fmt = get_plot_save_format(config_path)
318
+ default_label = label or _get_save_figure_config(config_path).get("default_label", "Save")
319
+ return make_plot_saver(
320
+ mo,
321
+ results_dir=results_dir,
322
+ config_path=config_path,
323
+ task_name=task_name,
324
+ model_id=model_id,
325
+ )(
326
+ fig,
327
+ name=default_label,
328
+ stem=stem,
329
+ label=f"{default_label} .{fmt}",
330
+ location=location,
331
+ )