xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (42) hide show
  1. xm_slurm/__init__.py +6 -2
  2. xm_slurm/api.py +301 -34
  3. xm_slurm/batching.py +4 -4
  4. xm_slurm/config.py +105 -55
  5. xm_slurm/constants.py +19 -0
  6. xm_slurm/contrib/__init__.py +0 -0
  7. xm_slurm/contrib/clusters/__init__.py +47 -13
  8. xm_slurm/contrib/clusters/drac.py +34 -16
  9. xm_slurm/dependencies.py +171 -0
  10. xm_slurm/executables.py +34 -22
  11. xm_slurm/execution.py +305 -107
  12. xm_slurm/executors.py +8 -12
  13. xm_slurm/experiment.py +601 -168
  14. xm_slurm/experimental/parameter_controller.py +202 -0
  15. xm_slurm/job_blocks.py +7 -0
  16. xm_slurm/packageables.py +42 -20
  17. xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
  18. xm_slurm/packaging/router.py +3 -1
  19. xm_slurm/packaging/utils.py +9 -81
  20. xm_slurm/resources.py +28 -4
  21. xm_slurm/scripts/_cloudpickle.py +28 -0
  22. xm_slurm/scripts/cli.py +52 -0
  23. xm_slurm/status.py +9 -0
  24. xm_slurm/templates/docker/mamba.Dockerfile +4 -2
  25. xm_slurm/templates/docker/python.Dockerfile +18 -10
  26. xm_slurm/templates/docker/uv.Dockerfile +35 -0
  27. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
  28. xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
  29. xm_slurm/templates/slurm/job.bash.j2 +4 -3
  30. xm_slurm/types.py +23 -0
  31. xm_slurm/utils.py +18 -10
  32. xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
  33. xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
  34. {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
  35. xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
  36. xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
  37. xm_slurm/packaging/docker/__init__.py +0 -75
  38. xm_slurm/packaging/docker/abc.py +0 -112
  39. xm_slurm/packaging/docker/cloud.py +0 -503
  40. xm_slurm/templates/docker/pdm.Dockerfile +0 -31
  41. xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
  42. xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/execution.py CHANGED
@@ -5,24 +5,29 @@ import functools
5
5
  import hashlib
6
6
  import logging
7
7
  import operator
8
- import re
9
8
  import shlex
10
- import typing
11
- from typing import Any, Mapping, Sequence
9
+ import typing as tp
12
10
 
13
11
  import asyncssh
14
12
  import backoff
15
13
  import jinja2 as j2
14
+ import more_itertools as mit
16
15
  from asyncssh.auth import KbdIntPrompts, KbdIntResponse
17
16
  from asyncssh.misc import MaybeAwait
17
+ from rich.console import ConsoleRenderable
18
+ from rich.rule import Rule
18
19
  from xmanager import xm
19
20
 
20
- from xm_slurm import batching, config, executors, status
21
+ from xm_slurm import batching, config, constants, dependencies, executors, status
21
22
  from xm_slurm.console import console
23
+ from xm_slurm.job_blocks import JobArgs
24
+ from xm_slurm.types import Descriptor
22
25
 
23
26
  SlurmClusterConfig = config.SlurmClusterConfig
24
27
  ContainerRuntime = config.ContainerRuntime
25
28
 
29
+ logger = logging.getLogger(__name__)
30
+
26
31
  """
27
32
  === Runtime Configurations ===
28
33
  With RunC:
@@ -41,11 +46,6 @@ With Singularity / Apptainer:
41
46
  apptainer run --compat <digest>
42
47
  """
43
48
 
44
- """
45
- #SBATCH --error=/dev/null
46
- #SBATCH --output=/dev/null
47
- """
48
-
49
49
  _POLL_INTERVAL = 30.0
50
50
  _BATCHED_BATCH_SIZE = 16
51
51
  _BATCHED_TIMEOUT = 0.2
@@ -67,12 +67,79 @@ class NoKBAuthSSHClient(asyncssh.SSHClient):
67
67
  return []
68
68
 
69
69
 
70
- def _group_by_ssh_options(
71
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
72
- ) -> dict[asyncssh.SSHClientConnectionOptions, list[str]]:
70
+ @dataclasses.dataclass(frozen=True, kw_only=True)
71
+ class SlurmJob:
72
+ job_id: str
73
+
74
+ @property
75
+ def is_array_job(self) -> bool:
76
+ return isinstance(self, SlurmArrayJob)
77
+
78
+ @property
79
+ def is_heterogeneous_job(self) -> bool:
80
+ return isinstance(self, SlurmHeterogeneousJob)
81
+
82
+ def __hash__(self) -> int:
83
+ return hash((type(self), self.job_id))
84
+
85
+
86
+ @dataclasses.dataclass(frozen=True, kw_only=True)
87
+ class SlurmArrayJob(SlurmJob):
88
+ array_job_id: str
89
+ array_task_id: str
90
+
91
+
92
+ @dataclasses.dataclass(frozen=True, kw_only=True)
93
+ class SlurmHeterogeneousJob(SlurmJob):
94
+ het_job_id: str
95
+ het_component_id: str
96
+
97
+
98
+ SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
99
+
100
+
101
+ class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
102
+ def __set_name__(self, owner: type, name: str):
103
+ del owner
104
+ self.job = f"_{name}"
105
+
106
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
107
+ del owner
108
+ return getattr(instance, self.job)
109
+
110
+ def __set__(self, instance: object, value: str):
111
+ _setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
112
+
113
+ match = constants.SLURM_JOB_ID_REGEX.match(value)
114
+ if match is None:
115
+ raise ValueError(f"Invalid Slurm job ID: {value}")
116
+ groups = match.groupdict()
117
+
118
+ job_id = groups["jobid"]
119
+ if array_task_id := groups.get("arraytaskid", None):
120
+ _setattr(
121
+ instance,
122
+ self.job,
123
+ SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
124
+ )
125
+ elif het_component_id := groups.get("componentid", None):
126
+ _setattr(
127
+ instance,
128
+ self.job,
129
+ SlurmHeterogeneousJob(
130
+ job_id=value, het_job_id=job_id, het_component_id=het_component_id
131
+ ),
132
+ )
133
+ else:
134
+ _setattr(instance, self.job, SlurmJob(job_id=value))
135
+
136
+
137
+ def _group_by_ssh_configs(
138
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
139
+ ) -> dict[config.SlurmSSHConfig, list[SlurmJob]]:
73
140
  jobs_by_cluster = collections.defaultdict(list)
74
- for options, job_id in zip(ssh_options, job_ids):
75
- jobs_by_cluster[options].append(job_id)
141
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
142
+ jobs_by_cluster[ssh_config].append(slurm_job)
76
143
  return jobs_by_cluster
77
144
 
78
145
 
@@ -83,18 +150,20 @@ class _BatchedSlurmHandle:
83
150
  batch_timeout=_BATCHED_TIMEOUT,
84
151
  )
