xmanager-slurm 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (38) hide show
  1. xm_slurm/__init__.py +44 -0
  2. xm_slurm/api.py +261 -0
  3. xm_slurm/batching.py +139 -0
  4. xm_slurm/config.py +162 -0
  5. xm_slurm/console.py +3 -0
  6. xm_slurm/contrib/clusters/__init__.py +52 -0
  7. xm_slurm/contrib/clusters/drac.py +169 -0
  8. xm_slurm/executables.py +201 -0
  9. xm_slurm/execution.py +491 -0
  10. xm_slurm/executors.py +127 -0
  11. xm_slurm/experiment.py +737 -0
  12. xm_slurm/job_blocks.py +14 -0
  13. xm_slurm/packageables.py +292 -0
  14. xm_slurm/packaging/__init__.py +8 -0
  15. xm_slurm/packaging/docker/__init__.py +75 -0
  16. xm_slurm/packaging/docker/abc.py +112 -0
  17. xm_slurm/packaging/docker/cloud.py +503 -0
  18. xm_slurm/packaging/docker/local.py +206 -0
  19. xm_slurm/packaging/registry.py +45 -0
  20. xm_slurm/packaging/router.py +52 -0
  21. xm_slurm/packaging/utils.py +202 -0
  22. xm_slurm/resources.py +150 -0
  23. xm_slurm/status.py +188 -0
  24. xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
  25. xm_slurm/templates/docker/mamba.Dockerfile +27 -0
  26. xm_slurm/templates/docker/pdm.Dockerfile +31 -0
  27. xm_slurm/templates/docker/python.Dockerfile +24 -0
  28. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
  29. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  30. xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
  31. xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
  32. xm_slurm/templates/slurm/job.bash.j2 +78 -0
  33. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
  34. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
  35. xm_slurm/utils.py +69 -0
  36. xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
  37. xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
  38. xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/execution.py ADDED
