xmanager-slurm 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +44 -0
- xm_slurm/api.py +261 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +162 -0
- xm_slurm/console.py +3 -0
- xm_slurm/contrib/clusters/__init__.py +52 -0
- xm_slurm/contrib/clusters/drac.py +169 -0
- xm_slurm/executables.py +201 -0
- xm_slurm/execution.py +491 -0
- xm_slurm/executors.py +127 -0
- xm_slurm/experiment.py +737 -0
- xm_slurm/job_blocks.py +14 -0
- xm_slurm/packageables.py +292 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker/__init__.py +75 -0
- xm_slurm/packaging/docker/abc.py +112 -0
- xm_slurm/packaging/docker/cloud.py +503 -0
- xm_slurm/packaging/docker/local.py +206 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +52 -0
- xm_slurm/packaging/utils.py +202 -0
- xm_slurm/resources.py +150 -0
- xm_slurm/status.py +188 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
- xm_slurm/templates/docker/mamba.Dockerfile +27 -0
- xm_slurm/templates/docker/pdm.Dockerfile +31 -0
- xm_slurm/templates/docker/python.Dockerfile +24 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
- xm_slurm/templates/slurm/job.bash.j2 +78 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
- xm_slurm/utils.py +69 -0
- xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
- xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
- xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/experiment.py
ADDED
|
@@ -0,0 +1,737 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections.abc
|
|
3
|
+
import contextvars
|
|
4
|
+
import dataclasses
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import typing
|
|
10
|
+
from concurrent import futures
|
|
11
|
+
from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
|
|
12
|
+
|
|
13
|
+
from xmanager import xm
|
|
14
|
+
from xmanager.xm import async_packager, id_predictor
|
|
15
|
+
|
|
16
|
+
from xm_slurm import api, execution, executors
|
|
17
|
+
from xm_slurm.console import console
|
|
18
|
+
from xm_slurm.packaging import router
|
|
19
|
+
from xm_slurm.status import SlurmWorkUnitStatus
|
|
20
|
+
from xm_slurm.utils import UserSet
|
|
21
|
+
|
|
22
|
+
_current_job_array_queue = contextvars.ContextVar[
|
|
23
|
+
asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
|
|
24
|
+
]("_current_job_array_queue", default=None)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _validate_job(
|
|
28
|
+
job: xm.JobType,
|
|
29
|
+
args_view: Mapping[str, Any],
|
|
30
|
+
) -> None:
|
|
31
|
+
if not args_view:
|
|
32
|
+
return
|
|
33
|
+
if not isinstance(args_view, collections.abc.Mapping):
|
|
34
|
+
raise ValueError("Job arguments via `experiment.add` must be mappings")
|
|
35
|
+
|
|
36
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
|
|
37
|
+
raise ValueError("Job group is empty")
|
|
38
|
+
|
|
39
|
+
if isinstance(job, xm.JobGroup) and any(
|
|
40
|
+
isinstance(child, xm.JobGroup) for child in job.jobs.values()
|
|
41
|
+
):
|
|
42
|
+
raise ValueError("Nested job groups are not supported")
|
|
43
|
+
|
|
44
|
+
allowed_keys = {"args", "env_vars"}
|
|
45
|
+
for key, expanded in args_view.items():
|
|
46
|
+
if isinstance(job, xm.JobGroup) and len(job.jobs) > 1 and key not in job.jobs:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Argument key `{key}` doesn't exist in job group with keys {job.jobs.keys()}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if isinstance(job, xm.JobGroup) and key in job.jobs:
|
|
52
|
+
_validate_job(job.jobs[key], expanded)
|
|
53
|
+
elif key not in allowed_keys:
|
|
54
|
+
raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
58
|
+
class Artifact:
|
|
59
|
+
name: str
|
|
60
|
+
uri: str
|
|
61
|
+
|
|
62
|
+
def __hash__(self) -> int:
|
|
63
|
+
return hash(self.name)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ContextArtifacts(UserSet[Artifact]):
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
owner: "SlurmExperiment | SlurmExperimentUnit",
|
|
70
|
+
*,
|
|
71
|
+
artifacts: Sequence[Artifact],
|
|
72
|
+
):
|
|
73
|
+
super().__init__(
|
|
74
|
+
artifacts,
|
|
75
|
+
on_add=self._on_add_artifact,
|
|
76
|
+
on_remove=self._on_remove_artifact,
|
|
77
|
+
on_discard=self._on_remove_artifact,
|
|
78
|
+
)
|
|
79
|
+
self._owner = owner
|
|
80
|
+
self._create_task = self._owner._create_task
|
|
81
|
+
|
|
82
|
+
def _on_add_artifact(self, artifact: Artifact) -> None:
|
|
83
|
+
match self._owner:
|
|
84
|
+
case SlurmExperiment():
|
|
85
|
+
api.client().insert_experiment_artifact(
|
|
86
|
+
self._owner.experiment_id,
|
|
87
|
+
api.ArtifactModel(
|
|
88
|
+
name=artifact.name,
|
|
89
|
+
uri=artifact.uri,
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
case SlurmWorkUnit():
|
|
93
|
+
api.client().insert_work_unit_artifact(
|
|
94
|
+
self._owner.experiment_id,
|
|
95
|
+
self._owner.work_unit_id,
|
|
96
|
+
api.ArtifactModel(
|
|
97
|
+
name=artifact.name,
|
|
98
|
+
uri=artifact.uri,
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _on_remove_artifact(self, artifact: Artifact) -> None:
|
|
103
|
+
match self._owner:
|
|
104
|
+
case SlurmExperiment():
|
|
105
|
+
api.client().delete_experiment_artifact(self._owner.experiment_id, artifact.name)
|
|
106
|
+
case SlurmWorkUnit():
|
|
107
|
+
api.client().delete_work_unit_artifact(
|
|
108
|
+
self._owner.experiment_id, self._owner.work_unit_id, artifact.name
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
113
|
+
class SlurmExperimentUnitMetadataContext:
|
|
114
|
+
artifacts: ContextArtifacts
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
118
|
+
"""ExperimentUnit is a collection of semantically associated `Job`s."""
|
|
119
|
+
|
|
120
|
+
experiment: "SlurmExperiment"
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
experiment: xm.Experiment,
|
|
125
|
+
create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
|
|
126
|
+
args: Mapping[str, Any] | None,
|
|
127
|
+
role: xm.ExperimentUnitRole,
|
|
128
|
+
) -> None:
|
|
129
|
+
super().__init__(experiment, create_task, args, role)
|
|
130
|
+
self._launched_jobs: list[xm.LaunchedJob] = []
|
|
131
|
+
self._execution_handles: list[execution.SlurmHandle] = []
|
|
132
|
+
self._context = SlurmExperimentUnitMetadataContext(
|
|
133
|
+
artifacts=ContextArtifacts(owner=self, artifacts=[]),
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def _submit_jobs_for_execution(
|
|
137
|
+
self,
|
|
138
|
+
job: xm.Job | xm.JobGroup,
|
|
139
|
+
args_view: Mapping[str, Any],
|
|
140
|
+
identity: str | None = None,
|
|
141
|
+
) -> execution.SlurmHandle:
|
|
142
|
+
return await execution.launch(
|
|
143
|
+
job=job,
|
|
144
|
+
args=args_view,
|
|
145
|
+
experiment_id=self.experiment_id,
|
|
146
|
+
identity=identity,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
|
|
150
|
+
match job:
|
|
151
|
+
case xm.JobGroup() as job_group:
|
|
152
|
+
for job in job_group.jobs.values():
|
|
153
|
+
self._launched_jobs.append(
|
|
154
|
+
xm.LaunchedJob(
|
|
155
|
+
name=job.name, # type: ignore
|
|
156
|
+
address=str(handle.job_id),
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
case xm.Job():
|
|
160
|
+
self._launched_jobs.append(
|
|
161
|
+
xm.LaunchedJob(
|
|
162
|
+
name=handle.job.name, # type: ignore
|
|
163
|
+
address=str(handle.job_id),
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
async def _wait_until_complete(self) -> None:
|
|
168
|
+
try:
|
|
169
|
+
await asyncio.gather(*[handle.wait() for handle in self._execution_handles])
|
|
170
|
+
except RuntimeError as error:
|
|
171
|
+
raise xm.ExperimentUnitFailedError(error)
|
|
172
|
+
|
|
173
|
+
def stop(
|
|
174
|
+
self,
|
|
175
|
+
*,
|
|
176
|
+
mark_as_failed: bool = False,
|
|
177
|
+
mark_as_completed: bool = False,
|
|
178
|
+
message: str | None = None,
|
|
179
|
+
) -> None:
|
|
180
|
+
del mark_as_failed, mark_as_completed, message
|
|
181
|
+
|
|
182
|
+
async def _stop_awaitable() -> None:
|
|
183
|
+
try:
|
|
184
|
+
await asyncio.gather(*[handle.stop() for handle in self._execution_handles])
|
|
185
|
+
except RuntimeError as error:
|
|
186
|
+
raise xm.ExperimentUnitFailedError(error)
|
|
187
|
+
|
|
188
|
+
self.experiment._create_task(_stop_awaitable())
|
|
189
|
+
|
|
190
|
+
async def get_status(self) -> SlurmWorkUnitStatus:
|
|
191
|
+
states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
|
|
192
|
+
return SlurmWorkUnitStatus.aggregate(states)
|
|
193
|
+
|
|
194
|
+
def launched_jobs(self) -> list[xm.LaunchedJob]:
|
|
195
|
+
return self._launched_jobs
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def context(self) -> SlurmExperimentUnitMetadataContext:
|
|
199
|
+
return self._context
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
experiment: "SlurmExperiment",
|
|
206
|
+
create_task: Callable[[Awaitable[Any]], futures.Future],
|
|
207
|
+
args: Mapping[str, Any],
|
|
208
|
+
role: xm.ExperimentUnitRole,
|
|
209
|
+
work_unit_id_predictor: id_predictor.Predictor,
|
|
210
|
+
) -> None:
|
|
211
|
+
super().__init__(experiment, create_task, args, role)
|
|
212
|
+
self._work_unit_id_predictor = work_unit_id_predictor
|
|
213
|
+
self._work_unit_id = self._work_unit_id_predictor.reserve_id()
|
|
214
|
+
api.client().insert_work_unit(
|
|
215
|
+
self.experiment_id,
|
|
216
|
+
api.WorkUnitPatchModel(
|
|
217
|
+
wid=self.work_unit_id,
|
|
218
|
+
identity=self.identity,
|
|
219
|
+
args=json.dumps(args),
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
|
|
224
|
+
self._execution_handles.append(handle)
|
|
225
|
+
api.client().insert_job(
|
|
226
|
+
self.experiment_id,
|
|
227
|
+
self.work_unit_id,
|
|
228
|
+
api.SlurmJobModel(
|
|
229
|
+
name=self.experiment_unit_name,
|
|
230
|
+
slurm_job_id=handle.job_id, # type: ignore
|
|
231
|
+
slurm_cluster=json.dumps({
|
|
232
|
+
"host": handle.ssh_connection_options.host,
|
|
233
|
+
"username": handle.ssh_connection_options.username,
|
|
234
|
+
"port": handle.ssh_connection_options.port,
|
|
235
|
+
"config": handle.ssh_connection_options.config.get_options(False),
|
|
236
|
+
}),
|
|
237
|
+
),
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
async def _launch_job_group(
|
|
241
|
+
self,
|
|
242
|
+
job: xm.JobGroup,
|
|
243
|
+
args_view: Mapping[str, Any],
|
|
244
|
+
identity: str,
|
|
245
|
+
) -> None:
|
|
246
|
+
global _current_job_array_queue
|
|
247
|
+
_validate_job(job, args_view)
|
|
248
|
+
|
|
249
|
+
future = asyncio.Future()
|
|
250
|
+
async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
|
|
251
|
+
# If we're scheduling as part of a job queue (i.e., the queue is set on the context)
|
|
252
|
+
# then we'll insert the job and current future that'll get resolved to the
|
|
253
|
+
# proper handle.
|
|
254
|
+
if job_array_queue := _current_job_array_queue.get():
|
|
255
|
+
job_array_queue.put_nowait((job, future))
|
|
256
|
+
# Otherwise we'll resolve the future with the scheduled job immediately
|
|
257
|
+
else:
|
|
258
|
+
# Set the result inside of the context manager so we don't get out-of-order
|
|
259
|
+
# id scheduling...
|
|
260
|
+
future.set_result(
|
|
261
|
+
await self._submit_jobs_for_execution(job, args_view, identity=identity)
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Wait for the job handle, this is either coming from scheduling the job array
|
|
265
|
+
# or from the single job above.
|
|
266
|
+
handle = await future
|
|
267
|
+
self._ingest_handle(handle)
|
|
268
|
+
self._ingest_launched_jobs(job, handle)
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def experiment_unit_name(self) -> str:
|
|
272
|
+
return f"{self.experiment_id}_{self._work_unit_id}"
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def work_unit_id(self) -> int:
|
|
276
|
+
return self._work_unit_id
|
|
277
|
+
|
|
278
|
+
def __repr__(self, /) -> str:
|
|
279
|
+
return f"<SlurmWorkUnit {self.experiment_unit_name}>"
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class SlurmAuxiliaryUnit(SlurmExperimentUnit):
|
|
283
|
+
"""An auxiliary unit operated by the Slurm backend."""
|
|
284
|
+
|
|
285
|
+
def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
|
|
286
|
+
del handle
|
|
287
|
+
console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
|
|
288
|
+
|
|
289
|
+
async def _launch_job_group(
|
|
290
|
+
self,
|
|
291
|
+
job: xm.Job | xm.JobGroup,
|
|
292
|
+
args_view: Mapping[str, Any],
|
|
293
|
+
identity: str,
|
|
294
|
+
) -> None:
|
|
295
|
+
_validate_job(job, args_view)
|
|
296
|
+
|
|
297
|
+
slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
|
|
298
|
+
self._ingest_handle(slurm_handle)
|
|
299
|
+
self._ingest_launched_jobs(job, slurm_handle)
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def experiment_unit_name(self) -> str:
|
|
303
|
+
return f"{self.experiment_id}_auxiliary"
|
|
304
|
+
|
|
305
|
+
def __repr__(self, /) -> str:
|
|
306
|
+
return f"<SlurmAuxiliaryUnit {self.experiment_unit_name}>"
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class SlurmExperimentContextAnnotations:
|
|
310
|
+
def __init__(
|
|
311
|
+
self,
|
|
312
|
+
experiment: "SlurmExperiment",
|
|
313
|
+
*,
|
|
314
|
+
title: str,
|
|
315
|
+
tags: set[str] | None = None,
|
|
316
|
+
description: str | None = None,
|
|
317
|
+
note: str | None = None,
|
|
318
|
+
):
|
|
319
|
+
self._experiment = experiment
|
|
320
|
+
self._create_task = self._experiment._create_task
|
|
321
|
+
self._title = title
|
|
322
|
+
self._tags = UserSet(
|
|
323
|
+
tags or set(),
|
|
324
|
+
on_add=self._on_tag_added,
|
|
325
|
+
on_remove=self._on_tag_removed,
|
|
326
|
+
on_discard=self._on_tag_removed,
|
|
327
|
+
)
|
|
328
|
+
self._description = description or ""
|
|
329
|
+
self._note = note or ""
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def title(self) -> str:
|
|
333
|
+
return self._title
|
|
334
|
+
|
|
335
|
+
@title.setter
|
|
336
|
+
def title(self, value: str) -> None:
|
|
337
|
+
self._title = value
|
|
338
|
+
api.client().update_experiment(
|
|
339
|
+
self._experiment.experiment_id,
|
|
340
|
+
api.ExperimentPatchModel(title=value),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
@property
|
|
344
|
+
def description(self) -> str:
|
|
345
|
+
return self._description
|
|
346
|
+
|
|
347
|
+
@description.setter
|
|
348
|
+
def description(self, value: str) -> None:
|
|
349
|
+
self._description = value
|
|
350
|
+
api.client().update_experiment(
|
|
351
|
+
self._experiment.experiment_id,
|
|
352
|
+
api.ExperimentPatchModel(description=value),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
@property
|
|
356
|
+
def note(self) -> str:
|
|
357
|
+
return self._note
|
|
358
|
+
|
|
359
|
+
@note.setter
|
|
360
|
+
def note(self, value: str) -> None:
|
|
361
|
+
self._note = value
|
|
362
|
+
api.client().update_experiment(
|
|
363
|
+
self._experiment.experiment_id,
|
|
364
|
+
api.ExperimentPatchModel(note=value),
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
@property
|
|
368
|
+
def tags(self) -> MutableSet[str]:
|
|
369
|
+
return self._tags
|
|
370
|
+
|
|
371
|
+
@tags.setter
|
|
372
|
+
def tags(self, tags: set[str]) -> None:
|
|
373
|
+
# TODO(jfarebro): Create custom tag collection
|
|
374
|
+
# and set it here, we need this so we can hook add and remove
|
|
375
|
+
# to mutate the database transparently
|
|
376
|
+
self._tags = UserSet(tags, on_add=self._on_tag_added, on_remove=self._on_tag_removed)
|
|
377
|
+
api.client().update_experiment(
|
|
378
|
+
self._experiment.experiment_id,
|
|
379
|
+
api.ExperimentPatchModel(tags=list(self._tags)),
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def _on_tag_added(self, tag: str) -> None:
|
|
383
|
+
del tag
|
|
384
|
+
api.client().update_experiment(
|
|
385
|
+
self._experiment.experiment_id,
|
|
386
|
+
api.ExperimentPatchModel(tags=list(self._tags)),
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def _on_tag_removed(self, tag: str) -> None:
|
|
390
|
+
del tag
|
|
391
|
+
api.client().update_experiment(
|
|
392
|
+
self._experiment.experiment_id,
|
|
393
|
+
api.ExperimentPatchModel(tags=list(self._tags)),
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class SlurmExperimentContextArtifacts(ContextArtifacts):
|
|
398
|
+
def add_graphviz_config(self, config: str) -> None:
|
|
399
|
+
self.add(Artifact(name="GRAPHVIZ", uri=f"graphviz://{config}"))
|
|
400
|
+
|
|
401
|
+
def add_python_config(self, config: str) -> None:
|
|
402
|
+
self.add(Artifact(name="PYTHON", uri=config))
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
406
|
+
class SlurmExperimentMetadataContext:
|
|
407
|
+
annotations: SlurmExperimentContextAnnotations
|
|
408
|
+
artifacts: ContextArtifacts
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class SlurmExperiment(xm.Experiment):
|
|
412
|
+
_id: int
|
|
413
|
+
_experiment_units: list[SlurmExperimentUnit]
|
|
414
|
+
_experiment_context: SlurmExperimentMetadataContext
|
|
415
|
+
_work_unit_count: int
|
|
416
|
+
_async_packager = async_packager.AsyncPackager(router.package)
|
|
417
|
+
|
|
418
|
+
def __init__(
|
|
419
|
+
self,
|
|
420
|
+
experiment_title: str,
|
|
421
|
+
experiment_id: int,
|
|
422
|
+
) -> None:
|
|
423
|
+
super().__init__()
|
|
424
|
+
self._id = experiment_id
|
|
425
|
+
self._experiment_units = []
|
|
426
|
+
self._experiment_context = SlurmExperimentMetadataContext(
|
|
427
|
+
annotations=SlurmExperimentContextAnnotations(
|
|
428
|
+
experiment=self,
|
|
429
|
+
title=experiment_title,
|
|
430
|
+
),
|
|
431
|
+
artifacts=ContextArtifacts(self, artifacts=[]),
|
|
432
|
+
)
|
|
433
|
+
self._work_unit_count = 0
|
|
434
|
+
|
|
435
|
+
@typing.overload
|
|
436
|
+
def add(
|
|
437
|
+
self,
|
|
438
|
+
job: xm.AuxiliaryUnitJob,
|
|
439
|
+
args: Mapping[str, Any] | None = ...,
|
|
440
|
+
*,
|
|
441
|
+
identity: str = "",
|
|
442
|
+
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
443
|
+
|
|
444
|
+
@typing.overload
|
|
445
|
+
def add(
|
|
446
|
+
self,
|
|
447
|
+
job: xm.JobType,
|
|
448
|
+
args: Mapping[str, Any] | None = ...,
|
|
449
|
+
*,
|
|
450
|
+
role: xm.WorkUnitRole = ...,
|
|
451
|
+
identity: str = "",
|
|
452
|
+
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
453
|
+
|
|
454
|
+
@typing.overload
|
|
455
|
+
def add(
|
|
456
|
+
self,
|
|
457
|
+
job: xm.JobType,
|
|
458
|
+
args: Mapping[str, Any] | None,
|
|
459
|
+
*,
|
|
460
|
+
role: xm.ExperimentUnitRole,
|
|
461
|
+
identity: str = "",
|
|
462
|
+
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
463
|
+
|
|
464
|
+
@typing.overload
|
|
465
|
+
def add(
|
|
466
|
+
self,
|
|
467
|
+
job: xm.JobType,
|
|
468
|
+
args: Mapping[str, Any] | None = ...,
|
|
469
|
+
*,
|
|
470
|
+
role: xm.ExperimentUnitRole,
|
|
471
|
+
identity: str = "",
|
|
472
|
+
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
473
|
+
|
|
474
|
+
@typing.overload
|
|
475
|
+
def add(
|
|
476
|
+
self,
|
|
477
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
478
|
+
args: Sequence[Mapping[str, Any]],
|
|
479
|
+
*,
|
|
480
|
+
role: xm.WorkUnitRole = ...,
|
|
481
|
+
identity: str = "",
|
|
482
|
+
) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
|
|
483
|
+
|
|
484
|
+
def add(
|
|
485
|
+
self,
|
|
486
|
+
job: xm.JobType,
|
|
487
|
+
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
|
|
488
|
+
*,
|
|
489
|
+
role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
|
|
490
|
+
identity: str = "",
|
|
491
|
+
) -> (
|
|
492
|
+
asyncio.Future[SlurmExperimentUnit]
|
|
493
|
+
| asyncio.Future[SlurmWorkUnit]
|
|
494
|
+
| asyncio.Future[Sequence[SlurmWorkUnit]]
|
|
495
|
+
):
|
|
496
|
+
if isinstance(args, collections.abc.Sequence):
|
|
497
|
+
if not isinstance(role, xm.WorkUnitRole):
|
|
498
|
+
raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
|
|
499
|
+
if identity:
|
|
500
|
+
raise ValueError(
|
|
501
|
+
"Cannot set an identity on the root add call. "
|
|
502
|
+
"Please use a job generator and set the identity within."
|
|
503
|
+
)
|
|
504
|
+
if isinstance(job, xm.JobGroup):
|
|
505
|
+
raise ValueError(
|
|
506
|
+
"Job arrays over `xm.JobGroup`s aren't supported. "
|
|
507
|
+
"Slurm doesn't support job arrays over heterogeneous jobs. "
|
|
508
|
+
"Instead you should call `experiment.add` for each of these trials."
|
|
509
|
+
)
|
|
510
|
+
assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
|
|
511
|
+
|
|
512
|
+
return asyncio.wrap_future(
|
|
513
|
+
self._create_task(self._launch_job_array(job, args, role, identity))
|
|
514
|
+
)
|
|
515
|
+
else:
|
|
516
|
+
return super().add(job, args, role=role, identity=identity) # type: ignore
|
|
517
|
+
|
|
518
|
+
async def _launch_job_array(
|
|
519
|
+
self,
|
|
520
|
+
job: xm.Job | xm.JobGeneratorType,
|
|
521
|
+
args: Sequence[Mapping[str, Any]],
|
|
522
|
+
role: xm.WorkUnitRole,
|
|
523
|
+
identity: str = "",
|
|
524
|
+
) -> Sequence[SlurmWorkUnit]:
|
|
525
|
+
global _current_job_array_queue
|
|
526
|
+
|
|
527
|
+
# Create our job array queue and assign it to the current context
|
|
528
|
+
job_array_queue = asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]](maxsize=len(args))
|
|
529
|
+
_current_job_array_queue.set(job_array_queue)
|
|
530
|
+
|
|
531
|
+
# For each trial we'll schedule the job
|
|
532
|
+
# and collect the futures
|
|
533
|
+
wu_futures = []
|
|
534
|
+
for trial in args:
|
|
535
|
+
wu_futures.append(super().add(job, args=trial, role=role, identity=identity))
|
|
536
|
+
|
|
537
|
+
# TODO(jfarebro): Set a timeout here
|
|
538
|
+
# We'll wait until XManager has filled the queue.
|
|
539
|
+
# There are two cases here, either we were given an xm.Job
|
|
540
|
+
# in which case this will be trivial and filled immediately.
|
|
541
|
+
# The other case is when you have a job generator and this is less
|
|
542
|
+
# trivial, you have to wait for wu.add to be called.
|
|
543
|
+
while not job_array_queue.full():
|
|
544
|
+
await asyncio.sleep(0.1)
|
|
545
|
+
|
|
546
|
+
# All jobs have been resolved
|
|
547
|
+
executable, executor, name = None, None, None
|
|
548
|
+
resolved_args, resolved_env_vars, resolved_futures = [], [], []
|
|
549
|
+
while not job_array_queue.empty():
|
|
550
|
+
# XManager automatically converts jobs to job groups so we must check
|
|
551
|
+
# that there's only a single job in this job group
|
|
552
|
+
job_group_view, future = job_array_queue.get_nowait()
|
|
553
|
+
assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
|
|
554
|
+
_, job_view = job_group_view.jobs.popitem()
|
|
555
|
+
|
|
556
|
+
if len(job_group_view.jobs) != 0 or not isinstance(job_view, xm.Job):
|
|
557
|
+
raise ValueError("Only `xm.Job` is supported for job arrays. ")
|
|
558
|
+
|
|
559
|
+
if executable is None:
|
|
560
|
+
executable = job_view.executable
|
|
561
|
+
if id(job_view.executable) != id(executable):
|
|
562
|
+
raise RuntimeError("Found multiple executables in job array.")
|
|
563
|
+
|
|
564
|
+
if executor is None:
|
|
565
|
+
executor = job_view.executor
|
|
566
|
+
if id(job_view.executor) != id(executor):
|
|
567
|
+
raise RuntimeError("Found multiple executors in job array")
|
|
568
|
+
|
|
569
|
+
if name is None:
|
|
570
|
+
name = job_view.name
|
|
571
|
+
if job_view.name != name:
|
|
572
|
+
raise RuntimeError("Found multiple names in job array")
|
|
573
|
+
|
|
574
|
+
resolved_args.append(
|
|
575
|
+
set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
|
|
576
|
+
)
|
|
577
|
+
resolved_env_vars.append(set(job_view.env_vars.items()))
|
|
578
|
+
resolved_futures.append(future)
|
|
579
|
+
assert executable is not None, "No executable found?"
|
|
580
|
+
assert executor is not None, "No executor found?"
|
|
581
|
+
assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
|
|
582
|
+
assert executor.requirements.cluster is not None, "Cluster must be set on executor."
|
|
583
|
+
|
|
584
|
+
common_args: set = functools.reduce(lambda a, b: a & b, resolved_args, set())
|
|
585
|
+
common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
|
|
586
|
+
|
|
587
|
+
sweep_args = [
|
|
588
|
+
{
|
|
589
|
+
"args": dict(a.difference(common_args)),
|
|
590
|
+
"env_vars": dict(e.difference(common_env_vars)),
|
|
591
|
+
}
|
|
592
|
+
for a, e in zip(resolved_args, resolved_env_vars)
|
|
593
|
+
]
|
|
594
|
+
|
|
595
|
+
# No support for sweep_env_vars right now.
|
|
596
|
+
# We schedule the job array and then we'll resolve all the work units with
|
|
597
|
+
# the handles Slurm gives back to us.
|
|
598
|
+
try:
|
|
599
|
+
handles = await execution.get_client().launch(
|
|
600
|
+
cluster=executor.requirements.cluster,
|
|
601
|
+
job=xm.Job(
|
|
602
|
+
executable=executable,
|
|
603
|
+
executor=executor,
|
|
604
|
+
name=name,
|
|
605
|
+
args=dict(common_args),
|
|
606
|
+
env_vars=dict(common_env_vars),
|
|
607
|
+
),
|
|
608
|
+
args=sweep_args,
|
|
609
|
+
experiment_id=self.experiment_id,
|
|
610
|
+
identity=identity,
|
|
611
|
+
)
|
|
612
|
+
except Exception as e:
|
|
613
|
+
for future in resolved_futures:
|
|
614
|
+
future.set_exception(e)
|
|
615
|
+
raise
|
|
616
|
+
else:
|
|
617
|
+
for handle, future in zip(handles, resolved_futures):
|
|
618
|
+
future.set_result(handle)
|
|
619
|
+
|
|
620
|
+
wus = await asyncio.gather(*wu_futures)
|
|
621
|
+
_current_job_array_queue.set(None)
|
|
622
|
+
return wus
|
|
623
|
+
|
|
624
|
+
def _create_experiment_unit(
|
|
625
|
+
self,
|
|
626
|
+
args: Mapping[str, Any],
|
|
627
|
+
role: xm.ExperimentUnitRole,
|
|
628
|
+
identity: str,
|
|
629
|
+
) -> Awaitable[SlurmWorkUnit]:
|
|
630
|
+
del identity
|
|
631
|
+
|
|
632
|
+
def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
|
|
633
|
+
work_unit = SlurmWorkUnit(
|
|
634
|
+
self,
|
|
635
|
+
self._create_task,
|
|
636
|
+
args,
|
|
637
|
+
role,
|
|
638
|
+
self._work_unit_id_predictor,
|
|
639
|
+
)
|
|
640
|
+
self._experiment_units.append(work_unit)
|
|
641
|
+
self._work_unit_count += 1
|
|
642
|
+
|
|
643
|
+
future = asyncio.Future()
|
|
644
|
+
future.set_result(work_unit)
|
|
645
|
+
return future
|
|
646
|
+
|
|
647
|
+
match role:
|
|
648
|
+
case xm.WorkUnitRole():
|
|
649
|
+
return _create_work_unit(role)
|
|
650
|
+
case _:
|
|
651
|
+
raise ValueError(f"Unsupported role {role}")
|
|
652
|
+
|
|
653
|
+
def _get_experiment_unit(
|
|
654
|
+
self,
|
|
655
|
+
experiment_id: int,
|
|
656
|
+
identity: str,
|
|
657
|
+
role: xm.ExperimentUnitRole,
|
|
658
|
+
args: Mapping[str, Any] | None = None,
|
|
659
|
+
) -> Awaitable[xm.ExperimentUnit]:
|
|
660
|
+
del experiment_id, identity, role, args
|
|
661
|
+
raise NotImplementedError
|
|
662
|
+
|
|
663
|
+
def _should_reload_experiment_unit(self, role: xm.ExperimentUnitRole) -> bool:
|
|
664
|
+
del role
|
|
665
|
+
return False
|
|
666
|
+
|
|
667
|
+
async def __aenter__(self) -> "SlurmExperiment":
|
|
668
|
+
await super().__aenter__()
|
|
669
|
+
return self
|
|
670
|
+
|
|
671
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
672
|
+
# If no work units were added, delete this experiment
|
|
673
|
+
# This is to prevent empty experiments from being persisted
|
|
674
|
+
# and cluttering the database.
|
|
675
|
+
if self.work_unit_count == 0:
|
|
676
|
+
console.print(
|
|
677
|
+
f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
|
|
678
|
+
)
|
|
679
|
+
api.client().delete_experiment(self.experiment_id)
|
|
680
|
+
|
|
681
|
+
await super().__aexit__(exc_type, exc_value, traceback)
|
|
682
|
+
|
|
683
|
+
@property
|
|
684
|
+
def experiment_id(self) -> int:
|
|
685
|
+
return self._id
|
|
686
|
+
|
|
687
|
+
@property
|
|
688
|
+
def experiment_title(self) -> str:
|
|
689
|
+
return self.context.annotations.title
|
|
690
|
+
|
|
691
|
+
@property
|
|
692
|
+
def context(self) -> SlurmExperimentMetadataContext:
|
|
693
|
+
return self._experiment_context
|
|
694
|
+
|
|
695
|
+
@property
|
|
696
|
+
def work_unit_count(self) -> int:
|
|
697
|
+
return self._work_unit_count
|
|
698
|
+
|
|
699
|
+
@property
|
|
700
|
+
def work_units(self) -> Mapping[int, SlurmWorkUnit]:
|
|
701
|
+
"""Gets work units created via self.add()."""
|
|
702
|
+
return {
|
|
703
|
+
wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
def __repr__(self, /) -> str:
|
|
707
|
+
return f"<SlurmExperiment {self.experiment_id} {self.experiment_title}>"
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def create_experiment(experiment_title: str) -> SlurmExperiment:
|
|
711
|
+
"""Create Experiment."""
|
|
712
|
+
experiment_id = api.client().insert_experiment(api.ExperimentPatchModel(title=experiment_title))
|
|
713
|
+
return SlurmExperiment(experiment_title=experiment_title, experiment_id=experiment_id)
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
717
|
+
"""Get Experiment."""
|
|
718
|
+
experiment_model = api.client().get_experiment(experiment_id)
|
|
719
|
+
# TODO(jfarebro): Fill in jobs and work units and annotations
|
|
720
|
+
return SlurmExperiment(experiment_title=experiment_model.title, experiment_id=experiment_id)
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
@functools.cache
|
|
724
|
+
def get_current_experiment() -> SlurmExperiment | None:
|
|
725
|
+
if xid := os.environ.get("XM_SLURM_EXPERIMENT_ID"):
|
|
726
|
+
return get_experiment(int(xid))
|
|
727
|
+
return None
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@functools.cache
|
|
731
|
+
def get_current_work_unit() -> SlurmWorkUnit | None:
|
|
732
|
+
if (xid := os.environ.get("XM_SLURM_EXPERIMENT_ID")) and (
|
|
733
|
+
wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
|
|
734
|
+
):
|
|
735
|
+
experiment = get_experiment(int(xid))
|
|
736
|
+
return experiment.work_units[int(wid)]
|
|
737
|
+
return None
|