xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +6 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +105 -55
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +47 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +34 -22
- xm_slurm/execution.py +305 -107
- xm_slurm/executors.py +8 -12
- xm_slurm/experiment.py +601 -168
- xm_slurm/experimental/parameter_controller.py +202 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +42 -20
- xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
- xm_slurm/packaging/router.py +3 -1
- xm_slurm/packaging/utils.py +9 -81
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +52 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/mamba.Dockerfile +4 -2
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
- xm_slurm/templates/slurm/job.bash.j2 +4 -3
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
- {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/__init__.py +0 -75
- xm_slurm/packaging/docker/abc.py +0 -112
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/experiment.py
CHANGED
|
@@ -2,23 +2,31 @@ import asyncio
|
|
|
2
2
|
import collections.abc
|
|
3
3
|
import contextvars
|
|
4
4
|
import dataclasses
|
|
5
|
+
import datetime as dt
|
|
5
6
|
import functools
|
|
6
7
|
import inspect
|
|
7
8
|
import json
|
|
9
|
+
import logging
|
|
8
10
|
import os
|
|
9
|
-
import
|
|
11
|
+
import traceback
|
|
12
|
+
import typing as tp
|
|
10
13
|
from concurrent import futures
|
|
11
|
-
from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
|
|
12
14
|
|
|
15
|
+
import more_itertools as mit
|
|
16
|
+
from rich.console import ConsoleRenderable
|
|
13
17
|
from xmanager import xm
|
|
14
|
-
from xmanager.xm import async_packager, id_predictor
|
|
18
|
+
from xmanager.xm import async_packager, core, id_predictor, job_operators
|
|
19
|
+
from xmanager.xm import job_blocks as xm_job_blocks
|
|
15
20
|
|
|
16
|
-
from xm_slurm import api, execution, executors
|
|
21
|
+
from xm_slurm import api, config, dependencies, execution, executors
|
|
17
22
|
from xm_slurm.console import console
|
|
23
|
+
from xm_slurm.job_blocks import JobArgs
|
|
18
24
|
from xm_slurm.packaging import router
|
|
19
25
|
from xm_slurm.status import SlurmWorkUnitStatus
|
|
20
26
|
from xm_slurm.utils import UserSet
|
|
21
27
|
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
22
30
|
_current_job_array_queue = contextvars.ContextVar[
|
|
23
31
|
asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
|
|
24
32
|
]("_current_job_array_queue", default=None)
|
|
@@ -26,7 +34,7 @@ _current_job_array_queue = contextvars.ContextVar[
|
|
|
26
34
|
|
|
27
35
|
def _validate_job(
|
|
28
36
|
job: xm.JobType,
|
|
29
|
-
args_view: Mapping[str,
|
|
37
|
+
args_view: JobArgs | tp.Mapping[str, JobArgs],
|
|
30
38
|
) -> None:
|
|
31
39
|
if not args_view:
|
|
32
40
|
return
|
|
@@ -49,7 +57,7 @@ def _validate_job(
|
|
|
49
57
|
)
|
|
50
58
|
|
|
51
59
|
if isinstance(job, xm.JobGroup) and key in job.jobs:
|
|
52
|
-
_validate_job(job.jobs[key], expanded)
|
|
60
|
+
_validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
|
|
53
61
|
elif key not in allowed_keys:
|
|
54
62
|
raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
|
|
55
63
|
|
|
@@ -60,7 +68,7 @@ class Artifact:
|
|
|
60
68
|
uri: str
|
|
61
69
|
|
|
62
70
|
def __hash__(self) -> int:
|
|
63
|
-
return hash(self.name)
|
|
71
|
+
return hash((type(self), self.name))
|
|
64
72
|
|
|
65
73
|
|
|
66
74
|
class ContextArtifacts(UserSet[Artifact]):
|
|
@@ -68,7 +76,7 @@ class ContextArtifacts(UserSet[Artifact]):
|
|
|
68
76
|
self,
|
|
69
77
|
owner: "SlurmExperiment | SlurmExperimentUnit",
|
|
70
78
|
*,
|
|
71
|
-
artifacts: Sequence[Artifact],
|
|
79
|
+
artifacts: tp.Sequence[Artifact],
|
|
72
80
|
):
|
|
73
81
|
super().__init__(
|
|
74
82
|
artifacts,
|
|
@@ -117,50 +125,194 @@ class SlurmExperimentUnitMetadataContext:
|
|
|
117
125
|
class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
118
126
|
"""ExperimentUnit is a collection of semantically associated `Job`s."""
|
|
119
127
|
|
|
120
|
-
experiment: "SlurmExperiment"
|
|
128
|
+
experiment: "SlurmExperiment" # type: ignore
|
|
121
129
|
|
|
122
130
|
def __init__(
|
|
123
131
|
self,
|
|
124
132
|
experiment: xm.Experiment,
|
|
125
|
-
create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
|
|
126
|
-
args: Mapping[str,
|
|
133
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
|
|
134
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
127
135
|
role: xm.ExperimentUnitRole,
|
|
136
|
+
identity: str = "",
|
|
128
137
|
) -> None:
|
|
129
|
-
super().__init__(experiment, create_task, args, role)
|
|
138
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
130
139
|
self._launched_jobs: list[xm.LaunchedJob] = []
|
|
131
140
|
self._execution_handles: list[execution.SlurmHandle] = []
|
|
132
141
|
self._context = SlurmExperimentUnitMetadataContext(
|
|
133
142
|
artifacts=ContextArtifacts(owner=self, artifacts=[]),
|
|
134
143
|
)
|
|
135
144
|
|
|
145
|
+
def add( # type: ignore
|
|
146
|
+
self,
|
|
147
|
+
job: xm.JobType,
|
|
148
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
149
|
+
*,
|
|
150
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
151
|
+
identity: str = "",
|
|
152
|
+
) -> tp.Awaitable[None]:
|
|
153
|
+
# Prioritize the identity given directly to the work unit at work unit
|
|
154
|
+
# creation time, as opposed to the identity passed when adding jobs to it as
|
|
155
|
+
# this is more consistent between job generator work units and regular work
|
|
156
|
+
# units.
|
|
157
|
+
identity = self.identity or identity
|
|
158
|
+
|
|
159
|
+
job = job_operators.shallow_copy_job_type(job) # type: ignore
|
|
160
|
+
if args is not None:
|
|
161
|
+
core._apply_args(job, args)
|
|
162
|
+
job_operators.populate_job_names(job) # type: ignore
|
|
163
|
+
|
|
164
|
+
def launch_job(job: xm.Job) -> tp.Awaitable[None]:
|
|
165
|
+
core._current_experiment.set(self.experiment)
|
|
166
|
+
core._current_experiment_unit.set(self)
|
|
167
|
+
return self._launch_job_group(
|
|
168
|
+
xm.JobGroup(**{job.name: job}), # type: ignore
|
|
169
|
+
core._work_unit_arguments(job, self._args),
|
|
170
|
+
dependency=dependency,
|
|
171
|
+
identity=identity,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
|
|
175
|
+
core._current_experiment.set(self.experiment)
|
|
176
|
+
core._current_experiment_unit.set(self)
|
|
177
|
+
return self._launch_job_group(
|
|
178
|
+
group,
|
|
179
|
+
core._work_unit_arguments(group, self._args),
|
|
180
|
+
dependency=dependency,
|
|
181
|
+
identity=identity,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def launch_job_generator(
|
|
185
|
+
job_generator: xm.JobGeneratorType,
|
|
186
|
+
) -> tp.Awaitable[None]:
|
|
187
|
+
if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
|
|
188
|
+
getattr(job_generator, "__call__")
|
|
189
|
+
):
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Job generator must be an async function. Signature needs to be "
|
|
192
|
+
"`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
|
|
193
|
+
)
|
|
194
|
+
core._current_experiment.set(self.experiment)
|
|
195
|
+
core._current_experiment_unit.set(self)
|
|
196
|
+
coroutine = job_generator(self, **(args or {}))
|
|
197
|
+
assert coroutine is not None
|
|
198
|
+
return coroutine
|
|
199
|
+
|
|
200
|
+
def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
|
|
201
|
+
core._current_experiment.set(self.experiment)
|
|
202
|
+
core._current_experiment_unit.set(self)
|
|
203
|
+
return self._launch_job_config(
|
|
204
|
+
job_config, dependency, tp.cast(JobArgs, args) or {}, identity
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
job_awaitable: tp.Awaitable[tp.Any]
|
|
208
|
+
match job:
|
|
209
|
+
case xm.Job() as job:
|
|
210
|
+
job_awaitable = launch_job(job)
|
|
211
|
+
case xm.JobGroup() as job_group:
|
|
212
|
+
job_awaitable = launch_job_group(job_group)
|
|
213
|
+
case job_generator if xm_job_blocks.is_job_generator(job):
|
|
214
|
+
job_awaitable = launch_job_generator(job_generator) # type: ignore
|
|
215
|
+
case xm.JobConfig() as job_config:
|
|
216
|
+
job_awaitable = launch_job_config(job_config)
|
|
217
|
+
case _:
|
|
218
|
+
raise TypeError(f"Unsupported job type: {job!r}")
|
|
219
|
+
|
|
220
|
+
launch_task = self._create_task(job_awaitable)
|
|
221
|
+
self._launch_tasks.append(launch_task)
|
|
222
|
+
return asyncio.wrap_future(launch_task)
|
|
223
|
+
|
|
224
|
+
async def _launch_job_group( # type: ignore
|
|
225
|
+
self,
|
|
226
|
+
job_group: xm.JobGroup,
|
|
227
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
228
|
+
*,
|
|
229
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
230
|
+
identity: str,
|
|
231
|
+
) -> None:
|
|
232
|
+
del job_group, dependency, args_view, identity
|
|
233
|
+
raise NotImplementedError
|
|
234
|
+
|
|
235
|
+
async def _launch_job_config( # type: ignore
|
|
236
|
+
self,
|
|
237
|
+
job_config: xm.JobConfig,
|
|
238
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
239
|
+
args_view: JobArgs,
|
|
240
|
+
identity: str,
|
|
241
|
+
) -> None:
|
|
242
|
+
del job_config, dependency, args_view, identity
|
|
243
|
+
raise NotImplementedError
|
|
244
|
+
|
|
245
|
+
@tp.overload
|
|
246
|
+
async def _submit_jobs_for_execution(
|
|
247
|
+
self,
|
|
248
|
+
job: xm.Job,
|
|
249
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
250
|
+
args_view: JobArgs,
|
|
251
|
+
identity: str | None = ...,
|
|
252
|
+
) -> execution.SlurmHandle: ...
|
|
253
|
+
|
|
254
|
+
@tp.overload
|
|
255
|
+
async def _submit_jobs_for_execution(
|
|
256
|
+
self,
|
|
257
|
+
job: xm.JobGroup,
|
|
258
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
259
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
260
|
+
identity: str | None = ...,
|
|
261
|
+
) -> execution.SlurmHandle: ...
|
|
262
|
+
|
|
263
|
+
@tp.overload
|
|
136
264
|
async def _submit_jobs_for_execution(
|
|
137
265
|
self,
|
|
138
|
-
job: xm.Job
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
266
|
+
job: xm.Job,
|
|
267
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
268
|
+
args_view: tp.Sequence[JobArgs],
|
|
269
|
+
identity: str | None = ...,
|
|
270
|
+
) -> list[execution.SlurmHandle]: ...
|
|
271
|
+
|
|
272
|
+
async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
|
|
142
273
|
return await execution.launch(
|
|
143
274
|
job=job,
|
|
275
|
+
dependency=dependency,
|
|
144
276
|
args=args_view,
|
|
145
277
|
experiment_id=self.experiment_id,
|
|
146
278
|
identity=identity,
|
|
147
279
|
)
|
|
148
280
|
|
|
149
281
|
def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
|
|
282
|
+
self._execution_handles.append(handle)
|
|
283
|
+
|
|
284
|
+
def _ingest_job(job: xm.Job) -> None:
|
|
285
|
+
if not isinstance(self._role, xm.WorkUnitRole):
|
|
286
|
+
return
|
|
287
|
+
assert isinstance(self, SlurmWorkUnit)
|
|
288
|
+
assert job.name is not None
|
|
289
|
+
api.client().insert_job(
|
|
290
|
+
self.experiment_id,
|
|
291
|
+
self.work_unit_id,
|
|
292
|
+
api.SlurmJobModel(
|
|
293
|
+
name=job.name,
|
|
294
|
+
slurm_job_id=handle.slurm_job.job_id,
|
|
295
|
+
slurm_ssh_config=handle.ssh.serialize(),
|
|
296
|
+
),
|
|
297
|
+
)
|
|
298
|
+
|
|
150
299
|
match job:
|
|
151
300
|
case xm.JobGroup() as job_group:
|
|
152
301
|
for job in job_group.jobs.values():
|
|
302
|
+
assert isinstance(job, xm.Job)
|
|
303
|
+
_ingest_job(job)
|
|
153
304
|
self._launched_jobs.append(
|
|
154
305
|
xm.LaunchedJob(
|
|
155
306
|
name=job.name, # type: ignore
|
|
156
|
-
address=str(handle.job_id),
|
|
307
|
+
address=str(handle.slurm_job.job_id),
|
|
157
308
|
)
|
|
158
309
|
)
|
|
159
310
|
case xm.Job():
|
|
311
|
+
_ingest_job(job)
|
|
160
312
|
self._launched_jobs.append(
|
|
161
313
|
xm.LaunchedJob(
|
|
162
314
|
name=handle.job.name, # type: ignore
|
|
163
|
-
address=str(handle.job_id),
|
|
315
|
+
address=str(handle.slurm_job.job_id),
|
|
164
316
|
)
|
|
165
317
|
)
|
|
166
318
|
|
|
@@ -187,85 +339,116 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
187
339
|
|
|
188
340
|
self.experiment._create_task(_stop_awaitable())
|
|
189
341
|
|
|
190
|
-
async def get_status(self) -> SlurmWorkUnitStatus:
|
|
342
|
+
async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
|
|
191
343
|
states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
|
|
192
344
|
return SlurmWorkUnitStatus.aggregate(states)
|
|
193
345
|
|
|
346
|
+
async def logs(
|
|
347
|
+
self,
|
|
348
|
+
*,
|
|
349
|
+
num_lines: int = 10,
|
|
350
|
+
block_size: int = 1024,
|
|
351
|
+
wait: bool = True,
|
|
352
|
+
follow: bool = False,
|
|
353
|
+
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
354
|
+
assert len(self._execution_handles) == 1, "Only one job handle is supported for logs."
|
|
355
|
+
handle = self._execution_handles[0] # TODO(jfarebro): interleave?
|
|
356
|
+
async for log in handle.logs(
|
|
357
|
+
num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
|
|
358
|
+
):
|
|
359
|
+
yield log
|
|
360
|
+
|
|
361
|
+
@property
|
|
194
362
|
def launched_jobs(self) -> list[xm.LaunchedJob]:
|
|
195
363
|
return self._launched_jobs
|
|
196
364
|
|
|
197
365
|
@property
|
|
198
|
-
def context(self) -> SlurmExperimentUnitMetadataContext:
|
|
366
|
+
def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
|
|
199
367
|
return self._context
|
|
200
368
|
|
|
369
|
+
def after_started(
|
|
370
|
+
self, *, time: dt.timedelta | None = None
|
|
371
|
+
) -> dependencies.SlurmJobDependencyAfter:
|
|
372
|
+
return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
|
|
373
|
+
|
|
374
|
+
def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
|
|
375
|
+
return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
|
|
376
|
+
|
|
377
|
+
def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
|
|
378
|
+
return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
|
|
379
|
+
|
|
380
|
+
def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
|
|
381
|
+
return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
|
|
382
|
+
|
|
201
383
|
|
|
202
384
|
class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
203
385
|
def __init__(
|
|
204
386
|
self,
|
|
205
387
|
experiment: "SlurmExperiment",
|
|
206
|
-
create_task: Callable[[Awaitable[Any]], futures.Future],
|
|
207
|
-
args: Mapping[str,
|
|
388
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
|
|
389
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
208
390
|
role: xm.ExperimentUnitRole,
|
|
209
391
|
work_unit_id_predictor: id_predictor.Predictor,
|
|
392
|
+
identity: str = "",
|
|
210
393
|
) -> None:
|
|
211
|
-
super().__init__(experiment, create_task, args, role)
|
|
394
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
212
395
|
self._work_unit_id_predictor = work_unit_id_predictor
|
|
213
396
|
self._work_unit_id = self._work_unit_id_predictor.reserve_id()
|
|
214
|
-
api.client().insert_work_unit(
|
|
215
|
-
self.experiment_id,
|
|
216
|
-
api.WorkUnitPatchModel(
|
|
217
|
-
wid=self.work_unit_id,
|
|
218
|
-
identity=self.identity,
|
|
219
|
-
args=json.dumps(args),
|
|
220
|
-
),
|
|
221
|
-
)
|
|
222
397
|
|
|
223
|
-
def
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
name=self.experiment_unit_name,
|
|
230
|
-
slurm_job_id=handle.job_id, # type: ignore
|
|
231
|
-
slurm_cluster=json.dumps({
|
|
232
|
-
"host": handle.ssh_connection_options.host,
|
|
233
|
-
"username": handle.ssh_connection_options.username,
|
|
234
|
-
"port": handle.ssh_connection_options.port,
|
|
235
|
-
"config": handle.ssh_connection_options.config.get_options(False),
|
|
236
|
-
}),
|
|
237
|
-
),
|
|
238
|
-
)
|
|
398
|
+
def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
|
|
399
|
+
job_name = mit.one(job.jobs.keys())
|
|
400
|
+
for handle in self._execution_handles:
|
|
401
|
+
if handle.job_name == job_name:
|
|
402
|
+
return handle
|
|
403
|
+
return None
|
|
239
404
|
|
|
240
|
-
async def _launch_job_group(
|
|
405
|
+
async def _launch_job_group( # type: ignore
|
|
241
406
|
self,
|
|
242
407
|
job: xm.JobGroup,
|
|
243
|
-
args_view: Mapping[str,
|
|
408
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
409
|
+
*,
|
|
410
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
244
411
|
identity: str,
|
|
245
412
|
) -> None:
|
|
246
413
|
global _current_job_array_queue
|
|
247
414
|
_validate_job(job, args_view)
|
|
415
|
+
future = asyncio.Future[execution.SlurmHandle]()
|
|
416
|
+
|
|
417
|
+
# If we already have a handle for this job, we don't need to submit it again.
|
|
418
|
+
# We'll just resolve the future with the existing handle.
|
|
419
|
+
# Otherwise we'll add callbacks to ingest the handle and the launched jobs.
|
|
420
|
+
if existing_handle := self._get_existing_handle(job):
|
|
421
|
+
future.set_result(existing_handle)
|
|
422
|
+
else:
|
|
423
|
+
future.add_done_callback(
|
|
424
|
+
lambda handle: self._ingest_launched_jobs(job, handle.result())
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
api.client().update_work_unit(
|
|
428
|
+
self.experiment_id,
|
|
429
|
+
self.work_unit_id,
|
|
430
|
+
api.ExperimentUnitPatchModel(args=json.dumps(args_view), identity=None),
|
|
431
|
+
)
|
|
248
432
|
|
|
249
|
-
future = asyncio.Future()
|
|
250
433
|
async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
|
|
251
434
|
# If we're scheduling as part of a job queue (i.e., the queue is set on the context)
|
|
252
|
-
# then we'll insert the job and
|
|
253
|
-
#
|
|
435
|
+
# then we'll insert the job and future that'll get resolved to the proper handle
|
|
436
|
+
# when the Slurm job array is scheduled.
|
|
254
437
|
if job_array_queue := _current_job_array_queue.get():
|
|
255
438
|
job_array_queue.put_nowait((job, future))
|
|
256
|
-
# Otherwise we'
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
439
|
+
# Otherwise, we're scheduling a single job and we'll submit it for execution.
|
|
440
|
+
# If the future is already done, i.e., the handle is already resolved, we don't need
|
|
441
|
+
# to submit the job again.
|
|
442
|
+
elif not future.done():
|
|
443
|
+
handle = await self._submit_jobs_for_execution(
|
|
444
|
+
job, dependency, args_view, identity=identity
|
|
262
445
|
)
|
|
446
|
+
future.set_result(handle)
|
|
263
447
|
|
|
264
448
|
# Wait for the job handle, this is either coming from scheduling the job array
|
|
265
|
-
# or from
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
self._ingest_launched_jobs(job, handle)
|
|
449
|
+
# or from a single job submission. If an existing handle was found, this will be
|
|
450
|
+
# a no-op.
|
|
451
|
+
await future
|
|
269
452
|
|
|
270
453
|
@property
|
|
271
454
|
def experiment_unit_name(self) -> str:
|
|
@@ -282,20 +465,19 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
|
282
465
|
class SlurmAuxiliaryUnit(SlurmExperimentUnit):
|
|
283
466
|
"""An auxiliary unit operated by the Slurm backend."""
|
|
284
467
|
|
|
285
|
-
def
|
|
286
|
-
del handle
|
|
287
|
-
console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
|
|
288
|
-
|
|
289
|
-
async def _launch_job_group(
|
|
468
|
+
async def _launch_job_group( # type: ignore
|
|
290
469
|
self,
|
|
291
|
-
job: xm.
|
|
292
|
-
args_view: Mapping[str,
|
|
470
|
+
job: xm.JobGroup,
|
|
471
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
472
|
+
*,
|
|
473
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
293
474
|
identity: str,
|
|
294
475
|
) -> None:
|
|
295
476
|
_validate_job(job, args_view)
|
|
296
477
|
|
|
297
|
-
slurm_handle = await self._submit_jobs_for_execution(
|
|
298
|
-
|
|
478
|
+
slurm_handle = await self._submit_jobs_for_execution(
|
|
479
|
+
job, dependency, args_view, identity=identity
|
|
480
|
+
)
|
|
299
481
|
self._ingest_launched_jobs(job, slurm_handle)
|
|
300
482
|
|
|
301
483
|
@property
|
|
@@ -365,7 +547,7 @@ class SlurmExperimentContextAnnotations:
|
|
|
365
547
|
)
|
|
366
548
|
|
|
367
549
|
@property
|
|
368
|
-
def tags(self) -> MutableSet[str]:
|
|
550
|
+
def tags(self) -> tp.MutableSet[str]:
|
|
369
551
|
return self._tags
|
|
370
552
|
|
|
371
553
|
@tags.setter
|
|
@@ -432,75 +614,85 @@ class SlurmExperiment(xm.Experiment):
|
|
|
432
614
|
)
|
|
433
615
|
self._work_unit_count = 0
|
|
434
616
|
|
|
435
|
-
@
|
|
436
|
-
def add(
|
|
617
|
+
@tp.overload
|
|
618
|
+
def add( # type: ignore
|
|
437
619
|
self,
|
|
438
620
|
job: xm.AuxiliaryUnitJob,
|
|
439
|
-
args: Mapping[str,
|
|
621
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
|
|
440
622
|
*,
|
|
441
|
-
|
|
442
|
-
|
|
623
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
624
|
+
identity: str = ...,
|
|
625
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
443
626
|
|
|
444
|
-
@
|
|
627
|
+
@tp.overload
|
|
445
628
|
def add(
|
|
446
629
|
self,
|
|
447
|
-
job: xm.
|
|
448
|
-
args: Mapping[str,
|
|
630
|
+
job: xm.JobGroup,
|
|
631
|
+
args: tp.Mapping[str, JobArgs] | None = ...,
|
|
449
632
|
*,
|
|
450
|
-
role: xm.WorkUnitRole = ...,
|
|
451
|
-
|
|
633
|
+
role: xm.WorkUnitRole | None = ...,
|
|
634
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
635
|
+
identity: str = ...,
|
|
452
636
|
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
453
637
|
|
|
454
|
-
@
|
|
638
|
+
@tp.overload
|
|
455
639
|
def add(
|
|
456
640
|
self,
|
|
457
|
-
job: xm.
|
|
458
|
-
args:
|
|
641
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
642
|
+
args: tp.Sequence[JobArgs],
|
|
459
643
|
*,
|
|
460
|
-
role: xm.
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
644
|
+
role: xm.WorkUnitRole | None = ...,
|
|
645
|
+
dependency: dependencies.SlurmJobDependency
|
|
646
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
647
|
+
| None = ...,
|
|
648
|
+
identity: str = ...,
|
|
649
|
+
) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
|
|
650
|
+
|
|
651
|
+
@tp.overload
|
|
465
652
|
def add(
|
|
466
653
|
self,
|
|
467
|
-
job: xm.
|
|
468
|
-
args:
|
|
654
|
+
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
655
|
+
args: JobArgs | None = ...,
|
|
469
656
|
*,
|
|
470
|
-
role: xm.
|
|
471
|
-
|
|
472
|
-
|
|
657
|
+
role: xm.WorkUnitRole | None = ...,
|
|
658
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
659
|
+
identity: str = ...,
|
|
660
|
+
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
473
661
|
|
|
474
|
-
@
|
|
662
|
+
@tp.overload
|
|
475
663
|
def add(
|
|
476
664
|
self,
|
|
477
|
-
job: xm.
|
|
478
|
-
args: Sequence[Mapping[str, Any]],
|
|
665
|
+
job: xm.JobType,
|
|
479
666
|
*,
|
|
480
|
-
role: xm.
|
|
481
|
-
|
|
482
|
-
|
|
667
|
+
role: xm.AuxiliaryUnitRole,
|
|
668
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
669
|
+
identity: str = ...,
|
|
670
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
483
671
|
|
|
484
|
-
def add(
|
|
672
|
+
def add( # type: ignore
|
|
485
673
|
self,
|
|
486
674
|
job: xm.JobType,
|
|
487
|
-
args:
|
|
675
|
+
args: JobArgs
|
|
676
|
+
| tp.Mapping[str, JobArgs]
|
|
677
|
+
| tp.Sequence[tp.Mapping[str, tp.Any]]
|
|
678
|
+
| None = None,
|
|
488
679
|
*,
|
|
489
|
-
role: xm.ExperimentUnitRole =
|
|
680
|
+
role: xm.ExperimentUnitRole | None = None,
|
|
681
|
+
dependency: dependencies.SlurmJobDependency
|
|
682
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
683
|
+
| None = None,
|
|
490
684
|
identity: str = "",
|
|
491
685
|
) -> (
|
|
492
|
-
asyncio.Future[
|
|
686
|
+
asyncio.Future[SlurmAuxiliaryUnit]
|
|
493
687
|
| asyncio.Future[SlurmWorkUnit]
|
|
494
|
-
| asyncio.Future[Sequence[SlurmWorkUnit]]
|
|
688
|
+
| asyncio.Future[tp.Sequence[SlurmWorkUnit]]
|
|
495
689
|
):
|
|
690
|
+
if role is None:
|
|
691
|
+
role = xm.WorkUnitRole()
|
|
692
|
+
|
|
496
693
|
if isinstance(args, collections.abc.Sequence):
|
|
497
694
|
if not isinstance(role, xm.WorkUnitRole):
|
|
498
695
|
raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
|
|
499
|
-
if identity:
|
|
500
|
-
raise ValueError(
|
|
501
|
-
"Cannot set an identity on the root add call. "
|
|
502
|
-
"Please use a job generator and set the identity within."
|
|
503
|
-
)
|
|
504
696
|
if isinstance(job, xm.JobGroup):
|
|
505
697
|
raise ValueError(
|
|
506
698
|
"Job arrays over `xm.JobGroup`s aren't supported. "
|
|
@@ -509,19 +701,79 @@ class SlurmExperiment(xm.Experiment):
|
|
|
509
701
|
)
|
|
510
702
|
assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
|
|
511
703
|
|
|
704
|
+
# Validate job & args
|
|
705
|
+
for trial in args:
|
|
706
|
+
_validate_job(job, trial)
|
|
707
|
+
args = tp.cast(tp.Sequence[JobArgs], args)
|
|
708
|
+
|
|
512
709
|
return asyncio.wrap_future(
|
|
513
|
-
self._create_task(self._launch_job_array(job, args, role, identity))
|
|
710
|
+
self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
|
|
711
|
+
loop=self._event_loop,
|
|
712
|
+
)
|
|
713
|
+
if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
|
|
714
|
+
raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
|
|
715
|
+
|
|
716
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
717
|
+
role = job.role
|
|
718
|
+
self._added_roles[type(role)] += 1
|
|
719
|
+
|
|
720
|
+
if self._should_reload_experiment_unit(role):
|
|
721
|
+
experiment_unit_future = self._get_experiment_unit(
|
|
722
|
+
self.experiment_id, identity, role, args
|
|
514
723
|
)
|
|
515
724
|
else:
|
|
516
|
-
|
|
725
|
+
experiment_unit_future = self._create_experiment_unit(args, role, identity)
|
|
726
|
+
|
|
727
|
+
async def launch():
|
|
728
|
+
experiment_unit = await experiment_unit_future
|
|
729
|
+
try:
|
|
730
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
731
|
+
except Exception as experiment_exception:
|
|
732
|
+
logger.error(
|
|
733
|
+
"Stopping experiment unit (identity %r) after it failed with: %s",
|
|
734
|
+
identity,
|
|
735
|
+
experiment_exception,
|
|
736
|
+
)
|
|
737
|
+
try:
|
|
738
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
739
|
+
experiment_unit.stop()
|
|
740
|
+
else:
|
|
741
|
+
experiment_unit.stop(
|
|
742
|
+
mark_as_failed=True,
|
|
743
|
+
message=f"Work unit creation failed. {traceback.format_exc()}",
|
|
744
|
+
)
|
|
745
|
+
except Exception as stop_exception: # pylint: disable=broad-except
|
|
746
|
+
logger.error("Couldn't stop experiment unit: %s", stop_exception)
|
|
747
|
+
raise
|
|
748
|
+
return experiment_unit
|
|
749
|
+
|
|
750
|
+
async def reload():
|
|
751
|
+
experiment_unit = await experiment_unit_future
|
|
752
|
+
try:
|
|
753
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
754
|
+
except Exception as update_exception:
|
|
755
|
+
logging.error(
|
|
756
|
+
"Could not reload the experiment unit: %s",
|
|
757
|
+
update_exception,
|
|
758
|
+
)
|
|
759
|
+
raise
|
|
760
|
+
return experiment_unit
|
|
761
|
+
|
|
762
|
+
return asyncio.wrap_future(
|
|
763
|
+
self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
|
|
764
|
+
loop=self._event_loop,
|
|
765
|
+
)
|
|
517
766
|
|
|
518
767
|
async def _launch_job_array(
|
|
519
768
|
self,
|
|
520
769
|
job: xm.Job | xm.JobGeneratorType,
|
|
521
|
-
|
|
770
|
+
dependency: dependencies.SlurmJobDependency
|
|
771
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
772
|
+
| None,
|
|
773
|
+
args: tp.Sequence[JobArgs],
|
|
522
774
|
role: xm.WorkUnitRole,
|
|
523
775
|
identity: str = "",
|
|
524
|
-
) -> Sequence[SlurmWorkUnit]:
|
|
776
|
+
) -> tp.Sequence[SlurmWorkUnit]:
|
|
525
777
|
global _current_job_array_queue
|
|
526
778
|
|
|
527
779
|
# Create our job array queue and assign it to the current context
|
|
@@ -531,10 +783,13 @@ class SlurmExperiment(xm.Experiment):
|
|
|
531
783
|
# For each trial we'll schedule the job
|
|
532
784
|
# and collect the futures
|
|
533
785
|
wu_futures = []
|
|
534
|
-
for trial in args:
|
|
535
|
-
wu_futures.append(
|
|
786
|
+
for idx, trial in enumerate(args):
|
|
787
|
+
wu_futures.append(
|
|
788
|
+
self.add(
|
|
789
|
+
job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
|
|
790
|
+
)
|
|
791
|
+
)
|
|
536
792
|
|
|
537
|
-
# TODO(jfarebro): Set a timeout here
|
|
538
793
|
# We'll wait until XManager has filled the queue.
|
|
539
794
|
# There are two cases here, either we were given an xm.Job
|
|
540
795
|
# in which case this will be trivial and filled immediately.
|
|
@@ -543,7 +798,8 @@ class SlurmExperiment(xm.Experiment):
|
|
|
543
798
|
while not job_array_queue.full():
|
|
544
799
|
await asyncio.sleep(0.1)
|
|
545
800
|
|
|
546
|
-
# All jobs have been resolved
|
|
801
|
+
# All jobs have been resolved so now we'll perform sanity checks
|
|
802
|
+
# to make sure we can infer the sweep
|
|
547
803
|
executable, executor, name = None, None, None
|
|
548
804
|
resolved_args, resolved_env_vars, resolved_futures = [], [], []
|
|
549
805
|
while not job_array_queue.empty():
|
|
@@ -551,9 +807,13 @@ class SlurmExperiment(xm.Experiment):
|
|
|
551
807
|
# that there's only a single job in this job group
|
|
552
808
|
job_group_view, future = job_array_queue.get_nowait()
|
|
553
809
|
assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
|
|
554
|
-
|
|
810
|
+
job_view = mit.one(
|
|
811
|
+
job_group_view.jobs.values(),
|
|
812
|
+
too_short=ValueError("Expected a single `xm.Job` in job group."),
|
|
813
|
+
too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
|
|
814
|
+
)
|
|
555
815
|
|
|
556
|
-
if
|
|
816
|
+
if not isinstance(job_view, xm.Job):
|
|
557
817
|
raise ValueError("Only `xm.Job` is supported for job arrays. ")
|
|
558
818
|
|
|
559
819
|
if executable is None:
|
|
@@ -571,92 +831,223 @@ class SlurmExperiment(xm.Experiment):
|
|
|
571
831
|
if job_view.name != name:
|
|
572
832
|
raise RuntimeError("Found multiple names in job array")
|
|
573
833
|
|
|
574
|
-
resolved_args.append(
|
|
575
|
-
set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
|
|
576
|
-
)
|
|
834
|
+
resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
|
|
577
835
|
resolved_env_vars.append(set(job_view.env_vars.items()))
|
|
578
836
|
resolved_futures.append(future)
|
|
579
837
|
assert executable is not None, "No executable found?"
|
|
580
838
|
assert executor is not None, "No executor found?"
|
|
581
839
|
assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
|
|
582
|
-
assert
|
|
583
|
-
|
|
584
|
-
|
|
840
|
+
assert (
|
|
841
|
+
executor.requirements.cluster is not None
|
|
842
|
+
), "Cluster must be specified on requirements."
|
|
843
|
+
|
|
844
|
+
# XManager merges job arguments with keyword arguments with job arguments
|
|
845
|
+
# coming first. These are the arguments that may be common across all jobs
|
|
846
|
+
# so we can find the largest common prefix and remove them from each job.
|
|
847
|
+
common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
|
|
585
848
|
common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
|
|
586
849
|
|
|
587
850
|
sweep_args = [
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
851
|
+
JobArgs(
|
|
852
|
+
args=functools.reduce(
|
|
853
|
+
# Remove the common arguments from each job
|
|
854
|
+
lambda args, to_remove: args.remove_args(to_remove),
|
|
855
|
+
common_args,
|
|
856
|
+
xm.SequentialArgs.from_collection(a),
|
|
857
|
+
),
|
|
858
|
+
env_vars=dict(e.difference(common_env_vars)),
|
|
859
|
+
)
|
|
592
860
|
for a, e in zip(resolved_args, resolved_env_vars)
|
|
593
861
|
]
|
|
594
862
|
|
|
863
|
+
# Dependency resolution
|
|
864
|
+
resolved_dependency = None
|
|
865
|
+
resolved_dependency_task_id_order = None
|
|
866
|
+
# one-to-one
|
|
867
|
+
if isinstance(dependency, collections.abc.Sequence):
|
|
868
|
+
if len(dependency) != len(wu_futures):
|
|
869
|
+
raise ValueError("Dependency list must be the same length as the number of trials.")
|
|
870
|
+
assert len(dependency) > 0, "Dependency list must not be empty."
|
|
871
|
+
|
|
872
|
+
# Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
|
|
873
|
+
# for any array jobs.
|
|
874
|
+
def _maybe_convert_afterok(
|
|
875
|
+
dep: dependencies.SlurmJobDependency,
|
|
876
|
+
) -> dependencies.SlurmJobDependency:
|
|
877
|
+
if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
|
|
878
|
+
handle.slurm_job.is_array_job for handle in dep.handles
|
|
879
|
+
]):
|
|
880
|
+
return dependencies.SlurmJobArrayDependencyAfterOK([
|
|
881
|
+
dataclasses.replace(
|
|
882
|
+
handle,
|
|
883
|
+
slurm_job=handle.slurm_job.array_job_id,
|
|
884
|
+
)
|
|
885
|
+
for handle in dep.handles
|
|
886
|
+
])
|
|
887
|
+
return dep
|
|
888
|
+
|
|
889
|
+
dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
|
|
890
|
+
dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
|
|
891
|
+
dependency_differences = functools.reduce(set.difference, dependency_sets, set())
|
|
892
|
+
# There should be NO differences between the dependencies of each trial after conversion.
|
|
893
|
+
if len(dependency_differences) > 0:
|
|
894
|
+
raise ValueError(
|
|
895
|
+
f"Found variable dependencies across trials: {dependency_differences}. "
|
|
896
|
+
"Slurm job arrays require the same dependencies across all trials. "
|
|
897
|
+
)
|
|
898
|
+
resolved_dependency = dependencies_converted[0]
|
|
899
|
+
|
|
900
|
+
# This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
|
|
901
|
+
# in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
|
|
902
|
+
original_array_dependencies = [
|
|
903
|
+
mit.one(
|
|
904
|
+
filter(
|
|
905
|
+
lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
|
|
906
|
+
and all([handle.slurm_job.is_array_job for handle in dep.handles]),
|
|
907
|
+
deps.flatten(),
|
|
908
|
+
)
|
|
909
|
+
)
|
|
910
|
+
for deps in dependency
|
|
911
|
+
]
|
|
912
|
+
resolved_dependency_task_id_order = [
|
|
913
|
+
int(
|
|
914
|
+
mit.one(
|
|
915
|
+
functools.reduce(
|
|
916
|
+
set.difference,
|
|
917
|
+
[handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
|
|
918
|
+
)
|
|
919
|
+
)
|
|
920
|
+
)
|
|
921
|
+
for dep in original_array_dependencies
|
|
922
|
+
]
|
|
923
|
+
assert len(resolved_dependency_task_id_order) == len(sweep_args)
|
|
924
|
+
assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
|
|
925
|
+
"Dependent job array tasks should have task ids 0, 1, ..., N. "
|
|
926
|
+
f"Found: {resolved_dependency_task_id_order}"
|
|
927
|
+
)
|
|
928
|
+
# one-to-many
|
|
929
|
+
elif isinstance(dependency, dependencies.SlurmJobDependency):
|
|
930
|
+
resolved_dependency = dependency
|
|
931
|
+
assert resolved_dependency is None or isinstance(
|
|
932
|
+
resolved_dependency, dependencies.SlurmJobDependency
|
|
933
|
+
), "Invalid dependency type"
|
|
934
|
+
|
|
595
935
|
# No support for sweep_env_vars right now.
|
|
596
936
|
# We schedule the job array and then we'll resolve all the work units with
|
|
597
937
|
# the handles Slurm gives back to us.
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
938
|
+
# If we already have handles for all the work units, we don't need to submit the
|
|
939
|
+
# job array to SLURM.
|
|
940
|
+
num_resolved_handles = sum(future.done() for future in resolved_futures)
|
|
941
|
+
if num_resolved_handles == 0:
|
|
942
|
+
try:
|
|
943
|
+
handles = await execution.launch(
|
|
944
|
+
job=xm.Job(
|
|
945
|
+
executable=executable,
|
|
946
|
+
executor=executor,
|
|
947
|
+
name=name,
|
|
948
|
+
args=xm.SequentialArgs.from_collection(common_args),
|
|
949
|
+
env_vars=dict(common_env_vars),
|
|
950
|
+
),
|
|
951
|
+
dependency=resolved_dependency,
|
|
952
|
+
args=[
|
|
953
|
+
sweep_args[resolved_dependency_task_id_order.index(i)]
|
|
954
|
+
for i in range(len(sweep_args))
|
|
955
|
+
]
|
|
956
|
+
if resolved_dependency_task_id_order
|
|
957
|
+
else sweep_args,
|
|
958
|
+
experiment_id=self.experiment_id,
|
|
959
|
+
identity=identity,
|
|
960
|
+
)
|
|
961
|
+
if resolved_dependency_task_id_order:
|
|
962
|
+
handles = [handles[i] for i in resolved_dependency_task_id_order]
|
|
963
|
+
except Exception as e:
|
|
964
|
+
for future in resolved_futures:
|
|
965
|
+
future.set_exception(e)
|
|
966
|
+
raise
|
|
967
|
+
else:
|
|
968
|
+
for handle, future in zip(handles, resolved_futures):
|
|
969
|
+
future.set_result(handle)
|
|
970
|
+
elif 0 < num_resolved_handles < len(resolved_futures):
|
|
971
|
+
raise RuntimeError(
|
|
972
|
+
"Some array job elements have handles, but some don't. This shouldn't happen."
|
|
611
973
|
)
|
|
612
|
-
except Exception as e:
|
|
613
|
-
for future in resolved_futures:
|
|
614
|
-
future.set_exception(e)
|
|
615
|
-
raise
|
|
616
|
-
else:
|
|
617
|
-
for handle, future in zip(handles, resolved_futures):
|
|
618
|
-
future.set_result(handle)
|
|
619
974
|
|
|
620
975
|
wus = await asyncio.gather(*wu_futures)
|
|
976
|
+
|
|
621
977
|
_current_job_array_queue.set(None)
|
|
622
978
|
return wus
|
|
623
979
|
|
|
624
|
-
def
|
|
980
|
+
def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
|
|
981
|
+
if identity == "":
|
|
982
|
+
return None
|
|
983
|
+
for unit in self._experiment_units:
|
|
984
|
+
if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
|
|
985
|
+
return unit
|
|
986
|
+
return None
|
|
987
|
+
|
|
988
|
+
def _create_experiment_unit( # type: ignore
|
|
625
989
|
self,
|
|
626
|
-
args: Mapping[str,
|
|
990
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
627
991
|
role: xm.ExperimentUnitRole,
|
|
628
992
|
identity: str,
|
|
629
|
-
) -> Awaitable[SlurmWorkUnit]:
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
|
|
993
|
+
) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
|
|
994
|
+
def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
|
|
633
995
|
work_unit = SlurmWorkUnit(
|
|
634
996
|
self,
|
|
635
997
|
self._create_task,
|
|
636
998
|
args,
|
|
637
999
|
role,
|
|
638
1000
|
self._work_unit_id_predictor,
|
|
1001
|
+
identity=identity,
|
|
639
1002
|
)
|
|
640
1003
|
self._experiment_units.append(work_unit)
|
|
641
1004
|
self._work_unit_count += 1
|
|
642
1005
|
|
|
643
|
-
|
|
1006
|
+
api.client().insert_work_unit(
|
|
1007
|
+
self.experiment_id,
|
|
1008
|
+
api.WorkUnitModel(
|
|
1009
|
+
wid=work_unit.work_unit_id,
|
|
1010
|
+
identity=work_unit.identity,
|
|
1011
|
+
args=json.dumps(args),
|
|
1012
|
+
),
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
644
1016
|
future.set_result(work_unit)
|
|
645
1017
|
return future
|
|
646
1018
|
|
|
1019
|
+
def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
|
|
1020
|
+
auxiliary_unit = SlurmAuxiliaryUnit(
|
|
1021
|
+
self,
|
|
1022
|
+
self._create_task,
|
|
1023
|
+
args,
|
|
1024
|
+
role,
|
|
1025
|
+
identity=identity,
|
|
1026
|
+
)
|
|
1027
|
+
self._experiment_units.append(auxiliary_unit)
|
|
1028
|
+
future = asyncio.Future[SlurmAuxiliaryUnit]()
|
|
1029
|
+
future.set_result(auxiliary_unit)
|
|
1030
|
+
return future
|
|
1031
|
+
|
|
647
1032
|
match role:
|
|
648
1033
|
case xm.WorkUnitRole():
|
|
1034
|
+
if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
|
|
1035
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
1036
|
+
future.set_result(existing_unit)
|
|
1037
|
+
return future
|
|
649
1038
|
return _create_work_unit(role)
|
|
1039
|
+
case xm.AuxiliaryUnitRole():
|
|
1040
|
+
return _create_auxiliary_unit(role)
|
|
650
1041
|
case _:
|
|
651
1042
|
raise ValueError(f"Unsupported role {role}")
|
|
652
1043
|
|
|
653
|
-
def _get_experiment_unit(
|
|
1044
|
+
def _get_experiment_unit( # type: ignore
|
|
654
1045
|
self,
|
|
655
1046
|
experiment_id: int,
|
|
656
1047
|
identity: str,
|
|
657
1048
|
role: xm.ExperimentUnitRole,
|
|
658
|
-
args: Mapping[str,
|
|
659
|
-
) -> Awaitable[
|
|
1049
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
1050
|
+
) -> tp.Awaitable[SlurmExperimentUnit]:
|
|
660
1051
|
del experiment_id, identity, role, args
|
|
661
1052
|
raise NotImplementedError
|
|
662
1053
|
|
|
@@ -672,7 +1063,7 @@ class SlurmExperiment(xm.Experiment):
|
|
|
672
1063
|
# If no work units were added, delete this experiment
|
|
673
1064
|
# This is to prevent empty experiments from being persisted
|
|
674
1065
|
# and cluttering the database.
|
|
675
|
-
if self.
|
|
1066
|
+
if len(self._experiment_units) == 0:
|
|
676
1067
|
console.print(
|
|
677
1068
|
f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
|
|
678
1069
|
)
|
|
@@ -689,15 +1080,14 @@ class SlurmExperiment(xm.Experiment):
|
|
|
689
1080
|
return self.context.annotations.title
|
|
690
1081
|
|
|
691
1082
|
@property
|
|
692
|
-
def context(self) -> SlurmExperimentMetadataContext:
|
|
1083
|
+
def context(self) -> SlurmExperimentMetadataContext: # type: ignore
|
|
693
1084
|
return self._experiment_context
|
|
694
1085
|
|
|
695
1086
|
@property
|
|
696
1087
|
def work_unit_count(self) -> int:
|
|
697
1088
|
return self._work_unit_count
|
|
698
1089
|
|
|
699
|
-
|
|
700
|
-
def work_units(self) -> Mapping[int, SlurmWorkUnit]:
|
|
1090
|
+
def work_units(self) -> dict[int, SlurmWorkUnit]:
|
|
701
1091
|
"""Gets work units created via self.add()."""
|
|
702
1092
|
return {
|
|
703
1093
|
wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
|
|
@@ -716,8 +1106,51 @@ def create_experiment(experiment_title: str) -> SlurmExperiment:
|
|
|
716
1106
|
def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
717
1107
|
"""Get Experiment."""
|
|
718
1108
|
experiment_model = api.client().get_experiment(experiment_id)
|
|
719
|
-
|
|
720
|
-
|
|
1109
|
+
experiment = SlurmExperiment(
|
|
1110
|
+
experiment_title=experiment_model.title, experiment_id=experiment_id
|
|
1111
|
+
)
|
|
1112
|
+
experiment._work_unit_id_predictor = id_predictor.Predictor(1)
|
|
1113
|
+
|
|
1114
|
+
# Populate annotations
|
|
1115
|
+
experiment.context.annotations.description = experiment_model.description or ""
|
|
1116
|
+
experiment.context.annotations.note = experiment_model.note or ""
|
|
1117
|
+
experiment.context.annotations.tags = set(experiment_model.tags or [])
|
|
1118
|
+
|
|
1119
|
+
# Populate artifacts
|
|
1120
|
+
for artifact in experiment_model.artifacts:
|
|
1121
|
+
experiment.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
|
|
1122
|
+
|
|
1123
|
+
# Populate work units
|
|
1124
|
+
for wu_model in experiment_model.work_units:
|
|
1125
|
+
work_unit = SlurmWorkUnit(
|
|
1126
|
+
experiment=experiment,
|
|
1127
|
+
create_task=experiment._create_task,
|
|
1128
|
+
args=json.loads(wu_model.args) if wu_model.args else {},
|
|
1129
|
+
role=xm.WorkUnitRole(),
|
|
1130
|
+
identity=wu_model.identity or "",
|
|
1131
|
+
work_unit_id_predictor=experiment._work_unit_id_predictor,
|
|
1132
|
+
)
|
|
1133
|
+
work_unit._work_unit_id = wu_model.wid
|
|
1134
|
+
|
|
1135
|
+
# Populate jobs for each work unit
|
|
1136
|
+
for job_model in wu_model.jobs:
|
|
1137
|
+
slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
|
|
1138
|
+
handle = execution.SlurmHandle(
|
|
1139
|
+
experiment_id=experiment_id,
|
|
1140
|
+
ssh=slurm_ssh_config,
|
|
1141
|
+
slurm_job=str(job_model.slurm_job_id),
|
|
1142
|
+
job_name=job_model.name,
|
|
1143
|
+
)
|
|
1144
|
+
work_unit._execution_handles.append(handle)
|
|
1145
|
+
|
|
1146
|
+
# Populate artifacts for each work unit
|
|
1147
|
+
for artifact in wu_model.artifacts:
|
|
1148
|
+
work_unit.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
|
|
1149
|
+
|
|
1150
|
+
experiment._experiment_units.append(work_unit)
|
|
1151
|
+
experiment._work_unit_count += 1
|
|
1152
|
+
|
|
1153
|
+
return experiment
|
|
721
1154
|
|
|
722
1155
|
|
|
723
1156
|
@functools.cache
|
|
@@ -733,5 +1166,5 @@ def get_current_work_unit() -> SlurmWorkUnit | None:
|
|
|
733
1166
|
wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
|
|
734
1167
|
):
|
|
735
1168
|
experiment = get_experiment(int(xid))
|
|
736
|
-
return experiment.work_units[int(wid)]
|
|
1169
|
+
return experiment.work_units()[int(wid)]
|
|
737
1170
|
return None
|