xmanager-slurm 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +0 -2
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/config.py +11 -3
- xm_slurm/contrib/clusters/__init__.py +3 -6
- xm_slurm/contrib/clusters/drac.py +4 -3
- xm_slurm/executables.py +4 -7
- xm_slurm/execution.py +290 -159
- xm_slurm/experiment.py +26 -180
- xm_slurm/filesystem.py +129 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +0 -9
- xm_slurm/packaging/docker.py +72 -22
- xm_slurm/packaging/utils.py +0 -108
- xm_slurm/scripts/cli.py +9 -2
- xm_slurm/templates/docker/uv.Dockerfile +6 -3
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +4 -4
- xm_slurm/templates/slurm/job-group.bash.j2 +2 -2
- xm_slurm/templates/slurm/job.bash.j2 +5 -4
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +18 -54
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +9 -24
- xm_slurm/utils.py +122 -41
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/METADATA +7 -3
- xmanager_slurm-0.4.7.dist-info/RECORD +51 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/WHEEL +1 -1
- xm_slurm/api.py +0 -528
- xmanager_slurm-0.4.5.dist-info/RECORD +0 -44
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/entry_points.txt +0 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/execution.py
CHANGED
|
@@ -5,7 +5,12 @@ import functools
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import logging
|
|
7
7
|
import operator
|
|
8
|
+
import os
|
|
9
|
+
import pathlib
|
|
10
|
+
import re
|
|
8
11
|
import shlex
|
|
12
|
+
import shutil
|
|
13
|
+
import subprocess
|
|
9
14
|
import typing as tp
|
|
10
15
|
|
|
11
16
|
import asyncssh
|
|
@@ -16,36 +21,18 @@ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
|
16
21
|
from asyncssh.misc import MaybeAwait
|
|
17
22
|
from rich.console import ConsoleRenderable
|
|
18
23
|
from rich.rule import Rule
|
|
24
|
+
from rich.text import Text
|
|
19
25
|
from xmanager import xm
|
|
20
26
|
|
|
21
|
-
from xm_slurm import batching, config, constants, dependencies, executors, status
|
|
27
|
+
from xm_slurm import batching, config, constants, dependencies, executors, status, utils
|
|
28
|
+
from xm_slurm.config import ContainerRuntime, SlurmClusterConfig, SlurmSSHConfig
|
|
22
29
|
from xm_slurm.console import console
|
|
30
|
+
from xm_slurm.filesystem import AsyncFileSystem, AsyncLocalFileSystem, AsyncSSHFileSystem
|
|
23
31
|
from xm_slurm.job_blocks import JobArgs
|
|
24
32
|
from xm_slurm.types import Descriptor
|
|
25
33
|
|
|
26
|
-
SlurmClusterConfig = config.SlurmClusterConfig
|
|
27
|
-
ContainerRuntime = config.ContainerRuntime
|
|
28
|
-
|
|
29
34
|
logger = logging.getLogger(__name__)
|
|
30
35
|
|
|
31
|
-
"""
|
|
32
|
-
=== Runtime Configurations ===
|
|
33
|
-
With RunC:
|
|
34
|
-
skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
|
|
35
|
-
|
|
36
|
-
pushd $SLURM_TMPDIR
|
|
37
|
-
|
|
38
|
-
umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
|
|
39
|
-
umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
|
|
40
|
-
|
|
41
|
-
runc run -b bundle/<digest> <container-id>
|
|
42
|
-
|
|
43
|
-
With Singularity / Apptainer:
|
|
44
|
-
|
|
45
|
-
apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
|
|
46
|
-
apptainer run --compat <digest>
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
36
|
_POLL_INTERVAL = 30.0
|
|
50
37
|
_BATCHED_BATCH_SIZE = 16
|
|
51
38
|
_BATCHED_TIMEOUT = 0.2
|
|
@@ -54,19 +41,6 @@ _BATCHED_TIMEOUT = 0.2
|
|
|
54
41
|
class SlurmExecutionError(Exception): ...
|
|
55
42
|
|
|
56
43
|
|
|
57
|
-
class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
58
|
-
"""SSHClient that does not prompt for keyboard-interactive authentication."""
|
|
59
|
-
|
|
60
|
-
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
61
|
-
return ""
|
|
62
|
-
|
|
63
|
-
def kbdint_challenge_received(
|
|
64
|
-
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
65
|
-
) -> MaybeAwait[KbdIntResponse | None]:
|
|
66
|
-
del name, instructions, lang, prompts
|
|
67
|
-
return []
|
|
68
|
-
|
|
69
|
-
|
|
70
44
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
71
45
|
class SlurmJob:
|
|
72
46
|
job_id: str
|
|
@@ -208,7 +182,7 @@ class _BatchedSlurmHandle:
|
|
|
208
182
|
# Reconstruct the job states in the original order
|
|
209
183
|
job_states = []
|
|
210
184
|
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
211
|
-
job_states.append(job_states_by_cluster[ssh_config][slurm_job])
|
|
185
|
+
job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
|
|
212
186
|
return job_states
|
|
213
187
|
|
|
214
188
|
@functools.partial(
|
|
@@ -263,52 +237,59 @@ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
|
|
|
263
237
|
async def logs(
|
|
264
238
|
self, *, num_lines: int, block_size: int, wait: bool, follow: bool
|
|
265
239
|
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
240
|
+
experiment_dir = await get_client().experiment_dir(self.ssh, self.experiment_id)
|
|
241
|
+
file = experiment_dir / f"slurm-{self.slurm_job.job_id}.out"
|
|
242
|
+
|
|
243
|
+
fs = await get_client().fs(self.ssh)
|
|
244
|
+
|
|
245
|
+
if wait:
|
|
246
|
+
while not (await fs.exists(file)):
|
|
247
|
+
await asyncio.sleep(5)
|
|
248
|
+
|
|
249
|
+
file_size = await fs.size(file)
|
|
250
|
+
assert file_size is not None
|
|
251
|
+
|
|
252
|
+
async with await fs.open(file, "rb") as remote_file: # type: ignore
|
|
253
|
+
data = b""
|
|
254
|
+
lines = []
|
|
255
|
+
position = file_size
|
|
256
|
+
|
|
257
|
+
while len(lines) <= num_lines and position > 0:
|
|
258
|
+
read_size = min(block_size, position)
|
|
259
|
+
position -= read_size
|
|
260
|
+
await remote_file.seek(position) # type: ignore
|
|
261
|
+
chunk = await remote_file.read(read_size)
|
|
262
|
+
data = chunk + data
|
|
263
|
+
lines = data.splitlines()
|
|
264
|
+
|
|
265
|
+
if position <= 0:
|
|
266
|
+
yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
|
|
267
|
+
for line in lines[-num_lines:]:
|
|
268
|
+
yield Text.from_ansi(line.decode("utf-8", errors="replace"))
|
|
269
|
+
|
|
270
|
+
if (await self.get_state()) not in status.SlurmActiveJobStates:
|
|
271
|
+
yield Rule("[bold red]END OF FILE[/bold red]")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
if not follow:
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
await remote_file.seek(file_size) # type: ignore
|
|
278
|
+
while True:
|
|
279
|
+
if new_data := (await remote_file.read(block_size)):
|
|
280
|
+
yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
|
|
281
|
+
else:
|
|
282
|
+
await asyncio.sleep(0.25)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class CompletedProcess(tp.Protocol):
|
|
286
|
+
returncode: int | None
|
|
287
|
+
stdout: bytes | str
|
|
288
|
+
stderr: bytes | str
|
|
308
289
|
|
|
309
290
|
|
|
310
291
|
@functools.cache
|
|
311
|
-
def get_template_env(
|
|
292
|
+
def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
|
|
312
293
|
template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
|
|
313
294
|
template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
|
|
314
295
|
|
|
@@ -318,44 +299,222 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
318
299
|
template_env.globals["raise"] = _raise_template_exception
|
|
319
300
|
template_env.globals["operator"] = operator
|
|
320
301
|
|
|
321
|
-
|
|
302
|
+
entrypoint_template = template_env.get_template("entrypoint.bash.j2")
|
|
303
|
+
template_env.globals.update(entrypoint_template.module.__dict__)
|
|
304
|
+
|
|
305
|
+
match runtime:
|
|
322
306
|
case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
|
|
323
307
|
runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
|
|
324
308
|
case ContainerRuntime.PODMAN:
|
|
325
309
|
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
326
310
|
case _:
|
|
327
|
-
raise NotImplementedError(f"Container runtime {
|
|
328
|
-
# Update our global env with the runtime template's exported globals
|
|
311
|
+
raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
|
|
329
312
|
template_env.globals.update(runtime_template.module.__dict__)
|
|
330
313
|
|
|
331
314
|
return template_env
|
|
332
315
|
|
|
333
316
|
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
return Client()
|
|
317
|
+
class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
318
|
+
"""SSHClient that does not prompt for keyboard-interactive authentication."""
|
|
337
319
|
|
|
320
|
+
def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
|
|
321
|
+
return ""
|
|
338
322
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
323
|
+
def kbdint_challenge_received(
|
|
324
|
+
self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
|
|
325
|
+
) -> MaybeAwait[KbdIntResponse | None]:
|
|
326
|
+
del name, instructions, lang, prompts
|
|
327
|
+
return []
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class SlurmExecutionClient:
|
|
331
|
+
def __init__(self):
|
|
332
|
+
self._remote_connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
|
|
333
|
+
self._remote_filesystems = dict[config.SlurmSSHConfig, AsyncSSHFileSystem]()
|
|
334
|
+
self._remote_connection_lock = asyncio.Lock()
|
|
335
|
+
|
|
336
|
+
self._local_fs = AsyncLocalFileSystem()
|
|
337
|
+
|
|
338
|
+
@backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
|
|
339
|
+
async def _local_run( # type: ignore
|
|
340
|
+
self,
|
|
341
|
+
command: str,
|
|
342
|
+
*,
|
|
343
|
+
check: bool = False,
|
|
344
|
+
timeout: float | None = None,
|
|
345
|
+
) -> subprocess.CompletedProcess[str]:
|
|
346
|
+
process = await asyncio.subprocess.create_subprocess_shell(
|
|
347
|
+
command,
|
|
348
|
+
stdout=asyncio.subprocess.PIPE,
|
|
349
|
+
stderr=asyncio.subprocess.PIPE,
|
|
350
|
+
# Filter out all SLURM_ environment variables as this could be running on a
|
|
351
|
+
# compute node and xm-slurm should act stateless.
|
|
352
|
+
env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
|
|
353
|
+
)
|
|
354
|
+
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
|
|
355
|
+
|
|
356
|
+
stdout = stdout.decode("utf-8").strip() if stdout else ""
|
|
357
|
+
stderr = stderr.decode("utf-8").strip() if stderr else ""
|
|
358
|
+
|
|
359
|
+
assert process.returncode is not None
|
|
360
|
+
if check and process.returncode != 0:
|
|
361
|
+
raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
|
|
362
|
+
|
|
363
|
+
return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
|
|
343
364
|
|
|
344
365
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
345
|
-
async def
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
366
|
+
async def _remote_run( # type: ignore
|
|
367
|
+
self,
|
|
368
|
+
ssh_config: config.SlurmSSHConfig,
|
|
369
|
+
command: str,
|
|
370
|
+
*,
|
|
371
|
+
check: bool = False,
|
|
372
|
+
timeout: float | None = None,
|
|
373
|
+
) -> asyncssh.SSHCompletedProcess:
|
|
374
|
+
client = await self._connection(ssh_config)
|
|
375
|
+
return await client.run(command, check=check, timeout=timeout)
|
|
376
|
+
|
|
377
|
+
@functools.cache
|
|
378
|
+
def _is_ssh_config_local(self, ssh_config: SlurmSSHConfig) -> bool:
|
|
379
|
+
"""A best effort check to see if the SSH config is local so we can bypass ssh."""
|
|
380
|
+
|
|
381
|
+
# We can't verify the connection so bail out
|
|
382
|
+
if ssh_config.host_public_key is None:
|
|
383
|
+
return False
|
|
384
|
+
if "SSH_CONNECTION" not in os.environ:
|
|
385
|
+
return False
|
|
386
|
+
|
|
387
|
+
def _is_host_local(host: str) -> bool:
|
|
388
|
+
nonlocal ssh_config
|
|
389
|
+
assert ssh_config.host_public_key is not None
|
|
390
|
+
|
|
391
|
+
if shutil.which("ssh-keyscan") is None:
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
keyscan_result = utils.run_command(
|
|
395
|
+
["ssh-keyscan", "-t", ssh_config.host_public_key.algorithm, host],
|
|
396
|
+
return_stdout=True,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if keyscan_result.returncode != 0:
|
|
400
|
+
return False
|
|
401
|
+
|
|
402
|
+
try:
|
|
403
|
+
key = mit.one(
|
|
404
|
+
filter(
|
|
405
|
+
lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
_, algorithm, key = key.split(" ")
|
|
409
|
+
|
|
410
|
+
if (
|
|
411
|
+
algorithm == ssh_config.host_public_key.algorithm
|
|
412
|
+
and key == ssh_config.host_public_key.key
|
|
413
|
+
):
|
|
414
|
+
return True
|
|
415
|
+
|
|
416
|
+
except Exception:
|
|
417
|
+
pass
|
|
418
|
+
|
|
419
|
+
return False
|
|
420
|
+
|
|
421
|
+
# 1): we're directly connected to the host
|
|
422
|
+
ssh_connection_str = os.environ["SSH_CONNECTION"]
|
|
423
|
+
_, _, server_ip, _ = ssh_connection_str.split()
|
|
424
|
+
|
|
425
|
+
logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
|
|
426
|
+
if _is_host_local(server_ip):
|
|
427
|
+
return True
|
|
428
|
+
|
|
429
|
+
# 2): we're in a Slurm job and the submission host is the host
|
|
430
|
+
if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
|
|
431
|
+
submit_host = os.environ["SLURM_SUBMIT_HOST"]
|
|
432
|
+
logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
|
|
433
|
+
if _is_host_local(submit_host):
|
|
434
|
+
return True
|
|
435
|
+
elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
|
|
436
|
+
# Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
|
|
437
|
+
# so we'll parse it from scontrol...
|
|
438
|
+
scontrol_result = utils.run_command(
|
|
439
|
+
["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
|
|
440
|
+
return_stdout=True,
|
|
441
|
+
)
|
|
442
|
+
if scontrol_result.returncode != 0:
|
|
443
|
+
return False
|
|
444
|
+
|
|
445
|
+
match = re.search(
|
|
446
|
+
r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
|
|
447
|
+
)
|
|
448
|
+
if match is not None:
|
|
449
|
+
host = match.group("host")
|
|
450
|
+
logger.debug("Checking if AllocNode %s is local", host)
|
|
451
|
+
if _is_host_local(host):
|
|
452
|
+
return True
|
|
453
|
+
|
|
454
|
+
return False
|
|
455
|
+
|
|
456
|
+
@functools.cache
|
|
457
|
+
async def _state_dir(self, ssh_config: SlurmSSHConfig) -> pathlib.Path:
|
|
458
|
+
state_dirs = [
|
|
459
|
+
("XM_SLURM_STATE_DIR", ""),
|
|
460
|
+
("XDG_STATE_HOME", "xm-slurm"),
|
|
461
|
+
("HOME", ".local/state/xm-slurm"),
|
|
462
|
+
]
|
|
463
|
+
|
|
464
|
+
for env_var, subpath in state_dirs:
|
|
465
|
+
cmd = await self.run(ssh_config, f"printenv {env_var}", check=False)
|
|
466
|
+
assert isinstance(cmd.stdout, str)
|
|
467
|
+
if cmd.returncode == 0:
|
|
468
|
+
return pathlib.Path(cmd.stdout.strip()) / subpath
|
|
469
|
+
|
|
470
|
+
raise SlurmExecutionError(
|
|
471
|
+
"Failed to find a valid state directory for XManager. "
|
|
472
|
+
"We weren't able to resolve any of the following paths: "
|
|
473
|
+
f"{', '.join(env_var + ('/' + subpath if subpath else '') for env_var, subpath in state_dirs)}."
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
async def experiment_dir(self, ssh_config: SlurmSSHConfig, experiment_id: int) -> pathlib.Path:
|
|
477
|
+
return (await self._state_dir(ssh_config)) / f"{experiment_id:08d}"
|
|
478
|
+
|
|
479
|
+
async def run(
|
|
480
|
+
self,
|
|
481
|
+
ssh_config: SlurmSSHConfig,
|
|
482
|
+
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
483
|
+
*,
|
|
484
|
+
check: bool = False,
|
|
485
|
+
timeout: float | None = None,
|
|
486
|
+
) -> CompletedProcess:
|
|
487
|
+
if isinstance(command, xm.SequentialArgs):
|
|
488
|
+
command = command.to_list()
|
|
489
|
+
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
490
|
+
command = shlex.join(command)
|
|
491
|
+
assert isinstance(command, str)
|
|
492
|
+
|
|
493
|
+
if self._is_ssh_config_local(ssh_config):
|
|
494
|
+
logger.debug("Running command locally: %s", command)
|
|
495
|
+
return await self._local_run(command, check=check, timeout=timeout) # type: ignore
|
|
496
|
+
else:
|
|
497
|
+
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
498
|
+
return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
|
|
499
|
+
|
|
500
|
+
async def fs(self, ssh_config: SlurmSSHConfig) -> AsyncFileSystem:
|
|
501
|
+
if self._is_ssh_config_local(ssh_config):
|
|
502
|
+
return self._local_fs
|
|
503
|
+
|
|
504
|
+
if ssh_config not in self._remote_filesystems:
|
|
505
|
+
self._remote_filesystems[ssh_config] = AsyncSSHFileSystem(
|
|
506
|
+
await (await self._connection(ssh_config)).start_sftp_client()
|
|
507
|
+
)
|
|
508
|
+
return self._remote_filesystems[ssh_config]
|
|
509
|
+
|
|
510
|
+
async def _connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
|
|
511
|
+
if ssh_config not in self._remote_connections:
|
|
512
|
+
async with self._remote_connection_lock:
|
|
353
513
|
try:
|
|
354
514
|
conn, _ = await asyncssh.create_connection(
|
|
355
515
|
NoKBAuthSSHClient, options=ssh_config.connection_options
|
|
356
516
|
)
|
|
357
|
-
|
|
358
|
-
self._connections[ssh_config] = conn
|
|
517
|
+
self._remote_connections[ssh_config] = conn
|
|
359
518
|
except asyncssh.misc.PermissionDenied as ex:
|
|
360
519
|
raise SlurmExecutionError(
|
|
361
520
|
f"Permission denied connecting to {ssh_config.host}"
|
|
@@ -375,28 +534,9 @@ class Client:
|
|
|
375
534
|
f"SSH connection error when connecting to {ssh_config.host}"
|
|
376
535
|
) from ex
|
|
377
536
|
|
|
378
|
-
return self.
|
|
379
|
-
|
|
380
|
-
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
381
|
-
async def run(
|
|
382
|
-
self,
|
|
383
|
-
ssh_config: config.SlurmSSHConfig,
|
|
384
|
-
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
385
|
-
*,
|
|
386
|
-
check: bool = False,
|
|
387
|
-
timeout: float | None = None,
|
|
388
|
-
) -> asyncssh.SSHCompletedProcess:
|
|
389
|
-
client = await self.connection(ssh_config)
|
|
390
|
-
if isinstance(command, xm.SequentialArgs):
|
|
391
|
-
command = command.to_list()
|
|
392
|
-
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
393
|
-
command = shlex.join(command)
|
|
394
|
-
assert isinstance(command, str)
|
|
395
|
-
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
396
|
-
|
|
397
|
-
return await client.run(command, check=check, timeout=timeout)
|
|
537
|
+
return self._remote_connections[ssh_config]
|
|
398
538
|
|
|
399
|
-
async def
|
|
539
|
+
async def _submission_script_template(
|
|
400
540
|
self,
|
|
401
541
|
*,
|
|
402
542
|
job: xm.Job | xm.JobGroup,
|
|
@@ -410,6 +550,12 @@ class Client:
|
|
|
410
550
|
args = {}
|
|
411
551
|
|
|
412
552
|
template_env = get_template_env(cluster.runtime)
|
|
553
|
+
template_context = dict(
|
|
554
|
+
dependency=dependency,
|
|
555
|
+
cluster=cluster,
|
|
556
|
+
experiment_id=experiment_id,
|
|
557
|
+
identity=identity,
|
|
558
|
+
)
|
|
413
559
|
|
|
414
560
|
# Sanitize job groups
|
|
415
561
|
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
@@ -430,26 +576,14 @@ class Client:
|
|
|
430
576
|
)
|
|
431
577
|
|
|
432
578
|
return template.render(
|
|
433
|
-
job=job_array,
|
|
434
|
-
dependency=dependency,
|
|
435
|
-
cluster=cluster,
|
|
436
|
-
args=sequential_args,
|
|
437
|
-
env_vars=env_vars,
|
|
438
|
-
experiment_id=experiment_id,
|
|
439
|
-
identity=identity,
|
|
579
|
+
job=job_array, args=sequential_args, env_vars=env_vars, **template_context
|
|
440
580
|
)
|
|
441
581
|
case xm.Job() if isinstance(args, collections.abc.Mapping):
|
|
442
582
|
template = template_env.get_template("job.bash.j2")
|
|
443
583
|
sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
|
|
444
584
|
env_vars = args.get("env_vars", None)
|
|
445
585
|
return template.render(
|
|
446
|
-
job=job,
|
|
447
|
-
dependency=dependency,
|
|
448
|
-
cluster=cluster,
|
|
449
|
-
args=sequential_args,
|
|
450
|
-
env_vars=env_vars,
|
|
451
|
-
experiment_id=experiment_id,
|
|
452
|
-
identity=identity,
|
|
586
|
+
job=job, args=sequential_args, env_vars=env_vars, **template_context
|
|
453
587
|
)
|
|
454
588
|
case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
|
|
455
589
|
template = template_env.get_template("job-group.bash.j2")
|
|
@@ -464,13 +598,7 @@ class Client:
|
|
|
464
598
|
for job_name in job_group.jobs.keys()
|
|
465
599
|
}
|
|
466
600
|
return template.render(
|
|
467
|
-
job_group=job_group,
|
|
468
|
-
dependency=dependency,
|
|
469
|
-
cluster=cluster,
|
|
470
|
-
args=sequential_args,
|
|
471
|
-
env_vars=env_vars,
|
|
472
|
-
experiment_id=experiment_id,
|
|
473
|
-
identity=identity,
|
|
601
|
+
job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
|
|
474
602
|
)
|
|
475
603
|
case _:
|
|
476
604
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
@@ -521,7 +649,7 @@ class Client:
|
|
|
521
649
|
experiment_id: int,
|
|
522
650
|
identity: str | None = None,
|
|
523
651
|
):
|
|
524
|
-
|
|
652
|
+
submission_script = await self._submission_script_template(
|
|
525
653
|
job=job,
|
|
526
654
|
dependency=dependency,
|
|
527
655
|
cluster=cluster,
|
|
@@ -529,24 +657,19 @@ class Client:
|
|
|
529
657
|
experiment_id=experiment_id,
|
|
530
658
|
identity=identity,
|
|
531
659
|
)
|
|
532
|
-
logger.debug("Slurm submission script:\n%s",
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
|
|
543
|
-
async with sftp.open(
|
|
544
|
-
f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
|
|
545
|
-
) as fp:
|
|
546
|
-
await fp.write(template)
|
|
660
|
+
logger.debug("Slurm submission script:\n%s", submission_script)
|
|
661
|
+
submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
|
|
662
|
+
submission_script_path = f"submission-script-{submission_script_hash}.sh"
|
|
663
|
+
|
|
664
|
+
fs = await self.fs(cluster.ssh)
|
|
665
|
+
|
|
666
|
+
template_dir = await self.experiment_dir(cluster.ssh, experiment_id)
|
|
667
|
+
|
|
668
|
+
await fs.makedirs(template_dir, exist_ok=True)
|
|
669
|
+
await fs.write(template_dir / submission_script_path, submission_script.encode())
|
|
547
670
|
|
|
548
671
|
# Construct and run command on the cluster
|
|
549
|
-
command = f"sbatch --chdir .
|
|
672
|
+
command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
|
|
550
673
|
result = await self.run(cluster.ssh, command)
|
|
551
674
|
if result.returncode != 0:
|
|
552
675
|
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
@@ -596,8 +719,16 @@ class Client:
|
|
|
596
719
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
597
720
|
|
|
598
721
|
def __del__(self):
|
|
599
|
-
for
|
|
722
|
+
for fs in self._remote_filesystems.values():
|
|
723
|
+
del fs
|
|
724
|
+
for conn in self._remote_connections.values():
|
|
600
725
|
conn.close()
|
|
726
|
+
del conn
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
@functools.cache
|
|
730
|
+
def get_client() -> SlurmExecutionClient:
|
|
731
|
+
return SlurmExecutionClient()
|
|
601
732
|
|
|
602
733
|
|
|
603
734
|
@tp.overload
|