xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (42) hide show
  1. xm_slurm/__init__.py +6 -2
  2. xm_slurm/api.py +301 -34
  3. xm_slurm/batching.py +4 -4
  4. xm_slurm/config.py +105 -55
  5. xm_slurm/constants.py +19 -0
  6. xm_slurm/contrib/__init__.py +0 -0
  7. xm_slurm/contrib/clusters/__init__.py +47 -13
  8. xm_slurm/contrib/clusters/drac.py +34 -16
  9. xm_slurm/dependencies.py +171 -0
  10. xm_slurm/executables.py +34 -22
  11. xm_slurm/execution.py +305 -107
  12. xm_slurm/executors.py +8 -12
  13. xm_slurm/experiment.py +601 -168
  14. xm_slurm/experimental/parameter_controller.py +202 -0
  15. xm_slurm/job_blocks.py +7 -0
  16. xm_slurm/packageables.py +42 -20
  17. xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
  18. xm_slurm/packaging/router.py +3 -1
  19. xm_slurm/packaging/utils.py +9 -81
  20. xm_slurm/resources.py +28 -4
  21. xm_slurm/scripts/_cloudpickle.py +28 -0
  22. xm_slurm/scripts/cli.py +52 -0
  23. xm_slurm/status.py +9 -0
  24. xm_slurm/templates/docker/mamba.Dockerfile +4 -2
  25. xm_slurm/templates/docker/python.Dockerfile +18 -10
  26. xm_slurm/templates/docker/uv.Dockerfile +35 -0
  27. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
  28. xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
  29. xm_slurm/templates/slurm/job.bash.j2 +4 -3
  30. xm_slurm/types.py +23 -0
  31. xm_slurm/utils.py +18 -10
  32. xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
  33. xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
  34. {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
  35. xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
  36. xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
  37. xm_slurm/packaging/docker/__init__.py +0 -75
  38. xm_slurm/packaging/docker/abc.py +0 -112
  39. xm_slurm/packaging/docker/cloud.py +0 -503
  40. xm_slurm/templates/docker/pdm.Dockerfile +0 -31
  41. xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
  42. xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
@@ -1,6 +1,9 @@
1
1
  import base64
2
+ import collections.abc
2
3
  import dataclasses
3
4
  import enum
5
+ import functools
6
+ import hashlib
4
7
  import json
5
8
  import logging
6
9
  import os
@@ -9,35 +12,36 @@ import shlex
9
12
  import shutil
10
13
  import subprocess
11
14
  import tempfile
12
- import typing
13
- from typing import Sequence
15
+ from typing import Hashable, Literal, Mapping, Sequence
14
16
 
17
+ import jinja2 as j2
15
18
  from xmanager import xm
16
19
 
17
20
  from xm_slurm.executables import (
18
21
  Dockerfile,
22
+ DockerImage,
19
23
  ImageURI,
20
24
  RemoteImage,
21
25
  RemoteRepositoryCredentials,
22
26
  )
23
27
  from xm_slurm.executors import SlurmSpec
28
+ from xm_slurm.packaging import registry
24
29
  from xm_slurm.packaging import utils as packaging_utils
25
- from xm_slurm.packaging.docker.abc import (
26
- DockerBakeCommand,
27
- DockerClient,
28
- DockerVersionCommand,
29
- )
30
30
  from xm_slurm.packaging.registry import IndexedContainer
31
31
 
32
+ logger = logging.getLogger(__name__)
33
+
32
34
 
33
- class LocalDockerClient(DockerClient):
34
- """Build Docker images locally."""
35
+ def _hash_digest(obj: Hashable) -> str:
36
+ return hashlib.sha256(repr(obj).encode()).hexdigest()
35
37
 
38
+
39
+ class DockerClient:
36
40
  class Builder(enum.Enum):
37
41
  BUILDKIT = enum.auto()
38
42
  BUILDAH = enum.auto()
39
43
 
40
- def __init__(self):
44
+ def __init__(self) -> None:
41
45
  if "XM_DOCKER_CLIENT" in os.environ:
42
46
  client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
43
47
  elif shutil.which("docker"):
@@ -48,12 +52,11 @@ class LocalDockerClient(DockerClient):
48
52
  raise RuntimeError("No Docker client found.")
49
53
  self._client_call = client_call
50
54
 
51
- version_command = DockerVersionCommand()
52
55
  backend_version = packaging_utils.run_command(
53
- xm.merge_args(self._client_call, version_command.to_args()), return_stdout=True
56
+ xm.merge_args(self._client_call, ["buildx", "version"]), return_stdout=True
54
57
  )
55
58
  if backend_version.stdout.startswith("github.com/docker/buildx"):
56
- self._builder = LocalDockerClient.Builder.BUILDKIT
59
+ self._builder = DockerClient.Builder.BUILDKIT
57
60
  else:
58
61
  raise NotImplementedError(f"Unsupported Docker build backend: {backend_version}")
59
62
 
@@ -80,7 +83,7 @@ class LocalDockerClient(DockerClient):
80
83
  return None
81
84
 
82
85
  def _parse_credentials_from_config(
83
- config_path: pathlib.Path
86
+ config_path: pathlib.Path,
84
87
  ) -> RemoteRepositoryCredentials | None:
85
88
  """Parse credentials from the Docker configuration file."""
86
89
  if not config_path.exists():
@@ -116,21 +119,69 @@ class LocalDockerClient(DockerClient):
116
119
 
117
120
  # Attempt to parse credentials from the Docker or Podman configuration
118
121
  match self._builder:
119
- case LocalDockerClient.Builder.BUILDKIT:
122
+ case DockerClient.Builder.BUILDKIT:
120
123
  docker_config_path = (
121
124
  pathlib.Path(os.environ.get("DOCKER_CONFIG", "~/.docker")).expanduser()
122
125
  / "config.json"
123
126
  )
124
127
  return _parse_credentials_from_config(docker_config_path)
125
- case LocalDockerClient.Builder.BUILDAH:
128
+ case DockerClient.Builder.BUILDAH:
126
129
  podman_config_path = (
127
130
  pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")).expanduser()
128
131
  / "containers"
129
132
  / "auth.json"
130
133
  )
131
134
  return _parse_credentials_from_config(podman_config_path)
132
- case _:
133
- return None
135
+
136
+ @functools.cached_property
137
+ def _bake_template(self) -> j2.Template:
138
+ template_loader = j2.PackageLoader("xm_slurm", "templates/docker")
139
+ template_env = j2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=False)
140
+
141
+ return template_env.get_template("docker-bake.hcl.j2")
142
+
143
+ def _bake_args(
144
+ self,
145
+ *,
146
+ targets: str | Sequence[str] | None = None,
147
+ builder: str | None = None,
148
+ files: str | os.PathLike[str] | Sequence[os.PathLike[str] | str] | None = None,
149
+ load: bool = False,
150
+ cache: bool = True,
151
+ print: bool = False,
152
+ pull: bool = False,
153
+ push: bool = False,
154
+ metadata_file: str | os.PathLike[str] | None = None,
155
+ progress: Literal["auto", "plain", "tty"] = "auto",
156
+ set: Mapping[str, str] | None = None,
157
+ ) -> xm.SequentialArgs:
158
+ files = files
159
+ if files is None:
160
+ files = []
161
+ if not isinstance(files, collections.abc.Sequence):
162
+ files = [files]
163
+
164
+ targets = targets
165
+ if targets is None:
166
+ targets = []
167
+ elif isinstance(targets, str):
168
+ targets = [targets]
169
+ assert isinstance(targets, collections.abc.Sequence)
170
+
171
+ return xm.merge_args(
172
+ ["buildx", "bake"],
173
+ [f"--progress={progress}"],
174
+ [f"--builder={builder}"] if builder else [],
175
+ [f"--metadata-file={metadata_file}"] if metadata_file else [],
176
+ ["--print"] if print else [],
177
+ ["--push"] if push else [],
178
+ ["--pull"] if pull else [],
179
+ ["--load"] if load else [],
180
+ ["--no-cache"] if not cache else [],
181
+ [f"--file={file}" for file in files],
182
+ [f"--set={key}={value}" for key, value in set.items()] if set else [],
183
+ targets,
184
+ )
134
185
 
