xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/__init__.py CHANGED
@@ -4,6 +4,7 @@ from xm_slurm.executables import Dockerfile, DockerImage
4
4
  from xm_slurm.executors import Slurm, SlurmSpec
5
5
  from xm_slurm.experiment import (
6
6
  Artifact,
7
+ SlurmExperiment,
7
8
  create_experiment,
8
9
  get_current_experiment,
9
10
  get_current_work_unit,
@@ -14,8 +15,8 @@ from xm_slurm.packageables import (
14
15
  docker_container,
15
16
  docker_image,
16
17
  mamba_container,
17
- pdm_container,
18
18
  python_container,
19
+ uv_container,
19
20
  )
20
21
  from xm_slurm.resources import JobRequirements, ResourceQuantity, ResourceType
21
22
 
@@ -35,10 +36,11 @@ __all__ = [
35
36
  "get_experiment",
36
37
  "JobRequirements",
37
38
  "mamba_container",
38
- "pdm_container",
39
+ "uv_container",
39
40
  "python_container",
40
41
  "ResourceQuantity",
41
42
  "ResourceType",
42
43
  "Slurm",
43
44
  "SlurmSpec",
45
+ "SlurmExperiment",
44
46
  ]
xm_slurm/api.py CHANGED
@@ -1,15 +1,27 @@
1
1
  import dataclasses
2
+ import enum
2
3
  import functools
3
4
  import importlib.util
4
5
  import logging
5
6
  import os
6
- import time
7
7
  import typing
8
+ from abc import ABC, abstractmethod
9
+ from contextlib import contextmanager
10
+ from pathlib import Path
8
11
  from typing import Any
9
12
 
13
+ from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
14
+ from sqlalchemy.ext.declarative import declarative_base
15
+ from sqlalchemy.orm import relationship, sessionmaker
16
+
10
17
  logger = logging.getLogger(__name__)
11
18
 
12
19
 
20
+ class ExperimentUnitRole(enum.Enum):
21
+ WORK_UNIT = enum.auto()
22
+ AUX_UNIT = enum.auto()
23
+
24
+
13
25
  @dataclasses.dataclass(kw_only=True, frozen=True)
14
26
  class ExperimentPatchModel:
15
27
  title: str | None = None
@@ -22,7 +34,7 @@ class ExperimentPatchModel:
22
34
  class SlurmJobModel:
23
35
  name: str
24
36
  slurm_job_id: int
25
- slurm_cluster: str
37
+ slurm_ssh_config: str
26
38
 
27
39
 
28
40
  @dataclasses.dataclass(kw_only=True, frozen=True)
@@ -32,15 +44,21 @@ class ArtifactModel:
32
44
 
33
45
 
34
46
  @dataclasses.dataclass(kw_only=True, frozen=True)
35
- class WorkUnitPatchModel:
36
- wid: int
47
+ class ExperimentUnitModel:
48
+ identity: str
49
+ args: str | None = None
50
+ jobs: list[SlurmJobModel] = dataclasses.field(default_factory=list)
51
+
52
+
53
+ @dataclasses.dataclass(kw_only=True, frozen=True)
54
+ class ExperimentUnitPatchModel:
37
55
  identity: str | None
38
56
  args: str | None = None
39
57
 
40
58
 
41
59
  @dataclasses.dataclass(kw_only=True, frozen=True)
42
- class WorkUnitModel(WorkUnitPatchModel):
43
- jobs: list[SlurmJobModel] = dataclasses.field(default_factory=list)
60
+ class WorkUnitModel(ExperimentUnitModel):
61
+ wid: int
44
62
  artifacts: list[ArtifactModel] = dataclasses.field(default_factory=list)
45
63
 
46
64
 
@@ -55,49 +73,54 @@ class ExperimentModel:
55
73
  artifacts: list[ArtifactModel]
56
74
 
57
75
 
58
- class XManagerAPI:
76
+ class XManagerAPI(ABC):
77
+ @abstractmethod
59
78
  def get_experiment(self, xid: int) -> ExperimentModel:
60
- del xid
61
- raise NotImplementedError("`get_experiment` is not implemented without a storage backend.")
79
+ pass
62
80
 
81
+ @abstractmethod
63
82
  def delete_experiment(self, experiment_id: int) -> None:
64
- del experiment_id
65
- logger.debug("`delete_experiment` is not implemented without a storage backend.")
83
+ pass
66
84
 
85
+ @abstractmethod
67
86
  def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
68
- del experiment
69
- logger.debug("`insert_experiment` is not implemented without a storage backend.")
70
- return int(time.time() * 10**3)
87
+ pass
71
88
 
89
+ @abstractmethod
72
90
  def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
73
- del experiment_id, experiment_patch
74
- logger.debug("`update_experiment` is not implemented without a storage backend.")
91
+ pass
75
92
 
93
+ @abstractmethod
76
94
  def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
77
- del experiment_id, work_unit_id, job
78
- logger.debug("`insert_job` is not implemented without a storage backend.")
95
+ pass
79
96
 
80
- def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
81
- del experiment_id, work_unit
82
- logger.debug("`insert_work_unit` is not implemented without a storage backend.")
97
+ @abstractmethod
98
+ def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
99
+ pass
83
100
 
101
+ @abstractmethod
102
+ def update_work_unit(
103
+ self, experiment_id: int, work_unit_id: int, patch: ExperimentUnitPatchModel
104
+ ) -> None:
105
+ pass
106
+
107
+ @abstractmethod
84
108
  def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
85
- del experiment_id, work_unit_id, name
86
- logger.debug("`delete_work_unit_artifact` is not implemented without a storage backend.")
109
+ pass
87
110
 
111
+ @abstractmethod
88
112
  def insert_work_unit_artifact(
89
113
  self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
90
114
  ) -> None:
91
- del experiment_id, work_unit_id, artifact
92
- logger.debug("`insert_work_unit_artifact` is not implemented without a storage backend.")
115
+ pass
93
116
 
117
+ @abstractmethod
94
118
  def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
95
- del experiment_id, name
96
- logger.debug("`delete_experiment_artifact` is not implemented without a storage backend.")
119
+ pass
97
120
 
121
+ @abstractmethod
98
122
  def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
99
- del experiment_id, artifact
100
- logger.debug("`insert_experiment_artifact` is not implemented without a storage backend.")
123
+ pass
101
124
 
102
125
 
103
126
  class XManagerWebAPI(XManagerAPI):
@@ -180,8 +203,6 @@ class XManagerWebAPI(XManagerAPI):
180
203
  update_experiment as _update_experiment,
181
204
  )
