xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
import backoff
|
|
4
|
+
import httpx
|
|
5
|
+
|
|
6
|
+
from xm_slurm.api import models
|
|
7
|
+
from xm_slurm.api.abc import XManagerAPI
|
|
8
|
+
|
|
9
|
+
# Define which exceptions should trigger a retry
|
|
10
|
+
RETRY_EXCEPTIONS = (
|
|
11
|
+
httpx.ConnectError,
|
|
12
|
+
httpx.ConnectTimeout,
|
|
13
|
+
httpx.ReadTimeout,
|
|
14
|
+
httpx.WriteTimeout,
|
|
15
|
+
httpx.NetworkError,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Common backoff decorator for all API calls
|
|
20
|
+
def with_backoff(f):
|
|
21
|
+
return backoff.on_exception(
|
|
22
|
+
backoff.expo,
|
|
23
|
+
RETRY_EXCEPTIONS,
|
|
24
|
+
max_tries=3, # Maximum number of attempts
|
|
25
|
+
max_time=30, # Maximum total time to try in seconds
|
|
26
|
+
jitter=backoff.full_jitter, # Add jitter to prevent thundering herd
|
|
27
|
+
)(f)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class XManagerWebAPI(XManagerAPI):
|
|
31
|
+
def __init__(self, base_url: str, token: str):
|
|
32
|
+
self.base_url = base_url.rstrip("/")
|
|
33
|
+
self.client = httpx.Client(headers={"Authorization": f"Bearer {token}"}, verify=False)
|
|
34
|
+
|
|
35
|
+
def _make_url(self, path: str) -> str:
|
|
36
|
+
return f"{self.base_url}{path}"
|
|
37
|
+
|
|
38
|
+
@with_backoff
|
|
39
|
+
def get_experiment(self, xid: int) -> models.Experiment:
|
|
40
|
+
response = self.client.get(self._make_url(f"/experiment/{xid}"))
|
|
41
|
+
response.raise_for_status()
|
|
42
|
+
data = response.json()
|
|
43
|
+
# Construct work units with nested jobs and artifacts
|
|
44
|
+
work_units = []
|
|
45
|
+
for wu_data in data.pop("work_units", []):
|
|
46
|
+
# Build jobs for this work unit
|
|
47
|
+
jobs = [
|
|
48
|
+
models.SlurmJob(
|
|
49
|
+
name=job["name"],
|
|
50
|
+
slurm_job_id=job["slurm_job_id"],
|
|
51
|
+
slurm_ssh_config=job["slurm_ssh_config"],
|
|
52
|
+
)
|
|
53
|
+
for job in wu_data.pop("jobs", [])
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
# Build artifacts for this work unit
|
|
57
|
+
artifacts = [
|
|
58
|
+
models.Artifact(name=artifact["name"], uri=artifact["uri"])
|
|
59
|
+
for artifact in wu_data.pop("artifacts", [])
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
# Create work unit with its jobs and artifacts
|
|
63
|
+
wu_data["jobs"] = jobs
|
|
64
|
+
wu_data["artifacts"] = artifacts
|
|
65
|
+
work_units.append(models.WorkUnit(**wu_data))
|
|
66
|
+
|
|
67
|
+
# Build experiment artifacts
|
|
68
|
+
artifacts = [
|
|
69
|
+
models.Artifact(name=artifact["name"], uri=artifact["uri"])
|
|
70
|
+
for artifact in data.pop("artifacts", [])
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
return models.Experiment(**data, work_units=work_units, artifacts=artifacts)
|
|
74
|
+
|
|
75
|
+
@with_backoff
|
|
76
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
77
|
+
response = self.client.delete(self._make_url(f"/experiment/{experiment_id}"))
|
|
78
|
+
response.raise_for_status()
|
|
79
|
+
|
|
80
|
+
@with_backoff
|
|
81
|
+
def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
|
|
82
|
+
assert experiment.title is not None, "Title must be set in the experiment model."
|
|
83
|
+
assert (
|
|
84
|
+
experiment.description is None and experiment.note is None and experiment.tags is None
|
|
85
|
+
), "Only title should be set in the experiment model."
|
|
86
|
+
|
|
87
|
+
response = self.client.put(
|
|
88
|
+
self._make_url("/experiment"), json=dataclasses.asdict(experiment)
|
|
89
|
+
)
|
|
90
|
+
response.raise_for_status()
|
|
91
|
+
return int(response.json()["xid"])
|
|
92
|
+
|
|
93
|
+
@with_backoff
|
|
94
|
+
def update_experiment(
|
|
95
|
+
self, experiment_id: int, experiment_patch: models.ExperimentPatch
|
|
96
|
+
) -> None:
|
|
97
|
+
response = self.client.patch(
|
|
98
|
+
self._make_url(f"/experiment/{experiment_id}"),
|
|
99
|
+
json=dataclasses.asdict(experiment_patch),
|
|
100
|
+
)
|
|
101
|
+
response.raise_for_status()
|
|
102
|
+
|
|
103
|
+
@with_backoff
|
|
104
|
+
def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
|
|
105
|
+
response = self.client.put(
|
|
106
|
+
self._make_url(f"/experiment/{experiment_id}/wu"),
|
|
107
|
+
json=dataclasses.asdict(work_unit),
|
|
108
|
+
)
|
|
109
|
+
response.raise_for_status()
|
|
110
|
+
|
|
111
|
+
@with_backoff
|
|
112
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
|
|
113
|
+
response = self.client.put(
|
|
114
|
+
self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/job"),
|
|
115
|
+
json=dataclasses.asdict(job),
|
|
116
|
+
)
|
|
117
|
+
response.raise_for_status()
|
|
118
|
+
|
|
119
|
+
@with_backoff
|
|
120
|
+
def insert_work_unit_artifact(
|
|
121
|
+
self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
|
|
122
|
+
) -> None:
|
|
123
|
+
response = self.client.put(
|
|
124
|
+
self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact"),
|
|
125
|
+
json=dataclasses.asdict(artifact),
|
|
126
|
+
)
|
|
127
|
+
response.raise_for_status()
|
|
128
|
+
|
|
129
|
+
@with_backoff
|
|
130
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
131
|
+
response = self.client.delete(
|
|
132
|
+
self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact/{name}")
|
|
133
|
+
)
|
|
134
|
+
response.raise_for_status()
|
|
135
|
+
|
|
136
|
+
@with_backoff
|
|
137
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
138
|
+
response = self.client.delete(
|
|
139
|
+
self._make_url(f"/experiment/{experiment_id}/artifact/{name}")
|
|
140
|
+
)
|
|
141
|
+
response.raise_for_status()
|
|
142
|
+
|
|
143
|
+
@with_backoff
|
|
144
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
|
|
145
|
+
response = self.client.put(
|
|
146
|
+
self._make_url(f"/experiment/{experiment_id}/artifact"),
|
|
147
|
+
json=dataclasses.asdict(artifact),
|
|
148
|
+
)
|
|
149
|
+
response.raise_for_status()
|
|
150
|
+
|
|
151
|
+
@with_backoff
|
|
152
|
+
def insert_experiment_config_artifact(
|
|
153
|
+
self, experiment_id: int, artifact: models.ConfigArtifact
|
|
154
|
+
) -> None:
|
|
155
|
+
response = self.client.put(
|
|
156
|
+
self._make_url(f"/experiment/{experiment_id}/config"), json=dataclasses.asdict(artifact)
|
|
157
|
+
)
|
|
158
|
+
response.raise_for_status()
|
|
159
|
+
|
|
160
|
+
@with_backoff
|
|
161
|
+
def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
|
|
162
|
+
response = self.client.delete(self._make_url(f"/experiment/{experiment_id}/config/{name}"))
|
|
163
|
+
response.raise_for_status()
|
|
164
|
+
|
|
165
|
+
@with_backoff
|
|
166
|
+
def update_work_unit(
|
|
167
|
+
self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
|
|
168
|
+
) -> None:
|
|
169
|
+
response = self.client.patch(
|
|
170
|
+
self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}"),
|
|
171
|
+
json=dataclasses.asdict(patch),
|
|
172
|
+
)
|
|
173
|
+
response.raise_for_status()
|
xm_slurm/batching.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import collections
|
|
3
|
+
import dataclasses
|
|
4
|
+
import inspect
|
|
5
|
+
import time
|
|
6
|
+
import types
|
|
7
|
+
import typing as tp
|
|
8
|
+
|
|
9
|
+
T = tp.TypeVar("T", contravariant=True)
|
|
10
|
+
R = tp.TypeVar("R", covariant=True)
|
|
11
|
+
P = tp.ParamSpec("P")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
15
|
+
class Request:
|
|
16
|
+
args: inspect.BoundArguments
|
|
17
|
+
future: asyncio.Future
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def stack_bound_arguments(
|
|
21
|
+
signature: inspect.Signature, bound_arguments: tp.Sequence[inspect.BoundArguments]
|
|
22
|
+
) -> inspect.BoundArguments:
|
|
23
|
+
"""Stacks bound arguments into a single bound arguments object."""
|
|
24
|
+
stacked_args = collections.OrderedDict[str, tp.Any]()
|
|
25
|
+
for bound_args in bound_arguments:
|
|
26
|
+
for name, value in bound_args.arguments.items():
|
|
27
|
+
stacked_args.setdefault(name, [])
|
|
28
|
+
stacked_args[name].append(value)
|
|
29
|
+
return inspect.BoundArguments(signature, stacked_args)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class batch(tp.Generic[R]):
|
|
33
|
+
__slots__ = (
|
|
34
|
+
"fn",
|
|
35
|
+
"signature",
|
|
36
|
+
"max_batch_size",
|
|
37
|
+
"batch_timeout",
|
|
38
|
+
"loop",
|
|
39
|
+
"process_batch_task",
|
|
40
|
+
"queue",
|
|
41
|
+
)
|
|
42
|
+
__name__: str = "batch"
|
|
43
|
+
__qualname__: str = "batch"
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
fn: tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]],
|
|
48
|
+
/,
|
|
49
|
+
*,
|
|
50
|
+
max_batch_size: int,
|
|
51
|
+
batch_timeout: float,
|
|
52
|
+
) -> None:
|
|
53
|
+
self.fn = fn
|
|
54
|
+
self.signature = inspect.signature(fn)
|
|
55
|
+
|
|
56
|
+
self.max_batch_size = max_batch_size
|
|
57
|
+
self.batch_timeout = batch_timeout
|
|
58
|
+
|
|
59
|
+
self.loop: asyncio.AbstractEventLoop | None = None
|
|
60
|
+
self.process_batch_task: asyncio.Task | None = None
|
|
61
|
+
|
|
62
|
+
self.queue = asyncio.Queue[Request]()
|
|
63
|
+
|
|
64
|
+
async def _process_batch(self):
|
|
65
|
+
assert self.loop is not None
|
|
66
|
+
while not self.loop.is_closed():
|
|
67
|
+
batch = await self._wait_for_batch()
|
|
68
|
+
assert len(batch) > 0
|
|
69
|
+
|
|
70
|
+
bound_args = stack_bound_arguments(self.signature, [request.args for request in batch])
|
|
71
|
+
futures = [request.future for request in batch]
|
|
72
|
+
|
|
73
|
+
results_future = self.fn(*bound_args.args, *bound_args.kwargs)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
results = await results_future
|
|
77
|
+
|
|
78
|
+
for result, future in zip(results, futures):
|
|
79
|
+
future.set_result(result)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
for future in futures:
|
|
82
|
+
future.set_exception(e)
|
|
83
|
+
|
|
84
|
+
async def _wait_for_batch(self) -> list[Request]:
|
|
85
|
+
batch = [await self.queue.get()]
|
|
86
|
+
|
|
87
|
+
batch_start_time = time.time()
|
|
88
|
+
while True:
|
|
89
|
+
remaining_batch_time_s = max(self.batch_timeout - (time.time() - batch_start_time), 0)
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
request = await asyncio.wait_for(self.queue.get(), timeout=remaining_batch_time_s)
|
|
93
|
+
batch.append(request)
|
|
94
|
+
except asyncio.TimeoutError:
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
if (
|
|
98
|
+
time.time() - batch_start_time >= self.batch_timeout
|
|
99
|
+
or len(batch) >= self.max_batch_size
|
|
100
|
+
):
|
|
101
|
+
break
|
|
102
|
+
|
|
103
|
+
return batch
|
|
104
|
+
|
|
105
|
+
def __get__(self, obj: tp.Any, objtype: tp.Type[tp.Any]) -> tp.Any:
|
|
106
|
+
del objtype
|
|
107
|
+
if isinstance(self.fn, staticmethod):
|
|
108
|
+
return self.__call__
|
|
109
|
+
else:
|
|
110
|
+
return types.MethodType(self, obj)
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def __func__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
|
|
114
|
+
return self.fn
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def __wrapped__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
|
|
118
|
+
return self.fn
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def __isabstractmethod__(self) -> bool:
|
|
122
|
+
"""Return whether the wrapped function is abstract."""
|
|
123
|
+
return getattr(self.fn, "__isabstractmethod__", False)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def _is_coroutine(self) -> bool:
|
|
127
|
+
# TODO(jfarebro): py312 adds inspect.markcoroutinefunction
|
|
128
|
+
# until then this is just a hack
|
|
129
|
+
return asyncio.coroutines._is_coroutine # type: ignore
|
|
130
|
+
|
|
131
|
+
async def __call__(self, *args, **kwargs) -> R:
|
|
132
|
+
if self.loop is None and self.process_batch_task is None:
|
|
133
|
+
self.loop = asyncio.get_event_loop()
|
|
134
|
+
self.process_batch_task = self.loop.create_task(self._process_batch())
|
|
135
|
+
|
|
136
|
+
future = asyncio.Future[R]()
|
|
137
|
+
bound_args = self.signature.bind(*args, **kwargs)
|
|
138
|
+
self.queue.put_nowait(Request(args=bound_args, future=future))
|
|
139
|
+
return await future
|
xm_slurm/config.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import enum
|
|
3
|
+
import functools
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import pathlib
|
|
7
|
+
import typing as tp
|
|
8
|
+
|
|
9
|
+
import asyncssh
|
|
10
|
+
from xmanager import xm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContainerRuntime(enum.Enum):
|
|
14
|
+
"""The container engine to use."""
|
|
15
|
+
|
|
16
|
+
SINGULARITY = enum.auto()
|
|
17
|
+
APPTAINER = enum.auto()
|
|
18
|
+
DOCKER = enum.auto()
|
|
19
|
+
PODMAN = enum.auto()
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_string(
|
|
23
|
+
cls, runtime: tp.Literal["singularity", "apptainer", "docker", "podman"]
|
|
24
|
+
) -> "ContainerRuntime":
|
|
25
|
+
return {
|
|
26
|
+
"singularity": cls.SINGULARITY,
|
|
27
|
+
"apptainer": cls.APPTAINER,
|
|
28
|
+
"docker": cls.DOCKER,
|
|
29
|
+
"podman": cls.PODMAN,
|
|
30
|
+
}[runtime]
|
|
31
|
+
|
|
32
|
+
def __str__(self):
|
|
33
|
+
if self is self.SINGULARITY:
|
|
34
|
+
return "singularity"
|
|
35
|
+
elif self is self.APPTAINER:
|
|
36
|
+
return "apptainer"
|
|
37
|
+
elif self is self.DOCKER:
|
|
38
|
+
return "docker"
|
|
39
|
+
elif self is self.PODMAN:
|
|
40
|
+
return "podman"
|
|
41
|
+
else:
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Endpoint(tp.NamedTuple):
|
|
46
|
+
hostname: str
|
|
47
|
+
port: int | None
|
|
48
|
+
|
|
49
|
+
def __str__(self) -> str:
|
|
50
|
+
if self.port is None or self.port == asyncssh.DEFAULT_PORT:
|
|
51
|
+
return self.hostname
|
|
52
|
+
return f"[{self.hostname}]:{self.port}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class PublicKey(tp.NamedTuple):
|
|
56
|
+
algorithm: str
|
|
57
|
+
key: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclasses.dataclass
|
|
61
|
+
class SSHConfig:
|
|
62
|
+
endpoints: tuple[Endpoint, ...]
|
|
63
|
+
public_key: PublicKey | None = None
|
|
64
|
+
user: str | None = None
|
|
65
|
+
|
|
66
|
+
def __post_init__(self) -> None:
|
|
67
|
+
if not isinstance(self.endpoints, tuple):
|
|
68
|
+
raise TypeError(f"endpoints must be a tuple, not {type(self.endpoints)}")
|
|
69
|
+
if len(self.endpoints) == 0:
|
|
70
|
+
raise ValueError("endpoints must be a non-empty tuple")
|
|
71
|
+
if not all(isinstance(endpoint, Endpoint) for endpoint in self.endpoints):
|
|
72
|
+
raise TypeError(f"endpoints must be a tuple of strings, not {type(self.endpoints)}")
|
|
73
|
+
|
|
74
|
+
if not isinstance(self.user, str | None):
|
|
75
|
+
raise TypeError(f"user must be a string or None, not {type(self.user)}")
|
|
76
|
+
if not isinstance(self.public_key, PublicKey | None):
|
|
77
|
+
raise TypeError(
|
|
78
|
+
f"public_key must be a SSHPublicKey or None, not {type(self.public_key)}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@functools.cached_property
|
|
82
|
+
def known_hosts(self) -> asyncssh.SSHKnownHosts | None:
|
|
83
|
+
if self.public_key is None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
known_hosts = []
|
|
87
|
+
for endpoint in self.endpoints:
|
|
88
|
+
known_hosts.append(f"{endpoint!s} {self.public_key.algorithm} {self.public_key.key}")
|
|
89
|
+
|
|
90
|
+
return asyncssh.SSHKnownHosts("\n".join(known_hosts))
|
|
91
|
+
|
|
92
|
+
def serialize(self):
|
|
93
|
+
return json.dumps({
|
|
94
|
+
"endpoints": tuple(tuple(endpoint) for endpoint in self.endpoints),
|
|
95
|
+
"public_key": tuple(self.public_key),
|
|
96
|
+
"user": self.user,
|
|
97
|
+
})
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def deserialize(cls, data):
|
|
101
|
+
data = json.loads(data)
|
|
102
|
+
return cls(
|
|
103
|
+
endpoints=tuple(Endpoint(*endpoint) for endpoint in data["endpoints"]),
|
|
104
|
+
public_key=PublicKey(*data["public_key"]) if data["public_key"] else None,
|
|
105
|
+
user=data["user"],
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def __hash__(self):
|
|
109
|
+
return hash((
|
|
110
|
+
type(self),
|
|
111
|
+
*(tuple(endpoint) for endpoint in self.endpoints),
|
|
112
|
+
self.public_key,
|
|
113
|
+
self.user,
|
|
114
|
+
))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
118
|
+
class SlurmClusterConfig:
|
|
119
|
+
name: str
|
|
120
|
+
|
|
121
|
+
ssh: SSHConfig
|
|
122
|
+
|
|
123
|
+
# Job submission directory
|
|
124
|
+
cwd: str | None = None
|
|
125
|
+
|
|
126
|
+
# Additional scripting
|
|
127
|
+
prolog: str | None = None
|
|
128
|
+
epilog: str | None = None
|
|
129
|
+
|
|
130
|
+
# Job scheduling
|
|
131
|
+
account: str | None = None
|
|
132
|
+
partition: str | None = None
|
|
133
|
+
qos: str | None = None
|
|
134
|
+
|
|
135
|
+
# If true, a reverse proxy is initiated via the submission host.
|
|
136
|
+
proxy: tp.Literal["submission-host"] | str | None = None
|
|
137
|
+
|
|
138
|
+
runtime: ContainerRuntime
|
|
139
|
+
|
|
140
|
+
# Environment variables
|
|
141
|
+
host_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
142
|
+
container_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
143
|
+
|
|
144
|
+
# Mounts
|
|
145
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
|
|
146
|
+
default_factory=dict
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Resource mapping
|
|
150
|
+
resources: tp.Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
151
|
+
|
|
152
|
+
features: tp.Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
153
|
+
|
|
154
|
+
# Function to validate the Slurm executor config
|
|
155
|
+
validate: tp.Callable[[xm.Job], None] | None = None
|
|
156
|
+
|
|
157
|
+
def __post_init__(self) -> None:
|
|
158
|
+
if not isinstance(self.ssh, SSHConfig):
|
|
159
|
+
raise TypeError(f"ssh must be a SlurmSSHConfig, not {type(self.ssh)}")
|
|
160
|
+
for src, dst in self.mounts.items():
|
|
161
|
+
if not isinstance(src, (str, os.PathLike)):
|
|
162
|
+
raise TypeError(
|
|
163
|
+
f"Mount source must be a string or path-like object, not {type(src)}"
|
|
164
|
+
)
|
|
165
|
+
if not isinstance(dst, (str, os.PathLike)):
|
|
166
|
+
raise TypeError(
|
|
167
|
+
f"Mount destination must be a string or path-like object, not {type(dst)}"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if not pathlib.Path(src).is_absolute():
|
|
171
|
+
raise ValueError(f"Mount source must be an absolute path: {src}")
|
|
172
|
+
if not pathlib.Path(dst).is_absolute():
|
|
173
|
+
raise ValueError(f"Mount destination must be an absolute path: {dst}")
|
|
174
|
+
|
|
175
|
+
def __hash__(self):
|
|
176
|
+
return hash((
|
|
177
|
+
type(self),
|
|
178
|
+
self.ssh,
|
|
179
|
+
self.cwd,
|
|
180
|
+
self.prolog,
|
|
181
|
+
self.epilog,
|
|
182
|
+
self.account,
|
|
183
|
+
self.partition,
|
|
184
|
+
self.qos,
|
|
185
|
+
self.proxy,
|
|
186
|
+
self.runtime,
|
|
187
|
+
frozenset(self.host_environment.items()),
|
|
188
|
+
frozenset(self.container_environment.items()),
|
|
189
|
+
))
|
xm_slurm/console.py
ADDED
xm_slurm/constants.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
SLURM_JOB_ID_REGEX = re.compile(
|
|
4
|
+
r"^(?P<jobid>\d+)(?:(?:\+(?P<componentid>\d+))|(?:_(?P<arraytaskid>\d+)))?$"
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
IMAGE_URI_REGEX = re.compile(
|
|
8
|
+
r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
DOMAIN_NAME_REGEX = re.compile(
|
|
12
|
+
r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
|
|
13
|
+
)
|
|
14
|
+
IPV4_REGEX = re.compile(
|
|
15
|
+
r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
|
16
|
+
)
|
|
17
|
+
IPV6_REGEX = re.compile(
|
|
18
|
+
r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
|
|
19
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from xm_slurm import config, resources
|
|
5
|
+
from xm_slurm.contrib.clusters import drac
|
|
6
|
+
|
|
7
|
+
# ComputeCanada alias
|
|
8
|
+
cc = drac
|
|
9
|
+
|
|
10
|
+
__all__ = ["drac", "mila", "cc"]
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def mila(
|
|
16
|
+
*,
|
|
17
|
+
user: str | None = None,
|
|
18
|
+
partition: str | None = None,
|
|
19
|
+
mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
20
|
+
) -> config.SlurmClusterConfig:
|
|
21
|
+
"""Mila Cluster (https://docs.mila.quebec/)."""
|
|
22
|
+
if mounts is None:
|
|
23
|
+
mounts = {
|
|
24
|
+
"/network/scratch/${USER:0:1}/$USER": "/scratch",
|
|
25
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
26
|
+
"/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
27
|
+
"/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
|
|
28
|
+
}
|
|
29
|
+
mounts = dict(mounts) | {"/dev/infiniband": "/dev/infiniband"}
|
|
30
|
+
|
|
31
|
+
return config.SlurmClusterConfig(
|
|
32
|
+
name="mila",
|
|
33
|
+
ssh=config.SSHConfig(
|
|
34
|
+
user=user,
|
|
35
|
+
endpoints=(config.Endpoint("login.server.mila.quebec", 2222),),
|
|
36
|
+
public_key=config.PublicKey(
|
|
37
|
+
"ssh-ed25519",
|
|
38
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
|
|
39
|
+
),
|
|
40
|
+
),
|
|
41
|
+
runtime=config.ContainerRuntime.SINGULARITY,
|
|
42
|
+
partition=partition,
|
|
43
|
+
prolog="module load singularity",
|
|
44
|
+
host_environment={
|
|
45
|
+
"SINGULARITY_CACHEDIR": "$SCRATCH/.apptainer",
|
|
46
|
+
"SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
|
|
47
|
+
"SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
48
|
+
},
|
|
49
|
+
container_environment={
|
|
50
|
+
"SCRATCH": "/scratch",
|
|
51
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
52
|
+
},
|
|
53
|
+
mounts=mounts,
|
|
54
|
+
resources={
|
|
55
|
+
resources.ResourceType.RTX8000: "rtx8000",
|
|
56
|
+
resources.ResourceType.V100: "v100",
|
|
57
|
+
resources.ResourceType.A100: "a100",
|
|
58
|
+
resources.ResourceType.A100_80GIB: "a100l",
|
|
59
|
+
resources.ResourceType.A6000: "a6000",
|
|
60
|
+
resources.ResourceType.L40S: "l40s",
|
|
61
|
+
resources.ResourceType.H100: "h100",
|
|
62
|
+
},
|
|
63
|
+
features={
|
|
64
|
+
resources.FeatureType.NVIDIA_MIG: "mig",
|
|
65
|
+
resources.FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
66
|
+
},
|
|
67
|
+
)
|