xmanager-slurm 0.4.13__py3-none-any.whl → 0.4.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/packageables.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import importlib.resources as resources
2
2
  import pathlib
3
3
  import sys
4
- from typing import Literal, Mapping, Sequence
4
+ import typing as tp
5
5
 
6
6
  from xmanager import xm
7
7
 
@@ -14,7 +14,7 @@ def docker_image(
14
14
  *,
15
15
  image: str,
16
16
  args: xm.UserArgs | None = None,
17
- env_vars: Mapping[str, str] | None = None,
17
+ env_vars: tp.Mapping[str, str] | None = None,
18
18
  ) -> xm.Packageable:
19
19
  """Creates a packageable for a pre-built Docker image.
20
20
 
@@ -39,12 +39,12 @@ def docker_container(
39
39
  dockerfile: pathlib.Path | None = None,
40
40
  context: pathlib.Path | None = None,
41
41
  target: str | None = None,
42
- ssh: Sequence[str] | Literal[True] | None = None,
43
- build_args: Mapping[str, str] | None = None,
44
- cache_from: str | Sequence[str] | None = None,
45
- labels: Mapping[str, str] | None = None,
42
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
43
+ build_args: tp.Mapping[str, str] | None = None,
44
+ cache_from: str | tp.Sequence[str] | None = None,
45
+ labels: tp.Mapping[str, str] | None = None,
46
46
  args: xm.UserArgs | None = None,
47
- env_vars: Mapping[str, str] | None = None,
47
+ env_vars: tp.Mapping[str, str] | None = None,
48
48
  ) -> xm.Packageable:
49
49
  """Creates a Docker container packageable from a dockerfile.
50
50
 
@@ -104,13 +104,13 @@ def python_container(
104
104
  context: pathlib.Path | None = None,
105
105
  requirements: pathlib.Path | None = None,
106
106
  base_image: str = "docker.io/python:{major}.{minor}-slim",
107
- extra_system_packages: Sequence[str] = (),
108
- extra_python_packages: Sequence[str] = (),
109
- cache_from: str | Sequence[str] | None = None,
110
- labels: Mapping[str, str] | None = None,
111
- ssh: Sequence[str] | Literal[True] | None = None,
107
+ extra_system_packages: tp.Sequence[str] = (),
108
+ extra_python_packages: tp.Sequence[str] = (),
109
+ cache_from: str | tp.Sequence[str] | None = None,
110
+ labels: tp.Mapping[str, str] | None = None,
111
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
112
112
  args: xm.UserArgs | None = None,
113
- env_vars: Mapping[str, str] | None = None,
113
+ env_vars: tp.Mapping[str, str] | None = None,
114
114
  ) -> xm.Packageable:
115
115
  """Creates a Python container from a base image using pip from a `requirements.txt` file.
116
116
 
@@ -181,11 +181,11 @@ def mamba_container(
181
181
  context: pathlib.Path | None = None,
182
182
  environment: pathlib.Path | None = None,
183
183
  base_image: str = "gcr.io/distroless/base-debian10",
184
- cache_from: str | Sequence[str] | None = None,
185
- labels: Mapping[str, str] | None = None,
186
- ssh: Sequence[str] | Literal[True] | None = None,
184
+ cache_from: str | tp.Sequence[str] | None = None,
185
+ labels: tp.Mapping[str, str] | None = None,
186
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
187
187
  args: xm.UserArgs | None = None,
188
- env_vars: Mapping[str, str] | None = None,
188
+ env_vars: tp.Mapping[str, str] | None = None,
189
189
  ) -> xm.Packageable:
190
190
  """Creates a Conda container from a base image using mamba from a `environment.yml` file.
191
191
 
