xmanager-slurm 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/experiment.py CHANGED
@@ -2,25 +2,31 @@ import asyncio
2
2
  import collections.abc
3
3
  import contextvars
4
4
  import dataclasses
5
+ import datetime as dt
5
6
  import functools
6
7
  import inspect
7
8
  import json
9
+ import logging
8
10
  import os
9
- import typing
11
+ import traceback
12
+ import typing as tp
10
13
  from concurrent import futures
11
- from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
12
14
 
13
15
  import more_itertools as mit
16
+ from rich.console import ConsoleRenderable
14
17
  from xmanager import xm
15
- from xmanager.xm import async_packager, id_predictor
18
+ from xmanager.xm import async_packager, core, id_predictor, job_operators
19
+ from xmanager.xm import job_blocks as xm_job_blocks
16
20
 
17
- from xm_slurm import api, config, execution, executors
21
+ from xm_slurm import api, config, dependencies, execution, executors
18
22
  from xm_slurm.console import console
19
23
  from xm_slurm.job_blocks import JobArgs
20
24
  from xm_slurm.packaging import router
21
25
  from xm_slurm.status import SlurmWorkUnitStatus
22
26
  from xm_slurm.utils import UserSet
23
27
 
28
+ logger = logging.getLogger(__name__)
29
+
24
30
  _current_job_array_queue = contextvars.ContextVar[
25
31
  asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
26
32
  ]("_current_job_array_queue", default=None)
@@ -28,7 +34,7 @@ _current_job_array_queue = contextvars.ContextVar[
28
34
 
29
35
  def _validate_job(
30
36
  job: xm.JobType,
31
- args_view: JobArgs | Mapping[str, JobArgs],
37
+ args_view: JobArgs | tp.Mapping[str, JobArgs],
32
38
  ) -> None:
33
39
  if not args_view:
34
40
  return
@@ -51,7 +57,7 @@ def _validate_job(
51
57
  )
52
58
 
53
59
  if isinstance(job, xm.JobGroup) and key in job.jobs:
54
- _validate_job(job.jobs[key], typing.cast(JobArgs, expanded))
60
+ _validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
55
61
  elif key not in allowed_keys:
56
62
  raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
57
63
 
@@ -62,7 +68,7 @@ class Artifact:
62
68
  uri: str
63
69
 
64
70
  def __hash__(self) -> int:
65
- return hash(self.name)
71
+ return hash((type(self), self.name))
66
72
 
67
73
 
68
74
  class ContextArtifacts(UserSet[Artifact]):
@@ -70,7 +76,7 @@ class ContextArtifacts(UserSet[Artifact]):
70
76
  self,
71
77
  owner: "SlurmExperiment | SlurmExperimentUnit",
72
78
  *,
73
- artifacts: Sequence[Artifact],
79
+ artifacts: tp.Sequence[Artifact],
74
80
  ):
75
81
  super().__init__(
76
82
  artifacts,
@@ -124,8 +130,8 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
124
130
  def __init__(
125
131
  self,
126
132
  experiment: xm.Experiment,
127
- create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
128
- args: JobArgs | None,
133
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
134
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
129
135
  role: xm.ExperimentUnitRole,
130
136
  identity: str = "",
131
137
  ) -> None:
@@ -136,25 +142,137 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
136
142
  artifacts=ContextArtifacts(owner=self, artifacts=[]),
137
143
  )
138
144
 
