xmanager-slurm 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/__init__.py CHANGED
@@ -3,7 +3,6 @@ import logging
3
3
  from xm_slurm.executables import Dockerfile, DockerImage
4
4
  from xm_slurm.executors import Slurm, SlurmSpec
5
5
  from xm_slurm.experiment import (
6
- Artifact,
7
6
  SlurmExperiment,
8
7
  create_experiment,
9
8
  get_current_experiment,
@@ -25,7 +24,6 @@ logging.getLogger("asyncssh").setLevel(logging.WARN)
25
24
  logging.getLogger("httpx").setLevel(logging.WARN)
26
25
 
27
26
  __all__ = [
28
- "Artifact",
29
27
  "conda_container",
30
28
  "create_experiment",
31
29
  "docker_container",
@@ -0,0 +1,33 @@
1
+ import functools
2
+ import logging
3
+ import os
4
+
5
+ from xm_slurm.api import models
6
+ from xm_slurm.api.abc import XManagerAPI
7
+ from xm_slurm.api.sqlite import client as sqlite_client
8
+ from xm_slurm.api.web import client as web_client
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @functools.cache
14
+ def client() -> XManagerAPI:
15
+ backend = os.environ.get("XM_SLURM_API_BACKEND", "sqlite").lower()
16
+ match backend:
17
+ case "rest":
18
+ if "XM_SLURM_REST_API_BASE_URL" not in os.environ:
19
+ raise ValueError("XM_SLURM_REST_API_BASE_URL is not set")
20
+ if "XM_SLURM_REST_API_TOKEN" not in os.environ:
21
+ raise ValueError("XM_SLURM_REST_API_TOKEN is not set")
22
+
23
+ return web_client.XManagerWebAPI(
24
+ base_url=os.environ["XM_SLURM_REST_API_BASE_URL"],
25
+ token=os.environ["XM_SLURM_REST_API_TOKEN"],
26
+ )
27
+ case "sqlite":
28
+ return sqlite_client.XManagerSqliteAPI()
29
+ case _:
30
+ raise ValueError(f"Invalid XM_SLURM_API_BACKEND: {backend}")
31
+
32
+
33
+ __all__ = ["client", "XManagerAPI", "models"]
xm_slurm/api/abc.py ADDED
@@ -0,0 +1,65 @@
1
+ import abc
2
+
3
+ from xm_slurm.api import models
4
+
5
+
6
+ class XManagerAPI(abc.ABC):
7
+ @abc.abstractmethod
8
+ def get_experiment(self, xid: int) -> models.Experiment:
9
+ pass
10
+
11
+ @abc.abstractmethod
12
+ def delete_experiment(self, experiment_id: int) -> None:
13
+ pass
14
+
15
+ @abc.abstractmethod
16
+ def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
17
+ pass
18
+
19
+ @abc.abstractmethod
20
+ def update_experiment(
21
+ self, experiment_id: int, experiment_patch: models.ExperimentPatch
22
+ ) -> None:
23
+ pass
24
+
25
+ @abc.abstractmethod
26
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
27
+ pass
28
+
29
+ @abc.abstractmethod
30
+ def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
31
+ pass
32
+
33
+ @abc.abstractmethod
34
+ def update_work_unit(
35
+ self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
36
+ ) -> None:
37
+ pass
38
+
39
+ @abc.abstractmethod
40
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
41
+ pass
42
+
43
+ @abc.abstractmethod
44
+ def insert_work_unit_artifact(
45
+ self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
46
+ ) -> None:
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
51
+ pass
52
+
53
+ @abc.abstractmethod
54
+ def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def insert_experiment_config_artifact(
59
+ self, experiment_id: int, artifact: models.ConfigArtifact
60
+ ) -> None:
61
+ pass
62
+
63
+ @abc.abstractmethod
64
+ def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
65
+ pass
xm_slurm/api/models.py ADDED
@@ -0,0 +1,70 @@
1
+ import dataclasses
2
+ import enum
3
+ import typing as tp
4
+
5
+
6
+ class ExperimentUnitRole(enum.Enum):
7
+ WORK_UNIT = enum.auto()
8
+ AUX_UNIT = enum.auto()
9
+
10
+
11
+ @dataclasses.dataclass(kw_only=True, frozen=True)
12
+ class ExperimentPatch:
13
+ title: str | None = None
14
+ description: str | None = None
15
+ note: str | None = None
16
+ tags: list[str] | None = None
17
+
18
+
19
+ @dataclasses.dataclass(kw_only=True, frozen=True)
20
+ class SlurmJob:
21
+ name: str
22
+ slurm_job_id: str
23
+ slurm_ssh_config: str
24
+
25
+
26
+ @dataclasses.dataclass(kw_only=True, frozen=True)
27
+ class Artifact:
28
+ name: str
29
+ uri: str
30
+
31
+
32
+ @dataclasses.dataclass(kw_only=True, frozen=True)
33
+ class ConfigArtifact:
34
+ name: tp.Literal["GRAPHVIZ", "PYTHON"]
35
+ uri: str
36
+
37
+
38
+ @dataclasses.dataclass(kw_only=True, frozen=True)
39
+ class ExperimentUnit:
40
+ identity: str
41
+ args: str | None = None
42
+ jobs: list[SlurmJob] = dataclasses.field(default_factory=list)
43
+
44
+
45
+ @dataclasses.dataclass(kw_only=True, frozen=True)
46
+ class ExperimentUnitPatch:
47
+ identity: str | None = None
48
+ args: str | None = None
49
+
50
+
51
+ @dataclasses.dataclass(kw_only=True, frozen=True)
52
+ class WorkUnit(ExperimentUnit):
53
+ wid: int
54
+ artifacts: list[Artifact] = dataclasses.field(default_factory=list)
55
+
56
+
57
+ @dataclasses.dataclass(kw_only=True, frozen=True)
58
+ class WorkUnitPatch(ExperimentUnitPatch):
59
+ wid: int
60
+
61
+
62
+ @dataclasses.dataclass(kw_only=True, frozen=True)
63
+ class Experiment:
64
+ title: str
65
+ description: str | None
66
+ note: str | None
67
+ tags: list[str] | None
68
+
69
+ work_units: list[WorkUnit] = dataclasses.field(default_factory=list)
70
+ artifacts: list[Artifact] = dataclasses.field(default_factory=list)
@@ -0,0 +1,358 @@
1
+ import logging
2
+ import os
3
+ from contextlib import contextmanager
4
+ from pathlib import Path
5
+
6
+ from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
7
+ from sqlalchemy.ext.declarative import declarative_base
8
+ from sqlalchemy.orm import relationship, sessionmaker
9
+
10
+ from xm_slurm.api import models
11
+ from xm_slurm.api.abc import XManagerAPI
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ Base = declarative_base()
16
+
17
+
18
+ class ExperimentSqliteModel(Base):
19
+ __tablename__ = "experiments"
20
+
21
+ id = Column(Integer, primary_key=True)
22
+ title = Column(String)
23
+ description = Column(String)
24
+ note = Column(String)
25
+ tags = Column(String)
26
+ work_units = relationship(
27
+ "WorkUnitSqliteModel", back_populates="experiment", cascade="all, delete-orphan"
28
+ )
29
+ artifacts = relationship(
30
+ "ExperimentArtifactSqliteModel", back_populates="experiment", cascade="all, delete-orphan"
31
+ )
32
+
33
+
34
+ class WorkUnitSqliteModel(Base):
35
+ __tablename__ = "work_units"
36
+
37
+ id = Column(Integer, primary_key=True)
38
+ experiment_id = Column(Integer, ForeignKey("experiments.id"))
39
+ wid = Column(Integer)
40
+ identity = Column(String)
41
+ args = Column(String)
42
+ experiment = relationship("ExperimentSqliteModel", back_populates="work_units")
43
+ jobs = relationship(
44
+ "SlurmJobSqliteModel", back_populates="work_unit", cascade="all, delete-orphan"
45
+ )
46
+ artifacts = relationship(
47
+ "WorkUnitArtifactSqliteModel", back_populates="work_unit", cascade="all, delete-orphan"
48
+ )
49
+
50
+
51
+ class SlurmJobSqliteModel(Base):
52
+ __tablename__ = "slurm_jobs"
53
+
54
+ id = Column(Integer, primary_key=True)
55
+ work_unit_id = Column(Integer, ForeignKey("work_units.id"))
56
+ name = Column(String)
57
+ slurm_job_id = Column(Integer)
58
+ slurm_ssh_config = Column(String)
59
+ work_unit = relationship("WorkUnitSqliteModel", back_populates="jobs")
60
+
61
+
62
+ class ArtifactSqliteModel(Base):
63
+ __tablename__ = "artifacts"
64
+ __mapper_args__ = {"polymorphic_on": "type"}
65
+
66
+ id = Column(Integer, primary_key=True)
67
+ name = Column(String)
68
+ uri = Column(String)
69
+ type = Column(String)
70
+
71
+
72
+ class ExperimentArtifactSqliteModel(ArtifactSqliteModel):
73
+ __tablename__ = "experiment_artifacts"
74
+ __mapper_args__ = {"polymorphic_identity": "experiment"}
75
+
76
+ id = Column(Integer, ForeignKey("artifacts.id"), primary_key=True)
77
+ experiment_id = Column(Integer, ForeignKey("experiments.id"))
78
+ experiment = relationship("ExperimentSqliteModel", back_populates="artifacts")
79
+
80
+
81
+ class ConfigArtifactSqliteModel(ExperimentArtifactSqliteModel):
82
+ __mapper_args__ = {"polymorphic_identity": "config"}
83
+
84
+
85
+ class WorkUnitArtifactSqliteModel(ArtifactSqliteModel):
86
+ __tablename__ = "work_unit_artifacts"
87
+ __mapper_args__ = {"polymorphic_identity": "work_unit"}
88
+
89
+ id = Column(Integer, ForeignKey("artifacts.id"), primary_key=True)
90
+ experiment_id = Column(Integer, ForeignKey("experiments.id"))
91
+ work_unit_id = Column(Integer, ForeignKey("work_units.id"))
92
+ experiment = relationship("ExperimentSqliteModel", back_populates="artifacts")
93
+ work_unit = relationship("WorkUnitSqliteModel", back_populates="artifacts")
94
+
95
+
96
+ class XManagerSqliteAPI(XManagerAPI):
97
+ def __init__(self):
98
+ if "XM_SLURM_STATE_DIR" in os.environ:
99
+ db_path = Path(os.environ["XM_SLURM_STATE_DIR"]) / "db.sqlite3"
100
+ else:
101
+ db_path = Path.home() / ".local" / "state" / "xm-slurm" / "db.sqlite3"
102
+ logger.debug("Looking for db at: %s", db_path)
103
+ db_path.parent.mkdir(parents=True, exist_ok=True)
104
+ engine = create_engine(f"sqlite:///{db_path}")
105
+ Base.metadata.create_all(engine) # type: ignore
106
+ self.Session = sessionmaker(bind=engine)
107
+
108
+ @contextmanager
109
+ def session_scope(self):
110
+ session = self.Session()
111
+ try:
112
+ yield session
113
+ session.commit()
114
+ except:
115
+ session.rollback()
116
+ raise
117
+ finally:
118
+ session.close()
119
+
120
+ def get_experiment(self, xid: int) -> models.Experiment:
121
+ with self.session_scope() as session:
122
+ experiment = (
123
+ session.query(ExperimentSqliteModel).filter(ExperimentSqliteModel.id == xid).first() # type: ignore
124
+ )
125
+ if not experiment:
126
+ raise ValueError(f"Experiment with id {xid} not found")
127
+
128
+ work_units = []
129
+ for wu in experiment.work_units:
130
+ jobs = [
131
+ models.SlurmJob(
132
+ name=job.name,
133
+ slurm_job_id=job.slurm_job_id,
134
+ slurm_ssh_config=job.slurm_ssh_config,
135
+ )
136
+ for job in wu.jobs
137
+ ]
138
+ artifacts = [
139
+ models.Artifact(name=artifact.name, uri=artifact.uri)
140
+ for artifact in wu.artifacts
141
+ ]
142
+ work_units.append(
143
+ models.WorkUnit(
144
+ wid=wu.wid,
145
+ identity=wu.identity,
146
+ args=wu.args,
147
+ jobs=jobs,
148
+ artifacts=artifacts,
149
+ )
150
+ )
151
+
152
+ # Combine regular experiment artifacts and config artifacts
153
+ artifacts = [
154
+ models.Artifact(name=artifact.name, uri=artifact.uri)
155
+ for artifact in experiment.artifacts
156
+ ] + [models.Artifact(name=config.name, uri=config.uri) for config in experiment.configs]
157
+
158
+ return models.Experiment(
159
+ title=experiment.title,
160
+ description=experiment.description,
161
+ note=experiment.note,
162
+ tags=experiment.tags.split(",") if experiment.tags else None,
163
+ work_units=work_units,
164
+ artifacts=artifacts,
165
+ )
166
+
167
+ def delete_experiment(self, experiment_id: int) -> None:
168
+ with self.session_scope() as session:
169
+ experiment = (
170
+ session.query(ExperimentSqliteModel)
171
+ .filter(ExperimentSqliteModel.id == experiment_id)
172
+ .first() # type: ignore
173
+ )
174
+ if experiment:
175
+ session.delete(experiment)
176
+
177
+ def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
178
+ with self.session_scope() as session:
179
+ new_experiment = ExperimentSqliteModel(
180
+ title=experiment.title, # type: ignore
181
+ description=experiment.description, # type: ignore
182
+ note=experiment.note, # type: ignore
183
+ tags=",".join(experiment.tags) if experiment.tags else None, # type: ignore
184
+ )
185
+ session.add(new_experiment)
186
+ session.flush()
187
+ return new_experiment.id
188
+
189
+ def update_experiment(
190
+ self, experiment_id: int, experiment_patch: models.ExperimentPatch
191
+ ) -> None:
192
+ with self.session_scope() as session:
193
+ experiment = (
194
+ session.query(ExperimentSqliteModel)
195
+ .filter(ExperimentSqliteModel.id == experiment_id)
196
+ .first() # type: ignore
197
+ )
198
+ if experiment:
199
+ if experiment_patch.title is not None:
200
+ experiment.title = experiment_patch.title
201
+ if experiment_patch.description is not None:
202
+ experiment.description = experiment_patch.description
203
+ if experiment_patch.note is not None:
204
+ experiment.note = experiment_patch.note
205
+ if experiment_patch.tags is not None:
206
+ experiment.tags = ",".join(experiment_patch.tags)
207
+
208
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
209
+ with self.session_scope() as session:
210
+ work_unit = (
211
+ session.query(WorkUnitSqliteModel)
212
+ .filter_by(experiment_id=experiment_id, wid=work_unit_id)
213
+ .first() # type: ignore
214
+ )
215
+ if work_unit:
216
+ new_job = SlurmJobSqliteModel(
217
+ work_unit_id=work_unit.id, # type: ignore
218
+ name=job.name, # type: ignore
219
+ slurm_job_id=job.slurm_job_id, # type: ignore
220
+ slurm_ssh_config=job.slurm_ssh_config, # type: ignore
221
+ )
222
+ session.add(new_job)
223
+ else:
224
+ raise ValueError(
225
+ f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
226
+ )
227
+
228
+ def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
229
+ with self.session_scope() as session:
230
+ new_work_unit = WorkUnitSqliteModel(
231
+ experiment_id=experiment_id, # type: ignore
232
+ wid=work_unit.wid, # type: ignore
233
+ identity=work_unit.identity, # type: ignore
234
+ args=work_unit.args, # type: ignore
235
+ )
236
+ session.add(new_work_unit)
237
+
238
+ def update_work_unit(
239
+ self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
240
+ ) -> None:
241
+ with self.session_scope() as session:
242
+ work_unit = (
243
+ session.query(WorkUnitSqliteModel)
244
+ .filter(
245
+ WorkUnitSqliteModel.experiment_id == experiment_id,
246
+ WorkUnitSqliteModel.wid == work_unit_id,
247
+ )
248
+ .first() # type: ignore
249
+ )
250
+
251
+ if work_unit:
252
+ if patch.identity is not None:
253
+ work_unit.identity = patch.identity
254
+ if patch.args is not None:
255
+ work_unit.args = patch.args
256
+ else:
257
+ raise ValueError(
258
+ f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
259
+ )
260
+
261
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
262
+ with self.session_scope() as session:
263
+ artifact = (
264
+ session.query(WorkUnitArtifactSqliteModel)
265
+ .filter(
266
+ WorkUnitArtifactSqliteModel.experiment_id == experiment_id,
267
+ WorkUnitArtifactSqliteModel.work_unit_id == work_unit_id,
268
+ WorkUnitArtifactSqliteModel.name == name,
269
+ )
270
+ .first() # type: ignore
271
+ )
272
+ if artifact:
273
+ session.delete(artifact)
274
+
275
+ def insert_work_unit_artifact(
276
+ self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
277
+ ) -> None:
278
+ with self.session_scope() as session:
279
+ work_unit = (
280
+ session.query(WorkUnitSqliteModel)
281
+ .filter(
282
+ WorkUnitSqliteModel.experiment_id == experiment_id,
283
+ WorkUnitSqliteModel.wid == work_unit_id,
284
+ )
285
+ .first() # type: ignore
286
+ )
287
+ if work_unit:
288
+ new_artifact = WorkUnitArtifactSqliteModel(
289
+ experiment_id=experiment_id, # type: ignore
290
+ work_unit_id=work_unit.id, # type: ignore
291
+ name=artifact.name, # type: ignore
292
+ uri=artifact.uri, # type: ignore
293
+ )
294
+ session.add(new_artifact)
295
+ else:
296
+ raise ValueError(
297
+ f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
298
+ )
299
+
300
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
301
+ with self.session_scope() as session:
302
+ # Try to find and delete either type of artifact
303
+ artifact = (
304
+ session.query(ArtifactSqliteModel)
305
+ .filter(
306
+ ArtifactSqliteModel.name == name,
307
+ (
308
+ (ArtifactSqliteModel.type == "experiment")
309
+ | (ArtifactSqliteModel.type == "config")
310
+ ),
311
+ )
312
+ .join(ExperimentArtifactSqliteModel) # type: ignore
313
+ .filter(ExperimentArtifactSqliteModel.experiment_id == experiment_id)
314
+ .first()
315
+ )
316
+ if artifact:
317
+ session.delete(artifact)
318
+
319
+ def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
320
+ with self.session_scope() as session:
321
+ # Determine if this should be a config artifact based on name
322
+ if artifact.name in ("PYTHON", "GRAPHVIZ"):
323
+ new_artifact = ConfigArtifactSqliteModel(
324
+ experiment_id=experiment_id, # type: ignore
325
+ name=artifact.name, # type: ignore
326
+ uri=artifact.uri, # type: ignore
327
+ )
328
+ else:
329
+ new_artifact = ExperimentArtifactSqliteModel(
330
+ experiment_id=experiment_id, # type: ignore
331
+ name=artifact.name, # type: ignore
332
+ uri=artifact.uri, # type: ignore
333
+ )
334
+ session.add(new_artifact)
335
+
336
+ def insert_experiment_config_artifact(
337
+ self, experiment_id: int, artifact: models.ConfigArtifact
338
+ ) -> None:
339
+ with self.session_scope() as session:
340
+ new_artifact = ConfigArtifactSqliteModel(
341
+ experiment_id=experiment_id, # type: ignore
342
+ name=artifact.name, # type: ignore
343
+ uri=artifact.uri, # type: ignore
344
+ )
345
+ session.add(new_artifact)
346
+
347
+ def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
348
+ with self.session_scope() as session:
349
+ artifact = (
350
+ session.query(ConfigArtifactSqliteModel)
351
+ .filter(
352
+ ConfigArtifactSqliteModel.experiment_id == experiment_id,
353
+ ConfigArtifactSqliteModel.name == name,
354
+ )
355
+ .first() # type: ignore
356
+ )
357
+ if artifact:
358
+ session.delete(artifact)