xmanager-slurm 0.4.14__py3-none-any.whl → 0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/__init__.py CHANGED
@@ -18,7 +18,7 @@ from xm_slurm.packageables import (
18
18
  python_container,
19
19
  uv_container,
20
20
  )
21
- from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
21
+ from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType, Topology
22
22
 
23
23
  logging.getLogger("asyncssh").setLevel(logging.WARN)
24
24
  logging.getLogger("httpx").setLevel(logging.WARN)
@@ -42,5 +42,6 @@ __all__ = [
42
42
  "Slurm",
43
43
  "SlurmExperiment",
44
44
  "SlurmSpec",
45
+ "Topology",
45
46
  "uv_container",
46
47
  ]
xm_slurm/batching.py CHANGED
@@ -4,11 +4,11 @@ import dataclasses
4
4
  import inspect
5
5
  import time
6
6
  import types
7
- from typing import Any, Callable, Coroutine, Generic, ParamSpec, Sequence, TypeVar
7
+ import typing as tp
8
8
 
9
- T = TypeVar("T", contravariant=True)
10
- R = TypeVar("R", covariant=True)
11
- P = ParamSpec("P")
9
+ T = tp.TypeVar("T", contravariant=True)
10
+ R = tp.TypeVar("R", covariant=True)
11
+ P = tp.ParamSpec("P")
12
12
 
13
13
 
14
14
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -18,10 +18,10 @@ class Request:
18
18
 
19
19
 
20
20
  def stack_bound_arguments(
21
- signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
21
+ signature: inspect.Signature, bound_arguments: tp.Sequence[inspect.BoundArguments]
22
22
  ) -> inspect.BoundArguments:
23
23
  """Stacks bound arguments into a single bound arguments object."""
24
- stacked_args = collections.OrderedDict[str, Any]()
24
+ stacked_args = collections.OrderedDict[str, tp.Any]()
25
25
  for bound_args in bound_arguments:
26
26
  for name, value in bound_args.arguments.items():
27
27
  stacked_args.setdefault(name, [])
