xmanager-slurm 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (38) hide show
  1. xm_slurm/__init__.py +44 -0
  2. xm_slurm/api.py +261 -0
  3. xm_slurm/batching.py +139 -0
  4. xm_slurm/config.py +162 -0
  5. xm_slurm/console.py +3 -0
  6. xm_slurm/contrib/clusters/__init__.py +52 -0
  7. xm_slurm/contrib/clusters/drac.py +169 -0
  8. xm_slurm/executables.py +201 -0
  9. xm_slurm/execution.py +491 -0
  10. xm_slurm/executors.py +127 -0
  11. xm_slurm/experiment.py +737 -0
  12. xm_slurm/job_blocks.py +14 -0
  13. xm_slurm/packageables.py +292 -0
  14. xm_slurm/packaging/__init__.py +8 -0
  15. xm_slurm/packaging/docker/__init__.py +75 -0
  16. xm_slurm/packaging/docker/abc.py +112 -0
  17. xm_slurm/packaging/docker/cloud.py +503 -0
  18. xm_slurm/packaging/docker/local.py +206 -0
  19. xm_slurm/packaging/registry.py +45 -0
  20. xm_slurm/packaging/router.py +52 -0
  21. xm_slurm/packaging/utils.py +202 -0
  22. xm_slurm/resources.py +150 -0
  23. xm_slurm/status.py +188 -0
  24. xm_slurm/templates/docker/docker-bake.hcl.j2 +47 -0
  25. xm_slurm/templates/docker/mamba.Dockerfile +27 -0
  26. xm_slurm/templates/docker/pdm.Dockerfile +31 -0
  27. xm_slurm/templates/docker/python.Dockerfile +24 -0
  28. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +32 -0
  29. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  30. xm_slurm/templates/slurm/job-array.bash.j2 +29 -0
  31. xm_slurm/templates/slurm/job-group.bash.j2 +41 -0
  32. xm_slurm/templates/slurm/job.bash.j2 +78 -0
  33. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +103 -0
  34. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +56 -0
  35. xm_slurm/utils.py +69 -0
  36. xmanager_slurm-0.3.0.dist-info/METADATA +25 -0
  37. xmanager_slurm-0.3.0.dist-info/RECORD +38 -0
  38. xmanager_slurm-0.3.0.dist-info/WHEEL +4 -0
