xmanager-slurm 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/execution.py CHANGED
@@ -5,7 +5,12 @@ import functools
5
5
  import hashlib
6
6
  import logging
7
7
  import operator
8
+ import os
9
+ import pathlib
10
+ import re
8
11
  import shlex
12
+ import shutil
13
+ import subprocess
9
14
  import typing as tp
10
15
 
11
16
  import asyncssh
@@ -16,36 +21,18 @@ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
16
21
  from asyncssh.misc import MaybeAwait
17
22
  from rich.console import ConsoleRenderable
18
23
  from rich.rule import Rule
24
+ from rich.text import Text
19
25
  from xmanager import xm
20
26
 
21
- from xm_slurm import batching, config, constants, dependencies, executors, status
27
+ from xm_slurm import batching, config, constants, dependencies, executors, status, utils
28
+ from xm_slurm.config import ContainerRuntime, SlurmClusterConfig, SlurmSSHConfig
22
29
  from xm_slurm.console import console
30
+ from xm_slurm.filesystem import AsyncFileSystem, AsyncLocalFileSystem, AsyncSSHFileSystem
23
31
  from xm_slurm.job_blocks import JobArgs
24
32
  from xm_slurm.types import Descriptor
25
33
 
26
- SlurmClusterConfig = config.SlurmClusterConfig
27
- ContainerRuntime = config.ContainerRuntime
28
-
29
34
  logger = logging.getLogger(__name__)
30
35
 
31
- """
32
- === Runtime Configurations ===
33
- With RunC:
34
- skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
35
-
36
- pushd $SLURM_TMPDIR
37
-
38
- umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
39
- umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
40
-
41
- runc run -b bundle/<digest> <container-id>
42
-
43
- With Singularity / Apptainer:
44
-
45
- apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
46
- apptainer run --compat <digest>
47
- """
48
-
49
36
  _POLL_INTERVAL = 30.0
50
37
  _BATCHED_BATCH_SIZE = 16
51
38
  _BATCHED_TIMEOUT = 0.2
@@ -54,19 +41,6 @@ _BATCHED_TIMEOUT = 0.2
54
41
  class SlurmExecutionError(Exception): ...
55
42
 
56
43
 
57
- class NoKBAuthSSHClient(asyncssh.SSHClient):
58
- """SSHClient that does not prompt for keyboard-interactive authentication."""
59
-
60
- def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
61
- return ""
62
-
63
- def kbdint_challenge_received(
64
- self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
65
- ) -> MaybeAwait[KbdIntResponse | None]:
66
- del name, instructions, lang, prompts
67
- return []
68
-
69
-
70
44
  @dataclasses.dataclass(frozen=True, kw_only=True)
71
45
  class SlurmJob:
72
46
  job_id: str
@@ -208,7 +182,7 @@ class _BatchedSlurmHandle:
208
182
  # Reconstruct the job states in the original order
209
183
  job_states = []
210
184
  for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
211
- job_states.append(job_states_by_cluster[ssh_config][slurm_job])
185
+ job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
212
186
  return job_states
213
187
 
