xmanager-slurm 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/execution.py CHANGED
@@ -5,7 +5,12 @@ import functools
5
5
  import hashlib
6
6
  import logging
7
7
  import operator
8
+ import os
9
+ import pathlib
10
+ import re
8
11
  import shlex
12
+ import shutil
13
+ import subprocess
9
14
  import typing as tp
10
15
 
11
16
  import asyncssh
@@ -16,36 +21,18 @@ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
16
21
  from asyncssh.misc import MaybeAwait
17
22
  from rich.console import ConsoleRenderable
18
23
  from rich.rule import Rule
24
+ from rich.text import Text
19
25
  from xmanager import xm
20
26
 
21
- from xm_slurm import batching, config, constants, dependencies, executors, status
27
+ from xm_slurm import batching, config, constants, dependencies, executors, status, utils
28
+ from xm_slurm.config import ContainerRuntime, SlurmClusterConfig, SlurmSSHConfig
22
29
  from xm_slurm.console import console
30
+ from xm_slurm.filesystem import AsyncFileSystem, AsyncLocalFileSystem, AsyncSSHFileSystem
23
31
  from xm_slurm.job_blocks import JobArgs
24
32
  from xm_slurm.types import Descriptor
25
33
 
26
- SlurmClusterConfig = config.SlurmClusterConfig
27
- ContainerRuntime = config.ContainerRuntime
28
-
29
34
  logger = logging.getLogger(__name__)
30
35
 
31
- """
32
- === Runtime Configurations ===
33
- With RunC:
34
- skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
35
-
36
- pushd $SLURM_TMPDIR
37
-
38
- umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
39
- umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
40
-
41
- runc run -b bundle/<digest> <container-id>
42
-
43
- With Singularity / Apptainer:
44
-
45
- apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
46
- apptainer run --compat <digest>
47
- """
48
-
49
36
  _POLL_INTERVAL = 30.0
50
37
  _BATCHED_BATCH_SIZE = 16
51
38
  _BATCHED_TIMEOUT = 0.2
@@ -54,19 +41,6 @@ _BATCHED_TIMEOUT = 0.2
54
41
  class SlurmExecutionError(Exception): ...
55
42
 
56
43
 
57
- class NoKBAuthSSHClient(asyncssh.SSHClient):
58
- """SSHClient that does not prompt for keyboard-interactive authentication."""
59
-
60
- def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
61
- return ""
62
-
63
- def kbdint_challenge_received(
64
- self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
65
- ) -> MaybeAwait[KbdIntResponse | None]:
66
- del name, instructions, lang, prompts
67
- return []
68
-
69
-
70
44
  @dataclasses.dataclass(frozen=True, kw_only=True)
71
45
  class SlurmJob:
72
46
  job_id: str
@@ -208,7 +182,7 @@ class _BatchedSlurmHandle:
208
182
  # Reconstruct the job states in the original order
209
183
  job_states = []
210
184
  for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
211
- job_states.append(job_states_by_cluster[ssh_config][slurm_job])
185
+ job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
212
186
  return job_states
213
187
 
