xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +99 -54
- xm_slurm/constants.py +15 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +22 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/executables.py +19 -7
- xm_slurm/execution.py +86 -38
- xm_slurm/experiment.py +273 -131
- xm_slurm/experimental/parameter_controller.py +200 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +45 -18
- xm_slurm/packaging/docker/__init__.py +5 -11
- xm_slurm/packaging/docker/local.py +13 -12
- xm_slurm/packaging/utils.py +7 -55
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/docker-bake.hcl.j2 +7 -0
- xm_slurm/templates/docker/mamba.Dockerfile +3 -1
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.0.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.0.dist-info/RECORD +42 -0
- {xmanager_slurm-0.3.1.dist-info → xmanager_slurm-0.4.0.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.0.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.1.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.1.dist-info/RECORD +0 -38
xm_slurm/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ from xm_slurm.executables import Dockerfile, DockerImage
|
|
|
4
4
|
from xm_slurm.executors import Slurm, SlurmSpec
|
|
5
5
|
from xm_slurm.experiment import (
|
|
6
6
|
Artifact,
|
|
7
|
+
SlurmExperiment,
|
|
7
8
|
create_experiment,
|
|
8
9
|
get_current_experiment,
|
|
9
10
|
get_current_work_unit,
|
|
@@ -14,8 +15,8 @@ from xm_slurm.packageables import (
|
|
|
14
15
|
docker_container,
|
|
15
16
|
docker_image,
|
|
16
17
|
mamba_container,
|
|
17
|
-
pdm_container,
|
|
18
18
|
python_container,
|
|
19
|
+
uv_container,
|
|
19
20
|
)
|
|
20
21
|
from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
|
|
21
22
|
|
|
@@ -35,10 +36,11 @@ __all__ = [
|
|
|
35
36
|
"get_experiment",
|
|
36
37
|
"JobRequirements",
|
|
37
38
|
"mamba_container",
|
|
38
|
-
"
|
|
39
|
+
"uv_container",
|
|
39
40
|
"python_container",
|
|
40
41
|
"ResourceQuantity",
|
|
41
42
|
"ResourceType",
|
|
42
43
|
"Slurm",
|
|
43
44
|
"SlurmSpec",
|
|
45
|
+
"SlurmExperiment",
|
|
44
46
|
]
|
xm_slurm/api.py
CHANGED
|
@@ -1,15 +1,27 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
+
import enum
|
|
2
3
|
import functools
|
|
3
4
|
import importlib.util
|
|
4
5
|
import logging
|
|
5
6
|
import os
|
|
6
|
-
import time
|
|
7
7
|
import typing
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from pathlib import Path
|
|
8
11
|
from typing import Any
|
|
9
12
|
|
|
13
|
+
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
|
|
14
|
+
from sqlalchemy.ext.declarative import declarative_base
|
|
15
|
+
from sqlalchemy.orm import relationship, sessionmaker
|
|
16
|
+
|
|
10
17
|
logger = logging.getLogger(__name__)
|
|
11
18
|
|
|
12
19
|
|
|
20
|
+
class ExperimentUnitRole(enum.Enum):
|
|
21
|
+
WORK_UNIT = enum.auto()
|
|
22
|
+
AUX_UNIT = enum.auto()
|
|
23
|
+
|
|
24
|
+
|
|
13
25
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
14
26
|
class ExperimentPatchModel:
|
|
15
27
|
title: str | None = None
|
|
@@ -22,7 +34,7 @@ class ExperimentPatchModel:
|
|
|
22
34
|
class SlurmJobModel:
|
|
23
35
|
name: str
|
|
24
36
|
slurm_job_id: int
|
|
25
|
-
|
|
37
|
+
slurm_ssh_config: str
|
|
26
38
|
|
|
27
39
|
|
|
28
40
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
@@ -32,15 +44,21 @@ class ArtifactModel:
|
|
|
32
44
|
|
|
33
45
|
|
|
34
46
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
35
|
-
class
|
|
36
|
-
|
|
47
|
+
class ExperimentUnitModel:
|
|
48
|
+
identity: str
|
|
49
|
+
args: str | None = None
|
|
50
|
+
jobs: list[SlurmJobModel] = dataclasses.field(default_factory=list)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
54
|
+
class ExperimentUnitPatchModel:
|
|
37
55
|
identity: str | None
|
|
38
56
|
args: str | None = None
|
|
39
57
|
|
|
40
58
|
|
|
41
59
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
|
42
|
-
class WorkUnitModel(
|
|
43
|
-
|
|
60
|
+
class WorkUnitModel(ExperimentUnitModel):
|
|
61
|
+
wid: int
|
|
44
62
|
artifacts: list[ArtifactModel] = dataclasses.field(default_factory=list)
|
|
45
63
|
|
|
46
64
|
|
|
@@ -55,49 +73,54 @@ class ExperimentModel:
|
|
|
55
73
|
artifacts: list[ArtifactModel]
|
|
56
74
|
|
|
57
75
|
|
|
58
|
-
class XManagerAPI:
|
|
76
|
+
class XManagerAPI(ABC):
|
|
77
|
+
@abstractmethod
|
|
59
78
|
def get_experiment(self, xid: int) -> ExperimentModel:
|
|
60
|
-
|
|
61
|
-
raise NotImplementedError("`get_experiment` is not implemented without a storage backend.")
|
|
79
|
+
pass
|
|
62
80
|
|
|
81
|
+
@abstractmethod
|
|
63
82
|
def delete_experiment(self, experiment_id: int) -> None:
|
|
64
|
-
|
|
65
|
-
logger.debug("`delete_experiment` is not implemented without a storage backend.")
|
|
83
|
+
pass
|
|
66
84
|
|
|
85
|
+
@abstractmethod
|
|
67
86
|
def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
|
|
68
|
-
|
|
69
|
-
logger.debug("`insert_experiment` is not implemented without a storage backend.")
|
|
70
|
-
return int(time.time() * 10**3)
|
|
87
|
+
pass
|
|
71
88
|
|
|
89
|
+
@abstractmethod
|
|
72
90
|
def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
|
|
73
|
-
|
|
74
|
-
logger.debug("`update_experiment` is not implemented without a storage backend.")
|
|
91
|
+
pass
|
|
75
92
|
|
|
93
|
+
@abstractmethod
|
|
76
94
|
def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
|
|
77
|
-
|
|
78
|
-
logger.debug("`insert_job` is not implemented without a storage backend.")
|
|
95
|
+
pass
|
|
79
96
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
|
|
99
|
+
pass
|
|
83
100
|
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def update_work_unit(
|
|
103
|
+
self, experiment_id: int, work_unit_id: int, patch: ExperimentUnitPatchModel
|
|
104
|
+
) -> None:
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
@abstractmethod
|
|
84
108
|
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
85
|
-
|
|
86
|
-
logger.debug("`delete_work_unit_artifact` is not implemented without a storage backend.")
|
|
109
|
+
pass
|
|
87
110
|
|
|
111
|
+
@abstractmethod
|
|
88
112
|
def insert_work_unit_artifact(
|
|
89
113
|
self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
|
|
90
114
|
) -> None:
|
|
91
|
-
|
|
92
|
-
logger.debug("`insert_work_unit_artifact` is not implemented without a storage backend.")
|
|
115
|
+
pass
|
|
93
116
|
|
|
117
|
+
@abstractmethod
|
|
94
118
|
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
95
|
-
|
|
96
|
-
logger.debug("`delete_experiment_artifact` is not implemented without a storage backend.")
|
|
119
|
+
pass
|
|
97
120
|
|
|
121
|
+
@abstractmethod
|
|
98
122
|
def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
|
|
99
|
-
|
|
100
|
-
logger.debug("`insert_experiment_artifact` is not implemented without a storage backend.")
|
|
123
|
+
pass
|
|
101
124
|
|
|
102
125
|
|
|
103
126
|
class XManagerWebAPI(XManagerAPI):
|
|
@@ -180,8 +203,6 @@ class XManagerWebAPI(XManagerAPI):
|
|
|
180
203
|
update_experiment as _update_experiment,
|
|
181
204
|
)
|
|
182
205
|
|
|
183
|
-
m = self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch))
|
|
184
|
-
|
|
185
206
|
_update_experiment.sync(
|
|
186
207
|
experiment_id,
|
|
187
208
|
client=self.client,
|
|
@@ -198,7 +219,7 @@ class XManagerWebAPI(XManagerAPI):
|
|
|
198
219
|
body=self.models.SlurmJob(**dataclasses.asdict(job)),
|
|
199
220
|
)
|
|
200
221
|
|
|
201
|
-
def insert_work_unit(self, experiment_id: int, work_unit:
|
|
222
|
+
def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
|
|
202
223
|
from xm_slurm_api_client.api.work_unit import ( # type: ignore
|
|
203
224
|
insert_work_unit as _insert_work_unit,
|
|
204
225
|
)
|
|
@@ -244,6 +265,253 @@ class XManagerWebAPI(XManagerAPI):
|
|
|
244
265
|
)
|
|
245
266
|
|
|
246
267
|
|
|
268
|
+
Base = declarative_base()
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class Experiment(Base):
|
|
272
|
+
__tablename__ = "experiments"
|
|
273
|
+
|
|
274
|
+
id = Column(Integer, primary_key=True)
|
|
275
|
+
title = Column(String)
|
|
276
|
+
description = Column(String)
|
|
277
|
+
note = Column(String)
|
|
278
|
+
tags = Column(String)
|
|
279
|
+
work_units = relationship("WorkUnit", back_populates="experiment")
|
|
280
|
+
artifacts = relationship("Artifact", back_populates="experiment")
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class WorkUnit(Base):
|
|
284
|
+
__tablename__ = "work_units"
|
|
285
|
+
|
|
286
|
+
id = Column(Integer, primary_key=True)
|
|
287
|
+
experiment_id = Column(Integer, ForeignKey("experiments.id"))
|
|
288
|
+
wid = Column(Integer)
|
|
289
|
+
identity = Column(String)
|
|
290
|
+
args = Column(String)
|
|
291
|
+
experiment = relationship("Experiment", back_populates="work_units")
|
|
292
|
+
jobs = relationship("SlurmJob", back_populates="work_unit")
|
|
293
|
+
artifacts = relationship("Artifact", back_populates="work_unit")
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class SlurmJob(Base):
|
|
297
|
+
__tablename__ = "slurm_jobs"
|
|
298
|
+
|
|
299
|
+
id = Column(Integer, primary_key=True)
|
|
300
|
+
work_unit_id = Column(Integer, ForeignKey("work_units.id"))
|
|
301
|
+
name = Column(String)
|
|
302
|
+
slurm_job_id = Column(Integer)
|
|
303
|
+
slurm_ssh_config = Column(String)
|
|
304
|
+
work_unit = relationship("WorkUnit", back_populates="jobs")
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class Artifact(Base):
|
|
308
|
+
__tablename__ = "artifacts"
|
|
309
|
+
|
|
310
|
+
id = Column(Integer, primary_key=True)
|
|
311
|
+
experiment_id = Column(Integer, ForeignKey("experiments.id"))
|
|
312
|
+
work_unit_id = Column(Integer, ForeignKey("work_units.id"))
|
|
313
|
+
name = Column(String)
|
|
314
|
+
uri = Column(String)
|
|
315
|
+
experiment = relationship("Experiment", back_populates="artifacts")
|
|
316
|
+
work_unit = relationship("WorkUnit", back_populates="artifacts")
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class XManagerSqliteAPI(XManagerAPI):
|
|
320
|
+
def __init__(self):
|
|
321
|
+
if "XM_SLURM_STATE_DIR" in os.environ:
|
|
322
|
+
db_path = Path(os.environ["XM_SLURM_STATE_DIR"]) / "db.sqlite3"
|
|
323
|
+
else:
|
|
324
|
+
db_path = Path.home() / ".local" / "state" / "xm-slurm" / "db.sqlite3"
|
|
325
|
+
logging.debug("Looking for db at: ", db_path)
|
|
326
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
327
|
+
engine = create_engine(f"sqlite:///{db_path}")
|
|
328
|
+
Base.metadata.create_all(engine)
|
|
329
|
+
self.Session = sessionmaker(bind=engine)
|
|
330
|
+
|
|
331
|
+
@contextmanager
|
|
332
|
+
def session_scope(self):
|
|
333
|
+
session = self.Session()
|
|
334
|
+
try:
|
|
335
|
+
yield session
|
|
336
|
+
session.commit()
|
|
337
|
+
except:
|
|
338
|
+
session.rollback()
|
|
339
|
+
raise
|
|
340
|
+
finally:
|
|
341
|
+
session.close()
|
|
342
|
+
|
|
343
|
+
def get_experiment(self, xid: int) -> ExperimentModel:
|
|
344
|
+
with self.session_scope() as session:
|
|
345
|
+
experiment = session.query(Experiment).filter(Experiment.id == xid).first()
|
|
346
|
+
if not experiment:
|
|
347
|
+
raise ValueError(f"Experiment with id {xid} not found")
|
|
348
|
+
|
|
349
|
+
work_units = []
|
|
350
|
+
for wu in experiment.work_units:
|
|
351
|
+
jobs = [
|
|
352
|
+
SlurmJobModel(
|
|
353
|
+
name=job.name,
|
|
354
|
+
slurm_job_id=job.slurm_job_id,
|
|
355
|
+
slurm_ssh_config=job.slurm_ssh_config,
|
|
356
|
+
)
|
|
357
|
+
for job in wu.jobs
|
|
358
|
+
]
|
|
359
|
+
artifacts = [
|
|
360
|
+
ArtifactModel(name=artifact.name, uri=artifact.uri) for artifact in wu.artifacts
|
|
361
|
+
]
|
|
362
|
+
work_units.append(
|
|
363
|
+
WorkUnitModel(
|
|
364
|
+
wid=wu.wid,
|
|
365
|
+
identity=wu.identity,
|
|
366
|
+
args=wu.args,
|
|
367
|
+
jobs=jobs,
|
|
368
|
+
artifacts=artifacts,
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
artifacts = [
|
|
373
|
+
ArtifactModel(name=artifact.name, uri=artifact.uri)
|
|
374
|
+
for artifact in experiment.artifacts
|
|
375
|
+
]
|
|
376
|
+
|
|
377
|
+
return ExperimentModel(
|
|
378
|
+
title=experiment.title,
|
|
379
|
+
description=experiment.description,
|
|
380
|
+
note=experiment.note,
|
|
381
|
+
tags=experiment.tags.split(",") if experiment.tags else None,
|
|
382
|
+
work_units=work_units,
|
|
383
|
+
artifacts=artifacts,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def delete_experiment(self, experiment_id: int) -> None:
|
|
387
|
+
with self.session_scope() as session:
|
|
388
|
+
experiment = session.query(Experiment).filter(Experiment.id == experiment_id).first()
|
|
389
|
+
if experiment:
|
|
390
|
+
session.delete(experiment)
|
|
391
|
+
|
|
392
|
+
def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
|
|
393
|
+
with self.session_scope() as session:
|
|
394
|
+
new_experiment = Experiment(
|
|
395
|
+
title=experiment.title,
|
|
396
|
+
description=experiment.description,
|
|
397
|
+
note=experiment.note,
|
|
398
|
+
tags=",".join(experiment.tags) if experiment.tags else None,
|
|
399
|
+
)
|
|
400
|
+
session.add(new_experiment)
|
|
401
|
+
session.flush()
|
|
402
|
+
return new_experiment.id
|
|
403
|
+
|
|
404
|
+
def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
|
|
405
|
+
with self.session_scope() as session:
|
|
406
|
+
experiment = session.query(Experiment).filter(Experiment.id == experiment_id).first()
|
|
407
|
+
if experiment:
|
|
408
|
+
if experiment_patch.title is not None:
|
|
409
|
+
experiment.title = experiment_patch.title
|
|
410
|
+
if experiment_patch.description is not None:
|
|
411
|
+
experiment.description = experiment_patch.description
|
|
412
|
+
if experiment_patch.note is not None:
|
|
413
|
+
experiment.note = experiment_patch.note
|
|
414
|
+
if experiment_patch.tags is not None:
|
|
415
|
+
experiment.tags = ",".join(experiment_patch.tags)
|
|
416
|
+
|
|
417
|
+
def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
|
|
418
|
+
with self.session_scope() as session:
|
|
419
|
+
work_unit = (
|
|
420
|
+
session.query(WorkUnit)
|
|
421
|
+
.filter_by(experiment_id=experiment_id, wid=work_unit_id)
|
|
422
|
+
.first()
|
|
423
|
+
)
|
|
424
|
+
if work_unit:
|
|
425
|
+
new_job = SlurmJob(
|
|
426
|
+
work_unit_id=work_unit.id,
|
|
427
|
+
name=job.name,
|
|
428
|
+
slurm_job_id=job.slurm_job_id,
|
|
429
|
+
slurm_ssh_config=job.slurm_ssh_config,
|
|
430
|
+
)
|
|
431
|
+
session.add(new_job)
|
|
432
|
+
else:
|
|
433
|
+
raise ValueError(
|
|
434
|
+
f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
|
|
438
|
+
with self.session_scope() as session:
|
|
439
|
+
new_work_unit = WorkUnit(
|
|
440
|
+
experiment_id=experiment_id,
|
|
441
|
+
wid=work_unit.wid,
|
|
442
|
+
identity=work_unit.identity,
|
|
443
|
+
args=work_unit.args,
|
|
444
|
+
)
|
|
445
|
+
session.add(new_work_unit)
|
|
446
|
+
for job in work_unit.jobs:
|
|
447
|
+
new_job = SlurmJob(
|
|
448
|
+
work_unit_id=new_work_unit.id,
|
|
449
|
+
name=job.name,
|
|
450
|
+
slurm_job_id=job.slurm_job_id,
|
|
451
|
+
slurm_ssh_config=job.slurm_ssh_config,
|
|
452
|
+
)
|
|
453
|
+
session.add(new_job)
|
|
454
|
+
for artifact in work_unit.artifacts:
|
|
455
|
+
new_artifact = Artifact(
|
|
456
|
+
work_unit_id=new_work_unit.id, name=artifact.name, uri=artifact.uri
|
|
457
|
+
)
|
|
458
|
+
session.add(new_artifact)
|
|
459
|
+
|
|
460
|
+
def update_work_unit(
|
|
461
|
+
self, experiment_id: int, work_unit_id: int, patch: ExperimentUnitPatchModel
|
|
462
|
+
) -> None:
|
|
463
|
+
with self.session_scope() as session:
|
|
464
|
+
work_unit = (
|
|
465
|
+
session.query(WorkUnit)
|
|
466
|
+
.filter(WorkUnit.experiment_id == experiment_id, WorkUnit.wid == work_unit_id)
|
|
467
|
+
.first()
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
if work_unit:
|
|
471
|
+
if patch.identity is not None:
|
|
472
|
+
work_unit.identity = patch.identity
|
|
473
|
+
if patch.args is not None:
|
|
474
|
+
work_unit.args = patch.args
|
|
475
|
+
else:
|
|
476
|
+
raise ValueError(
|
|
477
|
+
f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
|
|
481
|
+
with self.session_scope() as session:
|
|
482
|
+
artifact = (
|
|
483
|
+
session.query(Artifact)
|
|
484
|
+
.filter(Artifact.work_unit_id == work_unit_id, Artifact.name == name)
|
|
485
|
+
.first()
|
|
486
|
+
)
|
|
487
|
+
if artifact:
|
|
488
|
+
session.delete(artifact)
|
|
489
|
+
|
|
490
|
+
def insert_work_unit_artifact(
|
|
491
|
+
self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
|
|
492
|
+
) -> None:
|
|
493
|
+
with self.session_scope() as session:
|
|
494
|
+
new_artifact = Artifact(work_unit_id=work_unit_id, name=artifact.name, uri=artifact.uri)
|
|
495
|
+
session.add(new_artifact)
|
|
496
|
+
|
|
497
|
+
def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
|
|
498
|
+
with self.session_scope() as session:
|
|
499
|
+
artifact = (
|
|
500
|
+
session.query(Artifact)
|
|
501
|
+
.filter(Artifact.experiment_id == experiment_id, Artifact.name == name)
|
|
502
|
+
.first()
|
|
503
|
+
)
|
|
504
|
+
if artifact:
|
|
505
|
+
session.delete(artifact)
|
|
506
|
+
|
|
507
|
+
def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
|
|
508
|
+
with self.session_scope() as session:
|
|
509
|
+
new_artifact = Artifact(
|
|
510
|
+
experiment_id=experiment_id, name=artifact.name, uri=artifact.uri
|
|
511
|
+
)
|
|
512
|
+
session.add(new_artifact)
|
|
513
|
+
|
|
514
|
+
|
|
247
515
|
@functools.cache
|
|
248
516
|
def client() -> XManagerAPI:
|
|
249
517
|
if importlib.util.find_spec("xm_slurm_api_client") is not None:
|
|
@@ -257,5 +525,4 @@ def client() -> XManagerAPI:
|
|
|
257
525
|
"Disabling XManager API client."
|
|
258
526
|
)
|
|
259
527
|
|
|
260
|
-
|
|
261
|
-
return XManagerAPI()
|
|
528
|
+
return XManagerSqliteAPI()
|
xm_slurm/batching.py
CHANGED
|
@@ -21,7 +21,7 @@ def stack_bound_arguments(
|
|
|
21
21
|
signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
|
|
22
22
|
) -> inspect.BoundArguments:
|
|
23
23
|
"""Stacks bound arguments into a single bound arguments object."""
|
|
24
|
-
stacked_args = collections.OrderedDict()
|
|
24
|
+
stacked_args = collections.OrderedDict[str, Any]()
|
|
25
25
|
for bound_args in bound_arguments:
|
|
26
26
|
for name, value in bound_args.arguments.items():
|
|
27
27
|
stacked_args.setdefault(name, [])
|
|
@@ -59,7 +59,7 @@ class batch(Generic[R]):
|
|
|
59
59
|
self.loop: asyncio.AbstractEventLoop | None = None
|
|
60
60
|
self.process_batch_task: asyncio.Task | None = None
|
|
61
61
|
|
|
62
|
-
self.queue = asyncio.Queue()
|
|
62
|
+
self.queue = asyncio.Queue[Request]()
|
|
63
63
|
|
|
64
64
|
async def _process_batch(self):
|
|
65
65
|
assert self.loop is not None
|
|
@@ -128,12 +128,12 @@ class batch(Generic[R]):
|
|
|
128
128
|
# until then this is just a hack
|
|
129
129
|
return asyncio.coroutines._is_coroutine # type: ignore
|
|
130
130
|
|
|
131
|
-
async def __call__(self, *args, **kwargs) ->
|
|
131
|
+
async def __call__(self, *args, **kwargs) -> R:
|
|
132
132
|
if self.loop is None and self.process_batch_task is None:
|
|
133
133
|
self.loop = asyncio.get_event_loop()
|
|
134
134
|
self.process_batch_task = self.loop.create_task(self._process_batch())
|
|
135
135
|
|
|
136
|
-
future = asyncio.Future()
|
|
136
|
+
future = asyncio.Future[R]()
|
|
137
137
|
bound_args = self.signature.bind(*args, **kwargs)
|
|
138
138
|
self.queue.put_nowait(Request(args=bound_args, future=future))
|
|
139
139
|
return await future
|
xm_slurm/config.py
CHANGED
|
@@ -2,12 +2,15 @@ import dataclasses
|
|
|
2
2
|
import enum
|
|
3
3
|
import functools
|
|
4
4
|
import getpass
|
|
5
|
+
import json
|
|
5
6
|
import os
|
|
6
7
|
import pathlib
|
|
7
8
|
from typing import Literal, Mapping, NamedTuple
|
|
8
9
|
|
|
9
10
|
import asyncssh
|
|
10
11
|
|
|
12
|
+
from xm_slurm import constants
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
class ContainerRuntime(enum.Enum):
|
|
13
16
|
"""The container engine to use."""
|
|
@@ -46,15 +49,103 @@ class PublicKey(NamedTuple):
|
|
|
46
49
|
key: str
|
|
47
50
|
|
|
48
51
|
|
|
49
|
-
@dataclasses.dataclass
|
|
50
|
-
class
|
|
51
|
-
name: str
|
|
52
|
-
|
|
52
|
+
@dataclasses.dataclass
|
|
53
|
+
class SlurmSSHConfig:
|
|
53
54
|
host: str
|
|
54
55
|
host_public_key: PublicKey | None = None
|
|
55
56
|
user: str | None = None
|
|
56
57
|
port: int | None = None
|
|
57
58
|
|
|
59
|
+
@functools.cached_property
|
|
60
|
+
def known_hosts(self) -> asyncssh.SSHKnownHosts | None:
|
|
61
|
+
if self.host_public_key is None:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
return asyncssh.import_known_hosts(
|
|
65
|
+
f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@functools.cached_property
|
|
69
|
+
def config(self) -> asyncssh.config.SSHConfig:
|
|
70
|
+
ssh_config_paths = []
|
|
71
|
+
if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
|
|
72
|
+
ssh_config_paths.append(ssh_config)
|
|
73
|
+
if (xm_ssh_config_var := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
|
|
74
|
+
xm_ssh_config := pathlib.Path(xm_ssh_config_var).expanduser()
|
|
75
|
+
).exists():
|
|
76
|
+
ssh_config_paths.append(xm_ssh_config)
|
|
77
|
+
|
|
78
|
+
config = asyncssh.config.SSHClientConfig.load(
|
|
79
|
+
None,
|
|
80
|
+
ssh_config_paths,
|
|
81
|
+
False,
|
|
82
|
+
getpass.getuser(),
|
|
83
|
+
self.user or (),
|
|
84
|
+
self.host,
|
|
85
|
+
self.port or (),
|
|
86
|
+
)
|
|
87
|
+
if config.get("Hostname") is None and (
|
|
88
|
+
constants.DOMAIN_NAME_REGEX.match(self.host)
|
|
89
|
+
or constants.IPV4_REGEX.match(self.host)
|
|
90
|
+
or constants.IPV6_REGEX.match(self.host)
|
|
91
|
+
):
|
|
92
|
+
config._options["Hostname"] = self.host
|
|
93
|
+
elif config.get("Hostname") is None:
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
f"Failed to parse hostname from host `{self.host}` using "
|
|
96
|
+
f"SSH configs: {', '.join(map(str, ssh_config_paths))} and "
|
|
97
|
+
f"provided hostname `{self.host}` isn't a valid domain name "
|
|
98
|
+
"or IPv{4,6} address."
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if config.get("User") is None:
|
|
102
|
+
raise RuntimeError(
|
|
103
|
+
f"We could not find a user for the cluster configuration: `{self.host}`. "
|
|
104
|
+
"No user was specified in the configuration and we could not parse "
|
|
105
|
+
f"any users for host `{config.get('Hostname')}` from the SSH configs: "
|
|
106
|
+
f"{', '.join(map(lambda h: f'`{h}`', ssh_config_paths))}. Please either specify a user "
|
|
107
|
+
"in the configuration or add a user to your SSH configuration under the block "
|
|
108
|
+
f"`Host {config.get('Hostname')}`."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return config
|
|
112
|
+
|
|
113
|
+
@functools.cached_property
|
|
114
|
+
def connection_options(self) -> asyncssh.SSHClientConnectionOptions:
|
|
115
|
+
options = asyncssh.SSHClientConnectionOptions(config=None)
|
|
116
|
+
options.prepare(last_config=self.config, known_hosts=self.known_hosts)
|
|
117
|
+
return options
|
|
118
|
+
|
|
119
|
+
def serialize(self):
|
|
120
|
+
return json.dumps({
|
|
121
|
+
"host": self.host,
|
|
122
|
+
"host_public_key": self.host_public_key,
|
|
123
|
+
"user": self.user,
|
|
124
|
+
"port": self.port,
|
|
125
|
+
})
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def deserialize(cls, data):
|
|
129
|
+
data = json.loads(data)
|
|
130
|
+
return cls(
|
|
131
|
+
host=data["host"],
|
|
132
|
+
host_public_key=PublicKey(*data["host_public_key"])
|
|
133
|
+
if data["host_public_key"]
|
|
134
|
+
else None,
|
|
135
|
+
user=data["user"],
|
|
136
|
+
port=data["port"],
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def __hash__(self):
|
|
140
|
+
return hash((self.host, self.host_public_key, self.user, self.port))
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
|
144
|
+
class SlurmClusterConfig:
|
|
145
|
+
name: str
|
|
146
|
+
|
|
147
|
+
ssh: SlurmSSHConfig
|
|
148
|
+
|
|
58
149
|
# Job submission directory
|
|
59
150
|
cwd: str | None = None
|
|
60
151
|
|
|
@@ -81,7 +172,9 @@ class SlurmClusterConfig:
|
|
|
81
172
|
)
|
|
82
173
|
|
|
83
174
|
# Resource mapping
|
|
84
|
-
resources: Mapping[
|
|
175
|
+
resources: Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
176
|
+
|
|
177
|
+
features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
|
|
85
178
|
|
|
86
179
|
def __post_init__(self) -> None:
|
|
87
180
|
for src, dst in self.mounts.items():
|
|
@@ -99,57 +192,9 @@ class SlurmClusterConfig:
|
|
|
99
192
|
if not pathlib.Path(dst).is_absolute():
|
|
100
193
|
raise ValueError(f"Mount destination must be an absolute path: {dst}")
|
|
101
194
|
|
|
102
|
-
@functools.cached_property
|
|
103
|
-
def ssh_known_hosts(self) -> asyncssh.SSHKnownHosts | None:
|
|
104
|
-
if self.host_public_key is None:
|
|
105
|
-
return None
|
|
106
|
-
|
|
107
|
-
return asyncssh.import_known_hosts(
|
|
108
|
-
f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
@functools.cached_property
|
|
112
|
-
def ssh_config(self) -> asyncssh.config.SSHConfig:
|
|
113
|
-
ssh_config_paths = []
|
|
114
|
-
if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
|
|
115
|
-
ssh_config_paths.append(ssh_config)
|
|
116
|
-
if (xm_ssh_config := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
|
|
117
|
-
xm_ssh_config := pathlib.Path(xm_ssh_config).expanduser()
|
|
118
|
-
).exists():
|
|
119
|
-
ssh_config_paths.append(xm_ssh_config)
|
|
120
|
-
|
|
121
|
-
config = asyncssh.config.SSHClientConfig.load(
|
|
122
|
-
None,
|
|
123
|
-
ssh_config_paths,
|
|
124
|
-
True,
|
|
125
|
-
getpass.getuser(),
|
|
126
|
-
self.user or (),
|
|
127
|
-
self.host or (),
|
|
128
|
-
self.port or (),
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
if config.get("Hostname") is None:
|
|
132
|
-
raise RuntimeError(
|
|
133
|
-
f"Failed to parse hostname from host `{self.host}` using SSH configs: {', '.join(map(str, ssh_config_paths))}"
|
|
134
|
-
)
|
|
135
|
-
if config.get("User") is None:
|
|
136
|
-
raise RuntimeError(
|
|
137
|
-
f"Failed to parse user from SSH configs: {', '.join(map(str, ssh_config_paths))}"
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
return config
|
|
141
|
-
|
|
142
|
-
@functools.cached_property
|
|
143
|
-
def ssh_connection_options(self) -> asyncssh.SSHClientConnectionOptions:
|
|
144
|
-
options = asyncssh.SSHClientConnectionOptions(config=None)
|
|
145
|
-
options.prepare(last_config=self.ssh_config, known_hosts=self.ssh_known_hosts)
|
|
146
|
-
return options
|
|
147
|
-
|
|
148
195
|
def __hash__(self):
|
|
149
196
|
return hash((
|
|
150
|
-
self.
|
|
151
|
-
self.user,
|
|
152
|
-
self.port,
|
|
197
|
+
self.ssh,
|
|
153
198
|
self.cwd,
|
|
154
199
|
self.prolog,
|
|
155
200
|
self.epilog,
|
xm_slurm/constants.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
IMAGE_URI_REGEX = re.compile(
|
|
4
|
+
r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
DOMAIN_NAME_REGEX = re.compile(
|
|
8
|
+
r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
|
|
9
|
+
)
|
|
10
|
+
IPV4_REGEX = re.compile(
|
|
11
|
+
r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
|
12
|
+
)
|
|
13
|
+
IPV6_REGEX = re.compile(
|
|
14
|
+
r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
|
|
15
|
+
)
|
|
File without changes
|