xmanager-slurm 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +44 -0
- xm_slurm/api.py +261 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +162 -0
- xm_slurm/console.py +3 -0
- xm_slurm/contrib/clusters/__init__.py +52 -0
- xm_slurm/contrib/clusters/drac.py +169 -0
- xm_slurm/executables.py +201 -0
- xm_slurm/execution.py +491 -0
- xm_slurm/executors.py +127 -0
- xm_slurm/experiment.py +737 -0
- xm_slurm/job_blocks.py +14 -0
- xm_slurm/packageables.py +292 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker/__init__.py +75 -0
- xm_slurm/packaging/docker/abc.py +112 -0
- xm_slurm/packaging/docker/cloud.py +503 -0
- xm_slurm/packaging/docker/local.py +206 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +52 -0
- xm_slurm/packaging/utils.py +202 -0
- xm_slurm/resources.py +150 -0
- xm_slurm/status.py +188 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
- xm_slurm/templates/docker/mamba.Dockerfile +27 -0
- xm_slurm/templates/docker/pdm.Dockerfile +31 -0
- xm_slurm/templates/docker/python.Dockerfile +24 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
- xm_slurm/templates/slurm/job.bash.j2 +78 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
- xm_slurm/utils.py +69 -0
- xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
- xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
- xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import dataclasses
|
|
3
|
+
import enum
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
import shlex
|
|
9
|
+
import shutil
|
|
10
|
+
import subprocess
|
|
11
|
+
import tempfile
|
|
12
|
+
import typing
|
|
13
|
+
from typing import Sequence
|
|
14
|
+
|
|
15
|
+
from xmanager import xm
|
|
16
|
+
|
|
17
|
+
from xm_slurm.executables import (
|
|
18
|
+
Dockerfile,
|
|
19
|
+
ImageURI,
|
|
20
|
+
RemoteImage,
|
|
21
|
+
RemoteRepositoryCredentials,
|
|
22
|
+
)
|
|
23
|
+
from xm_slurm.executors import SlurmSpec
|
|
24
|
+
from xm_slurm.packaging import utils as packaging_utils
|
|
25
|
+
from xm_slurm.packaging.docker.abc import (
|
|
26
|
+
DockerBakeCommand,
|
|
27
|
+
DockerClient,
|
|
28
|
+
DockerVersionCommand,
|
|
29
|
+
)
|
|
30
|
+
from xm_slurm.packaging.registry import IndexedContainer
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LocalDockerClient(DockerClient):
|
|
34
|
+
"""Build Docker images locally."""
|
|
35
|
+
|
|
36
|
+
class Builder(enum.Enum):
|
|
37
|
+
BUILDKIT = enum.auto()
|
|
38
|
+
BUILDAH = enum.auto()
|
|
39
|
+
|
|
40
|
+
def __init__(self):
|
|
41
|
+
if "XM_DOCKER_CLIENT" in os.environ:
|
|
42
|
+
client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
|
|
43
|
+
elif shutil.which("docker"):
|
|
44
|
+
client_call = ["docker"]
|
|
45
|
+
elif shutil.which("podman"):
|
|
46
|
+
client_call = ["podman"]
|
|
47
|
+
else:
|
|
48
|
+
raise RuntimeError("No Docker client found.")
|
|
49
|
+
self._client_call = client_call
|
|
50
|
+
|
|
51
|
+
version_command = DockerVersionCommand()
|
|
52
|
+
backend_version = packaging_utils.run_command(
|
|
53
|
+
xm.merge_args(self._client_call, version_command.to_args()), return_stdout=True
|
|
54
|
+
)
|
|
55
|
+
if backend_version.stdout.startswith("github.com/docker/buildx"):
|
|
56
|
+
self._builder = LocalDockerClient.Builder.BUILDKIT
|
|
57
|
+
else:
|
|
58
|
+
raise NotImplementedError(f"Unsupported Docker build backend: {backend_version}")
|
|
59
|
+
|
|
60
|
+
self._credentials_cache: dict[str, RemoteRepositoryCredentials] = {}
|
|
61
|
+
|
|
62
|
+
def credentials(self, hostname: str) -> RemoteRepositoryCredentials | None:
|
|
63
|
+
"""Fetch credentials from the local Docker configuration."""
|
|
64
|
+
if hostname in self._credentials_cache:
|
|
65
|
+
return self._credentials_cache[hostname]
|
|
66
|
+
|
|
67
|
+
def _parse_docker_credentials(helper: str) -> RemoteRepositoryCredentials | None:
|
|
68
|
+
"""Parse credentials from a Docker credential helper."""
|
|
69
|
+
if not shutil.which(f"docker-credential-{helper}"):
|
|
70
|
+
return None
|
|
71
|
+
returncode, output = subprocess.getstatusoutput(
|
|
72
|
+
f"echo {hostname} | docker-credential-{helper} get",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if returncode == 0:
|
|
76
|
+
credentials = json.loads(output)
|
|
77
|
+
return RemoteRepositoryCredentials(
|
|
78
|
+
username=credentials["Username"], password=credentials["Secret"]
|
|
79
|
+
)
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
def _parse_credentials_from_config(
|
|
83
|
+
config_path: pathlib.Path
|
|
84
|
+
) -> RemoteRepositoryCredentials | None:
|
|
85
|
+
"""Parse credentials from the Docker configuration file."""
|
|
86
|
+
if not config_path.exists():
|
|
87
|
+
return None
|
|
88
|
+
config = json.loads(config_path.read_text())
|
|
89
|
+
|
|
90
|
+
# Attempt to parse from the global credential store
|
|
91
|
+
if (creds_store := config.get("credsStore", None)) and (
|
|
92
|
+
credentials := _parse_docker_credentials(creds_store)
|
|
93
|
+
):
|
|
94
|
+
self._credentials_cache[hostname] = credentials
|
|
95
|
+
return credentials
|
|
96
|
+
# Attempt to parse from the credential helper for this registry
|
|
97
|
+
if creds_helper := config.get("credHelpers", {}):
|
|
98
|
+
for registry, helper in creds_helper.items():
|
|
99
|
+
registry = ImageURI(registry)
|
|
100
|
+
if registry.domain == hostname and (
|
|
101
|
+
credentials := _parse_docker_credentials(helper)
|
|
102
|
+
):
|
|
103
|
+
self._credentials_cache[hostname] = credentials
|
|
104
|
+
return credentials
|
|
105
|
+
# Last resort: attempt to parse raw auth
|
|
106
|
+
if auths := config.get("auths", None):
|
|
107
|
+
for registry, metadata in auths.items():
|
|
108
|
+
registry = ImageURI(registry)
|
|
109
|
+
if registry.domain == hostname:
|
|
110
|
+
auth = base64.b64decode(metadata["auth"]).decode("utf-8")
|
|
111
|
+
username, password = auth.split(":")
|
|
112
|
+
credentials = RemoteRepositoryCredentials(username, password)
|
|
113
|
+
self._credentials_cache[hostname] = credentials
|
|
114
|
+
return credentials
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# Attempt to parse credentials from the Docker or Podman configuration
|
|
118
|
+
match self._builder:
|
|
119
|
+
case LocalDockerClient.Builder.BUILDKIT:
|
|
120
|
+
docker_config_path = (
|
|
121
|
+
pathlib.Path(os.environ.get("DOCKER_CONFIG", "~/.docker")).expanduser()
|
|
122
|
+
/ "config.json"
|
|
123
|
+
)
|
|
124
|
+
return _parse_credentials_from_config(docker_config_path)
|
|
125
|
+
case LocalDockerClient.Builder.BUILDAH:
|
|
126
|
+
podman_config_path = (
|
|
127
|
+
pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")).expanduser()
|
|
128
|
+
/ "containers"
|
|
129
|
+
/ "auth.json"
|
|
130
|
+
)
|
|
131
|
+
return _parse_credentials_from_config(podman_config_path)
|
|
132
|
+
case _:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
def bake(
|
|
136
|
+
self,
|
|
137
|
+
*,
|
|
138
|
+
targets: Sequence[IndexedContainer[xm.Packageable]],
|
|
139
|
+
) -> list[IndexedContainer[RemoteImage]]:
|
|
140
|
+
executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
|
|
141
|
+
executors_by_executables = typing.cast(
|
|
142
|
+
dict[Dockerfile, list[SlurmSpec]], executors_by_executables
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
with tempfile.TemporaryDirectory() as tempdir:
|
|
146
|
+
hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
|
|
147
|
+
metadata_file = pathlib.Path(tempdir) / "metadata.json"
|
|
148
|
+
|
|
149
|
+
# Write HCL and bake it
|
|
150
|
+
hcl = self._bake_template.render(
|
|
151
|
+
executables=executors_by_executables,
|
|
152
|
+
hash=packaging_utils.hash_digest,
|
|
153
|
+
)
|
|
154
|
+
hcl_file.write_text(hcl)
|
|
155
|
+
logging.debug(hcl)
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
command = DockerBakeCommand(
|
|
159
|
+
targets=list(
|
|
160
|
+
set(
|
|
161
|
+
[
|
|
162
|
+
packaging_utils.hash_digest(target.value.executable_spec)
|
|
163
|
+
for target in targets
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
),
|
|
167
|
+
files=[hcl_file],
|
|
168
|
+
metadata_file=metadata_file,
|
|
169
|
+
pull=True,
|
|
170
|
+
push=True,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
bake_command = xm.merge_args(self._client_call, command.to_args())
|
|
174
|
+
packaging_utils.run_command(bake_command.to_list(), tty=True, check=True)
|
|
175
|
+
except Exception as ex:
|
|
176
|
+
raise RuntimeError(f"Failed to build Dockerfiles: {ex}") from ex
|
|
177
|
+
else:
|
|
178
|
+
metadata = json.loads(metadata_file.read_text())
|
|
179
|
+
|
|
180
|
+
images = []
|
|
181
|
+
for target in targets:
|
|
182
|
+
assert isinstance(target.value.executable_spec, Dockerfile)
|
|
183
|
+
assert isinstance(target.value.executor_spec, SlurmSpec)
|
|
184
|
+
assert target.value.executor_spec.tag
|
|
185
|
+
|
|
186
|
+
executable_metadata = metadata[
|
|
187
|
+
packaging_utils.hash_digest(target.value.executable_spec)
|
|
188
|
+
]
|
|
189
|
+
uri = ImageURI(target.value.executor_spec.tag).with_digest(
|
|
190
|
+
executable_metadata["containerimage.digest"]
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
images.append(
|
|
194
|
+
dataclasses.replace(
|
|
195
|
+
target,
|
|
196
|
+
value=RemoteImage( # type: ignore
|
|
197
|
+
image=str(uri),
|
|
198
|
+
workdir=target.value.executable_spec.workdir,
|
|
199
|
+
args=target.value.args,
|
|
200
|
+
env_vars=target.value.env_vars,
|
|
201
|
+
credentials=self.credentials(uri.domain),
|
|
202
|
+
),
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return images
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Callable, Generic, ParamSpec, Sequence, Type, TypeVar
|
|
3
|
+
|
|
4
|
+
from xmanager import xm
|
|
5
|
+
|
|
6
|
+
T_co = TypeVar("T_co", covariant=True)
|
|
7
|
+
P = ParamSpec("P")
|
|
8
|
+
ExecutableSpecT = TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclasses.dataclass(frozen=True)
|
|
12
|
+
class IndexedContainer(Generic[T_co]):
|
|
13
|
+
index: int
|
|
14
|
+
value: T_co
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
RegistrationCallable = Callable[
|
|
18
|
+
[Sequence[IndexedContainer[xm.Packageable]]],
|
|
19
|
+
Sequence[IndexedContainer[xm.Executable]],
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_REGISTRY: dict[Type[xm.ExecutableSpec], RegistrationCallable] = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def register(
|
|
27
|
+
*typs: Type[ExecutableSpecT],
|
|
28
|
+
) -> Callable[[RegistrationCallable], RegistrationCallable]:
|
|
29
|
+
def decorator(
|
|
30
|
+
registration_callable: RegistrationCallable,
|
|
31
|
+
) -> RegistrationCallable:
|
|
32
|
+
global _REGISTRY
|
|
33
|
+
for typ in typs:
|
|
34
|
+
_REGISTRY[typ] = registration_callable
|
|
35
|
+
return registration_callable
|
|
36
|
+
|
|
37
|
+
return decorator
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def route(
|
|
41
|
+
typ: Type[ExecutableSpecT],
|
|
42
|
+
packageables: Sequence[IndexedContainer[xm.Packageable]],
|
|
43
|
+
) -> Sequence[IndexedContainer[xm.Executable]]:
|
|
44
|
+
global _REGISTRY
|
|
45
|
+
return _REGISTRY[typ](packageables)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Sequence, Type
|
|
4
|
+
|
|
5
|
+
from xmanager import xm
|
|
6
|
+
|
|
7
|
+
from xm_slurm.console import console
|
|
8
|
+
from xm_slurm.executors import SlurmSpec
|
|
9
|
+
from xm_slurm.packaging import registry
|
|
10
|
+
|
|
11
|
+
IndexedContainer = registry.IndexedContainer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def package(
|
|
15
|
+
packageables: Sequence[xm.Packageable],
|
|
16
|
+
) -> list[xm.Executable]:
|
|
17
|
+
"""
|
|
18
|
+
Takes as input a list of packageables and returns a mapping of
|
|
19
|
+
`DockerTarget`'s to the latest digest of that image.
|
|
20
|
+
"""
|
|
21
|
+
# Docker targets to be collected.
|
|
22
|
+
# These are a mapping from `DockerTarget` to the latest digest of the image.
|
|
23
|
+
targets_by_type = collections.defaultdict[
|
|
24
|
+
Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
|
|
25
|
+
](list)
|
|
26
|
+
|
|
27
|
+
# Collect dockerfiles that need to be built locally
|
|
28
|
+
for index, packageable in enumerate(packageables):
|
|
29
|
+
if not isinstance(packageable.executor_spec, SlurmSpec):
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Unsupported executor spec for packageable: {packageable}."
|
|
32
|
+
"xm_slurm only supports `xm_slurm.SlurmSpec`."
|
|
33
|
+
)
|
|
34
|
+
targets_by_type[type(packageable.executable_spec)].append(
|
|
35
|
+
IndexedContainer[xm.Packageable](index, packageable)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
targets: list[IndexedContainer[xm.Executable]] = []
|
|
39
|
+
# TODO(jfarebro): Could make this async as well...?
|
|
40
|
+
with console.status("[magenta] :package: Packaging executables..."):
|
|
41
|
+
for executable_spec_type, targets_for_type in targets_by_type.items():
|
|
42
|
+
logging.info(f"Packaging {len(targets_for_type)} {executable_spec_type!r} targets.")
|
|
43
|
+
targets.extend(registry.route(executable_spec_type, targets_for_type))
|
|
44
|
+
|
|
45
|
+
console.print(
|
|
46
|
+
f"[magenta]:package: Finished packaging [bold]{len(targets)} executable"
|
|
47
|
+
f"{'s' if len(targets) > 1 else ''}[/bold]."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
assert len(targets) == len(packageables), "Number of targets must match packageables"
|
|
51
|
+
targets = sorted(targets, key=lambda t: t.index)
|
|
52
|
+
return [target.value for target in targets]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import concurrent.futures
|
|
3
|
+
import functools
|
|
4
|
+
import hashlib
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
import pty
|
|
9
|
+
import re
|
|
10
|
+
import select
|
|
11
|
+
import shutil
|
|
12
|
+
import subprocess
|
|
13
|
+
import typing
|
|
14
|
+
from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
|
|
15
|
+
|
|
16
|
+
from xmanager import xm
|
|
17
|
+
|
|
18
|
+
from xm_slurm.packaging.registry import IndexedContainer
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
P = ParamSpec("P")
|
|
22
|
+
ReturnT = TypeVar("ReturnT")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def hash_digest(obj: Hashable) -> str:
|
|
26
|
+
# obj_hash = hash(obj)
|
|
27
|
+
# unsigned_obj_hash = obj_hash.from_bytes(
|
|
28
|
+
# obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
|
|
29
|
+
# "big",
|
|
30
|
+
# signed=False,
|
|
31
|
+
# )
|
|
32
|
+
# return hex(unsigned_obj_hash).removeprefix("0x")
|
|
33
|
+
return hashlib.sha256(repr(obj).encode()).hexdigest()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def collect_executors_by_executable(
|
|
37
|
+
targets: Sequence[IndexedContainer[xm.Packageable]],
|
|
38
|
+
) -> dict[xm.ExecutableSpec, set[xm.ExecutorSpec]]:
|
|
39
|
+
executors_by_executable = collections.defaultdict(set)
|
|
40
|
+
for target in targets:
|
|
41
|
+
executors_by_executable[target.value.executable_spec].add(target.value.executor_spec)
|
|
42
|
+
return executors_by_executable
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def parallel_map(
|
|
46
|
+
f: Callable[Concatenate[T, P], ReturnT],
|
|
47
|
+
) -> Callable[Concatenate[Sequence[T], P], list[ReturnT]]:
|
|
48
|
+
@functools.wraps(f)
|
|
49
|
+
def decorator(sequence: Sequence[T], *args: P.args, **kwargs: P.kwargs) -> list[ReturnT]:
|
|
50
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
51
|
+
return list(executor.map(lambda x: f(x, *args, **kwargs), sequence))
|
|
52
|
+
|
|
53
|
+
return decorator
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Cursor commands to filter out from the command data stream
|
|
57
|
+
cursor_commands_regex = re.compile(
|
|
58
|
+
rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
|
|
59
|
+
rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
|
|
60
|
+
rb"|\x1b\[s" # Matches cursor save position (CSI s)
|
|
61
|
+
rb"|\x1b\[u" # Matches cursor restore position (CSI u)
|
|
62
|
+
rb"|\x1b\[2J" # Matches clear screen (CSI 2J)
|
|
63
|
+
rb"|\x1b\[K" # Matches clear line (CSI K)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@typing.overload
|
|
68
|
+
def run_command(
|
|
69
|
+
args: Sequence[str] | xm.SequentialArgs,
|
|
70
|
+
env: dict[str, str] | None = ...,
|
|
71
|
+
tty: bool = ...,
|
|
72
|
+
cwd: str | os.PathLike[str] | None = ...,
|
|
73
|
+
check: bool = ...,
|
|
74
|
+
return_stdout: Literal[False] = False,
|
|
75
|
+
return_stderr: Literal[False] = False,
|
|
76
|
+
) -> subprocess.CompletedProcess[None]: ...
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@typing.overload
|
|
80
|
+
def run_command(
|
|
81
|
+
args: Sequence[str] | xm.SequentialArgs,
|
|
82
|
+
env: dict[str, str] | None = ...,
|
|
83
|
+
tty: bool = ...,
|
|
84
|
+
cwd: str | os.PathLike[str] | None = ...,
|
|
85
|
+
check: bool = ...,
|
|
86
|
+
return_stdout: Literal[True] = True,
|
|
87
|
+
return_stderr: Literal[False] = False,
|
|
88
|
+
) -> subprocess.CompletedProcess[str]: ...
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@typing.overload
|
|
92
|
+
def run_command(
|
|
93
|
+
args: Sequence[str] | xm.SequentialArgs,
|
|
94
|
+
env: dict[str, str] | None = ...,
|
|
95
|
+
tty: bool = ...,
|
|
96
|
+
cwd: str | os.PathLike[str] | None = ...,
|
|
97
|
+
check: bool = ...,
|
|
98
|
+
return_stdout: Literal[False] = False,
|
|
99
|
+
return_stderr: Literal[True] = True,
|
|
100
|
+
) -> subprocess.CompletedProcess[str]: ...
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@typing.overload
|
|
104
|
+
def run_command(
|
|
105
|
+
args: Sequence[str] | xm.SequentialArgs,
|
|
106
|
+
env: dict[str, str] | None = ...,
|
|
107
|
+
tty: bool = ...,
|
|
108
|
+
cwd: str | os.PathLike[str] | None = ...,
|
|
109
|
+
check: bool = ...,
|
|
110
|
+
return_stdout: Literal[True] = True,
|
|
111
|
+
return_stderr: Literal[True] = True,
|
|
112
|
+
) -> subprocess.CompletedProcess[str]: ...
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def run_command(
|
|
116
|
+
args: Sequence[str] | xm.SequentialArgs,
|
|
117
|
+
env: dict[str, str] | None = None,
|
|
118
|
+
tty: bool = False,
|
|
119
|
+
cwd: str | os.PathLike[str] | None = None,
|
|
120
|
+
check: bool = False,
|
|
121
|
+
return_stdout: bool = False,
|
|
122
|
+
return_stderr: bool = False,
|
|
123
|
+
) -> subprocess.CompletedProcess[str] | subprocess.CompletedProcess[None]:
|
|
124
|
+
if isinstance(args, xm.SequentialArgs):
|
|
125
|
+
args = args.to_list()
|
|
126
|
+
args = list(args)
|
|
127
|
+
|
|
128
|
+
executable = shutil.which(args[0])
|
|
129
|
+
if not executable:
|
|
130
|
+
raise RuntimeError(f"Couldn't find executable {args[0]}")
|
|
131
|
+
executable = pathlib.Path(executable)
|
|
132
|
+
|
|
133
|
+
subprocess_env = os.environ.copy() | (env if env else {})
|
|
134
|
+
if executable.name == "docker" and args[1] == "buildx":
|
|
135
|
+
subprocess_env |= {"DOCKER_CLI_EXPERIMENTAL": "enabled"}
|
|
136
|
+
|
|
137
|
+
logging.debug(f"env: {subprocess_env}")
|
|
138
|
+
logging.debug(f"command: {' '.join(args)}")
|
|
139
|
+
|
|
140
|
+
stdout_master, stdout_slave = pty.openpty()
|
|
141
|
+
stderr_master, stderr_slave = pty.openpty()
|
|
142
|
+
|
|
143
|
+
stdout_data, stderr_data = b"", b""
|
|
144
|
+
with subprocess.Popen(
|
|
145
|
+
executable=executable,
|
|
146
|
+
args=args,
|
|
147
|
+
shell=False,
|
|
148
|
+
text=True,
|
|
149
|
+
bufsize=0,
|
|
150
|
+
stdout=stdout_slave,
|
|
151
|
+
stderr=stderr_slave,
|
|
152
|
+
start_new_session=True,
|
|
153
|
+
close_fds=True,
|
|
154
|
+
cwd=cwd,
|
|
155
|
+
env=subprocess_env,
|
|
156
|
+
) as process:
|
|
157
|
+
os.close(stdout_slave)
|
|
158
|
+
os.close(stderr_slave)
|
|
159
|
+
|
|
160
|
+
fds = [stdout_master, stderr_master]
|
|
161
|
+
while fds:
|
|
162
|
+
rlist, _, _ = select.select(fds, [], [])
|
|
163
|
+
for fd in rlist:
|
|
164
|
+
try:
|
|
165
|
+
data = os.read(fd, 1024)
|
|
166
|
+
except OSError:
|
|
167
|
+
data = None
|
|
168
|
+
|
|
169
|
+
if not data:
|
|
170
|
+
os.close(fd)
|
|
171
|
+
fds.remove(fd)
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
data = re.sub(cursor_commands_regex, b"", data)
|
|
175
|
+
|
|
176
|
+
if fd == stdout_master:
|
|
177
|
+
if return_stdout:
|
|
178
|
+
stdout_data += data
|
|
179
|
+
if tty:
|
|
180
|
+
os.write(pty.STDOUT_FILENO, data)
|
|
181
|
+
elif fd == stderr_master:
|
|
182
|
+
if return_stderr:
|
|
183
|
+
stderr_data += data
|
|
184
|
+
if tty:
|
|
185
|
+
os.write(pty.STDERR_FILENO, data)
|
|
186
|
+
else:
|
|
187
|
+
raise RuntimeError("Unexpected file descriptor")
|
|
188
|
+
|
|
189
|
+
stdout = stdout_data.decode(errors="replace") if stdout_data else None
|
|
190
|
+
stderr = stderr_data.decode(errors="replace") if stderr_data else None
|
|
191
|
+
|
|
192
|
+
retcode = process.poll()
|
|
193
|
+
assert retcode is not None
|
|
194
|
+
|
|
195
|
+
if check and retcode:
|
|
196
|
+
raise subprocess.CalledProcessError(retcode, process.args)
|
|
197
|
+
return subprocess.CompletedProcess(
|
|
198
|
+
process.args,
|
|
199
|
+
retcode,
|
|
200
|
+
stdout=stdout,
|
|
201
|
+
stderr=stderr,
|
|
202
|
+
)
|
xm_slurm/resources.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import itertools
|
|
3
|
+
import math
|
|
4
|
+
from typing import Mapping
|
|
5
|
+
|
|
6
|
+
import immutabledict
|
|
7
|
+
|
|
8
|
+
from xm_slurm import config
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ResourceType(enum.IntEnum):
|
|
12
|
+
CPU = 1
|
|
13
|
+
|
|
14
|
+
MEMORY = 2
|
|
15
|
+
RAM = 2
|
|
16
|
+
|
|
17
|
+
EPHEMERAL_STORAGE = 3
|
|
18
|
+
DISK = 3
|
|
19
|
+
|
|
20
|
+
GPU = 1000
|
|
21
|
+
RTX8000 = 1001
|
|
22
|
+
P4 = 1010
|
|
23
|
+
|
|
24
|
+
P100 = 1011
|
|
25
|
+
P100_16GIB = 1012
|
|
26
|
+
|
|
27
|
+
V100 = 1020
|
|
28
|
+
V100_32GIB = 1021
|
|
29
|
+
|
|
30
|
+
A100 = 1030
|
|
31
|
+
A100_80GIB = 1031
|
|
32
|
+
A5000 = 1032
|
|
33
|
+
A6000 = 1033
|
|
34
|
+
|
|
35
|
+
H100 = 1040
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
AcceleratorType = set([
|
|
39
|
+
ResourceType.P4,
|
|
40
|
+
ResourceType.P100,
|
|
41
|
+
ResourceType.V100,
|
|
42
|
+
ResourceType.A100,
|
|
43
|
+
ResourceType.A100_80GIB,
|
|
44
|
+
ResourceType.A6000,
|
|
45
|
+
ResourceType.H100,
|
|
46
|
+
ResourceType.GPU,
|
|
47
|
+
])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
ResourceQuantity = int | float
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class JobRequirements:
|
|
54
|
+
replicas: int
|
|
55
|
+
location: str | None
|
|
56
|
+
accelerator: ResourceType | None
|
|
57
|
+
cluster: config.SlurmClusterConfig | None = None
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
*,
|
|
62
|
+
resources: Mapping[ResourceType | str, ResourceQuantity] = immutabledict.immutabledict(),
|
|
63
|
+
replicas: int = 1,
|
|
64
|
+
location: str | None = None,
|
|
65
|
+
cluster: config.SlurmClusterConfig | None = None,
|
|
66
|
+
**kw_resources: ResourceQuantity,
|
|
67
|
+
):
|
|
68
|
+
self.replicas = replicas or 1
|
|
69
|
+
self.location = location
|
|
70
|
+
self.accelerator = None
|
|
71
|
+
self.cluster = cluster
|
|
72
|
+
|
|
73
|
+
self.task_requirements: dict[ResourceType | str, ResourceQuantity] = {}
|
|
74
|
+
for resource_name, value in itertools.chain(resources.items(), kw_resources.items()):
|
|
75
|
+
match resource_name:
|
|
76
|
+
case str() if resource_name.upper() in ResourceType.__members__:
|
|
77
|
+
resource = ResourceType[resource_name.upper()]
|
|
78
|
+
case ResourceType():
|
|
79
|
+
resource = resource_name
|
|
80
|
+
case str():
|
|
81
|
+
resource = resource_name
|
|
82
|
+
|
|
83
|
+
if resource in AcceleratorType or (
|
|
84
|
+
isinstance(resource, str) and resource.startswith("gpu")
|
|
85
|
+
):
|
|
86
|
+
if self.accelerator is not None:
|
|
87
|
+
raise ValueError("Accelerator already set.")
|
|
88
|
+
self.accelerator = resource # type: ignore
|
|
89
|
+
|
|
90
|
+
if resource in self.task_requirements:
|
|
91
|
+
raise ValueError(f"{resource} has been specified twice.")
|
|
92
|
+
self.task_requirements[resource] = value
|
|
93
|
+
|
|
94
|
+
def to_directives(self) -> list[str]:
|
|
95
|
+
directives = []
|
|
96
|
+
|
|
97
|
+
for resource, value in self.task_requirements.items():
|
|
98
|
+
match resource:
|
|
99
|
+
case ResourceType.EPHEMERAL_STORAGE | ResourceType.DISK:
|
|
100
|
+
assert isinstance(value, int), "Disk space must be an integer"
|
|
101
|
+
directives.append(f"--tmp={math.ceil(value / 2**20)}M")
|
|
102
|
+
case ResourceType.MEMORY | ResourceType.RAM:
|
|
103
|
+
num_cpus = self.task_requirements.get(ResourceType.CPU, 1)
|
|
104
|
+
assert isinstance(value, (int, float)), "Memory must be an integer or float"
|
|
105
|
+
assert isinstance(num_cpus, int), "CPU must be an integer"
|
|
106
|
+
directives.append(f"--mem-per-cpu={math.ceil(value / num_cpus / 2**20)}M")
|
|
107
|
+
case ResourceType.CPU:
|
|
108
|
+
assert isinstance(value, int), "CPU must be an integer"
|
|
109
|
+
directives.append(f"--cpus-per-task={value}")
|
|
110
|
+
case ResourceType.GPU:
|
|
111
|
+
assert isinstance(value, int), "GPU must be an integer"
|
|
112
|
+
directives.append(f"--gpus-per-task={value}")
|
|
113
|
+
case ResourceType() if resource in AcceleratorType:
|
|
114
|
+
assert isinstance(value, int), "Accelerator must be an integer"
|
|
115
|
+
directives.append(f"--gpus-per-task={resource.name.lower()}:{value}")
|
|
116
|
+
case str():
|
|
117
|
+
directives.append(f"--gres={resource}:{value}")
|
|
118
|
+
|
|
119
|
+
directives.append(f"--ntasks={self.replicas}")
|
|
120
|
+
if self.location:
|
|
121
|
+
directives.append(f"--nodelist={self.location}")
|
|
122
|
+
|
|
123
|
+
return directives
|
|
124
|
+
|
|
125
|
+
def replace(
|
|
126
|
+
self,
|
|
127
|
+
cluster: config.SlurmClusterConfig | None,
|
|
128
|
+
**kw_resources: ResourceQuantity,
|
|
129
|
+
) -> "JobRequirements":
|
|
130
|
+
return JobRequirements(
|
|
131
|
+
resources=self.task_requirements | kw_resources, # type: ignore
|
|
132
|
+
replicas=self.replicas,
|
|
133
|
+
cluster=cluster or self.cluster,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def __repr__(self) -> str:
|
|
137
|
+
args = []
|
|
138
|
+
|
|
139
|
+
for resource, value in self.task_requirements.items():
|
|
140
|
+
if isinstance(resource, ResourceType):
|
|
141
|
+
resource = resource.name
|
|
142
|
+
args.append(f"{resource.lower()}={value!r}")
|
|
143
|
+
|
|
144
|
+
if self.replicas != 1:
|
|
145
|
+
args.append(f"replicas={self.replicas}")
|
|
146
|
+
|
|
147
|
+
if self.cluster is not None:
|
|
148
|
+
args.append(f"cluster={self.cluster!r}")
|
|
149
|
+
|
|
150
|
+
return f'xm_slurm.JobRequirements({", ".join(args)})'
|