xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

@@ -19,18 +19,22 @@ def mila(
19
19
  if mounts is None:
20
20
  mounts = {
21
21
  "/network/scratch/${USER:0:1}/$USER": "/scratch",
22
- "/network/archive/${USER:0:1}/$USER": "/archive",
22
+ # TODO: move these somewhere common to all cluster configs.
23
+ "/home/mila/${USER:0:1}/$USER/.local/state/xm-slurm": "/xm-slurm-state",
24
+ "/home/mila/${USER:0:1}/$USER/.ssh": "/home/mila/${USER:0:1}/$USER/.ssh",
23
25
  }
24
26
 
25
27
  return config.SlurmClusterConfig(
26
28
  name="mila",
27
- user=user,
28
- host="login.server.mila.quebec",
29
- host_public_key=config.PublicKey(
30
- "ssh-ed25519",
31
- "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
29
+ ssh=config.SlurmSSHConfig(
30
+ user=user,
31
+ host="login.server.mila.quebec",
32
+ host_public_key=config.PublicKey(
33
+ "ssh-ed25519",
34
+ "AAAAC3NzaC1lZDI1NTE5AAAAIBTPCzWRkwYDr/cFb4d2uR6rFlUtqfH3MoLMXPpJHK0n",
35
+ ),
36
+ port=2222,
32
37
  ),
33
- port=2222,
34
38
  runtime=config.ContainerRuntime.SINGULARITY,
35
39
  partition=partition,
36
40
  prolog="module load singularity",
@@ -39,14 +43,19 @@ def mila(
39
43
  "SINGULARITY_TMPDIR": "$SLURM_TMPDIR",
40
44
  "SINGULARITY_LOCALCACHEDIR": "$SLURM_TMPDIR",
41
45
  "SCRATCH": "/scratch",
42
- "ARCHIVE": "/archive",
46
+ # TODO: move this somewhere common to all cluster configs.
47
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
43
48
  },
44
49
  mounts=mounts,
45
50
  resources={
46
- "rtx8000": resources.ResourceType.RTX8000,
47
- "v100": resources.ResourceType.V100,
48
- "a100": resources.ResourceType.A100,
49
- "a100l": resources.ResourceType.A100_80GIB,
50
- "a6000": resources.ResourceType.A6000,
51
+ resources.ResourceType.RTX8000: "rtx8000",
52
+ resources.ResourceType.V100: "v100",
53
+ resources.ResourceType.A100: "a100",
54
+ resources.ResourceType.A100_80GIB: "a100l",
55
+ resources.ResourceType.A6000: "a6000",
56
+ },
57
+ features={
58
+ resources.FeatureType.NVIDIA_MIG: "mig",
59
+ resources.FeatureType.NVIDIA_NVLINK: "nvlink",
51
60
  },
52
61
  )
@@ -2,7 +2,7 @@ import os
2
2
  from typing import Literal
3
3
 
4
4
  from xm_slurm import config
5
- from xm_slurm.resources import ResourceType
5
+ from xm_slurm.resources import FeatureType, ResourceType
6
6
 
7
7
  __all__ = ["narval", "beluga", "cedar", "graham"]
8
8
 
@@ -18,18 +18,26 @@ def _drac_cluster(
18
18
  modules: list[str] | None = None,
19
19
  proxy: Literal["submission-host"] | str | None = None,
20
20
  mounts: dict[os.PathLike[str] | str, os.PathLike[str] | str] | None = None,
21
- resources: dict[str, ResourceType] | None = None,
21
+ resources: dict[ResourceType, str] | None = None,
22
+ features: dict[FeatureType, str] | None = None,
22
23
  ) -> config.SlurmClusterConfig:
23
24
  """DRAC Cluster."""
24
25
  if mounts is None:
25
- mounts = {"/scratch/$USER": "/scratch"}
26
+ mounts = {
27
+ "/scratch/$USER": "/scratch",
28
+ # TODO: move these somewhere common to all cluster configs.
29
+ "/home/$USER/.ssh": "/home/$USER/.ssh",
30
+ "/home/$USER/.local/state/xm-slurm": "/xm-slurm-state",
31
+ }
26
32
 
27
33
  return config.SlurmClusterConfig(
28
34
  name=name,
29
- user=user,
30
- host=host,
31
- host_public_key=host_public_key,
32
- port=port,
35
+ ssh=config.SlurmSSHConfig(
36
+ user=user,
37
+ host=host,
38
+ host_public_key=host_public_key,
39
+ port=port,
40
+ ),
33
41
  account=account,
34
42
  proxy=proxy,
35
43
  runtime=config.ContainerRuntime.APPTAINER,
@@ -40,9 +48,12 @@ def _drac_cluster(
40
48
  "APPTAINER_LOCALCACHEDIR": "$SLURM_TMPDIR",
41
49
  "_XDG_DATA_HOME": "$SLURM_TMPDIR/.local",
42
50
  "SCRATCH": "/scratch",
51
+ # TODO: move this somewhere common to all cluster configs.
52
+ "XM_SLURM_STATE_DIR": "/xm-slurm-state",
43
53
  },
44
54
  mounts=mounts,
45
55
  resources=resources or {},
56
+ features=features or {},
46
57
  )
47
58
 
48
59
 
@@ -70,7 +81,11 @@ def narval(
70
81
  mounts=mounts,
71
82
  proxy=proxy,
72
83
  modules=modules,
73
- resources={"a100": ResourceType.A100},
84
+ resources={ResourceType.A100: "a100"},
85
+ features={
86
+ FeatureType.NVIDIA_MIG: "a100mig",
87
+ FeatureType.NVIDIA_NVLINK: "nvlink",
88
+ },
74
89
  )
75
90
 
76
91
 
@@ -98,7 +113,10 @@ def beluga(
98
113
  mounts=mounts,
99
114
  proxy=proxy,
100
115
  modules=modules,
101
- resources={"tesla_v100-sxm2-16gb": ResourceType.V100},
116
+ resources={ResourceType.V100: "tesla_v100-sxm2-16gb"},
117
+ features={
118
+ FeatureType.NVIDIA_NVLINK: "nvlink",
119
+ },
102
120
  )
103
121
 
104
122
 
@@ -120,9 +138,9 @@ def cedar(
120
138
  account=account,
121
139
  mounts=mounts,
122
140
  resources={
123
- "v100l": ResourceType.V100_32GIB,
124
- "p100": ResourceType.P100,
125
- "p100l": ResourceType.P100_16GIB,
141
+ ResourceType.V100_32GIB: "v100l",
142
+ ResourceType.P100: "p100",
143
+ ResourceType.P100_16GIB: "p100l",
126
144
  },
127
145
  )
128
146
 
@@ -147,10 +165,10 @@ def graham(
147
165
  mounts=mounts,
148
166
  proxy=proxy,
149
167
  resources={
150
- "v100": ResourceType.V100,
151
- "p100": ResourceType.P100,
152
- "a100": ResourceType.A100,
153
- "a5000": ResourceType.A5000,
168
+ ResourceType.V100: "v100",
169
+ ResourceType.P100: "p100",
170
+ ResourceType.A100: "a100",
171
+ ResourceType.A5000: "a5000",
154
172
  },
155
173
  )
156
174
 
xm_slurm/executables.py CHANGED
@@ -1,10 +1,11 @@
1
1
  import dataclasses
2
2
  import pathlib
3
- import re
4
3
  from typing import Mapping, NamedTuple, Sequence
5
4
 
6
5
  from xmanager import xm
7
6
 
7
+ from xm_slurm import constants
8
+
8
9
 
9
10
  @dataclasses.dataclass(frozen=True, kw_only=True)
10
11
  class Dockerfile(xm.ExecutableSpec):
@@ -14,6 +15,7 @@ class Dockerfile(xm.ExecutableSpec):
14
15
  dockerfile: The path to the Dockerfile.
15
16
  context: The path to the Docker context.
16
17
  target: The Docker build target.
18
+ ssh: A list of docker SSH sockets/keys.
17
19
  build_args: Build arguments to docker.
18
20
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
19
21
  workdir: The working directory in container.
@@ -29,6 +31,9 @@ class Dockerfile(xm.ExecutableSpec):
29
31
  # Docker build target
30
32
  target: str | None = None
31
33
 
34
+ # SSH sockets/keys for the docker build step.
35
+ ssh: Sequence[str] = dataclasses.field(default_factory=list)
36
+
32
37
  # Build arguments to docker
33
38
  build_args: Mapping[str, str] = dataclasses.field(default_factory=dict)
34
39
 
@@ -56,6 +61,7 @@ class Dockerfile(xm.ExecutableSpec):
56
61
  self.dockerfile,
57
62
  self.context,
58
63
  self.target,
64
+ tuple(sorted(self.ssh)),
59
65
  tuple(sorted(self.build_args.items())),
60
66
  tuple(sorted(self.cache_from)),
61
67
  self.workdir,
@@ -87,11 +93,6 @@ class DockerImage(xm.ExecutableSpec):
87
93
  return hash((self.image, self.workdir))
88
94
 
89
95
 
90
- _IMAGE_URI_REGEX = re.compile(
91
- r"^(?P<scheme>(?:[^:]+://)?)?(?P<domain>[^/]+)(?P<path>/[^:]*)?(?::(?P<tag>[^@]+))?@?(?P<digest>.+)?$"
92
- )
93
-
94
-
95
96
  @dataclasses.dataclass
96
97
  class ImageURI:
97
98
  image: dataclasses.InitVar[str]
@@ -103,7 +104,7 @@ class ImageURI:
103
104
  digest: str | None = dataclasses.field(init=False, default=None)
104
105
 
105
106
  def __post_init__(self, image: str):
106
- match = _IMAGE_URI_REGEX.match(image)
107
+ match = constants.IMAGE_URI_REGEX.match(image)
107
108
  if not match:
108
109
  raise ValueError(f"Invalid OCI image URI: {image}")
109
110
  groups = {k: v for k, v in match.groupdict().items() if v is not None}
@@ -199,3 +200,14 @@ class RemoteImage(xm.Executable):
199
200
  @property
200
201
  def name(self) -> str:
201
202
  return str(self.image)
203
+
204
+ def __hash__(self) -> int:
205
+ return hash(
206
+ (
207
+ self.image,
208
+ self.workdir,
209
+ tuple(sorted(self.args.to_list())),
210
+ tuple(sorted(self.env_vars.items())),
211
+ self.credentials,
212
+ ),
213
+ )
xm_slurm/execution.py CHANGED
@@ -13,12 +13,14 @@ from typing import Any, Mapping, Sequence
13
13
  import asyncssh
14
14
  import backoff
15
15
  import jinja2 as j2
16
+ import more_itertools as mit
16
17
  from asyncssh.auth import KbdIntPrompts, KbdIntResponse
17
18
  from asyncssh.misc import MaybeAwait
18
19
  from xmanager import xm
19
20
 
20
21
  from xm_slurm import batching, config, executors, status
21
22
  from xm_slurm.console import console
23
+ from xm_slurm.job_blocks import JobArgs
22
24
 
23
25
  SlurmClusterConfig = config.SlurmClusterConfig
24
26
  ContainerRuntime = config.ContainerRuntime
@@ -163,8 +165,9 @@ class _BatchedSlurmHandle:
163
165
  class SlurmHandle(_BatchedSlurmHandle):
164
166
  """A handle for referring to the launched container."""
165
167
 
166
- ssh_connection_options: asyncssh.SSHClientConnectionOptions
168
+ ssh: config.SlurmSSHConfig
167
169
  job_id: str
170
+ job_name: str # XManager job name associated with this handle
168
171
 
169
172
  def __post_init__(self):
170
173
  if re.match(r"^\d+(_\d+|\+\d+)?$", self.job_id) is None:
@@ -180,10 +183,10 @@ class SlurmHandle(_BatchedSlurmHandle):
180
183
  return await self.get_state()
181
184
 
182
185
  async def stop(self) -> None:
183
- await self._batched_cancel(self.ssh_connection_options, self.job_id)
186
+ await self._batched_cancel(self.ssh.connection_options, self.job_id)
184
187
 
185
188
  async def get_state(self) -> status.SlurmJobState:
186
- return await self._batched_get_state(self.ssh_connection_options, self.job_id)
189
+ return await self._batched_get_state(self.ssh.connection_options, self.job_id)
187
190
 
188
191
 
189
192
  @functools.cache
@@ -208,7 +211,7 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
208
211
  case ContainerRuntime.PODMAN:
209
212
  runtime_template = template_env.get_template("runtimes/podman.bash.j2")
210
213
  case _:
211
- raise NotImplementedError
214
+ raise NotImplementedError(f"Container runtime {container_runtime} is not implemented.")
212
215
  # Update our global env with the runtime template's exported globals
213
216
  template_env.globals.update(runtime_template.module.__dict__)
214
217
 
@@ -216,16 +219,17 @@ def get_template_env(container_runtime: ContainerRuntime) -> j2.Environment:
216
219
 
217
220
 
218
221
  class Client:
219
- def __init__(self):
220
- self._connections: dict[
222
+ def __init__(self) -> None:
223
+ self._connections = dict[
221
224
  asyncssh.SSHClientConnectionOptions, asyncssh.SSHClientConnection
222
- ] = {}
225
+ ]()
223
226
  self._connection_lock = asyncio.Lock()
224
227
 
225
228
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
226
229
  async def _setup_remote_connection(self, conn: asyncssh.SSHClientConnection) -> None:
227
230
  # Make sure the xm-slurm state directory exists
228
- await conn.run("mkdir -p ~/.local/state/xm-slurm", check=True)
231
+ async with conn.start_sftp_client() as sftp_client:
232
+ await sftp_client.makedirs(".local/state/xm-slurm", exist_ok=True)
229
233
 
230
234
  async def connection(
231
235
  self,
@@ -238,7 +242,24 @@ class Client:
238
242
  await self._setup_remote_connection(conn)
239
243
  self._connections[options] = conn
240
244
  except asyncssh.misc.PermissionDenied as ex:
241
- raise RuntimeError(f"Permission denied connecting to {options.host}") from ex
245
+ raise SlurmExecutionError(
246
+ f"Permission denied connecting to {options.host}"
247
+ ) from ex
248
+ except asyncssh.misc.ConnectionLost as ex:
249
+ raise SlurmExecutionError(f"Connection lost to host {options.host}") from ex
250
+ except asyncssh.misc.HostKeyNotVerifiable as ex:
251
+ raise SlurmExecutionError(
252
+ f"Cannot verify the public key for host {options.host}"
253
+ ) from ex
254
+ except asyncssh.misc.KeyExchangeFailed as ex:
255
+ raise SlurmExecutionError(
256
+ f"Failed to exchange keys with host {options.host}"
257
+ ) from ex
258
+ except asyncssh.Error as ex:
259
+ raise SlurmExecutionError(
260
+ f"SSH connection error when connecting to {options.host}"
261
+ ) from ex
262
+
242
263
  return self._connections[options]
243
264
 
244
265
  @backoff.on_exception(backoff.expo, asyncssh.Error, max_tries=5, max_time=60.0)
@@ -340,8 +361,8 @@ class Client:
340
361
  self,
341
362
  *,
342
363
  cluster: SlurmClusterConfig,
343
- job: xm.Job | xm.JobGroup,
344
- args: Mapping[str, Any] | None,
364
+ job: xm.JobGroup,
365
+ args: Mapping[str, JobArgs] | None,
345
366
  experiment_id: int,
346
367
  identity: str | None = ...,
347
368
  ) -> SlurmHandle: ...
@@ -351,21 +372,24 @@ class Client:
351
372
  self,
352
373
  *,
353
374
  cluster: SlurmClusterConfig,
354
- job: xm.Job | xm.JobGroup,
355
- args: Sequence[Mapping[str, Any]],
375
+ job: xm.Job,
376
+ args: Sequence[JobArgs],
356
377
  experiment_id: int,
357
378
  identity: str | None = ...,
358
- ) -> Sequence[SlurmHandle]: ...
379
+ ) -> list[SlurmHandle]: ...
359
380
 
