xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import typing as tp
|
|
3
|
+
|
|
4
|
+
from xm_slurm import config
|
|
5
|
+
from xm_slurm.resources import FeatureType, ResourceType
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"fir",
|
|
9
|
+
"narval",
|
|
10
|
+
"nibi",
|
|
11
|
+
"rorqual",
|
|
12
|
+
"killarney",
|
|
13
|
+
"tamia",
|
|
14
|
+
"vulcan",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _drac_cluster(
|
|
19
|
+
*,
|
|
20
|
+
name: str,
|
|
21
|
+
host: str,
|
|
22
|
+
port: int = 22,
|
|
23
|
+
robot_host: str | None = None,
|
|
24
|
+
robot_port: int | None = None,
|
|
25
|
+
public_key: config.PublicKey,
|
|
26
|
+
user: str | None = None,
|
|
27
|
+
account: str | None = None,
|
|
28
|
+
modules: list[str] | None = None,
|
|
29
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
30
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
31
|
+
resources: tp.Mapping[ResourceType, str] | None = None,
|
|
32
|
+
features: tp.Mapping[FeatureType, str] | None = None,
|
|
33
|
+
) -> config.SlurmClusterConfig:
|
|
34
|
+
"""DRAC Cluster."""
|
|
35
|
+
if mounts is None:
|
|
36
|
+
mounts = {
|
|
37
|
+
"/scratch/$USER": "/scratch",
|
|
38
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
39
|
+
"/home/$USER/.ssh": "/home/$USER/.ssh",
|
|
40
|
+
"/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
41
|
+
}
|
|
42
|
+
mounts = dict(mounts) | {"/dev/infiniband": "/dev/infiniband"}
|
|
43
|
+
|
|
44
|
+
endpoints = []
|
|
45
|
+
if robot_host is not None and robot_host != host:
|
|
46
|
+
endpoints.append(config.Endpoint(robot_host, robot_port))
|
|
47
|
+
endpoints.append(config.Endpoint(host, port))
|
|
48
|
+
endpoints = tuple(endpoints)
|
|
49
|
+
|
|
50
|
+
return config.SlurmClusterConfig(
|
|
51
|
+
name=name,
|
|
52
|
+
ssh=config.SSHConfig(user=user, endpoints=endpoints, public_key=public_key),
|
|
53
|
+
account=account,
|
|
54
|
+
proxy=proxy,
|
|
55
|
+
runtime=config.ContainerRuntime.APPTAINER,
|
|
56
|
+
prolog=f"module load apptainer {' '.join(modules) if modules else ''}".rstrip(),
|
|
57
|
+
host_environment={
|
|
58
|
+
"XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
|
|
59
|
+
"APPTAINER_CACHEDIR": "$SCRATCH/.apptainer",
|
|
60
|
+
"APPTAINER_TMPDIR": "$SLURM_TMPDIR",
|
|
61
|
+
"APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
62
|
+
},
|
|
63
|
+
container_environment={
|
|
64
|
+
"SCRATCH": "/scratch",
|
|
65
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
66
|
+
},
|
|
67
|
+
mounts=mounts,
|
|
68
|
+
resources=resources or {},
|
|
69
|
+
features=features or {},
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def narval(
|
|
74
|
+
*,
|
|
75
|
+
user: str | None = None,
|
|
76
|
+
account: str | None = None,
|
|
77
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
78
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
79
|
+
) -> config.SlurmClusterConfig:
|
|
80
|
+
"""DRAC Narval Cluster (https://docs.alliancecan.ca/wiki/Narval/en)."""
|
|
81
|
+
modules = []
|
|
82
|
+
if proxy != "submission-host":
|
|
83
|
+
modules.append("httpproxy")
|
|
84
|
+
|
|
85
|
+
return _drac_cluster(
|
|
86
|
+
name="narval",
|
|
87
|
+
host="narval.alliancecan.ca",
|
|
88
|
+
robot_host="robot.narval.alliancecan.ca",
|
|
89
|
+
public_key=config.PublicKey(
|
|
90
|
+
"ssh-ed25519",
|
|
91
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAILFxB0spH5RApc43sBx0zOxo1ARVH0ezU+FbQH95FW+h",
|
|
92
|
+
),
|
|
93
|
+
user=user,
|
|
94
|
+
account=account,
|
|
95
|
+
mounts=mounts,
|
|
96
|
+
proxy=proxy,
|
|
97
|
+
modules=modules,
|
|
98
|
+
resources={ResourceType.A100: "a100"},
|
|
99
|
+
features={
|
|
100
|
+
FeatureType.NVIDIA_MIG: "a100mig",
|
|
101
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
102
|
+
},
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def rorqual(
|
|
107
|
+
*,
|
|
108
|
+
user: str | None = None,
|
|
109
|
+
account: str | None = None,
|
|
110
|
+
proxy: tp.Literal["submission-host"] | str | None = None,
|
|
111
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
112
|
+
) -> config.SlurmClusterConfig:
|
|
113
|
+
"""DRAC Rorqual Cluster (https://docs.alliancecan.ca/wiki/Rorqual/en)."""
|
|
114
|
+
modules = []
|
|
115
|
+
if proxy != "submission-host":
|
|
116
|
+
modules.append("httpproxy")
|
|
117
|
+
|
|
118
|
+
return _drac_cluster(
|
|
119
|
+
name="rorqual",
|
|
120
|
+
host="rorqual.alliancecan.ca",
|
|
121
|
+
robot_host="robot.rorqual.alliancecan.ca",
|
|
122
|
+
public_key=config.PublicKey(
|
|
123
|
+
"ssh-ed25519",
|
|
124
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAINME5e9bifKZbuKKOQSpe3xrvC4g1b0QLMYj+AXBQGJe",
|
|
125
|
+
),
|
|
126
|
+
user=user,
|
|
127
|
+
account=account,
|
|
128
|
+
mounts=mounts,
|
|
129
|
+
proxy=proxy,
|
|
130
|
+
modules=modules,
|
|
131
|
+
resources={ResourceType.H100: "h100"},
|
|
132
|
+
features={
|
|
133
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def fir(
|
|
139
|
+
*,
|
|
140
|
+
user: str | None = None,
|
|
141
|
+
account: str | None = None,
|
|
142
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
143
|
+
) -> config.SlurmClusterConfig:
|
|
144
|
+
"""DRAC Fir Cluster (https://docs.alliancecan.ca/wiki/Fir/en)."""
|
|
145
|
+
return _drac_cluster(
|
|
146
|
+
name="fir",
|
|
147
|
+
host="fir.alliancecan.ca",
|
|
148
|
+
robot_host="robot.fir.alliancecan.ca",
|
|
149
|
+
public_key=config.PublicKey(
|
|
150
|
+
"ssh-ed25519",
|
|
151
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIJtenyJz+inwobvlJntWYFNu+ANcVWNcOHRKcEN6zmDo",
|
|
152
|
+
),
|
|
153
|
+
user=user,
|
|
154
|
+
account=account,
|
|
155
|
+
mounts=mounts,
|
|
156
|
+
resources={ResourceType.H100: "h100"},
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def nibi(
|
|
161
|
+
*,
|
|
162
|
+
user: str | None = None,
|
|
163
|
+
account: str | None = None,
|
|
164
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
165
|
+
) -> config.SlurmClusterConfig:
|
|
166
|
+
"""DRAC Nibi Cluster (https://docs.alliancecan.ca/wiki/Nibi/en)."""
|
|
167
|
+
return _drac_cluster(
|
|
168
|
+
name="nibi",
|
|
169
|
+
host="nibi.alliancecan.ca",
|
|
170
|
+
robot_host="robot.nibi.alliancecan.ca",
|
|
171
|
+
public_key=config.PublicKey(
|
|
172
|
+
"ssh-ed25519",
|
|
173
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIEcmFoQZr6+KUHm/zm/BJpnNIlME7GytMxbHgfAUfoQX",
|
|
174
|
+
),
|
|
175
|
+
user=user,
|
|
176
|
+
account=account,
|
|
177
|
+
mounts=mounts,
|
|
178
|
+
resources={ResourceType.H100: "h100"},
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def killarney(
|
|
183
|
+
*,
|
|
184
|
+
user: str | None = None,
|
|
185
|
+
account: str | None = None,
|
|
186
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
187
|
+
) -> config.SlurmClusterConfig:
|
|
188
|
+
"""DRAC (PAICE) Killarney Cluster (https://docs.alliancecan.ca/wiki/Killarney/en)."""
|
|
189
|
+
return _drac_cluster(
|
|
190
|
+
name="killarney",
|
|
191
|
+
host="killarney.alliancecan.ca",
|
|
192
|
+
public_key=config.PublicKey(
|
|
193
|
+
"ssh-ed25519",
|
|
194
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIGlzaBBtvhJsSr23rMoY41gy8Svv1IOct8TBRH9CGuJf",
|
|
195
|
+
),
|
|
196
|
+
user=user,
|
|
197
|
+
account=account,
|
|
198
|
+
mounts=mounts,
|
|
199
|
+
resources={ResourceType.L40S: "l40s", ResourceType.H100: "h100"},
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def tamia(
|
|
204
|
+
*,
|
|
205
|
+
user: str | None = None,
|
|
206
|
+
account: str | None = None,
|
|
207
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
208
|
+
) -> config.SlurmClusterConfig:
|
|
209
|
+
"""DRAC (PAICE) Tamia Cluster (https://docs.alliancecan.ca/wiki/Tamia/en)."""
|
|
210
|
+
return _drac_cluster(
|
|
211
|
+
name="tamia",
|
|
212
|
+
host="tamia.alliancecan.ca",
|
|
213
|
+
public_key=config.PublicKey(
|
|
214
|
+
"ssh-ed25519",
|
|
215
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIN2wL9wOa0VveA/2l2ky/OhPsQfYtKuX99dyNnUTSYeU",
|
|
216
|
+
),
|
|
217
|
+
user=user,
|
|
218
|
+
account=account,
|
|
219
|
+
mounts=mounts,
|
|
220
|
+
resources={ResourceType.H100: "h100"},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def vulcan(
|
|
225
|
+
*,
|
|
226
|
+
user: str | None = None,
|
|
227
|
+
account: str | None = None,
|
|
228
|
+
mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
229
|
+
) -> config.SlurmClusterConfig:
|
|
230
|
+
"""DRAC (PAICE) Vulcan Cluster (https://docs.alliancecan.ca/wiki/Vulcan/en)."""
|
|
231
|
+
return _drac_cluster(
|
|
232
|
+
name="vulcan",
|
|
233
|
+
host="vulcan.alliancecan.ca",
|
|
234
|
+
public_key=config.PublicKey(
|
|
235
|
+
"ssh-ed25519",
|
|
236
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIMuIj6T45HqVeJgRotH9Qq46FzidekS2lXkD7FOTltnC",
|
|
237
|
+
),
|
|
238
|
+
user=user,
|
|
239
|
+
account=account,
|
|
240
|
+
mounts=mounts,
|
|
241
|
+
resources={ResourceType.L40S: "l40s"},
|
|
242
|
+
)
|
xm_slurm/dependencies.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import typing as tp
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SlurmDependencyException(Exception): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
NoChainingException = SlurmDependencyException(
|
|
11
|
+
"Slurm only supports chaining dependencies with the same logical operator. "
|
|
12
|
+
"For example, `dep1 & dep2 | dep3` is not supported but `dep1 & dep2 & dep3` is."
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SlurmJobDependency(abc.ABC):
|
|
17
|
+
@abc.abstractmethod
|
|
18
|
+
def to_dependency_str(self) -> str: ...
|
|
19
|
+
|
|
20
|
+
def to_directive(self) -> str:
|
|
21
|
+
return f"--dependency={self.to_dependency_str()}"
|
|
22
|
+
|
|
23
|
+
def __and__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyAND":
|
|
24
|
+
if isinstance(self, SlurmJobDependencyOR):
|
|
25
|
+
raise NoChainingException
|
|
26
|
+
return SlurmJobDependencyAND(self, other_dependency)
|
|
27
|
+
|
|
28
|
+
def __or__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyOR":
|
|
29
|
+
if isinstance(other_dependency, SlurmJobDependencyAND):
|
|
30
|
+
raise NoChainingException
|
|
31
|
+
return SlurmJobDependencyOR(self, other_dependency)
|
|
32
|
+
|
|
33
|
+
def flatten(self) -> tuple["SlurmJobDependency", ...]:
|
|
34
|
+
if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
|
|
35
|
+
return self.first_dependency.flatten() + self.second_dependency.flatten()
|
|
36
|
+
return (self,)
|
|
37
|
+
|
|
38
|
+
def traverse(
|
|
39
|
+
self, mapper: tp.Callable[["SlurmJobDependency"], "SlurmJobDependency"]
|
|
40
|
+
) -> "SlurmJobDependency":
|
|
41
|
+
if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
|
|
42
|
+
return type(self)(
|
|
43
|
+
first_dependency=self.first_dependency.traverse(mapper),
|
|
44
|
+
second_dependency=self.second_dependency.traverse(mapper),
|
|
45
|
+
)
|
|
46
|
+
return mapper(self)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass(frozen=True)
|
|
50
|
+
class SlurmJobDependencyAND(SlurmJobDependency):
|
|
51
|
+
first_dependency: SlurmJobDependency
|
|
52
|
+
second_dependency: SlurmJobDependency
|
|
53
|
+
|
|
54
|
+
def to_dependency_str(self) -> str:
|
|
55
|
+
return f"{self.first_dependency.to_dependency_str()},{self.second_dependency.to_dependency_str()}"
|
|
56
|
+
|
|
57
|
+
def __or__(self, other_dependency: SlurmJobDependency):
|
|
58
|
+
del other_dependency
|
|
59
|
+
raise NoChainingException
|
|
60
|
+
|
|
61
|
+
def __hash__(self) -> int:
|
|
62
|
+
return hash((type(self), self.first_dependency, self.second_dependency))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclasses.dataclass(frozen=True)
|
|
66
|
+
class SlurmJobDependencyOR(SlurmJobDependency):
|
|
67
|
+
first_dependency: SlurmJobDependency
|
|
68
|
+
second_dependency: SlurmJobDependency
|
|
69
|
+
|
|
70
|
+
def to_dependency_str(self) -> str:
|
|
71
|
+
return f"{self.first_dependency.to_dependency_str()}?{self.second_dependency.to_dependency_str()}"
|
|
72
|
+
|
|
73
|
+
def __and__(self, other_dependency: SlurmJobDependency):
|
|
74
|
+
del other_dependency
|
|
75
|
+
raise NoChainingException
|
|
76
|
+
|
|
77
|
+
def __hash__(self) -> int:
|
|
78
|
+
return hash((type(self), self.first_dependency, self.second_dependency))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclasses.dataclass(frozen=True)
|
|
82
|
+
class SlurmJobDependencyAfter(SlurmJobDependency):
|
|
83
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
84
|
+
time: dt.timedelta | None = None
|
|
85
|
+
|
|
86
|
+
def __post_init__(self):
|
|
87
|
+
if len(self.handles) == 0:
|
|
88
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
89
|
+
if self.time is not None and self.time.total_seconds() % 60 != 0:
|
|
90
|
+
raise SlurmDependencyException("Time must be specified in exact minutes")
|
|
91
|
+
|
|
92
|
+
def to_dependency_str(self) -> str:
|
|
93
|
+
directive = "after"
|
|
94
|
+
|
|
95
|
+
for handle in self.handles:
|
|
96
|
+
directive += f":{handle.slurm_job.job_id}"
|
|
97
|
+
if self.time is not None:
|
|
98
|
+
directive += f"+{self.time.total_seconds() // 60:.0f}"
|
|
99
|
+
return directive
|
|
100
|
+
|
|
101
|
+
def __hash__(self) -> int:
|
|
102
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclasses.dataclass(frozen=True)
|
|
106
|
+
class SlurmJobDependencyAfterAny(SlurmJobDependency):
|
|
107
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
if len(self.handles) == 0:
|
|
111
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
112
|
+
|
|
113
|
+
def to_dependency_str(self) -> str:
|
|
114
|
+
return ":".join(["afterany"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
115
|
+
|
|
116
|
+
def __hash__(self) -> int:
|
|
117
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclasses.dataclass(frozen=True)
|
|
121
|
+
class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
|
|
122
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
123
|
+
|
|
124
|
+
def __post_init__(self):
|
|
125
|
+
if len(self.handles) == 0:
|
|
126
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
127
|
+
|
|
128
|
+
def to_dependency_str(self) -> str:
|
|
129
|
+
return ":".join(["afternotok"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
130
|
+
|
|
131
|
+
def __hash__(self) -> int:
|
|
132
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclasses.dataclass(frozen=True)
|
|
136
|
+
class SlurmJobDependencyAfterOK(SlurmJobDependency):
|
|
137
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
138
|
+
|
|
139
|
+
def __post_init__(self):
|
|
140
|
+
if len(self.handles) == 0:
|
|
141
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
142
|
+
|
|
143
|
+
def to_dependency_str(self) -> str:
|
|
144
|
+
return ":".join(["afterok"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
145
|
+
|
|
146
|
+
def __hash__(self) -> int:
|
|
147
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclasses.dataclass(frozen=True)
|
|
151
|
+
class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
|
|
152
|
+
handles: tp.Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
|
|
153
|
+
|
|
154
|
+
def __post_init__(self):
|
|
155
|
+
if len(self.handles) == 0:
|
|
156
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
157
|
+
|
|
158
|
+
def to_dependency_str(self) -> str:
|
|
159
|
+
job_ids = []
|
|
160
|
+
for handle in self.handles:
|
|
161
|
+
job = handle.slurm_job
|
|
162
|
+
if job.is_array_job:
|
|
163
|
+
job_ids.append(job.array_job_id)
|
|
164
|
+
elif job.is_heterogeneous_job:
|
|
165
|
+
job_ids.append(job.het_job_id)
|
|
166
|
+
else:
|
|
167
|
+
job_ids.append(job.job_id)
|
|
168
|
+
return ":".join(["aftercorr"] + job_ids)
|
|
169
|
+
|
|
170
|
+
def __hash__(self) -> int:
|
|
171
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
xm_slurm/executables.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import os
|
|
3
|
+
import pathlib
|
|
4
|
+
import typing as tp
|
|
5
|
+
|
|
6
|
+
from xmanager import xm
|
|
7
|
+
|
|
8
|
+
from xm_slurm import constants
|
|
9
|
+
from xm_slurm.types import Descriptor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
13
|
+
class Dockerfile(xm.ExecutableSpec):
|
|
14
|
+
"""A specification describing a Dockerfile to build.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
dockerfile: The path to the Dockerfile.
|
|
18
|
+
context: The path to the Docker context.
|
|
19
|
+
target: The Docker build target.
|
|
20
|
+
ssh: A list of docker SSH sockets/keys.
|
|
21
|
+
build_args: Build arguments to docker.
|
|
22
|
+
cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
|
|
23
|
+
labels: The container labels.
|
|
24
|
+
platforms: The target platform.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# Dockerfile
|
|
28
|
+
dockerfile: pathlib.Path
|
|
29
|
+
# Docker context
|
|
30
|
+
context: pathlib.Path
|
|
31
|
+
|
|
32
|
+
# Docker build target
|
|
33
|
+
target: str | None = None
|
|
34
|
+
|
|
35
|
+
# SSH sockets/keys for the docker build step.
|
|
36
|
+
ssh: tp.Sequence[str] = dataclasses.field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
# Build arguments to docker
|
|
39
|
+
build_args: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
# --cache-from field in BuildKit
|
|
42
|
+
cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
# Container labels
|
|
45
|
+
labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
# Target platform
|
|
48
|
+
platforms: tp.Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def name(self) -> str:
|
|
52
|
+
name = self.dockerfile.stem
|
|
53
|
+
if self.target is not None:
|
|
54
|
+
name = f"{name}-{self.target}"
|
|
55
|
+
return name
|
|
56
|
+
|
|
57
|
+
def __hash__(self) -> int:
|
|
58
|
+
return hash((
|
|
59
|
+
type(self),
|
|
60
|
+
self.dockerfile,
|
|
61
|
+
self.context,
|
|
62
|
+
self.target,
|
|
63
|
+
tuple(sorted(self.ssh)),
|
|
64
|
+
tuple(sorted(self.build_args.items())),
|
|
65
|
+
tuple(sorted(self.cache_from)),
|
|
66
|
+
tuple(sorted(self.labels.items())),
|
|
67
|
+
tuple(sorted(self.platforms)),
|
|
68
|
+
))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
72
|
+
class DockerImage(xm.ExecutableSpec):
|
|
73
|
+
"""A specification describing a pre-built Docker image.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
image: The remote image URI.
|
|
77
|
+
workdir: The working directory in container.
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
image: str
|
|
82
|
+
|
|
83
|
+
# Working directory in container
|
|
84
|
+
workdir: pathlib.Path | None = None
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def name(self) -> str:
|
|
88
|
+
return self.image
|
|
89
|
+
|
|
90
|
+
def __hash__(self) -> int:
|
|
91
|
+
return hash((type(self), self.image, self.workdir))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclasses.dataclass
|
|
95
|
+
class ImageURI:
|
|
96
|
+
image: dataclasses.InitVar[str]
|
|
97
|
+
|
|
98
|
+
scheme: str | None = dataclasses.field(init=False, default=None)
|
|
99
|
+
domain: str = dataclasses.field(init=False)
|
|
100
|
+
path: str = dataclasses.field(init=False)
|
|
101
|
+
tag: str | None = dataclasses.field(init=False, default=None)
|
|
102
|
+
digest: str | None = dataclasses.field(init=False, default=None)
|
|
103
|
+
|
|
104
|
+
def __post_init__(self, image: str):
|
|
105
|
+
match = constants.IMAGE_URI_REGEX.match(image)
|
|
106
|
+
if not match:
|
|
107
|
+
raise ValueError(f"Invalid OCI image URI: {image}")
|
|
108
|
+
groups = {k: v for k, v in match.groupdict().items() if v is not None}
|
|
109
|
+
for k, v in groups.items():
|
|
110
|
+
setattr(self, k, v)
|
|
111
|
+
|
|
112
|
+
if self.tag is None and self.digest is None:
|
|
113
|
+
self.tag = "latest"
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def locator(self) -> str:
|
|
117
|
+
"""Unique locator for this image.
|
|
118
|
+
|
|
119
|
+
Locator will return the digest if it exists otherwise the tag format.
|
|
120
|
+
If neither are present, it will raise an AssertionError.
|
|
121
|
+
"""
|
|
122
|
+
if self.digest is not None:
|
|
123
|
+
return f"@{self.digest}"
|
|
124
|
+
assert self.tag is not None
|
|
125
|
+
return f":{self.tag}"
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def url(self) -> str:
|
|
129
|
+
"""URL for this image without the locator."""
|
|
130
|
+
return f"{self.origin}{self.path}"
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def origin(self) -> str:
|
|
134
|
+
return f"{self.scheme}{self.domain}"
|
|
135
|
+
|
|
136
|
+
def with_tag(self, tag: str) -> "ImageURI":
|
|
137
|
+
self.tag = tag
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
def with_digest(self, digest: str) -> "ImageURI":
|
|
141
|
+
self.digest = digest
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
def __str__(self) -> str:
|
|
145
|
+
return self.format("{url}{locator}")
|
|
146
|
+
|
|
147
|
+
def __hash__(self) -> int:
|
|
148
|
+
return hash((
|
|
149
|
+
type(self),
|
|
150
|
+
self.scheme,
|
|
151
|
+
self.domain,
|
|
152
|
+
self.path,
|
|
153
|
+
self.tag,
|
|
154
|
+
self.digest,
|
|
155
|
+
))
|
|
156
|
+
|
|
157
|
+
def format(self, format: str) -> str:
|
|
158
|
+
fields = {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
|
|
159
|
+
fields |= {"locator": self.locator, "url": self.url}
|
|
160
|
+
return format.format(**fields)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ImageDescriptor(Descriptor[ImageURI, str | ImageURI]):
|
|
164
|
+
def __set_name__(self, owner: type, name: str):
|
|
165
|
+
del owner
|
|
166
|
+
self.image = f"_{name}"
|
|
167
|
+
|
|
168
|
+
def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> ImageURI:
|
|
169
|
+
del owner
|
|
170
|
+
return getattr(instance, self.image)
|
|
171
|
+
|
|
172
|
+
def __set__(self, instance: object, value: str | ImageURI):
|
|
173
|
+
_setattr = object.__setattr__ if not hasattr(instance, self.image) else setattr
|
|
174
|
+
if isinstance(value, str):
|
|
175
|
+
value = ImageURI(value)
|
|
176
|
+
_setattr(instance, self.image, value)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class RemoteRepositoryCredentials(tp.NamedTuple):
|
|
180
|
+
username: str
|
|
181
|
+
password: str
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclasses.dataclass(frozen=True, kw_only=True) # type: ignore
|
|
185
|
+
class RemoteImage(xm.Executable):
|
|
186
|
+
# Remote base image
|
|
187
|
+
image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
|
|
188
|
+
|
|
189
|
+
workdir: os.PathLike[str] | str
|
|
190
|
+
entrypoint: xm.SequentialArgs
|
|
191
|
+
|
|
192
|
+
# Container arguments
|
|
193
|
+
args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
|
|
194
|
+
# Container environment variables
|
|
195
|
+
env_vars: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
196
|
+
|
|
197
|
+
# Remote repository credentials
|
|
198
|
+
credentials: RemoteRepositoryCredentials | None = None
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def name(self) -> str: # type: ignore
|
|
202
|
+
return str(self.image)
|
|
203
|
+
|
|
204
|
+
def __hash__(self) -> int:
|
|
205
|
+
return hash(
|
|
206
|
+
(
|
|
207
|
+
type(self),
|
|
208
|
+
self.image,
|
|
209
|
+
self.workdir,
|
|
210
|
+
tuple(sorted(self.entrypoint.to_list())),
|
|
211
|
+
tuple(sorted(self.args.to_list())),
|
|
212
|
+
tuple(sorted(self.env_vars.items())),
|
|
213
|
+
self.credentials,
|
|
214
|
+
),
|
|
215
|
+
)
|