139
- @typing.overload
145
+ def add( # type: ignore
146
+ self,
147
+ job: xm.JobType,
148
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
149
+ *,
150
+ dependency: dependencies.SlurmJobDependency | None = None,
151
+ identity: str = "",
152
+ ) -> tp.Awaitable[None]:
153
+ # Prioritize the identity given directly to the work unit at work unit
154
+ # creation time, as opposed to the identity passed when adding jobs to it as
155
+ # this is more consistent between job generator work units and regular work
156
+ # units.
157
+ identity = self.identity or identity
158
+
159
+ job = job_operators.shallow_copy_job_type(job) # type: ignore
160
+ if args is not None:
161
+ core._apply_args(job, args)
162
+ job_operators.populate_job_names(job) # type: ignore
163
+
164
+ def launch_job(job: xm.Job) -> tp.Awaitable[None]:
165
+ core._current_experiment.set(self.experiment)
166
+ core._current_experiment_unit.set(self)
167
+ return self._launch_job_group(
168
+ xm.JobGroup(**{job.name: job}), # type: ignore
169
+ core._work_unit_arguments(job, self._args),
170
+ dependency=dependency,
171
+ identity=identity,
172
+ )
173
+
174
+ def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
175
+ core._current_experiment.set(self.experiment)
176
+ core._current_experiment_unit.set(self)
177
+ return self._launch_job_group(
178
+ group,
179
+ core._work_unit_arguments(group, self._args),
180
+ dependency=dependency,
181
+ identity=identity,
182
+ )
183
+
184
+ def launch_job_generator(
185
+ job_generator: xm.JobGeneratorType,
186
+ ) -> tp.Awaitable[None]:
187
+ if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
188
+ getattr(job_generator, "__call__")
189
+ ):
190
+ raise ValueError(
191
+ "Job generator must be an async function. Signature needs to be "
192
+ "`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
193
+ )
194
+ core._current_experiment.set(self.experiment)
195
+ core._current_experiment_unit.set(self)
196
+ coroutine = job_generator(self, **(args or {}))
197
+ assert coroutine is not None
198
+ return coroutine
199
+
200
+ def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
201
+ core._current_experiment.set(self.experiment)
202
+ core._current_experiment_unit.set(self)
203
+ return self._launch_job_config(
204
+ job_config, dependency, tp.cast(JobArgs, args) or {}, identity
205
+ )
206
+
207
+ job_awaitable: tp.Awaitable[tp.Any]
208
+ match job:
209
+ case xm.Job() as job:
210
+ job_awaitable = launch_job(job)
211
+ case xm.JobGroup() as job_group:
212
+ job_awaitable = launch_job_group(job_group)
213
+ case job_generator if xm_job_blocks.is_job_generator(job):
214
+ job_awaitable = launch_job_generator(job_generator) # type: ignore
215
+ case xm.JobConfig() as job_config:
216
+ job_awaitable = launch_job_config(job_config)
217
+ case _:
218
+ raise TypeError(f"Unsupported job type: {job!r}")
219
+
220
+ launch_task = self._create_task(job_awaitable)
221
+ self._launch_tasks.append(launch_task)
222
+ return asyncio.wrap_future(launch_task)
223
+
224
+ async def _launch_job_group( # type: ignore
225
+ self,
226
+ job_group: xm.JobGroup,
227
+ args_view: tp.Mapping[str, JobArgs],
228
+ *,
229
+ dependency: dependencies.SlurmJobDependency | None,
230
+ identity: str,
231
+ ) -> None:
232
+ del job_group, dependency, args_view, identity
233
+ raise NotImplementedError
234
+
235
+ async def _launch_job_config( # type: ignore
236
+ self,
237
+ job_config: xm.JobConfig,
238
+ dependency: dependencies.SlurmJobDependency | None,
239
+ args_view: JobArgs,
240
+ identity: str,
241
+ ) -> None:
242
+ del job_config, dependency, args_view, identity
243
+ raise NotImplementedError
244
+
245
+ @tp.overload
140
246
  async def _submit_jobs_for_execution(
141
247
  self,
142
248
  job: xm.Job,
249
+ dependency: dependencies.SlurmJobDependency | None,
143
250
  args_view: JobArgs,
144
251
  identity: str | None = ...,
145
252
  ) -> execution.SlurmHandle: ...
146
253
 
147
- @typing.overload
254
+ @tp.overload
148
255
  async def _submit_jobs_for_execution(
149
256
  self,
150
257
  job: xm.JobGroup,
151
- args_view: Mapping[str, JobArgs],
258
+ dependency: dependencies.SlurmJobDependency | None,
259
+ args_view: tp.Mapping[str, JobArgs],
152
260
  identity: str | None = ...,
153
261
  ) -> execution.SlurmHandle: ...
154
262
 
