xmanager-slurm 0.4.13__py3-none-any.whl → 0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/__init__.py CHANGED
@@ -18,7 +18,7 @@ from xm_slurm.packageables import (
18
18
  python_container,
19
19
  uv_container,
20
20
  )
21
- from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
21
+ from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType, Topology
22
22
 
23
23
  logging.getLogger("asyncssh").setLevel(logging.WARN)
24
24
  logging.getLogger("httpx").setLevel(logging.WARN)
@@ -42,5 +42,6 @@ __all__ = [
42
42
  "Slurm",
43
43
  "SlurmExperiment",
44
44
  "SlurmSpec",
45
+ "Topology",
45
46
  "uv_container",
46
47
  ]
xm_slurm/batching.py CHANGED
@@ -4,11 +4,11 @@ import dataclasses
4
4
  import inspect
5
5
  import time
6
6
  import types
7
- from typing import Any, Callable, Coroutine, Generic, ParamSpec, Sequence, TypeVar
7
+ import typing as tp
8
8
 
9
- T = TypeVar("T", contravariant=True)
10
- R = TypeVar("R", covariant=True)
11
- P = ParamSpec("P")
9
+ T = tp.TypeVar("T", contravariant=True)
10
+ R = tp.TypeVar("R", covariant=True)
11
+ P = tp.ParamSpec("P")
12
12
 
13
13
 
14
14
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -18,10 +18,10 @@ class Request:
18
18
 
19
19
 
20
20
  def stack_bound_arguments(
21
- signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
21
+ signature: inspect.Signature, bound_arguments: tp.Sequence[inspect.BoundArguments]
22
22
  ) -> inspect.BoundArguments:
23
23
  """Stacks bound arguments into a single bound arguments object."""
24
- stacked_args = collections.OrderedDict[str, Any]()
24
+ stacked_args = collections.OrderedDict[str, tp.Any]()
25
25
  for bound_args in bound_arguments:
26
26
  for name, value in bound_args.arguments.items():
27
27
  stacked_args.setdefault(name, [])