@@ -249,13 +249,13 @@ def uv_container(
249
249
  entrypoint: xm.ModuleName | xm.CommandList,
250
250
  context: pathlib.Path | None = None,
251
251
  base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
252
- extra_system_packages: Sequence[str] = (),
253
- extra_python_packages: Sequence[str] = (),
254
- cache_from: str | Sequence[str] | None = None,
255
- labels: Mapping[str, str] | None = None,
256
- ssh: Sequence[str] | Literal[True] | None = None,
252
+ extra_system_packages: tp.Sequence[str] = (),
253
+ extra_python_packages: tp.Sequence[str] = (),
254
+ cache_from: str | tp.Sequence[str] | None = None,
255
+ labels: tp.Mapping[str, str] | None = None,
256
+ ssh: tp.Sequence[str] | tp.Literal[True] | None = None,
257
257
  args: xm.UserArgs | None = None,
258
- env_vars: Mapping[str, str] | None = None,
258
+ env_vars: tp.Mapping[str, str] | None = None,
259
259
  ) -> xm.Packageable:
260
260
  """Creates a Python container from a base image using uv from a `uv.lock` file.
261
261
 
@@ -1,31 +1,31 @@
1
1
  import dataclasses
2
- from typing import Callable, Generic, ParamSpec, Sequence, Type, TypeVar
2
+ import typing as tp
3
3
 
4
4
  from xmanager import xm
5
5
 
6
- T_co = TypeVar("T_co", covariant=True)
7
- P = ParamSpec("P")
8
- ExecutableSpecT = TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
6
+ T_co = tp.TypeVar("T_co", covariant=True)
7
+ P = tp.ParamSpec("P")
8
+ ExecutableSpecT = tp.TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
9
9
 
10
10
 
11
11
  @dataclasses.dataclass(frozen=True)
12
- class IndexedContainer(Generic[T_co]):
12
+ class IndexedContainer(tp.Generic[T_co]):
13
13
  index: int
14
14
  value: T_co
15
15
 
16
16
 
17
- RegistrationCallable = Callable[
18
- [Sequence[IndexedContainer[xm.Packageable]]],
19
- Sequence[IndexedContainer[xm.Executable]],
17
+ RegistrationCallable = tp.Callable[
18
+ [tp.Sequence[IndexedContainer[xm.Packageable]]],
19
+ tp.Sequence[IndexedContainer[xm.Executable]],
20
20
  ]
21
21
 
22
22
 
23
- _REGISTRY: dict[Type[xm.ExecutableSpec], RegistrationCallable] = {}
23
+ _REGISTRY: dict[tp.Type[xm.ExecutableSpec], RegistrationCallable] = {}
24
24
 
25
25
 
26
26
  def register(
27
- *typs: Type[ExecutableSpecT],
28
- ) -> Callable[[RegistrationCallable], RegistrationCallable]:
27
+ *typs: tp.Type[ExecutableSpecT],
28
+ ) -> tp.Callable[[RegistrationCallable], RegistrationCallable]:
29
29
  def decorator(
30
30
  registration_callable: RegistrationCallable,
31
31
  ) -> RegistrationCallable:
@@ -38,8 +38,8 @@ def register(
38
38
 
39
39
 
40
40
  def route(
41
- typ: Type[ExecutableSpecT],
42
- packageables: Sequence[IndexedContainer[xm.Packageable]],
43
- ) -> Sequence[IndexedContainer[xm.Executable]]:
41
+ typ: tp.Type[ExecutableSpecT],
42
+ packageables: tp.Sequence[IndexedContainer[xm.Packageable]],
43
+ ) -> tp.Sequence[IndexedContainer[xm.Executable]]:
44
44
  global _REGISTRY
45
45
  return _REGISTRY[typ](packageables)
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import logging
3
- from typing import Sequence, Type
3
+ import typing as tp
4
4
 
5
5
  from xmanager import xm
6
6
 
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def package(
17
- packageables: Sequence[xm.Packageable],
17
+ packageables: tp.Sequence[xm.Packageable],
18
18
  ) -> list[xm.Executable]:
19
19
  """
20
20
  Takes as input a list of packageables and returns a mapping of
