xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +99 -54
- xm_slurm/constants.py +15 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +22 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/executables.py +19 -7
- xm_slurm/execution.py +86 -38
- xm_slurm/experiment.py +273 -131
- xm_slurm/experimental/parameter_controller.py +200 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +45 -18
- xm_slurm/packaging/docker/__init__.py +5 -11
- xm_slurm/packaging/docker/local.py +13 -12
- xm_slurm/packaging/utils.py +7 -55
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +7 -0
- xm_slurm/templates/docker/mamba.Dockerfile +3 -1
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.0.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.0.dist-info/RECORD +42 -0
- {xmanager_slurm-0.3.1.dist-info → xmanager_slurm-0.4.0.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.0.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.1.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.1.dist-info/RECORD +0 -38
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import dataclasses
|
|
4
|
+
import enum
|
|
5
|
+
import functools
|
|
6
|
+
import logging
|
|
7
|
+
import zlib
|
|
8
|
+
from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
|
|
9
|
+
|
|
10
|
+
import backoff
|
|
11
|
+
import cloudpickle
|
|
12
|
+
from xmanager import xm
|
|
13
|
+
|
|
14
|
+
import xm_slurm
|
|
15
|
+
from xm_slurm import job_blocks, status
|
|
16
|
+
from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
|
|
17
|
+
|
|
18
|
+
P = ParamSpec("P")
|
|
19
|
+
T = TypeVar("T")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
async def _monitor_parameter_controller(
|
|
23
|
+
aux_unit: SlurmAuxiliaryUnit,
|
|
24
|
+
local_parameter_controller_coro: Coroutine[None, None, T],
|
|
25
|
+
*,
|
|
26
|
+
poll_interval: float = 30.0,
|
|
27
|
+
) -> None:
|
|
28
|
+
local_controller_finished = asyncio.Event()
|
|
29
|
+
local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
|
|
30
|
+
|
|
31
|
+
@local_parameter_controller.add_done_callback
|
|
32
|
+
def _(future: asyncio.Task[T]) -> None:
|
|
33
|
+
try:
|
|
34
|
+
_ = future.result()
|
|
35
|
+
except asyncio.CancelledError:
|
|
36
|
+
logging.info("Local parameter controller was cancelled, resuming on remote controller.")
|
|
37
|
+
pass
|
|
38
|
+
except Exception:
|
|
39
|
+
logging.error("Local parameter controller failed, stopping remote controller.")
|
|
40
|
+
aux_unit.stop(
|
|
41
|
+
mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
|
|
42
|
+
)
|
|
43
|
+
raise
|
|
44
|
+
else:
|
|
45
|
+
logging.info(
|
|
46
|
+
"Local parameter controller finished before remote controller started, "
|
|
47
|
+
"stopping remote controller."
|
|
48
|
+
)
|
|
49
|
+
local_controller_finished.set()
|
|
50
|
+
aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
|
|
51
|
+
|
|
52
|
+
@backoff.on_predicate(
|
|
53
|
+
backoff.constant,
|
|
54
|
+
lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
|
|
55
|
+
jitter=None,
|
|
56
|
+
interval=poll_interval,
|
|
57
|
+
)
|
|
58
|
+
async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
|
|
59
|
+
logging.info("Waiting for remote parameter controller to start.")
|
|
60
|
+
if local_controller_finished.is_set():
|
|
61
|
+
return status.SlurmWorkUnitStatusEnum.COMPLETED
|
|
62
|
+
return (await aux_unit.get_status()).status
|
|
63
|
+
|
|
64
|
+
logging.info("Monitoring remote parameter controller.")
|
|
65
|
+
# TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
|
|
66
|
+
# We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
|
|
67
|
+
await asyncio.sleep(15)
|
|
68
|
+
match await wait_for_remote_controller():
|
|
69
|
+
case status.SlurmWorkUnitStatusEnum.RUNNING:
|
|
70
|
+
logging.info("Remote parameter controller started.")
|
|
71
|
+
local_parameter_controller.cancel("Remote parameter controller started.")
|
|
72
|
+
case status.SlurmWorkUnitStatusEnum.COMPLETED:
|
|
73
|
+
if local_parameter_controller.done():
|
|
74
|
+
logging.info("Local parameter controller finished, stopping remote controller.")
|
|
75
|
+
aux_unit.stop(
|
|
76
|
+
mark_as_completed=True, message="Local parameter controller finished."
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
logging.info("Remote parameter controller finished, stopping local controller.")
|
|
80
|
+
local_parameter_controller.cancel()
|
|
81
|
+
case status.SlurmWorkUnitStatusEnum.FAILED:
|
|
82
|
+
logging.error("Remote parameter controller failed, stopping local controller.")
|
|
83
|
+
local_parameter_controller.cancel()
|
|
84
|
+
case status.SlurmWorkUnitStatusEnum.CANCELLED:
|
|
85
|
+
logging.info("Remote parameter controller was cancelled, stopping local controller.")
|
|
86
|
+
local_parameter_controller.cancel()
|
|
87
|
+
case status.SlurmWorkUnitStatusEnum.PENDING:
|
|
88
|
+
raise RuntimeError("Remote parameter controller is still pending, invalid state.")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class ParameterControllerMode(enum.Enum):
|
|
92
|
+
AUTO = enum.auto()
|
|
93
|
+
REMOTE_ONLY = enum.auto()
|
|
94
|
+
# TODO(jfarebro): is it possible to get LOCAL_ONLY?
|
|
95
|
+
# We'd need to have a dummy job type as we need to return a JobType?
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def parameter_controller(
|
|
99
|
+
*,
|
|
100
|
+
executable: xm.Executable,
|
|
101
|
+
executor: xm.Executor,
|
|
102
|
+
controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
|
|
103
|
+
controller_name: str = "parameter_controller",
|
|
104
|
+
controller_args: xm.UserArgs | None = None,
|
|
105
|
+
controller_env_vars: Mapping[str, str] | None = None,
|
|
106
|
+
) -> Callable[
|
|
107
|
+
[
|
|
108
|
+
Callable[Concatenate[SlurmExperiment, P], T]
|
|
109
|
+
| Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
|
|
110
|
+
],
|
|
111
|
+
Callable[P, xm.AuxiliaryUnitJob],
|
|
112
|
+
]:
|
|
113
|
+
"""Converts a function to a controller which can be added to an experiment.
|
|
114
|
+
|
|
115
|
+
Calling the wrapped function would return an xm.JobGenerator which would run
|
|
116
|
+
it as auxiliary unit on the specified executor.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
executable: An executable that has a Python entrypoint with all the necesarry dependencies.
|
|
120
|
+
executor: The executor to launch the controller job on.
|
|
121
|
+
controller_name: Name of the parameter controller job.
|
|
122
|
+
controller_args: Mapping of flag names and values to be used by the XM
|
|
123
|
+
client running inside the parameter controller job.
|
|
124
|
+
controller_env_vars: Mapping of env variable names and values to be passed
|
|
125
|
+
to the parameter controller job.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
A decorator to be applied to the function.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def decorator(
|
|
132
|
+
f: Callable[Concatenate[SlurmExperiment, P], T]
|
|
133
|
+
| Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
|
|
134
|
+
) -> Callable[P, xm.AuxiliaryUnitJob]:
|
|
135
|
+
@functools.wraps(f)
|
|
136
|
+
def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
|
|
137
|
+
# Modify the function to read the experiment from the API so that it can be pickled.
|
|
138
|
+
|
|
139
|
+
async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
|
|
140
|
+
experiment_id = aux_unit.experiment.experiment_id
|
|
141
|
+
|
|
142
|
+
async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
|
|
143
|
+
if asyncio.iscoroutinefunction(f):
|
|
144
|
+
return await f(aux_unit.experiment, *args, **kwargs)
|
|
145
|
+
else:
|
|
146
|
+
return f(aux_unit.experiment, *args, **kwargs)
|
|
147
|
+
|
|
148
|
+
async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
|
|
149
|
+
async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
|
|
150
|
+
if asyncio.iscoroutinefunction(f):
|
|
151
|
+
return await f(exp, *args, **kwargs)
|
|
152
|
+
else:
|
|
153
|
+
return f(exp, *args, **kwargs)
|
|
154
|
+
|
|
155
|
+
remote_controller_serialized = base64.urlsafe_b64encode(
|
|
156
|
+
zlib.compress(
|
|
157
|
+
cloudpickle.dumps(
|
|
158
|
+
functools.partial(remote_controller, *args, **kwargs),
|
|
159
|
+
)
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
parameter_controller_executable = dataclasses.replace(
|
|
164
|
+
executable,
|
|
165
|
+
args=xm.merge_args(
|
|
166
|
+
job_blocks.get_args_for_python_entrypoint(
|
|
167
|
+
xm.ModuleName("xm_slurm.scripts._cloudpickle")
|
|
168
|
+
),
|
|
169
|
+
xm.SequentialArgs.from_collection({
|
|
170
|
+
"cloudpickled_fn": remote_controller_serialized.decode("ascii"),
|
|
171
|
+
}),
|
|
172
|
+
xm.SequentialArgs.from_collection(controller_args),
|
|
173
|
+
),
|
|
174
|
+
env_vars=controller_env_vars or {},
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
await aux_unit.add(
|
|
178
|
+
xm.Job(
|
|
179
|
+
executor=executor,
|
|
180
|
+
executable=parameter_controller_executable,
|
|
181
|
+
name=controller_name,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Launch local parameter controller and monitor for when it starts running
|
|
186
|
+
# so we can kill the local controller.
|
|
187
|
+
if controller_mode is ParameterControllerMode.AUTO:
|
|
188
|
+
aux_unit._create_task(
|
|
189
|
+
_monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return xm.AuxiliaryUnitJob(
|
|
193
|
+
job_generator,
|
|
194
|
+
importance=xm.Importance.HIGH,
|
|
195
|
+
termination_delay_secs=0, # TODO: add support for termination delay.?
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return make_controller
|
|
199
|
+
|
|
200
|
+
return decorator
|
xm_slurm/job_blocks.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
+
from typing import Mapping, TypedDict
|
|
2
|
+
|
|
1
3
|
from xmanager import xm
|
|
2
4
|
|
|
3
5
|
|
|
6
|
+
class JobArgs(TypedDict, total=False):
|
|
7
|
+
args: xm.UserArgs
|
|
8
|
+
env_vars: Mapping[str, str]
|
|
9
|
+
|
|
10
|
+
|
|
4
11
|
def get_args_for_python_entrypoint(
|
|
5
12
|
entrypoint: xm.ModuleName | xm.CommandList,
|
|
6
13
|
) -> xm.SequentialArgs:
|
xm_slurm/packageables.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import importlib.resources as resources
|
|
2
2
|
import pathlib
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Mapping, Sequence
|
|
4
|
+
from typing import Literal, Mapping, Sequence
|
|
5
5
|
|
|
6
6
|
import immutabledict
|
|
7
7
|
from xmanager import xm
|
|
@@ -29,8 +29,8 @@ def docker_image(
|
|
|
29
29
|
return xm.Packageable(
|
|
30
30
|
executor_spec=SlurmSpec(),
|
|
31
31
|
executable_spec=DockerImage(image=image),
|
|
32
|
-
args=args,
|
|
33
|
-
env_vars=env_vars,
|
|
32
|
+
args=xm.SequentialArgs.from_collection(args),
|
|
33
|
+
env_vars=dict(env_vars),
|
|
34
34
|
)
|
|
35
35
|
|
|
36
36
|
|
|
@@ -40,6 +40,7 @@ def docker_container(
|
|
|
40
40
|
dockerfile: pathlib.Path | None = None,
|
|
41
41
|
context: pathlib.Path | None = None,
|
|
42
42
|
target: str | None = None,
|
|
43
|
+
ssh: Sequence[str] | Literal[True] | None = None,
|
|
43
44
|
build_args: Mapping[str, str] = immutabledict.immutabledict(),
|
|
44
45
|
cache_from: str | Sequence[str] | None = None,
|
|
45
46
|
labels: Mapping[str, str] = immutabledict.immutabledict(),
|
|
@@ -54,6 +55,7 @@ def docker_container(
|
|
|
54
55
|
dockerfile: The path to the dockerfile.
|
|
55
56
|
context: The path to the docker context.
|
|
56
57
|
target: The docker build target.
|
|
58
|
+
ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
|
|
57
59
|
build_args: Build arguments to docker.
|
|
58
60
|
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
59
61
|
labels: The container labels.
|
|
@@ -69,6 +71,12 @@ def docker_container(
|
|
|
69
71
|
if dockerfile is None:
|
|
70
72
|
dockerfile = context / "Dockerfile"
|
|
71
73
|
dockerfile = dockerfile.resolve()
|
|
74
|
+
|
|
75
|
+
if ssh is None:
|
|
76
|
+
ssh = []
|
|
77
|
+
elif ssh is True:
|
|
78
|
+
ssh = ["default"]
|
|
79
|
+
|
|
72
80
|
if cache_from is None and isinstance(executor_spec, SlurmSpec):
|
|
73
81
|
cache_from = executor_spec.tag
|
|
74
82
|
if cache_from is None:
|
|
@@ -82,13 +90,14 @@ def docker_container(
|
|
|
82
90
|
dockerfile=dockerfile,
|
|
83
91
|
context=context,
|
|
84
92
|
target=target,
|
|
93
|
+
ssh=ssh,
|
|
85
94
|
build_args=build_args,
|
|
86
95
|
cache_from=cache_from,
|
|
87
96
|
workdir=workdir,
|
|
88
97
|
labels=labels,
|
|
89
98
|
),
|
|
90
|
-
args=args,
|
|
91
|
-
env_vars=env_vars,
|
|
99
|
+
args=xm.SequentialArgs.from_collection(args),
|
|
100
|
+
env_vars=dict(env_vars),
|
|
92
101
|
)
|
|
93
102
|
|
|
94
103
|
|
|
@@ -99,14 +108,18 @@ def python_container(
|
|
|
99
108
|
context: pathlib.Path | None = None,
|
|
100
109
|
requirements: pathlib.Path | None = None,
|
|
101
110
|
base_image: str = "docker.io/python:{major}.{minor}-slim",
|
|
111
|
+
extra_system_packages: Sequence[str] = (),
|
|
112
|
+
extra_python_packages: Sequence[str] = (),
|
|
102
113
|
cache_from: str | Sequence[str] | None = None,
|
|
103
114
|
labels: Mapping[str, str] = immutabledict.immutabledict(),
|
|
115
|
+
ssh: Sequence[str] | Literal[True] | None = None,
|
|
104
116
|
args: xm.UserArgs | None = None,
|
|
105
117
|
env_vars: Mapping[str, str] = immutabledict.immutabledict(),
|
|
106
118
|
) -> xm.Packageable:
|
|
107
119
|
"""Creates a Python container from a base image using pip from a `requirements.txt` file.
|
|
108
120
|
|
|
109
121
|
NOTE: The base image will use the Python version of the current interpreter.
|
|
122
|
+
NOTE: uv is used to install packages from `requirements`.
|
|
110
123
|
|
|
111
124
|
Args:
|
|
112
125
|
executor_spec: The executor specification, where will the container be stored at.
|
|
@@ -114,8 +127,11 @@ def python_container(
|
|
|
114
127
|
context: The path to the docker context.
|
|
115
128
|
requirements: The path to the pip requirements file.
|
|
116
129
|
base_image: The base image to use. NOTE: The base image must contain the Python runtime.
|
|
130
|
+
extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
|
|
131
|
+
extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
|
|
117
132
|
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
118
133
|
labels: The container labels.
|
|
134
|
+
ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
|
|
119
135
|
args: The user arguments to pass to the executable.
|
|
120
136
|
env_vars: The environment variables to pass to the executable.
|
|
121
137
|
|
|
@@ -144,11 +160,11 @@ def python_container(
|
|
|
144
160
|
executor_spec=executor_spec,
|
|
145
161
|
dockerfile=dockerfile,
|
|
146
162
|
context=context,
|
|
163
|
+
ssh=ssh,
|
|
147
164
|
build_args={
|
|
148
165
|
"PIP_REQUIREMENTS": requirements.relative_to(context).as_posix(),
|
|
149
|
-
"
|
|
150
|
-
"
|
|
151
|
-
"PYTHON_MICRO": str(sys.version_info.micro),
|
|
166
|
+
"EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
|
|
167
|
+
"EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
|
|
152
168
|
"BASE_IMAGE": base_image.format_map({
|
|
153
169
|
"major": sys.version_info.major,
|
|
154
170
|
"minor": sys.version_info.minor,
|
|
@@ -173,6 +189,7 @@ def mamba_container(
|
|
|
173
189
|
base_image: str = "gcr.io/distroless/base-debian10",
|
|
174
190
|
cache_from: str | Sequence[str] | None = None,
|
|
175
191
|
labels: Mapping[str, str] = immutabledict.immutabledict(),
|
|
192
|
+
ssh: Sequence[str] | Literal[True] | None = None,
|
|
176
193
|
args: xm.UserArgs | None = None,
|
|
177
194
|
env_vars: Mapping[str, str] = immutabledict.immutabledict(),
|
|
178
195
|
) -> xm.Packageable:
|
|
@@ -188,6 +205,7 @@ def mamba_container(
|
|
|
188
205
|
base_image: The base image to use.
|
|
189
206
|
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
190
207
|
labels: The container labels.
|
|
208
|
+
ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
|
|
191
209
|
args: The user arguments to pass to the executable.
|
|
192
210
|
env_vars: The environment variables to pass to the executable.
|
|
193
211
|
|
|
@@ -216,6 +234,7 @@ def mamba_container(
|
|
|
216
234
|
executor_spec=executor_spec,
|
|
217
235
|
dockerfile=dockerfile,
|
|
218
236
|
context=context,
|
|
237
|
+
ssh=ssh,
|
|
219
238
|
build_args={
|
|
220
239
|
"CONDA_ENVIRONMENT": environment.relative_to(context).as_posix(),
|
|
221
240
|
"BASE_IMAGE": base_image,
|
|
@@ -232,26 +251,32 @@ def mamba_container(
|
|
|
232
251
|
conda_container = mamba_container
|
|
233
252
|
|
|
234
253
|
|
|
235
|
-
def
|
|
254
|
+
def uv_container(
|
|
236
255
|
*,
|
|
237
256
|
executor_spec: xm.ExecutorSpec,
|
|
238
257
|
entrypoint: xm.ModuleName | xm.CommandList,
|
|
239
258
|
context: pathlib.Path | None = None,
|
|
240
|
-
base_image: str = "docker.io/python:{major}.{minor}-slim",
|
|
259
|
+
base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
|
|
260
|
+
extra_system_packages: Sequence[str] = (),
|
|
261
|
+
extra_python_packages: Sequence[str] = (),
|
|
241
262
|
cache_from: str | Sequence[str] | None = None,
|
|
242
263
|
labels: Mapping[str, str] = immutabledict.immutabledict(),
|
|
264
|
+
ssh: Sequence[str] | Literal[True] | None = None,
|
|
243
265
|
args: xm.UserArgs | None = None,
|
|
244
266
|
env_vars: Mapping[str, str] = immutabledict.immutabledict(),
|
|
245
267
|
) -> xm.Packageable:
|
|
246
|
-
"""Creates a Python container from a base image using
|
|
268
|
+
"""Creates a Python container from a base image using uv from a `uv.lock` file.
|
|
247
269
|
|
|
248
270
|
Args:
|
|
249
271
|
executor_spec: The executor specification, where will the container be stored at.
|
|
250
272
|
entrypoint: The entrypoint to run in the container.
|
|
251
273
|
context: The path to the docker context.
|
|
252
274
|
base_image: The base image to use. NOTE: The base image must contain the Python runtime.
|
|
275
|
+
extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
|
|
276
|
+
extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
|
|
253
277
|
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
254
278
|
labels: The container labels.
|
|
279
|
+
ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
|
|
255
280
|
args: The user arguments to pass to the executable.
|
|
256
281
|
env_vars: The environment variables to pass to the executable.
|
|
257
282
|
|
|
@@ -263,20 +288,22 @@ def pdm_container(
|
|
|
263
288
|
if context is None:
|
|
264
289
|
context = utils.find_project_root()
|
|
265
290
|
context = context.resolve()
|
|
266
|
-
if not (context / "
|
|
267
|
-
raise ValueError(f"
|
|
291
|
+
if not (context / "pyproject.toml").exists():
|
|
292
|
+
raise ValueError(f"Python project file `{context / 'pyproject.toml'}` doesn't exist.")
|
|
293
|
+
if not (context / "uv.lock").exists():
|
|
294
|
+
raise ValueError(f"UV lock file `{context / 'uv.lock'}` doesn't exist.")
|
|
268
295
|
|
|
269
296
|
with resources.as_file(
|
|
270
|
-
resources.files("xm_slurm.templates").joinpath("docker/
|
|
297
|
+
resources.files("xm_slurm.templates").joinpath("docker/uv.Dockerfile")
|
|
271
298
|
) as dockerfile:
|
|
272
299
|
return docker_container(
|
|
273
300
|
executor_spec=executor_spec,
|
|
274
301
|
dockerfile=dockerfile,
|
|
275
302
|
context=context,
|
|
303
|
+
ssh=ssh,
|
|
276
304
|
build_args={
|
|
277
|
-
"
|
|
278
|
-
"
|
|
279
|
-
"PYTHON_MICRO": str(sys.version_info.micro),
|
|
305
|
+
"EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
|
|
306
|
+
"EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
|
|
280
307
|
"BASE_IMAGE": base_image.format_map({
|
|
281
308
|
"major": sys.version_info.major,
|
|
282
309
|
"minor": sys.version_info.minor,
|
|
@@ -286,7 +313,7 @@ def pdm_container(
|
|
|
286
313
|
cache_from=cache_from,
|
|
287
314
|
labels=labels,
|
|
288
315
|
# We must specify the workdir manually for apptainer support
|
|
289
|
-
workdir=pathlib.Path("/workspace
|
|
316
|
+
workdir=pathlib.Path("/workspace"),
|
|
290
317
|
args=args,
|
|
291
318
|
env_vars=env_vars,
|
|
292
319
|
)
|
|
@@ -11,8 +11,8 @@ from xm_slurm.packaging import registry
|
|
|
11
11
|
from xm_slurm.packaging.docker.abc import DockerClient
|
|
12
12
|
|
|
13
13
|
FLAGS = flags.FLAGS
|
|
14
|
-
|
|
15
|
-
"
|
|
14
|
+
DOCKER_CLIENT_PROVIDER = flags.DEFINE_enum(
|
|
15
|
+
"xm_docker_client", "docker", ["docker"], "Docker image build client."
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
IndexedContainer = registry.IndexedContainer
|
|
@@ -20,19 +20,13 @@ IndexedContainer = registry.IndexedContainer
|
|
|
20
20
|
|
|
21
21
|
@functools.cache
|
|
22
22
|
def docker_client() -> DockerClient:
|
|
23
|
-
match
|
|
24
|
-
case "
|
|
23
|
+
match DOCKER_CLIENT_PROVIDER.value:
|
|
24
|
+
case "docker":
|
|
25
25
|
from xm_slurm.packaging.docker.local import LocalDockerClient
|
|
26
26
|
|
|
27
27
|
return LocalDockerClient()
|
|
28
|
-
case "gcp":
|
|
29
|
-
from xm_slurm.packaging.docker.cloud import GoogleCloudRemoteDockerClient
|
|
30
|
-
|
|
31
|
-
return GoogleCloudRemoteDockerClient()
|
|
32
|
-
case "azure":
|
|
33
|
-
raise NotImplementedError("Azure remote build is not yet supported.")
|
|
34
28
|
case _:
|
|
35
|
-
raise ValueError(f"Unknown
|
|
29
|
+
raise ValueError(f"Unknown build client: {DOCKER_CLIENT_PROVIDER.value}")
|
|
36
30
|
|
|
37
31
|
|
|
38
32
|
@registry.register(Dockerfile)
|
|
@@ -9,7 +9,6 @@ import shlex
|
|
|
9
9
|
import shutil
|
|
10
10
|
import subprocess
|
|
11
11
|
import tempfile
|
|
12
|
-
import typing
|
|
13
12
|
from typing import Sequence
|
|
14
13
|
|
|
15
14
|
from xmanager import xm
|
|
@@ -37,7 +36,7 @@ class LocalDockerClient(DockerClient):
|
|
|
37
36
|
BUILDKIT = enum.auto()
|
|
38
37
|
BUILDAH = enum.auto()
|
|
39
38
|
|
|
40
|
-
def __init__(self):
|
|
39
|
+
def __init__(self) -> None:
|
|
41
40
|
if "XM_DOCKER_CLIENT" in os.environ:
|
|
42
41
|
client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
|
|
43
42
|
elif shutil.which("docker"):
|
|
@@ -80,7 +79,7 @@ class LocalDockerClient(DockerClient):
|
|
|
80
79
|
return None
|
|
81
80
|
|
|
82
81
|
def _parse_credentials_from_config(
|
|
83
|
-
config_path: pathlib.Path
|
|
82
|
+
config_path: pathlib.Path,
|
|
84
83
|
) -> RemoteRepositoryCredentials | None:
|
|
85
84
|
"""Parse credentials from the Docker configuration file."""
|
|
86
85
|
if not config_path.exists():
|
|
@@ -138,9 +137,13 @@ class LocalDockerClient(DockerClient):
|
|
|
138
137
|
targets: Sequence[IndexedContainer[xm.Packageable]],
|
|
139
138
|
) -> list[IndexedContainer[RemoteImage]]:
|
|
140
139
|
executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
140
|
+
for executable, executors in executors_by_executables.items():
|
|
141
|
+
assert isinstance(
|
|
142
|
+
executable, Dockerfile
|
|
143
|
+
), "All executables must be Dockerfiles when building Docker images."
|
|
144
|
+
assert all(
|
|
145
|
+
isinstance(executor, SlurmSpec) and executor.tag for executor in executors
|
|
146
|
+
), "All executors must be SlurmSpecs with tags when building Docker images."
|
|
144
147
|
|
|
145
148
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
146
149
|
hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
|
|
@@ -157,12 +160,10 @@ class LocalDockerClient(DockerClient):
|
|
|
157
160
|
try:
|
|
158
161
|
command = DockerBakeCommand(
|
|
159
162
|
targets=list(
|
|
160
|
-
set(
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
]
|
|
165
|
-
)
|
|
163
|
+
set([
|
|
164
|
+
packaging_utils.hash_digest(target.value.executable_spec)
|
|
165
|
+
for target in targets
|
|
166
|
+
])
|
|
166
167
|
),
|
|
167
168
|
files=[hcl_file],
|
|
168
169
|
metadata_file=metadata_file,
|
xm_slurm/packaging/utils.py
CHANGED
|
@@ -10,8 +10,7 @@ import re
|
|
|
10
10
|
import select
|
|
11
11
|
import shutil
|
|
12
12
|
import subprocess
|
|
13
|
-
import
|
|
14
|
-
from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
|
|
13
|
+
from typing import Callable, Concatenate, Hashable, ParamSpec, Sequence, TypeVar
|
|
15
14
|
|
|
16
15
|
from xmanager import xm
|
|
17
16
|
|
|
@@ -23,6 +22,7 @@ ReturnT = TypeVar("ReturnT")
|
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
def hash_digest(obj: Hashable) -> str:
|
|
25
|
+
# TODO(jfarebro): Need a better way to hash these objects
|
|
26
26
|
# obj_hash = hash(obj)
|
|
27
27
|
# unsigned_obj_hash = obj_hash.from_bytes(
|
|
28
28
|
# obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
|
|
@@ -54,7 +54,7 @@ def parallel_map(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
# Cursor commands to filter out from the command data stream
|
|
57
|
-
|
|
57
|
+
_CURSOR_ESCAPE_SEQUENCES_REGEX = re.compile(
|
|
58
58
|
rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
|
|
59
59
|
rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
|
|
60
60
|
rb"|\x1b\[s" # Matches cursor save position (CSI s)
|
|
@@ -64,54 +64,6 @@ cursor_commands_regex = re.compile(
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
-
@typing.overload
|
|
68
|
-
def run_command(
|
|
69
|
-
args: Sequence[str] | xm.SequentialArgs,
|
|
70
|
-
env: dict[str, str] | None = ...,
|
|
71
|
-
tty: bool = ...,
|
|
72
|
-
cwd: str | os.PathLike[str] | None = ...,
|
|
73
|
-
check: bool = ...,
|
|
74
|
-
return_stdout: Literal[False] = False,
|
|
75
|
-
return_stderr: Literal[False] = False,
|
|
76
|
-
) -> subprocess.CompletedProcess[None]: ...
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
@typing.overload
|
|
80
|
-
def run_command(
|
|
81
|
-
args: Sequence[str] | xm.SequentialArgs,
|
|
82
|
-
env: dict[str, str] | None = ...,
|
|
83
|
-
tty: bool = ...,
|
|
84
|
-
cwd: str | os.PathLike[str] | None = ...,
|
|
85
|
-
check: bool = ...,
|
|
86
|
-
return_stdout: Literal[True] = True,
|
|
87
|
-
return_stderr: Literal[False] = False,
|
|
88
|
-
) -> subprocess.CompletedProcess[str]: ...
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
@typing.overload
|
|
92
|
-
def run_command(
|
|
93
|
-
args: Sequence[str] | xm.SequentialArgs,
|
|
94
|
-
env: dict[str, str] | None = ...,
|
|
95
|
-
tty: bool = ...,
|
|
96
|
-
cwd: str | os.PathLike[str] | None = ...,
|
|
97
|
-
check: bool = ...,
|
|
98
|
-
return_stdout: Literal[False] = False,
|
|
99
|
-
return_stderr: Literal[True] = True,
|
|
100
|
-
) -> subprocess.CompletedProcess[str]: ...
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
@typing.overload
|
|
104
|
-
def run_command(
|
|
105
|
-
args: Sequence[str] | xm.SequentialArgs,
|
|
106
|
-
env: dict[str, str] | None = ...,
|
|
107
|
-
tty: bool = ...,
|
|
108
|
-
cwd: str | os.PathLike[str] | None = ...,
|
|
109
|
-
check: bool = ...,
|
|
110
|
-
return_stdout: Literal[True] = True,
|
|
111
|
-
return_stderr: Literal[True] = True,
|
|
112
|
-
) -> subprocess.CompletedProcess[str]: ...
|
|
113
|
-
|
|
114
|
-
|
|
115
67
|
def run_command(
|
|
116
68
|
args: Sequence[str] | xm.SequentialArgs,
|
|
117
69
|
env: dict[str, str] | None = None,
|
|
@@ -120,7 +72,7 @@ def run_command(
|
|
|
120
72
|
check: bool = False,
|
|
121
73
|
return_stdout: bool = False,
|
|
122
74
|
return_stderr: bool = False,
|
|
123
|
-
) -> subprocess.CompletedProcess[str]
|
|
75
|
+
) -> subprocess.CompletedProcess[str]:
|
|
124
76
|
if isinstance(args, xm.SequentialArgs):
|
|
125
77
|
args = args.to_list()
|
|
126
78
|
args = list(args)
|
|
@@ -171,7 +123,7 @@ def run_command(
|
|
|
171
123
|
fds.remove(fd)
|
|
172
124
|
continue
|
|
173
125
|
|
|
174
|
-
data =
|
|
126
|
+
data = _CURSOR_ESCAPE_SEQUENCES_REGEX.sub(b"", data)
|
|
175
127
|
|
|
176
128
|
if fd == stdout_master:
|
|
177
129
|
if return_stdout:
|
|
@@ -186,8 +138,8 @@ def run_command(
|
|
|
186
138
|
else:
|
|
187
139
|
raise RuntimeError("Unexpected file descriptor")
|
|
188
140
|
|
|
189
|
-
stdout = stdout_data.decode(errors="replace") if stdout_data else
|
|
190
|
-
stderr = stderr_data.decode(errors="replace") if stderr_data else
|
|
141
|
+
stdout = stdout_data.decode(errors="replace") if stdout_data else ""
|
|
142
|
+
stderr = stderr_data.decode(errors="replace") if stderr_data else ""
|
|
191
143
|
|
|
192
144
|
retcode = process.poll()
|
|
193
145
|
assert retcode is not None
|