xmanager-slurm 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. xm_slurm/__init__.py +47 -0
  2. xm_slurm/api/__init__.py +33 -0
  3. xm_slurm/api/abc.py +65 -0
  4. xm_slurm/api/models.py +70 -0
  5. xm_slurm/api/sqlite/client.py +358 -0
  6. xm_slurm/api/web/client.py +173 -0
  7. xm_slurm/batching.py +139 -0
  8. xm_slurm/config.py +189 -0
  9. xm_slurm/console.py +3 -0
  10. xm_slurm/constants.py +19 -0
  11. xm_slurm/contrib/__init__.py +0 -0
  12. xm_slurm/contrib/clusters/__init__.py +67 -0
  13. xm_slurm/contrib/clusters/drac.py +242 -0
  14. xm_slurm/dependencies.py +171 -0
  15. xm_slurm/executables.py +215 -0
  16. xm_slurm/execution.py +995 -0
  17. xm_slurm/executors.py +210 -0
  18. xm_slurm/experiment.py +1016 -0
  19. xm_slurm/experimental/parameter_controller.py +206 -0
  20. xm_slurm/filesystems.py +129 -0
  21. xm_slurm/job_blocks.py +21 -0
  22. xm_slurm/metadata_context.py +253 -0
  23. xm_slurm/packageables.py +309 -0
  24. xm_slurm/packaging/__init__.py +8 -0
  25. xm_slurm/packaging/docker.py +348 -0
  26. xm_slurm/packaging/registry.py +45 -0
  27. xm_slurm/packaging/router.py +56 -0
  28. xm_slurm/packaging/utils.py +22 -0
  29. xm_slurm/resources.py +350 -0
  30. xm_slurm/scripts/_cloudpickle.py +28 -0
  31. xm_slurm/scripts/cli.py +90 -0
  32. xm_slurm/status.py +197 -0
  33. xm_slurm/templates/docker/docker-bake.hcl.j2 +54 -0
  34. xm_slurm/templates/docker/mamba.Dockerfile +29 -0
  35. xm_slurm/templates/docker/python.Dockerfile +32 -0
  36. xm_slurm/templates/docker/uv.Dockerfile +38 -0
  37. xm_slurm/templates/slurm/entrypoint.bash.j2 +27 -0
  38. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +78 -0
  39. xm_slurm/templates/slurm/fragments/proxy.bash.j2 +31 -0
  40. xm_slurm/templates/slurm/job-array.bash.j2 +31 -0
  41. xm_slurm/templates/slurm/job-group.bash.j2 +47 -0
  42. xm_slurm/templates/slurm/job.bash.j2 +90 -0
  43. xm_slurm/templates/slurm/library/retry.bash +62 -0
  44. xm_slurm/templates/slurm/runtimes/apptainer.bash.j2 +73 -0
  45. xm_slurm/templates/slurm/runtimes/podman.bash.j2 +43 -0
  46. xm_slurm/types.py +23 -0
  47. xm_slurm/utils.py +196 -0
  48. xmanager_slurm-0.4.19.dist-info/METADATA +28 -0
  49. xmanager_slurm-0.4.19.dist-info/RECORD +52 -0
  50. xmanager_slurm-0.4.19.dist-info/WHEEL +4 -0
  51. xmanager_slurm-0.4.19.dist-info/entry_points.txt +2 -0
  52. xmanager_slurm-0.4.19.dist-info/licenses/LICENSE.md +227 -0
