xmanager-slurm 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +1 -1
- xm_slurm/config.py +7 -2
- xm_slurm/constants.py +4 -0
- xm_slurm/contrib/clusters/__init__.py +9 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +20 -15
- xm_slurm/execution.py +246 -96
- xm_slurm/executors.py +8 -12
- xm_slurm/experiment.py +374 -83
- xm_slurm/experimental/parameter_controller.py +12 -10
- xm_slurm/packaging/{docker/local.py → docker.py} +126 -32
- xm_slurm/packaging/router.py +3 -1
- xm_slurm/packaging/utils.py +4 -28
- xm_slurm/resources.py +2 -0
- xm_slurm/scripts/cli.py +77 -0
- xm_slurm/templates/docker/mamba.Dockerfile +1 -1
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
- xm_slurm/templates/slurm/job.bash.j2 +4 -3
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +1 -0
- xm_slurm/types.py +23 -0
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/METADATA +1 -1
- xmanager_slurm-0.4.2.dist-info/RECORD +44 -0
- xmanager_slurm-0.4.2.dist-info/entry_points.txt +2 -0
- xm_slurm/packaging/docker/__init__.py +0 -69
- xm_slurm/packaging/docker/abc.py +0 -112
- xmanager_slurm-0.4.0.dist-info/RECORD +0 -42
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/WHEEL +0 -0
- {xmanager_slurm-0.4.0.dist-info → xmanager_slurm-0.4.2.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/experiment.py
CHANGED
|
@@ -2,25 +2,31 @@ import asyncio
|
|
|
2
2
|
import collections.abc
|
|
3
3
|
import contextvars
|
|
4
4
|
import dataclasses
|
|
5
|
+
import datetime as dt
|
|
5
6
|
import functools
|
|
6
7
|
import inspect
|
|
7
8
|
import json
|
|
9
|
+
import logging
|
|
8
10
|
import os
|
|
9
|
-
import
|
|
11
|
+
import traceback
|
|
12
|
+
import typing as tp
|
|
10
13
|
from concurrent import futures
|
|
11
|
-
from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
|
|
12
14
|
|
|
13
15
|
import more_itertools as mit
|
|
16
|
+
from rich.console import ConsoleRenderable
|
|
14
17
|
from xmanager import xm
|
|
15
|
-
from xmanager.xm import async_packager, id_predictor
|
|
18
|
+
from xmanager.xm import async_packager, core, id_predictor, job_operators
|
|
19
|
+
from xmanager.xm import job_blocks as xm_job_blocks
|
|
16
20
|
|
|
17
|
-
from xm_slurm import api, config, execution, executors
|
|
21
|
+
from xm_slurm import api, config, dependencies, execution, executors
|
|
18
22
|
from xm_slurm.console import console
|
|
19
23
|
from xm_slurm.job_blocks import JobArgs
|
|
20
24
|
from xm_slurm.packaging import router
|
|
21
25
|
from xm_slurm.status import SlurmWorkUnitStatus
|
|
22
26
|
from xm_slurm.utils import UserSet
|
|
23
27
|
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
24
30
|
_current_job_array_queue = contextvars.ContextVar[
|
|
25
31
|
asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
|
|
26
32
|
]("_current_job_array_queue", default=None)
|
|
@@ -28,7 +34,7 @@ _current_job_array_queue = contextvars.ContextVar[
|
|
|
28
34
|
|
|
29
35
|
def _validate_job(
|
|
30
36
|
job: xm.JobType,
|
|
31
|
-
args_view: JobArgs | Mapping[str, JobArgs],
|
|
37
|
+
args_view: JobArgs | tp.Mapping[str, JobArgs],
|
|
32
38
|
) -> None:
|
|
33
39
|
if not args_view:
|
|
34
40
|
return
|
|
@@ -51,7 +57,7 @@ def _validate_job(
|
|
|
51
57
|
)
|
|
52
58
|
|
|
53
59
|
if isinstance(job, xm.JobGroup) and key in job.jobs:
|
|
54
|
-
_validate_job(job.jobs[key],
|
|
60
|
+
_validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
|
|
55
61
|
elif key not in allowed_keys:
|
|
56
62
|
raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
|
|
57
63
|
|
|
@@ -62,7 +68,7 @@ class Artifact:
|
|
|
62
68
|
uri: str
|
|
63
69
|
|
|
64
70
|
def __hash__(self) -> int:
|
|
65
|
-
return hash(self.name)
|
|
71
|
+
return hash((type(self), self.name))
|
|
66
72
|
|
|
67
73
|
|
|
68
74
|
class ContextArtifacts(UserSet[Artifact]):
|
|
@@ -70,7 +76,7 @@ class ContextArtifacts(UserSet[Artifact]):
|
|
|
70
76
|
self,
|
|
71
77
|
owner: "SlurmExperiment | SlurmExperimentUnit",
|
|
72
78
|
*,
|
|
73
|
-
artifacts: Sequence[Artifact],
|
|
79
|
+
artifacts: tp.Sequence[Artifact],
|
|
74
80
|
):
|
|
75
81
|
super().__init__(
|
|
76
82
|
artifacts,
|
|
@@ -124,8 +130,8 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
124
130
|
def __init__(
|
|
125
131
|
self,
|
|
126
132
|
experiment: xm.Experiment,
|
|
127
|
-
create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
|
|
128
|
-
args: JobArgs | None,
|
|
133
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
|
|
134
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
129
135
|
role: xm.ExperimentUnitRole,
|
|
130
136
|
identity: str = "",
|
|
131
137
|
) -> None:
|
|
@@ -136,25 +142,137 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
136
142
|
artifacts=ContextArtifacts(owner=self, artifacts=[]),
|
|
137
143
|
)
|
|
138
144
|
|
|
139
|
-
|
|
145
|
+
def add( # type: ignore
|
|
146
|
+
self,
|
|
147
|
+
job: xm.JobType,
|
|
148
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
149
|
+
*,
|
|
150
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
151
|
+
identity: str = "",
|
|
152
|
+
) -> tp.Awaitable[None]:
|
|
153
|
+
# Prioritize the identity given directly to the work unit at work unit
|
|
154
|
+
# creation time, as opposed to the identity passed when adding jobs to it as
|
|
155
|
+
# this is more consistent between job generator work units and regular work
|
|
156
|
+
# units.
|
|
157
|
+
identity = self.identity or identity
|
|
158
|
+
|
|
159
|
+
job = job_operators.shallow_copy_job_type(job) # type: ignore
|
|
160
|
+
if args is not None:
|
|
161
|
+
core._apply_args(job, args)
|
|
162
|
+
job_operators.populate_job_names(job) # type: ignore
|
|
163
|
+
|
|
164
|
+
def launch_job(job: xm.Job) -> tp.Awaitable[None]:
|
|
165
|
+
core._current_experiment.set(self.experiment)
|
|
166
|
+
core._current_experiment_unit.set(self)
|
|
167
|
+
return self._launch_job_group(
|
|
168
|
+
xm.JobGroup(**{job.name: job}), # type: ignore
|
|
169
|
+
core._work_unit_arguments(job, self._args),
|
|
170
|
+
dependency=dependency,
|
|
171
|
+
identity=identity,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
|
|
175
|
+
core._current_experiment.set(self.experiment)
|
|
176
|
+
core._current_experiment_unit.set(self)
|
|
177
|
+
return self._launch_job_group(
|
|
178
|
+
group,
|
|
179
|
+
core._work_unit_arguments(group, self._args),
|
|
180
|
+
dependency=dependency,
|
|
181
|
+
identity=identity,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def launch_job_generator(
|
|
185
|
+
job_generator: xm.JobGeneratorType,
|
|
186
|
+
) -> tp.Awaitable[None]:
|
|
187
|
+
if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
|
|
188
|
+
getattr(job_generator, "__call__")
|
|
189
|
+
):
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Job generator must be an async function. Signature needs to be "
|
|
192
|
+
"`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
|
|
193
|
+
)
|
|
194
|
+
core._current_experiment.set(self.experiment)
|
|
195
|
+
core._current_experiment_unit.set(self)
|
|
196
|
+
coroutine = job_generator(self, **(args or {}))
|
|
197
|
+
assert coroutine is not None
|
|
198
|
+
return coroutine
|
|
199
|
+
|
|
200
|
+
def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
|
|
201
|
+
core._current_experiment.set(self.experiment)
|
|
202
|
+
core._current_experiment_unit.set(self)
|
|
203
|
+
return self._launch_job_config(
|
|
204
|
+
job_config, dependency, tp.cast(JobArgs, args) or {}, identity
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
job_awaitable: tp.Awaitable[tp.Any]
|
|
208
|
+
match job:
|
|
209
|
+
case xm.Job() as job:
|
|
210
|
+
job_awaitable = launch_job(job)
|
|
211
|
+
case xm.JobGroup() as job_group:
|
|
212
|
+
job_awaitable = launch_job_group(job_group)
|
|
213
|
+
case job_generator if xm_job_blocks.is_job_generator(job):
|
|
214
|
+
job_awaitable = launch_job_generator(job_generator) # type: ignore
|
|
215
|
+
case xm.JobConfig() as job_config:
|
|
216
|
+
job_awaitable = launch_job_config(job_config)
|
|
217
|
+
case _:
|
|
218
|
+
raise TypeError(f"Unsupported job type: {job!r}")
|
|
219
|
+
|
|
220
|
+
launch_task = self._create_task(job_awaitable)
|
|
221
|
+
self._launch_tasks.append(launch_task)
|
|
222
|
+
return asyncio.wrap_future(launch_task)
|
|
223
|
+
|
|
224
|
+
async def _launch_job_group( # type: ignore
|
|
225
|
+
self,
|
|
226
|
+
job_group: xm.JobGroup,
|
|
227
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
228
|
+
*,
|
|
229
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
230
|
+
identity: str,
|
|
231
|
+
) -> None:
|
|
232
|
+
del job_group, dependency, args_view, identity
|
|
233
|
+
raise NotImplementedError
|
|
234
|
+
|
|
235
|
+
async def _launch_job_config( # type: ignore
|
|
236
|
+
self,
|
|
237
|
+
job_config: xm.JobConfig,
|
|
238
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
239
|
+
args_view: JobArgs,
|
|
240
|
+
identity: str,
|
|
241
|
+
) -> None:
|
|
242
|
+
del job_config, dependency, args_view, identity
|
|
243
|
+
raise NotImplementedError
|
|
244
|
+
|
|
245
|
+
@tp.overload
|
|
140
246
|
async def _submit_jobs_for_execution(
|
|
141
247
|
self,
|
|
142
248
|
job: xm.Job,
|
|
249
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
143
250
|
args_view: JobArgs,
|
|
144
251
|
identity: str | None = ...,
|
|
145
252
|
) -> execution.SlurmHandle: ...
|
|
146
253
|
|
|
147
|
-
@
|
|
254
|
+
@tp.overload
|
|
148
255
|
async def _submit_jobs_for_execution(
|
|
149
256
|
self,
|
|
150
257
|
job: xm.JobGroup,
|
|
151
|
-
|
|
258
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
259
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
152
260
|
identity: str | None = ...,
|
|
153
261
|
) -> execution.SlurmHandle: ...
|
|
154
262
|
|
|
155
|
-
|
|
263
|
+
@tp.overload
|
|
264
|
+
async def _submit_jobs_for_execution(
|
|
265
|
+
self,
|
|
266
|
+
job: xm.Job,
|
|
267
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
268
|
+
args_view: tp.Sequence[JobArgs],
|
|
269
|
+
identity: str | None = ...,
|
|
270
|
+
) -> list[execution.SlurmHandle]: ...
|
|
271
|
+
|
|
272
|
+
async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
|
|
156
273
|
return await execution.launch(
|
|
157
274
|
job=job,
|
|
275
|
+
dependency=dependency,
|
|
158
276
|
args=args_view,
|
|
159
277
|
experiment_id=self.experiment_id,
|
|
160
278
|
identity=identity,
|
|
@@ -173,7 +291,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
173
291
|
self.work_unit_id,
|
|
174
292
|
api.SlurmJobModel(
|
|
175
293
|
name=job.name,
|
|
176
|
-
slurm_job_id=handle.job_id,
|
|
294
|
+
slurm_job_id=handle.slurm_job.job_id,
|
|
177
295
|
slurm_ssh_config=handle.ssh.serialize(),
|
|
178
296
|
),
|
|
179
297
|
)
|
|
@@ -186,7 +304,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
186
304
|
self._launched_jobs.append(
|
|
187
305
|
xm.LaunchedJob(
|
|
188
306
|
name=job.name, # type: ignore
|
|
189
|
-
address=str(handle.job_id),
|
|
307
|
+
address=str(handle.slurm_job.job_id),
|
|
190
308
|
)
|
|
191
309
|
)
|
|
192
310
|
case xm.Job():
|
|
@@ -194,7 +312,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
194
312
|
self._launched_jobs.append(
|
|
195
313
|
xm.LaunchedJob(
|
|
196
314
|
name=handle.job.name, # type: ignore
|
|
197
|
-
address=str(handle.job_id),
|
|
315
|
+
address=str(handle.slurm_job.job_id),
|
|
198
316
|
)
|
|
199
317
|
)
|
|
200
318
|
|
|
@@ -221,10 +339,25 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
221
339
|
|
|
222
340
|
self.experiment._create_task(_stop_awaitable())
|
|
223
341
|
|
|
224
|
-
async def get_status(self) -> SlurmWorkUnitStatus:
|
|
342
|
+
async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
|
|
225
343
|
states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
|
|
226
344
|
return SlurmWorkUnitStatus.aggregate(states)
|
|
227
345
|
|
|
346
|
+
async def logs(
|
|
347
|
+
self,
|
|
348
|
+
*,
|
|
349
|
+
num_lines: int = 10,
|
|
350
|
+
block_size: int = 1024,
|
|
351
|
+
wait: bool = True,
|
|
352
|
+
follow: bool = False,
|
|
353
|
+
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
354
|
+
assert len(self._execution_handles) == 1, "Only one job handle is supported for logs."
|
|
355
|
+
handle = self._execution_handles[0] # TODO(jfarebro): interleave?
|
|
356
|
+
async for log in handle.logs(
|
|
357
|
+
num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
|
|
358
|
+
):
|
|
359
|
+
yield log
|
|
360
|
+
|
|
228
361
|
@property
|
|
229
362
|
def launched_jobs(self) -> list[xm.LaunchedJob]:
|
|
230
363
|
return self._launched_jobs
|
|
@@ -233,13 +366,27 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
233
366
|
def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
|
|
234
367
|
return self._context
|
|
235
368
|
|
|
369
|
+
def after_started(
|
|
370
|
+
self, *, time: dt.timedelta | None = None
|
|
371
|
+
) -> dependencies.SlurmJobDependencyAfter:
|
|
372
|
+
return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
|
|
373
|
+
|
|
374
|
+
def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
|
|
375
|
+
return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
|
|
376
|
+
|
|
377
|
+
def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
|
|
378
|
+
return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
|
|
379
|
+
|
|
380
|
+
def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
|
|
381
|
+
return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
|
|
382
|
+
|
|
236
383
|
|
|
237
384
|
class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
238
385
|
def __init__(
|
|
239
386
|
self,
|
|
240
387
|
experiment: "SlurmExperiment",
|
|
241
|
-
create_task: Callable[[Awaitable[Any]], futures.Future],
|
|
242
|
-
args: JobArgs,
|
|
388
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
|
|
389
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
243
390
|
role: xm.ExperimentUnitRole,
|
|
244
391
|
work_unit_id_predictor: id_predictor.Predictor,
|
|
245
392
|
identity: str = "",
|
|
@@ -258,7 +405,9 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
|
258
405
|
async def _launch_job_group( # type: ignore
|
|
259
406
|
self,
|
|
260
407
|
job: xm.JobGroup,
|
|
261
|
-
args_view: Mapping[str, JobArgs],
|
|
408
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
409
|
+
*,
|
|
410
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
262
411
|
identity: str,
|
|
263
412
|
) -> None:
|
|
264
413
|
global _current_job_array_queue
|
|
@@ -291,7 +440,9 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
|
291
440
|
# If the future is already done, i.e., the handle is already resolved, we don't need
|
|
292
441
|
# to submit the job again.
|
|
293
442
|
elif not future.done():
|
|
294
|
-
handle = await self._submit_jobs_for_execution(
|
|
443
|
+
handle = await self._submit_jobs_for_execution(
|
|
444
|
+
job, dependency, args_view, identity=identity
|
|
445
|
+
)
|
|
295
446
|
future.set_result(handle)
|
|
296
447
|
|
|
297
448
|
# Wait for the job handle, this is either coming from scheduling the job array
|
|
@@ -317,12 +468,16 @@ class SlurmAuxiliaryUnit(SlurmExperimentUnit):
|
|
|
317
468
|
async def _launch_job_group( # type: ignore
|
|
318
469
|
self,
|
|
319
470
|
job: xm.JobGroup,
|
|
320
|
-
args_view: Mapping[str, JobArgs],
|
|
471
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
472
|
+
*,
|
|
473
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
321
474
|
identity: str,
|
|
322
475
|
) -> None:
|
|
323
476
|
_validate_job(job, args_view)
|
|
324
477
|
|
|
325
|
-
slurm_handle = await self._submit_jobs_for_execution(
|
|
478
|
+
slurm_handle = await self._submit_jobs_for_execution(
|
|
479
|
+
job, dependency, args_view, identity=identity
|
|
480
|
+
)
|
|
326
481
|
self._ingest_launched_jobs(job, slurm_handle)
|
|
327
482
|
|
|
328
483
|
@property
|
|
@@ -392,7 +547,7 @@ class SlurmExperimentContextAnnotations:
|
|
|
392
547
|
)
|
|
393
548
|
|
|
394
549
|
@property
|
|
395
|
-
def tags(self) -> MutableSet[str]:
|
|
550
|
+
def tags(self) -> tp.MutableSet[str]:
|
|
396
551
|
return self._tags
|
|
397
552
|
|
|
398
553
|
@tags.setter
|
|
@@ -459,87 +614,82 @@ class SlurmExperiment(xm.Experiment):
|
|
|
459
614
|
)
|
|
460
615
|
self._work_unit_count = 0
|
|
461
616
|
|
|
462
|
-
@
|
|
463
|
-
def add(
|
|
617
|
+
@tp.overload
|
|
618
|
+
def add( # type: ignore
|
|
464
619
|
self,
|
|
465
620
|
job: xm.AuxiliaryUnitJob,
|
|
466
|
-
args: Mapping[str,
|
|
621
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
|
|
467
622
|
*,
|
|
623
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
468
624
|
identity: str = ...,
|
|
469
625
|
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
470
626
|
|
|
471
|
-
@
|
|
627
|
+
@tp.overload
|
|
472
628
|
def add(
|
|
473
629
|
self,
|
|
474
630
|
job: xm.JobGroup,
|
|
475
|
-
args: Mapping[str,
|
|
631
|
+
args: tp.Mapping[str, JobArgs] | None = ...,
|
|
476
632
|
*,
|
|
477
|
-
role: xm.WorkUnitRole =
|
|
633
|
+
role: xm.WorkUnitRole | None = ...,
|
|
634
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
478
635
|
identity: str = ...,
|
|
479
636
|
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
480
637
|
|
|
481
|
-
@
|
|
638
|
+
@tp.overload
|
|
482
639
|
def add(
|
|
483
640
|
self,
|
|
484
|
-
job: xm.
|
|
485
|
-
args:
|
|
641
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
642
|
+
args: tp.Sequence[JobArgs],
|
|
486
643
|
*,
|
|
487
|
-
role: xm.
|
|
644
|
+
role: xm.WorkUnitRole | None = ...,
|
|
645
|
+
dependency: dependencies.SlurmJobDependency
|
|
646
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
647
|
+
| None = ...,
|
|
488
648
|
identity: str = ...,
|
|
489
|
-
) -> asyncio.Future[
|
|
649
|
+
) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
|
|
490
650
|
|
|
491
|
-
@
|
|
651
|
+
@tp.overload
|
|
492
652
|
def add(
|
|
493
653
|
self,
|
|
494
654
|
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
495
|
-
args:
|
|
655
|
+
args: JobArgs | None = ...,
|
|
496
656
|
*,
|
|
497
|
-
role: xm.WorkUnitRole =
|
|
657
|
+
role: xm.WorkUnitRole | None = ...,
|
|
658
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
498
659
|
identity: str = ...,
|
|
499
660
|
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
500
661
|
|
|
501
|
-
@
|
|
502
|
-
def add(
|
|
503
|
-
self,
|
|
504
|
-
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
505
|
-
args: Mapping[str, Any] | None,
|
|
506
|
-
*,
|
|
507
|
-
role: xm.ExperimentUnitRole,
|
|
508
|
-
identity: str = ...,
|
|
509
|
-
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
510
|
-
|
|
511
|
-
@typing.overload
|
|
512
|
-
def add(
|
|
513
|
-
self,
|
|
514
|
-
job: xm.Job | xm.JobGeneratorType,
|
|
515
|
-
args: Sequence[Mapping[str, Any]],
|
|
516
|
-
*,
|
|
517
|
-
role: xm.WorkUnitRole = xm.WorkUnitRole(),
|
|
518
|
-
identity: str = ...,
|
|
519
|
-
) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
|
|
520
|
-
|
|
521
|
-
@typing.overload
|
|
662
|
+
@tp.overload
|
|
522
663
|
def add(
|
|
523
664
|
self,
|
|
524
665
|
job: xm.JobType,
|
|
525
666
|
*,
|
|
526
|
-
role: xm.AuxiliaryUnitRole
|
|
667
|
+
role: xm.AuxiliaryUnitRole,
|
|
668
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
527
669
|
identity: str = ...,
|
|
528
670
|
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
529
671
|
|
|
530
672
|
def add( # type: ignore
|
|
531
673
|
self,
|
|
532
674
|
job: xm.JobType,
|
|
533
|
-
args:
|
|
675
|
+
args: JobArgs
|
|
676
|
+
| tp.Mapping[str, JobArgs]
|
|
677
|
+
| tp.Sequence[tp.Mapping[str, tp.Any]]
|
|
678
|
+
| None = None,
|
|
534
679
|
*,
|
|
535
|
-
role: xm.ExperimentUnitRole =
|
|
680
|
+
role: xm.ExperimentUnitRole | None = None,
|
|
681
|
+
dependency: dependencies.SlurmJobDependency
|
|
682
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
683
|
+
| None = None,
|
|
536
684
|
identity: str = "",
|
|
537
685
|
) -> (
|
|
538
686
|
asyncio.Future[SlurmAuxiliaryUnit]
|
|
539
|
-
| asyncio.Future[SlurmExperimentUnit]
|
|
540
687
|
| asyncio.Future[SlurmWorkUnit]
|
|
541
|
-
| asyncio.Future[Sequence[SlurmWorkUnit]]
|
|
688
|
+
| asyncio.Future[tp.Sequence[SlurmWorkUnit]]
|
|
542
689
|
):
|
|
690
|
+
if role is None:
|
|
691
|
+
role = xm.WorkUnitRole()
|
|
692
|
+
|
|
543
693
|
if isinstance(args, collections.abc.Sequence):
|
|
544
694
|
if not isinstance(role, xm.WorkUnitRole):
|
|
545
695
|
raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
|
|
@@ -554,21 +704,76 @@ class SlurmExperiment(xm.Experiment):
|
|
|
554
704
|
# Validate job & args
|
|
555
705
|
for trial in args:
|
|
556
706
|
_validate_job(job, trial)
|
|
557
|
-
args =
|
|
707
|
+
args = tp.cast(tp.Sequence[JobArgs], args)
|
|
558
708
|
|
|
559
709
|
return asyncio.wrap_future(
|
|
560
|
-
self._create_task(self._launch_job_array(job, args, role, identity))
|
|
710
|
+
self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
|
|
711
|
+
loop=self._event_loop,
|
|
712
|
+
)
|
|
713
|
+
if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
|
|
714
|
+
raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
|
|
715
|
+
|
|
716
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
717
|
+
role = job.role
|
|
718
|
+
self._added_roles[type(role)] += 1
|
|
719
|
+
|
|
720
|
+
if self._should_reload_experiment_unit(role):
|
|
721
|
+
experiment_unit_future = self._get_experiment_unit(
|
|
722
|
+
self.experiment_id, identity, role, args
|
|
561
723
|
)
|
|
562
724
|
else:
|
|
563
|
-
|
|
725
|
+
experiment_unit_future = self._create_experiment_unit(args, role, identity)
|
|
726
|
+
|
|
727
|
+
async def launch():
|
|
728
|
+
experiment_unit = await experiment_unit_future
|
|
729
|
+
try:
|
|
730
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
731
|
+
except Exception as experiment_exception:
|
|
732
|
+
logger.error(
|
|
733
|
+
"Stopping experiment unit (identity %r) after it failed with: %s",
|
|
734
|
+
identity,
|
|
735
|
+
experiment_exception,
|
|
736
|
+
)
|
|
737
|
+
try:
|
|
738
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
739
|
+
experiment_unit.stop()
|
|
740
|
+
else:
|
|
741
|
+
experiment_unit.stop(
|
|
742
|
+
mark_as_failed=True,
|
|
743
|
+
message=f"Work unit creation failed. {traceback.format_exc()}",
|
|
744
|
+
)
|
|
745
|
+
except Exception as stop_exception: # pylint: disable=broad-except
|
|
746
|
+
logger.error("Couldn't stop experiment unit: %s", stop_exception)
|
|
747
|
+
raise
|
|
748
|
+
return experiment_unit
|
|
749
|
+
|
|
750
|
+
async def reload():
|
|
751
|
+
experiment_unit = await experiment_unit_future
|
|
752
|
+
try:
|
|
753
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
754
|
+
except Exception as update_exception:
|
|
755
|
+
logging.error(
|
|
756
|
+
"Could not reload the experiment unit: %s",
|
|
757
|
+
update_exception,
|
|
758
|
+
)
|
|
759
|
+
raise
|
|
760
|
+
return experiment_unit
|
|
761
|
+
|
|
762
|
+
return asyncio.wrap_future(
|
|
763
|
+
self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
|
|
764
|
+
loop=self._event_loop,
|
|
765
|
+
)
|
|
564
766
|
|
|
565
767
|
async def _launch_job_array(
|
|
566
768
|
self,
|
|
567
769
|
job: xm.Job | xm.JobGeneratorType,
|
|
568
|
-
|
|
770
|
+
dependency: dependencies.SlurmJobDependency
|
|
771
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
772
|
+
| None,
|
|
773
|
+
args: tp.Sequence[JobArgs],
|
|
569
774
|
role: xm.WorkUnitRole,
|
|
570
775
|
identity: str = "",
|
|
571
|
-
) -> Sequence[SlurmWorkUnit]:
|
|
776
|
+
) -> tp.Sequence[SlurmWorkUnit]:
|
|
572
777
|
global _current_job_array_queue
|
|
573
778
|
|
|
574
779
|
# Create our job array queue and assign it to the current context
|
|
@@ -579,7 +784,11 @@ class SlurmExperiment(xm.Experiment):
|
|
|
579
784
|
# and collect the futures
|
|
580
785
|
wu_futures = []
|
|
581
786
|
for idx, trial in enumerate(args):
|
|
582
|
-
wu_futures.append(
|
|
787
|
+
wu_futures.append(
|
|
788
|
+
self.add(
|
|
789
|
+
job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
|
|
790
|
+
)
|
|
791
|
+
)
|
|
583
792
|
|
|
584
793
|
# We'll wait until XManager has filled the queue.
|
|
585
794
|
# There are two cases here, either we were given an xm.Job
|
|
@@ -589,7 +798,8 @@ class SlurmExperiment(xm.Experiment):
|
|
|
589
798
|
while not job_array_queue.full():
|
|
590
799
|
await asyncio.sleep(0.1)
|
|
591
800
|
|
|
592
|
-
# All jobs have been resolved
|
|
801
|
+
# All jobs have been resolved so now we'll perform sanity checks
|
|
802
|
+
# to make sure we can infer the sweep
|
|
593
803
|
executable, executor, name = None, None, None
|
|
594
804
|
resolved_args, resolved_env_vars, resolved_futures = [], [], []
|
|
595
805
|
while not job_array_queue.empty():
|
|
@@ -650,6 +860,78 @@ class SlurmExperiment(xm.Experiment):
|
|
|
650
860
|
for a, e in zip(resolved_args, resolved_env_vars)
|
|
651
861
|
]
|
|
652
862
|
|
|
863
|
+
# Dependency resolution
|
|
864
|
+
resolved_dependency = None
|
|
865
|
+
resolved_dependency_task_id_order = None
|
|
866
|
+
# one-to-one
|
|
867
|
+
if isinstance(dependency, collections.abc.Sequence):
|
|
868
|
+
if len(dependency) != len(wu_futures):
|
|
869
|
+
raise ValueError("Dependency list must be the same length as the number of trials.")
|
|
870
|
+
assert len(dependency) > 0, "Dependency list must not be empty."
|
|
871
|
+
|
|
872
|
+
# Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
|
|
873
|
+
# for any array jobs.
|
|
874
|
+
def _maybe_convert_afterok(
|
|
875
|
+
dep: dependencies.SlurmJobDependency,
|
|
876
|
+
) -> dependencies.SlurmJobDependency:
|
|
877
|
+
if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
|
|
878
|
+
handle.slurm_job.is_array_job for handle in dep.handles
|
|
879
|
+
]):
|
|
880
|
+
return dependencies.SlurmJobArrayDependencyAfterOK([
|
|
881
|
+
dataclasses.replace(
|
|
882
|
+
handle,
|
|
883
|
+
slurm_job=handle.slurm_job.array_job_id,
|
|
884
|
+
)
|
|
885
|
+
for handle in dep.handles
|
|
886
|
+
])
|
|
887
|
+
return dep
|
|
888
|
+
|
|
889
|
+
dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
|
|
890
|
+
dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
|
|
891
|
+
dependency_differences = functools.reduce(set.difference, dependency_sets, set())
|
|
892
|
+
# There should be NO differences between the dependencies of each trial after conversion.
|
|
893
|
+
if len(dependency_differences) > 0:
|
|
894
|
+
raise ValueError(
|
|
895
|
+
f"Found variable dependencies across trials: {dependency_differences}. "
|
|
896
|
+
"Slurm job arrays require the same dependencies across all trials. "
|
|
897
|
+
)
|
|
898
|
+
resolved_dependency = dependencies_converted[0]
|
|
899
|
+
|
|
900
|
+
# This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
|
|
901
|
+
# in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
|
|
902
|
+
original_array_dependencies = [
|
|
903
|
+
mit.one(
|
|
904
|
+
filter(
|
|
905
|
+
lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
|
|
906
|
+
and all([handle.slurm_job.is_array_job for handle in dep.handles]),
|
|
907
|
+
deps.flatten(),
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
for deps in dependency
|
|
911
|
+
]
|
|
912
|
+
resolved_dependency_task_id_order = [
|
|
913
|
+
int(
|
|
914
|
+
mit.one(
|
|
915
|
+
functools.reduce(
|
|
916
|
+
set.difference,
|
|
917
|
+
[handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
|
|
918
|
+
)
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
for dep in original_array_dependencies
|
|
922
|
+
]
|
|
923
|
+
assert len(resolved_dependency_task_id_order) == len(sweep_args)
|
|
924
|
+
assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
|
|
925
|
+
"Dependent job array tasks should have task ids 0, 1, ..., N. "
|
|
926
|
+
f"Found: {resolved_dependency_task_id_order}"
|
|
927
|
+
)
|
|
928
|
+
# one-to-many
|
|
929
|
+
elif isinstance(dependency, dependencies.SlurmJobDependency):
|
|
930
|
+
resolved_dependency = dependency
|
|
931
|
+
assert resolved_dependency is None or isinstance(
|
|
932
|
+
resolved_dependency, dependencies.SlurmJobDependency
|
|
933
|
+
), "Invalid dependency type"
|
|
934
|
+
|
|
653
935
|
# No support for sweep_env_vars right now.
|
|
654
936
|
# We schedule the job array and then we'll resolve all the work units with
|
|
655
937
|
# the handles Slurm gives back to us.
|
|
@@ -666,10 +948,18 @@ class SlurmExperiment(xm.Experiment):
|
|
|
666
948
|
args=xm.SequentialArgs.from_collection(common_args),
|
|
667
949
|
env_vars=dict(common_env_vars),
|
|
668
950
|
),
|
|
669
|
-
|
|
951
|
+
dependency=resolved_dependency,
|
|
952
|
+
args=[
|
|
953
|
+
sweep_args[resolved_dependency_task_id_order.index(i)]
|
|
954
|
+
for i in range(len(sweep_args))
|
|
955
|
+
]
|
|
956
|
+
if resolved_dependency_task_id_order
|
|
957
|
+
else sweep_args,
|
|
670
958
|
experiment_id=self.experiment_id,
|
|
671
959
|
identity=identity,
|
|
672
960
|
)
|
|
961
|
+
if resolved_dependency_task_id_order:
|
|
962
|
+
handles = [handles[i] for i in resolved_dependency_task_id_order]
|
|
673
963
|
except Exception as e:
|
|
674
964
|
for future in resolved_futures:
|
|
675
965
|
future.set_exception(e)
|
|
@@ -697,11 +987,11 @@ class SlurmExperiment(xm.Experiment):
|
|
|
697
987
|
|
|
698
988
|
def _create_experiment_unit( # type: ignore
|
|
699
989
|
self,
|
|
700
|
-
args: JobArgs,
|
|
990
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
701
991
|
role: xm.ExperimentUnitRole,
|
|
702
992
|
identity: str,
|
|
703
|
-
) -> Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
|
|
704
|
-
def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
|
|
993
|
+
) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
|
|
994
|
+
def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
|
|
705
995
|
work_unit = SlurmWorkUnit(
|
|
706
996
|
self,
|
|
707
997
|
self._create_task,
|
|
@@ -726,7 +1016,7 @@ class SlurmExperiment(xm.Experiment):
|
|
|
726
1016
|
future.set_result(work_unit)
|
|
727
1017
|
return future
|
|
728
1018
|
|
|
729
|
-
def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> Awaitable[SlurmAuxiliaryUnit]:
|
|
1019
|
+
def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
|
|
730
1020
|
auxiliary_unit = SlurmAuxiliaryUnit(
|
|
731
1021
|
self,
|
|
732
1022
|
self._create_task,
|
|
@@ -756,8 +1046,8 @@ class SlurmExperiment(xm.Experiment):
|
|
|
756
1046
|
experiment_id: int,
|
|
757
1047
|
identity: str,
|
|
758
1048
|
role: xm.ExperimentUnitRole,
|
|
759
|
-
args: JobArgs | None = None,
|
|
760
|
-
) -> Awaitable[
|
|
1049
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
1050
|
+
) -> tp.Awaitable[SlurmExperimentUnit]:
|
|
761
1051
|
del experiment_id, identity, role, args
|
|
762
1052
|
raise NotImplementedError
|
|
763
1053
|
|
|
@@ -797,7 +1087,7 @@ class SlurmExperiment(xm.Experiment):
|
|
|
797
1087
|
def work_unit_count(self) -> int:
|
|
798
1088
|
return self._work_unit_count
|
|
799
1089
|
|
|
800
|
-
def work_units(self) ->
|
|
1090
|
+
def work_units(self) -> dict[int, SlurmWorkUnit]:
|
|
801
1091
|
"""Gets work units created via self.add()."""
|
|
802
1092
|
return {
|
|
803
1093
|
wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
|
|
@@ -822,8 +1112,8 @@ def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
|
822
1112
|
experiment._work_unit_id_predictor = id_predictor.Predictor(1)
|
|
823
1113
|
|
|
824
1114
|
# Populate annotations
|
|
825
|
-
experiment.context.annotations.description = experiment_model.description
|
|
826
|
-
experiment.context.annotations.note = experiment_model.note
|
|
1115
|
+
experiment.context.annotations.description = experiment_model.description or ""
|
|
1116
|
+
experiment.context.annotations.note = experiment_model.note or ""
|
|
827
1117
|
experiment.context.annotations.tags = set(experiment_model.tags or [])
|
|
828
1118
|
|
|
829
1119
|
# Populate artifacts
|
|
@@ -846,8 +1136,9 @@ def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
|
846
1136
|
for job_model in wu_model.jobs:
|
|
847
1137
|
slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
|
|
848
1138
|
handle = execution.SlurmHandle(
|
|
1139
|
+
experiment_id=experiment_id,
|
|
849
1140
|
ssh=slurm_ssh_config,
|
|
850
|
-
|
|
1141
|
+
slurm_job=str(job_model.slurm_job_id),
|
|
851
1142
|
job_name=job_model.name,
|
|
852
1143
|
)
|
|
853
1144
|
work_unit._execution_handles.append(handle)
|