xmanager-slurm 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +1 -1
- xm_slurm/config.py +7 -2
- xm_slurm/constants.py +4 -0
- xm_slurm/contrib/clusters/__init__.py +9 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +20 -15
- xm_slurm/execution.py +246 -96
- xm_slurm/executors.py +8 -12
- xm_slurm/experiment.py +374 -83
- xm_slurm/experimental/parameter_controller.py +12 -10
- xm_slurm/packaging/{docker/local.py → docker.py} +126 -32
- xm_slurm/packaging/router.py +3 -1
- xm_slurm/packaging/utils.py +4 -28
- xm_slurm/resources.py +2 -0
- xm_slurm/scripts/cli.py +77 -0
- xm_slurm/templates/docker/mamba.Dockerfile +1 -1
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
- xm_slurm/templates/slurm/job.bash.j2 +4 -3
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +1 -0
- xm_slurm/types.py +23 -0
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/METADATA +1 -1
- xmanager_slurm-0.4.2.dist-info/RECORD +44 -0
- xmanager_slurm-0.4.2.dist-info/entry_points.txt +2 -0
- xm_slurm/packaging/docker/__init__.py +0 -69
- xm_slurm/packaging/docker/abc.py +0 -112
- xmanager_slurm-0.4.0.dist-info/RECORD +0 -42
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/WHEEL +0 -0
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/execution.py
CHANGED
|
@@ -5,10 +5,8 @@ import functools
|
|
|
5
5
|
import hashlib
|
|
6
6
|
import logging
|
|
7
7
|
import operator
|
|
8
|
-
import re
|
|
9
8
|
import shlex
|
|
10
|
-
import typing
|
|
11
|
-
from typing import Any, Mapping, Sequence
|
|
9
|
+
import typing as tp
|
|
12
10
|
|
|
13
11
|
import asyncssh
|
|
14
12
|
import backoff
|
|
@@ -16,15 +14,20 @@ import jinja2 as j2
|
|
|
16
14
|
import more_itertools as mit
|
|
17
15
|
from asyncssh.auth import KbdIntPrompts, KbdIntResponse
|
|
18
16
|
from asyncssh.misc import MaybeAwait
|
|
17
|
+
from rich.console import ConsoleRenderable
|
|
18
|
+
from rich.rule import Rule
|
|
19
19
|
from xmanager import xm
|
|
20
20
|
|
|
21
|
-
from xm_slurm import batching, config, executors, status
|
|
21
|
+
from xm_slurm import batching, config, constants, dependencies, executors, status
|
|
22
22
|
from xm_slurm.console import console
|
|
23
23
|
from xm_slurm.job_blocks import JobArgs
|
|
24
|
+
from xm_slurm.types import Descriptor
|
|
24
25
|
|
|
25
26
|
SlurmClusterConfig = config.SlurmClusterConfig
|
|
26
27
|
ContainerRuntime = config.ContainerRuntime
|
|
27
28
|
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
28
31
|
"""
|
|
29
32
|
=== Runtime Configurations ===
|
|
30
33
|
With RunC:
|
|
@@ -43,11 +46,6 @@ With Singularity / Apptainer:
|
|
|
43
46
|
apptainer run --compat <digest>
|
|
44
47
|
"""
|
|
45
48
|
|
|
46
|
-
"""
|
|
47
|
-
#SBATCH --error=/dev/null
|
|
48
|
-
#SBATCH --output=/dev/null
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
49
|
_POLL_INTERVAL = 30.0
|
|
52
50
|
_BATCHED_BATCH_SIZE = 16
|
|
53
51
|
_BATCHED_TIMEOUT = 0.2
|
|
@@ -69,12 +67,79 @@ class NoKBAuthSSHClient(asyncssh.SSHClient):
|
|
|
69
67
|
return []
|
|
70
68
|
|
|
71
69
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
70
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
71
|
+
class SlurmJob:
|
|
72
|
+
job_id: str
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def is_array_job(self) -> bool:
|
|
76
|
+
return isinstance(self, SlurmArrayJob)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def is_heterogeneous_job(self) -> bool:
|
|
80
|
+
return isinstance(self, SlurmHeterogeneousJob)
|
|
81
|
+
|
|
82
|
+
def __hash__(self) -> int:
|
|
83
|
+
return hash((type(self), self.job_id))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
87
|
+
class SlurmArrayJob(SlurmJob):
|
|
88
|
+
array_job_id: str
|
|
89
|
+
array_task_id: str
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
93
|
+
class SlurmHeterogeneousJob(SlurmJob):
|
|
94
|
+
het_job_id: str
|
|
95
|
+
het_component_id: str
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
SlurmJobT = tp.TypeVar("SlurmJobT", bound=SlurmJob, covariant=True)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SlurmJobDescriptor(Descriptor[SlurmJobT, str]):
|
|
102
|
+
def __set_name__(self, owner: type, name: str):
|
|
103
|
+
del owner
|
|
104
|
+
self.job = f"_{name}"
|
|
105
|
+
|
|
106
|
+
def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> SlurmJobT:
|
|
107
|
+
del owner
|
|
108
|
+
return getattr(instance, self.job)
|
|
109
|
+
|
|
110
|
+
def __set__(self, instance: object, value: str):
|
|
111
|
+
_setattr = object.__setattr__ if not hasattr(instance, self.job) else setattr
|
|
112
|
+
|
|
113
|
+
match = constants.SLURM_JOB_ID_REGEX.match(value)
|
|
114
|
+
if match is None:
|
|
115
|
+
raise ValueError(f"Invalid Slurm job ID: {value}")
|
|
116
|
+
groups = match.groupdict()
|
|
117
|
+
|
|
118
|
+
job_id = groups["jobid"]
|
|
119
|
+
if array_task_id := groups.get("arraytaskid", None):
|
|
120
|
+
_setattr(
|
|
121
|
+
instance,
|
|
122
|
+
self.job,
|
|
123
|
+
SlurmArrayJob(job_id=value, array_job_id=job_id, array_task_id=array_task_id),
|
|
124
|
+
)
|
|
125
|
+
elif het_component_id := groups.get("componentid", None):
|
|
126
|
+
_setattr(
|
|
127
|
+
instance,
|
|
128
|
+
self.job,
|
|
129
|
+
SlurmHeterogeneousJob(
|
|
130
|
+
job_id=value, het_job_id=job_id, het_component_id=het_component_id
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
_setattr(instance, self.job, SlurmJob(job_id=value))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _group_by_ssh_configs(
|
|
138
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig], slurm_jobs: tp.Sequence[SlurmJob]
|
|
139
|
+
) -> dict[config.SlurmSSHConfig, list[SlurmJob]]:
|
|
75
140
|
jobs_by_cluster = collections.defaultdict(list)
|
|
76
|
-
for
|
|
77
|
-
jobs_by_cluster[
|
|
141
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
142
|
+
jobs_by_cluster[ssh_config].append(slurm_job)
|
|
78
143
|
return jobs_by_cluster
|
|
79
144
|
|
|
80
145
|
|
|
@@ -85,18 +150,20 @@ class _BatchedSlurmHandle:
|
|
|
85
150
|
batch_timeout=_BATCHED_TIMEOUT,
|
|
86
151
|
)
|
|
87
152
|
@staticmethod
|
|
153
|
+
@backoff.on_exception(backoff.expo, SlurmExecutionError, max_tries=5, max_time=60.0)
|
|
88
154
|
async def _batched_get_state(
|
|
89
|
-
|
|
90
|
-
|
|
155
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig],
|
|
156
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
157
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
91
158
|
async def _get_state(
|
|
92
|
-
options:
|
|
93
|
-
) -> Sequence[status.SlurmJobState]:
|
|
159
|
+
options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
|
|
160
|
+
) -> tp.Sequence[status.SlurmJobState]:
|
|
94
161
|
result = await get_client().run(
|
|
95
162
|
options,
|
|
96
163
|
[
|
|
97
164
|
"sacct",
|
|
98
165
|
"--jobs",
|
|
99
|
-
",".join(
|
|
166
|
+
",".join([slurm_job.job_id for slurm_job in slurm_jobs]),
|
|
100
167
|
"--format",
|
|
101
168
|
"JobID,State",
|
|
102
169
|
"--allocations",
|
|
@@ -113,32 +180,35 @@ class _BatchedSlurmHandle:
|
|
|
113
180
|
states_by_job_id[job_id] = status.SlurmJobState.from_slurm_str(state)
|
|
114
181
|
|
|
115
182
|
job_states = []
|
|
116
|
-
for
|
|
117
|
-
if job_id in states_by_job_id:
|
|
118
|
-
job_states.append(states_by_job_id[job_id])
|
|
183
|
+
for slurm_job in slurm_jobs:
|
|
184
|
+
if slurm_job.job_id in states_by_job_id:
|
|
185
|
+
job_states.append(states_by_job_id[slurm_job.job_id])
|
|
119
186
|
# This is a stupid hack around sacct's inability to display state information for
|
|
120
187
|
# array job elements that haven't begun. We'll assume that if the job ID is not found,
|
|
121
188
|
# and it's an array job, then it's pending.
|
|
122
|
-
elif
|
|
189
|
+
elif slurm_job.is_array_job:
|
|
123
190
|
job_states.append(status.SlurmJobState.PENDING)
|
|
124
191
|
else:
|
|
125
|
-
raise SlurmExecutionError(f"Failed to find job state info for {
|
|
192
|
+
raise SlurmExecutionError(f"Failed to find job state info for {slurm_job!r}")
|
|
126
193
|
return job_states
|
|
127
194
|
|
|
128
|
-
|
|
195
|
+
# Group Slurm jobs by their cluster so we can batch requests
|
|
196
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
129
197
|
|
|
198
|
+
# Async get state for each cluster
|
|
130
199
|
job_states_per_cluster = await asyncio.gather(*[
|
|
131
|
-
_get_state(options,
|
|
200
|
+
_get_state(options, jobs) for options, jobs in jobs_by_cluster.items()
|
|
132
201
|
])
|
|
133
|
-
job_states_by_cluster: dict[
|
|
134
|
-
asyncssh.SSHClientConnectionOptions, dict[str, status.SlurmJobState]
|
|
135
|
-
] = {}
|
|
136
|
-
for options, job_states in zip(ssh_options, job_states_per_cluster):
|
|
137
|
-
job_states_by_cluster[options] = dict(zip(jobs_by_cluster[options], job_states))
|
|
138
202
|
|
|
203
|
+
# Reconstruct the job states by cluster
|
|
204
|
+
job_states_by_cluster = {}
|
|
205
|
+
for ssh_config, job_states in zip(ssh_configs, job_states_per_cluster):
|
|
206
|
+
job_states_by_cluster[ssh_config] = dict(zip(jobs_by_cluster[ssh_config], job_states))
|
|
207
|
+
|
|
208
|
+
# Reconstruct the job states in the original order
|
|
139
209
|
job_states = []
|
|
140
|
-
for
|
|
141
|
-
job_states.append(job_states_by_cluster[
|
|
210
|
+
for ssh_config, slurm_job in zip(ssh_configs, slurm_jobs):
|
|
211
|
+
job_states.append(job_states_by_cluster[ssh_config][slurm_job])
|
|
142
212
|
return job_states
|
|
143
213
|
|
|
144
214
|
@functools.partial(
|
|
@@ -148,31 +218,33 @@ class _BatchedSlurmHandle:
|
|
|
148
218
|
)
|
|
149
219
|
@staticmethod
|
|
150
220
|
async def _batched_cancel(
|
|
151
|
-
|
|
152
|
-
|
|
221
|
+
ssh_configs: tp.Sequence[config.SlurmSSHConfig],
|
|
222
|
+
slurm_jobs: tp.Sequence[SlurmJob],
|
|
223
|
+
) -> tp.Sequence[None]:
|
|
153
224
|
async def _cancel(
|
|
154
|
-
options:
|
|
225
|
+
options: config.SlurmSSHConfig, slurm_jobs: tp.Sequence[SlurmJob]
|
|
155
226
|
) -> None:
|
|
156
|
-
await get_client().run(
|
|
227
|
+
await get_client().run(
|
|
228
|
+
options,
|
|
229
|
+
["scancel", " ".join([slurm_job.job_id for slurm_job in slurm_jobs])],
|
|
230
|
+
check=True,
|
|
231
|
+
)
|
|
157
232
|
|
|
158
|
-
jobs_by_cluster =
|
|
233
|
+
jobs_by_cluster = _group_by_ssh_configs(ssh_configs, slurm_jobs)
|
|
159
234
|
return await asyncio.gather(*[
|
|
160
235
|
_cancel(options, job_ids) for options, job_ids in jobs_by_cluster.items()
|
|
161
236
|
])
|
|
162
237
|
|
|
163
238
|
|
|
164
239
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
165
|
-
class SlurmHandle(_BatchedSlurmHandle):
|
|
240
|
+
class SlurmHandle(_BatchedSlurmHandle, tp.Generic[SlurmJobT]):
|
|
166
241
|
"""A handle for referring to the launched container."""
|
|
167
242
|
|
|
243
|
+
experiment_id: int
|
|
168
244
|
ssh: config.SlurmSSHConfig
|
|
169
|
-
|
|
245
|
+
slurm_job: Descriptor[SlurmJobT, str] = SlurmJobDescriptor[SlurmJobT]()
|
|
170
246
|
job_name: str # XManager job name associated with this handle
|
|
171
247
|
|
|
172
|
-
def __post_init__(self):
|
|
173
|
-
if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
|
|
174
|
-
raise ValueError(f"Invalid job ID: {self.job_id}")
|
|
175
|
-
|
|
176
248
|
@backoff.on_predicate(
|
|
177
249
|
backoff.constant,
|
|
178
250
|
lambda state: state in status.SlurmActiveJobStates,
|
|
@@ -183,15 +255,56 @@ class SlurmHandle(_BatchedSlurmHandle):
|
|
|
183
255
|
return await self.get_state()
|
|
184
256
|
|
|
185
257
|
async def stop(self) -> None:
|
|
186
|
-
await self._batched_cancel(self.ssh
|
|
258
|
+
await self._batched_cancel(self.ssh, self.slurm_job)
|
|
187
259
|
|
|
188
260
|
async def get_state(self) -> status.SlurmJobState:
|
|
189
|
-
return await self._batched_get_state(self.ssh
|
|
190
|
-
|
|
261
|
+
return await self._batched_get_state(self.ssh, self.slurm_job)
|
|
191
262
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
263
|
+
async def logs(
|
|
264
|
+
self, *, num_lines: int, block_size: int, wait: bool, follow: bool
|
|
265
|
+
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
266
|
+
file = f".local/state/xm-slurm/{self.experiment_id}/slurm-{self.slurm_job.job_id}.out"
|
|
267
|
+
conn = await get_client().connection(self.ssh)
|
|
268
|
+
async with conn.start_sftp_client() as sftp:
|
|
269
|
+
if wait:
|
|
270
|
+
while not (await sftp.exists(file)):
|
|
271
|
+
await asyncio.sleep(5)
|
|
272
|
+
|
|
273
|
+
async with sftp.open(file, "rb") as remote_file:
|
|
274
|
+
file_stat = await remote_file.stat()
|
|
275
|
+
file_size = file_stat.size
|
|
276
|
+
assert file_size is not None
|
|
277
|
+
|
|
278
|
+
data = b""
|
|
279
|
+
lines = []
|
|
280
|
+
position = file_size
|
|
281
|
+
|
|
282
|
+
while len(lines) <= num_lines and position > 0:
|
|
283
|
+
read_size = min(block_size, position)
|
|
284
|
+
position -= read_size
|
|
285
|
+
await remote_file.seek(position)
|
|
286
|
+
chunk = await remote_file.read(read_size)
|
|
287
|
+
data = chunk + data
|
|
288
|
+
lines = data.splitlines()
|
|
289
|
+
|
|
290
|
+
if position <= 0:
|
|
291
|
+
yield Rule("[bold red]BEGINNING OF FILE[/bold red]")
|
|
292
|
+
for line in lines[-num_lines:]:
|
|
293
|
+
yield line.decode("utf-8", errors="replace")
|
|
294
|
+
|
|
295
|
+
if (await self.get_state()) not in status.SlurmActiveJobStates:
|
|
296
|
+
yield Rule("[bold red]END OF FILE[/bold red]")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
if not follow:
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
await remote_file.seek(file_size)
|
|
303
|
+
while True:
|
|
304
|
+
if new_data := (await remote_file.read(block_size)):
|
|
305
|
+
yield new_data.decode("utf-8", errors="replace")
|
|
306
|
+
else:
|
|
307
|
+
await asyncio.sleep(0.25)
|
|
195
308
|
|
|
196
309
|
|
|
197
310
|
@functools.cache
|
|
@@ -218,11 +331,14 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
|
|
|
218
331
|
return template_env
|
|
219
332
|
|
|
220
333
|
|
|
334
|
+
@functools.cache
|
|
335
|
+
def get_client() -> "Client":
|
|
336
|
+
return Client()
|
|
337
|
+
|
|
338
|
+
|
|
221
339
|
class Client:
|
|
222
340
|
def __init__(self) -> None:
|
|
223
|
-
self._connections = dict[
|
|
224
|
-
asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
|
|
225
|
-
]()
|
|
341
|
+
self._connections = dict[config.SlurmSSHConfig, asyncssh.SSHClientConnection]()
|
|
226
342
|
self._connection_lock = asyncio.Lock()
|
|
227
343
|
|
|
228
344
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
@@ -231,53 +347,52 @@ class Client:
|
|
|
231
347
|
async with conn.start_sftp_client() as sftp_client:
|
|
232
348
|
await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
|
|
233
349
|
|
|
234
|
-
async def connection(
|
|
235
|
-
self
|
|
236
|
-
options: asyncssh.SSHClientConnectionOptions,
|
|
237
|
-
) -> asyncssh.SSHClientConnection:
|
|
238
|
-
if options not in self._connections:
|
|
350
|
+
async def connection(self, ssh_config: config.SlurmSSHConfig) -> asyncssh.SSHClientConnection:
|
|
351
|
+
if ssh_config not in self._connections:
|
|
239
352
|
async with self._connection_lock:
|
|
240
353
|
try:
|
|
241
|
-
conn, _ = await asyncssh.create_connection(
|
|
354
|
+
conn, _ = await asyncssh.create_connection(
|
|
355
|
+
NoKBAuthSSHClient, options=ssh_config.connection_options
|
|
356
|
+
)
|
|
242
357
|
await self._setup_remote_connection(conn)
|
|
243
|
-
self._connections[
|
|
358
|
+
self._connections[ssh_config] = conn
|
|
244
359
|
except asyncssh.misc.PermissionDenied as ex:
|
|
245
360
|
raise SlurmExecutionError(
|
|
246
|
-
f"Permission denied connecting to {
|
|
361
|
+
f"Permission denied connecting to {ssh_config.host}"
|
|
247
362
|
) from ex
|
|
248
363
|
except asyncssh.misc.ConnectionLost as ex:
|
|
249
|
-
raise SlurmExecutionError(f"Connection lost to host {
|
|
364
|
+
raise SlurmExecutionError(f"Connection lost to host {ssh_config.host}") from ex
|
|
250
365
|
except asyncssh.misc.HostKeyNotVerifiable as ex:
|
|
251
366
|
raise SlurmExecutionError(
|
|
252
|
-
f"Cannot verify the public key for host {
|
|
367
|
+
f"Cannot verify the public key for host {ssh_config.host}"
|
|
253
368
|
) from ex
|
|
254
369
|
except asyncssh.misc.KeyExchangeFailed as ex:
|
|
255
370
|
raise SlurmExecutionError(
|
|
256
|
-
f"Failed to exchange keys with host {
|
|
371
|
+
f"Failed to exchange keys with host {ssh_config.host}"
|
|
257
372
|
) from ex
|
|
258
373
|
except asyncssh.Error as ex:
|
|
259
374
|
raise SlurmExecutionError(
|
|
260
|
-
f"SSH connection error when connecting to {
|
|
375
|
+
f"SSH connection error when connecting to {ssh_config.host}"
|
|
261
376
|
) from ex
|
|
262
377
|
|
|
263
|
-
return self._connections[
|
|
378
|
+
return self._connections[ssh_config]
|
|
264
379
|
|
|
265
380
|
@backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
|
|
266
381
|
async def run(
|
|
267
382
|
self,
|
|
268
|
-
|
|
269
|
-
command: xm.SequentialArgs | str | Sequence[str],
|
|
383
|
+
ssh_config: config.SlurmSSHConfig,
|
|
384
|
+
command: xm.SequentialArgs | str | tp.Sequence[str],
|
|
270
385
|
*,
|
|
271
386
|
check: bool = False,
|
|
272
387
|
timeout: float | None = None,
|
|
273
388
|
) -> asyncssh.SSHCompletedProcess:
|
|
274
|
-
client = await self.connection(
|
|
389
|
+
client = await self.connection(ssh_config)
|
|
275
390
|
if isinstance(command, xm.SequentialArgs):
|
|
276
391
|
command = command.to_list()
|
|
277
392
|
if not isinstance(command, str) and isinstance(command, collections.abc.Sequence):
|
|
278
393
|
command = shlex.join(command)
|
|
279
394
|
assert isinstance(command, str)
|
|
280
|
-
|
|
395
|
+
logger.debug("Running command on %s: %s", ssh_config.host, command)
|
|
281
396
|
|
|
282
397
|
return await client.run(command, check=check, timeout=timeout)
|
|
283
398
|
|
|
@@ -285,8 +400,9 @@ class Client:
|
|
|
285
400
|
self,
|
|
286
401
|
*,
|
|
287
402
|
job: xm.Job | xm.JobGroup,
|
|
403
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
288
404
|
cluster: SlurmClusterConfig,
|
|
289
|
-
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
|
|
405
|
+
args: tp.Mapping[str, tp.Any] | tp.Sequence[tp.Mapping[str, tp.Any]] | None,
|
|
290
406
|
experiment_id: int,
|
|
291
407
|
identity: str | None,
|
|
292
408
|
) -> str:
|
|
@@ -297,7 +413,7 @@ class Client:
|
|
|
297
413
|
|
|
298
414
|
# Sanitize job groups
|
|
299
415
|
if isinstance(job, xm.JobGroup) and len(job.jobs) == 1:
|
|
300
|
-
job =
|
|
416
|
+
job = tp.cast(xm.Job, list(job.jobs.values())[0])
|
|
301
417
|
elif isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
302
418
|
raise ValueError("Job group must have at least one job")
|
|
303
419
|
|
|
@@ -315,6 +431,7 @@ class Client:
|
|
|
315
431
|
|
|
316
432
|
return template.render(
|
|
317
433
|
job=job_array,
|
|
434
|
+
dependency=dependency,
|
|
318
435
|
cluster=cluster,
|
|
319
436
|
args=sequential_args,
|
|
320
437
|
env_vars=env_vars,
|
|
@@ -327,6 +444,7 @@ class Client:
|
|
|
327
444
|
env_vars = args.get("env_vars", None)
|
|
328
445
|
return template.render(
|
|
329
446
|
job=job,
|
|
447
|
+
dependency=dependency,
|
|
330
448
|
cluster=cluster,
|
|
331
449
|
args=sequential_args,
|
|
332
450
|
env_vars=env_vars,
|
|
@@ -347,6 +465,7 @@ class Client:
|
|
|
347
465
|
}
|
|
348
466
|
return template.render(
|
|
349
467
|
job_group=job_group,
|
|
468
|
+
dependency=dependency,
|
|
350
469
|
cluster=cluster,
|
|
351
470
|
args=sequential_args,
|
|
352
471
|
env_vars=env_vars,
|
|
@@ -356,54 +475,66 @@ class Client:
|
|
|
356
475
|
case _:
|
|
357
476
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
358
477
|
|
|
359
|
-
@
|
|
478
|
+
@tp.overload
|
|
360
479
|
async def launch(
|
|
361
480
|
self,
|
|
362
481
|
*,
|
|
363
482
|
cluster: SlurmClusterConfig,
|
|
364
483
|
job: xm.JobGroup,
|
|
365
|
-
|
|
484
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
485
|
+
args: tp.Mapping[str, JobArgs] | None,
|
|
366
486
|
experiment_id: int,
|
|
367
487
|
identity: str | None = ...,
|
|
368
488
|
) -> SlurmHandle: ...
|
|
369
489
|
|
|
370
|
-
@
|
|
490
|
+
@tp.overload
|
|
371
491
|
async def launch(
|
|
372
492
|
self,
|
|
373
493
|
*,
|
|
374
494
|
cluster: SlurmClusterConfig,
|
|
375
495
|
job: xm.Job,
|
|
376
|
-
|
|
496
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
497
|
+
args: tp.Sequence[JobArgs],
|
|
377
498
|
experiment_id: int,
|
|
378
499
|
identity: str | None = ...,
|
|
379
500
|
) -> list[SlurmHandle]: ...
|
|
380
501
|
|
|
381
|
-
@
|
|
502
|
+
@tp.overload
|
|
382
503
|
async def launch(
|
|
383
504
|
self,
|
|
384
505
|
*,
|
|
385
506
|
cluster: SlurmClusterConfig,
|
|
386
507
|
job: xm.Job,
|
|
508
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
387
509
|
args: JobArgs,
|
|
388
510
|
experiment_id: int,
|
|
389
511
|
identity: str | None = ...,
|
|
390
512
|
) -> SlurmHandle: ...
|
|
391
513
|
|
|
392
|
-
async def launch(
|
|
393
|
-
|
|
514
|
+
async def launch(
|
|
515
|
+
self,
|
|
516
|
+
*,
|
|
517
|
+
cluster: SlurmClusterConfig,
|
|
518
|
+
job: xm.Job | xm.JobGroup,
|
|
519
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
520
|
+
args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs | None,
|
|
521
|
+
experiment_id: int,
|
|
522
|
+
identity: str | None = None,
|
|
523
|
+
):
|
|
394
524
|
template = await self.template(
|
|
395
525
|
job=job,
|
|
526
|
+
dependency=dependency,
|
|
396
527
|
cluster=cluster,
|
|
397
528
|
args=args,
|
|
398
529
|
experiment_id=experiment_id,
|
|
399
530
|
identity=identity,
|
|
400
531
|
)
|
|
401
|
-
|
|
532
|
+
logger.debug("Slurm submission script:\n%s", template)
|
|
402
533
|
|
|
403
534
|
# Hash submission script
|
|
404
535
|
template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
|
|
405
536
|
|
|
406
|
-
conn = await self.connection(cluster.ssh
|
|
537
|
+
conn = await self.connection(cluster.ssh)
|
|
407
538
|
async with conn.start_sftp_client() as sftp:
|
|
408
539
|
# Write the submission script to the cluster
|
|
409
540
|
# TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
|
|
@@ -416,7 +547,7 @@ class Client:
|
|
|
416
547
|
|
|
417
548
|
# Construct and run command on the cluster
|
|
418
549
|
command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
|
|
419
|
-
result = await self.run(cluster.ssh
|
|
550
|
+
result = await self.run(cluster.ssh, command)
|
|
420
551
|
if result.returncode != 0:
|
|
421
552
|
raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
|
|
422
553
|
|
|
@@ -435,21 +566,32 @@ class Client:
|
|
|
435
566
|
assert job.name is not None
|
|
436
567
|
return [
|
|
437
568
|
SlurmHandle(
|
|
569
|
+
experiment_id=experiment_id,
|
|
438
570
|
ssh=cluster.ssh,
|
|
439
|
-
|
|
571
|
+
slurm_job=f"{slurm_job_id}_{array_index}",
|
|
440
572
|
job_name=job.name,
|
|
441
573
|
)
|
|
442
574
|
for array_index in range(len(args))
|
|
443
575
|
]
|
|
444
576
|
elif isinstance(job, xm.Job):
|
|
445
577
|
assert job.name is not None
|
|
446
|
-
return SlurmHandle(
|
|
578
|
+
return SlurmHandle(
|
|
579
|
+
experiment_id=experiment_id,
|
|
580
|
+
ssh=cluster.ssh,
|
|
581
|
+
slurm_job=slurm_job_id,
|
|
582
|
+
job_name=job.name,
|
|
583
|
+
)
|
|
447
584
|
elif isinstance(job, xm.JobGroup):
|
|
448
585
|
# TODO: make this work for actual job groups.
|
|
449
|
-
job = mit.one(job.jobs.values())
|
|
586
|
+
job = tp.cast(xm.Job, mit.one(job.jobs.values()))
|
|
450
587
|
assert isinstance(job, xm.Job)
|
|
451
588
|
assert job.name is not None
|
|
452
|
-
return SlurmHandle(
|
|
589
|
+
return SlurmHandle(
|
|
590
|
+
experiment_id=experiment_id,
|
|
591
|
+
ssh=cluster.ssh,
|
|
592
|
+
slurm_job=slurm_job_id,
|
|
593
|
+
job_name=job.name,
|
|
594
|
+
)
|
|
453
595
|
else:
|
|
454
596
|
raise ValueError(f"Unsupported job type: {type(job)}")
|
|
455
597
|
|
|
@@ -458,30 +600,33 @@ class Client:
|
|
|
458
600
|
conn.close()
|
|
459
601
|
|
|
460
602
|
|
|
461
|
-
@
|
|
603
|
+
@tp.overload
|
|
462
604
|
async def launch(
|
|
463
605
|
*,
|
|
464
606
|
job: xm.JobGroup,
|
|
465
|
-
|
|
607
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
608
|
+
args: tp.Mapping[str, JobArgs],
|
|
466
609
|
experiment_id: int,
|
|
467
610
|
identity: str | None = ...,
|
|
468
611
|
) -> SlurmHandle: ...
|
|
469
612
|
|
|
470
613
|
|
|
471
|
-
@
|
|
614
|
+
@tp.overload
|
|
472
615
|
async def launch(
|
|
473
616
|
*,
|
|
474
617
|
job: xm.Job,
|
|
475
|
-
|
|
618
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
619
|
+
args: tp.Sequence[JobArgs],
|
|
476
620
|
experiment_id: int,
|
|
477
621
|
identity: str | None = ...,
|
|
478
622
|
) -> list[SlurmHandle]: ...
|
|
479
623
|
|
|
480
624
|
|
|
481
|
-
@
|
|
625
|
+
@tp.overload
|
|
482
626
|
async def launch(
|
|
483
627
|
*,
|
|
484
628
|
job: xm.Job,
|
|
629
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
485
630
|
args: JobArgs,
|
|
486
631
|
experiment_id: int,
|
|
487
632
|
identity: str | None = ...,
|
|
@@ -491,7 +636,8 @@ async def launch(
|
|
|
491
636
|
async def launch(
|
|
492
637
|
*,
|
|
493
638
|
job: xm.Job | xm.JobGroup,
|
|
494
|
-
|
|
639
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
640
|
+
args: tp.Mapping[str, JobArgs] | tp.Sequence[JobArgs] | JobArgs,
|
|
495
641
|
experiment_id: int,
|
|
496
642
|
identity: str | None = None,
|
|
497
643
|
) -> SlurmHandle | list[SlurmHandle]:
|
|
@@ -503,11 +649,14 @@ async def launch(
|
|
|
503
649
|
cluster = job_requirements.cluster
|
|
504
650
|
if cluster is None:
|
|
505
651
|
raise ValueError("Job must have a cluster requirement")
|
|
652
|
+
if cluster.validate is not None:
|
|
653
|
+
cluster.validate(job)
|
|
506
654
|
|
|
507
655
|
return await get_client().launch(
|
|
508
656
|
cluster=cluster,
|
|
509
657
|
job=job,
|
|
510
|
-
|
|
658
|
+
dependency=dependency,
|
|
659
|
+
args=tp.cast(JobArgs | tp.Sequence[JobArgs], args),
|
|
511
660
|
experiment_id=experiment_id,
|
|
512
661
|
identity=identity,
|
|
513
662
|
)
|
|
@@ -521,6 +670,8 @@ async def launch(
|
|
|
521
670
|
raise ValueError("Job must have a Slurm executor")
|
|
522
671
|
if job_item.executor.requirements.cluster is None:
|
|
523
672
|
raise ValueError("Job must have a cluster requirement")
|
|
673
|
+
if job_item.executor.requirements.cluster.validate is not None:
|
|
674
|
+
job_item.executor.requirements.cluster.validate(job_item)
|
|
524
675
|
job_group_clusters.add(job_item.executor.requirements.cluster)
|
|
525
676
|
job_group_executors.add(id(job_item.executor))
|
|
526
677
|
if len(job_group_executors) != 1:
|
|
@@ -531,9 +682,8 @@ async def launch(
|
|
|
531
682
|
return await get_client().launch(
|
|
532
683
|
cluster=job_group_clusters.pop(),
|
|
533
684
|
job=job_group,
|
|
534
|
-
|
|
685
|
+
dependency=dependency,
|
|
686
|
+
args=tp.cast(tp.Mapping[str, JobArgs], args),
|
|
535
687
|
experiment_id=experiment_id,
|
|
536
688
|
identity=identity,
|
|
537
689
|
)
|
|
538
|
-
case _:
|
|
539
|
-
raise ValueError("Unsupported job type")
|
xm_slurm/executors.py
CHANGED
|
@@ -48,6 +48,9 @@ class Slurm(xm.Executor):
|
|
|
48
48
|
qos: str | None = None
|
|
49
49
|
priority: int | None = None
|
|
50
50
|
|
|
51
|
+
# Job dependency handling
|
|
52
|
+
kill_on_invalid_dependencies: bool = True
|
|
53
|
+
|
|
51
54
|
# Job rescheduling
|
|
52
55
|
timeout_signal: signal.Signals = signal.SIGUSR2
|
|
53
56
|
timeout_signal_grace_period: dt.timedelta = dt.timedelta(seconds=90)
|
|
@@ -93,6 +96,11 @@ class Slurm(xm.Executor):
|
|
|
93
96
|
minutes, seconds = divmod(remainder, 60)
|
|
94
97
|
directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
|
|
95
98
|
|
|
99
|
+
# Job dependency handling
|
|
100
|
+
directives.append(
|
|
101
|
+
f"--kill-on-invalid-dep={'yes' if self.kill_on_invalid_dependencies else 'no'}"
|
|
102
|
+
)
|
|
103
|
+
|
|
96
104
|
# Placement
|
|
97
105
|
if self.account:
|
|
98
106
|
directives.append(f"--account={self.account}")
|
|
@@ -113,15 +121,3 @@ class Slurm(xm.Executor):
|
|
|
113
121
|
directives.append("--no-requeue")
|
|
114
122
|
|
|
115
123
|
return directives
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class DockerSpec(xm.ExecutorSpec):
|
|
119
|
-
"""Local Docker executor specification that describes the container runtime."""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
class Docker(xm.Executor):
|
|
123
|
-
"""Local Docker executor describing the runtime environment."""
|
|
124
|
-
|
|
125
|
-
@classmethod
|
|
126
|
-
def Spec(cls) -> DockerSpec:
|
|
127
|
-
return DockerSpec()
|