xmanager-slurm 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (38) hide show
  1. xm_slurm/__init__.py +44 -0
  2. xm_slurm/api.py +261 -0
  3. xm_slurm/batching.py +139 -0
  4. xm_slurm/config.py +162 -0
  5. xm_slurm/console.py +3 -0
  6. xm_slurm/contrib/clusters/__init__.py +52 -0
  7. xm_slurm/contrib/clusters/drac.py +169 -0
  8. xm_slurm/executables.py +201 -0
  9. xm_slurm/execution.py +491 -0
  10. xm_slurm/executors.py +127 -0
  11. xm_slurm/experiment.py +737 -0
  12. xm_slurm/job_blocks.py +14 -0
  13. xm_slurm/packageables.py +292 -0
  14. xm_slurm/packaging/__init__.py +8 -0
  15. xm_slurm/packaging/docker/__init__.py +75 -0
  16. xm_slurm/packaging/docker/abc.py +112 -0
  17. xm_slurm/packaging/docker/cloud.py +503 -0
  18. xm_slurm/packaging/docker/local.py +206 -0
  19. xm_slurm/packaging/registry.py +45 -0
  20. xm_slurm/packaging/router.py +52 -0
  21. xm_slurm/packaging/utils.py +202 -0
  22. xm_slurm/resources.py +150 -0
  23. xm_slurm/status.py +188 -0
  24. xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
  25. xm_slurm/templates/docker/mamba.Dockerfile +27 -0
  26. xm_slurm/templates/docker/pdm.Dockerfile +31 -0
  27. xm_slurm/templates/docker/python.Dockerfile +24 -0
  28. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
  29. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  30. xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
  31. xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
  32. xm_slurm/templates/slurm/job.bash.j2 +78 -0
  33. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
  34. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
  35. xm_slurm/utils.py +69 -0
  36. xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
  37. xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
  38. xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/__init__.py ADDED