@@ -29,7 +29,7 @@ def stack_bound_arguments(
29
29
  return inspect.BoundArguments(signature, stacked_args)
30
30
 
31
31
 
32
- class batch(Generic[R]):
32
+ class batch(tp.Generic[R]):
33
33
  __slots__ = (
34
34
  "fn",
35
35
  "signature",
@@ -44,7 +44,7 @@ class batch(Generic[R]):
44
44
 
45
45
  def __init__(
46
46
  self,
47
- fn: Callable[..., Coroutine[None, None, Sequence[R]]],
47
+ fn: tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]],
48
48
  /,
49
49
  *,
50
50
  max_batch_size: int,
@@ -102,7 +102,7 @@ class batch(Generic[R]):
102
102
 
103
103
  return batch
104
104
 
105
- def __get__(self, obj: Any, objtype: type) -> Any:
105
+ def __get__(self, obj: tp.Any, objtype: tp.Type[tp.Any]) -> tp.Any:
106
106
  del objtype
107
107
  if isinstance(self.fn, staticmethod):
108
108
  return self.__call__
@@ -110,11 +110,11 @@ class batch(Generic[R]):
110
110
  return types.MethodType(self, obj)
111
111
 
112
112
  @property
113
- def __func__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
113
+ def __func__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
114
114
  return self.fn
115
115
 
116
116
  @property
117
- def __wrapped__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
117
+ def __wrapped__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
118
118
  return self.fn
119
119
 
120
120
  @property
xm_slurm/config.py CHANGED
@@ -5,7 +5,7 @@ import getpass
5
5
  import json
6
6
  import os
7
7
  import pathlib
8
- from typing import Callable, Literal, Mapping, NamedTuple
8
+ import typing as tp
9
9
 
10
10
  import asyncssh
11
11
  from xmanager import xm
@@ -23,7 +23,7 @@ class ContainerRuntime(enum.Enum):
23
23
 
24
24
  @classmethod
25
25
  def from_string(
26
- cls, runtime: Literal["singularity", "apptainer", "docker", "podman"]
26
+ cls, runtime: tp.Literal["singularity", "apptainer", "docker", "podman"]
27
27
  ) -> "ContainerRuntime":
28
28
  return {
29
29
  "singularity": cls.SINGULARITY,
@@ -45,7 +45,7 @@ class ContainerRuntime(enum.Enum):
45
45
  raise NotImplementedError
46
46
 
47
47
 
48
- class PublicKey(NamedTuple):
48
+ class PublicKey(tp.NamedTuple):
49
49
  algorithm: str
50
50
  key: str
51
51
 
@@ -172,26 +172,26 @@ class SlurmClusterConfig:
172
172
  qos: str | None = None
173
173
 
174
174
  # If true, a reverse proxy is initiated via the submission host.
175
- proxy: Literal["submission-host"] | str | None = None
175
+ proxy: tp.Literal["submission-host"] | str | None = None
176
176
 
177
177
  runtime: ContainerRuntime
178
178
 
179
179
  # Environment variables
180
- host_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
181
- container_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
180
+ host_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
181
+ container_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
182
182
 
183
183
  # Mounts
184
- mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
184
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
185
185
  default_factory=dict
186
186
  )
187
187
 
188
188
  # Resource mapping
189
- resources: Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
189
+ resources: tp.Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
190
190
 
191
- features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
191
+ features: tp.Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
192
192
 
193
193
  # Function to validate the Slurm executor config
194
- validate: Callable[[xm.Job], None] | None = None
194
+ validate: tp.Callable[[xm.Job], None] | None = None
195
195
 
196
196
  def __post_init__(self) -> None:
197
197
  for src, dst in self.mounts.items():
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Literal
2
+ import typing as tp
3
3
 
4
4
  from xm_slurm import config
5
5
  from xm_slurm.resources import FeatureType, ResourceType
@@ -16,10 +16,10 @@ def _drac_cluster(
16
16
  user: str | None = None,
17
17
  account: str | None = None,
18
18
  modules: list[str] | None = None,
19
- proxy: Literal["submission-host"] | str | None = None,
20
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
- resources: dict[ResourceType, str] | None = None,
22
- features: dict[FeatureType, str] | None = None,
19
+ proxy: tp.Literal["submission-host"] | str | None = None,
20
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
+ resources: tp.Mapping[ResourceType, str] | None = None,
22
+ features: tp.Mapping[FeatureType, str] | None = None,
23
23
  ) -> config.SlurmClusterConfig:
24
24
  """DRAC Cluster."""
25
25
  if mounts is None:
@@ -62,8 +62,8 @@ def narval(
62
62
  *,
63
63
  user: str | None = None,
64
64
  account: str | None = None,
65
- proxy: Literal["submission-host"] | str | None = None,
66
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
65
+ proxy: tp.Literal["submission-host"] | str | None = None,
66
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
67
67
  ) -> config.SlurmClusterConfig:
68
68
  """DRAC Narval Cluster (https://docs.alliancecan.ca/wiki/Narval/en)."""
69
69
  modules = []
@@ -94,8 +94,8 @@ def beluga(
94
94
  *,
95
95
  user: str | None = None,
96
96
  account: str | None = None,
97
- proxy: Literal["submission-host"] | str | None = None,
98
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
97
+ proxy: tp.Literal["submission-host"] | str | None = None,
98
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
99
99
  ) -> config.SlurmClusterConfig:
100
100
  """DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/B%C3%A9luga/en)."""
101
101
  modules = []
@@ -125,8 +125,8 @@ def rorqual(
125
125
  *,
126
126
  user: str | None = None,
127
127
  account: str | None = None,
128
- proxy: Literal["submission-host"] | str | None = None,
129
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
128
+ proxy: tp.Literal["submission-host"] | str | None = None,
129
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
130
130
  ) -> config.SlurmClusterConfig:
131
131
  """DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/Rorqual/en)."""
132
132
  modules = []
@@ -155,7 +155,7 @@ def cedar(
155
155
  *,
156
156
  user: str | None = None,
157
157
  account: str | None = None,
158
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
158
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
159
159
  ) -> config.SlurmClusterConfig:
160
160
  """DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Cedar/en)."""
161
161
  return _drac_cluster(
@@ -180,7 +180,7 @@ def fir(
180
180
  *,
181
181
  user: str | None = None,
182
182
  account: str | None = None,
183
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
183
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
184
184
  ) -> config.SlurmClusterConfig:
185
185
  """DRAC Fir Cluster (https://docs.alliancecan.ca/wiki/Fir/en)."""
186
186
  return _drac_cluster(
@@ -201,8 +201,8 @@ def graham(
201
201
  *,
202
202
  user: str | None = None,
203
203
  account: str | None = None,
204
- proxy: Literal["submission-host"] | str | None = "submission-host",
205
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
204
+ proxy: tp.Literal["submission-host"] | str | None = "submission-host",
205
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
206
206
  ) -> config.SlurmClusterConfig:
207
207
  """DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Graham/en)."""
208
208
  return _drac_cluster(
@@ -223,17 +223,3 @@ def graham(
223
223
  ResourceType.A5000: "a5000",
224
224
  },
225
225
  )
226
-
227
-
228
- def all(
229
- user: str | None = None,
230
- account: str | None = None,
231
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
232
- ) -> list[config.SlurmClusterConfig]:
233
- """All DRAC clusters."""
234
- return [
235
- narval(user=user, account=account, mounts=mounts),
236
- beluga(user=user, account=account, mounts=mounts),
237
- cedar(user=user, account=account, mounts=mounts),
238
- graham(user=user, account=account, mounts=mounts),
239
- ]
xm_slurm/dependencies.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import abc
2
2
  import dataclasses
3
3
  import datetime as dt
4
- from typing import Callable, Sequence
4
+ import typing as tp
5
5
 
6
6
 
7
7
  class SlurmDependencyException(Exception): ...
@@ -36,7 +36,7 @@ class SlurmJobDependency(abc.ABC):
36
36
  return (self,)
37
37
 
38
38
  def traverse(
39
- self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
39
+ self, mapper: tp.Callable[["SlurmJobDependency"], "SlurmJobDependency"]
40
40
  ) -> "SlurmJobDependency":
41
41
  if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
42
42
  return type(self)(
@@ -80,7 +80,7 @@ class SlurmJobDependencyOR(SlurmJobDependency):
80
80
 
81
81
  @dataclasses.dataclass(frozen=True)
82
82
  class SlurmJobDependencyAfter(SlurmJobDependency):
83
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
83
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
84
84
  time: dt.timedelta | None = None
85
85
 
86
86
  def __post_init__(self):
@@ -104,7 +104,7 @@ class SlurmJobDependencyAfter(SlurmJobDependency):
104
104
 
105
105
  @dataclasses.dataclass(frozen=True)
106
106
  class SlurmJobDependencyAfterAny(SlurmJobDependency):
107
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
107
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
108
108
 
109
109
  def __post_init__(self):
110
110
  if len(self.handles) == 0:
@@ -119,7 +119,7 @@ class SlurmJobDependencyAfterAny(SlurmJobDependency):
119
119
 
120
120
  @dataclasses.dataclass(frozen=True)
121
121
  class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
122
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
122
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
123
123
 
124
124
  def __post_init__(self):
125
125
  if len(self.handles) == 0:
@@ -134,7 +134,7 @@ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
134
134
 
135
135
  @dataclasses.dataclass(frozen=True)
136
136
  class SlurmJobDependencyAfterOK(SlurmJobDependency):
137
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
137
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
138
138
 
139
139
  def __post_init__(self):
140
140
  if len(self.handles) == 0:
@@ -149,7 +149,7 @@ class SlurmJobDependencyAfterOK(SlurmJobDependency):
149
149
 
150
150
  @dataclasses.dataclass(frozen=True)
151
151
  class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
152
- handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
152
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
153
153
 
154
154
  def __post_init__(self):
155
155
  if len(self.handles) == 0:
xm_slurm/executors.py CHANGED
@@ -1,10 +1,18 @@
1
+ import collections.abc
1
2
  import dataclasses
2
3
  import datetime as dt
3
4
  import signal
5
+ import typing as tp
4
6
 
5
7
  from xmanager import xm
6
8
 
7
- from xm_slurm import resources
9
+ from xm_slurm import resources, utils
10
+
11
+ ResourceBindType = tp.Literal[
12
+ resources.ResourceType.GPU,
13
+ resources.ResourceType.MEMORY,
14
+ resources.ResourceType.RAM,
15
+ ]
8
16
 
9
17
 
10
18
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -26,10 +34,19 @@ class Slurm(xm.Executor):
26
34
  Args:
27
35
  requirements: The requirements for the job.
28
36
  time: The maximum time to run the job.
37
+ switches: Maximum count of leaf switches desired for the job allocation.
38
+ switches_grace_period: Maximum time to wait for that number of switches.
39
+ bind: How to bind tasks to resource (memory, GPU, or generic resource).
29
40
  account: The account to charge the job to.
30
41
  partition: The partition to run the job in.
31
42
  qos: The quality of service to run the job with.
32
43
  priority: The priority of the job.
44
+ reservation: Allocate resources for the job from the named reservation.
45
+ exclusive: Allow sharing nodes with other running jobs.
46
+ oversubscribe: Allow over-subscribing resources with other running jobs.
47
+ overcommit: Allow sharing of allocated resources as if only one task per was requested.
48
+ nice: Run the job with an adjusted scheduling priority.
49
+ kill_on_invalid_dependencies: Whether to kill the job if it has invalid dependencies.
33
50
  timeout_signal: The signal to send to the job when it runs out of time.
34
51
  timeout_signal_grace_period: The time to wait before sending `timeout_signal`.
35
52
  requeue: Whether or not the job is eligible for requeueing.
@@ -41,12 +58,18 @@ class Slurm(xm.Executor):
41
58
  # Job requirements
42
59
  requirements: resources.JobRequirements
43
60
  time: dt.timedelta
61
+ bind: tp.Mapping[ResourceBindType | str, str | None] | None = None
44
62
 
45
63
  # Placement
46
64
  account: str | None = None
47
65
  partition: str | None = None
48
66
  qos: str | None = None
49
67
  priority: int | None = None
68
+ reservation: str | tp.Iterable[str] | None = None
69
+ exclusive: bool = False
70
+ oversubscribe: bool = False
71
+ overcommit: bool = False
72
+ nice: int | None = None
50
73
 
51
74
  # Job dependency handling
52
75
  kill_on_invalid_dependencies: bool = True
@@ -65,13 +88,28 @@ class Slurm(xm.Executor):
65
88
  return self.time - self.timeout_signal_grace_period
66
89
 
67
90
  def __post_init__(self) -> None:
68
- if not isinstance(self.time, dt.timedelta):
69
- raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
70
91
  if not isinstance(self.requirements, resources.JobRequirements):
71
92
  raise TypeError(
72
93
  f"requirements must be a `xm_slurm.JobRequirements`, got {type(self.requirements)}. "
73
94
  "If you're still using `xm.JobRequirements`, please update to `xm_slurm.JobRequirements`."
74
95
  )
96
+ if not isinstance(self.time, dt.timedelta):
97
+ raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
98
+ if self.bind is not None:
99
+ if not isinstance(self.bind, collections.abc.Mapping):
100
+ raise TypeError(f"bind must be a mapping, got {type(self.bind)}")
101
+ for resource, value in self.bind.items():
102
+ if resource not in (
103
+ resources.ResourceType.GPU,
104
+ resources.ResourceType.MEMORY,
105
+ resources.ResourceType.RAM,
106
+ ) and not isinstance(resource, str):
107
+ raise TypeError(
108
+ f"bind resource must be a {resources.ResourceType.GPU.name}, {resources.ResourceType.MEMORY.name}, or {resources.ResourceType.RAM.name}, got {type(resource)}"
109
+ )
110
+ if value is not None and not isinstance(value, str):
111
+ raise TypeError(f"bind value must be None or a string, got {type(value)}")
112
+
75
113
  if not isinstance(self.timeout_signal, signal.Signals):
76
114
  raise TypeError(
77
115
  f"termination_signal must be a `signal.Signals`, got {type(self.timeout_signal)}"
@@ -86,6 +124,10 @@ class Slurm(xm.Executor):
86
124
  )
87
125
  if self.requeue_on_exit_code == 0:
88
126
  raise ValueError("requeue_on_exit_code should not be 0 to avoid unexpected behavior.")
127
+ if self.exclusive and self.oversubscribe:
128
+ raise ValueError("exclusive and oversubscribe are mutually exclusive.")
129
+ if self.nice is not None and not (-2147483645 <= self.nice <= 2147483645):
130
+ raise ValueError(f"nice must be between -2147483645 and 2147483645, got {self.nice}")
89
131
 
90
132
  @classmethod
91
133
  def Spec(cls, tag: str | None = None) -> SlurmSpec:
@@ -96,10 +138,22 @@ class Slurm(xm.Executor):
96
138
  directives = self.requirements.to_directives()
97
139
 
98
140
  # Time
99
- days = self.time.days
100
- hours, remainder = divmod(self.time.seconds, 3600)
101
- minutes, seconds = divmod(remainder, 60)
102
- directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
141
+ directives.append(f"--time={utils.timestr_from_timedelta(self.time)}")
142
+
143
+ # Resource binding
144
+ if self.bind is not None:
145
+ for resource, value in self.bind.items():
146
+ if value is None:
147
+ value = "none"
148
+ match resource:
149
+ case resources.ResourceType.MEMORY | resources.ResourceType.RAM:
150
+ directives.append(f"--mem-bind={value}")
151
+ case resources.ResourceType.GPU:
152
+ directives.append(f"--gpu-bind={value}")
153
+ case str():
154
+ directives.append(f"--tres-bind=gres/{resource}:{value}")
155
+ case _:
156
+ raise ValueError(f"Unsupported resource type {resource!r} for binding.")
103
157
 
104
158
  # Job dependency handling
105
159
  directives.append(
@@ -107,20 +161,36 @@ class Slurm(xm.Executor):
107
161
  )
108
162
 
109
163
  # Placement
110
- if self.account:
164
+ if self.account is not None:
111
165
  directives.append(f"--account={self.account}")
112
- if self.partition:
166
+ if self.partition is not None:
113
167
  directives.append(f"--partition={self.partition}")
114
- if self.qos:
168
+ if self.qos is not None:
115
169
  directives.append(f"--qos={self.qos}")
116
- if self.priority:
170
+ if self.priority is not None:
117
171
  directives.append(f"--priority={self.priority}")
172
+ if self.reservation is not None:
173
+ match self.reservation:
174
+ case str():
175
+ directives.append(f"--reservation={self.reservation}")
176
+ case collections.abc.Iterable():
177
+ directives.append(f"--reservation={','.join(self.reservation)}")
178
+ case _:
179
+ raise ValueError(f"Invalid reservation type: {type(self.reservation)}")
180
+ if self.exclusive:
181
+ directives.append("--exclusive")
182
+ if self.oversubscribe:
183
+ directives.append("--oversubscribe")
184
+ if self.overcommit:
185
+ directives.append("--overcommit")
186
+ if self.nice is not None:
187
+ directives.append(f"--nice={self.nice}")
118
188
 
119
189
  # Job rescheduling
120
190
  directives.append(
121
191
  f"--signal={self.timeout_signal.name.removeprefix('SIG')}@{self.timeout_signal_grace_period.seconds}"
122
192
  )
123
- if self.requeue and self.requeue_max_attempts > 0:
193
+ if self.requeue is not None and self.requeue_max_attempts > 0:
124
194
  directives.append("--requeue")
125
195
  else:
126
196
  directives.append("--no-requeue")
@@ -4,8 +4,8 @@ import dataclasses
4
4
  import enum
5
5
  import functools
6
6
  import logging
7
+ import typing as tp
7
8
  import zlib
8
- from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
9
9
 
10
10
  import backoff
11
11
  import cloudpickle
@@ -15,15 +15,15 @@ import xm_slurm
15
15
  from xm_slurm import job_blocks, status
16
16
  from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
17
 
18
- P = ParamSpec("P")
19
- T = TypeVar("T")
18
+ P = tp.ParamSpec("P")
19
+ T = tp.TypeVar("T")
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
24
  async def _monitor_parameter_controller(
25
25
  aux_unit: SlurmAuxiliaryUnit,
26
- local_parameter_controller_coro: Coroutine[None, None, T],
26
+ local_parameter_controller_coro: tp.Coroutine[None, None, T],
27
27
  *,
28
28
  poll_interval: float = 30.0,
29
29
  ) -> None:
@@ -104,13 +104,13 @@ def parameter_controller(
104
104
  controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
105
105
  controller_name: str = "parameter_controller",
106
106
  controller_args: xm.UserArgs | None = None,
107
- controller_env_vars: Mapping[str, str] | None = None,
108
- ) -> Callable[
107
+ controller_env_vars: tp.Mapping[str, str] | None = None,
108
+ ) -> tp.Callable[
109
109
  [
110
- Callable[Concatenate[SlurmExperiment, P], T]
111
- | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
110
+ tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
111
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
112
112
  ],
113
- Callable[P, xm.AuxiliaryUnitJob],
113
+ tp.Callable[P, xm.AuxiliaryUnitJob],
114
114
  ]:
115
115
  """Converts a function to a controller which can be added to an experiment.
116
116
 
@@ -131,9 +131,9 @@ def parameter_controller(
131
131
  """
132
132
 
133
133
  def decorator(
134
- f: Callable[Concatenate[SlurmExperiment, P], T]
135
- | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
136
- ) -> Callable[P, xm.AuxiliaryUnitJob]:
134
+ f: tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
135
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
136
+ ) -> tp.Callable[P, xm.AuxiliaryUnitJob]:
137
137
  @functools.wraps(f)
138
138
  def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
139
139
  # Modify the function to read the experiment from the API so that it can be pickled.
@@ -141,13 +141,17 @@ def parameter_controller(
141
141
  async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
142
142
  experiment_id = aux_unit.experiment.experiment_id
143
143
 
144
- async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
144
+ async def local_controller(
145
+ *args: P.args, **kwargs: P.kwargs
146
+ ) -> T | tp.Awaitable[T]:
145
147
  if asyncio.iscoroutinefunction(f):
146
148
  return await f(aux_unit.experiment, *args, **kwargs)
147
149
  else:
148
150
  return f(aux_unit.experiment, *args, **kwargs)
149
151
 
150
- async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
152
+ async def remote_controller(
153
+ *args: P.args, **kwargs: P.kwargs
154
+ ) -> T | tp.Awaitable[T]:
151
155
  async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
152
156
  if asyncio.iscoroutinefunction(f):
153
157
  return await f(exp, *args, **kwargs)
xm_slurm/job_blocks.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import Mapping, TypedDict
1
+ import typing as tp
2
2
 
3
3
  from xmanager import xm
4
4
 
5
5
 
6
- class JobArgs(TypedDict, total=False):
6
+ class JobArgs(tp.TypedDict, total=False):
7
7
  args: xm.UserArgs
8
- env_vars: Mapping[str, str]
8
+ env_vars: tp.Mapping[str, str]
9
9
 
10
10
 
11
11
  def get_args_for_python_entrypoint(
xm_slurm/packageables.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import importlib.resources as resources
2
2
  import pathlib
3
3
  import sys
4
- from typing import Literal, Mapping, Sequence
4
+ import typing as tp
5
5
 
6
6
  from xmanager import xm
7
7
 
@@ -14,7 +14,7 @@ def docker_image(
14
14
  *,
15
15
  image: str,
16
16
  args: xm.UserArgs | None = None,
17
- env_vars: Mapping[str, str] | None = None,
17
+ env_vars: tp.Mapping[str, str] | None = None,
18
18
  ) -> xm.Packageable:
19
19
  """Creates a packageable for a pre-built Docker image.
20
20
 
@@ -39,12 +39,12 @@ def docker_container(
39
39
  dockerfile: pathlib.Path | None = None,
40
40
  context: pathlib.Path | None = None,
41
41
  target: str | None = None,
42
- ssh: Sequence[str] | Literal[True] | None = None,
43
- build_args: Mapping[str, str] | None = None,
44
- cache_from: str | Sequence[str] | None = None,
45
- labels: Mapping[str, str] | None = None,
42
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
43
+ build_args: tp.Mapping[str, str] | None = None,
44
+ cache_from: str | tp.Sequence[str] | None = None,
45
+ labels: tp.Mapping[str, str] | None = None,
46
46
  args: xm.UserArgs | None = None,
47
- env_vars: Mapping[str, str] | None = None,
47
+ env_vars: tp.Mapping[str, str] | None = None,
48
48
  ) -> xm.Packageable:
49
49
  """Creates a Docker container packageable from a dockerfile.
50
50
 
@@ -104,13 +104,13 @@ def python_container(
104
104
  context: pathlib.Path | None = None,
105
105
  requirements: pathlib.Path | None = None,
106
106
  base_image: str = "docker.io/python:{major}.{minor}-slim",
107
- extra_system_packages: Sequence[str] = (),
108
- extra_python_packages: Sequence[str] = (),
109
- cache_from: str | Sequence[str] | None = None,
110
- labels: Mapping[str, str] | None = None,
111
- ssh: Sequence[str] | Literal[True] | None = None,
107
+ extra_system_packages: tp.Sequence[str] = (),
108
+ extra_python_packages: tp.Sequence[str] = (),
109
+ cache_from: str | tp.Sequence[str] | None = None,
110
+ labels: tp.Mapping[str, str] | None = None,
111
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
112
112
  args: xm.UserArgs | None = None,
113
- env_vars: Mapping[str, str] | None = None,
113
+ env_vars: tp.Mapping[str, str] | None = None,
114
114
  ) -> xm.Packageable:
115
115
  """Creates a Python container from a base image using pip from a `requirements.txt` file.
116
116
 
@@ -181,11 +181,11 @@ def mamba_container(
181
181
  context: pathlib.Path | None = None,
182
182
  environment: pathlib.Path | None = None,
183
183
  base_image: str = "gcr.io/distroless/base-debian10",
184
- cache_from: str | Sequence[str] | None = None,
185
- labels: Mapping[str, str] | None = None,
186
- ssh: Sequence[str] | Literal[True] | None = None,
184
+ cache_from: str | tp.Sequence[str] | None = None,
185
+ labels: tp.Mapping[str, str] | None = None,
186
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
187
187
  args: xm.UserArgs | None = None,
188
- env_vars: Mapping[str, str] | None = None,
188
+ env_vars: tp.Mapping[str, str] | None = None,
189
189
  ) -> xm.Packageable:
190
190
  """Creates a Conda container from a base image using mamba from a `environment.yml` file.
191
191
 
@@ -249,13 +249,13 @@ def uv_container(
249
249
  entrypoint: xm.ModuleName | xm.CommandList,
250
250
  context: pathlib.Path | None = None,
251
251
  base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
252
- extra_system_packages: Sequence[str] = (),
253
- extra_python_packages: Sequence[str] = (),
254
- cache_from: str | Sequence[str] | None = None,
255
- labels: Mapping[str, str] | None = None,
256
- ssh: Sequence[str] | Literal[True] | None = None,
252
+ extra_system_packages: tp.Sequence[str] = (),
253
+ extra_python_packages: tp.Sequence[str] = (),
254
+ cache_from: str | tp.Sequence[str] | None = None,
255
+ labels: tp.Mapping[str, str] | None = None,
256
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
257
257
  args: xm.UserArgs | None = None,
258
- env_vars: Mapping[str, str] | None = None,
258
+ env_vars: tp.Mapping[str, str] | None = None,
259
259
  ) -> xm.Packageable:
260
260
  """Creates a Python container from a base image using uv from a `uv.lock` file.
261
261
 
@@ -1,31 +1,31 @@
1
1
  import dataclasses
2
- from typing import Callable, Generic, ParamSpec, Sequence, Type, TypeVar
2
+ import typing as tp
3
3
 
4
4
  from xmanager import xm
5
5
 
6
- T_co = TypeVar("T_co", covariant=True)
7
- P = ParamSpec("P")
8
- ExecutableSpecT = TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
6
+ T_co = tp.TypeVar("T_co", covariant=True)
7
+ P = tp.ParamSpec("P")
8
+ ExecutableSpecT = tp.TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
9
9
 
10
10
 
11
11
  @dataclasses.dataclass(frozen=True)
12
- class IndexedContainer(Generic[T_co]):
12
+ class IndexedContainer(tp.Generic[T_co]):
13
13
  index: int
14
14
  value: T_co
15
15
 
16
16
 
17
- RegistrationCallable = Callable[
18
- [Sequence[IndexedContainer[xm.Packageable]]],
19
- Sequence[IndexedContainer[xm.Executable]],
17
+ RegistrationCallable = tp.Callable[
18
+ [tp.Sequence[IndexedContainer[xm.Packageable]]],
19
+ tp.Sequence[IndexedContainer[xm.Executable]],
20
20
  ]
21
21
 
22
22
 
23
- _REGISTRY: dict[Type[xm.ExecutableSpec], RegistrationCallable] = {}
23
+ _REGISTRY: dict[tp.Type[xm.ExecutableSpec], RegistrationCallable] = {}
24
24
 
25
25
 
26
26
  def register(
27
- *typs: Type[ExecutableSpecT],
28
- ) -> Callable[[RegistrationCallable], RegistrationCallable]:
27
+ *typs: tp.Type[ExecutableSpecT],
28
+ ) -> tp.Callable[[RegistrationCallable], RegistrationCallable]:
29
29
  def decorator(
30
30
  registration_callable: RegistrationCallable,
31
31
  ) -> RegistrationCallable:
@@ -38,8 +38,8 @@ def register(
38
38
 
39
39
 
40
40
  def route(
41
- typ: Type[ExecutableSpecT],
42
- packageables: Sequence[IndexedContainer[xm.Packageable]],
43
- ) -> Sequence[IndexedContainer[xm.Executable]]:
41
+ typ: tp.Type[ExecutableSpecT],
42
+ packageables: tp.Sequence[IndexedContainer[xm.Packageable]],
43
+ ) -> tp.Sequence[IndexedContainer[xm.Executable]]:
44
44
  global _REGISTRY
45
45
  return _REGISTRY[typ](packageables)
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import logging
3
- from typing import Sequence, Type
3
+ import typing as tp
4
4
 
5
5
  from xmanager import xm
6
6
 
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def package(
17
- packageables: Sequence[xm.Packageable],
17
+ packageables: tp.Sequence[xm.Packageable],
18
18
  ) -> list[xm.Executable]:
19
19
  """
20
20
  Takes as input a list of packageables and returns a mapping of
@@ -23,7 +23,7 @@ def package(
23
23
  # Docker targets to be collected.
24
24
  # These are a mapping from `DockerTarget` to the latest digest of the image.
25
25
  targets_by_type = collections.defaultdict[
26
- Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
26
+ tp.Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
27
27
  ](list)
28
28
 
29
29
  # Collect dockerfiles that need to be built locally
@@ -1,20 +1,20 @@
1
1
  import collections
2
2
  import logging
3
- from typing import ParamSpec, Sequence, TypeVar
3
+ import typing as tp
4
4
 
5
5
  from xmanager import xm
6
6
 
7
7
  from xm_slurm.packaging.registry import IndexedContainer
8
8
 
9
- T = TypeVar("T")
10
- P = ParamSpec("P")
11
- ReturnT = TypeVar("ReturnT")
9
+ T = tp.TypeVar("T")
10
+ P = tp.ParamSpec("P")
11
+ ReturnT = tp.TypeVar("ReturnT")
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def collect_executors_by_executable(
17
- targets: Sequence[IndexedContainer[xm.Packageable]],
17
+ targets: tp.Sequence[IndexedContainer[xm.Packageable]],
18
18
  ) -> dict[xm.ExecutableSpec, set[xm.ExecutorSpec]]:
19
19
  executors_by_executable = collections.defaultdict(set)
20
20
  for target in targets:
xm_slurm/resources.py CHANGED
@@ -1,9 +1,13 @@
1
+ import builtins
2
+ import collections.abc
3
+ import datetime as dt
1
4
  import enum
2
5
  import itertools
3
6
  import math
4
- from typing import Mapping
7
+ import re
8
+ import typing as tp
5
9
 
6
- from xm_slurm import config
10
+ from xm_slurm import config, utils
7
11
 
8
12
 
9
13
  class ResourceType(enum.IntEnum):
@@ -57,32 +61,123 @@ assert AcceleratorType | {
57
61
  } == set(ResourceType.__members__.values()), "Resource types are not exhaustive."
58
62
 
59
63
 
60
- ResourceQuantity = int | float
61
-
62
-
63
64
  class FeatureType(enum.IntEnum):
64
65
  NVIDIA_MIG = 1
65
66
  NVIDIA_NVLINK = 2
66
67
 
67
68
 
69
+ class InvalidTopologyError(Exception):
70
+ """An unrecognized topology has been provided."""
71
+
72
+
73
+ TOPOLOGY_REGEX = re.compile(r"^(?P<dims>[\d]+(?:x[\d]+)*)$")
74
+
75
+
76
+ class Topology:
77
+ mesh: str
78
+ dimensions: list[int]
79
+ switches: int | None
80
+ switches_grace_period: dt.timedelta | None
81
+
82
+ def __init__(
83
+ self,
84
+ mesh: str,
85
+ /,
86
+ *,
87
+ switches: int | None = None,
88
+ switches_grace_period: dt.timedelta | None = None,
89
+ ):
90
+ mesh_match = TOPOLOGY_REGEX.fullmatch(mesh)
91
+ if not mesh_match:
92
+ raise InvalidTopologyError(f"Invalid topology mesh: {mesh!r}.")
93
+
94
+ self.mesh = mesh
95
+ self.dimensions = list(map(int, mesh_match.group("dims").split("x")))
96
+ if switches is not None:
97
+ assert (
98
+ isinstance(switches, int) and switches > 0
99
+ ), "Switches must be a positive integer."
100
+ self.switches = switches
101
+ if switches_grace_period is not None:
102
+ assert isinstance(
103
+ switches_grace_period, dt.timedelta
104
+ ), "Switches grace period must be a `datetime.timedelta`."
105
+ self.switches_grace_period = switches_grace_period
106
+
107
+ @property
108
+ def chip_count(self) -> int:
109
+ return math.prod(self.dimensions)
110
+
111
+ @property
112
+ def ndim(self) -> int:
113
+ return len(self.dimensions)
114
+
115
+ def __eq__(self, other: object) -> bool:
116
+ if not isinstance(other, Topology):
117
+ return False
118
+ return (
119
+ self.mesh == other.mesh
120
+ and self.switches == other.switches
121
+ and self.switches_grace_period == other.switches_grace_period
122
+ )
123
+
124
+ def __hash__(self) -> int:
125
+ return hash((self.mesh, self.switches, self.switches_grace_period))
126
+
127
+
128
+ ResourceQuantity = int | float | Topology
129
+
130
+
131
+ def _parse_resource_quantity(
132
+ resource_name: ResourceType | str, value: ResourceQuantity
133
+ ) -> tuple[float, Topology | None]:
134
+ if isinstance(resource_name, ResourceType):
135
+ resource_name = resource_name.name
136
+ match value:
137
+ case Topology() as topology:
138
+ return topology.chip_count, topology
139
+ case builtins.str() as topology_str if (
140
+ "x" in topology_str and TOPOLOGY_REGEX.fullmatch(topology_str) is not None
141
+ ):
142
+ topology = Topology(topology_str)
143
+ return topology.chip_count, topology
144
+ case builtins.str() as num_str:
145
+ try:
146
+ value = float(num_str)
147
+ return int(value) if value.is_integer() else value, None
148
+ except ValueError as e:
149
+ raise ValueError(
150
+ f"Couldn't parse resource quantity for {resource_name}. "
151
+ f"{num_str!r} was given."
152
+ ) from e
153
+ case int() | float():
154
+ return value, None
155
+ case _:
156
+ raise ValueError(f"Invalid resource quantity: {value!r} for {resource_name!r}.")
157
+
158
+
68
159
  class JobRequirements:
69
160
  replicas: int
70
161
  location: str | None
71
162
  accelerator: ResourceType | None
72
- cluster: config.SlurmClusterConfig | None = None
163
+ topology: Topology | None
164
+ cluster: config.SlurmClusterConfig
73
165
 
74
166
  def __init__(
75
167
  self,
76
168
  *,
77
- resources: Mapping[ResourceType | str, ResourceQuantity] | None = None,
78
- replicas: int = 1,
79
- location: str | None = None,
80
- cluster: config.SlurmClusterConfig | None = None,
169
+ resources: tp.Mapping[ResourceType | str, ResourceQuantity] | None = None,
170
+ replicas: int | None = None,
171
+ location: str | tp.Iterable[str] | None = None,
172
+ cluster: config.SlurmClusterConfig,
81
173
  **kw_resources: ResourceQuantity,
82
174
  ):
83
- self.replicas = replicas or 1
175
+ if isinstance(location, collections.abc.Iterable) and not isinstance(location, str):
176
+ location = ",".join(location)
84
177
  self.location = location
178
+
85
179
  self.accelerator = None
180
+ self.topology = None
86
181
  self.cluster = cluster
87
182
 
88
183
  if resources is None:
@@ -90,6 +185,7 @@ class JobRequirements:
90
185
 
91
186
  self.task_requirements: dict[ResourceType | str, ResourceQuantity] = {}
92
187
  for resource_name, value in itertools.chain(resources.items(), kw_resources.items()):
188
+ quantity, topology = _parse_resource_quantity(resource_name, value)
93
189
  match resource_name:
94
190
  case str() if resource_name.upper() in ResourceType.__members__:
95
191
  resource = ResourceType[resource_name.upper()]
@@ -106,58 +202,132 @@ class JobRequirements:
106
202
  if self.accelerator is not None:
107
203
  raise ValueError("Accelerator already set.")
108
204
  self.accelerator = resource # type: ignore
205
+ self.topology = topology or Topology(f"{quantity:g}")
206
+ elif topology is not None:
207
+ raise ValueError(
208
+ f"A topology was specified for a non-accelerator resource: {resource_name!r}."
209
+ )
109
210
 
110
211
  if resource in self.task_requirements:
111
212
  raise ValueError(f"{resource} has been specified twice.")
112
- self.task_requirements[resource] = value
213
+ self.task_requirements[resource] = quantity
214
+
215
+ if self.topology is not None and self.topology.ndim > 2:
216
+ raise ValueError("Topologies with more than 2 dimensions are not supported.")
217
+
218
+ if (
219
+ self.accelerator is not None
220
+ and self.topology is not None
221
+ and len(self.topology.dimensions) == 2
222
+ ):
223
+ if replicas is not None and replicas != self.topology.dimensions[1]:
224
+ raise ValueError(
225
+ f"For multihost GPUs with topology {self.topology}, replicas should"
226
+ f"be either None or {self.topology.dimensions[1]}. Found: "
227
+ f"{replicas}"
228
+ )
229
+ replicas = self.topology.dimensions[1]
230
+
231
+ if replicas is not None and replicas <= 0:
232
+ raise ValueError(f"Replicas must be a positive integer, got {replicas!r}")
233
+ self.replicas = replicas or 1
113
234
 
114
235
  def to_directives(self) -> list[str]:
115
- if self.cluster is None:
116
- raise ValueError("Cannnot derive Slurm directives for requirements without a cluster.")
117
236
  directives = []
118
237
 
119
238
  for resource, value in self.task_requirements.items():
120
239
  match resource:
121
240
  case ResourceType.EPHEMERAL_STORAGE | ResourceType.DISK:
122
- assert isinstance(value, int), "Disk space must be an integer"
241
+ assert isinstance(
242
+ value, int
243
+ ), f"Disk space must be an integer, got {type(value)!r}"
123
244
  directives.append(f"--tmp={math.ceil(value / 2**20)}M")
124
245
  case ResourceType.MEMORY | ResourceType.RAM:
125
246
  num_cpus = self.task_requirements.get(ResourceType.CPU, 1)
126
- assert isinstance(value, (int, float)), "Memory must be an integer or float"
127
- assert isinstance(num_cpus, int), "CPU must be an integer"
247
+ assert isinstance(
248
+ value, (int, float)
249
+ ), f"Memory must be an integer or float, got {type(value)!r}"
250
+ assert isinstance(
251
+ num_cpus, int
252
+ ), f"CPU must be an integer, got {type(num_cpus)!r}"
128
253
  directives.append(f"--mem-per-cpu={math.ceil(value / num_cpus / 2**20)}M")
129
254
  case ResourceType.CPU:
130
- assert isinstance(value, int), "CPU must be an integer"
255
+ assert isinstance(value, int), f"CPU must be an integer, got {type(value)!r}"
131
256
  directives.append(f"--cpus-per-task={value}")
132
257
  case ResourceType.GPU:
133
- assert isinstance(value, int), "GPU must be an integer"
134
- directives.append(f"--gpus-per-task={value}")
258
+ assert isinstance(value, int), f"GPU must be an integer, got {type(value)!r}"
259
+ directives.append(f"--gpus={value}")
135
260
  case ResourceType() if resource in AcceleratorType:
136
- assert isinstance(value, int), "Accelerator must be an integer"
261
+ assert isinstance(
262
+ value, int
263
+ ), f"Accelerator must be an integer, got {type(value)!r}"
137
264
  resource_type = self.cluster.resources.get(resource, None)
138
265
  if resource_type is None:
139
266
  raise ValueError(
140
267
  f"Cluster {self.cluster.name} does not map resource type {resource!r}."
141
268
  )
142
- directives.append(f"--gpus-per-task={resource_type}:{value}")
269
+ directives.append(f"--gpus={resource_type}:{value}")
143
270
  case str():
144
271
  directives.append(f"--gres={resource}:{value}")
145
272
 
146
- directives.append(f"--ntasks={self.replicas}")
147
273
  if self.location:
274
+ assert isinstance(
275
+ self.location, str
276
+ ), f"Location must be a string, got {type(self.location)!r}"
148
277
  directives.append(f"--nodelist={self.location}")
149
278
 
279
+ assert (
280
+ isinstance(self.replicas, int) and self.replicas > 0
281
+ ), f"Replicas must be a positive integer, got {self.replicas!r}"
282
+ directives.append(f"--ntasks={self.replicas}")
283
+
284
+ if self.topology is not None:
285
+ assert self.accelerator is not None, "Accelerator must be set."
286
+ match self.accelerator:
287
+ case ResourceType.GPU:
288
+ directives.append(f"--gpus-per-task={self.topology.dimensions[0]}")
289
+ case ResourceType() if self.accelerator in AcceleratorType:
290
+ resource_type = self.cluster.resources[self.accelerator]
291
+ directives.append(
292
+ f"--gpus-per-task={resource_type}:{self.topology.dimensions[0]}"
293
+ )
294
+
295
+ if self.topology.switches is not None:
296
+ switches_timeout = (
297
+ f"@{utils.timestr_from_timedelta(self.topology.switches_grace_period)}"
298
+ if self.topology.switches_grace_period is not None
299
+ else ""
300
+ )
301
+ directives.append(f"--switches={self.topology.switches}{switches_timeout}")
302
+
150
303
  return directives
151
304
 
152
305
  def replace(
153
306
  self,
154
- cluster: config.SlurmClusterConfig | None,
307
+ replicas: int | None = None,
308
+ location: str | None = None,
309
+ cluster: config.SlurmClusterConfig | None = None,
155
310
  **kw_resources: ResourceQuantity,
156
311
  ) -> "JobRequirements":
312
+ # Merge kw_resources into existing task_requirements, removing conflicting enum keys
313
+ merged_resources = dict(self.task_requirements)
314
+
315
+ # Remove ResourceType keys that will be overridden by string keys in kw_resources
316
+ for key in list(merged_resources.keys()):
317
+ if isinstance(key, ResourceType) and any(
318
+ ResourceType[name.upper()] == key
319
+ for name in kw_resources
320
+ if name.upper() in ResourceType.__members__
321
+ ):
322
+ del merged_resources[key]
323
+
324
+ merged_resources.update(kw_resources) # type: ignore
325
+
157
326
  return JobRequirements(
158
- resources=self.task_requirements | kw_resources, # type: ignore
159
- replicas=self.replicas,
160
- cluster=cluster or self.cluster,
327
+ resources=merged_resources,
328
+ replicas=replicas if replicas is not None else self.replicas,
329
+ location=location if location is not None else self.location,
330
+ cluster=cluster if cluster is not None else self.cluster,
161
331
  )
162
332
 
163
333
  def __repr__(self) -> str:
@@ -169,7 +339,7 @@ class JobRequirements:
169
339
  args.append(f"{resource.lower()}={value!r}")
170
340
 
171
341
  if self.replicas != 1:
172
- args.append(f"replicas={self.replicas}")
342
+ args.append(f"replicas={self.replicas!r}")
173
343
 
174
344
  if self.cluster is not None:
175
345
  args.append(f"cluster={self.cluster!r}")
xm_slurm/status.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import enum
4
4
  import re
5
- from typing import Sequence
5
+ import typing as tp
6
6
 
7
7
  from xmanager import xm
8
8
 
@@ -151,7 +151,7 @@ class SlurmWorkUnitStatus(xm.ExperimentUnitStatus):
151
151
  """Status of a Slurm experiment job."""
152
152
 
153
153
  @classmethod
154
- def aggregate(cls, states: Sequence[SlurmJobState]) -> "SlurmWorkUnitStatus":
154
+ def aggregate(cls, states: tp.Sequence[SlurmJobState]) -> "SlurmWorkUnitStatus":
155
155
  """Aggregate a sequence of statuses into a single status."""
156
156
  assert len(states) > 0, "Cannot aggregate empty sequence of statuses."
157
157
  max_error_state: SlurmJobState | None = None
@@ -7,9 +7,9 @@
7
7
 
8
8
  {% block bootstrap %}
9
9
  srun \
10
+ --label \
10
11
  --unbuffered \
11
12
  --kill-on-bad-exit=0 \
12
- --overlap \
13
13
  --export="ALL" \
14
14
  bash <<'SRUN_EOF' &
15
15
  set -Eeuxo pipefail
@@ -29,6 +29,7 @@
29
29
  {% block bootstrap %}
30
30
  {% for job in job_group.jobs.values() +%}
31
31
  srun \
32
+ --label \
32
33
  --unbuffered \
33
34
  --kill-on-bad-exit=0 \
34
35
  --export="ALL" \
@@ -59,9 +59,9 @@ export {{ key }}="{{ value }}"
59
59
 
60
60
  {% block bootstrap %}
61
61
  srun \
62
+ --label \
62
63
  --unbuffered \
63
64
  --kill-on-bad-exit=0 \
64
- --overlap \
65
65
  --export="ALL" \
66
66
  bash <<'SRUN_EOF' &
67
67
  set -Eeuxo pipefail
xm_slurm/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import datetime as dt
2
3
  import functools
3
4
  import logging
4
5
  import os
@@ -186,3 +187,10 @@ def run_command(
186
187
  stdout=stdout,
187
188
  stderr=stderr,
188
189
  )
190
+
191
+
192
+ def timestr_from_timedelta(time: dt.timedelta) -> str:
193
+ days = time.days
194
+ hours, remainder = divmod(time.seconds, 3600)
195
+ minutes, seconds = divmod(remainder, 60)
196
+ return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xmanager-slurm
3
- Version: 0.4.14
3
+ Version: 0.4.15
4
4
  Summary: Slurm backend for XManager.
5
5
  Project-URL: GitHub, https://github.com/jessefarebro/xm-slurm
6
6
  Author-email: Jesse Farebrother <jfarebro@cs.mcgill.ca>
@@ -1,21 +1,21 @@
1
- xm_slurm/__init__.py,sha256=WgRn9HDYa5H3sfIH-HZu33liBOh98jM4GqcR349RaSY,1086
2
- xm_slurm/batching.py,sha256=GbKBsNz9w8gIc2fHLZpslC0e4K9YUfLXFHmjduRRCfQ,4385
3
- xm_slurm/config.py,sha256=i9WuxjfGBgVoHHDNk3BvO7LCwpBTJeRnOiFvTU-FHrk,7330
1
+ xm_slurm/__init__.py,sha256=VNbvBLbv5ccbPxQUpbiwgoo72qI3FrATTloevufstzY,1112
2
+ xm_slurm/batching.py,sha256=ynbMRItxNtBn0SbkhHbrv5ugYuHeMw-7BP7a_-I6Oqg,4384
3
+ xm_slurm/config.py,sha256=6RAdMTgteUJ_tmr3H_RE6MHtzda202_unU2S24iPHUE,7319
4
4
  xm_slurm/console.py,sha256=UpMqeJ0C8i0pkue1AHnnyyX0bFJ9zZeJ7HBR6yhuA8A,54
5
5
  xm_slurm/constants.py,sha256=zefVtlFdflgSolie5g_rVxWV-Zpydxapchm3y0a2FDc,999
6
- xm_slurm/dependencies.py,sha256=-5gN_tpfs3dOA7H5_MIHO2ratb7F5Pm_yjkR5rZcgI8,6421
6
+ xm_slurm/dependencies.py,sha256=G-8vfmvSptZH6c_Ow51SwT84Dr6LI1clRj8F8wOUkiw,6421
7
7
  xm_slurm/executables.py,sha256=fGmrFBl-258bMn6ip5adYeM7xxUHAeIbDN9zD2FDGtY,6373
8
8
  xm_slurm/execution.py,sha256=mTy5u2oP2StIbGzjaSiGCUAwXuBFOiaJ5ephWoc25hI,31799
9
- xm_slurm/executors.py,sha256=bUgKcgtvf-nPGjcuHRzUAqD1r3_vwea_h-Y9MAB-Kqo,4887
9
+ xm_slurm/executors.py,sha256=karM5u2UEG2IWi0z548_vasyBACrXGV675rCllJmwZw,8616
10
10
  xm_slurm/experiment.py,sha256=94r0mhtUPUzw4eaUEz0kpsufC25wEGqlDhV4Fcr1ukY,39883
11
11
  xm_slurm/filesystem.py,sha256=4rKtq3t-KDgxJbSGt6JVyRJT_3lCN_vIKTcwKHpTo3I,4389
12
- xm_slurm/job_blocks.py,sha256=_F8CKCs5BQFj40a2-mjG71HfacvWoBXBDPDKEaKTbXc,616
12
+ xm_slurm/job_blocks.py,sha256=BFOOYgeodoGIQsB5PdC7SsOUou5aZx-1qbQ7lcqqylI,604
13
13
  xm_slurm/metadata_context.py,sha256=mksVRbVUuistL1uE7TC-fkW-Y69On52jN_svP1e1kiQ,7841
14
- xm_slurm/packageables.py,sha256=K6vNhLvASdnqsc8vXlT3h9cObJpC9Rbw93pUBvBwapQ,12209
15
- xm_slurm/resources.py,sha256=T7uje3E6oWbZSrsxykgW-40DE-Bvw_NWDM2qXbw2rgI,5740
16
- xm_slurm/status.py,sha256=WTWiDHi-ZHtwHRnDP0cGa-27zTSm6LkA-GCKsN-zBgg,6916
14
+ xm_slurm/packageables.py,sha256=aEZUQpddfq4FK6h4f6kgGEI4XcOufhm68MjoDFOYR4U,12261
15
+ xm_slurm/resources.py,sha256=aC8MzO_7fB9IAdTACvhwVOaNDjLOlWnCh428-8_IDYA,12322
16
+ xm_slurm/status.py,sha256=JIBCJPOYsmeJOQbzdACXA2vTWK7g8YWWhzpGP79e7JE,6911
17
17
  xm_slurm/types.py,sha256=TsVykDm-LazVkrjeJrTwCMs4Q8APKhy7BTk0yKIhFNg,805
18
- xm_slurm/utils.py,sha256=xtFvktaxr0z65sTdu6HhOVfyo0OAB9t-EYXWcYrQQEU,5958
18
+ xm_slurm/utils.py,sha256=9w98HlXF0U9cKKtoB8QtGm0CnB0MnnzBARKlbbVNNpU,6211
19
19
  xm_slurm/api/__init__.py,sha256=cyao3LZ3uLftu1wIv1aN7Qvsl6gYzYpkxeehTHZ0fA8,1089
20
20
  xm_slurm/api/abc.py,sha256=-lS2OndnOuEiwNdr8ccQKkwMd1iDmKMmkBOSTvo5H5w,1816
21
21
  xm_slurm/api/models.py,sha256=_INVh0j-4-rRs0WASyg4fNB6NF1L1nUeGgQ6-XnbwsM,1610
@@ -23,13 +23,13 @@ xm_slurm/api/sqlite/client.py,sha256=jAesCKDuYwnNcAxwJk_1b1TB8cT_QGbSjo1UE3mZjEQ
23
23
  xm_slurm/api/web/client.py,sha256=uO67Y7fnQ-w__Vm_A5BEuy7Qi8wQcWk3vIsBGEBkyfk,6261
24
24
  xm_slurm/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  xm_slurm/contrib/clusters/__init__.py,sha256=XFCVnkThiU3_8uA_tUgDByOBanXNHrxDvfmuptmQ2KE,2214
26
- xm_slurm/contrib/clusters/drac.py,sha256=vY3dxrNUk12H9Gq-tuCcqo2YcdTGq-4LJnQF6DzD4_k,7431
27
- xm_slurm/experimental/parameter_controller.py,sha256=b5LfglHV307F6QcPrHeZX5GJBtyOK9aQydke_SZ3Wto,8457
26
+ xm_slurm/contrib/clusters/drac.py,sha256=_iubsmzz5tK2KTaKqSuykS3IDtsdXqJ0MXep1THSJUM,7020
27
+ xm_slurm/experimental/parameter_controller.py,sha256=IrFzq104LkZrhzuirit5GUZDXDvv2bBSYNMh3orsiPY,8518
28
28
  xm_slurm/packaging/__init__.py,sha256=dh307yLpUT9KN7rJ1e9fYC6hegGKfZcGboUq9nGpDVQ,233
29
29
  xm_slurm/packaging/docker.py,sha256=-DWcB9qqbeHmIEqyfF0-v6xOT25ae90u2x-QZ7kluOw,13579
30
- xm_slurm/packaging/registry.py,sha256=GrdmQg9MgSo38OiqOzMKWSkQyBuyryOfc3zcdgZ4CUE,1148
31
- xm_slurm/packaging/router.py,sha256=yPbdA9clrhly97cLgDsSRZG2LZRKE-oz8Hhdb7WtYqk,2070
32
- xm_slurm/packaging/utils.py,sha256=6EAb17zKQQeuyNY2EV9AoW1RvnDGrQwmIT9wtQEsC4c,632
30
+ xm_slurm/packaging/registry.py,sha256=Hq56KhqsQRxgr_y1EQhcZORlnrs13xY5vDGge5WEgYU,1134
31
+ xm_slurm/packaging/router.py,sha256=MLWvy-shJzRAx4YCi9z9Dj_PoWrXZO8T71DDlLOcjaM,2062
32
+ xm_slurm/packaging/utils.py,sha256=KlU_GGkFH1Xu5VZkAMqRilmq6SV1iLai80beEZ3UQmw,616
33
33
  xm_slurm/scripts/_cloudpickle.py,sha256=dlJYf2SceOuUn8wi-ozuoYAQg71wqD2MUVOUCyOwWIY,647
34
34
  xm_slurm/scripts/cli.py,sha256=zzsQpvkx9VThAeQPM34iDK9wAWfCVCIIvLNI12UaMhw,2577
35
35
  xm_slurm/templates/docker/docker-bake.hcl.j2,sha256=7qSJl2VN5poz-Hh8Gjo7--qR-k3lmfGtBu2mNbfG2uA,1499
@@ -37,16 +37,16 @@ xm_slurm/templates/docker/mamba.Dockerfile,sha256=Sgxr5IA5T-pT1Shumb5k3JngoG4pgC
37
37
  xm_slurm/templates/docker/python.Dockerfile,sha256=U4b4QVkopckQ0o9jJIE7d_M6TvExEYlYDirNwCoZ7W4,865
38
38
  xm_slurm/templates/docker/uv.Dockerfile,sha256=L2UJMX2c8waMdrRhiqPytQe3pTBu6u5PpMhJYsKkbEg,1040
39
39
  xm_slurm/templates/slurm/entrypoint.bash.j2,sha256=MRdSVwgGrgQdpEhqfkP35IidgsblrtVXB1YWzvE9hkk,666
40
- xm_slurm/templates/slurm/job-array.bash.j2,sha256=smxmSSzBEUHm6MJF-nYPVVjK6CLKrb1fRxF_tfrzAX8,552
41
- xm_slurm/templates/slurm/job-group.bash.j2,sha256=Cp8YhNOxYqaOkl4MFjQlcaLMGZwdDh97m8OGT5RWbAo,1101
42
- xm_slurm/templates/slurm/job.bash.j2,sha256=DrDipliaEfiHbq9vDfOdfD8zBVFLy1jjlvCV-9-6k9s,2086
40
+ xm_slurm/templates/slurm/job-array.bash.j2,sha256=j7jkJjSbe39XvSTJ9rmK2oVnHdntElIhdS5PFpZzpFs,550
41
+ xm_slurm/templates/slurm/job-group.bash.j2,sha256=vH5HwneVsVSHx6dPZwbLa4KT9NedRbrZ7cWNE5pXi-M,1113
42
+ xm_slurm/templates/slurm/job.bash.j2,sha256=JnK0D8_3tVNpnvPwM5yL_rjLcjqhuHiCtolDjUGAwpk,2084
43
43
  xm_slurm/templates/slurm/fragments/monitor.bash.j2,sha256=ri5FgoKs6_bQVf5DO8SL4rJf4UsLxV34aOV-OD8VWDU,2526
44
44
  xm_slurm/templates/slurm/fragments/proxy.bash.j2,sha256=VJLglZo-Nvx9R-qe3rHTxr07CylTQ6Z9NwBzvIpAZrA,814
45
45
  xm_slurm/templates/slurm/library/retry.bash,sha256=bLe59qvfWEk17rE1wZ4EHiHba3RvR2WWZPq-kSe8RAA,2164
46
46
  xm_slurm/templates/slurm/runtimes/apptainer.bash.j2,sha256=v0LwHM-kBW8sJqVcVA2jYr1n44imDSZrJqmqlr5uTGc,1980
47
47
  xm_slurm/templates/slurm/runtimes/podman.bash.j2,sha256=zWLsFEuVzOMSETOmv4A5ZCV4oQHwCipiR6wi79XVzNI,1188
48
- xmanager_slurm-0.4.14.dist-info/METADATA,sha256=T7xNy0jmrKhQemaDhCg9E-J64gWkabjJGpxDYgdsBx8,1007
49
- xmanager_slurm-0.4.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
50
- xmanager_slurm-0.4.14.dist-info/entry_points.txt,sha256=_HLGmLgxuQLOPmF2gOFYDVq2HqtMVD_SzigHvUh8TCY,49
51
- xmanager_slurm-0.4.14.dist-info/licenses/LICENSE.md,sha256=IxstXr3MPHwTJ5jMrByHrQsR1ZAGQ2U_uz_4qzI_15Y,11756
52
- xmanager_slurm-0.4.14.dist-info/RECORD,,
48
+ xmanager_slurm-0.4.15.dist-info/METADATA,sha256=xgJUFDConlb4R5W0cOK3xb2fuIR7x_tjNpOybMagT_A,1007
49
+ xmanager_slurm-0.4.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
50
+ xmanager_slurm-0.4.15.dist-info/entry_points.txt,sha256=_HLGmLgxuQLOPmF2gOFYDVq2HqtMVD_SzigHvUh8TCY,49
51
+ xmanager_slurm-0.4.15.dist-info/licenses/LICENSE.md,sha256=IxstXr3MPHwTJ5jMrByHrQsR1ZAGQ2U_uz_4qzI_15Y,11756
52
+ xmanager_slurm-0.4.15.dist-info/RECORD,,