85
152
  @staticmethod
153
+ @backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
86
154
  async def _batched_get_state(
87
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
88
- ) -> Sequence[status.SlurmJobState]:
155
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig],
156
+ slurm_jobs: tp.Sequence[SlurmJob],
157
+ ) -> tp.Sequence[status.SlurmJobState]:
89
158
  async def _get_state(
90
- options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
91
- ) -> Sequence[status.SlurmJobState]:
159
+ options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
160
+ ) -> tp.Sequence[status.SlurmJobState]:
92
161
  result = await get_client().run(
93
162
  options,
94
163
  [
95
164
  "sacct",
96
165
  "--jobs",
97
- ",".join(job_ids),
166
+ ",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
98
167
  "--format",
99
168
  "JobID,State",
100
169
  "--allocations",
@@ -111,32 +180,35 @@ class _BatchedSlurmHandle:
111
180
  states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
112
181
 
113
182
  job_states = []
114
- for job_id in job_ids:
115
- if job_id in states_by_job_id:
116
- job_states.append(states_by_job_id[job_id])
183
+ for slurm_job in slurm_jobs:
184
+ if slurm_job.job_id in states_by_job_id:
185
+ job_states.append(states_by_job_id[slurm_job.job_id])
117
186
  # This is a stupid hack around sacct's inability to display state information for
118
187
  # array job elements that haven't begun. We'll assume that if the job ID is not found,
119
188
  # and it's an array job, then it's pending.
120
- elif re.match(r"^\d+_\d+$", job_id) is not None:
189
+ elif slurm_job.is_array_job:
121
190
  job_states.append(status.SlurmJobState.PENDING)
122
191
  else:
123
- raise SlurmExecutionError(f"Failed to find job state info for {job_id}")
192
+ raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
124
193
  return job_states
125
194
 
126
- jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
195
+ # Group Slurm jobs by their cluster so we can batch requests
196
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
127
197
 
198
+ # Async get state for each cluster
128
199
  job_states_per_cluster = await asyncio.gather(*[
129
- _get_state(options, job_ids) for options, job_ids in jobs_by_cluster.items()
200
+ _get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
130
201
  ])
131
- job_states_by_cluster: dict[
132
- asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
133
- ] = {}
134
- for options, job_states in zip(ssh_options, job_states_per_cluster):
135
- job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
136
202
 
203
+ # Reconstruct the job states by cluster
204
+ job_states_by_cluster = {}
205
+ for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
206
+ job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
207
+
208
+ # Reconstruct the job states in the original order
137
209
  job_states = []
138
- for options, job_id in zip(ssh_options, job_ids):
139
- job_states.append(job_states_by_cluster[options][job_id])
210
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
211
+ job_states.append(job_states_by_cluster[ssh_config][slurm_job])
140
212
  return job_states
141
213
 
142
214
  @functools.partial(
@@ -146,29 +218,32 @@ class _BatchedSlurmHandle:
146
218
  )
147
219
  @staticmethod
148
220
  async def _batched_cancel(
149
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
150
- ) -> Sequence[None]:
221
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig],
222
+ slurm_jobs: tp.Sequence[SlurmJob],
223
+ ) -> tp.Sequence[None]:
151
224
  async def _cancel(
152
- options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
225
+ options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
153
226
  ) -> None:
