clusterpilot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clusterpilot/__init__.py +3 -0
- clusterpilot/__main__.py +139 -0
- clusterpilot/cluster/__init__.py +33 -0
- clusterpilot/cluster/probe.py +230 -0
- clusterpilot/cluster/slurm.py +174 -0
- clusterpilot/config.py +197 -0
- clusterpilot/db.py +286 -0
- clusterpilot/jobs/__init__.py +4 -0
- clusterpilot/jobs/ai_gen.py +375 -0
- clusterpilot/jobs/daemon.py +328 -0
- clusterpilot/jobs/env_detect.py +170 -0
- clusterpilot/notify/__init__.py +19 -0
- clusterpilot/notify/ntfy.py +131 -0
- clusterpilot/ssh/__init__.py +15 -0
- clusterpilot/ssh/connection.py +122 -0
- clusterpilot/ssh/rsync.py +167 -0
- clusterpilot/tui/__init__.py +3 -0
- clusterpilot/tui/app.py +679 -0
- clusterpilot/tui/config_view.py +87 -0
- clusterpilot/tui/jobs.py +369 -0
- clusterpilot/tui/submit.py +692 -0
- clusterpilot/tui/widgets/__init__.py +0 -0
- clusterpilot/tui/widgets/file_explorer.py +155 -0
- clusterpilot-0.1.0.dist-info/METADATA +367 -0
- clusterpilot-0.1.0.dist-info/RECORD +28 -0
- clusterpilot-0.1.0.dist-info/WHEEL +4 -0
- clusterpilot-0.1.0.dist-info/entry_points.txt +2 -0
- clusterpilot-0.1.0.dist-info/licenses/LICENSE +21 -0
clusterpilot/__init__.py
ADDED
clusterpilot/__main__.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""ClusterPilot entry point.
|
|
2
|
+
|
|
3
|
+
Usage
|
|
4
|
+
-----
|
|
5
|
+
clusterpilot # launch TUI
|
|
6
|
+
clusterpilot init # create starter ~/.config/clusterpilot/config.toml
|
|
7
|
+
clusterpilot daemon run # run poll daemon in foreground (no TUI)
|
|
8
|
+
clusterpilot daemon install # install systemd user service
|
|
9
|
+
"""
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main() -> None:
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
prog="clusterpilot",
|
|
19
|
+
description="AI-assisted HPC workflow manager",
|
|
20
|
+
)
|
|
21
|
+
sub = parser.add_subparsers(dest="cmd")
|
|
22
|
+
|
|
23
|
+
sub.add_parser("init", help="Create starter config at ~/.config/clusterpilot/config.toml")
|
|
24
|
+
|
|
25
|
+
daemon_p = sub.add_parser("daemon", help="Daemon management")
|
|
26
|
+
daemon_sub = daemon_p.add_subparsers(dest="daemon_cmd")
|
|
27
|
+
daemon_sub.add_parser("run", help="Run poll daemon in foreground")
|
|
28
|
+
daemon_sub.add_parser("install", help="Install systemd user service unit")
|
|
29
|
+
|
|
30
|
+
args = parser.parse_args()
|
|
31
|
+
|
|
32
|
+
if args.cmd == "init":
|
|
33
|
+
_cmd_init()
|
|
34
|
+
elif args.cmd == "daemon":
|
|
35
|
+
if args.daemon_cmd == "run":
|
|
36
|
+
_cmd_daemon_run()
|
|
37
|
+
elif args.daemon_cmd == "install":
|
|
38
|
+
_cmd_daemon_install()
|
|
39
|
+
else:
|
|
40
|
+
daemon_p.print_help()
|
|
41
|
+
else:
|
|
42
|
+
_cmd_tui()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ── Subcommands ───────────────────────────────────────────────────────────────
|
|
46
|
+
|
|
47
|
+
def _cmd_init() -> None:
|
|
48
|
+
from clusterpilot.config import CONFIG_PATH, write_default_config
|
|
49
|
+
if CONFIG_PATH.exists():
|
|
50
|
+
print(f"Config already exists: {CONFIG_PATH}")
|
|
51
|
+
return
|
|
52
|
+
write_default_config()
|
|
53
|
+
print(f"Config written to {CONFIG_PATH}")
|
|
54
|
+
print("Edit it to set your cluster username and account, then run: clusterpilot")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _cmd_daemon_run() -> None:
|
|
58
|
+
import asyncio
|
|
59
|
+
import aiosqlite
|
|
60
|
+
from clusterpilot.config import ConfigError, load_config
|
|
61
|
+
from clusterpilot.db import DB_PATH, init_db
|
|
62
|
+
from clusterpilot.jobs.daemon import PollDaemon
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
config = load_config()
|
|
66
|
+
except ConfigError as exc:
|
|
67
|
+
print(f"Error: {exc}", file=sys.stderr)
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
async def _run() -> None:
|
|
71
|
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
72
|
+
async with aiosqlite.connect(DB_PATH) as db:
|
|
73
|
+
await init_db(db)
|
|
74
|
+
daemon = PollDaemon(config, DB_PATH)
|
|
75
|
+
await daemon.run_forever()
|
|
76
|
+
|
|
77
|
+
print("ClusterPilot daemon running. Press Ctrl-C to stop.")
|
|
78
|
+
try:
|
|
79
|
+
asyncio.run(_run())
|
|
80
|
+
except KeyboardInterrupt:
|
|
81
|
+
print("\nDaemon stopped.")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _cmd_daemon_install() -> None:
|
|
85
|
+
from clusterpilot.jobs.daemon import write_service_file
|
|
86
|
+
path = write_service_file()
|
|
87
|
+
print(f"Service file written to: {path}")
|
|
88
|
+
print()
|
|
89
|
+
print("Enable and start with:")
|
|
90
|
+
print(" systemctl --user daemon-reload")
|
|
91
|
+
print(" systemctl --user enable --now clusterpilot-poll.service")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _cmd_tui() -> None:
|
|
95
|
+
import logging
|
|
96
|
+
from clusterpilot.config import ConfigError, load_config, write_default_config
|
|
97
|
+
from clusterpilot.tui.app import ClusterPilotApp
|
|
98
|
+
|
|
99
|
+
logging.basicConfig(
|
|
100
|
+
level=logging.WARNING,
|
|
101
|
+
format="%(levelname)s %(name)s: %(message)s",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
config = load_config()
|
|
106
|
+
except ConfigError:
|
|
107
|
+
# First run — write default config and guide the user.
|
|
108
|
+
write_default_config()
|
|
109
|
+
from clusterpilot.config import CONFIG_PATH
|
|
110
|
+
print("Welcome to ClusterPilot!")
|
|
111
|
+
print()
|
|
112
|
+
print(f"A starter config has been written to:\n {CONFIG_PATH}")
|
|
113
|
+
print()
|
|
114
|
+
print("Edit it to add your cluster username and account,")
|
|
115
|
+
print("then run 'clusterpilot' again.")
|
|
116
|
+
sys.exit(0)
|
|
117
|
+
|
|
118
|
+
if not config.clusters:
|
|
119
|
+
print("No clusters defined in config. Edit ~/.config/clusterpilot/config.toml.")
|
|
120
|
+
sys.exit(1)
|
|
121
|
+
|
|
122
|
+
if not config.api_key:
|
|
123
|
+
import os
|
|
124
|
+
if not os.environ.get("ANTHROPIC_API_KEY"):
|
|
125
|
+
print(
|
|
126
|
+
"Warning: no API key configured. "
|
|
127
|
+
"Script generation will fail.\n"
|
|
128
|
+
"Set api_key in config.toml or export ANTHROPIC_API_KEY."
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
from clusterpilot.db import DB_PATH
|
|
132
|
+
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
133
|
+
|
|
134
|
+
app = ClusterPilotApp(config)
|
|
135
|
+
app.run()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
main()
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from clusterpilot.cluster.probe import (
|
|
2
|
+
ClusterProbe,
|
|
3
|
+
PartitionInfo,
|
|
4
|
+
load_cache,
|
|
5
|
+
probe_cluster,
|
|
6
|
+
save_cache,
|
|
7
|
+
)
|
|
8
|
+
from clusterpilot.cluster.slurm import (
|
|
9
|
+
TERMINAL_STATES,
|
|
10
|
+
SlurmError,
|
|
11
|
+
cancel,
|
|
12
|
+
find_log,
|
|
13
|
+
job_status,
|
|
14
|
+
submit,
|
|
15
|
+
tail_log,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# probe
|
|
20
|
+
"ClusterProbe",
|
|
21
|
+
"PartitionInfo",
|
|
22
|
+
"load_cache",
|
|
23
|
+
"probe_cluster",
|
|
24
|
+
"save_cache",
|
|
25
|
+
# slurm
|
|
26
|
+
"TERMINAL_STATES",
|
|
27
|
+
"SlurmError",
|
|
28
|
+
"cancel",
|
|
29
|
+
"find_log",
|
|
30
|
+
"job_status",
|
|
31
|
+
"submit",
|
|
32
|
+
"tail_log",
|
|
33
|
+
]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Cluster probe: query sinfo, module avail, and sacctmgr; cache 24h.
|
|
2
|
+
|
|
3
|
+
Results are stored in ~/.cache/clusterpilot/<cluster_name>/probe.json and
|
|
4
|
+
returned from cache on subsequent calls until the TTL expires or force=True.
|
|
5
|
+
|
|
6
|
+
Parsed output is based on confirmed Grex (yak.hpc.umanitoba.ca) format.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
from dataclasses import asdict, dataclass, field
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
from clusterpilot.ssh.connection import run_remote
|
|
17
|
+
|
|
18
|
+
_CACHE_ROOT = Path.home() / ".cache" / "clusterpilot"
|
|
19
|
+
_CACHE_TTL = 24 * 3600 # seconds
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ── Data classes ──────────────────────────────────────────────────────────────
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class PartitionInfo:
|
|
26
|
+
name: str
|
|
27
|
+
max_time: str # e.g. "7-00:00:00" or "21-00:00:00"
|
|
28
|
+
gres: str # e.g. "gpu:v100:4" or "" for CPU-only
|
|
29
|
+
nodes: int
|
|
30
|
+
is_default: bool
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class ClusterProbe:
|
|
35
|
+
cluster_name: str
|
|
36
|
+
probed_at: float # Unix timestamp
|
|
37
|
+
partitions: list[PartitionInfo]
|
|
38
|
+
julia_versions: list[str] # e.g. ["julia/1.10.3", "julia/1.11.3"]
|
|
39
|
+
accounts: list[str] # e.g. ["def-stamps"]
|
|
40
|
+
account_max_wall: dict[str, str] # account → max walltime, "" = no limit
|
|
41
|
+
python_versions: list[str] = field(default_factory=list) # e.g. ["python/3.11.5"]
|
|
42
|
+
|
|
43
|
+
def gpu_partitions(self) -> list[PartitionInfo]:
|
|
44
|
+
"""Return partitions that have GPU GRES."""
|
|
45
|
+
return [p for p in self.partitions if p.gres.startswith("gpu:")]
|
|
46
|
+
|
|
47
|
+
def cpu_partitions(self) -> list[PartitionInfo]:
|
|
48
|
+
"""Return CPU-only partitions."""
|
|
49
|
+
return [p for p in self.partitions if not p.gres]
|
|
50
|
+
|
|
51
|
+
def default_partition(self) -> PartitionInfo | None:
|
|
52
|
+
for p in self.partitions:
|
|
53
|
+
if p.is_default:
|
|
54
|
+
return p
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ── Public API ────────────────────────────────────────────────────────────────
|
|
59
|
+
|
|
60
|
+
def load_cache(cluster_name: str) -> ClusterProbe | None:
|
|
61
|
+
"""Return cached probe if it exists and is younger than 24h, else None."""
|
|
62
|
+
path = _cache_path(cluster_name)
|
|
63
|
+
if not path.exists():
|
|
64
|
+
return None
|
|
65
|
+
try:
|
|
66
|
+
data = json.loads(path.read_text())
|
|
67
|
+
if time.time() - data["probed_at"] > _CACHE_TTL:
|
|
68
|
+
return None
|
|
69
|
+
return _from_dict(data)
|
|
70
|
+
except (KeyError, ValueError):
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def save_cache(probe: ClusterProbe) -> None:
|
|
75
|
+
"""Write probe data to ~/.cache/clusterpilot/<cluster>/probe.json."""
|
|
76
|
+
path = _cache_path(probe.cluster_name)
|
|
77
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
path.write_text(json.dumps(asdict(probe), indent=2))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def probe_cluster(
|
|
82
|
+
cluster_name: str,
|
|
83
|
+
host: str,
|
|
84
|
+
user: str,
|
|
85
|
+
*,
|
|
86
|
+
force: bool = False,
|
|
87
|
+
) -> ClusterProbe:
|
|
88
|
+
"""Query sinfo, module avail, and sacctmgr on host.
|
|
89
|
+
|
|
90
|
+
Returns cached data if < 24h old (unless force=True).
|
|
91
|
+
Saves fresh results to cache before returning.
|
|
92
|
+
|
|
93
|
+
Requires an active SSH ControlMaster socket (call open_connection first).
|
|
94
|
+
"""
|
|
95
|
+
if not force:
|
|
96
|
+
cached = load_cache(cluster_name)
|
|
97
|
+
if cached is not None:
|
|
98
|
+
return cached
|
|
99
|
+
|
|
100
|
+
sinfo_out, julia_out, python_out, sacctmgr_out = await _fetch_all(host, user)
|
|
101
|
+
|
|
102
|
+
result = ClusterProbe(
|
|
103
|
+
cluster_name=cluster_name,
|
|
104
|
+
probed_at=time.time(),
|
|
105
|
+
partitions=_parse_sinfo(sinfo_out),
|
|
106
|
+
julia_versions=_parse_julia_modules(julia_out),
|
|
107
|
+
python_versions=_parse_python_modules(python_out),
|
|
108
|
+
accounts=_parse_accounts(sacctmgr_out),
|
|
109
|
+
account_max_wall=_parse_max_wall(sacctmgr_out),
|
|
110
|
+
)
|
|
111
|
+
save_cache(result)
|
|
112
|
+
return result
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# ── Remote fetching ───────────────────────────────────────────────────────────
|
|
116
|
+
|
|
117
|
+
async def _fetch_all(host: str, user: str) -> tuple[str, str, str, str]:
|
|
118
|
+
"""Run all four probe commands concurrently."""
|
|
119
|
+
return await asyncio.gather(
|
|
120
|
+
run_remote(host, user, "sinfo -o '%P %l %G %D' --noheader"),
|
|
121
|
+
run_remote(host, user, "module avail julia 2>&1"),
|
|
122
|
+
run_remote(host, user, "module avail python 2>&1"),
|
|
123
|
+
run_remote(
|
|
124
|
+
host, user,
|
|
125
|
+
f"sacctmgr show user {user} withassoc "
|
|
126
|
+
f"format=account,maxjobs,maxwall -p --noheader",
|
|
127
|
+
),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# ── Parsers ───────────────────────────────────────────────────────────────────
|
|
132
|
+
|
|
133
|
+
def _parse_sinfo(output: str) -> list[PartitionInfo]:
|
|
134
|
+
"""Parse `sinfo -o '%P %l %G %D' --noheader` output.
|
|
135
|
+
|
|
136
|
+
Example line: "stamps 21-00:00:00 gpu:v100:4(S:0-1) 3"
|
|
137
|
+
"""
|
|
138
|
+
partitions = []
|
|
139
|
+
for line in output.splitlines():
|
|
140
|
+
parts = line.split()
|
|
141
|
+
if len(parts) < 4:
|
|
142
|
+
continue
|
|
143
|
+
name_raw, max_time, gres_raw, nodes_str = parts[0], parts[1], parts[2], parts[3]
|
|
144
|
+
is_default = name_raw.endswith("*")
|
|
145
|
+
name = name_raw.rstrip("*")
|
|
146
|
+
# Strip socket-affinity suffix: "gpu:v100:4(S:0-1)" → "gpu:v100:4"
|
|
147
|
+
gres = gres_raw.split("(")[0] if gres_raw != "(null)" else ""
|
|
148
|
+
try:
|
|
149
|
+
nodes = int(nodes_str)
|
|
150
|
+
except ValueError:
|
|
151
|
+
nodes = 0
|
|
152
|
+
partitions.append(PartitionInfo(
|
|
153
|
+
name=name,
|
|
154
|
+
max_time=max_time,
|
|
155
|
+
gres=gres,
|
|
156
|
+
nodes=nodes,
|
|
157
|
+
is_default=is_default,
|
|
158
|
+
))
|
|
159
|
+
return partitions
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _parse_julia_modules(output: str) -> list[str]:
|
|
163
|
+
"""Extract julia/X.Y.Z tokens from `module avail julia 2>&1` output.
|
|
164
|
+
|
|
165
|
+
Example: " julia/1.10.3 julia/1.11.3 (D)"
|
|
166
|
+
Tokens "(D)" are separate and filtered naturally by the startswith check.
|
|
167
|
+
"""
|
|
168
|
+
versions: set[str] = set()
|
|
169
|
+
for line in output.splitlines():
|
|
170
|
+
for token in line.split():
|
|
171
|
+
if token.startswith("julia/"):
|
|
172
|
+
versions.add(token)
|
|
173
|
+
return sorted(versions)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _parse_python_modules(output: str) -> list[str]:
|
|
177
|
+
"""Extract python/X.Y.Z tokens from `module avail python 2>&1` output."""
|
|
178
|
+
versions: set[str] = set()
|
|
179
|
+
for line in output.splitlines():
|
|
180
|
+
for token in line.split():
|
|
181
|
+
if token.startswith("python/"):
|
|
182
|
+
versions.add(token)
|
|
183
|
+
return sorted(versions)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _parse_accounts(output: str) -> list[str]:
|
|
187
|
+
"""Extract account names from pipe-delimited sacctmgr output."""
|
|
188
|
+
accounts = []
|
|
189
|
+
for line in output.splitlines():
|
|
190
|
+
if "|" not in line:
|
|
191
|
+
continue
|
|
192
|
+
account = line.split("|")[0].strip()
|
|
193
|
+
if account and account.lower() != "account":
|
|
194
|
+
accounts.append(account)
|
|
195
|
+
return accounts
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _parse_max_wall(output: str) -> dict[str, str]:
|
|
199
|
+
"""Extract account → max_walltime mapping from sacctmgr output.
|
|
200
|
+
|
|
201
|
+
Empty string means no limit set at the account level.
|
|
202
|
+
"""
|
|
203
|
+
result: dict[str, str] = {}
|
|
204
|
+
for line in output.splitlines():
|
|
205
|
+
parts = line.split("|")
|
|
206
|
+
if len(parts) < 3:
|
|
207
|
+
continue
|
|
208
|
+
account = parts[0].strip()
|
|
209
|
+
max_wall = parts[2].strip()
|
|
210
|
+
if account and account.lower() != "account":
|
|
211
|
+
result[account] = max_wall
|
|
212
|
+
return result
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ── Cache helpers ─────────────────────────────────────────────────────────────
|
|
216
|
+
|
|
217
|
+
def _cache_path(cluster_name: str) -> Path:
|
|
218
|
+
return _CACHE_ROOT / cluster_name / "probe.json"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _from_dict(data: dict) -> ClusterProbe:
|
|
222
|
+
return ClusterProbe(
|
|
223
|
+
cluster_name=data["cluster_name"],
|
|
224
|
+
probed_at=data["probed_at"],
|
|
225
|
+
partitions=[PartitionInfo(**p) for p in data["partitions"]],
|
|
226
|
+
julia_versions=data["julia_versions"],
|
|
227
|
+
python_versions=data.get("python_versions", []), # backwards-compat
|
|
228
|
+
accounts=data["accounts"],
|
|
229
|
+
account_max_wall=data["account_max_wall"],
|
|
230
|
+
)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""SLURM commands: submit, poll status, cancel, fetch log output.
|
|
2
|
+
|
|
3
|
+
All functions require an active SSH ControlMaster socket.
|
|
4
|
+
"""
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
from clusterpilot.ssh.connection import SSHError, run_remote
|
|
10
|
+
|
|
11
|
+
_SUBMITTED_RE = re.compile(r"Submitted batch job (\d+)")
|
|
12
|
+
|
|
13
|
+
# States that mean the job will never run again.
|
|
14
|
+
TERMINAL_STATES = frozenset({
|
|
15
|
+
"COMPLETED", "FAILED", "CANCELLED", "TIMEOUT",
|
|
16
|
+
"OUT_OF_MEMORY", "NODE_FAIL",
|
|
17
|
+
})
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SlurmError(SSHError):
|
|
21
|
+
"""Raised when a SLURM command fails unexpectedly."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ── Job submission ────────────────────────────────────────────────────────────
|
|
25
|
+
|
|
26
|
+
async def submit(
|
|
27
|
+
host: str,
|
|
28
|
+
user: str,
|
|
29
|
+
remote_script_path: str,
|
|
30
|
+
*,
|
|
31
|
+
working_dir: str | None = None,
|
|
32
|
+
) -> str:
|
|
33
|
+
"""Run sbatch on remote_script_path. Returns the numeric job ID string.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
host: SSH hostname.
|
|
37
|
+
user: Remote username.
|
|
38
|
+
remote_script_path: Absolute path to the .sh script on the cluster.
|
|
39
|
+
working_dir: If given, cd here before running sbatch.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
SlurmError: if sbatch output doesn't contain "Submitted batch job NNN".
|
|
43
|
+
"""
|
|
44
|
+
cmd = f"sbatch {remote_script_path}"
|
|
45
|
+
if working_dir:
|
|
46
|
+
cmd = f"cd {working_dir} && {cmd}"
|
|
47
|
+
try:
|
|
48
|
+
output = await run_remote(host, user, cmd)
|
|
49
|
+
except SSHError as exc:
|
|
50
|
+
raise SlurmError(f"sbatch failed: {exc}") from exc
|
|
51
|
+
|
|
52
|
+
match = _SUBMITTED_RE.search(output)
|
|
53
|
+
if not match:
|
|
54
|
+
raise SlurmError(f"Unexpected sbatch output: {output!r}")
|
|
55
|
+
return match.group(1)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ── Status polling ────────────────────────────────────────────────────────────
|
|
59
|
+
|
|
60
|
+
async def job_status(host: str, user: str, job_id: str) -> str | None:
|
|
61
|
+
"""Return the SLURM state for job_id, or None if the job cannot be found.
|
|
62
|
+
|
|
63
|
+
Strategy:
|
|
64
|
+
1. squeue (fast, in-memory) — works while the job is queued or running.
|
|
65
|
+
2. sacct (historical records) — works after the job has left the queue.
|
|
66
|
+
|
|
67
|
+
Common return values: PENDING, RUNNING, COMPLETED, FAILED,
|
|
68
|
+
CANCELLED, TIMEOUT, OUT_OF_MEMORY.
|
|
69
|
+
"""
|
|
70
|
+
# 1. squeue — job still in queue
|
|
71
|
+
try:
|
|
72
|
+
out = await run_remote(
|
|
73
|
+
host, user,
|
|
74
|
+
f"squeue -j {job_id} -h -o '%T' 2>/dev/null",
|
|
75
|
+
)
|
|
76
|
+
state = out.strip()
|
|
77
|
+
if state:
|
|
78
|
+
return state
|
|
79
|
+
except SSHError:
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
# 2. sacct — job already finished; -X = summary record only (no steps)
|
|
83
|
+
try:
|
|
84
|
+
out = await run_remote(
|
|
85
|
+
host, user,
|
|
86
|
+
f"sacct -j {job_id} -n -X -o State --parsable2 2>/dev/null",
|
|
87
|
+
)
|
|
88
|
+
for line in out.strip().splitlines():
|
|
89
|
+
# sacct can append "+" for job-step aggregates; strip it.
|
|
90
|
+
# "CANCELLED by 12345" → "CANCELLED"
|
|
91
|
+
state = line.strip().split("+")[0].split()[0]
|
|
92
|
+
if state:
|
|
93
|
+
return state
|
|
94
|
+
except SSHError:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# ── Job control ───────────────────────────────────────────────────────────────
|
|
101
|
+
|
|
102
|
+
async def cancel(host: str, user: str, job_id: str) -> None:
|
|
103
|
+
"""Cancel a queued or running SLURM job via scancel."""
|
|
104
|
+
try:
|
|
105
|
+
await run_remote(host, user, f"scancel {job_id}")
|
|
106
|
+
except SSHError as exc:
|
|
107
|
+
raise SlurmError(f"scancel failed for job {job_id}: {exc}") from exc
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# ── Log access ────────────────────────────────────────────────────────────────
|
|
111
|
+
|
|
112
|
+
async def tail_log(
|
|
113
|
+
host: str,
|
|
114
|
+
user: str,
|
|
115
|
+
remote_log_path: str,
|
|
116
|
+
n_lines: int = 50,
|
|
117
|
+
) -> str:
|
|
118
|
+
"""Return the last n_lines of a remote file. Empty string if not found."""
|
|
119
|
+
try:
|
|
120
|
+
return await run_remote(
|
|
121
|
+
host, user,
|
|
122
|
+
f"tail -n {n_lines} {remote_log_path} 2>/dev/null",
|
|
123
|
+
)
|
|
124
|
+
except SSHError:
|
|
125
|
+
return ""
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def cat_log(
|
|
129
|
+
host: str,
|
|
130
|
+
user: str,
|
|
131
|
+
remote_log_path: str,
|
|
132
|
+
) -> str:
|
|
133
|
+
"""Return the full contents of a remote log file. Empty string if not found."""
|
|
134
|
+
try:
|
|
135
|
+
return await run_remote(
|
|
136
|
+
host, user,
|
|
137
|
+
f"cat {remote_log_path} 2>/dev/null",
|
|
138
|
+
)
|
|
139
|
+
except SSHError:
|
|
140
|
+
return ""
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
async def find_log(
|
|
144
|
+
host: str,
|
|
145
|
+
user: str,
|
|
146
|
+
job_name: str,
|
|
147
|
+
job_id: str,
|
|
148
|
+
working_dir: str,
|
|
149
|
+
) -> str | None:
|
|
150
|
+
"""Locate the SLURM stdout log for this job on the remote host.
|
|
151
|
+
|
|
152
|
+
Tries common naming patterns in order:
|
|
153
|
+
<working_dir>/<job_name>-<job_id>.out (ClusterPilot default: %x-%j.out)
|
|
154
|
+
<working_dir>/slurm-<job_id>.out (SLURM default)
|
|
155
|
+
<working_dir>/<job_id>.out
|
|
156
|
+
|
|
157
|
+
Returns the first path that exists, or None.
|
|
158
|
+
"""
|
|
159
|
+
candidates = [
|
|
160
|
+
f"{working_dir}/{job_name}-{job_id}.out",
|
|
161
|
+
f"{working_dir}/slurm-{job_id}.out",
|
|
162
|
+
f"{working_dir}/{job_id}.out",
|
|
163
|
+
]
|
|
164
|
+
for path in candidates:
|
|
165
|
+
try:
|
|
166
|
+
out = await run_remote(
|
|
167
|
+
host, user,
|
|
168
|
+
f"test -f {path} && echo exists",
|
|
169
|
+
)
|
|
170
|
+
if out.strip() == "exists":
|
|
171
|
+
return path
|
|
172
|
+
except SSHError:
|
|
173
|
+
continue
|
|
174
|
+
return None
|