381
+ @typing.overload
360
382
  async def launch(
361
383
  self,
362
384
  *,
363
385
  cluster: SlurmClusterConfig,
364
- job: xm.Job | xm.JobGroup,
365
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None,
386
+ job: xm.Job,
387
+ args: JobArgs,
366
388
  experiment_id: int,
367
- identity: str | None = None,
368
- ) -> SlurmHandle | Sequence[SlurmHandle]:
389
+ identity: str | None = ...,
390
+ ) -> SlurmHandle: ...
391
+
392
+ async def launch(self, *, cluster, job, args, experiment_id, identity=None):
369
393
  # Construct template
370
394
  template = await self.template(
371
395
  job=job,
@@ -379,7 +403,7 @@ class Client:
379
403
  # Hash submission script
380
404
  template_hash = hashlib.blake2s(template.encode()).hexdigest()[:8]
381
405
 
382
- conn = await self.connection(cluster.ssh_connection_options)
406
+ conn = await self.connection(cluster.ssh.connection_options)
383
407
  async with conn.start_sftp_client() as sftp:
384
408
  # Write the submission script to the cluster
385
409
  # TODO(jfarebro): SHOULD FIND A WAY TO GET THE HOME DIRECTORY
@@ -392,9 +416,9 @@ class Client:
392
416
 
393
417
  # Construct and run command on the cluster
394
418
  command = f"sbatch --chdir .local/state/xm-slurm/{experiment_id} --parsable submission-script-{template_hash}.sh"
395
- result = await self.run(cluster.ssh_connection_options, command)
419
+ result = await self.run(cluster.ssh.connection_options, command)
396
420
  if result.returncode != 0:
397
- raise RuntimeError(f"Failed to schedule job on {cluster.host}: {result.stderr}")
421
+ raise RuntimeError(f"Failed to schedule job on {cluster.ssh.host}: {result.stderr}")
398
422
 
399
423
  assert isinstance(result.stdout, str)
400
424
  slurm_job_id, *_ = result.stdout.split(",")
@@ -405,26 +429,40 @@ class Client:
405
429
  f"[cyan]{cluster.name}[/cyan] "
406
430
  )