155
- async def _submit_jobs_for_execution(self, job, args_view, identity=None):
263
+ @tp.overload
264
+ async def _submit_jobs_for_execution(
265
+ self,
266
+ job: xm.Job,
267
+ dependency: dependencies.SlurmJobDependency | None,
268
+ args_view: tp.Sequence[JobArgs],
269
+ identity: str | None = ...,
270
+ ) -> list[execution.SlurmHandle]: ...
271
+
272
+ async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
156
273
  return await execution.launch(
157
274
  job=job,
275
+ dependency=dependency,
158
276
  args=args_view,
159
277
  experiment_id=self.experiment_id,
160
278
  identity=identity,
@@ -173,7 +291,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
173
291
  self.work_unit_id,
174
292
  api.SlurmJobModel(
175
293
  name=job.name,
176
- slurm_job_id=handle.job_id, # type: ignore
294
+ slurm_job_id=handle.slurm_job.job_id,
177
295
  slurm_ssh_config=handle.ssh.serialize(),
178
296
  ),
179
297
  )
@@ -186,7 +304,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
186
304
  self._launched_jobs.append(
187
305
  xm.LaunchedJob(
188
306
  name=job.name, # type: ignore
189
- address=str(handle.job_id),
307
+ address=str(handle.slurm_job.job_id),
190
308
  )
191
309
  )
192
310
  case xm.Job():
@@ -194,7 +312,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
194
312
  self._launched_jobs.append(
195
313
  xm.LaunchedJob(
196
314
  name=handle.job.name, # type: ignore
197
- address=str(handle.job_id),
315
+ address=str(handle.slurm_job.job_id),
198
316
  )
199
317
  )
200
318
 
@@ -221,10 +339,25 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
221
339
 
222
340
  self.experiment._create_task(_stop_awaitable())
223
341
 
224
- async def get_status(self) -> SlurmWorkUnitStatus:
342
+ async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
225
343
  states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
226
344
  return SlurmWorkUnitStatus.aggregate(states)
227
345
 
346
+ async def logs(
347
+ self,
348
+ *,
349
+ num_lines: int = 10,
350
+ block_size: int = 1024,
351
+ wait: bool = True,
352
+ follow: bool = False,
353
+ ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
354
+ assert len(self._execution_handles) == 1, "Only one job handle is supported for logs."
355
+ handle = self._execution_handles[0] # TODO(jfarebro): interleave?
356
+ async for log in handle.logs(
357
+ num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
358
+ ):
359
+ yield log
360
+
228
361
  @property
229
362
  def launched_jobs(self) -> list[xm.LaunchedJob]:
230
363
  return self._launched_jobs
@@ -233,13 +366,27 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
233
366
  def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
234
367
  return self._context
235
368
 
369
+ def after_started(
370
+ self, *, time: dt.timedelta | None = None
371
+ ) -> dependencies.SlurmJobDependencyAfter:
372
+ return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
373
+
374
+ def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
375
+ return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
376
+
377
+ def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
378
+ return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
379
+
380
+ def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
381
+ return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
382
+
236
383
 
237
384
  class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
238
385
  def __init__(
239
386
  self,
240
387
  experiment: "SlurmExperiment",
241
- create_task: Callable[[Awaitable[Any]], futures.Future],
242
- args: JobArgs,
388
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
389
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
243
390
  role: xm.ExperimentUnitRole,
244
391
  work_unit_id_predictor: id_predictor.Predictor,
245
392
  identity: str = "",
@@ -258,7 +405,9 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
258
405
  async def _launch_job_group( # type: ignore
259
406
  self,
260
407
  job: xm.JobGroup,
261
- args_view: Mapping[str, JobArgs],
408
+ args_view: tp.Mapping[str, JobArgs],
409
+ *,
410
+ dependency: dependencies.SlurmJobDependency | None,
262
411
  identity: str,
263
412
  ) -> None:
264
413
  global _current_job_array_queue
@@ -291,7 +440,9 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
291
440
  # If the future is already done, i.e., the handle is already resolved, we don't need
292
441
  # to submit the job again.
293
442
  elif not future.done():
294
- handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
443
+ handle = await self._submit_jobs_for_execution(
444
+ job, dependency, args_view, identity=identity
445
+ )
295
446
  future.set_result(handle)
296
447
 
297
448
  # Wait for the job handle, this is either coming from scheduling the job array
@@ -317,12 +468,16 @@ class SlurmAuxiliaryUnit(SlurmExperimentUnit):
317
468
  async def _launch_job_group( # type: ignore
318
469
  self,
319
470
  job: xm.JobGroup,
320
- args_view: Mapping[str, JobArgs],
471
+ args_view: tp.Mapping[str, JobArgs],
472
+ *,
473
+ dependency: dependencies.SlurmJobDependency | None,
321
474
  identity: str,
322
475
  ) -> None:
323
476
  _validate_job(job, args_view)
324
477
 
325
- slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
478
+ slurm_handle = await self._submit_jobs_for_execution(
479
+ job, dependency, args_view, identity=identity
480
+ )
326
481
  self._ingest_launched_jobs(job, slurm_handle)
327
482
 
328
483
  @property
@@ -392,7 +547,7 @@ class SlurmExperimentContextAnnotations:
392
547
  )
393
548
 
394
549
  @property
395
- def tags(self) -> MutableSet[str]:
550
+ def tags(self) -> tp.MutableSet[str]:
396
551
  return self._tags
397
552
 
398
553
  @tags.setter
@@ -459,87 +614,82 @@ class SlurmExperiment(xm.Experiment):
459
614
  )
