xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
xm_slurm/execution.py ADDED
@@ -0,0 +1,995 @@
1
+ import asyncio
2
+ import collections.abc
3
+ import dataclasses
4
+ import functools
5
+ import getpass
6
+ import hashlib
7
+ import importlib
8
+ import importlib.resources
9
+ import logging
10
+ import operator
11
+ import os
12
+ import pathlib
13
+ import re
14
+ import shlex
15
+ import shutil
16
+ import subprocess
17
+ import sys
18
+ import typing as tp
19
+
20
+ import asyncssh
21
+ import backoff
22
+ import jinja2 as j2
23
+ import more_itertools as mit
24
+ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
25
+ from asyncssh.misc import MaybeAwait
26
+ from rich.console import ConsoleRenderable
27
+ from rich.rule import Rule
28
+ from rich.text import Text
29
+ from xmanager import xm
30
+
31
+ from xm_slurm import (
32
+ batching,
33
+ config,
34
+ constants,
35
+ dependencies,
36
+ executors,
37
+ filesystems,
38
+ job_blocks,
39
+ status,
40
+ utils,
41
+ )
42
+ from xm_slurm.console import console
43
+ from xm_slurm.types import Descriptor
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ _POLL_INTERVAL = 30.0
48
+ _BATCHED_BATCH_SIZE = 16
49
+ _BATCHED_TIMEOUT = 0.2
50
+
51
+
52
+ class SlurmExecutionError(Exception): ...
53
+
54
+
55
+ @dataclasses.dataclass(frozen=True, kw_only=True)
56
+ class SlurmJob:
57
+ job_id: str
58
+
59
+ @property
60
+ def is_array_job(self) -> bool:
61
+ return isinstance(self, SlurmArrayJob)
62
+
63
+ @property
64
+ def is_heterogeneous_job(self) -> bool:
65
+ return isinstance(self, SlurmHeterogeneousJob)
66
+
67
+ def __hash__(self) -> int:
68
+ return hash((type(self), self.job_id))
69
+
70
+
71
+ @dataclasses.dataclass(frozen=True, kw_only=True)
72
+ class SlurmArrayJob(SlurmJob):
73
+ array_job_id: str
74
+ array_task_id: str
75
+
76
+
77
+ @dataclasses.dataclass(frozen=True, kw_only=True)
78
+ class SlurmHeterogeneousJob(SlurmJob):
79
+ het_job_id: str
80
+ het_component_id: str
81
+
82
+
83
+ SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
84
+
85
+
86
+ class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
87
+ def __set_name__(self, owner: type, name: str):
88
+ del owner
89
+ self.job = f"_{name}"
90
+
91
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
92
+ del owner
93
+ return getattr(instance, self.job)
94
+
95
+ def __set__(self, instance: object, value: str):
96
+ _setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
97
+
98
+ match = constants.SLURM_JOB_ID_REGEX.match(value)
99
+ if match is None:
100
+ raise ValueError(f"Invalid Slurm job ID: {value}")
101
+ groups = match.groupdict()
102
+
103
+ job_id = groups["jobid"]
104
+ if array_task_id := groups.get("arraytaskid"):
105
+ _setattr(
106
+ instance,
107
+ self.job,
108
+ SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
109
+ )
110
+ elif het_component_id := groups.get("componentid"):
111
+ _setattr(
112
+ instance,
113
+ self.job,
114
+ SlurmHeterogeneousJob(
115
+ job_id=value, het_job_id=job_id, het_component_id=het_component_id
116
+ ),
117
+ )
118
+ else:
119
+ _setattr(instance, self.job, SlurmJob(job_id=value))
120
+
121
+
122
+ def _group_by_ssh_configs(
123
+ ssh_configs: tp.Sequence[config.SSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
124
+ ) -> dict[config.SSHConfig, list[SlurmJob]]:
125
+ jobs_by_cluster = collections.defaultdict(list)
126
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
127
+ jobs_by_cluster[ssh_config].append(slurm_job)
128
+ return jobs_by_cluster
129
+
130
+
131
+ class _BatchedSlurmHandle:
132
+ @functools.partial(
133
+ batching.batch,
134
+ max_batch_size=_BATCHED_BATCH_SIZE,
135
+ batch_timeout=_BATCHED_TIMEOUT,
136
+ )
137
+ @staticmethod
138
+ @backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
139
+ async def _batched_get_state(
140
+ ssh_configs: tp.Sequence[config.SSHConfig],
141
+ slurm_jobs: tp.Sequence[SlurmJob],
142
+ ) -> tp.Sequence[status.SlurmJobState]:
143
+ async def _get_state(
144
+ options: config.SSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
145
+ ) -> tp.Sequence[status.SlurmJobState]:
146
+ result = await get_client().run(
147
+ options,
148
+ [
149
+ "sacct",
150
+ "--jobs",
151
+ ",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
152
+ "--format",
153
+ "JobID,State",
154
+ "--allocations",
155
+ "--noheader",
156
+ "--parsable2",
157
+ ],
158
+ check=True,
159
+ )
160
+
161
+ assert isinstance(result.stdout, str)
162
+ states_by_job_id = {}
163
+ for line in result.stdout.splitlines():
164
+ job_id, state = line.split("|")
165
+ states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
166
+
167
+ job_states = []
168
+ for slurm_job in slurm_jobs:
169
+ if slurm_job.job_id in states_by_job_id:
170
+ job_states.append(states_by_job_id[slurm_job.job_id])
171
+ # This is a stupid hack around sacct's inability to display state information for
172
+ # array job elements that haven't begun. We'll assume that if the job ID is not found,
173
+ # and it's an array job, then it's pending.
174
+ elif slurm_job.is_array_job:
175
+ job_states.append(status.SlurmJobState.PENDING)
176
+ else:
177
+ raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
178
+ return job_states
179
+
180
+ # Group Slurm jobs by their cluster so we can batch requests
181
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
182
+
183
+ # Async get state for each cluster
184
+ job_states_per_cluster = await asyncio.gather(*[
185
+ _get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
186
+ ])
187
+
188
+ # Reconstruct the job states by cluster
189
+ job_states_by_cluster = {}
190
+ for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
191
+ job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
192
+
193
+ # Reconstruct the job states in the original order
194
+ job_states = []
195
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
196
+ job_states.append(job_states_by_cluster[ssh_config][slurm_job]) # type: ignore
197
+ return job_states
198
+
199
+ @functools.partial(
200
+ batching.batch,
201
+ max_batch_size=_BATCHED_BATCH_SIZE,
202
+ batch_timeout=_BATCHED_TIMEOUT,
203
+ )
204
+ @staticmethod
205
+ async def _batched_cancel(
206
+ ssh_configs: tp.Sequence[config.SSHConfig],
207
+ slurm_jobs: tp.Sequence[SlurmJob],
208
+ ) -> tp.Sequence[None]:
209
+ async def _cancel(options: config.SSHConfig, slurm_jobs: tp.Sequence[SlurmJob]) -> None:
210
+ await get_client().run(
211
+ options,
212
+ ["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
213
+ check=True,
214
+ )
215
+
216
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
217
+ return await asyncio.gather(*[
218
+ _cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
219
+ ])
220
+
221
+
222
+ @dataclasses.dataclass(frozen=True, kw_only=True)
223
+ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
224
+ """A handle for referring to the launched container."""
225
+
226
+ experiment_id: int
227
+ ssh: config.SSHConfig
228
+ slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
229
+ job_name: str # XManager job name associated with this handle
230
+
231
+ @backoff.on_predicate(
232
+ backoff.constant,
233
+ lambda state: state in status.SlurmActiveJobStates,
234
+ jitter=None,
235
+ interval=_POLL_INTERVAL,
236
+ )
237
+ async def wait(self) -> status.SlurmJobState:
238
+ return await self.get_state()
239
+
240
+ async def stop(self) -> None:
241
+ await self._batched_cancel(self.ssh, self.slurm_job)
242
+
243
+ async def get_state(self) -> status.SlurmJobState:
244
+ return await self._batched_get_state(self.ssh, self.slurm_job)
245
+
246
+ async def logs(
247
+ self, *, num_lines: int, block_size: int, wait: bool, follow: bool, raw: bool = False
248
+ ) -> tp.AsyncGenerator[tp.Union[str, ConsoleRenderable], None]:
249
+ experiment_dir = await get_client().experiment_dir(self.ssh, self.experiment_id)
250
+ file = experiment_dir / f"slurm-{self.slurm_job.job_id}.out"
251
+ fs = await get_client().fs(self.ssh)
252
+
253
+ if wait:
254
+ while not (await fs.exists(file)):
255
+ await asyncio.sleep(5)
256
+
257
+ file_size = await fs.size(file)
258
+ assert file_size is not None
259
+
260
+ async with await fs.open(file, "rb") as remote_file:
261
+ data = b""
262
+ lines = []
263
+ position = file_size
264
+
265
+ while len(lines) <= num_lines and position > 0:
266
+ read_size = min(block_size, position)
267
+ position -= read_size
268
+ await remote_file.seek(position)
269
+ chunk = await remote_file.read(read_size)
270
+ data = chunk + data
271
+ lines = data.splitlines()
272
+
273
+ if position <= 0:
274
+ if raw:
275
+ yield "\033[31mBEGINNING OF FILE\033[0m\n"
276
+ else:
277
+ yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
278
+ for line in lines[-num_lines:]:
279
+ if raw:
280
+ yield line.decode("utf-8", errors="replace") + "\n"
281
+ else:
282
+ yield Text.from_ansi(line.decode("utf-8", errors="replace"))
283
+
284
+ if (await self.get_state()) not in status.SlurmActiveJobStates:
285
+ if raw:
286
+ yield "\033[31mEND OF FILE\033[0m\n"
287
+ return
288
+ else:
289
+ yield Rule("[bold red]END OF FILE[/bold red]")
290
+ return
291
+
292
+ if not follow:
293
+ return
294
+
295
+ await remote_file.seek(file_size)
296
+ while True:
297
+ if new_data := (await remote_file.read(block_size)):
298
+ if raw:
299
+ yield new_data.decode("utf-8", errors="replace")
300
+ else:
301
+ yield Text.from_ansi(new_data.decode("utf-8", errors="replace"))
302
+ else:
303
+ await asyncio.sleep(0.25)
304
+
305
+
306
+ class CompletedProcess(tp.Protocol):
307
+ returncode: int | None
308
+ stdout: bytes | str
309
+ stderr: bytes | str
310
+
311
+
312
+ @functools.cache
313
+ def get_template_env(runtime: config.ContainerRuntime) -> j2.Environment:
314
+ template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
315
+ template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
316
+
317
+ def _raise_template_exception(msg: str) -> None:
318
+ raise j2.TemplateRuntimeError(msg)
319
+
320
+ template_env.globals["raise"] = _raise_template_exception
321
+ template_env.globals["operator"] = operator
322
+
323
+ # Iterate over stdlib files and insert them into the template environment
324
+ stdlib = []
325
+ for file in importlib.resources.files("xm_slurm.templates.slurm.library").iterdir():
326
+ if not file.is_file() or not file.name.endswith(".bash"):
327
+ continue
328
+ stdlib.append(file.read_text())
329
+ template_env.globals["stdlib"] = stdlib
330
+
331
+ entrypoint_template = template_env.get_template("entrypoint.bash.j2")
332
+ template_env.globals.update(entrypoint_template.module.__dict__)
333
+
334
+ match runtime:
335
+ case config.ContainerRuntime.SINGULARITY | config.ContainerRuntime.APPTAINER:
336
+ runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
337
+ case config.ContainerRuntime.PODMAN:
338
+ runtime_template = template_env.get_template("runtimes/podman.bash.j2")
339
+ case _:
340
+ raise NotImplementedError(f"Container runtime {runtime} is not implemented.")
341
+ template_env.globals.update(runtime_template.module.__dict__)
342
+
343
+ return template_env
344
+
345
+
346
+ class SlurmSSHClient(asyncssh.SSHClient):
347
+ """SSHClient that handles keyboard-interactive 2FA authentication."""
348
+
349
+ _kbdint_auth_lock: tp.ClassVar[asyncio.Lock] = asyncio.Lock()
350
+ _host: str
351
+
352
+ def __init__(self, host: str):
353
+ self._host = host
354
+
355
+ def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
356
+ return ""
357
+
358
+ async def kbdint_challenge_received(
359
+ self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
360
+ ) -> MaybeAwait[KbdIntResponse | None]:
361
+ """Handle 2FA prompts by prompting user for input."""
362
+ del name, lang
363
+ if not sys.stdin.isatty():
364
+ raise SlurmExecutionError(
365
+ f"Two-factor authentication is not supported for non-interactive sessions on {self._host}"
366
+ )
367
+
368
+ async with self._kbdint_auth_lock:
369
+ if len(prompts) > 0:
370
+ console.rule(f"Two-Factor Authentication for {self._host}")
371
+
372
+ if instructions:
373
+ console.print(instructions, style="bold yellow")
374
+
375
+ responses = []
376
+ for prompt, echo in prompts:
377
+ # Manually disable password authentication
378
+ if prompt.strip() == "Password:":
379
+ return None
380
+
381
+ try:
382
+ response = await asyncio.to_thread(
383
+ console.input,
384
+ f"{prompt}\a",
385
+ password=not echo,
386
+ )
387
+ except (EOFError, KeyboardInterrupt):
388
+ console.print("\n[red]Authentication cancelled[/red]")
389
+ return None
390
+ else:
391
+ responses.append(response)
392
+
393
+ return responses
394
+
395
+
396
+ class SlurmExecutionClient:
397
+ def __init__(self):
398
+ self._remote_connections = dict[config.SSHConfig, asyncssh.SSHClientConnection]()
399
+ self._remote_connection_locks = collections.defaultdict(asyncio.Lock)
400
+ self._remote_filesystems = dict[config.SSHConfig, filesystems.AsyncSSHFileSystem]()
401
+
402
+ self._local_fs = filesystems.AsyncLocalFileSystem()
403
+
404
+ @backoff.on_exception(backoff.expo, asyncio.exceptions.TimeoutError, max_tries=5, max_time=60.0)
405
+ async def _local_run( # type: ignore
406
+ self,
407
+ command: str,
408
+ *,
409
+ check: bool = False,
410
+ timeout: float | None = None,
411
+ ) -> subprocess.CompletedProcess[str]:
412
+ process = await asyncio.subprocess.create_subprocess_shell(
413
+ command,
414
+ stdout=asyncio.subprocess.PIPE,
415
+ stderr=asyncio.subprocess.PIPE,
416
+ # Filter out all SLURM_ environment variables as this could be running on a
417
+ # compute node and xm-slurm should act stateless.
418
+ env=dict(filter(lambda x: not x[0].startswith("SLURM_"), os.environ.items())),
419
+ )
420
+ stdout, stderr = await asyncio.wait_for(process.communicate(), timeout)
421
+
422
+ stdout = stdout.decode("utf-8").strip() if stdout else ""
423
+ stderr = stderr.decode("utf-8").strip() if stderr else ""
424
+
425
+ assert process.returncode is not None
426
+ if check and process.returncode != 0:
427
+ raise RuntimeError(f"Command failed with return code {process.returncode}: {command}\n")
428
+
429
+ return subprocess.CompletedProcess[str](command, process.returncode, stdout, stderr)
430
+
431
+ @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
432
+ async def _remote_run( # type: ignore
433
+ self,
434
+ ssh_config: config.SSHConfig,
435
+ command: str,
436
+ *,
437
+ check: bool = False,
438
+ timeout: float | None = None,
439
+ ) -> asyncssh.SSHCompletedProcess:
440
+ client = await self._connection(ssh_config)
441
+ return await client.run(command, check=check, timeout=timeout)
442
+
443
+ @functools.cache
444
+ def _is_ssh_config_local(self, ssh_config: config.SSHConfig) -> bool:
445
+ """A best effort check to see if the SSH config is local so we can bypass ssh."""
446
+
447
+ # We can't verify the connection so bail out
448
+ if ssh_config.public_key is None:
449
+ return False
450
+ if "SSH_CONNECTION" not in os.environ:
451
+ return False
452
+
453
+ def _is_host_local(host: str) -> bool:
454
+ nonlocal ssh_config
455
+ assert ssh_config.public_key is not None
456
+
457
+ if shutil.which("ssh-keyscan") is None:
458
+ return False
459
+
460
+ keyscan_result = utils.run_command(
461
+ ["ssh-keyscan", "-t", ssh_config.public_key.algorithm, host],
462
+ return_stdout=True,
463
+ )
464
+
465
+ if keyscan_result.returncode != 0:
466
+ return False
467
+
468
+ try:
469
+ key = mit.one(
470
+ filter(
471
+ lambda x: not x.startswith("#"), keyscan_result.stdout.strip().split("\n")
472
+ )
473
+ )
474
+ _, algorithm, key = key.split(" ")
475
+
476
+ if (
477
+ algorithm == ssh_config.public_key.algorithm
478
+ and key == ssh_config.public_key.key
479
+ ):
480
+ return True
481
+
482
+ except Exception:
483
+ pass
484
+
485
+ return False
486
+
487
+ # 1): we're directly connected to the host
488
+ ssh_connection_str = os.environ["SSH_CONNECTION"]
489
+ _, _, server_ip, _ = ssh_connection_str.split()
490
+
491
+ logger.debug("Checking if SSH_CONNECTION server %s is local", server_ip)
492
+ if _is_host_local(server_ip):
493
+ return True
494
+
495
+ # 2): we're in a Slurm job and the submission host is the host
496
+ if "SLURM_JOB_ID" in os.environ and "SLURM_SUBMIT_HOST" in os.environ:
497
+ submit_host = os.environ["SLURM_SUBMIT_HOST"]
498
+ logger.debug("Checking if SLURM_SUBMIT_HOST %s is local", submit_host)
499
+ if _is_host_local(submit_host):
500
+ return True
501
+ elif "SLURM_JOB_ID" in os.environ and shutil.which("scontrol") is not None:
502
+ # Stupid edge case where if you run srun SLURM_SUBMIT_HOST isn't forwarded
503
+ # so we'll parse it from scontrol...
504
+ scontrol_result = utils.run_command(
505
+ ["scontrol", "show", "job", os.environ["SLURM_JOB_ID"]],
506
+ return_stdout=True,
507
+ )
508
+ if scontrol_result.returncode != 0:
509
+ return False
510
+
511
+ match = re.search(
512
+ r"AllocNode:Sid=(?P<host>[^ ]+):\d+", scontrol_result.stdout.strip(), re.MULTILINE
513
+ )
514
+ if match is not None:
515
+ host = match.group("host")
516
+ logger.debug("Checking if AllocNode %s is local", host)
517
+ if _is_host_local(host):
518
+ return True
519
+
520
+ return False
521
+
522
+ @functools.cache
523
+ @utils.reawaitable
524
+ async def _state_dir(self, ssh_config: config.SSHConfig) -> pathlib.Path:
525
+ state_dirs = [
526
+ ("XM_SLURM_STATE_DIR", ""),
527
+ ("XDG_STATE_HOME", "xm-slurm"),
528
+ ("HOME", ".local/state/xm-slurm"),
529
+ ]
530
+
531
+ for env_var, subpath in state_dirs:
532
+ cmd = await self.run(ssh_config, f"printenv {env_var}", check=False)
533
+ assert isinstance(cmd.stdout, str)
534
+ if cmd.returncode == 0:
535
+ return pathlib.Path(cmd.stdout.strip()) / subpath
536
+
537
+ raise SlurmExecutionError(
538
+ "Failed to find a valid state directory for XManager. "
539
+ "We weren't able to resolve any of the following paths: "
540
+ f"{', '.join(env_var + ('/' + subpath if subpath else '') for env_var, subpath in state_dirs)}."
541
+ )
542
+
543
+ @functools.cached_property
544
+ def _ssh_config_dirs(self) -> list[pathlib.Path]:
545
+ ssh_config_paths = []
546
+
547
+ if (ssh_config_path := pathlib.Path.home() / ".ssh" / "config").exists():
548
+ ssh_config_paths.append(ssh_config_path)
549
+ if (xm_ssh_config_var := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
550
+ xm_ssh_config_path := pathlib.Path(xm_ssh_config_var).expanduser()
551
+ ).exists():
552
+ ssh_config_paths.append(xm_ssh_config_path)
553
+
554
+ return ssh_config_paths
555
+
556
+ async def experiment_dir(
557
+ self, ssh_config: config.SSHConfig, experiment_id: int
558
+ ) -> pathlib.Path:
559
+ return (await self._state_dir(ssh_config)) / f"{experiment_id:08d}"
560
+
561
+ async def run(
562
+ self,
563
+ ssh_config: config.SSHConfig,
564
+ command: xm.SequentialArgs | str | tp.Sequence[str],
565
+ *,
566
+ check: bool = False,
567
+ timeout: float | None = None,
568
+ ) -> CompletedProcess:
569
+ if isinstance(command, xm.SequentialArgs):
570
+ command = command.to_list()
571
+ if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
572
+ command = shlex.join(command)
573
+ assert isinstance(command, str)
574
+
575
+ if self._is_ssh_config_local(ssh_config):
576
+ logger.debug("Running command locally: %s", command)
577
+ return await self._local_run(command, check=check, timeout=timeout) # type: ignore
578
+ else:
579
+ logger.debug(
580
+ "Running command on %s: %s", ", ".join(map(str, ssh_config.endpoints)), command
581
+ )
582
+ return await self._remote_run(ssh_config, command, check=check, timeout=timeout) # type: ignore
583
+
584
+ async def fs(self, ssh_config: config.SSHConfig) -> filesystems.AsyncFileSystem:
585
+ if self._is_ssh_config_local(ssh_config):
586
+ return self._local_fs
587
+
588
+ if ssh_config not in self._remote_filesystems:
589
+ self._remote_filesystems[ssh_config] = filesystems.AsyncSSHFileSystem(
590
+ await (await self._connection(ssh_config)).start_sftp_client()
591
+ )
592
+ return self._remote_filesystems[ssh_config]
593
+
594
+ async def _connection(self, ssh_config: config.SSHConfig) -> asyncssh.SSHClientConnection:
595
+ async def _connect_to_endpoint(
596
+ endpoint: config.Endpoint,
597
+ ) -> asyncssh.SSHClientConnection:
598
+ __tracebackhide__ = True
599
+ try:
600
+ config = asyncssh.config.SSHClientConfig.load(
601
+ None,
602
+ self._ssh_config_dirs,
603
+ True,
604
+ True,
605
+ True,
606
+ getpass.getuser(),
607
+ ssh_config.user or (),
608
+ endpoint.hostname,
609
+ endpoint.port or (),
610
+ )
611
+ if config.get("Hostname") is None and (
612
+ constants.DOMAIN_NAME_REGEX.match(endpoint.hostname)
613
+ or constants.IPV4_REGEX.match(endpoint.hostname)
614
+ or constants.IPV6_REGEX.match(endpoint.hostname)
615
+ ):
616
+ config._options["Hostname"] = endpoint.hostname
617
+ elif config.get("Hostname") is None:
618
+ raise RuntimeError(
619
+ f"Failed to parse hostname from host `{endpoint.hostname}` using "
620
+ f"SSH configs: {', '.join(map(str, self._ssh_config_dirs))} and "
621
+ f"provided hostname `{endpoint.hostname}` isn't a valid domain name "
622
+ "or IPv{4,6} address."
623
+ )
624
+
625
+ if config.get("User") is None:
626
+ raise RuntimeError(
627
+ f"We could not find a user for the cluster configuration: `{endpoint.hostname}`. "
628
+ "No user was specified in the configuration and we could not parse "
629
+ f"any users for host `{config.get('Hostname')}` from the SSH configs: "
630
+ f"{', '.join(map(lambda h: f'`{h}`', self._ssh_config_dirs))}. Please either specify a user "
631
+ "in the configuration or add a user to your SSH configuration under the block "
632
+ f"`Host {config.get('Hostname')}`."
633
+ )
634
+
635
+ options = await asyncssh.SSHClientConnectionOptions.construct(
636
+ config=None,
637
+ disable_trivial_auth=True,
638
+ password_auth=False,
639
+ server_host_key_algs=ssh_config.public_key.algorithm
640
+ if ssh_config.public_key
641
+ else None,
642
+ login_timeout=60 * 10, # 10 minutes
643
+ known_hosts=ssh_config.known_hosts,
644
+ )
645
+ options.prepare(last_config=config)
646
+
647
+ conn, _ = await asyncssh.create_connection(
648
+ lambda: SlurmSSHClient(endpoint.hostname),
649
+ host=endpoint.hostname,
650
+ port=endpoint.port,
651
+ options=options,
652
+ )
653
+ return conn
654
+ except asyncssh.misc.PermissionDenied as ex:
655
+ raise SlurmExecutionError(
656
+ f"Permission denied connecting to {endpoint.hostname}"
657
+ ) from ex
658
+ except asyncssh.misc.ConnectionLost as ex:
659
+ raise SlurmExecutionError(f"Connection lost to host {endpoint.hostname}") from ex
660
+ except asyncssh.misc.HostKeyNotVerifiable as ex:
661
+ raise SlurmExecutionError(
662
+ f"Cannot verify the public key for host {endpoint.hostname}"
663
+ ) from ex
664
+ except asyncssh.misc.KeyExchangeFailed as ex:
665
+ raise SlurmExecutionError(
666
+ f"Failed to exchange keys with host {endpoint.hostname}"
667
+ ) from ex
668
+ except asyncssh.Error as ex:
669
+ raise SlurmExecutionError(
670
+ f"SSH connection error when connecting to {endpoint.hostname}"
671
+ ) from ex
672
+
673
+ conn = self._remote_connections.get(ssh_config)
674
+ if conn is not None and not conn.is_closed():
675
+ return conn
676
+
677
+ async with self._remote_connection_locks[ssh_config]:
678
+ conn = self._remote_connections.get(ssh_config)
679
+ if conn is not None and not conn.is_closed():
680
+ return conn
681
+
682
+ exceptions: list[Exception] = []
683
+ for endpoint in ssh_config.endpoints:
684
+ try:
685
+ conn = await _connect_to_endpoint(endpoint)
686
+ except Exception as ex:
687
+ exceptions.append(ex)
688
+ else:
689
+ self._remote_connections[ssh_config] = conn
690
+ return conn
691
+
692
+ if sys.version_info >= (3, 11):
693
+ raise ExceptionGroup("Failed to connect to all hosts", exceptions) # noqa: F821
694
+ raise exceptions[-1]
695
+
696
+ async def _submission_script_template(
697
+ self,
698
+ *,
699
+ job: xm.Job | xm.JobGroup,
700
+ dependency: dependencies.SlurmJobDependency | None = None,
701
+ cluster: config.SlurmClusterConfig,
702
+ args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
703
+ experiment_id: int,
704
+ identity: str | None,
705
+ ) -> str:
706
+ # Sanitize args
707
+ match args:
708
+ case None:
709
+ args = {}
710
+ case collections.abc.Mapping():
711
+ args = dict(args)
712
+ case collections.abc.Sequence():
713
+ assert all(isinstance(trial, collections.abc.Mapping) for trial in args)
714
+ args = [dict(trial) for trial in args]
715
+ case _:
716
+ raise ValueError("Invalid args type")
717
+ args = tp.cast(dict[str, tp.Any] | list[dict[str, tp.Any]], args)
718
+
719
+ template_env = get_template_env(cluster.runtime)
720
+ template_context = dict(
721
+ dependency=dependency,
722
+ cluster=cluster,
723
+ experiment_id=experiment_id,
724
+ identity=identity,
725
+ )
726
+
727
+ # Sanitize job groups
728
+ if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
729
+ job = tp.cast(xm.Job, list(job.jobs.values())[0])
730
+ elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
731
+ raise ValueError("Job group must have at least one job")
732
+
733
+ match job:
734
+ case xm.Job() as job_array if isinstance(args, collections.abc.Sequence):
735
+ assert isinstance(args, list)
736
+ template = template_env.get_template("job-array.bash.j2")
737
+ sequential_args = [
738
+ xm.SequentialArgs.from_collection(trial.get("args")) for trial in args
739
+ ]
740
+ env_vars = [trial.get("env_vars") for trial in args]
741
+ if any(env_vars):
742
+ raise NotImplementedError(
743
+ "Job arrays over environment variables are not yet supported."
744
+ )
745
+
746
+ return template.render(
747
+ job=job_array, args=sequential_args, env_vars=env_vars, **template_context
748
+ )
749
+ case xm.Job() if isinstance(args, collections.abc.Mapping):
750
+ assert isinstance(args, dict)
751
+ template = template_env.get_template("job.bash.j2")
752
+ sequential_args = xm.SequentialArgs.from_collection(args.get("args"))
753
+ env_vars = args.get("env_vars")
754
+ return template.render(
755
+ job=job, args=sequential_args, env_vars=env_vars, **template_context
756
+ )
757
+ case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
758
+ assert isinstance(args, dict)
759
+ template = template_env.get_template("job-group.bash.j2")
760
+ sequential_args = {
761
+ job_name: {
762
+ "args": args.get(job_name, {}).get("args"),
763
+ }
764
+ for job_name in job_group.jobs.keys()
765
+ }
766
+ env_vars = {
767
+ job_name: args.get(job_name, {}).get("env_vars")
768
+ for job_name in job_group.jobs.keys()
769
+ }
770
+ return template.render(
771
+ job_group=job_group, args=sequential_args, env_vars=env_vars, **template_context
772
+ )
773
+ case _:
774
+ raise ValueError(f"Unsupported job type: {type(job)}")
775
+
776
+ @tp.overload
777
+ async def launch(
778
+ self,
779
+ *,
780
+ cluster: config.SlurmClusterConfig,
781
+ job: xm.JobGroup,
782
+ dependency: dependencies.SlurmJobDependency | None = None,
783
+ args: tp.Mapping[str, job_blocks.JobArgs] | None,
784
+ experiment_id: int,
785
+ identity: str | None = ...,
786
+ ) -> SlurmHandle: ...
787
+
788
+ @tp.overload
789
+ async def launch(
790
+ self,
791
+ *,
792
+ cluster: config.SlurmClusterConfig,
793
+ job: xm.Job,
794
+ dependency: dependencies.SlurmJobDependency | None = None,
795
+ args: tp.Sequence[job_blocks.JobArgs],
796
+ experiment_id: int,
797
+ identity: str | None = ...,
798
+ ) -> list[SlurmHandle]: ...
799
+
800
+ @tp.overload
801
+ async def launch(
802
+ self,
803
+ *,
804
+ cluster: config.SlurmClusterConfig,
805
+ job: xm.Job,
806
+ dependency: dependencies.SlurmJobDependency | None = None,
807
+ args: job_blocks.JobArgs,
808
+ experiment_id: int,
809
+ identity: str | None = ...,
810
+ ) -> SlurmHandle: ...
811
+
812
+ async def launch(
813
+ self,
814
+ *,
815
+ cluster: config.SlurmClusterConfig,
816
+ job: xm.Job | xm.JobGroup,
817
+ dependency: dependencies.SlurmJobDependency | None = None,
818
+ args: tp.Mapping[str, job_blocks.JobArgs]
819
+ | tp.Sequence[job_blocks.JobArgs]
820
+ | job_blocks.JobArgs
821
+ | None,
822
+ experiment_id: int,
823
+ identity: str | None = None,
824
+ ):
825
+ submission_script = await self._submission_script_template(
826
+ job=job,
827
+ dependency=dependency,
828
+ cluster=cluster,
829
+ args=args,
830
+ experiment_id=experiment_id,
831
+ identity=identity,
832
+ )
833
+ logger.debug("Slurm submission script:\n%s", submission_script)
834
+ submission_script_hash = hashlib.blake2s(submission_script.encode()).hexdigest()[:8]
835
+ submission_script_path = f"submission-script-{submission_script_hash}.sh"
836
+
837
+ fs = await self.fs(cluster.ssh)
838
+
839
+ template_dir = await self.experiment_dir(cluster.ssh, experiment_id)
840
+
841
+ await fs.makedirs(template_dir, exist_ok=True)
842
+ await fs.write(template_dir / submission_script_path, submission_script.encode())
843
+
844
+ # Construct and run command on the cluster
845
+ command = f"sbatch --chdir {template_dir.as_posix()} --parsable {submission_script_path}"
846
+ result = await self.run(cluster.ssh, command)
847
+ if result.returncode != 0:
848
+ raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
849
+
850
+ assert isinstance(result.stdout, str)
851
+ slurm_job_id, *_ = result.stdout.split(",")
852
+ slurm_job_id = slurm_job_id.strip()
853
+
854
+ console.log(
855
+ f"[magenta]:rocket: Job [cyan]{slurm_job_id}[/cyan] will be launched on "
856
+ f"[cyan]{cluster.name}[/cyan] "
857
+ )
858
+
859
+ # If we scheduled an array job make sure to return a list of handles
860
+ # The indexing is always sequential in 0, 1, ..., n - 1
861
+ if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
862
+ assert job.name is not None
863
+ return [
864
+ SlurmHandle(
865
+ experiment_id=experiment_id,
866
+ ssh=cluster.ssh,
867
+ slurm_job=f"{slurm_job_id}_{array_index}",
868
+ job_name=job.name,
869
+ )
870
+ for array_index in range(len(args))
871
+ ]
872
+ elif isinstance(job, xm.Job):
873
+ assert job.name is not None
874
+ return SlurmHandle(
875
+ experiment_id=experiment_id,
876
+ ssh=cluster.ssh,
877
+ slurm_job=slurm_job_id,
878
+ job_name=job.name,
879
+ )
880
+ elif isinstance(job, xm.JobGroup):
881
+ # TODO: make this work for actual job groups.
882
+ job = tp.cast(xm.Job, mit.one(job.jobs.values()))
883
+ assert isinstance(job, xm.Job)
884
+ assert job.name is not None
885
+ return SlurmHandle(
886
+ experiment_id=experiment_id,
887
+ ssh=cluster.ssh,
888
+ slurm_job=slurm_job_id,
889
+ job_name=job.name,
890
+ )
891
+ else:
892
+ raise ValueError(f"Unsupported job type: {type(job)}")
893
+
894
+ def __del__(self):
895
+ for fs in self._remote_filesystems.values():
896
+ del fs
897
+ for conn in self._remote_connections.values():
898
+ conn.close()
899
+ del conn
900
+
901
+
902
+ @functools.cache
903
+ def get_client() -> SlurmExecutionClient:
904
+ return SlurmExecutionClient()
905
+
906
+
907
+ @tp.overload
908
+ async def launch(
909
+ *,
910
+ job: xm.JobGroup,
911
+ dependency: dependencies.SlurmJobDependency | None = None,
912
+ args: tp.Mapping[str, job_blocks.JobArgs],
913
+ experiment_id: int,
914
+ identity: str | None = ...,
915
+ ) -> SlurmHandle: ...
916
+
917
+
918
+ @tp.overload
919
+ async def launch(
920
+ *,
921
+ job: xm.Job,
922
+ dependency: dependencies.SlurmJobDependency | None = None,
923
+ args: tp.Sequence[job_blocks.JobArgs],
924
+ experiment_id: int,
925
+ identity: str | None = ...,
926
+ ) -> list[SlurmHandle]: ...
927
+
928
+
929
+ @tp.overload
930
+ async def launch(
931
+ *,
932
+ job: xm.Job,
933
+ dependency: dependencies.SlurmJobDependency | None = None,
934
+ args: job_blocks.JobArgs,
935
+ experiment_id: int,
936
+ identity: str | None = ...,
937
+ ) -> SlurmHandle: ...
938
+
939
+
940
+ async def launch(
941
+ *,
942
+ job: xm.Job | xm.JobGroup,
943
+ dependency: dependencies.SlurmJobDependency | None = None,
944
+ args: tp.Mapping[str, job_blocks.JobArgs]
945
+ | tp.Sequence[job_blocks.JobArgs]
946
+ | job_blocks.JobArgs,
947
+ experiment_id: int,
948
+ identity: str | None = None,
949
+ ) -> SlurmHandle | list[SlurmHandle]:
950
+ match job:
951
+ case xm.Job() as job:
952
+ if not isinstance(job.executor, executors.Slurm):
953
+ raise ValueError("Job must have a Slurm executor")
954
+ job_requirements = job.executor.requirements
955
+ cluster = job_requirements.cluster
956
+ if cluster is None:
957
+ raise ValueError("Job must have a cluster requirement")
958
+ if cluster.validate is not None:
959
+ cluster.validate(job)
960
+
961
+ return await get_client().launch(
962
+ cluster=cluster,
963
+ job=job,
964
+ dependency=dependency,
965
+ args=tp.cast(job_blocks.JobArgs | tp.Sequence[job_blocks.JobArgs], args),
966
+ experiment_id=experiment_id,
967
+ identity=identity,
968
+ )
969
+ case xm.JobGroup() as job_group:
970
+ job_group_executors = set()
971
+ job_group_clusters = set()
972
+ for job_item in job_group.jobs.values():
973
+ if not isinstance(job_item, xm.Job):
974
+ raise ValueError("Job group must contain only jobs")
975
+ if not isinstance(job_item.executor, executors.Slurm):
976
+ raise ValueError("Job must have a Slurm executor")
977
+ if job_item.executor.requirements.cluster is None:
978
+ raise ValueError("Job must have a cluster requirement")
979
+ if job_item.executor.requirements.cluster.validate is not None:
980
+ job_item.executor.requirements.cluster.validate(job_item)
981
+ job_group_clusters.add(job_item.executor.requirements.cluster)
982
+ job_group_executors.add(id(job_item.executor))
983
+ if len(job_group_executors) != 1:
984
+ raise ValueError("Job group must have the same executor for all jobs")
985
+ if len(job_group_clusters) != 1:
986
+ raise ValueError("Job group must have the same cluster for all jobs")
987
+
988
+ return await get_client().launch(
989
+ cluster=job_group_clusters.pop(),
990
+ job=job_group,
991
+ dependency=dependency,
992
+ args=tp.cast(tp.Mapping[str, job_blocks.JobArgs], args),
993
+ experiment_id=experiment_id,
994
+ identity=identity,
995
+ )