clusterpilot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """ClusterPilot — AI-assisted HPC workflow manager."""
2
+
3
+ __version__ = "0.1.0"
@@ -0,0 +1,139 @@
1
+ """ClusterPilot entry point.
2
+
3
+ Usage
4
+ -----
5
+ clusterpilot # launch TUI
6
+ clusterpilot init # create starter ~/.config/clusterpilot/config.toml
7
+ clusterpilot daemon run # run poll daemon in foreground (no TUI)
8
+ clusterpilot daemon install # install systemd user service
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import sys
14
+
15
+
16
+ def main() -> None:
17
+ parser = argparse.ArgumentParser(
18
+ prog="clusterpilot",
19
+ description="AI-assisted HPC workflow manager",
20
+ )
21
+ sub = parser.add_subparsers(dest="cmd")
22
+
23
+ sub.add_parser("init", help="Create starter config at ~/.config/clusterpilot/config.toml")
24
+
25
+ daemon_p = sub.add_parser("daemon", help="Daemon management")
26
+ daemon_sub = daemon_p.add_subparsers(dest="daemon_cmd")
27
+ daemon_sub.add_parser("run", help="Run poll daemon in foreground")
28
+ daemon_sub.add_parser("install", help="Install systemd user service unit")
29
+
30
+ args = parser.parse_args()
31
+
32
+ if args.cmd == "init":
33
+ _cmd_init()
34
+ elif args.cmd == "daemon":
35
+ if args.daemon_cmd == "run":
36
+ _cmd_daemon_run()
37
+ elif args.daemon_cmd == "install":
38
+ _cmd_daemon_install()
39
+ else:
40
+ daemon_p.print_help()
41
+ else:
42
+ _cmd_tui()
43
+
44
+
45
+ # ── Subcommands ───────────────────────────────────────────────────────────────
46
+
47
+ def _cmd_init() -> None:
48
+ from clusterpilot.config import CONFIG_PATH, write_default_config
49
+ if CONFIG_PATH.exists():
50
+ print(f"Config already exists: {CONFIG_PATH}")
51
+ return
52
+ write_default_config()
53
+ print(f"Config written to {CONFIG_PATH}")
54
+ print("Edit it to set your cluster username and account, then run: clusterpilot")
55
+
56
+
57
+ def _cmd_daemon_run() -> None:
58
+ import asyncio
59
+ import aiosqlite
60
+ from clusterpilot.config import ConfigError, load_config
61
+ from clusterpilot.db import DB_PATH, init_db
62
+ from clusterpilot.jobs.daemon import PollDaemon
63
+
64
+ try:
65
+ config = load_config()
66
+ except ConfigError as exc:
67
+ print(f"Error: {exc}", file=sys.stderr)
68
+ sys.exit(1)
69
+
70
+ async def _run() -> None:
71
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
72
+ async with aiosqlite.connect(DB_PATH) as db:
73
+ await init_db(db)
74
+ daemon = PollDaemon(config, DB_PATH)
75
+ await daemon.run_forever()
76
+
77
+ print("ClusterPilot daemon running. Press Ctrl-C to stop.")
78
+ try:
79
+ asyncio.run(_run())
80
+ except KeyboardInterrupt:
81
+ print("\nDaemon stopped.")
82
+
83
+
84
+ def _cmd_daemon_install() -> None:
85
+ from clusterpilot.jobs.daemon import write_service_file
86
+ path = write_service_file()
87
+ print(f"Service file written to: {path}")
88
+ print()
89
+ print("Enable and start with:")
90
+ print(" systemctl --user daemon-reload")
91
+ print(" systemctl --user enable --now clusterpilot-poll.service")
92
+
93
+
94
+ def _cmd_tui() -> None:
95
+ import logging
96
+ from clusterpilot.config import ConfigError, load_config, write_default_config
97
+ from clusterpilot.tui.app import ClusterPilotApp
98
+
99
+ logging.basicConfig(
100
+ level=logging.WARNING,
101
+ format="%(levelname)s %(name)s: %(message)s",
102
+ )
103
+
104
+ try:
105
+ config = load_config()
106
+ except ConfigError:
107
+ # First run — write default config and guide the user.
108
+ write_default_config()
109
+ from clusterpilot.config import CONFIG_PATH
110
+ print("Welcome to ClusterPilot!")
111
+ print()
112
+ print(f"A starter config has been written to:\n {CONFIG_PATH}")
113
+ print()
114
+ print("Edit it to add your cluster username and account,")
115
+ print("then run 'clusterpilot' again.")
116
+ sys.exit(0)
117
+
118
+ if not config.clusters:
119
+ print("No clusters defined in config. Edit ~/.config/clusterpilot/config.toml.")
120
+ sys.exit(1)
121
+
122
+ if not config.api_key:
123
+ import os
124
+ if not os.environ.get("ANTHROPIC_API_KEY"):
125
+ print(
126
+ "Warning: no API key configured. "
127
+ "Script generation will fail.\n"
128
+ "Set api_key in config.toml or export ANTHROPIC_API_KEY."
129
+ )
130
+
131
+ from clusterpilot.db import DB_PATH
132
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
133
+
134
+ app = ClusterPilotApp(config)
135
+ app.run()
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
@@ -0,0 +1,33 @@
1
+ from clusterpilot.cluster.probe import (
2
+ ClusterProbe,
3
+ PartitionInfo,
4
+ load_cache,
5
+ probe_cluster,
6
+ save_cache,
7
+ )
8
+ from clusterpilot.cluster.slurm import (
9
+ TERMINAL_STATES,
10
+ SlurmError,
11
+ cancel,
12
+ find_log,
13
+ job_status,
14
+ submit,
15
+ tail_log,
16
+ )
17
+
18
+ __all__ = [
19
+ # probe
20
+ "ClusterProbe",
21
+ "PartitionInfo",
22
+ "load_cache",
23
+ "probe_cluster",
24
+ "save_cache",
25
+ # slurm
26
+ "TERMINAL_STATES",
27
+ "SlurmError",
28
+ "cancel",
29
+ "find_log",
30
+ "job_status",
31
+ "submit",
32
+ "tail_log",
33
+ ]
@@ -0,0 +1,230 @@
1
+ """Cluster probe: query sinfo, module avail, and sacctmgr; cache 24h.
2
+
3
+ Results are stored in ~/.cache/clusterpilot/<cluster_name>/probe.json and
4
+ returned from cache on subsequent calls until the TTL expires or force=True.
5
+
6
+ Parsed output is based on confirmed Grex (yak.hpc.umanitoba.ca) format.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import json
12
+ import time
13
+ from dataclasses import asdict, dataclass, field
14
+ from pathlib import Path
15
+
16
+ from clusterpilot.ssh.connection import run_remote
17
+
18
+ _CACHE_ROOT = Path.home() / ".cache" / "clusterpilot"
19
+ _CACHE_TTL = 24 * 3600 # seconds
20
+
21
+
22
+ # ── Data classes ──────────────────────────────────────────────────────────────
23
+
24
+ @dataclass
25
+ class PartitionInfo:
26
+ name: str
27
+ max_time: str # e.g. "7-00:00:00" or "21-00:00:00"
28
+ gres: str # e.g. "gpu:v100:4" or "" for CPU-only
29
+ nodes: int
30
+ is_default: bool
31
+
32
+
33
+ @dataclass
34
+ class ClusterProbe:
35
+ cluster_name: str
36
+ probed_at: float # Unix timestamp
37
+ partitions: list[PartitionInfo]
38
+ julia_versions: list[str] # e.g. ["julia/1.10.3", "julia/1.11.3"]
39
+ accounts: list[str] # e.g. ["def-stamps"]
40
+ account_max_wall: dict[str, str] # account → max walltime, "" = no limit
41
+ python_versions: list[str] = field(default_factory=list) # e.g. ["python/3.11.5"]
42
+
43
+ def gpu_partitions(self) -> list[PartitionInfo]:
44
+ """Return partitions that have GPU GRES."""
45
+ return [p for p in self.partitions if p.gres.startswith("gpu:")]
46
+
47
+ def cpu_partitions(self) -> list[PartitionInfo]:
48
+ """Return CPU-only partitions."""
49
+ return [p for p in self.partitions if not p.gres]
50
+
51
+ def default_partition(self) -> PartitionInfo | None:
52
+ for p in self.partitions:
53
+ if p.is_default:
54
+ return p
55
+ return None
56
+
57
+
58
+ # ── Public API ────────────────────────────────────────────────────────────────
59
+
60
+ def load_cache(cluster_name: str) -> ClusterProbe | None:
61
+ """Return cached probe if it exists and is younger than 24h, else None."""
62
+ path = _cache_path(cluster_name)
63
+ if not path.exists():
64
+ return None
65
+ try:
66
+ data = json.loads(path.read_text())
67
+ if time.time() - data["probed_at"] > _CACHE_TTL:
68
+ return None
69
+ return _from_dict(data)
70
+ except (KeyError, ValueError):
71
+ return None
72
+
73
+
74
+ def save_cache(probe: ClusterProbe) -> None:
75
+ """Write probe data to ~/.cache/clusterpilot/<cluster>/probe.json."""
76
+ path = _cache_path(probe.cluster_name)
77
+ path.parent.mkdir(parents=True, exist_ok=True)
78
+ path.write_text(json.dumps(asdict(probe), indent=2))
79
+
80
+
81
+ async def probe_cluster(
82
+ cluster_name: str,
83
+ host: str,
84
+ user: str,
85
+ *,
86
+ force: bool = False,
87
+ ) -> ClusterProbe:
88
+ """Query sinfo, module avail, and sacctmgr on host.
89
+
90
+ Returns cached data if < 24h old (unless force=True).
91
+ Saves fresh results to cache before returning.
92
+
93
+ Requires an active SSH ControlMaster socket (call open_connection first).
94
+ """
95
+ if not force:
96
+ cached = load_cache(cluster_name)
97
+ if cached is not None:
98
+ return cached
99
+
100
+ sinfo_out, julia_out, python_out, sacctmgr_out = await _fetch_all(host, user)
101
+
102
+ result = ClusterProbe(
103
+ cluster_name=cluster_name,
104
+ probed_at=time.time(),
105
+ partitions=_parse_sinfo(sinfo_out),
106
+ julia_versions=_parse_julia_modules(julia_out),
107
+ python_versions=_parse_python_modules(python_out),
108
+ accounts=_parse_accounts(sacctmgr_out),
109
+ account_max_wall=_parse_max_wall(sacctmgr_out),
110
+ )
111
+ save_cache(result)
112
+ return result
113
+
114
+
115
+ # ── Remote fetching ───────────────────────────────────────────────────────────
116
+
117
+ async def _fetch_all(host: str, user: str) -> tuple[str, str, str, str]:
118
+ """Run all four probe commands concurrently."""
119
+ return await asyncio.gather(
120
+ run_remote(host, user, "sinfo -o '%P %l %G %D' --noheader"),
121
+ run_remote(host, user, "module avail julia 2>&1"),
122
+ run_remote(host, user, "module avail python 2>&1"),
123
+ run_remote(
124
+ host, user,
125
+ f"sacctmgr show user {user} withassoc "
126
+ f"format=account,maxjobs,maxwall -p --noheader",
127
+ ),
128
+ )
129
+
130
+
131
+ # ── Parsers ───────────────────────────────────────────────────────────────────
132
+
133
+ def _parse_sinfo(output: str) -> list[PartitionInfo]:
134
+ """Parse `sinfo -o '%P %l %G %D' --noheader` output.
135
+
136
+ Example line: "stamps 21-00:00:00 gpu:v100:4(S:0-1) 3"
137
+ """
138
+ partitions = []
139
+ for line in output.splitlines():
140
+ parts = line.split()
141
+ if len(parts) < 4:
142
+ continue
143
+ name_raw, max_time, gres_raw, nodes_str = parts[0], parts[1], parts[2], parts[3]
144
+ is_default = name_raw.endswith("*")
145
+ name = name_raw.rstrip("*")
146
+ # Strip socket-affinity suffix: "gpu:v100:4(S:0-1)" → "gpu:v100:4"
147
+ gres = gres_raw.split("(")[0] if gres_raw != "(null)" else ""
148
+ try:
149
+ nodes = int(nodes_str)
150
+ except ValueError:
151
+ nodes = 0
152
+ partitions.append(PartitionInfo(
153
+ name=name,
154
+ max_time=max_time,
155
+ gres=gres,
156
+ nodes=nodes,
157
+ is_default=is_default,
158
+ ))
159
+ return partitions
160
+
161
+
162
+ def _parse_julia_modules(output: str) -> list[str]:
163
+ """Extract julia/X.Y.Z tokens from `module avail julia 2>&1` output.
164
+
165
+ Example: " julia/1.10.3 julia/1.11.3 (D)"
166
+ Tokens "(D)" are separate and filtered naturally by the startswith check.
167
+ """
168
+ versions: set[str] = set()
169
+ for line in output.splitlines():
170
+ for token in line.split():
171
+ if token.startswith("julia/"):
172
+ versions.add(token)
173
+ return sorted(versions)
174
+
175
+
176
+ def _parse_python_modules(output: str) -> list[str]:
177
+ """Extract python/X.Y.Z tokens from `module avail python 2>&1` output."""
178
+ versions: set[str] = set()
179
+ for line in output.splitlines():
180
+ for token in line.split():
181
+ if token.startswith("python/"):
182
+ versions.add(token)
183
+ return sorted(versions)
184
+
185
+
186
+ def _parse_accounts(output: str) -> list[str]:
187
+ """Extract account names from pipe-delimited sacctmgr output."""
188
+ accounts = []
189
+ for line in output.splitlines():
190
+ if "|" not in line:
191
+ continue
192
+ account = line.split("|")[0].strip()
193
+ if account and account.lower() != "account":
194
+ accounts.append(account)
195
+ return accounts
196
+
197
+
198
+ def _parse_max_wall(output: str) -> dict[str, str]:
199
+ """Extract account → max_walltime mapping from sacctmgr output.
200
+
201
+ Empty string means no limit set at the account level.
202
+ """
203
+ result: dict[str, str] = {}
204
+ for line in output.splitlines():
205
+ parts = line.split("|")
206
+ if len(parts) < 3:
207
+ continue
208
+ account = parts[0].strip()
209
+ max_wall = parts[2].strip()
210
+ if account and account.lower() != "account":
211
+ result[account] = max_wall
212
+ return result
213
+
214
+
215
+ # ── Cache helpers ─────────────────────────────────────────────────────────────
216
+
217
+ def _cache_path(cluster_name: str) -> Path:
218
+ return _CACHE_ROOT / cluster_name / "probe.json"
219
+
220
+
221
+ def _from_dict(data: dict) -> ClusterProbe:
222
+ return ClusterProbe(
223
+ cluster_name=data["cluster_name"],
224
+ probed_at=data["probed_at"],
225
+ partitions=[PartitionInfo(**p) for p in data["partitions"]],
226
+ julia_versions=data["julia_versions"],
227
+ python_versions=data.get("python_versions", []), # backwards-compat
228
+ accounts=data["accounts"],
229
+ account_max_wall=data["account_max_wall"],
230
+ )
@@ -0,0 +1,174 @@
1
+ """SLURM commands: submit, poll status, cancel, fetch log output.
2
+
3
+ All functions require an active SSH ControlMaster socket.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import re
8
+
9
+ from clusterpilot.ssh.connection import SSHError, run_remote
10
+
11
+ _SUBMITTED_RE = re.compile(r"Submitted batch job (\d+)")
12
+
13
+ # States that mean the job will never run again.
14
+ TERMINAL_STATES = frozenset({
15
+ "COMPLETED", "FAILED", "CANCELLED", "TIMEOUT",
16
+ "OUT_OF_MEMORY", "NODE_FAIL",
17
+ })
18
+
19
+
20
+ class SlurmError(SSHError):
21
+ """Raised when a SLURM command fails unexpectedly."""
22
+
23
+
24
+ # ── Job submission ────────────────────────────────────────────────────────────
25
+
26
+ async def submit(
27
+ host: str,
28
+ user: str,
29
+ remote_script_path: str,
30
+ *,
31
+ working_dir: str | None = None,
32
+ ) -> str:
33
+ """Run sbatch on remote_script_path. Returns the numeric job ID string.
34
+
35
+ Args:
36
+ host: SSH hostname.
37
+ user: Remote username.
38
+ remote_script_path: Absolute path to the .sh script on the cluster.
39
+ working_dir: If given, cd here before running sbatch.
40
+
41
+ Raises:
42
+ SlurmError: if sbatch output doesn't contain "Submitted batch job NNN".
43
+ """
44
+ cmd = f"sbatch {remote_script_path}"
45
+ if working_dir:
46
+ cmd = f"cd {working_dir} && {cmd}"
47
+ try:
48
+ output = await run_remote(host, user, cmd)
49
+ except SSHError as exc:
50
+ raise SlurmError(f"sbatch failed: {exc}") from exc
51
+
52
+ match = _SUBMITTED_RE.search(output)
53
+ if not match:
54
+ raise SlurmError(f"Unexpected sbatch output: {output!r}")
55
+ return match.group(1)
56
+
57
+
58
+ # ── Status polling ────────────────────────────────────────────────────────────
59
+
60
+ async def job_status(host: str, user: str, job_id: str) -> str | None:
61
+ """Return the SLURM state for job_id, or None if the job cannot be found.
62
+
63
+ Strategy:
64
+ 1. squeue (fast, in-memory) — works while the job is queued or running.
65
+ 2. sacct (historical records) — works after the job has left the queue.
66
+
67
+ Common return values: PENDING, RUNNING, COMPLETED, FAILED,
68
+ CANCELLED, TIMEOUT, OUT_OF_MEMORY.
69
+ """
70
+ # 1. squeue — job still in queue
71
+ try:
72
+ out = await run_remote(
73
+ host, user,
74
+ f"squeue -j {job_id} -h -o '%T' 2>/dev/null",
75
+ )
76
+ state = out.strip()
77
+ if state:
78
+ return state
79
+ except SSHError:
80
+ pass
81
+
82
+ # 2. sacct — job already finished; -X = summary record only (no steps)
83
+ try:
84
+ out = await run_remote(
85
+ host, user,
86
+ f"sacct -j {job_id} -n -X -o State --parsable2 2>/dev/null",
87
+ )
88
+ for line in out.strip().splitlines():
89
+ # sacct can append "+" for job-step aggregates; strip it.
90
+ # "CANCELLED by 12345" → "CANCELLED"
91
+ state = line.strip().split("+")[0].split()[0]
92
+ if state:
93
+ return state
94
+ except SSHError:
95
+ pass
96
+
97
+ return None
98
+
99
+
100
+ # ── Job control ───────────────────────────────────────────────────────────────
101
+
102
+ async def cancel(host: str, user: str, job_id: str) -> None:
103
+ """Cancel a queued or running SLURM job via scancel."""
104
+ try:
105
+ await run_remote(host, user, f"scancel {job_id}")
106
+ except SSHError as exc:
107
+ raise SlurmError(f"scancel failed for job {job_id}: {exc}") from exc
108
+
109
+
110
+ # ── Log access ────────────────────────────────────────────────────────────────
111
+
112
+ async def tail_log(
113
+ host: str,
114
+ user: str,
115
+ remote_log_path: str,
116
+ n_lines: int = 50,
117
+ ) -> str:
118
+ """Return the last n_lines of a remote file. Empty string if not found."""
119
+ try:
120
+ return await run_remote(
121
+ host, user,
122
+ f"tail -n {n_lines} {remote_log_path} 2>/dev/null",
123
+ )
124
+ except SSHError:
125
+ return ""
126
+
127
+
128
+ async def cat_log(
129
+ host: str,
130
+ user: str,
131
+ remote_log_path: str,
132
+ ) -> str:
133
+ """Return the full contents of a remote log file. Empty string if not found."""
134
+ try:
135
+ return await run_remote(
136
+ host, user,
137
+ f"cat {remote_log_path} 2>/dev/null",
138
+ )
139
+ except SSHError:
140
+ return ""
141
+
142
+
143
+ async def find_log(
144
+ host: str,
145
+ user: str,
146
+ job_name: str,
147
+ job_id: str,
148
+ working_dir: str,
149
+ ) -> str | None:
150
+ """Locate the SLURM stdout log for this job on the remote host.
151
+
152
+ Tries common naming patterns in order:
153
+ <working_dir>/<job_name>-<job_id>.out (ClusterPilot default: %x-%j.out)
154
+ <working_dir>/slurm-<job_id>.out (SLURM default)
155
+ <working_dir>/<job_id>.out
156
+
157
+ Returns the first path that exists, or None.
158
+ """
159
+ candidates = [
160
+ f"{working_dir}/{job_name}-{job_id}.out",
161
+ f"{working_dir}/slurm-{job_id}.out",
162
+ f"{working_dir}/{job_id}.out",
163
+ ]
164
+ for path in candidates:
165
+ try:
166
+ out = await run_remote(
167
+ host, user,
168
+ f"test -f {path} && echo exists",
169
+ )
170
+ if out.strip() == "exists":
171
+ return path
172
+ except SSHError:
173
+ continue
174
+ return None