xmanager-slurm 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xm_slurm/__init__.py +47 -0
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/batching.py +139 -0
- xm_slurm/config.py +189 -0
- xm_slurm/console.py +3 -0
- xm_slurm/constants.py +19 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +67 -0
- xm_slurm/contrib/clusters/drac.py +242 -0
- xm_slurm/dependencies.py +171 -0
- xm_slurm/executables.py +215 -0
- xm_slurm/execution.py +995 -0
- xm_slurm/executors.py +210 -0
- xm_slurm/experiment.py +1016 -0
- xm_slurm/experimental/parameter_controller.py +206 -0
- xm_slurm/filesystems.py +129 -0
- xm_slurm/job_blocks.py +21 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +309 -0
- xm_slurm/packaging/__init__.py +8 -0
- xm_slurm/packaging/docker.py +348 -0
- xm_slurm/packaging/registry.py +45 -0
- xm_slurm/packaging/router.py +56 -0
- xm_slurm/packaging/utils.py +22 -0
- xm_slurm/resources.py +350 -0
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/scripts/cli.py +90 -0
- xm_slurm/status.py +197 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
- xm_slurm/templates/docker/mamba.Dockerfile +29 -0
- xm_slurm/templates/docker/python.Dockerfile +32 -0
- xm_slurm/templates/docker/uv.Dockerfile +38 -0
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
- xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
- xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
- xm_slurm/templates/slurm/job.bash.j2 +90 -0
- xm_slurm/templates/slurm/library/retry.bash +62 -0
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
- xm_slurm/types.py +23 -0
- xm_slurm/utils.py +196 -0
- xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
- xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
- xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
- xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
- xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import dataclasses
|
|
4
|
+
import enum
|
|
5
|
+
import functools
|
|
6
|
+
import logging
|
|
7
|
+
import typing as tp
|
|
8
|
+
import zlib
|
|
9
|
+
|
|
10
|
+
import backoff
|
|
11
|
+
import cloudpickle
|
|
12
|
+
from xmanager import xm
|
|
13
|
+
|
|
14
|
+
import xm_slurm
|
|
15
|
+
from xm_slurm import job_blocks, status
|
|
16
|
+
from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
|
|
17
|
+
|
|
18
|
+
P = tp.ParamSpec("P")
|
|
19
|
+
T = tp.TypeVar("T")
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def _monitor_parameter_controller(
|
|
25
|
+
aux_unit: SlurmAuxiliaryUnit,
|
|
26
|
+
local_parameter_controller_coro: tp.Coroutine[None, None, T],
|
|
27
|
+
*,
|
|
28
|
+
poll_interval: float = 30.0,
|
|
29
|
+
) -> None:
|
|
30
|
+
local_controller_finished = asyncio.Event()
|
|
31
|
+
local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
|
|
32
|
+
|
|
33
|
+
@local_parameter_controller.add_done_callback
|
|
34
|
+
def _(future: asyncio.Task[T]) -> None:
|
|
35
|
+
try:
|
|
36
|
+
_ = future.result()
|
|
37
|
+
except asyncio.CancelledError:
|
|
38
|
+
logger.info("Local parameter controller was cancelled, resuming on remote controller.")
|
|
39
|
+
pass
|
|
40
|
+
except Exception:
|
|
41
|
+
logger.error("Local parameter controller failed, stopping remote controller.")
|
|
42
|
+
aux_unit.stop(
|
|
43
|
+
mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
|
|
44
|
+
)
|
|
45
|
+
raise
|
|
46
|
+
else:
|
|
47
|
+
logger.info(
|
|
48
|
+
"Local parameter controller finished before remote controller started, "
|
|
49
|
+
"stopping remote controller."
|
|
50
|
+
)
|
|
51
|
+
local_controller_finished.set()
|
|
52
|
+
aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
|
|
53
|
+
|
|
54
|
+
@backoff.on_predicate(
|
|
55
|
+
backoff.constant,
|
|
56
|
+
lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
|
|
57
|
+
jitter=None,
|
|
58
|
+
interval=poll_interval,
|
|
59
|
+
)
|
|
60
|
+
async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
|
|
61
|
+
logger.info("Waiting for remote parameter controller to start.")
|
|
62
|
+
if local_controller_finished.is_set():
|
|
63
|
+
return status.SlurmWorkUnitStatusEnum.COMPLETED
|
|
64
|
+
return (await aux_unit.get_status()).status
|
|
65
|
+
|
|
66
|
+
logger.info("Monitoring remote parameter controller.")
|
|
67
|
+
# TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
|
|
68
|
+
# We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
|
|
69
|
+
await asyncio.sleep(15)
|
|
70
|
+
match await wait_for_remote_controller():
|
|
71
|
+
case status.SlurmWorkUnitStatusEnum.RUNNING:
|
|
72
|
+
logger.info("Remote parameter controller started.")
|
|
73
|
+
local_parameter_controller.cancel("Remote parameter controller started.")
|
|
74
|
+
case status.SlurmWorkUnitStatusEnum.COMPLETED:
|
|
75
|
+
if local_parameter_controller.done():
|
|
76
|
+
logger.info("Local parameter controller finished, stopping remote controller.")
|
|
77
|
+
aux_unit.stop(
|
|
78
|
+
mark_as_completed=True, message="Local parameter controller finished."
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
logger.info("Remote parameter controller finished, stopping local controller.")
|
|
82
|
+
local_parameter_controller.cancel()
|
|
83
|
+
case status.SlurmWorkUnitStatusEnum.FAILED:
|
|
84
|
+
logger.error("Remote parameter controller failed, stopping local controller.")
|
|
85
|
+
local_parameter_controller.cancel()
|
|
86
|
+
case status.SlurmWorkUnitStatusEnum.CANCELLED:
|
|
87
|
+
logger.info("Remote parameter controller was cancelled, stopping local controller.")
|
|
88
|
+
local_parameter_controller.cancel()
|
|
89
|
+
case status.SlurmWorkUnitStatusEnum.PENDING:
|
|
90
|
+
raise RuntimeError("Remote parameter controller is still pending, invalid state.")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ParameterControllerMode(enum.Enum):
|
|
94
|
+
AUTO = enum.auto()
|
|
95
|
+
REMOTE_ONLY = enum.auto()
|
|
96
|
+
# TODO(jfarebro): is it possible to get LOCAL_ONLY?
|
|
97
|
+
# We'd need to have a dummy job type as we need to return a JobType?
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def parameter_controller(
|
|
101
|
+
*,
|
|
102
|
+
executable: xm.Executable,
|
|
103
|
+
executor: xm.Executor,
|
|
104
|
+
controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
|
|
105
|
+
controller_name: str = "parameter_controller",
|
|
106
|
+
controller_args: xm.UserArgs | None = None,
|
|
107
|
+
controller_env_vars: tp.Mapping[str, str] | None = None,
|
|
108
|
+
) -> tp.Callable[
|
|
109
|
+
[
|
|
110
|
+
tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
|
|
111
|
+
| tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
|
|
112
|
+
],
|
|
113
|
+
tp.Callable[P, xm.AuxiliaryUnitJob],
|
|
114
|
+
]:
|
|
115
|
+
"""Converts a function to a controller which can be added to an experiment.
|
|
116
|
+
|
|
117
|
+
Calling the wrapped function would return an xm.JobGenerator which would run
|
|
118
|
+
it as auxiliary unit on the specified executor.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
executable: An executable that has a Python entrypoint with all the necesarry dependencies.
|
|
122
|
+
executor: The executor to launch the controller job on.
|
|
123
|
+
controller_name: Name of the parameter controller job.
|
|
124
|
+
controller_args: Mapping of flag names and values to be used by the XM
|
|
125
|
+
client running inside the parameter controller job.
|
|
126
|
+
controller_env_vars: Mapping of env variable names and values to be passed
|
|
127
|
+
to the parameter controller job.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
A decorator to be applied to the function.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def decorator(
|
|
134
|
+
f: tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
|
|
135
|
+
| tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
|
|
136
|
+
) -> tp.Callable[P, xm.AuxiliaryUnitJob]:
|
|
137
|
+
@functools.wraps(f)
|
|
138
|
+
def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
|
|
139
|
+
# Modify the function to read the experiment from the API so that it can be pickled.
|
|
140
|
+
|
|
141
|
+
async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
|
|
142
|
+
experiment_id = aux_unit.experiment.experiment_id
|
|
143
|
+
|
|
144
|
+
async def local_controller(
|
|
145
|
+
*args: P.args, **kwargs: P.kwargs
|
|
146
|
+
) -> T | tp.Awaitable[T]:
|
|
147
|
+
if asyncio.iscoroutinefunction(f):
|
|
148
|
+
return await f(aux_unit.experiment, *args, **kwargs)
|
|
149
|
+
else:
|
|
150
|
+
return f(aux_unit.experiment, *args, **kwargs)
|
|
151
|
+
|
|
152
|
+
async def remote_controller(
|
|
153
|
+
*args: P.args, **kwargs: P.kwargs
|
|
154
|
+
) -> T | tp.Awaitable[T]:
|
|
155
|
+
async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
|
|
156
|
+
if asyncio.iscoroutinefunction(f):
|
|
157
|
+
return await f(exp, *args, **kwargs)
|
|
158
|
+
else:
|
|
159
|
+
return f(exp, *args, **kwargs)
|
|
160
|
+
|
|
161
|
+
remote_controller_serialized = base64.urlsafe_b64encode(
|
|
162
|
+
zlib.compress(
|
|
163
|
+
cloudpickle.dumps(
|
|
164
|
+
functools.partial(remote_controller, *args, **kwargs),
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
parameter_controller_executable = dataclasses.replace(
|
|
170
|
+
executable,
|
|
171
|
+
args=xm.merge_args(
|
|
172
|
+
job_blocks.get_args_for_python_entrypoint(
|
|
173
|
+
xm.ModuleName("xm_slurm.scripts._cloudpickle")
|
|
174
|
+
),
|
|
175
|
+
xm.SequentialArgs.from_collection({
|
|
176
|
+
"cloudpickled_fn": remote_controller_serialized.decode("ascii"),
|
|
177
|
+
}),
|
|
178
|
+
xm.SequentialArgs.from_collection(controller_args),
|
|
179
|
+
),
|
|
180
|
+
env_vars=controller_env_vars or {},
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
await aux_unit.add(
|
|
184
|
+
xm.Job(
|
|
185
|
+
executor=executor,
|
|
186
|
+
executable=parameter_controller_executable,
|
|
187
|
+
name=controller_name,
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Launch local parameter controller and monitor for when it starts running
|
|
192
|
+
# so we can kill the local controller.
|
|
193
|
+
if controller_mode is ParameterControllerMode.AUTO:
|
|
194
|
+
aux_unit._create_task(
|
|
195
|
+
_monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return xm.AuxiliaryUnitJob(
|
|
199
|
+
job_generator,
|
|
200
|
+
importance=xm.Importance.HIGH,
|
|
201
|
+
termination_delay_secs=0, # TODO: add support for termination delay.?
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return make_controller
|
|
205
|
+
|
|
206
|
+
return decorator
|
xm_slurm/filesystems.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import asyncio
|
|
3
|
+
import os
|
|
4
|
+
import typing as tp
|
|
5
|
+
|
|
6
|
+
import aiofile
|
|
7
|
+
import asyncssh
|
|
8
|
+
import typing_extensions as tpe
|
|
9
|
+
import wrapt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AsyncFileIO(tp.Protocol):
|
|
13
|
+
async def write(self, buffer: str | bytes, /) -> int: ...
|
|
14
|
+
async def read(self, size: int = -1, /) -> bytes: ...
|
|
15
|
+
async def seek(self, offset: int, /) -> int: ...
|
|
16
|
+
async def tell(self) -> int: ...
|
|
17
|
+
async def __aenter__(self) -> tpe.Self: ...
|
|
18
|
+
async def __aexit__(self, *args: tp.Any, **kwargs: tp.Any) -> None: ...
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AsyncFileSystem(tp.Protocol):
|
|
22
|
+
async def open(
|
|
23
|
+
self,
|
|
24
|
+
path: os.PathLike[str] | str,
|
|
25
|
+
mode: tp.Literal["r", "w", "rb", "wb"],
|
|
26
|
+
*,
|
|
27
|
+
encoding: str = "utf-8",
|
|
28
|
+
) -> AsyncFileIO: ...
|
|
29
|
+
|
|
30
|
+
async def read(
|
|
31
|
+
self, path: os.PathLike[str] | str, *, size: int = -1, offset: int = 0
|
|
32
|
+
) -> bytes: ...
|
|
33
|
+
async def write(
|
|
34
|
+
self, path: os.PathLike[str] | str, data: str | bytes, *, offset: int = 0
|
|
35
|
+
) -> None: ...
|
|
36
|
+
|
|
37
|
+
async def exists(self, path: os.PathLike[str] | str) -> bool: ...
|
|
38
|
+
async def size(self, path: os.PathLike[str] | str) -> int | None: ...
|
|
39
|
+
async def makedirs(
|
|
40
|
+
self, path: os.PathLike[str] | str, mode: int = 511, exist_ok: bool = False
|
|
41
|
+
) -> None: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AbstractAsyncFileSystem(AsyncFileSystem, abc.ABC):
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
async def open(
|
|
47
|
+
self,
|
|
48
|
+
path: os.PathLike[str] | str,
|
|
49
|
+
mode: tp.Literal["r", "w", "rb", "wb"],
|
|
50
|
+
*,
|
|
51
|
+
encoding: str = "utf-8",
|
|
52
|
+
) -> AsyncFileIO: ...
|
|
53
|
+
|
|
54
|
+
async def read(self, path: os.PathLike[str] | str, *, size: int = -1, offset: int = 0) -> bytes:
|
|
55
|
+
async with await self.open(path, "rb") as f:
|
|
56
|
+
await f.seek(offset)
|
|
57
|
+
return await f.read(size)
|
|
58
|
+
|
|
59
|
+
async def write(
|
|
60
|
+
self, path: os.PathLike[str] | str, data: str | bytes, *, offset: int = 0
|
|
61
|
+
) -> None:
|
|
62
|
+
async with await self.open(path, "wb") as f:
|
|
63
|
+
await f.seek(offset)
|
|
64
|
+
await f.write(data)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class AsyncLocalFileIO(wrapt.ObjectProxy):
|
|
68
|
+
async def seek(self, offset: int, /) -> int:
|
|
69
|
+
await asyncio.to_thread(self.__wrapped__.seek, offset)
|
|
70
|
+
return await asyncio.to_thread(self.__wrapped__.tell)
|
|
71
|
+
|
|
72
|
+
async def tell(self) -> int:
|
|
73
|
+
return await asyncio.to_thread(self.__wrapped__.tell)
|
|
74
|
+
|
|
75
|
+
async def __aenter__(self) -> tpe.Self:
|
|
76
|
+
return AsyncLocalFileIO(await self.__wrapped__.__aenter__())
|
|
77
|
+
|
|
78
|
+
async def __aexit__(self, *args: tp.Any, **kwargs: tp.Any) -> None:
|
|
79
|
+
return await self.__wrapped__.__aexit__(*args, **kwargs)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class AsyncLocalFileSystem(AbstractAsyncFileSystem):
|
|
83
|
+
def __init__(self): ...
|
|
84
|
+
|
|
85
|
+
async def open(
|
|
86
|
+
self,
|
|
87
|
+
path: os.PathLike[str] | str,
|
|
88
|
+
mode: tp.Literal["r", "w", "rb", "wb"],
|
|
89
|
+
*,
|
|
90
|
+
encoding: str = "utf-8",
|
|
91
|
+
) -> AsyncFileIO:
|
|
92
|
+
return AsyncLocalFileIO(aiofile.async_open(os.fspath(path), mode=mode, encoding=encoding)) # type: ignore
|
|
93
|
+
|
|
94
|
+
async def exists(self, path: os.PathLike[str] | str) -> bool:
|
|
95
|
+
return await asyncio.to_thread(os.path.exists, os.fspath(path))
|
|
96
|
+
|
|
97
|
+
async def size(self, path: os.PathLike[str] | str) -> int | None:
|
|
98
|
+
return await asyncio.to_thread(os.path.getsize, os.fspath(path))
|
|
99
|
+
|
|
100
|
+
async def makedirs(
|
|
101
|
+
self, path: os.PathLike[str] | str, mode: int = 0o777, exist_ok: bool = False
|
|
102
|
+
) -> None:
|
|
103
|
+
return await asyncio.to_thread(os.makedirs, os.fspath(path), mode=mode, exist_ok=exist_ok)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AsyncSSHFileSystem(AbstractAsyncFileSystem):
|
|
107
|
+
def __init__(self, client: asyncssh.SFTPClient):
|
|
108
|
+
self._client = client
|
|
109
|
+
|
|
110
|
+
async def open(
|
|
111
|
+
self,
|
|
112
|
+
path: os.PathLike[str] | str,
|
|
113
|
+
mode: tp.Literal["r", "w", "rb", "wb"],
|
|
114
|
+
*,
|
|
115
|
+
encoding: str = "utf-8",
|
|
116
|
+
) -> AsyncFileIO:
|
|
117
|
+
return await self._client.open(os.fspath(path), mode, encoding=encoding) # type: ignore
|
|
118
|
+
|
|
119
|
+
async def exists(self, path: os.PathLike[str] | str) -> bool:
|
|
120
|
+
return await self._client.exists(os.fspath(path))
|
|
121
|
+
|
|
122
|
+
async def size(self, path: os.PathLike[str] | str) -> int | None:
|
|
123
|
+
return (await self._client.stat(os.fspath(path))).size
|
|
124
|
+
|
|
125
|
+
async def makedirs(
|
|
126
|
+
self, path: os.PathLike[str] | str, mode: int = 0o777, exist_ok: bool = False
|
|
127
|
+
) -> None:
|
|
128
|
+
attrs = asyncssh.SFTPAttrs(permissions=mode)
|
|
129
|
+
return await self._client.makedirs(os.fspath(path), attrs=attrs, exist_ok=exist_ok)
|
xm_slurm/job_blocks.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import typing as tp
|
|
2
|
+
|
|
3
|
+
from xmanager import xm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class JobArgs(tp.TypedDict, total=False):
|
|
7
|
+
args: xm.UserArgs
|
|
8
|
+
env_vars: tp.Mapping[str, str]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_args_for_python_entrypoint(
|
|
12
|
+
entrypoint: xm.ModuleName | xm.CommandList,
|
|
13
|
+
) -> xm.SequentialArgs:
|
|
14
|
+
match entrypoint:
|
|
15
|
+
case xm.ModuleName():
|
|
16
|
+
entrypoint_args = ["-m", entrypoint.module_name]
|
|
17
|
+
case xm.CommandList():
|
|
18
|
+
entrypoint_args = entrypoint.commands
|
|
19
|
+
case _:
|
|
20
|
+
raise TypeError(f"Invalid entrypoint type: {type(entrypoint)}")
|
|
21
|
+
return xm.SequentialArgs.from_collection(entrypoint_args)
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import collections.abc
|
|
2
|
+
import typing as tp
|
|
3
|
+
|
|
4
|
+
from xmanager import xm
|
|
5
|
+
|
|
6
|
+
from xm_slurm import api
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SlurmContextArtifacts(collections.abc.MutableMapping[str, str]):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
owner: xm.Experiment | xm.ExperimentUnit,
|
|
13
|
+
*,
|
|
14
|
+
artifacts: tp.Sequence[api.models.Artifact],
|
|
15
|
+
):
|
|
16
|
+
self._data = {artifact.name: artifact.uri for artifact in artifacts}
|
|
17
|
+
self._owner = owner
|
|
18
|
+
self._create_task = self._owner._create_task
|
|
19
|
+
|
|
20
|
+
def add(self, name: str, uri: str) -> None:
|
|
21
|
+
artifact = api.models.Artifact(name=name, uri=uri)
|
|
22
|
+
match self._owner:
|
|
23
|
+
case xm.Experiment():
|
|
24
|
+
api.client().insert_experiment_artifact(self._owner.experiment_id, artifact)
|
|
25
|
+
case xm.WorkUnit():
|
|
26
|
+
api.client().insert_work_unit_artifact(
|
|
27
|
+
self._owner.experiment_id, self._owner.work_unit_id, artifact
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
self._data[name] = uri
|
|
31
|
+
|
|
32
|
+
def remove(self, name: str) -> None:
|
|
33
|
+
match self._owner:
|
|
34
|
+
case xm.Experiment():
|
|
35
|
+
api.client().delete_experiment_artifact(self._owner.experiment_id, name)
|
|
36
|
+
case xm.WorkUnit():
|
|
37
|
+
api.client().delete_work_unit_artifact(
|
|
38
|
+
self._owner.experiment_id, self._owner.work_unit_id, name
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def __setitem__(self, name: str, uri: str) -> None:
|
|
42
|
+
self.add(name, uri)
|
|
43
|
+
|
|
44
|
+
def __delitem__(self, name: str) -> None:
|
|
45
|
+
self.remove(name)
|
|
46
|
+
|
|
47
|
+
def __getitem__(self, name: str) -> str:
|
|
48
|
+
return self._data[name]
|
|
49
|
+
|
|
50
|
+
def __iter__(self) -> tp.Iterator[str]:
|
|
51
|
+
return iter(self._data)
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
return len(self._data)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SlurmExperimentAnnotationTags(collections.abc.MutableSet[str]):
|
|
58
|
+
def __init__(self, experiment: xm.Experiment, *, tags: tp.Iterable[str]):
|
|
59
|
+
self._experiment = experiment
|
|
60
|
+
# Use a dict to ensure order is preserved
|
|
61
|
+
self._tags = dict.fromkeys(tags)
|
|
62
|
+
|
|
63
|
+
def add(self, tag: str) -> None:
|
|
64
|
+
self._tags[tag] = None
|
|
65
|
+
api.client().update_experiment(
|
|
66
|
+
self._experiment.experiment_id,
|
|
67
|
+
api.models.ExperimentPatch(tags=list(self._tags)),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def remove(self, tag: str) -> None:
|
|
71
|
+
self.discard(tag)
|
|
72
|
+
|
|
73
|
+
def discard(self, tag: str) -> None:
|
|
74
|
+
self._tags.pop(tag)
|
|
75
|
+
api.client().update_experiment(
|
|
76
|
+
self._experiment.experiment_id,
|
|
77
|
+
api.models.ExperimentPatch(tags=list(self._tags)),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def __contains__(self, tag: str) -> bool:
|
|
81
|
+
return tag in self._tags
|
|
82
|
+
|
|
83
|
+
def __iter__(self) -> tp.Iterator[str]:
|
|
84
|
+
return iter(self._tags)
|
|
85
|
+
|
|
86
|
+
def __len__(self) -> int:
|
|
87
|
+
return len(self._tags)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class SlurmExperimentUnitMetadataContext:
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
experiment_unit: xm.ExperimentUnit,
|
|
94
|
+
*,
|
|
95
|
+
artifacts: SlurmContextArtifacts,
|
|
96
|
+
):
|
|
97
|
+
self._experiment_unit = experiment_unit
|
|
98
|
+
self._artifacts = artifacts
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def artifacts(self) -> SlurmContextArtifacts:
|
|
102
|
+
return self._artifacts
|
|
103
|
+
|
|
104
|
+
@artifacts.setter
|
|
105
|
+
def artifacts(self, artifacts: SlurmContextArtifacts) -> None:
|
|
106
|
+
del artifacts
|
|
107
|
+
raise ValueError("The artifacts object is immutable.")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class SlurmExperimentContextAnnotations:
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
experiment: xm.Experiment,
|
|
114
|
+
*,
|
|
115
|
+
title: str,
|
|
116
|
+
tags: set[str] | None = None,
|
|
117
|
+
description: str | None = None,
|
|
118
|
+
note: str | None = None,
|
|
119
|
+
):
|
|
120
|
+
self._experiment = experiment
|
|
121
|
+
self._create_task = self._experiment._create_task
|
|
122
|
+
self._title = title
|
|
123
|
+
self._tags = SlurmExperimentAnnotationTags(experiment, tags=tags or [])
|
|
124
|
+
self._description = description or ""
|
|
125
|
+
self._note = note or ""
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def title(self) -> str:
|
|
129
|
+
return self._title
|
|
130
|
+
|
|
131
|
+
@title.setter
|
|
132
|
+
def title(self, value: str) -> None:
|
|
133
|
+
self._title = value
|
|
134
|
+
api.client().update_experiment(
|
|
135
|
+
self._experiment.experiment_id,
|
|
136
|
+
api.models.ExperimentPatch(title=value),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def description(self) -> str:
|
|
141
|
+
return self._description
|
|
142
|
+
|
|
143
|
+
@description.setter
|
|
144
|
+
def description(self, value: str) -> None:
|
|
145
|
+
self._description = value
|
|
146
|
+
api.client().update_experiment(
|
|
147
|
+
self._experiment.experiment_id,
|
|
148
|
+
api.models.ExperimentPatch(description=value),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def note(self) -> str:
|
|
153
|
+
return self._note
|
|
154
|
+
|
|
155
|
+
@note.setter
|
|
156
|
+
def note(self, value: str) -> None:
|
|
157
|
+
self._note = value
|
|
158
|
+
api.client().update_experiment(
|
|
159
|
+
self._experiment.experiment_id,
|
|
160
|
+
api.models.ExperimentPatch(note=value),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def tags(self) -> SlurmExperimentAnnotationTags:
|
|
165
|
+
return self._tags
|
|
166
|
+
|
|
167
|
+
@tags.setter
|
|
168
|
+
def tags(self, tags: tp.Iterable[str]) -> None:
|
|
169
|
+
self._tags = SlurmExperimentAnnotationTags(self._experiment, tags=tags)
|
|
170
|
+
api.client().update_experiment(
|
|
171
|
+
self._experiment.experiment_id,
|
|
172
|
+
api.models.ExperimentPatch(tags=list(self._tags)),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class SlurmExperimentMetadataContext:
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
experiment: xm.Experiment,
|
|
180
|
+
*,
|
|
181
|
+
annotations: SlurmExperimentContextAnnotations,
|
|
182
|
+
artifacts: SlurmContextArtifacts,
|
|
183
|
+
):
|
|
184
|
+
self._experiment = experiment
|
|
185
|
+
self._annotations = annotations
|
|
186
|
+
self._artifacts = artifacts
|
|
187
|
+
|
|
188
|
+
self._graphviz_config = None
|
|
189
|
+
self._python_config = None
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def annotations(self) -> SlurmExperimentContextAnnotations:
|
|
193
|
+
return self._annotations
|
|
194
|
+
|
|
195
|
+
@annotations.setter
|
|
196
|
+
def annotations(self, annotations: SlurmExperimentContextAnnotations) -> None:
|
|
197
|
+
del annotations
|
|
198
|
+
raise ValueError("The annotations object is immutable.")
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def artifacts(self) -> SlurmContextArtifacts:
|
|
202
|
+
return self._artifacts
|
|
203
|
+
|
|
204
|
+
@artifacts.setter
|
|
205
|
+
def artifacts(self, artifacts: SlurmContextArtifacts) -> None:
|
|
206
|
+
del artifacts
|
|
207
|
+
raise ValueError("The artifacts object is immutable.")
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def graphviz_config(self) -> str | None:
|
|
211
|
+
return self._graphviz_config
|
|
212
|
+
|
|
213
|
+
@graphviz_config.setter
|
|
214
|
+
def graphviz_config(self, config: str | None) -> None:
|
|
215
|
+
self._graphviz_config = config
|
|
216
|
+
match config:
|
|
217
|
+
case None:
|
|
218
|
+
api.client().delete_experiment_config_artifact(
|
|
219
|
+
self._experiment.experiment_id, "GRAPHVIZ"
|
|
220
|
+
)
|
|
221
|
+
case str():
|
|
222
|
+
api.client().insert_experiment_config_artifact(
|
|
223
|
+
self._experiment.experiment_id,
|
|
224
|
+
api.models.ConfigArtifact(name="GRAPHVIZ", uri=f"graphviz://{config}"),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
@graphviz_config.deleter
|
|
228
|
+
def graphviz_config(self) -> None:
|
|
229
|
+
self._graphviz_config = None
|
|
230
|
+
api.client().delete_experiment_config_artifact(self._experiment.experiment_id, "GRAPHVIZ")
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def python_config(self) -> str | None:
|
|
234
|
+
return self._python_config
|
|
235
|
+
|
|
236
|
+
@python_config.setter
|
|
237
|
+
def python_config(self, config: str | None) -> None:
|
|
238
|
+
self._python_config = config
|
|
239
|
+
match config:
|
|
240
|
+
case None:
|
|
241
|
+
api.client().delete_experiment_config_artifact(
|
|
242
|
+
self._experiment.experiment_id, "PYTHON"
|
|
243
|
+
)
|
|
244
|
+
case str():
|
|
245
|
+
api.client().insert_experiment_config_artifact(
|
|
246
|
+
self._experiment.experiment_id,
|
|
247
|
+
api.models.ConfigArtifact(name="PYTHON", uri=config),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
@python_config.deleter
|
|
251
|
+
def python_config(self) -> None:
|
|
252
|
+
self._python_config = None
|
|
253
|
+
api.client().delete_experiment_config_artifact(self._experiment.experiment_id, "PYTHON")
|