xm_slurm/experiment.py ADDED
@@ -0,0 +1,1016 @@
1
+ import asyncio
2
+ import collections.abc
3
+ import contextvars
4
+ import dataclasses
5
+ import datetime as dt
6
+ import functools
7
+ import inspect
8
+ import json
9
+ import logging
10
+ import os
11
+ import traceback
12
+ import typing as tp
13
+ from concurrent import futures
14
+
15
+ import more_itertools as mit
16
+ from rich.console import ConsoleRenderable
17
+ from xmanager import xm
18
+ from xmanager.xm import async_packager, core, id_predictor, job_operators
19
+ from xmanager.xm import job_blocks as xm_job_blocks
20
+
21
+ from xm_slurm import api, config, dependencies, execution, executors, metadata_context
22
+ from xm_slurm.console import console
23
+ from xm_slurm.job_blocks import JobArgs
24
+ from xm_slurm.packaging import router
25
+ from xm_slurm.status import SlurmWorkUnitStatus
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ _current_job_array_queue = contextvars.ContextVar[
30
+ asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
31
+ ]("_current_job_array_queue", default=None)
32
+
33
+
34
+ def _validate_job(
35
+ job: xm.JobType,
36
+ args_view: JobArgs | tp.Mapping[str, JobArgs],
37
+ ) -> None:
38
+ if not args_view:
39
+ return
40
+ if not isinstance(args_view, collections.abc.Mapping):
41
+ raise ValueError("Job arguments via `experiment.add` must be mappings")
42
+
43
+ if isinstance(job, xm.JobGroup) and len(job.jobs) == 0:
44
+ raise ValueError("Job group is empty")
45
+
46
+ if isinstance(job, xm.JobGroup) and any(
47
+ isinstance(child, xm.JobGroup) for child in job.jobs.values()
48
+ ):
49
+ raise ValueError("Nested job groups are not supported")
50
+
51
+ allowed_keys = {"args", "env_vars"}
52
+ for key, expanded in args_view.items():
53
+ if isinstance(job, xm.JobGroup) and len(job.jobs) > 1 and key not in job.jobs:
54
+ raise ValueError(
55
+ f"Argument key `{key}` doesn't exist in job group with keys {job.jobs.keys()}"
56
+ )
57
+
58
+ if isinstance(job, xm.JobGroup) and key in job.jobs:
59
+ _validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
60
+ elif key not in allowed_keys:
61
+ raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
62
+
63
+
64
+ class SlurmExperimentUnit(xm.ExperimentUnit):
65
+ """ExperimentUnit is a collection of semantically associated `Job`s."""
66
+
67
+ experiment: "SlurmExperiment" # type: ignore
68
+
69
+ def __init__(
70
+ self,
71
+ experiment: xm.Experiment,
72
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
73
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
74
+ role: xm.ExperimentUnitRole,
75
+ identity: str = "",
76
+ ) -> None:
77
+ super().__init__(experiment, create_task, args, role, identity=identity)
78
+ self._launched_jobs: list[xm.LaunchedJob] = []
79
+ self._execution_handles: list[execution.SlurmHandle] = []
80
+ self._context = metadata_context.SlurmExperimentUnitMetadataContext(
81
+ self,
82
+ artifacts=metadata_context.SlurmContextArtifacts(owner=self, artifacts=[]),
83
+ )
84
+
85
+ def add( # type: ignore
86
+ self,
87
+ job: xm.JobType,
88
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
89
+ *,
90
+ dependency: dependencies.SlurmJobDependency | None = None,
91
+ identity: str = "",
92
+ ) -> tp.Awaitable[None]:
93
+ # Prioritize the identity given directly to the work unit at work unit
94
+ # creation time, as opposed to the identity passed when adding jobs to it as
95
+ # this is more consistent between job generator work units and regular work
96
+ # units.
97
+ identity = self.identity or identity
98
+
99
+ job = job_operators.shallow_copy_job_type(job) # type: ignore
100
+ if args is not None:
101
+ core._apply_args(job, args)
102
+ job_operators.populate_job_names(job) # type: ignore
103
+
104
+ def launch_job(job: xm.Job) -> tp.Awaitable[None]:
105
+ core._current_experiment.set(self.experiment)
106
+ core._current_experiment_unit.set(self)
107
+ return self._launch_job_group(
108
+ xm.JobGroup(**{job.name: job}), # type: ignore
109
+ core._work_unit_arguments(job, self._args),
110
+ dependency=dependency,
111
+ identity=identity,
112
+ )
113
+
114
+ def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
115
+ core._current_experiment.set(self.experiment)
116
+ core._current_experiment_unit.set(self)
117
+ return self._launch_job_group(
118
+ group,
119
+ core._work_unit_arguments(group, self._args),
120
+ dependency=dependency,
121
+ identity=identity,
122
+ )
123
+
124
+ def launch_job_generator(
125
+ job_generator: xm.JobGeneratorType,
126
+ ) -> tp.Awaitable[None]:
127
+ if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
128
+ getattr(job_generator, "__call__")
129
+ ):
130
+ raise ValueError(
131
+ "Job generator must be an async function. Signature needs to be "
132
+ "`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
133
+ )
134
+ core._current_experiment.set(self.experiment)
135
+ core._current_experiment_unit.set(self)
136
+ coroutine = job_generator(self, **(args or {}))
137
+ assert coroutine is not None
138
+ return coroutine
139
+
140
+ def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
141
+ core._current_experiment.set(self.experiment)
142
+ core._current_experiment_unit.set(self)
143
+ return self._launch_job_config(
144
+ job_config, dependency, tp.cast(JobArgs, args) or {}, identity
145
+ )
146
+
147
+ job_awaitable: tp.Awaitable[tp.Any]
148
+ match job:
149
+ case xm.Job() as job:
150
+ job_awaitable = launch_job(job)
151
+ case xm.JobGroup() as job_group:
152
+ job_awaitable = launch_job_group(job_group)
153
+ case job_generator if xm_job_blocks.is_job_generator(job):
154
+ job_awaitable = launch_job_generator(job_generator) # type: ignore
155
+ case xm.JobConfig() as job_config:
156
+ job_awaitable = launch_job_config(job_config)
157
+ case _:
158
+ raise TypeError(f"Unsupported job type: {job!r}")
159
+
160
+ launch_task = self._create_task(job_awaitable)
161
+ self._launch_tasks.append(launch_task)
162
+ return asyncio.wrap_future(launch_task)
163
+
164
+ async def _launch_job_group( # type: ignore
165
+ self,
166
+ job_group: xm.JobGroup,
167
+ args_view: tp.Mapping[str, JobArgs],
168
+ *,
169
+ dependency: dependencies.SlurmJobDependency | None,
170
+ identity: str,
171
+ ) -> None:
172
+ del job_group, dependency, args_view, identity
173
+ raise NotImplementedError
174
+
175
+ async def _launch_job_config( # type: ignore
176
+ self,
177
+ job_config: xm.JobConfig,
178
+ dependency: dependencies.SlurmJobDependency | None,
179
+ args_view: JobArgs,
180
+ identity: str,
181
+ ) -> None:
182
+ del job_config, dependency, args_view, identity
183
+ raise NotImplementedError
184
+
185
+ @tp.overload
186
+ async def _submit_jobs_for_execution(
187
+ self,
188
+ job: xm.Job,
189
+ dependency: dependencies.SlurmJobDependency | None,
190
+ args_view: JobArgs,
191
+ identity: str | None = ...,
192
+ ) -> execution.SlurmHandle: ...
193
+
194
+ @tp.overload
195
+ async def _submit_jobs_for_execution(
196
+ self,
197
+ job: xm.JobGroup,
198
+ dependency: dependencies.SlurmJobDependency | None,
199
+ args_view: tp.Mapping[str, JobArgs],
200
+ identity: str | None = ...,
201
+ ) -> execution.SlurmHandle: ...
202
+
203
+ @tp.overload
204
+ async def _submit_jobs_for_execution(
205
+ self,
206
+ job: xm.Job,
207
+ dependency: dependencies.SlurmJobDependency | None,
208
+ args_view: tp.Sequence[JobArgs],
209
+ identity: str | None = ...,
210
+ ) -> list[execution.SlurmHandle]: ...
211
+
212
+ async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
213
+ return await execution.launch(
214
+ job=job,
215
+ dependency=dependency,
216
+ args=args_view,
217
+ experiment_id=self.experiment_id,
218
+ identity=identity,
219
+ )
220
+
221
+ def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
222
+ self._execution_handles.append(handle)
223
+
224
+ def _ingest_job(job: xm.Job) -> None:
225
+ if not isinstance(self._role, xm.WorkUnitRole):
226
+ return
227
+ assert isinstance(self, SlurmWorkUnit)
228
+ assert job.name is not None
229
+ api.client().insert_job(
230
+ self.experiment_id,
231
+ self.work_unit_id,
232
+ api.models.SlurmJob(
233
+ name=job.name,
234
+ slurm_job_id=handle.slurm_job.job_id,
235
+ slurm_ssh_config=handle.ssh.serialize(),
236
+ ),
237
+ )
238
+
239
+ match job:
240
+ case xm.JobGroup() as job_group:
241
+ for job in job_group.jobs.values():
242
+ assert isinstance(job, xm.Job)
243
+ _ingest_job(job)
244
+ self._launched_jobs.append(
245
+ xm.LaunchedJob(
246
+ name=job.name, # type: ignore
247
+ address=str(handle.slurm_job.job_id),
248
+ )
249
+ )
250
+ case xm.Job():
251
+ _ingest_job(job)
252
+ self._launched_jobs.append(
253
+ xm.LaunchedJob(
254
+ name=handle.job.name, # type: ignore
255
+ address=str(handle.slurm_job.job_id),
256
+ )
257
+ )
258
+
259
+ async def _wait_until_complete(self) -> None:
260
+ try:
261
+ await asyncio.gather(*[handle.wait() for handle in self._execution_handles])
262
+ except RuntimeError as error:
263
+ raise xm.ExperimentUnitFailedError(error)
264
+
265
+ def stop(
266
+ self,
267
+ *,
268
+ mark_as_failed: bool = False,
269
+ mark_as_completed: bool = False,
270
+ message: str | None = None,
271
+ ) -> None:
272
+ del mark_as_failed, mark_as_completed, message
273
+
274
+ async def _stop_awaitable() -> None:
275
+ try:
276
+ await asyncio.gather(*[handle.stop() for handle in self._execution_handles])
277
+ except RuntimeError as error:
278
+ raise xm.ExperimentUnitFailedError(error)
279
+
280
+ self.experiment._create_task(_stop_awaitable())
281
+
282
+ async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
283
+ states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
284
+ return SlurmWorkUnitStatus.aggregate(states)
285
+
286
+ async def logs(
287
+ self,
288
+ *,
289
+ num_lines: int = 10,
290
+ block_size: int = 1024,
291
+ wait: bool = True,
292
+ follow: bool = False,
293
+ ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
294
+ if not self._execution_handles:
295
+ raise ValueError(f"No execution handles found for experiment unit {self!r}")
296
+ elif len(self._execution_handles) > 1:
297
+ raise ValueError(f"Multiple execution handles found for experiment unit {self!r}")
298
+ assert len(self._execution_handles) == 1
299
+
300
+ handle = self._execution_handles[0] # TODO(jfarebro): interleave?
301
+ async for log in handle.logs(
302
+ num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
303
+ ):
304
+ yield log
305
+
306
+ @property
307
+ def launched_jobs(self) -> list[xm.LaunchedJob]:
308
+ return self._launched_jobs
309
+
310
+ @property
311
+ def context(self) -> metadata_context.SlurmExperimentUnitMetadataContext: # type: ignore
312
+ return self._context
313
+
314
+ def after_started(
315
+ self, *, time: dt.timedelta | None = None
316
+ ) -> dependencies.SlurmJobDependencyAfter:
317
+ return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
318
+
319
+ def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
320
+ return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
321
+
322
+ def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
323
+ return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
324
+
325
+ def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
326
+ return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
327
+
328
+
329
+ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
330
+ def __init__(
331
+ self,
332
+ experiment: "SlurmExperiment",
333
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
334
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
335
+ role: xm.ExperimentUnitRole,
336
+ work_unit_id_predictor: id_predictor.Predictor,
337
+ identity: str = "",
338
+ ) -> None:
339
+ super().__init__(experiment, create_task, args, role, identity=identity)
340
+ self._work_unit_id_predictor = work_unit_id_predictor
341
+ self._work_unit_id = self._work_unit_id_predictor.reserve_id()
342
+
343
+ def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
344
+ job_name = mit.one(job.jobs.keys())
345
+ for handle in self._execution_handles:
346
+ if handle.job_name == job_name:
347
+ return handle
348
+ return None
349
+
350
+ async def _launch_job_group( # type: ignore
351
+ self,
352
+ job: xm.JobGroup,
353
+ args_view: tp.Mapping[str, JobArgs],
354
+ *,
355
+ dependency: dependencies.SlurmJobDependency | None,
356
+ identity: str,
357
+ ) -> None:
358
+ global _current_job_array_queue
359
+ _validate_job(job, args_view)
360
+ future = asyncio.Future[execution.SlurmHandle]()
361
+
362
+ # If we already have a handle for this job, we don't need to submit it again.
363
+ # We'll just resolve the future with the existing handle.
364
+ # Otherwise we'll add callbacks to ingest the handle and the launched jobs.
365
+ if existing_handle := self._get_existing_handle(job):
366
+ future.set_result(existing_handle)
367
+ else:
368
+ future.add_done_callback(
369
+ lambda handle: self._ingest_launched_jobs(job, handle.result())
370
+ )
371
+
372
+ api.client().update_work_unit(
373
+ self.experiment_id,
374
+ self.work_unit_id,
375
+ api.models.ExperimentUnitPatch(args=json.dumps(args_view), identity=None),
376
+ )
377
+
378
+ async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
379
+ # If we're scheduling as part of a job queue (i.e., the queue is set on the context)
380
+ # then we'll insert the job and future that'll get resolved to the proper handle
381
+ # when the Slurm job array is scheduled.
382
+ if job_array_queue := _current_job_array_queue.get():
383
+ job_array_queue.put_nowait((job, future))
384
+ # Otherwise, we're scheduling a single job and we'll submit it for execution.
385
+ # If the future is already done, i.e., the handle is already resolved, we don't need
386
+ # to submit the job again.
387
+ elif not future.done():
388
+ handle = await self._submit_jobs_for_execution(
389
+ job, dependency, args_view, identity=identity
390
+ )
391
+ future.set_result(handle)
392
+
393
+ # Wait for the job handle, this is either coming from scheduling the job array
394
+ # or from a single job submission. If an existing handle was found, this will be
395
+ # a no-op.
396
+ await future
397
+
398
+ @property
399
+ def experiment_unit_name(self) -> str:
400
+ return f"{self.experiment_id}_{self._work_unit_id}"
401
+
402
+ @property
403
+ def work_unit_id(self) -> int:
404
+ return self._work_unit_id
405
+
406
+ def __repr__(self, /) -> str:
407
+ return f"<SlurmWorkUnit {self.experiment_unit_name}>"
408
+
409
+
410
+ class SlurmAuxiliaryUnit(SlurmExperimentUnit):
411
+ """An auxiliary unit operated by the Slurm backend."""
412
+
413
+ async def _launch_job_group( # type: ignore
414
+ self,
415
+ job: xm.JobGroup,
416
+ args_view: tp.Mapping[str, JobArgs],
417
+ *,
418
+ dependency: dependencies.SlurmJobDependency | None,
419
+ identity: str,
420
+ ) -> None:
421
+ _validate_job(job, args_view)
422
+
423
+ slurm_handle = await self._submit_jobs_for_execution(
424
+ job, dependency, args_view, identity=identity
425
+ )
426
+ self._ingest_launched_jobs(job, slurm_handle)
427
+
428
+ @property
429
+ def experiment_unit_name(self) -> str:
430
+ return f"{self.experiment_id}_auxiliary"
431
+
432
+ def __repr__(self, /) -> str:
433
+ return f"<SlurmAuxiliaryUnit {self.experiment_unit_name}>"
434
+
435
+
436
+ class SlurmExperiment(xm.Experiment):
437
+ _id: int
438
+ _experiment_units: list[SlurmExperimentUnit]
439
+ _experiment_context: metadata_context.SlurmExperimentMetadataContext
440
+ _work_unit_count: int
441
+ _async_packager = async_packager.AsyncPackager(router.package)
442
+
443
+ def __init__(
444
+ self,
445
+ experiment_title: str,
446
+ experiment_id: int,
447
+ ) -> None:
448
+ super().__init__()
449
+ self._id = experiment_id
450
+ self._experiment_units = []
451
+ self._experiment_context = metadata_context.SlurmExperimentMetadataContext(
452
+ self,
453
+ annotations=metadata_context.SlurmExperimentContextAnnotations(
454
+ experiment=self,
455
+ title=experiment_title,
456
+ ),
457
+ artifacts=metadata_context.SlurmContextArtifacts(self, artifacts=[]),
458
+ )
459
+ self._work_unit_count = 0
460
+
461
+ @tp.overload
462
+ def add( # type: ignore
463
+ self,
464
+ job: xm.AuxiliaryUnitJob,
465
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
466
+ *,
467
+ dependency: dependencies.SlurmJobDependency | None = ...,
468
+ identity: str = ...,
469
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
470
+
471
+ @tp.overload
472
+ def add(
473
+ self,
474
+ job: xm.JobGroup,
475
+ args: tp.Mapping[str, JobArgs] | None = ...,
476
+ *,
477
+ role: xm.WorkUnitRole | None = ...,
478
+ dependency: dependencies.SlurmJobDependency | None = ...,
479
+ identity: str = ...,
480
+ ) -> asyncio.Future[SlurmWorkUnit]: ...
481
+
482
+ @tp.overload
483
+ def add(
484
+ self,
485
+ job: xm.Job | xm.JobGeneratorType,
486
+ args: tp.Sequence[JobArgs],
487
+ *,
488
+ role: xm.WorkUnitRole | None = ...,
489
+ dependency: dependencies.SlurmJobDependency
490
+ | tp.Sequence[dependencies.SlurmJobDependency]
491
+ | None = ...,
492
+ identity: str = ...,
493
+ ) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
494
+
495
+ @tp.overload
496
+ def add(
497
+ self,
498
+ job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
499
+ args: JobArgs | None = ...,
500
+ *,
501
+ role: xm.WorkUnitRole | None = ...,
502
+ dependency: dependencies.SlurmJobDependency | None = ...,
503
+ identity: str = ...,
504
+ ) -> asyncio.Future[SlurmWorkUnit]: ...
505
+
506
+ @tp.overload
507
+ def add(
508
+ self,
509
+ job: xm.JobType,
510
+ *,
511
+ role: xm.AuxiliaryUnitRole,
512
+ dependency: dependencies.SlurmJobDependency | None = ...,
513
+ identity: str = ...,
514
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
515
+
516
+ def add( # type: ignore
517
+ self,
518
+ job: xm.JobType,
519
+ args: JobArgs
520
+ | tp.Mapping[str, JobArgs]
521
+ | tp.Sequence[tp.Mapping[str, tp.Any]]
522
+ | None = None,
523
+ *,
524
+ role: xm.ExperimentUnitRole | None = None,
525
+ dependency: dependencies.SlurmJobDependency
526
+ | tp.Sequence[dependencies.SlurmJobDependency]
527
+ | None = None,
528
+ identity: str = "",
529
+ ) -> (
530
+ asyncio.Future[SlurmAuxiliaryUnit]
531
+ | asyncio.Future[SlurmWorkUnit]
532
+ | asyncio.Future[tp.Sequence[SlurmWorkUnit]]
533
+ ):
534
+ if role is None:
535
+ role = xm.WorkUnitRole()
536
+
537
+ if isinstance(args, collections.abc.Sequence):
538
+ if not isinstance(role, xm.WorkUnitRole):
539
+ raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
540
+ if isinstance(job, xm.JobGroup):
541
+ raise ValueError(
542
+ "Job arrays over `xm.JobGroup`s aren't supported. "
543
+ "Slurm doesn't support job arrays over heterogeneous jobs. "
544
+ "Instead you should call `experiment.add` for each of these trials."
545
+ )
546
+ assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
547
+
548
+ # Validate job & args
549
+ for trial in args:
550
+ _validate_job(job, trial)
551
+ args = tp.cast(tp.Sequence[JobArgs], args)
552
+
553
+ return asyncio.wrap_future(
554
+ self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
555
+ loop=self._event_loop,
556
+ )
557
+ if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
558
+ raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
559
+
560
+ if isinstance(job, xm.AuxiliaryUnitJob):
561
+ role = job.role
562
+ self._added_roles[type(role)] += 1
563
+
564
+ if self._should_reload_experiment_unit(role):
565
+ experiment_unit_future = self._get_experiment_unit(
566
+ self.experiment_id, identity, role, args
567
+ )
568
+ else:
569
+ experiment_unit_future = self._create_experiment_unit(args, role, identity)
570
+
571
+ async def launch():
572
+ experiment_unit = await experiment_unit_future
573
+ try:
574
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
575
+ except Exception as experiment_exception:
576
+ logger.error(
577
+ "Stopping experiment unit (identity %r) after it failed with: %s",
578
+ identity,
579
+ experiment_exception,
580
+ )
581
+ try:
582
+ if isinstance(job, xm.AuxiliaryUnitJob):
583
+ experiment_unit.stop()
584
+ else:
585
+ experiment_unit.stop(
586
+ mark_as_failed=True,
587
+ message=f"Work unit creation failed. {traceback.format_exc()}",
588
+ )
589
+ except Exception as stop_exception: # pylint: disable=broad-except
590
+ logger.error("Couldn't stop experiment unit: %s", stop_exception)
591
+ raise
592
+ return experiment_unit
593
+
594
+ async def reload():
595
+ experiment_unit = await experiment_unit_future
596
+ try:
597
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
598
+ except Exception as update_exception:
599
+ logging.error(
600
+ "Could not reload the experiment unit: %s",
601
+ update_exception,
602
+ )
603
+ raise
604
+ return experiment_unit
605
+
606
+ return asyncio.wrap_future(
607
+ self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
608
+ loop=self._event_loop,
609
+ )
610
+
611
+ async def _launch_job_array(
612
+ self,
613
+ job: xm.Job | xm.JobGeneratorType,
614
+ dependency: dependencies.SlurmJobDependency
615
+ | tp.Sequence[dependencies.SlurmJobDependency]
616
+ | None,
617
+ args: tp.Sequence[JobArgs],
618
+ role: xm.WorkUnitRole,
619
+ identity: str = "",
620
+ ) -> tp.Sequence[SlurmWorkUnit]:
621
+ global _current_job_array_queue
622
+
623
+ # Create our job array queue and assign it to the current context
624
+ job_array_queue = asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]](maxsize=len(args))
625
+ _current_job_array_queue.set(job_array_queue)
626
+
627
+ # For each trial we'll schedule the job
628
+ # and collect the futures
629
+ wu_futures = []
630
+ for idx, trial in enumerate(args):
631
+ wu_futures.append(
632
+ self.add(
633
+ job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
634
+ )
635
+ )
636
+
637
+ # We'll wait until XManager has filled the queue.
638
+ # There are two cases here, either we were given an xm.Job
639
+ # in which case this will be trivial and filled immediately.
640
+ # The other case is when you have a job generator and this is less
641
+ # trivial, you have to wait for wu.add to be called.
642
+ while not job_array_queue.full():
643
+ await asyncio.sleep(0.1)
644
+
645
+ # All jobs have been resolved so now we'll perform sanity checks
646
+ # to make sure we can infer the sweep
647
+ executable, executor, name = None, None, None
648
+ resolved_args, resolved_env_vars, resolved_futures = [], [], []
649
+ while not job_array_queue.empty():
650
+ # XManager automatically converts jobs to job groups so we must check
651
+ # that there's only a single job in this job group
652
+ job_group_view, future = job_array_queue.get_nowait()
653
+ assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
654
+ job_view = mit.one(
655
+ job_group_view.jobs.values(),
656
+ too_short=ValueError("Expected a single `xm.Job` in job group."),
657
+ too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
658
+ )
659
+
660
+ if not isinstance(job_view, xm.Job):
661
+ raise ValueError("Only `xm.Job` is supported for job arrays. ")
662
+
663
+ if executable is None:
664
+ executable = job_view.executable
665
+ if id(job_view.executable) != id(executable):
666
+ raise RuntimeError("Found multiple executables in job array.")
667
+
668
+ if executor is None:
669
+ executor = job_view.executor
670
+ if id(job_view.executor) != id(executor):
671
+ raise RuntimeError("Found multiple executors in job array")
672
+
673
+ if name is None:
674
+ name = job_view.name
675
+ if job_view.name != name:
676
+ raise RuntimeError("Found multiple names in job array")
677
+
678
+ resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
679
+ resolved_env_vars.append(set(job_view.env_vars.items()))
680
+ resolved_futures.append(future)
681
+ assert executable is not None, "No executable found?"
682
+ assert executor is not None, "No executor found?"
683
+ assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
684
+ assert (
685
+ executor.requirements.cluster is not None
686
+ ), "Cluster must be specified on requirements."
687
+
688
+ # XManager merges job arguments with keyword arguments with job arguments
689
+ # coming first. These are the arguments that may be common across all jobs
690
+ # so we can find the largest common prefix and remove them from each job.
691
+ common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
692
+ common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
693
+
694
+ sweep_args = [
695
+ JobArgs(
696
+ args=functools.reduce(
697
+ # Remove the common arguments from each job
698
+ lambda args, to_remove: args.remove_args(to_remove),
699
+ common_args,
700
+ xm.SequentialArgs.from_collection(a),
701
+ ),
702
+ env_vars=dict(e.difference(common_env_vars)),
703
+ )
704
+ for a, e in zip(resolved_args, resolved_env_vars)
705
+ ]
706
+
707
+ # Dependency resolution
708
+ resolved_dependency = None
709
+ resolved_dependency_task_id_order = None
710
+ # one-to-one
711
+ if isinstance(dependency, collections.abc.Sequence):
712
+ if len(dependency) != len(wu_futures):
713
+ raise ValueError("Dependency list must be the same length as the number of trials.")
714
+ assert len(dependency) > 0, "Dependency list must not be empty."
715
+
716
+ # Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
717
+ # for any array jobs.
718
+ def _maybe_convert_afterok(
719
+ dep: dependencies.SlurmJobDependency,
720
+ ) -> dependencies.SlurmJobDependency:
721
+ if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
722
+ handle.slurm_job.is_array_job for handle in dep.handles
723
+ ]):
724
+ return dependencies.SlurmJobArrayDependencyAfterOK([
725
+ dataclasses.replace(
726
+ handle,
727
+ slurm_job=handle.slurm_job.array_job_id,
728
+ )
729
+ for handle in dep.handles
730
+ ])
731
+ return dep
732
+
733
+ dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
734
+ dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
735
+ dependency_differences = functools.reduce(set.difference, dependency_sets, set())
736
+ # There should be NO differences between the dependencies of each trial after conversion.
737
+ if len(dependency_differences) > 0:
738
+ raise ValueError(
739
+ f"Found variable dependencies across trials: {dependency_differences}. "
740
+ "Slurm job arrays require the same dependencies across all trials. "
741
+ )
742
+ resolved_dependency = dependencies_converted[0]
743
+
744
+ # This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
745
+ # in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
746
+ original_array_dependencies = [
747
+ mit.one(
748
+ filter(
749
+ lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
750
+ and all([handle.slurm_job.is_array_job for handle in dep.handles]),
751
+ deps.flatten(),
752
+ )
753
+ )
754
+ for deps in dependency
755
+ ]
756
+ resolved_dependency_task_id_order = [
757
+ int(
758
+ mit.one(
759
+ functools.reduce(
760
+ set.difference,
761
+ [handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
762
+ )
763
+ )
764
+ )
765
+ for dep in original_array_dependencies
766
+ ]
767
+ assert len(resolved_dependency_task_id_order) == len(sweep_args)
768
+ assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
769
+ "Dependent job array tasks should have task ids 0, 1, ..., N. "
770
+ f"Found: {resolved_dependency_task_id_order}"
771
+ )
772
+ # one-to-many
773
+ elif isinstance(dependency, dependencies.SlurmJobDependency):
774
+ resolved_dependency = dependency
775
+ assert resolved_dependency is None or isinstance(
776
+ resolved_dependency, dependencies.SlurmJobDependency
777
+ ), "Invalid dependency type"
778
+
779
+ # No support for sweep_env_vars right now.
780
+ # We schedule the job array and then we'll resolve all the work units with
781
+ # the handles Slurm gives back to us.
782
+ # If we already have handles for all the work units, we don't need to submit the
783
+ # job array to SLURM.
784
+ num_resolved_handles = sum(future.done() for future in resolved_futures)
785
+ if num_resolved_handles == 0:
786
+ try:
787
+ handles = await execution.launch(
788
+ job=xm.Job(
789
+ executable=executable,
790
+ executor=executor,
791
+ name=name,
792
+ args=xm.SequentialArgs.from_collection(common_args),
793
+ env_vars=dict(common_env_vars),
794
+ ),
795
+ dependency=resolved_dependency,
796
+ args=[
797
+ sweep_args[resolved_dependency_task_id_order.index(i)]
798
+ for i in range(len(sweep_args))
799
+ ]
800
+ if resolved_dependency_task_id_order
801
+ else sweep_args,
802
+ experiment_id=self.experiment_id,
803
+ identity=identity,
804
+ )
805
+ if resolved_dependency_task_id_order:
806
+ handles = [handles[i] for i in resolved_dependency_task_id_order]
807
+ except Exception as e:
808
+ for future in resolved_futures:
809
+ future.set_exception(e)
810
+ raise
811
+ else:
812
+ for handle, future in zip(handles, resolved_futures):
813
+ future.set_result(handle)
814
+ elif 0 < num_resolved_handles < len(resolved_futures):
815
+ raise RuntimeError(
816
+ "Some array job elements have handles, but some don't. This shouldn't happen."
817
+ )
818
+
819
+ wus = await asyncio.gather(*wu_futures)
820
+
821
+ _current_job_array_queue.set(None)
822
+ return wus
823
+
824
+ def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
825
+ if identity == "":
826
+ return None
827
+ for unit in self._experiment_units:
828
+ if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
829
+ return unit
830
+ return None
831
+
832
+ def _create_experiment_unit( # type: ignore
833
+ self,
834
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
835
+ role: xm.ExperimentUnitRole,
836
+ identity: str,
837
+ ) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
838
+ def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
839
+ work_unit = SlurmWorkUnit(
840
+ self,
841
+ self._create_task,
842
+ args,
843
+ role,
844
+ self._work_unit_id_predictor,
845
+ identity=identity,
846
+ )
847
+ self._experiment_units.append(work_unit)
848
+ self._work_unit_count += 1
849
+
850
+ api.client().insert_work_unit(
851
+ self.experiment_id,
852
+ api.models.WorkUnitPatch(
853
+ wid=work_unit.work_unit_id,
854
+ identity=work_unit.identity,
855
+ args=json.dumps(args),
856
+ ),
857
+ )
858
+
859
+ future = asyncio.Future[SlurmWorkUnit]()
860
+ future.set_result(work_unit)
861
+ return future
862
+
863
+ def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
864
+ auxiliary_unit = SlurmAuxiliaryUnit(
865
+ self,
866
+ self._create_task,
867
+ args,
868
+ role,
869
+ identity=identity,
870
+ )
871
+ self._experiment_units.append(auxiliary_unit)
872
+ future = asyncio.Future[SlurmAuxiliaryUnit]()
873
+ future.set_result(auxiliary_unit)
874
+ return future
875
+
876
+ match role:
877
+ case xm.WorkUnitRole():
878
+ if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
879
+ future = asyncio.Future[SlurmWorkUnit]()
880
+ future.set_result(existing_unit)
881
+ return future
882
+ return _create_work_unit(role)
883
+ case xm.AuxiliaryUnitRole():
884
+ return _create_auxiliary_unit(role)
885
+ case _:
886
+ raise ValueError(f"Unsupported role {role}")
887
+
888
+ def _get_experiment_unit( # type: ignore
889
+ self,
890
+ experiment_id: int,
891
+ identity: str,
892
+ role: xm.ExperimentUnitRole,
893
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
894
+ ) -> tp.Awaitable[SlurmExperimentUnit]:
895
+ del experiment_id, identity, role, args
896
+ raise NotImplementedError
897
+
898
+ def _should_reload_experiment_unit(self, role: xm.ExperimentUnitRole) -> bool:
899
+ del role
900
+ return False
901
+
902
+ async def __aenter__(self) -> "SlurmExperiment":
903
+ await super().__aenter__()
904
+ return self
905
+
906
+ async def __aexit__(self, exc_type, exc_value, traceback):
907
+ # If no work units were added, delete this experiment
908
+ # This is to prevent empty experiments from being persisted
909
+ # and cluttering the database.
910
+ if len(self._experiment_units) == 0:
911
+ console.print(
912
+ f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
913
+ )
914
+ api.client().delete_experiment(self.experiment_id)
915
+
916
+ await super().__aexit__(exc_type, exc_value, traceback)
917
+
918
+ @property
919
+ def experiment_id(self) -> int:
920
+ return self._id
921
+
922
+ @property
923
+ def experiment_title(self) -> str:
924
+ return self.context.annotations.title
925
+
926
+ @property
927
+ def context(self) -> metadata_context.SlurmExperimentMetadataContext: # type: ignore
928
+ return self._experiment_context
929
+
930
+ @property
931
+ def work_unit_count(self) -> int:
932
+ return self._work_unit_count
933
+
934
+ def work_units(self) -> dict[int, SlurmWorkUnit]:
935
+ """Gets work units created via self.add()."""
936
+ return {
937
+ wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
938
+ }
939
+
940
+ def __repr__(self, /) -> str:
941
+ return f"<SlurmExperiment {self.experiment_id} {self.experiment_title}>"
942
+
943
+
944
+ def create_experiment(experiment_title: str) -> SlurmExperiment:
945
+ """Create Experiment."""
946
+ experiment_id = api.client().insert_experiment(
947
+ api.models.ExperimentPatch(title=experiment_title)
948
+ )
949
+ return SlurmExperiment(experiment_title=experiment_title, experiment_id=experiment_id)
950
+
951
+
952
+ def get_experiment(experiment_id: int) -> SlurmExperiment:
953
+ """Get Experiment."""
954
+ experiment_model = api.client().get_experiment(experiment_id)
955
+ experiment = SlurmExperiment(
956
+ experiment_title=experiment_model.title, experiment_id=experiment_id
957
+ )
958
+ experiment._work_unit_id_predictor = id_predictor.Predictor(1)
959
+
960
+ # Populate annotations
961
+ experiment.context.annotations.description = experiment_model.description or ""
962
+ experiment.context.annotations.note = experiment_model.note or ""
963
+ experiment.context.annotations.tags = experiment_model.tags or []
964
+
965
+ # Populate artifacts
966
+ for artifact in experiment_model.artifacts:
967
+ experiment.context.artifacts[artifact.name] = artifact.uri
968
+
969
+ # Populate work units
970
+ for wu_model in experiment_model.work_units:
971
+ work_unit = SlurmWorkUnit(
972
+ experiment=experiment,
973
+ create_task=experiment._create_task,
974
+ args=json.loads(wu_model.args) if wu_model.args else {},
975
+ role=xm.WorkUnitRole(),
976
+ identity=wu_model.identity or "",
977
+ work_unit_id_predictor=experiment._work_unit_id_predictor,
978
+ )
979
+ work_unit._work_unit_id = wu_model.wid
980
+
981
+ # Populate jobs for each work unit
982
+ for job_model in wu_model.jobs:
983
+ slurm_ssh_config = config.SSHConfig.deserialize(job_model.slurm_ssh_config)
984
+ handle = execution.SlurmHandle(
985
+ experiment_id=experiment_id,
986
+ ssh=slurm_ssh_config,
987
+ slurm_job=str(job_model.slurm_job_id),
988
+ job_name=job_model.name,
989
+ )
990
+ work_unit._execution_handles.append(handle)
991
+
992
+ # Populate artifacts for each work unit
993
+ for artifact in wu_model.artifacts:
994
+ work_unit.context.artifacts[artifact.name] = artifact.uri
995
+
996
+ experiment._experiment_units.append(work_unit)
997
+ experiment._work_unit_count += 1
998
+
999
+ return experiment
1000
+
1001
+
1002
+ @functools.cache
1003
+ def get_current_experiment() -> SlurmExperiment | None:
1004
+ if xid := os.environ.get("XM_SLURM_EXPERIMENT_ID"):
1005
+ return get_experiment(int(xid))
1006
+ return None
1007
+
1008
+
1009
+ @functools.cache
1010
+ def get_current_work_unit() -> SlurmWorkUnit | None:
1011
+ if (xid := os.environ.get("XM_SLURM_EXPERIMENT_ID")) and (
1012
+ wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
1013
+ ):
1014
+ experiment = get_experiment(int(xid))
1015
+ return experiment.work_units()[int(wid)]
1016
+ return None