xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
@@ -0,0 +1,309 @@
1
+ import importlib.resources as resources
2
+ import pathlib
3
+ import sys
4
+ import typing as tp
5
+
6
+ from xmanager import xm
7
+
8
+ from xm_slurm import job_blocks, utils
9
+ from xm_slurm.executables import Dockerfile, DockerImage
10
+ from xm_slurm.executors import SlurmSpec
11
+
12
+
13
+ def docker_image(
14
+ *,
15
+ image: str,
16
+ args: xm.UserArgs | None = None,
17
+ env_vars: tp.Mapping[str, str] | None = None,
18
+ ) -> xm.Packageable:
19
+ """Creates a packageable for a pre-built Docker image.
20
+
21
+ Args:
22
+ image: The remote image URI.
23
+ args: The user arguments to pass to the executable.
24
+ env_vars: The environment variables to pass to the executable.
25
+
26
+ Returns: A packageable for a pre-built Docker image.
27
+ """
28
+ return xm.Packageable(
29
+ executor_spec=SlurmSpec(),
30
+ executable_spec=DockerImage(image=image),
31
+ args=xm.SequentialArgs.from_collection(args),
32
+ env_vars=env_vars or {},
33
+ )
34
+
35
+
36
+ def docker_container(
37
+ *,
38
+ executor_spec: xm.ExecutorSpec,
39
+ dockerfile: pathlib.Path | None = None,
40
+ context: pathlib.Path | None = None,
41
+ target: str | None = None,
42
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
43
+ build_args: tp.Mapping[str, str] | None = None,
44
+ cache_from: str | tp.Sequence[str] | None = None,
45
+ labels: tp.Mapping[str, str] | None = None,
46
+ args: xm.UserArgs | None = None,
47
+ env_vars: tp.Mapping[str, str] | None = None,
48
+ ) -> xm.Packageable:
49
+ """Creates a Docker container packageable from a dockerfile.
50
+
51
+ Args:
52
+ executor_spec: The executor specification, where will the container be stored at.
53
+ dockerfile: The path to the dockerfile.
54
+ context: The path to the docker context.
55
+ target: The docker build target.
56
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
57
+ build_args: Build arguments to docker.
58
+ cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
59
+ labels: The container labels.
60
+ args: The user arguments to pass to the executable.
61
+ env_vars: The environment variables to pass to the executable.
62
+
63
+ Returns: A Docker container packageable.
64
+ """
65
+ if context is None:
66
+ context = utils.find_project_root()
67
+ context = context.resolve()
68
+ if dockerfile is None:
69
+ dockerfile = context / "Dockerfile"
70
+ dockerfile = dockerfile.resolve()
71
+
72
+ if ssh is None:
73
+ ssh = []
74
+ elif ssh is True:
75
+ ssh = ["default"]
76
+
77
+ if cache_from is None and isinstance(executor_spec, SlurmSpec):
78
+ cache_from = executor_spec.tag
79
+ if cache_from is None:
80
+ cache_from = []
81
+ elif isinstance(cache_from, str):
82
+ cache_from = [cache_from]
83
+
84
+ return xm.Packageable(
85
+ executor_spec=executor_spec,
86
+ executable_spec=Dockerfile(
87
+ dockerfile=dockerfile,
88
+ context=context,
89
+ target=target,
90
+ ssh=ssh,
91
+ build_args=build_args or {},
92
+ cache_from=cache_from,
93
+ labels=labels or {},
94
+ ),
95
+ args=xm.SequentialArgs.from_collection(args),
96
+ env_vars=env_vars or {},
97
+ )
98
+
99
+
100
+ def python_container(
101
+ *,
102
+ executor_spec: xm.ExecutorSpec,
103
+ entrypoint: xm.ModuleName | xm.CommandList,
104
+ context: pathlib.Path | None = None,
105
+ requirements: pathlib.Path | None = None,
106
+ base_image: str = "docker.io/python:{major}.{minor}-slim",
107
+ extra_system_packages: tp.Sequence[str] = (),
108
+ extra_python_packages: tp.Sequence[str] = (),
109
+ cache_from: str | tp.Sequence[str] | None = None,
110
+ labels: tp.Mapping[str, str] | None = None,
111
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
112
+ args: xm.UserArgs | None = None,
113
+ env_vars: tp.Mapping[str, str] | None = None,
114
+ ) -> xm.Packageable:
115
+ """Creates a Python container from a base image using pip from a `requirements.txt` file.
116
+
117
+ NOTE: The base image will use the Python version of the current interpreter.
118
+ NOTE: uv is used to install packages from `requirements`.
119
+
120
+ Args:
121
+ executor_spec: The executor specification, where will the container be stored at.
122
+ entrypoint: The entrypoint to run in the container.
123
+ context: The path to the docker context.
124
+ requirements: The path to the pip requirements file.
125
+ base_image: The base image to use. NOTE: The base image must contain the Python runtime.
126
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
127
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
128
+ cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
129
+ labels: The container labels.
130
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
131
+ args: The user arguments to pass to the executable.
132
+ env_vars: The environment variables to pass to the executable.
133
+
134
+ Returns: A Python container packageable.
135
+ """
136
+ entrypoint_args = job_blocks.get_args_for_python_entrypoint(entrypoint)
137
+ args = xm.merge_args(entrypoint_args, args or {})
138
+
139
+ if context is None:
140
+ context = utils.find_project_root()
141
+ context = context.resolve()
142
+ if requirements is None:
143
+ requirements = context / "requirements.txt"
144
+ requirements = requirements.resolve()
145
+ if not requirements.exists():
146
+ raise ValueError(f"Pip requirements `{requirements}` doesn't exist.")
147
+ if not requirements.is_relative_to(context):
148
+ raise ValueError(
149
+ f"Pip requirements `{requirements}` must be relative to context: `{context}`"
150
+ )
151
+
152
+ with resources.as_file(
153
+ resources.files("xm_slurm.templates").joinpath("docker/python.Dockerfile")
154
+ ) as dockerfile:
155
+ return docker_container(
156
+ executor_spec=executor_spec,
157
+ dockerfile=dockerfile,
158
+ context=context,
159
+ ssh=ssh,
160
+ build_args={
161
+ "PIP_REQUIREMENTS": requirements.relative_to(context).as_posix(),
162
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
163
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
164
+ "BASE_IMAGE": base_image.format_map({
165
+ "major": sys.version_info.major,
166
+ "minor": sys.version_info.minor,
167
+ "micro": sys.version_info.micro,
168
+ }),
169
+ },
170
+ cache_from=cache_from,
171
+ labels=labels,
172
+ args=args,
173
+ env_vars=env_vars,
174
+ )
175
+
176
+
177
+ def mamba_container(
178
+ *,
179
+ executor_spec: xm.ExecutorSpec,
180
+ entrypoint: xm.ModuleName | xm.CommandList,
181
+ context: pathlib.Path | None = None,
182
+ environment: pathlib.Path | None = None,
183
+ base_image: str = "gcr.io/distroless/base-debian10",
184
+ cache_from: str | tp.Sequence[str] | None = None,
185
+ labels: tp.Mapping[str, str] | None = None,
186
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
187
+ args: xm.UserArgs | None = None,
188
+ env_vars: tp.Mapping[str, str] | None = None,
189
+ ) -> xm.Packageable:
190
+ """Creates a Conda container from a base image using mamba from a `environment.yml` file.
191
+
192
+ Note: The base image *doesn't* need to contain the Python runtime.
193
+
194
+ Args:
195
+ executor_spec: The executor specification, where will the container be stored at.
196
+ entrypoint: The entrypoint to run in the container.
197
+ context: The path to the docker context.
198
+ environment: The path to the conda environment file.
199
+ base_image: The base image to use.
200
+ cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
201
+ labels: The container labels.
202
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
203
+ args: The user arguments to pass to the executable.
204
+ env_vars: The environment variables to pass to the executable.
205
+
206
+ Returns: A Conda container packageable.
207
+ """
208
+ entrypoint_args = job_blocks.get_args_for_python_entrypoint(entrypoint)
209
+ args = xm.merge_args(entrypoint_args, args or {})
210
+
211
+ if context is None:
212
+ context = utils.find_project_root()
213
+ context = context.resolve()
214
+ if environment is None:
215
+ environment = context / "environment.yml"
216
+ environment = environment.resolve()
217
+ if not environment.exists():
218
+ raise ValueError(f"Conda environment manifest `{environment}` doesn't exist.")
219
+ if not environment.is_relative_to(context):
220
+ raise ValueError(
221
+ f"Conda environment manifest `{environment}` must be relative to context: `{context}`"
222
+ )
223
+
224
+ with resources.as_file(
225
+ resources.files("xm_slurm.templates").joinpath("docker/mamba.Dockerfile")
226
+ ) as dockerfile:
227
+ return docker_container(
228
+ executor_spec=executor_spec,
229
+ dockerfile=dockerfile,
230
+ context=context,
231
+ ssh=ssh,
232
+ build_args={
233
+ "CONDA_ENVIRONMENT": environment.relative_to(context).as_posix(),
234
+ "BASE_IMAGE": base_image,
235
+ },
236
+ cache_from=cache_from,
237
+ labels=labels,
238
+ args=args,
239
+ env_vars=env_vars,
240
+ )
241
+
242
+
243
+ conda_container = mamba_container
244
+
245
+
246
+ def uv_container(
247
+ *,
248
+ executor_spec: xm.ExecutorSpec,
249
+ entrypoint: xm.ModuleName | xm.CommandList,
250
+ context: pathlib.Path | None = None,
251
+ base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
252
+ extra_system_packages: tp.Sequence[str] = (),
253
+ extra_python_packages: tp.Sequence[str] = (),
254
+ cache_from: str | tp.Sequence[str] | None = None,
255
+ labels: tp.Mapping[str, str] | None = None,
256
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
257
+ args: xm.UserArgs | None = None,
258
+ env_vars: tp.Mapping[str, str] | None = None,
259
+ ) -> xm.Packageable:
260
+ """Creates a Python container from a base image using uv from a `uv.lock` file.
261
+
262
+ Args:
263
+ executor_spec: The executor specification, where will the container be stored at.
264
+ entrypoint: The entrypoint to run in the container.
265
+ context: The path to the docker context.
266
+ base_image: The base image to use. NOTE: The base image must contain the Python runtime.
267
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
268
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
269
+ cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
270
+ labels: The container labels.
271
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
272
+ args: The user arguments to pass to the executable.
273
+ env_vars: The environment variables to pass to the executable.
274
+
275
+ Returns: A Python container packageable.
276
+ """
277
+ entrypoint_args = job_blocks.get_args_for_python_entrypoint(entrypoint)
278
+ args = xm.merge_args(entrypoint_args, args or {})
279
+
280
+ if context is None:
281
+ context = utils.find_project_root()
282
+ context = context.resolve()
283
+ if not (context / "pyproject.toml").exists():
284
+ raise ValueError(f"Python project file `{context / 'pyproject.toml'}` doesn't exist.")
285
+ if not (context / "uv.lock").exists():
286
+ raise ValueError(f"UV lock file `{context / 'uv.lock'}` doesn't exist.")
287
+
288
+ with resources.as_file(
289
+ resources.files("xm_slurm.templates").joinpath("docker/uv.Dockerfile")
290
+ ) as dockerfile:
291
+ return docker_container(
292
+ executor_spec=executor_spec,
293
+ dockerfile=dockerfile,
294
+ context=context,
295
+ ssh=ssh,
296
+ build_args={
297
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
298
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
299
+ "BASE_IMAGE": base_image.format_map({
300
+ "major": sys.version_info.major,
301
+ "minor": sys.version_info.minor,
302
+ "micro": sys.version_info.micro,
303
+ }),
304
+ },
305
+ cache_from=cache_from,
306
+ labels=labels,
307
+ args=args,
308
+ env_vars=env_vars,
309
+ )
@@ -0,0 +1,8 @@
1
+ # First register our built-in packaging methods
2
+ import xm_slurm.packaging.docker # noqa: F401
3
+ from xm_slurm.packaging import registry, router
4
+
5
+ package = router.package
6
+ register = registry.register
7
+
8
+ __all__ = ["package", "register"]
@@ -0,0 +1,348 @@
1
+ import base64
2
+ import collections.abc
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ import hashlib
7
+ import json
8
+ import logging
9
+ import os
10
+ import pathlib
11
+ import shlex
12
+ import shutil
13
+ import tempfile
14
+ import typing as tp
15
+
16
+ import jinja2 as j2
17
+ from xmanager import xm
18
+
19
+ from xm_slurm import utils
20
+ from xm_slurm.executables import (
21
+ Dockerfile,
22
+ DockerImage,
23
+ ImageURI,
24
+ RemoteImage,
25
+ RemoteRepositoryCredentials,
26
+ )
27
+ from xm_slurm.executors import SlurmSpec
28
+ from xm_slurm.packaging import registry
29
+ from xm_slurm.packaging import utils as packaging_utils
30
+ from xm_slurm.packaging.registry import IndexedContainer
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _hash_digest(obj: tp.Hashable) -> str:
36
+ return hashlib.sha256(repr(obj).encode()).hexdigest()
37
+
38
+
39
+ class DockerClient:
40
+ class Builder(enum.Enum):
41
+ BUILDKIT = enum.auto()
42
+ BUILDAH = enum.auto()
43
+
44
+ def __init__(self) -> None:
45
+ if "XM_DOCKER_CLIENT" in os.environ:
46
+ client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
47
+ elif shutil.which("docker"):
48
+ client_call = ["docker"]
49
+ elif shutil.which("podman"):
50
+ client_call = ["podman"]
51
+ else:
52
+ raise RuntimeError("No Docker client found.")
53
+ self._client_call = client_call
54
+
55
+ backend_version = utils.run_command(
56
+ xm.merge_args(self._client_call, ["buildx", "version"]), return_stdout=True
57
+ )
58
+ if backend_version.stdout.startswith("github.com/docker/buildx"):
59
+ self._builder = DockerClient.Builder.BUILDKIT
60
+ else:
61
+ raise NotImplementedError(f"Unsupported Docker build backend: {backend_version}")
62
+
63
+ self._credentials_cache: dict[str, RemoteRepositoryCredentials] = {}
64
+
65
+ def credentials(self, hostname: str) -> RemoteRepositoryCredentials | None:
66
+ """Fetch credentials from the local Docker configuration."""
67
+ if hostname in self._credentials_cache:
68
+ return self._credentials_cache[hostname]
69
+
70
+ def _parse_docker_credentials(helper: str) -> RemoteRepositoryCredentials | None:
71
+ """Parse credentials from a Docker credential helper."""
72
+ if not shutil.which(f"docker-credential-{helper}"):
73
+ return None
74
+ result = utils.run_command(
75
+ [f"docker-credential-{helper}", "get"],
76
+ stdin=hostname,
77
+ return_stdout=True,
78
+ )
79
+
80
+ if result.returncode == 0:
81
+ credentials = json.loads(result.stdout)
82
+ return RemoteRepositoryCredentials(
83
+ username=str.strip(credentials["Username"]),
84
+ password=str.strip(credentials["Secret"]),
85
+ )
86
+ return None
87
+
88
+ def _parse_credentials_from_config(
89
+ config_path: pathlib.Path,
90
+ ) -> RemoteRepositoryCredentials | None:
91
+ """Parse credentials from the Docker configuration file."""
92
+ if not config_path.exists():
93
+ return None
94
+ config = json.loads(config_path.read_text())
95
+
96
+ # Attempt to parse from the global credential store
97
+ if (creds_store := config.get("credsStore", None)) and (
98
+ credentials := _parse_docker_credentials(creds_store)
99
+ ):
100
+ self._credentials_cache[hostname] = credentials
101
+ return credentials
102
+ # Attempt to parse from the credential helper for this registry
103
+ if creds_helper := config.get("credHelpers", {}):
104
+ for registry, helper in creds_helper.items():
105
+ registry = ImageURI(registry)
106
+ if registry.domain == hostname and (
107
+ credentials := _parse_docker_credentials(helper)
108
+ ):
109
+ self._credentials_cache[hostname] = credentials
110
+ return credentials
111
+ # Last resort: attempt to parse raw auth
112
+ if auths := config.get("auths", None):
113
+ for registry, metadata in auths.items():
114
+ registry = ImageURI(registry)
115
+ if registry.domain == hostname:
116
+ auth = base64.b64decode(metadata["auth"]).decode("utf-8")
117
+ username, password = auth.split(":")
118
+ credentials = RemoteRepositoryCredentials(
119
+ str.strip(username),
120
+ str.strip(password),
121
+ )
122
+ self._credentials_cache[hostname] = credentials
123
+ return credentials
124
+ return None
125
+
126
+ # Attempt to parse credentials from the Docker or Podman configuration
127
+ match self._builder:
128
+ case DockerClient.Builder.BUILDKIT:
129
+ docker_config_path = (
130
+ pathlib.Path(os.environ.get("DOCKER_CONFIG", "~/.docker")).expanduser()
131
+ / "config.json"
132
+ )
133
+ return _parse_credentials_from_config(docker_config_path)
134
+ case DockerClient.Builder.BUILDAH:
135
+ podman_config_path = (
136
+ pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")).expanduser()
137
+ / "containers"
138
+ / "auth.json"
139
+ )
140
+ return _parse_credentials_from_config(podman_config_path)
141
+
142
+ def inspect(self, image: ImageURI, element: str) -> dict[str, tp.Any]:
143
+ output = utils.run_command(
144
+ xm.merge_args(
145
+ self._client_call,
146
+ ["buildx", "imagetools", "inspect"],
147
+ ["--format", f"{{{{json .{element}}}}}"],
148
+ [str(image)],
149
+ ),
150
+ check=True,
151
+ return_stdout=True,
152
+ )
153
+ return json.loads(output.stdout.strip().strip("'"))
154
+
155
+ @functools.cached_property
156
+ def _bake_template(self) -> j2.Template:
157
+ template_loader = j2.PackageLoader("xm_slurm", "templates/docker")
158
+ template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
159
+
160
+ return template_env.get_template("docker-bake.hcl.j2")
161
+
162
+ def _bake_args(
163
+ self,
164
+ *,
165
+ targets: str | tp.Sequence[str] | None = None,
166
+ builder: str | None = None,
167
+ files: str | os.PathLike[str] | tp.Sequence[os.PathLike[str] | str] | None = None,
168
+ load: bool = False,
169
+ cache: bool = True,
170
+ print: bool = False,
171
+ pull: bool = False,
172
+ push: bool = False,
173
+ metadata_file: str | os.PathLike[str] | None = None,
174
+ progress: tp.Literal["auto", "plain", "tty"] = "auto",
175
+ set: tp.Mapping[str, str] | None = None,
176
+ ) -> xm.SequentialArgs:
177
+ files = files
178
+ if files is None:
179
+ files = []
180
+ if not isinstance(files, collections.abc.Sequence):
181
+ files = [files]
182
+
183
+ targets = targets
184
+ if targets is None:
185
+ targets = []
186
+ elif isinstance(targets, str):
187
+ targets = [targets]
188
+ assert isinstance(targets, collections.abc.Sequence)
189
+
190
+ return xm.merge_args(
191
+ ["buildx", "bake"],
192
+ [f"--progress={progress}"],
193
+ [f"--builder={builder}"] if builder else [],
194
+ [f"--metadata-file={metadata_file}"] if metadata_file else [],
195
+ ["--print"] if print else [],
196
+ ["--push"] if push else [],
197
+ ["--pull"] if pull else [],
198
+ ["--load"] if load else [],
199
+ ["--no-cache"] if not cache else [],
200
+ [f"--file={file}" for file in files],
201
+ [f"--set={key}={value}" for key, value in set.items()] if set else [],
202
+ targets,
203
+ )
204
+
205
+ def bake(
206
+ self, *, targets: tp.Sequence[IndexedContainer[xm.Packageable]]
207
+ ) -> list[IndexedContainer[RemoteImage]]:
208
+ executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
209
+ for executable, executors in executors_by_executables.items():
210
+ assert isinstance(
211
+ executable, Dockerfile
212
+ ), "All executables must be Dockerfiles when building Docker images."
213
+ assert all(
214
+ isinstance(executor, SlurmSpec) and executor.tag for executor in executors
215
+ ), "All executors must be SlurmSpecs with tags when building Docker images."
216
+
217
+ with tempfile.TemporaryDirectory() as tempdir:
218
+ hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
219
+ metadata_file = pathlib.Path(tempdir) / "metadata.json"
220
+
221
+ # Write HCL and bake it
222
+ # TODO(jfarebro): Need a better way to hash the executables
223
+ hcl = self._bake_template.render(
224
+ executables=executors_by_executables,
225
+ hash=_hash_digest,
226
+ )
227
+ hcl_file.write_text(hcl)
228
+ logger.debug(hcl)
229
+
230
+ try:
231
+ bake_command = xm.merge_args(
232
+ self._client_call,
233
+ self._bake_args(
234
+ targets=list(
235
+ set([_hash_digest(target.value.executable_spec) for target in targets])
236
+ ),
237
+ files=[hcl_file],
238
+ metadata_file=metadata_file,
239
+ pull=False,
240
+ push=True,
241
+ ),
242
+ )
243
+ utils.run_command(bake_command.to_list(), tty=True, check=True)
244
+ except Exception as ex:
245
+ raise RuntimeError(f"Failed to build Dockerfiles: {ex}") from ex
246
+ else:
247
+ metadata = json.loads(metadata_file.read_text())
248
+
249
+ images = []
250
+ for target in targets:
251
+ assert isinstance(target.value.executable_spec, Dockerfile)
252
+ assert isinstance(target.value.executor_spec, SlurmSpec)
253
+ assert target.value.executor_spec.tag
254
+
255
+ executable_metadata = metadata[_hash_digest(target.value.executable_spec)]
256
+ uri = ImageURI(target.value.executor_spec.tag).with_digest(
257
+ executable_metadata["containerimage.digest"]
258
+ )
259
+ config = self.inspect(uri, "Image.Config")
260
+ if "WorkingDir" not in config:
261
+ raise ValueError(
262
+ "Docker image does not have a working directory. "
263
+ "To support all runtimes, we need to set a working directory. "
264
+ "Please set `WORKDIR` in the `Dockerfile`."
265
+ )
266
+ if "Entrypoint" not in config:
267
+ raise ValueError(
268
+ "Docker image does not have an entrypoint. "
269
+ "To support all runtimes, we need to set an entrypoint. "
270
+ "Please set `ENTRYPOINT` in the `Dockerfile`."
271
+ )
272
+
273
+ images.append(
274
+ dataclasses.replace(
275
+ target,
276
+ value=RemoteImage( # type: ignore
277
+ image=str(uri),
278
+ workdir=config["WorkingDir"],
279
+ entrypoint=xm.SequentialArgs.from_collection(config["Entrypoint"]),
280
+ args=target.value.args,
281
+ env_vars=target.value.env_vars,
282
+ credentials=self.credentials(uri.domain),
283
+ ),
284
+ )
285
+ )
286
+
287
+ return images
288
+
289
+
290
+ @functools.cache
291
+ def docker_client() -> DockerClient:
292
+ return DockerClient()
293
+
294
+
295
+ @registry.register(Dockerfile)
296
+ def _(
297
+ targets: tp.Sequence[IndexedContainer[xm.Packageable]],
298
+ ) -> list[IndexedContainer[RemoteImage]]:
299
+ return docker_client().bake(targets=targets)
300
+
301
+
302
+ @registry.register(DockerImage)
303
+ def _(
304
+ targets: tp.Sequence[IndexedContainer[xm.Packageable]],
305
+ ) -> list[IndexedContainer[RemoteImage]]:
306
+ """Build Docker images, this is essentially a passthrough."""
307
+ images = []
308
+ client = docker_client()
309
+ for target in targets:
310
+ assert isinstance(target.value.executable_spec, DockerImage)
311
+ assert isinstance(target.value.executor_spec, SlurmSpec)
312
+ if target.value.executor_spec.tag is not None:
313
+ raise ValueError(
314
+ "Executable `DockerImage` should not be tagged via `SlurmSpec`. "
315
+ "The image URI is provided by the `DockerImage` itself."
316
+ )
317
+
318
+ uri = ImageURI(target.value.executable_spec.image)
319
+
320
+ config = client.inspect(uri, "Image.Config")
321
+ if "WorkingDir" not in config:
322
+ raise ValueError(
323
+ "Docker image does not have a working directory. "
324
+ "To support all runtimes, we need to set a working directory. "
325
+ "Please set `WORKDIR` in the `Dockerfile`."
326
+ )
327
+ if "Entrypoint" not in config:
328
+ raise ValueError(
329
+ "Docker image does not have an entrypoint. "
330
+ "To support all runtimes, we need to set an entrypoint. "
331
+ "Please set `ENTRYPOINT` in the `Dockerfile`."
332
+ )
333
+
334
+ images.append(
335
+ dataclasses.replace(
336
+ target,
337
+ value=RemoteImage( # type: ignore
338
+ image=str(uri),
339
+ workdir=config["WorkingDir"],
340
+ entrypoint=xm.SequentialArgs.from_collection(config["Entrypoint"]),
341
+ args=target.value.args,
342
+ env_vars=target.value.env_vars,
343
+ credentials=client.credentials(hostname=uri.domain),
344
+ ),
345
+ )
346
+ )
347
+
348
+ return images