460
615
  self._work_unit_count = 0
461
616
 
462
- @typing.overload
463
- def add(
617
+ @tp.overload
618
+ def add( # type: ignore
464
619
  self,
465
620
  job: xm.AuxiliaryUnitJob,
466
- args: Mapping[str, Any] | None = ...,
621
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
467
622
  *,
623
+ dependency: dependencies.SlurmJobDependency | None = ...,
468
624
  identity: str = ...,
469
625
  ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
470
626
 
471
- @typing.overload
627
+ @tp.overload
472
628
  def add(
473
629
  self,
474
630
  job: xm.JobGroup,
475
- args: Mapping[str, Mapping[str, Any]] | None,
631
+ args: tp.Mapping[str, JobArgs] | None = ...,
476
632
  *,
477
- role: xm.WorkUnitRole = xm.WorkUnitRole(),
633
+ role: xm.WorkUnitRole | None = ...,
634
+ dependency: dependencies.SlurmJobDependency | None = ...,
478
635
  identity: str = ...,
479
636
  ) -> asyncio.Future[SlurmWorkUnit]: ...
480
637
 
481
- @typing.overload
638
+ @tp.overload
482
639
  def add(
483
640
  self,
484
- job: xm.JobGroup,
485
- args: Mapping[str, Mapping[str, Any]] | None,
641
+ job: xm.Job | xm.JobGeneratorType,
642
+ args: tp.Sequence[JobArgs],
486
643
  *,
487
- role: xm.ExperimentUnitRole,
644
+ role: xm.WorkUnitRole | None = ...,
645
+ dependency: dependencies.SlurmJobDependency
646
+ | tp.Sequence[dependencies.SlurmJobDependency]
647
+ | None = ...,
488
648
  identity: str = ...,
489
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
649
+ ) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
490
650
 
491
- @typing.overload
651
+ @tp.overload
492
652
  def add(
493
653
  self,
494
654
  job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
495
- args: Mapping[str, Any] | None,
655
+ args: JobArgs | None = ...,
496
656
  *,
497
- role: xm.WorkUnitRole = xm.WorkUnitRole(),
657
+ role: xm.WorkUnitRole | None = ...,
658
+ dependency: dependencies.SlurmJobDependency | None = ...,
498
659
  identity: str = ...,
499
660
  ) -> asyncio.Future[SlurmWorkUnit]: ...
500
661
 
501
- @typing.overload
502
- def add(
503
- self,
504
- job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
505
- args: Mapping[str, Any] | None,
506
- *,
507
- role: xm.ExperimentUnitRole,
508
- identity: str = ...,
509
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
510
-
511
- @typing.overload
512
- def add(
513
- self,
514
- job: xm.Job | xm.JobGeneratorType,
515
- args: Sequence[Mapping[str, Any]],
516
- *,
517
- role: xm.WorkUnitRole = xm.WorkUnitRole(),
518
- identity: str = ...,
519
- ) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
520
-
521
- @typing.overload
662
+ @tp.overload
522
663
  def add(
523
664
  self,
524
665
  job: xm.JobType,
525
666
  *,
526
- role: xm.AuxiliaryUnitRole = ...,
667
+ role: xm.AuxiliaryUnitRole,
668
+ dependency: dependencies.SlurmJobDependency | None = ...,
527
669
  identity: str = ...,
528
670
  ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
529
671
 
530
672
  def add( # type: ignore
531
673
  self,
532
674
  job: xm.JobType,
533
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
675
+ args: JobArgs
676
+ | tp.Mapping[str, JobArgs]
677
+ | tp.Sequence[tp.Mapping[str, tp.Any]]
678
+ | None = None,
534
679
  *,
535
- role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
680
+ role: xm.ExperimentUnitRole | None = None,
681
+ dependency: dependencies.SlurmJobDependency
682
+ | tp.Sequence[dependencies.SlurmJobDependency]
683
+ | None = None,
536
684
  identity: str = "",
537
685
  ) -> (
538
686
  asyncio.Future[SlurmAuxiliaryUnit]
539
- | asyncio.Future[SlurmExperimentUnit]
540
687
  | asyncio.Future[SlurmWorkUnit]
541
- | asyncio.Future[Sequence[SlurmWorkUnit]]
688
+ | asyncio.Future[tp.Sequence[SlurmWorkUnit]]
542
689
  ):
690
+ if role is None:
691
+ role = xm.WorkUnitRole()
692
+
543
693
  if isinstance(args, collections.abc.Sequence):
544
694
  if not isinstance(role, xm.WorkUnitRole):
545
695
  raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
@@ -554,21 +704,76 @@ class SlurmExperiment(xm.Experiment):
554
704
  # Validate job & args
555
705
  for trial in args:
556
706
  _validate_job(job, trial)
557
- args = typing.cast(Sequence[JobArgs], args)
707
+ args = tp.cast(tp.Sequence[JobArgs], args)
558
708
 
559
709
  return asyncio.wrap_future(
560
- self._create_task(self._launch_job_array(job, args, role, identity))
710
+ self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
711
+ loop=self._event_loop,
712
+ )
713
+ if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
714
+ raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
715
+
716
+ if isinstance(job, xm.AuxiliaryUnitJob):
717
+ role = job.role
718
+ self._added_roles[type(role)] += 1
719
+
720
+ if self._should_reload_experiment_unit(role):
721
+ experiment_unit_future = self._get_experiment_unit(
722
+ self.experiment_id, identity, role, args
561
723
  )
562
724
  else:
563
- return super().add(job, args, role=role, identity=identity) # type: ignore
725
+ experiment_unit_future = self._create_experiment_unit(args, role, identity)
726
+
727
+ async def launch():
728
+ experiment_unit = await experiment_unit_future
729
+ try:
730
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
731
+ except Exception as experiment_exception:
732
+ logger.error(
733
+ "Stopping experiment unit (identity %r) after it failed with: %s",
734
+ identity,
735
+ experiment_exception,
736
+ )
737
+ try:
738
+ if isinstance(job, xm.AuxiliaryUnitJob):
739
+ experiment_unit.stop()
740
+ else:
741
+ experiment_unit.stop(
742
+ mark_as_failed=True,
743
+ message=f"Work unit creation failed. {traceback.format_exc()}",
744
+ )
745
+ except Exception as stop_exception: # pylint: disable=broad-except
746
+ logger.error("Couldn't stop experiment unit: %s", stop_exception)
747
+ raise
748
+ return experiment_unit
749
+
750
+ async def reload():
751
+ experiment_unit = await experiment_unit_future
752
+ try:
753
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
754
+ except Exception as update_exception:
755
+ logging.error(
756
+ "Could not reload the experiment unit: %s",
757
+ update_exception,
758
+ )
759
+ raise
760
+ return experiment_unit
761
+
762
+ return asyncio.wrap_future(
763
+ self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
764
+ loop=self._event_loop,
765
+ )
564
766
 
565
767
  async def _launch_job_array(
566
768
  self,
567
769
  job: xm.Job | xm.JobGeneratorType,
568
- args: Sequence[JobArgs],
770
+ dependency: dependencies.SlurmJobDependency
771
+ | tp.Sequence[dependencies.SlurmJobDependency]
772
+ | None,
773
+ args: tp.Sequence[JobArgs],
569
774
  role: xm.WorkUnitRole,
570
775
  identity: str = "",
571
- ) -> Sequence[SlurmWorkUnit]:
776
+ ) -> tp.Sequence[SlurmWorkUnit]:
572
777
  global _current_job_array_queue
573
778
 
574
779
  # Create our job array queue and assign it to the current context
@@ -579,7 +784,11 @@ class SlurmExperiment(xm.Experiment):
579
784
  # and collect the futures
580
785
  wu_futures = []
581
786
  for idx, trial in enumerate(args):
582
- wu_futures.append(super().add(job, args=trial, role=role, identity=f"{identity}_{idx}"))
787
+ wu_futures.append(
788
+ self.add(
789
+ job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
790
+ )
791
+ )
583
792
 
584
793
  # We'll wait until XManager has filled the queue.
585
794
  # There are two cases here, either we were given an xm.Job
@@ -589,7 +798,8 @@ class SlurmExperiment(xm.Experiment):
589
798
  while not job_array_queue.full():
590
799
  await asyncio.sleep(0.1)
591
800
 
592
- # All jobs have been resolved
801
+ # All jobs have been resolved so now we'll perform sanity checks
802
+ # to make sure we can infer the sweep
593
803
  executable, executor, name = None, None, None
594
804
  resolved_args, resolved_env_vars, resolved_futures = [], [], []
595
805
  while not job_array_queue.empty():
@@ -650,6 +860,78 @@ class SlurmExperiment(xm.Experiment):
650
860
  for a, e in zip(resolved_args, resolved_env_vars)
651
861
  ]
652
862
 
863
+ # Dependency resolution
864
+ resolved_dependency = None
865
+ resolved_dependency_task_id_order = None
866
+ # one-to-one
867
+ if isinstance(dependency, collections.abc.Sequence):
868
+ if len(dependency) != len(wu_futures):
869
+ raise ValueError("Dependency list must be the same length as the number of trials.")
870
+ assert len(dependency) > 0, "Dependency list must not be empty."
871
+
872
+ # Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
873
+ # for any array jobs.
874
+ def _maybe_convert_afterok(
875
+ dep: dependencies.SlurmJobDependency,
876
+ ) -> dependencies.SlurmJobDependency:
877
+ if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
878
+ handle.slurm_job.is_array_job for handle in dep.handles
879
+ ]):
880
+ return dependencies.SlurmJobArrayDependencyAfterOK([
881
+ dataclasses.replace(
882
+ handle,
883
+ slurm_job=handle.slurm_job.array_job_id,
884
+ )
885
+ for handle in dep.handles
886
+ ])
887
+ return dep
888
+
889
+ dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
890
+ dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
891
+ dependency_differences = functools.reduce(set.difference, dependency_sets, set())
892
+ # There should be NO differences between the dependencies of each trial after conversion.
893
+ if len(dependency_differences) > 0:
894
+ raise ValueError(
895
+ f"Found variable dependencies across trials: {dependency_differences}. "
896
+ "Slurm job arrays require the same dependencies across all trials. "
897
+ )
898
+ resolved_dependency = dependencies_converted[0]
899
+
900
+ # This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
901
+ # in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
902
+ original_array_dependencies = [
903
+ mit.one(
904
+ filter(
905
+ lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
906
+ and all([handle.slurm_job.is_array_job for handle in dep.handles]),
907
+ deps.flatten(),
908
+ )
909
+ )
910
+ for deps in dependency
911
+ ]
912
+ resolved_dependency_task_id_order = [
913
+ int(
914
+ mit.one(
915
+ functools.reduce(
916
+ set.difference,
917
+ [handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
918
+ )
919
+ )
920
+ )
921
+ for dep in original_array_dependencies
922
+ ]
923
+ assert len(resolved_dependency_task_id_order) == len(sweep_args)
924
+ assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
925
+ "Dependent job array tasks should have task ids 0, 1, ..., N. "
926
+ f"Found: {resolved_dependency_task_id_order}"
927
+ )
928
+ # one-to-many
929
+ elif isinstance(dependency, dependencies.SlurmJobDependency):
930
+ resolved_dependency = dependency
931
+ assert resolved_dependency is None or isinstance(
932
+ resolved_dependency, dependencies.SlurmJobDependency
933
+ ), "Invalid dependency type"
934
+
653
935
  # No support for sweep_env_vars right now.
