xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import typing as tp
|
|
3
|
+
|
|
4
|
+
from xmanager import xm
|
|
5
|
+
|
|
6
|
+
T_co = tp.TypeVar("T_co", covariant=True)
|
|
7
|
+
P = tp.ParamSpec("P")
|
|
8
|
+
ExecutableSpecT = tp.TypeVar("ExecutableSpecT", bound=xm.ExecutableSpec)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclasses.dataclass(frozen=True)
|
|
12
|
+
class IndexedContainer(tp.Generic[T_co]):
|
|
13
|
+
index: int
|
|
14
|
+
value: T_co
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
RegistrationCallable = tp.Callable[
|
|
18
|
+
[tp.Sequence[IndexedContainer[xm.Packageable]]],
|
|
19
|
+
tp.Sequence[IndexedContainer[xm.Executable]],
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_REGISTRY: dict[tp.Type[xm.ExecutableSpec], RegistrationCallable] = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def register(
|
|
27
|
+
*typs: tp.Type[ExecutableSpecT],
|
|
28
|
+
) -> tp.Callable[[RegistrationCallable], RegistrationCallable]:
|
|
29
|
+
def decorator(
|
|
30
|
+
registration_callable: RegistrationCallable,
|
|
31
|
+
) -> RegistrationCallable:
|
|
32
|
+
global _REGISTRY
|
|
33
|
+
for typ in typs:
|
|
34
|
+
_REGISTRY[typ] = registration_callable
|
|
35
|
+
return registration_callable
|
|
36
|
+
|
|
37
|
+
return decorator
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def route(
|
|
41
|
+
typ: tp.Type[ExecutableSpecT],
|
|
42
|
+
packageables: tp.Sequence[IndexedContainer[xm.Packageable]],
|
|
43
|
+
) -> tp.Sequence[IndexedContainer[xm.Executable]]:
|
|
44
|
+
global _REGISTRY
|
|
45
|
+
return _REGISTRY[typ](packageables)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
import typing as tp
|
|
4
|
+
|
|
5
|
+
from xmanager import xm
|
|
6
|
+
|
|
7
|
+
from xm_slurm.console import console
|
|
8
|
+
from xm_slurm.executors import SlurmSpec
|
|
9
|
+
from xm_slurm.packaging import registry
|
|
10
|
+
|
|
11
|
+
IndexedContainer = registry.IndexedContainer
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def package(
|
|
17
|
+
packageables: tp.Sequence[xm.Packageable],
|
|
18
|
+
) -> list[xm.Executable]:
|
|
19
|
+
"""
|
|
20
|
+
Takes as input a list of packageables and returns a mapping of
|
|
21
|
+
`DockerTarget`'s to the latest digest of that image.
|
|
22
|
+
"""
|
|
23
|
+
# Docker targets to be collected.
|
|
24
|
+
# These are a mapping from `DockerTarget` to the latest digest of the image.
|
|
25
|
+
targets_by_type = collections.defaultdict[
|
|
26
|
+
tp.Type[xm.ExecutableSpec], list[IndexedContainer[xm.Packageable]]
|
|
27
|
+
](list)
|
|
28
|
+
|
|
29
|
+
# Collect dockerfiles that need to be built locally
|
|
30
|
+
for index, packageable in enumerate(packageables):
|
|
31
|
+
if not isinstance(packageable.executor_spec, SlurmSpec):
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"Unsupported executor spec for packageable: {packageable}."
|
|
34
|
+
"xm_slurm only supports `xm_slurm.SlurmSpec`."
|
|
35
|
+
)
|
|
36
|
+
targets_by_type[type(packageable.executable_spec)].append(
|
|
37
|
+
IndexedContainer[xm.Packageable](index, packageable)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
targets: list[IndexedContainer[xm.Executable]] = []
|
|
41
|
+
# TODO(jfarebro): Could make this async as well...?
|
|
42
|
+
with console.status("[magenta] :package: Packaging executables..."):
|
|
43
|
+
for executable_spec_type, targets_for_type in targets_by_type.items():
|
|
44
|
+
logger.info(
|
|
45
|
+
f"Packaging {len(targets_for_type)} {executable_spec_type.__name__} target{'s' if len(targets_for_type) > 1 else ''}."
|
|
46
|
+
)
|
|
47
|
+
targets.extend(registry.route(executable_spec_type, targets_for_type))
|
|
48
|
+
|
|
49
|
+
console.print(
|
|
50
|
+
f"[magenta]:package: Finished packaging [bold]{len(targets)} executable"
|
|
51
|
+
f"{'s' if len(targets) > 1 else ''}[/bold]."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
assert len(targets) == len(packageables), "Number of targets must match packageables"
|
|
55
|
+
targets = sorted(targets, key=lambda t: t.index)
|
|
56
|
+
return [target.value for target in targets]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import logging
|
|
3
|
+
import typing as tp
|
|
4
|
+
|
|
5
|
+
from xmanager import xm
|
|
6
|
+
|
|
7
|
+
from xm_slurm.packaging.registry import IndexedContainer
|
|
8
|
+
|
|
9
|
+
T = tp.TypeVar("T")
|
|
10
|
+
P = tp.ParamSpec("P")
|
|
11
|
+
ReturnT = tp.TypeVar("ReturnT")
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def collect_executors_by_executable(
|
|
17
|
+
targets: tp.Sequence[IndexedContainer[xm.Packageable]],
|
|
18
|
+
) -> dict[xm.ExecutableSpec, set[xm.ExecutorSpec]]:
|
|
19
|
+
executors_by_executable = collections.defaultdict(set)
|
|
20
|
+
for target in targets:
|
|
21
|
+
executors_by_executable[target.value.executable_spec].add(target.value.executor_spec)
|
|
22
|
+
return executors_by_executable
|
xm_slurm/resources.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import collections.abc
|
|
3
|
+
import datetime as dt
|
|
4
|
+
import enum
|
|
5
|
+
import itertools
|
|
6
|
+
import math
|
|
7
|
+
import re
|
|
8
|
+
import typing as tp
|
|
9
|
+
|
|
10
|
+
from xm_slurm import config, utils
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ResourceType(enum.IntEnum):
|
|
14
|
+
CPU = 1
|
|
15
|
+
|
|
16
|
+
MEMORY = 2
|
|
17
|
+
RAM = 2
|
|
18
|
+
|
|
19
|
+
EPHEMERAL_STORAGE = 3
|
|
20
|
+
DISK = 3
|
|
21
|
+
|
|
22
|
+
GPU = 1000
|
|
23
|
+
RTX8000 = 1001
|
|
24
|
+
P4 = 1010
|
|
25
|
+
|
|
26
|
+
P100 = 1011
|
|
27
|
+
P100_16GIB = 1012
|
|
28
|
+
|
|
29
|
+
V100 = 1020
|
|
30
|
+
V100_32GIB = 1021
|
|
31
|
+
|
|
32
|
+
A100 = 1030
|
|
33
|
+
A100_80GIB = 1031
|
|
34
|
+
A5000 = 1032
|
|
35
|
+
A6000 = 1033
|
|
36
|
+
|
|
37
|
+
H100 = 1040
|
|
38
|
+
L40S = 1041
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
AcceleratorType = set([
|
|
42
|
+
ResourceType.RTX8000,
|
|
43
|
+
ResourceType.P4,
|
|
44
|
+
ResourceType.P100,
|
|
45
|
+
ResourceType.P100_16GIB,
|
|
46
|
+
ResourceType.V100,
|
|
47
|
+
ResourceType.V100_32GIB,
|
|
48
|
+
ResourceType.A100,
|
|
49
|
+
ResourceType.A100_80GIB,
|
|
50
|
+
ResourceType.A5000,
|
|
51
|
+
ResourceType.A6000,
|
|
52
|
+
ResourceType.H100,
|
|
53
|
+
ResourceType.L40S,
|
|
54
|
+
])
|
|
55
|
+
|
|
56
|
+
assert AcceleratorType | {
|
|
57
|
+
ResourceType.CPU,
|
|
58
|
+
ResourceType.MEMORY,
|
|
59
|
+
ResourceType.DISK,
|
|
60
|
+
ResourceType.GPU,
|
|
61
|
+
} == set(ResourceType.__members__.values()), "Resource types are not exhaustive."
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class FeatureType(enum.IntEnum):
|
|
65
|
+
NVIDIA_MIG = 1
|
|
66
|
+
NVIDIA_NVLINK = 2
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class InvalidTopologyError(Exception):
|
|
70
|
+
"""An unrecognized topology has been provided."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
TOPOLOGY_REGEX = re.compile(r"^(?P<dims>[\d]+(?:x[\d]+)*)$")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Topology:
|
|
77
|
+
mesh: str
|
|
78
|
+
dimensions: list[int]
|
|
79
|
+
switches: int | None
|
|
80
|
+
switches_grace_period: dt.timedelta | None
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
mesh: str,
|
|
85
|
+
/,
|
|
86
|
+
*,
|
|
87
|
+
switches: int | None = None,
|
|
88
|
+
switches_grace_period: dt.timedelta | None = None,
|
|
89
|
+
):
|
|
90
|
+
mesh_match = TOPOLOGY_REGEX.fullmatch(mesh)
|
|
91
|
+
if not mesh_match:
|
|
92
|
+
raise InvalidTopologyError(f"Invalid topology mesh: {mesh!r}.")
|
|
93
|
+
|
|
94
|
+
self.mesh = mesh
|
|
95
|
+
self.dimensions = list(map(int, mesh_match.group("dims").split("x")))
|
|
96
|
+
if switches is not None:
|
|
97
|
+
assert (
|
|
98
|
+
isinstance(switches, int) and switches > 0
|
|
99
|
+
), "Switches must be a positive integer."
|
|
100
|
+
self.switches = switches
|
|
101
|
+
if switches_grace_period is not None:
|
|
102
|
+
assert isinstance(
|
|
103
|
+
switches_grace_period, dt.timedelta
|
|
104
|
+
), "Switches grace period must be a `datetime.timedelta`."
|
|
105
|
+
self.switches_grace_period = switches_grace_period
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def chip_count(self) -> int:
|
|
109
|
+
return math.prod(self.dimensions)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def ndim(self) -> int:
|
|
113
|
+
return len(self.dimensions)
|
|
114
|
+
|
|
115
|
+
def __eq__(self, other: object) -> bool:
|
|
116
|
+
if not isinstance(other, Topology):
|
|
117
|
+
return False
|
|
118
|
+
return (
|
|
119
|
+
self.mesh == other.mesh
|
|
120
|
+
and self.switches == other.switches
|
|
121
|
+
and self.switches_grace_period == other.switches_grace_period
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def __hash__(self) -> int:
|
|
125
|
+
return hash((self.mesh, self.switches, self.switches_grace_period))
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
ResourceQuantity = int | float | Topology
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _parse_resource_quantity(
|
|
132
|
+
resource_name: ResourceType | str, value: ResourceQuantity
|
|
133
|
+
) -> tuple[float, Topology | None]:
|
|
134
|
+
if isinstance(resource_name, ResourceType):
|
|
135
|
+
resource_name = resource_name.name
|
|
136
|
+
match value:
|
|
137
|
+
case Topology() as topology:
|
|
138
|
+
return topology.chip_count, topology
|
|
139
|
+
case builtins.str() as topology_str if (
|
|
140
|
+
"x" in topology_str and TOPOLOGY_REGEX.fullmatch(topology_str) is not None
|
|
141
|
+
):
|
|
142
|
+
topology = Topology(topology_str)
|
|
143
|
+
return topology.chip_count, topology
|
|
144
|
+
case builtins.str() as num_str:
|
|
145
|
+
try:
|
|
146
|
+
value = float(num_str)
|
|
147
|
+
return int(value) if value.is_integer() else value, None
|
|
148
|
+
except ValueError as e:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Couldn't parse resource quantity for {resource_name}. "
|
|
151
|
+
f"{num_str!r} was given."
|
|
152
|
+
) from e
|
|
153
|
+
case int() | float():
|
|
154
|
+
return value, None
|
|
155
|
+
case _:
|
|
156
|
+
raise ValueError(f"Invalid resource quantity: {value!r} for {resource_name!r}.")
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class JobRequirements:
|
|
160
|
+
replicas: int
|
|
161
|
+
location: str | None
|
|
162
|
+
accelerator: ResourceType | None
|
|
163
|
+
topology: Topology | None
|
|
164
|
+
cluster: config.SlurmClusterConfig
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
*,
|
|
169
|
+
resources: tp.Mapping[ResourceType | str, ResourceQuantity] | None = None,
|
|
170
|
+
replicas: int | None = None,
|
|
171
|
+
location: str | tp.Iterable[str] | None = None,
|
|
172
|
+
cluster: config.SlurmClusterConfig,
|
|
173
|
+
**kw_resources: ResourceQuantity,
|
|
174
|
+
):
|
|
175
|
+
if isinstance(location, collections.abc.Iterable) and not isinstance(location, str):
|
|
176
|
+
location = ",".join(location)
|
|
177
|
+
self.location = location
|
|
178
|
+
|
|
179
|
+
self.accelerator = None
|
|
180
|
+
self.topology = None
|
|
181
|
+
self.cluster = cluster
|
|
182
|
+
|
|
183
|
+
if resources is None:
|
|
184
|
+
resources = {}
|
|
185
|
+
|
|
186
|
+
self.task_requirements: dict[ResourceType | str, ResourceQuantity] = {}
|
|
187
|
+
for resource_name, value in itertools.chain(resources.items(), kw_resources.items()):
|
|
188
|
+
quantity, topology = _parse_resource_quantity(resource_name, value)
|
|
189
|
+
match resource_name:
|
|
190
|
+
case str() if resource_name.upper() in ResourceType.__members__:
|
|
191
|
+
resource = ResourceType[resource_name.upper()]
|
|
192
|
+
case ResourceType():
|
|
193
|
+
resource = resource_name
|
|
194
|
+
case str():
|
|
195
|
+
resource = resource_name
|
|
196
|
+
|
|
197
|
+
if (
|
|
198
|
+
resource in AcceleratorType
|
|
199
|
+
or resource == ResourceType.GPU
|
|
200
|
+
or (isinstance(resource, str) and resource.startswith("gpu"))
|
|
201
|
+
):
|
|
202
|
+
if self.accelerator is not None:
|
|
203
|
+
raise ValueError("Accelerator already set.")
|
|
204
|
+
self.accelerator = resource # type: ignore
|
|
205
|
+
self.topology = topology
|
|
206
|
+
elif topology is not None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"A topology was specified for a non-accelerator resource: {resource_name!r}."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if resource in self.task_requirements:
|
|
212
|
+
raise ValueError(f"{resource} has been specified twice.")
|
|
213
|
+
self.task_requirements[resource] = quantity
|
|
214
|
+
|
|
215
|
+
if self.topology is not None and self.topology.ndim > 2:
|
|
216
|
+
raise ValueError("Topologies with more than 2 dimensions are not supported.")
|
|
217
|
+
|
|
218
|
+
if (
|
|
219
|
+
self.accelerator is not None
|
|
220
|
+
and self.topology is not None
|
|
221
|
+
and len(self.topology.dimensions) == 2
|
|
222
|
+
):
|
|
223
|
+
if replicas is not None and replicas != self.topology.dimensions[1]:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"For multihost GPUs with topology {self.topology}, replicas should"
|
|
226
|
+
f"be either None or {self.topology.dimensions[1]}. Found: "
|
|
227
|
+
f"{replicas}"
|
|
228
|
+
)
|
|
229
|
+
replicas = self.topology.dimensions[1]
|
|
230
|
+
|
|
231
|
+
if replicas is not None and replicas <= 0:
|
|
232
|
+
raise ValueError(f"Replicas must be a positive integer, got {replicas!r}")
|
|
233
|
+
self.replicas = replicas or 1
|
|
234
|
+
|
|
235
|
+
def batch_directives(self) -> list[str]:
|
|
236
|
+
directives = []
|
|
237
|
+
|
|
238
|
+
for resource, value in self.task_requirements.items():
|
|
239
|
+
match resource:
|
|
240
|
+
case ResourceType.EPHEMERAL_STORAGE | ResourceType.DISK:
|
|
241
|
+
assert isinstance(
|
|
242
|
+
value, int
|
|
243
|
+
), f"Disk space must be an integer, got {type(value)!r}"
|
|
244
|
+
directives.append(f"--tmp={math.ceil(value / 2**20)}M")
|
|
245
|
+
case ResourceType.MEMORY | ResourceType.RAM:
|
|
246
|
+
num_cpus = self.task_requirements.get(ResourceType.CPU, 1)
|
|
247
|
+
assert isinstance(
|
|
248
|
+
value, (int, float)
|
|
249
|
+
), f"Memory must be an integer or float, got {type(value)!r}"
|
|
250
|
+
assert isinstance(
|
|
251
|
+
num_cpus, int
|
|
252
|
+
), f"CPU must be an integer, got {type(num_cpus)!r}"
|
|
253
|
+
directives.append(f"--mem-per-cpu={math.ceil(value / num_cpus / 2**20)}M")
|
|
254
|
+
case ResourceType.CPU:
|
|
255
|
+
assert isinstance(value, int), f"CPU must be an integer, got {type(value)!r}"
|
|
256
|
+
directives.append(f"--cpus-per-task={value}")
|
|
257
|
+
case ResourceType.GPU:
|
|
258
|
+
assert isinstance(value, int), f"GPU must be an integer, got {type(value)!r}"
|
|
259
|
+
directives.append(f"--gpus={value}")
|
|
260
|
+
case ResourceType() if resource in AcceleratorType:
|
|
261
|
+
assert isinstance(
|
|
262
|
+
value, int
|
|
263
|
+
), f"Accelerator must be an integer, got {type(value)!r}"
|
|
264
|
+
resource_type = self.cluster.resources.get(resource, None)
|
|
265
|
+
if resource_type is None:
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"Cluster {self.cluster.name} does not map resource type {resource!r}."
|
|
268
|
+
)
|
|
269
|
+
directives.append(f"--gpus={resource_type}:{value}")
|
|
270
|
+
case str():
|
|
271
|
+
directives.append(f"--gres={resource}:{value}")
|
|
272
|
+
|
|
273
|
+
if self.location:
|
|
274
|
+
assert isinstance(
|
|
275
|
+
self.location, str
|
|
276
|
+
), f"Location must be a string, got {type(self.location)!r}"
|
|
277
|
+
directives.append(f"--nodelist={self.location}")
|
|
278
|
+
|
|
279
|
+
assert (
|
|
280
|
+
isinstance(self.replicas, int) and self.replicas > 0
|
|
281
|
+
), f"Replicas must be a positive integer, got {self.replicas!r}"
|
|
282
|
+
directives.append(f"--ntasks={self.replicas}")
|
|
283
|
+
|
|
284
|
+
if self.topology is not None:
|
|
285
|
+
assert self.accelerator is not None, "Accelerator must be set."
|
|
286
|
+
match self.accelerator:
|
|
287
|
+
case ResourceType.GPU:
|
|
288
|
+
directives.append(f"--gpus-per-task={self.topology.dimensions[0]}")
|
|
289
|
+
case ResourceType() if self.accelerator in AcceleratorType:
|
|
290
|
+
resource_type = self.cluster.resources[self.accelerator]
|
|
291
|
+
directives.append(
|
|
292
|
+
f"--gpus-per-task={resource_type}:{self.topology.dimensions[0]}"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if self.topology.switches is not None:
|
|
296
|
+
switches_timeout = (
|
|
297
|
+
f"@{utils.timestr_from_timedelta(self.topology.switches_grace_period)}"
|
|
298
|
+
if self.topology.switches_grace_period is not None
|
|
299
|
+
else ""
|
|
300
|
+
)
|
|
301
|
+
directives.append(f"--switches={self.topology.switches}{switches_timeout}")
|
|
302
|
+
|
|
303
|
+
return directives
|
|
304
|
+
|
|
305
|
+
def step_directives(self) -> list[str]:
|
|
306
|
+
return []
|
|
307
|
+
|
|
308
|
+
def replace(
|
|
309
|
+
self,
|
|
310
|
+
replicas: int | None = None,
|
|
311
|
+
location: str | None = None,
|
|
312
|
+
cluster: config.SlurmClusterConfig | None = None,
|
|
313
|
+
**kw_resources: ResourceQuantity,
|
|
314
|
+
) -> "JobRequirements":
|
|
315
|
+
# Merge kw_resources into existing task_requirements, removing conflicting enum keys
|
|
316
|
+
merged_resources = dict(self.task_requirements)
|
|
317
|
+
|
|
318
|
+
# Remove ResourceType keys that will be overridden by string keys in kw_resources
|
|
319
|
+
for key in list(merged_resources.keys()):
|
|
320
|
+
if isinstance(key, ResourceType) and any(
|
|
321
|
+
ResourceType[name.upper()] == key
|
|
322
|
+
for name in kw_resources
|
|
323
|
+
if name.upper() in ResourceType.__members__
|
|
324
|
+
):
|
|
325
|
+
del merged_resources[key]
|
|
326
|
+
|
|
327
|
+
merged_resources.update(kw_resources) # type: ignore
|
|
328
|
+
|
|
329
|
+
return JobRequirements(
|
|
330
|
+
resources=merged_resources,
|
|
331
|
+
replicas=replicas if replicas is not None else self.replicas,
|
|
332
|
+
location=location if location is not None else self.location,
|
|
333
|
+
cluster=cluster if cluster is not None else self.cluster,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def __repr__(self) -> str:
|
|
337
|
+
args = []
|
|
338
|
+
|
|
339
|
+
for resource, value in self.task_requirements.items():
|
|
340
|
+
if isinstance(resource, ResourceType):
|
|
341
|
+
resource = resource.name
|
|
342
|
+
args.append(f"{resource.lower()}={value!r}")
|
|
343
|
+
|
|
344
|
+
if self.replicas != 1:
|
|
345
|
+
args.append(f"replicas={self.replicas!r}")
|
|
346
|
+
|
|
347
|
+
if self.cluster is not None:
|
|
348
|
+
args.append(f"cluster={self.cluster!r}")
|
|
349
|
+
|
|
350
|
+
return f'xm_slurm.JobRequirements({", ".join(args)})'
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import logging
|
|
3
|
+
import zlib
|
|
4
|
+
|
|
5
|
+
import cloudpickle
|
|
6
|
+
from absl import app, flags
|
|
7
|
+
from xmanager import xm
|
|
8
|
+
|
|
9
|
+
CLOUDPICKLED_FN = flags.DEFINE_string(
|
|
10
|
+
"cloudpickled_fn", None, "Base64 encoded cloudpickled function", required=True
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@xm.run_in_asyncio_loop
|
|
17
|
+
async def main(argv):
|
|
18
|
+
del argv
|
|
19
|
+
|
|
20
|
+
logger.info("Loading cloudpickled function...")
|
|
21
|
+
cloudpickled_fn = zlib.decompress(base64.urlsafe_b64decode(CLOUDPICKLED_FN.value))
|
|
22
|
+
function = cloudpickle.loads(cloudpickled_fn)
|
|
23
|
+
logger.info("Running cloudpickled function...")
|
|
24
|
+
await function()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
app.run(main)
|
xm_slurm/scripts/cli.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from xmanager import xm
|
|
5
|
+
|
|
6
|
+
import xm_slurm
|
|
7
|
+
from xm_slurm.console import console
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
async def logs(
|
|
11
|
+
experiment_id: int,
|
|
12
|
+
*,
|
|
13
|
+
wid: int | None,
|
|
14
|
+
identity: str | None,
|
|
15
|
+
follow: bool = True,
|
|
16
|
+
num_lines: int = 10,
|
|
17
|
+
block_size: int = 1024,
|
|
18
|
+
):
|
|
19
|
+
xp = xm_slurm.get_experiment(experiment_id)
|
|
20
|
+
|
|
21
|
+
if wid is not None:
|
|
22
|
+
wus = xp.work_units()
|
|
23
|
+
if wid not in wus:
|
|
24
|
+
console.print(
|
|
25
|
+
f"[red]Work Unit ID {wid} not found for experiment {experiment_id} with {len(wus)} work units.[/red]"
|
|
26
|
+
)
|
|
27
|
+
sys.exit(1)
|
|
28
|
+
wu = wus[wid]
|
|
29
|
+
elif identity is not None:
|
|
30
|
+
wu = xp._get_work_unit_by_identity(identity)
|
|
31
|
+
if wu is None:
|
|
32
|
+
console.print(f"[red]Work Unit with identity {identity} not found.[/red]")
|
|
33
|
+
sys.exit(1)
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError("Must specify either wid or identity.")
|
|
36
|
+
assert wu is not None
|
|
37
|
+
|
|
38
|
+
with console.status("Waiting for logs...") as status:
|
|
39
|
+
waiting = True
|
|
40
|
+
async for log in wu.logs(
|
|
41
|
+
num_lines=num_lines, block_size=block_size, wait=True, follow=follow
|
|
42
|
+
):
|
|
43
|
+
if waiting:
|
|
44
|
+
status.stop()
|
|
45
|
+
waiting = False
|
|
46
|
+
console.print(log, end="\n")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@xm.run_in_asyncio_loop
|
|
50
|
+
async def main():
|
|
51
|
+
parser = argparse.ArgumentParser(description="XManager.")
|
|
52
|
+
subparsers = parser.add_subparsers(dest="subcommand", required=True)
|
|
53
|
+
|
|
54
|
+
logs_parser = subparsers.add_parser("logs", help="Display logs for a specific experiment.")
|
|
55
|
+
logs_parser.add_argument("xid", type=int, help="Experiment ID.")
|
|
56
|
+
|
|
57
|
+
# Create a mutually exclusive group for wid and identity
|
|
58
|
+
group = logs_parser.add_mutually_exclusive_group()
|
|
59
|
+
group.add_argument("--wid", type=int, help="Work Unit ID.")
|
|
60
|
+
group.add_argument("--identity", type=str, help="Work Unit identity.")
|
|
61
|
+
|
|
62
|
+
logs_parser.add_argument(
|
|
63
|
+
"-n",
|
|
64
|
+
"--n-lines",
|
|
65
|
+
type=int,
|
|
66
|
+
default=50,
|
|
67
|
+
help="Number of lines to display from the end of the log file.",
|
|
68
|
+
)
|
|
69
|
+
logs_parser.add_argument(
|
|
70
|
+
"-f",
|
|
71
|
+
"--follow",
|
|
72
|
+
default=True,
|
|
73
|
+
action="store_true",
|
|
74
|
+
help="Follow the log file as it is updated.",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
args = parser.parse_args()
|
|
78
|
+
match args.subcommand:
|
|
79
|
+
case "logs":
|
|
80
|
+
await logs(
|
|
81
|
+
args.xid,
|
|
82
|
+
wid=args.wid,
|
|
83
|
+
identity=args.identity,
|
|
84
|
+
follow=args.follow,
|
|
85
|
+
num_lines=args.n_lines,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
if __name__ == "__main__":
|
|
90
|
+
main() # type: ignore
|