xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (42) hide show
  1. xm_slurm/__init__.py +6 -2
  2. xm_slurm/api.py +301 -34
  3. xm_slurm/batching.py +4 -4
  4. xm_slurm/config.py +105 -55
  5. xm_slurm/constants.py +19 -0
  6. xm_slurm/contrib/__init__.py +0 -0
  7. xm_slurm/contrib/clusters/__init__.py +47 -13
  8. xm_slurm/contrib/clusters/drac.py +34 -16
  9. xm_slurm/dependencies.py +171 -0
  10. xm_slurm/executables.py +34 -22
  11. xm_slurm/execution.py +305 -107
  12. xm_slurm/executors.py +8 -12
  13. xm_slurm/experiment.py +601 -168
  14. xm_slurm/experimental/parameter_controller.py +202 -0
  15. xm_slurm/job_blocks.py +7 -0
  16. xm_slurm/packageables.py +42 -20
  17. xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
  18. xm_slurm/packaging/router.py +3 -1
  19. xm_slurm/packaging/utils.py +9 -81
  20. xm_slurm/resources.py +28 -4
  21. xm_slurm/scripts/_cloudpickle.py +28 -0
  22. xm_slurm/scripts/cli.py +52 -0
  23. xm_slurm/status.py +9 -0
  24. xm_slurm/templates/docker/mamba.Dockerfile +4 -2
  25. xm_slurm/templates/docker/python.Dockerfile +18 -10
  26. xm_slurm/templates/docker/uv.Dockerfile +35 -0
  27. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
  28. xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
  29. xm_slurm/templates/slurm/job.bash.j2 +4 -3
  30. xm_slurm/types.py +23 -0
  31. xm_slurm/utils.py +18 -10
  32. xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
  33. xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
  34. {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
  35. xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
  36. xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
  37. xm_slurm/packaging/docker/__init__.py +0 -75
  38. xm_slurm/packaging/docker/abc.py +0 -112
  39. xm_slurm/packaging/docker/cloud.py +0 -503
  40. xm_slurm/templates/docker/pdm.Dockerfile +0 -31
  41. xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
  42. xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/constants.py ADDED
@@ -0,0 +1,19 @@
1
+ import re
2
+
3
+ SLURM_JOB_ID_REGEX = re.compile(
4
+ r"^(?P<jobid>\d+)(?:(?:\+(?P<componentid>\d+))|(?:_(?P<arraytaskid>\d+)))?$"
5
+ )
6
+
7
+ IMAGE_URI_REGEX = re.compile(
8
+ r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
9
+ )
10
+
11
+ DOMAIN_NAME_REGEX = re.compile(
12
+ r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
13
+ )
14
+ IPV4_REGEX = re.compile(
15
+ r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
16
+ )
17
+ IPV6_REGEX = re.compile(
18
+ r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
19
+ )
File without changes
@@ -1,13 +1,20 @@
1
+ import datetime as dt
2
+ import logging
1
3
  import os
2
4
 
5
+ from xmanager import xm
6
+
3
7
  from xm_slurm import config, resources
4
8
  from xm_slurm.contrib.clusters import drac
9
+ from xm_slurm.executors import Slurm
5
10
 
6
11
  # ComputeCanada alias
7
12
  cc = drac
8
13
 
9
14
  __all__ = ["drac", "mila", "cc"]
10
15
 
16
+ logger = logging.getLogger(__name__)
17
+
11
18
 
12
19
  def mila(
13
20
  *,
@@ -19,18 +26,39 @@ def mila(
19
26
  if mounts is None:
20
27
  mounts = {
21
28
  "/network/scratch/${USER:0:1}/$USER": "/scratch",
22
- "/network/archive/${USER:0:1}/$USER": "/archive",
29
+ # TODO: move these somewhere common to all cluster configs.
30
+ "/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
31
+ "/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
23
32
  }
24
33
 
34
+ def validate(job: xm.Job) -> None:
35
+ assert isinstance(job.executor, Slurm)
36
+
37
+ wants_requeue_with_grace_period = (
38
+ job.executor.requeue and job.executor.timeout_signal_grace_period > dt.timedelta(0)
39
+ )
40
+ partition = job.executor.partition or "main"
41
+
42
+ if wants_requeue_with_grace_period and (
43
+ partition is None or not partition.endswith("-grace")
44
+ ):
45
+ logger.warning(
46
+ f"Job {job.name} wants requeue with grace period, but partition `{partition}` does not end with '-grace'. "
47
+ "Mila Cluster requires you specify a grace partition. "
48
+ "This may result in the job not being requeued properly."
49
+ )
50
+
25
51
  return config.SlurmClusterConfig(
26
52
  name="mila",
27
- user=user,
28
- host="login.server.mila.quebec",
29
- host_public_key=config.PublicKey(
30
- "ssh-ed25519",
31
- "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
53
+ ssh=config.SlurmSSHConfig(
54
+ user=user,
55
+ host="login.server.mila.quebec",
56
+ host_public_key=config.PublicKey(
57
+ "ssh-ed25519",
58
+ "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
59
+ ),
60
+ port=2222,
32
61
  ),
33
- port=2222,
34
62
  runtime=config.ContainerRuntime.SINGULARITY,
35
63
  partition=partition,
36
64
  prolog="module load singularity",
@@ -39,14 +67,20 @@ def mila(
39
67
  "SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
40
68
  "SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
41
69
  "SCRATCH": "/scratch",
42
- "ARCHIVE": "/archive",
70
+ # TODO: move this somewhere common to all cluster configs.
71
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
43
72
  },
44
73
  mounts=mounts,
45
74
  resources={
46
- "rtx8000": resources.ResourceType.RTX8000,
47
- "v100": resources.ResourceType.V100,
48
- "a100": resources.ResourceType.A100,
49
- "a100l": resources.ResourceType.A100_80GIB,
50
- "a6000": resources.ResourceType.A6000,
75
+ resources.ResourceType.RTX8000: "rtx8000",
76
+ resources.ResourceType.V100: "v100",
77
+ resources.ResourceType.A100: "a100",
78
+ resources.ResourceType.A100_80GIB: "a100l",
79
+ resources.ResourceType.A6000: "a6000",
80
+ },
81
+ features={
82
+ resources.FeatureType.NVIDIA_MIG: "mig",
83
+ resources.FeatureType.NVIDIA_NVLINK: "nvlink",
51
84
  },
85
+ validate=validate,
52
86
  )
@@ -2,7 +2,7 @@ import os
2
2
  from typing import Literal
3
3
 
4
4
  from xm_slurm import config
5
- from xm_slurm.resources import ResourceType
5
+ from xm_slurm.resources import FeatureType, ResourceType
6
6
 
7
7
  __all__ = ["narval", "beluga", "cedar", "graham"]
8
8
 
@@ -18,18 +18,26 @@ def _drac_cluster(
18
18
  modules: list[str] | None = None,
19
19
  proxy: Literal["submission-host"] | str | None = None,
20
20
  mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
- resources: dict[str, ResourceType] | None = None,
21
+ resources: dict[ResourceType, str] | None = None,
22
+ features: dict[FeatureType, str] | None = None,
22
23
  ) -> config.SlurmClusterConfig:
23
24
  """DRAC Cluster."""
24
25
  if mounts is None:
25
- mounts = {"/scratch/$USER": "/scratch"}
26
+ mounts = {
27
+ "/scratch/$USER": "/scratch",
28
+ # TODO: move these somewhere common to all cluster configs.
29
+ "/home/$USER/.ssh": "/home/$USER/.ssh",
30
+ "/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
31
+ }
26
32
 
27
33
  return config.SlurmClusterConfig(
28
34
  name=name,
29
- user=user,
30
- host=host,
31
- host_public_key=host_public_key,
32
- port=port,
35
+ ssh=config.SlurmSSHConfig(
36
+ user=user,
37
+ host=host,
38
+ host_public_key=host_public_key,
39
+ port=port,
40
+ ),
33
41
  account=account,
34
42
  proxy=proxy,
35
43
  runtime=config.ContainerRuntime.APPTAINER,
@@ -40,9 +48,12 @@ def _drac_cluster(
40
48
  "APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
41
49
  "_XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
42
50
  "SCRATCH": "/scratch",
51
+ # TODO: move this somewhere common to all cluster configs.
52
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
43
53
  },
44
54
  mounts=mounts,
45
55
  resources=resources or {},
56
+ features=features or {},
46
57
  )
47
58
 
48
59
 
@@ -70,7 +81,11 @@ def narval(
70
81
  mounts=mounts,
71
82
  proxy=proxy,
72
83
  modules=modules,
73
- resources={"a100": ResourceType.A100},
84
+ resources={ResourceType.A100: "a100"},
85
+ features={
86
+ FeatureType.NVIDIA_MIG: "a100mig",
87
+ FeatureType.NVIDIA_NVLINK: "nvlink",
88
+ },
74
89
  )
75
90
 
76
91
 
@@ -98,7 +113,10 @@ def beluga(
98
113
  mounts=mounts,
99
114
  proxy=proxy,
100
115
  modules=modules,
101
- resources={"tesla_v100-sxm2-16gb": ResourceType.V100},
116
+ resources={ResourceType.V100: "tesla_v100-sxm2-16gb"},
117
+ features={
118
+ FeatureType.NVIDIA_NVLINK: "nvlink",
119
+ },
102
120
  )
103
121
 
104
122
 
@@ -120,9 +138,9 @@ def cedar(
120
138
  account=account,
121
139
  mounts=mounts,
122
140
  resources={
123
- "v100l": ResourceType.V100_32GIB,
124
- "p100": ResourceType.P100,
125
- "p100l": ResourceType.P100_16GIB,
141
+ ResourceType.V100_32GIB: "v100l",
142
+ ResourceType.P100: "p100",
143
+ ResourceType.P100_16GIB: "p100l",
126
144
  },
127
145
  )
128
146
 
@@ -147,10 +165,10 @@ def graham(
147
165
  mounts=mounts,
148
166
  proxy=proxy,
149
167
  resources={
150
- "v100": ResourceType.V100,
151
- "p100": ResourceType.P100,
152
- "a100": ResourceType.A100,
153
- "a5000": ResourceType.A5000,
168
+ ResourceType.V100: "v100",
169
+ ResourceType.P100: "p100",
170
+ ResourceType.A100: "a100",
171
+ ResourceType.A5000: "a5000",
154
172
  },
155
173
  )
156
174
 
@@ -0,0 +1,171 @@
1
+ import abc
2
+ import dataclasses
3
+ import datetime as dt
4
+ from typing import Callable, Sequence
5
+
6
+
7
+ class SlurmDependencyException(Exception): ...
8
+
9
+
10
+ NoChainingException = SlurmDependencyException(
11
+ "Slurm only supports chaining dependencies with the same logical operator. "
12
+ "For example, `dep1 & dep2 | dep3` is not supported but `dep1 & dep2 & dep3` is."
13
+ )
14
+
15
+
16
+ class SlurmJobDependency(abc.ABC):
17
+ @abc.abstractmethod
18
+ def to_dependency_str(self) -> str: ...
19
+
20
+ def to_directive(self) -> str:
21
+ return f"--dependency={self.to_dependency_str()}"
22
+
23
+ def __and__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyAND":
24
+ if isinstance(self, SlurmJobDependencyOR):
25
+ raise NoChainingException
26
+ return SlurmJobDependencyAND(self, other_dependency)
27
+
28
+ def __or__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyOR":
29
+ if isinstance(other_dependency, SlurmJobDependencyAND):
30
+ raise NoChainingException
31
+ return SlurmJobDependencyOR(self, other_dependency)
32
+
33
+ def flatten(self) -> tuple["SlurmJobDependency", ...]:
34
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
35
+ return self.first_dependency.flatten() + self.second_dependency.flatten()
36
+ return (self,)
37
+
38
+ def traverse(
39
+ self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
40
+ ) -> "SlurmJobDependency":
41
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
42
+ return type(self)(
43
+ first_dependency=self.first_dependency.traverse(mapper),
44
+ second_dependency=self.second_dependency.traverse(mapper),
45
+ )
46
+ return mapper(self)
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class SlurmJobDependencyAND(SlurmJobDependency):
51
+ first_dependency: SlurmJobDependency
52
+ second_dependency: SlurmJobDependency
53
+
54
+ def to_dependency_str(self) -> str:
55
+ return f"{self.first_dependency.to_dependency_str()},{self.second_dependency.to_dependency_str()}"
56
+
57
+ def __or__(self, other_dependency: SlurmJobDependency):
58
+ del other_dependency
59
+ raise NoChainingException
60
+
61
+ def __hash__(self) -> int:
62
+ return hash((type(self), self.first_dependency, self.second_dependency))
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True)
66
+ class SlurmJobDependencyOR(SlurmJobDependency):
67
+ first_dependency: SlurmJobDependency
68
+ second_dependency: SlurmJobDependency
69
+
70
+ def to_dependency_str(self) -> str:
71
+ return f"{self.first_dependency.to_dependency_str()}?{self.second_dependency.to_dependency_str()}"
72
+
73
+ def __and__(self, other_dependency: SlurmJobDependency):
74
+ del other_dependency
75
+ raise NoChainingException
76
+
77
+ def __hash__(self) -> int:
78
+ return hash((type(self), self.first_dependency, self.second_dependency))
79
+
80
+
81
+ @dataclasses.dataclass(frozen=True)
82
+ class SlurmJobDependencyAfter(SlurmJobDependency):
83
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
84
+ time: dt.timedelta | None = None
85
+
86
+ def __post_init__(self):
87
+ if len(self.handles) == 0:
88
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
89
+ if self.time is not None and self.time.total_seconds() % 60 != 0:
90
+ raise SlurmDependencyException("Time must be specified in exact minutes")
91
+
92
+ def to_dependency_str(self) -> str:
93
+ directive = "after"
94
+
95
+ for handle in self.handles:
96
+ directive += f":{handle.slurm_job.job_id}"
97
+ if self.time is not None:
98
+ directive += f"+{self.time.total_seconds() // 60:.0f}"
99
+ return directive
100
+
101
+ def __hash__(self) -> int:
102
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
103
+
104
+
105
+ @dataclasses.dataclass(frozen=True)
106
+ class SlurmJobDependencyAfterAny(SlurmJobDependency):
107
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
108
+
109
+ def __post_init__(self):
110
+ if len(self.handles) == 0:
111
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
112
+
113
+ def to_dependency_str(self) -> str:
114
+ return ":".join(["afterany"] + [handle.slurm_job.job_id for handle in self.handles])
115
+
116
+ def __hash__(self) -> int:
117
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
118
+
119
+
120
+ @dataclasses.dataclass(frozen=True)
121
+ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
122
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
123
+
124
+ def __post_init__(self):
125
+ if len(self.handles) == 0:
126
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
127
+
128
+ def to_dependency_str(self) -> str:
129
+ return ":".join(["afternotok"] + [handle.slurm_job.job_id for handle in self.handles])
130
+
131
+ def __hash__(self) -> int:
132
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
133
+
134
+
135
+ @dataclasses.dataclass(frozen=True)
136
+ class SlurmJobDependencyAfterOK(SlurmJobDependency):
137
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
138
+
139
+ def __post_init__(self):
140
+ if len(self.handles) == 0:
141
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
142
+
143
+ def to_dependency_str(self) -> str:
144
+ return ":".join(["afterok"] + [handle.slurm_job.job_id for handle in self.handles])
145
+
146
+ def __hash__(self) -> int:
147
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
148
+
149
+
150
+ @dataclasses.dataclass(frozen=True)
151
+ class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
152
+ handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
153
+
154
+ def __post_init__(self):
155
+ if len(self.handles) == 0:
156
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
157
+
158
+ def to_dependency_str(self) -> str:
159
+ job_ids = []
160
+ for handle in self.handles:
161
+ job = handle.slurm_job
162
+ if job.is_array_job:
163
+ job_ids.append(job.array_job_id)
164
+ elif job.is_heterogeneous_job:
165
+ job_ids.append(job.het_job_id)
166
+ else:
167
+ job_ids.append(job.job_id)
168
+ return ":".join(["aftercorr"] + job_ids)
169
+
170
+ def __hash__(self) -> int:
171
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
xm_slurm/executables.py CHANGED
@@ -1,10 +1,12 @@
1
1
  import dataclasses
2
2
  import pathlib
3
- import re
4
- from typing import Mapping, NamedTuple, Sequence
3
+ import typing as tp
5
4
 
6
5
  from xmanager import xm
7
6
 
7
+ from xm_slurm import constants
8
+ from xm_slurm.types import Descriptor
9
+
8
10
 
9
11
  @dataclasses.dataclass(frozen=True, kw_only=True)
10
12
  class Dockerfile(xm.ExecutableSpec):
@@ -31,22 +33,22 @@ class Dockerfile(xm.ExecutableSpec):
31
33
  target: str | None = None
32
34
 
33
35
  # SSH sockets/keys for the docker build step.
34
- ssh: list[str] = dataclasses.field(default_factory=list)
36
+ ssh: tp.Sequence[str] = dataclasses.field(default_factory=list)
35
37
 
36
38
  # Build arguments to docker
37
- build_args: Mapping[str, str] = dataclasses.field(default_factory=dict)
39
+ build_args: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
38
40
 
39
41
  # --cache-from field in BuildKit
40
- cache_from: Sequence[str] = dataclasses.field(default_factory=list)
42
+ cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
41
43
 
42
44
  # Working directory in container
43
45
  workdir: pathlib.Path | None = None
44
46
 
45
47
  # Container labels
46
- labels: Mapping[str, str] = dataclasses.field(default_factory=dict)
48
+ labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
47
49
 
48
50
  # Target platform
49
- platforms: Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
51
+ platforms: tp.Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
50
52
 
51
53
  @property
52
54
  def name(self) -> str:
@@ -57,6 +59,7 @@ class Dockerfile(xm.ExecutableSpec):
57
59
 
58
60
  def __hash__(self) -> int:
59
61
  return hash((
62
+ type(self),
60
63
  self.dockerfile,
61
64
  self.context,
62
65
  self.target,
@@ -89,12 +92,7 @@ class DockerImage(xm.ExecutableSpec):
89
92
  return self.image
90
93
 
91
94
  def __hash__(self) -> int:
92
- return hash((self.image, self.workdir))
93
-
94
-
95
- _IMAGE_URI_REGEX = re.compile(
96
- r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
97
- )
95
+ return hash((type(self), self.image, self.workdir))
98
96
 
99
97
 
100
98
  @dataclasses.dataclass
@@ -108,7 +106,7 @@ class ImageURI:
108
106
  digest: str | None = dataclasses.field(init=False, default=None)
109
107
 
110
108
  def __post_init__(self, image: str):
111
- match = _IMAGE_URI_REGEX.match(image)
109
+ match = constants.IMAGE_URI_REGEX.match(image)
112
110
  if not match:
113
111
  raise ValueError(f"Invalid OCI image URI: {image}")
114
112
  groups = {k: v for k, v in match.groupdict().items() if v is not None}
@@ -152,6 +150,7 @@ class ImageURI:
152
150
 
153
151
  def __hash__(self) -> int:
154
152
  return hash((
153
+ type(self),
155
154
  self.scheme,
156
155
  self.domain,
157
156
  self.path,
@@ -165,30 +164,31 @@ class ImageURI:
165
164
  return format.format(**fields)
166
165
 
167
166
 
168
- class ImageDescriptor:
167
+ class ImageDescriptor(Descriptor[ImageURI, str | ImageURI]):
169
168
  def __set_name__(self, owner: type, name: str):
170
169
  del owner
171
170
  self.image = f"_{name}"
172
171
 
173
- def __get__(self, instance: object, owner: type) -> ImageURI:
172
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> ImageURI:
174
173
  del owner
175
174
  return getattr(instance, self.image)
176
175
 
177
176
  def __set__(self, instance: object, value: str | ImageURI):
177
+ _setattr = object.__setattr__ if not hasattr(instance, self.image) else setattr
178
178
  if isinstance(value, str):
179
179
  value = ImageURI(value)
180
- setattr(instance, self.image, value)
180
+ _setattr(instance, self.image, value)
181
181
 
182
182
 
183
- class RemoteRepositoryCredentials(NamedTuple):
183
+ class RemoteRepositoryCredentials(tp.NamedTuple):
184
184
  username: str
185
185
  password: str
186
186
 
187
187
 
188
- @dataclasses.dataclass(kw_only=True) # type: ignore
188
+ @dataclasses.dataclass(frozen=True, kw_only=True) # type: ignore
189
189
  class RemoteImage(xm.Executable):
190
190
  # Remote base image
191
- image: ImageDescriptor = ImageDescriptor()
191
+ image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
192
192
 
193
193
  # Working directory in container
194
194
  workdir: pathlib.Path | None = None
@@ -196,11 +196,23 @@ class RemoteImage(xm.Executable):
196
196
  # Container arguments
197
197
  args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
198
198
  # Container environment variables
199
- env_vars: Mapping[str, str] = dataclasses.field(default_factory=dict)
199
+ env_vars: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
200
200
 
201
201
  # Remote repository credentials
202
202
  credentials: RemoteRepositoryCredentials | None = None
203
203
 
204
204
  @property
205
- def name(self) -> str:
205
+ def name(self) -> str: # type: ignore
206
206
  return str(self.image)
207
+
208
+ def __hash__(self) -> int:
209
+ return hash(
210
+ (
211
+ type(self),
212
+ self.image,
213
+ self.workdir,
214
+ tuple(sorted(self.args.to_list())),
215
+ tuple(sorted(self.env_vars.items())),
216
+ self.credentials,
217
+ ),
218
+ )