135
186
  def bake(
136
187
  self,
@@ -138,39 +189,40 @@ class LocalDockerClient(DockerClient):
138
189
  targets: Sequence[IndexedContainer[xm.Packageable]],
139
190
  ) -> list[IndexedContainer[RemoteImage]]:
140
191
  executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
141
- executors_by_executables = typing.cast(
142
- dict[Dockerfile, list[SlurmSpec]], executors_by_executables
143
- )
192
+ for executable, executors in executors_by_executables.items():
193
+ assert isinstance(
194
+ executable, Dockerfile
195
+ ), "All executables must be Dockerfiles when building Docker images."
196
+ assert all(
197
+ isinstance(executor, SlurmSpec) and executor.tag for executor in executors
198
+ ), "All executors must be SlurmSpecs with tags when building Docker images."
144
199
 
145
200
  with tempfile.TemporaryDirectory() as tempdir:
146
201
  hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
147
202
  metadata_file = pathlib.Path(tempdir) / "metadata.json"
148
203
 
149
204
  # Write HCL and bake it
205
+ # TODO(jfarebro): Need a better way to hash the executables
150
206
  hcl = self._bake_template.render(
151
207
  executables=executors_by_executables,
152
- hash=packaging_utils.hash_digest,
208
+ hash=_hash_digest,
153
209
  )