407
431
 
432
+ # If we scheduled an array job make sure to return a list of handles
433
+ # The indexing is always sequential in 0, 1, ..., n - 1
408
434
  if isinstance(job, xm.Job) and isinstance(args, collections.abc.Sequence):
435
+ assert job.name is not None
409
436
  return [
410
437
  SlurmHandle(
411
- ssh_connection_options=cluster.ssh_connection_options,
438
+ ssh=cluster.ssh,
412
439
  job_id=f"{slurm_job_id}_{array_index}",
440
+ job_name=job.name,
413
441
  )
414
442
  for array_index in range(len(args))
415
443
  ]
416
-
417
- return SlurmHandle(
418
- ssh_connection_options=cluster.ssh_connection_options,
419
- job_id=slurm_job_id,
420
- )
444
+ elif isinstance(job, xm.Job):
445
+ assert job.name is not None
446
+ return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
447
+ elif isinstance(job, xm.JobGroup):
448
+ # TODO: make this work for actual job groups.
449
+ job = mit.one(job.jobs.values())
450
+ assert isinstance(job, xm.Job)
451
+ assert job.name is not None
452
+ return SlurmHandle(ssh=cluster.ssh, job_id=slurm_job_id, job_name=job.name)
453
+ else:
454
+ raise ValueError(f"Unsupported job type: {type(job)}")
455
+
456
+ def __del__(self):
457
+ for conn in self._connections.values():
458
+ conn.close()
421
459
 
