nb2slurm 0.0.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nb2slurm/__init__.py +37 -0
- nb2slurm/config.py +70 -0
- nb2slurm/done.py +65 -0
- nb2slurm/environment.py +140 -0
- nb2slurm/render.py +48 -0
- nb2slurm/runtime.py +43 -0
- nb2slurm/settings.py +46 -0
- nb2slurm/ssh.py +211 -0
- nb2slurm/structure.py +90 -0
- nb2slurm/templates/cancel_jobs.sh.j2 +14 -0
- nb2slurm/templates/job.slurm.j2 +29 -0
- nb2slurm/templates/run_workflow.py.j2 +65 -0
- nb2slurm/templates/submit_batch.sh.j2 +39 -0
- nb2slurm/templates/submit_jobs.sh.j2 +38 -0
- nb2slurm/workflow.py +407 -0
- nb2slurm-0.0.1.dev1.dist-info/METADATA +530 -0
- nb2slurm-0.0.1.dev1.dist-info/RECORD +19 -0
- nb2slurm-0.0.1.dev1.dist-info/WHEEL +4 -0
- nb2slurm-0.0.1.dev1.dist-info/licenses/LICENSE +201 -0
nb2slurm/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""nb2slurm: scale a single-subject notebook workflow to many subjects on SLURM.
|
|
2
|
+
|
|
3
|
+
The public surface is intentionally tiny so the whole workflow can be driven
|
|
4
|
+
from a notebook with no command line:
|
|
5
|
+
|
|
6
|
+
import nb2slurm
|
|
7
|
+
|
|
8
|
+
wf = nb2slurm.Workflow(
|
|
9
|
+
name="square",
|
|
10
|
+
notebooks=["notebooks/0_settings.ipynb", "notebooks/1_compute.ipynb"],
|
|
11
|
+
kernel="python3",
|
|
12
|
+
varying=["item_id"],
|
|
13
|
+
resources=dict(nodes=1, cpus=2, time="00:10:00"),
|
|
14
|
+
)
|
|
15
|
+
wf.build() # render scripts/ into the project
|
|
16
|
+
wf.submit([1, 2, 3], ssh=cfg) # sbatch one job per item, over SSH
|
|
17
|
+
wf.status(ssh=cfg) # squeue -> list of dicts
|
|
18
|
+
wf.cancel(ssh=cfg) # scancel the jobs we submitted
|
|
19
|
+
|
|
20
|
+
Inside the notebooks themselves, use the Settings helper:
|
|
21
|
+
|
|
22
|
+
nb2slurm.Settings.write(outdir, {...}) # first notebook
|
|
23
|
+
settings = nb2slurm.Settings.load(path) # later notebooks
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .workflow import Workflow
|
|
27
|
+
from .environment import Environment
|
|
28
|
+
from .ssh import SSHConfig, generate_key, public_key
|
|
29
|
+
from .settings import Settings
|
|
30
|
+
from .done import Done
|
|
31
|
+
from .structure import Structure
|
|
32
|
+
from .config import save_config, load_config
|
|
33
|
+
from .runtime import on_hpc
|
|
34
|
+
|
|
35
|
+
__all__ = ["Workflow", "Environment", "SSHConfig", "generate_key", "public_key",
|
|
36
|
+
"Settings", "Done", "Structure", "save_config", "load_config", "on_hpc"]
|
|
37
|
+
__version__ = "0.0.1.dev1"
|
nb2slurm/config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Save / load a run's control settings to one JSON file.
|
|
2
|
+
|
|
3
|
+
The config notebook builds a Workflow (+ SSHConfig) and calls ``save_config``;
|
|
4
|
+
the build / submit / sync notebooks call ``load_config`` to get them back. This
|
|
5
|
+
keeps a single source of truth on disk, so the notebooks never duplicate settings.
|
|
6
|
+
|
|
7
|
+
# in 0_config.ipynb
|
|
8
|
+
nb2slurm.save_config("control_config.json", workflow=wf, ssh=cfg)
|
|
9
|
+
|
|
10
|
+
# in 1_build / 2_submit / 3_sync
|
|
11
|
+
wf, cfg = nb2slurm.load_config("control_config.json")
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
from dataclasses import asdict
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from .environment import Environment
|
|
22
|
+
from .ssh import SSHConfig
|
|
23
|
+
from .workflow import Workflow
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _workflow_to_dict(wf: Workflow) -> dict:
|
|
27
|
+
d = asdict(wf) # recurses Environment, resources, mounts
|
|
28
|
+
d.pop("submitted_jobs", None) # runtime state, not configuration
|
|
29
|
+
return d
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _workflow_from_dict(d: dict) -> Workflow:
|
|
33
|
+
d = dict(d)
|
|
34
|
+
d.pop("submitted_jobs", None)
|
|
35
|
+
env = d.pop("environment", None)
|
|
36
|
+
extras = d.pop("extra_environments", None) or []
|
|
37
|
+
return Workflow(
|
|
38
|
+
**d,
|
|
39
|
+
environment=Environment(**env) if env else None,
|
|
40
|
+
extra_environments=[Environment(**e) for e in extras],
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _ssh_to_dict(ssh: SSHConfig) -> dict:
|
|
45
|
+
d = asdict(ssh)
|
|
46
|
+
d.pop("password", None) # never persist secrets to disk
|
|
47
|
+
return d
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def save_config(path: str | Path, *, workflow: Workflow,
|
|
51
|
+
ssh: Optional[SSHConfig] = None) -> Path:
|
|
52
|
+
"""Write the workflow (and optional SSH) config to ``path`` as JSON."""
|
|
53
|
+
data = {"workflow": _workflow_to_dict(workflow)}
|
|
54
|
+
if ssh is not None:
|
|
55
|
+
data["ssh"] = _ssh_to_dict(ssh)
|
|
56
|
+
path = Path(path)
|
|
57
|
+
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
|
58
|
+
return path
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_config(path: str | Path) -> tuple[Workflow, Optional[SSHConfig]]:
|
|
62
|
+
"""Read a config file back into ``(workflow, ssh)``.
|
|
63
|
+
|
|
64
|
+
``ssh`` is ``None`` if none was saved. Passwords are never stored, so set
|
|
65
|
+
``ssh.password`` yourself afterwards if your cluster needs one.
|
|
66
|
+
"""
|
|
67
|
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
68
|
+
wf = _workflow_from_dict(data["workflow"])
|
|
69
|
+
ssh = SSHConfig(**data["ssh"]) if data.get("ssh") else None
|
|
70
|
+
return wf, ssh
|
nb2slurm/done.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Idempotent 'already done?' bookkeeping, as an object.
|
|
2
|
+
|
|
3
|
+
Generalised from cci.py: a single CSV records which subjects have completed, so
|
|
4
|
+
re-submitting the whole set skips finished work. A file lock makes it safe when
|
|
5
|
+
many jobs write concurrently on a shared filesystem.
|
|
6
|
+
|
|
7
|
+
done = nb2slurm.Done("done/done.csv")
|
|
8
|
+
if done.is_done(key):
|
|
9
|
+
...
|
|
10
|
+
done.mark(key)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import csv
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
from filelock import FileLock
|
|
19
|
+
|
|
20
|
+
LOCK_TIMEOUT = 60 * 3
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Done:
|
|
24
|
+
"""A CSV ledger of completed subjects, with concurrency-safe writes."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, csv_file: str | Path):
|
|
27
|
+
self.csv_file = Path(csv_file)
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def _lock(self) -> str:
|
|
31
|
+
return str(self.csv_file) + ".lock"
|
|
32
|
+
|
|
33
|
+
def is_done(self, key: str) -> bool:
|
|
34
|
+
"""Return True if ``key`` is already recorded."""
|
|
35
|
+
if not self.csv_file.exists():
|
|
36
|
+
return False
|
|
37
|
+
with FileLock(self._lock, timeout=LOCK_TIMEOUT):
|
|
38
|
+
with open(self.csv_file, newline="") as f:
|
|
39
|
+
reader = csv.reader(f)
|
|
40
|
+
next(reader, None) # header
|
|
41
|
+
for row in reader:
|
|
42
|
+
if row and row[0] == str(key):
|
|
43
|
+
return True
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
def mark(self, key: str) -> None:
|
|
47
|
+
"""Record ``key`` as done (no-op if already present)."""
|
|
48
|
+
self.csv_file.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
if not self.csv_file.exists():
|
|
50
|
+
with open(self.csv_file, "w", newline="") as f:
|
|
51
|
+
csv.writer(f).writerow(["key"])
|
|
52
|
+
with FileLock(self._lock, timeout=LOCK_TIMEOUT):
|
|
53
|
+
existing = set()
|
|
54
|
+
with open(self.csv_file, newline="") as f:
|
|
55
|
+
reader = csv.reader(f)
|
|
56
|
+
next(reader, None)
|
|
57
|
+
for row in reader:
|
|
58
|
+
if row:
|
|
59
|
+
existing.add(row[0])
|
|
60
|
+
if str(key) not in existing:
|
|
61
|
+
with open(self.csv_file, "a", newline="") as f:
|
|
62
|
+
csv.writer(f).writerow([key])
|
|
63
|
+
print(f"Marked {key} as done.")
|
|
64
|
+
else:
|
|
65
|
+
print(f"{key} already recorded as done.")
|
nb2slurm/environment.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Create the conda environment + Jupyter kernel the workflow runs in.
|
|
2
|
+
|
|
3
|
+
Most users of nb2slurm are not Linux/conda experts, but the generated SLURM job
|
|
4
|
+
does ``conda activate <env>`` and papermill needs a *registered Jupyter kernel*
|
|
5
|
+
to execute the notebooks. This module writes an ``environment.yml`` and creates
|
|
6
|
+
the environment + kernel on the cluster (or locally), so the user never touches
|
|
7
|
+
the command line.
|
|
8
|
+
|
|
9
|
+
from nb2slurm import Environment
|
|
10
|
+
|
|
11
|
+
env = Environment(
|
|
12
|
+
name="myenv",
|
|
13
|
+
kernel="myenv", # must match Workflow(kernel=...)
|
|
14
|
+
conda_packages=["xarray", "numpy"],
|
|
15
|
+
pip_packages=["nb2slurm", "ewatercycle"],
|
|
16
|
+
)
|
|
17
|
+
env.write() # -> environment.yml
|
|
18
|
+
env.create(ssh=cfg) # build env + register kernel on the HPC
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Optional
|
|
26
|
+
|
|
27
|
+
from .ssh import CommandResult, SSHConfig, run_shell
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class Environment:
|
|
32
|
+
name: str
|
|
33
|
+
kernel: str
|
|
34
|
+
python: str = "3.11"
|
|
35
|
+
channels: list[str] = field(default_factory=lambda: ["conda-forge"])
|
|
36
|
+
conda_packages: list[str] = field(default_factory=list)
|
|
37
|
+
pip_packages: list[str] = field(default_factory=lambda: ["nb2slurm"])
|
|
38
|
+
|
|
39
|
+
def to_yaml(self) -> str:
|
|
40
|
+
"""Render an ``environment.yml`` for conda/mamba."""
|
|
41
|
+
lines = [f"name: {self.name}", "channels:"]
|
|
42
|
+
lines += [f" - {c}" for c in self.channels]
|
|
43
|
+
lines.append("dependencies:")
|
|
44
|
+
lines.append(f" - python={self.python}")
|
|
45
|
+
lines.append(" - pip")
|
|
46
|
+
lines.append(" - ipykernel") # required so papermill can run the notebooks
|
|
47
|
+
lines += [f" - {p}" for p in self.conda_packages]
|
|
48
|
+
if self.pip_packages:
|
|
49
|
+
lines.append(" - pip:")
|
|
50
|
+
lines += [f" - {p}" for p in self.pip_packages]
|
|
51
|
+
return "\n".join(lines) + "\n"
|
|
52
|
+
|
|
53
|
+
def write(self, project_dir: str | Path = ".", filename: str = "environment.yml") -> Path:
|
|
54
|
+
"""Write the ``environment.yml`` into the project directory."""
|
|
55
|
+
path = Path(project_dir) / filename
|
|
56
|
+
path.write_text(self.to_yaml(), encoding="utf-8")
|
|
57
|
+
return path
|
|
58
|
+
|
|
59
|
+
def _exists_test(self) -> str:
|
|
60
|
+
"""A shell test (exit 0 = env exists). Matches the name as a whole path
|
|
61
|
+
component so ``montecarlo`` doesn't match ``montecarlo2``."""
|
|
62
|
+
return f'conda env list | grep -qE "[/ ]{self.name}([ /]|$)"'
|
|
63
|
+
|
|
64
|
+
def _create_command(self, filename: str = "environment.yml") -> str:
|
|
65
|
+
# mamba is much faster than conda; use it when available.
|
|
66
|
+
return (
|
|
67
|
+
"set -e; "
|
|
68
|
+
# There's no TTY over ssh, so any interactive prompt would hang the
|
|
69
|
+
# build forever. Belt and suspenders: set always-yes AND pipe `yes`
|
|
70
|
+
# into conda (CONDA_ALWAYS_YES alone is ignored by some mamba builds
|
|
71
|
+
# for the 'Confirm changes? [Y/n]' transaction prompt).
|
|
72
|
+
"export CONDA_ALWAYS_YES=yes; "
|
|
73
|
+
"CONDA=conda; command -v mamba >/dev/null 2>&1 && CONDA=mamba; "
|
|
74
|
+
# update an existing env in place rather than hitting the interactive
|
|
75
|
+
# 'Found conda-prefix ... Overwrite? [y/N]' prompt.
|
|
76
|
+
f"if {self._exists_test()}; then "
|
|
77
|
+
f' echo "Updating existing environment {self.name} with $CONDA..."; '
|
|
78
|
+
f" yes | $CONDA env update -f {filename}; "
|
|
79
|
+
f"else "
|
|
80
|
+
f' echo "Creating environment {self.name} with $CONDA..."; '
|
|
81
|
+
f" yes | $CONDA env create -f {filename}; "
|
|
82
|
+
f"fi; "
|
|
83
|
+
f"conda run -n {self.name} python -m ipykernel install --user "
|
|
84
|
+
f'--name {self.kernel} --display-name "{self.kernel}"; '
|
|
85
|
+
f'echo "Environment {self.name} ready; kernel {self.kernel} registered."'
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _remove_command(self) -> str:
|
|
89
|
+
return (
|
|
90
|
+
"export CONDA_ALWAYS_YES=yes; "
|
|
91
|
+
f'echo "Removing environment {self.name} and kernel {self.kernel}..."; '
|
|
92
|
+
# `|| true`: removing a non-existent env/kernel is not an error here
|
|
93
|
+
f"conda env remove -n {self.name} || true; "
|
|
94
|
+
f"jupyter kernelspec remove -f {self.kernel} 2>/dev/null || true; "
|
|
95
|
+
f'echo "Removed {self.name}."'
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def exists(
|
|
99
|
+
self,
|
|
100
|
+
ssh: Optional[SSHConfig] = None,
|
|
101
|
+
project_dir: str | Path = ".",
|
|
102
|
+
) -> bool:
|
|
103
|
+
"""Return True if the conda env already exists — on the HPC (ssh) or locally."""
|
|
104
|
+
return run_shell(self._exists_test(), ssh, str(project_dir)).exit_status == 0
|
|
105
|
+
|
|
106
|
+
def remove(
|
|
107
|
+
self,
|
|
108
|
+
ssh: Optional[SSHConfig] = None,
|
|
109
|
+
project_dir: str | Path = ".",
|
|
110
|
+
stream: bool = True,
|
|
111
|
+
) -> CommandResult:
|
|
112
|
+
"""Delete the conda env and its Jupyter kernel — on the HPC (ssh) or locally.
|
|
113
|
+
|
|
114
|
+
Safe to call when nothing is there yet (a missing env/kernel is ignored).
|
|
115
|
+
Use it to recover from a half-built env or to force a clean rebuild.
|
|
116
|
+
"""
|
|
117
|
+
return run_shell(self._remove_command(), ssh, str(project_dir), stream=stream).check()
|
|
118
|
+
|
|
119
|
+
def create(
|
|
120
|
+
self,
|
|
121
|
+
ssh: Optional[SSHConfig] = None,
|
|
122
|
+
project_dir: str | Path = ".",
|
|
123
|
+
filename: str = "environment.yml",
|
|
124
|
+
stream: bool = True,
|
|
125
|
+
overwrite: bool = False,
|
|
126
|
+
) -> CommandResult:
|
|
127
|
+
"""Create the env and register the kernel — on the HPC (ssh) or locally.
|
|
128
|
+
|
|
129
|
+
Writes ``environment.yml`` first if it is missing. If the env already
|
|
130
|
+
exists it is *updated* in place; pass ``overwrite=True`` to delete and
|
|
131
|
+
rebuild it from scratch. ``stream=True`` (the default) echoes conda/pip
|
|
132
|
+
output live, since a solve + downloads can take minutes and would
|
|
133
|
+
otherwise look like a hang.
|
|
134
|
+
"""
|
|
135
|
+
if not (Path(project_dir) / filename).exists():
|
|
136
|
+
self.write(project_dir, filename)
|
|
137
|
+
if overwrite:
|
|
138
|
+
self.remove(ssh=ssh, project_dir=project_dir, stream=stream)
|
|
139
|
+
command = self._create_command(filename)
|
|
140
|
+
return run_shell(command, ssh, str(project_dir), stream=stream).check()
|
nb2slurm/render.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Render the bundled Jinja2 templates into a project's scripts/ directory."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib import resources
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Mapping
|
|
8
|
+
|
|
9
|
+
from jinja2 import Environment
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _env() -> Environment:
|
|
13
|
+
# No autoescaping: these are shell/python scripts, not HTML.
|
|
14
|
+
# Move the comment delimiters off the default "{# #}" so bash parameter
|
|
15
|
+
# expansions like ${#arr[@]} don't get mistaken for Jinja comments.
|
|
16
|
+
return Environment(
|
|
17
|
+
trim_blocks=True,
|
|
18
|
+
lstrip_blocks=True,
|
|
19
|
+
keep_trailing_newline=True,
|
|
20
|
+
comment_start_string="<#nb2slurm#",
|
|
21
|
+
comment_end_string="#nb2slurm#>",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def render_template(template_name: str, context: Mapping[str, Any]) -> str:
|
|
26
|
+
"""Render one bundled template (e.g. 'run_workflow.py.j2') to a string."""
|
|
27
|
+
source = (
|
|
28
|
+
resources.files("nb2slurm.templates")
|
|
29
|
+
.joinpath(template_name)
|
|
30
|
+
.read_text(encoding="utf-8")
|
|
31
|
+
)
|
|
32
|
+
return _env().from_string(source).render(**context)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def write_rendered(
|
|
36
|
+
template_name: str,
|
|
37
|
+
out_path: str | Path,
|
|
38
|
+
context: Mapping[str, Any],
|
|
39
|
+
executable: bool = False,
|
|
40
|
+
) -> Path:
|
|
41
|
+
"""Render a template and write it to ``out_path``."""
|
|
42
|
+
out_path = Path(out_path)
|
|
43
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
44
|
+
out_path.write_text(render_template(template_name, context), encoding="utf-8")
|
|
45
|
+
if executable:
|
|
46
|
+
mode = out_path.stat().st_mode
|
|
47
|
+
out_path.chmod(mode | 0o111)
|
|
48
|
+
return out_path
|
nb2slurm/runtime.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Detect whether the current notebook is running under nb2slurm on a cluster.
|
|
2
|
+
|
|
3
|
+
Checking ``Path.home()`` for a username is fragile (breaks for other users, can't
|
|
4
|
+
tell a cloud VM from a laptop). Instead we look at environment variables that are
|
|
5
|
+
only present in a batch job:
|
|
6
|
+
|
|
7
|
+
* the ``SLURM_*`` variables SLURM sets in every job, and
|
|
8
|
+
* the ``NB2SLURM`` sentinel that the generated ``job.slurm`` exports (so this also
|
|
9
|
+
works on non-SLURM batch setups that run nb2slurm's job script).
|
|
10
|
+
|
|
11
|
+
Use it in your notebooks for the things that genuinely differ between
|
|
12
|
+
interactive and batch runs — machine-specific data paths, skipping ``!pip install``:
|
|
13
|
+
|
|
14
|
+
import nb2slurm
|
|
15
|
+
if nb2slurm.on_hpc():
|
|
16
|
+
data_dir = "/project/ewater/Data"
|
|
17
|
+
else:
|
|
18
|
+
data_dir = "/data/shared"
|
|
19
|
+
|
|
20
|
+
It also cleans up importing a helper from ``scripts/``. On the cluster the job
|
|
21
|
+
runs from the project root, so ``from scripts.foo import bar`` just works; run
|
|
22
|
+
interactively from a ``notebooks/`` subfolder it doesn't, so add the project root
|
|
23
|
+
to the path only when running locally:
|
|
24
|
+
|
|
25
|
+
import sys
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
import nb2slurm
|
|
28
|
+
if not nb2slurm.on_hpc():
|
|
29
|
+
sys.path.append(str(Path().resolve().parent))
|
|
30
|
+
from scripts.montecarlo import estimate_pi
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
import os
|
|
36
|
+
|
|
37
|
+
# present inside a SLURM job; NB2SLURM is exported by the generated job.slurm
|
|
38
|
+
_BATCH_ENV_VARS = ("NB2SLURM", "SLURM_JOB_ID", "SLURM_JOBID")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def on_hpc() -> bool:
|
|
42
|
+
"""Return True if running inside an nb2slurm/SLURM batch job."""
|
|
43
|
+
return any(var in os.environ for var in _BATCH_ENV_VARS)
|
nb2slurm/settings.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""The per-run settings file, as an object.
|
|
2
|
+
|
|
3
|
+
The paper's notebook structure has the *first* notebook write a settings file
|
|
4
|
+
(JSON) that every later notebook reads. This keeps the varying parameters in one
|
|
5
|
+
place: only notebook 0 is parameterised by papermill, the rest just load it.
|
|
6
|
+
|
|
7
|
+
import nb2slurm
|
|
8
|
+
|
|
9
|
+
# first notebook (parameterised by nb2slurm):
|
|
10
|
+
nb2slurm.Settings.write(outdir, {"region_id": region_id, "outdir": outdir})
|
|
11
|
+
|
|
12
|
+
# every later notebook:
|
|
13
|
+
settings = nb2slurm.Settings.load(settings_path)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Mapping
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Settings:
|
|
24
|
+
"""Read/write the per-run ``settings.json``."""
|
|
25
|
+
|
|
26
|
+
filename = "settings.json"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def write(outdir: str | Path, settings: Mapping[str, Any]) -> Path:
|
|
30
|
+
"""Write ``settings`` to ``<outdir>/settings.json`` and return the path.
|
|
31
|
+
|
|
32
|
+
Call this from your first notebook. ``outdir`` is supplied to that
|
|
33
|
+
notebook by the generated runner, so you never hardcode a path.
|
|
34
|
+
"""
|
|
35
|
+
outdir = Path(outdir)
|
|
36
|
+
outdir.mkdir(parents=True, exist_ok=True)
|
|
37
|
+
path = outdir / Settings.filename
|
|
38
|
+
with open(path, "w") as f:
|
|
39
|
+
json.dump(dict(settings), f, indent=2)
|
|
40
|
+
return path
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def load(settings_path: str | Path) -> dict[str, Any]:
|
|
44
|
+
"""Load a settings file. Call this at the top of every later notebook."""
|
|
45
|
+
with open(settings_path) as f:
|
|
46
|
+
return json.load(f)
|
nb2slurm/ssh.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Minimal SSH transport so the managing notebook can drive SLURM with no CLI.
|
|
2
|
+
|
|
3
|
+
This mimics the command line / ssh that the paper says should be hidden from the
|
|
4
|
+
user: sbatch/squeue/scancel run on the cluster, but the user only writes Python.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional, Tuple
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class CommandResult:
|
|
19
|
+
command: str
|
|
20
|
+
exit_status: int
|
|
21
|
+
stdout: str
|
|
22
|
+
stderr: str
|
|
23
|
+
|
|
24
|
+
def check(self) -> "CommandResult":
|
|
25
|
+
if self.exit_status != 0:
|
|
26
|
+
raise RuntimeError(
|
|
27
|
+
f"Remote command failed ({self.exit_status}): {self.command}\n{self.stderr}"
|
|
28
|
+
)
|
|
29
|
+
return self
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class SSHConfig:
|
|
34
|
+
"""Connection details for the HPC login node.
|
|
35
|
+
|
|
36
|
+
Provide either ``key_filename`` or ``password`` (or rely on an agent/known
|
|
37
|
+
config). ``remote_dir`` is the project directory on the cluster that the
|
|
38
|
+
generated scripts live in; commands are run from there.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
host: str
|
|
42
|
+
user: str
|
|
43
|
+
remote_dir: str
|
|
44
|
+
port: int = 22
|
|
45
|
+
key_filename: Optional[str] = None
|
|
46
|
+
password: Optional[str] = None
|
|
47
|
+
extra_connect_kwargs: dict = field(default_factory=dict)
|
|
48
|
+
|
|
49
|
+
def key_path(self) -> Optional[str]:
|
|
50
|
+
"""The private key path with ``~`` expanded, or ``None`` if unset.
|
|
51
|
+
|
|
52
|
+
paramiko opens ``key_filename`` directly and does **not** expand ``~``,
|
|
53
|
+
so we resolve it here (e.g. ``~/.ssh/id_rsa`` -> the absolute path).
|
|
54
|
+
"""
|
|
55
|
+
return os.path.expanduser(self.key_filename) if self.key_filename else None
|
|
56
|
+
|
|
57
|
+
def rsync_ssh(self) -> str:
|
|
58
|
+
"""The ``-e`` transport string rsync should use (ssh + port + key)."""
|
|
59
|
+
parts = ["ssh"]
|
|
60
|
+
if self.port != 22:
|
|
61
|
+
parts += ["-p", str(self.port)]
|
|
62
|
+
if self.key_filename:
|
|
63
|
+
parts += ["-i", self.key_path()]
|
|
64
|
+
return " ".join(parts)
|
|
65
|
+
|
|
66
|
+
def rsync_target(self, subpath: str = "") -> str:
|
|
67
|
+
"""A ``user@host:remote_dir/<subpath>`` spec for rsync."""
|
|
68
|
+
base = self.remote_dir.rstrip("/")
|
|
69
|
+
return f"{self.user}@{self.host}:{base}/{subpath}" if subpath else f"{self.user}@{self.host}:{base}/"
|
|
70
|
+
|
|
71
|
+
def run(self, command: str, cwd: Optional[str] = None,
|
|
72
|
+
stream: bool = False) -> CommandResult:
|
|
73
|
+
"""Run a single command on the cluster and return its result.
|
|
74
|
+
|
|
75
|
+
Output is drained continuously while the command runs, so a chatty
|
|
76
|
+
command (``conda env create``, ``pip install``) can't fill paramiko's
|
|
77
|
+
channel window and deadlock against ``recv_exit_status``. Pass
|
|
78
|
+
``stream=True`` to also echo output live — useful for long-running
|
|
79
|
+
builds where you'd otherwise see nothing until they finish.
|
|
80
|
+
"""
|
|
81
|
+
import paramiko # imported lazily so the package imports without a cluster
|
|
82
|
+
|
|
83
|
+
cwd = cwd or self.remote_dir
|
|
84
|
+
wrapped = f"cd {cwd} && {command}" if cwd else command
|
|
85
|
+
|
|
86
|
+
client = paramiko.SSHClient()
|
|
87
|
+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
88
|
+
try:
|
|
89
|
+
client.connect(
|
|
90
|
+
hostname=self.host,
|
|
91
|
+
port=self.port,
|
|
92
|
+
username=self.user,
|
|
93
|
+
key_filename=self.key_path(),
|
|
94
|
+
password=self.password,
|
|
95
|
+
**self.extra_connect_kwargs,
|
|
96
|
+
)
|
|
97
|
+
chan = client.get_transport().open_session()
|
|
98
|
+
chan.exec_command(wrapped)
|
|
99
|
+
|
|
100
|
+
out_parts: list[str] = []
|
|
101
|
+
err_parts: list[str] = []
|
|
102
|
+
|
|
103
|
+
def _drain() -> bool:
|
|
104
|
+
got = False
|
|
105
|
+
while chan.recv_ready():
|
|
106
|
+
chunk = chan.recv(32768).decode("utf-8", "replace")
|
|
107
|
+
out_parts.append(chunk)
|
|
108
|
+
if stream:
|
|
109
|
+
print(chunk, end="", flush=True)
|
|
110
|
+
got = True
|
|
111
|
+
while chan.recv_stderr_ready():
|
|
112
|
+
chunk = chan.recv_stderr(32768).decode("utf-8", "replace")
|
|
113
|
+
err_parts.append(chunk)
|
|
114
|
+
if stream:
|
|
115
|
+
print(chunk, end="", flush=True)
|
|
116
|
+
got = True
|
|
117
|
+
return got
|
|
118
|
+
|
|
119
|
+
# keep reading so the remote side never blocks on a full window
|
|
120
|
+
while not chan.exit_status_ready():
|
|
121
|
+
if not _drain():
|
|
122
|
+
time.sleep(0.05)
|
|
123
|
+
while _drain(): # whatever is left after exit
|
|
124
|
+
pass
|
|
125
|
+
status = chan.recv_exit_status()
|
|
126
|
+
finally:
|
|
127
|
+
client.close()
|
|
128
|
+
return CommandResult(wrapped, status, "".join(out_parts), "".join(err_parts))
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _pub_path(path: str) -> Path:
|
|
132
|
+
"""The ``.pub`` file for a key path (accepts the private path or the .pub)."""
|
|
133
|
+
p = Path(os.path.expanduser(path))
|
|
134
|
+
return p if p.suffix == ".pub" else Path(str(p) + ".pub")
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def public_key(path: str = "~/.ssh/id_rsa") -> str:
|
|
138
|
+
"""Return the public key line for ``path`` (reads ``<path>.pub``).
|
|
139
|
+
|
|
140
|
+
This is the text you paste into your HPC — ``print(nb2slurm.public_key())``
|
|
141
|
+
then copy it into the cluster's key-upload page (or ``~/.ssh/authorized_keys``
|
|
142
|
+
on a login node, if your HPC lets you edit it directly).
|
|
143
|
+
"""
|
|
144
|
+
return _pub_path(path).read_text().strip()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def generate_key(path: str = "~/.ssh/id_rsa", bits: int = 4096,
|
|
148
|
+
comment: Optional[str] = None, overwrite: bool = False,
|
|
149
|
+
show: bool = True) -> Tuple[Path, Path]:
|
|
150
|
+
"""Create an RSA SSH keypair at ``path`` (+ ``<path>.pub``).
|
|
151
|
+
|
|
152
|
+
Returns ``(private_path, public_path)``. The private key is written 0600 and
|
|
153
|
+
the public key in ``authorized_keys`` format. An existing key is left alone
|
|
154
|
+
unless ``overwrite=True``, so this is safe to call repeatedly. Point your
|
|
155
|
+
``SSHConfig(key_filename=...)`` at ``path``.
|
|
156
|
+
|
|
157
|
+
nb2slurm can't install the key for you — many HPCs disable password login, so
|
|
158
|
+
there's no way in. With ``show=True`` (default) the public key is printed so
|
|
159
|
+
you can copy it into your cluster's key-upload page (or its
|
|
160
|
+
``~/.ssh/authorized_keys``); ``nb2slurm.public_key(path)`` reprints it later.
|
|
161
|
+
"""
|
|
162
|
+
import paramiko # lazy: keep the package importable without a crypto backend
|
|
163
|
+
|
|
164
|
+
priv = Path(os.path.expanduser(path))
|
|
165
|
+
pub = _pub_path(path)
|
|
166
|
+
if priv.exists() and not overwrite:
|
|
167
|
+
if show:
|
|
168
|
+
print(f"key already exists at {priv}; its public key is:\n\n{public_key(path)}")
|
|
169
|
+
return priv, pub
|
|
170
|
+
priv.parent.mkdir(parents=True, exist_ok=True)
|
|
171
|
+
|
|
172
|
+
key = paramiko.RSAKey.generate(bits)
|
|
173
|
+
key.write_private_key_file(str(priv))
|
|
174
|
+
pub.write_text(f"ssh-rsa {key.get_base64()} {comment or ''}".strip() + "\n")
|
|
175
|
+
for p, mode in ((priv, 0o600), (pub, 0o644)):
|
|
176
|
+
try:
|
|
177
|
+
os.chmod(p, mode)
|
|
178
|
+
except OSError:
|
|
179
|
+
pass # Windows without POSIX perms; OpenSSH there enforces via ACLs
|
|
180
|
+
if show:
|
|
181
|
+
print(
|
|
182
|
+
f"created SSH key: {priv} (private) and {pub} (public)\n\n"
|
|
183
|
+
"Add the PUBLIC key below to your HPC - via its key-upload page, or by\n"
|
|
184
|
+
"appending it to ~/.ssh/authorized_keys on a login node. nb2slurm can't\n"
|
|
185
|
+
"do this step for you (clusters usually disable password login):\n\n"
|
|
186
|
+
f"{public_key(path)}"
|
|
187
|
+
)
|
|
188
|
+
return priv, pub
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def run_shell(command: str, ssh: Optional[SSHConfig] = None,
|
|
192
|
+
cwd: str = ".", stream: bool = False) -> CommandResult:
|
|
193
|
+
"""Run a shell command on the cluster (via ``ssh``) or locally (subprocess).
|
|
194
|
+
|
|
195
|
+
Shared by Workflow and Environment so the ssh-vs-local branch lives in one
|
|
196
|
+
place. ``stream=True`` echoes output live (for long-running commands).
|
|
197
|
+
"""
|
|
198
|
+
if ssh is not None:
|
|
199
|
+
return ssh.run(command, stream=stream)
|
|
200
|
+
if stream:
|
|
201
|
+
proc = subprocess.Popen(command, shell=True, cwd=str(cwd), text=True,
|
|
202
|
+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
|
203
|
+
parts: list[str] = []
|
|
204
|
+
for line in proc.stdout: # tee: capture and echo
|
|
205
|
+
parts.append(line)
|
|
206
|
+
print(line, end="", flush=True)
|
|
207
|
+
proc.wait()
|
|
208
|
+
return CommandResult(command, proc.returncode, "".join(parts), "")
|
|
209
|
+
proc = subprocess.run(command, shell=True, cwd=str(cwd),
|
|
210
|
+
capture_output=True, text=True)
|
|
211
|
+
return CommandResult(command, proc.returncode, proc.stdout, proc.stderr)
|