xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +99 -54
- xm_slurm/constants.py +15 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +22 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/executables.py +19 -7
- xm_slurm/execution.py +86 -38
- xm_slurm/experiment.py +273 -131
- xm_slurm/experimental/parameter_controller.py +200 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +45 -18
- xm_slurm/packaging/docker/__init__.py +5 -11
- xm_slurm/packaging/docker/local.py +13 -12
- xm_slurm/packaging/utils.py +7 -55
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +7 -0
- xm_slurm/templates/docker/mamba.Dockerfile +3 -1
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.0.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.0.dist-info/RECORD +42 -0
- {xmanager_slurm-0.3.1.dist-info → xmanager_slurm-0.4.0.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.0.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.1.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.1.dist-info/RECORD +0 -38
|
@@ -19,18 +19,22 @@ def mila(
|
|
|
19
19
|
if mounts is None:
|
|
20
20
|
mounts = {
|
|
21
21
|
"/network/scratch/${USER:0:1}/$USER": "/scratch",
|
|
22
|
-
|
|
22
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
23
|
+
"/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
24
|
+
"/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
|
|
23
25
|
}
|
|
24
26
|
|
|
25
27
|
return config.SlurmClusterConfig(
|
|
26
28
|
name="mila",
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
ssh=config.SlurmSSHConfig(
|
|
30
|
+
user=user,
|
|
31
|
+
host="login.server.mila.quebec",
|
|
32
|
+
host_public_key=config.PublicKey(
|
|
33
|
+
"ssh-ed25519",
|
|
34
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
|
|
35
|
+
),
|
|
36
|
+
port=2222,
|
|
32
37
|
),
|
|
33
|
-
port=2222,
|
|
34
38
|
runtime=config.ContainerRuntime.SINGULARITY,
|
|
35
39
|
partition=partition,
|
|
36
40
|
prolog="module load singularity",
|
|
@@ -39,14 +43,19 @@ def mila(
|
|
|
39
43
|
"SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
|
|
40
44
|
"SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
41
45
|
"SCRATCH": "/scratch",
|
|
42
|
-
|
|
46
|
+
# TODO: move this somewhere common to all cluster configs.
|
|
47
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
43
48
|
},
|
|
44
49
|
mounts=mounts,
|
|
45
50
|
resources={
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
+
resources.ResourceType.RTX8000: "rtx8000",
|
|
52
|
+
resources.ResourceType.V100: "v100",
|
|
53
|
+
resources.ResourceType.A100: "a100",
|
|
54
|
+
resources.ResourceType.A100_80GIB: "a100l",
|
|
55
|
+
resources.ResourceType.A6000: "a6000",
|
|
56
|
+
},
|
|
57
|
+
features={
|
|
58
|
+
resources.FeatureType.NVIDIA_MIG: "mig",
|
|
59
|
+
resources.FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
51
60
|
},
|
|
52
61
|
)
|
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
from typing import Literal
|
|
3
3
|
|
|
4
4
|
from xm_slurm import config
|
|
5
|
-
from xm_slurm.resources import ResourceType
|
|
5
|
+
from xm_slurm.resources import FeatureType, ResourceType
|
|
6
6
|
|
|
7
7
|
__all__ = ["narval", "beluga", "cedar", "graham"]
|
|
8
8
|
|
|
@@ -18,18 +18,26 @@ def _drac_cluster(
|
|
|
18
18
|
modules: list[str] | None = None,
|
|
19
19
|
proxy: Literal["submission-host"] | str | None = None,
|
|
20
20
|
mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
21
|
-
resources: dict[
|
|
21
|
+
resources: dict[ResourceType, str] | None = None,
|
|
22
|
+
features: dict[FeatureType, str] | None = None,
|
|
22
23
|
) -> config.SlurmClusterConfig:
|
|
23
24
|
"""DRAC Cluster."""
|
|
24
25
|
if mounts is None:
|
|
25
|
-
mounts = {
|
|
26
|
+
mounts = {
|
|
27
|
+
"/scratch/$USER": "/scratch",
|
|
28
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
29
|
+
"/home/$USER/.ssh": "/home/$USER/.ssh",
|
|
30
|
+
"/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
31
|
+
}
|
|
26
32
|
|
|
27
33
|
return config.SlurmClusterConfig(
|
|
28
34
|
name=name,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
ssh=config.SlurmSSHConfig(
|
|
36
|
+
user=user,
|
|
37
|
+
host=host,
|
|
38
|
+
host_public_key=host_public_key,
|
|
39
|
+
port=port,
|
|
40
|
+
),
|
|
33
41
|
account=account,
|
|
34
42
|
proxy=proxy,
|
|
35
43
|
runtime=config.ContainerRuntime.APPTAINER,
|
|
@@ -40,9 +48,12 @@ def _drac_cluster(
|
|
|
40
48
|
"APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
41
49
|
"_XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
|
|
42
50
|
"SCRATCH": "/scratch",
|
|
51
|
+
# TODO: move this somewhere common to all cluster configs.
|
|
52
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
43
53
|
},
|
|
44
54
|
mounts=mounts,
|
|
45
55
|
resources=resources or {},
|
|
56
|
+
features=features or {},
|
|
46
57
|
)
|
|
47
58
|
|
|
48
59
|
|
|
@@ -70,7 +81,11 @@ def narval(
|
|
|
70
81
|
mounts=mounts,
|
|
71
82
|
proxy=proxy,
|
|
72
83
|
modules=modules,
|
|
73
|
-
resources={"a100"
|
|
84
|
+
resources={ResourceType.A100: "a100"},
|
|
85
|
+
features={
|
|
86
|
+
FeatureType.NVIDIA_MIG: "a100mig",
|
|
87
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
88
|
+
},
|
|
74
89
|
)
|
|
75
90
|
|
|
76
91
|
|
|
@@ -98,7 +113,10 @@ def beluga(
|
|
|
98
113
|
mounts=mounts,
|
|
99
114
|
proxy=proxy,
|
|
100
115
|
modules=modules,
|
|
101
|
-
resources={"tesla_v100-sxm2-16gb"
|
|
116
|
+
resources={ResourceType.V100: "tesla_v100-sxm2-16gb"},
|
|
117
|
+
features={
|
|
118
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
119
|
+
},
|
|
102
120
|
)
|
|
103
121
|
|
|
104
122
|
|
|
@@ -120,9 +138,9 @@ def cedar(
|
|
|
120
138
|
account=account,
|
|
121
139
|
mounts=mounts,
|
|
122
140
|
resources={
|
|
123
|
-
"v100l"
|
|
124
|
-
"p100"
|
|
125
|
-
"p100l"
|
|
141
|
+
ResourceType.V100_32GIB: "v100l",
|
|
142
|
+
ResourceType.P100: "p100",
|
|
143
|
+
ResourceType.P100_16GIB: "p100l",
|
|
126
144
|
},
|
|
127
145
|
)
|
|
128
146
|
|
|
@@ -147,10 +165,10 @@ def graham(
|
|
|
147
165
|
mounts=mounts,
|
|
148
166
|
proxy=proxy,
|
|
149
167
|
resources={
|
|
150
|
-
"v100"
|
|
151
|
-
"p100"
|
|
152
|
-
"a100"
|
|
153
|
-
"a5000"
|
|
168
|
+
ResourceType.V100: "v100",
|
|
169
|
+
ResourceType.P100: "p100",
|
|
170
|
+
ResourceType.A100: "a100",
|
|
171
|
+
ResourceType.A5000: "a5000",
|
|
154
172
|
},
|
|
155
173
|
)
|
|
156
174
|
|
xm_slurm/executables.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import pathlib
|
|
3
|
-
import re
|
|
4
3
|
from typing import Mapping, NamedTuple, Sequence
|
|
5
4
|
|
|
6
5
|
from xmanager import xm
|
|
7
6
|
|
|
7
|
+
from xm_slurm import constants
|
|
8
|
+
|
|
8
9
|
|
|
9
10
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
10
11
|
class Dockerfile(xm.ExecutableSpec):
|
|
@@ -14,6 +15,7 @@ class Dockerfile(xm.ExecutableSpec):
|
|
|
14
15
|
dockerfile: The path to the Dockerfile.
|
|
15
16
|
context: The path to the Docker context.
|
|
16
17
|
target: The Docker build target.
|
|
18
|
+
ssh: A list of docker SSH sockets/keys.
|
|
17
19
|
build_args: Build arguments to docker.
|
|
18
20
|
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
19
21
|
workdir: The working directory in container.
|
|
@@ -29,6 +31,9 @@ class Dockerfile(xm.ExecutableSpec):
|
|
|
29
31
|
# Docker build target
|
|
30
32
|
target: str | None = None
|
|
31
33
|
|
|
34
|
+
# SSH sockets/keys for the docker build step.
|
|
35
|
+
ssh: Sequence[str] = dataclasses.field(default_factory=list)
|
|
36
|
+
|
|
32
37
|
# Build arguments to docker
|
|
33
38
|
build_args: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
34
39
|
|
|
@@ -56,6 +61,7 @@ class Dockerfile(xm.ExecutableSpec):
|
|
|
56
61
|
self.dockerfile,
|
|
57
62
|
self.context,
|
|
58
63
|
self.target,
|
|
64
|
+
tuple(sorted(self.ssh)),
|
|
59
65
|
tuple(sorted(self.build_args.items())),
|
|
60
66
|
tuple(sorted(self.cache_from)),
|
|
61
67
|
self.workdir,
|
|
@@ -87,11 +93,6 @@ class DockerImage(xm.ExecutableSpec):
|
|
|
87
93
|
return hash((self.image, self.workdir))
|
|
88
94
|
|
|
89
95
|
|
|
90
|
-
_IMAGE_URI_REGEX = re.compile(
|
|
91
|
-
r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
|
|
95
96
|
@dataclasses.dataclass
|
|
96
97
|
class ImageURI:
|
|
97
98
|
image: dataclasses.InitVar[str]
|
|
@@ -103,7 +104,7 @@ class ImageURI:
|
|
|
103
104
|
digest: str | None = dataclasses.field(init=False, default=None)
|
|
104
105
|
|
|
105
106
|
def __post_init__(self, image: str):
|
|
106
|
-
match =
|
|
107
|
+
match = constants.IMAGE_URI_REGEX.match(image)
|
|
107
108
|
if not match:
|
|
108
109
|
raise ValueError(f"Invalid OCI image URI: {image}")
|
|
109
110
|
groups = {k: v for k, v in match.groupdict().items() if v is not None}
|
|
@@ -199,3 +200,14 @@ class RemoteImage(xm.Executable):
|
|
|
199
200
|
@property
|
|
200
201
|
def name(self) -> str:
|
|
201
202
|
return str(self.image)
|
|
203
|
+
|
|
204
|
+
def __hash__(self) -> int:
|
|
205
|
+
return hash(
|
|
206
|
+
(
|
|
207
|
+
self.image,
|
|
208
|
+
self.workdir,
|
|
209
|
+
tuple(sorted(self.args.to_list())),
|
|
210
|
+
tuple(sorted(self.env_vars.items())),
|
|
211
|
+
self.credentials,
|
|
212
|
+
),
|
|
213
|
+
)
|
xm_slurm/execution.py
CHANGED
|
@@ -13,12 +13,14 @@ from typing import Any, Mapping, Sequence
|
|
|
13
13
|
import asyncssh
|
|
14
14
|
import backoff
|
|
15
15
|
import jinja2 as j2
|
|
16
|
+
import more_itertools as mit
|
|
16
17
|
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
17
18
|
from asyncssh.misc import MaybeAwait
|
|
18
19
|
from xmanager import xm
|
|
19
20
|
|
|
20
21
|
from xm_slurm import batching, config, executors, status
|
|
21
22
|
from xm_slurm.console import console
|
|
23
|
+
from xm_slurm.job_blocks import JobArgs
|
|
22
24
|
|
|
23
25
|
SlurmClusterConfig = config.SlurmClusterConfig
|
|
24
26
|
ContainerRuntime = config.ContainerRuntime
|
|
@@ -163,8 +165,9 @@ class _BatchedSlurmHandle:
|
|
|
163
165
|
class SlurmHandle(_BatchedSlurmHandle):
|
|
164
166
|
"""A handle for referring to the launched container."""
|
|
165
167
|
|
|
166
|
-
|
|
168
|
+
ssh: config.SlurmSSHConfig
|
|
167
169
|
job_id: str
|
|
170
|
+
job_name: str # XManager job name associated with this handle
|
|
168
171
|
|
|
169
172
|
def __post_init__(self):
|
|
170
173
|
if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
|
|
@@ -180,10 +183,10 @@ class SlurmHandle(_BatchedSlurmHandle):
|
|
|
180
183
|
return await self.get_state()
|
|
181
184
|
|
|
182
185
|
async def stop(self) -> None:
|
|
183
|
-
await self._batched_cancel(self.
|
|
186
|
+
await self._batched_cancel(self.ssh.connection_options, self.job_id)
|
|
184
187
|
|
|
185
188
|
async def get_state(self) -> status.SlurmJobState:
|
|
186
|
-
return await self._batched_get_state(self.
|
|
189
|
+
return await self._batched_get_state(self.ssh.connection_options, self.job_id)
|
|
187
190
|
|
|
188
191
|
|
|
189
192
|
@functools.cache
|
|
@@ -208,7 +211,7 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
208
211
|
case ContainerRuntime.PODMAN:
|
|
209
212
|
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
210
213
|
case _:
|
|
211
|
-
raise NotImplementedError
|
|
214
|
+
raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
|
|
212
215
|
# Update our global env with the runtime template's exported globals
|
|
213
216
|
template_env.globals.update(runtime_template.module.__dict__)
|
|
214
217
|
|
|
@@ -216,16 +219,17 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
216
219
|
|
|
217
220
|
|
|
218
221
|
class Client:
|
|
219
|
-
def __init__(self):
|
|
220
|
-
self._connections
|
|
222
|
+
def __init__(self) -> None:
|
|
223
|
+
self._connections = dict[
|
|
221
224
|
asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
|
|
222
|
-
]
|
|
225
|
+
]()
|
|
223
226
|
self._connection_lock = asyncio.Lock()
|
|
224
227
|
|
|
225
228
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
226
229
|
async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
|
|
227
230
|
# Make sure the xm-slurm state directory exists
|
|
228
|
-
|
|
231
|
+
async with conn.start_sftp_client() as sftp_client:
|
|
232
|
+
await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
|
|
229
233
|
|
|
230
234
|
async def connection(
|
|
231
235
|
self,
|
|
@@ -238,7 +242,24 @@ class Client:
|
|
|
238
242
|
await self._setup_remote_connection(conn)
|
|
239
243
|
self._connections[options] = conn
|
|
240
244
|
except asyncssh.misc.PermissionDenied as ex:
|
|
241
|
-
raise
|
|
245
|
+
raise SlurmExecutionError(
|
|
246
|
+
f"Permission denied connecting to {options.host}"
|
|
247
|
+
) from ex
|
|
248
|
+
except asyncssh.misc.ConnectionLost as ex:
|
|
249
|
+
raise SlurmExecutionError(f"Connection lost to host {options.host}") from ex
|
|
250
|
+
except asyncssh.misc.HostKeyNotVerifiable as ex:
|
|
251
|
+
raise SlurmExecutionError(
|
|
252
|
+
f"Cannot verify the public key for host {options.host}"
|
|
253
|
+
) from ex
|
|
254
|
+
except asyncssh.misc.KeyExchangeFailed as ex:
|
|
255
|
+
raise SlurmExecutionError(
|
|
256
|
+
f"Failed to exchange keys with host {options.host}"
|
|
257
|
+
) from ex
|
|
258
|
+
except asyncssh.Error as ex:
|
|
259
|
+
raise SlurmExecutionError(
|
|
260
|
+
f"SSH connection error when connecting to {options.host}"
|
|
261
|
+
) from ex
|
|
262
|
+
|
|
242
263
|
return self._connections[options]
|
|
243
264
|
|
|
244
265
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
@@ -340,8 +361,8 @@ class Client:
|
|
|
340
361
|
self,
|
|
341
362
|
*,
|
|
342
363
|
cluster: SlurmClusterConfig,
|
|
343
|
-
job: xm.
|
|
344
|
-
args: Mapping[str,
|
|
364
|
+
job: xm.JobGroup,
|
|
365
|
+
args: Mapping[str, JobArgs] | None,
|
|
345
366
|
experiment_id: int,
|
|
346
367
|
identity: str | None = ...,
|
|
347
368
|
) -> SlurmHandle: ...
|
|
@@ -351,21 +372,24 @@ class Client:
|
|
|
351
372
|
self,
|
|
352
373
|
*,
|
|
353
374
|
cluster: SlurmClusterConfig,
|
|
354
|
-
job: xm.Job
|
|
355
|
-
args: Sequence[
|
|
375
|
+
job: xm.Job,
|
|
376
|
+
args: Sequence[JobArgs],
|
|
356
377
|
experiment_id: int,
|
|
357
378
|
identity: str | None = ...,
|
|
358
|
-
) ->
|
|
379
|
+
) -> list[SlurmHandle]: ...
|
|
359
380
|
|
|
381
|
+
@typing.overload
|
|
360
382
|
async def launch(
|
|
361
383
|
self,
|
|
362
384
|
*,
|
|
363
385
|
cluster: SlurmClusterConfig,
|
|
364
|
-
job: xm.Job
|
|
365
|
-
args:
|
|
386
|
+
job: xm.Job,
|
|
387
|
+
args: JobArgs,
|
|
366
388
|
experiment_id: int,
|
|
367
|
-
identity: str | None =
|
|
368
|
-
) -> SlurmHandle
|
|
389
|
+
identity: str | None = ...,
|
|
390
|
+
) -> SlurmHandle: ...
|
|
391
|
+
|
|
392
|
+
async def launch(self, *, cluster, job, args, experiment_id, identity=None):
|
|
369
393
|
# Construct template
|
|
370
394
|
template = await self.template(
|
|
371
395
|
job=job,
|
|
@@ -379,7 +403,7 @@ class Client:
|
|
|
379
403
|
# Hash submission script
|
|
380
404
|
template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
|
|
381
405
|
|
|
382
|
-
conn = await self.connection(cluster.
|
|
406
|
+
conn = await self.connection(cluster.ssh.connection_options)
|
|
383
407
|
async with conn.start_sftp_client() as sftp:
|
|
384
408
|
# Write the submission script to the cluster
|
|
385
409
|
# TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
|
|
@@ -392,9 +416,9 @@ class Client:
|
|
|
392
416
|
|
|
393
417
|
# Construct and run command on the cluster
|
|
394
418
|
command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
|
|
395
|
-
result = await self.run(cluster.
|
|
419
|
+
result = await self.run(cluster.ssh.connection_options, command)
|
|
396
420
|
if result.returncode != 0:
|
|
397
|
-
raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
|
|
421
|
+
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
398
422
|
|
|
399
423
|
assert isinstance(result.stdout, str)
|
|
400
424
|
slurm_job_id, *_ = result.stdout.split(",")
|
|
@@ -405,26 +429,40 @@ class Client:
|
|
|
405
429
|
f"[cyan]{cluster.name}[/cyan] "
|
|
406
430
|
)
|
|
407
431
|
|
|
432
|
+
# If we scheduled an array job make sure to return a list of handles
|
|
433
|
+
# The indexing is always sequential in 0, 1, ..., n - 1
|
|
408
434
|
if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
|
|
435
|
+
assert job.name is not None
|
|
409
436
|
return [
|
|
410
437
|
SlurmHandle(
|
|
411
|
-
|
|
438
|
+
ssh=cluster.ssh,
|
|
412
439
|
job_id=f"{slurm_job_id}_{array_index}",
|
|
440
|
+
job_name=job.name,
|
|
413
441
|
)
|
|
414
442
|
for array_index in range(len(args))
|
|
415
443
|
]
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
444
|
+
elif isinstance(job, xm.Job):
|
|
445
|
+
assert job.name is not None
|
|
446
|
+
return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
|
|
447
|
+
elif isinstance(job, xm.JobGroup):
|
|
448
|
+
# TODO: make this work for actual job groups.
|
|
449
|
+
job = mit.one(job.jobs.values())
|
|
450
|
+
assert isinstance(job, xm.Job)
|
|
451
|
+
assert job.name is not None
|
|
452
|
+
return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
|
|
453
|
+
else:
|
|
454
|
+
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
455
|
+
|
|
456
|
+
def __del__(self):
|
|
457
|
+
for conn in self._connections.values():
|
|
458
|
+
conn.close()
|
|
421
459
|
|
|
422
460
|
|
|
423
461
|
@typing.overload
|
|
424
462
|
async def launch(
|
|
425
463
|
*,
|
|
426
|
-
job: xm.
|
|
427
|
-
args: Mapping[str,
|
|
464
|
+
job: xm.JobGroup,
|
|
465
|
+
args: Mapping[str, JobArgs],
|
|
428
466
|
experiment_id: int,
|
|
429
467
|
identity: str | None = ...,
|
|
430
468
|
) -> SlurmHandle: ...
|
|
@@ -433,22 +471,32 @@ async def launch(
|
|
|
433
471
|
@typing.overload
|
|
434
472
|
async def launch(
|
|
435
473
|
*,
|
|
436
|
-
job: xm.Job
|
|
437
|
-
args: Sequence[
|
|
474
|
+
job: xm.Job,
|
|
475
|
+
args: Sequence[JobArgs],
|
|
438
476
|
experiment_id: int,
|
|
439
477
|
identity: str | None = ...,
|
|
440
|
-
) ->
|
|
478
|
+
) -> list[SlurmHandle]: ...
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
@typing.overload
|
|
482
|
+
async def launch(
|
|
483
|
+
*,
|
|
484
|
+
job: xm.Job,
|
|
485
|
+
args: JobArgs,
|
|
486
|
+
experiment_id: int,
|
|
487
|
+
identity: str | None = ...,
|
|
488
|
+
) -> SlurmHandle: ...
|
|
441
489
|
|
|
442
490
|
|
|
443
491
|
async def launch(
|
|
444
492
|
*,
|
|
445
493
|
job: xm.Job | xm.JobGroup,
|
|
446
|
-
args: Mapping[str,
|
|
494
|
+
args: Mapping[str, JobArgs] | Sequence[JobArgs] | JobArgs,
|
|
447
495
|
experiment_id: int,
|
|
448
496
|
identity: str | None = None,
|
|
449
|
-
) -> SlurmHandle |
|
|
497
|
+
) -> SlurmHandle | list[SlurmHandle]:
|
|
450
498
|
match job:
|
|
451
|
-
case xm.Job():
|
|
499
|
+
case xm.Job() as job:
|
|
452
500
|
if not isinstance(job.executor, executors.Slurm):
|
|
453
501
|
raise ValueError("Job must have a Slurm executor")
|
|
454
502
|
job_requirements = job.executor.requirements
|
|
@@ -459,7 +507,7 @@ async def launch(
|
|
|
459
507
|
return await get_client().launch(
|
|
460
508
|
cluster=cluster,
|
|
461
509
|
job=job,
|
|
462
|
-
args=args,
|
|
510
|
+
args=typing.cast(JobArgs | Sequence[JobArgs], args),
|
|
463
511
|
experiment_id=experiment_id,
|
|
464
512
|
identity=identity,
|
|
465
513
|
)
|
|
@@ -482,8 +530,8 @@ async def launch(
|
|
|
482
530
|
|
|
483
531
|
return await get_client().launch(
|
|
484
532
|
cluster=job_group_clusters.pop(),
|
|
485
|
-
job=
|
|
486
|
-
args=args,
|
|
533
|
+
job=job_group,
|
|
534
|
+
args=typing.cast(Mapping[str, JobArgs], args),
|
|
487
535
|
experiment_id=experiment_id,
|
|
488
536
|
identity=identity,
|
|
489
537
|
)
|