nb2slurm 0.0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nb2slurm/__init__.py ADDED
@@ -0,0 +1,37 @@
1
+ """nb2slurm: scale a single-subject notebook workflow to many subjects on SLURM.
2
+
3
+ The public surface is intentionally tiny so the whole workflow can be driven
4
+ from a notebook with no command line:
5
+
6
+ import nb2slurm
7
+
8
+ wf = nb2slurm.Workflow(
9
+ name="square",
10
+ notebooks=["notebooks/0_settings.ipynb", "notebooks/1_compute.ipynb"],
11
+ kernel="python3",
12
+ varying=["item_id"],
13
+ resources=dict(nodes=1, cpus=2, time="00:10:00"),
14
+ )
15
+ wf.build() # render scripts/ into the project
16
+ wf.submit([1, 2, 3], ssh=cfg) # sbatch one job per item, over SSH
17
+ wf.status(ssh=cfg) # squeue -> list of dicts
18
+ wf.cancel(ssh=cfg) # scancel the jobs we submitted
19
+
20
+ Inside the notebooks themselves, use the Settings helper:
21
+
22
+ nb2slurm.Settings.write(outdir, {...}) # first notebook
23
+ settings = nb2slurm.Settings.load(path) # later notebooks
24
+ """
25
+
26
+ from .workflow import Workflow
27
+ from .environment import Environment
28
+ from .ssh import SSHConfig, generate_key, public_key
29
+ from .settings import Settings
30
+ from .done import Done
31
+ from .structure import Structure
32
+ from .config import save_config, load_config
33
+ from .runtime import on_hpc
34
+
35
+ __all__ = ["Workflow", "Environment", "SSHConfig", "generate_key", "public_key",
36
+ "Settings", "Done", "Structure", "save_config", "load_config", "on_hpc"]
37
+ __version__ = "0.0.1.dev1"
nb2slurm/config.py ADDED
@@ -0,0 +1,70 @@
1
+ """Save / load a run's control settings to one JSON file.
2
+
3
+ The config notebook builds a Workflow (+ SSHConfig) and calls ``save_config``;
4
+ the build / submit / sync notebooks call ``load_config`` to get them back. This
5
+ keeps a single source of truth on disk, so the notebooks never duplicate settings.
6
+
7
+ # in 0_config.ipynb
8
+ nb2slurm.save_config("control_config.json", workflow=wf, ssh=cfg)
9
+
10
+ # in 1_build / 2_submit / 3_sync
11
+ wf, cfg = nb2slurm.load_config("control_config.json")
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ from dataclasses import asdict
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ from .environment import Environment
22
+ from .ssh import SSHConfig
23
+ from .workflow import Workflow
24
+
25
+
26
+ def _workflow_to_dict(wf: Workflow) -> dict:
27
+ d = asdict(wf) # recurses Environment, resources, mounts
28
+ d.pop("submitted_jobs", None) # runtime state, not configuration
29
+ return d
30
+
31
+
32
+ def _workflow_from_dict(d: dict) -> Workflow:
33
+ d = dict(d)
34
+ d.pop("submitted_jobs", None)
35
+ env = d.pop("environment", None)
36
+ extras = d.pop("extra_environments", None) or []
37
+ return Workflow(
38
+ **d,
39
+ environment=Environment(**env) if env else None,
40
+ extra_environments=[Environment(**e) for e in extras],
41
+ )
42
+
43
+
44
+ def _ssh_to_dict(ssh: SSHConfig) -> dict:
45
+ d = asdict(ssh)
46
+ d.pop("password", None) # never persist secrets to disk
47
+ return d
48
+
49
+
50
+ def save_config(path: str | Path, *, workflow: Workflow,
51
+ ssh: Optional[SSHConfig] = None) -> Path:
52
+ """Write the workflow (and optional SSH) config to ``path`` as JSON."""
53
+ data = {"workflow": _workflow_to_dict(workflow)}
54
+ if ssh is not None:
55
+ data["ssh"] = _ssh_to_dict(ssh)
56
+ path = Path(path)
57
+ path.write_text(json.dumps(data, indent=2), encoding="utf-8")
58
+ return path
59
+
60
+
61
+ def load_config(path: str | Path) -> tuple[Workflow, Optional[SSHConfig]]:
62
+ """Read a config file back into ``(workflow, ssh)``.
63
+
64
+ ``ssh`` is ``None`` if none was saved. Passwords are never stored, so set
65
+ ``ssh.password`` yourself afterwards if your cluster needs one.
66
+ """
67
+ data = json.loads(Path(path).read_text(encoding="utf-8"))
68
+ wf = _workflow_from_dict(data["workflow"])
69
+ ssh = SSHConfig(**data["ssh"]) if data.get("ssh") else None
70
+ return wf, ssh
nb2slurm/done.py ADDED
@@ -0,0 +1,65 @@
1
+ """Idempotent 'already done?' bookkeeping, as an object.
2
+
3
+ Generalised from cci.py: a single CSV records which subjects have completed, so
4
+ re-submitting the whole set skips finished work. A file lock makes it safe when
5
+ many jobs write concurrently on a shared filesystem.
6
+
7
+ done = nb2slurm.Done("done/done.csv")
8
+ if done.is_done(key):
9
+ ...
10
+ done.mark(key)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import csv
16
+ from pathlib import Path
17
+
18
+ from filelock import FileLock
19
+
20
+ LOCK_TIMEOUT = 60 * 3
21
+
22
+
23
+ class Done:
24
+ """A CSV ledger of completed subjects, with concurrency-safe writes."""
25
+
26
+ def __init__(self, csv_file: str | Path):
27
+ self.csv_file = Path(csv_file)
28
+
29
+ @property
30
+ def _lock(self) -> str:
31
+ return str(self.csv_file) + ".lock"
32
+
33
+ def is_done(self, key: str) -> bool:
34
+ """Return True if ``key`` is already recorded."""
35
+ if not self.csv_file.exists():
36
+ return False
37
+ with FileLock(self._lock, timeout=LOCK_TIMEOUT):
38
+ with open(self.csv_file, newline="") as f:
39
+ reader = csv.reader(f)
40
+ next(reader, None) # header
41
+ for row in reader:
42
+ if row and row[0] == str(key):
43
+ return True
44
+ return False
45
+
46
+ def mark(self, key: str) -> None:
47
+ """Record ``key`` as done (no-op if already present)."""
48
+ self.csv_file.parent.mkdir(parents=True, exist_ok=True)
49
+ if not self.csv_file.exists():
50
+ with open(self.csv_file, "w", newline="") as f:
51
+ csv.writer(f).writerow(["key"])
52
+ with FileLock(self._lock, timeout=LOCK_TIMEOUT):
53
+ existing = set()
54
+ with open(self.csv_file, newline="") as f:
55
+ reader = csv.reader(f)
56
+ next(reader, None)
57
+ for row in reader:
58
+ if row:
59
+ existing.add(row[0])
60
+ if str(key) not in existing:
61
+ with open(self.csv_file, "a", newline="") as f:
62
+ csv.writer(f).writerow([key])
63
+ print(f"Marked {key} as done.")
64
+ else:
65
+ print(f"{key} already recorded as done.")
@@ -0,0 +1,140 @@
1
+ """Create the conda environment + Jupyter kernel the workflow runs in.
2
+
3
+ Most users of nb2slurm are not Linux/conda experts, but the generated SLURM job
4
+ does ``conda activate <env>`` and papermill needs a *registered Jupyter kernel*
5
+ to execute the notebooks. This module writes an ``environment.yml`` and creates
6
+ the environment + kernel on the cluster (or locally), so the user never touches
7
+ the command line.
8
+
9
+ from nb2slurm import Environment
10
+
11
+ env = Environment(
12
+ name="myenv",
13
+ kernel="myenv", # must match Workflow(kernel=...)
14
+ conda_packages=["xarray", "numpy"],
15
+ pip_packages=["nb2slurm", "ewatercycle"],
16
+ )
17
+ env.write() # -> environment.yml
18
+ env.create(ssh=cfg) # build env + register kernel on the HPC
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, field
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+ from .ssh import CommandResult, SSHConfig, run_shell
28
+
29
+
30
+ @dataclass
31
+ class Environment:
32
+ name: str
33
+ kernel: str
34
+ python: str = "3.11"
35
+ channels: list[str] = field(default_factory=lambda: ["conda-forge"])
36
+ conda_packages: list[str] = field(default_factory=list)
37
+ pip_packages: list[str] = field(default_factory=lambda: ["nb2slurm"])
38
+
39
+ def to_yaml(self) -> str:
40
+ """Render an ``environment.yml`` for conda/mamba."""
41
+ lines = [f"name: {self.name}", "channels:"]
42
+ lines += [f" - {c}" for c in self.channels]
43
+ lines.append("dependencies:")
44
+ lines.append(f" - python={self.python}")
45
+ lines.append(" - pip")
46
+ lines.append(" - ipykernel") # required so papermill can run the notebooks
47
+ lines += [f" - {p}" for p in self.conda_packages]
48
+ if self.pip_packages:
49
+ lines.append(" - pip:")
50
+ lines += [f" - {p}" for p in self.pip_packages]
51
+ return "\n".join(lines) + "\n"
52
+
53
+ def write(self, project_dir: str | Path = ".", filename: str = "environment.yml") -> Path:
54
+ """Write the ``environment.yml`` into the project directory."""
55
+ path = Path(project_dir) / filename
56
+ path.write_text(self.to_yaml(), encoding="utf-8")
57
+ return path
58
+
59
+ def _exists_test(self) -> str:
60
+ """A shell test (exit 0 = env exists). Matches the name as a whole path
61
+ component so ``montecarlo`` doesn't match ``montecarlo2``."""
62
+ return f'conda env list | grep -qE "[/ ]{self.name}([ /]|$)"'
63
+
64
+ def _create_command(self, filename: str = "environment.yml") -> str:
65
+ # mamba is much faster than conda; use it when available.
66
+ return (
67
+ "set -e; "
68
+ # There's no TTY over ssh, so any interactive prompt would hang the
69
+ # build forever. Belt and suspenders: set always-yes AND pipe `yes`
70
+ # into conda (CONDA_ALWAYS_YES alone is ignored by some mamba builds
71
+ # for the 'Confirm changes? [Y/n]' transaction prompt).
72
+ "export CONDA_ALWAYS_YES=yes; "
73
+ "CONDA=conda; command -v mamba >/dev/null 2>&1 && CONDA=mamba; "
74
+ # update an existing env in place rather than hitting the interactive
75
+ # 'Found conda-prefix ... Overwrite? [y/N]' prompt.
76
+ f"if {self._exists_test()}; then "
77
+ f' echo "Updating existing environment {self.name} with $CONDA..."; '
78
+ f" yes | $CONDA env update -f {filename}; "
79
+ f"else "
80
+ f' echo "Creating environment {self.name} with $CONDA..."; '
81
+ f" yes | $CONDA env create -f {filename}; "
82
+ f"fi; "
83
+ f"conda run -n {self.name} python -m ipykernel install --user "
84
+ f'--name {self.kernel} --display-name "{self.kernel}"; '
85
+ f'echo "Environment {self.name} ready; kernel {self.kernel} registered."'
86
+ )
87
+
88
+ def _remove_command(self) -> str:
89
+ return (
90
+ "export CONDA_ALWAYS_YES=yes; "
91
+ f'echo "Removing environment {self.name} and kernel {self.kernel}..."; '
92
+ # `|| true`: removing a non-existent env/kernel is not an error here
93
+ f"conda env remove -n {self.name} || true; "
94
+ f"jupyter kernelspec remove -f {self.kernel} 2>/dev/null || true; "
95
+ f'echo "Removed {self.name}."'
96
+ )
97
+
98
+ def exists(
99
+ self,
100
+ ssh: Optional[SSHConfig] = None,
101
+ project_dir: str | Path = ".",
102
+ ) -> bool:
103
+ """Return True if the conda env already exists — on the HPC (ssh) or locally."""
104
+ return run_shell(self._exists_test(), ssh, str(project_dir)).exit_status == 0
105
+
106
+ def remove(
107
+ self,
108
+ ssh: Optional[SSHConfig] = None,
109
+ project_dir: str | Path = ".",
110
+ stream: bool = True,
111
+ ) -> CommandResult:
112
+ """Delete the conda env and its Jupyter kernel — on the HPC (ssh) or locally.
113
+
114
+ Safe to call when nothing is there yet (a missing env/kernel is ignored).
115
+ Use it to recover from a half-built env or to force a clean rebuild.
116
+ """
117
+ return run_shell(self._remove_command(), ssh, str(project_dir), stream=stream).check()
118
+
119
+ def create(
120
+ self,
121
+ ssh: Optional[SSHConfig] = None,
122
+ project_dir: str | Path = ".",
123
+ filename: str = "environment.yml",
124
+ stream: bool = True,
125
+ overwrite: bool = False,
126
+ ) -> CommandResult:
127
+ """Create the env and register the kernel — on the HPC (ssh) or locally.
128
+
129
+ Writes ``environment.yml`` first if it is missing. If the env already
130
+ exists it is *updated* in place; pass ``overwrite=True`` to delete and
131
+ rebuild it from scratch. ``stream=True`` (the default) echoes conda/pip
132
+ output live, since a solve + downloads can take minutes and would
133
+ otherwise look like a hang.
134
+ """
135
+ if not (Path(project_dir) / filename).exists():
136
+ self.write(project_dir, filename)
137
+ if overwrite:
138
+ self.remove(ssh=ssh, project_dir=project_dir, stream=stream)
139
+ command = self._create_command(filename)
140
+ return run_shell(command, ssh, str(project_dir), stream=stream).check()
nb2slurm/render.py ADDED
@@ -0,0 +1,48 @@
1
+ """Render the bundled Jinja2 templates into a project's scripts/ directory."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from importlib import resources
6
+ from pathlib import Path
7
+ from typing import Any, Mapping
8
+
9
+ from jinja2 import Environment
10
+
11
+
12
+ def _env() -> Environment:
13
+ # No autoescaping: these are shell/python scripts, not HTML.
14
+ # Move the comment delimiters off the default "{# #}" so bash parameter
15
+ # expansions like ${#arr[@]} don't get mistaken for Jinja comments.
16
+ return Environment(
17
+ trim_blocks=True,
18
+ lstrip_blocks=True,
19
+ keep_trailing_newline=True,
20
+ comment_start_string="<#nb2slurm#",
21
+ comment_end_string="#nb2slurm#>",
22
+ )
23
+
24
+
25
+ def render_template(template_name: str, context: Mapping[str, Any]) -> str:
26
+ """Render one bundled template (e.g. 'run_workflow.py.j2') to a string."""
27
+ source = (
28
+ resources.files("nb2slurm.templates")
29
+ .joinpath(template_name)
30
+ .read_text(encoding="utf-8")
31
+ )
32
+ return _env().from_string(source).render(**context)
33
+
34
+
35
+ def write_rendered(
36
+ template_name: str,
37
+ out_path: str | Path,
38
+ context: Mapping[str, Any],
39
+ executable: bool = False,
40
+ ) -> Path:
41
+ """Render a template and write it to ``out_path``."""
42
+ out_path = Path(out_path)
43
+ out_path.parent.mkdir(parents=True, exist_ok=True)
44
+ out_path.write_text(render_template(template_name, context), encoding="utf-8")
45
+ if executable:
46
+ mode = out_path.stat().st_mode
47
+ out_path.chmod(mode | 0o111)
48
+ return out_path
nb2slurm/runtime.py ADDED
@@ -0,0 +1,43 @@
1
+ """Detect whether the current notebook is running under nb2slurm on a cluster.
2
+
3
+ Checking ``Path.home()`` for a username is fragile (breaks for other users, can't
4
+ tell a cloud VM from a laptop). Instead we look at environment variables that are
5
+ only present in a batch job:
6
+
7
+ * the ``SLURM_*`` variables SLURM sets in every job, and
8
+ * the ``NB2SLURM`` sentinel that the generated ``job.slurm`` exports (so this also
9
+ works on non-SLURM batch setups that run nb2slurm's job script).
10
+
11
+ Use it in your notebooks for the things that genuinely differ between
12
+ interactive and batch runs — machine-specific data paths, skipping ``!pip install``:
13
+
14
+ import nb2slurm
15
+ if nb2slurm.on_hpc():
16
+ data_dir = "/project/ewater/Data"
17
+ else:
18
+ data_dir = "/data/shared"
19
+
20
+ It also cleans up importing a helper from ``scripts/``. On the cluster the job
21
+ runs from the project root, so ``from scripts.foo import bar`` just works; run
22
+ interactively from a ``notebooks/`` subfolder it doesn't, so add the project root
23
+ to the path only when running locally:
24
+
25
+ import sys
26
+ from pathlib import Path
27
+ import nb2slurm
28
+ if not nb2slurm.on_hpc():
29
+ sys.path.append(str(Path().resolve().parent))
30
+ from scripts.montecarlo import estimate_pi
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import os
36
+
37
+ # present inside a SLURM job; NB2SLURM is exported by the generated job.slurm
38
+ _BATCH_ENV_VARS = ("NB2SLURM", "SLURM_JOB_ID", "SLURM_JOBID")
39
+
40
+
41
+ def on_hpc() -> bool:
42
+ """Return True if running inside an nb2slurm/SLURM batch job."""
43
+ return any(var in os.environ for var in _BATCH_ENV_VARS)
nb2slurm/settings.py ADDED
@@ -0,0 +1,46 @@
1
+ """The per-run settings file, as an object.
2
+
3
+ The paper's notebook structure has the *first* notebook write a settings file
4
+ (JSON) that every later notebook reads. This keeps the varying parameters in one
5
+ place: only notebook 0 is parameterised by papermill, the rest just load it.
6
+
7
+ import nb2slurm
8
+
9
+ # first notebook (parameterised by nb2slurm):
10
+ nb2slurm.Settings.write(outdir, {"region_id": region_id, "outdir": outdir})
11
+
12
+ # every later notebook:
13
+ settings = nb2slurm.Settings.load(settings_path)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ from pathlib import Path
20
+ from typing import Any, Mapping
21
+
22
+
23
+ class Settings:
24
+ """Read/write the per-run ``settings.json``."""
25
+
26
+ filename = "settings.json"
27
+
28
+ @staticmethod
29
+ def write(outdir: str | Path, settings: Mapping[str, Any]) -> Path:
30
+ """Write ``settings`` to ``<outdir>/settings.json`` and return the path.
31
+
32
+ Call this from your first notebook. ``outdir`` is supplied to that
33
+ notebook by the generated runner, so you never hardcode a path.
34
+ """
35
+ outdir = Path(outdir)
36
+ outdir.mkdir(parents=True, exist_ok=True)
37
+ path = outdir / Settings.filename
38
+ with open(path, "w") as f:
39
+ json.dump(dict(settings), f, indent=2)
40
+ return path
41
+
42
+ @staticmethod
43
+ def load(settings_path: str | Path) -> dict[str, Any]:
44
+ """Load a settings file. Call this at the top of every later notebook."""
45
+ with open(settings_path) as f:
46
+ return json.load(f)
nb2slurm/ssh.py ADDED
@@ -0,0 +1,211 @@
1
+ """Minimal SSH transport so the managing notebook can drive SLURM with no CLI.
2
+
3
+ This mimics the command line / ssh that the paper says should be hidden from the
4
+ user: sbatch/squeue/scancel run on the cluster, but the user only writes Python.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import subprocess
11
+ import time
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Optional, Tuple
15
+
16
+
17
+ @dataclass
18
+ class CommandResult:
19
+ command: str
20
+ exit_status: int
21
+ stdout: str
22
+ stderr: str
23
+
24
+ def check(self) -> "CommandResult":
25
+ if self.exit_status != 0:
26
+ raise RuntimeError(
27
+ f"Remote command failed ({self.exit_status}): {self.command}\n{self.stderr}"
28
+ )
29
+ return self
30
+
31
+
32
+ @dataclass
33
+ class SSHConfig:
34
+ """Connection details for the HPC login node.
35
+
36
+ Provide either ``key_filename`` or ``password`` (or rely on an agent/known
37
+ config). ``remote_dir`` is the project directory on the cluster that the
38
+ generated scripts live in; commands are run from there.
39
+ """
40
+
41
+ host: str
42
+ user: str
43
+ remote_dir: str
44
+ port: int = 22
45
+ key_filename: Optional[str] = None
46
+ password: Optional[str] = None
47
+ extra_connect_kwargs: dict = field(default_factory=dict)
48
+
49
+ def key_path(self) -> Optional[str]:
50
+ """The private key path with ``~`` expanded, or ``None`` if unset.
51
+
52
+ paramiko opens ``key_filename`` directly and does **not** expand ``~``,
53
+ so we resolve it here (e.g. ``~/.ssh/id_rsa`` -> the absolute path).
54
+ """
55
+ return os.path.expanduser(self.key_filename) if self.key_filename else None
56
+
57
+ def rsync_ssh(self) -> str:
58
+ """The ``-e`` transport string rsync should use (ssh + port + key)."""
59
+ parts = ["ssh"]
60
+ if self.port != 22:
61
+ parts += ["-p", str(self.port)]
62
+ if self.key_filename:
63
+ parts += ["-i", self.key_path()]
64
+ return " ".join(parts)
65
+
66
+ def rsync_target(self, subpath: str = "") -> str:
67
+ """A ``user@host:remote_dir/<subpath>`` spec for rsync."""
68
+ base = self.remote_dir.rstrip("/")
69
+ return f"{self.user}@{self.host}:{base}/{subpath}" if subpath else f"{self.user}@{self.host}:{base}/"
70
+
71
+ def run(self, command: str, cwd: Optional[str] = None,
72
+ stream: bool = False) -> CommandResult:
73
+ """Run a single command on the cluster and return its result.
74
+
75
+ Output is drained continuously while the command runs, so a chatty
76
+ command (``conda env create``, ``pip install``) can't fill paramiko's
77
+ channel window and deadlock against ``recv_exit_status``. Pass
78
+ ``stream=True`` to also echo output live — useful for long-running
79
+ builds where you'd otherwise see nothing until they finish.
80
+ """
81
+ import paramiko # imported lazily so the package imports without a cluster
82
+
83
+ cwd = cwd or self.remote_dir
84
+ wrapped = f"cd {cwd} && {command}" if cwd else command
85
+
86
+ client = paramiko.SSHClient()
87
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
88
+ try:
89
+ client.connect(
90
+ hostname=self.host,
91
+ port=self.port,
92
+ username=self.user,
93
+ key_filename=self.key_path(),
94
+ password=self.password,
95
+ **self.extra_connect_kwargs,
96
+ )
97
+ chan = client.get_transport().open_session()
98
+ chan.exec_command(wrapped)
99
+
100
+ out_parts: list[str] = []
101
+ err_parts: list[str] = []
102
+
103
+ def _drain() -> bool:
104
+ got = False
105
+ while chan.recv_ready():
106
+ chunk = chan.recv(32768).decode("utf-8", "replace")
107
+ out_parts.append(chunk)
108
+ if stream:
109
+ print(chunk, end="", flush=True)
110
+ got = True
111
+ while chan.recv_stderr_ready():
112
+ chunk = chan.recv_stderr(32768).decode("utf-8", "replace")
113
+ err_parts.append(chunk)
114
+ if stream:
115
+ print(chunk, end="", flush=True)
116
+ got = True
117
+ return got
118
+
119
+ # keep reading so the remote side never blocks on a full window
120
+ while not chan.exit_status_ready():
121
+ if not _drain():
122
+ time.sleep(0.05)
123
+ while _drain(): # whatever is left after exit
124
+ pass
125
+ status = chan.recv_exit_status()
126
+ finally:
127
+ client.close()
128
+ return CommandResult(wrapped, status, "".join(out_parts), "".join(err_parts))
129
+
130
+
131
+ def _pub_path(path: str) -> Path:
132
+ """The ``.pub`` file for a key path (accepts the private path or the .pub)."""
133
+ p = Path(os.path.expanduser(path))
134
+ return p if p.suffix == ".pub" else Path(str(p) + ".pub")
135
+
136
+
137
+ def public_key(path: str = "~/.ssh/id_rsa") -> str:
138
+ """Return the public key line for ``path`` (reads ``<path>.pub``).
139
+
140
+ This is the text you paste into your HPC — ``print(nb2slurm.public_key())``
141
+ then copy it into the cluster's key-upload page (or ``~/.ssh/authorized_keys``
142
+ on a login node, if your HPC lets you edit it directly).
143
+ """
144
+ return _pub_path(path).read_text().strip()
145
+
146
+
147
+ def generate_key(path: str = "~/.ssh/id_rsa", bits: int = 4096,
148
+ comment: Optional[str] = None, overwrite: bool = False,
149
+ show: bool = True) -> Tuple[Path, Path]:
150
+ """Create an RSA SSH keypair at ``path`` (+ ``<path>.pub``).
151
+
152
+ Returns ``(private_path, public_path)``. The private key is written 0600 and
153
+ the public key in ``authorized_keys`` format. An existing key is left alone
154
+ unless ``overwrite=True``, so this is safe to call repeatedly. Point your
155
+ ``SSHConfig(key_filename=...)`` at ``path``.
156
+
157
+ nb2slurm can't install the key for you — many HPCs disable password login, so
158
+ there's no way in. With ``show=True`` (default) the public key is printed so
159
+ you can copy it into your cluster's key-upload page (or its
160
+ ``~/.ssh/authorized_keys``); ``nb2slurm.public_key(path)`` reprints it later.
161
+ """
162
+ import paramiko # lazy: keep the package importable without a crypto backend
163
+
164
+ priv = Path(os.path.expanduser(path))
165
+ pub = _pub_path(path)
166
+ if priv.exists() and not overwrite:
167
+ if show:
168
+ print(f"key already exists at {priv}; its public key is:\n\n{public_key(path)}")
169
+ return priv, pub
170
+ priv.parent.mkdir(parents=True, exist_ok=True)
171
+
172
+ key = paramiko.RSAKey.generate(bits)
173
+ key.write_private_key_file(str(priv))
174
+ pub.write_text(f"ssh-rsa {key.get_base64()} {comment or ''}".strip() + "\n")
175
+ for p, mode in ((priv, 0o600), (pub, 0o644)):
176
+ try:
177
+ os.chmod(p, mode)
178
+ except OSError:
179
+ pass # Windows without POSIX perms; OpenSSH there enforces via ACLs
180
+ if show:
181
+ print(
182
+ f"created SSH key: {priv} (private) and {pub} (public)\n\n"
183
+ "Add the PUBLIC key below to your HPC - via its key-upload page, or by\n"
184
+ "appending it to ~/.ssh/authorized_keys on a login node. nb2slurm can't\n"
185
+ "do this step for you (clusters usually disable password login):\n\n"
186
+ f"{public_key(path)}"
187
+ )
188
+ return priv, pub
189
+
190
+
191
+ def run_shell(command: str, ssh: Optional[SSHConfig] = None,
192
+ cwd: str = ".", stream: bool = False) -> CommandResult:
193
+ """Run a shell command on the cluster (via ``ssh``) or locally (subprocess).
194
+
195
+ Shared by Workflow and Environment so the ssh-vs-local branch lives in one
196
+ place. ``stream=True`` echoes output live (for long-running commands).
197
+ """
198
+ if ssh is not None:
199
+ return ssh.run(command, stream=stream)
200
+ if stream:
201
+ proc = subprocess.Popen(command, shell=True, cwd=str(cwd), text=True,
202
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
203
+ parts: list[str] = []
204
+ for line in proc.stdout: # tee: capture and echo
205
+ parts.append(line)
206
+ print(line, end="", flush=True)
207
+ proc.wait()
208
+ return CommandResult(command, proc.returncode, "".join(parts), "")
209
+ proc = subprocess.run(command, shell=True, cwd=str(cwd),
210
+ capture_output=True, text=True)
211
+ return CommandResult(command, proc.returncode, proc.stdout, proc.stderr)