xm_slurm/experiment.py ADDED
@@ -0,0 +1,737 @@
1
+ import asyncio
2
+ import collections.abc
3
+ import contextvars
4
+ import dataclasses
5
+ import functools
6
+ import inspect
7
+ import json
8
+ import os
9
+ import typing
10
+ from concurrent import futures
11
+ from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
12
+
13
+ from xmanager import xm
14
+ from xmanager.xm import async_packager, id_predictor
15
+
16
+ from xm_slurm import api, execution, executors
17
+ from xm_slurm.console import console
18
+ from xm_slurm.packaging import router
19
+ from xm_slurm.status import SlurmWorkUnitStatus
20
+ from xm_slurm.utils import UserSet
21
+
22
+ _current_job_array_queue = contextvars.ContextVar[
23
+ asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
24
+ ]("_current_job_array_queue", default=None)
25
+
26
+
27
+ def _validate_job(
28
+ job: xm.JobType,
29
+ args_view: Mapping[str, Any],
30
+ ) -> None:
31
+ if not args_view:
32
+ return
33
+ if not isinstance(args_view, collections.abc.Mapping):
34
+ raise ValueError("Job arguments via `experiment.add` must be mappings")
35
+
36
+ if isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
37
+ raise ValueError("Job group is empty")
38
+
39
+ if isinstance(job, xm.JobGroup) and any(
40
+ isinstance(child, xm.JobGroup) for child in job.jobs.values()
41
+ ):
42
+ raise ValueError("Nested job groups are not supported")
43
+
44
+ allowed_keys = {"args", "env_vars"}
45
+ for key, expanded in args_view.items():
46
+ if isinstance(job, xm.JobGroup) and len(job.jobs) > 1 and key not in job.jobs:
47
+ raise ValueError(
48
+ f"Argument key `{key}` doesn't exist in job group with keys {job.jobs.keys()}"
49
+ )
50
+
51
+ if isinstance(job, xm.JobGroup) and key in job.jobs:
52
+ _validate_job(job.jobs[key], expanded)
53
+ elif key not in allowed_keys:
54
+ raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
55
+
56
+
57
+ @dataclasses.dataclass(kw_only=True, frozen=True)
58
+ class Artifact:
59
+ name: str
60
+ uri: str
61
+
62
+ def __hash__(self) -> int:
63
+ return hash(self.name)
64
+
65
+
66
+ class ContextArtifacts(UserSet[Artifact]):
67
+ def __init__(
68
+ self,
69
+ owner: "SlurmExperiment | SlurmExperimentUnit",
70
+ *,
71
+ artifacts: Sequence[Artifact],
72
+ ):
73
+ super().__init__(
74
+ artifacts,
75
+ on_add=self._on_add_artifact,
76
+ on_remove=self._on_remove_artifact,
77
+ on_discard=self._on_remove_artifact,
78
+ )
79
+ self._owner = owner
80
+ self._create_task = self._owner._create_task
81
+
82
+ def _on_add_artifact(self, artifact: Artifact) -> None:
83
+ match self._owner:
84
+ case SlurmExperiment():
85
+ api.client().insert_experiment_artifact(
86
+ self._owner.experiment_id,
87
+ api.ArtifactModel(
88
+ name=artifact.name,
89
+ uri=artifact.uri,
90
+ ),
91
+ )
92
+ case SlurmWorkUnit():
93
+ api.client().insert_work_unit_artifact(
94
+ self._owner.experiment_id,
95
+ self._owner.work_unit_id,
96
+ api.ArtifactModel(
97
+ name=artifact.name,
98
+ uri=artifact.uri,
99
+ ),
100
+ )
101
+
102
+ def _on_remove_artifact(self, artifact: Artifact) -> None:
103
+ match self._owner:
104
+ case SlurmExperiment():
105
+ api.client().delete_experiment_artifact(self._owner.experiment_id, artifact.name)
106
+ case SlurmWorkUnit():
107
+ api.client().delete_work_unit_artifact(
108
+ self._owner.experiment_id, self._owner.work_unit_id, artifact.name
109
+ )
110
+
111
+
112
+ @dataclasses.dataclass(frozen=True, kw_only=True)
113
+ class SlurmExperimentUnitMetadataContext:
114
+ artifacts: ContextArtifacts
115
+
116
+
117
+ class SlurmExperimentUnit(xm.ExperimentUnit):
118
+ """ExperimentUnit is a collection of semantically associated `Job`s."""
119
+
120
+ experiment: "SlurmExperiment"
121
+
122
+ def __init__(
123
+ self,
124
+ experiment: xm.Experiment,
125
+ create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
126
+ args: Mapping[str, Any] | None,
127
+ role: xm.ExperimentUnitRole,
128
+ ) -> None:
129
+ super().__init__(experiment, create_task, args, role)
130
+ self._launched_jobs: list[xm.LaunchedJob] = []
131
+ self._execution_handles: list[execution.SlurmHandle] = []
132
+ self._context = SlurmExperimentUnitMetadataContext(
133
+ artifacts=ContextArtifacts(owner=self, artifacts=[]),
134
+ )
135
+
136
+ async def _submit_jobs_for_execution(
137
+ self,
138
+ job: xm.Job | xm.JobGroup,
139
+ args_view: Mapping[str, Any],
140
+ identity: str | None = None,
141
+ ) -> execution.SlurmHandle:
142
+ return await execution.launch(
143
+ job=job,
144
+ args=args_view,
145
+ experiment_id=self.experiment_id,
146
+ identity=identity,
147
+ )
148
+
149
+ def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
150
+ match job:
151
+ case xm.JobGroup() as job_group:
152
+ for job in job_group.jobs.values():
153
+ self._launched_jobs.append(
154
+ xm.LaunchedJob(
155
+ name=job.name, # type: ignore
156
+ address=str(handle.job_id),
157
+ )
158
+ )
159
+ case xm.Job():
160
+ self._launched_jobs.append(
161
+ xm.LaunchedJob(
162
+ name=handle.job.name, # type: ignore
163
+ address=str(handle.job_id),
164
+ )
165
+ )
166
+
167
+ async def _wait_until_complete(self) -> None:
168
+ try:
169
+ await asyncio.gather(*[handle.wait() for handle in self._execution_handles])
170
+ except RuntimeError as error:
171
+ raise xm.ExperimentUnitFailedError(error)
172
+
173
+ def stop(
174
+ self,
175
+ *,
176
+ mark_as_failed: bool = False,
177
+ mark_as_completed: bool = False,
178
+ message: str | None = None,
179
+ ) -> None:
180
+ del mark_as_failed, mark_as_completed, message
181
+
182
+ async def _stop_awaitable() -> None:
183
+ try:
184
+ await asyncio.gather(*[handle.stop() for handle in self._execution_handles])
185
+ except RuntimeError as error:
186
+ raise xm.ExperimentUnitFailedError(error)
187
+
188
+ self.experiment._create_task(_stop_awaitable())
189
+
190
+ async def get_status(self) -> SlurmWorkUnitStatus:
191
+ states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
192
+ return SlurmWorkUnitStatus.aggregate(states)
193
+
194
+ def launched_jobs(self) -> list[xm.LaunchedJob]:
195
+ return self._launched_jobs
196
+
197
+ @property
198
+ def context(self) -> SlurmExperimentUnitMetadataContext:
199
+ return self._context
200
+
201
+
202
+ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
203
+ def __init__(
204
+ self,
205
+ experiment: "SlurmExperiment",
206
+ create_task: Callable[[Awaitable[Any]], futures.Future],
207
+ args: Mapping[str, Any],
208
+ role: xm.ExperimentUnitRole,
209
+ work_unit_id_predictor: id_predictor.Predictor,
210
+ ) -> None:
211
+ super().__init__(experiment, create_task, args, role)
212
+ self._work_unit_id_predictor = work_unit_id_predictor
213
+ self._work_unit_id = self._work_unit_id_predictor.reserve_id()
214
+ api.client().insert_work_unit(
215
+ self.experiment_id,
216
+ api.WorkUnitPatchModel(
217
+ wid=self.work_unit_id,
218
+ identity=self.identity,
219
+ args=json.dumps(args),
220
+ ),
221
+ )
222
+
223
+ def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
224
+ self._execution_handles.append(handle)
225
+ api.client().insert_job(
226
+ self.experiment_id,
227
+ self.work_unit_id,
228
+ api.SlurmJobModel(
229
+ name=self.experiment_unit_name,
230
+ slurm_job_id=handle.job_id, # type: ignore
231
+ slurm_cluster=json.dumps({
232
+ "host": handle.ssh_connection_options.host,
233
+ "username": handle.ssh_connection_options.username,
234
+ "port": handle.ssh_connection_options.port,
235
+ "config": handle.ssh_connection_options.config.get_options(False),
236
+ }),
237
+ ),
238
+ )
239
+
240
+ async def _launch_job_group(
241
+ self,
242
+ job: xm.JobGroup,
243
+ args_view: Mapping[str, Any],
244
+ identity: str,
245
+ ) -> None:
246
+ global _current_job_array_queue
247
+ _validate_job(job, args_view)
248
+
249
+ future = asyncio.Future()
250
+ async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
251
+ # If we're scheduling as part of a job queue (i.e., the queue is set on the context)
252
+ # then we'll insert the job and current future that'll get resolved to the
253
+ # proper handle.
254
+ if job_array_queue := _current_job_array_queue.get():
255
+ job_array_queue.put_nowait((job, future))
256
+ # Otherwise we'll resolve the future with the scheduled job immediately
257
+ else:
258
+ # Set the result inside of the context manager so we don't get out-of-order
259
+ # id scheduling...
260
+ future.set_result(
261
+ await self._submit_jobs_for_execution(job, args_view, identity=identity)
262
+ )
263
+
264
+ # Wait for the job handle, this is either coming from scheduling the job array
265
+ # or from the single job above.
266
+ handle = await future
267
+ self._ingest_handle(handle)
268
+ self._ingest_launched_jobs(job, handle)
269
+
270
+ @property
271
+ def experiment_unit_name(self) -> str:
272
+ return f"{self.experiment_id}_{self._work_unit_id}"
273
+
274
+ @property
275
+ def work_unit_id(self) -> int:
276
+ return self._work_unit_id
277
+
278
+ def __repr__(self, /) -> str:
279
+ return f"<SlurmWorkUnit {self.experiment_unit_name}>"
280
+
281
+
282
+ class SlurmAuxiliaryUnit(SlurmExperimentUnit):
283
+ """An auxiliary unit operated by the Slurm backend."""
284
+
285
+ def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
286
+ del handle
287
+ console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
288
+
289
+ async def _launch_job_group(
290
+ self,
291
+ job: xm.Job | xm.JobGroup,
292
+ args_view: Mapping[str, Any],
293
+ identity: str,
294
+ ) -> None:
295
+ _validate_job(job, args_view)
296
+
297
+ slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
298
+ self._ingest_handle(slurm_handle)
299
+ self._ingest_launched_jobs(job, slurm_handle)
300
+
301
+ @property
302
+ def experiment_unit_name(self) -> str:
303
+ return f"{self.experiment_id}_auxiliary"
304
+
305
+ def __repr__(self, /) -> str:
306
+ return f"<SlurmAuxiliaryUnit {self.experiment_unit_name}>"
307
+
308
+
309
+ class SlurmExperimentContextAnnotations:
310
+ def __init__(
311
+ self,
312
+ experiment: "SlurmExperiment",
313
+ *,
314
+ title: str,
315
+ tags: set[str] | None = None,
316
+ description: str | None = None,
317
+ note: str | None = None,
318
+ ):
319
+ self._experiment = experiment
320
+ self._create_task = self._experiment._create_task
321
+ self._title = title
322
+ self._tags = UserSet(
323
+ tags or set(),
324
+ on_add=self._on_tag_added,
325
+ on_remove=self._on_tag_removed,
326
+ on_discard=self._on_tag_removed,
327
+ )
328
+ self._description = description or ""
329
+ self._note = note or ""
330
+
331
+ @property
332
+ def title(self) -> str:
333
+ return self._title
334
+
335
+ @title.setter
336
+ def title(self, value: str) -> None:
337
+ self._title = value
338
+ api.client().update_experiment(
339
+ self._experiment.experiment_id,
340
+ api.ExperimentPatchModel(title=value),
341
+ )
342
+
343
+ @property
344
+ def description(self) -> str:
345
+ return self._description
346
+
347
+ @description.setter
348
+ def description(self, value: str) -> None:
349
+ self._description = value
350
+ api.client().update_experiment(
351
+ self._experiment.experiment_id,
352
+ api.ExperimentPatchModel(description=value),
353
+ )
354
+
355
+ @property
356
+ def note(self) -> str:
357
+ return self._note
358
+
359
+ @note.setter
360
+ def note(self, value: str) -> None:
361
+ self._note = value
362
+ api.client().update_experiment(
363
+ self._experiment.experiment_id,
364
+ api.ExperimentPatchModel(note=value),
365
+ )
366
+
367
+ @property
368
+ def tags(self) -> MutableSet[str]:
369
+ return self._tags
370
+
371
+ @tags.setter
372
+ def tags(self, tags: set[str]) -> None:
373
+ # TODO(jfarebro): Create custom tag collection
374
+ # and set it here, we need this so we can hook add and remove
375
+ # to mutate the database transparently
376
+ self._tags = UserSet(tags, on_add=self._on_tag_added, on_remove=self._on_tag_removed)
377
+ api.client().update_experiment(
378
+ self._experiment.experiment_id,
379
+ api.ExperimentPatchModel(tags=list(self._tags)),
380
+ )
381
+
382
+ def _on_tag_added(self, tag: str) -> None:
383
+ del tag
384
+ api.client().update_experiment(
385
+ self._experiment.experiment_id,
386
+ api.ExperimentPatchModel(tags=list(self._tags)),
387
+ )
388
+
389
+ def _on_tag_removed(self, tag: str) -> None:
390
+ del tag
391
+ api.client().update_experiment(
392
+ self._experiment.experiment_id,
393
+ api.ExperimentPatchModel(tags=list(self._tags)),
394
+ )
395
+
396
+
397
+ class SlurmExperimentContextArtifacts(ContextArtifacts):
398
+ def add_graphviz_config(self, config: str) -> None:
399
+ self.add(Artifact(name="GRAPHVIZ", uri=f"graphviz://{config}"))
400
+
401
+ def add_python_config(self, config: str) -> None:
402
+ self.add(Artifact(name="PYTHON", uri=config))
403
+
404
+
405
+ @dataclasses.dataclass(frozen=True, kw_only=True)
406
+ class SlurmExperimentMetadataContext:
407
+ annotations: SlurmExperimentContextAnnotations
408
+ artifacts: ContextArtifacts
409
+
410
+
411
+ class SlurmExperiment(xm.Experiment):
412
+ _id: int
413
+ _experiment_units: list[SlurmExperimentUnit]
414
+ _experiment_context: SlurmExperimentMetadataContext
415
+ _work_unit_count: int
416
+ _async_packager = async_packager.AsyncPackager(router.package)
417
+
418
+ def __init__(
419
+ self,
420
+ experiment_title: str,
421
+ experiment_id: int,
422
+ ) -> None:
423
+ super().__init__()
424
+ self._id = experiment_id
425
+ self._experiment_units = []
426
+ self._experiment_context = SlurmExperimentMetadataContext(
427
+ annotations=SlurmExperimentContextAnnotations(
428
+ experiment=self,
429
+ title=experiment_title,
430
+ ),
431
+ artifacts=ContextArtifacts(self, artifacts=[]),
432
+ )
433
+ self._work_unit_count = 0
434
+
435
+ @typing.overload
436
+ def add(
437
+ self,
438
+ job: xm.AuxiliaryUnitJob,
439
+ args: Mapping[str, Any] | None = ...,
440
+ *,
441
+ identity: str = "",
442
+ ) -> asyncio.Future[SlurmExperimentUnit]: ...
443
+
444
+ @typing.overload
445
+ def add(
446
+ self,
447
+ job: xm.JobType,
448
+ args: Mapping[str, Any] | None = ...,
449
+ *,
450
+ role: xm.WorkUnitRole = ...,
451
+ identity: str = "",
452
+ ) -> asyncio.Future[SlurmWorkUnit]: ...
453
+
454
+ @typing.overload
455
+ def add(
456
+ self,
457
+ job: xm.JobType,
458
+ args: Mapping[str, Any] | None,
459
+ *,
460
+ role: xm.ExperimentUnitRole,
461
+ identity: str = "",
462
+ ) -> asyncio.Future[SlurmExperimentUnit]: ...
463
+
464
+ @typing.overload
465
+ def add(
466
+ self,
467
+ job: xm.JobType,
468
+ args: Mapping[str, Any] | None = ...,
469
+ *,
470
+ role: xm.ExperimentUnitRole,
471
+ identity: str = "",
472
+ ) -> asyncio.Future[SlurmExperimentUnit]: ...
473
+
474
+ @typing.overload
475
+ def add(
476
+ self,
477
+ job: xm.Job | xm.JobGeneratorType,
478
+ args: Sequence[Mapping[str, Any]],
479
+ *,
480
+ role: xm.WorkUnitRole = ...,
481
+ identity: str = "",
482
+ ) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
483
+
484
+ def add(
485
+ self,
486
+ job: xm.JobType,
487
+ args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
488
+ *,
489
+ role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
490
+ identity: str = "",
491
+ ) -> (
492
+ asyncio.Future[SlurmExperimentUnit]
493
+ | asyncio.Future[SlurmWorkUnit]
494
+ | asyncio.Future[Sequence[SlurmWorkUnit]]
495
+ ):
496
+ if isinstance(args, collections.abc.Sequence):
497
+ if not isinstance(role, xm.WorkUnitRole):
498
+ raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
499
+ if identity:
500
+ raise ValueError(
501
+ "Cannot set an identity on the root add call. "
502
+ "Please use a job generator and set the identity within."
503
+ )
504
+ if isinstance(job, xm.JobGroup):
505
+ raise ValueError(
506
+ "Job arrays over `xm.JobGroup`s aren't supported. "
507
+ "Slurm doesn't support job arrays over heterogeneous jobs. "
508
+ "Instead you should call `experiment.add` for each of these trials."
509
+ )
510
+ assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
511
+
512
+ return asyncio.wrap_future(
513
+ self._create_task(self._launch_job_array(job, args, role, identity))
514
+ )
515
+ else:
516
+ return super().add(job, args, role=role, identity=identity) # type: ignore
517
+
518
+ async def _launch_job_array(
519
+ self,
520
+ job: xm.Job | xm.JobGeneratorType,
521
+ args: Sequence[Mapping[str, Any]],
522
+ role: xm.WorkUnitRole,
523
+ identity: str = "",
524
+ ) -> Sequence[SlurmWorkUnit]:
525
+ global _current_job_array_queue
526
+
527
+ # Create our job array queue and assign it to the current context
528
+ job_array_queue = asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]](maxsize=len(args))
529
+ _current_job_array_queue.set(job_array_queue)
530
+
531
+ # For each trial we'll schedule the job
532
+ # and collect the futures
533
+ wu_futures = []
534
+ for trial in args:
535
+ wu_futures.append(super().add(job, args=trial, role=role, identity=identity))
536
+
537
+ # TODO(jfarebro): Set a timeout here
538
+ # We'll wait until XManager has filled the queue.
539
+ # There are two cases here, either we were given an xm.Job
540
+ # in which case this will be trivial and filled immediately.
541
+ # The other case is when you have a job generator and this is less
542
+ # trivial, you have to wait for wu.add to be called.
543
+ while not job_array_queue.full():
544
+ await asyncio.sleep(0.1)
545
+
546
+ # All jobs have been resolved
547
+ executable, executor, name = None, None, None
548
+ resolved_args, resolved_env_vars, resolved_futures = [], [], []
549
+ while not job_array_queue.empty():
550
+ # XManager automatically converts jobs to job groups so we must check
551
+ # that there's only a single job in this job group
552
+ job_group_view, future = job_array_queue.get_nowait()
553
+ assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
554
+ _, job_view = job_group_view.jobs.popitem()
555
+
556
+ if len(job_group_view.jobs) != 0 or not isinstance(job_view, xm.Job):
557
+ raise ValueError("Only `xm.Job` is supported for job arrays. ")
558
+
559
+ if executable is None:
560
+ executable = job_view.executable
561
+ if id(job_view.executable) != id(executable):
562
+ raise RuntimeError("Found multiple executables in job array.")
563
+
564
+ if executor is None:
565
+ executor = job_view.executor
566
+ if id(job_view.executor) != id(executor):
567
+ raise RuntimeError("Found multiple executors in job array")
568
+
569
+ if name is None:
570
+ name = job_view.name
571
+ if job_view.name != name:
572
+ raise RuntimeError("Found multiple names in job array")
573
+
574
+ resolved_args.append(
575
+ set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
576
+ )
577
+ resolved_env_vars.append(set(job_view.env_vars.items()))
578
+ resolved_futures.append(future)
579
+ assert executable is not None, "No executable found?"
580
+ assert executor is not None, "No executor found?"
581
+ assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
582
+ assert executor.requirements.cluster is not None, "Cluster must be set on executor."
583
+
584
+ common_args: set = functools.reduce(lambda a, b: a & b, resolved_args, set())
585
+ common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
586
+
587
+ sweep_args = [
588
+ {
589
+ "args": dict(a.difference(common_args)),
590
+ "env_vars": dict(e.difference(common_env_vars)),
591
+ }
592
+ for a, e in zip(resolved_args, resolved_env_vars)
593
+ ]
594
+
595
+ # No support for sweep_env_vars right now.
596
+ # We schedule the job array and then we'll resolve all the work units with
597
+ # the handles Slurm gives back to us.
598
+ try:
599
+ handles = await execution.get_client().launch(
600
+ cluster=executor.requirements.cluster,
601
+ job=xm.Job(
602
+ executable=executable,
603
+ executor=executor,
604
+ name=name,
605
+ args=dict(common_args),
606
+ env_vars=dict(common_env_vars),
607
+ ),
608
+ args=sweep_args,
609
+ experiment_id=self.experiment_id,
610
+ identity=identity,
611
+ )
612
+ except Exception as e:
613
+ for future in resolved_futures:
614
+ future.set_exception(e)
615
+ raise
616
+ else:
617
+ for handle, future in zip(handles, resolved_futures):
618
+ future.set_result(handle)
619
+
620
+ wus = await asyncio.gather(*wu_futures)
621
+ _current_job_array_queue.set(None)
622
+ return wus
623
+
624
+ def _create_experiment_unit(
625
+ self,
626
+ args: Mapping[str, Any],
627
+ role: xm.ExperimentUnitRole,
628
+ identity: str,
629
+ ) -> Awaitable[SlurmWorkUnit]:
630
+ del identity
631
+
632
+ def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
633
+ work_unit = SlurmWorkUnit(
634
+ self,
635
+ self._create_task,
636
+ args,
637
+ role,
638
+ self._work_unit_id_predictor,
639
+ )
640
+ self._experiment_units.append(work_unit)
641
+ self._work_unit_count += 1
642
+
643
+ future = asyncio.Future()
644
+ future.set_result(work_unit)
645
+ return future
646
+
647
+ match role:
648
+ case xm.WorkUnitRole():
649
+ return _create_work_unit(role)
650
+ case _:
651
+ raise ValueError(f"Unsupported role {role}")
652
+
653
+ def _get_experiment_unit(
654
+ self,
655
+ experiment_id: int,
656
+ identity: str,
657
+ role: xm.ExperimentUnitRole,
658
+ args: Mapping[str, Any] | None = None,
659
+ ) -> Awaitable[xm.ExperimentUnit]:
660
+ del experiment_id, identity, role, args
661
+ raise NotImplementedError
662
+
663
+ def _should_reload_experiment_unit(self, role: xm.ExperimentUnitRole) -> bool:
664
+ del role
665
+ return False
666
+
667
+ async def __aenter__(self) -> "SlurmExperiment":
668
+ await super().__aenter__()
669
+ return self
670
+
671
+ async def __aexit__(self, exc_type, exc_value, traceback):
672
+ # If no work units were added, delete this experiment
673
+ # This is to prevent empty experiments from being persisted
674
+ # and cluttering the database.
675
+ if self.work_unit_count == 0:
676
+ console.print(
677
+ f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
678
+ )
679
+ api.client().delete_experiment(self.experiment_id)
680
+
681
+ await super().__aexit__(exc_type, exc_value, traceback)
682
+
683
+ @property
684
+ def experiment_id(self) -> int:
685
+ return self._id
686
+
687
+ @property
688
+ def experiment_title(self) -> str:
689
+ return self.context.annotations.title
690
+
691
+ @property
692
+ def context(self) -> SlurmExperimentMetadataContext:
693
+ return self._experiment_context
694
+
695
+ @property
696
+ def work_unit_count(self) -> int:
697
+ return self._work_unit_count
698
+
699
+ @property
700
+ def work_units(self) -> Mapping[int, SlurmWorkUnit]:
701
+ """Gets work units created via self.add()."""
702
+ return {
703
+ wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
704
+ }
705
+
706
+ def __repr__(self, /) -> str:
707
+ return f"<SlurmExperiment {self.experiment_id} {self.experiment_title}>"
708
+
709
+
710
+ def create_experiment(experiment_title: str) -> SlurmExperiment:
711
+ """Create Experiment."""
712
+ experiment_id = api.client().insert_experiment(api.ExperimentPatchModel(title=experiment_title))
713
+ return SlurmExperiment(experiment_title=experiment_title, experiment_id=experiment_id)
714
+
715
+
716
+ def get_experiment(experiment_id: int) -> SlurmExperiment:
717
+ """Get Experiment."""
718
+ experiment_model = api.client().get_experiment(experiment_id)
719
+ # TODO(jfarebro): Fill in jobs and work units and annotations
720
+ return SlurmExperiment(experiment_title=experiment_model.title, experiment_id=experiment_id)
721
+
722
+
723
+ @functools.cache
724
+ def get_current_experiment() -> SlurmExperiment | None:
725
+ if xid := os.environ.get("XM_SLURM_EXPERIMENT_ID"):
726
+ return get_experiment(int(xid))
727
+ return None
728
+
729
+
730
+ @functools.cache
731
+ def get_current_work_unit() -> SlurmWorkUnit | None:
732
+ if (xid := os.environ.get("XM_SLURM_EXPERIMENT_ID")) and (
733
+ wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
734
+ ):
735
+ experiment = get_experiment(int(xid))
736
+ return experiment.work_units[int(wid)]
737
+ return None