py-cluster-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cluster_api/__init__.py +58 -0
- cluster_api/_types.py +153 -0
- cluster_api/config.py +121 -0
- cluster_api/core.py +375 -0
- cluster_api/exceptions.py +17 -0
- cluster_api/executors/__init__.py +36 -0
- cluster_api/executors/local.py +107 -0
- cluster_api/executors/lsf.py +307 -0
- cluster_api/monitor.py +168 -0
- py_cluster_api-0.1.0.dist-info/METADATA +244 -0
- py_cluster_api-0.1.0.dist-info/RECORD +13 -0
- py_cluster_api-0.1.0.dist-info/WHEEL +4 -0
- py_cluster_api-0.1.0.dist-info/licenses/LICENSE +28 -0
cluster_api/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""py-cluster-api: Generic Python library for running jobs on HPC clusters."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ._types import ArrayElement, JobExitCondition, JobRecord, JobStatus, ResourceSpec
|
|
8
|
+
from .config import ClusterConfig, load_config
|
|
9
|
+
from .core import Executor
|
|
10
|
+
from .exceptions import (
|
|
11
|
+
ClusterAPIError,
|
|
12
|
+
CommandFailedError,
|
|
13
|
+
CommandTimeoutError,
|
|
14
|
+
SubmitError,
|
|
15
|
+
)
|
|
16
|
+
from .executors import get_executor_class
|
|
17
|
+
from .executors.local import LocalExecutor
|
|
18
|
+
from .executors.lsf import LSFExecutor
|
|
19
|
+
from .monitor import JobMonitor
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"create_executor",
|
|
23
|
+
"ArrayElement",
|
|
24
|
+
"Executor",
|
|
25
|
+
"LSFExecutor",
|
|
26
|
+
"LocalExecutor",
|
|
27
|
+
"JobRecord",
|
|
28
|
+
"JobStatus",
|
|
29
|
+
"JobExitCondition",
|
|
30
|
+
"ResourceSpec",
|
|
31
|
+
"ClusterConfig",
|
|
32
|
+
"JobMonitor",
|
|
33
|
+
"load_config",
|
|
34
|
+
"ClusterAPIError",
|
|
35
|
+
"CommandFailedError",
|
|
36
|
+
"CommandTimeoutError",
|
|
37
|
+
"SubmitError",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def create_executor(
|
|
42
|
+
profile: str | None = None,
|
|
43
|
+
config_path: str | None = None,
|
|
44
|
+
**overrides: Any,
|
|
45
|
+
) -> Executor:
|
|
46
|
+
"""Create an executor from config.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
profile: Config profile name to use.
|
|
50
|
+
config_path: Explicit path to config YAML.
|
|
51
|
+
**overrides: Override individual config values.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
An Executor instance configured for the specified backend.
|
|
55
|
+
"""
|
|
56
|
+
config = load_config(path=config_path, profile=profile, overrides=overrides or None)
|
|
57
|
+
cls = get_executor_class(config.executor)
|
|
58
|
+
return cls(config)
|
cluster_api/_types.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Shared types for cluster_api."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import enum
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class JobStatus(enum.Enum):
|
|
12
|
+
"""Status of a cluster job."""
|
|
13
|
+
|
|
14
|
+
PENDING = "pending"
|
|
15
|
+
RUNNING = "running"
|
|
16
|
+
DONE = "done"
|
|
17
|
+
FAILED = "failed"
|
|
18
|
+
KILLED = "killed"
|
|
19
|
+
UNKNOWN = "unknown"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class JobExitCondition(enum.Enum):
|
|
23
|
+
"""Conditions for callback dispatch."""
|
|
24
|
+
|
|
25
|
+
SUCCESS = "success"
|
|
26
|
+
FAILURE = "failure"
|
|
27
|
+
KILLED = "killed"
|
|
28
|
+
ANY = "any"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
_TERMINAL_STATUSES = frozenset({JobStatus.DONE, JobStatus.FAILED, JobStatus.KILLED})
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ResourceSpec:
|
|
36
|
+
"""Resource requirements for a job."""
|
|
37
|
+
|
|
38
|
+
cpus: int | None = None
|
|
39
|
+
gpus: int | None = None
|
|
40
|
+
memory: str | None = None
|
|
41
|
+
walltime: str | None = None
|
|
42
|
+
queue: str | None = None
|
|
43
|
+
account: str | None = None
|
|
44
|
+
work_dir: str | None = None
|
|
45
|
+
cluster_options: list[str] | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ArrayElement:
|
|
50
|
+
"""Tracks a single element of a job array."""
|
|
51
|
+
|
|
52
|
+
index: int
|
|
53
|
+
status: JobStatus = JobStatus.PENDING
|
|
54
|
+
exit_code: int | None = None
|
|
55
|
+
exec_host: str | None = None
|
|
56
|
+
max_mem: str | None = None
|
|
57
|
+
submit_time: datetime | None = None
|
|
58
|
+
start_time: datetime | None = None
|
|
59
|
+
finish_time: datetime | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class JobRecord:
|
|
64
|
+
"""Tracks a submitted job and its metadata."""
|
|
65
|
+
|
|
66
|
+
job_id: str
|
|
67
|
+
name: str
|
|
68
|
+
command: str
|
|
69
|
+
status: JobStatus = JobStatus.PENDING
|
|
70
|
+
exit_code: int | None = None
|
|
71
|
+
resources: ResourceSpec | None = None
|
|
72
|
+
script_path: str | None = None
|
|
73
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
74
|
+
exec_host: str | None = None
|
|
75
|
+
max_mem: str | None = None
|
|
76
|
+
submit_time: datetime | None = None
|
|
77
|
+
start_time: datetime | None = None
|
|
78
|
+
finish_time: datetime | None = None
|
|
79
|
+
array_elements: dict[int, ArrayElement] = field(default_factory=dict)
|
|
80
|
+
_callbacks: list[tuple[JobExitCondition, Callable]] = field(
|
|
81
|
+
default_factory=list, repr=False
|
|
82
|
+
)
|
|
83
|
+
_last_seen: datetime | None = field(default=None, repr=False)
|
|
84
|
+
|
|
85
|
+
def on_exit(self, callback: Callable, condition: JobExitCondition = JobExitCondition.ANY) -> JobRecord:
|
|
86
|
+
"""Register a callback for the given exit condition. Returns self for chaining."""
|
|
87
|
+
self._callbacks.append((condition, callback))
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
def on_success(self, callback: Callable) -> JobRecord:
|
|
91
|
+
"""Register a callback for successful completion."""
|
|
92
|
+
return self.on_exit(callback, JobExitCondition.SUCCESS)
|
|
93
|
+
|
|
94
|
+
def on_failure(self, callback: Callable) -> JobRecord:
|
|
95
|
+
"""Register a callback for failure."""
|
|
96
|
+
return self.on_exit(callback, JobExitCondition.FAILURE)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def is_terminal(self) -> bool:
|
|
100
|
+
"""Whether the job has reached a terminal state."""
|
|
101
|
+
return self.status in _TERMINAL_STATUSES
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def is_array(self) -> bool:
|
|
105
|
+
"""Whether this is an array job."""
|
|
106
|
+
return "array_range" in self.metadata
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def element_count(self) -> int:
|
|
110
|
+
"""Total number of expected array elements (0 if not an array)."""
|
|
111
|
+
if not self.is_array:
|
|
112
|
+
return 0
|
|
113
|
+
start, end = self.metadata["array_range"]
|
|
114
|
+
return end - start + 1
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def completed_elements(self) -> int:
|
|
118
|
+
"""Number of elements in a terminal state."""
|
|
119
|
+
return sum(1 for e in self.array_elements.values() if e.status in _TERMINAL_STATUSES)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def failed_element_indices(self) -> list[int]:
|
|
123
|
+
"""Indices of failed or killed elements."""
|
|
124
|
+
return [
|
|
125
|
+
e.index for e in self.array_elements.values()
|
|
126
|
+
if e.status in {JobStatus.FAILED, JobStatus.KILLED}
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
def compute_array_status(self) -> JobStatus:
|
|
130
|
+
"""Derive overall status from element statuses.
|
|
131
|
+
|
|
132
|
+
Conservative: only returns a terminal status when ALL expected
|
|
133
|
+
elements have been seen and are themselves terminal.
|
|
134
|
+
"""
|
|
135
|
+
if not self.array_elements:
|
|
136
|
+
return self.status
|
|
137
|
+
|
|
138
|
+
statuses = {e.status for e in self.array_elements.values()}
|
|
139
|
+
|
|
140
|
+
# Any non-terminal element → still in progress
|
|
141
|
+
if statuses & {JobStatus.RUNNING, JobStatus.PENDING, JobStatus.UNKNOWN}:
|
|
142
|
+
return JobStatus.RUNNING
|
|
143
|
+
|
|
144
|
+
# All seen elements are terminal — but have we seen them all?
|
|
145
|
+
if len(self.array_elements) < self.element_count:
|
|
146
|
+
return JobStatus.RUNNING
|
|
147
|
+
|
|
148
|
+
# All expected elements accounted for and terminal
|
|
149
|
+
if JobStatus.KILLED in statuses:
|
|
150
|
+
return JobStatus.KILLED
|
|
151
|
+
if JobStatus.FAILED in statuses:
|
|
152
|
+
return JobStatus.FAILED
|
|
153
|
+
return JobStatus.DONE
|
cluster_api/config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""YAML config loader with Nextflow-style profiles."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_MEMORY_UNITS = {
|
|
15
|
+
"B": 1,
|
|
16
|
+
"KB": 1024,
|
|
17
|
+
"MB": 1024**2,
|
|
18
|
+
"GB": 1024**3,
|
|
19
|
+
"TB": 1024**4,
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
_MEMORY_RE = re.compile(r"^\s*([\d.]+)\s*(B|KB|MB|GB|TB)\s*$", re.IGNORECASE)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_memory_bytes(s: str) -> int:
|
|
26
|
+
"""Parse a memory string like '8 GB' into bytes."""
|
|
27
|
+
m = _MEMORY_RE.match(s)
|
|
28
|
+
if not m:
|
|
29
|
+
raise ValueError(f"Cannot parse memory string: {s!r}")
|
|
30
|
+
value = float(m.group(1))
|
|
31
|
+
unit = m.group(2).upper()
|
|
32
|
+
return int(value * _MEMORY_UNITS[unit])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ClusterConfig:
|
|
37
|
+
"""Configuration for cluster job execution."""
|
|
38
|
+
|
|
39
|
+
executor: str = "local"
|
|
40
|
+
cpus: int | None = None
|
|
41
|
+
gpus: int | None = None
|
|
42
|
+
memory: str | None = None
|
|
43
|
+
walltime: str | None = None
|
|
44
|
+
queue: str | None = None
|
|
45
|
+
account: str | None = None
|
|
46
|
+
poll_interval: float = 10.0
|
|
47
|
+
shebang: str = "#!/bin/bash"
|
|
48
|
+
script_prologue: list[str] = field(default_factory=list)
|
|
49
|
+
script_epilogue: list[str] = field(default_factory=list)
|
|
50
|
+
extra_directives: list[str] = field(default_factory=list)
|
|
51
|
+
directives_skip: list[str] = field(default_factory=list)
|
|
52
|
+
log_directory: str = "./logs"
|
|
53
|
+
lsf_units: str = "MB"
|
|
54
|
+
use_stdin: bool = False
|
|
55
|
+
job_name_prefix: str | None = None
|
|
56
|
+
zombie_timeout_minutes: float = 30.0
|
|
57
|
+
completed_retention_minutes: float = 10.0
|
|
58
|
+
command_timeout: float = 100.0
|
|
59
|
+
suppress_job_email: bool = True
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
_CONFIG_SEARCH_PATHS = [
|
|
63
|
+
"cluster_api.yaml",
|
|
64
|
+
"~/.config/cluster_api/config.yaml",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _find_config_path() -> Path | None:
|
|
69
|
+
"""Search for config file in standard locations."""
|
|
70
|
+
env_path = os.environ.get("CLUSTER_API_CONFIG")
|
|
71
|
+
if env_path:
|
|
72
|
+
p = Path(env_path).expanduser()
|
|
73
|
+
if p.exists():
|
|
74
|
+
return p
|
|
75
|
+
|
|
76
|
+
for candidate in _CONFIG_SEARCH_PATHS:
|
|
77
|
+
p = Path(candidate).expanduser()
|
|
78
|
+
if p.exists():
|
|
79
|
+
return p
|
|
80
|
+
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def load_config(
|
|
85
|
+
path: str | Path | None = None,
|
|
86
|
+
profile: str | None = None,
|
|
87
|
+
overrides: dict[str, Any] | None = None,
|
|
88
|
+
) -> ClusterConfig:
|
|
89
|
+
"""Load configuration from YAML with optional profile and overrides.
|
|
90
|
+
|
|
91
|
+
Merges: base config → profile → overrides.
|
|
92
|
+
|
|
93
|
+
Raises FileNotFoundError if an explicit path is given but doesn't exist.
|
|
94
|
+
"""
|
|
95
|
+
raw: dict[str, Any] = {}
|
|
96
|
+
|
|
97
|
+
if path is not None:
|
|
98
|
+
config_path = Path(path)
|
|
99
|
+
if not config_path.exists():
|
|
100
|
+
raise FileNotFoundError(f"Config file not found: {config_path}")
|
|
101
|
+
with open(config_path) as f:
|
|
102
|
+
raw = yaml.safe_load(f) or {}
|
|
103
|
+
else:
|
|
104
|
+
config_path = _find_config_path()
|
|
105
|
+
if config_path and config_path.exists():
|
|
106
|
+
with open(config_path) as f:
|
|
107
|
+
raw = yaml.safe_load(f) or {}
|
|
108
|
+
|
|
109
|
+
profiles = raw.pop("profiles", {})
|
|
110
|
+
|
|
111
|
+
if profile and profile in profiles:
|
|
112
|
+
raw = {**raw, **profiles[profile]}
|
|
113
|
+
|
|
114
|
+
if overrides:
|
|
115
|
+
raw = {**raw, **overrides}
|
|
116
|
+
|
|
117
|
+
# Build ClusterConfig from the merged dict, ignoring unknown keys
|
|
118
|
+
known_fields = {f.name for f in ClusterConfig.__dataclass_fields__.values()}
|
|
119
|
+
filtered = {k: v for k, v in raw.items() if k in known_fields}
|
|
120
|
+
|
|
121
|
+
return ClusterConfig(**filtered)
|
cluster_api/core.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""Abstract Executor base class."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import abc
|
|
6
|
+
import asyncio
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import re
|
|
10
|
+
import secrets
|
|
11
|
+
import string
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from .config import ClusterConfig
|
|
17
|
+
from .exceptions import ClusterAPIError, CommandFailedError, CommandTimeoutError, SubmitError
|
|
18
|
+
from ._types import ArrayElement, JobRecord, JobStatus, ResourceSpec
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
_SCRIPT_TEMPLATE = """\
|
|
23
|
+
%(shebang)s
|
|
24
|
+
%(job_header)s
|
|
25
|
+
%(prologue)s
|
|
26
|
+
%(command)s
|
|
27
|
+
%(epilogue)s
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_ARRAY_ELEMENT_RE = re.compile(r"^(.+)\[(\d+)\]$")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Executor(abc.ABC):
|
|
34
|
+
"""Abstract base for cluster job executors."""
|
|
35
|
+
|
|
36
|
+
submit_command: str
|
|
37
|
+
cancel_command: str
|
|
38
|
+
status_command: str
|
|
39
|
+
job_id_regexp: str = r"(?P<job_id>\d+)"
|
|
40
|
+
directive_prefix: str = ""
|
|
41
|
+
|
|
42
|
+
def __init__(self, config: ClusterConfig) -> None:
|
|
43
|
+
self.config = config
|
|
44
|
+
self._jobs: dict[str, JobRecord] = {}
|
|
45
|
+
self._script_counter = 0
|
|
46
|
+
self._log_dir = Path(config.log_directory).expanduser()
|
|
47
|
+
self._log_dir.mkdir(parents=True, exist_ok=True)
|
|
48
|
+
if config.job_name_prefix:
|
|
49
|
+
self._prefix = config.job_name_prefix
|
|
50
|
+
else:
|
|
51
|
+
# Generate a random prefix so concurrent users/sessions don't
|
|
52
|
+
# see each other's jobs when polling by name.
|
|
53
|
+
alphabet = string.ascii_lowercase + string.digits
|
|
54
|
+
self._prefix = "".join(secrets.choice(alphabet) for _ in range(5))
|
|
55
|
+
|
|
56
|
+
# --- Script rendering ---
|
|
57
|
+
|
|
58
|
+
def render_script(
|
|
59
|
+
self,
|
|
60
|
+
command: str,
|
|
61
|
+
name: str,
|
|
62
|
+
resources: ResourceSpec | None = None,
|
|
63
|
+
prologue: list[str] | None = None,
|
|
64
|
+
epilogue: list[str] | None = None,
|
|
65
|
+
) -> str:
|
|
66
|
+
"""Render a job script from the template."""
|
|
67
|
+
header_lines = self.build_header(name, resources)
|
|
68
|
+
# Filter via directives_skip
|
|
69
|
+
skip = set(self.config.directives_skip)
|
|
70
|
+
if skip:
|
|
71
|
+
header_lines = [
|
|
72
|
+
line
|
|
73
|
+
for line in header_lines
|
|
74
|
+
if not any(s in line for s in skip)
|
|
75
|
+
]
|
|
76
|
+
# Extend with extra_directives
|
|
77
|
+
header_lines.extend(self.config.extra_directives)
|
|
78
|
+
|
|
79
|
+
all_prologue = list(self.config.script_prologue)
|
|
80
|
+
if prologue:
|
|
81
|
+
all_prologue.extend(prologue)
|
|
82
|
+
|
|
83
|
+
all_epilogue = list(self.config.script_epilogue)
|
|
84
|
+
if epilogue:
|
|
85
|
+
all_epilogue.extend(epilogue)
|
|
86
|
+
|
|
87
|
+
return _SCRIPT_TEMPLATE % {
|
|
88
|
+
"shebang": self.config.shebang,
|
|
89
|
+
"job_header": "\n".join(header_lines),
|
|
90
|
+
"prologue": "\n".join(all_prologue),
|
|
91
|
+
"command": command,
|
|
92
|
+
"epilogue": "\n".join(all_epilogue),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
@abc.abstractmethod
|
|
96
|
+
def build_header(
|
|
97
|
+
self, name: str, resources: ResourceSpec | None = None
|
|
98
|
+
) -> list[str]:
|
|
99
|
+
"""Build scheduler-specific directive lines."""
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
# --- Submission ---
|
|
103
|
+
|
|
104
|
+
async def submit(
|
|
105
|
+
self,
|
|
106
|
+
command: str,
|
|
107
|
+
name: str,
|
|
108
|
+
resources: ResourceSpec | None = None,
|
|
109
|
+
prologue: list[str] | None = None,
|
|
110
|
+
epilogue: list[str] | None = None,
|
|
111
|
+
env: dict[str, str] | None = None,
|
|
112
|
+
metadata: dict[str, Any] | None = None,
|
|
113
|
+
) -> JobRecord:
|
|
114
|
+
"""Submit a job to the scheduler."""
|
|
115
|
+
full_name = f"{self._prefix}-{name}"
|
|
116
|
+
script = self.render_script(command, full_name, resources, prologue, epilogue)
|
|
117
|
+
script_path = self._write_script(script, full_name)
|
|
118
|
+
|
|
119
|
+
job_id = await self._submit_job(script_path, full_name, env)
|
|
120
|
+
|
|
121
|
+
record = JobRecord(
|
|
122
|
+
job_id=job_id,
|
|
123
|
+
name=full_name,
|
|
124
|
+
command=command,
|
|
125
|
+
status=JobStatus.PENDING,
|
|
126
|
+
resources=resources,
|
|
127
|
+
script_path=script_path,
|
|
128
|
+
metadata=metadata or {},
|
|
129
|
+
_last_seen=datetime.now(timezone.utc),
|
|
130
|
+
)
|
|
131
|
+
self._jobs[job_id] = record
|
|
132
|
+
logger.info("Submitted job %s (%s)", job_id, full_name)
|
|
133
|
+
return record
|
|
134
|
+
|
|
135
|
+
async def submit_array(
|
|
136
|
+
self,
|
|
137
|
+
command: str,
|
|
138
|
+
name: str,
|
|
139
|
+
array_range: tuple[int, int],
|
|
140
|
+
resources: ResourceSpec | None = None,
|
|
141
|
+
prologue: list[str] | None = None,
|
|
142
|
+
epilogue: list[str] | None = None,
|
|
143
|
+
env: dict[str, str] | None = None,
|
|
144
|
+
metadata: dict[str, Any] | None = None,
|
|
145
|
+
max_concurrent: int | None = None,
|
|
146
|
+
) -> JobRecord:
|
|
147
|
+
"""Submit a job array to the scheduler."""
|
|
148
|
+
full_name = f"{self._prefix}-{name}"
|
|
149
|
+
script = self.render_script(command, full_name, resources, prologue, epilogue)
|
|
150
|
+
script_path = self._write_script(script, full_name)
|
|
151
|
+
|
|
152
|
+
job_id = await self._submit_array_job(
|
|
153
|
+
script_path, full_name, array_range, env, max_concurrent
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
meta = {**(metadata or {}), "array_range": array_range}
|
|
157
|
+
if max_concurrent is not None:
|
|
158
|
+
meta["max_concurrent"] = max_concurrent
|
|
159
|
+
|
|
160
|
+
record = JobRecord(
|
|
161
|
+
job_id=job_id,
|
|
162
|
+
name=full_name,
|
|
163
|
+
command=command,
|
|
164
|
+
status=JobStatus.PENDING,
|
|
165
|
+
resources=resources,
|
|
166
|
+
script_path=script_path,
|
|
167
|
+
metadata=meta,
|
|
168
|
+
_last_seen=datetime.now(timezone.utc),
|
|
169
|
+
)
|
|
170
|
+
self._jobs[job_id] = record
|
|
171
|
+
logger.info(
|
|
172
|
+
"Submitted array job %s (%s[%d-%d])",
|
|
173
|
+
job_id, full_name, array_range[0], array_range[1],
|
|
174
|
+
)
|
|
175
|
+
return record
|
|
176
|
+
|
|
177
|
+
async def _submit_job(
|
|
178
|
+
self,
|
|
179
|
+
script_path: str,
|
|
180
|
+
name: str,
|
|
181
|
+
env: dict[str, str] | None = None,
|
|
182
|
+
) -> str:
|
|
183
|
+
"""Submit a script and return the job ID. Override for stdin submission."""
|
|
184
|
+
out = await self._call(
|
|
185
|
+
[self.submit_command, script_path],
|
|
186
|
+
env=env,
|
|
187
|
+
timeout=self.config.command_timeout,
|
|
188
|
+
)
|
|
189
|
+
return self._job_id_from_submit_output(out)
|
|
190
|
+
|
|
191
|
+
async def _submit_array_job(
|
|
192
|
+
self,
|
|
193
|
+
script_path: str,
|
|
194
|
+
name: str,
|
|
195
|
+
array_range: tuple[int, int],
|
|
196
|
+
env: dict[str, str] | None = None,
|
|
197
|
+
max_concurrent: int | None = None,
|
|
198
|
+
) -> str:
|
|
199
|
+
"""Submit an array job. Override in subclasses."""
|
|
200
|
+
return await self._submit_job(script_path, name, env)
|
|
201
|
+
|
|
202
|
+
def _write_script(self, script_content: str, name: str) -> str:
|
|
203
|
+
"""Write job script to log directory and return its path."""
|
|
204
|
+
safe_name = re.sub(r"[^\w\-.]", "_", name)
|
|
205
|
+
self._script_counter += 1
|
|
206
|
+
script_path = self._log_dir / f"{safe_name}.{self._script_counter}.sh"
|
|
207
|
+
script_path.write_text(script_content)
|
|
208
|
+
script_path.chmod(0o755)
|
|
209
|
+
return str(script_path)
|
|
210
|
+
|
|
211
|
+
def _job_id_from_submit_output(self, out: str) -> str:
|
|
212
|
+
"""Extract job ID from submission output using regex."""
|
|
213
|
+
match = re.search(self.job_id_regexp, out)
|
|
214
|
+
if not match:
|
|
215
|
+
raise SubmitError(
|
|
216
|
+
f"Could not parse job ID from output: {out!r}"
|
|
217
|
+
)
|
|
218
|
+
return match.group("job_id")
|
|
219
|
+
|
|
220
|
+
# --- Cancellation ---
|
|
221
|
+
|
|
222
|
+
async def cancel(self, job_id: str) -> None:
|
|
223
|
+
"""Cancel a job by ID."""
|
|
224
|
+
await self._call(
|
|
225
|
+
[self.cancel_command, job_id],
|
|
226
|
+
timeout=self.config.command_timeout,
|
|
227
|
+
)
|
|
228
|
+
if job_id in self._jobs:
|
|
229
|
+
self._jobs[job_id].status = JobStatus.KILLED
|
|
230
|
+
logger.info("Cancelled job %s", job_id)
|
|
231
|
+
|
|
232
|
+
async def cancel_by_name(self, name_pattern: str) -> None:
|
|
233
|
+
"""Cancel jobs by name pattern. Override in subclasses for native support."""
|
|
234
|
+
raise NotImplementedError("cancel_by_name not supported by this executor")
|
|
235
|
+
|
|
236
|
+
async def cancel_all(self) -> None:
|
|
237
|
+
"""Cancel all tracked jobs."""
|
|
238
|
+
to_cancel = [jid for jid, r in self._jobs.items() if not r.is_terminal]
|
|
239
|
+
await asyncio.gather(*(self.cancel(jid) for jid in to_cancel))
|
|
240
|
+
|
|
241
|
+
# --- Status polling ---
|
|
242
|
+
|
|
243
|
+
@abc.abstractmethod
|
|
244
|
+
def _build_status_args(self) -> list[str]:
|
|
245
|
+
"""Build args for the status query command."""
|
|
246
|
+
...
|
|
247
|
+
|
|
248
|
+
@abc.abstractmethod
|
|
249
|
+
def _parse_job_statuses(
|
|
250
|
+
self, output: str
|
|
251
|
+
) -> dict[str, tuple[JobStatus, dict[str, Any]]]:
|
|
252
|
+
"""Parse status command output into {job_id: (status, metadata_dict)}."""
|
|
253
|
+
...
|
|
254
|
+
|
|
255
|
+
async def poll(self) -> dict[str, JobStatus]:
|
|
256
|
+
"""Query scheduler, update job records, detect zombies. Returns current statuses."""
|
|
257
|
+
active = [r for r in self._jobs.values() if not r.is_terminal]
|
|
258
|
+
if not active:
|
|
259
|
+
return {jid: r.status for jid, r in self._jobs.items()}
|
|
260
|
+
|
|
261
|
+
args = self._build_status_args()
|
|
262
|
+
try:
|
|
263
|
+
out = await self._call(args, timeout=self.config.command_timeout)
|
|
264
|
+
except (ClusterAPIError, OSError):
|
|
265
|
+
logger.warning("Status query failed, skipping poll cycle")
|
|
266
|
+
return {jid: r.status for jid, r in self._jobs.items()}
|
|
267
|
+
|
|
268
|
+
statuses = self._parse_job_statuses(out)
|
|
269
|
+
now = datetime.now(timezone.utc)
|
|
270
|
+
array_jobs_updated: set[str] = set()
|
|
271
|
+
|
|
272
|
+
for raw_id, (new_status, meta) in statuses.items():
|
|
273
|
+
# Check if this is an array element ID like "12345[1]"
|
|
274
|
+
m = _ARRAY_ELEMENT_RE.match(raw_id)
|
|
275
|
+
if m:
|
|
276
|
+
parent_id, element_index = m.group(1), int(m.group(2))
|
|
277
|
+
record = self._jobs.get(parent_id)
|
|
278
|
+
if record and not record.is_terminal and record.is_array:
|
|
279
|
+
if element_index not in record.array_elements:
|
|
280
|
+
record.array_elements[element_index] = ArrayElement(index=element_index)
|
|
281
|
+
elem = record.array_elements[element_index]
|
|
282
|
+
elem.status = new_status
|
|
283
|
+
for key in ("exec_host", "max_mem", "exit_code",
|
|
284
|
+
"submit_time", "start_time", "finish_time"):
|
|
285
|
+
if key in meta and meta[key] is not None:
|
|
286
|
+
setattr(elem, key, meta[key])
|
|
287
|
+
record._last_seen = now
|
|
288
|
+
array_jobs_updated.add(parent_id)
|
|
289
|
+
else:
|
|
290
|
+
record = self._jobs.get(raw_id)
|
|
291
|
+
if record and not record.is_terminal:
|
|
292
|
+
record.status = new_status
|
|
293
|
+
record._last_seen = now
|
|
294
|
+
for key in ("exec_host", "max_mem", "exit_code",
|
|
295
|
+
"submit_time", "start_time", "finish_time"):
|
|
296
|
+
if key in meta and meta[key] is not None:
|
|
297
|
+
setattr(record, key, meta[key])
|
|
298
|
+
|
|
299
|
+
# Aggregate parent status for array jobs that got element updates
|
|
300
|
+
for parent_id in array_jobs_updated:
|
|
301
|
+
record = self._jobs[parent_id]
|
|
302
|
+
record.status = record.compute_array_status()
|
|
303
|
+
|
|
304
|
+
return {jid: r.status for jid, r in self._jobs.items()}
|
|
305
|
+
|
|
306
|
+
# --- Subprocess helper ---
|
|
307
|
+
|
|
308
|
+
@staticmethod
|
|
309
|
+
async def _call(
|
|
310
|
+
cmd: list[str],
|
|
311
|
+
shell: bool = False,
|
|
312
|
+
timeout: float = 100.0,
|
|
313
|
+
env: dict[str, str] | None = None,
|
|
314
|
+
stdin_data: str | None = None,
|
|
315
|
+
) -> str:
|
|
316
|
+
"""Run a subprocess and return stdout.
|
|
317
|
+
|
|
318
|
+
Inspired by dask-jobqueue's _call(), with added timeout support.
|
|
319
|
+
"""
|
|
320
|
+
full_env = None
|
|
321
|
+
if env:
|
|
322
|
+
full_env = {**os.environ, **env}
|
|
323
|
+
|
|
324
|
+
if shell:
|
|
325
|
+
proc = await asyncio.create_subprocess_shell(
|
|
326
|
+
cmd if isinstance(cmd, str) else " ".join(cmd),
|
|
327
|
+
stdout=asyncio.subprocess.PIPE,
|
|
328
|
+
stderr=asyncio.subprocess.PIPE,
|
|
329
|
+
env=full_env,
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
proc = await asyncio.create_subprocess_exec(
|
|
333
|
+
*cmd,
|
|
334
|
+
stdout=asyncio.subprocess.PIPE,
|
|
335
|
+
stderr=asyncio.subprocess.PIPE,
|
|
336
|
+
env=full_env,
|
|
337
|
+
stdin=asyncio.subprocess.PIPE if stdin_data else None,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
stdout, stderr = await asyncio.wait_for(
|
|
342
|
+
proc.communicate(stdin_data.encode() if stdin_data else None),
|
|
343
|
+
timeout=timeout,
|
|
344
|
+
)
|
|
345
|
+
except asyncio.TimeoutError:
|
|
346
|
+
proc.kill()
|
|
347
|
+
raise CommandTimeoutError(
|
|
348
|
+
f"Command timed out after {timeout}s: {cmd}"
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
out = stdout.decode().strip()
|
|
352
|
+
err = stderr.decode().strip()
|
|
353
|
+
|
|
354
|
+
if proc.returncode != 0:
|
|
355
|
+
raise CommandFailedError(
|
|
356
|
+
f"Command failed (exit {proc.returncode}): {cmd}\nstderr: {err}"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return out
|
|
360
|
+
|
|
361
|
+
# --- Properties ---
|
|
362
|
+
|
|
363
|
+
def remove_job(self, job_id: str) -> None:
|
|
364
|
+
"""Remove a job from tracking."""
|
|
365
|
+
self._jobs.pop(job_id, None)
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def jobs(self) -> dict[str, JobRecord]:
|
|
369
|
+
"""All tracked jobs."""
|
|
370
|
+
return dict(self._jobs)
|
|
371
|
+
|
|
372
|
+
@property
|
|
373
|
+
def active_jobs(self) -> dict[str, JobRecord]:
|
|
374
|
+
"""Non-terminal jobs."""
|
|
375
|
+
return {jid: r for jid, r in self._jobs.items() if not r.is_terminal}
|