xmanager-slurm 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/execution.py CHANGED
@@ -5,10 +5,8 @@ import functools
5
5
  import hashlib
6
6
  import logging
7
7
  import operator
8
- import re
9
8
  import shlex
10
- import typing
11
- from typing import Any, Mapping, Sequence
9
+ import typing as tp
12
10
 
13
11
  import asyncssh
14
12
  import backoff
@@ -16,15 +14,20 @@ import jinja2 as j2
16
14
  import more_itertools as mit
17
15
  from asyncssh.auth import KbdIntPrompts, KbdIntResponse
18
16
  from asyncssh.misc import MaybeAwait
17
+ from rich.console import ConsoleRenderable
18
+ from rich.rule import Rule
19
19
  from xmanager import xm
20
20
 
21
- from xm_slurm import batching, config, executors, status
21
+ from xm_slurm import batching, config, constants, dependencies, executors, status
22
22
  from xm_slurm.console import console
23
23
  from xm_slurm.job_blocks import JobArgs
24
+ from xm_slurm.types import Descriptor
24
25
 
25
26
  SlurmClusterConfig = config.SlurmClusterConfig
26
27
  ContainerRuntime = config.ContainerRuntime
27
28
 
29
+ logger = logging.getLogger(__name__)
30
+
28
31
  """
29
32
  === Runtime Configurations ===
30
33
  With RunC:
@@ -43,11 +46,6 @@ With Singularity / Apptainer:
43
46
  apptainer run --compat <digest>
44
47
  """
45
48
 
46
- """
47
- #SBATCH --error=/dev/null
48
- #SBATCH --output=/dev/null
49
- """
50
-
51
49
  _POLL_INTERVAL = 30.0
52
50
  _BATCHED_BATCH_SIZE = 16
53
51
  _BATCHED_TIMEOUT = 0.2
@@ -69,12 +67,79 @@ class NoKBAuthSSHClient(asyncssh.SSHClient):
69
67
  return []
70
68
 
71
69
 
72
- def _group_by_ssh_options(
73
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
74
- ) -> dict[asyncssh.SSHClientConnectionOptions, list[str]]:
70
+ @dataclasses.dataclass(frozen=True, kw_only=True)
71
+ class SlurmJob:
72
+ job_id: str
73
+
74
+ @property
75
+ def is_array_job(self) -> bool:
76
+ return isinstance(self, SlurmArrayJob)
77
+
78
+ @property
79
+ def is_heterogeneous_job(self) -> bool:
80
+ return isinstance(self, SlurmHeterogeneousJob)
81
+
82
+ def __hash__(self) -> int:
83
+ return hash((type(self), self.job_id))
84
+
85
+
86
+ @dataclasses.dataclass(frozen=True, kw_only=True)
87
+ class SlurmArrayJob(SlurmJob):
88
+ array_job_id: str
89
+ array_task_id: str
90
+
91
+
92
+ @dataclasses.dataclass(frozen=True, kw_only=True)
93
+ class SlurmHeterogeneousJob(SlurmJob):
94
+ het_job_id: str
95
+ het_component_id: str
96
+
97
+
98
+ SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
99
+
100
+
101
+ class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
102
+ def __set_name__(self, owner: type, name: str):
103
+ del owner
104
+ self.job = f"_{name}"
105
+
106
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
107
+ del owner
108
+ return getattr(instance, self.job)
109
+
110
+ def __set__(self, instance: object, value: str):
111
+ _setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
112
+
113
+ match = constants.SLURM_JOB_ID_REGEX.match(value)
114
+ if match is None:
115
+ raise ValueError(f"Invalid Slurm job ID: {value}")
116
+ groups = match.groupdict()
117
+
118
+ job_id = groups["jobid"]
119
+ if array_task_id := groups.get("arraytaskid", None):
120
+ _setattr(
121
+ instance,
122
+ self.job,
123
+ SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
124
+ )
125
+ elif het_component_id := groups.get("componentid", None):
126
+ _setattr(
127
+ instance,
128
+ self.job,
129
+ SlurmHeterogeneousJob(
130
+ job_id=value, het_job_id=job_id, het_component_id=het_component_id
131
+ ),
132
+ )
133
+ else:
134
+ _setattr(instance, self.job, SlurmJob(job_id=value))
135
+
136
+
137
+ def _group_by_ssh_configs(
138
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
139
+ ) -> dict[config.SlurmSSHConfig, list[SlurmJob]]:
75
140
  jobs_by_cluster = collections.defaultdict(list)
