py-cluster-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ """Exception hierarchy for cluster_api."""
2
+
3
+
4
+ class ClusterAPIError(Exception):
5
+ """Base exception for all cluster-api errors."""
6
+
7
+
8
+ class CommandTimeoutError(ClusterAPIError):
9
+ """A subprocess command exceeded its timeout."""
10
+
11
+
12
+ class CommandFailedError(ClusterAPIError):
13
+ """A subprocess command returned a non-zero exit code."""
14
+
15
+
16
+ class SubmitError(ClusterAPIError):
17
+ """Failed to submit a job or parse its ID from scheduler output."""
@@ -0,0 +1,36 @@
1
+ """Executor registry."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from ..core import Executor
9
+
10
+ _REGISTRY: dict[str, type[Executor]] = {}
11
+
12
+
13
+ def _ensure_builtins() -> None:
14
+ """Lazily register built-in executors."""
15
+ if "lsf" in _REGISTRY and "local" in _REGISTRY:
16
+ return
17
+ from .lsf import LSFExecutor
18
+ from .local import LocalExecutor
19
+
20
+ _REGISTRY.setdefault("lsf", LSFExecutor)
21
+ _REGISTRY.setdefault("local", LocalExecutor)
22
+
23
+
24
+ def get_executor_class(name: str) -> type[Executor]:
25
+ """Get an executor class by name."""
26
+ _ensure_builtins()
27
+ if name not in _REGISTRY:
28
+ raise ValueError(
29
+ f"Unknown executor: {name!r}. Available: {list(_REGISTRY.keys())}"
30
+ )
31
+ return _REGISTRY[name]
32
+
33
+
34
+ def register_executor(name: str, cls: type[Executor]) -> None:
35
+ """Register a custom executor class."""
36
+ _REGISTRY[name] = cls
@@ -0,0 +1,107 @@
1
+ """Local subprocess executor for testing without a real scheduler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ from datetime import datetime, timezone
9
+ from typing import Any
10
+
11
+ from .._types import JobStatus, ResourceSpec
12
+ from ..config import ClusterConfig
13
+ from ..core import Executor
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class LocalExecutor(Executor):
19
+ """Runs jobs as local bash subprocesses. Useful for testing."""
20
+
21
+ submit_command = "bash"
22
+ cancel_command = "kill"
23
+ status_command = "ps"
24
+ directive_prefix = "# LOCAL"
25
+
26
+ def __init__(self, config: ClusterConfig) -> None:
27
+ super().__init__(config)
28
+ self._processes: dict[str, asyncio.subprocess.Process] = {}
29
+ self._next_id = 1
30
+
31
+ def build_header(
32
+ self, name: str, resources: ResourceSpec | None = None
33
+ ) -> list[str]:
34
+ """Local executor doesn't need scheduler directives."""
35
+ return [f"# LOCAL Job: {name}"]
36
+
37
+ async def _submit_job(
38
+ self,
39
+ script_path: str,
40
+ name: str,
41
+ env: dict[str, str] | None = None,
42
+ ) -> str:
43
+ """Run script as a background subprocess."""
44
+ full_env = {**os.environ, **(env or {})}
45
+
46
+ proc = await asyncio.create_subprocess_exec(
47
+ "bash", script_path,
48
+ stdout=asyncio.subprocess.PIPE,
49
+ stderr=asyncio.subprocess.PIPE,
50
+ env=full_env,
51
+ )
52
+
53
+ job_id = str(self._next_id)
54
+ self._next_id += 1
55
+ self._processes[job_id] = proc
56
+ return job_id
57
+
58
+ def _build_status_args(self) -> list[str]:
59
+ # Not used for local executor; poll() is overridden
60
+ return []
61
+
62
+ def _parse_job_statuses(
63
+ self, output: str
64
+ ) -> dict[str, tuple[JobStatus, dict[str, Any]]]:
65
+ # Not used for local executor; poll() is overridden
66
+ return {}
67
+
68
+ async def poll(self) -> dict[str, JobStatus]:
69
+ """Check subprocess return codes."""
70
+ for job_id, record in self._jobs.items():
71
+ if record.is_terminal:
72
+ continue
73
+
74
+ proc = self._processes.get(job_id)
75
+ if proc is None:
76
+ continue
77
+
78
+ if proc.returncode is not None:
79
+ # Process finished
80
+ now = datetime.now(timezone.utc)
81
+ record.finish_time = now
82
+ record._last_seen = now
83
+ if proc.returncode == 0:
84
+ record.status = JobStatus.DONE
85
+ record.exit_code = 0
86
+ else:
87
+ record.status = JobStatus.FAILED
88
+ record.exit_code = proc.returncode
89
+ else:
90
+ record.status = JobStatus.RUNNING
91
+ record._last_seen = datetime.now(timezone.utc)
92
+
93
+ return {jid: r.status for jid, r in self._jobs.items()}
94
+
95
+ async def cancel(self, job_id: str) -> None:
96
+ """Terminate a local subprocess."""
97
+ proc = self._processes.get(job_id)
98
+ if proc and proc.returncode is None:
99
+ proc.terminate()
100
+ try:
101
+ await asyncio.wait_for(proc.wait(), timeout=5.0)
102
+ except asyncio.TimeoutError:
103
+ proc.kill()
104
+
105
+ if job_id in self._jobs:
106
+ self._jobs[job_id].status = JobStatus.KILLED
107
+ logger.info("Cancelled local job %s", job_id)
@@ -0,0 +1,307 @@
1
+ """LSF executor using bsub/bjobs/bkill."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import fnmatch
6
+ import json
7
+ import logging
8
+ import math
9
+ import re
10
+ from datetime import datetime, timezone
11
+ from typing import Any
12
+
13
+ from .._types import JobStatus, ResourceSpec
14
+ from ..config import ClusterConfig, parse_memory_bytes
15
+ from ..core import Executor
16
+ from ..exceptions import ClusterAPIError
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ _LSF_STATUS_MAP: dict[str, JobStatus] = {
22
+ "PEND": JobStatus.PENDING,
23
+ "RUN": JobStatus.RUNNING,
24
+ "DONE": JobStatus.DONE,
25
+ "EXIT": JobStatus.FAILED,
26
+ "ZOMBI": JobStatus.FAILED,
27
+ "USUSP": JobStatus.PENDING,
28
+ "PSUSP": JobStatus.PENDING,
29
+ "SSUSP": JobStatus.PENDING,
30
+ }
31
+
32
+ _BJOBS_FIELDS = (
33
+ "jobid stat exit_code exec_host max_mem "
34
+ "submit_time start_time finish_time"
35
+ )
36
+
37
+
38
+ def lsf_format_bytes_ceil(n_bytes: int, lsf_units: str = "MB") -> str:
39
+ """Format bytes into LSF memory units, rounding up.
40
+
41
+ Inspired by dask-jobqueue's lsf_format_bytes_ceil.
42
+ """
43
+ units = {"KB": 1024, "MB": 1024**2, "GB": 1024**3, "TB": 1024**4}
44
+ if lsf_units not in units:
45
+ raise ValueError(f"Unknown LSF units: {lsf_units}")
46
+ return str(math.ceil(n_bytes / units[lsf_units]))
47
+
48
+
49
+ async def lsf_detect_units(
50
+ timeout: float = 100.0,
51
+ ) -> str:
52
+ """Detect LSF memory units from lsadmin output.
53
+
54
+ Inspired by dask-jobqueue's approach.
55
+ """
56
+ try:
57
+ out = await Executor._call(
58
+ ["lsadmin", "showconf", "lim"],
59
+ timeout=timeout,
60
+ )
61
+ for line in out.splitlines():
62
+ if "LSF_UNIT_FOR_LIMITS" in line:
63
+ return line.split("=")[-1].strip().upper()
64
+ except (ClusterAPIError, OSError):
65
+ pass
66
+ return "KB" # LSF default
67
+
68
+
69
+ class LSFExecutor(Executor):
70
+ """LSF executor using bsub, bjobs, bkill."""
71
+
72
+ submit_command = "bsub"
73
+ cancel_command = "bkill"
74
+ status_command = "bjobs"
75
+ directive_prefix = "#BSUB"
76
+ job_id_regexp = r"Job <(?P<job_id>\d+)>"
77
+
78
+ def __init__(self, config: ClusterConfig) -> None:
79
+ super().__init__(config)
80
+ self._lsf_units = config.lsf_units
81
+
82
+ def build_header(
83
+ self, name: str, resources: ResourceSpec | None = None
84
+ ) -> list[str]:
85
+ """Build #BSUB directive lines."""
86
+ lines: list[str] = []
87
+ p = self.directive_prefix
88
+
89
+ lines.append(f"{p} -J {name}")
90
+
91
+ log_dir = self.config.log_directory
92
+ lines.append(f"{p} -o {log_dir}/{name}.out")
93
+ lines.append(f"{p} -e {log_dir}/{name}.err")
94
+
95
+ # Queue
96
+ queue = (resources and resources.queue) or self.config.queue
97
+ if queue:
98
+ lines.append(f"{p} -q {queue}")
99
+
100
+ # Account/project
101
+ account = (resources and resources.account) or self.config.account
102
+ if account:
103
+ lines.append(f"{p} -P {account}")
104
+
105
+ # CPUs
106
+ cpus = (resources and resources.cpus) or self.config.cpus
107
+ if cpus:
108
+ lines.append(f"{p} -n {cpus}")
109
+ if cpus > 1:
110
+ lines.append(f'{p} -R "span[hosts=1]"')
111
+
112
+ # GPUs
113
+ gpus = (resources and resources.gpus) or self.config.gpus
114
+ if gpus:
115
+ lines.append(f'{p} -gpu "num={gpus}"')
116
+
117
+ # Memory
118
+ memory_str = (resources and resources.memory) or self.config.memory
119
+ if memory_str:
120
+ mem_bytes = parse_memory_bytes(memory_str)
121
+ mem_val = lsf_format_bytes_ceil(mem_bytes, self._lsf_units)
122
+ lines.append(f"{p} -M {mem_val}")
123
+ lines.append(f'{p} -R "rusage[mem={mem_val}]"')
124
+
125
+ # Walltime
126
+ walltime = (resources and resources.walltime) or self.config.walltime
127
+ if walltime:
128
+ lines.append(f"{p} -W {walltime}")
129
+
130
+ # Working directory
131
+ work_dir = resources and resources.work_dir
132
+ if work_dir:
133
+ lines.append(f"{p} -cwd {work_dir}")
134
+
135
+ # Custom cluster options
136
+ if resources and resources.cluster_options:
137
+ for opt in resources.cluster_options:
138
+ lines.append(f"{p} {opt}")
139
+
140
+ return lines
141
+
142
+ def _build_submit_env(self, env: dict[str, str] | None) -> dict[str, str] | None:
143
+ """Build environment dict for bsub, applying email suppression."""
144
+ submit_env = dict(env) if env else {}
145
+ if self.config.suppress_job_email:
146
+ submit_env["LSB_JOB_REPORT_MAIL"] = "N"
147
+ return submit_env or None
148
+
149
+ async def _bsub(
150
+ self, script_path: str, content: str | None, env: dict[str, str] | None,
151
+ ) -> str:
152
+ """Run bsub via stdin or file and return raw output."""
153
+ submit_env = self._build_submit_env(env)
154
+ if self.config.use_stdin:
155
+ if content is None:
156
+ with open(script_path) as f:
157
+ content = f.read()
158
+ return await self._call(
159
+ [self.submit_command],
160
+ env=submit_env,
161
+ timeout=self.config.command_timeout,
162
+ stdin_data=content,
163
+ )
164
+ return await self._call(
165
+ [self.submit_command, script_path],
166
+ env=submit_env,
167
+ timeout=self.config.command_timeout,
168
+ )
169
+
170
+ async def _submit_job(
171
+ self,
172
+ script_path: str,
173
+ name: str,
174
+ env: dict[str, str] | None = None,
175
+ ) -> str:
176
+ """Submit via bsub with stdin mode support."""
177
+ out = await self._bsub(script_path, None, env)
178
+ return self._job_id_from_submit_output(out)
179
+
180
+ async def _submit_array_job(
181
+ self,
182
+ script_path: str,
183
+ name: str,
184
+ array_range: tuple[int, int],
185
+ env: dict[str, str] | None = None,
186
+ max_concurrent: int | None = None,
187
+ ) -> str:
188
+ """Submit an array job with -J 'name[start-end]'."""
189
+ array_spec = f"{array_range[0]}-{array_range[1]}"
190
+ if max_concurrent is not None:
191
+ array_spec += f"%{max_concurrent}"
192
+ array_name = f"{name}[{array_spec}]"
193
+
194
+ # Rewrite only #BSUB directive lines for array syntax
195
+ with open(script_path) as f:
196
+ lines = f.readlines()
197
+ new_lines = []
198
+ for line in lines:
199
+ if line.startswith(self.directive_prefix):
200
+ line = line.replace(f"-J {name}", f"-J {array_name}")
201
+ line = line.replace(f"{name}.out", f"{name}.%I.out")
202
+ line = line.replace(f"{name}.err", f"{name}.%I.err")
203
+ new_lines.append(line)
204
+ content = "".join(new_lines)
205
+ with open(script_path, "w") as f:
206
+ f.write(content)
207
+
208
+ out = await self._bsub(script_path, content, env)
209
+ return self._job_id_from_submit_output(out)
210
+
211
+ def _build_status_args(self) -> list[str]:
212
+ """Build bjobs command with JSON output."""
213
+ prefix = self._prefix
214
+ args = [
215
+ self.status_command,
216
+ "-J", f"{prefix}-*",
217
+ "-a",
218
+ "-o", _BJOBS_FIELDS,
219
+ "-json",
220
+ ]
221
+ return args
222
+
223
+ def _parse_job_statuses(
224
+ self, output: str
225
+ ) -> dict[str, tuple[JobStatus, dict[str, Any]]]:
226
+ """Parse bjobs JSON output into status + metadata dicts."""
227
+ result: dict[str, tuple[JobStatus, dict[str, Any]]] = {}
228
+
229
+ if not output.strip():
230
+ return result
231
+
232
+ try:
233
+ data = json.loads(output)
234
+ except json.JSONDecodeError:
235
+ logger.warning("Failed to parse bjobs JSON output")
236
+ return result
237
+
238
+ records = data.get("RECORDS", [])
239
+ for rec in records:
240
+ job_id = str(rec.get("JOBID", "")).strip()
241
+ if not job_id:
242
+ continue
243
+
244
+ stat = rec.get("STAT", "").strip()
245
+ status = _LSF_STATUS_MAP.get(stat, JobStatus.UNKNOWN)
246
+
247
+ exit_code_str = str(rec.get("EXIT_CODE", "")).strip()
248
+ exit_code = None
249
+ if exit_code_str and exit_code_str != "-":
250
+ try:
251
+ exit_code = int(exit_code_str)
252
+ except ValueError:
253
+ pass
254
+ # LSF returns "" for exit_code on DONE jobs — infer 0
255
+ if exit_code is None and status == JobStatus.DONE:
256
+ exit_code = 0
257
+
258
+ meta: dict[str, Any] = {
259
+ "exec_host": _clean_field(rec.get("EXEC_HOST")),
260
+ "max_mem": _clean_field(rec.get("MAX_MEM")),
261
+ "exit_code": exit_code,
262
+ "submit_time": _parse_lsf_time(rec.get("SUBMIT_TIME")),
263
+ "start_time": _parse_lsf_time(rec.get("START_TIME")),
264
+ "finish_time": _parse_lsf_time(rec.get("FINISH_TIME")),
265
+ }
266
+
267
+ result[job_id] = (status, meta)
268
+
269
+ return result
270
+
271
+ async def cancel_by_name(self, name_pattern: str) -> None:
272
+ """Cancel jobs matching name pattern via bkill -J."""
273
+ await self._call(
274
+ [self.cancel_command, "-J", name_pattern],
275
+ timeout=self.config.command_timeout,
276
+ )
277
+ # Update in-memory state for matching jobs
278
+ for record in self._jobs.values():
279
+ if not record.is_terminal and fnmatch.fnmatch(record.name, name_pattern):
280
+ record.status = JobStatus.KILLED
281
+ logger.info("Cancelled jobs matching %s", name_pattern)
282
+
283
+
284
+ def _clean_field(value: Any) -> str | None:
285
+ """Clean a bjobs field value, returning None for empty/dash values."""
286
+ if value is None:
287
+ return None
288
+ s = str(value).strip()
289
+ if s in ("", "-"):
290
+ return None
291
+ return s
292
+
293
+
294
+ def _parse_lsf_time(value: Any) -> datetime | None:
295
+ """Parse an LSF timestamp string."""
296
+ s = _clean_field(value)
297
+ if s is None:
298
+ return None
299
+ # Strip trailing timezone indicator (e.g. " L" in "Feb 8 10:31:17 2026 L")
300
+ s = re.sub(r"\s+[A-Z]$", "", s)
301
+ # LSF timestamps are typically like "Jan 1 12:00:00 2024"
302
+ for fmt in ("%b %d %H:%M:%S %Y", "%b %d %H:%M:%S %Y", "%Y/%m/%d-%H:%M:%S"):
303
+ try:
304
+ return datetime.strptime(s, fmt).replace(tzinfo=timezone.utc)
305
+ except ValueError:
306
+ continue
307
+ return None
cluster_api/monitor.py ADDED
@@ -0,0 +1,168 @@
1
+ """Async polling loop and callback dispatch."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import inspect
7
+ import logging
8
+ from datetime import datetime, timezone
9
+ from typing import TYPE_CHECKING
10
+
11
+ from ._types import JobExitCondition, JobRecord, JobStatus
12
+
13
+ if TYPE_CHECKING:
14
+ from .core import Executor
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class JobMonitor:
20
+ """Monitors job status via polling and dispatches callbacks."""
21
+
22
+ def __init__(self, executor: Executor, poll_interval: float | None = None) -> None:
23
+ self.executor = executor
24
+ self.poll_interval = poll_interval or executor.config.poll_interval
25
+ self._task: asyncio.Task | None = None
26
+ self._stopped = asyncio.Event()
27
+ self._completion_events: dict[str, asyncio.Event] = {}
28
+
29
+ async def start(self) -> None:
30
+ """Start the polling loop."""
31
+ self._stopped.clear()
32
+ self._task = asyncio.create_task(self._poll_loop())
33
+ logger.info("Monitor started (interval=%.1fs)", self.poll_interval)
34
+
35
+ async def stop(self) -> None:
36
+ """Stop the polling loop and wait for it to finish."""
37
+ self._stopped.set()
38
+ if self._task:
39
+ try:
40
+ await self._task
41
+ except asyncio.CancelledError:
42
+ pass
43
+ self._task = None
44
+ logger.info("Monitor stopped")
45
+
46
+ async def _poll_loop(self) -> None:
47
+ """Main polling loop."""
48
+ while not self._stopped.is_set():
49
+ try:
50
+ await self.executor.poll()
51
+ await self._dispatch_all_callbacks()
52
+ await self._check_zombies()
53
+ self._notify_waiters()
54
+ await self._purge_completed()
55
+ except Exception:
56
+ logger.exception("Error in poll loop")
57
+
58
+ try:
59
+ await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval)
60
+ break # stopped was set
61
+ except asyncio.TimeoutError:
62
+ pass # normal: just means poll_interval elapsed
63
+
64
+ async def _dispatch_all_callbacks(self) -> None:
65
+ """Check all jobs for pending callbacks."""
66
+ for record in list(self.executor.jobs.values()):
67
+ if record.is_terminal and record._callbacks:
68
+ await self._dispatch_callbacks(record)
69
+
70
+ async def _dispatch_callbacks(self, record: JobRecord) -> None:
71
+ """Fire matching callbacks for a terminal job, then clear them."""
72
+ condition_map = {
73
+ JobStatus.DONE: JobExitCondition.SUCCESS,
74
+ JobStatus.FAILED: JobExitCondition.FAILURE,
75
+ JobStatus.KILLED: JobExitCondition.KILLED,
76
+ }
77
+ job_condition = condition_map.get(record.status)
78
+
79
+ fired: list[int] = []
80
+ for i, (condition, callback) in enumerate(record._callbacks):
81
+ if condition == JobExitCondition.ANY or condition == job_condition:
82
+ try:
83
+ result = callback(record)
84
+ if inspect.isawaitable(result):
85
+ await result
86
+ except Exception:
87
+ logger.exception(
88
+ "Callback error for job %s", record.job_id
89
+ )
90
+ fired.append(i)
91
+
92
+ # Remove fired callbacks in reverse order to preserve indices
93
+ for i in reversed(fired):
94
+ record._callbacks.pop(i)
95
+
96
+ async def _check_zombies(self) -> None:
97
+ """Mark jobs as FAILED if not seen by scheduler for > zombie_timeout."""
98
+ timeout_minutes = self.executor.config.zombie_timeout_minutes
99
+ now = datetime.now(timezone.utc)
100
+
101
+ for record in self.executor.jobs.values():
102
+ if record.is_terminal:
103
+ continue
104
+ if record._last_seen is None:
105
+ continue
106
+ elapsed = (now - record._last_seen).total_seconds() / 60.0
107
+ if elapsed > timeout_minutes:
108
+ logger.warning(
109
+ "Zombie detected: job %s not seen for %.1f minutes",
110
+ record.job_id, elapsed,
111
+ )
112
+ record.status = JobStatus.FAILED
113
+ record.metadata["zombie"] = True
114
+
115
+ async def _purge_completed(self) -> None:
116
+ """Remove terminal jobs older than completed_retention."""
117
+ retention_minutes = self.executor.config.completed_retention_minutes
118
+ now = datetime.now(timezone.utc)
119
+
120
+ to_remove = []
121
+ for job_id, record in self.executor.jobs.items():
122
+ if not record.is_terminal:
123
+ continue
124
+ if not record._callbacks:
125
+ # Use finish_time if available, else _last_seen
126
+ ref_time = record.finish_time or record._last_seen
127
+ if ref_time is not None:
128
+ elapsed = (now - ref_time).total_seconds() / 60.0
129
+ if elapsed > retention_minutes:
130
+ to_remove.append(job_id)
131
+
132
+ for job_id in to_remove:
133
+ self.executor.remove_job(job_id)
134
+ logger.debug("Purged completed job %s", job_id)
135
+
136
+ def _notify_waiters(self) -> None:
137
+ """Set completion events for terminal jobs that have waiters."""
138
+ done = []
139
+ for job_id, event in self._completion_events.items():
140
+ record = self.executor.jobs.get(job_id)
141
+ if record is not None and record.is_terminal:
142
+ event.set()
143
+ done.append(job_id)
144
+ for job_id in done:
145
+ del self._completion_events[job_id]
146
+
147
+ async def wait_for(
148
+ self, *records: JobRecord, timeout: float | None = None
149
+ ) -> None:
150
+ """Wait until all given jobs reach a terminal state."""
151
+ events = []
152
+ for r in records:
153
+ if r.is_terminal:
154
+ continue
155
+ if r.job_id not in self._completion_events:
156
+ self._completion_events[r.job_id] = asyncio.Event()
157
+ events.append(self._completion_events[r.job_id])
158
+
159
+ if not events:
160
+ return
161
+
162
+ async def _wait() -> None:
163
+ await asyncio.gather(*(e.wait() for e in events))
164
+
165
+ if timeout is not None:
166
+ await asyncio.wait_for(_wait(), timeout=timeout)
167
+ else:
168
+ await _wait()