422
460
 
423
461
  @typing.overload
424
462
  async def launch(
425
463
  *,
426
- job: xm.Job | xm.JobGroup,
427
- args: Mapping[str, Any],
464
+ job: xm.JobGroup,
465
+ args: Mapping[str, JobArgs],
428
466
  experiment_id: int,
429
467
  identity: str | None = ...,
430
468
  ) -> SlurmHandle: ...
@@ -433,22 +471,32 @@ async def launch(
433
471
  @typing.overload
434
472
  async def launch(
435
473
  *,
436
- job: xm.Job | xm.JobGroup,
437
- args: Sequence[Mapping[str, Any]],
474
+ job: xm.Job,
475
+ args: Sequence[JobArgs],
438
476
  experiment_id: int,
439
477
  identity: str | None = ...,
440
- ) -> Sequence[SlurmHandle]: ...
478
+ ) -> list[SlurmHandle]: ...
479
+
480
+
481
+ @typing.overload
482
+ async def launch(
483
+ *,
484
+ job: xm.Job,
485
+ args: JobArgs,
486
+ experiment_id: int,
487
+ identity: str | None = ...,
488
+ ) -> SlurmHandle: ...
441
489
 
442
490
 
443
491
  async def launch(
444
492
  *,
445
493
  job: xm.Job | xm.JobGroup,
446
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]],
494
+ args: Mapping[str, JobArgs] | Sequence[JobArgs] | JobArgs,
447
495
  experiment_id: int,
448
496
  identity: str | None = None,
449
- ) -> SlurmHandle | Sequence[SlurmHandle]:
497
+ ) -> SlurmHandle | list[SlurmHandle]:
450
498
  match job:
451
- case xm.Job():
499
+ case xm.Job() as job:
452
500
  if not isinstance(job.executor, executors.Slurm):
453
501
  raise ValueError("Job must have a Slurm executor")
454
502
  job_requirements = job.executor.requirements
@@ -459,7 +507,7 @@ async def launch(
459
507
  return await get_client().launch(
460
508
  cluster=cluster,
461
509
  job=job,
462
- args=args,
510
+ args=typing.cast(JobArgs | Sequence[JobArgs], args),
463
511
  experiment_id=experiment_id,
464
512
  identity=identity,
465
513
  )
@@ -482,8 +530,8 @@ async def launch(
482
530
 
483
531
  return await get_client().launch(
484
532
  cluster=job_group_clusters.pop(),
485
- job=job,
486
- args=args,
533
+ job=job_group,
534
+ args=typing.cast(Mapping[str, JobArgs], args),
487
535
  experiment_id=experiment_id,
488
536
  identity=identity,
489
537
  )