xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

xm_slurm/experiment.py CHANGED
@@ -10,11 +10,13 @@ import typing
10
10
  from concurrent import futures
11
11
  from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
12
12
 
13
+ import more_itertools as mit
13
14
  from xmanager import xm
14
15
  from xmanager.xm import async_packager, id_predictor
15
16
 
16
- from xm_slurm import api, execution, executors
17
+ from xm_slurm import api, config, execution, executors
17
18
  from xm_slurm.console import console
19
+ from xm_slurm.job_blocks import JobArgs
18
20
  from xm_slurm.packaging import router
19
21
  from xm_slurm.status import SlurmWorkUnitStatus
20
22
  from xm_slurm.utils import UserSet
@@ -26,7 +28,7 @@ _current_job_array_queue = contextvars.ContextVar[
26
28
 
27
29
  def _validate_job(
28
30
  job: xm.JobType,
29
- args_view: Mapping[str, Any],
31
+ args_view: JobArgs | Mapping[str, JobArgs],
30
32
  ) -> None:
31
33
  if not args_view:
32
34
  return
@@ -49,7 +51,7 @@ def _validate_job(
49
51
  )
50
52
 
51
53
  if isinstance(job, xm.JobGroup) and key in job.jobs:
52
- _validate_job(job.jobs[key], expanded)
54
+ _validate_job(job.jobs[key], typing.cast(JobArgs, expanded))
53
55
  elif key not in allowed_keys:
54
56
  raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
55
57
 
@@ -117,28 +119,40 @@ class SlurmExperimentUnitMetadataContext:
117
119
  class SlurmExperimentUnit(xm.ExperimentUnit):
118
120
  """ExperimentUnit is a collection of semantically associated `Job`s."""
119
121
 
120
- experiment: "SlurmExperiment"
122
+ experiment: "SlurmExperiment" # type: ignore
121
123
 
122
124
  def __init__(
123
125
  self,
124
126
  experiment: xm.Experiment,
125
127
  create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
126
- args: Mapping[str, Any] | None,
128
+ args: JobArgs | None,
127
129
  role: xm.ExperimentUnitRole,
130
+ identity: str = "",
128
131
  ) -> None:
129
- super().__init__(experiment, create_task, args, role)
132
+ super().__init__(experiment, create_task, args, role, identity=identity)
130
133
  self._launched_jobs: list[xm.LaunchedJob] = []
131
134
  self._execution_handles: list[execution.SlurmHandle] = []
132
135
  self._context = SlurmExperimentUnitMetadataContext(
133
136
  artifacts=ContextArtifacts(owner=self, artifacts=[]),
134
137
  )
135
138
 
139
+ @typing.overload
140
+ async def _submit_jobs_for_execution(
141
+ self,
142
+ job: xm.Job,
143
+ args_view: JobArgs,
144
+ identity: str | None = ...,
145
+ ) -> execution.SlurmHandle: ...
146
+
147
+ @typing.overload
136
148
  async def _submit_jobs_for_execution(
137
149
  self,
138
- job: xm.Job | xm.JobGroup,
139
- args_view: Mapping[str, Any],
140
- identity: str | None = None,
141
- ) -> execution.SlurmHandle:
150
+ job: xm.JobGroup,
151
+ args_view: Mapping[str, JobArgs],
152
+ identity: str | None = ...,
153
+ ) -> execution.SlurmHandle: ...
154
+
155
+ async def _submit_jobs_for_execution(self, job, args_view, identity=None):
142
156
  return await execution.launch(
143
157
  job=job,
144
158
  args=args_view,
@@ -147,9 +161,28 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
147
161
  )
148
162
 
149
163
  def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