182
205
 
183
- m = self.models.ExperimentPatch(**dataclasses.asdict(experiment_patch))
184
-
185
206
  _update_experiment.sync(
186
207
  experiment_id,
187
208
  client=self.client,
@@ -198,7 +219,7 @@ class XManagerWebAPI(XManagerAPI):
198
219
  body=self.models.SlurmJob(**dataclasses.asdict(job)),
199
220
  )
200
221
 
201
- def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitPatchModel) -> None:
222
+ def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
202
223
  from xm_slurm_api_client.api.work_unit import ( # type: ignore
203
224
  insert_work_unit as _insert_work_unit,
204
225
  )
@@ -244,6 +265,253 @@ class XManagerWebAPI(XManagerAPI):
244
265
  )
245
266
 
246
267
 
268
+ Base = declarative_base()
269
+
270
+
271
+ class Experiment(Base):
272
+ __tablename__ = "experiments"
273
+
274
+ id = Column(Integer, primary_key=True)
275
+ title = Column(String)
276
+ description = Column(String)
277
+ note = Column(String)
278
+ tags = Column(String)
279
+ work_units = relationship("WorkUnit", back_populates="experiment")
280
+ artifacts = relationship("Artifact", back_populates="experiment")
281
+
282
+
283
+ class WorkUnit(Base):
284
+ __tablename__ = "work_units"
285
+
286
+ id = Column(Integer, primary_key=True)
287
+ experiment_id = Column(Integer, ForeignKey("experiments.id"))
288
+ wid = Column(Integer)
289
+ identity = Column(String)
290
+ args = Column(String)
291
+ experiment = relationship("Experiment", back_populates="work_units")
292
+ jobs = relationship("SlurmJob", back_populates="work_unit")
293
+ artifacts = relationship("Artifact", back_populates="work_unit")
294
+
295
+
296
+ class SlurmJob(Base):
297
+ __tablename__ = "slurm_jobs"
298
+
299
+ id = Column(Integer, primary_key=True)
300
+ work_unit_id = Column(Integer, ForeignKey("work_units.id"))
301
+ name = Column(String)
302
+ slurm_job_id = Column(Integer)
303
+ slurm_ssh_config = Column(String)
304
+ work_unit = relationship("WorkUnit", back_populates="jobs")
305
+
306
+
307
+ class Artifact(Base):
308
+ __tablename__ = "artifacts"
309
+
310
+ id = Column(Integer, primary_key=True)
311
+ experiment_id = Column(Integer, ForeignKey("experiments.id"))
312
+ work_unit_id = Column(Integer, ForeignKey("work_units.id"))
313
+ name = Column(String)
314
+ uri = Column(String)
315
+ experiment = relationship("Experiment", back_populates="artifacts")
316
+ work_unit = relationship("WorkUnit", back_populates="artifacts")
317
+
318
+
319
+ class XManagerSqliteAPI(XManagerAPI):
320
+ def __init__(self):
321
+ if "XM_SLURM_STATE_DIR" in os.environ:
322
+ db_path = Path(os.environ["XM_SLURM_STATE_DIR"]) / "db.sqlite3"
323
+ else:
324
+ db_path = Path.home() / ".local" / "state" / "xm-slurm" / "db.sqlite3"
325
+ logging.debug("Looking for db at: ", db_path)
326
+ db_path.parent.mkdir(parents=True, exist_ok=True)
327
+ engine = create_engine(f"sqlite:///{db_path}")
328
+ Base.metadata.create_all(engine)
329
+ self.Session = sessionmaker(bind=engine)
330
+
331
+ @contextmanager
332
+ def session_scope(self):
333
+ session = self.Session()
334
+ try:
335
+ yield session
336
+ session.commit()
337
+ except:
338
+ session.rollback()
339
+ raise
340
+ finally:
341
+ session.close()
342
+
343
+ def get_experiment(self, xid: int) -> ExperimentModel:
344
+ with self.session_scope() as session:
345
+ experiment = session.query(Experiment).filter(Experiment.id == xid).first()
346
+ if not experiment:
347
+ raise ValueError(f"Experiment with id {xid} not found")
348
+
349
+ work_units = []
350
+ for wu in experiment.work_units:
351
+ jobs = [
352
+ SlurmJobModel(
353
+ name=job.name,
354
+ slurm_job_id=job.slurm_job_id,
355
+ slurm_ssh_config=job.slurm_ssh_config,
356
+ )
357
+ for job in wu.jobs
358
+ ]
359
+ artifacts = [
360
+ ArtifactModel(name=artifact.name, uri=artifact.uri) for artifact in wu.artifacts
361
+ ]
362
+ work_units.append(
363
+ WorkUnitModel(
364
+ wid=wu.wid,
365
+ identity=wu.identity,
366
+ args=wu.args,
367
+ jobs=jobs,
368
+ artifacts=artifacts,
369
+ )
370
+ )
371
+
372
+ artifacts = [
373
+ ArtifactModel(name=artifact.name, uri=artifact.uri)
374
+ for artifact in experiment.artifacts
375
+ ]
376
+
377
+ return ExperimentModel(
378
+ title=experiment.title,
379
+ description=experiment.description,
380
+ note=experiment.note,
381
+ tags=experiment.tags.split(",") if experiment.tags else None,
382
+ work_units=work_units,
383
+ artifacts=artifacts,
384
+ )
385
+
386
+ def delete_experiment(self, experiment_id: int) -> None:
387
+ with self.session_scope() as session:
388
+ experiment = session.query(Experiment).filter(Experiment.id == experiment_id).first()
389
+ if experiment:
390
+ session.delete(experiment)
391
+
392
+ def insert_experiment(self, experiment: ExperimentPatchModel) -> int:
393
+ with self.session_scope() as session:
394
+ new_experiment = Experiment(
395
+ title=experiment.title,
396
+ description=experiment.description,
397
+ note=experiment.note,
398
+ tags=",".join(experiment.tags) if experiment.tags else None,
399
+ )
400
+ session.add(new_experiment)
401
+ session.flush()
402
+ return new_experiment.id
403
+
404
+ def update_experiment(self, experiment_id: int, experiment_patch: ExperimentPatchModel) -> None:
405
+ with self.session_scope() as session:
406
+ experiment = session.query(Experiment).filter(Experiment.id == experiment_id).first()
407
+ if experiment:
408
+ if experiment_patch.title is not None:
409
+ experiment.title = experiment_patch.title
410
+ if experiment_patch.description is not None:
411
+ experiment.description = experiment_patch.description
412
+ if experiment_patch.note is not None:
413
+ experiment.note = experiment_patch.note
414
+ if experiment_patch.tags is not None:
415
+ experiment.tags = ",".join(experiment_patch.tags)
416
+
417
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: SlurmJobModel) -> None:
418
+ with self.session_scope() as session:
419
+ work_unit = (
420
+ session.query(WorkUnit)
421
+ .filter_by(experiment_id=experiment_id, wid=work_unit_id)
422
+ .first()
423
+ )
424
+ if work_unit:
425
+ new_job = SlurmJob(
426
+ work_unit_id=work_unit.id,
427
+ name=job.name,
428
+ slurm_job_id=job.slurm_job_id,
429
+ slurm_ssh_config=job.slurm_ssh_config,
430
+ )
431
+ session.add(new_job)
432
+ else:
433
+ raise ValueError(
434
+ f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
435
+ )
436
+
437
+ def insert_work_unit(self, experiment_id: int, work_unit: WorkUnitModel) -> None:
438
+ with self.session_scope() as session:
439
+ new_work_unit = WorkUnit(
440
+ experiment_id=experiment_id,
441
+ wid=work_unit.wid,
442
+ identity=work_unit.identity,
443
+ args=work_unit.args,
444
+ )
445
+ session.add(new_work_unit)
446
+ for job in work_unit.jobs:
447
+ new_job = SlurmJob(
448
+ work_unit_id=new_work_unit.id,
449
+ name=job.name,
450
+ slurm_job_id=job.slurm_job_id,
451
+ slurm_ssh_config=job.slurm_ssh_config,
452
+ )
453
+ session.add(new_job)
454
+ for artifact in work_unit.artifacts:
455
+ new_artifact = Artifact(
456
+ work_unit_id=new_work_unit.id, name=artifact.name, uri=artifact.uri
457
+ )
458
+ session.add(new_artifact)
459
+
460
+ def update_work_unit(
461
+ self, experiment_id: int, work_unit_id: int, patch: ExperimentUnitPatchModel
462
+ ) -> None:
463
+ with self.session_scope() as session:
464
+ work_unit = (
465
+ session.query(WorkUnit)
466
+ .filter(WorkUnit.experiment_id == experiment_id, WorkUnit.wid == work_unit_id)
467
+ .first()
468
+ )
469
+
470
+ if work_unit:
471
+ if patch.identity is not None:
472
+ work_unit.identity = patch.identity
473
+ if patch.args is not None:
474
+ work_unit.args = patch.args
475
+ else:
476
+ raise ValueError(
477
+ f"Work unit with id {work_unit_id} not found in experiment {experiment_id}"
478
+ )
479
+
480
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
481
+ with self.session_scope() as session:
482
+ artifact = (
483
+ session.query(Artifact)
484
+ .filter(Artifact.work_unit_id == work_unit_id, Artifact.name == name)
485
+ .first()
486
+ )
487
+ if artifact:
488
+ session.delete(artifact)
489
+
490
+ def insert_work_unit_artifact(
491
+ self, experiment_id: int, work_unit_id: int, artifact: ArtifactModel
492
+ ) -> None:
493
+ with self.session_scope() as session:
494
+ new_artifact = Artifact(work_unit_id=work_unit_id, name=artifact.name, uri=artifact.uri)
495
+ session.add(new_artifact)
496
+
497
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
498
+ with self.session_scope() as session:
499
+ artifact = (
500
+ session.query(Artifact)
501
+ .filter(Artifact.experiment_id == experiment_id, Artifact.name == name)
502
+ .first()
503
+ )
504
+ if artifact:
505
+ session.delete(artifact)
506
+
507
+ def insert_experiment_artifact(self, experiment_id: int, artifact: ArtifactModel) -> None:
508
+ with self.session_scope() as session:
509
+ new_artifact = Artifact(
510
+ experiment_id=experiment_id, name=artifact.name, uri=artifact.uri
511
+ )
512
+ session.add(new_artifact)
513
+
514
+
247
515
  @functools.cache
