xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
@@ -0,0 +1,173 @@
1
+ import dataclasses
2
+
3
+ import backoff
4
+ import httpx
5
+
6
+ from xm_slurm.api import models
7
+ from xm_slurm.api.abc import XManagerAPI
8
+
9
+ # Define which exceptions should trigger a retry
10
+ RETRY_EXCEPTIONS = (
11
+ httpx.ConnectError,
12
+ httpx.ConnectTimeout,
13
+ httpx.ReadTimeout,
14
+ httpx.WriteTimeout,
15
+ httpx.NetworkError,
16
+ )
17
+
18
+
19
+ # Common backoff decorator for all API calls
20
+ def with_backoff(f):
21
+ return backoff.on_exception(
22
+ backoff.expo,
23
+ RETRY_EXCEPTIONS,
24
+ max_tries=3, # Maximum number of attempts
25
+ max_time=30, # Maximum total time to try in seconds
26
+ jitter=backoff.full_jitter, # Add jitter to prevent thundering herd
27
+ )(f)
28
+
29
+
30
+ class XManagerWebAPI(XManagerAPI):
31
+ def __init__(self, base_url: str, token: str):
32
+ self.base_url = base_url.rstrip("/")
33
+ self.client = httpx.Client(headers={"Authorization": f"Bearer {token}"}, verify=False)
34
+
35
+ def _make_url(self, path: str) -> str:
36
+ return f"{self.base_url}{path}"
37
+
38
+ @with_backoff
39
+ def get_experiment(self, xid: int) -> models.Experiment:
40
+ response = self.client.get(self._make_url(f"/experiment/{xid}"))
41
+ response.raise_for_status()
42
+ data = response.json()
43
+ # Construct work units with nested jobs and artifacts
44
+ work_units = []
45
+ for wu_data in data.pop("work_units", []):
46
+ # Build jobs for this work unit
47
+ jobs = [
48
+ models.SlurmJob(
49
+ name=job["name"],
50
+ slurm_job_id=job["slurm_job_id"],
51
+ slurm_ssh_config=job["slurm_ssh_config"],
52
+ )
53
+ for job in wu_data.pop("jobs", [])
54
+ ]
55
+
56
+ # Build artifacts for this work unit
57
+ artifacts = [
58
+ models.Artifact(name=artifact["name"], uri=artifact["uri"])
59
+ for artifact in wu_data.pop("artifacts", [])
60
+ ]
61
+
62
+ # Create work unit with its jobs and artifacts
63
+ wu_data["jobs"] = jobs
64
+ wu_data["artifacts"] = artifacts
65
+ work_units.append(models.WorkUnit(**wu_data))
66
+
67
+ # Build experiment artifacts
68
+ artifacts = [
69
+ models.Artifact(name=artifact["name"], uri=artifact["uri"])
70
+ for artifact in data.pop("artifacts", [])
71
+ ]
72
+
73
+ return models.Experiment(**data, work_units=work_units, artifacts=artifacts)
74
+
75
+ @with_backoff
76
+ def delete_experiment(self, experiment_id: int) -> None:
77
+ response = self.client.delete(self._make_url(f"/experiment/{experiment_id}"))
78
+ response.raise_for_status()
79
+
80
+ @with_backoff
81
+ def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
82
+ assert experiment.title is not None, "Title must be set in the experiment model."
83
+ assert (
84
+ experiment.description is None and experiment.note is None and experiment.tags is None
85
+ ), "Only title should be set in the experiment model."
86
+
87
+ response = self.client.put(
88
+ self._make_url("/experiment"), json=dataclasses.asdict(experiment)
89
+ )
90
+ response.raise_for_status()
91
+ return int(response.json()["xid"])
92
+
93
+ @with_backoff
94
+ def update_experiment(
95
+ self, experiment_id: int, experiment_patch: models.ExperimentPatch
96
+ ) -> None:
97
+ response = self.client.patch(
98
+ self._make_url(f"/experiment/{experiment_id}"),
99
+ json=dataclasses.asdict(experiment_patch),
100
+ )
101
+ response.raise_for_status()
102
+
103
+ @with_backoff
104
+ def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
105
+ response = self.client.put(
106
+ self._make_url(f"/experiment/{experiment_id}/wu"),
107
+ json=dataclasses.asdict(work_unit),
108
+ )
109
+ response.raise_for_status()
110
+
111
+ @with_backoff
112
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
113
+ response = self.client.put(
114
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/job"),
115
+ json=dataclasses.asdict(job),
116
+ )
117
+ response.raise_for_status()
118
+
119
+ @with_backoff
120
+ def insert_work_unit_artifact(
121
+ self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
122
+ ) -> None:
123
+ response = self.client.put(
124
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact"),
125
+ json=dataclasses.asdict(artifact),
126
+ )
127
+ response.raise_for_status()
128
+
129
+ @with_backoff
130
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
131
+ response = self.client.delete(
132
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact/{name}")
133
+ )
134
+ response.raise_for_status()
135
+
136
+ @with_backoff
137
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
138
+ response = self.client.delete(
139
+ self._make_url(f"/experiment/{experiment_id}/artifact/{name}")
140
+ )
141
+ response.raise_for_status()
142
+
143
+ @with_backoff
144
+ def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
145
+ response = self.client.put(
146
+ self._make_url(f"/experiment/{experiment_id}/artifact"),
147
+ json=dataclasses.asdict(artifact),
148
+ )
149
+ response.raise_for_status()
150
+
151
+ @with_backoff
152
+ def insert_experiment_config_artifact(
153
+ self, experiment_id: int, artifact: models.ConfigArtifact
154
+ ) -> None:
155
+ response = self.client.put(
156
+ self._make_url(f"/experiment/{experiment_id}/config"), json=dataclasses.asdict(artifact)
157
+ )
158
+ response.raise_for_status()
159
+
160
+ @with_backoff
161
+ def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
162
+ response = self.client.delete(self._make_url(f"/experiment/{experiment_id}/config/{name}"))
163
+ response.raise_for_status()
164
+
165
+ @with_backoff
166
+ def update_work_unit(
167
+ self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
168
+ ) -> None:
169
+ response = self.client.patch(
170
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}"),
171
+ json=dataclasses.asdict(patch),
172
+ )
173
+ response.raise_for_status()
xm_slurm/batching.py ADDED
@@ -0,0 +1,139 @@
1
+ import asyncio
2
+ import collections
3
+ import dataclasses
4
+ import inspect
5
+ import time
6
+ import types
7
+ import typing as tp
8
+
9
+ T = tp.TypeVar("T", contravariant=True)
10
+ R = tp.TypeVar("R", covariant=True)
11
+ P = tp.ParamSpec("P")
12
+
13
+
14
+ @dataclasses.dataclass(frozen=True, kw_only=True)
15
+ class Request:
16
+ args: inspect.BoundArguments
17
+ future: asyncio.Future
18
+
19
+
20
+ def stack_bound_arguments(
21
+ signature: inspect.Signature, bound_arguments: tp.Sequence[inspect.BoundArguments]
22
+ ) -> inspect.BoundArguments:
23
+ """Stacks bound arguments into a single bound arguments object."""
24
+ stacked_args = collections.OrderedDict[str, tp.Any]()
25
+ for bound_args in bound_arguments:
26
+ for name, value in bound_args.arguments.items():
27
+ stacked_args.setdefault(name, [])
28
+ stacked_args[name].append(value)
29
+ return inspect.BoundArguments(signature, stacked_args)
30
+
31
+
32
+ class batch(tp.Generic[R]):
33
+ __slots__ = (
34
+ "fn",
35
+ "signature",
36
+ "max_batch_size",
37
+ "batch_timeout",
38
+ "loop",
39
+ "process_batch_task",
40
+ "queue",
41
+ )
42
+ __name__: str = "batch"
43
+ __qualname__: str = "batch"
44
+
45
+ def __init__(
46
+ self,
47
+ fn: tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]],
48
+ /,
49
+ *,
50
+ max_batch_size: int,
51
+ batch_timeout: float,
52
+ ) -> None:
53
+ self.fn = fn
54
+ self.signature = inspect.signature(fn)
55
+
56
+ self.max_batch_size = max_batch_size
57
+ self.batch_timeout = batch_timeout
58
+
59
+ self.loop: asyncio.AbstractEventLoop | None = None
60
+ self.process_batch_task: asyncio.Task | None = None
61
+
62
+ self.queue = asyncio.Queue[Request]()
63
+
64
+ async def _process_batch(self):
65
+ assert self.loop is not None
66
+ while not self.loop.is_closed():
67
+ batch = await self._wait_for_batch()
68
+ assert len(batch) > 0
69
+
70
+ bound_args = stack_bound_arguments(self.signature, [request.args for request in batch])
71
+ futures = [request.future for request in batch]
72
+
73
+ results_future = self.fn(*bound_args.args, *bound_args.kwargs)
74
+
75
+ try:
76
+ results = await results_future
77
+
78
+ for result, future in zip(results, futures):
79
+ future.set_result(result)
80
+ except Exception as e:
81
+ for future in futures:
82
+ future.set_exception(e)
83
+
84
+ async def _wait_for_batch(self) -> list[Request]:
85
+ batch = [await self.queue.get()]
86
+
87
+ batch_start_time = time.time()
88
+ while True:
89
+ remaining_batch_time_s = max(self.batch_timeout - (time.time() - batch_start_time), 0)
90
+
91
+ try:
92
+ request = await asyncio.wait_for(self.queue.get(), timeout=remaining_batch_time_s)
93
+ batch.append(request)
94
+ except asyncio.TimeoutError:
95
+ break
96
+
97
+ if (
98
+ time.time() - batch_start_time >= self.batch_timeout
99
+ or len(batch) >= self.max_batch_size
100
+ ):
101
+ break
102
+
103
+ return batch
104
+
105
+ def __get__(self, obj: tp.Any, objtype: tp.Type[tp.Any]) -> tp.Any:
106
+ del objtype
107
+ if isinstance(self.fn, staticmethod):
108
+ return self.__call__
109
+ else:
110
+ return types.MethodType(self, obj)
111
+
112
+ @property
113
+ def __func__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
114
+ return self.fn
115
+
116
+ @property
117
+ def __wrapped__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
118
+ return self.fn
119
+
120
+ @property
121
+ def __isabstractmethod__(self) -> bool:
122
+ """Return whether the wrapped function is abstract."""
123
+ return getattr(self.fn, "__isabstractmethod__", False)
124
+
125
+ @property
126
+ def _is_coroutine(self) -> bool:
127
+ # TODO(jfarebro): py312 adds inspect.markcoroutinefunction
128
+ # until then this is just a hack
129
+ return asyncio.coroutines._is_coroutine # type: ignore
130
+
131
+ async def __call__(self, *args, **kwargs) -> R:
132
+ if self.loop is None and self.process_batch_task is None:
133
+ self.loop = asyncio.get_event_loop()
134
+ self.process_batch_task = self.loop.create_task(self._process_batch())
135
+
136
+ future = asyncio.Future[R]()
137
+ bound_args = self.signature.bind(*args, **kwargs)
138
+ self.queue.put_nowait(Request(args=bound_args, future=future))
139
+ return await future
xm_slurm/config.py ADDED
@@ -0,0 +1,189 @@
1
+ import dataclasses
2
+ import enum
3
+ import functools
4
+ import json
5
+ import os
6
+ import pathlib
7
+ import typing as tp
8
+
9
+ import asyncssh
10
+ from xmanager import xm
11
+
12
+
13
+ class ContainerRuntime(enum.Enum):
14
+ """The container engine to use."""
15
+
16
+ SINGULARITY = enum.auto()
17
+ APPTAINER = enum.auto()
18
+ DOCKER = enum.auto()
19
+ PODMAN = enum.auto()
20
+
21
+ @classmethod
22
+ def from_string(
23
+ cls, runtime: tp.Literal["singularity", "apptainer", "docker", "podman"]
24
+ ) -> "ContainerRuntime":
25
+ return {
26
+ "singularity": cls.SINGULARITY,
27
+ "apptainer": cls.APPTAINER,
28
+ "docker": cls.DOCKER,
29
+ "podman": cls.PODMAN,
30
+ }[runtime]
31
+
32
+ def __str__(self):
33
+ if self is self.SINGULARITY:
34
+ return "singularity"
35
+ elif self is self.APPTAINER:
36
+ return "apptainer"
37
+ elif self is self.DOCKER:
38
+ return "docker"
39
+ elif self is self.PODMAN:
40
+ return "podman"
41
+ else:
42
+ raise NotImplementedError
43
+
44
+
45
+ class Endpoint(tp.NamedTuple):
46
+ hostname: str
47
+ port: int | None
48
+
49
+ def __str__(self) -> str:
50
+ if self.port is None or self.port == asyncssh.DEFAULT_PORT:
51
+ return self.hostname
52
+ return f"[{self.hostname}]:{self.port}"
53
+
54
+
55
+ class PublicKey(tp.NamedTuple):
56
+ algorithm: str
57
+ key: str
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class SSHConfig:
62
+ endpoints: tuple[Endpoint, ...]
63
+ public_key: PublicKey | None = None
64
+ user: str | None = None
65
+
66
+ def __post_init__(self) -> None:
67
+ if not isinstance(self.endpoints, tuple):
68
+ raise TypeError(f"endpoints must be a tuple, not {type(self.endpoints)}")
69
+ if len(self.endpoints) == 0:
70
+ raise ValueError("endpoints must be a non-empty tuple")
71
+ if not all(isinstance(endpoint, Endpoint) for endpoint in self.endpoints):
72
+ raise TypeError(f"endpoints must be a tuple of strings, not {type(self.endpoints)}")
73
+
74
+ if not isinstance(self.user, str | None):
75
+ raise TypeError(f"user must be a string or None, not {type(self.user)}")
76
+ if not isinstance(self.public_key, PublicKey | None):
77
+ raise TypeError(
78
+ f"public_key must be a SSHPublicKey or None, not {type(self.public_key)}"
79
+ )
80
+
81
+ @functools.cached_property
82
+ def known_hosts(self) -> asyncssh.SSHKnownHosts | None:
83
+ if self.public_key is None:
84
+ return None
85
+
86
+ known_hosts = []
87
+ for endpoint in self.endpoints:
88
+ known_hosts.append(f"{endpoint!s} {self.public_key.algorithm} {self.public_key.key}")
89
+
90
+ return asyncssh.SSHKnownHosts("\n".join(known_hosts))
91
+
92
+ def serialize(self):
93
+ return json.dumps({
94
+ "endpoints": tuple(tuple(endpoint) for endpoint in self.endpoints),
95
+ "public_key": tuple(self.public_key),
96
+ "user": self.user,
97
+ })
98
+
99
+ @classmethod
100
+ def deserialize(cls, data):
101
+ data = json.loads(data)
102
+ return cls(
103
+ endpoints=tuple(Endpoint(*endpoint) for endpoint in data["endpoints"]),
104
+ public_key=PublicKey(*data["public_key"]) if data["public_key"] else None,
105
+ user=data["user"],
106
+ )
107
+
108
+ def __hash__(self):
109
+ return hash((
110
+ type(self),
111
+ *(tuple(endpoint) for endpoint in self.endpoints),
112
+ self.public_key,
113
+ self.user,
114
+ ))
115
+
116
+
117
+ @dataclasses.dataclass(frozen=True, kw_only=True)
118
+ class SlurmClusterConfig:
119
+ name: str
120
+
121
+ ssh: SSHConfig
122
+
123
+ # Job submission directory
124
+ cwd: str | None = None
125
+
126
+ # Additional scripting
127
+ prolog: str | None = None
128
+ epilog: str | None = None
129
+
130
+ # Job scheduling
131
+ account: str | None = None
132
+ partition: str | None = None
133
+ qos: str | None = None
134
+
135
+ # If true, a reverse proxy is initiated via the submission host.
136
+ proxy: tp.Literal["submission-host"] | str | None = None
137
+
138
+ runtime: ContainerRuntime
139
+
140
+ # Environment variables
141
+ host_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
142
+ container_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
143
+
144
+ # Mounts
145
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
146
+ default_factory=dict
147
+ )
148
+
149
+ # Resource mapping
150
+ resources: tp.Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
151
+
152
+ features: tp.Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
153
+
154
+ # Function to validate the Slurm executor config
155
+ validate: tp.Callable[[xm.Job], None] | None = None
156
+
157
+ def __post_init__(self) -> None:
158
+ if not isinstance(self.ssh, SSHConfig):
159
+ raise TypeError(f"ssh must be a SlurmSSHConfig, not {type(self.ssh)}")
160
+ for src, dst in self.mounts.items():
161
+ if not isinstance(src, (str, os.PathLike)):
162
+ raise TypeError(
163
+ f"Mount source must be a string or path-like object, not {type(src)}"
164
+ )
165
+ if not isinstance(dst, (str, os.PathLike)):
166
+ raise TypeError(
167
+ f"Mount destination must be a string or path-like object, not {type(dst)}"
168
+ )
169
+
170
+ if not pathlib.Path(src).is_absolute():
171
+ raise ValueError(f"Mount source must be an absolute path: {src}")
172
+ if not pathlib.Path(dst).is_absolute():
173
+ raise ValueError(f"Mount destination must be an absolute path: {dst}")
174
+
175
+ def __hash__(self):
176
+ return hash((
177
+ type(self),
178
+ self.ssh,
179
+ self.cwd,
180
+ self.prolog,
181
+ self.epilog,
182
+ self.account,
183
+ self.partition,
184
+ self.qos,
185
+ self.proxy,
186
+ self.runtime,
187
+ frozenset(self.host_environment.items()),
188
+ frozenset(self.container_environment.items()),
189
+ ))
xm_slurm/console.py ADDED
@@ -0,0 +1,3 @@
1
+ from rich.console import Console
2
+
3
+ console = Console()
xm_slurm/constants.py ADDED
@@ -0,0 +1,19 @@
1
+ import re
2
+
3
+ SLURM_JOB_ID_REGEX = re.compile(
4
+ r"^(?P<jobid>\d+)(?:(?:\+(?P<componentid>\d+))|(?:_(?P<arraytaskid>\d+)))?$"
5
+ )
6
+
7
+ IMAGE_URI_REGEX = re.compile(
8
+ r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
9
+ )
10
+
11
+ DOMAIN_NAME_REGEX = re.compile(
12
+ r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
13
+ )
14
+ IPV4_REGEX = re.compile(
15
+ r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
16
+ )
17
+ IPV6_REGEX = re.compile(
18
+ r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
19
+ )
File without changes
@@ -0,0 +1,67 @@
1
+ import logging
2
+ import os
3
+
4
+ from xm_slurm import config, resources
5
+ from xm_slurm.contrib.clusters import drac
6
+
7
+ # ComputeCanada alias
8
+ cc = drac
9
+
10
+ __all__ = ["drac", "mila", "cc"]
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def mila(
16
+ *,
17
+ user: str | None = None,
18
+ partition: str | None = None,
19
+ mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
20
+ ) -> config.SlurmClusterConfig:
21
+ """Mila Cluster (https://docs.mila.quebec/)."""
22
+ if mounts is None:
23
+ mounts = {
24
+ "/network/scratch/${USER:0:1}/$USER": "/scratch",
25
+ # TODO: move these somewhere common to all cluster configs.
26
+ "/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
27
+ "/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
28
+ }
29
+ mounts = dict(mounts) | {"/dev/infiniband": "/dev/infiniband"}
30
+
31
+ return config.SlurmClusterConfig(
32
+ name="mila",
33
+ ssh=config.SSHConfig(
34
+ user=user,
35
+ endpoints=(config.Endpoint("login.server.mila.quebec", 2222),),
36
+ public_key=config.PublicKey(
37
+ "ssh-ed25519",
38
+ "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
39
+ ),
40
+ ),
41
+ runtime=config.ContainerRuntime.SINGULARITY,
42
+ partition=partition,
43
+ prolog="module load singularity",
44
+ host_environment={
45
+ "SINGULARITY_CACHEDIR": "$SCRATCH/.apptainer",
46
+ "SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
47
+ "SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
48
+ },
49
+ container_environment={
50
+ "SCRATCH": "/scratch",
51
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
52
+ },
53
+ mounts=mounts,
54
+ resources={
55
+ resources.ResourceType.RTX8000: "rtx8000",
56
+ resources.ResourceType.V100: "v100",
57
+ resources.ResourceType.A100: "a100",
58
+ resources.ResourceType.A100_80GIB: "a100l",
59
+ resources.ResourceType.A6000: "a6000",
60
+ resources.ResourceType.L40S: "l40s",
61
+ resources.ResourceType.H100: "h100",
62
+ },
63
+ features={
64
+ resources.FeatureType.NVIDIA_MIG: "mig",
65
+ resources.FeatureType.NVIDIA_NVLINK: "nvlink",
66
+ },
67
+ )