164
+ self._execution_handles.append(handle)
165
+
166
+ def _ingest_job(job: xm.Job) -> None:
167
+ if not isinstance(self._role, xm.WorkUnitRole):
168
+ return
169
+ assert isinstance(self, SlurmWorkUnit)
170
+ assert job.name is not None
171
+ api.client().insert_job(
172
+ self.experiment_id,
173
+ self.work_unit_id,
174
+ api.SlurmJobModel(
175
+ name=job.name,
176
+ slurm_job_id=handle.job_id, # type: ignore
177
+ slurm_ssh_config=handle.ssh.serialize(),
178
+ ),
179
+ )
180
+
150
181
  match job:
151
182
  case xm.JobGroup() as job_group:
152
183
  for job in job_group.jobs.values():
184
+ assert isinstance(job, xm.Job)
185
+ _ingest_job(job)
153
186
  self._launched_jobs.append(
154
187
  xm.LaunchedJob(
155
188
  name=job.name, # type: ignore
@@ -157,6 +190,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
157
190
  )
158
191
  )
159
192
  case xm.Job():
193
+ _ingest_job(job)
160
194
  self._launched_jobs.append(
161
195
  xm.LaunchedJob(
162
196
  name=handle.job.name, # type: ignore
@@ -191,11 +225,12 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
191
225
  states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
192
226
  return SlurmWorkUnitStatus.aggregate(states)
193
227
 
228
+ @property
194
229
  def launched_jobs(self) -> list[xm.LaunchedJob]:
195
230
  return self._launched_jobs
196
231
 
197
232
  @property
198
- def context(self) -> SlurmExperimentUnitMetadataContext:
233
+ def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
199
234
  return self._context
200
235
 
201
236
 
@@ -204,68 +239,65 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
204
239
  self,
205
240
  experiment: "SlurmExperiment",
206
241
  create_task: Callable[[Awaitable[Any]], futures.Future],
207
- args: Mapping[str, Any],
242
+ args: JobArgs,
208
243
  role: xm.ExperimentUnitRole,
209
244
  work_unit_id_predictor: id_predictor.Predictor,
245
+ identity: str = "",
210
246
  ) -> None:
211
- super().__init__(experiment, create_task, args, role)
247
+ super().__init__(experiment, create_task, args, role, identity=identity)
212
248
  self._work_unit_id_predictor = work_unit_id_predictor
213
249
  self._work_unit_id = self._work_unit_id_predictor.reserve_id()
214
- api.client().insert_work_unit(
215
- self.experiment_id,
216
- api.WorkUnitPatchModel(
217
- wid=self.work_unit_id,
218
- identity=self.identity,
219
- args=json.dumps(args),
220
- ),
221
- )
222
250
 
223
- def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
224
- self._execution_handles.append(handle)
225
- api.client().insert_job(
226
- self.experiment_id,
227
- self.work_unit_id,
228
- api.SlurmJobModel(
229
- name=self.experiment_unit_name,
230
- slurm_job_id=handle.job_id, # type: ignore
231
- slurm_cluster=json.dumps({
232
- "host": handle.ssh_connection_options.host,
233
- "username": handle.ssh_connection_options.username,
234
- "port": handle.ssh_connection_options.port,
235
- "config": handle.ssh_connection_options.config.get_options(False),
236
- }),
237
- ),
238
- )
251
+ def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
252
+ job_name = mit.one(job.jobs.keys())
253
+ for handle in self._execution_handles:
254
+ if handle.job_name == job_name:
255
+ return handle
256
+ return None
239
257
 
240
- async def _launch_job_group(
258
+ async def _launch_job_group( # type: ignore
241
259
  self,
242
260
  job: xm.JobGroup,
243
- args_view: Mapping[str, Any],
261
+ args_view: Mapping[str, JobArgs],
244
262
  identity: str,
245
263
  ) -> None:
246
264
  global _current_job_array_queue
247
265
  _validate_job(job, args_view)
266
+ future = asyncio.Future[execution.SlurmHandle]()
267
+
268
+ # If we already have a handle for this job, we don't need to submit it again.
269
+ # We'll just resolve the future with the existing handle.
270
+ # Otherwise we'll add callbacks to ingest the handle and the launched jobs.
271
+ if existing_handle := self._get_existing_handle(job):
272
+ future.set_result(existing_handle)
273
+ else:
274
+ future.add_done_callback(
275
+ lambda handle: self._ingest_launched_jobs(job, handle.result())
276
+ )
277
+
278
+ api.client().update_work_unit(
279
+ self.experiment_id,
280
+ self.work_unit_id,
281
+ api.ExperimentUnitPatchModel(args=json.dumps(args_view), identity=None),
282
+ )
248
283
 
249
- future = asyncio.Future()
250
284
  async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
251
285
  # If we're scheduling as part of a job queue (i.e., the queue is set on the context)
252
- # then we'll insert the job and current future that'll get resolved to the
253
- # proper handle.
286
+ # then we'll insert the job and future that'll get resolved to the proper handle
287
+ # when the Slurm job array is scheduled.
254
288
  if job_array_queue := _current_job_array_queue.get():
255
289
  job_array_queue.put_nowait((job, future))
256
- # Otherwise we'll resolve the future with the scheduled job immediately
257
- else:
258
- # Set the result inside of the context manager so we don't get out-of-order
259
- # id scheduling...
260
- future.set_result(
261
- await self._submit_jobs_for_execution(job, args_view, identity=identity)
262
- )
290
+ # Otherwise, we're scheduling a single job and we'll submit it for execution.
291
+ # If the future is already done, i.e., the handle is already resolved, we don't need
292
+ # to submit the job again.
293
+ elif not future.done():
294
+ handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
295
+ future.set_result(handle)
263
296
 
264
297
  # Wait for the job handle, this is either coming from scheduling the job array
265
- # or from the single job above.
266
- handle = await future
267
- self._ingest_handle(handle)
268
- self._ingest_launched_jobs(job, handle)
298
+ # or from a single job submission. If an existing handle was found, this will be
299
+ # a no-op.
300
+ await future
269
301
 
270
302
  @property
271
303
  def experiment_unit_name(self) -> str:
@@ -282,20 +314,15 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
282
314
  class SlurmAuxiliaryUnit(SlurmExperimentUnit):
283
315
  """An auxiliary unit operated by the Slurm backend."""
284
316
 
285
- def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
286
- del handle
287
- console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
288
-
289
- async def _launch_job_group(
317
+ async def _launch_job_group( # type: ignore
290
318
  self,
291
- job: xm.Job | xm.JobGroup,
292
- args_view: Mapping[str, Any],
319
+ job: xm.JobGroup,
320
+ args_view: Mapping[str, JobArgs],
293
321
  identity: str,
294
322
  ) -> None:
295
323
  _validate_job(job, args_view)
296
324
 
297
325
  slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
298
- self._ingest_handle(slurm_handle)
299
326
  self._ingest_launched_jobs(job, slurm_handle)
300
327
 
301
328
  @property
@@ -438,37 +465,47 @@ class SlurmExperiment(xm.Experiment):
438
465
  job: xm.AuxiliaryUnitJob,
439
466
  args: Mapping[str, Any] | None = ...,
440
467
  *,
441
- identity: str = "",
442
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
468
+ identity: str = ...,
469
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
443
470
 
444
471
  @typing.overload
445
472
  def add(
446
473
  self,
447
- job: xm.JobType,
448
- args: Mapping[str, Any] | None = ...,
474
+ job: xm.JobGroup,
475
+ args: Mapping[str, Mapping[str, Any]] | None,
449
476
  *,
450
- role: xm.WorkUnitRole = ...,
451
- identity: str = "",
477
+ role: xm.WorkUnitRole = xm.WorkUnitRole(),
478
+ identity: str = ...,
452
479
  ) -> asyncio.Future[SlurmWorkUnit]: ...
453
480
 
454
481
  @typing.overload
455
482
  def add(
456
483
  self,
457
- job: xm.JobType,
458
- args: Mapping[str, Any] | None,
484
+ job: xm.JobGroup,
485
+ args: Mapping[str, Mapping[str, Any]] | None,
459
486
  *,
460
487
  role: xm.ExperimentUnitRole,
461
- identity: str = "",
488
+ identity: str = ...,
462
489
  ) -> asyncio.Future[SlurmExperimentUnit]: ...
463
490
 
464
491
  @typing.overload
465
492
  def add(
466
493
  self,
467
- job: xm.JobType,
468
- args: Mapping[str, Any] | None = ...,
494
+ job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
495
+ args: Mapping[str, Any] | None,
496
+ *,
497
+ role: xm.WorkUnitRole = xm.WorkUnitRole(),
498
+ identity: str = ...,
499
+ ) -> asyncio.Future[SlurmWorkUnit]: ...
500
+
501
+ @typing.overload
502
+ def add(
503
+ self,
504
+ job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
505
+ args: Mapping[str, Any] | None,
469
506
  *,
470
507
  role: xm.ExperimentUnitRole,
471
- identity: str = "",
508
+ identity: str = ...,
472
509
  ) -> asyncio.Future[SlurmExperimentUnit]: ...
473
510
 
474
511
  @typing.overload
@@ -477,11 +514,20 @@ class SlurmExperiment(xm.Experiment):
477
514
  job: xm.Job | xm.JobGeneratorType,
478
515
  args: Sequence[Mapping[str, Any]],
479
516
  *,
480
- role: xm.WorkUnitRole = ...,
481
- identity: str = "",
517
+ role: xm.WorkUnitRole = xm.WorkUnitRole(),
518
+ identity: str = ...,
482
519
  ) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
483
520
 
521
+ @typing.overload
484
522
  def add(
523
+ self,
524
+ job: xm.JobType,
525
+ *,
526
+ role: xm.AuxiliaryUnitRole = ...,
527
+ identity: str = ...,
528
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
529
+
530
+ def add( # type: ignore
485
531
  self,
486
532
  job: xm.JobType,
487
533
  args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
@@ -489,18 +535,14 @@ class SlurmExperiment(xm.Experiment):
489
535
  role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
490
536
  identity: str = "",
491
537
  ) -> (
492
- asyncio.Future[SlurmExperimentUnit]
538
+ asyncio.Future[SlurmAuxiliaryUnit]
539
+ | asyncio.Future[SlurmExperimentUnit]
493
540
  | asyncio.Future[SlurmWorkUnit]
494
541
  | asyncio.Future[Sequence[SlurmWorkUnit]]
495
542
  ):
496
543
  if isinstance(args, collections.abc.Sequence):
497
544
  if not isinstance(role, xm.WorkUnitRole):
498
545
  raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
499
- if identity:
500
- raise ValueError(
501
- "Cannot set an identity on the root add call. "
502
- "Please use a job generator and set the identity within."
503
- )
504
546
  if isinstance(job, xm.JobGroup):
505
547
  raise ValueError(
506
548
  "Job arrays over `xm.JobGroup`s aren't supported. "
@@ -509,6 +551,11 @@ class SlurmExperiment(xm.Experiment):
509
551
  )
510
552
  assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
511
553
 
554
+ # Validate job & args
555
+ for trial in args:
556
+ _validate_job(job, trial)
557
+ args = typing.cast(Sequence[JobArgs], args)
558
+
512
559
  return asyncio.wrap_future(
513
560
  self._create_task(self._launch_job_array(job, args, role, identity))
514
561
  )
@@ -518,7 +565,7 @@ class SlurmExperiment(xm.Experiment):
518
565
  async def _launch_job_array(
519
566
  self,
520
567
  job: xm.Job | xm.JobGeneratorType,
521
- args: Sequence[Mapping[str, Any]],
568
+ args: Sequence[JobArgs],
522
569
  role: xm.WorkUnitRole,
523
570
  identity: str = "",
524
571
  ) -> Sequence[SlurmWorkUnit]:
@@ -531,10 +578,9 @@ class SlurmExperiment(xm.Experiment):
531
578
  # For each trial we'll schedule the job
532
579
  # and collect the futures
533
580
  wu_futures = []
534
- for trial in args:
535
- wu_futures.append(super().add(job, args=trial, role=role, identity=identity))
581
+ for idx, trial in enumerate(args):
582
+ wu_futures.append(super().add(job, args=trial, role=role, identity=f"{identity}_{idx}"))
536
583
 
537
- # TODO(jfarebro): Set a timeout here
538
584
  # We'll wait until XManager has filled the queue.
539
585
  # There are two cases here, either we were given an xm.Job
540
586
  # in which case this will be trivial and filled immediately.
@@ -551,9 +597,13 @@ class SlurmExperiment(xm.Experiment):
551
597
  # that there's only a single job in this job group
552
598
  job_group_view, future = job_array_queue.get_nowait()
553
599
  assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
554
- _, job_view = job_group_view.jobs.popitem()
600
+ job_view = mit.one(
601
+ job_group_view.jobs.values(),
602
+ too_short=ValueError("Expected a single `xm.Job` in job group."),
603
+ too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
604
+ )
555
605
 
556
- if len(job_group_view.jobs) != 0 or not isinstance(job_view, xm.Job):
606
+ if not isinstance(job_view, xm.Job):
557
607
  raise ValueError("Only `xm.Job` is supported for job arrays. ")
558
608
 
559
609
  if executable is None:
@@ -571,64 +621,86 @@ class SlurmExperiment(xm.Experiment):
571
621
  if job_view.name != name:
572
622
  raise RuntimeError("Found multiple names in job array")
573
623
 
574
- resolved_args.append(
575
- set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
576
- )
624
+ resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
577
625
  resolved_env_vars.append(set(job_view.env_vars.items()))
578
626
  resolved_futures.append(future)
579
627
  assert executable is not None, "No executable found?"
580
628
  assert executor is not None, "No executor found?"
581
629
  assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
582
- assert executor.requirements.cluster is not None, "Cluster must be set on executor."
583
-
584
- common_args: set = functools.reduce(lambda a, b: a & b, resolved_args, set())
630
+ assert (
631
+ executor.requirements.cluster is not None
632
+ ), "Cluster must be specified on requirements."
633
+
634
+ # XManager merges job arguments with keyword arguments with job arguments
635
+ # coming first. These are the arguments that may be common across all jobs
636
+ # so we can find the largest common prefix and remove them from each job.
637
+ common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
585
638
  common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
586
639
 
587
640
  sweep_args = [
588
- {
589
- "args": dict(a.difference(common_args)),
590
- "env_vars": dict(e.difference(common_env_vars)),
591
- }
641
+ JobArgs(
642
+ args=functools.reduce(
643
+ # Remove the common arguments from each job
644
+ lambda args, to_remove: args.remove_args(to_remove),
645
+ common_args,
646
+ xm.SequentialArgs.from_collection(a),
647
+ ),
648
+ env_vars=dict(e.difference(common_env_vars)),
649
+ )
592
650
  for a, e in zip(resolved_args, resolved_env_vars)
593
651
  ]
594
652
 
595
653
  # No support for sweep_env_vars right now.
596
654
  # We schedule the job array and then we'll resolve all the work units with
597
655
  # the handles Slurm gives back to us.
598
- try:
599
- handles = await execution.get_client().launch(
600
- cluster=executor.requirements.cluster,
601
- job=xm.Job(
602
- executable=executable,
603
- executor=executor,
604
- name=name,
605
- args=dict(common_args),
606
- env_vars=dict(common_env_vars),
607
- ),
608
- args=sweep_args,
609
- experiment_id=self.experiment_id,
610
- identity=identity,
656
+ # If we already have handles for all the work units, we don't need to submit the
657
+ # job array to SLURM.
658
+ num_resolved_handles = sum(future.done() for future in resolved_futures)
659
+ if num_resolved_handles == 0:
660
+ try:
661
+ handles = await execution.launch(
662
+ job=xm.Job(
663
+ executable=executable,
664
+ executor=executor,
665
+ name=name,
666
+ args=xm.SequentialArgs.from_collection(common_args),
667
+ env_vars=dict(common_env_vars),
668
+ ),
669
+ args=sweep_args,
670
+ experiment_id=self.experiment_id,
671
+ identity=identity,
672
+ )
673
+ except Exception as e:
674
+ for future in resolved_futures:
675
+ future.set_exception(e)
676
+ raise
677
+ else:
678
+ for handle, future in zip(handles, resolved_futures):
679
+ future.set_result(handle)
680
+ elif 0 < num_resolved_handles < len(resolved_futures):
681
+ raise RuntimeError(
682
+ "Some array job elements have handles, but some don't. This shouldn't happen."
611
683
  )
612
- except Exception as e:
613
- for future in resolved_futures:
614
- future.set_exception(e)
615
- raise
616
- else:
617
- for handle, future in zip(handles, resolved_futures):
618
- future.set_result(handle)
619
684
 
620
685
  wus = await asyncio.gather(*wu_futures)
686
+
621
687
  _current_job_array_queue.set(None)
622
688
  return wus
623
689
 
624
- def _create_experiment_unit(
690
+ def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
691
+ if identity == "":
692
+ return None
693
+ for unit in self._experiment_units:
694
+ if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
695
+ return unit
696
+ return None
697
+
698
+ def _create_experiment_unit( # type: ignore
625
699
  self,
626
- args: Mapping[str, Any],
700
+ args: JobArgs,
627
701
  role: xm.ExperimentUnitRole,
628
702
  identity: str,
629
- ) -> Awaitable[SlurmWorkUnit]:
630
- del identity
631
-
703
+ ) -> Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
632
704
  def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
633
705
  work_unit = SlurmWorkUnit(
634
706
  self,
@@ -636,26 +708,55 @@ class SlurmExperiment(xm.Experiment):
636
708
  args,
637
709
  role,
638
710
  self._work_unit_id_predictor,
711
+ identity=identity,
639
712
  )
640
713
  self._experiment_units.append(work_unit)
641
714
  self._work_unit_count += 1
642
715
 
643
- future = asyncio.Future()
716
+ api.client().insert_work_unit(
717
+ self.experiment_id,
718
+ api.WorkUnitModel(
719
+ wid=work_unit.work_unit_id,
720
+ identity=work_unit.identity,
721
+ args=json.dumps(args),
722
+ ),
723
+ )
724
+
725
+ future = asyncio.Future[SlurmWorkUnit]()
644
726
  future.set_result(work_unit)
645
727
  return future
646
728
 
729
+ def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> Awaitable[SlurmAuxiliaryUnit]:
730
+ auxiliary_unit = SlurmAuxiliaryUnit(
731
+ self,
732
+ self._create_task,
733
+ args,
734
+ role,
735
+ identity=identity,
736
+ )
737
+ self._experiment_units.append(auxiliary_unit)
738
+ future = asyncio.Future[SlurmAuxiliaryUnit]()
739
+ future.set_result(auxiliary_unit)
740
+ return future
741
+
647
742
  match role:
648
743
  case xm.WorkUnitRole():
744
+ if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
745
+ future = asyncio.Future[SlurmWorkUnit]()
746
+ future.set_result(existing_unit)
747
+ return future
649
748
  return _create_work_unit(role)
749
+ case xm.AuxiliaryUnitRole():
750
+ return _create_auxiliary_unit(role)
650
751
  case _:
651
752
  raise ValueError(f"Unsupported role {role}")
652
753
 
653
- def _get_experiment_unit(
754
+ def _get_experiment_unit( # type: ignore
654
755
  self,
655
756
  experiment_id: int,
656
757
  identity: str,
657
758
  role: xm.ExperimentUnitRole,
658
- args: Mapping[str, Any] | None = None,
759
+ args: JobArgs | None = None,
659
760
  ) -> Awaitable[xm.ExperimentUnit]:
660
761
  del experiment_id, identity, role, args
661
762
  raise NotImplementedError
@@ -672,7 +773,7 @@ class SlurmExperiment(xm.Experiment):
672
773
  # If no work units were added, delete this experiment
673
774
  # This is to prevent empty experiments from being persisted
674
775
  # and cluttering the database.
675
- if self.work_unit_count == 0:
776
+ if len(self._experiment_units) == 0:
676
777
  console.print(
677
778
  f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
678
779
  )
@@ -689,14 +790,13 @@ class SlurmExperiment(xm.Experiment):
689
790
  return self.context.annotations.title
690
791
 
691
792
  @property
692
- def context(self) -> SlurmExperimentMetadataContext:
793
+ def context(self) -> SlurmExperimentMetadataContext: # type: ignore
693
794
  return self._experiment_context
694
795
 
695
796
  @property
696
797
  def work_unit_count(self) -> int:
697
798
  return self._work_unit_count
698
799
 
699
- @property
700
800
  def work_units(self) -> Mapping[int, SlurmWorkUnit]:
701
801
  """Gets work units created via self.add()."""
702
802
  return {
@@ -716,8 +816,50 @@ def create_experiment(experiment_title: str) -> SlurmExperiment:
716
816
  def get_experiment(experiment_id: int) -> SlurmExperiment:
717
817
  """Get Experiment."""
718
818
  experiment_model = api.client().get_experiment(experiment_id)
719
- # TODO(jfarebro): Fill in jobs and work units and annotations
720
- return SlurmExperiment(experiment_title=experiment_model.title, experiment_id=experiment_id)
819
+ experiment = SlurmExperiment(
820
+ experiment_title=experiment_model.title, experiment_id=experiment_id
821
+ )
822
+ experiment._work_unit_id_predictor = id_predictor.Predictor(1)
823
+
824
+ # Populate annotations
825
+ experiment.context.annotations.description = experiment_model.description
826
+ experiment.context.annotations.note = experiment_model.note
827
+ experiment.context.annotations.tags = set(experiment_model.tags or [])
828
+
829
+ # Populate artifacts
830
+ for artifact in experiment_model.artifacts:
831
+ experiment.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
832
+
833
+ # Populate work units
834
+ for wu_model in experiment_model.work_units:
835
+ work_unit = SlurmWorkUnit(
836
+ experiment=experiment,
837
+ create_task=experiment._create_task,
838
+ args=json.loads(wu_model.args) if wu_model.args else {},
839
+ role=xm.WorkUnitRole(),
840
+ identity=wu_model.identity or "",
841
+ work_unit_id_predictor=experiment._work_unit_id_predictor,
842
+ )
843
+ work_unit._work_unit_id = wu_model.wid
844
+
845
+ # Populate jobs for each work unit
846
+ for job_model in wu_model.jobs:
847
+ slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
848
+ handle = execution.SlurmHandle(
849
+ ssh=slurm_ssh_config,
850
+ job_id=str(job_model.slurm_job_id),
851
+ job_name=job_model.name,
852
+ )
853
+ work_unit._execution_handles.append(handle)
854
+
855
+ # Populate artifacts for each work unit
856
+ for artifact in wu_model.artifacts:
857
+ work_unit.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
858
+
859
+ experiment._experiment_units.append(work_unit)
860
+ experiment._work_unit_count += 1
861
+
862
+ return experiment
721
863
 
722
864
 
723
865
  @functools.cache
@@ -733,5 +875,5 @@ def get_current_work_unit() -> SlurmWorkUnit | None:
733
875
  wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
734
876
  ):
735
877
  experiment = get_experiment(int(xid))
736
- return experiment.work_units[int(wid)]
878
+ return experiment.work_units()[int(wid)]
737
879
  return None