154
- await get_client().run(options, ["scancel", " ".join(job_ids)], check=True)
227
+ await get_client().run(
228
+ options,
229
+ ["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
230
+ check=True,
231
+ )
155
232
 
156
- jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
233
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
157
234
  return await asyncio.gather(*[
158
235
  _cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
159
236
  ])
160
237
 
161
238
 
162
239
  @dataclasses.dataclass(frozen=True, kw_only=True)
163
- class SlurmHandle(_BatchedSlurmHandle):
240
+ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
164
241
  """A handle for referring to the launched container."""
165
242
 
166
- ssh_connection_options: asyncssh.SSHClientConnectionOptions
167
- job_id: str
168
-
169
- def __post_init__(self):
170
- if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
171
- raise ValueError(f"Invalid job ID: {self.job_id}")
243
+ experiment_id: int
244
+ ssh: config.SlurmSSHConfig
245
+ slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
246
+ job_name: str # XManager job name associated with this handle
172
247
 
173
248
  @backoff.on_predicate(
174
249
  backoff.constant,
@@ -180,15 +255,56 @@ class SlurmHandle(_BatchedSlurmHandle):
180
255
  return await self.get_state()
181
256
 
182
257
  async def stop(self) -> None:
183
- await self._batched_cancel(self.ssh_connection_options, self.job_id)
258
+ await self._batched_cancel(self.ssh, self.slurm_job)
184
259
 
185
260
  async def get_state(self) -> status.SlurmJobState:
186
- return await self._batched_get_state(self.ssh_connection_options, self.job_id)
187
-
261
+ return await self._batched_get_state(self.ssh, self.slurm_job)
188
262
 
189
- @functools.cache
190
- def get_client() -> "Client":
191
- return Client()
263
+ async def logs(
264
+ self, *, num_lines: int, block_size: int, wait: bool, follow: bool
265
+ ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
266
+ file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
267
+ conn = await get_client().connection(self.ssh)
268
+ async with conn.start_sftp_client() as sftp:
269
+ if wait:
270
+ while not (await sftp.exists(file)):
271
+ await asyncio.sleep(5)
272
+
273
+ async with sftp.open(file, "rb") as remote_file:
274
+ file_stat = await remote_file.stat()
275
+ file_size = file_stat.size
276
+ assert file_size is not None
277
+
278
+ data = b""
279
+ lines = []
280
+ position = file_size
281
+
282
+ while len(lines) <= num_lines and position > 0:
283
+ read_size = min(block_size, position)
284
+ position -= read_size
285
+ await remote_file.seek(position)
286
+ chunk = await remote_file.read(read_size)
287
+ data = chunk + data
288
+ lines = data.splitlines()
289
+
290
+ if position <= 0:
291
+ yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
292
+ for line in lines[-num_lines:]:
293
+ yield line.decode("utf-8", errors="replace")
294
+
295
+ if (await self.get_state()) not in status.SlurmActiveJobStates:
296
+ yield Rule("[bold red]END OF FILE[/bold red]")
297
+ return
298
+
299
+ if not follow:
300
+ return
301
+
302
+ await remote_file.seek(file_size)
303
+ while True:
304
+ if new_data := (await remote_file.read(block_size)):
305
+ yield new_data.decode("utf-8", errors="replace")
306
+ else:
307
+ await asyncio.sleep(0.25)
192
308
 
193
309
 
194
310
  @functools.cache
@@ -208,55 +324,75 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
208
324
  case ContainerRuntime.PODMAN:
209
325
  runtime_template = template_env.get_template("runtimes/podman.bash.j2")
210
326
  case _:
211
- raise NotImplementedError
327
+ raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
212
328
  # Update our global env with the runtime template's exported globals
213
329
  template_env.globals.update(runtime_template.module.__dict__)
214
330
 
215
331
  return template_env
216
332
 
217
333
 
334
+ @functools.cache
335
+ def get_client() -> "Client":
336
+ return Client()
337
+
338
+
218
339
  class Client:
219
- def __init__(self):
220
- self._connections: dict[
221
- asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
222
- ] = {}
340
+ def __init__(self) -> None:
341
+ self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
223
342
  self._connection_lock = asyncio.Lock()
224
343
 
225
344
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
226
345
  async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
227
346
  # Make sure the xm-slurm state directory exists
228
- await conn.run("mkdir -p ~/.local/state/xm-slurm", check=True)
347
+ async with conn.start_sftp_client() as sftp_client:
348
+ await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
229
349
 
230
- async def connection(
231
- self,
232
- options: asyncssh.SSHClientConnectionOptions,
233
- ) -> asyncssh.SSHClientConnection:
234
- if options not in self._connections:
350
+ async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
351
+ if ssh_config not in self._connections:
235
352
  async with self._connection_lock:
236
353
  try:
237
- conn, _ = await asyncssh.create_connection(NoKBAuthSSHClient, options=options)
354
+ conn, _ = await asyncssh.create_connection(
355
+ NoKBAuthSSHClient, options=ssh_config.connection_options
356
+ )
238
357
  await self._setup_remote_connection(conn)
239
- self._connections[options] = conn
358
+ self._connections[ssh_config] = conn
240
359
  except asyncssh.misc.PermissionDenied as ex:
241
- raise RuntimeError(f"Permission denied connecting to {options.host}") from ex
242
- return self._connections[options]
360
+ raise SlurmExecutionError(
361
+ f"Permission denied connecting to {ssh_config.host}"
362
+ ) from ex
363
+ except asyncssh.misc.ConnectionLost as ex:
364
+ raise SlurmExecutionError(f"Connection lost to host {ssh_config.host}") from ex
365
+ except asyncssh.misc.HostKeyNotVerifiable as ex:
366
+ raise SlurmExecutionError(
367
+ f"Cannot verify the public key for host {ssh_config.host}"
368
+ ) from ex
369
+ except asyncssh.misc.KeyExchangeFailed as ex:
370
+ raise SlurmExecutionError(
371
+ f"Failed to exchange keys with host {ssh_config.host}"
372
+ ) from ex
373
+ except asyncssh.Error as ex:
374
+ raise SlurmExecutionError(
375
+ f"SSH connection error when connecting to {ssh_config.host}"
376
+ ) from ex
377
+
378
+ return self._connections[ssh_config]
243
379
 
244
380
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
245
381
  async def run(
246
382
  self,
247
- options: asyncssh.SSHClientConnectionOptions,
248
- command: xm.SequentialArgs | str | Sequence[str],
383
+ ssh_config: config.SlurmSSHConfig,
384
+ command: xm.SequentialArgs | str | tp.Sequence[str],
249
385
  *,
250
386
  check: bool = False,
251
387
  timeout: float | None = None,
252
388
  ) -> asyncssh.SSHCompletedProcess:
253
- client = await self.connection(options)
389
+ client = await self.connection(ssh_config)
254
390
  if isinstance(command, xm.SequentialArgs):
255
391
  command = command.to_list()
256
392
  if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
257
393
  command = shlex.join(command)
258
394
  assert isinstance(command, str)
259
- logging.debug("Running command on %s: %s", options.host, command)
395
+ logger.debug("Running command on %s: %s", ssh_config.host, command)
260
396
 
261
397
  return await client.run(command, check=check, timeout=timeout)
262
398
 
@@ -264,8 +400,9 @@ class Client:
264
400
  self,
265
401
  *,
266
402
  job: xm.Job | xm.JobGroup,
403
+ dependency: dependencies.SlurmJobDependency | None = None,
267
404
  cluster: SlurmClusterConfig,
268
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
405
+ args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
269
406
  experiment_id: int,
270
407
  identity: str | None,
271
408
  ) -> str:
@@ -276,7 +413,7 @@ class Client:
276
413
 
277
414
  # Sanitize job groups
278
415
  if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
279
- job = typing.cast(xm.Job, list(job.jobs.values())[0])
416
+ job = tp.cast(xm.Job, list(job.jobs.values())[0])
280
417
  elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
281
418
  raise ValueError("Job group must have at least one job")
282
419
 
@@ -294,6 +431,7 @@ class Client:
294
431
 
295
432
  return template.render(
296
433
  job=job_array,
434
+ dependency=dependency,
297
435
  cluster=cluster,
298
436
  args=sequential_args,
299
437
  env_vars=env_vars,
@@ -306,6 +444,7 @@ class Client:
306
444
  env_vars = args.get("env_vars", None)
307
445
  return template.render(
308
446
  job=job,
447
+ dependency=dependency,
309
448
  cluster=cluster,
310
449
  args=sequential_args,
311
450
  env_vars=env_vars,
@@ -326,6 +465,7 @@ class Client:
326
465
  }
327
466
  return template.render(
328
467
  job_group=job_group,
468
+ dependency=dependency,
329
469
  cluster=cluster,
330
470
  args=sequential_args,
331
471
  env_vars=env_vars,
@@ -335,51 +475,66 @@ class Client:
335
475
  case _:
336
476
  raise ValueError(f"Unsupported job type: {type(job)}")
337
477
 
338
- @typing.overload
478
+ @tp.overload
339
479
  async def launch(
340
480
  self,
341
481
  *,
342
482
  cluster: SlurmClusterConfig,
343
- job: xm.Job | xm.JobGroup,
344
- args: Mapping[str, Any] | None,
483
+ job: xm.JobGroup,
484
+ dependency: dependencies.SlurmJobDependency | None = None,
485
+ args: tp.Mapping[str, JobArgs] | None,
345
486
  experiment_id: int,
346
487
  identity: str | None = ...,
347
488
  ) -> SlurmHandle: ...
348
489
 
349
- @typing.overload
490
+ @tp.overload
350
491
  async def launch(
351
492
  self,
352
493
  *,
353
494
  cluster: SlurmClusterConfig,
354
- job: xm.Job | xm.JobGroup,
355
- args: Sequence[Mapping[str, Any]],
495
+ job: xm.Job,
496
+ dependency: dependencies.SlurmJobDependency | None = None,
497
+ args: tp.Sequence[JobArgs],
498
+ experiment_id: int,
499
+ identity: str | None = ...,
500
+ ) -> list[SlurmHandle]: ...
501
+
502
+ @tp.overload
503
+ async def launch(
504
+ self,
505
+ *,
506
+ cluster: SlurmClusterConfig,
507
+ job: xm.Job,
508
+ dependency: dependencies.SlurmJobDependency | None = None,
509
+ args: JobArgs,
356
510
  experiment_id: int,
357
511
  identity: str | None = ...,
358
- ) -> Sequence[SlurmHandle]: ...
512
+ ) -> SlurmHandle: ...
359
513
 
360
514
  async def launch(
361
515
  self,
362
516
  *,
363
517
  cluster: SlurmClusterConfig,
364
518
  job: xm.Job | xm.JobGroup,
365
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
519
+ dependency: dependencies.SlurmJobDependency | None = None,
520
+ args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs | None,
366
521
  experiment_id: int,
367
522
  identity: str | None = None,
368
- ) -> SlurmHandle | Sequence[SlurmHandle]:
369
- # Construct template
523
+ ):
370
524
  template = await self.template(
371
525
  job=job,
526
+ dependency=dependency,
372
527
  cluster=cluster,
373
528
  args=args,
374
529
  experiment_id=experiment_id,
375
530
  identity=identity,
376
531
  )
377
- logging.debug("Slurm submission script:\n%s", template)
532
+ logger.debug("Slurm submission script:\n%s", template)
378
533
 
379
534
  # Hash submission script
380
535
  template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
381
536
 
382
- conn = await self.connection(cluster.ssh_connection_options)
537
+ conn = await self.connection(cluster.ssh)
383
538
  async with conn.start_sftp_client() as sftp:
384
539
  # Write the submission script to the cluster
385
540
  # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
@@ -392,9 +547,9 @@ class Client:
392
547
 
393
548
  # Construct and run command on the cluster
394
549
  command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
395
- result = await self.run(cluster.ssh_connection_options, command)
550
+ result = await self.run(cluster.ssh, command)
396
551
  if result.returncode != 0:
397
- raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
552
+ raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
398
553
 
399
554
  assert isinstance(result.stdout, str)
400
555
  slurm_job_id, *_ = result.stdout.split(",")
@@ -405,61 +560,103 @@ class Client:
405
560
  f"[cyan]{cluster.name}[/cyan] "
406
561
  )
407
562
 
563
+ # If we scheduled an array job make sure to return a list of handles
564
+ # The indexing is always sequential in 0, 1, ..., n - 1
408
565
  if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
566
+ assert job.name is not None
409
567
  return [
410
568
  SlurmHandle(
411
- ssh_connection_options=cluster.ssh_connection_options,
412
- job_id=f"{slurm_job_id}_{array_index}",
569
+ experiment_id=experiment_id,
570
+ ssh=cluster.ssh,
571
+ slurm_job=f"{slurm_job_id}_{array_index}",
572
+ job_name=job.name,
413
573
  )
414
574
  for array_index in range(len(args))
415
575
  ]
576
+ elif isinstance(job, xm.Job):
577
+ assert job.name is not None
578
+ return SlurmHandle(
579
+ experiment_id=experiment_id,
580
+ ssh=cluster.ssh,
581
+ slurm_job=slurm_job_id,
582
+ job_name=job.name,
583
+ )
584
+ elif isinstance(job, xm.JobGroup):
585
+ # TODO: make this work for actual job groups.
586
+ job = tp.cast(xm.Job, mit.one(job.jobs.values()))
587
+ assert isinstance(job, xm.Job)
588
+ assert job.name is not None
589
+ return SlurmHandle(
590
+ experiment_id=experiment_id,
591
+ ssh=cluster.ssh,
592
+ slurm_job=slurm_job_id,
593
+ job_name=job.name,
594
+ )
595
+ else:
596
+ raise ValueError(f"Unsupported job type: {type(job)}")
416
597
 
417
- return SlurmHandle(
418
- ssh_connection_options=cluster.ssh_connection_options,
419
- job_id=slurm_job_id,
420
- )
598
+ def __del__(self):
599
+ for conn in self._connections.values():
600
+ conn.close()
421
601
 
422
602
 
423
- @typing.overload
603
+ @tp.overload
424
604
  async def launch(
425
605
  *,
426
- job: xm.Job | xm.JobGroup,
427
- args: Mapping[str, Any],
606
+ job: xm.JobGroup,
607
+ dependency: dependencies.SlurmJobDependency | None = None,
608
+ args: tp.Mapping[str, JobArgs],
428
609
  experiment_id: int,
429
610
  identity: str | None = ...,
430
611
  ) -> SlurmHandle: ...
431
612
 
432
613
 
433
- @typing.overload
614
+ @tp.overload
434
615
  async def launch(
435
616
  *,
436
- job: xm.Job | xm.JobGroup,
437
- args: Sequence[Mapping[str, Any]],
617
+ job: xm.Job,
618
+ dependency: dependencies.SlurmJobDependency | None = None,
619
+ args: tp.Sequence[JobArgs],
438
620
  experiment_id: int,
439
621
  identity: str | None = ...,
440
- ) -> Sequence[SlurmHandle]: ...
622
+ ) -> list[SlurmHandle]: ...
623
+
624
+
625
+ @tp.overload
626
+ async def launch(
627
+ *,
628
+ job: xm.Job,
629
+ dependency: dependencies.SlurmJobDependency | None = None,
630
+ args: JobArgs,
631
+ experiment_id: int,
632
+ identity: str | None = ...,
633
+ ) -> SlurmHandle: ...
441
634
 
