xmanager-slurm 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/__init__.py CHANGED
@@ -10,6 +10,7 @@ from xm_slurm.experiment import (
10
10
  get_current_work_unit,
11
11
  get_experiment,
12
12
  )
13
+ from xm_slurm.job_blocks import JobArgs
13
14
  from xm_slurm.packageables import (
14
15
  conda_container,
15
16
  docker_container,
@@ -34,13 +35,14 @@ __all__ = [
34
35
  "get_current_experiment",
35
36
  "get_current_work_unit",
36
37
  "get_experiment",
38
+ "JobArgs",
37
39
  "JobRequirements",
38
40
  "mamba_container",
39
- "uv_container",
40
41
  "python_container",
41
42
  "ResourceQuantity",
42
43
  "ResourceType",
43
44
  "Slurm",
44
- "SlurmSpec",
45
45
  "SlurmExperiment",
46
+ "SlurmSpec",
47
+ "uv_container",
46
48
  ]
xm_slurm/api.py CHANGED
@@ -322,7 +322,7 @@ class XManagerSqliteAPI(XManagerAPI):
322
322
  db_path = Path(os.environ["XM_SLURM_STATE_DIR"]) / "db.sqlite3"
323
323
  else:
324
324
  db_path = Path.home() / ".local" / "state" / "xm-slurm" / "db.sqlite3"
325
- logging.debug("Looking for db at: ", db_path)
325
+ logger.debug("Looking for db at: ", db_path)
326
326
  db_path.parent.mkdir(parents=True, exist_ok=True)
327
327
  engine = create_engine(f"sqlite:///{db_path}")
328
328
  Base.metadata.create_all(engine)
xm_slurm/config.py CHANGED
@@ -5,9 +5,10 @@ import getpass
5
5
  import json
6
6
  import os
7
7
  import pathlib
8
- from typing import Literal, Mapping, NamedTuple
8
+ from typing import Callable, Literal, Mapping, NamedTuple
9
9
 
10
10
  import asyncssh
11
+ from xmanager import xm
11
12
 
12
13
  from xm_slurm import constants
13
14
 
@@ -137,7 +138,7 @@ class SlurmSSHConfig:
137
138
  )
138
139
 
139
140
  def __hash__(self):
140
- return hash((self.host, self.host_public_key, self.user, self.port))
141
+ return hash((type(self), self.host, self.host_public_key, self.user, self.port))
141
142
 
142
143
 
143
144
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -176,6 +177,9 @@ class SlurmClusterConfig:
176
177
 
177
178
  features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
178
179
 
180
+ # Function to validate the Slurm executor config
181
+ validate: Callable[[xm.Job], None] | None = None
182
+
179
183
  def __post_init__(self) -> None:
180
184
  for src, dst in self.mounts.items():
181
185
  if not isinstance(src, (str, os.PathLike)):
@@ -194,6 +198,7 @@ class SlurmClusterConfig:
194
198
 
195
199
  def __hash__(self):