@@ -23,7 +23,7 @@ def package(
23
23
  # Docker targets to be collected.
24
24
  # These are a mapping from `DockerTarget` to the latest digest of the image.
25
25
  targets_by_type = collections.defaultdict[
26
- Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
26
+ tp.Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
27
27
  ](list)
28
28
 
29
29
  # Collect dockerfiles that need to be built locally
@@ -1,20 +1,20 @@
1
1
  import collections
2
2
  import logging
3
- from typing import ParamSpec, Sequence, TypeVar
3
+ import typing as tp
4
4
 
5
5
  from xmanager import xm
6
6
 
7
7
  from xm_slurm.packaging.registry import IndexedContainer
8
8
 
9
- T = TypeVar("T")
10
- P = ParamSpec("P")
11
- ReturnT = TypeVar("ReturnT")
9
+ T = tp.TypeVar("T")
10
+ P = tp.ParamSpec("P")
11
+ ReturnT = tp.TypeVar("ReturnT")
12
12
 
13
13
  logger = logging.getLogger(__name__)
14
14
 
15
15
 
16
16
  def collect_executors_by_executable(
17
- targets: Sequence[IndexedContainer[xm.Packageable]],
17
+ targets: tp.Sequence[IndexedContainer[xm.Packageable]],
18
18
  ) -> dict[xm.ExecutableSpec, set[xm.ExecutorSpec]]:
19
19
  executors_by_executable = collections.defaultdict(set)
20
20
  for target in targets:
xm_slurm/resources.py CHANGED
@@ -1,9 +1,13 @@
1
+ import builtins
2
+ import collections.abc
3
+ import datetime as dt
1
4
  import enum
2
5
  import itertools
3
6
  import math
4
- from typing import Mapping
7
+ import re
8
+ import typing as tp
5
9
 
6
- from xm_slurm import config
10
+ from xm_slurm import config, utils
7
11
 
8
12
 
9
13
  class ResourceType(enum.IntEnum):
@@ -57,32 +61,123 @@ assert AcceleratorType | {
57
61
  } == set(ResourceType.__members__.values()), "Resource types are not exhaustive."
58
62
 
59
63
 
60
- ResourceQuantity = int | float
61
-
62
-
63
64
  class FeatureType(enum.IntEnum):
64
65
  NVIDIA_MIG = 1
65
66
  NVIDIA_NVLINK = 2
66
67
 
67
68
 
69
+ class InvalidTopologyError(Exception):
70
+ """An unrecognized topology has been provided."""
71
+
72
+
73
+ TOPOLOGY_REGEX = re.compile(r"^(?P<dims>[\d]+(?:x[\d]+)*)$")
74
+
75
+
76
+ class Topology:
77
+ mesh: str
78
+ dimensions: list[int]
79
+ switches: int | None
80
+ switches_grace_period: dt.timedelta | None
81
+
82
+ def __init__(
83
+ self,
84
+ mesh: str,
85
+ /,
86
+ *,
87
+ switches: int | None = None,
88
+ switches_grace_period: dt.timedelta | None = None,
89
+ ):
90
+ mesh_match = TOPOLOGY_REGEX.fullmatch(mesh)
91
+ if not mesh_match:
92
+ raise InvalidTopologyError(f"Invalid topology mesh: {mesh!r}.")
93
+
94
+ self.mesh = mesh
95
+ self.dimensions = list(map(int, mesh_match.group("dims").split("x")))
96
+ if switches is not None:
97
+ assert (
98
+ isinstance(switches, int) and switches > 0
99
+ ), "Switches must be a positive integer."
100
+ self.switches = switches
101
+ if switches_grace_period is not None:
102
+ assert isinstance(
103
+ switches_grace_period, dt.timedelta
104
+ ), "Switches grace period must be a `datetime.timedelta`."
105
+ self.switches_grace_period = switches_grace_period
106
+
107
+ @property
108
+ def chip_count(self) -> int:
109
+ return math.prod(self.dimensions)
110
+
111
+ @property
112
+ def ndim(self) -> int:
113
+ return len(self.dimensions)
114
+
115
+ def __eq__(self, other: object) -> bool:
116
+ if not isinstance(other, Topology):
117
+ return False
118
+ return (
119
+ self.mesh == other.mesh
120
+ and self.switches == other.switches
121
+ and self.switches_grace_period == other.switches_grace_period
122
+ )
123
+
124
+ def __hash__(self) -> int:
125
+ return hash((self.mesh, self.switches, self.switches_grace_period))
126
+
127
+
128
+ ResourceQuantity = int | float | Topology
129
+
130
+
131
+ def _parse_resource_quantity(
132
+ resource_name: ResourceType | str, value: ResourceQuantity
133
+ ) -> tuple[float, Topology | None]:
134
+ if isinstance(resource_name, ResourceType):
135
+ resource_name = resource_name.name
136
+ match value:
137
+ case Topology() as topology:
138
+ return topology.chip_count, topology
139
+ case builtins.str() as topology_str if (
140
+ "x" in topology_str and TOPOLOGY_REGEX.fullmatch(topology_str) is not None
141
+ ):
142
+ topology = Topology(topology_str)
143
+ return topology.chip_count, topology
144
+ case builtins.str() as num_str:
145
+ try:
146
+ value = float(num_str)
147
+ return int(value) if value.is_integer() else value, None
148
+ except ValueError as e:
149
+ raise ValueError(
150
+ f"Couldn't parse resource quantity for {resource_name}. "
151
+ f"{num_str!r} was given."
152
+ ) from e
153
+ case int() | float():
154
+ return value, None
155
+ case _:
156
+ raise ValueError(f"Invalid resource quantity: {value!r} for {resource_name!r}.")
157
+
158
+
68
159
  class JobRequirements:
69
160
  replicas: int
70
161
  location: str | None
71
162
  accelerator: ResourceType | None
72
- cluster: config.SlurmClusterConfig | None = None
163
+ topology: Topology | None
164
+ cluster: config.SlurmClusterConfig
73
165
 
74
166
  def __init__(
75
167
  self,
76
168
  *,
77
- resources: Mapping[ResourceType | str, ResourceQuantity] | None = None,
78
- replicas: int = 1,
79
- location: str | None = None,
80
- cluster: config.SlurmClusterConfig | None = None,
169
+ resources: tp.Mapping[ResourceType | str, ResourceQuantity] | None = None,
170
+ replicas: int | None = None,
171
+ location: str | tp.Iterable[str] | None = None,
172
+ cluster: config.SlurmClusterConfig,
81
173
  **kw_resources: ResourceQuantity,
82
174
  ):
83
- self.replicas = replicas or 1
175
+ if isinstance(location, collections.abc.Iterable) and not isinstance(location, str):
176
+ location = ",".join(location)
84
177
  self.location = location
178
+
85
179
  self.accelerator = None
180
+ self.topology = None
86
181
  self.cluster = cluster
87
182
 
88
183
  if resources is None:
@@ -90,6 +185,7 @@ class JobRequirements:
90
185
 
91
186
  self.task_requirements: dict[ResourceType | str, ResourceQuantity] = {}
92
187
  for resource_name, value in itertools.chain(resources.items(), kw_resources.items()):
188
+ quantity, topology = _parse_resource_quantity(resource_name, value)
93
189
  match resource_name:
94
190
  case str() if resource_name.upper() in ResourceType.__members__:
95
191
  resource = ResourceType[resource_name.upper()]
@@ -106,58 +202,132 @@ class JobRequirements:
106
202
  if self.accelerator is not None:
107
203
  raise ValueError("Accelerator already set.")
108
204
  self.accelerator = resource # type: ignore
205
+ self.topology = topology or Topology(f"{quantity:g}")
206
+ elif topology is not None:
207
+ raise ValueError(
208
+ f"A topology was specified for a non-accelerator resource: {resource_name!r}."
209
+ )
109
210
 
110
211
  if resource in self.task_requirements:
111
212
  raise ValueError(f"{resource} has been specified twice.")
112
- self.task_requirements[resource] = value
213
+ self.task_requirements[resource] = quantity
214
+
215
+ if self.topology is not None and self.topology.ndim > 2:
216
+ raise ValueError("Topologies with more than 2 dimensions are not supported.")
217
+
218
+ if (
219
+ self.accelerator is not None
220
+ and self.topology is not None
221
+ and len(self.topology.dimensions) == 2
222
+ ):
223
+ if replicas is not None and replicas != self.topology.dimensions[1]:
224
+ raise ValueError(
225
+ f"For multihost GPUs with topology {self.topology}, replicas should"
226
+ f"be either None or {self.topology.dimensions[1]}. Found: "
227
+ f"{replicas}"
228
+ )
229
+ replicas = self.topology.dimensions[1]
230
+
231
+ if replicas is not None and replicas <= 0:
232
+ raise ValueError(f"Replicas must be a positive integer, got {replicas!r}")
233
+ self.replicas = replicas or 1
113
234
 
114
235
  def to_directives(self) -> list[str]:
115
- if self.cluster is None:
116
- raise ValueError("Cannnot derive Slurm directives for requirements without a cluster.")
117
236
  directives = []
118
237
 
119
238
  for resource, value in self.task_requirements.items():
120
239
  match resource:
121
240
  case ResourceType.EPHEMERAL_STORAGE | ResourceType.DISK:
122
- assert isinstance(value, int), "Disk space must be an integer"
241
+ assert isinstance(
242
+ value, int
243
+ ), f"Disk space must be an integer, got {type(value)!r}"
123
244
  directives.append(f"--tmp={math.ceil(value / 2**20)}M")
124
245
  case ResourceType.MEMORY | ResourceType.RAM:
125
246
  num_cpus = self.task_requirements.get(ResourceType.CPU, 1)
126
- assert isinstance(value, (int, float)), "Memory must be an integer or float"
127
- assert isinstance(num_cpus, int), "CPU must be an integer"
247
+ assert isinstance(
248
+ value, (int, float)
249
+ ), f"Memory must be an integer or float, got {type(value)!r}"
250
+ assert isinstance(
251
+ num_cpus, int
252
+ ), f"CPU must be an integer, got {type(num_cpus)!r}"
128
253
  directives.append(f"--mem-per-cpu={math.ceil(value / num_cpus / 2**20)}M")
129
254
  case ResourceType.CPU:
130
- assert isinstance(value, int), "CPU must be an integer"
255
+ assert isinstance(value, int), f"CPU must be an integer, got {type(value)!r}"
131
256
  directives.append(f"--cpus-per-task={value}")
132
257
  case ResourceType.GPU:
133
- assert isinstance(value, int), "GPU must be an integer"
134
- directives.append(f"--gpus-per-task={value}")
258
+ assert isinstance(value, int), f"GPU must be an integer, got {type(value)!r}"
259
+ directives.append(f"--gpus={value}")
135
260
  case ResourceType() if resource in AcceleratorType:
136
- assert isinstance(value, int), "Accelerator must be an integer"
261
+ assert isinstance(
262
+ value, int
263
+ ), f"Accelerator must be an integer, got {type(value)!r}"
137
264
  resource_type = self.cluster.resources.get(resource, None)
138
265
  if resource_type is None:
139
266
  raise ValueError(
140
267
  f"Cluster {self.cluster.name} does not map resource type {resource!r}."
141
268
  )
142
- directives.append(f"--gpus-per-task={resource_type}:{value}")
269
+ directives.append(f"--gpus={resource_type}:{value}")
143
270
  case str():
144
271
  directives.append(f"--gres={resource}:{value}")
145
272
 
146
- directives.append(f"--ntasks={self.replicas}")
147
273
  if self.location:
274
+ assert isinstance(
275
+ self.location, str
276
+ ), f"Location must be a string, got {type(self.location)!r}"
148
277
  directives.append(f"--nodelist={self.location}")
149
278
 
279
+ assert (
280
+ isinstance(self.replicas, int) and self.replicas > 0
281
+ ), f"Replicas must be a positive integer, got {self.replicas!r}"
282
+ directives.append(f"--ntasks={self.replicas}")
283
+
284
+ if self.topology is not None:
285
+ assert self.accelerator is not None, "Accelerator must be set."
286
+ match self.accelerator:
287
+ case ResourceType.GPU:
288
+ directives.append(f"--gpus-per-task={self.topology.dimensions[0]}")
289
+ case ResourceType() if self.accelerator in AcceleratorType:
290
+ resource_type = self.cluster.resources[self.accelerator]
291
+ directives.append(
292
+ f"--gpus-per-task={resource_type}:{self.topology.dimensions[0]}"
293
+ )
294
+
295
+ if self.topology.switches is not None:
296
+ switches_timeout = (
297
+ f"@{utils.timestr_from_timedelta(self.topology.switches_grace_period)}"
298
+ if self.topology.switches_grace_period is not None
299
+ else ""
300
+ )
301
+ directives.append(f"--switches={self.topology.switches}{switches_timeout}")
302
+
150
303
  return directives
151
304
 
152
305
  def replace(
153
306
  self,
154
- cluster: config.SlurmClusterConfig | None,
307
+ replicas: int | None = None,
308
+ location: str | None = None,
309
+ cluster: config.SlurmClusterConfig | None = None,
155
310
  **kw_resources: ResourceQuantity,
156
311
  ) -> "JobRequirements":
312
+ # Merge kw_resources into existing task_requirements, removing conflicting enum keys
313
+ merged_resources = dict(self.task_requirements)
314
+
315
+ # Remove ResourceType keys that will be overridden by string keys in kw_resources
316
+ for key in list(merged_resources.keys()):
317
+ if isinstance(key, ResourceType) and any(
318
+ ResourceType[name.upper()] == key
319
+ for name in kw_resources
320
+ if name.upper() in ResourceType.__members__
321
+ ):
322
+ del merged_resources[key]
323
+
324
+ merged_resources.update(kw_resources) # type: ignore
325
+
157
326
  return JobRequirements(
158
- resources=self.task_requirements | kw_resources, # type: ignore
159
- replicas=self.replicas,
160
- cluster=cluster or self.cluster,
327
+ resources=merged_resources,
328
+ replicas=replicas if replicas is not None else self.replicas,
329
+ location=location if location is not None else self.location,
330
+ cluster=cluster if cluster is not None else self.cluster,
161
331
  )
162
332
 
163
333
  def __repr__(self) -> str:
@@ -169,7 +339,7 @@ class JobRequirements:
169
339
  args.append(f"{resource.lower()}={value!r}")
170
340
 
171
341
  if self.replicas != 1:
172
- args.append(f"replicas={self.replicas}")
342
+ args.append(f"replicas={self.replicas!r}")
173
343
 
174
344
  if self.cluster is not None:
175
345
  args.append(f"cluster={self.cluster!r}")
xm_slurm/status.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import enum
4
4
  import re
5
- from typing import Sequence
5
+ import typing as tp
6
6
 
7
7
  from xmanager import xm
8
8
 
@@ -151,7 +151,7 @@ class SlurmWorkUnitStatus(xm.ExperimentUnitStatus):
151
151
  """Status of a Slurm experiment job."""
152
152
 
153
153
  @classmethod
154
- def aggregate(cls, states: Sequence[SlurmJobState]) -> "SlurmWorkUnitStatus":
154
+ def aggregate(cls, states: tp.Sequence[SlurmJobState]) -> "SlurmWorkUnitStatus":
155
155
  """Aggregate a sequence of statuses into a single status."""
156
156
  assert len(states) > 0, "Cannot aggregate empty sequence of statuses."
157
157
  max_error_state: SlurmJobState | None = None
@@ -25,6 +25,7 @@ __xm_slurm_wait_for_children() {
25
25
  while [ ${#children[@]} -gt 0 ]; do
26
26
  {% endraw %}
27
27
  echo "INFO: Waiting for child processes to finish..."
28
+ set +e
28
29
  {% if requeue_on_timeout %}
29
30
  # Wait on either one of the child processes or the timeout process.
30
31
  wait -n -p child_pid "${children[@]}" "${timeout_pid}"
@@ -32,6 +33,7 @@ __xm_slurm_wait_for_children() {
32
33
  wait -n -p child_pid "${children[@]}"
33
34
  {% endif %}
34
35
  local child_exit_code=$?
36
+ set -e
35
37
 
36
38
  {% if requeue_on_timeout %}
37
39
  # If the finished process is the watchdog, trigger the timeout handling.
@@ -7,9 +7,9 @@
7
7
 
8
8
  {% block bootstrap %}
9
9
  srun \
10
+ --label \
10
11
  --unbuffered \
11
12
  --kill-on-bad-exit=0 \
12
- --overlap \
13
13
  --export="ALL" \
14
14
  bash <<'SRUN_EOF' &
15
15
  set -Eeuxo pipefail
@@ -29,6 +29,7 @@
29
29
  {% block bootstrap %}
30
30
  {% for job in job_group.jobs.values() +%}
31
31
  srun \
32
+ --label \
32
33
  --unbuffered \
33
34
  --kill-on-bad-exit=0 \
34
35
  --export="ALL" \
@@ -27,6 +27,13 @@
27
27
  {% endblock directives %}
28
28
  set -Eeuxo pipefail
29
29
 
30
+ {% if stdlib %}
31
+ # --- Helper functions ---
32
+ {% for fn in stdlib %}
33
+ {{ fn }}
34
+ {% endfor %}
35
+ {% endif %}
36
+
30
37
  {% block prolog %}
31
38
  {% if cluster.prolog %}
32
39
  {{- cluster.prolog -}}
@@ -52,9 +59,9 @@ export {{ key }}="{{ value }}"
52
59
 
53
60
  {% block bootstrap %}
54
61
  srun \
62
+ --label \
55
63
  --unbuffered \
56
64
  --kill-on-bad-exit=0 \
57
- --overlap \
58
65
  --export="ALL" \
59
66
  bash <<'SRUN_EOF' &
60
67
  set -Eeuxo pipefail
@@ -0,0 +1,62 @@
1
+ # retry: rerun a command if it exits with certain codes
2
+ # Options:
3
+ # -c CODE Retry on this exit code (repeatable).
4
+ # -n N Max attempts (incl. first). Default: unlimited
5
+ # -d SECS Initial delay before first retry. Default: 1
6
+ # -b FACTOR Integer backoff multiplier per retry. Default: 1 (no backoff)
7
+ # -q Quiet (no logs)
8
+ # Usage:
9
+ # retry [-c CODE ...] [-n N] [-d SECS] [-b FACTOR] [-q] -- cmd arg1 arg2 ...
10
+ retry() {
11
+ local -a codes=()
12
+ local -i max=-1 delay=1 backoff=1 quiet=0 status
13
+ local opt OPTIND=1
14
+
15
+ while getopts ":c:n:d:b:q" opt; do
16
+ case "$opt" in
17
+ c) codes+=("$OPTARG") ;;
18
+ n) max=$OPTARG ;;
19
+ d) delay=$OPTARG ;;
20
+ b) backoff=$OPTARG ;;
21
+ q) quiet=1 ;;
22
+ :) printf 'retry: option -%s requires an argument\n' "$OPTARG" >&2; return 2 ;;
23
+ \?) printf 'retry: invalid option -- %s\n' "$OPTARG" >&2; return 2 ;;
24
+ esac
25
+ done
26
+ shift $((OPTIND-1))
27
+ (( $# )) || { printf 'retry: missing command\n' >&2; return 2; }
28
+
29
+ ((${#codes[@]})) || { printf 'retry: no return codes specified\n' >&2; return 2; }
30
+
31
+ for ((attempt=1; ; attempt++)); do
32
+ if "$@"; then # safe with set -e (exception context)
33
+ return 0
34
+ else
35
+ status=$? # capture failing status immediately
36
+ fi
37
+
38
+ # retryable?
39
+ local retryable=0 c
40
+ for c in "${codes[@]}"; do
41
+ (( status == c )) && { retryable=1; break; }
42
+ done
43
+
44
+ # stop if not retryable OR we've just hit the max attempt
45
+ if (( !retryable )) || (( max >= 0 && attempt >= max )); then
46
+ (( quiet )) || {
47
+ if (( attempt > 1 )); then
48
+ printf 'retry: giving up after %d attempts; last exit=%d\n' "$attempt" "$status" >&2
49
+ else
50
+ printf 'retry: command failed; exit=%d\n' "$status" >&2
51
+ fi
52
+ }
53
+ return "$status" # propagate exact code; errexit will catch
54
+ fi
55
+
56
+ (( quiet )) || printf 'retry: attempt %d failed with %d; retrying in %ds...\n' \
57
+ "$attempt" "$status" "$delay" >&2
58
+ sleep "$delay" || : # never trip set -e if sleep errors
59
+ (( delay *= backoff ))
60
+ done
61
+ }
62
+ export -f retry