654
936
  # We schedule the job array and then we'll resolve all the work units with
655
937
  # the handles Slurm gives back to us.
@@ -666,10 +948,18 @@ class SlurmExperiment(xm.Experiment):
666
948
  args=xm.SequentialArgs.from_collection(common_args),
667
949
  env_vars=dict(common_env_vars),
668
950
  ),
669
- args=sweep_args,
951
+ dependency=resolved_dependency,
952
+ args=[
953
+ sweep_args[resolved_dependency_task_id_order.index(i)]
954
+ for i in range(len(sweep_args))
955
+ ]
956
+ if resolved_dependency_task_id_order
957
+ else sweep_args,
670
958
  experiment_id=self.experiment_id,
671
959
  identity=identity,
672
960
  )
961
+ if resolved_dependency_task_id_order:
962
+ handles = [handles[i] for i in resolved_dependency_task_id_order]
673
963
  except Exception as e:
674
964
  for future in resolved_futures:
675
965
  future.set_exception(e)
@@ -697,11 +987,11 @@ class SlurmExperiment(xm.Experiment):
697
987
 
698
988
  def _create_experiment_unit( # type: ignore
699
989
  self,
700
- args: JobArgs,
990
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
701
991
  role: xm.ExperimentUnitRole,
702
992
  identity: str,
703
- ) -> Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
704
- def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
993
+ ) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
994
+ def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
705
995
  work_unit = SlurmWorkUnit(
706
996
  self,
707
997
  self._create_task,
@@ -726,7 +1016,7 @@ class SlurmExperiment(xm.Experiment):
726
1016
  future.set_result(work_unit)
727
1017
  return future
728
1018
 
729
- def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> Awaitable[SlurmAuxiliaryUnit]:
1019
+ def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
730
1020
  auxiliary_unit = SlurmAuxiliaryUnit(
731
1021
  self,
732
1022
  self._create_task,
@@ -756,8 +1046,8 @@ class SlurmExperiment(xm.Experiment):
756
1046
  experiment_id: int,
757
1047
  identity: str,
758
1048
  role: xm.ExperimentUnitRole,
759
- args: JobArgs | None = None,
760
- ) -> Awaitable[xm.ExperimentUnit]:
1049
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
1050
+ ) -> tp.Awaitable[SlurmExperimentUnit]:
761
1051
  del experiment_id, identity, role, args