214
188
  @functools.partial(
@@ -263,52 +237,59 @@ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
263
237
  async def logs(
264
238
  self, *, num_lines: int, block_size: int, wait: bool, follow: bool
265
239
  ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
266
- file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
267
- conn = await get_client().connection(self.ssh)
268
- async with conn.start_sftp_client() as sftp:
269
- if wait:
270
- while not (await sftp.exists(file)):
271
- await asyncio.sleep(5)
272
-
273
- async with sftp.open(file, "rb") as remote_file:
274
- file_stat = await remote_file.stat()
275
- file_size = file_stat.size
276
- assert file_size is not None
277
-
278
- data = b""
279
- lines = []
280
- position = file_size
281
-
282
- while len(lines) <= num_lines and position > 0:
283
- read_size = min(block_size, position)
284
- position -= read_size
285
- await remote_file.seek(position)
286
- chunk = await remote_file.read(read_size)
287
- data = chunk + data
288
- lines = data.splitlines()
289
-
290
- if position <= 0:
291
- yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
292
- for line in lines[-num_lines:]:
293
- yield line.decode("utf-8", errors="replace")
294
-
295
- if (await self.get_state()) not in status.SlurmActiveJobStates:
296
- yield Rule("[bold red]END OF FILE[/bold red]")
297
- return
298
-
299
- if not follow:
300
- return
301
-
302
- await remote_file.seek(file_size)
303
- while True:
304
- if new_data := (await remote_file.read(block_size)):
305
- yield new_data.decode("utf-8", errors="replace")
306
- else:
307
- await asyncio.sleep(0.25)
240
+ experiment_dir = await get_client().experiment_dir(self.ssh, self.experiment_id)
241
+ file = experiment_dir / f"slurm-{self.slurm_job.job_id}.out"
242
+
243
+ fs = await get_client().fs(self.ssh)
244
+
245
+ if wait:
246
+ while not (await fs.exists(file)):
247
+ await asyncio.sleep(5)
248
+
249
+ file_size = await fs.size(file)
250
+ assert file_size is not None
251
+
252
+ async with await fs.open(file, "rb") as remote_file: # type: ignore
253
+ data = b""
254
+ lines = []
255
+ position = file_size
256
+
257
+ while len(lines) <= num_lines and position > 0:
258
+ read_size = min(block_size, position)
259
+ position -= read_size
260
+ await remote_file.seek(position) # type: ignore
261
+ chunk = await remote_file.read(read_size)
262
+ data = chunk + data
263
+ lines = data.splitlines()
264
+
265
+ if position <= 0:
266
+ yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
267
+ for line in lines[-num_lines:]:
268
+ yield Text.from_ansi(line.decode("utf-8", errors="replace"))
269
+
270
+ if (await self.get_state()) not in status.SlurmActiveJobStates:
271
+ yield Rule("[bold red]END OF FILE[/bold red]")
272
+ return
273
+
274
+ if not follow:
275
+ return
276
+
277
+ await remote_file.seek(file_size) # type: ignore
278
+ while True:
279
+ if new_data := (await remote_file.read(block_size)):
280
+ yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
281
+ else:
282
+ await asyncio.sleep(0.25)
283
+
284
+
285
+ class CompletedProcess(tp.Protocol):
286
+ returncode: int | None
287
+ stdout: bytes | str
288
+ stderr: bytes | str
308
289
 
309
290
 
310
291
  @functools.cache
311
- def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
292
+ def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
312
293
  template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
313
294
  template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
314
295
 
@@ -318,44 +299,222 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
318
299
  template_env.globals["raise"] = _raise_template_exception
319
300
  template_env.globals["operator"] = operator
320
301
 
321
- match container_runtime:
302
+ entrypoint_template = template_env.get_template("entrypoint.bash.j2")
303
+ template_env.globals.update(entrypoint_template.module.__dict__)
304
+
305
+ match runtime:
322
306
  case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
323
307
  runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
324
308
  case ContainerRuntime.PODMAN:
325
309
  runtime_template = template_env.get_template("runtimes/podman.bash.j2")
326
310
  case _:
327
- raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
328
- # Update our global env with the runtime template's exported globals
311
+ raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
329
312
  template_env.globals.update(runtime_template.module.__dict__)
330
313
 
331
314
  return template_env
332
315
 
333
316
 
334
- @functools.cache
335
- def get_client() -> "Client":
336
- return Client()
317
+ class NoKBAuthSSHClient(asyncssh.SSHClient):
318
+ """SSHClient that does not prompt for keyboard-interactive authentication."""
337
319
 
320
+ def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
321
+ return ""
338
322
 
339
- class Client:
340
- def __init__(self) -> None:
341
- self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
342
- self._connection_lock = asyncio.Lock()
323
+ def kbdint_challenge_received(
324
+ self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
325
+ ) -> MaybeAwait[KbdIntResponse | None]:
326
+ del name, instructions, lang, prompts
327
+ return []
328
+
329
+
330
+ class SlurmExecutionClient:
331
+ def __init__(self):
332
+ self._remote_connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
333
+ self._remote_filesystems = dict[config.SlurmSSHConfig, AsyncSSHFileSystem]()
334
+ self._remote_connection_lock = asyncio.Lock()
335
+
336
+ self._local_fs = AsyncLocalFileSystem()
337
+
338
+ @backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
339
+ async def _local_run( # type: ignore
340
+ self,
341
+ command: str,
342
+ *,
343
+ check: bool = False,
344
+ timeout: float | None = None,
345
+ ) -> subprocess.CompletedProcess[str]:
346
+ process = await asyncio.subprocess.create_subprocess_shell(
347
+ command,
348
+ stdout=asyncio.subprocess.PIPE,
349
+ stderr=asyncio.subprocess.PIPE,
350
+ # Filter out all SLURM_ environment variables as this could be running on a
351
+ # compute node and xm-slurm should act stateless.
352
+ env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
353
+ )
354
+ stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
355
+
356
+ stdout = stdout.decode("utf-8").strip() if stdout else ""
357
+ stderr = stderr.decode("utf-8").strip() if stderr else ""
358
+
359
+ assert process.returncode is not None
360
+ if check and process.returncode != 0:
361
+ raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
362
+
363
+ return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
343
364
 
344
365
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
345
- async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
346
- # Make sure the xm-slurm state directory exists
347
- async with conn.start_sftp_client() as sftp_client:
348
- await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
349
-
350
- async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
351
- if ssh_config not in self._connections:
352
- async with self._connection_lock:
366
+ async def _remote_run( # type: ignore
367
+ self,
368
+ ssh_config: config.SlurmSSHConfig,
369
+ command: str,
370
+ *,
371
+ check: bool = False,
372
+ timeout: float | None = None,
373
+ ) -> asyncssh.SSHCompletedProcess:
374
+ client = await self._connection(ssh_config)
375
+ return await client.run(command, check=check, timeout=timeout)
376
+
377
+ @functools.cache
378
+ def _is_ssh_config_local(self, ssh_config: SlurmSSHConfig) -> bool:
379
+ """A best effort check to see if the SSH config is local so we can bypass ssh."""
380
+
381
+ # We can't verify the connection so bail out
382
+ if ssh_config.host_public_key is None:
383
+ return False
384
+ if "SSH_CONNECTION" not in os.environ:
385
+ return False
386
+
387
+ def _is_host_local(host: str) -> bool:
388
+ nonlocal ssh_config
389
+ assert ssh_config.host_public_key is not None
390
+
391
+ if shutil.which("ssh-keyscan") is None:
392
+ return False
393
+
394
+ keyscan_result = utils.run_command(
395
+ ["ssh-keyscan", "-t", ssh_config.host_public_key.algorithm, host],
396
+ return_stdout=True,
397
+ )
398
+
399
+ if keyscan_result.returncode != 0:
400
+ return False
401
+
402
+ try:
403
+ key = mit.one(
404
+ filter(
405
+ lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
406
+ )
407
+ )
408
+ _, algorithm, key = key.split(" ")
409
+
410
+ if (
411
+ algorithm == ssh_config.host_public_key.algorithm
412
+ and key == ssh_config.host_public_key.key
413
+ ):
414
+ return True
415
+
416
+ except Exception:
417
+ pass
418
+
419
+ return False
420
+
421
+ # 1): we're directly connected to the host
422
+ ssh_connection_str = os.environ["SSH_CONNECTION"]
423
+ _, _, server_ip, _ = ssh_connection_str.split()
424
+
425
+ logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
426
+ if _is_host_local(server_ip):
427
+ return True
428
+
429
+ # 2): we're in a Slurm job and the submission host is the host
430
+ if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
431
+ submit_host = os.environ["SLURM_SUBMIT_HOST"]
432
+ logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
433
+ if _is_host_local(submit_host):
434
+ return True
435
+ elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
436
+ # Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
437
+ # so we'll parse it from scontrol...
438
+ scontrol_result = utils.run_command(
439
+ ["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
440
+ return_stdout=True,
441
+ )
442
+ if scontrol_result.returncode != 0:
443
+ return False
444
+
445
+ match = re.search(
446
+ r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
447
+ )
448
+ if match is not None:
449
+ host = match.group("host")
450
+ logger.debug("Checking if AllocNode %s is local", host)
451
+ if _is_host_local(host):
452
+ return True
453
+
454
+ return False
455
+
456
+ @functools.cache
457
+ async def _state_dir(self, ssh_config: SlurmSSHConfig) -> pathlib.Path:
458
+ state_dirs = [
459
+ ("XM_SLURM_STATE_DIR", ""),
460
+ ("XDG_STATE_HOME", "xm-slurm"),
461
+ ("HOME", ".local/state/xm-slurm"),
462
+ ]
463
+
464
+ for env_var, subpath in state_dirs:
465
+ cmd = await self.run(ssh_config, f"printenv {env_var}", check=False)
466
+ assert isinstance(cmd.stdout, str)
467
+ if cmd.returncode == 0:
468
+ return pathlib.Path(cmd.stdout.strip()) / subpath
469
+
470
+ raise SlurmExecutionError(
471
+ "Failed to find a valid state directory for XManager. "
472
+ "We weren't able to resolve any of the following paths: "
473
+ f"{', '.join(env_var + ('/' + subpath if subpath else '') for env_var, subpath in state_dirs)}."
474
+ )
475
+
476
+ async def experiment_dir(self, ssh_config: SlurmSSHConfig, experiment_id: int) -> pathlib.Path:
477
+ return (await self._state_dir(ssh_config)) / f"{experiment_id:08d}"
478
+
479
+ async def run(
480
+ self,
481
+ ssh_config: SlurmSSHConfig,
482
+ command: xm.SequentialArgs | str | tp.Sequence[str],
483
+ *,
484
+ check: bool = False,
485
+ timeout: float | None = None,
486
+ ) -> CompletedProcess:
487
+ if isinstance(command, xm.SequentialArgs):
488
+ command = command.to_list()
489
+ if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
490
+ command = shlex.join(command)
491
+ assert isinstance(command, str)
492
+
493
+ if self._is_ssh_config_local(ssh_config):
494
+ logger.debug("Running command locally: %s", command)
495
+ return await self._local_run(command, check=check, timeout=timeout) # type: ignore
496
+ else:
497
+ logger.debug("Running command on %s: %s", ssh_config.host, command)
498
+ return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
499
+
500
+ async def fs(self, ssh_config: SlurmSSHConfig) -> AsyncFileSystem:
501
+ if self._is_ssh_config_local(ssh_config):
502
+ return self._local_fs
503
+
504
+ if ssh_config not in self._remote_filesystems:
505
+ self._remote_filesystems[ssh_config] = AsyncSSHFileSystem(
506
+ await (await self._connection(ssh_config)).start_sftp_client()
507
+ )
508
+ return self._remote_filesystems[ssh_config]
509
+
510
+ async def _connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
511
+ if ssh_config not in self._remote_connections:
512
+ async with self._remote_connection_lock:
353
513
  try:
354
514
  conn, _ = await asyncssh.create_connection(
355
515
  NoKBAuthSSHClient, options=ssh_config.connection_options
356
516
  )
357
- await self._setup_remote_connection(conn)
358
- self._connections[ssh_config] = conn
517
+ self._remote_connections[ssh_config] = conn
359
518
  except asyncssh.misc.PermissionDenied as ex:
360
519
  raise SlurmExecutionError(
361
520
  f"Permission denied connecting to {ssh_config.host}"
@@ -375,28 +534,9 @@ class Client:
375
534
  f"SSH connection error when connecting to {ssh_config.host}"
376
535
  ) from ex
377
536
 
378
- return self._connections[ssh_config]
379
-
380
- @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
381
- async def run(
382
- self,
383
- ssh_config: config.SlurmSSHConfig,
384
- command: xm.SequentialArgs | str | tp.Sequence[str],
385
- *,
386
- check: bool = False,
387
- timeout: float | None = None,
388
- ) -> asyncssh.SSHCompletedProcess:
389
- client = await self.connection(ssh_config)
390
- if isinstance(command, xm.SequentialArgs):
391
- command = command.to_list()
392
- if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
393
- command = shlex.join(command)
394
- assert isinstance(command, str)
395
- logger.debug("Running command on %s: %s", ssh_config.host, command)
396
-
397
- return await client.run(command, check=check, timeout=timeout)
537
+ return self._remote_connections[ssh_config]
398
538
 
399
- async def template(
539
+ async def _submission_script_template(
400
540
  self,
401
541
  *,
402
542
  job: xm.Job | xm.JobGroup,
@@ -410,6 +550,12 @@ class Client:
410
550
  args = {}
411
551
 
412
552
  template_env = get_template_env(cluster.runtime)
553
+ template_context = dict(
554
+ dependency=dependency,
555
+ cluster=cluster,
556
+ experiment_id=experiment_id,
557
+ identity=identity,
558
+ )
413
559
 
414
560
  # Sanitize job groups
415
561
  if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
@@ -430,26 +576,14 @@ class Client:
430
576
  )
431
577
 
432
578
  return template.render(
433
- job=job_array,
434
- dependency=dependency,
435
- cluster=cluster,
436
- args=sequential_args,
437
- env_vars=env_vars,
438
- experiment_id=experiment_id,
439
- identity=identity,
579
+ job=job_array, args=sequential_args, env_vars=env_vars, **template_context
440
580
  )
441
581
  case xm.Job() if isinstance(args, collections.abc.Mapping):
442
582
  template = template_env.get_template("job.bash.j2")
443
583
  sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
444
584
  env_vars = args.get("env_vars", None)
445
585
  return template.render(
446
- job=job,
447
- dependency=dependency,
448
- cluster=cluster,
449
- args=sequential_args,
450
- env_vars=env_vars,
451
- experiment_id=experiment_id,
452
- identity=identity,
586
+ job=job, args=sequential_args, env_vars=env_vars, **template_context
453
587
  )
454
588
  case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
455
589
  template = template_env.get_template("job-group.bash.j2")
@@ -464,13 +598,7 @@ class Client:
464
598
  for job_name in job_group.jobs.keys()
465
599
  }
466
600
  return template.render(
467
- job_group=job_group,
468
- dependency=dependency,
469
- cluster=cluster,
470
- args=sequential_args,
471
- env_vars=env_vars,
472
- experiment_id=experiment_id,
473
- identity=identity,
601
+ job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
474
602
  )
475
603
  case _:
476
604
  raise ValueError(f"Unsupported job type: {type(job)}")
@@ -521,7 +649,7 @@ class Client:
521
649
  experiment_id: int,
522
650
  identity: str | None = None,
523
651
  ):
524
- template = await self.template(
652
+ submission_script = await self._submission_script_template(
525
653
  job=job,
526
654
  dependency=dependency,
527
655
  cluster=cluster,
@@ -529,24 +657,19 @@ class Client:
529
657
  experiment_id=experiment_id,
530
658
  identity=identity,
531
659
  )
532
- logger.debug("Slurm submission script:\n%s", template)
533
-
534
- # Hash submission script
535
- template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
536
-
537
- conn = await self.connection(cluster.ssh)
538
- async with conn.start_sftp_client() as sftp:
539
- # Write the submission script to the cluster
540
- # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
541
- # INSTEAD OF ASSUMING SFTP PUTS US IN THE HOME DIRECTORY
542
- await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
543
- async with sftp.open(
544
- f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
545
- ) as fp:
546
- await fp.write(template)
660
+ logger.debug("Slurm submission script:\n%s", submission_script)
661
+ submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
662
+ submission_script_path = f"submission-script-{submission_script_hash}.sh"
663
+
664
+ fs = await self.fs(cluster.ssh)
665
+
666
+ template_dir = await self.experiment_dir(cluster.ssh, experiment_id)
667
+
668
+ await fs.makedirs(template_dir, exist_ok=True)
669
+ await fs.write(template_dir / submission_script_path, submission_script.encode())
547
670
 
548
671
  # Construct and run command on the cluster
549
- command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
672
+ command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
550
673
  result = await self.run(cluster.ssh, command)
551
674
  if result.returncode != 0:
552
675
  raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
@@ -596,8 +719,16 @@ class Client:
596
719
  raise ValueError(f"Unsupported job type: {type(job)}")
597
720
 
598
721
  def __del__(self):
599
- for conn in self._connections.values():
722
+ for fs in self._remote_filesystems.values():
723
+ del fs
724
+ for conn in self._remote_connections.values():
600
725
  conn.close()
726
+ del conn
727
+
728
+
729
+ @functools.cache
730
+ def get_client() -> SlurmExecutionClient:
731
+ return SlurmExecutionClient()
601
732
 
602
733
 
603
734
  @tp.overload