@@ -0,0 +1,491 @@
1
+ import asyncio
2
+ import collections.abc
3
+ import dataclasses
4
+ import functools
5
+ import hashlib
6
+ import logging
7
+ import operator
8
+ import re
9
+ import shlex
10
+ import typing
11
+ from typing import Any, Mapping, Sequence
12
+
13
+ import asyncssh
14
+ import backoff
15
+ import jinja2 as j2
16
+ from asyncssh.auth import KbdIntPrompts, KbdIntResponse
17
+ from asyncssh.misc import MaybeAwait
18
+ from xmanager import xm
19
+
20
+ from xm_slurm import batching, config, executors, status
21
+ from xm_slurm.console import console
22
+
23
+ SlurmClusterConfig = config.SlurmClusterConfig
24
+ ContainerRuntime = config.ContainerRuntime
25
+
26
+ """
27
+ === Runtime Configurations ===
28
+ With RunC:
29
+ skopeo copy --dest-creds=<username>:<secret> docker://<image>@<digest> oci:<image>:<digest>
30
+
31
+ pushd $SLURM_TMPDIR
32
+
33
+ umoci raw unpack --rootless --image <image>:<digest> bundle/<digest>
34
+ umoci raw runtime-config --image <image>:<digest> bundle/<digest>/config.json
35
+
36
+ runc run -b bundle/<digest> <container-id>
37
+
38
+ With Singularity / Apptainer:
39
+
40
+ apptainer build --fix-perms --sandbox <digest> docker://<image>@<digest>
41
+ apptainer run --compat <digest>
42
+ """
43
+
44
+ """
45
+ #SBATCH --error=/dev/null
46
+ #SBATCH --output=/dev/null
47
+ """
48
+
49
+ _POLL_INTERVAL = 30.0
50
+ _BATCHED_BATCH_SIZE = 16
51
+ _BATCHED_TIMEOUT = 0.2
52
+
53
+
54
+ class SlurmExecutionError(Exception): ...
55
+
56
+
57
+ class NoKBAuthSSHClient(asyncssh.SSHClient):
58
+ """SSHClient that does not prompt for keyboard-interactive authentication."""
59
+
60
+ def kbdint_auth_requested(self) -> MaybeAwait[str | None]:
61
+ return ""
62
+
63
+ def kbdint_challenge_received(
64
+ self, name: str, instructions: str, lang: str, prompts: KbdIntPrompts
65
+ ) -> MaybeAwait[KbdIntResponse | None]:
66
+ del name, instructions, lang, prompts
67
+ return []
68
+
69
+
70
+ def _group_by_ssh_options(
71
+ ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
72
+ ) -> dict[asyncssh.SSHClientConnectionOptions, list[str]]:
73
+ jobs_by_cluster = collections.defaultdict(list)
74
+ for options, job_id in zip(ssh_options, job_ids):
75
+ jobs_by_cluster[options].append(job_id)
76
+ return jobs_by_cluster
77
+
78
+
79
+ class _BatchedSlurmHandle:
80
+ @functools.partial(
81
+ batching.batch,
82
+ max_batch_size=_BATCHED_BATCH_SIZE,
83
+ batch_timeout=_BATCHED_TIMEOUT,
84
+ )
85
+ @staticmethod
86
+ async def _batched_get_state(
87
+ ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
88
+ ) -> Sequence[status.SlurmJobState]:
89
+ async def _get_state(
90
+ options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
91
+ ) -> Sequence[status.SlurmJobState]:
92
+ result = await get_client().run(
93
+ options,
94
+ [
95
+ "sacct",
96
+ "--jobs",
97
+ ",".join(job_ids),
98
+ "--format",
99
+ "JobID,State",
100
+ "--allocations",
101
+ "--noheader",
102
+ "--parsable2",
103
+ ],
104
+ check=True,
105
+ )
106
+
107
+ assert isinstance(result.stdout, str)
108
+ states_by_job_id = {}
109
+ for line in result.stdout.splitlines():
110
+ job_id, state = line.split("|")
111
+ states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
112
+
113
+ job_states = []
114
+ for job_id in job_ids:
115
+ if job_id in states_by_job_id:
116
+ job_states.append(states_by_job_id[job_id])
117
+ # This is a stupid hack around sacct's inability to display state information for
118
+ # array job elements that haven't begun. We'll assume that if the job ID is not found,
119
+ # and it's an array job, then it's pending.
120
+ elif re.match(r"^\d+_\d+$", job_id) is not None:
121
+ job_states.append(status.SlurmJobState.PENDING)
122
+ else:
123
+ raise SlurmExecutionError(f"Failed to find job state info for {job_id}")
124
+ return job_states
125
+
126
+ jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
127
+
128
+ job_states_per_cluster = await asyncio.gather(*[
129
+ _get_state(options, job_ids) for options, job_ids in jobs_by_cluster.items()
130
+ ])
131
+ job_states_by_cluster: dict[
132
+ asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
133
+ ] = {}
134
+ for options, job_states in zip(ssh_options, job_states_per_cluster):
135
+ job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
136
+
137
+ job_states = []
138
+ for options, job_id in zip(ssh_options, job_ids):
139
+ job_states.append(job_states_by_cluster[options][job_id])
140
+ return job_states
141
+
142
+ @functools.partial(
143
+ batching.batch,
144
+ max_batch_size=_BATCHED_BATCH_SIZE,
145
+ batch_timeout=_BATCHED_TIMEOUT,
146
+ )
147
+ @staticmethod
148
+ async def _batched_cancel(
149
+ ssh_options: Sequence[asyncssh.SSHClientConnectionOptions], job_ids: Sequence[str]
150
+ ) -> Sequence[None]:
151
+ async def _cancel(
152
+ options: asyncssh.SSHClientConnectionOptions, job_ids: Sequence[str]
153
+ ) -> None:
154
+ await get_client().run(options, ["scancel", " ".join(job_ids)], check=True)
155
+
156
+ jobs_by_cluster = _group_by_ssh_options(ssh_options, job_ids)
157
+ return await asyncio.gather(*[
158
+ _cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
159
+ ])
160
+
161
+
162
+ @dataclasses.dataclass(frozen=True, kw_only=True)
163
+ class SlurmHandle(_BatchedSlurmHandle):
164
+ """A handle for referring to the launched container."""
165
+
166
+ ssh_connection_options: asyncssh.SSHClientConnectionOptions
167
+ job_id: str
168
+
169
+ def __post_init__(self):
170
+ if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
171
+ raise ValueError(f"Invalid job ID: {self.job_id}")
172
+
173
+ @backoff.on_predicate(
174
+ backoff.constant,
175
+ lambda state: state in status.SlurmActiveJobStates,
176
+ jitter=None,
177
+ interval=_POLL_INTERVAL,
178
+ )
179
+ async def wait(self) -> status.SlurmJobState:
180
+ return await self.get_state()
181
+
182
+ async def stop(self) -> None:
183
+ await self._batched_cancel(self.ssh_connection_options, self.job_id)
184
+
185
+ async def get_state(self) -> status.SlurmJobState:
186
+ return await self._batched_get_state(self.ssh_connection_options, self.job_id)
187
+
188
+
189
+ @functools.cache
190
+ def get_client() -> "Client":
191
+ return Client()
192
+
193
+
194
+ @functools.cache
195
+ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
196
+ template_loader = j2.PackageLoader("xm_slurm", "templates/slurm")
197
+ template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
198
+
199
+ def _raise_template_exception(msg: str) -> None:
200
+ raise j2.TemplateRuntimeError(msg)
201
+
202
+ template_env.globals["raise"] = _raise_template_exception
203
+ template_env.globals["operator"] = operator
204
+
205
+ match container_runtime:
206
+ case ContainerRuntime.SINGULARITY | ContainerRuntime.APPTAINER:
207
+ runtime_template = template_env.get_template("runtimes/apptainer.bash.j2")
208
+ case ContainerRuntime.PODMAN:
209
+ runtime_template = template_env.get_template("runtimes/podman.bash.j2")
210
+ case _:
211
+ raise NotImplementedError
212
+ # Update our global env with the runtime template's exported globals
213
+ template_env.globals.update(runtime_template.module.__dict__)
214
+
215
+ return template_env
216
+
217
+
218
+ class Client:
219
+ def __init__(self):
220
+ self._connections: dict[
221
+ asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
222
+ ] = {}
223
+ self._connection_lock = asyncio.Lock()
224
+
225
+ @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
226
+ async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
227
+ # Make sure the xm-slurm state directory exists
228
+ await conn.run("mkdir -p ~/.local/state/xm-slurm", check=True)
229
+
230
+ async def connection(
231
+ self,
232
+ options: asyncssh.SSHClientConnectionOptions,
233
+ ) -> asyncssh.SSHClientConnection:
234
+ if options not in self._connections:
235
+ async with self._connection_lock:
236
+ try:
237
+ conn, _ = await asyncssh.create_connection(NoKBAuthSSHClient, options=options)
238
+ await self._setup_remote_connection(conn)
239
+ self._connections[options] = conn
240
+ except asyncssh.misc.PermissionDenied as ex:
241
+ raise RuntimeError(f"Permission denied connecting to {options.host}") from ex
242
+ return self._connections[options]
243
+
244
+ @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
245
+ async def run(
246
+ self,
247
+ options: asyncssh.SSHClientConnectionOptions,
248
+ command: xm.SequentialArgs | str | Sequence[str],
249
+ *,
250
+ check: bool = False,
251
+ timeout: float | None = None,
252
+ ) -> asyncssh.SSHCompletedProcess:
253
+ client = await self.connection(options)
254
+ if isinstance(command, xm.SequentialArgs):
255
+ command = command.to_list()
256
+ if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
257
+ command = shlex.join(command)
258
+ assert isinstance(command, str)
259
+ logging.debug("Running command on %s: %s", options.host, command)
260
+
261
+ return await client.run(command, check=check, timeout=timeout)
262
+
263
+ async def template(
264
+ self,
265
+ *,
266
+ job: xm.Job | xm.JobGroup,
267
+ cluster: SlurmClusterConfig,
268
+ args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
269
+ experiment_id: int,
270
+ identity: str | None,
271
+ ) -> str:
272
+ if args is None:
273
+ args = {}
274
+
275
+ template_env = get_template_env(cluster.runtime)
276
+
277
+ # Sanitize job groups
278
+ if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
279
+ job = typing.cast(xm.Job, list(job.jobs.values())[0])
280
+ elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
281
+ raise ValueError("Job group must have at least one job")
282
+
283
+ match job:
284
+ case xm.Job() as job_array if isinstance(args, collections.abc.Sequence):
285
+ template = template_env.get_template("job-array.bash.j2")
286
+ sequential_args = [
287
+ xm.SequentialArgs.from_collection(trial.get("args", None)) for trial in args
288
+ ]
289
+ env_vars = [trial.get("env_vars") for trial in args]
290
+ if any(env_vars):
291
+ raise NotImplementedError(
292
+ "Job arrays over environment variables are not yet supported."
293
+ )
294
+
295
+ return template.render(
296
+ job=job_array,
297
+ cluster=cluster,
298
+ args=sequential_args,
299
+ env_vars=env_vars,
300
+ experiment_id=experiment_id,
301
+ identity=identity,
302
+ )
303
+ case xm.Job() if isinstance(args, collections.abc.Mapping):
304
+ template = template_env.get_template("job.bash.j2")
305
+ sequential_args = xm.SequentialArgs.from_collection(args.get("args", None))
306
+ env_vars = args.get("env_vars", None)
307
+ return template.render(
308
+ job=job,
309
+ cluster=cluster,
310
+ args=sequential_args,
311
+ env_vars=env_vars,
312
+ experiment_id=experiment_id,
313
+ identity=identity,
314
+ )
315
+ case xm.JobGroup() as job_group if isinstance(args, collections.abc.Mapping):
316
+ template = template_env.get_template("job-group.bash.j2")
317
+ sequential_args = {
318
+ job_name: {
319
+ "args": args.get(job_name, {}).get("args", None),
320
+ }
321
+ for job_name in job_group.jobs.keys()
322
+ }
323
+ env_vars = {
324
+ job_name: args.get(job_name, {}).get("env_vars", None)
325
+ for job_name in job_group.jobs.keys()
326
+ }
327
+ return template.render(
328
+ job_group=job_group,
329
+ cluster=cluster,
330
+ args=sequential_args,
331
+ env_vars=env_vars,
332
+ experiment_id=experiment_id,
333
+ identity=identity,
334
+ )
335
+ case _:
336
+ raise ValueError(f"Unsupported job type: {type(job)}")
337
+
338
+ @typing.overload
339
+ async def launch(
340
+ self,
341
+ *,
342
+ cluster: SlurmClusterConfig,
343
+ job: xm.Job | xm.JobGroup,
344
+ args: Mapping[str, Any] | None,
345
+ experiment_id: int,
346
+ identity: str | None = ...,
347
+ ) -> SlurmHandle: ...
348
+
349
+ @typing.overload
350
+ async def launch(
351
+ self,
352
+ *,
353
+ cluster: SlurmClusterConfig,
354
+ job: xm.Job | xm.JobGroup,
355
+ args: Sequence[Mapping[str, Any]],
356
+ experiment_id: int,
357
+ identity: str | None = ...,
358
+ ) -> Sequence[SlurmHandle]: ...
359
+
360
+ async def launch(
361
+ self,
362
+ *,
363
+ cluster: SlurmClusterConfig,
364
+ job: xm.Job | xm.JobGroup,
365
+ args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
366
+ experiment_id: int,
367
+ identity: str | None = None,
368
+ ) -> SlurmHandle | Sequence[SlurmHandle]:
369
+ # Construct template
370
+ template = await self.template(
371
+ job=job,
372
+ cluster=cluster,
373
+ args=args,
374
+ experiment_id=experiment_id,
375
+ identity=identity,
376
+ )
377
+ logging.debug("Slurm submission script:\n%s", template)
378
+
379
+ # Hash submission script
380
+ template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
381
+
382
+ conn = await self.connection(cluster.ssh_connection_options)
383
+ async with conn.start_sftp_client() as sftp:
384
+ # Write the submission script to the cluster
385
+ # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
386
+ # INSTEAD OF ASSUMING SFTP PUTS US IN THE HOME DIRECTORY
387
+ await sftp.makedirs(f".local/state/xm-slurm/{experiment_id}", exist_ok=True)
388
+ async with sftp.open(
389
+ f".local/state/xm-slurm/{experiment_id}/submission-script-{template_hash}.sh", "w"
390
+ ) as fp:
391
+ await fp.write(template)
392
+
393
+ # Construct and run command on the cluster
394
+ command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
395
+ result = await self.run(cluster.ssh_connection_options, command)
396
+ if result.returncode != 0:
397
+ raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
398
+
399
+ assert isinstance(result.stdout, str)
400
+ slurm_job_id, *_ = result.stdout.split(",")
401
+ slurm_job_id = slurm_job_id.strip()
402
+
403
+ console.log(
404
+ f"[magenta]:rocket: Job [cyan]{slurm_job_id}[/cyan] will be launched on "
405
+ f"[cyan]{cluster.name}[/cyan] "
406
+ )
407
+
408
+ if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
409
+ return [
410
+ SlurmHandle(
411
+ ssh_connection_options=cluster.ssh_connection_options,
412
+ job_id=f"{slurm_job_id}_{array_index}",
413
+ )
414
+ for array_index in range(len(args))
415
+ ]
416
+
417
+ return SlurmHandle(
418
+ ssh_connection_options=cluster.ssh_connection_options,
419
+ job_id=slurm_job_id,
420
+ )
421
+
422
+
423
+ @typing.overload
424
+ async def launch(
425
+ *,
426
+ job: xm.Job | xm.JobGroup,
427
+ args: Mapping[str, Any],
428
+ experiment_id: int,
429
+ identity: str | None = ...,
430
+ ) -> SlurmHandle: ...
431
+
432
+
433
+ @typing.overload
434
+ async def launch(
435
+ *,
436
+ job: xm.Job | xm.JobGroup,
437
+ args: Sequence[Mapping[str, Any]],
438
+ experiment_id: int,
439
+ identity: str | None = ...,
440
+ ) -> Sequence[SlurmHandle]: ...
441
+
442
+
443
+ async def launch(
444
+ *,
445
+ job: xm.Job | xm.JobGroup,
446
+ args: Mapping[str, Any] | Sequence[Mapping[str, Any]],
447
+ experiment_id: int,
448
+ identity: str | None = None,
449
+ ) -> SlurmHandle | Sequence[SlurmHandle]:
450
+ match job:
451
+ case xm.Job():
452
+ if not isinstance(job.executor, executors.Slurm):
453
+ raise ValueError("Job must have a Slurm executor")
454
+ job_requirements = job.executor.requirements
455
+ cluster = job_requirements.cluster
456
+ if cluster is None:
457
+ raise ValueError("Job must have a cluster requirement")
458
+
459
+ return await get_client().launch(
460
+ cluster=cluster,
461
+ job=job,
462
+ args=args,
463
+ experiment_id=experiment_id,
464
+ identity=identity,
465
+ )
466
+ case xm.JobGroup() as job_group:
467
+ job_group_executors = set()
468
+ job_group_clusters = set()
469
+ for job_item in job_group.jobs.values():
470
+ if not isinstance(job_item, xm.Job):
471
+ raise ValueError("Job group must contain only jobs")
472
+ if not isinstance(job_item.executor, executors.Slurm):
473
+ raise ValueError("Job must have a Slurm executor")
474
+ if job_item.executor.requirements.cluster is None:
475
+ raise ValueError("Job must have a cluster requirement")
476
+ job_group_clusters.add(job_item.executor.requirements.cluster)
477
+ job_group_executors.add(id(job_item.executor))
478
+ if len(job_group_executors) != 1:
479
+ raise ValueError("Job group must have the same executor for all jobs")
480
+ if len(job_group_clusters) != 1:
481
+ raise ValueError("Job group must have the same cluster for all jobs")
482
+
483
+ return await get_client().launch(
484
+ cluster=job_group_clusters.pop(),
485
+ job=job,
486
+ args=args,
487
+ experiment_id=experiment_id,
488
+ identity=identity,
489
+ )
490
+ case _:
491
+ raise ValueError("Unsupported job type")
xm_slurm/executors.py ADDED
@@ -0,0 +1,127 @@
1
+ import dataclasses
2
+ import datetime as dt
3
+ import signal
4
+
5
+ from xmanager import xm
6
+
7
+ from xm_slurm import resources
8
+
9
+
10
+ @dataclasses.dataclass(frozen=True, kw_only=True)
11
+ class SlurmSpec(xm.ExecutorSpec):
12
+ """Slurm executor specification that describes the location of the container runtime.
13
+
14
+ Args:
15
+ tag: The Image URI to push and pull the container image from.
16
+ For example, using the GitHub Container Registry: `ghcr.io/my-project/my-image:latest`.
17
+ """
18
+
19
+ tag: str | None = None
20
+
21
+
22
+ @dataclasses.dataclass(frozen=True, kw_only=True)
23
+ class Slurm(xm.Executor):
24
+ """Slurm Executor describing the runtime environment.
25
+
26
+ Args:
27
+ requirements: The requirements for the job.
28
+ time: The maximum time to run the job.
29
+ account: The account to charge the job to.
30
+ partition: The partition to run the job in.
31
+ qos: The quality of service to run the job with.
32
+ priority: The priority of the job.
33
+ timeout_signal: The signal to send to the job when it runs out of time.
34
+ timeout_signal_grace_period: The time to wait before sending `timeout_signal`.
35
+ requeue: Whether or not the job is eligible for requeueing.
36
+ requeue_on_exit_code: The exit code that triggers requeueing.
37
+ requeue_max_attempts: The maximum number of times to attempt requeueing.
38
+
39
+ """
40
+
41
+ # Job requirements
42
+ requirements: resources.JobRequirements
43
+ time: dt.timedelta
44
+
45
+ # Placement
46
+ account: str | None = None
47
+ partition: str | None = None
48
+ qos: str | None = None
49
+ priority: int | None = None
50
+
51
+ # Job rescheduling
52
+ timeout_signal: signal.Signals = signal.SIGUSR2
53
+ timeout_signal_grace_period: dt.timedelta = dt.timedelta(seconds=90)
54
+
55
+ requeue: bool = True # Is this job ellible for requeueing?
56
+ requeue_on_exit_code: int = 42 # The exit code that triggers requeueing
57
+ requeue_max_attempts: int = 5 # How many times to attempt requeueing
58
+
59
+ def __post_init__(self) -> None:
60
+ if not isinstance(self.time, dt.timedelta):
61
+ raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
62
+ if not isinstance(self.requirements, resources.JobRequirements):
63
+ raise TypeError(
64
+ f"requirements must be a `xm_slurm.JobRequirements`, got {type(self.requirements)}. "
65
+ "If you're still using `xm.JobRequirements`, please update to `xm_slurm.JobRequirements`."
66
+ )
67
+ if not isinstance(self.timeout_signal, signal.Signals):
68
+ raise TypeError(
69
+ f"termination_signal must be a `signal.Signals`, got {type(self.timeout_signal)}"
70
+ )
71
+ if not isinstance(self.timeout_signal_grace_period, dt.timedelta):
72
+ raise TypeError(
73
+ f"termination_signal_delay_time must be a `datetime.timedelta`, got {type(self.timeout_signal_grace_period)}"
74
+ )
75
+ if self.requeue_max_attempts < 0:
76
+ raise ValueError(
77
+ f"requeue_max_attempts must be greater than or equal to 0, got {self.requeue_max_attempts}"
78
+ )
79
+ if self.requeue_on_exit_code == 0:
80
+ raise ValueError("requeue_on_exit_code should not be 0 to avoid unexpected behavior.")
81
+
82
+ @classmethod
83
+ def Spec(cls, tag: str | None = None) -> SlurmSpec:
84
+ return SlurmSpec(tag=tag)
85
+
86
+ def to_directives(self) -> list[str]:
87
+ # Job requirements
88
+ directives = self.requirements.to_directives()
89
+
90
+ # Time
91
+ days = self.time.days
92
+ hours, remainder = divmod(self.time.seconds, 3600)
93
+ minutes, seconds = divmod(remainder, 60)
94
+ directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
95
+
96
+ # Placement
97
+ if self.account:
98
+ directives.append(f"--account={self.account}")
99
+ if self.partition:
100
+ directives.append(f"--partition={self.partition}")
101
+ if self.qos:
102
+ directives.append(f"--qos={self.qos}")
103
+ if self.priority:
104
+ directives.append(f"--priority={self.priority}")
105
+
106
+ # Job rescheduling
107
+ directives.append(
108
+ f"--signal={self.timeout_signal.name.removeprefix('SIG')}@{self.timeout_signal_grace_period.seconds}"
109
+ )
110
+ if self.requeue and self.requeue_max_attempts > 0:
111
+ directives.append("--requeue")
112
+ else:
113
+ directives.append("--no-requeue")
114
+
115
+ return directives
116
+
117
+
118
+ class DockerSpec(xm.ExecutorSpec):
119
+ """Local Docker executor specification that describes the container runtime."""
120
+
121
+
122
+ class Docker(xm.Executor):
123
+ """Local Docker executor describing the runtime environment."""
124
+
125
+ @classmethod
126
+ def Spec(cls) -> DockerSpec:
127
+ return DockerSpec()