xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +6 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +105 -55
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +47 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +34 -22
- xm_slurm/execution.py +305 -107
- xm_slurm/executors.py +8 -12
- xm_slurm/experiment.py +601 -168
- xm_slurm/experimental/parameter_controller.py +202 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +42 -20
- xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
- xm_slurm/packaging/router.py +3 -1
- xm_slurm/packaging/utils.py +9 -81
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +52 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/mamba.Dockerfile +4 -2
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
- xm_slurm/templates/slurm/job.bash.j2 +4 -3
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
- {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/__init__.py +0 -75
- xm_slurm/packaging/docker/abc.py +0 -112
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/constants.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
SLURM_JOB_ID_REGEX = re.compile(
|
|
4
|
+
r"^(?P<jobid>\d+)(?:(?:\+(?P<componentid>\d+))|(?:_(?P<arraytaskid>\d+)))?$"
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
IMAGE_URI_REGEX = re.compile(
|
|
8
|
+
r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
DOMAIN_NAME_REGEX = re.compile(
|
|
12
|
+
r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
|
|
13
|
+
)
|
|
14
|
+
IPV4_REGEX = re.compile(
|
|
15
|
+
r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
|
16
|
+
)
|
|
17
|
+
IPV6_REGEX = re.compile(
|
|
18
|
+
r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
|
|
19
|
+
)
|
|
File without changes
|
|
@@ -1,13 +1,20 @@
|
|
|
1
|
+
import datetime as dt
|
|
2
|
+
import logging
|
|
1
3
|
import os
|
|
2
4
|
|
|
5
|
+
from xmanager import xm
|
|
6
|
+
|
|
3
7
|
from xm_slurm import config, resources
|
|
4
8
|
from xm_slurm.contrib.clusters import drac
|
|
9
|
+
from xm_slurm.executors import Slurm
|
|
5
10
|
|
|
6
11
|
# ComputeCanada alias
|
|
7
12
|
cc = drac
|
|
8
13
|
|
|
9
14
|
__all__ = ["drac", "mila", "cc"]
|
|
10
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
11
18
|
|
|
12
19
|
def mila(
|
|
13
20
|
*,
|
|
@@ -19,18 +26,39 @@ def mila(
|
|
|
19
26
|
if mounts is None:
|
|
20
27
|
mounts = {
|
|
21
28
|
"/network/scratch/${USER:0:1}/$USER": "/scratch",
|
|
22
|
-
|
|
29
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
30
|
+
"/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
31
|
+
"/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
|
|
23
32
|
}
|
|
24
33
|
|
|
34
|
+
def validate(job: xm.Job) -> None:
|
|
35
|
+
assert isinstance(job.executor, Slurm)
|
|
36
|
+
|
|
37
|
+
wants_requeue_with_grace_period = (
|
|
38
|
+
job.executor.requeue and job.executor.timeout_signal_grace_period > dt.timedelta(0)
|
|
39
|
+
)
|
|
40
|
+
partition = job.executor.partition or "main"
|
|
41
|
+
|
|
42
|
+
if wants_requeue_with_grace_period and (
|
|
43
|
+
partition is None or not partition.endswith("-grace")
|
|
44
|
+
):
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Job {job.name} wants requeue with grace period, but partition `{partition}` does not end with '-grace'. "
|
|
47
|
+
"Mila Cluster requires you specify a grace partition. "
|
|
48
|
+
"This may result in the job not being requeued properly."
|
|
49
|
+
)
|
|
50
|
+
|
|
25
51
|
return config.SlurmClusterConfig(
|
|
26
52
|
name="mila",
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
53
|
+
ssh=config.SlurmSSHConfig(
|
|
54
|
+
user=user,
|
|
55
|
+
host="login.server.mila.quebec",
|
|
56
|
+
host_public_key=config.PublicKey(
|
|
57
|
+
"ssh-ed25519",
|
|
58
|
+
"AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
|
|
59
|
+
),
|
|
60
|
+
port=2222,
|
|
32
61
|
),
|
|
33
|
-
port=2222,
|
|
34
62
|
runtime=config.ContainerRuntime.SINGULARITY,
|
|
35
63
|
partition=partition,
|
|
36
64
|
prolog="module load singularity",
|
|
@@ -39,14 +67,20 @@ def mila(
|
|
|
39
67
|
"SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
|
|
40
68
|
"SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
41
69
|
"SCRATCH": "/scratch",
|
|
42
|
-
|
|
70
|
+
# TODO: move this somewhere common to all cluster configs.
|
|
71
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
43
72
|
},
|
|
44
73
|
mounts=mounts,
|
|
45
74
|
resources={
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
75
|
+
resources.ResourceType.RTX8000: "rtx8000",
|
|
76
|
+
resources.ResourceType.V100: "v100",
|
|
77
|
+
resources.ResourceType.A100: "a100",
|
|
78
|
+
resources.ResourceType.A100_80GIB: "a100l",
|
|
79
|
+
resources.ResourceType.A6000: "a6000",
|
|
80
|
+
},
|
|
81
|
+
features={
|
|
82
|
+
resources.FeatureType.NVIDIA_MIG: "mig",
|
|
83
|
+
resources.FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
51
84
|
},
|
|
85
|
+
validate=validate,
|
|
52
86
|
)
|
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
from typing import Literal
|
|
3
3
|
|
|
4
4
|
from xm_slurm import config
|
|
5
|
-
from xm_slurm.resources import ResourceType
|
|
5
|
+
from xm_slurm.resources import FeatureType, ResourceType
|
|
6
6
|
|
|
7
7
|
__all__ = ["narval", "beluga", "cedar", "graham"]
|
|
8
8
|
|
|
@@ -18,18 +18,26 @@ def _drac_cluster(
|
|
|
18
18
|
modules: list[str] | None = None,
|
|
19
19
|
proxy: Literal["submission-host"] | str | None = None,
|
|
20
20
|
mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
|
|
21
|
-
resources: dict[
|
|
21
|
+
resources: dict[ResourceType, str] | None = None,
|
|
22
|
+
features: dict[FeatureType, str] | None = None,
|
|
22
23
|
) -> config.SlurmClusterConfig:
|
|
23
24
|
"""DRAC Cluster."""
|
|
24
25
|
if mounts is None:
|
|
25
|
-
mounts = {
|
|
26
|
+
mounts = {
|
|
27
|
+
"/scratch/$USER": "/scratch",
|
|
28
|
+
# TODO: move these somewhere common to all cluster configs.
|
|
29
|
+
"/home/$USER/.ssh": "/home/$USER/.ssh",
|
|
30
|
+
"/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
|
|
31
|
+
}
|
|
26
32
|
|
|
27
33
|
return config.SlurmClusterConfig(
|
|
28
34
|
name=name,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
35
|
+
ssh=config.SlurmSSHConfig(
|
|
36
|
+
user=user,
|
|
37
|
+
host=host,
|
|
38
|
+
host_public_key=host_public_key,
|
|
39
|
+
port=port,
|
|
40
|
+
),
|
|
33
41
|
account=account,
|
|
34
42
|
proxy=proxy,
|
|
35
43
|
runtime=config.ContainerRuntime.APPTAINER,
|
|
@@ -40,9 +48,12 @@ def _drac_cluster(
|
|
|
40
48
|
"APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
|
|
41
49
|
"_XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
|
|
42
50
|
"SCRATCH": "/scratch",
|
|
51
|
+
# TODO: move this somewhere common to all cluster configs.
|
|
52
|
+
"XM_SLURM_STATE_DIR": "/xm-slurm-state",
|
|
43
53
|
},
|
|
44
54
|
mounts=mounts,
|
|
45
55
|
resources=resources or {},
|
|
56
|
+
features=features or {},
|
|
46
57
|
)
|
|
47
58
|
|
|
48
59
|
|
|
@@ -70,7 +81,11 @@ def narval(
|
|
|
70
81
|
mounts=mounts,
|
|
71
82
|
proxy=proxy,
|
|
72
83
|
modules=modules,
|
|
73
|
-
resources={"a100"
|
|
84
|
+
resources={ResourceType.A100: "a100"},
|
|
85
|
+
features={
|
|
86
|
+
FeatureType.NVIDIA_MIG: "a100mig",
|
|
87
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
88
|
+
},
|
|
74
89
|
)
|
|
75
90
|
|
|
76
91
|
|
|
@@ -98,7 +113,10 @@ def beluga(
|
|
|
98
113
|
mounts=mounts,
|
|
99
114
|
proxy=proxy,
|
|
100
115
|
modules=modules,
|
|
101
|
-
resources={"tesla_v100-sxm2-16gb"
|
|
116
|
+
resources={ResourceType.V100: "tesla_v100-sxm2-16gb"},
|
|
117
|
+
features={
|
|
118
|
+
FeatureType.NVIDIA_NVLINK: "nvlink",
|
|
119
|
+
},
|
|
102
120
|
)
|
|
103
121
|
|
|
104
122
|
|
|
@@ -120,9 +138,9 @@ def cedar(
|
|
|
120
138
|
account=account,
|
|
121
139
|
mounts=mounts,
|
|
122
140
|
resources={
|
|
123
|
-
"v100l"
|
|
124
|
-
"p100"
|
|
125
|
-
"p100l"
|
|
141
|
+
ResourceType.V100_32GIB: "v100l",
|
|
142
|
+
ResourceType.P100: "p100",
|
|
143
|
+
ResourceType.P100_16GIB: "p100l",
|
|
126
144
|
},
|
|
127
145
|
)
|
|
128
146
|
|
|
@@ -147,10 +165,10 @@ def graham(
|
|
|
147
165
|
mounts=mounts,
|
|
148
166
|
proxy=proxy,
|
|
149
167
|
resources={
|
|
150
|
-
"v100"
|
|
151
|
-
"p100"
|
|
152
|
-
"a100"
|
|
153
|
-
"a5000"
|
|
168
|
+
ResourceType.V100: "v100",
|
|
169
|
+
ResourceType.P100: "p100",
|
|
170
|
+
ResourceType.A100: "a100",
|
|
171
|
+
ResourceType.A5000: "a5000",
|
|
154
172
|
},
|
|
155
173
|
)
|
|
156
174
|
|
xm_slurm/dependencies.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
import datetime as dt
|
|
4
|
+
from typing import Callable, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SlurmDependencyException(Exception): ...
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
NoChainingException = SlurmDependencyException(
|
|
11
|
+
"Slurm only supports chaining dependencies with the same logical operator. "
|
|
12
|
+
"For example, `dep1 & dep2 | dep3` is not supported but `dep1 & dep2 & dep3` is."
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SlurmJobDependency(abc.ABC):
|
|
17
|
+
@abc.abstractmethod
|
|
18
|
+
def to_dependency_str(self) -> str: ...
|
|
19
|
+
|
|
20
|
+
def to_directive(self) -> str:
|
|
21
|
+
return f"--dependency={self.to_dependency_str()}"
|
|
22
|
+
|
|
23
|
+
def __and__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyAND":
|
|
24
|
+
if isinstance(self, SlurmJobDependencyOR):
|
|
25
|
+
raise NoChainingException
|
|
26
|
+
return SlurmJobDependencyAND(self, other_dependency)
|
|
27
|
+
|
|
28
|
+
def __or__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyOR":
|
|
29
|
+
if isinstance(other_dependency, SlurmJobDependencyAND):
|
|
30
|
+
raise NoChainingException
|
|
31
|
+
return SlurmJobDependencyOR(self, other_dependency)
|
|
32
|
+
|
|
33
|
+
def flatten(self) -> tuple["SlurmJobDependency", ...]:
|
|
34
|
+
if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
|
|
35
|
+
return self.first_dependency.flatten() + self.second_dependency.flatten()
|
|
36
|
+
return (self,)
|
|
37
|
+
|
|
38
|
+
def traverse(
|
|
39
|
+
self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
|
|
40
|
+
) -> "SlurmJobDependency":
|
|
41
|
+
if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
|
|
42
|
+
return type(self)(
|
|
43
|
+
first_dependency=self.first_dependency.traverse(mapper),
|
|
44
|
+
second_dependency=self.second_dependency.traverse(mapper),
|
|
45
|
+
)
|
|
46
|
+
return mapper(self)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclasses.dataclass(frozen=True)
|
|
50
|
+
class SlurmJobDependencyAND(SlurmJobDependency):
|
|
51
|
+
first_dependency: SlurmJobDependency
|
|
52
|
+
second_dependency: SlurmJobDependency
|
|
53
|
+
|
|
54
|
+
def to_dependency_str(self) -> str:
|
|
55
|
+
return f"{self.first_dependency.to_dependency_str()},{self.second_dependency.to_dependency_str()}"
|
|
56
|
+
|
|
57
|
+
def __or__(self, other_dependency: SlurmJobDependency):
|
|
58
|
+
del other_dependency
|
|
59
|
+
raise NoChainingException
|
|
60
|
+
|
|
61
|
+
def __hash__(self) -> int:
|
|
62
|
+
return hash((type(self), self.first_dependency, self.second_dependency))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclasses.dataclass(frozen=True)
|
|
66
|
+
class SlurmJobDependencyOR(SlurmJobDependency):
|
|
67
|
+
first_dependency: SlurmJobDependency
|
|
68
|
+
second_dependency: SlurmJobDependency
|
|
69
|
+
|
|
70
|
+
def to_dependency_str(self) -> str:
|
|
71
|
+
return f"{self.first_dependency.to_dependency_str()}?{self.second_dependency.to_dependency_str()}"
|
|
72
|
+
|
|
73
|
+
def __and__(self, other_dependency: SlurmJobDependency):
|
|
74
|
+
del other_dependency
|
|
75
|
+
raise NoChainingException
|
|
76
|
+
|
|
77
|
+
def __hash__(self) -> int:
|
|
78
|
+
return hash((type(self), self.first_dependency, self.second_dependency))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclasses.dataclass(frozen=True)
|
|
82
|
+
class SlurmJobDependencyAfter(SlurmJobDependency):
|
|
83
|
+
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
84
|
+
time: dt.timedelta | None = None
|
|
85
|
+
|
|
86
|
+
def __post_init__(self):
|
|
87
|
+
if len(self.handles) == 0:
|
|
88
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
89
|
+
if self.time is not None and self.time.total_seconds() % 60 != 0:
|
|
90
|
+
raise SlurmDependencyException("Time must be specified in exact minutes")
|
|
91
|
+
|
|
92
|
+
def to_dependency_str(self) -> str:
|
|
93
|
+
directive = "after"
|
|
94
|
+
|
|
95
|
+
for handle in self.handles:
|
|
96
|
+
directive += f":{handle.slurm_job.job_id}"
|
|
97
|
+
if self.time is not None:
|
|
98
|
+
directive += f"+{self.time.total_seconds() // 60:.0f}"
|
|
99
|
+
return directive
|
|
100
|
+
|
|
101
|
+
def __hash__(self) -> int:
|
|
102
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclasses.dataclass(frozen=True)
|
|
106
|
+
class SlurmJobDependencyAfterAny(SlurmJobDependency):
|
|
107
|
+
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
108
|
+
|
|
109
|
+
def __post_init__(self):
|
|
110
|
+
if len(self.handles) == 0:
|
|
111
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
112
|
+
|
|
113
|
+
def to_dependency_str(self) -> str:
|
|
114
|
+
return ":".join(["afterany"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
115
|
+
|
|
116
|
+
def __hash__(self) -> int:
|
|
117
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclasses.dataclass(frozen=True)
|
|
121
|
+
class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
|
|
122
|
+
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
123
|
+
|
|
124
|
+
def __post_init__(self):
|
|
125
|
+
if len(self.handles) == 0:
|
|
126
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
127
|
+
|
|
128
|
+
def to_dependency_str(self) -> str:
|
|
129
|
+
return ":".join(["afternotok"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
130
|
+
|
|
131
|
+
def __hash__(self) -> int:
|
|
132
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclasses.dataclass(frozen=True)
|
|
136
|
+
class SlurmJobDependencyAfterOK(SlurmJobDependency):
|
|
137
|
+
handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
|
|
138
|
+
|
|
139
|
+
def __post_init__(self):
|
|
140
|
+
if len(self.handles) == 0:
|
|
141
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
142
|
+
|
|
143
|
+
def to_dependency_str(self) -> str:
|
|
144
|
+
return ":".join(["afterok"] + [handle.slurm_job.job_id for handle in self.handles])
|
|
145
|
+
|
|
146
|
+
def __hash__(self) -> int:
|
|
147
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclasses.dataclass(frozen=True)
|
|
151
|
+
class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
|
|
152
|
+
handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
|
|
153
|
+
|
|
154
|
+
def __post_init__(self):
|
|
155
|
+
if len(self.handles) == 0:
|
|
156
|
+
raise SlurmDependencyException("Dependency doesn't have any handles.")
|
|
157
|
+
|
|
158
|
+
def to_dependency_str(self) -> str:
|
|
159
|
+
job_ids = []
|
|
160
|
+
for handle in self.handles:
|
|
161
|
+
job = handle.slurm_job
|
|
162
|
+
if job.is_array_job:
|
|
163
|
+
job_ids.append(job.array_job_id)
|
|
164
|
+
elif job.is_heterogeneous_job:
|
|
165
|
+
job_ids.append(job.het_job_id)
|
|
166
|
+
else:
|
|
167
|
+
job_ids.append(job.job_id)
|
|
168
|
+
return ":".join(["aftercorr"] + job_ids)
|
|
169
|
+
|
|
170
|
+
def __hash__(self) -> int:
|
|
171
|
+
return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
|
xm_slurm/executables.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import pathlib
|
|
3
|
-
import
|
|
4
|
-
from typing import Mapping, NamedTuple, Sequence
|
|
3
|
+
import typing as tp
|
|
5
4
|
|
|
6
5
|
from xmanager import xm
|
|
7
6
|
|
|
7
|
+
from xm_slurm import constants
|
|
8
|
+
from xm_slurm.types import Descriptor
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
10
12
|
class Dockerfile(xm.ExecutableSpec):
|
|
@@ -31,22 +33,22 @@ class Dockerfile(xm.ExecutableSpec):
|
|
|
31
33
|
target: str | None = None
|
|
32
34
|
|
|
33
35
|
# SSH sockets/keys for the docker build step.
|
|
34
|
-
ssh:
|
|
36
|
+
ssh: tp.Sequence[str] = dataclasses.field(default_factory=list)
|
|
35
37
|
|
|
36
38
|
# Build arguments to docker
|
|
37
|
-
build_args: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
39
|
+
build_args: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
38
40
|
|
|
39
41
|
# --cache-from field in BuildKit
|
|
40
|
-
cache_from: Sequence[str] = dataclasses.field(default_factory=list)
|
|
42
|
+
cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
|
|
41
43
|
|
|
42
44
|
# Working directory in container
|
|
43
45
|
workdir: pathlib.Path | None = None
|
|
44
46
|
|
|
45
47
|
# Container labels
|
|
46
|
-
labels: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
48
|
+
labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
47
49
|
|
|
48
50
|
# Target platform
|
|
49
|
-
platforms: Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
|
|
51
|
+
platforms: tp.Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
|
|
50
52
|
|
|
51
53
|
@property
|
|
52
54
|
def name(self) -> str:
|
|
@@ -57,6 +59,7 @@ class Dockerfile(xm.ExecutableSpec):
|
|
|
57
59
|
|
|
58
60
|
def __hash__(self) -> int:
|
|
59
61
|
return hash((
|
|
62
|
+
type(self),
|
|
60
63
|
self.dockerfile,
|
|
61
64
|
self.context,
|
|
62
65
|
self.target,
|
|
@@ -89,12 +92,7 @@ class DockerImage(xm.ExecutableSpec):
|
|
|
89
92
|
return self.image
|
|
90
93
|
|
|
91
94
|
def __hash__(self) -> int:
|
|
92
|
-
return hash((self.image, self.workdir))
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
_IMAGE_URI_REGEX = re.compile(
|
|
96
|
-
r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
|
|
97
|
-
)
|
|
95
|
+
return hash((type(self), self.image, self.workdir))
|
|
98
96
|
|
|
99
97
|
|
|
100
98
|
@dataclasses.dataclass
|
|
@@ -108,7 +106,7 @@ class ImageURI:
|
|
|
108
106
|
digest: str | None = dataclasses.field(init=False, default=None)
|
|
109
107
|
|
|
110
108
|
def __post_init__(self, image: str):
|
|
111
|
-
match =
|
|
109
|
+
match = constants.IMAGE_URI_REGEX.match(image)
|
|
112
110
|
if not match:
|
|
113
111
|
raise ValueError(f"Invalid OCI image URI: {image}")
|
|
114
112
|
groups = {k: v for k, v in match.groupdict().items() if v is not None}
|
|
@@ -152,6 +150,7 @@ class ImageURI:
|
|
|
152
150
|
|
|
153
151
|
def __hash__(self) -> int:
|
|
154
152
|
return hash((
|
|
153
|
+
type(self),
|
|
155
154
|
self.scheme,
|
|
156
155
|
self.domain,
|
|
157
156
|
self.path,
|
|
@@ -165,30 +164,31 @@ class ImageURI:
|
|
|
165
164
|
return format.format(**fields)
|
|
166
165
|
|
|
167
166
|
|
|
168
|
-
class ImageDescriptor:
|
|
167
|
+
class ImageDescriptor(Descriptor[ImageURI, str | ImageURI]):
|
|
169
168
|
def __set_name__(self, owner: type, name: str):
|
|
170
169
|
del owner
|
|
171
170
|
self.image = f"_{name}"
|
|
172
171
|
|
|
173
|
-
def __get__(self, instance: object, owner:
|
|
172
|
+
def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> ImageURI:
|
|
174
173
|
del owner
|
|
175
174
|
return getattr(instance, self.image)
|
|
176
175
|
|
|
177
176
|
def __set__(self, instance: object, value: str | ImageURI):
|
|
177
|
+
_setattr = object.__setattr__ if not hasattr(instance, self.image) else setattr
|
|
178
178
|
if isinstance(value, str):
|
|
179
179
|
value = ImageURI(value)
|
|
180
|
-
|
|
180
|
+
_setattr(instance, self.image, value)
|
|
181
181
|
|
|
182
182
|
|
|
183
|
-
class RemoteRepositoryCredentials(NamedTuple):
|
|
183
|
+
class RemoteRepositoryCredentials(tp.NamedTuple):
|
|
184
184
|
username: str
|
|
185
185
|
password: str
|
|
186
186
|
|
|
187
187
|
|
|
188
|
-
@dataclasses.dataclass(kw_only=True) # type: ignore
|
|
188
|
+
@dataclasses.dataclass(frozen=True, kw_only=True) # type: ignore
|
|
189
189
|
class RemoteImage(xm.Executable):
|
|
190
190
|
# Remote base image
|
|
191
|
-
image:
|
|
191
|
+
image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
|
|
192
192
|
|
|
193
193
|
# Working directory in container
|
|
194
194
|
workdir: pathlib.Path | None = None
|
|
@@ -196,11 +196,23 @@ class RemoteImage(xm.Executable):
|
|
|
196
196
|
# Container arguments
|
|
197
197
|
args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
|
|
198
198
|
# Container environment variables
|
|
199
|
-
env_vars: Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
199
|
+
env_vars: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
|
|
200
200
|
|
|
201
201
|
# Remote repository credentials
|
|
202
202
|
credentials: RemoteRepositoryCredentials | None = None
|
|
203
203
|
|
|
204
204
|
@property
|
|
205
|
-
def name(self) -> str:
|
|
205
|
+
def name(self) -> str: # type: ignore
|
|
206
206
|
return str(self.image)
|
|
207
|
+
|
|
208
|
+
def __hash__(self) -> int:
|
|
209
|
+
return hash(
|
|
210
|
+
(
|
|
211
|
+
type(self),
|
|
212
|
+
self.image,
|
|
213
|
+
self.workdir,
|
|
214
|
+
tuple(sorted(self.args.to_list())),
|
|
215
|
+
tuple(sorted(self.env_vars.items())),
|
|
216
|
+
self.credentials,
|
|
217
|
+
),
|
|
218
|
+
)
|