xmanager-slurm 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (38) hide show
  1. xm_slurm/__init__.py +44 -0
  2. xm_slurm/api.py +261 -0
  3. xm_slurm/batching.py +139 -0
  4. xm_slurm/config.py +162 -0
  5. xm_slurm/console.py +3 -0
  6. xm_slurm/contrib/clusters/__init__.py +52 -0
  7. xm_slurm/contrib/clusters/drac.py +169 -0
  8. xm_slurm/executables.py +201 -0
  9. xm_slurm/execution.py +491 -0
  10. xm_slurm/executors.py +127 -0
  11. xm_slurm/experiment.py +737 -0
  12. xm_slurm/job_blocks.py +14 -0
  13. xm_slurm/packageables.py +292 -0
  14. xm_slurm/packaging/__init__.py +8 -0
  15. xm_slurm/packaging/docker/__init__.py +75 -0
  16. xm_slurm/packaging/docker/abc.py +112 -0
  17. xm_slurm/packaging/docker/cloud.py +503 -0
  18. xm_slurm/packaging/docker/local.py +206 -0
  19. xm_slurm/packaging/registry.py +45 -0
  20. xm_slurm/packaging/router.py +52 -0
  21. xm_slurm/packaging/utils.py +202 -0
  22. xm_slurm/resources.py +150 -0
  23. xm_slurm/status.py +188 -0
  24. xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
  25. xm_slurm/templates/docker/mamba.Dockerfile +27 -0
  26. xm_slurm/templates/docker/pdm.Dockerfile +31 -0
  27. xm_slurm/templates/docker/python.Dockerfile +24 -0
  28. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
  29. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  30. xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
  31. xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
  32. xm_slurm/templates/slurm/job.bash.j2 +78 -0
  33. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
  34. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
  35. xm_slurm/utils.py +69 -0
  36. xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
  37. xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
  38. xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,206 @@