76
- for options, job_id in zip(ssh_options, job_ids):
77
- jobs_by_cluster[options].append(job_id)
141
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
142
+ jobs_by_cluster[ssh_config].append(slurm_job)
78
143
  return jobs_by_cluster
79
144
 
80
145
 
@@ -85,18 +150,20 @@ class _BatchedSlurmHandle:
85
150
  batch_timeout=_BATCHED_TIMEOUT,
86
151
  )
87
152
  @staticmethod
153
+ @backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
88
154
  async def _batched_get_state(
89
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
90
- ) -> Sequence[status.SlurmJobState]:
155
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig],
156
+ slurm_jobs: tp.Sequence[SlurmJob],
157
+ ) -> tp.Sequence[status.SlurmJobState]:
91
158
  async def _get_state(
92
- options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
93
- ) -> Sequence[status.SlurmJobState]:
159
+ options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
160
+ ) -> tp.Sequence[status.SlurmJobState]:
94
161
  result = await get_client().run(
95
162
  options,
96
163
  [
97
164
  "sacct",
98
165
  "--jobs",
99
- ",".join(job_ids),
166
+ ",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
100
167
  "--format",
101
168
  "JobID,State",
102
169
  "--allocations",
@@ -113,32 +180,35 @@ class _BatchedSlurmHandle:
113
180
  states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
114
181
 
115
182
  job_states = []
116
- for job_id in job_ids:
117
- if job_id in states_by_job_id:
118
- job_states.append(states_by_job_id[job_id])
183
+ for slurm_job in slurm_jobs:
184
+ if slurm_job.job_id in states_by_job_id:
185
+ job_states.append(states_by_job_id[slurm_job.job_id])
119
186
  # This is a stupid hack around sacct's inability to display state information for
120
187
  # array job elements that haven't begun. We'll assume that if the job ID is not found,
121
188
  # and it's an array job, then it's pending.
122
- elif re.match(r"^\d+_\d+$", job_id) is not None:
189
+ elif slurm_job.is_array_job:
123
190
  job_states.append(status.SlurmJobState.PENDING)
124
191
  else:
125
- raise SlurmExecutionError(f"Failed to find job state info for {job_id}")
192
+ raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
126
193
  return job_states
127
194
 
128
- jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
195
+ # Group Slurm jobs by their cluster so we can batch requests
196
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
129
197
 
198
+ # Async get state for each cluster
130
199
  job_states_per_cluster = await asyncio.gather(*[
131
- _get_state(options, job_ids) for options, job_ids in jobs_by_cluster.items()
200
+ _get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
132
201
  ])
133
- job_states_by_cluster: dict[
134
- asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
135
- ] = {}
136
- for options, job_states in zip(ssh_options, job_states_per_cluster):
137
- job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
138
202
 
203
+ # Reconstruct the job states by cluster
204
+ job_states_by_cluster = {}
205
+ for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
206
+ job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
207
+
208
+ # Reconstruct the job states in the original order
139
209
  job_states = []
140
- for options, job_id in zip(ssh_options, job_ids):
141
- job_states.append(job_states_by_cluster[options][job_id])
210
+ for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
211
+ job_states.append(job_states_by_cluster[ssh_config][slurm_job])
142
212
  return job_states
143
213
 
144
214
  @functools.partial(
@@ -148,31 +218,33 @@ class _BatchedSlurmHandle:
148
218
  )
149
219
  @staticmethod
150
220
  async def _batched_cancel(
151
- ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
152
- ) -> Sequence[None]:
221
+ ssh_configs: tp.Sequence[config.SlurmSSHConfig],
222
+ slurm_jobs: tp.Sequence[SlurmJob],
223
+ ) -> tp.Sequence[None]:
153
224
  async def _cancel(
154
- options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
225
+ options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
155
226
  ) -> None:
156
- await get_client().run(options, ["scancel", " ".join(job_ids)], check=True)
227
+ await get_client().run(
228
+ options,
229
+ ["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
230
+ check=True,
231
+ )
157
232
 
158
- jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
233
+ jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
159
234
  return await asyncio.gather(*[
160
235
  _cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
161
236
  ])
162
237
 
163
238
 
164
239
  @dataclasses.dataclass(frozen=True, kw_only=True)
165
- class SlurmHandle(_BatchedSlurmHandle):
240
+ class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
166
241
  """A handle for referring to the launched container."""
167
242
 
243
+ experiment_id: int
168
244
  ssh: config.SlurmSSHConfig
169
- job_id: str
245
+ slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
170
246
  job_name: str # XManager job name associated with this handle
171
247
 
172
- def __post_init__(self):
173
- if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
174
- raise ValueError(f"Invalid job ID: {self.job_id}")
175
-
176
248
  @backoff.on_predicate(
177
249
  backoff.constant,
178
250
  lambda state: state in status.SlurmActiveJobStates,
@@ -183,15 +255,56 @@ class SlurmHandle(_BatchedSlurmHandle):
183
255
  return await self.get_state()
184
256
 
185
257
  async def stop(self) -> None:
186
- await self._batched_cancel(self.ssh.connection_options, self.job_id)
258
+ await self._batched_cancel(self.ssh, self.slurm_job)
187
259
 
188
260
  async def get_state(self) -> status.SlurmJobState:
189
- return await self._batched_get_state(self.ssh.connection_options, self.job_id)
190
-
261
+ return await self._batched_get_state(self.ssh, self.slurm_job)
191
262
 
192
- @functools.cache
193
- def get_client() -> "Client":
194
- return Client()
263
+ async def logs(
264
+ self, *, num_lines: int, block_size: int, wait: bool, follow: bool
265
+ ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
266
+ file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
267
+ conn = await get_client().connection(self.ssh)
268
+ async with conn.start_sftp_client() as sftp:
269
+ if wait:
270
+ while not (await sftp.exists(file)):
271
+ await asyncio.sleep(5)
272
+
273
+ async with sftp.open(file, "rb") as remote_file:
274
+ file_stat = await remote_file.stat()
275
+ file_size = file_stat.size
276
+ assert file_size is not None
277
+
278
+ data = b""
279
+ lines = []
280
+ position = file_size
281
+
282
+ while len(lines) <= num_lines and position > 0:
283
+ read_size = min(block_size, position)
284
+ position -= read_size
285
+ await remote_file.seek(position)
286
+ chunk = await remote_file.read(read_size)
287
+ data = chunk + data
288
+ lines = data.splitlines()
289
+
290
+ if position <= 0:
291
+ yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
292
+ for line in lines[-num_lines:]:
293
+ yield line.decode("utf-8", errors="replace")
294
+
295
+ if (await self.get_state()) not in status.SlurmActiveJobStates:
296
+ yield Rule("[bold red]END OF FILE[/bold red]")
297
+ return
298
+
299
+ if not follow:
300
+ return
301
+
302
+ await remote_file.seek(file_size)
303
+ while True:
304
+ if new_data := (await remote_file.read(block_size)):
305
+ yield new_data.decode("utf-8", errors="replace")
306
+ else:
307
+ await asyncio.sleep(0.25)
195
308
 
196
309
 
197
310
  @functools.cache
@@ -218,11 +331,14 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
218
331
  return template_env
219
332
 
220
333
 
334
+ @functools.cache
335
+ def get_client() -> "Client":
336
+ return Client()
337
+
338
+
221
339
  class Client:
222
340
  def __init__(self) -> None:
223
- self._connections = dict[
224
- asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
225
- ]()
341
+ self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
226
342
  self._connection_lock = asyncio.Lock()
227
343
 
228
344
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
@@ -231,53 +347,52 @@ class Client:
231
347
  async with conn.start_sftp_client() as sftp_client:
232
348
  await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
233
349
 
234
- async def connection(
235
- self,
236
- options: asyncssh.SSHClientConnectionOptions,
237
- ) -> asyncssh.SSHClientConnection:
238
- if options not in self._connections:
350
+ async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
351
+ if ssh_config not in self._connections:
239
352
  async with self._connection_lock:
240
353
  try:
241
- conn, _ = await asyncssh.create_connection(NoKBAuthSSHClient, options=options)
354
+ conn, _ = await asyncssh.create_connection(
355
+ NoKBAuthSSHClient, options=ssh_config.connection_options
356
+ )
242
357
  await self._setup_remote_connection(conn)
243
- self._connections[options] = conn
358
+ self._connections[ssh_config] = conn
244
359
  except asyncssh.misc.PermissionDenied as ex:
245
360
  raise SlurmExecutionError(
246
- f"Permission denied connecting to {options.host}"
361
+ f"Permission denied connecting to {ssh_config.host}"
247
362
  ) from ex
248
363
  except asyncssh.misc.ConnectionLost as ex:
249
- raise SlurmExecutionError(f"Connection lost to host {options.host}") from ex
364
+ raise SlurmExecutionError(f"Connection lost to host {ssh_config.host}") from ex
250
365
  except asyncssh.misc.HostKeyNotVerifiable as ex:
251
366
  raise SlurmExecutionError(
252
- f"Cannot verify the public key for host {options.host}"
367
+ f"Cannot verify the public key for host {ssh_config.host}"
253
368
  ) from ex
254
369
  except asyncssh.misc.KeyExchangeFailed as ex:
255
370
  raise SlurmExecutionError(
256
- f"Failed to exchange keys with host {options.host}"
371
+ f"Failed to exchange keys with host {ssh_config.host}"
257
372
  ) from ex
258
373
  except asyncssh.Error as ex:
259
374
  raise SlurmExecutionError(
260
- f"SSH connection error when connecting to {options.host}"
375
+ f"SSH connection error when connecting to {ssh_config.host}"
261
376
  ) from ex
262
377
 
263
- return self._connections[options]
378
+ return self._connections[ssh_config]
264
379
 
265
380
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
266
381
  async def run(
267
382
  self,
268
- options: asyncssh.SSHClientConnectionOptions,
269
- command: xm.SequentialArgs | str | Sequence[str],
383
+ ssh_config: config.SlurmSSHConfig,
384
+ command: xm.SequentialArgs | str | tp.Sequence[str],
270
385
  *,
271
386
  check: bool = False,
272
387
  timeout: float | None = None,
273
388
  ) -> asyncssh.SSHCompletedProcess:
274
- client = await self.connection(options)
389
+ client = await self.connection(ssh_config)
275
390
  if isinstance(command, xm.SequentialArgs):
276
391
  command = command.to_list()
277
392
  if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
278
393
  command = shlex.join(command)
279
394
  assert isinstance(command, str)
280
- logging.debug("Running command on %s: %s", options.host, command)
395
+ logger.debug("Running command on %s: %s", ssh_config.host, command)
281
396
 
282
397
  return await client.run(command, check=check, timeout=timeout)
283
398
 
@@ -285,8 +400,9 @@ class Client:
285
400
  self,
286
401
  *,
287
402
  job: xm.Job | xm.JobGroup,
403
+ dependency: dependencies.SlurmJobDependency | None = None,
288
404
  cluster: SlurmClusterConfig,
289
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
405
+ args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
290
406
  experiment_id: int,
291
407
  identity: str | None,
292
408
  ) -> str:
@@ -297,7 +413,7 @@ class Client:
297
413
 
298
414
  # Sanitize job groups
299
415
  if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
300
- job = typing.cast(xm.Job, list(job.jobs.values())[0])
416
+ job = tp.cast(xm.Job, list(job.jobs.values())[0])
301
417
  elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
302
418
  raise ValueError("Job group must have at least one job")
303
419
 
@@ -315,6 +431,7 @@ class Client:
315
431
 
316
432
  return template.render(
317
433
  job=job_array,
434
+ dependency=dependency,
318
435
  cluster=cluster,
319
436
  args=sequential_args,
320
437
  env_vars=env_vars,
@@ -327,6 +444,7 @@ class Client:
327
444
  env_vars = args.get("env_vars", None)
328
445
  return template.render(
329
446
  job=job,
447
+ dependency=dependency,
330
448
  cluster=cluster,
331
449
  args=sequential_args,
332
450
  env_vars=env_vars,
@@ -347,6 +465,7 @@ class Client:
347
465
  }
348
466
  return template.render(
349
467
  job_group=job_group,
468
+ dependency=dependency,
350
469
  cluster=cluster,
351
470
  args=sequential_args,
352
471
  env_vars=env_vars,
@@ -356,54 +475,66 @@ class Client:
356
475
  case _:
357
476
  raise ValueError(f"Unsupported job type: {type(job)}")
358
477
 
359
- @typing.overload
478
+ @tp.overload
360
479
  async def launch(
361
480
  self,
362
481
  *,
363
482
  cluster: SlurmClusterConfig,
364
483
  job: xm.JobGroup,
365
- args: Mapping[str, JobArgs] | None,
484
+ dependency: dependencies.SlurmJobDependency | None = None,
485
+ args: tp.Mapping[str, JobArgs] | None,
366
486
  experiment_id: int,
367
487
  identity: str | None = ...,
368
488
  ) -> SlurmHandle: ...
369
489
 
370
- @typing.overload
490
+ @tp.overload
371
491
  async def launch(
372
492
  self,
373
493
  *,
374
494
  cluster: SlurmClusterConfig,
375
495
  job: xm.Job,
376
- args: Sequence[JobArgs],
496
+ dependency: dependencies.SlurmJobDependency | None = None,
497
+ args: tp.Sequence[JobArgs],
377
498
  experiment_id: int,
378
499
  identity: str | None = ...,
379
500
  ) -> list[SlurmHandle]: ...
380
501
 
381
- @typing.overload
502
+ @tp.overload
382
503
  async def launch(
383
504
  self,
384
505
  *,
385
506
  cluster: SlurmClusterConfig,
386
507
  job: xm.Job,
508
+ dependency: dependencies.SlurmJobDependency | None = None,
387
509
  args: JobArgs,
388
510
  experiment_id: int,
389
511
  identity: str | None = ...,
390
512
  ) -> SlurmHandle: ...
391
513
 
392
- async def launch(self, *, cluster, job, args, experiment_id, identity=None):
393
- # Construct template
514
+ async def launch(
515
+ self,
516
+ *,
517
+ cluster: SlurmClusterConfig,
518
+ job: xm.Job | xm.JobGroup,
519
+ dependency: dependencies.SlurmJobDependency | None = None,
520
+ args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs | None,
521
+ experiment_id: int,
522
+ identity: str | None = None,
523
+ ):
394
524
  template = await self.template(
395
525
  job=job,
526
+ dependency=dependency,
396
527
  cluster=cluster,
397
528
  args=args,
398
529
  experiment_id=experiment_id,
399
530
  identity=identity,
400
531
  )
401
- logging.debug("Slurm submission script:\n%s", template)
532
+ logger.debug("Slurm submission script:\n%s", template)
402
533
 
403
534
  # Hash submission script
404
535
  template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
405
536
 
406
- conn = await self.connection(cluster.ssh.connection_options)
537
+ conn = await self.connection(cluster.ssh)
407
538
  async with conn.start_sftp_client() as sftp:
408
539
  # Write the submission script to the cluster
409
540
  # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
@@ -416,7 +547,7 @@ class Client:
416
547
 
417
548
  # Construct and run command on the cluster
418
549
  command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
419
- result = await self.run(cluster.ssh.connection_options, command)
550
+ result = await self.run(cluster.ssh, command)
420
551
  if result.returncode != 0:
421
552
  raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
422
553
 
@@ -435,21 +566,32 @@ class Client:
435
566
  assert job.name is not None
436
567
  return [
437
568
  SlurmHandle(
569
+ experiment_id=experiment_id,
438
570
  ssh=cluster.ssh,
439
- job_id=f"{slurm_job_id}_{array_index}",
571
+ slurm_job=f"{slurm_job_id}_{array_index}",
440
572
  job_name=job.name,
441
573
  )
442
574
  for array_index in range(len(args))
443
575
  ]
444
576
  elif isinstance(job, xm.Job):
445
577
  assert job.name is not None
446
- return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
578
+ return SlurmHandle(
579
+ experiment_id=experiment_id,
580
+ ssh=cluster.ssh,
581
+ slurm_job=slurm_job_id,
582
+ job_name=job.name,
583
+ )
447
584
  elif isinstance(job, xm.JobGroup):
448
585
  # TODO: make this work for actual job groups.
449
- job = mit.one(job.jobs.values())
586
+ job = tp.cast(xm.Job, mit.one(job.jobs.values()))
450
587
  assert isinstance(job, xm.Job)
451
588
  assert job.name is not None
452
- return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
589
+ return SlurmHandle(
590
+ experiment_id=experiment_id,
591
+ ssh=cluster.ssh,
592
+ slurm_job=slurm_job_id,
593
+ job_name=job.name,
594
+ )
453
595
  else:
454
596
  raise ValueError(f"Unsupported job type: {type(job)}")
455
597
 
@@ -458,30 +600,33 @@ class Client:
458
600
  conn.close()
459
601
 
460
602
 
461
- @typing.overload
603
+ @tp.overload
462
604
  async def launch(
463
605
  *,
464
606
  job: xm.JobGroup,
465
- args: Mapping[str, JobArgs],
607
+ dependency: dependencies.SlurmJobDependency | None = None,
608
+ args: tp.Mapping[str, JobArgs],
466
609
  experiment_id: int,
467
610
  identity: str | None = ...,
468
611
  ) -> SlurmHandle: ...
469
612
 
470
613
 
471
- @typing.overload
614
+ @tp.overload
472
615
  async def launch(
473
616
  *,
474
617
  job: xm.Job,
475
- args: Sequence[JobArgs],
618
+ dependency: dependencies.SlurmJobDependency | None = None,
619
+ args: tp.Sequence[JobArgs],
476
620
  experiment_id: int,
477
621
  identity: str | None = ...,
478
622
  ) -> list[SlurmHandle]: ...
479
623
 
480
624
 
481
- @typing.overload
625
+ @tp.overload
482
626
  async def launch(
483
627
  *,
484
628
  job: xm.Job,
629
+ dependency: dependencies.SlurmJobDependency | None = None,
485
630
  args: JobArgs,
486
631
  experiment_id: int,
487
632
  identity: str | None = ...,
@@ -491,7 +636,8 @@ async def launch(
491
636
  async def launch(
492
637
  *,
493
638
  job: xm.Job | xm.JobGroup,
494
- args: Mapping[str, JobArgs] | Sequence[JobArgs] | JobArgs,
639
+ dependency: dependencies.SlurmJobDependency | None = None,
640
+ args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs,
495
641
  experiment_id: int,
496
642
  identity: str | None = None,
497
643
  ) -> SlurmHandle | list[SlurmHandle]:
@@ -503,11 +649,14 @@ async def launch(
503
649
  cluster = job_requirements.cluster
504
650
  if cluster is None:
505
651
  raise ValueError("Job must have a cluster requirement")
652
+ if cluster.validate is not None:
653
+ cluster.validate(job)
506
654
 
507
655
  return await get_client().launch(
508
656
  cluster=cluster,
509
657
  job=job,
510
- args=typing.cast(JobArgs | Sequence[JobArgs], args),
658
+ dependency=dependency,
659
+ args=tp.cast(JobArgs | tp.Sequence[JobArgs], args),
511
660
  experiment_id=experiment_id,
512
661
  identity=identity,
513
662
  )
@@ -521,6 +670,8 @@ async def launch(
521
670
  raise ValueError("Job must have a Slurm executor")
522
671
  if job_item.executor.requirements.cluster is None:
523
672
  raise ValueError("Job must have a cluster requirement")
673
+ if job_item.executor.requirements.cluster.validate is not None:
674
+ job_item.executor.requirements.cluster.validate(job_item)
524
675
  job_group_clusters.add(job_item.executor.requirements.cluster)
525
676
  job_group_executors.add(id(job_item.executor))
526
677
  if len(job_group_executors) != 1:
@@ -531,9 +682,8 @@ async def launch(
531
682
  return await get_client().launch(
532
683
  cluster=job_group_clusters.pop(),
533
684
  job=job_group,
534
- args=typing.cast(Mapping[str, JobArgs], args),
685
+ dependency=dependency,
686
+ args=tp.cast(tp.Mapping[str, JobArgs], args),
535
687
  experiment_id=experiment_id,
536
688
  identity=identity,
537
689
  )
538
- case _:
539
- raise ValueError("Unsupported job type")
xm_slurm/executors.py CHANGED
@@ -48,6 +48,9 @@ class Slurm(xm.Executor):
48
48
  qos: str | None = None
49
49
  priority: int | None = None
50
50
 
51
+ # Job dependency handling
52
+ kill_on_invalid_dependencies: bool = True
53
+
51
54
  # Job rescheduling
52
55
  timeout_signal: signal.Signals = signal.SIGUSR2
53
56
  timeout_signal_grace_period: dt.timedelta = dt.timedelta(seconds=90)
@@ -93,6 +96,11 @@ class Slurm(xm.Executor):
93
96
  minutes, seconds = divmod(remainder, 60)
94
97
  directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
95
98
 
99
+ # Job dependency handling
100
+ directives.append(
101
+ f"--kill-on-invalid-dep={'yes' if self.kill_on_invalid_dependencies else 'no'}"
102
+ )
103
+
96
104
  # Placement
97
105
  if self.account:
98
106
  directives.append(f"--account={self.account}")
@@ -113,15 +121,3 @@ class Slurm(xm.Executor):
113
121
  directives.append("--no-requeue")
114
122
 
115
123
  return directives
116
-
117
-
118
- class DockerSpec(xm.ExecutorSpec):
119
- """Local Docker executor specification that describes the container runtime."""
120
-
121
-
122
- class Docker(xm.Executor):
123
- """Local Docker executor describing the runtime environment."""
124
-
125
- @classmethod
126
- def Spec(cls) -> DockerSpec:
127
- return DockerSpec()