xmanager-slurm 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +44 -0
- xm_slurm/api.py +261 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +162 -0
- xm_slurm/console.py +3 -0
- xm_slurm/contrib/clusters/__init__.py +52 -0
- xm_slurm/contrib/clusters/drac.py +169 -0
- xm_slurm/executables.py +201 -0
- xm_slurm/execution.py +491 -0
- xm_slurm/executors.py +127 -0
- xm_slurm/experiment.py +737 -0
- xm_slurm/job_blocks.py +14 -0
- xm_slurm/packageables.py +292 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker/__init__.py +75 -0
- xm_slurm/packaging/docker/abc.py +112 -0
- xm_slurm/packaging/docker/cloud.py +503 -0
- xm_slurm/packaging/docker/local.py +206 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +52 -0
- xm_slurm/packaging/utils.py +202 -0
- xm_slurm/resources.py +150 -0
- xm_slurm/status.py +188 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
- xm_slurm/templates/docker/mamba.Dockerfile +27 -0
- xm_slurm/templates/docker/pdm.Dockerfile +31 -0
- xm_slurm/templates/docker/python.Dockerfile +24 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
- xm_slurm/templates/slurm/job.bash.j2 +78 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
- xm_slurm/utils.py +69 -0
- xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
- xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
- xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from xm_slurm.executables import Dockerfile, DockerImage
|
|
4
|
+
from xm_slurm.executors import Slurm, SlurmSpec
|
|
5
|
+
from xm_slurm.experiment import (
|
|
6
|
+
Artifact,
|
|
7
|
+
create_experiment,
|
|
8
|
+
get_current_experiment,
|
|
9
|
+
get_current_work_unit,
|
|
10
|
+
get_experiment,
|
|
11
|
+
)
|
|
12
|
+
from xm_slurm.packageables import (
|
|
13
|
+
conda_container,
|
|
14
|
+
docker_container,
|
|
15
|
+
docker_image,
|
|
16
|
+
mamba_container,
|
|
17
|
+
pdm_container,
|
|
18
|
+
python_container,
|
|
19
|
+
)
|
|
20
|
+
from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
|
|
21
|
+
|
|
22
|
+
logging.getLogger("asyncssh").setLevel(logging.WARN)
|
|
23
|
+
logging.getLogger("httpx").setLevel(logging.WARN)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Artifact",
|
|
27
|
+
"conda_container",
|
|
28
|
+
"create_experiment",
|
|
29
|
+
"docker_container",
|
|
30
|
+
"docker_image",
|
|
31
|
+
"Dockerfile",
|
|
32
|
+
"DockerImage",
|
|
33
|
+
"get_current_experiment",
|
|
34
|
+
"get_current_work_unit",
|
|
35
|
+
"get_experiment",
|
|
36
|
+
"JobRequirements",
|
|
37
|
+
"mamba_container",
|
|
38
|
+
"pdm_container",
|
|
39
|
+
"python_container",
|
|
40
|
+
"ResourceQuantity",
|
|
41
|
+
"ResourceType",
|
|
42
|
+
"Slurm",
|
|
43
|
+
"SlurmSpec",
|
|
44
|
+
]
|
xm_slurm/api.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import functools
|
|
3
|
+
import importlib.util
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import typing
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
14
|
+
class ExperimentPatchModel:
|
|
15
|
+
title: str | None = None
|
|
16
|
+
description: str | None = None
|
|
17
|
+
note: str | None = None
|
|
18
|
+
tags: list[str] | None = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
22
|
+
class SlurmJobModel:
|
|
23
|
+
name: str
|
|
24
|
+
slurm_job_id: int
|
|
25
|
+
slurm_cluster: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
29
|
+
class ArtifactModel:
|
|
30
|
+
name: str
|
|
31
|
+
uri: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
35
|
+
class WorkUnitPatchModel:
|
|
36
|
+
wid: int
|
|
37
|
+
identity: str | None
|
|
38
|
+
args: str | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
42
|
+
class WorkUnitModel(WorkUnitPatchModel):
|
|
43
|
+
jobs: list[SlurmJobModel] = dataclasses.field(default_factory=list)
|
|
44
|
+
artifacts: list[ArtifactModel] = dataclasses.field(default_factory=list)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
48
|
+
class ExperimentModel:
|
|
49
|
+
title: str
|
|
50
|
+
description: str | None
|
|
51
|
+
note: str | None
|
|
52
|
+
tags: list[str] | None
|
|
53
|
+
|
|
54
|
+
work_units: list[WorkUnitModel]
|
|
55
|
+
artifacts: list[ArtifactModel]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class XManagerAPI:
|
|
59
|
+
def get_experiment(self, xid: int) -> ExperimentModel:
|
|
60
|
+
del xid
|
|
61
|
+
raise NotImplementedError("`get_experiment` is not implemented without a storage backend.")
|
|
62
|
+
|
|
63
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
64
|
+
del experiment_id
|
|
65
|
+
logger.debug("`delete_experiment` is not implemented without a storage backend.")
|
|
66
|
+
|
|
67
|
+
def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
|
|
68
|
+
del experiment
|
|
69
|
+
logger.debug("`insert_experiment` is not implemented without a storage backend.")
|
|
70
|
+
return int(time.time() * 10**3)
|
|
71
|
+
|
|
72
|
+
def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
|
|
73
|
+
del experiment_id, experiment_patch
|
|
74
|
+
logger.debug("`update_experiment` is not implemented without a storage backend.")
|
|
75
|
+
|
|
76
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
|
|
77
|
+
del experiment_id, work_unit_id, job
|
|
78
|
+
logger.debug("`insert_job` is not implemented without a storage backend.")
|
|
79
|
+
|
|
80
|
+
def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
|
|
81
|
+
del experiment_id, work_unit
|
|
82
|
+
logger.debug("`insert_work_unit` is not implemented without a storage backend.")
|
|
83
|
+
|
|
84
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
85
|
+
del experiment_id, work_unit_id, name
|
|
86
|
+
logger.debug("`delete_work_unit_artifact` is not implemented without a storage backend.")
|
|
87
|
+
|
|
88
|
+
def insert_work_unit_artifact(
|
|
89
|
+
self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
|
|
90
|
+
) -> None:
|
|
91
|
+
del experiment_id, work_unit_id, artifact
|
|
92
|
+
logger.debug("`insert_work_unit_artifact` is not implemented without a storage backend.")
|
|
93
|
+
|
|
94
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
95
|
+
del experiment_id, name
|
|
96
|
+
logger.debug("`delete_experiment_artifact` is not implemented without a storage backend.")
|
|
97
|
+
|
|
98
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
|
|
99
|
+
del experiment_id, artifact
|
|
100
|
+
logger.debug("`insert_experiment_artifact` is not implemented without a storage backend.")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class XManagerWebAPI(XManagerAPI):
|
|
104
|
+
def __init__(self, base_url: str, token: str):
|
|
105
|
+
if importlib.util.find_spec("xm_slurm_api_client") is None:
|
|
106
|
+
raise ImportError("xm_slurm_api_client not found.")
|
|
107
|
+
|
|
108
|
+
from xm_slurm_api_client import AuthenticatedClient # type: ignore
|
|
109
|
+
from xm_slurm_api_client import models as _models # type: ignore
|
|
110
|
+
|
|
111
|
+
self.models = _models
|
|
112
|
+
self.client = AuthenticatedClient(
|
|
113
|
+
base_url,
|
|
114
|
+
token=token,
|
|
115
|
+
raise_on_unexpected_status=True,
|
|
116
|
+
verify_ssl=False,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_experiment(self, xid: int) -> ExperimentModel:
|
|
120
|
+
from xm_slurm_api_client.api.experiment import ( # type: ignore
|
|
121
|
+
get_experiment as _get_experiment,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
experiment: Any = _get_experiment.sync(xid, client=self.client) # type: ignore
|
|
125
|
+
wus = []
|
|
126
|
+
for wu in experiment.work_units:
|
|
127
|
+
jobs = []
|
|
128
|
+
for job in wu.jobs:
|
|
129
|
+
jobs.append(SlurmJobModel(**job.dict()))
|
|
130
|
+
artifacts = []
|
|
131
|
+
for artifact in wu.artifacts:
|
|
132
|
+
artifacts.append(ArtifactModel(**artifact.dict()))
|
|
133
|
+
wus.append(
|
|
134
|
+
WorkUnitModel(
|
|
135
|
+
wid=wu.wid,
|
|
136
|
+
identity=wu.identity,
|
|
137
|
+
args=wu.args,
|
|
138
|
+
jobs=jobs,
|
|
139
|
+
artifacts=artifacts,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
artifacts = []
|
|
144
|
+
for artifact in experiment.artifacts:
|
|
145
|
+
artifacts.append(ArtifactModel(**artifact.dict()))
|
|
146
|
+
|
|
147
|
+
return ExperimentModel(
|
|
148
|
+
title=experiment.title,
|
|
149
|
+
description=experiment.description,
|
|
150
|
+
note=experiment.note,
|
|
151
|
+
tags=experiment.tags,
|
|
152
|
+
work_units=wus,
|
|
153
|
+
artifacts=artifacts,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
157
|
+
from xm_slurm_api_client.api.experiment import ( # type: ignore
|
|
158
|
+
delete_experiment as _delete_experiment,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
_delete_experiment.sync(experiment_id, client=self.client)
|
|
162
|
+
|
|
163
|
+
def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
|
|
164
|
+
from xm_slurm_api_client.api.experiment import ( # type: ignore
|
|
165
|
+
insert_experiment as _insert_experiment,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
assert experiment.title is not None, "Title must be set in the experiment model."
|
|
169
|
+
assert (
|
|
170
|
+
experiment.description is None and experiment.note is None and experiment.tags is None
|
|
171
|
+
), "Only title should be set in the experiment model."
|
|
172
|
+
experiment_response = _insert_experiment.sync(
|
|
173
|
+
client=self.client,
|
|
174
|
+
body=self.models.Experiment(title=experiment.title),
|
|
175
|
+
)
|
|
176
|
+
return typing.cast(int, experiment_response["xid"]) # type: ignore
|
|
177
|
+
|
|
178
|
+
def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
|
|
179
|
+
from xm_slurm_api_client.api.experiment import ( # type: ignore
|
|
180
|
+
update_experiment as _update_experiment,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
m = self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch))
|
|
184
|
+
|
|
185
|
+
_update_experiment.sync(
|
|
186
|
+
experiment_id,
|
|
187
|
+
client=self.client,
|
|
188
|
+
body=self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch)),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
|
|
192
|
+
from xm_slurm_api_client.api.job import insert_job as _insert_job # type: ignore
|
|
193
|
+
|
|
194
|
+
_insert_job.sync(
|
|
195
|
+
experiment_id,
|
|
196
|
+
work_unit_id,
|
|
197
|
+
client=self.client,
|
|
198
|
+
body=self.models.SlurmJob(**dataclasses.asdict(job)),
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
|
|
202
|
+
from xm_slurm_api_client.api.work_unit import ( # type: ignore
|
|
203
|
+
insert_work_unit as _insert_work_unit,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
_insert_work_unit.sync(
|
|
207
|
+
experiment_id,
|
|
208
|
+
client=self.client,
|
|
209
|
+
body=self.models.WorkUnit(**dataclasses.asdict(work_unit)),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
213
|
+
from xm_slurm_api_client.api.artifact import ( # type: ignore
|
|
214
|
+
delete_work_unit_artifact as _delete_work_unit_artifact,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
_delete_work_unit_artifact.sync(experiment_id, work_unit_id, name, client=self.client)
|
|
218
|
+
|
|
219
|
+
def insert_work_unit_artifact(
|
|
220
|
+
self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
|
|
221
|
+
) -> None:
|
|
222
|
+
from xm_slurm_api_client.api.artifact import ( # type: ignore
|
|
223
|
+
insert_work_unit_artifact as _insert_work_unit_artifact,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
_insert_work_unit_artifact.sync(
|
|
227
|
+
experiment_id,
|
|
228
|
+
work_unit_id,
|
|
229
|
+
client=self.client,
|
|
230
|
+
body=self.models.Artifact(**dataclasses.asdict(artifact)),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None: ...
|
|
234
|
+
|
|
235
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
|
|
236
|
+
from xm_slurm_api_client.api.artifact import ( # type: ignore
|
|
237
|
+
insert_experiment_artifact as _insert_experiment_artifact,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
_insert_experiment_artifact.sync(
|
|
241
|
+
experiment_id,
|
|
242
|
+
client=self.client,
|
|
243
|
+
body=self.models.Artifact(**dataclasses.asdict(artifact)),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@functools.cache
|
|
248
|
+
def client() -> XManagerAPI:
|
|
249
|
+
if importlib.util.find_spec("xm_slurm_api_client") is not None:
|
|
250
|
+
if (base_url := os.environ.get("XM_SLURM_API_BASE_URL")) is not None and (
|
|
251
|
+
token := os.environ.get("XM_SLURM_API_TOKEN")
|
|
252
|
+
) is not None:
|
|
253
|
+
return XManagerWebAPI(base_url=base_url, token=token)
|
|
254
|
+
else:
|
|
255
|
+
logger.warn(
|
|
256
|
+
"XM_SLURM_API_BASE_URL and XM_SLURM_API_TOKEN not set. "
|
|
257
|
+
"Disabling XManager API client."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
logger.debug("xm_slurm_api_client not found... skipping logging to the API.")
|
|
261
|
+
return XManagerAPI()
|
xm_slurm/batching.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections
|
|
3
|
+
import dataclasses
|
|
4
|
+
import inspect
|
|
5
|
+
import time
|
|
6
|
+
import types
|
|
7
|
+
from typing import Any, Callable, Coroutine, Generic, ParamSpec, Sequence, TypeVar
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", contravariant=True)
|
|
10
|
+
R = TypeVar("R", covariant=True)
|
|
11
|
+
P = ParamSpec("P")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
15
|
+
class Request:
|
|
16
|
+
args: inspect.BoundArguments
|
|
17
|
+
future: asyncio.Future
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def stack_bound_arguments(
|
|
21
|
+
signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
|
|
22
|
+
) -> inspect.BoundArguments:
|
|
23
|
+
"""Stacks bound arguments into a single bound arguments object."""
|
|
24
|
+
stacked_args = collections.OrderedDict()
|
|
25
|
+
for bound_args in bound_arguments:
|
|
26
|
+
for name, value in bound_args.arguments.items():
|
|
27
|
+
stacked_args.setdefault(name, [])
|
|
28
|
+
stacked_args[name].append(value)
|
|
29
|
+
return inspect.BoundArguments(signature, stacked_args)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class batch(Generic[R]):
|
|
33
|
+
__slots__ = (
|
|
34
|
+
"fn",
|
|
35
|
+
"signature",
|
|
36
|
+
"max_batch_size",
|
|
37
|
+
"batch_timeout",
|
|
38
|
+
"loop",
|
|
39
|
+
"process_batch_task",
|
|
40
|
+
"queue",
|
|
41
|
+
)
|
|
42
|
+
__name__: str = "batch"
|
|
43
|
+
__qualname__: str = "batch"
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
fn: Callable[..., Coroutine[None, None, Sequence[R]]],
|
|
48
|
+
/,
|
|
49
|
+
*,
|
|
50
|
+
max_batch_size: int,
|
|
51
|
+
batch_timeout: float,
|
|
52
|
+
) -> None:
|
|
53
|
+
self.fn = fn
|
|
54
|
+
self.signature = inspect.signature(fn)
|
|
55
|
+
|
|
56
|
+
self.max_batch_size = max_batch_size
|
|
57
|
+
self.batch_timeout = batch_timeout
|
|
58
|
+
|
|
59
|
+
self.loop: asyncio.AbstractEventLoop | None = None
|
|
60
|
+
self.process_batch_task: asyncio.Task | None = None
|
|
61
|
+
|
|
62
|
+
self.queue = asyncio.Queue()
|
|
63
|
+
|
|
64
|
+
async def _process_batch(self):
|
|
65
|
+
assert self.loop is not None
|
|
66
|
+
while not self.loop.is_closed():
|
|
67
|
+
batch = await self._wait_for_batch()
|
|
68
|
+
assert len(batch) > 0
|
|
69
|
+
|
|
70
|
+
bound_args = stack_bound_arguments(self.signature, [request.args for request in batch])
|
|
71
|
+
futures = [request.future for request in batch]
|
|
72
|
+
|
|
73
|
+
results_future = self.fn(*bound_args.args, *bound_args.kwargs)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
results = await results_future
|
|
77
|
+
|
|
78
|
+
for result, future in zip(results, futures):
|
|
79
|
+
future.set_result(result)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
for future in futures:
|
|
82
|
+
future.set_exception(e)
|
|
83
|
+
|
|
84
|
+
async def _wait_for_batch(self) -> list[Request]:
|
|
85
|
+
batch = [await self.queue.get()]
|
|
86
|
+
|
|
87
|
+
batch_start_time = time.time()
|
|
88
|
+
while True:
|
|
89
|
+
remaining_batch_time_s = max(self.batch_timeout - (time.time() - batch_start_time), 0)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
request = await asyncio.wait_for(self.queue.get(), timeout=remaining_batch_time_s)
|
|
93
|
+
batch.append(request)
|
|
94
|
+
except asyncio.TimeoutError:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
if (
|
|
98
|
+
time.time() - batch_start_time >= self.batch_timeout
|
|
99
|
+
or len(batch) >= self.max_batch_size
|
|
100
|
+
):
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
return batch
|
|
104
|
+
|
|
105
|
+
def __get__(self, obj: Any, objtype: type) -> Any:
|
|
106
|
+
del objtype
|
|
107
|
+
if isinstance(self.fn, staticmethod):
|
|
108
|
+
return self.__call__
|
|
109
|
+
else:
|
|
110
|
+
return types.MethodType(self, obj)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def __func__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
|
|
114
|
+
return self.fn
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def __wrapped__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
|
|
118
|
+
return self.fn
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def __isabstractmethod__(self) -> bool:
|
|
122
|
+
"""Return whether the wrapped function is abstract."""
|
|
123
|
+
return getattr(self.fn, "__isabstractmethod__", False)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def _is_coroutine(self) -> bool:
|
|
127
|
+
# TODO(jfarebro): py312 adds inspect.markcoroutinefunction
|
|
128
|
+
# until then this is just a hack
|
|
129
|
+
return asyncio.coroutines._is_coroutine # type: ignore
|
|
130
|
+
|
|
131
|
+
async def __call__(self, *args, **kwargs) -> asyncio.Future[R]:
|
|
132
|
+
if self.loop is None and self.process_batch_task is None:
|
|
133
|
+
self.loop = asyncio.get_event_loop()
|
|
134
|
+
self.process_batch_task = self.loop.create_task(self._process_batch())
|
|
135
|
+
|
|
136
|
+
future = asyncio.Future()
|
|
137
|
+
bound_args = self.signature.bind(*args, **kwargs)
|
|
138
|
+
self.queue.put_nowait(Request(args=bound_args, future=future))
|
|
139
|
+
return await future
|
xm_slurm/config.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import enum
|
|
3
|
+
import functools
|
|
4
|
+
import getpass
|
|
5
|
+
import os
|
|
6
|
+
import pathlib
|
|
7
|
+
from typing import Literal, Mapping, NamedTuple
|
|
8
|
+
|
|
9
|
+
import asyncssh
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ContainerRuntime(enum.Enum):
|
|
13
|
+
"""The container engine to use."""
|
|
14
|
+
|
|
15
|
+
SINGULARITY = enum.auto()
|
|
16
|
+
APPTAINER = enum.auto()
|
|
17
|
+
DOCKER = enum.auto()
|
|
18
|
+
PODMAN = enum.auto()
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def from_string(
|
|
22
|
+
cls, runtime: Literal["singularity", "apptainer", "docker", "podman"]
|
|
23
|
+
) -> "ContainerRuntime":
|
|
24
|
+
return {
|
|
25
|
+
"singularity": cls.SINGULARITY,
|
|
26
|
+
"apptainer": cls.APPTAINER,
|
|
27
|
+
"docker": cls.DOCKER,
|
|
28
|
+
"podman": cls.PODMAN,
|
|
29
|
+
}[runtime]
|
|
30
|
+
|
|
31
|
+
def __str__(self):
|
|
32
|
+
if self is self.SINGULARITY:
|
|
33
|
+
return "singularity"
|
|
34
|
+
elif self is self.APPTAINER:
|
|
35
|
+
return "apptainer"
|
|
36
|
+
elif self is self.DOCKER:
|
|
37
|
+
return "docker"
|
|
38
|
+
elif self is self.PODMAN:
|
|
39
|
+
return "podman"
|
|
40
|
+
else:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PublicKey(NamedTuple):
|
|
45
|
+
algorithm: str
|
|
46
|
+
key: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
50
|
+
class SlurmClusterConfig:
|
|
51
|
+
name: str
|
|
52
|
+
|
|
53
|
+
host: str
|
|
54
|
+
host_public_key: PublicKey | None = None
|
|
55
|
+
user: str | None = None
|
|
56
|
+
port: int | None = None
|
|
57
|
+
|
|
58
|
+
# Job submission directory
|
|
59
|
+
cwd: str | None = None
|
|
60
|
+
|
|
61
|
+
# Additional scripting
|
|
62
|
+
prolog: str | None = None
|
|
63
|
+
epilog: str | None = None
|
|
64
|
+
|
|
65
|
+
# Job scheduling
|
|
66
|
+
account: str | None = None
|
|
67
|
+
partition: str | None = None
|
|
68
|
+
qos: str | None = None
|
|
69
|
+
|
|
70
|
+
# If true, a reverse proxy is initiated via the submission host.
|
|
71
|
+
proxy: Literal["submission-host"] | str | None = None
|
|
72
|
+
|
|
73
|
+
runtime: ContainerRuntime
|
|
74
|
+
|
|
75
|
+
# Environment variables
|
|
76
|
+
environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
77
|
+
|
|
78
|
+
# Mounts
|
|
79
|
+
mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
|
|
80
|
+
default_factory=dict
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Resource mapping
|
|
84
|
+
resources: Mapping[str, "xm_slurm.ResourceType"] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
85
|
+
|
|
86
|
+
def __post_init__(self) -> None:
|
|
87
|
+
for src, dst in self.mounts.items():
|
|
88
|
+
if not isinstance(src, (str, os.PathLike)):
|
|
89
|
+
raise TypeError(
|
|
90
|
+
f"Mount source must be a string or path-like object, not {type(src)}"
|
|
91
|
+
)
|
|
92
|
+
if not isinstance(dst, (str, os.PathLike)):
|
|
93
|
+
raise TypeError(
|
|
94
|
+
f"Mount destination must be a string or path-like object, not {type(dst)}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if not pathlib.Path(src).is_absolute():
|
|
98
|
+
raise ValueError(f"Mount source must be an absolute path: {src}")
|
|
99
|
+
if not pathlib.Path(dst).is_absolute():
|
|
100
|
+
raise ValueError(f"Mount destination must be an absolute path: {dst}")
|
|
101
|
+
|
|
102
|
+
@functools.cached_property
|
|
103
|
+
def ssh_known_hosts(self) -> asyncssh.SSHKnownHosts | None:
|
|
104
|
+
if self.host_public_key is None:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
return asyncssh.import_known_hosts(
|
|
108
|
+
f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@functools.cached_property
|
|
112
|
+
def ssh_config(self) -> asyncssh.config.SSHConfig:
|
|
113
|
+
ssh_config_paths = []
|
|
114
|
+
if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
|
|
115
|
+
ssh_config_paths.append(ssh_config)
|
|
116
|
+
if (xm_ssh_config := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
|
|
117
|
+
xm_ssh_config := pathlib.Path(xm_ssh_config).expanduser()
|
|
118
|
+
).exists():
|
|
119
|
+
ssh_config_paths.append(xm_ssh_config)
|
|
120
|
+
|
|
121
|
+
config = asyncssh.config.SSHClientConfig.load(
|
|
122
|
+
None,
|
|
123
|
+
ssh_config_paths,
|
|
124
|
+
True,
|
|
125
|
+
getpass.getuser(),
|
|
126
|
+
self.user or (),
|
|
127
|
+
self.host or (),
|
|
128
|
+
self.port or (),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if config.get("Hostname") is None:
|
|
132
|
+
raise RuntimeError(
|
|
133
|
+
f"Failed to parse hostname from host `{self.host}` using SSH configs: {', '.join(map(str, ssh_config_paths))}"
|
|
134
|
+
)
|
|
135
|
+
if config.get("User") is None:
|
|
136
|
+
raise RuntimeError(
|
|
137
|
+
f"Failed to parse user from SSH configs: {', '.join(map(str, ssh_config_paths))}"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return config
|
|
141
|
+
|
|
142
|
+
@functools.cached_property
|
|
143
|
+
def ssh_connection_options(self) -> asyncssh.SSHClientConnectionOptions:
|
|
144
|
+
options = asyncssh.SSHClientConnectionOptions(config=None)
|
|
145
|
+
options.prepare(last_config=self.ssh_config, known_hosts=self.ssh_known_hosts)
|
|
146
|
+
return options
|
|
147
|
+
|
|
148
|
+
def __hash__(self):
|
|
149
|
+
return hash((
|
|
150
|
+
self.host,
|
|
151
|
+
self.user,
|
|
152
|
+
self.port,
|
|
153
|
+
self.cwd,
|
|
154
|
+
self.prolog,
|
|
155
|
+
self.epilog,
|
|
156
|
+
self.account,
|
|
157
|
+
self.partition,
|
|
158
|
+
self.qos,
|
|
159
|
+
self.proxy,
|
|
160
|
+
self.runtime,
|
|
161
|
+
frozenset(self.environment.items()),
|
|
162
|
+
))
|
xm_slurm/console.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from xm_slurm import config, resources
|
|
4
|
+
from xm_slurm.contrib.clusters import drac
|
|
5
|
+
|
|
6
|
+
# ComputeCanada alias
|
|
7
|
+
cc = drac
|
|
8
|
+
|
|
9
|
+
__all__ = ["drac", "mila", "cc"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def mila(
|
|
13
|
+
*,
|
|
14
|
+
user: str | None = None,
|
|
15
|
+
partition: str | None = None,
|
|
16
|
+
mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
17
|
+
) -> config.SlurmClusterConfig:
|
|
18
|
+
"""Mila Cluster (https://docs.mila.quebec/)."""
|
|
19
|
+
if mounts is None:
|
|
20
|
+
mounts = {
|
|
21
|
+
"/network/scratch/${USER:0:1}/$USER": "/scratch",
|
|
22
|
+
"/network/archive/${USER:0:1}/$USER": "/archive",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
return config.SlurmClusterConfig(
|
|
26
|
+
name="mila",
|
|
27
|
+
user=user,
|
|
28
|
+
host="login.server.mila.quebec",
|
|
29
|
+
host_public_key=config.PublicKey(
|
|
30
|
+
"ssh-ed25519",
|
|
31
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
|
|
32
|
+
),
|
|
33
|
+
port=2222,
|
|
34
|
+
runtime=config.ContainerRuntime.SINGULARITY,
|
|
35
|
+
partition=partition,
|
|
36
|
+
prolog="module load singularity",
|
|
37
|
+
environment={
|
|
38
|
+
"SINGULARITY_CACHEDIR": "$SCRATCH/.apptainer",
|
|
39
|
+
"SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
|
|
40
|
+
"SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
41
|
+
"SCRATCH": "/scratch",
|
|
42
|
+
"ARCHIVE": "/archive",
|
|
43
|
+
},
|
|
44
|
+
mounts=mounts,
|
|
45
|
+
resources={
|
|
46
|
+
"rtx8000": resources.ResourceType.RTX8000,
|
|
47
|
+
"v100": resources.ResourceType.V100,
|
|
48
|
+
"a100": resources.ResourceType.A100,
|
|
49
|
+
"a100l": resources.ResourceType.A100_80GIB,
|
|
50
|
+
"a6000": resources.ResourceType.A6000,
|
|
51
|
+
},
|
|
52
|
+
)
|