xmanager-slurm 0.4.13__py3-none-any.whl → 0.4.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +2 -1
- xm_slurm/batching.py +11 -11
- xm_slurm/config.py +10 -10
- xm_slurm/contrib/clusters/drac.py +15 -29
- xm_slurm/dependencies.py +7 -7
- xm_slurm/execution.py +10 -0
- xm_slurm/executors.py +82 -12
- xm_slurm/experimental/parameter_controller.py +18 -14
- xm_slurm/job_blocks.py +3 -3
- xm_slurm/packageables.py +23 -23
- xm_slurm/packaging/registry.py +14 -14
- xm_slurm/packaging/router.py +3 -3
- xm_slurm/packaging/utils.py +5 -5
- xm_slurm/resources.py +198 -28
- xm_slurm/status.py +2 -2
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +2 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -1
- xm_slurm/templates/slurm/job-group.bash.j2 +1 -0
- xm_slurm/templates/slurm/job.bash.j2 +8 -1
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +8 -7
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +4 -3
- xm_slurm/utils.py +8 -0
- {xmanager_slurm-0.4.13.dist-info → xmanager_slurm-0.4.15.dist-info}/METADATA +1 -1
- xmanager_slurm-0.4.15.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.13.dist-info/RECORD +0 -51
- {xmanager_slurm-0.4.13.dist-info → xmanager_slurm-0.4.15.dist-info}/WHEEL +0 -0
- {xmanager_slurm-0.4.13.dist-info → xmanager_slurm-0.4.15.dist-info}/entry_points.txt +0 -0
- {xmanager_slurm-0.4.13.dist-info → xmanager_slurm-0.4.15.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/__init__.py
CHANGED
|
@@ -18,7 +18,7 @@ from xm_slurm.packageables import (
|
|
|
18
18
|
python_container,
|
|
19
19
|
uv_container,
|
|
20
20
|
)
|
|
21
|
-
from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
|
|
21
|
+
from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType, Topology
|
|
22
22
|
|
|
23
23
|
logging.getLogger("asyncssh").setLevel(logging.WARN)
|
|
24
24
|
logging.getLogger("httpx").setLevel(logging.WARN)
|
|
@@ -42,5 +42,6 @@ __all__ = [
|
|
|
42
42
|
"Slurm",
|
|
43
43
|
"SlurmExperiment",
|
|
44
44
|
"SlurmSpec",
|
|
45
|
+
"Topology",
|
|
45
46
|
"uv_container",
|
|
46
47
|
]
|
xm_slurm/batching.py
CHANGED
|
@@ -4,11 +4,11 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
import time
|
|
6
6
|
import types
|
|
7
|
-
|
|
7
|
+
import typing as tp
|
|
8
8
|
|
|
9
|
-
T = TypeVar("T", contravariant=True)
|
|
10
|
-
R = TypeVar("R", covariant=True)
|
|
11
|
-
P = ParamSpec("P")
|
|
9
|
+
T = tp.TypeVar("T", contravariant=True)
|
|
10
|
+
R = tp.TypeVar("R", covariant=True)
|
|
11
|
+
P = tp.ParamSpec("P")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
@@ -18,10 +18,10 @@ class Request:
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def stack_bound_arguments(
|
|
21
|
-
signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
|
|
21
|
+
signature: inspect.Signature, bound_arguments: tp.Sequence[inspect.BoundArguments]
|
|
22
22
|
) -> inspect.BoundArguments:
|
|
23
23
|
"""Stacks bound arguments into a single bound arguments object."""
|
|
24
|
-
stacked_args = collections.OrderedDict[str, Any]()
|
|
24
|
+
stacked_args = collections.OrderedDict[str, tp.Any]()
|
|
25
25
|
for bound_args in bound_arguments:
|
|
26
26
|
for name, value in bound_args.arguments.items():
|
|
27
27
|
stacked_args.setdefault(name, [])
|
|
@@ -29,7 +29,7 @@ def stack_bound_arguments(
|
|
|
29
29
|
return inspect.BoundArguments(signature, stacked_args)
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class batch(Generic[R]):
|
|
32
|
+
class batch(tp.Generic[R]):
|
|
33
33
|
__slots__ = (
|
|
34
34
|
"fn",
|
|
35
35
|
"signature",
|
|
@@ -44,7 +44,7 @@ class batch(Generic[R]):
|
|
|
44
44
|
|
|
45
45
|
def __init__(
|
|
46
46
|
self,
|
|
47
|
-
fn: Callable[..., Coroutine[None, None, Sequence[R]]],
|
|
47
|
+
fn: tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]],
|
|
48
48
|
/,
|
|
49
49
|
*,
|
|
50
50
|
max_batch_size: int,
|
|
@@ -102,7 +102,7 @@ class batch(Generic[R]):
|
|
|
102
102
|
|
|
103
103
|
return batch
|
|
104
104
|
|
|
105
|
-
def __get__(self, obj: Any, objtype:
|
|
105
|
+
def __get__(self, obj: tp.Any, objtype: tp.Type[tp.Any]) -> tp.Any:
|
|
106
106
|
del objtype
|
|
107
107
|
if isinstance(self.fn, staticmethod):
|
|
108
108
|
return self.__call__
|
|
@@ -110,11 +110,11 @@ class batch(Generic[R]):
|
|
|
110
110
|
return types.MethodType(self, obj)
|
|
111
111
|
|
|
112
112
|
@property
|
|
113
|
-
def __func__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
|
|
113
|
+
def __func__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
|
|
114
114
|
return self.fn
|
|
115
115
|
|
|
116
116
|
@property
|
|
117
|
-
def __wrapped__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
|
|
117
|
+
def __wrapped__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
|
|
118
118
|
return self.fn
|
|
119
119
|
|
|
120
120
|
@property
|
xm_slurm/config.py
CHANGED
|
@@ -5,7 +5,7 @@ import getpass
|
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
7
|
import pathlib
|
|
8
|
-
|
|
8
|
+
import typing as tp
|
|
9
9
|
|
|
10
10
|
import asyncssh
|
|
11
11
|
from xmanager import xm
|
|
@@ -23,7 +23,7 @@ class ContainerRuntime(enum.Enum):
|
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
25
|
def from_string(
|
|
26
|
-
cls, runtime: Literal["singularity", "apptainer", "docker", "podman"]
|
|
26
|
+
cls, runtime: tp.Literal["singularity", "apptainer", "docker", "podman"]
|
|
27
27
|
) -> "ContainerRuntime":
|
|
28
28
|
return {
|
|
29
29
|
"singularity": cls.SINGULARITY,
|
|
@@ -45,7 +45,7 @@ class ContainerRuntime(enum.Enum):
|
|
|
45
45
|
raise NotImplementedError
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
class PublicKey(NamedTuple):
|
|
48
|
+
class PublicKey(tp.NamedTuple):
|
|
49
49
|
algorithm: str
|
|
50
50
|
key: str
|
|
51
51
|
|
|
@@ -172,26 +172,26 @@ class SlurmClusterConfig:
|
|
|
172
172
|
qos: str | None = None
|
|
173
173
|
|
|
174
174
|
# If true, a reverse proxy is initiated via the submission host.
|
|
175
|
-
proxy: Literal["submission-host"] | str | None = None
|
|
175
|
+
proxy: tp.Literal["submission-host"] | str | None = None
|
|
176
176
|
|
|
177
177
|
runtime: ContainerRuntime
|
|
178
178
|
|
|
179
179
|
# Environment variables
|
|
180
|
-
host_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
181
|
-
container_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
180
|
+
host_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
181
|
+
container_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
182
182
|
|
|
183
183
|
# Mounts
|
|
184
|
-
mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
|
|
184
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
|
|
185
185
|
default_factory=dict
|
|
186
186
|
)
|
|
187
187
|
|
|
188
188
|
# Resource mapping
|
|
189
|
-
resources: Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
189
|
+
resources: tp.Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
190
190
|
|
|
191
|
-
features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
191
|
+
features: tp.Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
192
192
|
|
|
193
193
|
# Function to validate the Slurm executor config
|
|
194
|
-
validate: Callable[[xm.Job], None] | None = None
|
|
194
|
+
validate: tp.Callable[[xm.Job], None] | None = None
|
|
195
195
|
|
|
196
196
|
def __post_init__(self) -> None:
|
|
197
197
|
for src, dst in self.mounts.items():
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
|
|
2
|
+
import typing as tp
|
|
3
3
|
|
|
4
4
|
from xm_slurm import config
|
|
5
5
|
from xm_slurm.resources import FeatureType, ResourceType
|
|
@@ -16,10 +16,10 @@ def _drac_cluster(
|
|
|
16
16
|
user: str | None = None,
|
|
17
17
|
account: str | None = None,
|
|
18
18
|
modules: list[str] | None = None,
|
|
19
|
-
proxy: Literal["submission-host"] | str | None = None,
|
|
20
|
-
mounts:
|
|
21
|
-
resources:
|
|
22
|
-
features:
|
|
19
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
20
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
21
|
+
resources: tp.Mapping[ResourceType, str] | None = None,
|
|
22
|
+
features: tp.Mapping[FeatureType, str] | None = None,
|
|
23
23
|
) -> config.SlurmClusterConfig:
|
|
24
24
|
"""DRAC Cluster."""
|
|
25
25
|
if mounts is None:
|
|
@@ -62,8 +62,8 @@ def narval(
|
|
|
62
62
|
*,
|
|
63
63
|
user: str | None = None,
|
|
64
64
|
account: str | None = None,
|
|
65
|
-
proxy: Literal["submission-host"] | str | None = None,
|
|
66
|
-
mounts:
|
|
65
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
66
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
67
67
|
) -> config.SlurmClusterConfig:
|
|
68
68
|
"""DRAC Narval Cluster (https://docs.alliancecan.ca/wiki/Narval/en)."""
|
|
69
69
|
modules = []
|
|
@@ -94,8 +94,8 @@ def beluga(
|
|
|
94
94
|
*,
|
|
95
95
|
user: str | None = None,
|
|
96
96
|
account: str | None = None,
|
|
97
|
-
proxy: Literal["submission-host"] | str | None = None,
|
|
98
|
-
mounts:
|
|
97
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
98
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
99
99
|
) -> config.SlurmClusterConfig:
|
|
100
100
|
"""DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/B%C3%A9luga/en)."""
|
|
101
101
|
modules = []
|
|
@@ -125,8 +125,8 @@ def rorqual(
|
|
|
125
125
|
*,
|
|
126
126
|
user: str | None = None,
|
|
127
127
|
account: str | None = None,
|
|
128
|
-
proxy: Literal["submission-host"] | str | None = None,
|
|
129
|
-
mounts:
|
|
128
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
129
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
130
130
|
) -> config.SlurmClusterConfig:
|
|
131
131
|
"""DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/Rorqual/en)."""
|
|
132
132
|
modules = []
|
|
@@ -155,7 +155,7 @@ def cedar(
|
|
|
155
155
|
*,
|
|
156
156
|
user: str | None = None,
|
|
157
157
|
account: str | None = None,
|
|
158
|
-
mounts:
|
|
158
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
159
159
|
) -> config.SlurmClusterConfig:
|
|
160
160
|
"""DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Cedar/en)."""
|
|
161
161
|
return _drac_cluster(
|
|
@@ -180,7 +180,7 @@ def fir(
|
|
|
180
180
|
*,
|
|
181
181
|
user: str | None = None,
|
|
182
182
|
account: str | None = None,
|
|
183
|
-
mounts:
|
|
183
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
184
184
|
) -> config.SlurmClusterConfig:
|
|
185
185
|
"""DRAC Fir Cluster (https://docs.alliancecan.ca/wiki/Fir/en)."""
|
|
186
186
|
return _drac_cluster(
|
|
@@ -201,8 +201,8 @@ def graham(
|
|
|
201
201
|
*,
|
|
202
202
|
user: str | None = None,
|
|
203
203
|
account: str | None = None,
|
|
204
|
-
proxy: Literal["submission-host"] | str | None = "submission-host",
|
|
205
|
-
mounts:
|
|
204
|
+
proxy: tp.Literal["submission-host"] | str | None = "submission-host",
|
|
205
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
206
206
|
) -> config.SlurmClusterConfig:
|
|
207
207
|
"""DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Graham/en)."""
|
|
208
208
|
return _drac_cluster(
|
|
@@ -223,17 +223,3 @@ def graham(
|
|
|
223
223
|
ResourceType.A5000: "a5000",
|
|
224
224
|
},
|
|
225
225
|
)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
def all(
|
|
229
|
-
user: str | None = None,
|
|
230
|
-
account: str | None = None,
|
|
231
|
-
mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
232
|
-
) -> list[config.SlurmClusterConfig]:
|
|
233
|
-
"""All DRAC clusters."""
|
|
234
|
-
return [
|
|
235
|
-
narval(user=user, account=account, mounts=mounts),
|
|
236
|
-
beluga(user=user, account=account, mounts=mounts),
|
|
237
|
-
cedar(user=user, account=account, mounts=mounts),
|
|
238
|
-
graham(user=user, account=account, mounts=mounts),
|
|
239
|
-
]
|
xm_slurm/dependencies.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import dataclasses
|
|
3
3
|
import datetime as dt
|
|
4
|
-
|
|
4
|
+
import typing as tp
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class SlurmDependencyException(Exception): ...
|
|
@@ -36,7 +36,7 @@ class SlurmJobDependency(abc.ABC):
|
|
|
36
36
|
return (self,)
|
|
37
37
|
|
|
38
38
|
def traverse(
|
|
39
|
-
self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
|
|
39
|
+
self, mapper: tp.Callable[["SlurmJobDependency"], "SlurmJobDependency"]
|
|
40
40
|
) -> "SlurmJobDependency":
|
|
41
41
|
if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
|
|
42
42
|
return type(self)(
|
|
@@ -80,7 +80,7 @@ class SlurmJobDependencyOR(SlurmJobDependency):
|
|
|
80
80
|
|
|
81
81
|
@dataclasses.dataclass(frozen=True)
|
|
82
82
|
class SlurmJobDependencyAfter(SlurmJobDependency):
|
|
83
|
-
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
83
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
84
84
|
time: dt.timedelta | None = None
|
|
85
85
|
|
|
86
86
|
def __post_init__(self):
|
|
@@ -104,7 +104,7 @@ class SlurmJobDependencyAfter(SlurmJobDependency):
|
|
|
104
104
|
|
|
105
105
|
@dataclasses.dataclass(frozen=True)
|
|
106
106
|
class SlurmJobDependencyAfterAny(SlurmJobDependency):
|
|
107
|
-
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
107
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
108
108
|
|
|
109
109
|
def __post_init__(self):
|
|
110
110
|
if len(self.handles) == 0:
|
|
@@ -119,7 +119,7 @@ class SlurmJobDependencyAfterAny(SlurmJobDependency):
|
|
|
119
119
|
|
|
120
120
|
@dataclasses.dataclass(frozen=True)
|
|
121
121
|
class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
|
|
122
|
-
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
122
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
123
123
|
|
|
124
124
|
def __post_init__(self):
|
|
125
125
|
if len(self.handles) == 0:
|
|
@@ -134,7 +134,7 @@ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
|
|
|
134
134
|
|
|
135
135
|
@dataclasses.dataclass(frozen=True)
|
|
136
136
|
class SlurmJobDependencyAfterOK(SlurmJobDependency):
|
|
137
|
-
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
137
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
138
138
|
|
|
139
139
|
def __post_init__(self):
|
|
140
140
|
if len(self.handles) == 0:
|
|
@@ -149,7 +149,7 @@ class SlurmJobDependencyAfterOK(SlurmJobDependency):
|
|
|
149
149
|
|
|
150
150
|
@dataclasses.dataclass(frozen=True)
|
|
151
151
|
class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
|
|
152
|
-
handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
|
|
152
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
|
|
153
153
|
|
|
154
154
|
def __post_init__(self):
|
|
155
155
|
if len(self.handles) == 0:
|
xm_slurm/execution.py
CHANGED
|
@@ -3,6 +3,8 @@ import collections.abc
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import functools
|
|
5
5
|
import hashlib
|
|
6
|
+
import importlib
|
|
7
|
+
import importlib.resources
|
|
6
8
|
import logging
|
|
7
9
|
import operator
|
|
8
10
|
import os
|
|
@@ -311,6 +313,14 @@ def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
|
|
|
311
313
|
template_env.globals["raise"] = _raise_template_exception
|
|
312
314
|
template_env.globals["operator"] = operator
|
|
313
315
|
|
|
316
|
+
# Iterate over stdlib files and insert them into the template environment
|
|
317
|
+
stdlib = []
|
|
318
|
+
for file in importlib.resources.files("xm_slurm.templates.slurm.library").iterdir():
|
|
319
|
+
if not file.is_file() or not file.name.endswith(".bash"):
|
|
320
|
+
continue
|
|
321
|
+
stdlib.append(file.read_text())
|
|
322
|
+
template_env.globals["stdlib"] = stdlib
|
|
323
|
+
|
|
314
324
|
entrypoint_template = template_env.get_template("entrypoint.bash.j2")
|
|
315
325
|
template_env.globals.update(entrypoint_template.module.__dict__)
|
|
316
326
|
|
xm_slurm/executors.py
CHANGED
|
@@ -1,10 +1,18 @@
|
|
|
1
|
+
import collections.abc
|
|
1
2
|
import dataclasses
|
|
2
3
|
import datetime as dt
|
|
3
4
|
import signal
|
|
5
|
+
import typing as tp
|
|
4
6
|
|
|
5
7
|
from xmanager import xm
|
|
6
8
|
|
|
7
|
-
from xm_slurm import resources
|
|
9
|
+
from xm_slurm import resources, utils
|
|
10
|
+
|
|
11
|
+
ResourceBindType = tp.Literal[
|
|
12
|
+
resources.ResourceType.GPU,
|
|
13
|
+
resources.ResourceType.MEMORY,
|
|
14
|
+
resources.ResourceType.RAM,
|
|
15
|
+
]
|
|
8
16
|
|
|
9
17
|
|
|
10
18
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
@@ -26,10 +34,19 @@ class Slurm(xm.Executor):
|
|
|
26
34
|
Args:
|
|
27
35
|
requirements: The requirements for the job.
|
|
28
36
|
time: The maximum time to run the job.
|
|
37
|
+
switches: Maximum count of leaf switches desired for the job allocation.
|
|
38
|
+
switches_grace_period: Maximum time to wait for that number of switches.
|
|
39
|
+
bind: How to bind tasks to resource (memory, GPU, or generic resource).
|
|
29
40
|
account: The account to charge the job to.
|
|
30
41
|
partition: The partition to run the job in.
|
|
31
42
|
qos: The quality of service to run the job with.
|
|
32
43
|
priority: The priority of the job.
|
|
44
|
+
reservation: Allocate resources for the job from the named reservation.
|
|
45
|
+
exclusive: Allow sharing nodes with other running jobs.
|
|
46
|
+
oversubscribe: Allow over-subscribing resources with other running jobs.
|
|
47
|
+
overcommit: Allow sharing of allocated resources as if only one task per was requested.
|
|
48
|
+
nice: Run the job with an adjusted scheduling priority.
|
|
49
|
+
kill_on_invalid_dependencies: Whether to kill the job if it has invalid dependencies.
|
|
33
50
|
timeout_signal: The signal to send to the job when it runs out of time.
|
|
34
51
|
timeout_signal_grace_period: The time to wait before sending `timeout_signal`.
|
|
35
52
|
requeue: Whether or not the job is eligible for requeueing.
|
|
@@ -41,12 +58,18 @@ class Slurm(xm.Executor):
|
|
|
41
58
|
# Job requirements
|
|
42
59
|
requirements: resources.JobRequirements
|
|
43
60
|
time: dt.timedelta
|
|
61
|
+
bind: tp.Mapping[ResourceBindType | str, str | None] | None = None
|
|
44
62
|
|
|
45
63
|
# Placement
|
|
46
64
|
account: str | None = None
|
|
47
65
|
partition: str | None = None
|
|
48
66
|
qos: str | None = None
|
|
49
67
|
priority: int | None = None
|
|
68
|
+
reservation: str | tp.Iterable[str] | None = None
|
|
69
|
+
exclusive: bool = False
|
|
70
|
+
oversubscribe: bool = False
|
|
71
|
+
overcommit: bool = False
|
|
72
|
+
nice: int | None = None
|
|
50
73
|
|
|
51
74
|
# Job dependency handling
|
|
52
75
|
kill_on_invalid_dependencies: bool = True
|
|
@@ -65,13 +88,28 @@ class Slurm(xm.Executor):
|
|
|
65
88
|
return self.time - self.timeout_signal_grace_period
|
|
66
89
|
|
|
67
90
|
def __post_init__(self) -> None:
|
|
68
|
-
if not isinstance(self.time, dt.timedelta):
|
|
69
|
-
raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
|
|
70
91
|
if not isinstance(self.requirements, resources.JobRequirements):
|
|
71
92
|
raise TypeError(
|
|
72
93
|
f"requirements must be a `xm_slurm.JobRequirements`, got {type(self.requirements)}. "
|
|
73
94
|
"If you're still using `xm.JobRequirements`, please update to `xm_slurm.JobRequirements`."
|
|
74
95
|
)
|
|
96
|
+
if not isinstance(self.time, dt.timedelta):
|
|
97
|
+
raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
|
|
98
|
+
if self.bind is not None:
|
|
99
|
+
if not isinstance(self.bind, collections.abc.Mapping):
|
|
100
|
+
raise TypeError(f"bind must be a mapping, got {type(self.bind)}")
|
|
101
|
+
for resource, value in self.bind.items():
|
|
102
|
+
if resource not in (
|
|
103
|
+
resources.ResourceType.GPU,
|
|
104
|
+
resources.ResourceType.MEMORY,
|
|
105
|
+
resources.ResourceType.RAM,
|
|
106
|
+
) and not isinstance(resource, str):
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"bind resource must be a {resources.ResourceType.GPU.name}, {resources.ResourceType.MEMORY.name}, or {resources.ResourceType.RAM.name}, got {type(resource)}"
|
|
109
|
+
)
|
|
110
|
+
if value is not None and not isinstance(value, str):
|
|
111
|
+
raise TypeError(f"bind value must be None or a string, got {type(value)}")
|
|
112
|
+
|
|
75
113
|
if not isinstance(self.timeout_signal, signal.Signals):
|
|
76
114
|
raise TypeError(
|
|
77
115
|
f"termination_signal must be a `signal.Signals`, got {type(self.timeout_signal)}"
|
|
@@ -86,6 +124,10 @@ class Slurm(xm.Executor):
|
|
|
86
124
|
)
|
|
87
125
|
if self.requeue_on_exit_code == 0:
|
|
88
126
|
raise ValueError("requeue_on_exit_code should not be 0 to avoid unexpected behavior.")
|
|
127
|
+
if self.exclusive and self.oversubscribe:
|
|
128
|
+
raise ValueError("exclusive and oversubscribe are mutually exclusive.")
|
|
129
|
+
if self.nice is not None and not (-2147483645 <= self.nice <= 2147483645):
|
|
130
|
+
raise ValueError(f"nice must be between -2147483645 and 2147483645, got {self.nice}")
|
|
89
131
|
|
|
90
132
|
@classmethod
|
|
91
133
|
def Spec(cls, tag: str | None = None) -> SlurmSpec:
|
|
@@ -96,10 +138,22 @@ class Slurm(xm.Executor):
|
|
|
96
138
|
directives = self.requirements.to_directives()
|
|
97
139
|
|
|
98
140
|
# Time
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
141
|
+
directives.append(f"--time={utils.timestr_from_timedelta(self.time)}")
|
|
142
|
+
|
|
143
|
+
# Resource binding
|
|
144
|
+
if self.bind is not None:
|
|
145
|
+
for resource, value in self.bind.items():
|
|
146
|
+
if value is None:
|
|
147
|
+
value = "none"
|
|
148
|
+
match resource:
|
|
149
|
+
case resources.ResourceType.MEMORY | resources.ResourceType.RAM:
|
|
150
|
+
directives.append(f"--mem-bind={value}")
|
|
151
|
+
case resources.ResourceType.GPU:
|
|
152
|
+
directives.append(f"--gpu-bind={value}")
|
|
153
|
+
case str():
|
|
154
|
+
directives.append(f"--tres-bind=gres/{resource}:{value}")
|
|
155
|
+
case _:
|
|
156
|
+
raise ValueError(f"Unsupported resource type {resource!r} for binding.")
|
|
103
157
|
|
|
104
158
|
# Job dependency handling
|
|
105
159
|
directives.append(
|
|
@@ -107,20 +161,36 @@ class Slurm(xm.Executor):
|
|
|
107
161
|
)
|
|
108
162
|
|
|
109
163
|
# Placement
|
|
110
|
-
if self.account:
|
|
164
|
+
if self.account is not None:
|
|
111
165
|
directives.append(f"--account={self.account}")
|
|
112
|
-
if self.partition:
|
|
166
|
+
if self.partition is not None:
|
|
113
167
|
directives.append(f"--partition={self.partition}")
|
|
114
|
-
if self.qos:
|
|
168
|
+
if self.qos is not None:
|
|
115
169
|
directives.append(f"--qos={self.qos}")
|
|
116
|
-
if self.priority:
|
|
170
|
+
if self.priority is not None:
|
|
117
171
|
directives.append(f"--priority={self.priority}")
|
|
172
|
+
if self.reservation is not None:
|
|
173
|
+
match self.reservation:
|
|
174
|
+
case str():
|
|
175
|
+
directives.append(f"--reservation={self.reservation}")
|
|
176
|
+
case collections.abc.Iterable():
|
|
177
|
+
directives.append(f"--reservation={','.join(self.reservation)}")
|
|
178
|
+
case _:
|
|
179
|
+
raise ValueError(f"Invalid reservation type: {type(self.reservation)}")
|
|
180
|
+
if self.exclusive:
|
|
181
|
+
directives.append("--exclusive")
|
|
182
|
+
if self.oversubscribe:
|
|
183
|
+
directives.append("--oversubscribe")
|
|
184
|
+
if self.overcommit:
|
|
185
|
+
directives.append("--overcommit")
|
|
186
|
+
if self.nice is not None:
|
|
187
|
+
directives.append(f"--nice={self.nice}")
|
|
118
188
|
|
|
119
189
|
# Job rescheduling
|
|
120
190
|
directives.append(
|
|
121
191
|
f"--signal={self.timeout_signal.name.removeprefix('SIG')}@{self.timeout_signal_grace_period.seconds}"
|
|
122
192
|
)
|
|
123
|
-
if self.requeue and self.requeue_max_attempts > 0:
|
|
193
|
+
if self.requeue is not None and self.requeue_max_attempts > 0:
|
|
124
194
|
directives.append("--requeue")
|
|
125
195
|
else:
|
|
126
196
|
directives.append("--no-requeue")
|
|
@@ -4,8 +4,8 @@ import dataclasses
|
|
|
4
4
|
import enum
|
|
5
5
|
import functools
|
|
6
6
|
import logging
|
|
7
|
+
import typing as tp
|
|
7
8
|
import zlib
|
|
8
|
-
from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
|
|
9
9
|
|
|
10
10
|
import backoff
|
|
11
11
|
import cloudpickle
|
|
@@ -15,15 +15,15 @@ import xm_slurm
|
|
|
15
15
|
from xm_slurm import job_blocks, status
|
|
16
16
|
from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
|
|
17
17
|
|
|
18
|
-
P = ParamSpec("P")
|
|
19
|
-
T = TypeVar("T")
|
|
18
|
+
P = tp.ParamSpec("P")
|
|
19
|
+
T = tp.TypeVar("T")
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
async def _monitor_parameter_controller(
|
|
25
25
|
aux_unit: SlurmAuxiliaryUnit,
|
|
26
|
-
local_parameter_controller_coro: Coroutine[None, None, T],
|
|
26
|
+
local_parameter_controller_coro: tp.Coroutine[None, None, T],
|
|
27
27
|
*,
|
|
28
28
|
poll_interval: float = 30.0,
|
|
29
29
|
) -> None:
|
|
@@ -104,13 +104,13 @@ def parameter_controller(
|
|
|
104
104
|
controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
|
|
105
105
|
controller_name: str = "parameter_controller",
|
|
106
106
|
controller_args: xm.UserArgs | None = None,
|
|
107
|
-
controller_env_vars: Mapping[str, str] | None = None,
|
|
108
|
-
) -> Callable[
|
|
107
|
+
controller_env_vars: tp.Mapping[str, str] | None = None,
|
|
108
|
+
) -> tp.Callable[
|
|
109
109
|
[
|
|
110
|
-
Callable[Concatenate[SlurmExperiment, P], T]
|
|
111
|
-
| Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
|
|
110
|
+
tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
|
|
111
|
+
| tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
|
|
112
112
|
],
|
|
113
|
-
Callable[P, xm.AuxiliaryUnitJob],
|
|
113
|
+
tp.Callable[P, xm.AuxiliaryUnitJob],
|
|
114
114
|
]:
|
|
115
115
|
"""Converts a function to a controller which can be added to an experiment.
|
|
116
116
|
|
|
@@ -131,9 +131,9 @@ def parameter_controller(
|
|
|
131
131
|
"""
|
|
132
132
|
|
|
133
133
|
def decorator(
|
|
134
|
-
f: Callable[Concatenate[SlurmExperiment, P], T]
|
|
135
|
-
| Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
|
|
136
|
-
) -> Callable[P, xm.AuxiliaryUnitJob]:
|
|
134
|
+
f: tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
|
|
135
|
+
| tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
|
|
136
|
+
) -> tp.Callable[P, xm.AuxiliaryUnitJob]:
|
|
137
137
|
@functools.wraps(f)
|
|
138
138
|
def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
|
|
139
139
|
# Modify the function to read the experiment from the API so that it can be pickled.
|
|
@@ -141,13 +141,17 @@ def parameter_controller(
|
|
|
141
141
|
async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
|
|
142
142
|
experiment_id = aux_unit.experiment.experiment_id
|
|
143
143
|
|
|
144
|
-
async def local_controller(
|
|
144
|
+
async def local_controller(
|
|
145
|
+
*args: P.args, **kwargs: P.kwargs
|
|
146
|
+
) -> T | tp.Awaitable[T]:
|
|
145
147
|
if asyncio.iscoroutinefunction(f):
|
|
146
148
|
return await f(aux_unit.experiment, *args, **kwargs)
|
|
147
149
|
else:
|
|
148
150
|
return f(aux_unit.experiment, *args, **kwargs)
|
|
149
151
|
|
|
150
|
-
async def remote_controller(
|
|
152
|
+
async def remote_controller(
|
|
153
|
+
*args: P.args, **kwargs: P.kwargs
|
|
154
|
+
) -> T | tp.Awaitable[T]:
|
|
151
155
|
async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
|
|
152
156
|
if asyncio.iscoroutinefunction(f):
|
|
153
157
|
return await f(exp, *args, **kwargs)
|
xm_slurm/job_blocks.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import typing as tp
|
|
2
2
|
|
|
3
3
|
from xmanager import xm
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class JobArgs(TypedDict, total=False):
|
|
6
|
+
class JobArgs(tp.TypedDict, total=False):
|
|
7
7
|
args: xm.UserArgs
|
|
8
|
-
env_vars: Mapping[str, str]
|
|
8
|
+
env_vars: tp.Mapping[str, str]
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def get_args_for_python_entrypoint(
|