plot-saver 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plot_saver-0.1.0/.github/workflows/publish.yml +47 -0
- plot_saver-0.1.0/PKG-INFO +115 -0
- plot_saver-0.1.0/README.md +105 -0
- plot_saver-0.1.0/pyproject.toml +18 -0
- plot_saver-0.1.0/src/plot_saver/__init__.py +25 -0
- plot_saver-0.1.0/src/plot_saver/anywidget_compat.py +65 -0
- plot_saver-0.1.0/src/plot_saver/config.py +50 -0
- plot_saver-0.1.0/src/plot_saver/figure_save_widget.css +66 -0
- plot_saver-0.1.0/src/plot_saver/figure_save_widget.js +40 -0
- plot_saver-0.1.0/src/plot_saver/resources/default_config.toml +31 -0
- plot_saver-0.1.0/src/plot_saver/save_widget.py +331 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
release:
|
|
5
|
+
types: [published]
|
|
6
|
+
|
|
7
|
+
jobs:
|
|
8
|
+
build:
|
|
9
|
+
runs-on: ubuntu-latest
|
|
10
|
+
|
|
11
|
+
steps:
|
|
12
|
+
- name: Check out repository
|
|
13
|
+
uses: actions/checkout@v4
|
|
14
|
+
|
|
15
|
+
- name: Set up Python
|
|
16
|
+
uses: actions/setup-python@v5
|
|
17
|
+
with:
|
|
18
|
+
python-version: "3.11"
|
|
19
|
+
|
|
20
|
+
- name: Install build tooling
|
|
21
|
+
run: python -m pip install --upgrade build
|
|
22
|
+
|
|
23
|
+
- name: Build distributions
|
|
24
|
+
run: python -m build
|
|
25
|
+
|
|
26
|
+
- name: Upload distributions
|
|
27
|
+
uses: actions/upload-artifact@v4
|
|
28
|
+
with:
|
|
29
|
+
name: dist
|
|
30
|
+
path: dist/*
|
|
31
|
+
|
|
32
|
+
publish:
|
|
33
|
+
needs: build
|
|
34
|
+
runs-on: ubuntu-latest
|
|
35
|
+
environment: pypi
|
|
36
|
+
permissions:
|
|
37
|
+
id-token: write
|
|
38
|
+
|
|
39
|
+
steps:
|
|
40
|
+
- name: Download distributions
|
|
41
|
+
uses: actions/download-artifact@v4
|
|
42
|
+
with:
|
|
43
|
+
name: dist
|
|
44
|
+
path: dist
|
|
45
|
+
|
|
46
|
+
- name: Publish to PyPI
|
|
47
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: plot-saver
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Reusable anywidget plot saver configurable with a config.toml.
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: anywidget>=0.9.21
|
|
7
|
+
Requires-Dist: marimo>=0.21.1
|
|
8
|
+
Requires-Dist: traitlets>=5.14.3
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
|
|
11
|
+
# plot-saver
|
|
12
|
+
|
|
13
|
+
`plot-saver` is a small reusable package for adding save buttons to Matplotlib figures in marimo notebooks.
|
|
14
|
+
|
|
15
|
+
It saves plots under a project results directory and can be configured with a project-level `config.toml`.
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install plot-saver
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
or with `uv`:
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
uv add plot-saver
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Config
|
|
30
|
+
|
|
31
|
+
`plot-saver` looks for a `config.toml` by searching upward from the current working directory.
|
|
32
|
+
|
|
33
|
+
All configuration keys are optional. If you do not provide a `config.toml`, or if you only provide some keys, `plot-saver` falls back to its built-in defaults.
|
|
34
|
+
|
|
35
|
+
The default output format is `png`.
|
|
36
|
+
|
|
37
|
+
You can override the defaults with a section like this:
|
|
38
|
+
|
|
39
|
+
```toml
|
|
40
|
+
[plot-saver]
|
|
41
|
+
format = "png"
|
|
42
|
+
default_label = "Save"
|
|
43
|
+
save_all_label = "Save all model plots"
|
|
44
|
+
empty_title = "No plots available"
|
|
45
|
+
empty_detail = "Render the notebook plots first."
|
|
46
|
+
saved_title = "Saved"
|
|
47
|
+
saved_many_title = "Saved {count} plots"
|
|
48
|
+
saved_error_title = "Saved with errors"
|
|
49
|
+
failed_title = "Could not save plots"
|
|
50
|
+
toast_detail_color = "#6b7280"
|
|
51
|
+
|
|
52
|
+
[plot-saver.theme]
|
|
53
|
+
radius = "8px"
|
|
54
|
+
padding_y = "0.45rem"
|
|
55
|
+
padding_x = "0.9rem"
|
|
56
|
+
font_size = "0.9rem"
|
|
57
|
+
light_border = "rgba(107, 114, 128, 0.35)"
|
|
58
|
+
light_background = "rgba(249, 250, 251, 0.95)"
|
|
59
|
+
light_text = "#111827"
|
|
60
|
+
light_hover_background = "rgba(243, 244, 246, 1)"
|
|
61
|
+
light_hover_border = "rgba(107, 114, 128, 0.55)"
|
|
62
|
+
light_disabled_background = "rgba(229, 231, 235, 0.9)"
|
|
63
|
+
light_disabled_border = "rgba(156, 163, 175, 0.35)"
|
|
64
|
+
dark_border = "rgba(156, 163, 175, 0.35)"
|
|
65
|
+
dark_background = "rgba(31, 41, 55, 0.95)"
|
|
66
|
+
dark_text = "#f3f4f6"
|
|
67
|
+
dark_hover_background = "rgba(55, 65, 81, 1)"
|
|
68
|
+
dark_hover_border = "rgba(209, 213, 219, 0.45)"
|
|
69
|
+
dark_disabled_background = "rgba(55, 65, 81, 0.8)"
|
|
70
|
+
dark_disabled_border = "rgba(107, 114, 128, 0.35)"
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Usage
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
from pathlib import Path
|
|
77
|
+
|
|
78
|
+
import marimo as mo
|
|
79
|
+
import matplotlib.pyplot as plt
|
|
80
|
+
from plot_saver import make_plot_saver
|
|
81
|
+
|
|
82
|
+
fig, ax = plt.subplots()
|
|
83
|
+
ax.plot([0, 1, 2], [1, 3, 2])
|
|
84
|
+
|
|
85
|
+
save_plot = make_plot_saver(
|
|
86
|
+
mo,
|
|
87
|
+
results_dir=Path("results"),
|
|
88
|
+
config_path=None,
|
|
89
|
+
task_name="2AFC",
|
|
90
|
+
model_id="example-model",
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
button = save_plot(fig, "Example figure")
|
|
94
|
+
save_all = save_plot.save_all_widget()
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
Saved plots go to:
|
|
98
|
+
|
|
99
|
+
```text
|
|
100
|
+
results/plots/<task_name>/<model_id>/
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
`save_plot(fig, "Example figure")` registers that figure and returns an individual save button for it.
|
|
104
|
+
|
|
105
|
+
`save_plot.save_all_widget()` returns a single button that saves every figure previously registered with that `PlotSaver` instance. This is useful in notebooks where you render several figures for the same task and model and want one action to export all of them together.
|
|
106
|
+
|
|
107
|
+
## API
|
|
108
|
+
|
|
109
|
+
Main entry points:
|
|
110
|
+
|
|
111
|
+
- `make_plot_saver(...)`
|
|
112
|
+
- `save_button(...)`
|
|
113
|
+
- `save_figure(...)`
|
|
114
|
+
- `get_plot_save_format(...)`
|
|
115
|
+
- `find_project_config_path(...)`
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# plot-saver
|
|
2
|
+
|
|
3
|
+
`plot-saver` is a small reusable package for adding save buttons to Matplotlib figures in marimo notebooks.
|
|
4
|
+
|
|
5
|
+
It saves plots under a project results directory and can be configured with a project-level `config.toml`.
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install plot-saver
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
or with `uv`:
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
uv add plot-saver
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Config
|
|
20
|
+
|
|
21
|
+
`plot-saver` looks for a `config.toml` by searching upward from the current working directory.
|
|
22
|
+
|
|
23
|
+
All configuration keys are optional. If you do not provide a `config.toml`, or if you only provide some keys, `plot-saver` falls back to its built-in defaults.
|
|
24
|
+
|
|
25
|
+
The default output format is `png`.
|
|
26
|
+
|
|
27
|
+
You can override the defaults with a section like this:
|
|
28
|
+
|
|
29
|
+
```toml
|
|
30
|
+
[plot-saver]
|
|
31
|
+
format = "png"
|
|
32
|
+
default_label = "Save"
|
|
33
|
+
save_all_label = "Save all model plots"
|
|
34
|
+
empty_title = "No plots available"
|
|
35
|
+
empty_detail = "Render the notebook plots first."
|
|
36
|
+
saved_title = "Saved"
|
|
37
|
+
saved_many_title = "Saved {count} plots"
|
|
38
|
+
saved_error_title = "Saved with errors"
|
|
39
|
+
failed_title = "Could not save plots"
|
|
40
|
+
toast_detail_color = "#6b7280"
|
|
41
|
+
|
|
42
|
+
[plot-saver.theme]
|
|
43
|
+
radius = "8px"
|
|
44
|
+
padding_y = "0.45rem"
|
|
45
|
+
padding_x = "0.9rem"
|
|
46
|
+
font_size = "0.9rem"
|
|
47
|
+
light_border = "rgba(107, 114, 128, 0.35)"
|
|
48
|
+
light_background = "rgba(249, 250, 251, 0.95)"
|
|
49
|
+
light_text = "#111827"
|
|
50
|
+
light_hover_background = "rgba(243, 244, 246, 1)"
|
|
51
|
+
light_hover_border = "rgba(107, 114, 128, 0.55)"
|
|
52
|
+
light_disabled_background = "rgba(229, 231, 235, 0.9)"
|
|
53
|
+
light_disabled_border = "rgba(156, 163, 175, 0.35)"
|
|
54
|
+
dark_border = "rgba(156, 163, 175, 0.35)"
|
|
55
|
+
dark_background = "rgba(31, 41, 55, 0.95)"
|
|
56
|
+
dark_text = "#f3f4f6"
|
|
57
|
+
dark_hover_background = "rgba(55, 65, 81, 1)"
|
|
58
|
+
dark_hover_border = "rgba(209, 213, 219, 0.45)"
|
|
59
|
+
dark_disabled_background = "rgba(55, 65, 81, 0.8)"
|
|
60
|
+
dark_disabled_border = "rgba(107, 114, 128, 0.35)"
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Usage
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from pathlib import Path
|
|
67
|
+
|
|
68
|
+
import marimo as mo
|
|
69
|
+
import matplotlib.pyplot as plt
|
|
70
|
+
from plot_saver import make_plot_saver
|
|
71
|
+
|
|
72
|
+
fig, ax = plt.subplots()
|
|
73
|
+
ax.plot([0, 1, 2], [1, 3, 2])
|
|
74
|
+
|
|
75
|
+
save_plot = make_plot_saver(
|
|
76
|
+
mo,
|
|
77
|
+
results_dir=Path("results"),
|
|
78
|
+
config_path=None,
|
|
79
|
+
task_name="2AFC",
|
|
80
|
+
model_id="example-model",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
button = save_plot(fig, "Example figure")
|
|
84
|
+
save_all = save_plot.save_all_widget()
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
Saved plots go to:
|
|
88
|
+
|
|
89
|
+
```text
|
|
90
|
+
results/plots/<task_name>/<model_id>/
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
`save_plot(fig, "Example figure")` registers that figure and returns an individual save button for it.
|
|
94
|
+
|
|
95
|
+
`save_plot.save_all_widget()` returns a single button that saves every figure previously registered with that `PlotSaver` instance. This is useful in notebooks where you render several figures for the same task and model and want one action to export all of them together.
|
|
96
|
+
|
|
97
|
+
## API
|
|
98
|
+
|
|
99
|
+
Main entry points:
|
|
100
|
+
|
|
101
|
+
- `make_plot_saver(...)`
|
|
102
|
+
- `save_button(...)`
|
|
103
|
+
- `save_figure(...)`
|
|
104
|
+
- `get_plot_save_format(...)`
|
|
105
|
+
- `find_project_config_path(...)`
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "plot-saver"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Reusable anywidget plot saver configurable with a config.toml."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"anywidget>=0.9.21",
|
|
9
|
+
"marimo>=0.21.1",
|
|
10
|
+
"traitlets>=5.14.3",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[build-system]
|
|
14
|
+
requires = ["hatchling"]
|
|
15
|
+
build-backend = "hatchling.build"
|
|
16
|
+
|
|
17
|
+
[tool.hatch.build.targets.wheel]
|
|
18
|
+
packages = ["src/plot_saver"]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .save_widget import (
|
|
2
|
+
PlotSaver,
|
|
3
|
+
SaveFigureAnyWidget,
|
|
4
|
+
build_plot_path,
|
|
5
|
+
find_project_config_path,
|
|
6
|
+
get_plot_save_format,
|
|
7
|
+
load_app_config,
|
|
8
|
+
make_plot_saver,
|
|
9
|
+
sanitize_stem,
|
|
10
|
+
save_button,
|
|
11
|
+
save_figure,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"PlotSaver",
|
|
16
|
+
"SaveFigureAnyWidget",
|
|
17
|
+
"build_plot_path",
|
|
18
|
+
"find_project_config_path",
|
|
19
|
+
"get_plot_save_format",
|
|
20
|
+
"load_app_config",
|
|
21
|
+
"make_plot_saver",
|
|
22
|
+
"sanitize_stem",
|
|
23
|
+
"save_button",
|
|
24
|
+
"save_figure",
|
|
25
|
+
]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from weakref import WeakKeyDictionary
|
|
6
|
+
|
|
7
|
+
from marimo._plugins.ui._core.ui_element import UIElement
|
|
8
|
+
from marimo._plugins.ui._impl.from_anywidget import (
|
|
9
|
+
anywidget as MarimoAnyWidget,
|
|
10
|
+
get_anywidget_model_id,
|
|
11
|
+
)
|
|
12
|
+
from marimo._utils.code import hash_code
|
|
13
|
+
from marimo._utils.data_uri import build_data_url
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@lru_cache(maxsize=32)
|
|
17
|
+
def _build_js_data_url(js: str) -> str:
|
|
18
|
+
encoded = base64.b64encode(js.encode("utf-8"))
|
|
19
|
+
return build_data_url("text/javascript", encoded)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _resolve_js_url(js: str) -> str:
|
|
23
|
+
if not js:
|
|
24
|
+
return ""
|
|
25
|
+
if js.startswith(("data:", "http://", "https://")):
|
|
26
|
+
return js
|
|
27
|
+
return _build_js_data_url(js)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class _StableAnyWidget(MarimoAnyWidget):
|
|
31
|
+
def __init__(self, widget):
|
|
32
|
+
self.widget = widget
|
|
33
|
+
self._initialized = False
|
|
34
|
+
|
|
35
|
+
js = str(getattr(widget, "_esm", "") or "")
|
|
36
|
+
js_hash = hash_code(js)
|
|
37
|
+
|
|
38
|
+
_ = widget.comm
|
|
39
|
+
model_id = get_anywidget_model_id(widget)
|
|
40
|
+
|
|
41
|
+
UIElement.__init__(
|
|
42
|
+
self,
|
|
43
|
+
component_name="marimo-anywidget",
|
|
44
|
+
initial_value={"model_id": model_id},
|
|
45
|
+
label=None,
|
|
46
|
+
args={
|
|
47
|
+
"js-url": _resolve_js_url(js),
|
|
48
|
+
"js-hash": js_hash,
|
|
49
|
+
"model-id": model_id,
|
|
50
|
+
},
|
|
51
|
+
on_change=None,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_WIDGET_CACHE: WeakKeyDictionary[object, _StableAnyWidget] = WeakKeyDictionary()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def wrap_anywidget(widget):
|
|
59
|
+
cached = _WIDGET_CACHE.get(widget)
|
|
60
|
+
if cached is not None:
|
|
61
|
+
return cached
|
|
62
|
+
|
|
63
|
+
wrapped = _StableAnyWidget(widget)
|
|
64
|
+
_WIDGET_CACHE[widget] = wrapped
|
|
65
|
+
return wrapped
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from importlib.resources import files
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import tomllib
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
PROJECT_CONFIG_NAME = "config.toml"
|
|
10
|
+
DEFAULT_CONFIG_RESOURCE = "resources/default_config.toml"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _read_toml(path: Path) -> dict:
|
|
14
|
+
with path.open("rb") as f:
|
|
15
|
+
return tomllib.load(f)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _merge_dicts(base: dict, override: dict) -> dict:
|
|
19
|
+
merged = deepcopy(base)
|
|
20
|
+
for key, value in override.items():
|
|
21
|
+
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
|
22
|
+
merged[key] = _merge_dicts(merged[key], value)
|
|
23
|
+
else:
|
|
24
|
+
merged[key] = deepcopy(value)
|
|
25
|
+
return merged
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def find_project_config_path(start: str | Path | None = None) -> Path | None:
|
|
29
|
+
current = Path(start).expanduser().resolve() if start is not None else Path.cwd().resolve()
|
|
30
|
+
if current.is_file():
|
|
31
|
+
current = current.parent
|
|
32
|
+
for directory in (current, *current.parents):
|
|
33
|
+
candidate = directory / PROJECT_CONFIG_NAME
|
|
34
|
+
if candidate.exists():
|
|
35
|
+
return candidate
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _load_default_config() -> dict:
|
|
40
|
+
default_config = files("plot_saver").joinpath(DEFAULT_CONFIG_RESOURCE)
|
|
41
|
+
with default_config.open("rb") as f:
|
|
42
|
+
return tomllib.load(f)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_app_config(config_path: str | Path | None = None) -> dict:
|
|
46
|
+
cfg = deepcopy(_load_default_config())
|
|
47
|
+
resolved = Path(config_path).expanduser().resolve() if config_path else find_project_config_path()
|
|
48
|
+
if resolved is not None and resolved.exists():
|
|
49
|
+
cfg = _merge_dicts(cfg, _read_toml(resolved))
|
|
50
|
+
return cfg
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
.save-figure-wrap {
|
|
2
|
+
display: inline-flex;
|
|
3
|
+
flex-direction: column;
|
|
4
|
+
align-items: flex-start;
|
|
5
|
+
gap: 0.35rem;
|
|
6
|
+
--save-figure-radius: 8px;
|
|
7
|
+
--save-figure-padding-y: 0.45rem;
|
|
8
|
+
--save-figure-padding-x: 0.9rem;
|
|
9
|
+
--save-figure-font-size: 0.9rem;
|
|
10
|
+
--save-figure-light-border: rgba(107, 114, 128, 0.35);
|
|
11
|
+
--save-figure-light-background: rgba(249, 250, 251, 0.95);
|
|
12
|
+
--save-figure-light-text: #111827;
|
|
13
|
+
--save-figure-light-hover-background: rgba(243, 244, 246, 1);
|
|
14
|
+
--save-figure-light-hover-border: rgba(107, 114, 128, 0.55);
|
|
15
|
+
--save-figure-light-disabled-background: rgba(229, 231, 235, 0.9);
|
|
16
|
+
--save-figure-light-disabled-border: rgba(156, 163, 175, 0.35);
|
|
17
|
+
--save-figure-dark-border: rgba(156, 163, 175, 0.35);
|
|
18
|
+
--save-figure-dark-background: rgba(31, 41, 55, 0.95);
|
|
19
|
+
--save-figure-dark-text: #f3f4f6;
|
|
20
|
+
--save-figure-dark-hover-background: rgba(55, 65, 81, 1);
|
|
21
|
+
--save-figure-dark-hover-border: rgba(209, 213, 219, 0.45);
|
|
22
|
+
--save-figure-dark-disabled-background: rgba(55, 65, 81, 0.8);
|
|
23
|
+
--save-figure-dark-disabled-border: rgba(107, 114, 128, 0.35);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
.save-figure-btn {
|
|
27
|
+
border: 1px solid var(--save-figure-light-border);
|
|
28
|
+
background: var(--save-figure-light-background);
|
|
29
|
+
color: var(--save-figure-light-text);
|
|
30
|
+
border-radius: var(--save-figure-radius);
|
|
31
|
+
padding: var(--save-figure-padding-y) var(--save-figure-padding-x);
|
|
32
|
+
font-size: var(--save-figure-font-size);
|
|
33
|
+
line-height: 1;
|
|
34
|
+
cursor: pointer;
|
|
35
|
+
transition: background 120ms ease, border-color 120ms ease;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
.save-figure-btn:hover {
|
|
39
|
+
background: var(--save-figure-light-hover-background);
|
|
40
|
+
border-color: var(--save-figure-light-hover-border);
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
.save-figure-btn:disabled {
|
|
44
|
+
cursor: not-allowed;
|
|
45
|
+
opacity: 0.55;
|
|
46
|
+
background: var(--save-figure-light-disabled-background);
|
|
47
|
+
border-color: var(--save-figure-light-disabled-border);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
@media (prefers-color-scheme: dark) {
|
|
51
|
+
.save-figure-btn {
|
|
52
|
+
background: var(--save-figure-dark-background);
|
|
53
|
+
color: var(--save-figure-dark-text);
|
|
54
|
+
border-color: var(--save-figure-dark-border);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
.save-figure-btn:hover {
|
|
58
|
+
background: var(--save-figure-dark-hover-background);
|
|
59
|
+
border-color: var(--save-figure-dark-hover-border);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
.save-figure-btn:disabled {
|
|
63
|
+
background: var(--save-figure-dark-disabled-background);
|
|
64
|
+
border-color: var(--save-figure-dark-disabled-border);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
function render({ model, el }) {
|
|
2
|
+
el.innerHTML = "";
|
|
3
|
+
|
|
4
|
+
const wrap = document.createElement("div");
|
|
5
|
+
wrap.className = "save-figure-wrap";
|
|
6
|
+
|
|
7
|
+
const button = document.createElement("button");
|
|
8
|
+
button.className = "save-figure-btn";
|
|
9
|
+
|
|
10
|
+
const applyTheme = () => {
|
|
11
|
+
const theme = model.get("theme_tokens") || {};
|
|
12
|
+
for (const [key, value] of Object.entries(theme)) {
|
|
13
|
+
wrap.style.setProperty(`--save-figure-${key}`, value);
|
|
14
|
+
}
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
const updateButton = () => {
|
|
18
|
+
button.textContent = model.get("label");
|
|
19
|
+
button.disabled = !!model.get("disabled");
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
button.addEventListener("click", () => {
|
|
23
|
+
if (button.disabled) return;
|
|
24
|
+
const clicks = model.get("clicks") || 0;
|
|
25
|
+
model.set("clicks", clicks + 1);
|
|
26
|
+
model.save_changes();
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
model.on("change:label", updateButton);
|
|
30
|
+
model.on("change:disabled", updateButton);
|
|
31
|
+
model.on("change:theme_tokens", applyTheme);
|
|
32
|
+
|
|
33
|
+
applyTheme();
|
|
34
|
+
updateButton();
|
|
35
|
+
|
|
36
|
+
wrap.appendChild(button);
|
|
37
|
+
el.appendChild(wrap);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export default { render };
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
[plot-saver]
|
|
2
|
+
format = "png"
|
|
3
|
+
default_label = "Save"
|
|
4
|
+
save_all_label = "Save all model plots"
|
|
5
|
+
empty_title = "No plots available"
|
|
6
|
+
empty_detail = "Render the notebook plots first."
|
|
7
|
+
saved_title = "Saved"
|
|
8
|
+
saved_many_title = "Saved {count} plots"
|
|
9
|
+
saved_error_title = "Saved with errors"
|
|
10
|
+
failed_title = "Could not save plots"
|
|
11
|
+
toast_detail_color = "#6b7280"
|
|
12
|
+
|
|
13
|
+
[plot-saver.theme]
|
|
14
|
+
radius = "8px"
|
|
15
|
+
padding_y = "0.45rem"
|
|
16
|
+
padding_x = "0.9rem"
|
|
17
|
+
font_size = "0.9rem"
|
|
18
|
+
light_border = "rgba(107, 114, 128, 0.35)"
|
|
19
|
+
light_background = "rgba(249, 250, 251, 0.95)"
|
|
20
|
+
light_text = "#111827"
|
|
21
|
+
light_hover_background = "rgba(243, 244, 246, 1)"
|
|
22
|
+
light_hover_border = "rgba(107, 114, 128, 0.55)"
|
|
23
|
+
light_disabled_background = "rgba(229, 231, 235, 0.9)"
|
|
24
|
+
light_disabled_border = "rgba(156, 163, 175, 0.35)"
|
|
25
|
+
dark_border = "rgba(156, 163, 175, 0.35)"
|
|
26
|
+
dark_background = "rgba(31, 41, 55, 0.95)"
|
|
27
|
+
dark_text = "#f3f4f6"
|
|
28
|
+
dark_hover_background = "rgba(55, 65, 81, 1)"
|
|
29
|
+
dark_hover_border = "rgba(209, 213, 219, 0.45)"
|
|
30
|
+
dark_disabled_background = "rgba(55, 65, 81, 0.8)"
|
|
31
|
+
dark_disabled_border = "rgba(107, 114, 128, 0.35)"
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import anywidget
|
|
8
|
+
import traitlets
|
|
9
|
+
|
|
10
|
+
from .anywidget_compat import wrap_anywidget
|
|
11
|
+
from .config import find_project_config_path, load_app_config
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_ASSET_DIR = Path(__file__).parent
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _read_asset(name: str) -> str:
|
|
18
|
+
return (_ASSET_DIR / name).read_text(encoding="utf-8")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _get_save_figure_config(config_path: Path | None) -> dict[str, str]:
|
|
22
|
+
cfg = load_app_config(config_path)
|
|
23
|
+
section = cfg.get("plot-saver", {})
|
|
24
|
+
if not isinstance(section, dict):
|
|
25
|
+
return {}
|
|
26
|
+
save_cfg = {str(key): str(value) for key, value in section.items() if not isinstance(value, dict)}
|
|
27
|
+
theme_section = section.get("theme", {})
|
|
28
|
+
if isinstance(theme_section, dict):
|
|
29
|
+
save_cfg.update({str(key): str(value) for key, value in theme_section.items()})
|
|
30
|
+
return save_cfg
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _save_figure_theme_tokens(save_cfg: dict[str, str]) -> dict[str, str]:
|
|
34
|
+
token_keys = (
|
|
35
|
+
"radius",
|
|
36
|
+
"padding_y",
|
|
37
|
+
"padding_x",
|
|
38
|
+
"font_size",
|
|
39
|
+
"light_border",
|
|
40
|
+
"light_background",
|
|
41
|
+
"light_text",
|
|
42
|
+
"light_hover_background",
|
|
43
|
+
"light_hover_border",
|
|
44
|
+
"light_disabled_background",
|
|
45
|
+
"light_disabled_border",
|
|
46
|
+
"dark_border",
|
|
47
|
+
"dark_background",
|
|
48
|
+
"dark_text",
|
|
49
|
+
"dark_hover_background",
|
|
50
|
+
"dark_hover_border",
|
|
51
|
+
"dark_disabled_background",
|
|
52
|
+
"dark_disabled_border",
|
|
53
|
+
)
|
|
54
|
+
return {
|
|
55
|
+
key.replace("_", "-"): value
|
|
56
|
+
for key in token_keys
|
|
57
|
+
if (value := save_cfg.get(key))
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _toast_detail_html(detail: str, save_cfg: dict[str, str]) -> str:
|
|
62
|
+
color = save_cfg.get("toast_detail_color", "#6b7280")
|
|
63
|
+
return f"<span style='color:{color}'>{detail}</span>"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SaveFigureAnyWidget(anywidget.AnyWidget):
|
|
67
|
+
_esm = _read_asset("figure_save_widget.js")
|
|
68
|
+
_css = _read_asset("figure_save_widget.css")
|
|
69
|
+
|
|
70
|
+
clicks = traitlets.Int(0).tag(sync=True)
|
|
71
|
+
label = traitlets.Unicode("Save").tag(sync=True)
|
|
72
|
+
disabled = traitlets.Bool(False).tag(sync=True)
|
|
73
|
+
theme_tokens = traitlets.Dict(default_value={}).tag(sync=True)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_plot_save_format(config_path: Path | None = None) -> str:
|
|
77
|
+
cfg = load_app_config(config_path)
|
|
78
|
+
fmt = str(cfg.get("plot-saver", {}).get("format", "pdf")).lower().strip(". ")
|
|
79
|
+
if fmt not in {"pdf", "svg", "png"}:
|
|
80
|
+
fmt = "pdf"
|
|
81
|
+
return fmt
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def sanitize_stem(stem: str) -> str:
|
|
85
|
+
stem = re.sub(r"[^A-Za-z0-9._-]+", "_", stem).strip("._-")
|
|
86
|
+
return stem or "figure"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def build_plot_path(results_dir: Path, task_name: str, model_id: str, stem: str, fmt: str) -> Path:
|
|
90
|
+
out_dir = results_dir / "plots" / task_name / model_id
|
|
91
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
return out_dir / f"{sanitize_stem(stem)}.{fmt}"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _axis_for_location(fig: Any, location: tuple[int, int]):
|
|
96
|
+
row, col = (int(location[0]), int(location[1]))
|
|
97
|
+
for ax in getattr(fig, "axes", []):
|
|
98
|
+
get_subplotspec = getattr(ax, "get_subplotspec", None)
|
|
99
|
+
if get_subplotspec is None:
|
|
100
|
+
continue
|
|
101
|
+
spec = get_subplotspec()
|
|
102
|
+
if spec is None:
|
|
103
|
+
continue
|
|
104
|
+
if row in range(spec.rowspan.start, spec.rowspan.stop) and col in range(spec.colspan.start, spec.colspan.stop):
|
|
105
|
+
return ax
|
|
106
|
+
raise ValueError(f"No subplot found at location ({row}, {col}).")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _save_axis(fig: Any, ax: Any, out_path: Path, fmt: str) -> Path:
|
|
110
|
+
if hasattr(fig, "canvas") and fig.canvas is not None:
|
|
111
|
+
fig.canvas.draw()
|
|
112
|
+
renderer = fig.canvas.get_renderer()
|
|
113
|
+
bbox = ax.get_tightbbox(renderer).transformed(fig.dpi_scale_trans.inverted())
|
|
114
|
+
save_kwargs = {"bbox_inches": bbox}
|
|
115
|
+
if fmt != "svg":
|
|
116
|
+
save_kwargs["dpi"] = 300
|
|
117
|
+
fig.savefig(out_path, **save_kwargs)
|
|
118
|
+
return out_path
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def save_figure(
|
|
122
|
+
fig,
|
|
123
|
+
*,
|
|
124
|
+
results_dir: Path,
|
|
125
|
+
config_path: Path | None,
|
|
126
|
+
task_name: str,
|
|
127
|
+
model_id: str,
|
|
128
|
+
stem: str,
|
|
129
|
+
location: tuple[int, int] | None = None,
|
|
130
|
+
) -> Path:
|
|
131
|
+
fmt = get_plot_save_format(config_path)
|
|
132
|
+
out_path = build_plot_path(results_dir, task_name, model_id, stem, fmt)
|
|
133
|
+
if location is not None:
|
|
134
|
+
ax = _axis_for_location(fig, location)
|
|
135
|
+
return _save_axis(fig, ax, out_path, fmt)
|
|
136
|
+
if hasattr(fig, "canvas") and fig.canvas is not None:
|
|
137
|
+
fig.canvas.draw()
|
|
138
|
+
save_kwargs = {"bbox_inches": "tight"}
|
|
139
|
+
if fmt != "svg":
|
|
140
|
+
save_kwargs["dpi"] = 300
|
|
141
|
+
fig.savefig(out_path, **save_kwargs)
|
|
142
|
+
return out_path
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class PlotSaver:
|
|
146
|
+
def __init__(self, mo, *, results_dir: Path, config_path: Path | None, task_name: str, model_id: str):
|
|
147
|
+
self.mo = mo
|
|
148
|
+
self.results_dir = results_dir
|
|
149
|
+
self.config_path = config_path or find_project_config_path()
|
|
150
|
+
self.task_name = task_name
|
|
151
|
+
self.model_id = model_id
|
|
152
|
+
self.fmt = get_plot_save_format(self.config_path)
|
|
153
|
+
self.save_cfg = _get_save_figure_config(self.config_path)
|
|
154
|
+
self.theme_tokens = _save_figure_theme_tokens(self.save_cfg)
|
|
155
|
+
self._registry: dict[str, dict[str, object]] = {}
|
|
156
|
+
self._save_all = SaveFigureAnyWidget(
|
|
157
|
+
label=self.save_cfg.get("save_all_label", "Save all model plots"),
|
|
158
|
+
disabled=True,
|
|
159
|
+
theme_tokens=self.theme_tokens,
|
|
160
|
+
)
|
|
161
|
+
self._save_all.observe(self._handle_save_all_click, names="clicks")
|
|
162
|
+
self._save_all._save_observer = self._handle_save_all_click
|
|
163
|
+
self._save_all_ui = None
|
|
164
|
+
|
|
165
|
+
def _save_one(self, fig, *, stem: str, location: tuple[int, int] | None = None) -> Path:
|
|
166
|
+
return save_figure(
|
|
167
|
+
fig,
|
|
168
|
+
results_dir=self.results_dir,
|
|
169
|
+
config_path=self.config_path,
|
|
170
|
+
task_name=self.task_name,
|
|
171
|
+
model_id=self.model_id,
|
|
172
|
+
stem=stem,
|
|
173
|
+
location=location,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def _register(self, fig, *, name: str, stem: str, location: tuple[int, int] | None = None) -> None:
|
|
177
|
+
self._registry[stem] = {
|
|
178
|
+
"fig": fig,
|
|
179
|
+
"name": name,
|
|
180
|
+
"stem": stem,
|
|
181
|
+
"location": location,
|
|
182
|
+
}
|
|
183
|
+
self._save_all.disabled = not bool(self._registry)
|
|
184
|
+
|
|
185
|
+
def _saved_message(self, saved_paths: list[Path]) -> str:
|
|
186
|
+
if not saved_paths:
|
|
187
|
+
return "No files saved."
|
|
188
|
+
if len(saved_paths) == 1:
|
|
189
|
+
return saved_paths[0].name
|
|
190
|
+
return f"{saved_paths[0].name} + {len(saved_paths) - 1} more"
|
|
191
|
+
|
|
192
|
+
def save_all(self) -> tuple[list[Path], list[tuple[str, Exception]]]:
|
|
193
|
+
saved_paths: list[Path] = []
|
|
194
|
+
errors: list[tuple[str, Exception]] = []
|
|
195
|
+
for item in list(self._registry.values()):
|
|
196
|
+
try:
|
|
197
|
+
out_path = self._save_one(
|
|
198
|
+
item["fig"],
|
|
199
|
+
stem=str(item["stem"]),
|
|
200
|
+
location=item["location"],
|
|
201
|
+
)
|
|
202
|
+
saved_paths.append(out_path)
|
|
203
|
+
except Exception as exc:
|
|
204
|
+
errors.append((str(item["name"]), exc))
|
|
205
|
+
return saved_paths, errors
|
|
206
|
+
|
|
207
|
+
def _handle_save_all_click(self, change) -> None:
|
|
208
|
+
if int(change["new"]) <= int(change["old"]):
|
|
209
|
+
return
|
|
210
|
+
if not self._registry:
|
|
211
|
+
self.mo.status.toast(
|
|
212
|
+
self.save_cfg.get("empty_title", "No plots available"),
|
|
213
|
+
_toast_detail_html(
|
|
214
|
+
self.save_cfg.get("empty_detail", "Render the notebook plots first."),
|
|
215
|
+
self.save_cfg,
|
|
216
|
+
),
|
|
217
|
+
kind="danger",
|
|
218
|
+
)
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
saved_paths, errors = self.save_all()
|
|
222
|
+
if errors:
|
|
223
|
+
detail = self._saved_message(saved_paths)
|
|
224
|
+
if saved_paths:
|
|
225
|
+
detail = f"{detail}; {len(errors)} failed"
|
|
226
|
+
else:
|
|
227
|
+
detail = f"{len(errors)} failed"
|
|
228
|
+
self.mo.status.toast(
|
|
229
|
+
self.save_cfg.get(
|
|
230
|
+
"saved_error_title" if saved_paths else "failed_title",
|
|
231
|
+
"Saved with errors" if saved_paths else "Could not save plots",
|
|
232
|
+
),
|
|
233
|
+
_toast_detail_html(detail, self.save_cfg),
|
|
234
|
+
kind="danger",
|
|
235
|
+
)
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
count = len(saved_paths)
|
|
239
|
+
title = (
|
|
240
|
+
self.save_cfg.get("saved_title", "Saved")
|
|
241
|
+
if count == 1
|
|
242
|
+
else self.save_cfg.get("saved_many_title", "Saved {count} plots").format(count=count)
|
|
243
|
+
)
|
|
244
|
+
self.mo.status.toast(
|
|
245
|
+
title,
|
|
246
|
+
_toast_detail_html(self._saved_message(saved_paths), self.save_cfg),
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def save_all_widget(self, label: str | None = None):
|
|
250
|
+
self._save_all.label = label or self.save_cfg.get("save_all_label", "Save all model plots")
|
|
251
|
+
if self._save_all_ui is None:
|
|
252
|
+
self._save_all_ui = wrap_anywidget(self._save_all)
|
|
253
|
+
return self._save_all_ui
|
|
254
|
+
|
|
255
|
+
def __call__(
|
|
256
|
+
self,
|
|
257
|
+
fig,
|
|
258
|
+
name: str,
|
|
259
|
+
*,
|
|
260
|
+
stem: str | None = None,
|
|
261
|
+
label: str | None = None,
|
|
262
|
+
location: tuple[int, int] | None = None,
|
|
263
|
+
):
|
|
264
|
+
if location is not None:
|
|
265
|
+
row, col = (int(location[0]), int(location[1]))
|
|
266
|
+
default_stem = f"{sanitize_stem(name.lower())}_r{row}_c{col}"
|
|
267
|
+
else:
|
|
268
|
+
default_stem = sanitize_stem(name.lower())
|
|
269
|
+
resolved_stem = stem or default_stem
|
|
270
|
+
button_label = label or f"{self.save_cfg.get('default_label', 'Save')} .{self.fmt}"
|
|
271
|
+
self._register(fig, name=name, stem=resolved_stem, location=location)
|
|
272
|
+
widget = SaveFigureAnyWidget(label=button_label, theme_tokens=self.theme_tokens)
|
|
273
|
+
|
|
274
|
+
def _handle_click(change):
|
|
275
|
+
if int(change["new"]) <= int(change["old"]):
|
|
276
|
+
return
|
|
277
|
+
try:
|
|
278
|
+
out_path = self._save_one(fig, stem=resolved_stem, location=location)
|
|
279
|
+
self.mo.status.toast(
|
|
280
|
+
self.save_cfg.get("saved_title", "Saved"),
|
|
281
|
+
_toast_detail_html(out_path.name, self.save_cfg),
|
|
282
|
+
)
|
|
283
|
+
except Exception as exc:
|
|
284
|
+
self.mo.status.toast(
|
|
285
|
+
self.save_cfg.get("failed_title", "Could not save plots"),
|
|
286
|
+
_toast_detail_html(f"{type(exc).__name__}: {exc}", self.save_cfg),
|
|
287
|
+
kind="danger",
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
widget.observe(_handle_click, names="clicks")
|
|
291
|
+
widget._save_observer = _handle_click
|
|
292
|
+
return wrap_anywidget(widget)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def make_plot_saver(mo, *, results_dir: Path, config_path: Path | None, task_name: str, model_id: str):
|
|
296
|
+
return PlotSaver(
|
|
297
|
+
mo,
|
|
298
|
+
results_dir=results_dir,
|
|
299
|
+
config_path=config_path,
|
|
300
|
+
task_name=task_name,
|
|
301
|
+
model_id=model_id,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def save_button(
|
|
306
|
+
mo,
|
|
307
|
+
fig,
|
|
308
|
+
*,
|
|
309
|
+
results_dir: Path,
|
|
310
|
+
config_path: Path | None,
|
|
311
|
+
task_name: str,
|
|
312
|
+
model_id: str,
|
|
313
|
+
stem: str,
|
|
314
|
+
label: str | None = None,
|
|
315
|
+
location: tuple[int, int] | None = None,
|
|
316
|
+
):
|
|
317
|
+
fmt = get_plot_save_format(config_path)
|
|
318
|
+
default_label = label or _get_save_figure_config(config_path).get("default_label", "Save")
|
|
319
|
+
return make_plot_saver(
|
|
320
|
+
mo,
|
|
321
|
+
results_dir=results_dir,
|
|
322
|
+
config_path=config_path,
|
|
323
|
+
task_name=task_name,
|
|
324
|
+
model_id=model_id,
|
|
325
|
+
)(
|
|
326
|
+
fig,
|
|
327
|
+
name=default_label,
|
|
328
|
+
stem=stem,
|
|
329
|
+
label=f"{default_label} .{fmt}",
|
|
330
|
+
location=location,
|
|
331
|
+
)
|