248
516
  def client() -> XManagerAPI:
249
517
  if importlib.util.find_spec("xm_slurm_api_client") is not None:
@@ -257,5 +525,4 @@ def client() -> XManagerAPI:
257
525
  "Disabling XManager API client."
258
526
  )
259
527
 
260
- logger.debug("xm_slurm_api_client not found... skipping logging to the API.")
261
- return XManagerAPI()
528
+ return XManagerSqliteAPI()
xm_slurm/batching.py CHANGED
@@ -21,7 +21,7 @@ def stack_bound_arguments(
21
21
  signature: inspect.Signature, bound_arguments: Sequence[inspect.BoundArguments]
22
22
  ) -> inspect.BoundArguments:
23
23
  """Stacks bound arguments into a single bound arguments object."""
24
- stacked_args = collections.OrderedDict()
24
+ stacked_args = collections.OrderedDict[str, Any]()
25
25
  for bound_args in bound_arguments:
26
26
  for name, value in bound_args.arguments.items():
27
27
  stacked_args.setdefault(name, [])
@@ -59,7 +59,7 @@ class batch(Generic[R]):
59
59
  self.loop: asyncio.AbstractEventLoop | None = None
60
60
  self.process_batch_task: asyncio.Task | None = None
61
61
 
62
- self.queue = asyncio.Queue()
62
+ self.queue = asyncio.Queue[Request]()
63
63
 
