xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (42) hide show
  1. xm_slurm/__init__.py +6 -2
  2. xm_slurm/api.py +301 -34
  3. xm_slurm/batching.py +4 -4
  4. xm_slurm/config.py +105 -55
  5. xm_slurm/constants.py +19 -0
  6. xm_slurm/contrib/__init__.py +0 -0
  7. xm_slurm/contrib/clusters/__init__.py +47 -13
  8. xm_slurm/contrib/clusters/drac.py +34 -16
  9. xm_slurm/dependencies.py +171 -0
  10. xm_slurm/executables.py +34 -22
  11. xm_slurm/execution.py +305 -107
  12. xm_slurm/executors.py +8 -12
  13. xm_slurm/experiment.py +601 -168
  14. xm_slurm/experimental/parameter_controller.py +202 -0
  15. xm_slurm/job_blocks.py +7 -0
  16. xm_slurm/packageables.py +42 -20
  17. xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
  18. xm_slurm/packaging/router.py +3 -1
  19. xm_slurm/packaging/utils.py +9 -81
  20. xm_slurm/resources.py +28 -4
  21. xm_slurm/scripts/_cloudpickle.py +28 -0
  22. xm_slurm/scripts/cli.py +52 -0
  23. xm_slurm/status.py +9 -0
  24. xm_slurm/templates/docker/mamba.Dockerfile +4 -2
  25. xm_slurm/templates/docker/python.Dockerfile +18 -10
  26. xm_slurm/templates/docker/uv.Dockerfile +35 -0
  27. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
  28. xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
  29. xm_slurm/templates/slurm/job.bash.j2 +4 -3
  30. xm_slurm/types.py +23 -0
  31. xm_slurm/utils.py +18 -10
  32. xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
  33. xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
  34. {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
  35. xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
  36. xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
  37. xm_slurm/packaging/docker/__init__.py +0 -75
  38. xm_slurm/packaging/docker/abc.py +0 -112
  39. xm_slurm/packaging/docker/cloud.py +0 -503
  40. xm_slurm/templates/docker/pdm.Dockerfile +0 -31
  41. xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
  42. xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
@@ -0,0 +1,202 @@
1
+ import asyncio
2
+ import base64
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ import logging
7
+ import zlib
8
+ from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
9
+
10
+ import backoff
11
+ import cloudpickle
12
+ from xmanager import xm
13
+
14
+ import xm_slurm
15
+ from xm_slurm import job_blocks, status
16
+ from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
+
18
+ P = ParamSpec("P")
19
+ T = TypeVar("T")
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ async def _monitor_parameter_controller(
25
+ aux_unit: SlurmAuxiliaryUnit,
26
+ local_parameter_controller_coro: Coroutine[None, None, T],
27
+ *,
28
+ poll_interval: float = 30.0,
29
+ ) -> None:
30
+ local_controller_finished = asyncio.Event()
31
+ local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
32
+
33
+ @local_parameter_controller.add_done_callback
34
+ def _(future: asyncio.Task[T]) -> None:
35
+ try:
36
+ _ = future.result()
37
+ except asyncio.CancelledError:
38
+ logger.info("Local parameter controller was cancelled, resuming on remote controller.")
39
+ pass
40
+ except Exception:
41
+ logger.error("Local parameter controller failed, stopping remote controller.")
42
+ aux_unit.stop(
43
+ mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
44
+ )
45
+ raise
46
+ else:
47
+ logger.info(
48
+ "Local parameter controller finished before remote controller started, "
49
+ "stopping remote controller."
50
+ )
51
+ local_controller_finished.set()
52
+ aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
53
+
54
+ @backoff.on_predicate(
55
+ backoff.constant,
56
+ lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
57
+ jitter=None,
58
+ interval=poll_interval,
59
+ )
60
+ async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
61
+ logger.info("Waiting for remote parameter controller to start.")
62
+ if local_controller_finished.is_set():
63
+ return status.SlurmWorkUnitStatusEnum.COMPLETED
64
+ return (await aux_unit.get_status()).status
65
+
66
+ logger.info("Monitoring remote parameter controller.")
67
+ # TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
68
+ # We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
69
+ await asyncio.sleep(15)
70
+ match await wait_for_remote_controller():
71
+ case status.SlurmWorkUnitStatusEnum.RUNNING:
72
+ logger.info("Remote parameter controller started.")
73
+ local_parameter_controller.cancel("Remote parameter controller started.")
74
+ case status.SlurmWorkUnitStatusEnum.COMPLETED:
75
+ if local_parameter_controller.done():
76
+ logger.info("Local parameter controller finished, stopping remote controller.")
77
+ aux_unit.stop(
78
+ mark_as_completed=True, message="Local parameter controller finished."
79
+ )
80
+ else:
81
+ logger.info("Remote parameter controller finished, stopping local controller.")
82
+ local_parameter_controller.cancel()
83
+ case status.SlurmWorkUnitStatusEnum.FAILED:
84
+ logger.error("Remote parameter controller failed, stopping local controller.")
85
+ local_parameter_controller.cancel()
86
+ case status.SlurmWorkUnitStatusEnum.CANCELLED:
87
+ logger.info("Remote parameter controller was cancelled, stopping local controller.")
88
+ local_parameter_controller.cancel()
89
+ case status.SlurmWorkUnitStatusEnum.PENDING:
90
+ raise RuntimeError("Remote parameter controller is still pending, invalid state.")
91
+
92
+
93
+ class ParameterControllerMode(enum.Enum):
94
+ AUTO = enum.auto()
95
+ REMOTE_ONLY = enum.auto()
96
+ # TODO(jfarebro): is it possible to get LOCAL_ONLY?
97
+ # We'd need to have a dummy job type as we need to return a JobType?
98
+
99
+
100
+ def parameter_controller(
101
+ *,
102
+ executable: xm.Executable,
103
+ executor: xm.Executor,
104
+ controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
105
+ controller_name: str = "parameter_controller",
106
+ controller_args: xm.UserArgs | None = None,
107
+ controller_env_vars: Mapping[str, str] | None = None,
108
+ ) -> Callable[
109
+ [
110
+ Callable[Concatenate[SlurmExperiment, P], T]
111
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
112
+ ],
113
+ Callable[P, xm.AuxiliaryUnitJob],
114
+ ]:
115
+ """Converts a function to a controller which can be added to an experiment.
116
+
117
+ Calling the wrapped function would return an xm.JobGenerator which would run
118
+ it as auxiliary unit on the specified executor.
119
+
120
+ Args:
121
+ executable: An executable that has a Python entrypoint with all the necesarry dependencies.
122
+ executor: The executor to launch the controller job on.
123
+ controller_name: Name of the parameter controller job.
124
+ controller_args: Mapping of flag names and values to be used by the XM
125
+ client running inside the parameter controller job.
126
+ controller_env_vars: Mapping of env variable names and values to be passed
127
+ to the parameter controller job.
128
+
129
+ Returns:
130
+ A decorator to be applied to the function.
131
+ """
132
+
133
+ def decorator(
134
+ f: Callable[Concatenate[SlurmExperiment, P], T]
135
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
136
+ ) -> Callable[P, xm.AuxiliaryUnitJob]:
137
+ @functools.wraps(f)
138
+ def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
139
+ # Modify the function to read the experiment from the API so that it can be pickled.
140
+
141
+ async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
142
+ experiment_id = aux_unit.experiment.experiment_id
143
+
144
+ async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
145
+ if asyncio.iscoroutinefunction(f):
146
+ return await f(aux_unit.experiment, *args, **kwargs)
147
+ else:
148
+ return f(aux_unit.experiment, *args, **kwargs)
149
+
150
+ async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
151
+ async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
152
+ if asyncio.iscoroutinefunction(f):
153
+ return await f(exp, *args, **kwargs)
154
+ else:
155
+ return f(exp, *args, **kwargs)
156
+
157
+ remote_controller_serialized = base64.urlsafe_b64encode(
158
+ zlib.compress(
159
+ cloudpickle.dumps(
160
+ functools.partial(remote_controller, *args, **kwargs),
161
+ )
162
+ )
163
+ )
164
+
165
+ parameter_controller_executable = dataclasses.replace(
166
+ executable,
167
+ args=xm.merge_args(
168
+ job_blocks.get_args_for_python_entrypoint(
169
+ xm.ModuleName("xm_slurm.scripts._cloudpickle")
170
+ ),
171
+ xm.SequentialArgs.from_collection({
172
+ "cloudpickled_fn": remote_controller_serialized.decode("ascii"),
173
+ }),
174
+ xm.SequentialArgs.from_collection(controller_args),
175
+ ),
176
+ env_vars=controller_env_vars or {},
177
+ )
178
+
179
+ await aux_unit.add(
180
+ xm.Job(
181
+ executor=executor,
182
+ executable=parameter_controller_executable,
183
+ name=controller_name,
184
+ )
185
+ )
186
+
187
+ # Launch local parameter controller and monitor for when it starts running
188
+ # so we can kill the local controller.
189
+ if controller_mode is ParameterControllerMode.AUTO:
190
+ aux_unit._create_task(
191
+ _monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
192
+ )
193
+
194
+ return xm.AuxiliaryUnitJob(
195
+ job_generator,
196
+ importance=xm.Importance.HIGH,
197
+ termination_delay_secs=0, # TODO: add support for termination delay.?
198
+ )
199
+
200
+ return make_controller
201
+
202
+ return decorator
xm_slurm/job_blocks.py CHANGED
@@ -1,6 +1,13 @@
1
+ from typing import Mapping, TypedDict
2
+
1
3
  from xmanager import xm
2
4
 
3
5
 
6
+ class JobArgs(TypedDict, total=False):
7
+ args: xm.UserArgs
8
+ env_vars: Mapping[str, str]
9
+
10
+
4
11
  def get_args_for_python_entrypoint(
5
12
  entrypoint: xm.ModuleName | xm.CommandList,
6
13
  ) -> xm.SequentialArgs:
xm_slurm/packageables.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import importlib.resources as resources
2
2
  import pathlib
3
3
  import sys
4
- from typing import Mapping, Sequence
4
+ from typing import Literal, Mapping, Sequence
5
5
 
6
6
  import immutabledict
7
7
  from xmanager import xm
@@ -29,8 +29,8 @@ def docker_image(
29
29
  return xm.Packageable(
30
30
  executor_spec=SlurmSpec(),
31
31
  executable_spec=DockerImage(image=image),
32
- args=args,
33
- env_vars=env_vars,
32
+ args=xm.SequentialArgs.from_collection(args),
33
+ env_vars=dict(env_vars),
34
34
  )
35
35
 
36
36
 
@@ -40,7 +40,7 @@ def docker_container(
40
40
  dockerfile: pathlib.Path | None = None,
41
41
  context: pathlib.Path | None = None,
42
42
  target: str | None = None,
43
- ssh: list[str] | None = None,
43
+ ssh: Sequence[str] | Literal[True] | None = None,
44
44
  build_args: Mapping[str, str] = immutabledict.immutabledict(),
45
45
  cache_from: str | Sequence[str] | None = None,
46
46
  labels: Mapping[str, str] = immutabledict.immutabledict(),
@@ -55,7 +55,7 @@ def docker_container(
55
55
  dockerfile: The path to the dockerfile.
56
56
  context: The path to the docker context.
57
57
  target: The docker build target.
58
- ssh: A list of SSH sockets/keys for the docker build step.
58
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
59
59
  build_args: Build arguments to docker.
60
60
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
61
61
  labels: The container labels.
@@ -71,8 +71,12 @@ def docker_container(
71
71
  if dockerfile is None:
72
72
  dockerfile = context / "Dockerfile"
73
73
  dockerfile = dockerfile.resolve()
74
+
74
75
  if ssh is None:
75
76
  ssh = []
77
+ elif ssh is True:
78
+ ssh = ["default"]
79
+
76
80
  if cache_from is None and isinstance(executor_spec, SlurmSpec):
77
81
  cache_from = executor_spec.tag
78
82
  if cache_from is None:
@@ -92,8 +96,8 @@ def docker_container(
92
96
  workdir=workdir,
93
97
  labels=labels,
94
98
  ),
95
- args=args,
96
- env_vars=env_vars,
99
+ args=xm.SequentialArgs.from_collection(args),
100
+ env_vars=dict(env_vars),
97
101
  )
98
102
 
99
103
 
@@ -104,14 +108,18 @@ def python_container(
104
108
  context: pathlib.Path | None = None,
105
109
  requirements: pathlib.Path | None = None,
106
110
  base_image: str = "docker.io/python:{major}.{minor}-slim",
111
+ extra_system_packages: Sequence[str] = (),
112
+ extra_python_packages: Sequence[str] = (),
107
113
  cache_from: str | Sequence[str] | None = None,
108
114
  labels: Mapping[str, str] = immutabledict.immutabledict(),
115
+ ssh: Sequence[str] | Literal[True] | None = None,
109
116
  args: xm.UserArgs | None = None,
110
117
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
111
118
  ) -> xm.Packageable:
112
119
  """Creates a Python container from a base image using pip from a `requirements.txt` file.
113
120
 
114
121
  NOTE: The base image will use the Python version of the current interpreter.
122
+ NOTE: uv is used to install packages from `requirements`.
115
123
 
116
124
  Args:
117
125
  executor_spec: The executor specification, where will the container be stored at.
@@ -119,8 +127,11 @@ def python_container(
119
127
  context: The path to the docker context.
120
128
  requirements: The path to the pip requirements file.
121
129
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
130
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
131
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
122
132
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
123
133
  labels: The container labels.
134
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
124
135
  args: The user arguments to pass to the executable.
125
136
  env_vars: The environment variables to pass to the executable.
126
137
 
@@ -149,11 +160,11 @@ def python_container(
149
160
  executor_spec=executor_spec,
150
161
  dockerfile=dockerfile,
151
162
  context=context,
163
+ ssh=ssh,
152
164
  build_args={
153
165
  "PIP_REQUIREMENTS": requirements.relative_to(context).as_posix(),
154
- "PYTHON_MAJOR": str(sys.version_info.major),
155
- "PYTHON_MINOR": str(sys.version_info.minor),
156
- "PYTHON_MICRO": str(sys.version_info.micro),
166
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
167
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
157
168
  "BASE_IMAGE": base_image.format_map({
158
169
  "major": sys.version_info.major,
159
170
  "minor": sys.version_info.minor,
@@ -178,6 +189,7 @@ def mamba_container(
178
189
  base_image: str = "gcr.io/distroless/base-debian10",
179
190
  cache_from: str | Sequence[str] | None = None,
180
191
  labels: Mapping[str, str] = immutabledict.immutabledict(),
192
+ ssh: Sequence[str] | Literal[True] | None = None,
181
193
  args: xm.UserArgs | None = None,
182
194
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
183
195
  ) -> xm.Packageable:
@@ -193,6 +205,7 @@ def mamba_container(
193
205
  base_image: The base image to use.
194
206
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
195
207
  labels: The container labels.
208
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
196
209
  args: The user arguments to pass to the executable.
197
210
  env_vars: The environment variables to pass to the executable.
198
211
 
@@ -221,6 +234,7 @@ def mamba_container(
221
234
  executor_spec=executor_spec,
222
235
  dockerfile=dockerfile,
223
236
  context=context,
237
+ ssh=ssh,
224
238
  build_args={
225
239
  "CONDA_ENVIRONMENT": environment.relative_to(context).as_posix(),
226
240
  "BASE_IMAGE": base_image,
@@ -237,26 +251,32 @@ def mamba_container(
237
251
  conda_container = mamba_container
238
252
 
239
253
 
240
- def pdm_container(
254
+ def uv_container(
241
255
  *,
242
256
  executor_spec: xm.ExecutorSpec,
243
257
  entrypoint: xm.ModuleName | xm.CommandList,
244
258
  context: pathlib.Path | None = None,
245
- base_image: str = "docker.io/python:{major}.{minor}-slim",
259
+ base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
260
+ extra_system_packages: Sequence[str] = (),
261
+ extra_python_packages: Sequence[str] = (),
246
262
  cache_from: str | Sequence[str] | None = None,
247
263
  labels: Mapping[str, str] = immutabledict.immutabledict(),
264
+ ssh: Sequence[str] | Literal[True] | None = None,
248
265
  args: xm.UserArgs | None = None,
249
266
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
250
267
  ) -> xm.Packageable:
251
- """Creates a Python container from a base image using pdm from a `pdm.lock` file.
268
+ """Creates a Python container from a base image using uv from a `uv.lock` file.
252
269
 
253
270
  Args:
254
271
  executor_spec: The executor specification, where will the container be stored at.
255
272
  entrypoint: The entrypoint to run in the container.
256
273
  context: The path to the docker context.
257
274
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
275
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
276
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
258
277
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
259
278
  labels: The container labels.
279
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
260
280
  args: The user arguments to pass to the executable.
261
281
  env_vars: The environment variables to pass to the executable.
262
282
 
@@ -268,20 +288,22 @@ def pdm_container(
268
288
  if context is None:
269
289
  context = utils.find_project_root()
270
290
  context = context.resolve()
271
- if not (context / "pdm.lock").exists():
272
- raise ValueError(f"PDM lockfile `{context / 'pdm.lock'}` doesn't exist.")
291
+ if not (context / "pyproject.toml").exists():
292
+ raise ValueError(f"Python project file `{context / 'pyproject.toml'}` doesn't exist.")
293
+ if not (context / "uv.lock").exists():
294
+ raise ValueError(f"UV lock file `{context / 'uv.lock'}` doesn't exist.")
273
295
 
274
296
  with resources.as_file(
275
- resources.files("xm_slurm.templates").joinpath("docker/pdm.Dockerfile")
297
+ resources.files("xm_slurm.templates").joinpath("docker/uv.Dockerfile")
276
298
  ) as dockerfile:
277
299
  return docker_container(
278
300
  executor_spec=executor_spec,
279
301
  dockerfile=dockerfile,
280
302
  context=context,
303
+ ssh=ssh,
281
304
  build_args={
282
- "PYTHON_MAJOR": str(sys.version_info.major),
283
- "PYTHON_MINOR": str(sys.version_info.minor),
284
- "PYTHON_MICRO": str(sys.version_info.micro),
305
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
306
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
285
307
  "BASE_IMAGE": base_image.format_map({
286
308
  "major": sys.version_info.major,
287
309
  "minor": sys.version_info.minor,
@@ -291,7 +313,7 @@ def pdm_container(
291
313
  cache_from=cache_from,
292
314
  labels=labels,
293
315
  # We must specify the workdir manually for apptainer support
294
- workdir=pathlib.Path("/workspace/src"),
316
+ workdir=pathlib.Path("/workspace"),
295
317
  args=args,
296
318
  env_vars=env_vars,
297
319
  )