xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
xm_slurm/experiment.py
ADDED
|
@@ -0,0 +1,1016 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections.abc
|
|
3
|
+
import contextvars
|
|
4
|
+
import dataclasses
|
|
5
|
+
import datetime as dt
|
|
6
|
+
import functools
|
|
7
|
+
import inspect
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import traceback
|
|
12
|
+
import typing as tp
|
|
13
|
+
from concurrent import futures
|
|
14
|
+
|
|
15
|
+
import more_itertools as mit
|
|
16
|
+
from rich.console import ConsoleRenderable
|
|
17
|
+
from xmanager import xm
|
|
18
|
+
from xmanager.xm import async_packager, core, id_predictor, job_operators
|
|
19
|
+
from xmanager.xm import job_blocks as xm_job_blocks
|
|
20
|
+
|
|
21
|
+
from xm_slurm import api, config, dependencies, execution, executors, metadata_context
|
|
22
|
+
from xm_slurm.console import console
|
|
23
|
+
from xm_slurm.job_blocks import JobArgs
|
|
24
|
+
from xm_slurm.packaging import router
|
|
25
|
+
from xm_slurm.status import SlurmWorkUnitStatus
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
_current_job_array_queue = contextvars.ContextVar[
|
|
30
|
+
asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
|
|
31
|
+
]("_current_job_array_queue", default=None)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _validate_job(
|
|
35
|
+
job: xm.JobType,
|
|
36
|
+
args_view: JobArgs | tp.Mapping[str, JobArgs],
|
|
37
|
+
) -> None:
|
|
38
|
+
if not args_view:
|
|
39
|
+
return
|
|
40
|
+
if not isinstance(args_view, collections.abc.Mapping):
|
|
41
|
+
raise ValueError("Job arguments via `experiment.add` must be mappings")
|
|
42
|
+
|
|
43
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
44
|
+
raise ValueError("Job group is empty")
|
|
45
|
+
|
|
46
|
+
if isinstance(job, xm.JobGroup) and any(
|
|
47
|
+
isinstance(child, xm.JobGroup) for child in job.jobs.values()
|
|
48
|
+
):
|
|
49
|
+
raise ValueError("Nested job groups are not supported")
|
|
50
|
+
|
|
51
|
+
allowed_keys = {"args", "env_vars"}
|
|
52
|
+
for key, expanded in args_view.items():
|
|
53
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) > 1 and key not in job.jobs:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"Argument key `{key}` doesn't exist in job group with keys {job.jobs.keys()}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if isinstance(job, xm.JobGroup) and key in job.jobs:
|
|
59
|
+
_validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
|
|
60
|
+
elif key not in allowed_keys:
|
|
61
|
+
raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
65
|
+
"""ExperimentUnit is a collection of semantically associated `Job`s."""
|
|
66
|
+
|
|
67
|
+
experiment: "SlurmExperiment" # type: ignore
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
experiment: xm.Experiment,
|
|
72
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
|
|
73
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
74
|
+
role: xm.ExperimentUnitRole,
|
|
75
|
+
identity: str = "",
|
|
76
|
+
) -> None:
|
|
77
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
78
|
+
self._launched_jobs: list[xm.LaunchedJob] = []
|
|
79
|
+
self._execution_handles: list[execution.SlurmHandle] = []
|
|
80
|
+
self._context = metadata_context.SlurmExperimentUnitMetadataContext(
|
|
81
|
+
self,
|
|
82
|
+
artifacts=metadata_context.SlurmContextArtifacts(owner=self, artifacts=[]),
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def add( # type: ignore
|
|
86
|
+
self,
|
|
87
|
+
job: xm.JobType,
|
|
88
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
89
|
+
*,
|
|
90
|
+
dependency: dependencies.SlurmJobDependency | None = None,
|
|
91
|
+
identity: str = "",
|
|
92
|
+
) -> tp.Awaitable[None]:
|
|
93
|
+
# Prioritize the identity given directly to the work unit at work unit
|
|
94
|
+
# creation time, as opposed to the identity passed when adding jobs to it as
|
|
95
|
+
# this is more consistent between job generator work units and regular work
|
|
96
|
+
# units.
|
|
97
|
+
identity = self.identity or identity
|
|
98
|
+
|
|
99
|
+
job = job_operators.shallow_copy_job_type(job) # type: ignore
|
|
100
|
+
if args is not None:
|
|
101
|
+
core._apply_args(job, args)
|
|
102
|
+
job_operators.populate_job_names(job) # type: ignore
|
|
103
|
+
|
|
104
|
+
def launch_job(job: xm.Job) -> tp.Awaitable[None]:
|
|
105
|
+
core._current_experiment.set(self.experiment)
|
|
106
|
+
core._current_experiment_unit.set(self)
|
|
107
|
+
return self._launch_job_group(
|
|
108
|
+
xm.JobGroup(**{job.name: job}), # type: ignore
|
|
109
|
+
core._work_unit_arguments(job, self._args),
|
|
110
|
+
dependency=dependency,
|
|
111
|
+
identity=identity,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
|
|
115
|
+
core._current_experiment.set(self.experiment)
|
|
116
|
+
core._current_experiment_unit.set(self)
|
|
117
|
+
return self._launch_job_group(
|
|
118
|
+
group,
|
|
119
|
+
core._work_unit_arguments(group, self._args),
|
|
120
|
+
dependency=dependency,
|
|
121
|
+
identity=identity,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def launch_job_generator(
|
|
125
|
+
job_generator: xm.JobGeneratorType,
|
|
126
|
+
) -> tp.Awaitable[None]:
|
|
127
|
+
if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
|
|
128
|
+
getattr(job_generator, "__call__")
|
|
129
|
+
):
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"Job generator must be an async function. Signature needs to be "
|
|
132
|
+
"`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
|
|
133
|
+
)
|
|
134
|
+
core._current_experiment.set(self.experiment)
|
|
135
|
+
core._current_experiment_unit.set(self)
|
|
136
|
+
coroutine = job_generator(self, **(args or {}))
|
|
137
|
+
assert coroutine is not None
|
|
138
|
+
return coroutine
|
|
139
|
+
|
|
140
|
+
def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
|
|
141
|
+
core._current_experiment.set(self.experiment)
|
|
142
|
+
core._current_experiment_unit.set(self)
|
|
143
|
+
return self._launch_job_config(
|
|
144
|
+
job_config, dependency, tp.cast(JobArgs, args) or {}, identity
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
job_awaitable: tp.Awaitable[tp.Any]
|
|
148
|
+
match job:
|
|
149
|
+
case xm.Job() as job:
|
|
150
|
+
job_awaitable = launch_job(job)
|
|
151
|
+
case xm.JobGroup() as job_group:
|
|
152
|
+
job_awaitable = launch_job_group(job_group)
|
|
153
|
+
case job_generator if xm_job_blocks.is_job_generator(job):
|
|
154
|
+
job_awaitable = launch_job_generator(job_generator) # type: ignore
|
|
155
|
+
case xm.JobConfig() as job_config:
|
|
156
|
+
job_awaitable = launch_job_config(job_config)
|
|
157
|
+
case _:
|
|
158
|
+
raise TypeError(f"Unsupported job type: {job!r}")
|
|
159
|
+
|
|
160
|
+
launch_task = self._create_task(job_awaitable)
|
|
161
|
+
self._launch_tasks.append(launch_task)
|
|
162
|
+
return asyncio.wrap_future(launch_task)
|
|
163
|
+
|
|
164
|
+
async def _launch_job_group( # type: ignore
|
|
165
|
+
self,
|
|
166
|
+
job_group: xm.JobGroup,
|
|
167
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
168
|
+
*,
|
|
169
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
170
|
+
identity: str,
|
|
171
|
+
) -> None:
|
|
172
|
+
del job_group, dependency, args_view, identity
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
async def _launch_job_config( # type: ignore
|
|
176
|
+
self,
|
|
177
|
+
job_config: xm.JobConfig,
|
|
178
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
179
|
+
args_view: JobArgs,
|
|
180
|
+
identity: str,
|
|
181
|
+
) -> None:
|
|
182
|
+
del job_config, dependency, args_view, identity
|
|
183
|
+
raise NotImplementedError
|
|
184
|
+
|
|
185
|
+
@tp.overload
|
|
186
|
+
async def _submit_jobs_for_execution(
|
|
187
|
+
self,
|
|
188
|
+
job: xm.Job,
|
|
189
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
190
|
+
args_view: JobArgs,
|
|
191
|
+
identity: str | None = ...,
|
|
192
|
+
) -> execution.SlurmHandle: ...
|
|
193
|
+
|
|
194
|
+
@tp.overload
|
|
195
|
+
async def _submit_jobs_for_execution(
|
|
196
|
+
self,
|
|
197
|
+
job: xm.JobGroup,
|
|
198
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
199
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
200
|
+
identity: str | None = ...,
|
|
201
|
+
) -> execution.SlurmHandle: ...
|
|
202
|
+
|
|
203
|
+
@tp.overload
|
|
204
|
+
async def _submit_jobs_for_execution(
|
|
205
|
+
self,
|
|
206
|
+
job: xm.Job,
|
|
207
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
208
|
+
args_view: tp.Sequence[JobArgs],
|
|
209
|
+
identity: str | None = ...,
|
|
210
|
+
) -> list[execution.SlurmHandle]: ...
|
|
211
|
+
|
|
212
|
+
async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
|
|
213
|
+
return await execution.launch(
|
|
214
|
+
job=job,
|
|
215
|
+
dependency=dependency,
|
|
216
|
+
args=args_view,
|
|
217
|
+
experiment_id=self.experiment_id,
|
|
218
|
+
identity=identity,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
|
|
222
|
+
self._execution_handles.append(handle)
|
|
223
|
+
|
|
224
|
+
def _ingest_job(job: xm.Job) -> None:
|
|
225
|
+
if not isinstance(self._role, xm.WorkUnitRole):
|
|
226
|
+
return
|
|
227
|
+
assert isinstance(self, SlurmWorkUnit)
|
|
228
|
+
assert job.name is not None
|
|
229
|
+
api.client().insert_job(
|
|
230
|
+
self.experiment_id,
|
|
231
|
+
self.work_unit_id,
|
|
232
|
+
api.models.SlurmJob(
|
|
233
|
+
name=job.name,
|
|
234
|
+
slurm_job_id=handle.slurm_job.job_id,
|
|
235
|
+
slurm_ssh_config=handle.ssh.serialize(),
|
|
236
|
+
),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
match job:
|
|
240
|
+
case xm.JobGroup() as job_group:
|
|
241
|
+
for job in job_group.jobs.values():
|
|
242
|
+
assert isinstance(job, xm.Job)
|
|
243
|
+
_ingest_job(job)
|
|
244
|
+
self._launched_jobs.append(
|
|
245
|
+
xm.LaunchedJob(
|
|
246
|
+
name=job.name, # type: ignore
|
|
247
|
+
address=str(handle.slurm_job.job_id),
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
case xm.Job():
|
|
251
|
+
_ingest_job(job)
|
|
252
|
+
self._launched_jobs.append(
|
|
253
|
+
xm.LaunchedJob(
|
|
254
|
+
name=handle.job.name, # type: ignore
|
|
255
|
+
address=str(handle.slurm_job.job_id),
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
async def _wait_until_complete(self) -> None:
|
|
260
|
+
try:
|
|
261
|
+
await asyncio.gather(*[handle.wait() for handle in self._execution_handles])
|
|
262
|
+
except RuntimeError as error:
|
|
263
|
+
raise xm.ExperimentUnitFailedError(error)
|
|
264
|
+
|
|
265
|
+
def stop(
|
|
266
|
+
self,
|
|
267
|
+
*,
|
|
268
|
+
mark_as_failed: bool = False,
|
|
269
|
+
mark_as_completed: bool = False,
|
|
270
|
+
message: str | None = None,
|
|
271
|
+
) -> None:
|
|
272
|
+
del mark_as_failed, mark_as_completed, message
|
|
273
|
+
|
|
274
|
+
async def _stop_awaitable() -> None:
|
|
275
|
+
try:
|
|
276
|
+
await asyncio.gather(*[handle.stop() for handle in self._execution_handles])
|
|
277
|
+
except RuntimeError as error:
|
|
278
|
+
raise xm.ExperimentUnitFailedError(error)
|
|
279
|
+
|
|
280
|
+
self.experiment._create_task(_stop_awaitable())
|
|
281
|
+
|
|
282
|
+
async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
|
|
283
|
+
states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
|
|
284
|
+
return SlurmWorkUnitStatus.aggregate(states)
|
|
285
|
+
|
|
286
|
+
async def logs(
|
|
287
|
+
self,
|
|
288
|
+
*,
|
|
289
|
+
num_lines: int = 10,
|
|
290
|
+
block_size: int = 1024,
|
|
291
|
+
wait: bool = True,
|
|
292
|
+
follow: bool = False,
|
|
293
|
+
) -> tp.AsyncGenerator[ConsoleRenderable, None]:
|
|
294
|
+
if not self._execution_handles:
|
|
295
|
+
raise ValueError(f"No execution handles found for experiment unit {self!r}")
|
|
296
|
+
elif len(self._execution_handles) > 1:
|
|
297
|
+
raise ValueError(f"Multiple execution handles found for experiment unit {self!r}")
|
|
298
|
+
assert len(self._execution_handles) == 1
|
|
299
|
+
|
|
300
|
+
handle = self._execution_handles[0] # TODO(jfarebro): interleave?
|
|
301
|
+
async for log in handle.logs(
|
|
302
|
+
num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
|
|
303
|
+
):
|
|
304
|
+
yield log
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def launched_jobs(self) -> list[xm.LaunchedJob]:
|
|
308
|
+
return self._launched_jobs
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def context(self) -> metadata_context.SlurmExperimentUnitMetadataContext: # type: ignore
|
|
312
|
+
return self._context
|
|
313
|
+
|
|
314
|
+
def after_started(
|
|
315
|
+
self, *, time: dt.timedelta | None = None
|
|
316
|
+
) -> dependencies.SlurmJobDependencyAfter:
|
|
317
|
+
return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
|
|
318
|
+
|
|
319
|
+
def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
|
|
320
|
+
return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
|
|
321
|
+
|
|
322
|
+
def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
|
|
323
|
+
return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
|
|
324
|
+
|
|
325
|
+
def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
|
|
326
|
+
return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
330
|
+
def __init__(
|
|
331
|
+
self,
|
|
332
|
+
experiment: "SlurmExperiment",
|
|
333
|
+
create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
|
|
334
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
335
|
+
role: xm.ExperimentUnitRole,
|
|
336
|
+
work_unit_id_predictor: id_predictor.Predictor,
|
|
337
|
+
identity: str = "",
|
|
338
|
+
) -> None:
|
|
339
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
340
|
+
self._work_unit_id_predictor = work_unit_id_predictor
|
|
341
|
+
self._work_unit_id = self._work_unit_id_predictor.reserve_id()
|
|
342
|
+
|
|
343
|
+
def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
|
|
344
|
+
job_name = mit.one(job.jobs.keys())
|
|
345
|
+
for handle in self._execution_handles:
|
|
346
|
+
if handle.job_name == job_name:
|
|
347
|
+
return handle
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
async def _launch_job_group( # type: ignore
|
|
351
|
+
self,
|
|
352
|
+
job: xm.JobGroup,
|
|
353
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
354
|
+
*,
|
|
355
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
356
|
+
identity: str,
|
|
357
|
+
) -> None:
|
|
358
|
+
global _current_job_array_queue
|
|
359
|
+
_validate_job(job, args_view)
|
|
360
|
+
future = asyncio.Future[execution.SlurmHandle]()
|
|
361
|
+
|
|
362
|
+
# If we already have a handle for this job, we don't need to submit it again.
|
|
363
|
+
# We'll just resolve the future with the existing handle.
|
|
364
|
+
# Otherwise we'll add callbacks to ingest the handle and the launched jobs.
|
|
365
|
+
if existing_handle := self._get_existing_handle(job):
|
|
366
|
+
future.set_result(existing_handle)
|
|
367
|
+
else:
|
|
368
|
+
future.add_done_callback(
|
|
369
|
+
lambda handle: self._ingest_launched_jobs(job, handle.result())
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
api.client().update_work_unit(
|
|
373
|
+
self.experiment_id,
|
|
374
|
+
self.work_unit_id,
|
|
375
|
+
api.models.ExperimentUnitPatch(args=json.dumps(args_view), identity=None),
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
|
|
379
|
+
# If we're scheduling as part of a job queue (i.e., the queue is set on the context)
|
|
380
|
+
# then we'll insert the job and future that'll get resolved to the proper handle
|
|
381
|
+
# when the Slurm job array is scheduled.
|
|
382
|
+
if job_array_queue := _current_job_array_queue.get():
|
|
383
|
+
job_array_queue.put_nowait((job, future))
|
|
384
|
+
# Otherwise, we're scheduling a single job and we'll submit it for execution.
|
|
385
|
+
# If the future is already done, i.e., the handle is already resolved, we don't need
|
|
386
|
+
# to submit the job again.
|
|
387
|
+
elif not future.done():
|
|
388
|
+
handle = await self._submit_jobs_for_execution(
|
|
389
|
+
job, dependency, args_view, identity=identity
|
|
390
|
+
)
|
|
391
|
+
future.set_result(handle)
|
|
392
|
+
|
|
393
|
+
# Wait for the job handle, this is either coming from scheduling the job array
|
|
394
|
+
# or from a single job submission. If an existing handle was found, this will be
|
|
395
|
+
# a no-op.
|
|
396
|
+
await future
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def experiment_unit_name(self) -> str:
|
|
400
|
+
return f"{self.experiment_id}_{self._work_unit_id}"
|
|
401
|
+
|
|
402
|
+
@property
|
|
403
|
+
def work_unit_id(self) -> int:
|
|
404
|
+
return self._work_unit_id
|
|
405
|
+
|
|
406
|
+
def __repr__(self, /) -> str:
|
|
407
|
+
return f"<SlurmWorkUnit {self.experiment_unit_name}>"
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class SlurmAuxiliaryUnit(SlurmExperimentUnit):
|
|
411
|
+
"""An auxiliary unit operated by the Slurm backend."""
|
|
412
|
+
|
|
413
|
+
async def _launch_job_group( # type: ignore
|
|
414
|
+
self,
|
|
415
|
+
job: xm.JobGroup,
|
|
416
|
+
args_view: tp.Mapping[str, JobArgs],
|
|
417
|
+
*,
|
|
418
|
+
dependency: dependencies.SlurmJobDependency | None,
|
|
419
|
+
identity: str,
|
|
420
|
+
) -> None:
|
|
421
|
+
_validate_job(job, args_view)
|
|
422
|
+
|
|
423
|
+
slurm_handle = await self._submit_jobs_for_execution(
|
|
424
|
+
job, dependency, args_view, identity=identity
|
|
425
|
+
)
|
|
426
|
+
self._ingest_launched_jobs(job, slurm_handle)
|
|
427
|
+
|
|
428
|
+
@property
|
|
429
|
+
def experiment_unit_name(self) -> str:
|
|
430
|
+
return f"{self.experiment_id}_auxiliary"
|
|
431
|
+
|
|
432
|
+
def __repr__(self, /) -> str:
|
|
433
|
+
return f"<SlurmAuxiliaryUnit {self.experiment_unit_name}>"
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class SlurmExperiment(xm.Experiment):
|
|
437
|
+
_id: int
|
|
438
|
+
_experiment_units: list[SlurmExperimentUnit]
|
|
439
|
+
_experiment_context: metadata_context.SlurmExperimentMetadataContext
|
|
440
|
+
_work_unit_count: int
|
|
441
|
+
_async_packager = async_packager.AsyncPackager(router.package)
|
|
442
|
+
|
|
443
|
+
def __init__(
|
|
444
|
+
self,
|
|
445
|
+
experiment_title: str,
|
|
446
|
+
experiment_id: int,
|
|
447
|
+
) -> None:
|
|
448
|
+
super().__init__()
|
|
449
|
+
self._id = experiment_id
|
|
450
|
+
self._experiment_units = []
|
|
451
|
+
self._experiment_context = metadata_context.SlurmExperimentMetadataContext(
|
|
452
|
+
self,
|
|
453
|
+
annotations=metadata_context.SlurmExperimentContextAnnotations(
|
|
454
|
+
experiment=self,
|
|
455
|
+
title=experiment_title,
|
|
456
|
+
),
|
|
457
|
+
artifacts=metadata_context.SlurmContextArtifacts(self, artifacts=[]),
|
|
458
|
+
)
|
|
459
|
+
self._work_unit_count = 0
|
|
460
|
+
|
|
461
|
+
@tp.overload
|
|
462
|
+
def add( # type: ignore
|
|
463
|
+
self,
|
|
464
|
+
job: xm.AuxiliaryUnitJob,
|
|
465
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
|
|
466
|
+
*,
|
|
467
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
468
|
+
identity: str = ...,
|
|
469
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
470
|
+
|
|
471
|
+
@tp.overload
|
|
472
|
+
def add(
|
|
473
|
+
self,
|
|
474
|
+
job: xm.JobGroup,
|
|
475
|
+
args: tp.Mapping[str, JobArgs] | None = ...,
|
|
476
|
+
*,
|
|
477
|
+
role: xm.WorkUnitRole | None = ...,
|
|
478
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
479
|
+
identity: str = ...,
|
|
480
|
+
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
481
|
+
|
|
482
|
+
@tp.overload
|
|
483
|
+
def add(
|
|
484
|
+
self,
|
|
485
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
486
|
+
args: tp.Sequence[JobArgs],
|
|
487
|
+
*,
|
|
488
|
+
role: xm.WorkUnitRole | None = ...,
|
|
489
|
+
dependency: dependencies.SlurmJobDependency
|
|
490
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
491
|
+
| None = ...,
|
|
492
|
+
identity: str = ...,
|
|
493
|
+
) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
|
|
494
|
+
|
|
495
|
+
@tp.overload
|
|
496
|
+
def add(
|
|
497
|
+
self,
|
|
498
|
+
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
499
|
+
args: JobArgs | None = ...,
|
|
500
|
+
*,
|
|
501
|
+
role: xm.WorkUnitRole | None = ...,
|
|
502
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
503
|
+
identity: str = ...,
|
|
504
|
+
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
505
|
+
|
|
506
|
+
@tp.overload
|
|
507
|
+
def add(
|
|
508
|
+
self,
|
|
509
|
+
job: xm.JobType,
|
|
510
|
+
*,
|
|
511
|
+
role: xm.AuxiliaryUnitRole,
|
|
512
|
+
dependency: dependencies.SlurmJobDependency | None = ...,
|
|
513
|
+
identity: str = ...,
|
|
514
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
515
|
+
|
|
516
|
+
def add( # type: ignore
|
|
517
|
+
self,
|
|
518
|
+
job: xm.JobType,
|
|
519
|
+
args: JobArgs
|
|
520
|
+
| tp.Mapping[str, JobArgs]
|
|
521
|
+
| tp.Sequence[tp.Mapping[str, tp.Any]]
|
|
522
|
+
| None = None,
|
|
523
|
+
*,
|
|
524
|
+
role: xm.ExperimentUnitRole | None = None,
|
|
525
|
+
dependency: dependencies.SlurmJobDependency
|
|
526
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
527
|
+
| None = None,
|
|
528
|
+
identity: str = "",
|
|
529
|
+
) -> (
|
|
530
|
+
asyncio.Future[SlurmAuxiliaryUnit]
|
|
531
|
+
| asyncio.Future[SlurmWorkUnit]
|
|
532
|
+
| asyncio.Future[tp.Sequence[SlurmWorkUnit]]
|
|
533
|
+
):
|
|
534
|
+
if role is None:
|
|
535
|
+
role = xm.WorkUnitRole()
|
|
536
|
+
|
|
537
|
+
if isinstance(args, collections.abc.Sequence):
|
|
538
|
+
if not isinstance(role, xm.WorkUnitRole):
|
|
539
|
+
raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
|
|
540
|
+
if isinstance(job, xm.JobGroup):
|
|
541
|
+
raise ValueError(
|
|
542
|
+
"Job arrays over `xm.JobGroup`s aren't supported. "
|
|
543
|
+
"Slurm doesn't support job arrays over heterogeneous jobs. "
|
|
544
|
+
"Instead you should call `experiment.add` for each of these trials."
|
|
545
|
+
)
|
|
546
|
+
assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
|
|
547
|
+
|
|
548
|
+
# Validate job & args
|
|
549
|
+
for trial in args:
|
|
550
|
+
_validate_job(job, trial)
|
|
551
|
+
args = tp.cast(tp.Sequence[JobArgs], args)
|
|
552
|
+
|
|
553
|
+
return asyncio.wrap_future(
|
|
554
|
+
self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
|
|
555
|
+
loop=self._event_loop,
|
|
556
|
+
)
|
|
557
|
+
if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
|
|
558
|
+
raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
|
|
559
|
+
|
|
560
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
561
|
+
role = job.role
|
|
562
|
+
self._added_roles[type(role)] += 1
|
|
563
|
+
|
|
564
|
+
if self._should_reload_experiment_unit(role):
|
|
565
|
+
experiment_unit_future = self._get_experiment_unit(
|
|
566
|
+
self.experiment_id, identity, role, args
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
experiment_unit_future = self._create_experiment_unit(args, role, identity)
|
|
570
|
+
|
|
571
|
+
async def launch():
|
|
572
|
+
experiment_unit = await experiment_unit_future
|
|
573
|
+
try:
|
|
574
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
575
|
+
except Exception as experiment_exception:
|
|
576
|
+
logger.error(
|
|
577
|
+
"Stopping experiment unit (identity %r) after it failed with: %s",
|
|
578
|
+
identity,
|
|
579
|
+
experiment_exception,
|
|
580
|
+
)
|
|
581
|
+
try:
|
|
582
|
+
if isinstance(job, xm.AuxiliaryUnitJob):
|
|
583
|
+
experiment_unit.stop()
|
|
584
|
+
else:
|
|
585
|
+
experiment_unit.stop(
|
|
586
|
+
mark_as_failed=True,
|
|
587
|
+
message=f"Work unit creation failed. {traceback.format_exc()}",
|
|
588
|
+
)
|
|
589
|
+
except Exception as stop_exception: # pylint: disable=broad-except
|
|
590
|
+
logger.error("Couldn't stop experiment unit: %s", stop_exception)
|
|
591
|
+
raise
|
|
592
|
+
return experiment_unit
|
|
593
|
+
|
|
594
|
+
async def reload():
|
|
595
|
+
experiment_unit = await experiment_unit_future
|
|
596
|
+
try:
|
|
597
|
+
await experiment_unit.add(job, args, dependency=dependency, identity=identity)
|
|
598
|
+
except Exception as update_exception:
|
|
599
|
+
logging.error(
|
|
600
|
+
"Could not reload the experiment unit: %s",
|
|
601
|
+
update_exception,
|
|
602
|
+
)
|
|
603
|
+
raise
|
|
604
|
+
return experiment_unit
|
|
605
|
+
|
|
606
|
+
return asyncio.wrap_future(
|
|
607
|
+
self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
|
|
608
|
+
loop=self._event_loop,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
async def _launch_job_array(
|
|
612
|
+
self,
|
|
613
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
614
|
+
dependency: dependencies.SlurmJobDependency
|
|
615
|
+
| tp.Sequence[dependencies.SlurmJobDependency]
|
|
616
|
+
| None,
|
|
617
|
+
args: tp.Sequence[JobArgs],
|
|
618
|
+
role: xm.WorkUnitRole,
|
|
619
|
+
identity: str = "",
|
|
620
|
+
) -> tp.Sequence[SlurmWorkUnit]:
|
|
621
|
+
global _current_job_array_queue
|
|
622
|
+
|
|
623
|
+
# Create our job array queue and assign it to the current context
|
|
624
|
+
job_array_queue = asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]](maxsize=len(args))
|
|
625
|
+
_current_job_array_queue.set(job_array_queue)
|
|
626
|
+
|
|
627
|
+
# For each trial we'll schedule the job
|
|
628
|
+
# and collect the futures
|
|
629
|
+
wu_futures = []
|
|
630
|
+
for idx, trial in enumerate(args):
|
|
631
|
+
wu_futures.append(
|
|
632
|
+
self.add(
|
|
633
|
+
job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
|
|
634
|
+
)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# We'll wait until XManager has filled the queue.
|
|
638
|
+
# There are two cases here, either we were given an xm.Job
|
|
639
|
+
# in which case this will be trivial and filled immediately.
|
|
640
|
+
# The other case is when you have a job generator and this is less
|
|
641
|
+
# trivial, you have to wait for wu.add to be called.
|
|
642
|
+
while not job_array_queue.full():
|
|
643
|
+
await asyncio.sleep(0.1)
|
|
644
|
+
|
|
645
|
+
# All jobs have been resolved so now we'll perform sanity checks
|
|
646
|
+
# to make sure we can infer the sweep
|
|
647
|
+
executable, executor, name = None, None, None
|
|
648
|
+
resolved_args, resolved_env_vars, resolved_futures = [], [], []
|
|
649
|
+
while not job_array_queue.empty():
|
|
650
|
+
# XManager automatically converts jobs to job groups so we must check
|
|
651
|
+
# that there's only a single job in this job group
|
|
652
|
+
job_group_view, future = job_array_queue.get_nowait()
|
|
653
|
+
assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
|
|
654
|
+
job_view = mit.one(
|
|
655
|
+
job_group_view.jobs.values(),
|
|
656
|
+
too_short=ValueError("Expected a single `xm.Job` in job group."),
|
|
657
|
+
too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
if not isinstance(job_view, xm.Job):
|
|
661
|
+
raise ValueError("Only `xm.Job` is supported for job arrays. ")
|
|
662
|
+
|
|
663
|
+
if executable is None:
|
|
664
|
+
executable = job_view.executable
|
|
665
|
+
if id(job_view.executable) != id(executable):
|
|
666
|
+
raise RuntimeError("Found multiple executables in job array.")
|
|
667
|
+
|
|
668
|
+
if executor is None:
|
|
669
|
+
executor = job_view.executor
|
|
670
|
+
if id(job_view.executor) != id(executor):
|
|
671
|
+
raise RuntimeError("Found multiple executors in job array")
|
|
672
|
+
|
|
673
|
+
if name is None:
|
|
674
|
+
name = job_view.name
|
|
675
|
+
if job_view.name != name:
|
|
676
|
+
raise RuntimeError("Found multiple names in job array")
|
|
677
|
+
|
|
678
|
+
resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
|
|
679
|
+
resolved_env_vars.append(set(job_view.env_vars.items()))
|
|
680
|
+
resolved_futures.append(future)
|
|
681
|
+
assert executable is not None, "No executable found?"
|
|
682
|
+
assert executor is not None, "No executor found?"
|
|
683
|
+
assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
|
|
684
|
+
assert (
|
|
685
|
+
executor.requirements.cluster is not None
|
|
686
|
+
), "Cluster must be specified on requirements."
|
|
687
|
+
|
|
688
|
+
# XManager merges job arguments with keyword arguments with job arguments
|
|
689
|
+
# coming first. These are the arguments that may be common across all jobs
|
|
690
|
+
# so we can find the largest common prefix and remove them from each job.
|
|
691
|
+
common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
|
|
692
|
+
common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
|
|
693
|
+
|
|
694
|
+
sweep_args = [
|
|
695
|
+
JobArgs(
|
|
696
|
+
args=functools.reduce(
|
|
697
|
+
# Remove the common arguments from each job
|
|
698
|
+
lambda args, to_remove: args.remove_args(to_remove),
|
|
699
|
+
common_args,
|
|
700
|
+
xm.SequentialArgs.from_collection(a),
|
|
701
|
+
),
|
|
702
|
+
env_vars=dict(e.difference(common_env_vars)),
|
|
703
|
+
)
|
|
704
|
+
for a, e in zip(resolved_args, resolved_env_vars)
|
|
705
|
+
]
|
|
706
|
+
|
|
707
|
+
# Dependency resolution
|
|
708
|
+
resolved_dependency = None
|
|
709
|
+
resolved_dependency_task_id_order = None
|
|
710
|
+
# one-to-one
|
|
711
|
+
if isinstance(dependency, collections.abc.Sequence):
|
|
712
|
+
if len(dependency) != len(wu_futures):
|
|
713
|
+
raise ValueError("Dependency list must be the same length as the number of trials.")
|
|
714
|
+
assert len(dependency) > 0, "Dependency list must not be empty."
|
|
715
|
+
|
|
716
|
+
# Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
|
|
717
|
+
# for any array jobs.
|
|
718
|
+
def _maybe_convert_afterok(
|
|
719
|
+
dep: dependencies.SlurmJobDependency,
|
|
720
|
+
) -> dependencies.SlurmJobDependency:
|
|
721
|
+
if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
|
|
722
|
+
handle.slurm_job.is_array_job for handle in dep.handles
|
|
723
|
+
]):
|
|
724
|
+
return dependencies.SlurmJobArrayDependencyAfterOK([
|
|
725
|
+
dataclasses.replace(
|
|
726
|
+
handle,
|
|
727
|
+
slurm_job=handle.slurm_job.array_job_id,
|
|
728
|
+
)
|
|
729
|
+
for handle in dep.handles
|
|
730
|
+
])
|
|
731
|
+
return dep
|
|
732
|
+
|
|
733
|
+
dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
|
|
734
|
+
dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
|
|
735
|
+
dependency_differences = functools.reduce(set.difference, dependency_sets, set())
|
|
736
|
+
# There should be NO differences between the dependencies of each trial after conversion.
|
|
737
|
+
if len(dependency_differences) > 0:
|
|
738
|
+
raise ValueError(
|
|
739
|
+
f"Found variable dependencies across trials: {dependency_differences}. "
|
|
740
|
+
"Slurm job arrays require the same dependencies across all trials. "
|
|
741
|
+
)
|
|
742
|
+
resolved_dependency = dependencies_converted[0]
|
|
743
|
+
|
|
744
|
+
# This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
|
|
745
|
+
# in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
|
|
746
|
+
original_array_dependencies = [
|
|
747
|
+
mit.one(
|
|
748
|
+
filter(
|
|
749
|
+
lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
|
|
750
|
+
and all([handle.slurm_job.is_array_job for handle in dep.handles]),
|
|
751
|
+
deps.flatten(),
|
|
752
|
+
)
|
|
753
|
+
)
|
|
754
|
+
for deps in dependency
|
|
755
|
+
]
|
|
756
|
+
resolved_dependency_task_id_order = [
|
|
757
|
+
int(
|
|
758
|
+
mit.one(
|
|
759
|
+
functools.reduce(
|
|
760
|
+
set.difference,
|
|
761
|
+
[handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
|
|
762
|
+
)
|
|
763
|
+
)
|
|
764
|
+
)
|
|
765
|
+
for dep in original_array_dependencies
|
|
766
|
+
]
|
|
767
|
+
assert len(resolved_dependency_task_id_order) == len(sweep_args)
|
|
768
|
+
assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
|
|
769
|
+
"Dependent job array tasks should have task ids 0, 1, ..., N. "
|
|
770
|
+
f"Found: {resolved_dependency_task_id_order}"
|
|
771
|
+
)
|
|
772
|
+
# one-to-many
|
|
773
|
+
elif isinstance(dependency, dependencies.SlurmJobDependency):
|
|
774
|
+
resolved_dependency = dependency
|
|
775
|
+
assert resolved_dependency is None or isinstance(
|
|
776
|
+
resolved_dependency, dependencies.SlurmJobDependency
|
|
777
|
+
), "Invalid dependency type"
|
|
778
|
+
|
|
779
|
+
# No support for sweep_env_vars right now.
|
|
780
|
+
# We schedule the job array and then we'll resolve all the work units with
|
|
781
|
+
# the handles Slurm gives back to us.
|
|
782
|
+
# If we already have handles for all the work units, we don't need to submit the
|
|
783
|
+
# job array to SLURM.
|
|
784
|
+
num_resolved_handles = sum(future.done() for future in resolved_futures)
|
|
785
|
+
if num_resolved_handles == 0:
|
|
786
|
+
try:
|
|
787
|
+
handles = await execution.launch(
|
|
788
|
+
job=xm.Job(
|
|
789
|
+
executable=executable,
|
|
790
|
+
executor=executor,
|
|
791
|
+
name=name,
|
|
792
|
+
args=xm.SequentialArgs.from_collection(common_args),
|
|
793
|
+
env_vars=dict(common_env_vars),
|
|
794
|
+
),
|
|
795
|
+
dependency=resolved_dependency,
|
|
796
|
+
args=[
|
|
797
|
+
sweep_args[resolved_dependency_task_id_order.index(i)]
|
|
798
|
+
for i in range(len(sweep_args))
|
|
799
|
+
]
|
|
800
|
+
if resolved_dependency_task_id_order
|
|
801
|
+
else sweep_args,
|
|
802
|
+
experiment_id=self.experiment_id,
|
|
803
|
+
identity=identity,
|
|
804
|
+
)
|
|
805
|
+
if resolved_dependency_task_id_order:
|
|
806
|
+
handles = [handles[i] for i in resolved_dependency_task_id_order]
|
|
807
|
+
except Exception as e:
|
|
808
|
+
for future in resolved_futures:
|
|
809
|
+
future.set_exception(e)
|
|
810
|
+
raise
|
|
811
|
+
else:
|
|
812
|
+
for handle, future in zip(handles, resolved_futures):
|
|
813
|
+
future.set_result(handle)
|
|
814
|
+
elif 0 < num_resolved_handles < len(resolved_futures):
|
|
815
|
+
raise RuntimeError(
|
|
816
|
+
"Some array job elements have handles, but some don't. This shouldn't happen."
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
wus = await asyncio.gather(*wu_futures)
|
|
820
|
+
|
|
821
|
+
_current_job_array_queue.set(None)
|
|
822
|
+
return wus
|
|
823
|
+
|
|
824
|
+
def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
|
|
825
|
+
if identity == "":
|
|
826
|
+
return None
|
|
827
|
+
for unit in self._experiment_units:
|
|
828
|
+
if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
|
|
829
|
+
return unit
|
|
830
|
+
return None
|
|
831
|
+
|
|
832
|
+
def _create_experiment_unit( # type: ignore
|
|
833
|
+
self,
|
|
834
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None,
|
|
835
|
+
role: xm.ExperimentUnitRole,
|
|
836
|
+
identity: str,
|
|
837
|
+
) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
|
|
838
|
+
def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
|
|
839
|
+
work_unit = SlurmWorkUnit(
|
|
840
|
+
self,
|
|
841
|
+
self._create_task,
|
|
842
|
+
args,
|
|
843
|
+
role,
|
|
844
|
+
self._work_unit_id_predictor,
|
|
845
|
+
identity=identity,
|
|
846
|
+
)
|
|
847
|
+
self._experiment_units.append(work_unit)
|
|
848
|
+
self._work_unit_count += 1
|
|
849
|
+
|
|
850
|
+
api.client().insert_work_unit(
|
|
851
|
+
self.experiment_id,
|
|
852
|
+
api.models.WorkUnitPatch(
|
|
853
|
+
wid=work_unit.work_unit_id,
|
|
854
|
+
identity=work_unit.identity,
|
|
855
|
+
args=json.dumps(args),
|
|
856
|
+
),
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
860
|
+
future.set_result(work_unit)
|
|
861
|
+
return future
|
|
862
|
+
|
|
863
|
+
def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
|
|
864
|
+
auxiliary_unit = SlurmAuxiliaryUnit(
|
|
865
|
+
self,
|
|
866
|
+
self._create_task,
|
|
867
|
+
args,
|
|
868
|
+
role,
|
|
869
|
+
identity=identity,
|
|
870
|
+
)
|
|
871
|
+
self._experiment_units.append(auxiliary_unit)
|
|
872
|
+
future = asyncio.Future[SlurmAuxiliaryUnit]()
|
|
873
|
+
future.set_result(auxiliary_unit)
|
|
874
|
+
return future
|
|
875
|
+
|
|
876
|
+
match role:
|
|
877
|
+
case xm.WorkUnitRole():
|
|
878
|
+
if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
|
|
879
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
880
|
+
future.set_result(existing_unit)
|
|
881
|
+
return future
|
|
882
|
+
return _create_work_unit(role)
|
|
883
|
+
case xm.AuxiliaryUnitRole():
|
|
884
|
+
return _create_auxiliary_unit(role)
|
|
885
|
+
case _:
|
|
886
|
+
raise ValueError(f"Unsupported role {role}")
|
|
887
|
+
|
|
888
|
+
def _get_experiment_unit( # type: ignore
|
|
889
|
+
self,
|
|
890
|
+
experiment_id: int,
|
|
891
|
+
identity: str,
|
|
892
|
+
role: xm.ExperimentUnitRole,
|
|
893
|
+
args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
|
|
894
|
+
) -> tp.Awaitable[SlurmExperimentUnit]:
|
|
895
|
+
del experiment_id, identity, role, args
|
|
896
|
+
raise NotImplementedError
|
|
897
|
+
|
|
898
|
+
def _should_reload_experiment_unit(self, role: xm.ExperimentUnitRole) -> bool:
|
|
899
|
+
del role
|
|
900
|
+
return False
|
|
901
|
+
|
|
902
|
+
async def __aenter__(self) -> "SlurmExperiment":
|
|
903
|
+
await super().__aenter__()
|
|
904
|
+
return self
|
|
905
|
+
|
|
906
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
907
|
+
# If no work units were added, delete this experiment
|
|
908
|
+
# This is to prevent empty experiments from being persisted
|
|
909
|
+
# and cluttering the database.
|
|
910
|
+
if len(self._experiment_units) == 0:
|
|
911
|
+
console.print(
|
|
912
|
+
f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
|
|
913
|
+
)
|
|
914
|
+
api.client().delete_experiment(self.experiment_id)
|
|
915
|
+
|
|
916
|
+
await super().__aexit__(exc_type, exc_value, traceback)
|
|
917
|
+
|
|
918
|
+
@property
|
|
919
|
+
def experiment_id(self) -> int:
|
|
920
|
+
return self._id
|
|
921
|
+
|
|
922
|
+
@property
|
|
923
|
+
def experiment_title(self) -> str:
|
|
924
|
+
return self.context.annotations.title
|
|
925
|
+
|
|
926
|
+
@property
|
|
927
|
+
def context(self) -> metadata_context.SlurmExperimentMetadataContext: # type: ignore
|
|
928
|
+
return self._experiment_context
|
|
929
|
+
|
|
930
|
+
@property
|
|
931
|
+
def work_unit_count(self) -> int:
|
|
932
|
+
return self._work_unit_count
|
|
933
|
+
|
|
934
|
+
def work_units(self) -> dict[int, SlurmWorkUnit]:
|
|
935
|
+
"""Gets work units created via self.add()."""
|
|
936
|
+
return {
|
|
937
|
+
wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
|
|
938
|
+
}
|
|
939
|
+
|
|
940
|
+
def __repr__(self, /) -> str:
|
|
941
|
+
return f"<SlurmExperiment {self.experiment_id} {self.experiment_title}>"
|
|
942
|
+
|
|
943
|
+
|
|
944
|
+
def create_experiment(experiment_title: str) -> SlurmExperiment:
|
|
945
|
+
"""Create Experiment."""
|
|
946
|
+
experiment_id = api.client().insert_experiment(
|
|
947
|
+
api.models.ExperimentPatch(title=experiment_title)
|
|
948
|
+
)
|
|
949
|
+
return SlurmExperiment(experiment_title=experiment_title, experiment_id=experiment_id)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
953
|
+
"""Get Experiment."""
|
|
954
|
+
experiment_model = api.client().get_experiment(experiment_id)
|
|
955
|
+
experiment = SlurmExperiment(
|
|
956
|
+
experiment_title=experiment_model.title, experiment_id=experiment_id
|
|
957
|
+
)
|
|
958
|
+
experiment._work_unit_id_predictor = id_predictor.Predictor(1)
|
|
959
|
+
|
|
960
|
+
# Populate annotations
|
|
961
|
+
experiment.context.annotations.description = experiment_model.description or ""
|
|
962
|
+
experiment.context.annotations.note = experiment_model.note or ""
|
|
963
|
+
experiment.context.annotations.tags = experiment_model.tags or []
|
|
964
|
+
|
|
965
|
+
# Populate artifacts
|
|
966
|
+
for artifact in experiment_model.artifacts:
|
|
967
|
+
experiment.context.artifacts[artifact.name] = artifact.uri
|
|
968
|
+
|
|
969
|
+
# Populate work units
|
|
970
|
+
for wu_model in experiment_model.work_units:
|
|
971
|
+
work_unit = SlurmWorkUnit(
|
|
972
|
+
experiment=experiment,
|
|
973
|
+
create_task=experiment._create_task,
|
|
974
|
+
args=json.loads(wu_model.args) if wu_model.args else {},
|
|
975
|
+
role=xm.WorkUnitRole(),
|
|
976
|
+
identity=wu_model.identity or "",
|
|
977
|
+
work_unit_id_predictor=experiment._work_unit_id_predictor,
|
|
978
|
+
)
|
|
979
|
+
work_unit._work_unit_id = wu_model.wid
|
|
980
|
+
|
|
981
|
+
# Populate jobs for each work unit
|
|
982
|
+
for job_model in wu_model.jobs:
|
|
983
|
+
slurm_ssh_config = config.SSHConfig.deserialize(job_model.slurm_ssh_config)
|
|
984
|
+
handle = execution.SlurmHandle(
|
|
985
|
+
experiment_id=experiment_id,
|
|
986
|
+
ssh=slurm_ssh_config,
|
|
987
|
+
slurm_job=str(job_model.slurm_job_id),
|
|
988
|
+
job_name=job_model.name,
|
|
989
|
+
)
|
|
990
|
+
work_unit._execution_handles.append(handle)
|
|
991
|
+
|
|
992
|
+
# Populate artifacts for each work unit
|
|
993
|
+
for artifact in wu_model.artifacts:
|
|
994
|
+
work_unit.context.artifacts[artifact.name] = artifact.uri
|
|
995
|
+
|
|
996
|
+
experiment._experiment_units.append(work_unit)
|
|
997
|
+
experiment._work_unit_count += 1
|
|
998
|
+
|
|
999
|
+
return experiment
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
@functools.cache
|
|
1003
|
+
def get_current_experiment() -> SlurmExperiment | None:
|
|
1004
|
+
if xid := os.environ.get("XM_SLURM_EXPERIMENT_ID"):
|
|
1005
|
+
return get_experiment(int(xid))
|
|
1006
|
+
return None
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
@functools.cache
|
|
1010
|
+
def get_current_work_unit() -> SlurmWorkUnit | None:
|
|
1011
|
+
if (xid := os.environ.get("XM_SLURM_EXPERIMENT_ID")) and (
|
|
1012
|
+
wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
|
|
1013
|
+
):
|
|
1014
|
+
experiment = get_experiment(int(xid))
|
|
1015
|
+
return experiment.work_units()[int(wid)]
|
|
1016
|
+
return None
|