xmanager-slurm 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +0 -2
- xm_slurm/api/__init__.py +33 -0
- xm_slurm/api/abc.py +65 -0
- xm_slurm/api/models.py +70 -0
- xm_slurm/api/sqlite/client.py +358 -0
- xm_slurm/api/web/client.py +173 -0
- xm_slurm/config.py +11 -3
- xm_slurm/contrib/clusters/__init__.py +3 -6
- xm_slurm/contrib/clusters/drac.py +4 -3
- xm_slurm/executables.py +4 -7
- xm_slurm/execution.py +290 -159
- xm_slurm/experiment.py +26 -180
- xm_slurm/filesystem.py +129 -0
- xm_slurm/metadata_context.py +253 -0
- xm_slurm/packageables.py +0 -9
- xm_slurm/packaging/docker.py +72 -22
- xm_slurm/packaging/utils.py +0 -108
- xm_slurm/scripts/cli.py +9 -2
- xm_slurm/templates/docker/uv.Dockerfile +6 -3
- xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
- xm_slurm/templates/slurm/job-array.bash.j2 +4 -4
- xm_slurm/templates/slurm/job-group.bash.j2 +2 -2
- xm_slurm/templates/slurm/job.bash.j2 +5 -4
- xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +18 -54
- xm_slurm/templates/slurm/runtimes/podman.bash.j2 +9 -24
- xm_slurm/utils.py +122 -41
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/METADATA +7 -3
- xmanager_slurm-0.4.7.dist-info/RECORD +51 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/WHEEL +1 -1
- xm_slurm/api.py +0 -528
- xmanager_slurm-0.4.5.dist-info/RECORD +0 -44
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/entry_points.txt +0 -0
- {xmanager_slurm-0.4.5.dist-info → xmanager_slurm-0.4.7.dist-info}/licenses/LICENSE.md +0 -0
xm_slurm/__init__.py
CHANGED
|
@@ -3,7 +3,6 @@ import logging
|
|
|
3
3
|
from xm_slurm.executables import Dockerfile, DockerImage
|
|
4
4
|
from xm_slurm.executors import Slurm, SlurmSpec
|
|
5
5
|
from xm_slurm.experiment import (
|
|
6
|
-
Artifact,
|
|
7
6
|
SlurmExperiment,
|
|
8
7
|
create_experiment,
|
|
9
8
|
get_current_experiment,
|
|
@@ -25,7 +24,6 @@ logging.getLogger("asyncssh").setLevel(logging.WARN)
|
|
|
25
24
|
logging.getLogger("httpx").setLevel(logging.WARN)
|
|
26
25
|
|
|
27
26
|
__all__ = [
|
|
28
|
-
"Artifact",
|
|
29
27
|
"conda_container",
|
|
30
28
|
"create_experiment",
|
|
31
29
|
"docker_container",
|
xm_slurm/api/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from xm_slurm.api import models
|
|
6
|
+
from xm_slurm.api.abc import XManagerAPI
|
|
7
|
+
from xm_slurm.api.sqlite import client as sqlite_client
|
|
8
|
+
from xm_slurm.api.web import client as web_client
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@functools.cache
|
|
14
|
+
def client() -> XManagerAPI:
|
|
15
|
+
backend = os.environ.get("XM_SLURM_API_BACKEND", "sqlite").lower()
|
|
16
|
+
match backend:
|
|
17
|
+
case "rest":
|
|
18
|
+
if "XM_SLURM_REST_API_BASE_URL" not in os.environ:
|
|
19
|
+
raise ValueError("XM_SLURM_REST_API_BASE_URL is not set")
|
|
20
|
+
if "XM_SLURM_REST_API_TOKEN" not in os.environ:
|
|
21
|
+
raise ValueError("XM_SLURM_REST_API_TOKEN is not set")
|
|
22
|
+
|
|
23
|
+
return web_client.XManagerWebAPI(
|
|
24
|
+
base_url=os.environ["XM_SLURM_REST_API_BASE_URL"],
|
|
25
|
+
token=os.environ["XM_SLURM_REST_API_TOKEN"],
|
|
26
|
+
)
|
|
27
|
+
case "sqlite":
|
|
28
|
+
return sqlite_client.XManagerSqliteAPI()
|
|
29
|
+
case _:
|
|
30
|
+
raise ValueError(f"Invalid XM_SLURM_API_BACKEND: {backend}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
__all__ = ["client", "XManagerAPI", "models"]
|
xm_slurm/api/abc.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
|
|
3
|
+
from xm_slurm.api import models
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class XManagerAPI(abc.ABC):
|
|
7
|
+
@abc.abstractmethod
|
|
8
|
+
def get_experiment(self, xid: int) -> models.Experiment:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
@abc.abstractmethod
|
|
12
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abc.abstractmethod
|
|
16
|
+
def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def update_experiment(
|
|
21
|
+
self, experiment_id: int, experiment_patch: models.ExperimentPatch
|
|
22
|
+
) -> None:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def update_work_unit(
|
|
35
|
+
self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
|
|
36
|
+
) -> None:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abc.abstractmethod
|
|
40
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
def insert_work_unit_artifact(
|
|
45
|
+
self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
|
|
46
|
+
) -> None:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abc.abstractmethod
|
|
54
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@abc.abstractmethod
|
|
58
|
+
def insert_experiment_config_artifact(
|
|
59
|
+
self, experiment_id: int, artifact: models.ConfigArtifact
|
|
60
|
+
) -> None:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abc.abstractmethod
|
|
64
|
+
def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
|
|
65
|
+
pass
|
xm_slurm/api/models.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import enum
|
|
3
|
+
import typing as tp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExperimentUnitRole(enum.Enum):
|
|
7
|
+
WORK_UNIT = enum.auto()
|
|
8
|
+
AUX_UNIT = enum.auto()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
12
|
+
class ExperimentPatch:
|
|
13
|
+
title: str | None = None
|
|
14
|
+
description: str | None = None
|
|
15
|
+
note: str | None = None
|
|
16
|
+
tags: list[str] | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
20
|
+
class SlurmJob:
|
|
21
|
+
name: str
|
|
22
|
+
slurm_job_id: str
|
|
23
|
+
slurm_ssh_config: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
27
|
+
class Artifact:
|
|
28
|
+
name: str
|
|
29
|
+
uri: str
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
33
|
+
class ConfigArtifact:
|
|
34
|
+
name: tp.Literal["GRAPHVIZ", "PYTHON"]
|
|
35
|
+
uri: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
39
|
+
class ExperimentUnit:
|
|
40
|
+
identity: str
|
|
41
|
+
args: str | None = None
|
|
42
|
+
jobs: list[SlurmJob] = dataclasses.field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
46
|
+
class ExperimentUnitPatch:
|
|
47
|
+
identity: str | None = None
|
|
48
|
+
args: str | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
52
|
+
class WorkUnit(ExperimentUnit):
|
|
53
|
+
wid: int
|
|
54
|
+
artifacts: list[Artifact] = dataclasses.field(default_factory=list)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
58
|
+
class WorkUnitPatch(ExperimentUnitPatch):
|
|
59
|
+
wid: int
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
63
|
+
class Experiment:
|
|
64
|
+
title: str
|
|
65
|
+
description: str | None
|
|
66
|
+
note: str | None
|
|
67
|
+
tags: list[str] | None
|
|
68
|
+
|
|
69
|
+
work_units: list[WorkUnit] = dataclasses.field(default_factory=list)
|
|
70
|
+
artifacts: list[Artifact] = dataclasses.field(default_factory=list)
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
|
|
7
|
+
from sqlalchemy.ext.declarative import declarative_base
|
|
8
|
+
from sqlalchemy.orm import relationship, sessionmaker
|
|
9
|
+
|
|
10
|
+
from xm_slurm.api import models
|
|
11
|
+
from xm_slurm.api.abc import XManagerAPI
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
Base = declarative_base()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ExperimentSqliteModel(Base):
|
|
19
|
+
__tablename__ = "experiments"
|
|
20
|
+
|
|
21
|
+
id = Column(Integer, primary_key=True)
|
|
22
|
+
title = Column(String)
|
|
23
|
+
description = Column(String)
|
|
24
|
+
note = Column(String)
|
|
25
|
+
tags = Column(String)
|
|
26
|
+
work_units = relationship(
|
|
27
|
+
"WorkUnitSqliteModel", back_populates="experiment", cascade="all, delete-orphan"
|
|
28
|
+
)
|
|
29
|
+
artifacts = relationship(
|
|
30
|
+
"ExperimentArtifactSqliteModel", back_populates="experiment", cascade="all, delete-orphan"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WorkUnitSqliteModel(Base):
|
|
35
|
+
__tablename__ = "work_units"
|
|
36
|
+
|
|
37
|
+
id = Column(Integer, primary_key=True)
|
|
38
|
+
experiment_id = Column(Integer, ForeignKey("experiments.id"))
|
|
39
|
+
wid = Column(Integer)
|
|
40
|
+
identity = Column(String)
|
|
41
|
+
args = Column(String)
|
|
42
|
+
experiment = relationship("ExperimentSqliteModel", back_populates="work_units")
|
|
43
|
+
jobs = relationship(
|
|
44
|
+
"SlurmJobSqliteModel", back_populates="work_unit", cascade="all, delete-orphan"
|
|
45
|
+
)
|
|
46
|
+
artifacts = relationship(
|
|
47
|
+
"WorkUnitArtifactSqliteModel", back_populates="work_unit", cascade="all, delete-orphan"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SlurmJobSqliteModel(Base):
|
|
52
|
+
__tablename__ = "slurm_jobs"
|
|
53
|
+
|
|
54
|
+
id = Column(Integer, primary_key=True)
|
|
55
|
+
work_unit_id = Column(Integer, ForeignKey("work_units.id"))
|
|
56
|
+
name = Column(String)
|
|
57
|
+
slurm_job_id = Column(Integer)
|
|
58
|
+
slurm_ssh_config = Column(String)
|
|
59
|
+
work_unit = relationship("WorkUnitSqliteModel", back_populates="jobs")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ArtifactSqliteModel(Base):
|
|
63
|
+
__tablename__ = "artifacts"
|
|
64
|
+
__mapper_args__ = {"polymorphic_on": "type"}
|
|
65
|
+
|
|
66
|
+
id = Column(Integer, primary_key=True)
|
|
67
|
+
name = Column(String)
|
|
68
|
+
uri = Column(String)
|
|
69
|
+
type = Column(String)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ExperimentArtifactSqliteModel(ArtifactSqliteModel):
|
|
73
|
+
__tablename__ = "experiment_artifacts"
|
|
74
|
+
__mapper_args__ = {"polymorphic_identity": "experiment"}
|
|
75
|
+
|
|
76
|
+
id = Column(Integer, ForeignKey("artifacts.id"), primary_key=True)
|
|
77
|
+
experiment_id = Column(Integer, ForeignKey("experiments.id"))
|
|
78
|
+
experiment = relationship("ExperimentSqliteModel", back_populates="artifacts")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ConfigArtifactSqliteModel(ExperimentArtifactSqliteModel):
|
|
82
|
+
__mapper_args__ = {"polymorphic_identity": "config"}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class WorkUnitArtifactSqliteModel(ArtifactSqliteModel):
|
|
86
|
+
__tablename__ = "work_unit_artifacts"
|
|
87
|
+
__mapper_args__ = {"polymorphic_identity": "work_unit"}
|
|
88
|
+
|
|
89
|
+
id = Column(Integer, ForeignKey("artifacts.id"), primary_key=True)
|
|
90
|
+
experiment_id = Column(Integer, ForeignKey("experiments.id"))
|
|
91
|
+
work_unit_id = Column(Integer, ForeignKey("work_units.id"))
|
|
92
|
+
experiment = relationship("ExperimentSqliteModel", back_populates="artifacts")
|
|
93
|
+
work_unit = relationship("WorkUnitSqliteModel", back_populates="artifacts")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class XManagerSqliteAPI(XManagerAPI):
|
|
97
|
+
def __init__(self):
|
|
98
|
+
if "XM_SLURM_STATE_DIR" in os.environ:
|
|
99
|
+
db_path = Path(os.environ["XM_SLURM_STATE_DIR"]) / "db.sqlite3"
|
|
100
|
+
else:
|
|
101
|
+
db_path = Path.home() / ".local" / "state" / "xm-slurm" / "db.sqlite3"
|
|
102
|
+
logger.debug("Looking for db at: %s", db_path)
|
|
103
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
104
|
+
engine = create_engine(f"sqlite:///{db_path}")
|
|
105
|
+
Base.metadata.create_all(engine) # type: ignore
|
|
106
|
+
self.Session = sessionmaker(bind=engine)
|
|
107
|
+
|
|
108
|
+
@contextmanager
|
|
109
|
+
def session_scope(self):
|
|
110
|
+
session = self.Session()
|
|
111
|
+
try:
|
|
112
|
+
yield session
|
|
113
|
+
session.commit()
|
|
114
|
+
except:
|
|
115
|
+
session.rollback()
|
|
116
|
+
raise
|
|
117
|
+
finally:
|
|
118
|
+
session.close()
|
|
119
|
+
|
|
120
|
+
def get_experiment(self, xid: int) -> models.Experiment:
|
|
121
|
+
with self.session_scope() as session:
|
|
122
|
+
experiment = (
|
|
123
|
+
session.query(ExperimentSqliteModel).filter(ExperimentSqliteModel.id == xid).first() # type: ignore
|
|
124
|
+
)
|
|
125
|
+
if not experiment:
|
|
126
|
+
raise ValueError(f"Experiment with id {xid} not found")
|
|
127
|
+
|
|
128
|
+
work_units = []
|
|
129
|
+
for wu in experiment.work_units:
|
|
130
|
+
jobs = [
|
|
131
|
+
models.SlurmJob(
|
|
132
|
+
name=job.name,
|
|
133
|
+
slurm_job_id=job.slurm_job_id,
|
|
134
|
+
slurm_ssh_config=job.slurm_ssh_config,
|
|
135
|
+
)
|
|
136
|
+
for job in wu.jobs
|
|
137
|
+
]
|
|
138
|
+
artifacts = [
|
|
139
|
+
models.Artifact(name=artifact.name, uri=artifact.uri)
|
|
140
|
+
for artifact in wu.artifacts
|
|
141
|
+
]
|
|
142
|
+
work_units.append(
|
|
143
|
+
models.WorkUnit(
|
|
144
|
+
wid=wu.wid,
|
|
145
|
+
identity=wu.identity,
|
|
146
|
+
args=wu.args,
|
|
147
|
+
jobs=jobs,
|
|
148
|
+
artifacts=artifacts,
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Combine regular experiment artifacts and config artifacts
|
|
153
|
+
artifacts = [
|
|
154
|
+
models.Artifact(name=artifact.name, uri=artifact.uri)
|
|
155
|
+
for artifact in experiment.artifacts
|
|
156
|
+
] + [models.Artifact(name=config.name, uri=config.uri) for config in experiment.configs]
|
|
157
|
+
|
|
158
|
+
return models.Experiment(
|
|
159
|
+
title=experiment.title,
|
|
160
|
+
description=experiment.description,
|
|
161
|
+
note=experiment.note,
|
|
162
|
+
tags=experiment.tags.split(",") if experiment.tags else None,
|
|
163
|
+
work_units=work_units,
|
|
164
|
+
artifacts=artifacts,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
168
|
+
with self.session_scope() as session:
|
|
169
|
+
experiment = (
|
|
170
|
+
session.query(ExperimentSqliteModel)
|
|
171
|
+
.filter(ExperimentSqliteModel.id == experiment_id)
|
|
172
|
+
.first() # type: ignore
|
|
173
|
+
)
|
|
174
|
+
if experiment:
|
|
175
|
+
session.delete(experiment)
|
|
176
|
+
|
|
177
|
+
def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
|
|
178
|
+
with self.session_scope() as session:
|
|
179
|
+
new_experiment = ExperimentSqliteModel(
|
|
180
|
+
title=experiment.title, # type: ignore
|
|
181
|
+
description=experiment.description, # type: ignore
|
|
182
|
+
note=experiment.note, # type: ignore
|
|
183
|
+
tags=",".join(experiment.tags) if experiment.tags else None, # type: ignore
|
|
184
|
+
)
|
|
185
|
+
session.add(new_experiment)
|
|
186
|
+
session.flush()
|
|
187
|
+
return new_experiment.id
|
|
188
|
+
|
|
189
|
+
def update_experiment(
|
|
190
|
+
self, experiment_id: int, experiment_patch: models.ExperimentPatch
|
|
191
|
+
) -> None:
|
|
192
|
+
with self.session_scope() as session:
|
|
193
|
+
experiment = (
|
|
194
|
+
session.query(ExperimentSqliteModel)
|
|
195
|
+
.filter(ExperimentSqliteModel.id == experiment_id)
|
|
196
|
+
.first() # type: ignore
|
|
197
|
+
)
|
|
198
|
+
if experiment:
|
|
199
|
+
if experiment_patch.title is not None:
|
|
200
|
+
experiment.title = experiment_patch.title
|
|
201
|
+
if experiment_patch.description is not None:
|
|
202
|
+
experiment.description = experiment_patch.description
|
|
203
|
+
if experiment_patch.note is not None:
|
|
204
|
+
experiment.note = experiment_patch.note
|
|
205
|
+
if experiment_patch.tags is not None:
|
|
206
|
+
experiment.tags = ",".join(experiment_patch.tags)
|
|
207
|
+
|
|
208
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
|
|
209
|
+
with self.session_scope() as session:
|
|
210
|
+
work_unit = (
|
|
211
|
+
session.query(WorkUnitSqliteModel)
|
|
212
|
+
.filter_by(experiment_id=experiment_id, wid=work_unit_id)
|
|
213
|
+
.first() # type: ignore
|
|
214
|
+
)
|
|
215
|
+
if work_unit:
|
|
216
|
+
new_job = SlurmJobSqliteModel(
|
|
217
|
+
work_unit_id=work_unit.id, # type: ignore
|
|
218
|
+
name=job.name, # type: ignore
|
|
219
|
+
slurm_job_id=job.slurm_job_id, # type: ignore
|
|
220
|
+
slurm_ssh_config=job.slurm_ssh_config, # type: ignore
|
|
221
|
+
)
|
|
222
|
+
session.add(new_job)
|
|
223
|
+
else:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
|
|
229
|
+
with self.session_scope() as session:
|
|
230
|
+
new_work_unit = WorkUnitSqliteModel(
|
|
231
|
+
experiment_id=experiment_id, # type: ignore
|
|
232
|
+
wid=work_unit.wid, # type: ignore
|
|
233
|
+
identity=work_unit.identity, # type: ignore
|
|
234
|
+
args=work_unit.args, # type: ignore
|
|
235
|
+
)
|
|
236
|
+
session.add(new_work_unit)
|
|
237
|
+
|
|
238
|
+
def update_work_unit(
|
|
239
|
+
self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
|
|
240
|
+
) -> None:
|
|
241
|
+
with self.session_scope() as session:
|
|
242
|
+
work_unit = (
|
|
243
|
+
session.query(WorkUnitSqliteModel)
|
|
244
|
+
.filter(
|
|
245
|
+
WorkUnitSqliteModel.experiment_id == experiment_id,
|
|
246
|
+
WorkUnitSqliteModel.wid == work_unit_id,
|
|
247
|
+
)
|
|
248
|
+
.first() # type: ignore
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if work_unit:
|
|
252
|
+
if patch.identity is not None:
|
|
253
|
+
work_unit.identity = patch.identity
|
|
254
|
+
if patch.args is not None:
|
|
255
|
+
work_unit.args = patch.args
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
262
|
+
with self.session_scope() as session:
|
|
263
|
+
artifact = (
|
|
264
|
+
session.query(WorkUnitArtifactSqliteModel)
|
|
265
|
+
.filter(
|
|
266
|
+
WorkUnitArtifactSqliteModel.experiment_id == experiment_id,
|
|
267
|
+
WorkUnitArtifactSqliteModel.work_unit_id == work_unit_id,
|
|
268
|
+
WorkUnitArtifactSqliteModel.name == name,
|
|
269
|
+
)
|
|
270
|
+
.first() # type: ignore
|
|
271
|
+
)
|
|
272
|
+
if artifact:
|
|
273
|
+
session.delete(artifact)
|
|
274
|
+
|
|
275
|
+
def insert_work_unit_artifact(
|
|
276
|
+
self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
|
|
277
|
+
) -> None:
|
|
278
|
+
with self.session_scope() as session:
|
|
279
|
+
work_unit = (
|
|
280
|
+
session.query(WorkUnitSqliteModel)
|
|
281
|
+
.filter(
|
|
282
|
+
WorkUnitSqliteModel.experiment_id == experiment_id,
|
|
283
|
+
WorkUnitSqliteModel.wid == work_unit_id,
|
|
284
|
+
)
|
|
285
|
+
.first() # type: ignore
|
|
286
|
+
)
|
|
287
|
+
if work_unit:
|
|
288
|
+
new_artifact = WorkUnitArtifactSqliteModel(
|
|
289
|
+
experiment_id=experiment_id, # type: ignore
|
|
290
|
+
work_unit_id=work_unit.id, # type: ignore
|
|
291
|
+
name=artifact.name, # type: ignore
|
|
292
|
+
uri=artifact.uri, # type: ignore
|
|
293
|
+
)
|
|
294
|
+
session.add(new_artifact)
|
|
295
|
+
else:
|
|
296
|
+
raise ValueError(
|
|
297
|
+
f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
301
|
+
with self.session_scope() as session:
|
|
302
|
+
# Try to find and delete either type of artifact
|
|
303
|
+
artifact = (
|
|
304
|
+
session.query(ArtifactSqliteModel)
|
|
305
|
+
.filter(
|
|
306
|
+
ArtifactSqliteModel.name == name,
|
|
307
|
+
(
|
|
308
|
+
(ArtifactSqliteModel.type == "experiment")
|
|
309
|
+
| (ArtifactSqliteModel.type == "config")
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
.join(ExperimentArtifactSqliteModel) # type: ignore
|
|
313
|
+
.filter(ExperimentArtifactSqliteModel.experiment_id == experiment_id)
|
|
314
|
+
.first()
|
|
315
|
+
)
|
|
316
|
+
if artifact:
|
|
317
|
+
session.delete(artifact)
|
|
318
|
+
|
|
319
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
|
|
320
|
+
with self.session_scope() as session:
|
|
321
|
+
# Determine if this should be a config artifact based on name
|
|
322
|
+
if artifact.name in ("PYTHON", "GRAPHVIZ"):
|
|
323
|
+
new_artifact = ConfigArtifactSqliteModel(
|
|
324
|
+
experiment_id=experiment_id, # type: ignore
|
|
325
|
+
name=artifact.name, # type: ignore
|
|
326
|
+
uri=artifact.uri, # type: ignore
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
new_artifact = ExperimentArtifactSqliteModel(
|
|
330
|
+
experiment_id=experiment_id, # type: ignore
|
|
331
|
+
name=artifact.name, # type: ignore
|
|
332
|
+
uri=artifact.uri, # type: ignore
|
|
333
|
+
)
|
|
334
|
+
session.add(new_artifact)
|
|
335
|
+
|
|
336
|
+
def insert_experiment_config_artifact(
|
|
337
|
+
self, experiment_id: int, artifact: models.ConfigArtifact
|
|
338
|
+
) -> None:
|
|
339
|
+
with self.session_scope() as session:
|
|
340
|
+
new_artifact = ConfigArtifactSqliteModel(
|
|
341
|
+
experiment_id=experiment_id, # type: ignore
|
|
342
|
+
name=artifact.name, # type: ignore
|
|
343
|
+
uri=artifact.uri, # type: ignore
|
|
344
|
+
)
|
|
345
|
+
session.add(new_artifact)
|
|
346
|
+
|
|
347
|
+
def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
|
|
348
|
+
with self.session_scope() as session:
|
|
349
|
+
artifact = (
|
|
350
|
+
session.query(ConfigArtifactSqliteModel)
|
|
351
|
+
.filter(
|
|
352
|
+
ConfigArtifactSqliteModel.experiment_id == experiment_id,
|
|
353
|
+
ConfigArtifactSqliteModel.name == name,
|
|
354
|
+
)
|
|
355
|
+
.first() # type: ignore
|
|
356
|
+
)
|
|
357
|
+
if artifact:
|
|
358
|
+
session.delete(artifact)
|