762
1052
  raise NotImplementedError
763
1053
 
@@ -797,7 +1087,7 @@ class SlurmExperiment(xm.Experiment):
797
1087
  def work_unit_count(self) -> int:
798
1088
  return self._work_unit_count
799
1089
 
800
- def work_units(self) -> Mapping[int, SlurmWorkUnit]:
1090
+ def work_units(self) -> dict[int, SlurmWorkUnit]:
801
1091
  """Gets work units created via self.add()."""
802
1092
  return {
803
1093
  wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
@@ -822,8 +1112,8 @@ def get_experiment(experiment_id: int) -> SlurmExperiment:
822
1112
  experiment._work_unit_id_predictor = id_predictor.Predictor(1)
823
1113
 
824
1114
  # Populate annotations
825
- experiment.context.annotations.description = experiment_model.description
826
- experiment.context.annotations.note = experiment_model.note
1115
+ experiment.context.annotations.description = experiment_model.description or ""
1116
+ experiment.context.annotations.note = experiment_model.note or ""
827
1117
  experiment.context.annotations.tags = set(experiment_model.tags or [])
828
1118
 
829
1119
  # Populate artifacts
@@ -846,8 +1136,9 @@ def get_experiment(experiment_id: int) -> SlurmExperiment:
846
1136
  for job_model in wu_model.jobs:
847
1137
  slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
848
1138
  handle = execution.SlurmHandle(
1139
+ experiment_id=experiment_id,
849
1140
  ssh=slurm_ssh_config,
850
- job_id=str(job_model.slurm_job_id),
1141
+ slurm_job=str(job_model.slurm_job_id),
851
1142
  job_name=job_model.name,
852
1143
  )
853
1144
  work_unit._execution_handles.append(handle)