xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
@@ -0,0 +1,206 @@
1
+ import asyncio
2
+ import base64
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ import logging
7
+ import typing as tp
8
+ import zlib
9
+
10
+ import backoff
11
+ import cloudpickle
12
+ from xmanager import xm
13
+
14
+ import xm_slurm
15
+ from xm_slurm import job_blocks, status
16
+ from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
+
18
+ P = tp.ParamSpec("P")
19
+ T = tp.TypeVar("T")
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ async def _monitor_parameter_controller(
25
+ aux_unit: SlurmAuxiliaryUnit,
26
+ local_parameter_controller_coro: tp.Coroutine[None, None, T],
27
+ *,
28
+ poll_interval: float = 30.0,
29
+ ) -> None:
30
+ local_controller_finished = asyncio.Event()
31
+ local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
32
+
33
+ @local_parameter_controller.add_done_callback
34
+ def _(future: asyncio.Task[T]) -> None:
35
+ try:
36
+ _ = future.result()
37
+ except asyncio.CancelledError:
38
+ logger.info("Local parameter controller was cancelled, resuming on remote controller.")
39
+ pass
40
+ except Exception:
41
+ logger.error("Local parameter controller failed, stopping remote controller.")
42
+ aux_unit.stop(
43
+ mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
44
+ )
45
+ raise
46
+ else:
47
+ logger.info(
48
+ "Local parameter controller finished before remote controller started, "
49
+ "stopping remote controller."
50
+ )
51
+ local_controller_finished.set()
52
+ aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
53
+
54
+ @backoff.on_predicate(
55
+ backoff.constant,
56
+ lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
57
+ jitter=None,
58
+ interval=poll_interval,
59
+ )
60
+ async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
61
+ logger.info("Waiting for remote parameter controller to start.")
62
+ if local_controller_finished.is_set():
63
+ return status.SlurmWorkUnitStatusEnum.COMPLETED
64
+ return (await aux_unit.get_status()).status
65
+
66
+ logger.info("Monitoring remote parameter controller.")
67
+ # TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
68
+ # We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
69
+ await asyncio.sleep(15)
70
+ match await wait_for_remote_controller():
71
+ case status.SlurmWorkUnitStatusEnum.RUNNING:
72
+ logger.info("Remote parameter controller started.")
73
+ local_parameter_controller.cancel("Remote parameter controller started.")
74
+ case status.SlurmWorkUnitStatusEnum.COMPLETED:
75
+ if local_parameter_controller.done():
76
+ logger.info("Local parameter controller finished, stopping remote controller.")
77
+ aux_unit.stop(
78
+ mark_as_completed=True, message="Local parameter controller finished."
79
+ )
80
+ else:
81
+ logger.info("Remote parameter controller finished, stopping local controller.")
82
+ local_parameter_controller.cancel()
83
+ case status.SlurmWorkUnitStatusEnum.FAILED:
84
+ logger.error("Remote parameter controller failed, stopping local controller.")
85
+ local_parameter_controller.cancel()
86
+ case status.SlurmWorkUnitStatusEnum.CANCELLED:
87
+ logger.info("Remote parameter controller was cancelled, stopping local controller.")
88
+ local_parameter_controller.cancel()
89
+ case status.SlurmWorkUnitStatusEnum.PENDING:
90
+ raise RuntimeError("Remote parameter controller is still pending, invalid state.")
91
+
92
+
93
+ class ParameterControllerMode(enum.Enum):
94
+ AUTO = enum.auto()
95
+ REMOTE_ONLY = enum.auto()
96
+ # TODO(jfarebro): is it possible to get LOCAL_ONLY?
97
+ # We'd need to have a dummy job type as we need to return a JobType?
98
+
99
+
100
+ def parameter_controller(
101
+ *,
102
+ executable: xm.Executable,
103
+ executor: xm.Executor,
104
+ controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
105
+ controller_name: str = "parameter_controller",
106
+ controller_args: xm.UserArgs | None = None,
107
+ controller_env_vars: tp.Mapping[str, str] | None = None,
108
+ ) -> tp.Callable[
109
+ [
110
+ tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
111
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
112
+ ],
113
+ tp.Callable[P, xm.AuxiliaryUnitJob],
114
+ ]:
115
+ """Converts a function to a controller which can be added to an experiment.
116
+
117
+ Calling the wrapped function would return an xm.JobGenerator which would run
118
+ it as auxiliary unit on the specified executor.
119
+
120
+ Args:
121
+ executable: An executable that has a Python entrypoint with all the necesarry dependencies.
122
+ executor: The executor to launch the controller job on.
123
+ controller_name: Name of the parameter controller job.
124
+ controller_args: Mapping of flag names and values to be used by the XM
125
+ client running inside the parameter controller job.
126
+ controller_env_vars: Mapping of env variable names and values to be passed
127
+ to the parameter controller job.
128
+
129
+ Returns:
130
+ A decorator to be applied to the function.
131
+ """
132
+
133
+ def decorator(
134
+ f: tp.Callable[tp.Concatenate[SlurmExperiment, P], T]
135
+ | tp.Callable[tp.Concatenate[SlurmExperiment, P], tp.Awaitable[T]],
136
+ ) -> tp.Callable[P, xm.AuxiliaryUnitJob]:
137
+ @functools.wraps(f)
138
+ def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
139
+ # Modify the function to read the experiment from the API so that it can be pickled.
140
+
141
+ async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
142
+ experiment_id = aux_unit.experiment.experiment_id
143
+
144
+ async def local_controller(
145
+ *args: P.args, **kwargs: P.kwargs
146
+ ) -> T | tp.Awaitable[T]:
147
+ if asyncio.iscoroutinefunction(f):
148
+ return await f(aux_unit.experiment, *args, **kwargs)
149
+ else:
150
+ return f(aux_unit.experiment, *args, **kwargs)
151
+
152
+ async def remote_controller(
153
+ *args: P.args, **kwargs: P.kwargs
154
+ ) -> T | tp.Awaitable[T]:
155
+ async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
156
+ if asyncio.iscoroutinefunction(f):
157
+ return await f(exp, *args, **kwargs)
158
+ else:
159
+ return f(exp, *args, **kwargs)
160
+
161
+ remote_controller_serialized = base64.urlsafe_b64encode(
162
+ zlib.compress(
163
+ cloudpickle.dumps(
164
+ functools.partial(remote_controller, *args, **kwargs),
165
+ )
166
+ )
167
+ )
168
+
169
+ parameter_controller_executable = dataclasses.replace(
170
+ executable,
171
+ args=xm.merge_args(
172
+ job_blocks.get_args_for_python_entrypoint(
173
+ xm.ModuleName("xm_slurm.scripts._cloudpickle")
174
+ ),
175
+ xm.SequentialArgs.from_collection({
176
+ "cloudpickled_fn": remote_controller_serialized.decode("ascii"),
177
+ }),
178
+ xm.SequentialArgs.from_collection(controller_args),
179
+ ),
180
+ env_vars=controller_env_vars or {},
181
+ )
182
+
183
+ await aux_unit.add(
184
+ xm.Job(
185
+ executor=executor,
186
+ executable=parameter_controller_executable,
187
+ name=controller_name,
188
+ )
189
+ )
190
+
191
+ # Launch local parameter controller and monitor for when it starts running
192
+ # so we can kill the local controller.
193
+ if controller_mode is ParameterControllerMode.AUTO:
194
+ aux_unit._create_task(
195
+ _monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
196
+ )
197
+
198
+ return xm.AuxiliaryUnitJob(
199
+ job_generator,
200
+ importance=xm.Importance.HIGH,
201
+ termination_delay_secs=0, # TODO: add support for termination delay.?
202
+ )
203
+
204
+ return make_controller
205
+
206
+ return decorator
@@ -0,0 +1,129 @@
1
+ import abc
2
+ import asyncio
3
+ import os
4
+ import typing as tp
5
+
6
+ import aiofile
7
+ import asyncssh
8
+ import typing_extensions as tpe
9
+ import wrapt
10
+
11
+
12
+ class AsyncFileIO(tp.Protocol):
13
+ async def write(self, buffer: str | bytes, /) -> int: ...
14
+ async def read(self, size: int = -1, /) -> bytes: ...
15
+ async def seek(self, offset: int, /) -> int: ...
16
+ async def tell(self) -> int: ...
17
+ async def __aenter__(self) -> tpe.Self: ...
18
+ async def __aexit__(self, *args: tp.Any, **kwargs: tp.Any) -> None: ...
19
+
20
+
21
+ class AsyncFileSystem(tp.Protocol):
22
+ async def open(
23
+ self,
24
+ path: os.PathLike[str] | str,
25
+ mode: tp.Literal["r", "w", "rb", "wb"],
26
+ *,
27
+ encoding: str = "utf-8",
28
+ ) -> AsyncFileIO: ...
29
+
30
+ async def read(
31
+ self, path: os.PathLike[str] | str, *, size: int = -1, offset: int = 0
32
+ ) -> bytes: ...
33
+ async def write(
34
+ self, path: os.PathLike[str] | str, data: str | bytes, *, offset: int = 0
35
+ ) -> None: ...
36
+
37
+ async def exists(self, path: os.PathLike[str] | str) -> bool: ...
38
+ async def size(self, path: os.PathLike[str] | str) -> int | None: ...
39
+ async def makedirs(
40
+ self, path: os.PathLike[str] | str, mode: int = 511, exist_ok: bool = False
41
+ ) -> None: ...
42
+
43
+
44
+ class AbstractAsyncFileSystem(AsyncFileSystem, abc.ABC):
45
+ @abc.abstractmethod
46
+ async def open(
47
+ self,
48
+ path: os.PathLike[str] | str,
49
+ mode: tp.Literal["r", "w", "rb", "wb"],
50
+ *,
51
+ encoding: str = "utf-8",
52
+ ) -> AsyncFileIO: ...
53
+
54
+ async def read(self, path: os.PathLike[str] | str, *, size: int = -1, offset: int = 0) -> bytes:
55
+ async with await self.open(path, "rb") as f:
56
+ await f.seek(offset)
57
+ return await f.read(size)
58
+
59
+ async def write(
60
+ self, path: os.PathLike[str] | str, data: str | bytes, *, offset: int = 0
61
+ ) -> None:
62
+ async with await self.open(path, "wb") as f:
63
+ await f.seek(offset)
64
+ await f.write(data)
65
+
66
+
67
+ class AsyncLocalFileIO(wrapt.ObjectProxy):
68
+ async def seek(self, offset: int, /) -> int:
69
+ await asyncio.to_thread(self.__wrapped__.seek, offset)
70
+ return await asyncio.to_thread(self.__wrapped__.tell)
71
+
72
+ async def tell(self) -> int:
73
+ return await asyncio.to_thread(self.__wrapped__.tell)
74
+
75
+ async def __aenter__(self) -> tpe.Self:
76
+ return AsyncLocalFileIO(await self.__wrapped__.__aenter__())
77
+
78
+ async def __aexit__(self, *args: tp.Any, **kwargs: tp.Any) -> None:
79
+ return await self.__wrapped__.__aexit__(*args, **kwargs)
80
+
81
+
82
+ class AsyncLocalFileSystem(AbstractAsyncFileSystem):
83
+ def __init__(self): ...
84
+
85
+ async def open(
86
+ self,
87
+ path: os.PathLike[str] | str,
88
+ mode: tp.Literal["r", "w", "rb", "wb"],
89
+ *,
90
+ encoding: str = "utf-8",
91
+ ) -> AsyncFileIO:
92
+ return AsyncLocalFileIO(aiofile.async_open(os.fspath(path), mode=mode, encoding=encoding)) # type: ignore
93
+
94
+ async def exists(self, path: os.PathLike[str] | str) -> bool:
95
+ return await asyncio.to_thread(os.path.exists, os.fspath(path))
96
+
97
+ async def size(self, path: os.PathLike[str] | str) -> int | None:
98
+ return await asyncio.to_thread(os.path.getsize, os.fspath(path))
99
+
100
+ async def makedirs(
101
+ self, path: os.PathLike[str] | str, mode: int = 0o777, exist_ok: bool = False
102
+ ) -> None:
103
+ return await asyncio.to_thread(os.makedirs, os.fspath(path), mode=mode, exist_ok=exist_ok)
104
+
105
+
106
+ class AsyncSSHFileSystem(AbstractAsyncFileSystem):
107
+ def __init__(self, client: asyncssh.SFTPClient):
108
+ self._client = client
109
+
110
+ async def open(
111
+ self,
112
+ path: os.PathLike[str] | str,
113
+ mode: tp.Literal["r", "w", "rb", "wb"],
114
+ *,
115
+ encoding: str = "utf-8",
116
+ ) -> AsyncFileIO:
117
+ return await self._client.open(os.fspath(path), mode, encoding=encoding) # type: ignore
118
+
119
+ async def exists(self, path: os.PathLike[str] | str) -> bool:
120
+ return await self._client.exists(os.fspath(path))
121
+
122
+ async def size(self, path: os.PathLike[str] | str) -> int | None:
123
+ return (await self._client.stat(os.fspath(path))).size
124
+
125
+ async def makedirs(
126
+ self, path: os.PathLike[str] | str, mode: int = 0o777, exist_ok: bool = False
127
+ ) -> None:
128
+ attrs = asyncssh.SFTPAttrs(permissions=mode)
129
+ return await self._client.makedirs(os.fspath(path), attrs=attrs, exist_ok=exist_ok)
xm_slurm/job_blocks.py ADDED
@@ -0,0 +1,21 @@
1
+ import typing as tp
2
+
3
+ from xmanager import xm
4
+
5
+
6
+ class JobArgs(tp.TypedDict, total=False):
7
+ args: xm.UserArgs
8
+ env_vars: tp.Mapping[str, str]
9
+
10
+
11
+ def get_args_for_python_entrypoint(
12
+ entrypoint: xm.ModuleName | xm.CommandList,
13
+ ) -> xm.SequentialArgs:
14
+ match entrypoint:
15
+ case xm.ModuleName():
16
+ entrypoint_args = ["-m", entrypoint.module_name]
17
+ case xm.CommandList():
18
+ entrypoint_args = entrypoint.commands
19
+ case _:
20
+ raise TypeError(f"Invalid entrypoint type: {type(entrypoint)}")
21
+ return xm.SequentialArgs.from_collection(entrypoint_args)
@@ -0,0 +1,253 @@
1
+ import collections.abc
2
+ import typing as tp
3
+
4
+ from xmanager import xm
5
+
6
+ from xm_slurm import api
7
+
8
+
9
+ class SlurmContextArtifacts(collections.abc.MutableMapping[str, str]):
10
+ def __init__(
11
+ self,
12
+ owner: xm.Experiment | xm.ExperimentUnit,
13
+ *,
14
+ artifacts: tp.Sequence[api.models.Artifact],
15
+ ):
16
+ self._data = {artifact.name: artifact.uri for artifact in artifacts}
17
+ self._owner = owner
18
+ self._create_task = self._owner._create_task
19
+
20
+ def add(self, name: str, uri: str) -> None:
21
+ artifact = api.models.Artifact(name=name, uri=uri)
22
+ match self._owner:
23
+ case xm.Experiment():
24
+ api.client().insert_experiment_artifact(self._owner.experiment_id, artifact)
25
+ case xm.WorkUnit():
26
+ api.client().insert_work_unit_artifact(
27
+ self._owner.experiment_id, self._owner.work_unit_id, artifact
28
+ )
29
+
30
+ self._data[name] = uri
31
+
32
+ def remove(self, name: str) -> None:
33
+ match self._owner:
34
+ case xm.Experiment():
35
+ api.client().delete_experiment_artifact(self._owner.experiment_id, name)
36
+ case xm.WorkUnit():
37
+ api.client().delete_work_unit_artifact(
38
+ self._owner.experiment_id, self._owner.work_unit_id, name
39
+ )
40
+
41
+ def __setitem__(self, name: str, uri: str) -> None:
42
+ self.add(name, uri)
43
+
44
+ def __delitem__(self, name: str) -> None:
45
+ self.remove(name)
46
+
47
+ def __getitem__(self, name: str) -> str:
48
+ return self._data[name]
49
+
50
+ def __iter__(self) -> tp.Iterator[str]:
51
+ return iter(self._data)
52
+
53
+ def __len__(self) -> int:
54
+ return len(self._data)
55
+
56
+
57
+ class SlurmExperimentAnnotationTags(collections.abc.MutableSet[str]):
58
+ def __init__(self, experiment: xm.Experiment, *, tags: tp.Iterable[str]):
59
+ self._experiment = experiment
60
+ # Use a dict to ensure order is preserved
61
+ self._tags = dict.fromkeys(tags)
62
+
63
+ def add(self, tag: str) -> None:
64
+ self._tags[tag] = None
65
+ api.client().update_experiment(
66
+ self._experiment.experiment_id,
67
+ api.models.ExperimentPatch(tags=list(self._tags)),
68
+ )
69
+
70
+ def remove(self, tag: str) -> None:
71
+ self.discard(tag)
72
+
73
+ def discard(self, tag: str) -> None:
74
+ self._tags.pop(tag)
75
+ api.client().update_experiment(
76
+ self._experiment.experiment_id,
77
+ api.models.ExperimentPatch(tags=list(self._tags)),
78
+ )
79
+
80
+ def __contains__(self, tag: str) -> bool:
81
+ return tag in self._tags
82
+
83
+ def __iter__(self) -> tp.Iterator[str]:
84
+ return iter(self._tags)
85
+
86
+ def __len__(self) -> int:
87
+ return len(self._tags)
88
+
89
+
90
+ class SlurmExperimentUnitMetadataContext:
91
+ def __init__(
92
+ self,
93
+ experiment_unit: xm.ExperimentUnit,
94
+ *,
95
+ artifacts: SlurmContextArtifacts,
96
+ ):
97
+ self._experiment_unit = experiment_unit
98
+ self._artifacts = artifacts
99
+
100
+ @property
101
+ def artifacts(self) -> SlurmContextArtifacts:
102
+ return self._artifacts
103
+
104
+ @artifacts.setter
105
+ def artifacts(self, artifacts: SlurmContextArtifacts) -> None:
106
+ del artifacts
107
+ raise ValueError("The artifacts object is immutable.")
108
+
109
+
110
+ class SlurmExperimentContextAnnotations:
111
+ def __init__(
112
+ self,
113
+ experiment: xm.Experiment,
114
+ *,
115
+ title: str,
116
+ tags: set[str] | None = None,
117
+ description: str | None = None,
118
+ note: str | None = None,
119
+ ):
120
+ self._experiment = experiment
121
+ self._create_task = self._experiment._create_task
122
+ self._title = title
123
+ self._tags = SlurmExperimentAnnotationTags(experiment, tags=tags or [])
124
+ self._description = description or ""
125
+ self._note = note or ""
126
+
127
+ @property
128
+ def title(self) -> str:
129
+ return self._title
130
+
131
+ @title.setter
132
+ def title(self, value: str) -> None:
133
+ self._title = value
134
+ api.client().update_experiment(
135
+ self._experiment.experiment_id,
136
+ api.models.ExperimentPatch(title=value),
137
+ )
138
+
139
+ @property
140
+ def description(self) -> str:
141
+ return self._description
142
+
143
+ @description.setter
144
+ def description(self, value: str) -> None:
145
+ self._description = value
146
+ api.client().update_experiment(
147
+ self._experiment.experiment_id,
148
+ api.models.ExperimentPatch(description=value),
149
+ )
150
+
151
+ @property
152
+ def note(self) -> str:
153
+ return self._note
154
+
155
+ @note.setter
156
+ def note(self, value: str) -> None:
157
+ self._note = value
158
+ api.client().update_experiment(
159
+ self._experiment.experiment_id,
160
+ api.models.ExperimentPatch(note=value),
161
+ )
162
+
163
+ @property
164
+ def tags(self) -> SlurmExperimentAnnotationTags:
165
+ return self._tags
166
+
167
+ @tags.setter
168
+ def tags(self, tags: tp.Iterable[str]) -> None:
169
+ self._tags = SlurmExperimentAnnotationTags(self._experiment, tags=tags)
170
+ api.client().update_experiment(
171
+ self._experiment.experiment_id,
172
+ api.models.ExperimentPatch(tags=list(self._tags)),
173
+ )
174
+
175
+
176
+ class SlurmExperimentMetadataContext:
177
+ def __init__(
178
+ self,
179
+ experiment: xm.Experiment,
180
+ *,
181
+ annotations: SlurmExperimentContextAnnotations,
182
+ artifacts: SlurmContextArtifacts,
183
+ ):
184
+ self._experiment = experiment
185
+ self._annotations = annotations
186
+ self._artifacts = artifacts
187
+
188
+ self._graphviz_config = None
189
+ self._python_config = None
190
+
191
+ @property
192
+ def annotations(self) -> SlurmExperimentContextAnnotations:
193
+ return self._annotations
194
+
195
+ @annotations.setter
196
+ def annotations(self, annotations: SlurmExperimentContextAnnotations) -> None:
197
+ del annotations
198
+ raise ValueError("The annotations object is immutable.")
199
+
200
+ @property
201
+ def artifacts(self) -> SlurmContextArtifacts:
202
+ return self._artifacts
203
+
204
+ @artifacts.setter
205
+ def artifacts(self, artifacts: SlurmContextArtifacts) -> None:
206
+ del artifacts
207
+ raise ValueError("The artifacts object is immutable.")
208
+
209
+ @property
210
+ def graphviz_config(self) -> str | None:
211
+ return self._graphviz_config
212
+
213
+ @graphviz_config.setter
214
+ def graphviz_config(self, config: str | None) -> None:
215
+ self._graphviz_config = config
216
+ match config:
217
+ case None:
218
+ api.client().delete_experiment_config_artifact(
219
+ self._experiment.experiment_id, "GRAPHVIZ"
220
+ )
221
+ case str():
222
+ api.client().insert_experiment_config_artifact(
223
+ self._experiment.experiment_id,
224
+ api.models.ConfigArtifact(name="GRAPHVIZ", uri=f"graphviz://{config}"),
225
+ )
226
+
227
+ @graphviz_config.deleter
228
+ def graphviz_config(self) -> None:
229
+ self._graphviz_config = None
230
+ api.client().delete_experiment_config_artifact(self._experiment.experiment_id, "GRAPHVIZ")
231
+
232
+ @property
233
+ def python_config(self) -> str | None:
234
+ return self._python_config
235
+
236
+ @python_config.setter
237
+ def python_config(self, config: str | None) -> None:
238
+ self._python_config = config
239
+ match config:
240
+ case None:
241
+ api.client().delete_experiment_config_artifact(
242
+ self._experiment.experiment_id, "PYTHON"
243
+ )
244
+ case str():
245
+ api.client().insert_experiment_config_artifact(
246
+ self._experiment.experiment_id,
247
+ api.models.ConfigArtifact(name="PYTHON", uri=config),
248
+ )
249
+
250
+ @python_config.deleter
251
+ def python_config(self) -> None:
252
+ self._python_config = None
253
+ api.client().delete_experiment_config_artifact(self._experiment.experiment_id, "PYTHON")