xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +6 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +105 -55
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +47 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +34 -22
- xm_slurm/execution.py +305 -107
- xm_slurm/executors.py +8 -12
- xm_slurm/experiment.py +601 -168
- xm_slurm/experimental/parameter_controller.py +202 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +42 -20
- xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
- xm_slurm/packaging/router.py +3 -1
- xm_slurm/packaging/utils.py +9 -81
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +52 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/mamba.Dockerfile +4 -2
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
- xm_slurm/templates/slurm/job.bash.j2 +4 -3
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
- {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/__init__.py +0 -75
- xm_slurm/packaging/docker/abc.py +0 -112
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/execution.py
CHANGED
|
@@ -5,24 +5,29 @@ import functools
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import logging
|
|
7
7
|
import operator
|
|
8
|
-
import re
|
|
9
8
|
import shlex
|
|
10
|
-
import typing
|
|
11
|
-
from typing import Any, Mapping, Sequence
|
|
9
|
+
import typing as tp
|
|
12
10
|
|
|
13
11
|
import asyncssh
|
|
14
12
|
import backoff
|
|
15
13
|
import jinja2 as j2
|
|
14
|
+
import more_itertools as mit
|
|
16
15
|
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
17
16
|
from asyncssh.misc import MaybeAwait
|
|
17
|
+
from rich.console import ConsoleRenderable
|
|
18
|
+
from rich.rule import Rule
|
|
18
19
|
from xmanager import xm
|
|
19
20
|
|
|
20
|
-
from xm_slurm import batching, config, executors, status
|
|
21
|
+
from xm_slurm import batching, config, constants, dependencies, executors, status
|
|
21
22
|
from xm_slurm.console import console
|
|
23
|
+
from xm_slurm.job_blocks import JobArgs
|
|
24
|
+
from xm_slurm.types import Descriptor
|
|
22
25
|
|
|
23
26
|
SlurmClusterConfig = config.SlurmClusterConfig
|
|
24
27
|
ContainerRuntime = config.ContainerRuntime
|
|
25
28
|
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
26
31
|
"""
|
|
27
32
|
=== Runtime Configurations ===
|
|
28
33
|
With RunC:
|
|
@@ -41,11 +46,6 @@ With Singularity / Apptainer:
|
|
|
41
46
|
apptainer run --compat <digest>
|
|
42
47
|
"""
|
|
43
48
|
|
|
44
|
-
"""
|
|
45
|
-
#SBATCH --error=/dev/null
|
|
46
|
-
#SBATCH --output=/dev/null
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
49
|
_POLL_INTERVAL = 30.0
|
|
50
50
|
_BATCHED_BATCH_SIZE = 16
|
|
51
51
|
_BATCHED_TIMEOUT = 0.2
|
|
@@ -67,12 +67,79 @@ class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
|
67
67
|
return []
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
70
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
71
|
+
class SlurmJob:
|
|
72
|
+
job_id: str
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def is_array_job(self) -> bool:
|
|
76
|
+
return isinstance(self, SlurmArrayJob)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def is_heterogeneous_job(self) -> bool:
|
|
80
|
+
return isinstance(self, SlurmHeterogeneousJob)
|
|
81
|
+
|
|
82
|
+
def __hash__(self) -> int:
|
|
83
|
+
return hash((type(self), self.job_id))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
87
|
+
class SlurmArrayJob(SlurmJob):
|
|
88
|
+
array_job_id: str
|
|
89
|
+
array_task_id: str
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
93
|
+
class SlurmHeterogeneousJob(SlurmJob):
|
|
94
|
+
het_job_id: str
|
|
95
|
+
het_component_id: str
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
|
|
102
|
+
def __set_name__(self, owner: type, name: str):
|
|
103
|
+
del owner
|
|
104
|
+
self.job = f"_{name}"
|
|
105
|
+
|
|
106
|
+
def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
|
|
107
|
+
del owner
|
|
108
|
+
return getattr(instance, self.job)
|
|
109
|
+
|
|
110
|
+
def __set__(self, instance: object, value: str):
|
|
111
|
+
_setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
|
|
112
|
+
|
|
113
|
+
match = constants.SLURM_JOB_ID_REGEX.match(value)
|
|
114
|
+
if match is None:
|
|
115
|
+
raise ValueError(f"Invalid Slurm job ID: {value}")
|
|
116
|
+
groups = match.groupdict()
|
|
117
|
+
|
|
118
|
+
job_id = groups["jobid"]
|
|
119
|
+
if array_task_id := groups.get("arraytaskid", None):
|
|
120
|
+
_setattr(
|
|
121
|
+
instance,
|
|
122
|
+
self.job,
|
|
123
|
+
SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
|
|
124
|
+
)
|
|
125
|
+
elif het_component_id := groups.get("componentid", None):
|
|
126
|
+
_setattr(
|
|
127
|
+
instance,
|
|
128
|
+
self.job,
|
|
129
|
+
SlurmHeterogeneousJob(
|
|
130
|
+
job_id=value, het_job_id=job_id, het_component_id=het_component_id
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
_setattr(instance, self.job, SlurmJob(job_id=value))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _group_by_ssh_configs(
|
|
138
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
|
|
139
|
+
) -> dict[config.SlurmSSHConfig, list[SlurmJob]]:
|
|
73
140
|
jobs_by_cluster = collections.defaultdict(list)
|
|
74
|
-
for
|
|
75
|
-
jobs_by_cluster[
|
|
141
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
142
|
+
jobs_by_cluster[ssh_config].append(slurm_job)
|
|
76
143
|
return jobs_by_cluster
|
|
77
144
|
|
|
78
145
|
|
|
@@ -83,18 +150,20 @@ class _BatchedSlurmHandle:
|
|
|
83
150
|
batch_timeout=_BATCHED_TIMEOUT,
|
|
84
151
|
)
|
|
85
152
|
@staticmethod
|
|
153
|
+
@backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
|
|
86
154
|
async def _batched_get_state(
|
|
87
|
-
|
|
88
|
-
|
|
155
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig],
|
|
156
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
157
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
89
158
|
async def _get_state(
|
|
90
|
-
options:
|
|
91
|
-
) -> Sequence[status.SlurmJobState]:
|
|
159
|
+
options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
|
|
160
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
92
161
|
result = await get_client().run(
|
|
93
162
|
options,
|
|
94
163
|
[
|
|
95
164
|
"sacct",
|
|
96
165
|
"--jobs",
|
|
97
|
-
",".join(
|
|
166
|
+
",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
|
|
98
167
|
"--format",
|
|
99
168
|
"JobID,State",
|
|
100
169
|
"--allocations",
|
|
@@ -111,32 +180,35 @@ class _BatchedSlurmHandle:
|
|
|
111
180
|
states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
|
|
112
181
|
|
|
113
182
|
job_states = []
|
|
114
|
-
for
|
|
115
|
-
if job_id in states_by_job_id:
|
|
116
|
-
job_states.append(states_by_job_id[job_id])
|
|
183
|
+
for slurm_job in slurm_jobs:
|
|
184
|
+
if slurm_job.job_id in states_by_job_id:
|
|
185
|
+
job_states.append(states_by_job_id[slurm_job.job_id])
|
|
117
186
|
# This is a stupid hack around sacct's inability to display state information for
|
|
118
187
|
# array job elements that haven't begun. We'll assume that if the job ID is not found,
|
|
119
188
|
# and it's an array job, then it's pending.
|
|
120
|
-
elif
|
|
189
|
+
elif slurm_job.is_array_job:
|
|
121
190
|
job_states.append(status.SlurmJobState.PENDING)
|
|
122
191
|
else:
|
|
123
|
-
raise SlurmExecutionError(f"Failed to find job state info for {
|
|
192
|
+
raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
|
|
124
193
|
return job_states
|
|
125
194
|
|
|
126
|
-
|
|
195
|
+
# Group Slurm jobs by their cluster so we can batch requests
|
|
196
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
127
197
|
|
|
198
|
+
# Async get state for each cluster
|
|
128
199
|
job_states_per_cluster = await asyncio.gather(*[
|
|
129
|
-
_get_state(options,
|
|
200
|
+
_get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
|
|
130
201
|
])
|
|
131
|
-
job_states_by_cluster: dict[
|
|
132
|
-
asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
|
|
133
|
-
] = {}
|
|
134
|
-
for options, job_states in zip(ssh_options, job_states_per_cluster):
|
|
135
|
-
job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
|
|
136
202
|
|
|
203
|
+
# Reconstruct the job states by cluster
|
|
204
|
+
job_states_by_cluster = {}
|
|
205
|
+
for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
|
|
206
|
+
job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
|
|
207
|
+
|
|
208
|
+
# Reconstruct the job states in the original order
|
|
137
209
|
job_states = []
|
|
138
|
-
for
|
|
139
|
-
job_states.append(job_states_by_cluster[
|
|
210
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
211
|
+
job_states.append(job_states_by_cluster[ssh_config][slurm_job])
|
|
140
212
|
return job_states
|
|
141
213
|
|
|
142
214
|
@functools.partial(
|
|
@@ -146,29 +218,32 @@ class _BatchedSlurmHandle:
|
|
|
146
218
|
)
|
|
147
219
|
@staticmethod
|
|
148
220
|
async def _batched_cancel(
|
|
149
|
-
|
|
150
|
-
|
|
221
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig],
|
|
222
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
223
|
+
) -> tp.Sequence[None]:
|
|
151
224
|
async def _cancel(
|
|
152
|
-
options:
|
|
225
|
+
options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
|
|
153
226
|
) -> None:
|
|
154
|
-
await get_client().run(
|
|
227
|
+
await get_client().run(
|
|
228
|
+
options,
|
|
229
|
+
["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
|
|
230
|
+
check=True,
|
|
231
|
+
)
|
|
155
232
|
|
|
156
|
-
jobs_by_cluster =
|
|
233
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
157
234
|
return await asyncio.gather(*[
|
|
158
235
|
_cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
|
|
159
236
|
])
|
|
160
237
|
|
|
161
238
|
|
|
162
239
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
163
|
-
class SlurmHandle(_BatchedSlurmHandle):
|
|
240
|
+
class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
|
|
164
241
|
"""A handle for referring to the launched container."""
|
|
165
242
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
|
|
171
|
-
raise ValueError(f"Invalid job ID: {self.job_id}")
|
|
243
|
+
experiment_id: int
|
|
244
|
+
ssh: config.SlurmSSHConfig
|
|
245
|
+
slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
|
|
246
|
+
job_name: str # XManager job name associated with this handle
|
|
172
247
|
|
|
173
248
|
@backoff.on_predicate(
|
|
174
249
|
backoff.constant,
|
|
@@ -180,15 +255,56 @@ class SlurmHandle(_BatchedSlurmHandle):
|
|
|
180
255
|
return await self.get_state()
|
|
181
256
|
|
|
182
257
|
async def stop(self) -> None:
|
|
183
|
-
await self._batched_cancel(self.
|
|
258
|
+
await self._batched_cancel(self.ssh, self.slurm_job)
|
|
184
259
|
|
|
185
260
|
async def get_state(self) -> status.SlurmJobState:
|
|
186
|
-
return await self._batched_get_state(self.
|
|
187
|
-
|
|
261
|
+
return await self._batched_get_state(self.ssh, self.slurm_job)
|
|
188
262
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
263
|
+
async def logs(
|
|
264
|
+
self, *, num_lines: int, block_size: int, wait: bool, follow: bool
|
|
265
|
+
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
266
|
+
file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
|
|
267
|
+
conn = await get_client().connection(self.ssh)
|
|
268
|
+
async with conn.start_sftp_client() as sftp:
|
|
269
|
+
if wait:
|
|
270
|
+
while not (await sftp.exists(file)):
|
|
271
|
+
await asyncio.sleep(5)
|
|
272
|
+
|
|
273
|
+
async with sftp.open(file, "rb") as remote_file:
|
|
274
|
+
file_stat = await remote_file.stat()
|
|
275
|
+
file_size = file_stat.size
|
|
276
|
+
assert file_size is not None
|
|
277
|
+
|
|
278
|
+
data = b""
|
|
279
|
+
lines = []
|
|
280
|
+
position = file_size
|
|
281
|
+
|
|
282
|
+
while len(lines) <= num_lines and position > 0:
|
|
283
|
+
read_size = min(block_size, position)
|
|
284
|
+
position -= read_size
|
|
285
|
+
await remote_file.seek(position)
|
|
286
|
+
chunk = await remote_file.read(read_size)
|
|
287
|
+
data = chunk + data
|
|
288
|
+
lines = data.splitlines()
|
|
289
|
+
|
|
290
|
+
if position <= 0:
|
|
291
|
+
yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
|
|
292
|
+
for line in lines[-num_lines:]:
|
|
293
|
+
yield line.decode("utf-8", errors="replace")
|
|
294
|
+
|
|
295
|
+
if (await self.get_state()) not in status.SlurmActiveJobStates:
|
|
296
|
+
yield Rule("[bold red]END OF FILE[/bold red]")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
if not follow:
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
await remote_file.seek(file_size)
|
|
303
|
+
while True:
|
|
304
|
+
if new_data := (await remote_file.read(block_size)):
|
|
305
|
+
yield new_data.decode("utf-8", errors="replace")
|
|
306
|
+
else:
|
|
307
|
+
await asyncio.sleep(0.25)
|
|
192
308
|
|
|
193
309
|
|
|
194
310
|
@functools.cache
|
|
@@ -208,55 +324,75 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
208
324
|
case ContainerRuntime.PODMAN:
|
|
209
325
|
runtime_template = template_env.get_template("runtimes/podman.bash.j2")
|
|
210
326
|
case _:
|
|
211
|
-
raise NotImplementedError
|
|
327
|
+
raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
|
|
212
328
|
# Update our global env with the runtime template's exported globals
|
|
213
329
|
template_env.globals.update(runtime_template.module.__dict__)
|
|
214
330
|
|
|
215
331
|
return template_env
|
|
216
332
|
|
|
217
333
|
|
|
334
|
+
@functools.cache
|
|
335
|
+
def get_client() -> "Client":
|
|
336
|
+
return Client()
|
|
337
|
+
|
|
338
|
+
|
|
218
339
|
class Client:
|
|
219
|
-
def __init__(self):
|
|
220
|
-
self._connections
|
|
221
|
-
asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
|
|
222
|
-
] = {}
|
|
340
|
+
def __init__(self) -> None:
|
|
341
|
+
self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
|
|
223
342
|
self._connection_lock = asyncio.Lock()
|
|
224
343
|
|
|
225
344
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
226
345
|
async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
|
|
227
346
|
# Make sure the xm-slurm state directory exists
|
|
228
|
-
|
|
347
|
+
async with conn.start_sftp_client() as sftp_client:
|
|
348
|
+
await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
|
|
229
349
|
|
|
230
|
-
async def connection(
|
|
231
|
-
self
|
|
232
|
-
options: asyncssh.SSHClientConnectionOptions,
|
|
233
|
-
) -> asyncssh.SSHClientConnection:
|
|
234
|
-
if options not in self._connections:
|
|
350
|
+
async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
|
|
351
|
+
if ssh_config not in self._connections:
|
|
235
352
|
async with self._connection_lock:
|
|
236
353
|
try:
|
|
237
|
-
conn, _ = await asyncssh.create_connection(
|
|
354
|
+
conn, _ = await asyncssh.create_connection(
|
|
355
|
+
NoKBAuthSSHClient, options=ssh_config.connection_options
|
|
356
|
+
)
|
|
238
357
|
await self._setup_remote_connection(conn)
|
|
239
|
-
self._connections[
|
|
358
|
+
self._connections[ssh_config] = conn
|
|
240
359
|
except asyncssh.misc.PermissionDenied as ex:
|
|
241
|
-
raise
|
|
242
|
-
|
|
360
|
+
raise SlurmExecutionError(
|
|
361
|
+
f"Permission denied connecting to {ssh_config.host}"
|
|
362
|
+
) from ex
|
|
363
|
+
except asyncssh.misc.ConnectionLost as ex:
|
|
364
|
+
raise SlurmExecutionError(f"Connection lost to host {ssh_config.host}") from ex
|
|
365
|
+
except asyncssh.misc.HostKeyNotVerifiable as ex:
|
|
366
|
+
raise SlurmExecutionError(
|
|
367
|
+
f"Cannot verify the public key for host {ssh_config.host}"
|
|
368
|
+
) from ex
|
|
369
|
+
except asyncssh.misc.KeyExchangeFailed as ex:
|
|
370
|
+
raise SlurmExecutionError(
|
|
371
|
+
f"Failed to exchange keys with host {ssh_config.host}"
|
|
372
|
+
) from ex
|
|
373
|
+
except asyncssh.Error as ex:
|
|
374
|
+
raise SlurmExecutionError(
|
|
375
|
+
f"SSH connection error when connecting to {ssh_config.host}"
|
|
376
|
+
) from ex
|
|
377
|
+
|
|
378
|
+
return self._connections[ssh_config]
|
|
243
379
|
|
|
244
380
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
245
381
|
async def run(
|
|
246
382
|
self,
|
|
247
|
-
|
|
248
|
-
command: xm.SequentialArgs | str | Sequence[str],
|
|
383
|
+
ssh_config: config.SlurmSSHConfig,
|
|
384
|
+
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
249
385
|
*,
|
|
250
386
|
check: bool = False,
|
|
251
387
|
timeout: float | None = None,
|
|
252
388
|
) -> asyncssh.SSHCompletedProcess:
|
|
253
|
-
client = await self.connection(
|
|
389
|
+
client = await self.connection(ssh_config)
|
|
254
390
|
if isinstance(command, xm.SequentialArgs):
|
|
255
391
|
command = command.to_list()
|
|
256
392
|
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
257
393
|
command = shlex.join(command)
|
|
258
394
|
assert isinstance(command, str)
|
|
259
|
-
|
|
395
|
+
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
260
396
|
|
|
261
397
|
return await client.run(command, check=check, timeout=timeout)
|
|
262
398
|
|
|
@@ -264,8 +400,9 @@ class Client:
|
|
|
264
400
|
self,
|
|
265
401
|
*,
|
|
266
402
|
job: xm.Job | xm.JobGroup,
|
|
403
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
267
404
|
cluster: SlurmClusterConfig,
|
|
268
|
-
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
|
|
405
|
+
args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
|
|
269
406
|
experiment_id: int,
|
|
270
407
|
identity: str | None,
|
|
271
408
|
) -> str:
|
|
@@ -276,7 +413,7 @@ class Client:
|
|
|
276
413
|
|
|
277
414
|
# Sanitize job groups
|
|
278
415
|
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
279
|
-
job =
|
|
416
|
+
job = tp.cast(xm.Job, list(job.jobs.values())[0])
|
|
280
417
|
elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
281
418
|
raise ValueError("Job group must have at least one job")
|
|
282
419
|
|
|
@@ -294,6 +431,7 @@ class Client:
|
|
|
294
431
|
|
|
295
432
|
return template.render(
|
|
296
433
|
job=job_array,
|
|
434
|
+
dependency=dependency,
|
|
297
435
|
cluster=cluster,
|
|
298
436
|
args=sequential_args,
|
|
299
437
|
env_vars=env_vars,
|
|
@@ -306,6 +444,7 @@ class Client:
|
|
|
306
444
|
env_vars = args.get("env_vars", None)
|
|
307
445
|
return template.render(
|
|
308
446
|
job=job,
|
|
447
|
+
dependency=dependency,
|
|
309
448
|
cluster=cluster,
|
|
310
449
|
args=sequential_args,
|
|
311
450
|
env_vars=env_vars,
|
|
@@ -326,6 +465,7 @@ class Client:
|
|
|
326
465
|
}
|
|
327
466
|
return template.render(
|
|
328
467
|
job_group=job_group,
|
|
468
|
+
dependency=dependency,
|
|
329
469
|
cluster=cluster,
|
|
330
470
|
args=sequential_args,
|
|
331
471
|
env_vars=env_vars,
|
|
@@ -335,51 +475,66 @@ class Client:
|
|
|
335
475
|
case _:
|
|
336
476
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
337
477
|
|
|
338
|
-
@
|
|
478
|
+
@tp.overload
|
|
339
479
|
async def launch(
|
|
340
480
|
self,
|
|
341
481
|
*,
|
|
342
482
|
cluster: SlurmClusterConfig,
|
|
343
|
-
job: xm.
|
|
344
|
-
|
|
483
|
+
job: xm.JobGroup,
|
|
484
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
485
|
+
args: tp.Mapping[str, JobArgs] | None,
|
|
345
486
|
experiment_id: int,
|
|
346
487
|
identity: str | None = ...,
|
|
347
488
|
) -> SlurmHandle: ...
|
|
348
489
|
|
|
349
|
-
@
|
|
490
|
+
@tp.overload
|
|
350
491
|
async def launch(
|
|
351
492
|
self,
|
|
352
493
|
*,
|
|
353
494
|
cluster: SlurmClusterConfig,
|
|
354
|
-
job: xm.Job
|
|
355
|
-
|
|
495
|
+
job: xm.Job,
|
|
496
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
497
|
+
args: tp.Sequence[JobArgs],
|
|
498
|
+
experiment_id: int,
|
|
499
|
+
identity: str | None = ...,
|
|
500
|
+
) -> list[SlurmHandle]: ...
|
|
501
|
+
|
|
502
|
+
@tp.overload
|
|
503
|
+
async def launch(
|
|
504
|
+
self,
|
|
505
|
+
*,
|
|
506
|
+
cluster: SlurmClusterConfig,
|
|
507
|
+
job: xm.Job,
|
|
508
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
509
|
+
args: JobArgs,
|
|
356
510
|
experiment_id: int,
|
|
357
511
|
identity: str | None = ...,
|
|
358
|
-
) ->
|
|
512
|
+
) -> SlurmHandle: ...
|
|
359
513
|
|
|
360
514
|
async def launch(
|
|
361
515
|
self,
|
|
362
516
|
*,
|
|
363
517
|
cluster: SlurmClusterConfig,
|
|
364
518
|
job: xm.Job | xm.JobGroup,
|
|
365
|
-
|
|
519
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
520
|
+
args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs | None,
|
|
366
521
|
experiment_id: int,
|
|
367
522
|
identity: str | None = None,
|
|
368
|
-
)
|
|
369
|
-
# Construct template
|
|
523
|
+
):
|
|
370
524
|
template = await self.template(
|
|
371
525
|
job=job,
|
|
526
|
+
dependency=dependency,
|
|
372
527
|
cluster=cluster,
|
|
373
528
|
args=args,
|
|
374
529
|
experiment_id=experiment_id,
|
|
375
530
|
identity=identity,
|
|
376
531
|
)
|
|
377
|
-
|
|
532
|
+
logger.debug("Slurm submission script:\n%s", template)
|
|
378
533
|
|
|
379
534
|
# Hash submission script
|
|
380
535
|
template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
|
|
381
536
|
|
|
382
|
-
conn = await self.connection(cluster.
|
|
537
|
+
conn = await self.connection(cluster.ssh)
|
|
383
538
|
async with conn.start_sftp_client() as sftp:
|
|
384
539
|
# Write the submission script to the cluster
|
|
385
540
|
# TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
|
|
@@ -392,9 +547,9 @@ class Client:
|
|
|
392
547
|
|
|
393
548
|
# Construct and run command on the cluster
|
|
394
549
|
command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
|
|
395
|
-
result = await self.run(cluster.
|
|
550
|
+
result = await self.run(cluster.ssh, command)
|
|
396
551
|
if result.returncode != 0:
|
|
397
|
-
raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
|
|
552
|
+
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
398
553
|
|
|
399
554
|
assert isinstance(result.stdout, str)
|
|
400
555
|
slurm_job_id, *_ = result.stdout.split(",")
|
|
@@ -405,61 +560,103 @@ class Client:
|
|
|
405
560
|
f"[cyan]{cluster.name}[/cyan] "
|
|
406
561
|
)
|
|
407
562
|
|
|
563
|
+
# If we scheduled an array job make sure to return a list of handles
|
|
564
|
+
# The indexing is always sequential in 0, 1, ..., n - 1
|
|
408
565
|
if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
|
|
566
|
+
assert job.name is not None
|
|
409
567
|
return [
|
|
410
568
|
SlurmHandle(
|
|
411
|
-
|
|
412
|
-
|
|
569
|
+
experiment_id=experiment_id,
|
|
570
|
+
ssh=cluster.ssh,
|
|
571
|
+
slurm_job=f"{slurm_job_id}_{array_index}",
|
|
572
|
+
job_name=job.name,
|
|
413
573
|
)
|
|
414
574
|
for array_index in range(len(args))
|
|
415
575
|
]
|
|
576
|
+
elif isinstance(job, xm.Job):
|
|
577
|
+
assert job.name is not None
|
|
578
|
+
return SlurmHandle(
|
|
579
|
+
experiment_id=experiment_id,
|
|
580
|
+
ssh=cluster.ssh,
|
|
581
|
+
slurm_job=slurm_job_id,
|
|
582
|
+
job_name=job.name,
|
|
583
|
+
)
|
|
584
|
+
elif isinstance(job, xm.JobGroup):
|
|
585
|
+
# TODO: make this work for actual job groups.
|
|
586
|
+
job = tp.cast(xm.Job, mit.one(job.jobs.values()))
|
|
587
|
+
assert isinstance(job, xm.Job)
|
|
588
|
+
assert job.name is not None
|
|
589
|
+
return SlurmHandle(
|
|
590
|
+
experiment_id=experiment_id,
|
|
591
|
+
ssh=cluster.ssh,
|
|
592
|
+
slurm_job=slurm_job_id,
|
|
593
|
+
job_name=job.name,
|
|
594
|
+
)
|
|
595
|
+
else:
|
|
596
|
+
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
416
597
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
)
|
|
598
|
+
def __del__(self):
|
|
599
|
+
for conn in self._connections.values():
|
|
600
|
+
conn.close()
|
|
421
601
|
|
|
422
602
|
|
|
423
|
-
@
|
|
603
|
+
@tp.overload
|
|
424
604
|
async def launch(
|
|
425
605
|
*,
|
|
426
|
-
job: xm.
|
|
427
|
-
|
|
606
|
+
job: xm.JobGroup,
|
|
607
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
608
|
+
args: tp.Mapping[str, JobArgs],
|
|
428
609
|
experiment_id: int,
|
|
429
610
|
identity: str | None = ...,
|
|
430
611
|
) -> SlurmHandle: ...
|
|
431
612
|
|
|
432
613
|
|
|
433
|
-
@
|
|
614
|
+
@tp.overload
|
|
434
615
|
async def launch(
|
|
435
616
|
*,
|
|
436
|
-
job: xm.Job
|
|
437
|
-
|
|
617
|
+
job: xm.Job,
|
|
618
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
619
|
+
args: tp.Sequence[JobArgs],
|
|
438
620
|
experiment_id: int,
|
|
439
621
|
identity: str | None = ...,
|
|
440
|
-
) ->
|
|
622
|
+
) -> list[SlurmHandle]: ...
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@tp.overload
|
|
626
|
+
async def launch(
|
|
627
|
+
*,
|
|
628
|
+
job: xm.Job,
|
|
629
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
630
|
+
args: JobArgs,
|
|
631
|
+
experiment_id: int,
|
|
632
|
+
identity: str | None = ...,
|
|
633
|
+
) -> SlurmHandle: ...
|
|
441
634
|
|
|
442
635
|
|
|
443
636
|
async def launch(
|
|
444
637
|
*,
|
|
445
638
|
job: xm.Job | xm.JobGroup,
|
|
446
|
-
|
|
639
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
640
|
+
args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs,
|
|
447
641
|
experiment_id: int,
|
|
448
642
|
identity: str | None = None,
|
|
449
|
-
) -> SlurmHandle |
|
|
643
|
+
) -> SlurmHandle | list[SlurmHandle]:
|
|
450
644
|
match job:
|
|
451
|
-
case xm.Job():
|
|
645
|
+
case xm.Job() as job:
|
|
452
646
|
if not isinstance(job.executor, executors.Slurm):
|
|
453
647
|
raise ValueError("Job must have a Slurm executor")
|
|
454
648
|
job_requirements = job.executor.requirements
|
|
455
649
|
cluster = job_requirements.cluster
|
|
456
650
|
if cluster is None:
|
|
457
651
|
raise ValueError("Job must have a cluster requirement")
|
|
652
|
+
if cluster.validate is not None:
|
|
653
|
+
cluster.validate(job)
|
|
458
654
|
|
|
459
655
|
return await get_client().launch(
|
|
460
656
|
cluster=cluster,
|
|
461
657
|
job=job,
|
|
462
|
-
|
|
658
|
+
dependency=dependency,
|
|
659
|
+
args=tp.cast(JobArgs | tp.Sequence[JobArgs], args),
|
|
463
660
|
experiment_id=experiment_id,
|
|
464
661
|
identity=identity,
|
|
465
662
|
)
|
|
@@ -473,6 +670,8 @@ async def launch(
|
|
|
473
670
|
raise ValueError("Job must have a Slurm executor")
|
|
474
671
|
if job_item.executor.requirements.cluster is None:
|
|
475
672
|
raise ValueError("Job must have a cluster requirement")
|
|
673
|
+
if job_item.executor.requirements.cluster.validate is not None:
|
|
674
|
+
job_item.executor.requirements.cluster.validate(job_item)
|
|
476
675
|
job_group_clusters.add(job_item.executor.requirements.cluster)
|
|
477
676
|
job_group_executors.add(id(job_item.executor))
|
|
478
677
|
if len(job_group_executors) != 1:
|
|
@@ -482,10 +681,9 @@ async def launch(
|
|
|
482
681
|
|
|
483
682
|
return await get_client().launch(
|
|
484
683
|
cluster=job_group_clusters.pop(),
|
|
485
|
-
job=
|
|
486
|
-
|
|
684
|
+
job=job_group,
|
|
685
|
+
dependency=dependency,
|
|
686
|
+
args=tp.cast(tp.Mapping[str, JobArgs], args),
|
|
487
687
|
experiment_id=experiment_id,
|
|
488
688
|
identity=identity,
|
|
489
689
|
)
|
|
490
|
-
case _:
|
|
491
|
-
raise ValueError("Unsupported job type")
|