154
210
  hcl_file.write_text(hcl)
155
- logging.debug(hcl)
211
+ logger.debug(hcl)
156
212
 
157
213
  try:
158
- command = DockerBakeCommand(
159
- targets=list(
160
- set(
161
- [
162
- packaging_utils.hash_digest(target.value.executable_spec)
163
- for target in targets
164
- ]
165
- )
214
+ bake_command = xm.merge_args(
215
+ self._client_call,
216
+ self._bake_args(
217
+ targets=list(
218
+ set([_hash_digest(target.value.executable_spec) for target in targets])
219
+ ),
220
+ files=[hcl_file],
221
+ metadata_file=metadata_file,
222
+ pull=True,
223
+ push=True,
166
224
  ),
167
- files=[hcl_file],
168
- metadata_file=metadata_file,
169
- pull=True,
170
- push=True,
171
225
  )
172
-
173
- bake_command = xm.merge_args(self._client_call, command.to_args())
174
226
  packaging_utils.run_command(bake_command.to_list(), tty=True, check=True)
175
227
  except Exception as ex:
176
228
  raise RuntimeError(f"Failed to build Dockerfiles: {ex}") from ex
@@ -183,9 +235,7 @@ class LocalDockerClient(DockerClient):
183
235
  assert isinstance(target.value.executor_spec, SlurmSpec)
184
236
  assert target.value.executor_spec.tag
185
237
 
186
- executable_metadata = metadata[
187
- packaging_utils.hash_digest(target.value.executable_spec)
188
- ]
238
+ executable_metadata = metadata[_hash_digest(target.value.executable_spec)]
189
239
  uri = ImageURI(target.value.executor_spec.tag).with_digest(
190
240
  executable_metadata["containerimage.digest"]
191
241
  )
@@ -204,3 +254,48 @@ class LocalDockerClient(DockerClient):
204
254
  )
205
255
 
206
256
  return images
257
+
258
+
259
+ @functools.cache
260
+ def docker_client() -> DockerClient:
261
+ return DockerClient()
262
+
263
+
264
+ @registry.register(Dockerfile)
265
+ def _(
266
+ targets: Sequence[IndexedContainer[xm.Packageable]],
267
+ ) -> list[IndexedContainer[RemoteImage]]:
268
+ return docker_client().bake(targets=targets)
269
+
270
+
271
+ @registry.register(DockerImage)
272
+ def _(
273
+ targets: Sequence[IndexedContainer[xm.Packageable]],
274
+ ) -> list[IndexedContainer[RemoteImage]]:
275
+ """Build Docker images, this is essentially a passthrough."""
276
+ images = []
277
+ client = docker_client()
278
+ for target in targets:
279
+ assert isinstance(target.value.executable_spec, DockerImage)
280
+ assert isinstance(target.value.executor_spec, SlurmSpec)
281
+ if target.value.executor_spec.tag is not None:
282
+ raise ValueError(
283
+ "Executable `DockerImage` should not be tagged via `SlurmSpec`. "
284
+ "The image URI is provided by the `DockerImage` itself."
285
+ )
286
+
287
+ uri = ImageURI(target.value.executable_spec.image)
288
+ images.append(
289
+ dataclasses.replace(
290
+ target,
291
+ value=RemoteImage( # type: ignore
292
+ image=str(uri),
293
+ workdir=target.value.executable_spec.workdir,
294
+ args=target.value.args,
295
+ env_vars=target.value.env_vars,
296
+ credentials=client.credentials(hostname=uri.domain),
297
+ ),
298
+ )
299
+ )
300
+
301
+ return images
@@ -10,6 +10,8 @@ from xm_slurm.packaging import registry
10
10
 
11
11
  IndexedContainer = registry.IndexedContainer
12
12
 
13
+ logger = logging.getLogger(__name__)
14
+
13
15
 
14
16
  def package(
15
17
  packageables: Sequence[xm.Packageable],
@@ -39,7 +41,7 @@ def package(
39
41
  # TODO(jfarebro): Could make this async as well...?
40
42
  with console.status("[magenta] :package: Packaging executables..."):
41
43
  for executable_spec_type, targets_for_type in targets_by_type.items():
42
- logging.info(f"Packaging {len(targets_for_type)} {executable_spec_type!r} targets.")
44
+ logger.info(f"Packaging {len(targets_for_type)} {executable_spec_type!r} targets.")
43
45
  targets.extend(registry.route(executable_spec_type, targets_for_type))
44
46
 
45
47
  console.print(
@@ -1,7 +1,4 @@
1
1
  import collections
2
- import concurrent.futures
3
- import functools
4
- import hashlib
5
2
  import logging
6
3
  import os
7
4
  import pathlib
@@ -10,8 +7,7 @@ import re
10
7
  import select
11
8
  import shutil
12
9
  import subprocess
13
- import typing
14
- from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
10
+ from typing import ParamSpec, Sequence, TypeVar
15
11
 
16
12
  from xmanager import xm
17
13
 
@@ -21,16 +17,7 @@ T = TypeVar("T")
21
17
  P = ParamSpec("P")
22
18
  ReturnT = TypeVar("ReturnT")
23
19
 
24
-
25
- def hash_digest(obj: Hashable) -> str:
26
- # obj_hash = hash(obj)
27
- # unsigned_obj_hash = obj_hash.from_bytes(
28
- # obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
29
- # "big",
30
- # signed=False,
31
- # )
32
- # return hex(unsigned_obj_hash).removeprefix("0x")
33
- return hashlib.sha256(repr(obj).encode()).hexdigest()
20
+ logger = logging.getLogger(__name__)
34
21
 
35
22
 
36
23
  def collect_executors_by_executable(
@@ -42,19 +29,8 @@ def collect_executors_by_executable(
42
29
  return executors_by_executable
43
30
 
44
31
 
45
- def parallel_map(
46
- f: Callable[Concatenate[T, P], ReturnT],
47
- ) -> Callable[Concatenate[Sequence[T], P], list[ReturnT]]:
48
- @functools.wraps(f)
49
- def decorator(sequence: Sequence[T], *args: P.args, **kwargs: P.kwargs) -> list[ReturnT]:
50
- with concurrent.futures.ThreadPoolExecutor() as executor:
51
- return list(executor.map(lambda x: f(x, *args, **kwargs), sequence))
52
-
53
- return decorator
54
-
55
-
56
32
  # Cursor commands to filter out from the command data stream
57
- cursor_commands_regex = re.compile(
33
+ _CURSOR_ESCAPE_SEQUENCES_REGEX = re.compile(
58
34
  rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
59
35
  rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
60
36
  rb"|\x1b\[s" # Matches cursor save position (CSI s)
@@ -64,54 +40,6 @@ cursor_commands_regex = re.compile(
64
40
  )
65
41
 
66
42
 
67
- @typing.overload
68
- def run_command(
69
- args: Sequence[str] | xm.SequentialArgs,
70
- env: dict[str, str] | None = ...,
71
- tty: bool = ...,
72
- cwd: str | os.PathLike[str] | None = ...,
73
- check: bool = ...,
74
- return_stdout: Literal[False] = False,
75
- return_stderr: Literal[False] = False,
76
- ) -> subprocess.CompletedProcess[None]: ...
77
-
78
-
79
- @typing.overload
80
- def run_command(
81
- args: Sequence[str] | xm.SequentialArgs,
82
- env: dict[str, str] | None = ...,
83
- tty: bool = ...,
84
- cwd: str | os.PathLike[str] | None = ...,
85
- check: bool = ...,
86
- return_stdout: Literal[True] = True,
87
- return_stderr: Literal[False] = False,
88
- ) -> subprocess.CompletedProcess[str]: ...
89
-
90
-
91
- @typing.overload
92
- def run_command(
93
- args: Sequence[str] | xm.SequentialArgs,
94
- env: dict[str, str] | None = ...,
95
- tty: bool = ...,
96
- cwd: str | os.PathLike[str] | None = ...,
97
- check: bool = ...,
98
- return_stdout: Literal[False] = False,
99
- return_stderr: Literal[True] = True,
100
- ) -> subprocess.CompletedProcess[str]: ...
101
-
102
-
103
- @typing.overload
104
- def run_command(
105
- args: Sequence[str] | xm.SequentialArgs,
106
- env: dict[str, str] | None = ...,
107
- tty: bool = ...,
108
- cwd: str | os.PathLike[str] | None = ...,
109
- check: bool = ...,
110
- return_stdout: Literal[True] = True,
111
- return_stderr: Literal[True] = True,
112
- ) -> subprocess.CompletedProcess[str]: ...
113
-
114
-
115
43
  def run_command(
116
44
  args: Sequence[str] | xm.SequentialArgs,
117
45
  env: dict[str, str] | None = None,
@@ -120,7 +48,7 @@ def run_command(
120
48
  check: bool = False,
121
49
  return_stdout: bool = False,
122
50
  return_stderr: bool = False,
123
- ) -> subprocess.CompletedProcess[str] | subprocess.CompletedProcess[None]:
51
+ ) -> subprocess.CompletedProcess[str]:
124
52
  if isinstance(args, xm.SequentialArgs):
125
53
  args = args.to_list()
126
54
  args = list(args)
@@ -134,8 +62,8 @@ def run_command(
134
62
  if executable.name == "docker" and args[1] == "buildx":
135
63
  subprocess_env |= {"DOCKER_CLI_EXPERIMENTAL": "enabled"}
136
64
 
137
- logging.debug(f"env: {subprocess_env}")
138
- logging.debug(f"command: {' '.join(args)}")
65
+ logger.debug(f"env: {subprocess_env}")
66
+ logger.debug(f"command: {' '.join(args)}")
139
67
 
140
68
  stdout_master, stdout_slave = pty.openpty()
141
69
  stderr_master, stderr_slave = pty.openpty()
@@ -171,7 +99,7 @@ def run_command(
171
99
  fds.remove(fd)
172
100
  continue
173
101
 
174
- data = re.sub(cursor_commands_regex, b"", data)
102
+ data = _CURSOR_ESCAPE_SEQUENCES_REGEX.sub(b"", data)
175
103
 
176
104
  if fd == stdout_master:
177
105
  if return_stdout:
@@ -186,8 +114,8 @@ def run_command(
186
114
  else:
187
115
  raise RuntimeError("Unexpected file descriptor")
188
116
 
189
- stdout = stdout_data.decode(errors="replace") if stdout_data else None
190
- stderr = stderr_data.decode(errors="replace") if stderr_data else None
117
+ stdout = stdout_data.decode(errors="replace") if stdout_data else ""
118
+ stderr = stderr_data.decode(errors="replace") if stderr_data else ""
191
119
 
192
120
  retcode = process.poll()
193
121
  assert retcode is not None
xm_slurm/resources.py CHANGED
@@ -36,20 +36,35 @@ class ResourceType(enum.IntEnum):
36
36
 
37
37
 
38
38
  AcceleratorType = set([
39
+ ResourceType.RTX8000,
39
40
  ResourceType.P4,
40
41
  ResourceType.P100,
42
+ ResourceType.P100_16GIB,
41
43
  ResourceType.V100,
44
+ ResourceType.V100_32GIB,
42
45
  ResourceType.A100,
43
46
  ResourceType.A100_80GIB,
47
+ ResourceType.A5000,
44
48
  ResourceType.A6000,
45
49
  ResourceType.H100,
46
- ResourceType.GPU,
47
50
  ])
48
51
 
52
+ assert AcceleratorType | {
53
+ ResourceType.CPU,
54
+ ResourceType.MEMORY,
55
+ ResourceType.DISK,
56
+ ResourceType.GPU,
57
+ } == set(ResourceType.__members__.values()), "Resource types are not exhaustive."
58
+
49
59
 
50
60
  ResourceQuantity = int | float
51
61
 
52
62
 
63
+ class FeatureType(enum.IntEnum):
64
+ NVIDIA_MIG = 1
65
+ NVIDIA_NVLINK = 2
66
+
67
+
53
68
  class JobRequirements:
54
69
  replicas: int
55
70
  location: str | None
@@ -80,8 +95,10 @@ class JobRequirements:
80
95
  case str():
81
96
  resource = resource_name
82
97
 
83
- if resource in AcceleratorType or (
84
- isinstance(resource, str) and resource.startswith("gpu")
98
+ if (
99
+ resource in AcceleratorType
100
+ or resource == ResourceType.GPU
101
+ or (isinstance(resource, str) and resource.startswith("gpu"))
85
102
  ):
86
103
  if self.accelerator is not None:
87
104
  raise ValueError("Accelerator already set.")
@@ -92,6 +109,8 @@ class JobRequirements:
92
109
  self.task_requirements[resource] = value
93
110
 
94
111
  def to_directives(self) -> list[str]:
112
+ if self.cluster is None:
113
+ raise ValueError("Cannnot derive Slurm directives for requirements without a cluster.")
95
114
  directives = []
96
115
 
97
116
  for resource, value in self.task_requirements.items():
@@ -112,7 +131,12 @@ class JobRequirements:
112
131
  directives.append(f"--gpus-per-task={value}")
113
132
  case ResourceType() if resource in AcceleratorType:
114
133
  assert isinstance(value, int), "Accelerator must be an integer"
115
- directives.append(f"--gpus-per-task={resource.name.lower()}:{value}")
134
+ resource_type = self.cluster.resources.get(resource, None)
135
+ if resource_type is None:
136
+ raise ValueError(
137
+ f"Cluster {self.cluster.name} does not map resource type {resource!r}."
138
+ )
139
+ directives.append(f"--gpus-per-task={resource_type}:{value}")
116
140
  case str():
117
141
  directives.append(f"--gres={resource}:{value}")
118
142
 
@@ -0,0 +1,28 @@
1
+ import base64
2
+ import logging
3
+ import zlib
4
+
5
+ import cloudpickle
6
+ from absl import app, flags
7
+ from xmanager import xm
8
+
9
+ CLOUDPICKLED_FN = flags.DEFINE_string(
10
+ "cloudpickled_fn", None, "Base64 encoded cloudpickled function", required=True
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @xm.run_in_asyncio_loop
17
+ async def main(argv):
18
+ del argv
19
+
20
+ logger.info("Loading cloudpickled function...")
21
+ cloudpickled_fn = zlib.decompress(base64.urlsafe_b64decode(CLOUDPICKLED_FN.value))
22
+ function = cloudpickle.loads(cloudpickled_fn)
23
+ logger.info("Running cloudpickled function...")
24
+ await function()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ app.run(main)
@@ -0,0 +1,52 @@
1
+ import argparse
2
+
3
+ from xmanager import xm
4
+
5
+ import xm_slurm
6
+ from xm_slurm.console import console
7
+
8
+
9
+ async def logs(
10
+ experiment_id: int,
11
+ wid: int,
12
+ *,
13
+ follow: bool = True,
14
+ num_lines: int = 10,
15
+ block_size: int = 1024,
16
+ ):
17
+ wu = xm_slurm.get_experiment(experiment_id).work_units()[wid]
18
+ async for log in wu.logs(num_lines=num_lines, block_size=block_size, wait=True, follow=follow):
19
+ console.print(log, end="\n")
20
+
21
+
22
+ @xm.run_in_asyncio_loop
23
+ async def main():
24
+ parser = argparse.ArgumentParser(description="XManager.")
25
+ subparsers = parser.add_subparsers(dest="subcommand", required=True)
26
+
27
+ logs_parser = subparsers.add_parser("logs", help="Display logs for a specific experiment.")
28
+ logs_parser.add_argument("xid", type=int, help="Experiment ID.")
29
+ logs_parser.add_argument("wid", type=int, help="Work Unit ID.")
30
+ logs_parser.add_argument(
31
+ "-n",
32
+ "--n-lines",
33
+ type=int,
34
+ default=50,
35
+ help="Number of lines to display from the end of the log file.",
36
+ )
37
+ logs_parser.add_argument(
38
+ "-f",
39
+ "--follow",
40
+ default=True,
41
+ action="store_true",
42
+ help="Follow the log file as it is updated.",
43
+ )
44
+
45
+ args = parser.parse_args()
46
+ match args.subcommand:
47
+ case "logs":
48
+ await logs(args.xid, args.wid, follow=args.follow, num_lines=args.n_lines)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main() # type: ignore
xm_slurm/status.py CHANGED
@@ -106,6 +106,15 @@ SlurmFailedJobStates = set([
106
106
  ])
107
107
  SlurmCancelledJobStates = set([SlurmJobState.CANCELLED])
108
108
 
109
+ assert (
110
+ SlurmPendingJobStates
111
+ | SlurmRunningJobStates
112
+ | SlurmActiveJobStates
113
+ | SlurmCompletedJobStates
114
+ | SlurmFailedJobStates
115
+ | SlurmCancelledJobStates
116
+ ) == set(SlurmJobState.__members__.values()), "Slurm job states are not exhaustive."
117
+
109
118
 
110
119
  class SlurmWorkUnitStatusEnum(enum.IntEnum):
111
120
  """Status of a local experiment job."""
@@ -1,7 +1,7 @@
1
1
  # syntax=docker/dockerfile:1.4
2
2
  ARG BASE_IMAGE=gcr.io/distroless/base-debian10
3
3
 
4
- FROM docker.io/mambaorg/micromamba:jammy as mamba
4
+ FROM docker.io/mambaorg/micromamba:bookworm-slim as mamba
5
5
  ARG CONDA_ENVIRONMENT=environment.yml
6
6
 
7
7
  USER root
@@ -9,7 +9,9 @@ USER root
9
9
  COPY $CONDA_ENVIRONMENT /tmp/
10
10
 
11
11
  # Setup mamba environment
12
- RUN --mount=type=cache,target=/opt/conda/pkgs --mount=type=cache,target=/root/.cache/pip \
12
+ RUN --mount=type=cache,target=/opt/conda/pkgs \
13
+ --mount=type=cache,target=/root/.cache/pip \
14
+ --mount=type=ssh \
13
15
  micromamba create --yes --always-copy --no-pyc --prefix /opt/env --file /tmp/environment.yml
14
16
 
15
17
  RUN find /opt/env/ -follow -type f -name '*.a' -delete && \
@@ -1,24 +1,32 @@
1
1
  # syntax=docker/dockerfile:1.4
2
- ARG BASE_IMAGE=docker.io/python:3.10-slim
2
+ ARG BASE_IMAGE=docker.io/python:3.10-slim-bookworm
3
3
  FROM $BASE_IMAGE as builder
4
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
5
+
6
+ ARG EXTRA_SYSTEM_PACKAGES=""
7
+ ARG EXTRA_PYTHON_PACKAGES=""
8
+
9
+ ENV UV_PYTHON_DOWNLOADS=0
10
+ ENV UV_COMPILE_BYTECODE=1
11
+ ENV UV_LINK_MODE=copy
12
+
13
+ WORKDIR /workspace
4
14
 
5
15
  RUN apt-get update \
6
16
  && apt-get install -y --no-install-recommends \
7
- git \
17
+ git $EXTRA_SYSTEM_PACKAGES \
8
18
  && rm -rf /var/lib/apt/lists/*
9
19
 
10
20
  # Install and update necesarry global Python packages
11
- RUN pip install -U pip setuptools wheel pysocks
21
+ RUN uv pip install --system pysocks $EXTRA_PYTHON_PACKAGES
12
22
 
13
23
  ARG PIP_REQUIREMENTS=requirements.txt
14
24
 
15
- RUN python -m venv --copies --upgrade --upgrade-deps --system-site-packages /venv
16
- COPY $PIP_REQUIREMENTS /tmp/requirements.txt
17
- RUN --mount=type=cache,target=/root/.cache/pip \
18
- PIP_CACHE_DIR=/root/.cache/pip /venv/bin/pip install -r /tmp/requirements.txt \
19
- && rm -rf /tmp/requirements.txt
25
+ RUN --mount=type=cache,target=/root/.cache/uv \
26
+ --mount=type=bind,source=$PIP_REQUIREMENTS,target=requirements.txt \
27
+ --mount=type=ssh \
28
+ uv pip install --system --requirement requirements.txt
20
29
 
21
30
  COPY --link . /workspace
22
- WORKDIR /workspace
23
31
 
24
- ENTRYPOINT [ "/venv/bin/python" ]
32
+ ENTRYPOINT [ "python" ]