@@ -29,7 +29,7 @@ def stack_bound_arguments(
29
29
  return inspect.BoundArguments(signature, stacked_args)
30
30
 
31
31
 
32
- class batch(Generic[R]):
32
+ class batch(tp.Generic[R]):
33
33
  __slots__ = (
34
34
  "fn",
35
35
  "signature",
@@ -44,7 +44,7 @@ class batch(Generic[R]):
44
44
 
45
45
  def __init__(
46
46
  self,
47
- fn: Callable[..., Coroutine[None, None, Sequence[R]]],
47
+ fn: tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]],
48
48
  /,
49
49
  *,
50
50
  max_batch_size: int,
@@ -102,7 +102,7 @@ class batch(Generic[R]):
102
102
 
103
103
  return batch
104
104
 
105
- def __get__(self, obj: Any, objtype: type) -> Any:
105
+ def __get__(self, obj: tp.Any, objtype: tp.Type[tp.Any]) -> tp.Any:
106
106
  del objtype
107
107
  if isinstance(self.fn, staticmethod):
108
108
  return self.__call__
@@ -110,11 +110,11 @@ class batch(Generic[R]):
110
110
  return types.MethodType(self, obj)
111
111
 
112
112
  @property
113
- def __func__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
113
+ def __func__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
114
114
  return self.fn
115
115
 
116
116
  @property
117
- def __wrapped__(self) -> Callable[..., Coroutine[None, None, Sequence[R]]]:
117
+ def __wrapped__(self) -> tp.Callable[..., tp.Coroutine[None, None, tp.Sequence[R]]]:
118
118
  return self.fn
119
119
 
120
120
  @property
xm_slurm/config.py CHANGED
@@ -5,7 +5,7 @@ import getpass
5
5
  import json
6
6
  import os
7
7
  import pathlib
8
- from typing import Callable, Literal, Mapping, NamedTuple
8
+ import typing as tp
9
9
 
10
10
  import asyncssh
11
11
  from xmanager import xm
@@ -23,7 +23,7 @@ class ContainerRuntime(enum.Enum):
23
23
 
24
24
  @classmethod
25
25
  def from_string(
26
- cls, runtime: Literal["singularity", "apptainer", "docker", "podman"]
26
+ cls, runtime: tp.Literal["singularity", "apptainer", "docker", "podman"]
27
27
  ) -> "ContainerRuntime":
28
28
  return {
29
29
  "singularity": cls.SINGULARITY,
@@ -45,7 +45,7 @@ class ContainerRuntime(enum.Enum):
45
45
  raise NotImplementedError
46
46
 
47
47
 
48
- class PublicKey(NamedTuple):
48
+ class PublicKey(tp.NamedTuple):
49
49
  algorithm: str
50
50
  key: str
51
51
 
@@ -172,26 +172,26 @@ class SlurmClusterConfig:
172
172
  qos: str | None = None
173
173
 
174
174
  # If true, a reverse proxy is initiated via the submission host.
175
- proxy: Literal["submission-host"] | str | None = None
175
+ proxy: tp.Literal["submission-host"] | str | None = None
176
176
 
177
177
  runtime: ContainerRuntime
178
178
 
179
179
  # Environment variables
180
- host_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
181
- container_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
180
+ host_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
181
+ container_environment: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
182
182
 
183
183
  # Mounts
184
- mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
184
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
185
185
  default_factory=dict
186
186
  )
187
187
 
188
188
  # Resource mapping
189
- resources: Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
189
+ resources: tp.Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
190
190
 
191
- features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
191
+ features: tp.Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
192
192
 
193
193
  # Function to validate the Slurm executor config
194
- validate: Callable[[xm.Job], None] | None = None
194
+ validate: tp.Callable[[xm.Job], None] | None = None
195
195
 
196
196
  def __post_init__(self) -> None:
197
197
  for src, dst in self.mounts.items():
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Literal
2
+ import typing as tp
3
3
 
4
4
  from xm_slurm import config
5
5
  from xm_slurm.resources import FeatureType, ResourceType
@@ -16,10 +16,10 @@ def _drac_cluster(
16
16
  user: str | None = None,
17
17
  account: str | None = None,
18
18
  modules: list[str] | None = None,
19
- proxy: Literal["submission-host"] | str | None = None,
20
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
- resources: dict[ResourceType, str] | None = None,
22
- features: dict[FeatureType, str] | None = None,
19
+ proxy: tp.Literal["submission-host"] | str | None = None,
20
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
+ resources: tp.Mapping[ResourceType, str] | None = None,
22
+ features: tp.Mapping[FeatureType, str] | None = None,
23
23
  ) -> config.SlurmClusterConfig:
24
24
  """DRAC Cluster."""
25
25
  if mounts is None:
@@ -62,8 +62,8 @@ def narval(
62
62
  *,
63
63
  user: str | None = None,
64
64
  account: str | None = None,
65
- proxy: Literal["submission-host"] | str | None = None,
66
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
65
+ proxy: tp.Literal["submission-host"] | str | None = None,
66
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
67
67
  ) -> config.SlurmClusterConfig:
68
68
  """DRAC Narval Cluster (https://docs.alliancecan.ca/wiki/Narval/en)."""
69
69
  modules = []
@@ -94,8 +94,8 @@ def beluga(
94
94
  *,
95
95
  user: str | None = None,
96
96
  account: str | None = None,
97
- proxy: Literal["submission-host"] | str | None = None,
98
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
97
+ proxy: tp.Literal["submission-host"] | str | None = None,
98
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
99
99
  ) -> config.SlurmClusterConfig:
100
100
  """DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/B%C3%A9luga/en)."""
101
101
  modules = []
@@ -125,8 +125,8 @@ def rorqual(
125
125
  *,
126
126
  user: str | None = None,
127
127
  account: str | None = None,
128
- proxy: Literal["submission-host"] | str | None = None,
129
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
128
+ proxy: tp.Literal["submission-host"] | str | None = None,
129
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
130
130
  ) -> config.SlurmClusterConfig:
131
131
  """DRAC Beluga Cluster (https://docs.alliancecan.ca/wiki/Rorqual/en)."""
132
132
  modules = []
@@ -155,7 +155,7 @@ def cedar(
155
155
  *,
156
156
  user: str | None = None,
157
157
  account: str | None = None,
158
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
158
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
159
159
  ) -> config.SlurmClusterConfig:
160
160
  """DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Cedar/en)."""
161
161
  return _drac_cluster(
@@ -180,7 +180,7 @@ def fir(
180
180
  *,
181
181
  user: str | None = None,
182
182
  account: str | None = None,
183
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
183
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
184
184
  ) -> config.SlurmClusterConfig:
185
185
  """DRAC Fir Cluster (https://docs.alliancecan.ca/wiki/Fir/en)."""
186
186
  return _drac_cluster(
@@ -201,8 +201,8 @@ def graham(
201
201
  *,
202
202
  user: str | None = None,
203
203
  account: str | None = None,
204
- proxy: Literal["submission-host"] | str | None = "submission-host",
205
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
204
+ proxy: tp.Literal["submission-host"] | str | None = "submission-host",
205
+ mounts: tp.Mapping[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
206
206
  ) -> config.SlurmClusterConfig:
207
207
  """DRAC Cedar Cluster (https://docs.alliancecan.ca/wiki/Graham/en)."""
208
208
  return _drac_cluster(
@@ -223,17 +223,3 @@ def graham(
223
223
  ResourceType.A5000: "a5000",
224
224
  },
225
225
  )
226
-
227
-
228
- def all(
229
- user: str | None = None,
230
- account: str | None = None,
231
- mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
232
- ) -> list[config.SlurmClusterConfig]:
233
- """All DRAC clusters."""
234
- return [
235
- narval(user=user, account=account, mounts=mounts),
236
- beluga(user=user, account=account, mounts=mounts),
237
- cedar(user=user, account=account, mounts=mounts),
238
- graham(user=user, account=account, mounts=mounts),
239
- ]
xm_slurm/dependencies.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import abc
2
2
  import dataclasses
3
3
  import datetime as dt
4
- from typing import Callable, Sequence
4
+ import typing as tp
5
5
 
6
6
 
7
7
  class SlurmDependencyException(Exception): ...
@@ -36,7 +36,7 @@ class SlurmJobDependency(abc.ABC):
36
36
  return (self,)
37
37
 
38
38
  def traverse(
39
- self, mapper: Callable[["SlurmJobDependency"], "SlurmJobDependency"]
39
+ self, mapper: tp.Callable[["SlurmJobDependency"], "SlurmJobDependency"]
40
40
  ) -> "SlurmJobDependency":
41
41
  if isinstance(self, SlurmJobDependencyAND) or isinstance(self, SlurmJobDependencyOR):
42
42
  return type(self)(
@@ -80,7 +80,7 @@ class SlurmJobDependencyOR(SlurmJobDependency):
80
80
 
81
81
  @dataclasses.dataclass(frozen=True)
82
82
  class SlurmJobDependencyAfter(SlurmJobDependency):
83
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
83
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
84
84
  time: dt.timedelta | None = None
85
85
 
86
86
  def __post_init__(self):
@@ -104,7 +104,7 @@ class SlurmJobDependencyAfter(SlurmJobDependency):
104
104
 
105
105
  @dataclasses.dataclass(frozen=True)
106
106
  class SlurmJobDependencyAfterAny(SlurmJobDependency):
107
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
107
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
108
108
 
109
109
  def __post_init__(self):
110
110
  if len(self.handles) == 0:
@@ -119,7 +119,7 @@ class SlurmJobDependencyAfterAny(SlurmJobDependency):
119
119
 
120
120
  @dataclasses.dataclass(frozen=True)
121
121
  class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
122
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
122
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
123
123
 
124
124
  def __post_init__(self):
125
125
  if len(self.handles) == 0:
@@ -134,7 +134,7 @@ class SlurmJobDependencyAfterNotOK(SlurmJobDependency):
134
134
 
135
135
  @dataclasses.dataclass(frozen=True)
136
136
  class SlurmJobDependencyAfterOK(SlurmJobDependency):
137
- handles: Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
137
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle"] # type: ignore # noqa: F821
138
138
 
139
139
  def __post_init__(self):
140
140
  if len(self.handles) == 0:
@@ -149,7 +149,7 @@ class SlurmJobDependencyAfterOK(SlurmJobDependency):
149
149
 
150
150
  @dataclasses.dataclass(frozen=True)
151
151
  class SlurmJobArrayDependencyAfterOK(SlurmJobDependency):
152
- handles: Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
152
+ handles: tp.Sequence["xm_slurm.execution.SlurmHandle[SlurmJob]"] # type: ignore # noqa: F821
153
153
 
154
154
  def __post_init__(self):
155
155
  if len(self.handles) == 0:
xm_slurm/execution.py CHANGED
@@ -3,6 +3,8 @@ import collections.abc
3
3
  import dataclasses
4
4
  import functools
5
5
  import hashlib
6
+ import importlib
7
+ import importlib.resources
6
8
  import logging
7
9
  import operator
8
10
  import os
@@ -311,6 +313,14 @@ def get_template_env(runtime: ContainerRuntime) -> j2.Environment:
311
313
  template_env.globals["raise"] = _raise_template_exception
312
314
  template_env.globals["operator"] = operator
313
315
 
316
+ # Iterate over stdlib files and insert them into the template environment
317
+ stdlib = []
318
+ for file in importlib.resources.files("xm_slurm.templates.slurm.library").iterdir():
319
+ if not file.is_file() or not file.name.endswith(".bash"):
320
+ continue
321
+ stdlib.append(file.read_text())
322
+ template_env.globals["stdlib"] = stdlib
323
+
314
324
  entrypoint_template = template_env.get_template("entrypoint.bash.j2")
315
325
  template_env.globals.update(entrypoint_template.module.__dict__)
316
326
 
xm_slurm/executors.py CHANGED
@@ -1,10 +1,18 @@
1
+ import collections.abc
1
2
  import dataclasses
2
3
  import datetime as dt
3
4
  import signal
5
+ import typing as tp
4
6
 
5
7
  from xmanager import xm
6
8
 
7
- from xm_slurm import resources
9
+ from xm_slurm import resources, utils
10
+
11
+ ResourceBindType = tp.Literal[
12
+ resources.ResourceType.GPU,
13
+ resources.ResourceType.MEMORY,
14
+ resources.ResourceType.RAM,
15
+ ]
8
16
 
9
17
 
10
18
  @dataclasses.dataclass(frozen=True, kw_only=True)
@@ -26,10 +34,19 @@ class Slurm(xm.Executor):
26
34
  Args:
27
35
  requirements: The requirements for the job.
28
36
  time: The maximum time to run the job.
37
+ switches: Maximum count of leaf switches desired for the job allocation.
38
+ switches_grace_period: Maximum time to wait for that number of switches.
39
+ bind: How to bind tasks to resource (memory, GPU, or generic resource).
29
40
  account: The account to charge the job to.
30
41
  partition: The partition to run the job in.
31
42
  qos: The quality of service to run the job with.
32
43
  priority: The priority of the job.
44
+ reservation: Allocate resources for the job from the named reservation.
45
+ exclusive: Allow sharing nodes with other running jobs.
46
+ oversubscribe: Allow over-subscribing resources with other running jobs.
47
+ overcommit: Allow sharing of allocated resources as if only one task per was requested.
48
+ nice: Run the job with an adjusted scheduling priority.
49
+ kill_on_invalid_dependencies: Whether to kill the job if it has invalid dependencies.
33
50
  timeout_signal: The signal to send to the job when it runs out of time.
34
51
  timeout_signal_grace_period: The time to wait before sending `timeout_signal`.
35
52
  requeue: Whether or not the job is eligible for requeueing.
@@ -41,12 +58,18 @@ class Slurm(xm.Executor):
41
58
  # Job requirements
42
59
  requirements: resources.JobRequirements
43
60
  time: dt.timedelta
61
+ bind: tp.Mapping[ResourceBindType | str, str | None] | None = None
44
62
 
45
63
  # Placement
46
64
  account: str | None = None
47
65
  partition: str | None = None
48
66
  qos: str | None = None
49
67
  priority: int | None = None
68
+ reservation: str | tp.Iterable[str] | None = None
69
+ exclusive: bool = False
70
+ oversubscribe: bool = False
71
+ overcommit: bool = False
72
+ nice: int | None = None
50
73
 
51
74
  # Job dependency handling
52
75
  kill_on_invalid_dependencies: bool = True
@@ -65,13 +88,28 @@ class Slurm(xm.Executor):
65
88
  return self.time - self.timeout_signal_grace_period
66
89
 
67
90
  def __post_init__(self) -> None:
68
- if not isinstance(self.time, dt.timedelta):
69
- raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
70
91
  if not isinstance(self.requirements, resources.JobRequirements):
71
92
  raise TypeError(
72
93
  f"requirements must be a `xm_slurm.JobRequirements`, got {type(self.requirements)}. "
73
94
  "If you're still using `xm.JobRequirements`, please update to `xm_slurm.JobRequirements`."
74
95
  )
96
+ if not isinstance(self.time, dt.timedelta):
97
+ raise TypeError(f"time must be a `datetime.timedelta`, got {type(self.time)}")
98
+ if self.bind is not None:
99
+ if not isinstance(self.bind, collections.abc.Mapping):
100
+ raise TypeError(f"bind must be a mapping, got {type(self.bind)}")
101
+ for resource, value in self.bind.items():
102
+ if resource not in (
103
+ resources.ResourceType.GPU,
104
+ resources.ResourceType.MEMORY,
105
+ resources.ResourceType.RAM,
106
+ ) and not isinstance(resource, str):
107
+ raise TypeError(
108
+ f"bind resource must be a {resources.ResourceType.GPU.name}, {resources.ResourceType.MEMORY.name}, or {resources.ResourceType.RAM.name}, got {type(resource)}"
109
+ )
110
+ if value is not None and not isinstance(value, str):
111
+ raise TypeError(f"bind value must be None or a string, got {type(value)}")
112
+
75
113
  if not isinstance(self.timeout_signal, signal.Signals):
76
114
  raise TypeError(
77
115
  f"termination_signal must be a `signal.Signals`, got {type(self.timeout_signal)}"
@@ -86,6 +124,10 @@ class Slurm(xm.Executor):
86
124
  )
87
125
  if self.requeue_on_exit_code == 0:
88
126
  raise ValueError("requeue_on_exit_code should not be 0 to avoid unexpected behavior.")
127
+ if self.exclusive and self.oversubscribe:
128
+ raise ValueError("exclusive and oversubscribe are mutually exclusive.")
129
+ if self.nice is not None and not (-2147483645 <= self.nice <= 2147483645):
130
+ raise ValueError(f"nice must be between -2147483645 and 2147483645, got {self.nice}")
89
131
 
90
132
  @classmethod
91
133
  def Spec(cls, tag: str | None = None) -> SlurmSpec:
@@ -96,10 +138,22 @@ class Slurm(xm.Executor):
96
138
  directives = self.requirements.to_directives()
97
139
 
98
140
  # Time
99
- days = self.time.days
100
- hours, remainder = divmod(self.time.seconds, 3600)
101
- minutes, seconds = divmod(remainder, 60)
102
- directives.append(f"--time={days}-{hours:02}:{minutes:02}:{seconds:02}")
141
+ directives.append(f"--time={utils.timestr_from_timedelta(self.time)}")
142
+
143
+ # Resource binding
144
+ if self.bind is not None:
145
+ for resource, value in self.bind.items():
146
+ if value is None:
147
+ value = "none"
148
+ match resource:
149
+ case resources.ResourceType.MEMORY | resources.ResourceType.RAM:
150
+ directives.append(f"--mem-bind={value}")
151
+ case resources.ResourceType.GPU:
152
+ directives.append(f"--gpu-bind={value}")
153
+ case str():
154
+ directives.append(f"--tres-bind=gres/{resource}:{value}")
155
+ case _:
156
+ raise ValueError(f"Unsupported resource type {resource!r} for binding.")
103
157
 
104
158
  # Job dependency handling
105
159
  directives.append(
@@ -107,20 +161,36 @@ class Slurm(xm.Executor):
107
161
  )
108
162
 
109
163
  # Placement
110
- if self.account:
164
+ if self.account is not None:
111
165
  directives.append(f"--account={self.account}")
112
- if self.partition:
166
+ if self.partition is not None:
113
167
  directives.append(f"--partition={self.partition}")
114
- if self.qos:
168
+ if self.qos is not None:
115
169
  directives.append(f"--qos={self.qos}")
116
- if self.priority:
170
+ if self.priority is not None:
117
171
  directives.append(f"--priority={self.priority}")
172
+ if self.reservation is not None:
173
+ match self.reservation:
174
+ case str():
175
+ directives.append(f"--reservation={self.reservation}")
176
+ case collections.abc.Iterable():
177
+ directives.append(f"--reservation={','.join(self.reservation)}")
178
+ case _:
179
+ raise ValueError(f"Invalid reservation type: {type(self.reservation)}")
180
+ if self.exclusive:
181
+ directives.append("--exclusive")
182
+ if self.oversubscribe:
183
+ directives.append("--oversubscribe")
184
+ if self.overcommit:
185
+ directives.append("--overcommit")
186
+ if self.nice is not None:
187
+ directives.append(f"--nice={self.nice}")
118
188
 
119
189
  # Job rescheduling
120
190
  directives.append(
121
191
  f"--signal={self.timeout_signal.name.removeprefix('SIG')}@{self.timeout_signal_grace_period.seconds}"
122
192
  )
123
- if self.requeue and self.requeue_max_attempts > 0:
193
+ if self.requeue is not None and self.requeue_max_attempts > 0:
124
194
  directives.append("--requeue")
125
195
  else:
126
196
  directives.append("--no-requeue")
@@ -4,8 +4,8 @@ import dataclasses
4
4
  import enum
5
5
  import functools
6
6
  import logging
7
+ import typing as tp
7
8
  import zlib
8
- from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
9
9
 
10
10
  import backoff
11
11
  import cloudpickle
@@ -15,15 +15,15 @@ import xm_slurm
15
15
  from xm_slurm import job_blocks, status
16
16
  from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
17
 
18
- P = ParamSpec("P")
19
- T = TypeVar("T")
18
+ P = tp.ParamSpec("P")
19
+ T = tp.TypeVar("T")
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
24
  async def _monitor_parameter_controller(
25
25
  aux_unit: SlurmAuxiliaryUnit,
26
- local_parameter_controller_coro: Coroutine[None, None, T],
26
+ local_parameter_controller_coro: tp.Coroutine[None, None, T],
27
27
  *,
28
28
  poll_interval: float = 30.0,
29
29
  ) -> None:
@@ -104,13 +104,13 @@ def parameter_controller(
104
104
  controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
105
105
  controller_name: str = "parameter_controller",
106
106
  controller_args: xm.UserArgs | None = None,
107
- controller_env_vars: Mapping[str, str] | None = None,
108
- ) -> Callable[
107
+ controller_env_vars: tp.Mapping[str, str] | None = None,
108
+ ) -> tp.Callable[
109
109
  [
110
- Callable[Concatenate[SlurmExperiment, P], T]
111
- | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
110
+ tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
111
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
112
112
  ],
113
- Callable[P, xm.AuxiliaryUnitJob],
113
+ tp.Callable[P, xm.AuxiliaryUnitJob],
114
114
  ]:
115
115
  """Converts a function to a controller which can be added to an experiment.
116
116
 
@@ -131,9 +131,9 @@ def parameter_controller(
131
131
  """
132
132
 
133
133
  def decorator(
134
- f: Callable[Concatenate[SlurmExperiment, P], T]
135
- | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
136
- ) -> Callable[P, xm.AuxiliaryUnitJob]:
134
+ f: tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
135
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
136
+ ) -> tp.Callable[P, xm.AuxiliaryUnitJob]:
137
137
  @functools.wraps(f)
138
138
  def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
139
139
  # Modify the function to read the experiment from the API so that it can be pickled.
@@ -141,13 +141,17 @@ def parameter_controller(
141
141
  async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
142
142
  experiment_id = aux_unit.experiment.experiment_id
143
143
 
144
- async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
144
+ async def local_controller(
145
+ *args: P.args, **kwargs: P.kwargs
146
+ ) -> T | tp.Awaitable[T]:
145
147
  if asyncio.iscoroutinefunction(f):
146
148
  return await f(aux_unit.experiment, *args, **kwargs)
147
149
  else:
148
150
  return f(aux_unit.experiment, *args, **kwargs)
149
151
 
150
- async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
152
+ async def remote_controller(
153
+ *args: P.args, **kwargs: P.kwargs
154
+ ) -> T | tp.Awaitable[T]:
151
155
  async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
152
156
  if asyncio.iscoroutinefunction(f):
153
157
  return await f(exp, *args, **kwargs)
xm_slurm/job_blocks.py CHANGED
@@ -1,11 +1,11 @@
1
- from typing import Mapping, TypedDict
1
+ import typing as tp
2
2
 
3
3
  from xmanager import xm
4
4
 
5
5
 
6
- class JobArgs(TypedDict, total=False):
6
+ class JobArgs(tp.TypedDict, total=False):
7
7
  args: xm.UserArgs
8
- env_vars: Mapping[str, str]
8
+ env_vars: tp.Mapping[str, str]
9
9
 
10
10
 
11
11
  def get_args_for_python_entrypoint(