196
200
  return hash((
201
+ type(self),
197
202
  self.ssh,
198
203
  self.cwd,
199
204
  self.prolog,
xm_slurm/constants.py CHANGED
@@ -1,5 +1,9 @@
1
1
  import re
2
2
 
3
+ SLURM_JOB_ID_REGEX = re.compile(
4
+ r"^(?P<jobid>\d+)(?:(?:\+(?P<componentid>\d+))|(?:_(?P<arraytaskid>\d+)))?$"
5
+ )
6
+
3
7
  IMAGE_URI_REGEX = re.compile(
4
8
  r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
5
9
  )
@@ -1,13 +1,20 @@
1
+ import datetime as dt
2
+ import logging
1
3
  import os
2
4
 
5
+ from xmanager import xm
6
+
3
7
  from xm_slurm import config, resources
4
8
  from xm_slurm.contrib.clusters import drac
9
+ from xm_slurm.executors import Slurm
5
10
 
6
11
  # ComputeCanada alias
7
12
  cc = drac
8
13
 
9
14
  __all__ = ["drac", "mila", "cc"]
10
15
 
16
+ logger = logging.getLogger(__name__)
17
+
11
18
 
12
19
  def mila(
13
20
  *,
@@ -53,6 +60,8 @@ def mila(
53
60
  resources.ResourceType.A100: "a100",
54
61
  resources.ResourceType.A100_80GIB: "a100l",
55
62
  resources.ResourceType.A6000: "a6000",
63
+ resources.ResourceType.L40S: "l40s",
64
+ resources.ResourceType.H100: "h100",
56
65
  },
57
66
  features={
58
67
  resources.FeatureType.NVIDIA_MIG: "mig",
@@ -0,0 +1,171 @@
1
+ import abc
2
+ import dataclasses
3
+ import datetime as dt
4
+ from typing import Callable, Sequence
5
+
6
+
7
+ class SlurmDependencyException(Exception): ...
8
+
9
+
10
+ NoChainingException = SlurmDependencyException(
11
+ "Slurm only supports chaining dependencies with the same logical operator. "
12
+ "For example, `dep1 & dep2 | dep3` is not supported but `dep1 & dep2 & dep3` is."
13
+ )
14
+
15
+
16
+ class SlurmJobDependency(abc.ABC):
17
+ @abc.abstractmethod
18
+ def to_dependency_str(self) -> str: ...
19
+
20
+ def to_directive(self) -> str:
21
+ return f"--dependency={self.to_dependency_str()}"
22
+
23
+ def __and__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyAND":
24
+ if isinstance(self, SlurmJobDependencyOR):
25
+ raise NoChainingException
26
+ return SlurmJobDependencyAND(self, other_dependency)
27
+
28
+ def __or__(self, other_dependency: "SlurmJobDependency") -> "SlurmJobDependencyOR":
29
+ if isinstance(other_dependency, SlurmJobDependencyAND):
30
+ raise NoChainingException
31
+ return SlurmJobDependencyOR(self, other_dependency)
32
+
33
+ def flatten(self) -> tuple["SlurmJobDependency", ...]:
34
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
35
+ return self.first_dependency.flatten() + self.second_dependency.flatten()
36
+ return (self,)
37
+
38
+ def traverse(
39
+ self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
40
+ ) -> "SlurmJobDependency":
41
+ if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
42
+ return type(self)(
43
+ first_dependency=self.first_dependency.traverse(mapper),
44
+ second_dependency=self.second_dependency.traverse(mapper),
45
+ )
46
+ return mapper(self)
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class SlurmJobDependencyAND(SlurmJobDependency):
51
+ first_dependency: SlurmJobDependency
52
+ second_dependency: SlurmJobDependency
53
+
54
+ def to_dependency_str(self) -> str:
55
+ return f"{self.first_dependency.to_dependency_str()},{self.second_dependency.to_dependency_str()}"
56
+
57
+ def __or__(self, other_dependency: SlurmJobDependency):
58
+ del other_dependency
59
+ raise NoChainingException
60
+
61
+ def __hash__(self) -> int:
62
+ return hash((type(self), self.first_dependency, self.second_dependency))
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True)
66
+ class SlurmJobDependencyOR(SlurmJobDependency):
67
+ first_dependency: SlurmJobDependency
68
+ second_dependency: SlurmJobDependency
69
+
70
+ def to_dependency_str(self) -> str:
71
+ return f"{self.first_dependency.to_dependency_str()}?{self.second_dependency.to_dependency_str()}"
72
+
73
+ def __and__(self, other_dependency: SlurmJobDependency):
74
+ del other_dependency
75
+ raise NoChainingException
76
+
77
+ def __hash__(self) -> int:
78
+ return hash((type(self), self.first_dependency, self.second_dependency))
79
+
80
+
81
+ @dataclasses.dataclass(frozen=True)
82
+ class SlurmJobDependencyAfter(SlurmJobDependency):
83
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
84
+ time: dt.timedelta | None = None
85
+
86
+ def __post_init__(self):
87
+ if len(self.handles) == 0:
88
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
89
+ if self.time is not None and self.time.total_seconds() % 60 != 0:
90
+ raise SlurmDependencyException("Time must be specified in exact minutes")
91
+
92
+ def to_dependency_str(self) -> str:
93
+ directive = "after"
94
+
95
+ for handle in self.handles:
96
+ directive += f":{handle.slurm_job.job_id}"
97
+ if self.time is not None:
98
+ directive += f"+{self.time.total_seconds() // 60:.0f}"
99
+ return directive
100
+
101
+ def __hash__(self) -> int:
102
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
103
+
104
+
105
+ @dataclasses.dataclass(frozen=True)
106
+ class SlurmJobDependencyAfterAny(SlurmJobDependency):
107
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
108
+
109
+ def __post_init__(self):
110
+ if len(self.handles) == 0:
111
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
112
+
113
+ def to_dependency_str(self) -> str:
114
+ return ":".join(["afterany"] + [handle.slurm_job.job_id for handle in self.handles])
115
+
116
+ def __hash__(self) -> int:
117
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
118
+
119
+
120
+ @dataclasses.dataclass(frozen=True)
121
+ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
122
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
123
+
124
+ def __post_init__(self):
125
+ if len(self.handles) == 0:
126
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
127
+
128
+ def to_dependency_str(self) -> str:
129
+ return ":".join(["afternotok"] + [handle.slurm_job.job_id for handle in self.handles])
130
+
131
+ def __hash__(self) -> int:
132
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
133
+
134
+
135
+ @dataclasses.dataclass(frozen=True)
136
+ class SlurmJobDependencyAfterOK(SlurmJobDependency):
137
+ handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
138
+
139
+ def __post_init__(self):
140
+ if len(self.handles) == 0:
141
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
142
+
143
+ def to_dependency_str(self) -> str:
144
+ return ":".join(["afterok"] + [handle.slurm_job.job_id for handle in self.handles])
145
+
146
+ def __hash__(self) -> int:
147
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
148
+
149
+
150
+ @dataclasses.dataclass(frozen=True)
151
+ class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
152
+ handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
153
+
154
+ def __post_init__(self):
155
+ if len(self.handles) == 0:
156
+ raise SlurmDependencyException("Dependency doesn't have any handles.")
157
+
158
+ def to_dependency_str(self) -> str:
159
+ job_ids = []
160
+ for handle in self.handles:
161
+ job = handle.slurm_job
162
+ if job.is_array_job:
163
+ job_ids.append(job.array_job_id)
164
+ elif job.is_heterogeneous_job:
165
+ job_ids.append(job.het_job_id)
166
+ else:
167
+ job_ids.append(job.job_id)
168
+ return ":".join(["aftercorr"] + job_ids)
169
+
170
+ def __hash__(self) -> int:
171
+ return hash((type(self),) + tuple([handle.slurm_job for handle in self.handles]))
xm_slurm/executables.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import dataclasses
2
2
  import pathlib
3
- from typing import Mapping, NamedTuple, Sequence
3
+ import typing as tp
4
4
 
5
5
  from xmanager import xm
6
6
 
7
7
  from xm_slurm import constants
8
+ from xm_slurm.types import Descriptor
8
9
 
9
10
 
10
11
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -32,22 +33,22 @@ class Dockerfile(xm.ExecutableSpec):
32
33
  target: str | None = None
33
34
 
34
35
  # SSH sockets/keys for the docker build step.
35
- ssh: Sequence[str] = dataclasses.field(default_factory=list)
36
+ ssh: tp.Sequence[str] = dataclasses.field(default_factory=list)
36
37
 
37
38
  # Build arguments to docker
38
- build_args: Mapping[str, str] = dataclasses.field(default_factory=dict)
39
+ build_args: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
39
40
 
40
41
  # --cache-from field in BuildKit
41
- cache_from: Sequence[str] = dataclasses.field(default_factory=list)
42
+ cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
42
43
 
43
44
  # Working directory in container
44
45
  workdir: pathlib.Path | None = None
45
46
 
46
47
  # Container labels
47
- labels: Mapping[str, str] = dataclasses.field(default_factory=dict)
48
+ labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
48
49
 
49
50
  # Target platform
50
- platforms: Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
51
+ platforms: tp.Sequence[str] = dataclasses.field(default_factory=lambda: ["linux/amd64"])
51
52
 
52
53
  @property
53
54
  def name(self) -> str:
@@ -58,6 +59,7 @@ class Dockerfile(xm.ExecutableSpec):
58
59
 
59
60
  def __hash__(self) -> int:
60
61
  return hash((
62
+ type(self),
61
63
  self.dockerfile,
62
64
  self.context,
63
65
  self.target,
@@ -90,7 +92,7 @@ class DockerImage(xm.ExecutableSpec):
90
92
  return self.image
91
93
 
92
94
  def __hash__(self) -> int:
93
- return hash((self.image, self.workdir))
95
+ return hash((type(self), self.image, self.workdir))
94
96
 
95
97
 
96
98
  @dataclasses.dataclass
@@ -148,6 +150,7 @@ class ImageURI:
148
150
 
149
151
  def __hash__(self) -> int:
150
152
  return hash((
153
+ type(self),
151
154
  self.scheme,
152
155
  self.domain,
153
156
  self.path,
@@ -161,30 +164,31 @@ class ImageURI:
161
164
  return format.format(**fields)
162
165
 
163
166
 
164
- class ImageDescriptor:
167
+ class ImageDescriptor(Descriptor[ImageURI, str | ImageURI]):
165
168
  def __set_name__(self, owner: type, name: str):
166
169
  del owner
167
170
  self.image = f"_{name}"
168
171
 
169
- def __get__(self, instance: object, owner: type) -> ImageURI:
172
+ def __get__(self, instance: object | None, owner: tp.Type[object] | None = None) -> ImageURI:
170
173
  del owner
171
174
  return getattr(instance, self.image)
172
175
 
173
176
  def __set__(self, instance: object, value: str | ImageURI):
177
+ _setattr = object.__setattr__ if not hasattr(instance, self.image) else setattr
174
178
  if isinstance(value, str):
175
179
  value = ImageURI(value)
176
- setattr(instance, self.image, value)
180
+ _setattr(instance, self.image, value)
177
181
 
178
182
 
179
- class RemoteRepositoryCredentials(NamedTuple):
183
+ class RemoteRepositoryCredentials(tp.NamedTuple):
180
184
  username: str
181
185
  password: str
182
186
 
183
187
 
184
- @dataclasses.dataclass(kw_only=True) # type: ignore
188
+ @dataclasses.dataclass(frozen=True, kw_only=True) # type: ignore
185
189
  class RemoteImage(xm.Executable):
186
190
  # Remote base image
187
- image: ImageDescriptor = ImageDescriptor()
191
+ image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
188
192
 
189
193
  # Working directory in container
190
194
  workdir: pathlib.Path | None = None
@@ -192,18 +196,19 @@ class RemoteImage(xm.Executable):
192
196
  # Container arguments
193
197
  args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
194
198
  # Container environment variables
195
- env_vars: Mapping[str, str] = dataclasses.field(default_factory=dict)
199
+ env_vars: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
196
200
 
197
201
  # Remote repository credentials
198
202
  credentials: RemoteRepositoryCredentials | None = None
199
203
 
200
204
  @property
201
- def name(self) -> str:
205
+ def name(self) -> str: # type: ignore
202
206
  return str(self.image)
203
207
 
204
208
  def __hash__(self) -> int:
205
209
  return hash(
206
210
  (
211
+ type(self),
207
212
  self.image,
208
213
  self.workdir,
209
214
  tuple(sorted(self.args.to_list())),