xmanager-slurm 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

@@ -0,0 +1,173 @@
1
+ import dataclasses
2
+
3
+ import backoff
4
+ import httpx
5
+
6
+ from xm_slurm.api import models
7
+ from xm_slurm.api.abc import XManagerAPI
8
+
9
+ # Define which exceptions should trigger a retry
10
+ RETRY_EXCEPTIONS = (
11
+ httpx.ConnectError,
12
+ httpx.ConnectTimeout,
13
+ httpx.ReadTimeout,
14
+ httpx.WriteTimeout,
15
+ httpx.NetworkError,
16
+ )
17
+
18
+
19
+ # Common backoff decorator for all API calls
20
+ def with_backoff(f):
21
+ return backoff.on_exception(
22
+ backoff.expo,
23
+ RETRY_EXCEPTIONS,
24
+ max_tries=3, # Maximum number of attempts
25
+ max_time=30, # Maximum total time to try in seconds
26
+ jitter=backoff.full_jitter, # Add jitter to prevent thundering herd
27
+ )(f)
28
+
29
+
30
+ class XManagerWebAPI(XManagerAPI):
31
+ def __init__(self, base_url: str, token: str):
32
+ self.base_url = base_url.rstrip("/")
33
+ self.client = httpx.Client(headers={"Authorization": f"Bearer {token}"}, verify=False)
34
+
35
+ def _make_url(self, path: str) -> str:
36
+ return f"{self.base_url}/api{path}"
37
+
38
+ @with_backoff
39
+ def get_experiment(self, xid: int) -> models.Experiment:
40
+ response = self.client.get(self._make_url(f"/experiment/{xid}"))
41
+ response.raise_for_status()
42
+ data = response.json()
43
+ # Construct work units with nested jobs and artifacts
44
+ work_units = []
45
+ for wu_data in data.pop("work_units", []):
46
+ # Build jobs for this work unit
47
+ jobs = [
48
+ models.SlurmJob(
49
+ name=job["name"],
50
+ slurm_job_id=job["slurm_job_id"],
51
+ slurm_ssh_config=job["slurm_ssh_config"],
52
+ )
53
+ for job in wu_data.pop("jobs", [])
54
+ ]
55
+
56
+ # Build artifacts for this work unit
57
+ artifacts = [
58
+ models.Artifact(name=artifact["name"], uri=artifact["uri"])
59
+ for artifact in wu_data.pop("artifacts", [])
60
+ ]
61
+
62
+ # Create work unit with its jobs and artifacts
63
+ wu_data["jobs"] = jobs
64
+ wu_data["artifacts"] = artifacts
65
+ work_units.append(models.WorkUnit(**wu_data))
66
+
67
+ # Build experiment artifacts
68
+ artifacts = [
69
+ models.Artifact(name=artifact["name"], uri=artifact["uri"])
70
+ for artifact in data.pop("artifacts", [])
71
+ ]
72
+
73
+ return models.Experiment(**data, work_units=work_units, artifacts=artifacts)
74
+
75
+ @with_backoff
76
+ def delete_experiment(self, experiment_id: int) -> None:
77
+ response = self.client.delete(self._make_url(f"/experiment/{experiment_id}"))
78
+ response.raise_for_status()
79
+
80
+ @with_backoff
81
+ def insert_experiment(self, experiment: models.ExperimentPatch) -> int:
82
+ assert experiment.title is not None, "Title must be set in the experiment model."
83
+ assert (
84
+ experiment.description is None and experiment.note is None and experiment.tags is None
85
+ ), "Only title should be set in the experiment model."
86
+
87
+ response = self.client.put(
88
+ self._make_url("/experiment"), json=dataclasses.asdict(experiment)
89
+ )
90
+ response.raise_for_status()
91
+ return int(response.json()["xid"])
92
+
93
+ @with_backoff
94
+ def update_experiment(
95
+ self, experiment_id: int, experiment_patch: models.ExperimentPatch
96
+ ) -> None:
97
+ response = self.client.patch(
98
+ self._make_url(f"/experiment/{experiment_id}"),
99
+ json=dataclasses.asdict(experiment_patch),
100
+ )
101
+ response.raise_for_status()
102
+
103
+ @with_backoff
104
+ def insert_work_unit(self, experiment_id: int, work_unit: models.WorkUnitPatch) -> None:
105
+ response = self.client.put(
106
+ self._make_url(f"/experiment/{experiment_id}/wu"),
107
+ json=dataclasses.asdict(work_unit),
108
+ )
109
+ response.raise_for_status()
110
+
111
+ @with_backoff
112
+ def insert_job(self, experiment_id: int, work_unit_id: int, job: models.SlurmJob) -> None:
113
+ response = self.client.put(
114
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/job"),
115
+ json=dataclasses.asdict(job),
116
+ )
117
+ response.raise_for_status()
118
+
119
+ @with_backoff
120
+ def insert_work_unit_artifact(
121
+ self, experiment_id: int, work_unit_id: int, artifact: models.Artifact
122
+ ) -> None:
123
+ response = self.client.put(
124
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact"),
125
+ json=dataclasses.asdict(artifact),
126
+ )
127
+ response.raise_for_status()
128
+
129
+ @with_backoff
130
+ def delete_work_unit_artifact(self, experiment_id: int, work_unit_id: int, name: str) -> None:
131
+ response = self.client.delete(
132
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}/artifact/{name}")
133
+ )
134
+ response.raise_for_status()
135
+
136
+ @with_backoff
137
+ def delete_experiment_artifact(self, experiment_id: int, name: str) -> None:
138
+ response = self.client.delete(
139
+ self._make_url(f"/experiment/{experiment_id}/artifact/{name}")
140
+ )
141
+ response.raise_for_status()
142
+
143
+ @with_backoff
144
+ def insert_experiment_artifact(self, experiment_id: int, artifact: models.Artifact) -> None:
145
+ response = self.client.put(
146
+ self._make_url(f"/experiment/{experiment_id}/artifact"),
147
+ json=dataclasses.asdict(artifact),
148
+ )
149
+ response.raise_for_status()
150
+
151
+ @with_backoff
152
+ def insert_experiment_config_artifact(
153
+ self, experiment_id: int, artifact: models.ConfigArtifact
154
+ ) -> None:
155
+ response = self.client.put(
156
+ self._make_url(f"/experiment/{experiment_id}/config"), json=dataclasses.asdict(artifact)
157
+ )
158
+ response.raise_for_status()
159
+
160
+ @with_backoff
161
+ def delete_experiment_config_artifact(self, experiment_id: int, name: str) -> None:
162
+ response = self.client.delete(self._make_url(f"/experiment/{experiment_id}/config/{name}"))
163
+ response.raise_for_status()
164
+
165
+ @with_backoff
166
+ def update_work_unit(
167
+ self, experiment_id: int, work_unit_id: int, patch: models.ExperimentUnitPatch
168
+ ) -> None:
169
+ response = self.client.patch(
170
+ self._make_url(f"/experiment/{experiment_id}/wu/{work_unit_id}"),
171
+ json=dataclasses.asdict(patch),
172
+ )
173
+ response.raise_for_status()
xm_slurm/config.py CHANGED
@@ -80,6 +80,8 @@ class SlurmSSHConfig:
80
80
  None,
