xmanager-slurm 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +44 -0
- xm_slurm/api.py +261 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +162 -0
- xm_slurm/console.py +3 -0
- xm_slurm/contrib/clusters/__init__.py +52 -0
- xm_slurm/contrib/clusters/drac.py +169 -0
- xm_slurm/executables.py +201 -0
- xm_slurm/execution.py +491 -0
- xm_slurm/executors.py +127 -0
- xm_slurm/experiment.py +737 -0
- xm_slurm/job_blocks.py +14 -0
- xm_slurm/packageables.py +292 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker/__init__.py +75 -0
- xm_slurm/packaging/docker/abc.py +112 -0
- xm_slurm/packaging/docker/cloud.py +503 -0
- xm_slurm/packaging/docker/local.py +206 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +52 -0
- xm_slurm/packaging/utils.py +202 -0
- xm_slurm/resources.py +150 -0
- xm_slurm/status.py +188 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
- xm_slurm/templates/docker/mamba.Dockerfile +27 -0
- xm_slurm/templates/docker/pdm.Dockerfile +31 -0
- xm_slurm/templates/docker/python.Dockerfile +24 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
- xm_slurm/templates/slurm/job.bash.j2 +78 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
- xm_slurm/utils.py +69 -0
- xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
- xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
- xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/execution.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections.abc
|
|
3
|
+
import dataclasses
|
|
4
|
+
import functools
|
|
5
|
+
import hashlib
|
|
6
|
+
import logging
|
|
7
|
+
import operator
|
|
8
|
+
import re
|
|
9
|
+
import shlex
|
|
10
|
+
import typing
|
|
11
|
+
from typing import Any, Mapping, Sequence
|
|
12
|
+
|
|
13
|
+
import asyncssh
|
|
14
|
+
import backoff
|
|
15
|
+
import jinja2 as j2
|
|
16
|
+
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
17
|
+
from asyncssh.misc import MaybeAwait
|
|
18
|
+
from xmanager import xm
|
|
19
|
+
|
|
20
|
+
from xm_slurm import batching, config, executors, status
|
|
21
|
+
from xm_slurm.console import console
|
|
22
|
+
|
|
23
|
+
SlurmClusterConfig = config.SlurmClusterConfig
|
|
24
|
+
ContainerRuntime = config.ContainerRuntime
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
=== Runtime Configurations ===
|
|
28
|
+
With RunC:
|
|
29
|
+
skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
|
|
30
|
+
|
|
31
|
+
pushd $SLURM_TMPDIR
|
|
32
|
+
|
|
33
|
+
umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
|
|
34
|
+
umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
|
|
35
|
+
|
|
36
|
+
runc run -b bundle/<digest> <container-id>
|
|
37
|
+
|
|
38
|
+
With Singularity / Apptainer:
|
|
39
|
+
|
|
40
|
+
apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
|
|
41
|
+
apptainer run --compat <digest>
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
#SBATCH --error=/dev/null
|
|
46
|
+
#SBATCH --output=/dev/null
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
_POLL_INTERVAL = 30.0
|
|
50
|
+
_BATCHED_BATCH_SIZE = 16
|
|
51
|
+
_BATCHED_TIMEOUT = 0.2
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SlurmExecutionError(Exception): ...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
58
|
+
"""SSHClient that does not prompt for keyboard-interactive authentication."""
|
|
59
|
+
|
|
60
|
+
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
61
|
+
return ""
|
|
62
|
+
|
|
63
|
+
def kbdint_challenge_received(
|
|
64
|
+
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
65
|
+
) -> MaybeAwait[KbdIntResponse | None]:
|
|
66
|
+
del name, instructions, lang, prompts
|
|
67
|
+
return []
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _group_by_ssh_options(
|
|
71
|
+
ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
|
|
72
|
+
) -> dict[asyncssh.SSHClientConnectionOptions, list[str]]:
|
|
73
|
+
jobs_by_cluster = collections.defaultdict(list)
|
|
74
|
+
for options, job_id in zip(ssh_options, job_ids):
|
|
75
|
+
jobs_by_cluster[options].append(job_id)
|
|
76
|
+
return jobs_by_cluster
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class _BatchedSlurmHandle:
|
|
80
|
+
@functools.partial(
|
|
81
|
+
batching.batch,
|
|
82
|
+
max_batch_size=_BATCHED_BATCH_SIZE,
|
|
83
|
+
batch_timeout=_BATCHED_TIMEOUT,
|
|
84
|
+
)
|
|
85
|
+
@staticmethod
|
|
86
|
+
async def _batched_get_state(
|
|
87
|
+
ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
|
|
88
|
+
) -> Sequence[status.SlurmJobState]:
|
|
89
|
+
async def _get_state(
|
|
90
|
+
options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
|
|
91
|
+
) -> Sequence[status.SlurmJobState]:
|
|
92
|
+
result = await get_client().run(
|
|
93
|
+
options,
|
|
94
|
+
[
|
|
95
|
+
"sacct",
|
|
96
|
+
"--jobs",
|
|
97
|
+
",".join(job_ids),
|
|
98
|
+
"--format",
|
|
99
|
+
"JobID,State",
|
|
100
|
+
"--allocations",
|
|
101
|
+
"--noheader",
|
|
102
|
+
"--parsable2",
|
|
103
|
+
],
|
|
104
|
+
check=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
assert isinstance(result.stdout, str)
|
|
108
|
+
states_by_job_id = {}
|
|
109
|
+
for line in result.stdout.splitlines():
|
|
110
|
+
job_id, state = line.split("|")
|
|
111
|
+
states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
|
|
112
|
+
|
|
113
|
+
job_states = []
|
|
114
|
+
for job_id in job_ids:
|
|
115
|
+
if job_id in states_by_job_id:
|
|
116
|
+
job_states.append(states_by_job_id[job_id])
|
|
117
|
+
# This is a stupid hack around sacct's inability to display state information for
|
|
118
|
+
# array job elements that haven't begun. We'll assume that if the job ID is not found,
|
|
119
|
+
# and it's an array job, then it's pending.
|
|
120
|
+
elif re.match(r"^\d+_\d+$", job_id) is not None:
|
|
121
|
+
job_states.append(status.SlurmJobState.PENDING)
|
|
122
|
+
else:
|
|
123
|
+
raise SlurmExecutionError(f"Failed to find job state info for {job_id}")
|
|
124
|
+
return job_states
|
|
125
|
+
|
|
126
|
+
jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
|
|
127
|
+
|
|
128
|
+
job_states_per_cluster = await asyncio.gather(*[
|
|
129
|
+
_get_state(options, job_ids) for options, job_ids in jobs_by_cluster.items()
|
|
130
|
+
])
|
|
131
|
+
job_states_by_cluster: dict[
|
|
132
|
+
asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
|
|
133
|
+
] = {}
|
|
134
|
+
for options, job_states in zip(ssh_options, job_states_per_cluster):
|
|
135
|
+
job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
|
|
136
|
+
|
|
137
|
+
job_states = []
|
|
138
|
+
for options, job_id in zip(ssh_options, job_ids):
|
|
139
|
+
job_states.append(job_states_by_cluster[options][job_id])
|
|
140
|
+
return job_states
|
|
141
|
+
|
|
142
|
+
@functools.partial(
|
|
143
|
+
batching.batch,
|
|
144
|
+
max_batch_size=_BATCHED_BATCH_SIZE,
|
|
145
|
+
batch_timeout=_BATCHED_TIMEOUT,
|
|
146
|
+
)
|
|
147
|
+
@staticmethod
|
|
148
|
+
async def _batched_cancel(
|
|
149
|
+
ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
|
|
150
|
+
) -> Sequence[None]:
|
|
151
|
+
async def _cancel(
|
|
152
|
+
options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
|
|
153
|
+
) -> None:
|
|
154
|
+
await get_client().run(options, ["scancel", " ".join(job_ids)], check=True)
|
|
155
|
+
|
|
156
|
+
jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
|
|
157
|
+
return await asyncio.gather(*[
|
|
158
|
+
_cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
|
|
159
|
+
])
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
163
|
+
class SlurmHandle(_BatchedSlurmHandle):
|
|
164
|
+
"""A handle for referring to the launched container."""
|
|
165
|
+
|
|
166
|
+
ssh_connection_options: asyncssh.SSHClientConnectionOptions
|
|
167
|
+
job_id: str
|
|
168
|
+
|
|
169
|
+
def __post_init__(self):
|
|
170
|
+
if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
|
|
171
|
+
raise ValueError(f"Invalid job ID: {self.job_id}")
|
|
172
|
+
|
|
173
|
+
@backoff.on_predicate(
|
|
174
|
+
backoff.constant,
|
|
175
|
+
lambda state: state in status.SlurmActiveJobStates,
|
|
176
|
+
jitter=None,
|
|
177
|
+
interval=_POLL_INTERVAL,
|
|
178
|
+
)
|
|
179
|
+
async def wait(self) -> status.SlurmJobState:
|
|
180
|
+
return await self.get_state()
|
|
181
|
+
|
|
182
|
+
async def stop(self) -> None:
|
|
183
|
+
await self._batched_cancel(self.ssh_connection_options, self.job_id)
|
|
184
|
+
|
|
185
|
+
async def get_state(self) -> status.SlurmJobState:
|
|
186
|
+
return await self._batched_get_state(self.ssh_connection_options, self.job_id)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@functools.cache
|
|
190
|
+
def get_client() -> "Client":
|
|
191
|
+
return Client()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@functools.cache
|
|
195
|
+
def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
196
|
+
template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
|
|
197
|
+
template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
|
|
198
|
+
|
|
199
|
+
def _raise_template_exception(msg: str) -> None:
|
|
200
|
+
raise j2.TemplateRuntimeError(msg)
|
|
201
|
+
|
|
202
|
+
template_env.globals["raise"] = _raise_template_exception
|
|
203
|
+
template_env.globals["operator"] = operator
|
|
204
|
+
|
|
205
|
+
match container_runtime:
|
|
206
|
+
case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
|
|
207
|
+
runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
|
|
208
|
+
case ContainerRuntime.PODMAN:
|
|
209
|
+
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
210
|
+
case _:
|
|
211
|
+
raise NotImplementedError
|
|
212
|
+
# Update our global env with the runtime template's exported globals
|
|
213
|
+
template_env.globals.update(runtime_template.module.__dict__)
|
|
214
|
+
|
|
215
|
+
return template_env
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class Client:
|
|
219
|
+
def __init__(self):
|
|
220
|
+
self._connections: dict[
|
|
221
|
+
asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
|
|
222
|
+
] = {}
|
|
223
|
+
self._connection_lock = asyncio.Lock()
|
|
224
|
+
|
|
225
|
+
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
226
|
+
async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
|
|
227
|
+
# Make sure the xm-slurm state directory exists
|
|
228
|
+
await conn.run("mkdir -p ~/.local/state/xm-slurm", check=True)
|
|
229
|
+
|
|
230
|
+
async def connection(
|
|
231
|
+
self,
|
|
232
|
+
options: asyncssh.SSHClientConnectionOptions,
|
|
233
|
+
) -> asyncssh.SSHClientConnection:
|
|
234
|
+
if options not in self._connections:
|
|
235
|
+
async with self._connection_lock:
|
|
236
|
+
try:
|
|
237
|
+
conn, _ = await asyncssh.create_connection(NoKBAuthSSHClient, options=options)
|
|
238
|
+
await self._setup_remote_connection(conn)
|
|
239
|
+
self._connections[options] = conn
|
|
240
|
+
except asyncssh.misc.PermissionDenied as ex:
|
|
241
|
+
raise RuntimeError(f"Permission denied connecting to {options.host}") from ex
|
|
242
|
+
return self._connections[options]
|
|
243
|
+
|
|
244
|
+
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
245
|
+
async def run(
|
|
246
|
+
self,
|
|
247
|
+
options: asyncssh.SSHClientConnectionOptions,
|
|
248
|
+
command: xm.SequentialArgs | str | Sequence[str],
|
|
249
|
+
*,
|
|
250
|
+
check: bool = False,
|
|
251
|
+
timeout: float | None = None,
|
|
252
|
+
) -> asyncssh.SSHCompletedProcess:
|
|
253
|
+
client = await self.connection(options)
|
|
254
|
+
if isinstance(command, xm.SequentialArgs):
|
|
255
|
+
command = command.to_list()
|
|
256
|
+
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
257
|
+
command = shlex.join(command)
|
|
258
|
+
assert isinstance(command, str)
|
|
259
|
+
logging.debug("Running command on %s: %s", options.host, command)
|
|
260
|
+
|
|
261
|
+
return await client.run(command, check=check, timeout=timeout)
|
|
262
|
+
|
|
263
|
+
async def template(
|
|
264
|
+
self,
|
|
265
|
+
*,
|
|
266
|
+
job: xm.Job | xm.JobGroup,
|
|
267
|
+
cluster: SlurmClusterConfig,
|
|
268
|
+
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
|
|
269
|
+
experiment_id: int,
|
|
270
|
+
identity: str | None,
|
|
271
|
+
) -> str:
|
|
272
|
+
if args is None:
|
|
273
|
+
args = {}
|
|
274
|
+
|
|
275
|
+
template_env = get_template_env(cluster.runtime)
|
|
276
|
+
|
|
277
|
+
# Sanitize job groups
|
|
278
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
279
|
+
job = typing.cast(xm.Job, list(job.jobs.values())[0])
|
|
280
|
+
elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
281
|
+
raise ValueError("Job group must have at least one job")
|
|
282
|
+
|
|
283
|
+
match job:
|
|
284
|
+
case xm.Job() as job_array if isinstance(args, collections.abc.Sequence):
|
|
285
|
+
template = template_env.get_template("job-array.bash.j2")
|
|
286
|
+
sequential_args = [
|
|
287
|
+
xm.SequentialArgs.from_collection(trial.get("args", None)) for trial in args
|
|
288
|
+
]
|
|
289
|
+
env_vars = [trial.get("env_vars") for trial in args]
|
|
290
|
+
if any(env_vars):
|
|
291
|
+
raise NotImplementedError(
|
|
292
|
+
"Job arrays over environment variables are not yet supported."
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
return template.render(
|
|
296
|
+
job=job_array,
|
|
297
|
+
cluster=cluster,
|
|
298
|
+
args=sequential_args,
|
|
299
|
+
env_vars=env_vars,
|
|
300
|
+
experiment_id=experiment_id,
|
|
301
|
+
identity=identity,
|
|
302
|
+
)
|
|
303
|
+
case xm.Job() if isinstance(args, collections.abc.Mapping):
|
|
304
|
+
template = template_env.get_template("job.bash.j2")
|
|
305
|
+
sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
|
|
306
|
+
env_vars = args.get("env_vars", None)
|
|
307
|
+
return template.render(
|
|
308
|
+
job=job,
|
|
309
|
+
cluster=cluster,
|
|
310
|
+
args=sequential_args,
|
|
311
|
+
env_vars=env_vars,
|
|
312
|
+
experiment_id=experiment_id,
|
|
313
|
+
identity=identity,
|
|
314
|
+
)
|
|
315
|
+
case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
|
|
316
|
+
template = template_env.get_template("job-group.bash.j2")
|
|
317
|
+
sequential_args = {
|
|
318
|
+
job_name: {
|
|
319
|
+
"args": args.get(job_name, {}).get("args", None),
|
|
320
|
+
}
|
|
321
|
+
for job_name in job_group.jobs.keys()
|
|
322
|
+
}
|
|
323
|
+
env_vars = {
|
|
324
|
+
job_name: args.get(job_name, {}).get("env_vars", None)
|
|
325
|
+
for job_name in job_group.jobs.keys()
|
|
326
|
+
}
|
|
327
|
+
return template.render(
|
|
328
|
+
job_group=job_group,
|
|
329
|
+
cluster=cluster,
|
|
330
|
+
args=sequential_args,
|
|
331
|
+
env_vars=env_vars,
|
|
332
|
+
experiment_id=experiment_id,
|
|
333
|
+
identity=identity,
|
|
334
|
+
)
|
|
335
|
+
case _:
|
|
336
|
+
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
337
|
+
|
|
338
|
+
@typing.overload
|
|
339
|
+
async def launch(
|
|
340
|
+
self,
|
|
341
|
+
*,
|
|
342
|
+
cluster: SlurmClusterConfig,
|
|
343
|
+
job: xm.Job | xm.JobGroup,
|
|
344
|
+
args: Mapping[str, Any] | None,
|
|
345
|
+
experiment_id: int,
|
|
346
|
+
identity: str | None = ...,
|
|
347
|
+
) -> SlurmHandle: ...
|
|
348
|
+
|
|
349
|
+
@typing.overload
|
|
350
|
+
async def launch(
|
|
351
|
+
self,
|
|
352
|
+
*,
|
|
353
|
+
cluster: SlurmClusterConfig,
|
|
354
|
+
job: xm.Job | xm.JobGroup,
|
|
355
|
+
args: Sequence[Mapping[str, Any]],
|
|
356
|
+
experiment_id: int,
|
|
357
|
+
identity: str | None = ...,
|
|
358
|
+
) -> Sequence[SlurmHandle]: ...
|
|
359
|
+
|
|
360
|
+
async def launch(
|
|
361
|
+
self,
|
|
362
|
+
*,
|
|
363
|
+
cluster: SlurmClusterConfig,
|
|
364
|
+
job: xm.Job | xm.JobGroup,
|
|
365
|
+
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
|
|
366
|
+
experiment_id: int,
|
|
367
|
+
identity: str | None = None,
|
|
368
|
+
) -> SlurmHandle | Sequence[SlurmHandle]:
|
|
369
|
+
# Construct template
|
|
370
|
+
template = await self.template(
|
|
371
|
+
job=job,
|
|
372
|
+
cluster=cluster,
|
|
373
|
+
args=args,
|
|
374
|
+
experiment_id=experiment_id,
|
|
375
|
+
identity=identity,
|
|
376
|
+
)
|
|
377
|
+
logging.debug("Slurm submission script:\n%s", template)
|
|
378
|
+
|
|
379
|
+
# Hash submission script
|
|
380
|
+
template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
|
|
381
|
+
|
|
382
|
+
conn = await self.connection(cluster.ssh_connection_options)
|
|
383
|
+
async with conn.start_sftp_client() as sftp:
|
|
384
|
+
# Write the submission script to the cluster
|
|
385
|
+
# TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
|
|
386
|
+
# INSTEAD OF ASSUMING SFTP PUTS US IN THE HOME DIRECTORY
|
|
387
|
+
await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
|
|
388
|
+
async with sftp.open(
|
|
389
|
+
f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
|
|
390
|
+
) as fp:
|
|
391
|
+
await fp.write(template)
|
|
392
|
+
|
|
393
|
+
# Construct and run command on the cluster
|
|
394
|
+
command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
|
|
395
|
+
result = await self.run(cluster.ssh_connection_options, command)
|
|
396
|
+
if result.returncode != 0:
|
|
397
|
+
raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
|
|
398
|
+
|
|
399
|
+
assert isinstance(result.stdout, str)
|
|
400
|
+
slurm_job_id, *_ = result.stdout.split(",")
|
|
401
|
+
slurm_job_id = slurm_job_id.strip()
|
|
402
|
+
|
|
403
|
+
console.log(
|
|
404
|
+
f"[magenta]:rocket: Job [cyan]{slurm_job_id}[/cyan] will be launched on "
|
|
405
|
+
f"[cyan]{cluster.name}[/cyan] "
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
|
|
409
|
+
return [
|
|
410
|
+
SlurmHandle(
|
|
411
|
+
ssh_connection_options=cluster.ssh_connection_options,
|
|
412
|
+
job_id=f"{slurm_job_id}_{array_index}",
|
|
413
|
+
)
|
|
414
|
+
for array_index in range(len(args))
|
|
415
|
+
]
|
|
416
|
+
|
|
417
|
+
return SlurmHandle(
|
|
418
|
+
ssh_connection_options=cluster.ssh_connection_options,
|
|
419
|
+
job_id=slurm_job_id,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@typing.overload
|
|
424
|
+
async def launch(
|
|
425
|
+
*,
|
|
426
|
+
job: xm.Job | xm.JobGroup,
|
|
427
|
+
args: Mapping[str, Any],
|
|
428
|
+
experiment_id: int,
|
|
429
|
+
identity: str | None = ...,
|
|
430
|
+
) -> SlurmHandle: ...
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@typing.overload
|
|
434
|
+
async def launch(
|
|
435
|
+
*,
|
|
436
|
+
job: xm.Job | xm.JobGroup,
|
|
437
|
+
args: Sequence[Mapping[str, Any]],
|
|
438
|
+
experiment_id: int,
|
|
439
|
+
identity: str | None = ...,
|
|
440
|
+
) -> Sequence[SlurmHandle]: ...
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
async def launch(
|
|
444
|
+
*,
|
|
445
|
+
job: xm.Job | xm.JobGroup,
|
|
446
|
+
args: Mapping[str, Any] | Sequence[Mapping[str, Any]],
|
|
447
|
+
experiment_id: int,
|
|
448
|
+
identity: str | None = None,
|
|
449
|
+
) -> SlurmHandle | Sequence[SlurmHandle]:
|
|
450
|
+
match job:
|
|
451
|
+
case xm.Job():
|
|
452
|
+
if not isinstance(job.executor, executors.Slurm):
|
|
453
|
+
raise ValueError("Job must have a Slurm executor")
|
|
454
|
+
job_requirements = job.executor.requirements
|
|
455
|
+
cluster = job_requirements.cluster
|
|
456
|
+
if cluster is None:
|
|
457
|
+
raise ValueError("Job must have a cluster requirement")
|
|
458
|
+
|
|
459
|
+
return await get_client().launch(
|
|
460
|
+
cluster=cluster,
|
|
461
|
+
job=job,
|
|
462
|
+
args=args,
|
|
463
|
+
experiment_id=experiment_id,
|
|
464
|
+
identity=identity,
|
|
465
|
+
)
|
|
466
|
+
case xm.JobGroup() as job_group:
|
|
467
|
+
job_group_executors = set()
|
|
468
|
+
job_group_clusters = set()
|
|
469
|
+
for job_item in job_group.jobs.values():
|
|
470
|
+
if not isinstance(job_item, xm.Job):
|
|
471
|
+
raise ValueError("Job group must contain only jobs")
|
|
472
|
+
if not isinstance(job_item.executor, executors.Slurm):
|
|
473
|
+
raise ValueError("Job must have a Slurm executor")
|
|
474
|
+
if job_item.executor.requirements.cluster is None:
|
|
475
|
+
raise ValueError("Job must have a cluster requirement")
|
|
476
|
+
job_group_clusters.add(job_item.executor.requirements.cluster)
|
|
477
|
+
job_group_executors.add(id(job_item.executor))
|
|
478
|
+
if len(job_group_executors) != 1:
|
|
479
|
+
raise ValueError("Job group must have the same executor for all jobs")
|
|
480
|
+
if len(job_group_clusters) != 1:
|
|
481
|
+
raise ValueError("Job group must have the same cluster for all jobs")
|
|
482
|
+
|
|
483
|
+
return await get_client().launch(
|
|
484
|
+
cluster=job_group_clusters.pop(),
|
|
485
|
+
job=job,
|
|
486
|
+
args=args,
|
|
487
|
+
experiment_id=experiment_id,
|
|
488
|
+
identity=identity,
|
|
489
|
+
)
|
|
490
|
+
case _:
|
|
491
|
+
raise ValueError("Unsupported job type")
|
xm_slurm/executors.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import datetime as dt
|
|
3
|
+
import signal
|
|
4
|
+
|
|
5
|
+
from xmanager import xm
|
|
6
|
+
|
|
7
|
+
from xm_slurm import resources
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
11
|
+
class SlurmSpec(xm.ExecutorSpec):
|
|
12
|
+
"""Slurm executor specification that describes the location of the container runtime.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
tag: The Image URI to push and pull the container image from.
|
|
16
|
+
For example, using the GitHub Container Registry: `ghcr.io/my-project/my-image:latest`.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
tag: str | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
23
|
+
class Slurm(xm.Executor):
|
|
24
|
+
"""Slurm Executor describing the runtime environment.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
requirements: The requirements for the job.
|
|
28
|
+
time: The maximum time to run the job.
|
|
29
|
+
account: The account to charge the job to.
|
|
30
|
+
partition: The partition to run the job in.
|
|
31
|
+
qos: The quality of service to run the job with.
|
|
32
|
+
priority: The priority of the job.
|
|
33
|
+
timeout_signal: The signal to send to the job when it runs out of time.
|
|
34
|
+
timeout_signal_grace_period: The time to wait before sending `timeout_signal`.
|
|
35
|
+
requeue: Whether or not the job is eligible for requeueing.
|
|
36
|
+
requeue_on_exit_code: The exit code that triggers requeueing.
|
|
37
|
+
requeue_max_attempts: The maximum number of times to attempt requeueing.
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# Job requirements
|
|
42
|
+
requirements: resources.JobRequirements
|
|
43
|
+
time: dt.timedelta
|
|
44
|
+
|
|
45
|
+
# Placement
|
|
46
|
+
account: str | None = None
|
|
47
|
+
partition: str | None = None
|
|
48
|
+
qos: str | None = None
|
|
49
|
+
priority: int | None = None
|
|
50
|
+
|
|
51
|
+
# Job rescheduling
|
|
52
|
+
timeout_signal: signal.Signals = signal.SIGUSR2
|
|
53
|
+
timeout_signal_grace_period: dt.timedelta = dt.timedelta(seconds=90)
|
|
54
|
+
|
|
55
|
+
requeue: bool = True # Is this job ellible for requeueing?
|
|
56
|
+
requeue_on_exit_code: int = 42 # The exit code that triggers requeueing
|
|
57
|
+
requeue_max_attempts: int = 5 # How many times to attempt requeueing
|
|
58
|
+
|
|
59
|
+
def __post_init__(self) -> None:
|
|
60
|
+
if not isinstance(self.time, dt.timedelta):
|
|
61
|
+
raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
|
|
62
|
+
if not isinstance(self.requirements, resources.JobRequirements):
|
|
63
|
+
raise TypeError(
|
|
64
|
+
f"requirements must be a `xm_slurm.JobRequirements`, got {type(self.requirements)}. "
|
|
65
|
+
"If you're still using `xm.JobRequirements`, please update to `xm_slurm.JobRequirements`."
|
|
66
|
+
)
|
|
67
|
+
if not isinstance(self.timeout_signal, signal.Signals):
|
|
68
|
+
raise TypeError(
|
|
69
|
+
f"termination_signal must be a `signal.Signals`, got {type(self.timeout_signal)}"
|
|
70
|
+
)
|
|
71
|
+
if not isinstance(self.timeout_signal_grace_period, dt.timedelta):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
f"termination_signal_delay_time must be a `datetime.timedelta`, got {type(self.timeout_signal_grace_period)}"
|
|
74
|
+
)
|
|
75
|
+
if self.requeue_max_attempts < 0:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"requeue_max_attempts must be greater than or equal to 0, got {self.requeue_max_attempts}"
|
|
78
|
+
)
|
|
79
|
+
if self.requeue_on_exit_code == 0:
|
|
80
|
+
raise ValueError("requeue_on_exit_code should not be 0 to avoid unexpected behavior.")
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def Spec(cls, tag: str | None = None) -> SlurmSpec:
|
|
84
|
+
return SlurmSpec(tag=tag)
|
|
85
|
+
|
|
86
|
+
def to_directives(self) -> list[str]:
|
|
87
|
+
# Job requirements
|
|
88
|
+
directives = self.requirements.to_directives()
|
|
89
|
+
|
|
90
|
+
# Time
|
|
91
|
+
days = self.time.days
|
|
92
|
+
hours, remainder = divmod(self.time.seconds, 3600)
|
|
93
|
+
minutes, seconds = divmod(remainder, 60)
|
|
94
|
+
directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
|
|
95
|
+
|
|
96
|
+
# Placement
|
|
97
|
+
if self.account:
|
|
98
|
+
directives.append(f"--account={self.account}")
|
|
99
|
+
if self.partition:
|
|
100
|
+
directives.append(f"--partition={self.partition}")
|
|
101
|
+
if self.qos:
|
|
102
|
+
directives.append(f"--qos={self.qos}")
|
|
103
|
+
if self.priority:
|
|
104
|
+
directives.append(f"--priority={self.priority}")
|
|
105
|
+
|
|
106
|
+
# Job rescheduling
|
|
107
|
+
directives.append(
|
|
108
|
+
f"--signal={self.timeout_signal.name.removeprefix('SIG')}@{self.timeout_signal_grace_period.seconds}"
|
|
109
|
+
)
|
|
110
|
+
if self.requeue and self.requeue_max_attempts > 0:
|
|
111
|
+
directives.append("--requeue")
|
|
112
|
+
else:
|
|
113
|
+
directives.append("--no-requeue")
|
|
114
|
+
|
|
115
|
+
return directives
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class DockerSpec(xm.ExecutorSpec):
|
|
119
|
+
"""Local Docker executor specification that describes the container runtime."""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Docker(xm.Executor):
|
|
123
|
+
"""Local Docker executor describing the runtime environment."""
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def Spec(cls) -> DockerSpec:
|
|
127
|
+
return DockerSpec()
|