214
188
  @functools.partial(
@@ -263,52 +237,59 @@ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
263
237
  async def logs(
264
238
  self, *, num_lines: int, block_size: int, wait: bool, follow: bool
265
239
  ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
266
- file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
267
- conn = await get_client().connection(self.ssh)
268
- async with conn.start_sftp_client() as sftp:
269
- if wait:
270
- while not (await sftp.exists(file)):
271
- await asyncio.sleep(5)
272
-
273
- async with sftp.open(file, "rb") as remote_file:
274
- file_stat = await remote_file.stat()
275
- file_size = file_stat.size
276
- assert file_size is not None
277
-
278
- data = b""
279
- lines = []
280
- position = file_size
281
-
282
- while len(lines) <= num_lines and position > 0:
283
- read_size = min(block_size, position)
284
- position -= read_size
285
- await remote_file.seek(position)
286
- chunk = await remote_file.read(read_size)
287
- data = chunk + data
288
- lines = data.splitlines()
289
-
290
- if position <= 0:
291
- yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
292
- for line in lines[-num_lines:]:
293
- yield line.decode("utf-8", errors="replace")
294
-
295
- if (await self.get_state()) not in status.SlurmActiveJobStates:
296
- yield Rule("[bold red]END OF FILE[/bold red]")
297
- return
298
-
299
- if not follow:
300
- return
301
-
302
- await remote_file.seek(file_size)
303
- while True:
304
- if new_data := (await remote_file.read(block_size)):
305
- yield new_data.decode("utf-8", errors="replace")
306
- else:
307
- await asyncio.sleep(0.25)
240
+ statedir = await get_client()._state_dir(self.ssh)
241
+ file = statedir / f"{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
242
+
243
+ fs = await get_client().fs(self.ssh)
244
+
245
+ if wait:
246
+ while not (await fs.exists(file)):
247
+ await asyncio.sleep(5)
248
+
249
+ file_size = await fs.size(file)
250
+ assert file_size is not None
251
+
252
+ async with await fs.open(file, "rb") as remote_file: # type: ignore
253
+ data = b""
254
+ lines = []
255
+ position = file_size
256
+
257
+ while len(lines) <= num_lines and position > 0:
258
+ read_size = min(block_size, position)
259
+ position -= read_size
260
+ await remote_file.seek(position) # type: ignore
261
+ chunk = await remote_file.read(read_size)
262
+ data = chunk + data
263
+ lines = data.splitlines()
264
+
265
+ if position <= 0:
266
+ yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
267
+ for line in lines[-num_lines:]:
268
+ yield Text.from_ansi(line.decode("utf-8", errors="replace"))
269
+
270
+ if (await self.get_state()) not in status.SlurmActiveJobStates:
271
+ yield Rule("[bold red]END OF FILE[/bold red]")
272
+ return
273
+
274
+ if not follow:
275
+ return
276
+
277
+ await remote_file.seek(file_size) # type: ignore
278
+ while True:
279
+ if new_data := (await remote_file.read(block_size)):
280
+ yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
281
+ else:
282
+ await asyncio.sleep(0.25)
283
+
284
+
285
+ class CompletedProcess(tp.Protocol):
286
+ returncode: int | None
287
+ stdout: bytes | str
288
+ stderr: bytes | str
308
289
 
309
290
 
310
291
  @functools.cache
311
- def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
292
+ def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
312
293
  template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
313
294
  template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
314
295
 
@@ -318,44 +299,205 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
318
299
  template_env.globals["raise"] = _raise_template_exception
319
300
  template_env.globals["operator"] = operator
320
301
 
321
- match container_runtime:
302
+ entrypoint_template = template_env.get_template("entrypoint.bash.j2")
303
+ template_env.globals.update(entrypoint_template.module.__dict__)
304
+
305
+ match runtime:
322
306
  case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
323
307
  runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
324
308
  case ContainerRuntime.PODMAN:
325
309
  runtime_template = template_env.get_template("runtimes/podman.bash.j2")
326
310
  case _:
327
- raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
328
- # Update our global env with the runtime template's exported globals
311
+ raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
329
312
  template_env.globals.update(runtime_template.module.__dict__)
330
313
 
331
314
  return template_env
332
315
 
333
316
 
334
- @functools.cache
335
- def get_client() -> "Client":
336
- return Client()
317
+ class NoKBAuthSSHClient(asyncssh.SSHClient):
318
+ """SSHClient that does not prompt for keyboard-interactive authentication."""
337
319
 
320
+ def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
321
+ return ""
338
322
 
339
- class Client:
340
- def __init__(self) -> None:
341
- self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
342
- self._connection_lock = asyncio.Lock()
323
+ def kbdint_challenge_received(
324
+ self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
325
+ ) -> MaybeAwait[KbdIntResponse | None]:
326
+ del name, instructions, lang, prompts
327
+ return []
328
+
329
+
330
+ class SlurmExecutionClient:
331
+ def __init__(self):
332
+ self._remote_connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
333
+ self._remote_filesystems = dict[config.SlurmSSHConfig, AsyncSSHFileSystem]()
334
+ self._remote_connection_lock = asyncio.Lock()
335
+
336
+ self._local_fs = AsyncLocalFileSystem()
337
+
338
+ @backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
339
+ async def _local_run( # type: ignore
340
+ self,
341
+ command: str,
342
+ *,
343
+ check: bool = False,
344
+ timeout: float | None = None,
345
+ ) -> subprocess.CompletedProcess[str]:
346
+ process = await asyncio.subprocess.create_subprocess_shell(
347
+ command,
348
+ stdout=asyncio.subprocess.PIPE,
349
+ stderr=asyncio.subprocess.PIPE,
350
+ # Filter out all SLURM_ environment variables as this could be running on a
351
+ # compute node and xm-slurm should act stateless.
352
+ env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
353
+ )
354
+ stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
355
+
356
+ stdout = stdout.decode("utf-8").strip() if stdout else ""
357
+ stderr = stderr.decode("utf-8").strip() if stderr else ""
358
+
359
+ assert process.returncode is not None
360
+ if check and process.returncode != 0:
361
+ raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
362
+
363
+ return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
343
364
 
344
365
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
345
- async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
346
- # Make sure the xm-slurm state directory exists
347
- async with conn.start_sftp_client() as sftp_client:
348
- await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
349
-
350
- async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
351
- if ssh_config not in self._connections:
352
- async with self._connection_lock:
366
+ async def _remote_run( # type: ignore
367
+ self,
368
+ ssh_config: config.SlurmSSHConfig,
369
+ command: str,
370
+ *,
371
+ check: bool = False,
372
+ timeout: float | None = None,
373
+ ) -> asyncssh.SSHCompletedProcess:
374
+ client = await self._connection(ssh_config)
375
+ return await client.run(command, check=check, timeout=timeout)
376
+
377
+ @functools.cache
378
+ def _is_ssh_config_local(self, ssh_config: SlurmSSHConfig) -> bool:
379
+ """A best effort check to see if the SSH config is local so we can bypass ssh."""
380
+
381
+ # We can't verify the connection so bail out
382
+ if ssh_config.host_public_key is None:
383
+ return False
384
+ if "SSH_CONNECTION" not in os.environ:
385
+ return False
386
+
387
+ def _is_host_local(host: str) -> bool:
388
+ nonlocal ssh_config
389
+ assert ssh_config.host_public_key is not None
390
+
391
+ if shutil.which("ssh-keyscan") is None:
392
+ return False
393
+
394
+ keyscan_result = utils.run_command(
395
+ ["ssh-keyscan", "-t", ssh_config.host_public_key.algorithm, host],
396
+ return_stdout=True,
397
+ )
398
+
399
+ if keyscan_result.returncode != 0:
400
+ return False
401
+
402
+ try:
403
+ key = mit.one(
404
+ filter(
405
+ lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
406
+ )
407
+ )
408
+ _, algorithm, key = key.split(" ")
409
+
410
+ if (
411
+ algorithm == ssh_config.host_public_key.algorithm
412
+ and key == ssh_config.host_public_key.key
413
+ ):
414
+ return True
415
+
416
+ except Exception:
417
+ pass
418
+
419
+ return False
420
+
421
+ # 1): we're directly connected to the host
422
+ ssh_connection_str = os.environ["SSH_CONNECTION"]
423
+ _, _, server_ip, _ = ssh_connection_str.split()
424
+
425
+ logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
426
+ if _is_host_local(server_ip):
427
+ return True
428
+
429
+ # 2): we're in a Slurm job and the submission host is the host
430
+ if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
431
+ submit_host = os.environ["SLURM_SUBMIT_HOST"]
432
+ logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
433
+ if _is_host_local(submit_host):
434
+ return True
435
+ elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
436
+ # Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
437
+ # so we'll parse it from scontrol...
438
+ scontrol_result = utils.run_command(
439
+ ["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
440
+ return_stdout=True,
441
+ )
442
+ if scontrol_result.returncode != 0:
443
+ return False
444
+
445
+ match = re.search(
446
+ r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
447
+ )
448
+ if match is not None:
449
+ host = match.group("host")
450
+ logger.debug("Checking if AllocNode %s is local", host)
451
+ if _is_host_local(host):
452
+ return True
453
+
454
+ return False
455
+
456
+ @functools.cache
457
+ async def _state_dir(self, ssh_config: SlurmSSHConfig) -> pathlib.Path:
458
+ cmd = await self.run(ssh_config, "printenv HOME", check=True)
459
+ assert isinstance(cmd.stdout, str)
460
+ return pathlib.Path(cmd.stdout.strip()) / ".local" / "state" / "xm-slurm"
461
+
462
+ async def run(
463
+ self,
464
+ ssh_config: SlurmSSHConfig,
465
+ command: xm.SequentialArgs | str | tp.Sequence[str],
466
+ *,
467
+ check: bool = False,
468
+ timeout: float | None = None,
469
+ ) -> CompletedProcess:
470
+ if isinstance(command, xm.SequentialArgs):
471
+ command = command.to_list()
472
+ if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
473
+ command = shlex.join(command)
474
+ assert isinstance(command, str)
475
+
476
+ if self._is_ssh_config_local(ssh_config):
477
+ logger.debug("Running command locally: %s", command)
478
+ return await self._local_run(command, check=check, timeout=timeout) # type: ignore
479
+ else:
480
+ logger.debug("Running command on %s: %s", ssh_config.host, command)
481
+ return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
482
+
483
+ async def fs(self, ssh_config: SlurmSSHConfig) -> AsyncFileSystem:
484
+ if self._is_ssh_config_local(ssh_config):
485
+ return self._local_fs
486
+
487
+ if ssh_config not in self._remote_filesystems:
488
+ self._remote_filesystems[ssh_config] = AsyncSSHFileSystem(
489
+ await (await self._connection(ssh_config)).start_sftp_client()
490
+ )
491
+ return self._remote_filesystems[ssh_config]
492
+
493
+ async def _connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
494
+ if ssh_config not in self._remote_connections:
495
+ async with self._remote_connection_lock:
353
496
  try:
354
497
  conn, _ = await asyncssh.create_connection(
355
498
  NoKBAuthSSHClient, options=ssh_config.connection_options
356
499
  )
357
- await self._setup_remote_connection(conn)
358
- self._connections[ssh_config] = conn
500
+ self._remote_connections[ssh_config] = conn
359
501
  except asyncssh.misc.PermissionDenied as ex:
360
502
  raise SlurmExecutionError(
361
503
  f"Permission denied connecting to {ssh_config.host}"
@@ -375,28 +517,9 @@ class Client:
375
517
  f"SSH connection error when connecting to {ssh_config.host}"
376
518
  ) from ex
377
519
 
378
- return self._connections[ssh_config]
379
-
380
- @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
381
- async def run(
382
- self,
383
- ssh_config: config.SlurmSSHConfig,
384
- command: xm.SequentialArgs | str | tp.Sequence[str],
385
- *,
386
- check: bool = False,
387
- timeout: float | None = None,
388
- ) -> asyncssh.SSHCompletedProcess:
389
- client = await self.connection(ssh_config)
390
- if isinstance(command, xm.SequentialArgs):
391
- command = command.to_list()
392
- if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
393
- command = shlex.join(command)
394
- assert isinstance(command, str)
395
- logger.debug("Running command on %s: %s", ssh_config.host, command)
396
-
397
- return await client.run(command, check=check, timeout=timeout)
520
+ return self._remote_connections[ssh_config]
398
521
 
399
- async def template(
522
+ async def _submission_script_template(
400
523
  self,
401
524
  *,
402
525
  job: xm.Job | xm.JobGroup,
@@ -410,6 +533,12 @@ class Client:
410
533
  args = {}
411
534
 
412
535
  template_env = get_template_env(cluster.runtime)
536
+ template_context = dict(
537
+ dependency=dependency,
538
+ cluster=cluster,
539
+ experiment_id=experiment_id,
540
+ identity=identity,
541
+ )
413
542
 
414
543
  # Sanitize job groups
415
544
  if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
@@ -430,26 +559,14 @@ class Client:
430
559
  )
431
560
 
432
561
  return template.render(
433
- job=job_array,
434
- dependency=dependency,
435
- cluster=cluster,
436
- args=sequential_args,
437
- env_vars=env_vars,
438
- experiment_id=experiment_id,
439
- identity=identity,
562
+ job=job_array, args=sequential_args, env_vars=env_vars, **template_context
440
563
  )
441
564
  case xm.Job() if isinstance(args, collections.abc.Mapping):
442
565
  template = template_env.get_template("job.bash.j2")
443
566
  sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
444
567
  env_vars = args.get("env_vars", None)
445
568
  return template.render(
446
- job=job,
447
- dependency=dependency,
448
- cluster=cluster,
449
- args=sequential_args,
450
- env_vars=env_vars,
451
- experiment_id=experiment_id,
452
- identity=identity,
569
+ job=job, args=sequential_args, env_vars=env_vars, **template_context
453
570
  )
454
571
  case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
455
572
  template = template_env.get_template("job-group.bash.j2")
@@ -464,13 +581,7 @@ class Client:
464
581
  for job_name in job_group.jobs.keys()
465
582
  }
466
583
  return template.render(
467
- job_group=job_group,
468
- dependency=dependency,
469
- cluster=cluster,
470
- args=sequential_args,
471
- env_vars=env_vars,
472
- experiment_id=experiment_id,
473
- identity=identity,
584
+ job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
474
585
  )
475
586
  case _:
476
587
  raise ValueError(f"Unsupported job type: {type(job)}")
@@ -521,7 +632,7 @@ class Client:
521
632
  experiment_id: int,
522
633
  identity: str | None = None,
523
634
  ):
524
- template = await self.template(
635
+ submission_script = await self._submission_script_template(
525
636
  job=job,
526
637
  dependency=dependency,
527
638
  cluster=cluster,
@@ -529,24 +640,19 @@ class Client:
529
640
  experiment_id=experiment_id,
530
641
  identity=identity,
531
642
  )
532
- logger.debug("Slurm submission script:\n%s", template)
533
-
534
- # Hash submission script
535
- template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
536
-
537
- conn = await self.connection(cluster.ssh)
538
- async with conn.start_sftp_client() as sftp:
539
- # Write the submission script to the cluster
540
- # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
541
- # INSTEAD OF ASSUMING SFTP PUTS US IN THE HOME DIRECTORY
542
- await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
543
- async with sftp.open(
544
- f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
545
- ) as fp:
546
- await fp.write(template)
643
+ logger.debug("Slurm submission script:\n%s", submission_script)
644
+ submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
645
+ submission_script_path = f"submission-script-{submission_script_hash}.sh"
646
+
647
+ fs = await self.fs(cluster.ssh)
648
+
649
+ template_dir = (await self._state_dir(cluster.ssh)) / f"{experiment_id}"
650
+
651
+ await fs.makedirs(template_dir, exist_ok=True)
652
+ await fs.write(template_dir / submission_script_path, submission_script.encode())
547
653
 
548
654
  # Construct and run command on the cluster
549
- command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
655
+ command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
550
656
  result = await self.run(cluster.ssh, command)
551
657
  if result.returncode != 0:
552
658
  raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
@@ -596,8 +702,16 @@ class Client:
596
702
  raise ValueError(f"Unsupported job type: {type(job)}")
597
703
 
598
704
  def __del__(self):
599
- for conn in self._connections.values():
705
+ for fs in self._remote_filesystems.values():
706
+ del fs
707
+ for conn in self._remote_connections.values():
600
708
  conn.close()
709
+ del conn
710
+
711
+
712
+ @functools.cache
713
+ def get_client() -> SlurmExecutionClient:
714
+ return SlurmExecutionClient()
601
715
 
602
716
 
603
717
  @tp.overload