81
81
  ssh_config_paths,
82
82
  False,
83
+ True,
84
+ True,
83
85
  getpass.getuser(),
84
86
  self.user or (),
85
87
  self.host,
@@ -113,7 +115,11 @@ class SlurmSSHConfig:
113
115
 
114
116
  @functools.cached_property
115
117
  def connection_options(self) -> asyncssh.SSHClientConnectionOptions:
116
- options = asyncssh.SSHClientConnectionOptions(config=None)
118
+ options = asyncssh.SSHClientConnectionOptions(
119
+ config=None,
120
+ kbdint_auth=False,
121
+ disable_trivial_auth=True,
122
+ )
117
123
  options.prepare(last_config=self.config, known_hosts=self.known_hosts)
118
124
  return options
119
125
 
@@ -165,7 +171,8 @@ class SlurmClusterConfig:
165
171
  runtime: ContainerRuntime
166
172
 
167
173
  # Environment variables
168
- environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
174
+ host_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
175
+ container_environment: Mapping[str, str] = dataclasses.field(default_factory=dict)
169
176
 
170
177
  # Mounts
171
178
  mounts: Mapping[os.PathLike[str] | str, os.PathLike[str] | str] = dataclasses.field(
@@ -208,5 +215,6 @@ class SlurmClusterConfig:
208
215
  self.qos,
209
216
  self.proxy,
210
217
  self.runtime,
211
- frozenset(self.environment.items()),
218
+ frozenset(self.host_environment.items()),
219
+ frozenset(self.container_environment.items()),
212
220
  ))
