xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
@@ -0,0 +1,242 @@
1
+ import os
2
+ import typing as tp
3
+
4
+ from xm_slurm import config
5
+ from xm_slurm.resources import FeatureType, ResourceType
6
+
7
+ __all__ = [
8
+ "fir",
9
+ "narval",
10
+ "nibi",
11
+ "rorqual",
12
+ "killarney",
13
+ "tamia",
14
+ "vulcan",
15
+ ]
16
+
17
+
18
+ def _drac_cluster(
19
+ *,
20
+ name: str,
21
+ host: str,
22
+ port: int = 22,
23
+ robot_host: str | None = None,
24
+ robot_port: int | None = None,
25
+ public_key: config.PublicKey,
26
+ user: str | None = None,
27
+ account: str | None = None,
28
+ modules: list[str] | None = None,
29
+ proxy: tp.Literal["submission-host"] | str | None = None,
30
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
31
+ resources: tp.Mapping[ResourceType, str] | None = None,
32
+ features: tp.Mapping[FeatureType, str] | None = None,
33
+ ) -> config.SlurmClusterConfig:
34
+ """DRAC Cluster."""
35
+ if mounts is None:
36
+ mounts = {
37
+ "/scratch/$USER": "/scratch",
38
+ # TODO: move these somewhere common to all cluster configs.
39
+ "/home/$USER/.ssh": "/home/$USER/.ssh",
40
+ "/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
41
+ }
42
+ mounts = dict(mounts) | {"/dev/infiniband": "/dev/infiniband"}
43
+
44
+ endpoints = []
45
+ if robot_host is not None and robot_host != host:
46
+ endpoints.append(config.Endpoint(robot_host, robot_port))
47
+ endpoints.append(config.Endpoint(host, port))
48
+ endpoints = tuple(endpoints)
49
+
50
+ return config.SlurmClusterConfig(
51
+ name=name,
52
+ ssh=config.SSHConfig(user=user, endpoints=endpoints, public_key=public_key),
53
+ account=account,
54
+ proxy=proxy,
55
+ runtime=config.ContainerRuntime.APPTAINER,
56
+ prolog=f"module load apptainer {' '.join(modules) if modules else ''}".rstrip(),
57
+ host_environment={
58
+ "XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
59
+ "APPTAINER_CACHEDIR": "$SCRATCH/.apptainer",
60
+ "APPTAINER_TMPDIR": "$SLURM_TMPDIR",
61
+ "APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
62
+ },
63
+ container_environment={
64
+ "SCRATCH": "/scratch",
65
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
66
+ },
67
+ mounts=mounts,
68
+ resources=resources or {},
69
+ features=features or {},
70
+ )
71
+
72
+
73
+ def narval(
74
+ *,
75
+ user: str | None = None,
76
+ account: str | None = None,
77
+ proxy: tp.Literal["submission-host"] | str | None = None,
78
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
79
+ ) -> config.SlurmClusterConfig:
80
+ """DRAC Narval Cluster (https://docs.alliancecan.ca/wiki/Narval/en)."""
81
+ modules = []
82
+ if proxy != "submission-host":
83
+ modules.append("httpproxy")
84
+
85
+ return _drac_cluster(
86
+ name="narval",
87
+ host="narval.alliancecan.ca",
88
+ robot_host="robot.narval.alliancecan.ca",
89
+ public_key=config.PublicKey(
90
+ "ssh-ed25519",
91
+ "AAAAC3NzaC1lZDI1NTE5AAAAILFxB0spH5RApc43sBx0zOxo1ARVH0ezU+FbQH95FW+h",
92
+ ),
93
+ user=user,
94
+ account=account,
95
+ mounts=mounts,
96
+ proxy=proxy,
97
+ modules=modules,
98
+ resources={ResourceType.A100: "a100"},
99
+ features={
100
+ FeatureType.NVIDIA_MIG: "a100mig",
101
+ FeatureType.NVIDIA_NVLINK: "nvlink",
102
+ },
103
+ )
104
+
105
+
106
+ def rorqual(
107
+ *,
108
+ user: str | None = None,
109
+ account: str | None = None,
110
+ proxy: tp.Literal["submission-host"] | str | None = None,
111
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
112
+ ) -> config.SlurmClusterConfig:
113
+ """DRAC Rorqual Cluster (https://docs.alliancecan.ca/wiki/Rorqual/en)."""
114
+ modules = []
115
+ if proxy != "submission-host":
116
+ modules.append("httpproxy")
117
+
118
+ return _drac_cluster(
119
+ name="rorqual",
120
+ host="rorqual.alliancecan.ca",
121
+ robot_host="robot.rorqual.alliancecan.ca",
122
+ public_key=config.PublicKey(
123
+ "ssh-ed25519",
124
+ "AAAAC3NzaC1lZDI1NTE5AAAAINME5e9bifKZbuKKOQSpe3xrvC4g1b0QLMYj+AXBQGJe",
125
+ ),
126
+ user=user,
127
+ account=account,
128
+ mounts=mounts,
129
+ proxy=proxy,
130
+ modules=modules,
131
+ resources={ResourceType.H100: "h100"},
132
+ features={
133
+ FeatureType.NVIDIA_NVLINK: "nvlink",
134
+ },
135
+ )
136
+
137
+
138
+ def fir(
139
+ *,
140
+ user: str | None = None,
141
+ account: str | None = None,
142
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
143
+ ) -> config.SlurmClusterConfig:
144
+ """DRAC Fir Cluster (https://docs.alliancecan.ca/wiki/Fir/en)."""
145
+ return _drac_cluster(
146
+ name="fir",
147
+ host="fir.alliancecan.ca",
148
+ robot_host="robot.fir.alliancecan.ca",
149
+ public_key=config.PublicKey(
150
+ "ssh-ed25519",
151
+ "AAAAC3NzaC1lZDI1NTE5AAAAIJtenyJz+inwobvlJntWYFNu+ANcVWNcOHRKcEN6zmDo",
152
+ ),
153
+ user=user,
154
+ account=account,
155
+ mounts=mounts,
156
+ resources={ResourceType.H100: "h100"},
157
+ )
158
+
159
+
160
+ def nibi(
161
+ *,
162
+ user: str | None = None,
163
+ account: str | None = None,
164
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
165
+ ) -> config.SlurmClusterConfig:
166
+ """DRAC Nibi Cluster (https://docs.alliancecan.ca/wiki/Nibi/en)."""
167
+ return _drac_cluster(
168
+ name="nibi",
169
+ host="nibi.alliancecan.ca",
170
+ robot_host="robot.nibi.alliancecan.ca",
171
+ public_key=config.PublicKey(
172
+ "ssh-ed25519",
173
+ "AAAAC3NzaC1lZDI1NTE5AAAAIEcmFoQZr6+KUHm/zm/BJpnNIlME7GytMxbHgfAUfoQX",
174
+ ),
175
+ user=user,
176
+ account=account,
177
+ mounts=mounts,
178
+ resources={ResourceType.H100: "h100"},
179
+ )
180
+
181
+
182
+ def killarney(
183
+ *,
184
+ user: str | None = None,
185
+ account: str | None = None,
186
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
187
+ ) -> config.SlurmClusterConfig:
188
+ """DRAC (PAICE) Killarney Cluster (https://docs.alliancecan.ca/wiki/Killarney/en)."""
189
+ return _drac_cluster(
190
+ name="killarney",
191
+ host="killarney.alliancecan.ca",
192
+ public_key=config.PublicKey(
193
+ "ssh-ed25519",
194
+ "AAAAC3NzaC1lZDI1NTE5AAAAIGlzaBBtvhJsSr23rMoY41gy8Svv1IOct8TBRH9CGuJf",
195
+ ),
196
+ user=user,
197
+ account=account,
198
+ mounts=mounts,
199
+ resources={ResourceType.L40S: "l40s", ResourceType.H100: "h100"},
200
+ )
201
+
202
+
203
+ def tamia(
204
+ *,
205
+ user: str | None = None,
206
+ account: str | None = None,
207
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
208
+ ) -> config.SlurmClusterConfig:
209
+ """DRAC (PAICE) Tamia Cluster (https://docs.alliancecan.ca/wiki/Tamia/en)."""
210
+ return _drac_cluster(
211
+ name="tamia",
212
+ host="tamia.alliancecan.ca",
213
+ public_key=config.PublicKey(
214
+ "ssh-ed25519",
215
+ "AAAAC3NzaC1lZDI1NTE5AAAAIN2wL9wOa0VveA/2l2ky/OhPsQfYtKuX99dyNnUTSYeU",
216
+ ),
217
+ user=user,
218
+ account=account,
219
+ mounts=mounts,
220
+ resources={ResourceType.H100: "h100"},
221
+ )
222
+
223
+
224
+ def vulcan(
225
+ *,
226
+ user: str | None = None,
227
+ account: str | None = None,
228
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
229
+ ) -> config.SlurmClusterConfig:
230
+ """DRAC (PAICE) Vulcan Cluster (https://docs.alliancecan.ca/wiki/Vulcan/en)."""
231
+ return _drac_cluster(
232
+ name="vulcan",
233
+ host="vulcan.alliancecan.ca",
234
+ public_key=config.PublicKey(
235
+ "ssh-ed25519",
236
+ "AAAAC3NzaC1lZDI1NTE5AAAAIMuIj6T45HqVeJgRotH9Qq46FzidekS2lXkD7FOTltnC",
237
+ ),
238
+ user=user,
239
+ account=account,
240
+ mounts=mounts,
241
+ resources={ResourceType.L40S: "l40s"},
242
+ )
@@ -0,0 +1,171 @@
1
+ import abc
2
+ import dataclasses
3
+ import datetime as dt
4
+ import typing as tp
5
+
6
+
7
+ class SlurmDependencyException(Exception): ...
8
+
9
+
10
+ NoChainingException = SlurmDependencyException(
11
+ "Slurm only supports chaining dependencies with the same logical operator. "
12
+ "For example, `dep1 & dep2 | dep3` is not supported but `dep1 & dep2 & dep3` is."
13
+ )
14
+
15
+
16
+ class SlurmJobDependency(abc.ABC):
17
+ @abc.abstractmethod
18
+ def to_dependency_str(self) -> str: ...
19
+
20
+ def to_directive(self) -> str:
21
+ return f"--dependency={self.to_dependency_str()}"
22
+
23
+ def __and__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyAND":
24
+ if isinstance(self, SlurmJobDependencyOR):
25
+ raise NoChainingException
26
+ return SlurmJobDependencyAND(self, other_dependency)
27
+
28
+ def __or__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyOR":
29
+ if isinstance(other_dependency, SlurmJobDependencyAND):
30
+ raise NoChainingException
31
+ return SlurmJobDependencyOR(self, other_dependency)
32
+
33
+ def flatten(self) -> tuple["SlurmJobDependency", ...]:
34
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
35
+ return self.first_dependency.flatten() + self.second_dependency.flatten()
36
+ return (self,)
37
+
38
+ def traverse(
39
+ self, mapper: tp.Callable[["SlurmJobDependency"], "SlurmJobDependency"]
40
+ ) -> "SlurmJobDependency":
41
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
42
+ return type(self)(
43
+ first_dependency=self.first_dependency.traverse(mapper),
44
+ second_dependency=self.second_dependency.traverse(mapper),
45
+ )
46
+ return mapper(self)
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class SlurmJobDependencyAND(SlurmJobDependency):
51
+ first_dependency: SlurmJobDependency
52
+ second_dependency: SlurmJobDependency
53
+
54
+ def to_dependency_str(self) -> str:
55
+ return f"{self.first_dependency.to_dependency_str()},{self.second_dependency.to_dependency_str()}"
56
+
57
+ def __or__(self, other_dependency: SlurmJobDependency):
58
+ del other_dependency
59
+ raise NoChainingException
60
+
61
+ def __hash__(self) -> int:
62
+ return hash((type(self), self.first_dependency, self.second_dependency))
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True)
66
+ class SlurmJobDependencyOR(SlurmJobDependency):
67
+ first_dependency: SlurmJobDependency
68
+ second_dependency: SlurmJobDependency
69
+
70
+ def to_dependency_str(self) -> str:
71
+ return f"{self.first_dependency.to_dependency_str()}?{self.second_dependency.to_dependency_str()}"
72
+
73
+ def __and__(self, other_dependency: SlurmJobDependency):
74
+ del other_dependency
75
+ raise NoChainingException
76
+
77
+ def __hash__(self) -> int:
78
+ return hash((type(self), self.first_dependency, self.second_dependency))
79
+
80
+
81
+ @dataclasses.dataclass(frozen=True)
82
+ class SlurmJobDependencyAfter(SlurmJobDependency):
83
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
84
+ time: dt.timedelta | None = None
85
+
86
+ def __post_init__(self):
87
+ if len(self.handles) == 0:
88
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
89
+ if self.time is not None and self.time.total_seconds() % 60 != 0:
90
+ raise SlurmDependencyException("Time must be specified in exact minutes")
91
+
92
+ def to_dependency_str(self) -> str:
93
+ directive = "after"
94
+
95
+ for handle in self.handles:
96
+ directive += f":{handle.slurm_job.job_id}"
97
+ if self.time is not None:
98
+ directive += f"+{self.time.total_seconds() // 60:.0f}"
99
+ return directive
100
+
101
+ def __hash__(self) -> int:
102
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
103
+
104
+
105
+ @dataclasses.dataclass(frozen=True)
106
+ class SlurmJobDependencyAfterAny(SlurmJobDependency):
107
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
108
+
109
+ def __post_init__(self):
110
+ if len(self.handles) == 0:
111
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
112
+
113
+ def to_dependency_str(self) -> str:
114
+ return ":".join(["afterany"] + [handle.slurm_job.job_id for handle in self.handles])
115
+
116
+ def __hash__(self) -> int:
117
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
118
+
119
+
120
+ @dataclasses.dataclass(frozen=True)
121
+ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
122
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
123
+
124
+ def __post_init__(self):
125
+ if len(self.handles) == 0:
126
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
127
+
128
+ def to_dependency_str(self) -> str:
129
+ return ":".join(["afternotok"] + [handle.slurm_job.job_id for handle in self.handles])
130
+
131
+ def __hash__(self) -> int:
132
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
133
+
134
+
135
+ @dataclasses.dataclass(frozen=True)
136
+ class SlurmJobDependencyAfterOK(SlurmJobDependency):
137
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
138
+
139
+ def __post_init__(self):
140
+ if len(self.handles) == 0:
141
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
142
+
143
+ def to_dependency_str(self) -> str:
144
+ return ":".join(["afterok"] + [handle.slurm_job.job_id for handle in self.handles])
145
+
146
+ def __hash__(self) -> int:
147
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
148
+
149
+
150
+ @dataclasses.dataclass(frozen=True)
151
+ class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
152
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
153
+
154
+ def __post_init__(self):
155
+ if len(self.handles) == 0:
156
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
157
+
158
+ def to_dependency_str(self) -> str:
159
+ job_ids = []
160
+ for handle in self.handles:
161
+ job = handle.slurm_job
162
+ if job.is_array_job:
163
+ job_ids.append(job.array_job_id)
164
+ elif job.is_heterogeneous_job:
165
+ job_ids.append(job.het_job_id)
166
+ else:
167
+ job_ids.append(job.job_id)
168
+ return ":".join(["aftercorr"] + job_ids)
169
+
170
+ def __hash__(self) -> int:
171
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
@@ -0,0 +1,215 @@
1
+ import dataclasses
2
+ import os
3
+ import pathlib
4
+ import typing as tp
5
+
6
+ from xmanager import xm
7
+
8
+ from xm_slurm import constants
9
+ from xm_slurm.types import Descriptor
10
+
11
+
12
+ @dataclasses.dataclass(frozen=True, kw_only=True)
13
+ class Dockerfile(xm.ExecutableSpec):
14
+ """A specification describing a Dockerfile to build.
15
+
16
+ Args:
17
+ dockerfile: The path to the Dockerfile.
18
+ context: The path to the Docker context.
19
+ target: The Docker build target.
20
+ ssh: A list of docker SSH sockets/keys.
21
+ build_args: Build arguments to docker.
22
+ cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
23
+ labels: The container labels.
24
+ platforms: The target platform.
25
+ """
26
+
27
+ # Dockerfile
28
+ dockerfile: pathlib.Path
29
+ # Docker context
30
+ context: pathlib.Path
31
+
32
+ # Docker build target
33
+ target: str | None = None
34
+
35
+ # SSH sockets/keys for the docker build step.
36
+ ssh: tp.Sequence[str] = dataclasses.field(default_factory=list)
37
+
38
+ # Build arguments to docker
39
+ build_args: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
40
+
41
+ # --cache-from field in BuildKit
42
+ cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
43
+
44
+ # Container labels
45
+ labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
46
+
47
+ # Target platform
48
+ platforms: tp.Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
49
+
50
+ @property
51
+ def name(self) -> str:
52
+ name = self.dockerfile.stem
53
+ if self.target is not None:
54
+ name = f"{name}-{self.target}"
55
+ return name
56
+
57
+ def __hash__(self) -> int:
58
+ return hash((
59
+ type(self),
60
+ self.dockerfile,
61
+ self.context,
62
+ self.target,
63
+ tuple(sorted(self.ssh)),
64
+ tuple(sorted(self.build_args.items())),
65
+ tuple(sorted(self.cache_from)),
66
+ tuple(sorted(self.labels.items())),
67
+ tuple(sorted(self.platforms)),
68
+ ))
69
+
70
+
71
+ @dataclasses.dataclass(frozen=True, kw_only=True)
72
+ class DockerImage(xm.ExecutableSpec):
73
+ """A specification describing a pre-built Docker image.
74
+
75
+ Args:
76
+ image: The remote image URI.
77
+ workdir: The working directory in container.
78
+
79
+ """
80
+
81
+ image: str
82
+
83
+ # Working directory in container
84
+ workdir: pathlib.Path | None = None
85
+
86
+ @property
87
+ def name(self) -> str:
88
+ return self.image
89
+
90
+ def __hash__(self) -> int:
91
+ return hash((type(self), self.image, self.workdir))
92
+
93
+
94
+ @dataclasses.dataclass
95
+ class ImageURI:
96
+ image: dataclasses.InitVar[str]
97
+
98
+ scheme: str | None = dataclasses.field(init=False, default=None)
99
+ domain: str = dataclasses.field(init=False)
100
+ path: str = dataclasses.field(init=False)
101
+ tag: str | None = dataclasses.field(init=False, default=None)
102
+ digest: str | None = dataclasses.field(init=False, default=None)
103
+
104
+ def __post_init__(self, image: str):
105
+ match = constants.IMAGE_URI_REGEX.match(image)
106
+ if not match:
107
+ raise ValueError(f"Invalid OCI image URI: {image}")
108
+ groups = {k: v for k, v in match.groupdict().items() if v is not None}
109
+ for k, v in groups.items():
110
+ setattr(self, k, v)
111
+
112
+ if self.tag is None and self.digest is None:
113
+ self.tag = "latest"
114
+
115
+ @property
116
+ def locator(self) -> str:
117
+ """Unique locator for this image.
118
+
119
+ Locator will return the digest if it exists otherwise the tag format.
120
+ If neither are present, it will raise an AssertionError.
121
+ """
122
+ if self.digest is not None:
123
+ return f"@{self.digest}"
124
+ assert self.tag is not None
125
+ return f":{self.tag}"
126
+
127
+ @property
128
+ def url(self) -> str:
129
+ """URL for this image without the locator."""
130
+ return f"{self.origin}{self.path}"
131
+
132
+ @property
133
+ def origin(self) -> str:
134
+ return f"{self.scheme}{self.domain}"
135
+
136
+ def with_tag(self, tag: str) -> "ImageURI":
137
+ self.tag = tag
138
+ return self
139
+
140
+ def with_digest(self, digest: str) -> "ImageURI":
141
+ self.digest = digest
142
+ return self
143
+
144
+ def __str__(self) -> str:
145
+ return self.format("{url}{locator}")
146
+
147
+ def __hash__(self) -> int:
148
+ return hash((
149
+ type(self),
150
+ self.scheme,
151
+ self.domain,
152
+ self.path,
153
+ self.tag,
154
+ self.digest,
155
+ ))
156
+
157
+ def format(self, format: str) -> str:
158
+ fields = {k: v for k, v in dataclasses.asdict(self).items() if v is not None}
159
+ fields |= {"locator": self.locator, "url": self.url}
160
+ return format.format(**fields)
161
+
162
+
163
+ class ImageDescriptor(Descriptor[ImageURI, str | ImageURI]):
164
+ def __set_name__(self, owner: type, name: str):
165
+ del owner
166
+ self.image = f"_{name}"
167
+
168
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> ImageURI:
169
+ del owner
170
+ return getattr(instance, self.image)
171
+
172
+ def __set__(self, instance: object, value: str | ImageURI):
173
+ _setattr = object.__setattr__ if not hasattr(instance, self.image) else setattr
174
+ if isinstance(value, str):
175
+ value = ImageURI(value)
176
+ _setattr(instance, self.image, value)
177
+
178
+
179
+ class RemoteRepositoryCredentials(tp.NamedTuple):
180
+ username: str
181
+ password: str
182
+
183
+
184
+ @dataclasses.dataclass(frozen=True, kw_only=True) # type: ignore
185
+ class RemoteImage(xm.Executable):
186
+ # Remote base image
187
+ image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
188
+
189
+ workdir: os.PathLike[str] | str
190
+ entrypoint: xm.SequentialArgs
191
+
192
+ # Container arguments
193
+ args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
194
+ # Container environment variables
195
+ env_vars: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
196
+
197
+ # Remote repository credentials
198
+ credentials: RemoteRepositoryCredentials | None = None
199
+
200
+ @property
201
+ def name(self) -> str: # type: ignore
202
+ return str(self.image)
203
+
204
+ def __hash__(self) -> int:
205
+ return hash(
206
+ (
207
+ type(self),
208
+ self.image,
209
+ self.workdir,
210
+ tuple(sorted(self.entrypoint.to_list())),
211
+ tuple(sorted(self.args.to_list())),
212
+ tuple(sorted(self.env_vars.items())),
213
+ self.credentials,
214
+ ),
215
+ )