64
64
  async def _process_batch(self):
65
65
  assert self.loop is not None
@@ -128,12 +128,12 @@ class batch(Generic[R]):
128
128
  # until then this is just a hack
129
129
  return asyncio.coroutines._is_coroutine # type: ignore
130
130
 
131
- async def __call__(self, *args, **kwargs) -> asyncio.Future[R]:
131
+ async def __call__(self, *args, **kwargs) -> R:
132
132
  if self.loop is None and self.process_batch_task is None:
133
133
  self.loop = asyncio.get_event_loop()
134
134
  self.process_batch_task = self.loop.create_task(self._process_batch())
135
135
 
136
- future = asyncio.Future()
136
+ future = asyncio.Future[R]()
137
137
  bound_args = self.signature.bind(*args, **kwargs)
138
138
  self.queue.put_nowait(Request(args=bound_args, future=future))
139
139
  return await future
xm_slurm/config.py CHANGED
@@ -2,12 +2,15 @@ import dataclasses
2
2
  import enum
3
3
  import functools
4
4
  import getpass
5
+ import json
5
6
  import os
6
7
  import pathlib
7
8
  from typing import Literal, Mapping, NamedTuple
8
9
 
9
10
  import asyncssh
10
11
 
12
+ from xm_slurm import constants
13
+
11
14
 