@@ -0,0 +1,44 @@
1
+ import logging
2
+
3
+ from xm_slurm.executables import Dockerfile, DockerImage
4
+ from xm_slurm.executors import Slurm, SlurmSpec
5
+ from xm_slurm.experiment import (
6
+ Artifact,
7
+ create_experiment,
8
+ get_current_experiment,
9
+ get_current_work_unit,
10
+ get_experiment,
11
+ )
12
+ from xm_slurm.packageables import (
13
+ conda_container,
14
+ docker_container,
15
+ docker_image,
16
+ mamba_container,
17
+ pdm_container,
18
+ python_container,
19
+ )
20
+ from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
21
+
22
+ logging.getLogger("asyncssh").setLevel(logging.WARN)
23
+ logging.getLogger("httpx").setLevel(logging.WARN)
24
+
25
+ __all__ = [
26
+ "Artifact",
27
+ "conda_container",
28
+ "create_experiment",
29
+ "docker_container",
30
+ "docker_image",
31
+ "Dockerfile",
32
+ "DockerImage",
33
+ "get_current_experiment",
34
+ "get_current_work_unit",
35
+ "get_experiment",
36
+ "JobRequirements",
37
+ "mamba_container",
38
+ "pdm_container",
39
+ "python_container",
40
+ "ResourceQuantity",
41
+ "ResourceType",
42
+ "Slurm",
43
+ "SlurmSpec",
44
+ ]
xm_slurm/api.py ADDED
@@ -0,0 +1,261 @@
1
+ import dataclasses
2
+ import functools
3
+ import importlib.util
4
+ import logging
5
+ import os
6
+ import time
7
+ import typing
8
+ from typing import Any
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclasses.dataclass(kw_only=True, frozen=True)
14
+ class ExperimentPatchModel:
15
+ title: str | None = None
16
+ description: str | None = None
17
+ note: str | None = None
18
+ tags: list[str] | None = None
19
+
20
+
21
+ @dataclasses.dataclass(kw_only=True, frozen=True)
22
+ class SlurmJobModel:
23
+ name: str
24
+ slurm_job_id: int
25
+ slurm_cluster: str
26
+
27
+
28
+ @dataclasses.dataclass(kw_only=True, frozen=True)
29
+ class ArtifactModel:
30
+ name: str
31
+ uri: str
32
+
33
+
34
+ @dataclasses.dataclass(kw_only=True, frozen=True)
35
+ class WorkUnitPatchModel:
36
+ wid: int
37
+ identity: str | None
38
+ args: str | None = None
39
+
40
+
41
+ @dataclasses.dataclass(kw_only=True, frozen=True)
42
+ class WorkUnitModel(WorkUnitPatchModel):
43
+ jobs: list[SlurmJobModel] = dataclasses.field(default_factory=list)
44
+ artifacts: list[ArtifactModel] = dataclasses.field(default_factory=list)
45
+
46
+
47
+ @dataclasses.dataclass(kw_only=True, frozen=True)
48
+ class ExperimentModel:
49
+ title: str
50
+ description: str | None
51
+ note: str | None
52
+ tags: list[str] | None
53
+
54
+ work_units: list[WorkUnitModel]
55
+ artifacts: list[ArtifactModel]
56
+
57
+
58
+ class XManagerAPI:
59
+ def get_experiment(self, xid: int) -> ExperimentModel:
60
+ del xid
61
+ raise NotImplementedError("`get_experiment` is not implemented without a storage backend.")
62
+
63
+ def delete_experiment(self, experiment_id: int) -> None:
64
+ del experiment_id
65
+ logger.debug("`delete_experiment` is not implemented without a storage backend.")
66
+
67
+ def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
68
+ del experiment
69
+ logger.debug("`insert_experiment` is not implemented without a storage backend.")
70
+ return int(time.time() * 10**3)
71
+
72
+ def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
73
+ del experiment_id, experiment_patch
74
+ logger.debug("`update_experiment` is not implemented without a storage backend.")
75
+
76
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
77
+ del experiment_id, work_unit_id, job
78
+ logger.debug("`insert_job` is not implemented without a storage backend.")
79
+
80
+ def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
81
+ del experiment_id, work_unit
82
+ logger.debug("`insert_work_unit` is not implemented without a storage backend.")
83
+
84
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
85
+ del experiment_id, work_unit_id, name
86
+ logger.debug("`delete_work_unit_artifact` is not implemented without a storage backend.")
87
+
88
+ def insert_work_unit_artifact(
89
+ self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
90
+ ) -> None:
91
+ del experiment_id, work_unit_id, artifact
92
+ logger.debug("`insert_work_unit_artifact` is not implemented without a storage backend.")
93
+
94
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
95
+ del experiment_id, name
96
+ logger.debug("`delete_experiment_artifact` is not implemented without a storage backend.")
97
+
98
+ def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
99
+ del experiment_id, artifact
100
+ logger.debug("`insert_experiment_artifact` is not implemented without a storage backend.")
101
+
102
+
103
+ class XManagerWebAPI(XManagerAPI):
104
+ def __init__(self, base_url: str, token: str):
105
+ if importlib.util.find_spec("xm_slurm_api_client") is None:
106
+ raise ImportError("xm_slurm_api_client not found.")
107
+
108
+ from xm_slurm_api_client import AuthenticatedClient # type: ignore
109
+ from xm_slurm_api_client import models as _models # type: ignore
110
+
111
+ self.models = _models
112
+ self.client = AuthenticatedClient(
113
+ base_url,
114
+ token=token,
115
+ raise_on_unexpected_status=True,
116
+ verify_ssl=False,
117
+ )
118
+
119
+ def get_experiment(self, xid: int) -> ExperimentModel:
120
+ from xm_slurm_api_client.api.experiment import ( # type: ignore
121
+ get_experiment as _get_experiment,
122
+ )
123
+
124
+ experiment: Any = _get_experiment.sync(xid, client=self.client) # type: ignore
125
+ wus = []
126
+ for wu in experiment.work_units:
127
+ jobs = []
128
+ for job in wu.jobs:
129
+ jobs.append(SlurmJobModel(**job.dict()))
130
+ artifacts = []
131
+ for artifact in wu.artifacts:
132
+ artifacts.append(ArtifactModel(**artifact.dict()))
133
+ wus.append(
134
+ WorkUnitModel(
135
+ wid=wu.wid,
136
+ identity=wu.identity,
137
+ args=wu.args,
138
+ jobs=jobs,
139
+ artifacts=artifacts,
140
+ )
141
+ )
142
+
143
+ artifacts = []
144
+ for artifact in experiment.artifacts:
145
+ artifacts.append(ArtifactModel(**artifact.dict()))
146
+
147
+ return ExperimentModel(
148
+ title=experiment.title,
149
+ description=experiment.description,
150
+ note=experiment.note,
151
+ tags=experiment.tags,
152
+ work_units=wus,
153
+ artifacts=artifacts,
154
+ )
155
+
156
+ def delete_experiment(self, experiment_id: int) -> None:
157
+ from xm_slurm_api_client.api.experiment import ( # type: ignore
158
+ delete_experiment as _delete_experiment,
159
+ )
160
+
161
+ _delete_experiment.sync(experiment_id, client=self.client)
162
+
163
+ def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
164
+ from xm_slurm_api_client.api.experiment import ( # type: ignore
165
+ insert_experiment as _insert_experiment,
166
+ )
167
+
168
+ assert experiment.title is not None, "Title must be set in the experiment model."
169
+ assert (
170
+ experiment.description is None and experiment.note is None and experiment.tags is None
171
+ ), "Only title should be set in the experiment model."
172
+ experiment_response = _insert_experiment.sync(
173
+ client=self.client,
174
+ body=self.models.Experiment(title=experiment.title),
175
+ )
176
+ return typing.cast(int, experiment_response["xid"]) # type: ignore
177
+
178
+ def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
179
+ from xm_slurm_api_client.api.experiment import ( # type: ignore
180
+ update_experiment as _update_experiment,
181
+ )
182
+
183
+ m = self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch))
184
+
185
+ _update_experiment.sync(
186
+ experiment_id,
187
+ client=self.client,
188
+ body=self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch)),
189
+ )
190
+
191
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
192
+ from xm_slurm_api_client.api.job import insert_job as _insert_job # type: ignore
193
+
194
+ _insert_job.sync(
195
+ experiment_id,
196
+ work_unit_id,
197
+ client=self.client,
198
+ body=self.models.SlurmJob(**dataclasses.asdict(job)),
199
+ )
200
+
201
+ def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
202
+ from xm_slurm_api_client.api.work_unit import ( # type: ignore
203
+ insert_work_unit as _insert_work_unit,
204
+ )
205
+
206
+ _insert_work_unit.sync(
207
+ experiment_id,
208
+ client=self.client,
209
+ body=self.models.WorkUnit(**dataclasses.asdict(work_unit)),
210
+ )
211
+
212
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
213
+ from xm_slurm_api_client.api.artifact import ( # type: ignore
214
+ delete_work_unit_artifact as _delete_work_unit_artifact,
215
+ )
216
+
217
+ _delete_work_unit_artifact.sync(experiment_id, work_unit_id, name, client=self.client)
218
+
219
+ def insert_work_unit_artifact(
220
+ self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
221
+ ) -> None:
222
+ from xm_slurm_api_client.api.artifact import ( # type: ignore
223
+ insert_work_unit_artifact as _insert_work_unit_artifact,
224
+ )
225
+
226
+ _insert_work_unit_artifact.sync(
227
+ experiment_id,
228
+ work_unit_id,
229
+ client=self.client,
230
+ body=self.models.Artifact(**dataclasses.asdict(artifact)),
231
+ )
232
+
233
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None: ...
234
+
235
+ def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
236
+ from xm_slurm_api_client.api.artifact import ( # type: ignore
237
+ insert_experiment_artifact as _insert_experiment_artifact,
238
+ )
239
+
240
+ _insert_experiment_artifact.sync(
241
+ experiment_id,
242
+ client=self.client,
243
+ body=self.models.Artifact(**dataclasses.asdict(artifact)),
244
+ )
245
+
246
+
247
+ @functools.cache
248
+ def client() -> XManagerAPI:
249
+ if importlib.util.find_spec("xm_slurm_api_client") is not None:
250
+ if (base_url := os.environ.get("XM_SLURM_API_BASE_URL")) is not None and (
251
+ token := os.environ.get("XM_SLURM_API_TOKEN")
252
+ ) is not None:
253
+ return XManagerWebAPI(base_url=base_url, token=token)
254
+ else:
255
+ logger.warn(
256
+ "XM_SLURM_API_BASE_URL and XM_SLURM_API_TOKEN not set. "
257
+ "Disabling XManager API client."
258
+ )
259
+
260
+ logger.debug("xm_slurm_api_client not found... skipping logging to the API.")
261
+ return XManagerAPI()
xm_slurm/batching.py ADDED
@@ -0,0 +1,139 @@
1
+ import asyncio
2
+ import collections
3
+ import dataclasses
4
+ import inspect
5
+ import time
6
+ import types
7
+ from typing import Any, Callable, Coroutine, Generic, ParamSpec, Sequence, TypeVar
8
+
9
+ T = TypeVar("T", contravariant=True)
10
+ R = TypeVar("R", covariant=True)
11
+ P = ParamSpec("P")
12
+
13
+
14
+ @dataclasses.dataclass(frozen=True, kw_only=True)
15
+ class Request:
16
+ args: inspect.BoundArguments
17
+ future: asyncio.Future
18
+
19
+
20
+ def stack_bound_arguments(
21
+ signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
22
+ ) -> inspect.BoundArguments:
23
+ """Stacks bound arguments into a single bound arguments object."""
24
+ stacked_args = collections.OrderedDict()
25
+ for bound_args in bound_arguments:
26
+ for name, value in bound_args.arguments.items():
27
+ stacked_args.setdefault(name, [])
28
+ stacked_args[name].append(value)
29
+ return inspect.BoundArguments(signature, stacked_args)
30
+
31
+
32
+ class batch(Generic[R]):
33
+ __slots__ = (
34
+ "fn",
35
+ "signature",
36
+ "max_batch_size",
37
+ "batch_timeout",
38
+ "loop",
39
+ "process_batch_task",
40
+ "queue",
41
+ )
42
+ __name__: str = "batch"
43
+ __qualname__: str = "batch"
44
+
45
+ def __init__(
46
+ self,
47
+ fn: Callable[..., Coroutine[None, None, Sequence[R]]],
48
+ /,
49
+ *,
50
+ max_batch_size: int,
51
+ batch_timeout: float,
52
+ ) -> None:
53
+ self.fn = fn
54
+ self.signature = inspect.signature(fn)
55
+
56
+ self.max_batch_size = max_batch_size
57
+ self.batch_timeout = batch_timeout
58
+
59
+ self.loop: asyncio.AbstractEventLoop | None = None
60
+ self.process_batch_task: asyncio.Task | None = None
61
+
62
+ self.queue = asyncio.Queue()
63
+
64
+ async def _process_batch(self):
65
+ assert self.loop is not None
66
+ while not self.loop.is_closed():
67
+ batch = await self._wait_for_batch()
68
+ assert len(batch) > 0
69
+
70
+ bound_args = stack_bound_arguments(self.signature, [request.args for request in batch])
71
+ futures = [request.future for request in batch]
72
+
73
+ results_future = self.fn(*bound_args.args, *bound_args.kwargs)
74
+
75
+ try:
76
+ results = await results_future
77
+
78
+ for result, future in zip(results, futures):
79
+ future.set_result(result)
80
+ except Exception as e:
81
+ for future in futures:
82
+ future.set_exception(e)
83
+
84
+ async def _wait_for_batch(self) -> list[Request]:
85
+ batch = [await self.queue.get()]
86
+
87
+ batch_start_time = time.time()
88
+ while True:
89
+ remaining_batch_time_s = max(self.batch_timeout - (time.time() - batch_start_time), 0)
90
+
91
+ try:
92
+ request = await asyncio.wait_for(self.queue.get(), timeout=remaining_batch_time_s)
93
+ batch.append(request)
94
+ except asyncio.TimeoutError:
95
+ break
96
+
97
+ if (
98
+ time.time() - batch_start_time >= self.batch_timeout
99
+ or len(batch) >= self.max_batch_size
100
+ ):
101
+ break
102
+
103
+ return batch
104
+
105
+ def __get__(self, obj: Any, objtype: type) -> Any:
106
+ del objtype
107
+ if isinstance(self.fn, staticmethod):
108
+ return self.__call__
109
+ else:
110
+ return types.MethodType(self, obj)
111
+
112
+ @property
113
+ def __func__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
114
+ return self.fn
115
+
116
+ @property
117
+ def __wrapped__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
118
+ return self.fn
119
+
120
+ @property
121
+ def __isabstractmethod__(self) -> bool:
122
+ """Return whether the wrapped function is abstract."""
123
+ return getattr(self.fn, "__isabstractmethod__", False)
124
+
125
+ @property
126
+ def _is_coroutine(self) -> bool:
127
+ # TODO(jfarebro): py312 adds inspect.markcoroutinefunction
128
+ # until then this is just a hack
129
+ return asyncio.coroutines._is_coroutine # type: ignore
130
+
131
+ async def __call__(self, *args, **kwargs) -> asyncio.Future[R]:
132
+ if self.loop is None and self.process_batch_task is None:
133
+ self.loop = asyncio.get_event_loop()
134
+ self.process_batch_task = self.loop.create_task(self._process_batch())
135
+
136
+ future = asyncio.Future()
137
+ bound_args = self.signature.bind(*args, **kwargs)
138
+ self.queue.put_nowait(Request(args=bound_args, future=future))
139
+ return await future
xm_slurm/config.py ADDED
@@ -0,0 +1,162 @@
1
+ import dataclasses
2
+ import enum
3
+ import functools
4
+ import getpass
5
+ import os
6
+ import pathlib
7
+ from typing import Literal, Mapping, NamedTuple
8
+
9
+ import asyncssh
10
+
11
+
12
+ class ContainerRuntime(enum.Enum):
13
+ """The container engine to use."""
14
+
15
+ SINGULARITY = enum.auto()
16
+ APPTAINER = enum.auto()
17
+ DOCKER = enum.auto()
18
+ PODMAN = enum.auto()
19
+
20
+ @classmethod
21
+ def from_string(
22
+ cls, runtime: Literal["singularity", "apptainer", "docker", "podman"]
23
+ ) -> "ContainerRuntime":
24
+ return {
25
+ "singularity": cls.SINGULARITY,
26
+ "apptainer": cls.APPTAINER,
27
+ "docker": cls.DOCKER,
28
+ "podman": cls.PODMAN,
29
+ }[runtime]
30
+
31
+ def __str__(self):
32
+ if self is self.SINGULARITY:
33
+ return "singularity"
34
+ elif self is self.APPTAINER:
35
+ return "apptainer"
36
+ elif self is self.DOCKER:
37
+ return "docker"
38
+ elif self is self.PODMAN:
39
+ return "podman"
40
+ else:
41
+ raise NotImplementedError
42
+
43
+
44
+ class PublicKey(NamedTuple):
45
+ algorithm: str
46
+ key: str
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True, kw_only=True)
50
+ class SlurmClusterConfig:
51
+ name: str
52
+
53
+ host: str
54
+ host_public_key: PublicKey | None = None
55
+ user: str | None = None
56
+ port: int | None = None
57
+
58
+ # Job submission directory
59
+ cwd: str | None = None
60
+
61
+ # Additional scripting
62
+ prolog: str | None = None
63
+ epilog: str | None = None
64
+
65
+ # Job scheduling
66
+ account: str | None = None
67
+ partition: str | None = None
68
+ qos: str | None = None
69
+
70
+ # If true, a reverse proxy is initiated via the submission host.
71
+ proxy: Literal["submission-host"] | str | None = None
72
+
73
+ runtime: ContainerRuntime
74
+
75
+ # Environment variables
76
+ environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
77
+
78
+ # Mounts
79
+ mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
80
+ default_factory=dict
81
+ )
82
+
83
+ # Resource mapping
84
+ resources: Mapping[str, "xm_slurm.ResourceType"] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
85
+
86
+ def __post_init__(self) -> None:
87
+ for src, dst in self.mounts.items():
88
+ if not isinstance(src, (str, os.PathLike)):
89
+ raise TypeError(
90
+ f"Mount source must be a string or path-like object, not {type(src)}"
91
+ )
92
+ if not isinstance(dst, (str, os.PathLike)):
93
+ raise TypeError(
94
+ f"Mount destination must be a string or path-like object, not {type(dst)}"
95
+ )
96
+
97
+ if not pathlib.Path(src).is_absolute():
98
+ raise ValueError(f"Mount source must be an absolute path: {src}")
99
+ if not pathlib.Path(dst).is_absolute():
100
+ raise ValueError(f"Mount destination must be an absolute path: {dst}")
101
+
102
+ @functools.cached_property
103
+ def ssh_known_hosts(self) -> asyncssh.SSHKnownHosts | None:
104
+ if self.host_public_key is None:
105
+ return None
106
+
107
+ return asyncssh.import_known_hosts(
108
+ f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
109
+ )
110
+
111
+ @functools.cached_property
112
+ def ssh_config(self) -> asyncssh.config.SSHConfig:
113
+ ssh_config_paths = []
114
+ if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
115
+ ssh_config_paths.append(ssh_config)
116
+ if (xm_ssh_config := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
117
+ xm_ssh_config := pathlib.Path(xm_ssh_config).expanduser()
118
+ ).exists():
119
+ ssh_config_paths.append(xm_ssh_config)
120
+
121
+ config = asyncssh.config.SSHClientConfig.load(
122
+ None,
123
+ ssh_config_paths,
124
+ True,
125
+ getpass.getuser(),
126
+ self.user or (),
127
+ self.host or (),
128
+ self.port or (),
129
+ )
130
+
131
+ if config.get("Hostname") is None:
132
+ raise RuntimeError(
133
+ f"Failed to parse hostname from host `{self.host}` using SSH configs: {', '.join(map(str, ssh_config_paths))}"
134
+ )
135
+ if config.get("User") is None:
136
+ raise RuntimeError(
137
+ f"Failed to parse user from SSH configs: {', '.join(map(str, ssh_config_paths))}"
138
+ )
139
+
140
+ return config
141
+
142
+ @functools.cached_property
143
+ def ssh_connection_options(self) -> asyncssh.SSHClientConnectionOptions:
144
+ options = asyncssh.SSHClientConnectionOptions(config=None)
145
+ options.prepare(last_config=self.ssh_config, known_hosts=self.ssh_known_hosts)
146
+ return options
147
+
148
+ def __hash__(self):
149
+ return hash((
150
+ self.host,
151
+ self.user,
152
+ self.port,
153
+ self.cwd,
154
+ self.prolog,
155
+ self.epilog,
156
+ self.account,
157
+ self.partition,
158
+ self.qos,
159
+ self.proxy,
160
+ self.runtime,
161
+ frozenset(self.environment.items()),
162
+ ))
xm_slurm/console.py ADDED
@@ -0,0 +1,3 @@
1
+ from rich.console import Console
2
+
3
+ console = Console()
@@ -0,0 +1,52 @@
1
+ import os
2
+
3
+ from xm_slurm import config, resources
4
+ from xm_slurm.contrib.clusters import drac
5
+
6
+ # ComputeCanada alias
7
+ cc = drac
8
+
9
+ __all__ = ["drac", "mila", "cc"]
10
+
11
+
12
+ def mila(
13
+ *,
14
+ user: str | None = None,
15
+ partition: str | None = None,
16
+ mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
17
+ ) -> config.SlurmClusterConfig:
18
+ """Mila Cluster (https://docs.mila.quebec/)."""
19
+ if mounts is None:
20
+ mounts = {
21
+ "/network/scratch/${USER:0:1}/$USER": "/scratch",
22
+ "/network/archive/${USER:0:1}/$USER": "/archive",
23
+ }
24
+
25
+ return config.SlurmClusterConfig(
26
+ name="mila",
27
+ user=user,
28
+ host="login.server.mila.quebec",
29
+ host_public_key=config.PublicKey(
30
+ "ssh-ed25519",
31
+ "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
32
+ ),
33
+ port=2222,
34
+ runtime=config.ContainerRuntime.SINGULARITY,
35
+ partition=partition,
36
+ prolog="module load singularity",
37
+ environment={
38
+ "SINGULARITY_CACHEDIR": "$SCRATCH/.apptainer",
39
+ "SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
40
+ "SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
41
+ "SCRATCH": "/scratch",
42
+ "ARCHIVE": "/archive",
43
+ },
44
+ mounts=mounts,
45
+ resources={
46
+ "rtx8000": resources.ResourceType.RTX8000,
47
+ "v100": resources.ResourceType.V100,
48
+ "a100": resources.ResourceType.A100,
49
+ "a100l": resources.ResourceType.A100_80GIB,
50
+ "a6000": resources.ResourceType.A6000,
51
+ },
52
+ )