xmanager-slurm 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +0 -2
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/config.py +11 -3
- xm_slurm/contrib/clusters/__init__.py +3 -6
- xm_slurm/contrib/clusters/drac.py +4 -3
- xm_slurm/executables.py +4 -7
- xm_slurm/execution.py +273 -159
- xm_slurm/experiment.py +26 -180
- xm_slurm/filesystem.py +129 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +0 -9
- xm_slurm/packaging/docker.py +72 -22
- xm_slurm/packaging/utils.py +0 -108
- xm_slurm/scripts/cli.py +9 -2
- xm_slurm/templates/docker/uv.Dockerfile +6 -3
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +4 -4
- xm_slurm/templates/slurm/job-group.bash.j2 +2 -2
- xm_slurm/templates/slurm/job.bash.j2 +5 -4
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +18 -54
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +10 -24
- xm_slurm/utils.py +122 -41
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.6.dist-info}/METADATA +7 -3
- xmanager_slurm-0.4.6.dist-info/RECORD +51 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.6.dist-info}/WHEEL +1 -1
- xm_slurm/api.py +0 -528
- xmanager_slurm-0.4.5.dist-info/RECORD +0 -44
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.6.dist-info}/entry_points.txt +0 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.6.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/execution.py
CHANGED
|
@@ -5,7 +5,12 @@ import functools
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import logging
|
|
7
7
|
import operator
|
|
8
|
+
import os
|
|
9
|
+
import pathlib
|
|
10
|
+
import re
|
|
8
11
|
import shlex
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
9
14
|
import typing as tp
|
|
10
15
|
|
|
11
16
|
import asyncssh
|
|
@@ -16,36 +21,18 @@ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
|
16
21
|
from asyncssh.misc import MaybeAwait
|
|
17
22
|
from rich.console import ConsoleRenderable
|
|
18
23
|
from rich.rule import Rule
|
|
24
|
+
from rich.text import Text
|
|
19
25
|
from xmanager import xm
|
|
20
26
|
|
|
21
|
-
from xm_slurm import batching, config, constants, dependencies, executors, status
|
|
27
|
+
from xm_slurm import batching, config, constants, dependencies, executors, status, utils
|
|
28
|
+
from xm_slurm.config import ContainerRuntime, SlurmClusterConfig, SlurmSSHConfig
|
|
22
29
|
from xm_slurm.console import console
|
|
30
|
+
from xm_slurm.filesystem import AsyncFileSystem, AsyncLocalFileSystem, AsyncSSHFileSystem
|
|
23
31
|
from xm_slurm.job_blocks import JobArgs
|
|
24
32
|
from xm_slurm.types import Descriptor
|
|
25
33
|
|
|
26
|
-
SlurmClusterConfig = config.SlurmClusterConfig
|
|
27
|
-
ContainerRuntime = config.ContainerRuntime
|
|
28
|
-
|
|
29
34
|
logger = logging.getLogger(__name__)
|
|
30
35
|
|
|
31
|
-
"""
|
|
32
|
-
=== Runtime Configurations ===
|
|
33
|
-
With RunC:
|
|
34
|
-
skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
|
|
35
|
-
|
|
36
|
-
pushd $SLURM_TMPDIR
|
|
37
|
-
|
|
38
|
-
umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
|
|
39
|
-
umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
|
|
40
|
-
|
|
41
|
-
runc run -b bundle/<digest> <container-id>
|
|
42
|
-
|
|
43
|
-
With Singularity / Apptainer:
|
|
44
|
-
|
|
45
|
-
apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
|
|
46
|
-
apptainer run --compat <digest>
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
36
|
_POLL_INTERVAL = 30.0
|
|
50
37
|
_BATCHED_BATCH_SIZE = 16
|
|
51
38
|
_BATCHED_TIMEOUT = 0.2
|
|
@@ -54,19 +41,6 @@ _BATCHED_TIMEOUT = 0.2
|
|
|
54
41
|
class SlurmExecutionError(Exception): ...
|
|
55
42
|
|
|
56
43
|
|
|
57
|
-
class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
58
|
-
"""SSHClient that does not prompt for keyboard-interactive authentication."""
|
|
59
|
-
|
|
60
|
-
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
61
|
-
return ""
|
|
62
|
-
|
|
63
|
-
def kbdint_challenge_received(
|
|
64
|
-
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
65
|
-
) -> MaybeAwait[KbdIntResponse | None]:
|
|
66
|
-
del name, instructions, lang, prompts
|
|
67
|
-
return []
|
|
68
|
-
|
|
69
|
-
|
|
70
44
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
71
45
|
class SlurmJob:
|
|
72
46
|
job_id: str
|
|
@@ -208,7 +182,7 @@ class _BatchedSlurmHandle:
|
|
|
208
182
|
# Reconstruct the job states in the original order
|
|
209
183
|
job_states = []
|
|
210
184
|
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
211
|
-
job_states.append(job_states_by_cluster[ssh_config][slurm_job])
|
|
185
|
+
job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
|
|
212
186
|
return job_states
|
|
213
187
|
|
|
214
188
|
@functools.partial(
|
|
@@ -263,52 +237,59 @@ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
|
|
|
263
237
|
async def logs(
|
|
264
238
|
self, *, num_lines: int, block_size: int, wait: bool, follow: bool
|
|
265
239
|
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
240
|
+
statedir = await get_client()._state_dir(self.ssh)
|
|
241
|
+
file = statedir / f"{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
|
|
242
|
+
|
|
243
|
+
fs = await get_client().fs(self.ssh)
|
|
244
|
+
|
|
245
|
+
if wait:
|
|
246
|
+
while not (await fs.exists(file)):
|
|
247
|
+
await asyncio.sleep(5)
|
|
248
|
+
|
|
249
|
+
file_size = await fs.size(file)
|
|
250
|
+
assert file_size is not None
|
|
251
|
+
|
|
252
|
+
async with await fs.open(file, "rb") as remote_file: # type: ignore
|
|
253
|
+
data = b""
|
|
254
|
+
lines = []
|
|
255
|
+
position = file_size
|
|
256
|
+
|
|
257
|
+
while len(lines) <= num_lines and position > 0:
|
|
258
|
+
read_size = min(block_size, position)
|
|
259
|
+
position -= read_size
|
|
260
|
+
await remote_file.seek(position) # type: ignore
|
|
261
|
+
chunk = await remote_file.read(read_size)
|
|
262
|
+
data = chunk + data
|
|
263
|
+
lines = data.splitlines()
|
|
264
|
+
|
|
265
|
+
if position <= 0:
|
|
266
|
+
yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
|
|
267
|
+
for line in lines[-num_lines:]:
|
|
268
|
+
yield Text.from_ansi(line.decode("utf-8", errors="replace"))
|
|
269
|
+
|
|
270
|
+
if (await self.get_state()) not in status.SlurmActiveJobStates:
|
|
271
|
+
yield Rule("[bold red]END OF FILE[/bold red]")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
if not follow:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
await remote_file.seek(file_size) # type: ignore
|
|
278
|
+
while True:
|
|
279
|
+
if new_data := (await remote_file.read(block_size)):
|
|
280
|
+
yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
|
|
281
|
+
else:
|
|
282
|
+
await asyncio.sleep(0.25)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class CompletedProcess(tp.Protocol):
|
|
286
|
+
returncode: int | None
|
|
287
|
+
stdout: bytes | str
|
|
288
|
+
stderr: bytes | str
|
|
308
289
|
|
|
309
290
|
|
|
310
291
|
@functools.cache
|
|
311
|
-
def get_template_env(
|
|
292
|
+
def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
|
|
312
293
|
template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
|
|
313
294
|
template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
|
|
314
295
|
|
|
@@ -318,44 +299,205 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
318
299
|
template_env.globals["raise"] = _raise_template_exception
|
|
319
300
|
template_env.globals["operator"] = operator
|
|
320
301
|
|
|
321
|
-
|
|
302
|
+
entrypoint_template = template_env.get_template("entrypoint.bash.j2")
|
|
303
|
+
template_env.globals.update(entrypoint_template.module.__dict__)
|
|
304
|
+
|
|
305
|
+
match runtime:
|
|
322
306
|
case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
|
|
323
307
|
runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
|
|
324
308
|
case ContainerRuntime.PODMAN:
|
|
325
309
|
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
326
310
|
case _:
|
|
327
|
-
raise NotImplementedError(f"Container runtime {
|
|
328
|
-
# Update our global env with the runtime template's exported globals
|
|
311
|
+
raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
|
|
329
312
|
template_env.globals.update(runtime_template.module.__dict__)
|
|
330
313
|
|
|
331
314
|
return template_env
|
|
332
315
|
|
|
333
316
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
return Client()
|
|
317
|
+
class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
318
|
+
"""SSHClient that does not prompt for keyboard-interactive authentication."""
|
|
337
319
|
|
|
320
|
+
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
321
|
+
return ""
|
|
338
322
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
323
|
+
def kbdint_challenge_received(
|
|
324
|
+
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
325
|
+
) -> MaybeAwait[KbdIntResponse | None]:
|
|
326
|
+
del name, instructions, lang, prompts
|
|
327
|
+
return []
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class SlurmExecutionClient:
|
|
331
|
+
def __init__(self):
|
|
332
|
+
self._remote_connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
|
|
333
|
+
self._remote_filesystems = dict[config.SlurmSSHConfig, AsyncSSHFileSystem]()
|
|
334
|
+
self._remote_connection_lock = asyncio.Lock()
|
|
335
|
+
|
|
336
|
+
self._local_fs = AsyncLocalFileSystem()
|
|
337
|
+
|
|
338
|
+
@backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
|
|
339
|
+
async def _local_run( # type: ignore
|
|
340
|
+
self,
|
|
341
|
+
command: str,
|
|
342
|
+
*,
|
|
343
|
+
check: bool = False,
|
|
344
|
+
timeout: float | None = None,
|
|
345
|
+
) -> subprocess.CompletedProcess[str]:
|
|
346
|
+
process = await asyncio.subprocess.create_subprocess_shell(
|
|
347
|
+
command,
|
|
348
|
+
stdout=asyncio.subprocess.PIPE,
|
|
349
|
+
stderr=asyncio.subprocess.PIPE,
|
|
350
|
+
# Filter out all SLURM_ environment variables as this could be running on a
|
|
351
|
+
# compute node and xm-slurm should act stateless.
|
|
352
|
+
env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
|
|
353
|
+
)
|
|
354
|
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
|
|
355
|
+
|
|
356
|
+
stdout = stdout.decode("utf-8").strip() if stdout else ""
|
|
357
|
+
stderr = stderr.decode("utf-8").strip() if stderr else ""
|
|
358
|
+
|
|
359
|
+
assert process.returncode is not None
|
|
360
|
+
if check and process.returncode != 0:
|
|
361
|
+
raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
|
|
362
|
+
|
|
363
|
+
return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
|
|
343
364
|
|
|
344
365
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
345
|
-
async def
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
366
|
+
async def _remote_run( # type: ignore
|
|
367
|
+
self,
|
|
368
|
+
ssh_config: config.SlurmSSHConfig,
|
|
369
|
+
command: str,
|
|
370
|
+
*,
|
|
371
|
+
check: bool = False,
|
|
372
|
+
timeout: float | None = None,
|
|
373
|
+
) -> asyncssh.SSHCompletedProcess:
|
|
374
|
+
client = await self._connection(ssh_config)
|
|
375
|
+
return await client.run(command, check=check, timeout=timeout)
|
|
376
|
+
|
|
377
|
+
@functools.cache
|
|
378
|
+
def _is_ssh_config_local(self, ssh_config: SlurmSSHConfig) -> bool:
|
|
379
|
+
"""A best effort check to see if the SSH config is local so we can bypass ssh."""
|
|
380
|
+
|
|
381
|
+
# We can't verify the connection so bail out
|
|
382
|
+
if ssh_config.host_public_key is None:
|
|
383
|
+
return False
|
|
384
|
+
if "SSH_CONNECTION" not in os.environ:
|
|
385
|
+
return False
|
|
386
|
+
|
|
387
|
+
def _is_host_local(host: str) -> bool:
|
|
388
|
+
nonlocal ssh_config
|
|
389
|
+
assert ssh_config.host_public_key is not None
|
|
390
|
+
|
|
391
|
+
if shutil.which("ssh-keyscan") is None:
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
keyscan_result = utils.run_command(
|
|
395
|
+
["ssh-keyscan", "-t", ssh_config.host_public_key.algorithm, host],
|
|
396
|
+
return_stdout=True,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if keyscan_result.returncode != 0:
|
|
400
|
+
return False
|
|
401
|
+
|
|
402
|
+
try:
|
|
403
|
+
key = mit.one(
|
|
404
|
+
filter(
|
|
405
|
+
lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
_, algorithm, key = key.split(" ")
|
|
409
|
+
|
|
410
|
+
if (
|
|
411
|
+
algorithm == ssh_config.host_public_key.algorithm
|
|
412
|
+
and key == ssh_config.host_public_key.key
|
|
413
|
+
):
|
|
414
|
+
return True
|
|
415
|
+
|
|
416
|
+
except Exception:
|
|
417
|
+
pass
|
|
418
|
+
|
|
419
|
+
return False
|
|
420
|
+
|
|
421
|
+
# 1): we're directly connected to the host
|
|
422
|
+
ssh_connection_str = os.environ["SSH_CONNECTION"]
|
|
423
|
+
_, _, server_ip, _ = ssh_connection_str.split()
|
|
424
|
+
|
|
425
|
+
logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
|
|
426
|
+
if _is_host_local(server_ip):
|
|
427
|
+
return True
|
|
428
|
+
|
|
429
|
+
# 2): we're in a Slurm job and the submission host is the host
|
|
430
|
+
if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
|
|
431
|
+
submit_host = os.environ["SLURM_SUBMIT_HOST"]
|
|
432
|
+
logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
|
|
433
|
+
if _is_host_local(submit_host):
|
|
434
|
+
return True
|
|
435
|
+
elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
|
|
436
|
+
# Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
|
|
437
|
+
# so we'll parse it from scontrol...
|
|
438
|
+
scontrol_result = utils.run_command(
|
|
439
|
+
["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
|
|
440
|
+
return_stdout=True,
|
|
441
|
+
)
|
|
442
|
+
if scontrol_result.returncode != 0:
|
|
443
|
+
return False
|
|
444
|
+
|
|
445
|
+
match = re.search(
|
|
446
|
+
r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
|
|
447
|
+
)
|
|
448
|
+
if match is not None:
|
|
449
|
+
host = match.group("host")
|
|
450
|
+
logger.debug("Checking if AllocNode %s is local", host)
|
|
451
|
+
if _is_host_local(host):
|
|
452
|
+
return True
|
|
453
|
+
|
|
454
|
+
return False
|
|
455
|
+
|
|
456
|
+
@functools.cache
|
|
457
|
+
async def _state_dir(self, ssh_config: SlurmSSHConfig) -> pathlib.Path:
|
|
458
|
+
cmd = await self.run(ssh_config, "printenv HOME", check=True)
|
|
459
|
+
assert isinstance(cmd.stdout, str)
|
|
460
|
+
return pathlib.Path(cmd.stdout.strip()) / ".local" / "state" / "xm-slurm"
|
|
461
|
+
|
|
462
|
+
async def run(
|
|
463
|
+
self,
|
|
464
|
+
ssh_config: SlurmSSHConfig,
|
|
465
|
+
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
466
|
+
*,
|
|
467
|
+
check: bool = False,
|
|
468
|
+
timeout: float | None = None,
|
|
469
|
+
) -> CompletedProcess:
|
|
470
|
+
if isinstance(command, xm.SequentialArgs):
|
|
471
|
+
command = command.to_list()
|
|
472
|
+
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
473
|
+
command = shlex.join(command)
|
|
474
|
+
assert isinstance(command, str)
|
|
475
|
+
|
|
476
|
+
if self._is_ssh_config_local(ssh_config):
|
|
477
|
+
logger.debug("Running command locally: %s", command)
|
|
478
|
+
return await self._local_run(command, check=check, timeout=timeout) # type: ignore
|
|
479
|
+
else:
|
|
480
|
+
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
481
|
+
return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
|
|
482
|
+
|
|
483
|
+
async def fs(self, ssh_config: SlurmSSHConfig) -> AsyncFileSystem:
|
|
484
|
+
if self._is_ssh_config_local(ssh_config):
|
|
485
|
+
return self._local_fs
|
|
486
|
+
|
|
487
|
+
if ssh_config not in self._remote_filesystems:
|
|
488
|
+
self._remote_filesystems[ssh_config] = AsyncSSHFileSystem(
|
|
489
|
+
await (await self._connection(ssh_config)).start_sftp_client()
|
|
490
|
+
)
|
|
491
|
+
return self._remote_filesystems[ssh_config]
|
|
492
|
+
|
|
493
|
+
async def _connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
|
|
494
|
+
if ssh_config not in self._remote_connections:
|
|
495
|
+
async with self._remote_connection_lock:
|
|
353
496
|
try:
|
|
354
497
|
conn, _ = await asyncssh.create_connection(
|
|
355
498
|
NoKBAuthSSHClient, options=ssh_config.connection_options
|
|
356
499
|
)
|
|
357
|
-
|
|
358
|
-
self._connections[ssh_config] = conn
|
|
500
|
+
self._remote_connections[ssh_config] = conn
|
|
359
501
|
except asyncssh.misc.PermissionDenied as ex:
|
|
360
502
|
raise SlurmExecutionError(
|
|
361
503
|
f"Permission denied connecting to {ssh_config.host}"
|
|
@@ -375,28 +517,9 @@ class Client:
|
|
|
375
517
|
f"SSH connection error when connecting to {ssh_config.host}"
|
|
376
518
|
) from ex
|
|
377
519
|
|
|
378
|
-
return self.
|
|
379
|
-
|
|
380
|
-
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
381
|
-
async def run(
|
|
382
|
-
self,
|
|
383
|
-
ssh_config: config.SlurmSSHConfig,
|
|
384
|
-
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
385
|
-
*,
|
|
386
|
-
check: bool = False,
|
|
387
|
-
timeout: float | None = None,
|
|
388
|
-
) -> asyncssh.SSHCompletedProcess:
|
|
389
|
-
client = await self.connection(ssh_config)
|
|
390
|
-
if isinstance(command, xm.SequentialArgs):
|
|
391
|
-
command = command.to_list()
|
|
392
|
-
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
393
|
-
command = shlex.join(command)
|
|
394
|
-
assert isinstance(command, str)
|
|
395
|
-
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
396
|
-
|
|
397
|
-
return await client.run(command, check=check, timeout=timeout)
|
|
520
|
+
return self._remote_connections[ssh_config]
|
|
398
521
|
|
|
399
|
-
async def
|
|
522
|
+
async def _submission_script_template(
|
|
400
523
|
self,
|
|
401
524
|
*,
|
|
402
525
|
job: xm.Job | xm.JobGroup,
|
|
@@ -410,6 +533,12 @@ class Client:
|
|
|
410
533
|
args = {}
|
|
411
534
|
|
|
412
535
|
template_env = get_template_env(cluster.runtime)
|
|
536
|
+
template_context = dict(
|
|
537
|
+
dependency=dependency,
|
|
538
|
+
cluster=cluster,
|
|
539
|
+
experiment_id=experiment_id,
|
|
540
|
+
identity=identity,
|
|
541
|
+
)
|
|
413
542
|
|
|
414
543
|
# Sanitize job groups
|
|
415
544
|
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
@@ -430,26 +559,14 @@ class Client:
|
|
|
430
559
|
)
|
|
431
560
|
|
|
432
561
|
return template.render(
|
|
433
|
-
job=job_array,
|
|
434
|
-
dependency=dependency,
|
|
435
|
-
cluster=cluster,
|
|
436
|
-
args=sequential_args,
|
|
437
|
-
env_vars=env_vars,
|
|
438
|
-
experiment_id=experiment_id,
|
|
439
|
-
identity=identity,
|
|
562
|
+
job=job_array, args=sequential_args, env_vars=env_vars, **template_context
|
|
440
563
|
)
|
|
441
564
|
case xm.Job() if isinstance(args, collections.abc.Mapping):
|
|
442
565
|
template = template_env.get_template("job.bash.j2")
|
|
443
566
|
sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
|
|
444
567
|
env_vars = args.get("env_vars", None)
|
|
445
568
|
return template.render(
|
|
446
|
-
job=job,
|
|
447
|
-
dependency=dependency,
|
|
448
|
-
cluster=cluster,
|
|
449
|
-
args=sequential_args,
|
|
450
|
-
env_vars=env_vars,
|
|
451
|
-
experiment_id=experiment_id,
|
|
452
|
-
identity=identity,
|
|
569
|
+
job=job, args=sequential_args, env_vars=env_vars, **template_context
|
|
453
570
|
)
|
|
454
571
|
case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
|
|
455
572
|
template = template_env.get_template("job-group.bash.j2")
|
|
@@ -464,13 +581,7 @@ class Client:
|
|
|
464
581
|
for job_name in job_group.jobs.keys()
|
|
465
582
|
}
|
|
466
583
|
return template.render(
|
|
467
|
-
job_group=job_group,
|
|
468
|
-
dependency=dependency,
|
|
469
|
-
cluster=cluster,
|
|
470
|
-
args=sequential_args,
|
|
471
|
-
env_vars=env_vars,
|
|
472
|
-
experiment_id=experiment_id,
|
|
473
|
-
identity=identity,
|
|
584
|
+
job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
|
|
474
585
|
)
|
|
475
586
|
case _:
|
|
476
587
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
@@ -521,7 +632,7 @@ class Client:
|
|
|
521
632
|
experiment_id: int,
|
|
522
633
|
identity: str | None = None,
|
|
523
634
|
):
|
|
524
|
-
|
|
635
|
+
submission_script = await self._submission_script_template(
|
|
525
636
|
job=job,
|
|
526
637
|
dependency=dependency,
|
|
527
638
|
cluster=cluster,
|
|
@@ -529,24 +640,19 @@ class Client:
|
|
|
529
640
|
experiment_id=experiment_id,
|
|
530
641
|
identity=identity,
|
|
531
642
|
)
|
|
532
|
-
logger.debug("Slurm submission script:\n%s",
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
|
|
543
|
-
async with sftp.open(
|
|
544
|
-
f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
|
|
545
|
-
) as fp:
|
|
546
|
-
await fp.write(template)
|
|
643
|
+
logger.debug("Slurm submission script:\n%s", submission_script)
|
|
644
|
+
submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
|
|
645
|
+
submission_script_path = f"submission-script-{submission_script_hash}.sh"
|
|
646
|
+
|
|
647
|
+
fs = await self.fs(cluster.ssh)
|
|
648
|
+
|
|
649
|
+
template_dir = (await self._state_dir(cluster.ssh)) / f"{experiment_id}"
|
|
650
|
+
|
|
651
|
+
await fs.makedirs(template_dir, exist_ok=True)
|
|
652
|
+
await fs.write(template_dir / submission_script_path, submission_script.encode())
|
|
547
653
|
|
|
548
654
|
# Construct and run command on the cluster
|
|
549
|
-
command = f"sbatch --chdir .
|
|
655
|
+
command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
|
|
550
656
|
result = await self.run(cluster.ssh, command)
|
|
551
657
|
if result.returncode != 0:
|
|
552
658
|
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
@@ -596,8 +702,16 @@ class Client:
|
|
|
596
702
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
597
703
|
|
|
598
704
|
def __del__(self):
|
|
599
|
-
for
|
|
705
|
+
for fs in self._remote_filesystems.values():
|
|
706
|
+
del fs
|
|
707
|
+
for conn in self._remote_connections.values():
|
|
600
708
|
conn.close()
|
|
709
|
+
del conn
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
@functools.cache
|
|
713
|
+
def get_client() -> SlurmExecutionClient:
|
|
714
|
+
return SlurmExecutionClient()
|
|
601
715
|
|
|
602
716
|
|
|
603
717
|
@tp.overload
|