12
15
  class ContainerRuntime(enum.Enum):
13
16
  """The container engine to use."""
@@ -46,15 +49,103 @@ class PublicKey(NamedTuple):
46
49
  key: str
47
50
 
48
51
 
49
- @dataclasses.dataclass(frozen=True, kw_only=True)
50
- class SlurmClusterConfig:
51
- name: str
52
-
52
+ @dataclasses.dataclass
53
+ class SlurmSSHConfig:
53
54
  host: str
54
55
  host_public_key: PublicKey | None = None
55
56
  user: str | None = None
56
57
  port: int | None = None
57
58
 
59
+ @functools.cached_property
60
+ def known_hosts(self) -> asyncssh.SSHKnownHosts | None:
61
+ if self.host_public_key is None:
62
+ return None
63
+
64
+ return asyncssh.import_known_hosts(
65
+ f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
66
+ )
67
+
68
+ @functools.cached_property
69
+ def config(self) -> asyncssh.config.SSHConfig:
70
+ ssh_config_paths = []
71
+ if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
72
+ ssh_config_paths.append(ssh_config)
73
+ if (xm_ssh_config_var := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
74
+ xm_ssh_config := pathlib.Path(xm_ssh_config_var).expanduser()
75
+ ).exists():
76
+ ssh_config_paths.append(xm_ssh_config)
77
+
78
+ config = asyncssh.config.SSHClientConfig.load(
79
+ None,
80
+ ssh_config_paths,
81
+ False,
82
+ getpass.getuser(),
83
+ self.user or (),
84
+ self.host,
85
+ self.port or (),
86
+ )
87
+ if config.get("Hostname") is None and (
88
+ constants.DOMAIN_NAME_REGEX.match(self.host)
89
+ or constants.IPV4_REGEX.match(self.host)
90
+ or constants.IPV6_REGEX.match(self.host)
91
+ ):
92
+ config._options["Hostname"] = self.host
93
+ elif config.get("Hostname") is None:
94
+ raise RuntimeError(
95
+ f"Failed to parse hostname from host `{self.host}` using "
96
+ f"SSH configs: {', '.join(map(str, ssh_config_paths))} and "
97
+ f"provided hostname `{self.host}` isn't a valid domain name "
98
+ "or IPv{4,6} address."
99
+ )
100
+
101
+ if config.get("User") is None:
102
+ raise RuntimeError(
103
+ f"We could not find a user for the cluster configuration: `{self.host}`. "
104
+ "No user was specified in the configuration and we could not parse "
105
+ f"any users for host `{config.get('Hostname')}` from the SSH configs: "
106
+ f"{', '.join(map(lambda h: f'`{h}`', ssh_config_paths))}. Please either specify a user "
107
+ "in the configuration or add a user to your SSH configuration under the block "
108
+ f"`Host {config.get('Hostname')}`."
109
+ )
110
+
111
+ return config
112
+
113
+ @functools.cached_property
114
+ def connection_options(self) -> asyncssh.SSHClientConnectionOptions:
115
+ options = asyncssh.SSHClientConnectionOptions(config=None)
116
+ options.prepare(last_config=self.config, known_hosts=self.known_hosts)
117
+ return options
118
+
119
+ def serialize(self):
120
+ return json.dumps({
121
+ "host": self.host,
122
+ "host_public_key": self.host_public_key,
123
+ "user": self.user,
124
+ "port": self.port,
125
+ })
126
+
127
+ @classmethod
128
+ def deserialize(cls, data):
129
+ data = json.loads(data)
130
+ return cls(
131
+ host=data["host"],
132
+ host_public_key=PublicKey(*data["host_public_key"])
133
+ if data["host_public_key"]
134
+ else None,
135
+ user=data["user"],
136
+ port=data["port"],
137
+ )
138
+
139
+ def __hash__(self):
140
+ return hash((self.host, self.host_public_key, self.user, self.port))
141
+
142
+
143
+ @dataclasses.dataclass(frozen=True, kw_only=True)
144
+ class SlurmClusterConfig:
145
+ name: str
146
+
147
+ ssh: SlurmSSHConfig
148
+
58
149
  # Job submission directory
