xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
xm_slurm/execution.py
ADDED
|
@@ -0,0 +1,995 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections.abc
|
|
3
|
+
import dataclasses
|
|
4
|
+
import functools
|
|
5
|
+
import getpass
|
|
6
|
+
import hashlib
|
|
7
|
+
import importlib
|
|
8
|
+
import importlib.resources
|
|
9
|
+
import logging
|
|
10
|
+
import operator
|
|
11
|
+
import os
|
|
12
|
+
import pathlib
|
|
13
|
+
import re
|
|
14
|
+
import shlex
|
|
15
|
+
import shutil
|
|
16
|
+
import subprocess
|
|
17
|
+
import sys
|
|
18
|
+
import typing as tp
|
|
19
|
+
|
|
20
|
+
import asyncssh
|
|
21
|
+
import backoff
|
|
22
|
+
import jinja2 as j2
|
|
23
|
+
import more_itertools as mit
|
|
24
|
+
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
25
|
+
from asyncssh.misc import MaybeAwait
|
|
26
|
+
from rich.console import ConsoleRenderable
|
|
27
|
+
from rich.rule import Rule
|
|
28
|
+
from rich.text import Text
|
|
29
|
+
from xmanager import xm
|
|
30
|
+
|
|
31
|
+
from xm_slurm import (
|
|
32
|
+
batching,
|
|
33
|
+
config,
|
|
34
|
+
constants,
|
|
35
|
+
dependencies,
|
|
36
|
+
executors,
|
|
37
|
+
filesystems,
|
|
38
|
+
job_blocks,
|
|
39
|
+
status,
|
|
40
|
+
utils,
|
|
41
|
+
)
|
|
42
|
+
from xm_slurm.console import console
|
|
43
|
+
from xm_slurm.types import Descriptor
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
_POLL_INTERVAL = 30.0
|
|
48
|
+
_BATCHED_BATCH_SIZE = 16
|
|
49
|
+
_BATCHED_TIMEOUT = 0.2
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SlurmExecutionError(Exception): ...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
56
|
+
class SlurmJob:
|
|
57
|
+
job_id: str
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def is_array_job(self) -> bool:
|
|
61
|
+
return isinstance(self, SlurmArrayJob)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def is_heterogeneous_job(self) -> bool:
|
|
65
|
+
return isinstance(self, SlurmHeterogeneousJob)
|
|
66
|
+
|
|
67
|
+
def __hash__(self) -> int:
|
|
68
|
+
return hash((type(self), self.job_id))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
72
|
+
class SlurmArrayJob(SlurmJob):
|
|
73
|
+
array_job_id: str
|
|
74
|
+
array_task_id: str
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
78
|
+
class SlurmHeterogeneousJob(SlurmJob):
|
|
79
|
+
het_job_id: str
|
|
80
|
+
het_component_id: str
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
|
|
87
|
+
def __set_name__(self, owner: type, name: str):
|
|
88
|
+
del owner
|
|
89
|
+
self.job = f"_{name}"
|
|
90
|
+
|
|
91
|
+
def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
|
|
92
|
+
del owner
|
|
93
|
+
return getattr(instance, self.job)
|
|
94
|
+
|
|
95
|
+
def __set__(self, instance: object, value: str):
|
|
96
|
+
_setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
|
|
97
|
+
|
|
98
|
+
match = constants.SLURM_JOB_ID_REGEX.match(value)
|
|
99
|
+
if match is None:
|
|
100
|
+
raise ValueError(f"Invalid Slurm job ID: {value}")
|
|
101
|
+
groups = match.groupdict()
|
|
102
|
+
|
|
103
|
+
job_id = groups["jobid"]
|
|
104
|
+
if array_task_id := groups.get("arraytaskid"):
|
|
105
|
+
_setattr(
|
|
106
|
+
instance,
|
|
107
|
+
self.job,
|
|
108
|
+
SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
|
|
109
|
+
)
|
|
110
|
+
elif het_component_id := groups.get("componentid"):
|
|
111
|
+
_setattr(
|
|
112
|
+
instance,
|
|
113
|
+
self.job,
|
|
114
|
+
SlurmHeterogeneousJob(
|
|
115
|
+
job_id=value, het_job_id=job_id, het_component_id=het_component_id
|
|
116
|
+
),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
_setattr(instance, self.job, SlurmJob(job_id=value))
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _group_by_ssh_configs(
|
|
123
|
+
ssh_configs: tp.Sequence[config.SSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
|
|
124
|
+
) -> dict[config.SSHConfig, list[SlurmJob]]:
|
|
125
|
+
jobs_by_cluster = collections.defaultdict(list)
|
|
126
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
127
|
+
jobs_by_cluster[ssh_config].append(slurm_job)
|
|
128
|
+
return jobs_by_cluster
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class _BatchedSlurmHandle:
|
|
132
|
+
@functools.partial(
|
|
133
|
+
batching.batch,
|
|
134
|
+
max_batch_size=_BATCHED_BATCH_SIZE,
|
|
135
|
+
batch_timeout=_BATCHED_TIMEOUT,
|
|
136
|
+
)
|
|
137
|
+
@staticmethod
|
|
138
|
+
@backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
|
|
139
|
+
async def _batched_get_state(
|
|
140
|
+
ssh_configs: tp.Sequence[config.SSHConfig],
|
|
141
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
142
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
143
|
+
async def _get_state(
|
|
144
|
+
options: config.SSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
|
|
145
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
146
|
+
result = await get_client().run(
|
|
147
|
+
options,
|
|
148
|
+
[
|
|
149
|
+
"sacct",
|
|
150
|
+
"--jobs",
|
|
151
|
+
",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
|
|
152
|
+
"--format",
|
|
153
|
+
"JobID,State",
|
|
154
|
+
"--allocations",
|
|
155
|
+
"--noheader",
|
|
156
|
+
"--parsable2",
|
|
157
|
+
],
|
|
158
|
+
check=True,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
assert isinstance(result.stdout, str)
|
|
162
|
+
states_by_job_id = {}
|
|
163
|
+
for line in result.stdout.splitlines():
|
|
164
|
+
job_id, state = line.split("|")
|
|
165
|
+
states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
|
|
166
|
+
|
|
167
|
+
job_states = []
|
|
168
|
+
for slurm_job in slurm_jobs:
|
|
169
|
+
if slurm_job.job_id in states_by_job_id:
|
|
170
|
+
job_states.append(states_by_job_id[slurm_job.job_id])
|
|
171
|
+
# This is a stupid hack around sacct's inability to display state information for
|
|
172
|
+
# array job elements that haven't begun. We'll assume that if the job ID is not found,
|
|
173
|
+
# and it's an array job, then it's pending.
|
|
174
|
+
elif slurm_job.is_array_job:
|
|
175
|
+
job_states.append(status.SlurmJobState.PENDING)
|
|
176
|
+
else:
|
|
177
|
+
raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
|
|
178
|
+
return job_states
|
|
179
|
+
|
|
180
|
+
# Group Slurm jobs by their cluster so we can batch requests
|
|
181
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
182
|
+
|
|
183
|
+
# Async get state for each cluster
|
|
184
|
+
job_states_per_cluster = await asyncio.gather(*[
|
|
185
|
+
_get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
|
|
186
|
+
])
|
|
187
|
+
|
|
188
|
+
# Reconstruct the job states by cluster
|
|
189
|
+
job_states_by_cluster = {}
|
|
190
|
+
for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
|
|
191
|
+
job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
|
|
192
|
+
|
|
193
|
+
# Reconstruct the job states in the original order
|
|
194
|
+
job_states = []
|
|
195
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
196
|
+
job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
|
|
197
|
+
return job_states
|
|
198
|
+
|
|
199
|
+
@functools.partial(
|
|
200
|
+
batching.batch,
|
|
201
|
+
max_batch_size=_BATCHED_BATCH_SIZE,
|
|
202
|
+
batch_timeout=_BATCHED_TIMEOUT,
|
|
203
|
+
)
|
|
204
|
+
@staticmethod
|
|
205
|
+
async def _batched_cancel(
|
|
206
|
+
ssh_configs: tp.Sequence[config.SSHConfig],
|
|
207
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
208
|
+
) -> tp.Sequence[None]:
|
|
209
|
+
async def _cancel(options: config.SSHConfig, slurm_jobs: tp.Sequence[SlurmJob]) -> None:
|
|
210
|
+
await get_client().run(
|
|
211
|
+
options,
|
|
212
|
+
["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
|
|
213
|
+
check=True,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
217
|
+
return await asyncio.gather(*[
|
|
218
|
+
_cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
|
|
219
|
+
])
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
223
|
+
class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
|
|
224
|
+
"""A handle for referring to the launched container."""
|
|
225
|
+
|
|
226
|
+
experiment_id: int
|
|
227
|
+
ssh: config.SSHConfig
|
|
228
|
+
slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
|
|
229
|
+
job_name: str # XManager job name associated with this handle
|
|
230
|
+
|
|
231
|
+
@backoff.on_predicate(
|
|
232
|
+
backoff.constant,
|
|
233
|
+
lambda state: state in status.SlurmActiveJobStates,
|
|
234
|
+
jitter=None,
|
|
235
|
+
interval=_POLL_INTERVAL,
|
|
236
|
+
)
|
|
237
|
+
async def wait(self) -> status.SlurmJobState:
|
|
238
|
+
return await self.get_state()
|
|
239
|
+
|
|
240
|
+
async def stop(self) -> None:
|
|
241
|
+
await self._batched_cancel(self.ssh, self.slurm_job)
|
|
242
|
+
|
|
243
|
+
async def get_state(self) -> status.SlurmJobState:
|
|
244
|
+
return await self._batched_get_state(self.ssh, self.slurm_job)
|
|
245
|
+
|
|
246
|
+
async def logs(
|
|
247
|
+
self, *, num_lines: int, block_size: int, wait: bool, follow: bool, raw: bool = False
|
|
248
|
+
) -> tp.AsyncGenerator[tp.Union[str, ConsoleRenderable], None]:
|
|
249
|
+
experiment_dir = await get_client().experiment_dir(self.ssh, self.experiment_id)
|
|
250
|
+
file = experiment_dir / f"slurm-{self.slurm_job.job_id}.out"
|
|
251
|
+
fs = await get_client().fs(self.ssh)
|
|
252
|
+
|
|
253
|
+
if wait:
|
|
254
|
+
while not (await fs.exists(file)):
|
|
255
|
+
await asyncio.sleep(5)
|
|
256
|
+
|
|
257
|
+
file_size = await fs.size(file)
|
|
258
|
+
assert file_size is not None
|
|
259
|
+
|
|
260
|
+
async with await fs.open(file, "rb") as remote_file:
|
|
261
|
+
data = b""
|
|
262
|
+
lines = []
|
|
263
|
+
position = file_size
|
|
264
|
+
|
|
265
|
+
while len(lines) <= num_lines and position > 0:
|
|
266
|
+
read_size = min(block_size, position)
|
|
267
|
+
position -= read_size
|
|
268
|
+
await remote_file.seek(position)
|
|
269
|
+
chunk = await remote_file.read(read_size)
|
|
270
|
+
data = chunk + data
|
|
271
|
+
lines = data.splitlines()
|
|
272
|
+
|
|
273
|
+
if position <= 0:
|
|
274
|
+
if raw:
|
|
275
|
+
yield "\033[31mBEGINNING OF FILE\033[0m\n"
|
|
276
|
+
else:
|
|
277
|
+
yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
|
|
278
|
+
for line in lines[-num_lines:]:
|
|
279
|
+
if raw:
|
|
280
|
+
yield line.decode("utf-8", errors="replace") + "\n"
|
|
281
|
+
else:
|
|
282
|
+
yield Text.from_ansi(line.decode("utf-8", errors="replace"))
|
|
283
|
+
|
|
284
|
+
if (await self.get_state()) not in status.SlurmActiveJobStates:
|
|
285
|
+
if raw:
|
|
286
|
+
yield "\033[31mEND OF FILE\033[0m\n"
|
|
287
|
+
return
|
|
288
|
+
else:
|
|
289
|
+
yield Rule("[bold red]END OF FILE[/bold red]")
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
if not follow:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
await remote_file.seek(file_size)
|
|
296
|
+
while True:
|
|
297
|
+
if new_data := (await remote_file.read(block_size)):
|
|
298
|
+
if raw:
|
|
299
|
+
yield new_data.decode("utf-8", errors="replace")
|
|
300
|
+
else:
|
|
301
|
+
yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
|
|
302
|
+
else:
|
|
303
|
+
await asyncio.sleep(0.25)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class CompletedProcess(tp.Protocol):
|
|
307
|
+
returncode: int | None
|
|
308
|
+
stdout: bytes | str
|
|
309
|
+
stderr: bytes | str
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@functools.cache
|
|
313
|
+
def get_template_env(runtime: config.ContainerRuntime) -> j2.Environment:
|
|
314
|
+
template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
|
|
315
|
+
template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
|
|
316
|
+
|
|
317
|
+
def _raise_template_exception(msg: str) -> None:
|
|
318
|
+
raise j2.TemplateRuntimeError(msg)
|
|
319
|
+
|
|
320
|
+
template_env.globals["raise"] = _raise_template_exception
|
|
321
|
+
template_env.globals["operator"] = operator
|
|
322
|
+
|
|
323
|
+
# Iterate over stdlib files and insert them into the template environment
|
|
324
|
+
stdlib = []
|
|
325
|
+
for file in importlib.resources.files("xm_slurm.templates.slurm.library").iterdir():
|
|
326
|
+
if not file.is_file() or not file.name.endswith(".bash"):
|
|
327
|
+
continue
|
|
328
|
+
stdlib.append(file.read_text())
|
|
329
|
+
template_env.globals["stdlib"] = stdlib
|
|
330
|
+
|
|
331
|
+
entrypoint_template = template_env.get_template("entrypoint.bash.j2")
|
|
332
|
+
template_env.globals.update(entrypoint_template.module.__dict__)
|
|
333
|
+
|
|
334
|
+
match runtime:
|
|
335
|
+
case config.ContainerRuntime.SINGULARITY | config.ContainerRuntime.APPTAINER:
|
|
336
|
+
runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
|
|
337
|
+
case config.ContainerRuntime.PODMAN:
|
|
338
|
+
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
339
|
+
case _:
|
|
340
|
+
raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
|
|
341
|
+
template_env.globals.update(runtime_template.module.__dict__)
|
|
342
|
+
|
|
343
|
+
return template_env
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class SlurmSSHClient(asyncssh.SSHClient):
|
|
347
|
+
"""SSHClient that handles keyboard-interactive 2FA authentication."""
|
|
348
|
+
|
|
349
|
+
_kbdint_auth_lock: tp.ClassVar[asyncio.Lock] = asyncio.Lock()
|
|
350
|
+
_host: str
|
|
351
|
+
|
|
352
|
+
def __init__(self, host: str):
|
|
353
|
+
self._host = host
|
|
354
|
+
|
|
355
|
+
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
356
|
+
return ""
|
|
357
|
+
|
|
358
|
+
async def kbdint_challenge_received(
|
|
359
|
+
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
360
|
+
) -> MaybeAwait[KbdIntResponse | None]:
|
|
361
|
+
"""Handle 2FA prompts by prompting user for input."""
|
|
362
|
+
del name, lang
|
|
363
|
+
if not sys.stdin.isatty():
|
|
364
|
+
raise SlurmExecutionError(
|
|
365
|
+
f"Two-factor authentication is not supported for non-interactive sessions on {self._host}"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
async with self._kbdint_auth_lock:
|
|
369
|
+
if len(prompts) > 0:
|
|
370
|
+
console.rule(f"Two-Factor Authentication for {self._host}")
|
|
371
|
+
|
|
372
|
+
if instructions:
|
|
373
|
+
console.print(instructions, style="bold yellow")
|
|
374
|
+
|
|
375
|
+
responses = []
|
|
376
|
+
for prompt, echo in prompts:
|
|
377
|
+
# Manually disable password authentication
|
|
378
|
+
if prompt.strip() == "Password:":
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
response = await asyncio.to_thread(
|
|
383
|
+
console.input,
|
|
384
|
+
f"{prompt}\a",
|
|
385
|
+
password=not echo,
|
|
386
|
+
)
|
|
387
|
+
except (EOFError, KeyboardInterrupt):
|
|
388
|
+
console.print("\n[red]Authentication cancelled[/red]")
|
|
389
|
+
return None
|
|
390
|
+
else:
|
|
391
|
+
responses.append(response)
|
|
392
|
+
|
|
393
|
+
return responses
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class SlurmExecutionClient:
|
|
397
|
+
def __init__(self):
|
|
398
|
+
self._remote_connections = dict[config.SSHConfig, asyncssh.SSHClientConnection]()
|
|
399
|
+
self._remote_connection_locks = collections.defaultdict(asyncio.Lock)
|
|
400
|
+
self._remote_filesystems = dict[config.SSHConfig, filesystems.AsyncSSHFileSystem]()
|
|
401
|
+
|
|
402
|
+
self._local_fs = filesystems.AsyncLocalFileSystem()
|
|
403
|
+
|
|
404
|
+
@backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
|
|
405
|
+
async def _local_run( # type: ignore
|
|
406
|
+
self,
|
|
407
|
+
command: str,
|
|
408
|
+
*,
|
|
409
|
+
check: bool = False,
|
|
410
|
+
timeout: float | None = None,
|
|
411
|
+
) -> subprocess.CompletedProcess[str]:
|
|
412
|
+
process = await asyncio.subprocess.create_subprocess_shell(
|
|
413
|
+
command,
|
|
414
|
+
stdout=asyncio.subprocess.PIPE,
|
|
415
|
+
stderr=asyncio.subprocess.PIPE,
|
|
416
|
+
# Filter out all SLURM_ environment variables as this could be running on a
|
|
417
|
+
# compute node and xm-slurm should act stateless.
|
|
418
|
+
env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
|
|
419
|
+
)
|
|
420
|
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
|
|
421
|
+
|
|
422
|
+
stdout = stdout.decode("utf-8").strip() if stdout else ""
|
|
423
|
+
stderr = stderr.decode("utf-8").strip() if stderr else ""
|
|
424
|
+
|
|
425
|
+
assert process.returncode is not None
|
|
426
|
+
if check and process.returncode != 0:
|
|
427
|
+
raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
|
|
428
|
+
|
|
429
|
+
return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
|
|
430
|
+
|
|
431
|
+
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
432
|
+
async def _remote_run( # type: ignore
|
|
433
|
+
self,
|
|
434
|
+
ssh_config: config.SSHConfig,
|
|
435
|
+
command: str,
|
|
436
|
+
*,
|
|
437
|
+
check: bool = False,
|
|
438
|
+
timeout: float | None = None,
|
|
439
|
+
) -> asyncssh.SSHCompletedProcess:
|
|
440
|
+
client = await self._connection(ssh_config)
|
|
441
|
+
return await client.run(command, check=check, timeout=timeout)
|
|
442
|
+
|
|
443
|
+
@functools.cache
|
|
444
|
+
def _is_ssh_config_local(self, ssh_config: config.SSHConfig) -> bool:
|
|
445
|
+
"""A best effort check to see if the SSH config is local so we can bypass ssh."""
|
|
446
|
+
|
|
447
|
+
# We can't verify the connection so bail out
|
|
448
|
+
if ssh_config.public_key is None:
|
|
449
|
+
return False
|
|
450
|
+
if "SSH_CONNECTION" not in os.environ:
|
|
451
|
+
return False
|
|
452
|
+
|
|
453
|
+
def _is_host_local(host: str) -> bool:
|
|
454
|
+
nonlocal ssh_config
|
|
455
|
+
assert ssh_config.public_key is not None
|
|
456
|
+
|
|
457
|
+
if shutil.which("ssh-keyscan") is None:
|
|
458
|
+
return False
|
|
459
|
+
|
|
460
|
+
keyscan_result = utils.run_command(
|
|
461
|
+
["ssh-keyscan", "-t", ssh_config.public_key.algorithm, host],
|
|
462
|
+
return_stdout=True,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if keyscan_result.returncode != 0:
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
key = mit.one(
|
|
470
|
+
filter(
|
|
471
|
+
lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
_, algorithm, key = key.split(" ")
|
|
475
|
+
|
|
476
|
+
if (
|
|
477
|
+
algorithm == ssh_config.public_key.algorithm
|
|
478
|
+
and key == ssh_config.public_key.key
|
|
479
|
+
):
|
|
480
|
+
return True
|
|
481
|
+
|
|
482
|
+
except Exception:
|
|
483
|
+
pass
|
|
484
|
+
|
|
485
|
+
return False
|
|
486
|
+
|
|
487
|
+
# 1): we're directly connected to the host
|
|
488
|
+
ssh_connection_str = os.environ["SSH_CONNECTION"]
|
|
489
|
+
_, _, server_ip, _ = ssh_connection_str.split()
|
|
490
|
+
|
|
491
|
+
logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
|
|
492
|
+
if _is_host_local(server_ip):
|
|
493
|
+
return True
|
|
494
|
+
|
|
495
|
+
# 2): we're in a Slurm job and the submission host is the host
|
|
496
|
+
if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
|
|
497
|
+
submit_host = os.environ["SLURM_SUBMIT_HOST"]
|
|
498
|
+
logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
|
|
499
|
+
if _is_host_local(submit_host):
|
|
500
|
+
return True
|
|
501
|
+
elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
|
|
502
|
+
# Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
|
|
503
|
+
# so we'll parse it from scontrol...
|
|
504
|
+
scontrol_result = utils.run_command(
|
|
505
|
+
["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
|
|
506
|
+
return_stdout=True,
|
|
507
|
+
)
|
|
508
|
+
if scontrol_result.returncode != 0:
|
|
509
|
+
return False
|
|
510
|
+
|
|
511
|
+
match = re.search(
|
|
512
|
+
r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
|
|
513
|
+
)
|
|
514
|
+
if match is not None:
|
|
515
|
+
host = match.group("host")
|
|
516
|
+
logger.debug("Checking if AllocNode %s is local", host)
|
|
517
|
+
if _is_host_local(host):
|
|
518
|
+
return True
|
|
519
|
+
|
|
520
|
+
return False
|
|
521
|
+
|
|
522
|
+
@functools.cache
|
|
523
|
+
@utils.reawaitable
|
|
524
|
+
async def _state_dir(self, ssh_config: config.SSHConfig) -> pathlib.Path:
|
|
525
|
+
state_dirs = [
|
|
526
|
+
("XM_SLURM_STATE_DIR", ""),
|
|
527
|
+
("XDG_STATE_HOME", "xm-slurm"),
|
|
528
|
+
("HOME", ".local/state/xm-slurm"),
|
|
529
|
+
]
|
|
530
|
+
|
|
531
|
+
for env_var, subpath in state_dirs:
|
|
532
|
+
cmd = await self.run(ssh_config, f"printenv {env_var}", check=False)
|
|
533
|
+
assert isinstance(cmd.stdout, str)
|
|
534
|
+
if cmd.returncode == 0:
|
|
535
|
+
return pathlib.Path(cmd.stdout.strip()) / subpath
|
|
536
|
+
|
|
537
|
+
raise SlurmExecutionError(
|
|
538
|
+
"Failed to find a valid state directory for XManager. "
|
|
539
|
+
"We weren't able to resolve any of the following paths: "
|
|
540
|
+
f"{', '.join(env_var + ('/' + subpath if subpath else '') for env_var, subpath in state_dirs)}."
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
@functools.cached_property
|
|
544
|
+
def _ssh_config_dirs(self) -> list[pathlib.Path]:
|
|
545
|
+
ssh_config_paths = []
|
|
546
|
+
|
|
547
|
+
if (ssh_config_path := pathlib.Path.home() / ".ssh" / "config").exists():
|
|
548
|
+
ssh_config_paths.append(ssh_config_path)
|
|
549
|
+
if (xm_ssh_config_var := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
|
|
550
|
+
xm_ssh_config_path := pathlib.Path(xm_ssh_config_var).expanduser()
|
|
551
|
+
).exists():
|
|
552
|
+
ssh_config_paths.append(xm_ssh_config_path)
|
|
553
|
+
|
|
554
|
+
return ssh_config_paths
|
|
555
|
+
|
|
556
|
+
async def experiment_dir(
|
|
557
|
+
self, ssh_config: config.SSHConfig, experiment_id: int
|
|
558
|
+
) -> pathlib.Path:
|
|
559
|
+
return (await self._state_dir(ssh_config)) / f"{experiment_id:08d}"
|
|
560
|
+
|
|
561
|
+
async def run(
|
|
562
|
+
self,
|
|
563
|
+
ssh_config: config.SSHConfig,
|
|
564
|
+
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
565
|
+
*,
|
|
566
|
+
check: bool = False,
|
|
567
|
+
timeout: float | None = None,
|
|
568
|
+
) -> CompletedProcess:
|
|
569
|
+
if isinstance(command, xm.SequentialArgs):
|
|
570
|
+
command = command.to_list()
|
|
571
|
+
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
572
|
+
command = shlex.join(command)
|
|
573
|
+
assert isinstance(command, str)
|
|
574
|
+
|
|
575
|
+
if self._is_ssh_config_local(ssh_config):
|
|
576
|
+
logger.debug("Running command locally: %s", command)
|
|
577
|
+
return await self._local_run(command, check=check, timeout=timeout) # type: ignore
|
|
578
|
+
else:
|
|
579
|
+
logger.debug(
|
|
580
|
+
"Running command on %s: %s", ", ".join(map(str, ssh_config.endpoints)), command
|
|
581
|
+
)
|
|
582
|
+
return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
|
|
583
|
+
|
|
584
|
+
async def fs(self, ssh_config: config.SSHConfig) -> filesystems.AsyncFileSystem:
|
|
585
|
+
if self._is_ssh_config_local(ssh_config):
|
|
586
|
+
return self._local_fs
|
|
587
|
+
|
|
588
|
+
if ssh_config not in self._remote_filesystems:
|
|
589
|
+
self._remote_filesystems[ssh_config] = filesystems.AsyncSSHFileSystem(
|
|
590
|
+
await (await self._connection(ssh_config)).start_sftp_client()
|
|
591
|
+
)
|
|
592
|
+
return self._remote_filesystems[ssh_config]
|
|
593
|
+
|
|
594
|
+
async def _connection(self, ssh_config: config.SSHConfig) -> asyncssh.SSHClientConnection:
|
|
595
|
+
async def _connect_to_endpoint(
|
|
596
|
+
endpoint: config.Endpoint,
|
|
597
|
+
) -> asyncssh.SSHClientConnection:
|
|
598
|
+
__tracebackhide__ = True
|
|
599
|
+
try:
|
|
600
|
+
config = asyncssh.config.SSHClientConfig.load(
|
|
601
|
+
None,
|
|
602
|
+
self._ssh_config_dirs,
|
|
603
|
+
True,
|
|
604
|
+
True,
|
|
605
|
+
True,
|
|
606
|
+
getpass.getuser(),
|
|
607
|
+
ssh_config.user or (),
|
|
608
|
+
endpoint.hostname,
|
|
609
|
+
endpoint.port or (),
|
|
610
|
+
)
|
|
611
|
+
if config.get("Hostname") is None and (
|
|
612
|
+
constants.DOMAIN_NAME_REGEX.match(endpoint.hostname)
|
|
613
|
+
or constants.IPV4_REGEX.match(endpoint.hostname)
|
|
614
|
+
or constants.IPV6_REGEX.match(endpoint.hostname)
|
|
615
|
+
):
|
|
616
|
+
config._options["Hostname"] = endpoint.hostname
|
|
617
|
+
elif config.get("Hostname") is None:
|
|
618
|
+
raise RuntimeError(
|
|
619
|
+
f"Failed to parse hostname from host `{endpoint.hostname}` using "
|
|
620
|
+
f"SSH configs: {', '.join(map(str, self._ssh_config_dirs))} and "
|
|
621
|
+
f"provided hostname `{endpoint.hostname}` isn't a valid domain name "
|
|
622
|
+
"or IPv{4,6} address."
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
if config.get("User") is None:
|
|
626
|
+
raise RuntimeError(
|
|
627
|
+
f"We could not find a user for the cluster configuration: `{endpoint.hostname}`. "
|
|
628
|
+
"No user was specified in the configuration and we could not parse "
|
|
629
|
+
f"any users for host `{config.get('Hostname')}` from the SSH configs: "
|
|
630
|
+
f"{', '.join(map(lambda h: f'`{h}`', self._ssh_config_dirs))}. Please either specify a user "
|
|
631
|
+
"in the configuration or add a user to your SSH configuration under the block "
|
|
632
|
+
f"`Host {config.get('Hostname')}`."
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
options = await asyncssh.SSHClientConnectionOptions.construct(
|
|
636
|
+
config=None,
|
|
637
|
+
disable_trivial_auth=True,
|
|
638
|
+
password_auth=False,
|
|
639
|
+
server_host_key_algs=ssh_config.public_key.algorithm
|
|
640
|
+
if ssh_config.public_key
|
|
641
|
+
else None,
|
|
642
|
+
login_timeout=60 * 10, # 10 minutes
|
|
643
|
+
known_hosts=ssh_config.known_hosts,
|
|
644
|
+
)
|
|
645
|
+
options.prepare(last_config=config)
|
|
646
|
+
|
|
647
|
+
conn, _ = await asyncssh.create_connection(
|
|
648
|
+
lambda: SlurmSSHClient(endpoint.hostname),
|
|
649
|
+
host=endpoint.hostname,
|
|
650
|
+
port=endpoint.port,
|
|
651
|
+
options=options,
|
|
652
|
+
)
|
|
653
|
+
return conn
|
|
654
|
+
except asyncssh.misc.PermissionDenied as ex:
|
|
655
|
+
raise SlurmExecutionError(
|
|
656
|
+
f"Permission denied connecting to {endpoint.hostname}"
|
|
657
|
+
) from ex
|
|
658
|
+
except asyncssh.misc.ConnectionLost as ex:
|
|
659
|
+
raise SlurmExecutionError(f"Connection lost to host {endpoint.hostname}") from ex
|
|
660
|
+
except asyncssh.misc.HostKeyNotVerifiable as ex:
|
|
661
|
+
raise SlurmExecutionError(
|
|
662
|
+
f"Cannot verify the public key for host {endpoint.hostname}"
|
|
663
|
+
) from ex
|
|
664
|
+
except asyncssh.misc.KeyExchangeFailed as ex:
|
|
665
|
+
raise SlurmExecutionError(
|
|
666
|
+
f"Failed to exchange keys with host {endpoint.hostname}"
|
|
667
|
+
) from ex
|
|
668
|
+
except asyncssh.Error as ex:
|
|
669
|
+
raise SlurmExecutionError(
|
|
670
|
+
f"SSH connection error when connecting to {endpoint.hostname}"
|
|
671
|
+
) from ex
|
|
672
|
+
|
|
673
|
+
conn = self._remote_connections.get(ssh_config)
|
|
674
|
+
if conn is not None and not conn.is_closed():
|
|
675
|
+
return conn
|
|
676
|
+
|
|
677
|
+
async with self._remote_connection_locks[ssh_config]:
|
|
678
|
+
conn = self._remote_connections.get(ssh_config)
|
|
679
|
+
if conn is not None and not conn.is_closed():
|
|
680
|
+
return conn
|
|
681
|
+
|
|
682
|
+
exceptions: list[Exception] = []
|
|
683
|
+
for endpoint in ssh_config.endpoints:
|
|
684
|
+
try:
|
|
685
|
+
conn = await _connect_to_endpoint(endpoint)
|
|
686
|
+
except Exception as ex:
|
|
687
|
+
exceptions.append(ex)
|
|
688
|
+
else:
|
|
689
|
+
self._remote_connections[ssh_config] = conn
|
|
690
|
+
return conn
|
|
691
|
+
|
|
692
|
+
if sys.version_info >= (3, 11):
|
|
693
|
+
raise ExceptionGroup("Failed to connect to all hosts", exceptions) # noqa: F821
|
|
694
|
+
raise exceptions[-1]
|
|
695
|
+
|
|
696
|
+
async def _submission_script_template(
|
|
697
|
+
self,
|
|
698
|
+
*,
|
|
699
|
+
job: xm.Job | xm.JobGroup,
|
|
700
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
701
|
+
cluster: config.SlurmClusterConfig,
|
|
702
|
+
args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
|
|
703
|
+
experiment_id: int,
|
|
704
|
+
identity: str | None,
|
|
705
|
+
) -> str:
|
|
706
|
+
# Sanitize args
|
|
707
|
+
match args:
|
|
708
|
+
case None:
|
|
709
|
+
args = {}
|
|
710
|
+
case collections.abc.Mapping():
|
|
711
|
+
args = dict(args)
|
|
712
|
+
case collections.abc.Sequence():
|
|
713
|
+
assert all(isinstance(trial, collections.abc.Mapping) for trial in args)
|
|
714
|
+
args = [dict(trial) for trial in args]
|
|
715
|
+
case _:
|
|
716
|
+
raise ValueError("Invalid args type")
|
|
717
|
+
args = tp.cast(dict[str, tp.Any] | list[dict[str, tp.Any]], args)
|
|
718
|
+
|
|
719
|
+
template_env = get_template_env(cluster.runtime)
|
|
720
|
+
template_context = dict(
|
|
721
|
+
dependency=dependency,
|
|
722
|
+
cluster=cluster,
|
|
723
|
+
experiment_id=experiment_id,
|
|
724
|
+
identity=identity,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# Sanitize job groups
|
|
728
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
729
|
+
job = tp.cast(xm.Job, list(job.jobs.values())[0])
|
|
730
|
+
elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
731
|
+
raise ValueError("Job group must have at least one job")
|
|
732
|
+
|
|
733
|
+
match job:
|
|
734
|
+
case xm.Job() as job_array if isinstance(args, collections.abc.Sequence):
|
|
735
|
+
assert isinstance(args, list)
|
|
736
|
+
template = template_env.get_template("job-array.bash.j2")
|
|
737
|
+
sequential_args = [
|
|
738
|
+
xm.SequentialArgs.from_collection(trial.get("args")) for trial in args
|
|
739
|
+
]
|
|
740
|
+
env_vars = [trial.get("env_vars") for trial in args]
|
|
741
|
+
if any(env_vars):
|
|
742
|
+
raise NotImplementedError(
|
|
743
|
+
"Job arrays over environment variables are not yet supported."
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
return template.render(
|
|
747
|
+
job=job_array, args=sequential_args, env_vars=env_vars, **template_context
|
|
748
|
+
)
|
|
749
|
+
case xm.Job() if isinstance(args, collections.abc.Mapping):
|
|
750
|
+
assert isinstance(args, dict)
|
|
751
|
+
template = template_env.get_template("job.bash.j2")
|
|
752
|
+
sequential_args = xm.SequentialArgs.from_collection(args.get("args"))
|
|
753
|
+
env_vars = args.get("env_vars")
|
|
754
|
+
return template.render(
|
|
755
|
+
job=job, args=sequential_args, env_vars=env_vars, **template_context
|
|
756
|
+
)
|
|
757
|
+
case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
|
|
758
|
+
assert isinstance(args, dict)
|
|
759
|
+
template = template_env.get_template("job-group.bash.j2")
|
|
760
|
+
sequential_args = {
|
|
761
|
+
job_name: {
|
|
762
|
+
"args": args.get(job_name, {}).get("args"),
|
|
763
|
+
}
|
|
764
|
+
for job_name in job_group.jobs.keys()
|
|
765
|
+
}
|
|
766
|
+
env_vars = {
|
|
767
|
+
job_name: args.get(job_name, {}).get("env_vars")
|
|
768
|
+
for job_name in job_group.jobs.keys()
|
|
769
|
+
}
|
|
770
|
+
return template.render(
|
|
771
|
+
job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
|
|
772
|
+
)
|
|
773
|
+
case _:
|
|
774
|
+
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
775
|
+
|
|
776
|
+
@tp.overload
|
|
777
|
+
async def launch(
|
|
778
|
+
self,
|
|
779
|
+
*,
|
|
780
|
+
cluster: config.SlurmClusterConfig,
|
|
781
|
+
job: xm.JobGroup,
|
|
782
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
783
|
+
args: tp.Mapping[str, job_blocks.JobArgs] | None,
|
|
784
|
+
experiment_id: int,
|
|
785
|
+
identity: str | None = ...,
|
|
786
|
+
) -> SlurmHandle: ...
|
|
787
|
+
|
|
788
|
+
@tp.overload
|
|
789
|
+
async def launch(
|
|
790
|
+
self,
|
|
791
|
+
*,
|
|
792
|
+
cluster: config.SlurmClusterConfig,
|
|
793
|
+
job: xm.Job,
|
|
794
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
795
|
+
args: tp.Sequence[job_blocks.JobArgs],
|
|
796
|
+
experiment_id: int,
|
|
797
|
+
identity: str | None = ...,
|
|
798
|
+
) -> list[SlurmHandle]: ...
|
|
799
|
+
|
|
800
|
+
@tp.overload
|
|
801
|
+
async def launch(
|
|
802
|
+
self,
|
|
803
|
+
*,
|
|
804
|
+
cluster: config.SlurmClusterConfig,
|
|
805
|
+
job: xm.Job,
|
|
806
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
807
|
+
args: job_blocks.JobArgs,
|
|
808
|
+
experiment_id: int,
|
|
809
|
+
identity: str | None = ...,
|
|
810
|
+
) -> SlurmHandle: ...
|
|
811
|
+
|
|
812
|
+
async def launch(
|
|
813
|
+
self,
|
|
814
|
+
*,
|
|
815
|
+
cluster: config.SlurmClusterConfig,
|
|
816
|
+
job: xm.Job | xm.JobGroup,
|
|
817
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
818
|
+
args: tp.Mapping[str, job_blocks.JobArgs]
|
|
819
|
+
| tp.Sequence[job_blocks.JobArgs]
|
|
820
|
+
| job_blocks.JobArgs
|
|
821
|
+
| None,
|
|
822
|
+
experiment_id: int,
|
|
823
|
+
identity: str | None = None,
|
|
824
|
+
):
|
|
825
|
+
submission_script = await self._submission_script_template(
|
|
826
|
+
job=job,
|
|
827
|
+
dependency=dependency,
|
|
828
|
+
cluster=cluster,
|
|
829
|
+
args=args,
|
|
830
|
+
experiment_id=experiment_id,
|
|
831
|
+
identity=identity,
|
|
832
|
+
)
|
|
833
|
+
logger.debug("Slurm submission script:\n%s", submission_script)
|
|
834
|
+
submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
|
|
835
|
+
submission_script_path = f"submission-script-{submission_script_hash}.sh"
|
|
836
|
+
|
|
837
|
+
fs = await self.fs(cluster.ssh)
|
|
838
|
+
|
|
839
|
+
template_dir = await self.experiment_dir(cluster.ssh, experiment_id)
|
|
840
|
+
|
|
841
|
+
await fs.makedirs(template_dir, exist_ok=True)
|
|
842
|
+
await fs.write(template_dir / submission_script_path, submission_script.encode())
|
|
843
|
+
|
|
844
|
+
# Construct and run command on the cluster
|
|
845
|
+
command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
|
|
846
|
+
result = await self.run(cluster.ssh, command)
|
|
847
|
+
if result.returncode != 0:
|
|
848
|
+
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
849
|
+
|
|
850
|
+
assert isinstance(result.stdout, str)
|
|
851
|
+
slurm_job_id, *_ = result.stdout.split(",")
|
|
852
|
+
slurm_job_id = slurm_job_id.strip()
|
|
853
|
+
|
|
854
|
+
console.log(
|
|
855
|
+
f"[magenta]:rocket: Job [cyan]{slurm_job_id}[/cyan] will be launched on "
|
|
856
|
+
f"[cyan]{cluster.name}[/cyan] "
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
# If we scheduled an array job make sure to return a list of handles
|
|
860
|
+
# The indexing is always sequential in 0, 1, ..., n - 1
|
|
861
|
+
if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
|
|
862
|
+
assert job.name is not None
|
|
863
|
+
return [
|
|
864
|
+
SlurmHandle(
|
|
865
|
+
experiment_id=experiment_id,
|
|
866
|
+
ssh=cluster.ssh,
|
|
867
|
+
slurm_job=f"{slurm_job_id}_{array_index}",
|
|
868
|
+
job_name=job.name,
|
|
869
|
+
)
|
|
870
|
+
for array_index in range(len(args))
|
|
871
|
+
]
|
|
872
|
+
elif isinstance(job, xm.Job):
|
|
873
|
+
assert job.name is not None
|
|
874
|
+
return SlurmHandle(
|
|
875
|
+
experiment_id=experiment_id,
|
|
876
|
+
ssh=cluster.ssh,
|
|
877
|
+
slurm_job=slurm_job_id,
|
|
878
|
+
job_name=job.name,
|
|
879
|
+
)
|
|
880
|
+
elif isinstance(job, xm.JobGroup):
|
|
881
|
+
# TODO: make this work for actual job groups.
|
|
882
|
+
job = tp.cast(xm.Job, mit.one(job.jobs.values()))
|
|
883
|
+
assert isinstance(job, xm.Job)
|
|
884
|
+
assert job.name is not None
|
|
885
|
+
return SlurmHandle(
|
|
886
|
+
experiment_id=experiment_id,
|
|
887
|
+
ssh=cluster.ssh,
|
|
888
|
+
slurm_job=slurm_job_id,
|
|
889
|
+
job_name=job.name,
|
|
890
|
+
)
|
|
891
|
+
else:
|
|
892
|
+
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
893
|
+
|
|
894
|
+
def __del__(self):
|
|
895
|
+
for fs in self._remote_filesystems.values():
|
|
896
|
+
del fs
|
|
897
|
+
for conn in self._remote_connections.values():
|
|
898
|
+
conn.close()
|
|
899
|
+
del conn
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
@functools.cache
|
|
903
|
+
def get_client() -> SlurmExecutionClient:
|
|
904
|
+
return SlurmExecutionClient()
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
@tp.overload
|
|
908
|
+
async def launch(
|
|
909
|
+
*,
|
|
910
|
+
job: xm.JobGroup,
|
|
911
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
912
|
+
args: tp.Mapping[str, job_blocks.JobArgs],
|
|
913
|
+
experiment_id: int,
|
|
914
|
+
identity: str | None = ...,
|
|
915
|
+
) -> SlurmHandle: ...
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
@tp.overload
|
|
919
|
+
async def launch(
|
|
920
|
+
*,
|
|
921
|
+
job: xm.Job,
|
|
922
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
923
|
+
args: tp.Sequence[job_blocks.JobArgs],
|
|
924
|
+
experiment_id: int,
|
|
925
|
+
identity: str | None = ...,
|
|
926
|
+
) -> list[SlurmHandle]: ...
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
@tp.overload
|
|
930
|
+
async def launch(
|
|
931
|
+
*,
|
|
932
|
+
job: xm.Job,
|
|
933
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
934
|
+
args: job_blocks.JobArgs,
|
|
935
|
+
experiment_id: int,
|
|
936
|
+
identity: str | None = ...,
|
|
937
|
+
) -> SlurmHandle: ...
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
async def launch(
|
|
941
|
+
*,
|
|
942
|
+
job: xm.Job | xm.JobGroup,
|
|
943
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
944
|
+
args: tp.Mapping[str, job_blocks.JobArgs]
|
|
945
|
+
| tp.Sequence[job_blocks.JobArgs]
|
|
946
|
+
| job_blocks.JobArgs,
|
|
947
|
+
experiment_id: int,
|
|
948
|
+
identity: str | None = None,
|
|
949
|
+
) -> SlurmHandle | list[SlurmHandle]:
|
|
950
|
+
match job:
|
|
951
|
+
case xm.Job() as job:
|
|
952
|
+
if not isinstance(job.executor, executors.Slurm):
|
|
953
|
+
raise ValueError("Job must have a Slurm executor")
|
|
954
|
+
job_requirements = job.executor.requirements
|
|
955
|
+
cluster = job_requirements.cluster
|
|
956
|
+
if cluster is None:
|
|
957
|
+
raise ValueError("Job must have a cluster requirement")
|
|
958
|
+
if cluster.validate is not None:
|
|
959
|
+
cluster.validate(job)
|
|
960
|
+
|
|
961
|
+
return await get_client().launch(
|
|
962
|
+
cluster=cluster,
|
|
963
|
+
job=job,
|
|
964
|
+
dependency=dependency,
|
|
965
|
+
args=tp.cast(job_blocks.JobArgs | tp.Sequence[job_blocks.JobArgs], args),
|
|
966
|
+
experiment_id=experiment_id,
|
|
967
|
+
identity=identity,
|
|
968
|
+
)
|
|
969
|
+
case xm.JobGroup() as job_group:
|
|
970
|
+
job_group_executors = set()
|
|
971
|
+
job_group_clusters = set()
|
|
972
|
+
for job_item in job_group.jobs.values():
|
|
973
|
+
if not isinstance(job_item, xm.Job):
|
|
974
|
+
raise ValueError("Job group must contain only jobs")
|
|
975
|
+
if not isinstance(job_item.executor, executors.Slurm):
|
|
976
|
+
raise ValueError("Job must have a Slurm executor")
|
|
977
|
+
if job_item.executor.requirements.cluster is None:
|
|
978
|
+
raise ValueError("Job must have a cluster requirement")
|
|
979
|
+
if job_item.executor.requirements.cluster.validate is not None:
|
|
980
|
+
job_item.executor.requirements.cluster.validate(job_item)
|
|
981
|
+
job_group_clusters.add(job_item.executor.requirements.cluster)
|
|
982
|
+
job_group_executors.add(id(job_item.executor))
|
|
983
|
+
if len(job_group_executors) != 1:
|
|
984
|
+
raise ValueError("Job group must have the same executor for all jobs")
|
|
985
|
+
if len(job_group_clusters) != 1:
|
|
986
|
+
raise ValueError("Job group must have the same cluster for all jobs")
|
|
987
|
+
|
|
988
|
+
return await get_client().launch(
|
|
989
|
+
cluster=job_group_clusters.pop(),
|
|
990
|
+
job=job_group,
|
|
991
|
+
dependency=dependency,
|
|
992
|
+
args=tp.cast(tp.Mapping[str, job_blocks.JobArgs], args),
|
|
993
|
+
experiment_id=experiment_id,
|
|
994
|
+
identity=identity,
|
|
995
|
+
)
|