1
+ import base64
2
+ import dataclasses
3
+ import enum
4
+ import json
5
+ import logging
6
+ import os
7
+ import pathlib
8
+ import shlex
9
+ import shutil
10
+ import subprocess
11
+ import tempfile
12
+ import typing
13
+ from typing import Sequence
14
+
15
+ from xmanager import xm
16
+
17
+ from xm_slurm.executables import (
18
+ Dockerfile,
19
+ ImageURI,
20
+ RemoteImage,
21
+ RemoteRepositoryCredentials,
22
+ )
23
+ from xm_slurm.executors import SlurmSpec
24
+ from xm_slurm.packaging import utils as packaging_utils
25
+ from xm_slurm.packaging.docker.abc import (
26
+ DockerBakeCommand,
27
+ DockerClient,
28
+ DockerVersionCommand,
29
+ )
30
+ from xm_slurm.packaging.registry import IndexedContainer
31
+
32
+
33
+ class LocalDockerClient(DockerClient):
34
+ """Build Docker images locally."""
35
+
36
+ class Builder(enum.Enum):
37
+ BUILDKIT = enum.auto()
38
+ BUILDAH = enum.auto()
39
+
40
+ def __init__(self):
41
+ if "XM_DOCKER_CLIENT" in os.environ:
42
+ client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
43
+ elif shutil.which("docker"):
44
+ client_call = ["docker"]
45
+ elif shutil.which("podman"):
46
+ client_call = ["podman"]
47
+ else:
48
+ raise RuntimeError("No Docker client found.")
49
+ self._client_call = client_call
50
+
51
+ version_command = DockerVersionCommand()
52
+ backend_version = packaging_utils.run_command(
53
+ xm.merge_args(self._client_call, version_command.to_args()), return_stdout=True
54
+ )
55
+ if backend_version.stdout.startswith("github.com/docker/buildx"):
56
+ self._builder = LocalDockerClient.Builder.BUILDKIT
57
+ else:
58
+ raise NotImplementedError(f"Unsupported Docker build backend: {backend_version}")
59
+
60
+ self._credentials_cache: dict[str, RemoteRepositoryCredentials] = {}
61
+
62
+ def credentials(self, hostname: str) -> RemoteRepositoryCredentials | None:
63
+ """Fetch credentials from the local Docker configuration."""
64
+ if hostname in self._credentials_cache:
65
+ return self._credentials_cache[hostname]
66
+
67
+ def _parse_docker_credentials(helper: str) -> RemoteRepositoryCredentials | None:
68
+ """Parse credentials from a Docker credential helper."""
69
+ if not shutil.which(f"docker-credential-{helper}"):
70
+ return None
71
+ returncode, output = subprocess.getstatusoutput(
72
+ f"echo {hostname} | docker-credential-{helper} get",
73
+ )
74
+
75
+ if returncode == 0:
76
+ credentials = json.loads(output)
77
+ return RemoteRepositoryCredentials(
78
+ username=credentials["Username"], password=credentials["Secret"]
79
+ )
80
+ return None
81
+
82
+ def _parse_credentials_from_config(
83
+ config_path: pathlib.Path
84
+ ) -> RemoteRepositoryCredentials | None:
85
+ """Parse credentials from the Docker configuration file."""
86
+ if not config_path.exists():
87
+ return None
88
+ config = json.loads(config_path.read_text())
89
+
90
+ # Attempt to parse from the global credential store
91
+ if (creds_store := config.get("credsStore", None)) and (
92
+ credentials := _parse_docker_credentials(creds_store)
93
+ ):
94
+ self._credentials_cache[hostname] = credentials
95
+ return credentials
96
+ # Attempt to parse from the credential helper for this registry
97
+ if creds_helper := config.get("credHelpers", {}):
98
+ for registry, helper in creds_helper.items():
99
+ registry = ImageURI(registry)
100
+ if registry.domain == hostname and (
101
+ credentials := _parse_docker_credentials(helper)
102
+ ):
103
+ self._credentials_cache[hostname] = credentials
104
+ return credentials
105
+ # Last resort: attempt to parse raw auth
106
+ if auths := config.get("auths", None):
107
+ for registry, metadata in auths.items():
108
+ registry = ImageURI(registry)
109
+ if registry.domain == hostname:
110
+ auth = base64.b64decode(metadata["auth"]).decode("utf-8")
111
+ username, password = auth.split(":")
112
+ credentials = RemoteRepositoryCredentials(username, password)
113
+ self._credentials_cache[hostname] = credentials
114
+ return credentials
115
+ return None
116
+
117
+ # Attempt to parse credentials from the Docker or Podman configuration
118
+ match self._builder:
119
+ case LocalDockerClient.Builder.BUILDKIT:
120
+ docker_config_path = (
121
+ pathlib.Path(os.environ.get("DOCKER_CONFIG", "~/.docker")).expanduser()
122
+ / "config.json"
123
+ )
124
+ return _parse_credentials_from_config(docker_config_path)
125
+ case LocalDockerClient.Builder.BUILDAH:
126
+ podman_config_path = (
127
+ pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")).expanduser()
128
+ / "containers"
129
+ / "auth.json"
130
+ )
131
+ return _parse_credentials_from_config(podman_config_path)
132
+ case _:
133
+ return None
134
+
135
+ def bake(
136
+ self,
137
+ *,
138
+ targets: Sequence[IndexedContainer[xm.Packageable]],
139
+ ) -> list[IndexedContainer[RemoteImage]]:
140
+ executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
141
+ executors_by_executables = typing.cast(
142
+ dict[Dockerfile, list[SlurmSpec]], executors_by_executables
143
+ )
144
+
145
+ with tempfile.TemporaryDirectory() as tempdir:
146
+ hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
147
+ metadata_file = pathlib.Path(tempdir) / "metadata.json"
148
+
149
+ # Write HCL and bake it
150
+ hcl = self._bake_template.render(
151
+ executables=executors_by_executables,
152
+ hash=packaging_utils.hash_digest,
153
+ )
154
+ hcl_file.write_text(hcl)
155
+ logging.debug(hcl)
156
+
157
+ try:
158
+ command = DockerBakeCommand(
159
+ targets=list(
160
+ set(
161
+ [
162
+ packaging_utils.hash_digest(target.value.executable_spec)
163
+ for target in targets
164
+ ]
165
+ )
166
+ ),
167
+ files=[hcl_file],
168
+ metadata_file=metadata_file,
169
+ pull=True,
170
+ push=True,
171
+ )
172
+
173
+ bake_command = xm.merge_args(self._client_call, command.to_args())
174
+ packaging_utils.run_command(bake_command.to_list(), tty=True, check=True)
175
+ except Exception as ex:
176
+ raise RuntimeError(f"Failed to build Dockerfiles: {ex}") from ex
177
+ else:
178
+ metadata = json.loads(metadata_file.read_text())
179
+
180
+ images = []
181
+ for target in targets:
182
+ assert isinstance(target.value.executable_spec, Dockerfile)
183
+ assert isinstance(target.value.executor_spec, SlurmSpec)
184
+ assert target.value.executor_spec.tag
185
+
186
+ executable_metadata = metadata[
187
+ packaging_utils.hash_digest(target.value.executable_spec)
188
+ ]
189
+ uri = ImageURI(target.value.executor_spec.tag).with_digest(
190
+ executable_metadata["containerimage.digest"]
191
+ )
192
+
193
+ images.append(
194
+ dataclasses.replace(
195
+ target,
196
+ value=RemoteImage( # type: ignore
197
+ image=str(uri),
198
+ workdir=target.value.executable_spec.workdir,
199
+ args=target.value.args,
200
+ env_vars=target.value.env_vars,
201
+ credentials=self.credentials(uri.domain),
202
+ ),
203
+ )
204
+ )
205
+
206
+ return images
@@ -0,0 +1,45 @@
1
+ import dataclasses
2
+ from typing import Callable, Generic, ParamSpec, Sequence, Type, TypeVar
3
+
4
+ from xmanager import xm
5
+
6
+ T_co = TypeVar("T_co", covariant=True)
7
+ P = ParamSpec("P")
8
+ ExecutableSpecT = TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
9
+
10
+
11
+ @dataclasses.dataclass(frozen=True)
12
+ class IndexedContainer(Generic[T_co]):
13
+ index: int
14
+ value: T_co
15
+
16
+
17
+ RegistrationCallable = Callable[
18
+ [Sequence[IndexedContainer[xm.Packageable]]],
19
+ Sequence[IndexedContainer[xm.Executable]],
20
+ ]
21
+
22
+
23
+ _REGISTRY: dict[Type[xm.ExecutableSpec], RegistrationCallable] = {}
24
+
25
+
26
+ def register(
27
+ *typs: Type[ExecutableSpecT],
28
+ ) -> Callable[[RegistrationCallable], RegistrationCallable]:
29
+ def decorator(
30
+ registration_callable: RegistrationCallable,
31
+ ) -> RegistrationCallable:
32
+ global _REGISTRY
33
+ for typ in typs:
34
+ _REGISTRY[typ] = registration_callable
35
+ return registration_callable
36
+
37
+ return decorator
38
+
39
+
40
+ def route(
41
+ typ: Type[ExecutableSpecT],
42
+ packageables: Sequence[IndexedContainer[xm.Packageable]],
43
+ ) -> Sequence[IndexedContainer[xm.Executable]]:
44
+ global _REGISTRY
45
+ return _REGISTRY[typ](packageables)
@@ -0,0 +1,52 @@
1
+ import collections
2
+ import logging
3
+ from typing import Sequence, Type
4
+
5
+ from xmanager import xm
6
+
7
+ from xm_slurm.console import console
8
+ from xm_slurm.executors import SlurmSpec
9
+ from xm_slurm.packaging import registry
10
+
11
+ IndexedContainer = registry.IndexedContainer
12
+
13
+
14
+ def package(
15
+ packageables: Sequence[xm.Packageable],
16
+ ) -> list[xm.Executable]:
17
+ """
18
+ Takes as input a list of packageables and returns a mapping of
19
+ `DockerTarget`'s to the latest digest of that image.
20
+ """
21
+ # Docker targets to be collected.
22
+ # These are a mapping from `DockerTarget` to the latest digest of the image.
23
+ targets_by_type = collections.defaultdict[
24
+ Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
25
+ ](list)
26
+
27
+ # Collect dockerfiles that need to be built locally
28
+ for index, packageable in enumerate(packageables):
29
+ if not isinstance(packageable.executor_spec, SlurmSpec):
30
+ raise ValueError(
31
+ f"Unsupported executor spec for packageable: {packageable}."
32
+ "xm_slurm only supports `xm_slurm.SlurmSpec`."
33
+ )
34
+ targets_by_type[type(packageable.executable_spec)].append(
35
+ IndexedContainer[xm.Packageable](index, packageable)
36
+ )
37
+
38
+ targets: list[IndexedContainer[xm.Executable]] = []
39
+ # TODO(jfarebro): Could make this async as well...?
40
+ with console.status("[magenta] :package: Packaging executables..."):
41
+ for executable_spec_type, targets_for_type in targets_by_type.items():
42
+ logging.info(f"Packaging {len(targets_for_type)} {executable_spec_type!r} targets.")
43
+ targets.extend(registry.route(executable_spec_type, targets_for_type))
44
+
45
+ console.print(
46
+ f"[magenta]:package: Finished packaging [bold]{len(targets)} executable"
47
+ f"{'s' if len(targets) > 1 else ''}[/bold]."
48
+ )
49
+
50
+ assert len(targets) == len(packageables), "Number of targets must match packageables"
51
+ targets = sorted(targets, key=lambda t: t.index)
52
+ return [target.value for target in targets]
@@ -0,0 +1,202 @@
1
+ import collections
2
+ import concurrent.futures
3
+ import functools
4
+ import hashlib
5
+ import logging
6
+ import os
7
+ import pathlib
8
+ import pty
9
+ import re
10
+ import select
11
+ import shutil
12
+ import subprocess
13
+ import typing
14
+ from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
15
+
16
+ from xmanager import xm
17
+
18
+ from xm_slurm.packaging.registry import IndexedContainer
19
+
20
+ T = TypeVar("T")
21
+ P = ParamSpec("P")
22
+ ReturnT = TypeVar("ReturnT")
23
+
24
+
25
+ def hash_digest(obj: Hashable) -> str:
26
+ # obj_hash = hash(obj)
27
+ # unsigned_obj_hash = obj_hash.from_bytes(
28
+ # obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
29
+ # "big",
30
+ # signed=False,
31
+ # )
32
+ # return hex(unsigned_obj_hash).removeprefix("0x")
33
+ return hashlib.sha256(repr(obj).encode()).hexdigest()
34
+
35
+
36
+ def collect_executors_by_executable(
37
+ targets: Sequence[IndexedContainer[xm.Packageable]],
38
+ ) -> dict[xm.ExecutableSpec, set[xm.ExecutorSpec]]:
39
+ executors_by_executable = collections.defaultdict(set)
40
+ for target in targets:
41
+ executors_by_executable[target.value.executable_spec].add(target.value.executor_spec)
42
+ return executors_by_executable
43
+
44
+
45
+ def parallel_map(
46
+ f: Callable[Concatenate[T, P], ReturnT],
47
+ ) -> Callable[Concatenate[Sequence[T], P], list[ReturnT]]:
48
+ @functools.wraps(f)
49
+ def decorator(sequence: Sequence[T], *args: P.args, **kwargs: P.kwargs) -> list[ReturnT]:
50
+ with concurrent.futures.ThreadPoolExecutor() as executor:
51
+ return list(executor.map(lambda x: f(x, *args, **kwargs), sequence))
52
+
53
+ return decorator
54
+
55
+
56
+ # Cursor commands to filter out from the command data stream
57
+ cursor_commands_regex = re.compile(
58
+ rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
59
+ rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
60
+ rb"|\x1b\[s" # Matches cursor save position (CSI s)
61
+ rb"|\x1b\[u" # Matches cursor restore position (CSI u)
62
+ rb"|\x1b\[2J" # Matches clear screen (CSI 2J)
63
+ rb"|\x1b\[K" # Matches clear line (CSI K)
64
+ )
65
+
66
+
67
+ @typing.overload
68
+ def run_command(
69
+ args: Sequence[str] | xm.SequentialArgs,
70
+ env: dict[str, str] | None = ...,
71
+ tty: bool = ...,
72
+ cwd: str | os.PathLike[str] | None = ...,
73
+ check: bool = ...,
74
+ return_stdout: Literal[False] = False,
75
+ return_stderr: Literal[False] = False,
76
+ ) -> subprocess.CompletedProcess[None]: ...
77
+
78
+
79
+ @typing.overload
80
+ def run_command(
81
+ args: Sequence[str] | xm.SequentialArgs,
82
+ env: dict[str, str] | None = ...,
83
+ tty: bool = ...,
84
+ cwd: str | os.PathLike[str] | None = ...,
85
+ check: bool = ...,
86
+ return_stdout: Literal[True] = True,
87
+ return_stderr: Literal[False] = False,
88
+ ) -> subprocess.CompletedProcess[str]: ...
89
+
90
+
91
+ @typing.overload
92
+ def run_command(
93
+ args: Sequence[str] | xm.SequentialArgs,
94
+ env: dict[str, str] | None = ...,
95
+ tty: bool = ...,
96
+ cwd: str | os.PathLike[str] | None = ...,
97
+ check: bool = ...,
98
+ return_stdout: Literal[False] = False,
99
+ return_stderr: Literal[True] = True,
100
+ ) -> subprocess.CompletedProcess[str]: ...
101
+
102
+
103
+ @typing.overload
104
+ def run_command(
105
+ args: Sequence[str] | xm.SequentialArgs,
106
+ env: dict[str, str] | None = ...,
107
+ tty: bool = ...,
108
+ cwd: str | os.PathLike[str] | None = ...,
109
+ check: bool = ...,
110
+ return_stdout: Literal[True] = True,
111
+ return_stderr: Literal[True] = True,
112
+ ) -> subprocess.CompletedProcess[str]: ...
113
+
114
+
115
+ def run_command(
116
+ args: Sequence[str] | xm.SequentialArgs,
117
+ env: dict[str, str] | None = None,
118
+ tty: bool = False,
119
+ cwd: str | os.PathLike[str] | None = None,
120
+ check: bool = False,
121
+ return_stdout: bool = False,
122
+ return_stderr: bool = False,
123
+ ) -> subprocess.CompletedProcess[str] | subprocess.CompletedProcess[None]:
124
+ if isinstance(args, xm.SequentialArgs):
125
+ args = args.to_list()
126
+ args = list(args)
127
+
128
+ executable = shutil.which(args[0])
129
+ if not executable:
130
+ raise RuntimeError(f"Couldn't find executable {args[0]}")
131
+ executable = pathlib.Path(executable)
132
+
133
+ subprocess_env = os.environ.copy() | (env if env else {})
134
+ if executable.name == "docker" and args[1] == "buildx":
135
+ subprocess_env |= {"DOCKER_CLI_EXPERIMENTAL": "enabled"}
136
+
137
+ logging.debug(f"env: {subprocess_env}")
138
+ logging.debug(f"command: {' '.join(args)}")
139
+
140
+ stdout_master, stdout_slave = pty.openpty()
141
+ stderr_master, stderr_slave = pty.openpty()
142
+
143
+ stdout_data, stderr_data = b"", b""
144
+ with subprocess.Popen(
145
+ executable=executable,
146
+ args=args,
147
+ shell=False,
148
+ text=True,
149
+ bufsize=0,
150
+ stdout=stdout_slave,
151
+ stderr=stderr_slave,
152
+ start_new_session=True,
153
+ close_fds=True,
154
+ cwd=cwd,
155
+ env=subprocess_env,
156
+ ) as process:
157
+ os.close(stdout_slave)
158
+ os.close(stderr_slave)
159
+
160
+ fds = [stdout_master, stderr_master]
161
+ while fds:
162
+ rlist, _, _ = select.select(fds, [], [])
163
+ for fd in rlist:
164
+ try:
165
+ data = os.read(fd, 1024)
166
+ except OSError:
167
+ data = None
168
+
169
+ if not data:
170
+ os.close(fd)
171
+ fds.remove(fd)
172
+ continue
173
+
174
+ data = re.sub(cursor_commands_regex, b"", data)
175
+
176
+ if fd == stdout_master:
177
+ if return_stdout:
178
+ stdout_data += data
179
+ if tty:
180
+ os.write(pty.STDOUT_FILENO, data)
181
+ elif fd == stderr_master:
182
+ if return_stderr:
183
+ stderr_data += data
184
+ if tty:
185
+ os.write(pty.STDERR_FILENO, data)
186
+ else:
187
+ raise RuntimeError("Unexpected file descriptor")
188
+
189
+ stdout = stdout_data.decode(errors="replace") if stdout_data else None
190
+ stderr = stderr_data.decode(errors="replace") if stderr_data else None
191
+
192
+ retcode = process.poll()
193
+ assert retcode is not None
194
+
195
+ if check and retcode:
196
+ raise subprocess.CalledProcessError(retcode, process.args)
197
+ return subprocess.CompletedProcess(
198
+ process.args,
199
+ retcode,
200
+ stdout=stdout,
201
+ stderr=stderr,
202
+ )
xm_slurm/resources.py ADDED
@@ -0,0 +1,150 @@
1
+ import enum
2
+ import itertools
3
+ import math
4
+ from typing import Mapping
5
+
6
+ import immutabledict
7
+
8
+ from xm_slurm import config
9
+
10
+
11
+ class ResourceType(enum.IntEnum):
12
+ CPU = 1
13
+
14
+ MEMORY = 2
15
+ RAM = 2
16
+
17
+ EPHEMERAL_STORAGE = 3
18
+ DISK = 3
19
+
20
+ GPU = 1000
21
+ RTX8000 = 1001
22
+ P4 = 1010
23
+
24
+ P100 = 1011
25
+ P100_16GIB = 1012
26
+
27
+ V100 = 1020
28
+ V100_32GIB = 1021
29
+
30
+ A100 = 1030
31
+ A100_80GIB = 1031
32
+ A5000 = 1032
33
+ A6000 = 1033
34
+
35
+ H100 = 1040
36
+
37
+
38
+ AcceleratorType = set([
39
+ ResourceType.P4,
40
+ ResourceType.P100,
41
+ ResourceType.V100,
42
+ ResourceType.A100,
43
+ ResourceType.A100_80GIB,
44
+ ResourceType.A6000,
45
+ ResourceType.H100,
46
+ ResourceType.GPU,
47
+ ])
48
+
49
+
50
+ ResourceQuantity = int | float
51
+
52
+
53
+ class JobRequirements:
54
+ replicas: int
55
+ location: str | None
56
+ accelerator: ResourceType | None
57
+ cluster: config.SlurmClusterConfig | None = None
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ resources: Mapping[ResourceType | str, ResourceQuantity] = immutabledict.immutabledict(),
63
+ replicas: int = 1,
64
+ location: str | None = None,
65
+ cluster: config.SlurmClusterConfig | None = None,
66
+ **kw_resources: ResourceQuantity,
67
+ ):
68
+ self.replicas = replicas or 1
69
+ self.location = location
70
+ self.accelerator = None
71
+ self.cluster = cluster
72
+
73
+ self.task_requirements: dict[ResourceType | str, ResourceQuantity] = {}
74
+ for resource_name, value in itertools.chain(resources.items(), kw_resources.items()):
75
+ match resource_name:
76
+ case str() if resource_name.upper() in ResourceType.__members__:
77
+ resource = ResourceType[resource_name.upper()]
78
+ case ResourceType():
79
+ resource = resource_name
80
+ case str():
81
+ resource = resource_name
82
+
83
+ if resource in AcceleratorType or (
84
+ isinstance(resource, str) and resource.startswith("gpu")
85
+ ):
86
+ if self.accelerator is not None:
87
+ raise ValueError("Accelerator already set.")
88
+ self.accelerator = resource # type: ignore
89
+
90
+ if resource in self.task_requirements:
91
+ raise ValueError(f"{resource} has been specified twice.")
92
+ self.task_requirements[resource] = value
93
+
94
+ def to_directives(self) -> list[str]:
95
+ directives = []
96
+
97
+ for resource, value in self.task_requirements.items():
98
+ match resource:
99
+ case ResourceType.EPHEMERAL_STORAGE | ResourceType.DISK:
100
+ assert isinstance(value, int), "Disk space must be an integer"
101
+ directives.append(f"--tmp={math.ceil(value / 2**20)}M")
102
+ case ResourceType.MEMORY | ResourceType.RAM:
103
+ num_cpus = self.task_requirements.get(ResourceType.CPU, 1)
104
+ assert isinstance(value, (int, float)), "Memory must be an integer or float"
105
+ assert isinstance(num_cpus, int), "CPU must be an integer"
106
+ directives.append(f"--mem-per-cpu={math.ceil(value / num_cpus / 2**20)}M")
107
+ case ResourceType.CPU:
108
+ assert isinstance(value, int), "CPU must be an integer"
109
+ directives.append(f"--cpus-per-task={value}")
110
+ case ResourceType.GPU:
111
+ assert isinstance(value, int), "GPU must be an integer"
112
+ directives.append(f"--gpus-per-task={value}")
113
+ case ResourceType() if resource in AcceleratorType:
114
+ assert isinstance(value, int), "Accelerator must be an integer"
115
+ directives.append(f"--gpus-per-task={resource.name.lower()}:{value}")
116
+ case str():
117
+ directives.append(f"--gres={resource}:{value}")
118
+
119
+ directives.append(f"--ntasks={self.replicas}")
120
+ if self.location:
121
+ directives.append(f"--nodelist={self.location}")
122
+
123
+ return directives
124
+
125
+ def replace(
126
+ self,
127
+ cluster: config.SlurmClusterConfig | None,
128
+ **kw_resources: ResourceQuantity,
129
+ ) -> "JobRequirements":
130
+ return JobRequirements(
131
+ resources=self.task_requirements | kw_resources, # type: ignore
132
+ replicas=self.replicas,
133
+ cluster=cluster or self.cluster,
134
+ )
135
+
136
+ def __repr__(self) -> str:
137
+ args = []
138
+
139
+ for resource, value in self.task_requirements.items():
140
+ if isinstance(resource, ResourceType):
141
+ resource = resource.name
142
+ args.append(f"{resource.lower()}={value!r}")
143
+
144
+ if self.replicas != 1:
145
+ args.append(f"replicas={self.replicas}")
146
+
147
+ if self.cluster is not None:
148
+ args.append(f"cluster={self.cluster!r}")
149
+
150
+ return f'xm_slurm.JobRequirements({", ".join(args)})'