442
635
 
443
636
  async def launch(
444
637
  *,
445
638
  job: xm.Job | xm.JobGroup,
446
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]],
639
+ dependency: dependencies.SlurmJobDependency | None = None,
640
+ args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs,
447
641
  experiment_id: int,
448
642
  identity: str | None = None,
449
- ) -> SlurmHandle | Sequence[SlurmHandle]:
643
+ ) -> SlurmHandle | list[SlurmHandle]:
450
644
  match job:
451
- case xm.Job():
645
+ case xm.Job() as job:
452
646
  if not isinstance(job.executor, executors.Slurm):
453
647
  raise ValueError("Job must have a Slurm executor")
454
648
  job_requirements = job.executor.requirements
455
649
  cluster = job_requirements.cluster
456
650
  if cluster is None:
457
651
  raise ValueError("Job must have a cluster requirement")
652
+ if cluster.validate is not None:
653
+ cluster.validate(job)
458
654
 
459
655
  return await get_client().launch(
460
656
  cluster=cluster,
461
657
  job=job,
462
- args=args,
658
+ dependency=dependency,
659
+ args=tp.cast(JobArgs | tp.Sequence[JobArgs], args),
463
660
  experiment_id=experiment_id,
464
661
  identity=identity,
465
662
  )
@@ -473,6 +670,8 @@ async def launch(
473
670
  raise ValueError("Job must have a Slurm executor")
474
671
  if job_item.executor.requirements.cluster is None:
475
672
  raise ValueError("Job must have a cluster requirement")
673
+ if job_item.executor.requirements.cluster.validate is not None:
674
+ job_item.executor.requirements.cluster.validate(job_item)
476
675
  job_group_clusters.add(job_item.executor.requirements.cluster)
477
676
  job_group_executors.add(id(job_item.executor))
478
677
  if len(job_group_executors) != 1:
@@ -482,10 +681,9 @@ async def launch(
482
681
 
483
682
  return await get_client().launch(
484
683
  cluster=job_group_clusters.pop(),
485
- job=job,
486
- args=args,
684
+ job=job_group,
685
+ dependency=dependency,
686
+ args=tp.cast(tp.Mapping[str, JobArgs], args),
487
687
  experiment_id=experiment_id,
488
688
  identity=identity,
489
689
  )
490
- case _:
491
- raise ValueError("Unsupported job type")