59
150
  cwd: str | None = None
60
151
 
@@ -81,7 +172,9 @@ class SlurmClusterConfig:
81
172
  )
82
173
 
83
174
  # Resource mapping
84
- resources: Mapping[str, "xm_slurm.ResourceType"] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
175
+ resources: Mapping["xm_slurm.ResourceType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
176
+
177
+ features: Mapping["xm_slurm.FeatureType", str] = dataclasses.field(default_factory=dict) # type: ignore # noqa: F821
85
178
 
86
179
  def __post_init__(self) -> None:
87
180
  for src, dst in self.mounts.items():
@@ -99,57 +192,9 @@ class SlurmClusterConfig:
99
192
  if not pathlib.Path(dst).is_absolute():
100
193
  raise ValueError(f"Mount destination must be an absolute path: {dst}")
101
194
 
102
- @functools.cached_property
103
- def ssh_known_hosts(self) -> asyncssh.SSHKnownHosts | None:
104
- if self.host_public_key is None:
105
- return None
106
-
107
- return asyncssh.import_known_hosts(
108
- f"[{self.host}]:{self.port} {self.host_public_key.algorithm} {self.host_public_key.key}"
109
- )
110
-
111
- @functools.cached_property
112
- def ssh_config(self) -> asyncssh.config.SSHConfig:
113
- ssh_config_paths = []
114
- if (ssh_config := pathlib.Path.home() / ".ssh" / "config").exists():
115
- ssh_config_paths.append(ssh_config)
116
- if (xm_ssh_config := os.environ.get("XM_SLURM_SSH_CONFIG")) and (
117
- xm_ssh_config := pathlib.Path(xm_ssh_config).expanduser()
118
- ).exists():
119
- ssh_config_paths.append(xm_ssh_config)
120
-
121
- config = asyncssh.config.SSHClientConfig.load(
122
- None,
123
- ssh_config_paths,
124
- True,
125
- getpass.getuser(),
126
- self.user or (),
127
- self.host or (),
128
- self.port or (),
129
- )
130
-
131
- if config.get("Hostname") is None:
132
- raise RuntimeError(
133
- f"Failed to parse hostname from host `{self.host}` using SSH configs: {', '.join(map(str, ssh_config_paths))}"
134
- )
135
- if config.get("User") is None:
136
- raise RuntimeError(
137
- f"Failed to parse user from SSH configs: {', '.join(map(str, ssh_config_paths))}"
138
- )
139
-
140
- return config
141
-
142
- @functools.cached_property
143
- def ssh_connection_options(self) -> asyncssh.SSHClientConnectionOptions:
144
- options = asyncssh.SSHClientConnectionOptions(config=None)
145
- options.prepare(last_config=self.ssh_config, known_hosts=self.ssh_known_hosts)
146
- return options
147
-
148
195
  def __hash__(self):
149
196
  return hash((
150
- self.host,
151
- self.user,
152
- self.port,
197
+ self.ssh,
153
198
  self.cwd,
154
199
  self.prolog,
155
200
  self.epilog,
xm_slurm/constants.py ADDED
@@ -0,0 +1,15 @@
1
+ import re
2
+
3
+ IMAGE_URI_REGEX = re.compile(
4
+ r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
5
+ )
6
+
7
+ DOMAIN_NAME_REGEX = re.compile(
8
+ r"^(?!-)(?!.*--)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*(\.[A-Za-z]{2,})$"
9
+ )
10
+ IPV4_REGEX = re.compile(
11
+ r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
12
+ )
13
+ IPV6_REGEX = re.compile(
14
+ r"^(([0-9a-fA-F]{1,4}:){7}([0-9a-fA-F]{1,4}|:)|(([0-9a-fA-F]{1,4}:){1,7}|:):(([0-9a-fA-F]{1,4}:){1,6}|:)|(([0-9a-fA-F]{1,4}:){1,6}|:):(([0-9a-fA-F]{1,4}:){1,5}|:)|(([0-9a-fA-F]{1,4}:){1,5}|:):(([0-9a-fA-F]{1,4}:){1,4}|:)|(([0-9a-fA-F]{1,4}:){1,4}|:):(([0-9a-fA-F]{1,4}:){1,3}|:)|(([0-9a-fA-F]{1,4}:){1,3}|:):(([0-9a-fA-F]{1,4}:){1,2}|:)|(([0-9a-fA-F]{1,4}:){1,2}|:):([0-9a-fA-F]{1,4}|:)|([0-9a-fA-F]{1,4}|:):([0-9a-fA-F]{1,4}|:)(:([0-9a-fA-F]{1,4}|:)){1,6})$"
15
+ )
File without changes