@@ -1,12 +1,8 @@
1
- import datetime as dt
2
1
  import logging
3
2
  import os
4
3
 
5
- from xmanager import xm
6
-
7
4
  from xm_slurm import config, resources
8
5
  from xm_slurm.contrib.clusters import drac
9
- from xm_slurm.executors import Slurm
10
6
 
11
7
  # ComputeCanada alias
12
8
  cc = drac
@@ -45,12 +41,13 @@ def mila(
45
41
  runtime=config.ContainerRuntime.SINGULARITY,
46
42
  partition=partition,
47
43
  prolog="module load singularity",
48
- environment={
44
+ host_environment={
49
45
  "SINGULARITY_CACHEDIR": "$SCRATCH/.apptainer",
50
46
  "SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
51
47
  "SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
48
+ },
49
+ container_environment={
52
50
  "SCRATCH": "/scratch",
53
- # TODO: move this somewhere common to all cluster configs.
54
51
  "XM_SLURM_STATE_DIR": "/xm-slurm-state",
55
52
  },
56
53
  mounts=mounts,
@@ -42,13 +42,14 @@ def _drac_cluster(
42
42
  proxy=proxy,
43
43
  runtime=config.ContainerRuntime.APPTAINER,
44
44
  prolog=f"module load apptainer {' '.join(modules) if modules else ''}".rstrip(),
45
- environment={
45
+ host_environment={
46
+ "XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
46
47
  "APPTAINER_CACHEDIR": "$SCRATCH/.apptainer",
47
48
  "APPTAINER_TMPDIR": "$SLURM_TMPDIR",
48
49
  "APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
49
- "_XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
50
+ },
51
+ container_environment={
50
52
  "SCRATCH": "/scratch",
51
- # TODO: move this somewhere common to all cluster configs.
52
53
  "XM_SLURM_STATE_DIR": "/xm-slurm-state",
53
54
  },
54
55
  mounts=mounts,
xm_slurm/executables.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import dataclasses
2
+ import os
2
3
  import pathlib
3
4
  import typing as tp
4
5
 
@@ -19,7 +20,6 @@ class Dockerfile(xm.ExecutableSpec):
19
20
  ssh: A list of docker SSH sockets/keys.
20
21
  build_args: Build arguments to docker.
21
22
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
22
- workdir: The working directory in container.
23
23
  labels: The container labels.
24
24
  platforms: The target platform.
25
25
  """
@@ -41,9 +41,6 @@ class Dockerfile(xm.ExecutableSpec):
41
41
  # --cache-from field in BuildKit
42
42
  cache_from: tp.Sequence[str] = dataclasses.field(default_factory=list)
43
43
 
44
- # Working directory in container
45
- workdir: pathlib.Path | None = None
46
-
47
44
  # Container labels
48
45
  labels: tp.Mapping[str, str] = dataclasses.field(default_factory=dict)
49
46
 
@@ -66,7 +63,6 @@ class Dockerfile(xm.ExecutableSpec):
66
63
  tuple(sorted(self.ssh)),
67
64
  tuple(sorted(self.build_args.items())),
68
65
  tuple(sorted(self.cache_from)),
69
- self.workdir,
70
66
  tuple(sorted(self.labels.items())),
71
67
  tuple(sorted(self.platforms)),
72
68
  ))
@@ -190,8 +186,8 @@ class RemoteImage(xm.Executable):
190
186
  # Remote base image
191
187
  image: Descriptor[ImageURI, str | ImageURI] = ImageDescriptor()
192
188
 
193
- # Working directory in container
194
- workdir: pathlib.Path | None = None
189
+ workdir: os.PathLike[str] | str
190
+ entrypoint: xm.SequentialArgs
195
191
 
196
192
  # Container arguments
197
193
  args: xm.SequentialArgs = dataclasses.field(default_factory=xm.SequentialArgs)
@@ -211,6 +207,7 @@ class RemoteImage(xm.Executable):
211
207
  type(self),
212
208
  self.image,
213
209
  self.workdir,
210
+ tuple(sorted(self.entrypoint.to_list())),
214
211
  tuple(sorted(self.args.to_list())),
215
212
  tuple(sorted(self.env_vars.items())),
216
213
  self.credentials,