xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

Files changed (42) hide show
  1. xm_slurm/__init__.py +6 -2
  2. xm_slurm/api.py +301 -34
  3. xm_slurm/batching.py +4 -4
  4. xm_slurm/config.py +105 -55
  5. xm_slurm/constants.py +19 -0
  6. xm_slurm/contrib/__init__.py +0 -0
  7. xm_slurm/contrib/clusters/__init__.py +47 -13
  8. xm_slurm/contrib/clusters/drac.py +34 -16
  9. xm_slurm/dependencies.py +171 -0
  10. xm_slurm/executables.py +34 -22
  11. xm_slurm/execution.py +305 -107
  12. xm_slurm/executors.py +8 -12
  13. xm_slurm/experiment.py +601 -168
  14. xm_slurm/experimental/parameter_controller.py +202 -0
  15. xm_slurm/job_blocks.py +7 -0
  16. xm_slurm/packageables.py +42 -20
  17. xm_slurm/packaging/{docker/local.py → docker.py} +135 -40
  18. xm_slurm/packaging/router.py +3 -1
  19. xm_slurm/packaging/utils.py +9 -81
  20. xm_slurm/resources.py +28 -4
  21. xm_slurm/scripts/_cloudpickle.py +28 -0
  22. xm_slurm/scripts/cli.py +52 -0
  23. xm_slurm/status.py +9 -0
  24. xm_slurm/templates/docker/mamba.Dockerfile +4 -2
  25. xm_slurm/templates/docker/python.Dockerfile +18 -10
  26. xm_slurm/templates/docker/uv.Dockerfile +35 -0
  27. xm_slurm/templates/slurm/fragments/monitor.bash.j2 +5 -0
  28. xm_slurm/templates/slurm/job-array.bash.j2 +1 -2
  29. xm_slurm/templates/slurm/job.bash.j2 +4 -3
  30. xm_slurm/types.py +23 -0
  31. xm_slurm/utils.py +18 -10
  32. xmanager_slurm-0.4.1.dist-info/METADATA +26 -0
  33. xmanager_slurm-0.4.1.dist-info/RECORD +44 -0
  34. {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.1.dist-info}/WHEEL +1 -1
  35. xmanager_slurm-0.4.1.dist-info/entry_points.txt +2 -0
  36. xmanager_slurm-0.4.1.dist-info/licenses/LICENSE.md +227 -0
  37. xm_slurm/packaging/docker/__init__.py +0 -75
  38. xm_slurm/packaging/docker/abc.py +0 -112
  39. xm_slurm/packaging/docker/cloud.py +0 -503
  40. xm_slurm/templates/docker/pdm.Dockerfile +0 -31
  41. xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
  42. xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/experiment.py CHANGED
@@ -2,23 +2,31 @@ import asyncio
2
2
  import collections.abc
3
3
  import contextvars
4
4
  import dataclasses
5
+ import datetime as dt
5
6
  import functools
6
7
  import inspect
7
8
  import json
9
+ import logging
8
10
  import os
9
- import typing
11
+ import traceback
12
+ import typing as tp
10
13
  from concurrent import futures
11
- from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
12
14
 
15
+ import more_itertools as mit
16
+ from rich.console import ConsoleRenderable
13
17
  from xmanager import xm
14
- from xmanager.xm import async_packager, id_predictor
18
+ from xmanager.xm import async_packager, core, id_predictor, job_operators
19
+ from xmanager.xm import job_blocks as xm_job_blocks
15
20
 
16
- from xm_slurm import api, execution, executors
21
+ from xm_slurm import api, config, dependencies, execution, executors
17
22
  from xm_slurm.console import console
23
+ from xm_slurm.job_blocks import JobArgs
18
24
  from xm_slurm.packaging import router
19
25
  from xm_slurm.status import SlurmWorkUnitStatus
20
26
  from xm_slurm.utils import UserSet
21
27
 
28
+ logger = logging.getLogger(__name__)
29
+
22
30
  _current_job_array_queue = contextvars.ContextVar[
23
31
  asyncio.Queue[tuple[xm.JobGroup, asyncio.Future]] | None
24
32
  ]("_current_job_array_queue", default=None)
@@ -26,7 +34,7 @@ _current_job_array_queue = contextvars.ContextVar[
26
34
 
27
35
  def _validate_job(
28
36
  job: xm.JobType,
29
- args_view: Mapping[str, Any],
37
+ args_view: JobArgs | tp.Mapping[str, JobArgs],
30
38
  ) -> None:
31
39
  if not args_view:
32
40
  return
@@ -49,7 +57,7 @@ def _validate_job(
49
57
  )
50
58
 
51
59
  if isinstance(job, xm.JobGroup) and key in job.jobs:
52
- _validate_job(job.jobs[key], expanded)
60
+ _validate_job(job.jobs[key], tp.cast(JobArgs, expanded))
53
61
  elif key not in allowed_keys:
54
62
  raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
55
63
 
@@ -60,7 +68,7 @@ class Artifact:
60
68
  uri: str
61
69
 
62
70
  def __hash__(self) -> int:
63
- return hash(self.name)
71
+ return hash((type(self), self.name))
64
72
 
65
73
 
66
74
  class ContextArtifacts(UserSet[Artifact]):
@@ -68,7 +76,7 @@ class ContextArtifacts(UserSet[Artifact]):
68
76
  self,
69
77
  owner: "SlurmExperiment | SlurmExperimentUnit",
70
78
  *,
71
- artifacts: Sequence[Artifact],
79
+ artifacts: tp.Sequence[Artifact],
72
80
  ):
73
81
  super().__init__(
74
82
  artifacts,
@@ -117,50 +125,194 @@ class SlurmExperimentUnitMetadataContext:
117
125
  class SlurmExperimentUnit(xm.ExperimentUnit):
118
126
  """ExperimentUnit is a collection of semantically associated `Job`s."""
119
127
 
120
- experiment: "SlurmExperiment"
128
+ experiment: "SlurmExperiment" # type: ignore
121
129
 
122
130
  def __init__(
123
131
  self,
124
132
  experiment: xm.Experiment,
125
- create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
126
- args: Mapping[str, Any] | None,
133
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future[tp.Any]],
134
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
127
135
  role: xm.ExperimentUnitRole,
136
+ identity: str = "",
128
137
  ) -> None:
129
- super().__init__(experiment, create_task, args, role)
138
+ super().__init__(experiment, create_task, args, role, identity=identity)
130
139
  self._launched_jobs: list[xm.LaunchedJob] = []
131
140
  self._execution_handles: list[execution.SlurmHandle] = []
132
141
  self._context = SlurmExperimentUnitMetadataContext(
133
142
  artifacts=ContextArtifacts(owner=self, artifacts=[]),
134
143
  )
135
144
 
145
+ def add( # type: ignore
146
+ self,
147
+ job: xm.JobType,
148
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
149
+ *,
150
+ dependency: dependencies.SlurmJobDependency | None = None,
151
+ identity: str = "",
152
+ ) -> tp.Awaitable[None]:
153
+ # Prioritize the identity given directly to the work unit at work unit
154
+ # creation time, as opposed to the identity passed when adding jobs to it as
155
+ # this is more consistent between job generator work units and regular work
156
+ # units.
157
+ identity = self.identity or identity
158
+
159
+ job = job_operators.shallow_copy_job_type(job) # type: ignore
160
+ if args is not None:
161
+ core._apply_args(job, args)
162
+ job_operators.populate_job_names(job) # type: ignore
163
+
164
+ def launch_job(job: xm.Job) -> tp.Awaitable[None]:
165
+ core._current_experiment.set(self.experiment)
166
+ core._current_experiment_unit.set(self)
167
+ return self._launch_job_group(
168
+ xm.JobGroup(**{job.name: job}), # type: ignore
169
+ core._work_unit_arguments(job, self._args),
170
+ dependency=dependency,
171
+ identity=identity,
172
+ )
173
+
174
+ def launch_job_group(group: xm.JobGroup) -> tp.Awaitable[None]:
175
+ core._current_experiment.set(self.experiment)
176
+ core._current_experiment_unit.set(self)
177
+ return self._launch_job_group(
178
+ group,
179
+ core._work_unit_arguments(group, self._args),
180
+ dependency=dependency,
181
+ identity=identity,
182
+ )
183
+
184
+ def launch_job_generator(
185
+ job_generator: xm.JobGeneratorType,
186
+ ) -> tp.Awaitable[None]:
187
+ if not inspect.iscoroutinefunction(job_generator) and not inspect.iscoroutinefunction(
188
+ getattr(job_generator, "__call__")
189
+ ):
190
+ raise ValueError(
191
+ "Job generator must be an async function. Signature needs to be "
192
+ "`async def job_generator(work_unit: xm.WorkUnit) -> None:`"
193
+ )
194
+ core._current_experiment.set(self.experiment)
195
+ core._current_experiment_unit.set(self)
196
+ coroutine = job_generator(self, **(args or {}))
197
+ assert coroutine is not None
198
+ return coroutine
199
+
200
+ def launch_job_config(job_config: xm.JobConfig) -> tp.Awaitable[None]:
201
+ core._current_experiment.set(self.experiment)
202
+ core._current_experiment_unit.set(self)
203
+ return self._launch_job_config(
204
+ job_config, dependency, tp.cast(JobArgs, args) or {}, identity
205
+ )
206
+
207
+ job_awaitable: tp.Awaitable[tp.Any]
208
+ match job:
209
+ case xm.Job() as job:
210
+ job_awaitable = launch_job(job)
211
+ case xm.JobGroup() as job_group:
212
+ job_awaitable = launch_job_group(job_group)
213
+ case job_generator if xm_job_blocks.is_job_generator(job):
214
+ job_awaitable = launch_job_generator(job_generator) # type: ignore
215
+ case xm.JobConfig() as job_config:
216
+ job_awaitable = launch_job_config(job_config)
217
+ case _:
218
+ raise TypeError(f"Unsupported job type: {job!r}")
219
+
220
+ launch_task = self._create_task(job_awaitable)
221
+ self._launch_tasks.append(launch_task)
222
+ return asyncio.wrap_future(launch_task)
223
+
224
+ async def _launch_job_group( # type: ignore
225
+ self,
226
+ job_group: xm.JobGroup,
227
+ args_view: tp.Mapping[str, JobArgs],
228
+ *,
229
+ dependency: dependencies.SlurmJobDependency | None,
230
+ identity: str,
231
+ ) -> None:
232
+ del job_group, dependency, args_view, identity
233
+ raise NotImplementedError
234
+
235
+ async def _launch_job_config( # type: ignore
236
+ self,
237
+ job_config: xm.JobConfig,
238
+ dependency: dependencies.SlurmJobDependency | None,
239
+ args_view: JobArgs,
240
+ identity: str,
241
+ ) -> None:
242
+ del job_config, dependency, args_view, identity
243
+ raise NotImplementedError
244
+
245
+ @tp.overload
246
+ async def _submit_jobs_for_execution(
247
+ self,
248
+ job: xm.Job,
249
+ dependency: dependencies.SlurmJobDependency | None,
250
+ args_view: JobArgs,
251
+ identity: str | None = ...,
252
+ ) -> execution.SlurmHandle: ...
253
+
254
+ @tp.overload
255
+ async def _submit_jobs_for_execution(
256
+ self,
257
+ job: xm.JobGroup,
258
+ dependency: dependencies.SlurmJobDependency | None,
259
+ args_view: tp.Mapping[str, JobArgs],
260
+ identity: str | None = ...,
261
+ ) -> execution.SlurmHandle: ...
262
+
263
+ @tp.overload
136
264
  async def _submit_jobs_for_execution(
137
265
  self,
138
- job: xm.Job | xm.JobGroup,
139
- args_view: Mapping[str, Any],
140
- identity: str | None = None,
141
- ) -> execution.SlurmHandle:
266
+ job: xm.Job,
267
+ dependency: dependencies.SlurmJobDependency | None,
268
+ args_view: tp.Sequence[JobArgs],
269
+ identity: str | None = ...,
270
+ ) -> list[execution.SlurmHandle]: ...
271
+
272
+ async def _submit_jobs_for_execution(self, job, dependency, args_view, identity=None):
142
273
  return await execution.launch(
143
274
  job=job,
275
+ dependency=dependency,
144
276
  args=args_view,
145
277
  experiment_id=self.experiment_id,
146
278
  identity=identity,
147
279
  )
148
280
 
149
281
  def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
282
+ self._execution_handles.append(handle)
283
+
284
+ def _ingest_job(job: xm.Job) -> None:
285
+ if not isinstance(self._role, xm.WorkUnitRole):
286
+ return
287
+ assert isinstance(self, SlurmWorkUnit)
288
+ assert job.name is not None
289
+ api.client().insert_job(
290
+ self.experiment_id,
291
+ self.work_unit_id,
292
+ api.SlurmJobModel(
293
+ name=job.name,
294
+ slurm_job_id=handle.slurm_job.job_id,
295
+ slurm_ssh_config=handle.ssh.serialize(),
296
+ ),
297
+ )
298
+
150
299
  match job:
151
300
  case xm.JobGroup() as job_group:
152
301
  for job in job_group.jobs.values():
302
+ assert isinstance(job, xm.Job)
303
+ _ingest_job(job)
153
304
  self._launched_jobs.append(
154
305
  xm.LaunchedJob(
155
306
  name=job.name, # type: ignore
156
- address=str(handle.job_id),
307
+ address=str(handle.slurm_job.job_id),
157
308
  )
158
309
  )
159
310
  case xm.Job():
311
+ _ingest_job(job)
160
312
  self._launched_jobs.append(
161
313
  xm.LaunchedJob(
162
314
  name=handle.job.name, # type: ignore
163
- address=str(handle.job_id),
315
+ address=str(handle.slurm_job.job_id),
164
316
  )
165
317
  )
166
318
 
@@ -187,85 +339,116 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
187
339
 
188
340
  self.experiment._create_task(_stop_awaitable())
189
341
 
190
- async def get_status(self) -> SlurmWorkUnitStatus:
342
+ async def get_status(self) -> SlurmWorkUnitStatus: # type: ignore
191
343
  states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
192
344
  return SlurmWorkUnitStatus.aggregate(states)
193
345
 
346
+ async def logs(
347
+ self,
348
+ *,
349
+ num_lines: int = 10,
350
+ block_size: int = 1024,
351
+ wait: bool = True,
352
+ follow: bool = False,
353
+ ) -> tp.AsyncGenerator[ConsoleRenderable, None]:
354
+ assert len(self._execution_handles) == 1, "Only one job handle is supported for logs."
355
+ handle = self._execution_handles[0] # TODO(jfarebro): interleave?
356
+ async for log in handle.logs(
357
+ num_lines=num_lines, block_size=block_size, wait=wait, follow=follow
358
+ ):
359
+ yield log
360
+
361
+ @property
194
362
  def launched_jobs(self) -> list[xm.LaunchedJob]:
195
363
  return self._launched_jobs
196
364
 
197
365
  @property
198
- def context(self) -> SlurmExperimentUnitMetadataContext:
366
+ def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
199
367
  return self._context
200
368
 
369
+ def after_started(
370
+ self, *, time: dt.timedelta | None = None
371
+ ) -> dependencies.SlurmJobDependencyAfter:
372
+ return dependencies.SlurmJobDependencyAfter(self._execution_handles, time=time)
373
+
374
+ def after_finished(self) -> dependencies.SlurmJobDependencyAfterAny:
375
+ return dependencies.SlurmJobDependencyAfterAny(self._execution_handles)
376
+
377
+ def after_completed(self) -> dependencies.SlurmJobDependencyAfterOK:
378
+ return dependencies.SlurmJobDependencyAfterOK(self._execution_handles)
379
+
380
+ def after_failed(self) -> dependencies.SlurmJobDependencyAfterNotOK:
381
+ return dependencies.SlurmJobDependencyAfterNotOK(self._execution_handles)
382
+
201
383
 
202
384
  class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
203
385
  def __init__(
204
386
  self,
205
387
  experiment: "SlurmExperiment",
206
- create_task: Callable[[Awaitable[Any]], futures.Future],
207
- args: Mapping[str, Any],
388
+ create_task: tp.Callable[[tp.Awaitable[tp.Any]], futures.Future],
389
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
208
390
  role: xm.ExperimentUnitRole,
209
391
  work_unit_id_predictor: id_predictor.Predictor,
392
+ identity: str = "",
210
393
  ) -> None:
211
- super().__init__(experiment, create_task, args, role)
394
+ super().__init__(experiment, create_task, args, role, identity=identity)
212
395
  self._work_unit_id_predictor = work_unit_id_predictor
213
396
  self._work_unit_id = self._work_unit_id_predictor.reserve_id()
214
- api.client().insert_work_unit(
215
- self.experiment_id,
216
- api.WorkUnitPatchModel(
217
- wid=self.work_unit_id,
218
- identity=self.identity,
219
- args=json.dumps(args),
220
- ),
221
- )
222
397
 
223
- def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
224
- self._execution_handles.append(handle)
225
- api.client().insert_job(
226
- self.experiment_id,
227
- self.work_unit_id,
228
- api.SlurmJobModel(
229
- name=self.experiment_unit_name,
230
- slurm_job_id=handle.job_id, # type: ignore
231
- slurm_cluster=json.dumps({
232
- "host": handle.ssh_connection_options.host,
233
- "username": handle.ssh_connection_options.username,
234
- "port": handle.ssh_connection_options.port,
235
- "config": handle.ssh_connection_options.config.get_options(False),
236
- }),
237
- ),
238
- )
398
+ def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
399
+ job_name = mit.one(job.jobs.keys())
400
+ for handle in self._execution_handles:
401
+ if handle.job_name == job_name:
402
+ return handle
403
+ return None
239
404
 
240
- async def _launch_job_group(
405
+ async def _launch_job_group( # type: ignore
241
406
  self,
242
407
  job: xm.JobGroup,
243
- args_view: Mapping[str, Any],
408
+ args_view: tp.Mapping[str, JobArgs],
409
+ *,
410
+ dependency: dependencies.SlurmJobDependency | None,
244
411
  identity: str,
245
412
  ) -> None:
246
413
  global _current_job_array_queue
247
414
  _validate_job(job, args_view)
415
+ future = asyncio.Future[execution.SlurmHandle]()
416
+
417
+ # If we already have a handle for this job, we don't need to submit it again.
418
+ # We'll just resolve the future with the existing handle.
419
+ # Otherwise we'll add callbacks to ingest the handle and the launched jobs.
420
+ if existing_handle := self._get_existing_handle(job):
421
+ future.set_result(existing_handle)
422
+ else:
423
+ future.add_done_callback(
424
+ lambda handle: self._ingest_launched_jobs(job, handle.result())
425
+ )
426
+
427
+ api.client().update_work_unit(
428
+ self.experiment_id,
429
+ self.work_unit_id,
430
+ api.ExperimentUnitPatchModel(args=json.dumps(args_view), identity=None),
431
+ )
248
432
 
249
- future = asyncio.Future()
250
433
  async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
251
434
  # If we're scheduling as part of a job queue (i.e., the queue is set on the context)
252
- # then we'll insert the job and current future that'll get resolved to the
253
- # proper handle.
435
+ # then we'll insert the job and future that'll get resolved to the proper handle
436
+ # when the Slurm job array is scheduled.
254
437
  if job_array_queue := _current_job_array_queue.get():
255
438
  job_array_queue.put_nowait((job, future))
256
- # Otherwise we'll resolve the future with the scheduled job immediately
257
- else:
258
- # Set the result inside of the context manager so we don't get out-of-order
259
- # id scheduling...
260
- future.set_result(
261
- await self._submit_jobs_for_execution(job, args_view, identity=identity)
439
+ # Otherwise, we're scheduling a single job and we'll submit it for execution.
440
+ # If the future is already done, i.e., the handle is already resolved, we don't need
441
+ # to submit the job again.
442
+ elif not future.done():
443
+ handle = await self._submit_jobs_for_execution(
444
+ job, dependency, args_view, identity=identity
262
445
  )
446
+ future.set_result(handle)
263
447
 
264
448
  # Wait for the job handle, this is either coming from scheduling the job array
265
- # or from the single job above.
266
- handle = await future
267
- self._ingest_handle(handle)
268
- self._ingest_launched_jobs(job, handle)
449
+ # or from a single job submission. If an existing handle was found, this will be
450
+ # a no-op.
451
+ await future
269
452
 
270
453
  @property
271
454
  def experiment_unit_name(self) -> str:
@@ -282,20 +465,19 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
282
465
  class SlurmAuxiliaryUnit(SlurmExperimentUnit):
283
466
  """An auxiliary unit operated by the Slurm backend."""
284
467
 
285
- def _ingest_handle(self, handle: execution.SlurmHandle) -> None:
286
- del handle
287
- console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
288
-
289
- async def _launch_job_group(
468
+ async def _launch_job_group( # type: ignore
290
469
  self,
291
- job: xm.Job | xm.JobGroup,
292
- args_view: Mapping[str, Any],
470
+ job: xm.JobGroup,
471
+ args_view: tp.Mapping[str, JobArgs],
472
+ *,
473
+ dependency: dependencies.SlurmJobDependency | None,
293
474
  identity: str,
294
475
  ) -> None:
295
476
  _validate_job(job, args_view)
296
477
 
297
- slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
298
- self._ingest_handle(slurm_handle)
478
+ slurm_handle = await self._submit_jobs_for_execution(
479
+ job, dependency, args_view, identity=identity
480
+ )
299
481
  self._ingest_launched_jobs(job, slurm_handle)
300
482
 
301
483
  @property
@@ -365,7 +547,7 @@ class SlurmExperimentContextAnnotations:
365
547
  )
366
548
 
367
549
  @property
368
- def tags(self) -> MutableSet[str]:
550
+ def tags(self) -> tp.MutableSet[str]:
369
551
  return self._tags
370
552
 
371
553
  @tags.setter
@@ -432,75 +614,85 @@ class SlurmExperiment(xm.Experiment):
432
614
  )
433
615
  self._work_unit_count = 0
434
616
 
435
- @typing.overload
436
- def add(
617
+ @tp.overload
618
+ def add( # type: ignore
437
619
  self,
438
620
  job: xm.AuxiliaryUnitJob,
439
- args: Mapping[str, Any] | None = ...,
621
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = ...,
440
622
  *,
441
- identity: str = "",
442
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
623
+ dependency: dependencies.SlurmJobDependency | None = ...,
624
+ identity: str = ...,
625
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
443
626
 
444
- @typing.overload
627
+ @tp.overload
445
628
  def add(
446
629
  self,
447
- job: xm.JobType,
448
- args: Mapping[str, Any] | None = ...,
630
+ job: xm.JobGroup,
631
+ args: tp.Mapping[str, JobArgs] | None = ...,
449
632
  *,
450
- role: xm.WorkUnitRole = ...,
451
- identity: str = "",
633
+ role: xm.WorkUnitRole | None = ...,
634
+ dependency: dependencies.SlurmJobDependency | None = ...,
635
+ identity: str = ...,
452
636
  ) -> asyncio.Future[SlurmWorkUnit]: ...
453
637
 
454
- @typing.overload
638
+ @tp.overload
455
639
  def add(
456
640
  self,
457
- job: xm.JobType,
458
- args: Mapping[str, Any] | None,
641
+ job: xm.Job | xm.JobGeneratorType,
642
+ args: tp.Sequence[JobArgs],
459
643
  *,
460
- role: xm.ExperimentUnitRole,
461
- identity: str = "",
462
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
463
-
464
- @typing.overload
644
+ role: xm.WorkUnitRole | None = ...,
645
+ dependency: dependencies.SlurmJobDependency
646
+ | tp.Sequence[dependencies.SlurmJobDependency]
647
+ | None = ...,
648
+ identity: str = ...,
649
+ ) -> asyncio.Future[tp.Sequence[SlurmWorkUnit]]: ...
650
+
651
+ @tp.overload
465
652
  def add(
466
653
  self,
467
- job: xm.JobType,
468
- args: Mapping[str, Any] | None = ...,
654
+ job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
655
+ args: JobArgs | None = ...,
469
656
  *,
470
- role: xm.ExperimentUnitRole,
471
- identity: str = "",
472
- ) -> asyncio.Future[SlurmExperimentUnit]: ...
657
+ role: xm.WorkUnitRole | None = ...,
658
+ dependency: dependencies.SlurmJobDependency | None = ...,
659
+ identity: str = ...,
660
+ ) -> asyncio.Future[SlurmWorkUnit]: ...
473
661
 
474
- @typing.overload
662
+ @tp.overload
475
663
  def add(
476
664
  self,
477
- job: xm.Job | xm.JobGeneratorType,
478
- args: Sequence[Mapping[str, Any]],
665
+ job: xm.JobType,
479
666
  *,
480
- role: xm.WorkUnitRole = ...,
481
- identity: str = "",
482
- ) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
667
+ role: xm.AuxiliaryUnitRole,
668
+ dependency: dependencies.SlurmJobDependency | None = ...,
669
+ identity: str = ...,
670
+ ) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
483
671
 
484
- def add(
672
+ def add( # type: ignore
485
673
  self,
486
674
  job: xm.JobType,
487
- args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
675
+ args: JobArgs
676
+ | tp.Mapping[str, JobArgs]
677
+ | tp.Sequence[tp.Mapping[str, tp.Any]]
678
+ | None = None,
488
679
  *,
489
- role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
680
+ role: xm.ExperimentUnitRole | None = None,
681
+ dependency: dependencies.SlurmJobDependency
682
+ | tp.Sequence[dependencies.SlurmJobDependency]
683
+ | None = None,
490
684
  identity: str = "",
491
685
  ) -> (
492
- asyncio.Future[SlurmExperimentUnit]
686
+ asyncio.Future[SlurmAuxiliaryUnit]
493
687
  | asyncio.Future[SlurmWorkUnit]
494
- | asyncio.Future[Sequence[SlurmWorkUnit]]
688
+ | asyncio.Future[tp.Sequence[SlurmWorkUnit]]
495
689
  ):
690
+ if role is None:
691
+ role = xm.WorkUnitRole()
692
+
496
693
  if isinstance(args, collections.abc.Sequence):
497
694
  if not isinstance(role, xm.WorkUnitRole):
498
695
  raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
499
- if identity:
500
- raise ValueError(
501
- "Cannot set an identity on the root add call. "
502
- "Please use a job generator and set the identity within."
503
- )
504
696
  if isinstance(job, xm.JobGroup):
505
697
  raise ValueError(
506
698
  "Job arrays over `xm.JobGroup`s aren't supported. "
@@ -509,19 +701,79 @@ class SlurmExperiment(xm.Experiment):
509
701
  )
510
702
  assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
511
703
 
704
+ # Validate job & args
705
+ for trial in args:
706
+ _validate_job(job, trial)
707
+ args = tp.cast(tp.Sequence[JobArgs], args)
708
+
512
709
  return asyncio.wrap_future(
513
- self._create_task(self._launch_job_array(job, args, role, identity))
710
+ self._create_task(self._launch_job_array(job, dependency, args, role, identity)),
711
+ loop=self._event_loop,
712
+ )
713
+ if not (isinstance(dependency, dependencies.SlurmJobDependency) or dependency is None):
714
+ raise ValueError("Invalid dependency type, expected a SlurmJobDependency or None")
715
+
716
+ if isinstance(job, xm.AuxiliaryUnitJob):
717
+ role = job.role
718
+ self._added_roles[type(role)] += 1
719
+
720
+ if self._should_reload_experiment_unit(role):
721
+ experiment_unit_future = self._get_experiment_unit(
722
+ self.experiment_id, identity, role, args
514
723
  )
515
724
  else:
516
- return super().add(job, args, role=role, identity=identity) # type: ignore
725
+ experiment_unit_future = self._create_experiment_unit(args, role, identity)
726
+
727
+ async def launch():
728
+ experiment_unit = await experiment_unit_future
729
+ try:
730
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
731
+ except Exception as experiment_exception:
732
+ logger.error(
733
+ "Stopping experiment unit (identity %r) after it failed with: %s",
734
+ identity,
735
+ experiment_exception,
736
+ )
737
+ try:
738
+ if isinstance(job, xm.AuxiliaryUnitJob):
739
+ experiment_unit.stop()
740
+ else:
741
+ experiment_unit.stop(
742
+ mark_as_failed=True,
743
+ message=f"Work unit creation failed. {traceback.format_exc()}",
744
+ )
745
+ except Exception as stop_exception: # pylint: disable=broad-except
746
+ logger.error("Couldn't stop experiment unit: %s", stop_exception)
747
+ raise
748
+ return experiment_unit
749
+
750
+ async def reload():
751
+ experiment_unit = await experiment_unit_future
752
+ try:
753
+ await experiment_unit.add(job, args, dependency=dependency, identity=identity)
754
+ except Exception as update_exception:
755
+ logging.error(
756
+ "Could not reload the experiment unit: %s",
757
+ update_exception,
758
+ )
759
+ raise
760
+ return experiment_unit
761
+
762
+ return asyncio.wrap_future(
763
+ self._create_task(reload() if self._should_reload_experiment_unit(role) else launch()),
764
+ loop=self._event_loop,
765
+ )
517
766
 
518
767
  async def _launch_job_array(
519
768
  self,
520
769
  job: xm.Job | xm.JobGeneratorType,
521
- args: Sequence[Mapping[str, Any]],
770
+ dependency: dependencies.SlurmJobDependency
771
+ | tp.Sequence[dependencies.SlurmJobDependency]
772
+ | None,
773
+ args: tp.Sequence[JobArgs],
522
774
  role: xm.WorkUnitRole,
523
775
  identity: str = "",
524
- ) -> Sequence[SlurmWorkUnit]:
776
+ ) -> tp.Sequence[SlurmWorkUnit]:
525
777
  global _current_job_array_queue
526
778
 
527
779
  # Create our job array queue and assign it to the current context
@@ -531,10 +783,13 @@ class SlurmExperiment(xm.Experiment):
531
783
  # For each trial we'll schedule the job
532
784
  # and collect the futures
533
785
  wu_futures = []
534
- for trial in args:
535
- wu_futures.append(super().add(job, args=trial, role=role, identity=identity))
786
+ for idx, trial in enumerate(args):
787
+ wu_futures.append(
788
+ self.add(
789
+ job, args=trial, role=role, identity=f"{identity}_{idx}" if identity else ""
790
+ )
791
+ )
536
792
 
537
- # TODO(jfarebro): Set a timeout here
538
793
  # We'll wait until XManager has filled the queue.
539
794
  # There are two cases here, either we were given an xm.Job
540
795
  # in which case this will be trivial and filled immediately.
@@ -543,7 +798,8 @@ class SlurmExperiment(xm.Experiment):
543
798
  while not job_array_queue.full():
544
799
  await asyncio.sleep(0.1)
545
800
 
546
- # All jobs have been resolved
801
+ # All jobs have been resolved so now we'll perform sanity checks
802
+ # to make sure we can infer the sweep
547
803
  executable, executor, name = None, None, None
548
804
  resolved_args, resolved_env_vars, resolved_futures = [], [], []
549
805
  while not job_array_queue.empty():
@@ -551,9 +807,13 @@ class SlurmExperiment(xm.Experiment):
551
807
  # that there's only a single job in this job group
552
808
  job_group_view, future = job_array_queue.get_nowait()
553
809
  assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
554
- _, job_view = job_group_view.jobs.popitem()
810
+ job_view = mit.one(
811
+ job_group_view.jobs.values(),
812
+ too_short=ValueError("Expected a single `xm.Job` in job group."),
813
+ too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
814
+ )
555
815
 
556
- if len(job_group_view.jobs) != 0 or not isinstance(job_view, xm.Job):
816
+ if not isinstance(job_view, xm.Job):
557
817
  raise ValueError("Only `xm.Job` is supported for job arrays. ")
558
818
 
559
819
  if executable is None:
@@ -571,92 +831,223 @@ class SlurmExperiment(xm.Experiment):
571
831
  if job_view.name != name:
572
832
  raise RuntimeError("Found multiple names in job array")
573
833
 
574
- resolved_args.append(
575
- set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
576
- )
834
+ resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
577
835
  resolved_env_vars.append(set(job_view.env_vars.items()))
578
836
  resolved_futures.append(future)
579
837
  assert executable is not None, "No executable found?"
580
838
  assert executor is not None, "No executor found?"
581
839
  assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
582
- assert executor.requirements.cluster is not None, "Cluster must be set on executor."
583
-
584
- common_args: set = functools.reduce(lambda a, b: a & b, resolved_args, set())
840
+ assert (
841
+ executor.requirements.cluster is not None
842
+ ), "Cluster must be specified on requirements."
843
+
844
+ # XManager merges job arguments with keyword arguments with job arguments
845
+ # coming first. These are the arguments that may be common across all jobs
846
+ # so we can find the largest common prefix and remove them from each job.
847
+ common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
585
848
  common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
586
849
 
587
850
  sweep_args = [
588
- {
589
- "args": dict(a.difference(common_args)),
590
- "env_vars": dict(e.difference(common_env_vars)),
591
- }
851
+ JobArgs(
852
+ args=functools.reduce(
853
+ # Remove the common arguments from each job
854
+ lambda args, to_remove: args.remove_args(to_remove),
855
+ common_args,
856
+ xm.SequentialArgs.from_collection(a),
857
+ ),
858
+ env_vars=dict(e.difference(common_env_vars)),
859
+ )
592
860
  for a, e in zip(resolved_args, resolved_env_vars)
593
861
  ]
594
862
 
863
+ # Dependency resolution
864
+ resolved_dependency = None
865
+ resolved_dependency_task_id_order = None
866
+ # one-to-one
867
+ if isinstance(dependency, collections.abc.Sequence):
868
+ if len(dependency) != len(wu_futures):
869
+ raise ValueError("Dependency list must be the same length as the number of trials.")
870
+ assert len(dependency) > 0, "Dependency list must not be empty."
871
+
872
+ # Convert any SlurmJobDependencyAfterOK to SlurmJobArrayDependencyAfterOK
873
+ # for any array jobs.
874
+ def _maybe_convert_afterok(
875
+ dep: dependencies.SlurmJobDependency,
876
+ ) -> dependencies.SlurmJobDependency:
877
+ if isinstance(dep, dependencies.SlurmJobDependencyAfterOK) and all([
878
+ handle.slurm_job.is_array_job for handle in dep.handles
879
+ ]):
880
+ return dependencies.SlurmJobArrayDependencyAfterOK([
881
+ dataclasses.replace(
882
+ handle,
883
+ slurm_job=handle.slurm_job.array_job_id,
884
+ )
885
+ for handle in dep.handles
886
+ ])
887
+ return dep
888
+
889
+ dependencies_converted = [dep.traverse(_maybe_convert_afterok) for dep in dependency]
890
+ dependency_sets = [set(dep.flatten()) for dep in dependencies_converted]
891
+ dependency_differences = functools.reduce(set.difference, dependency_sets, set())
892
+ # There should be NO differences between the dependencies of each trial after conversion.
893
+ if len(dependency_differences) > 0:
894
+ raise ValueError(
895
+ f"Found variable dependencies across trials: {dependency_differences}. "
896
+ "Slurm job arrays require the same dependencies across all trials. "
897
+ )
898
+ resolved_dependency = dependencies_converted[0]
899
+
900
+ # This is slightly annoying but we need to re-sort the sweep arguments in case the dependencies were passed
901
+ # in a different order than 1, 2, ..., N as the Job array can only have correspondance with the same task id.
902
+ original_array_dependencies = [
903
+ mit.one(
904
+ filter(
905
+ lambda dep: isinstance(dep, dependencies.SlurmJobDependencyAfterOK)
906
+ and all([handle.slurm_job.is_array_job for handle in dep.handles]),
907
+ deps.flatten(),
908
+ )
909
+ )
910
+ for deps in dependency
911
+ ]
912
+ resolved_dependency_task_id_order = [
913
+ int(
914
+ mit.one(
915
+ functools.reduce(
916
+ set.difference,
917
+ [handle.slurm_job.array_task_id for handle in dep.handles], # type: ignore
918
+ )
919
+ )
920
+ )
921
+ for dep in original_array_dependencies
922
+ ]
923
+ assert len(resolved_dependency_task_id_order) == len(sweep_args)
924
+ assert set(resolved_dependency_task_id_order) == set(range(len(sweep_args))), (
925
+ "Dependent job array tasks should have task ids 0, 1, ..., N. "
926
+ f"Found: {resolved_dependency_task_id_order}"
927
+ )
928
+ # one-to-many
929
+ elif isinstance(dependency, dependencies.SlurmJobDependency):
930
+ resolved_dependency = dependency
931
+ assert resolved_dependency is None or isinstance(
932
+ resolved_dependency, dependencies.SlurmJobDependency
933
+ ), "Invalid dependency type"
934
+
595
935
  # No support for sweep_env_vars right now.
596
936
  # We schedule the job array and then we'll resolve all the work units with
597
937
  # the handles Slurm gives back to us.
598
- try:
599
- handles = await execution.get_client().launch(
600
- cluster=executor.requirements.cluster,
601
- job=xm.Job(
602
- executable=executable,
603
- executor=executor,
604
- name=name,
605
- args=dict(common_args),
606
- env_vars=dict(common_env_vars),
607
- ),
608
- args=sweep_args,
609
- experiment_id=self.experiment_id,
610
- identity=identity,
938
+ # If we already have handles for all the work units, we don't need to submit the
939
+ # job array to SLURM.
940
+ num_resolved_handles = sum(future.done() for future in resolved_futures)
941
+ if num_resolved_handles == 0:
942
+ try:
943
+ handles = await execution.launch(
944
+ job=xm.Job(
945
+ executable=executable,
946
+ executor=executor,
947
+ name=name,
948
+ args=xm.SequentialArgs.from_collection(common_args),
949
+ env_vars=dict(common_env_vars),
950
+ ),
951
+ dependency=resolved_dependency,
952
+ args=[
953
+ sweep_args[resolved_dependency_task_id_order.index(i)]
954
+ for i in range(len(sweep_args))
955
+ ]
956
+ if resolved_dependency_task_id_order
957
+ else sweep_args,
958
+ experiment_id=self.experiment_id,
959
+ identity=identity,
960
+ )
961
+ if resolved_dependency_task_id_order:
962
+ handles = [handles[i] for i in resolved_dependency_task_id_order]
963
+ except Exception as e:
964
+ for future in resolved_futures:
965
+ future.set_exception(e)
966
+ raise
967
+ else:
968
+ for handle, future in zip(handles, resolved_futures):
969
+ future.set_result(handle)
970
+ elif 0 < num_resolved_handles < len(resolved_futures):
971
+ raise RuntimeError(
972
+ "Some array job elements have handles, but some don't. This shouldn't happen."
611
973
  )
612
- except Exception as e:
613
- for future in resolved_futures:
614
- future.set_exception(e)
615
- raise
616
- else:
617
- for handle, future in zip(handles, resolved_futures):
618
- future.set_result(handle)
619
974
 
620
975
  wus = await asyncio.gather(*wu_futures)
976
+
621
977
  _current_job_array_queue.set(None)
622
978
  return wus
623
979
 
624
- def _create_experiment_unit(
980
+ def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
981
+ if identity == "":
982
+ return None
983
+ for unit in self._experiment_units:
984
+ if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
985
+ return unit
986
+ return None
987
+
988
+ def _create_experiment_unit( # type: ignore
625
989
  self,
626
- args: Mapping[str, Any],
990
+ args: JobArgs | tp.Mapping[str, JobArgs] | None,
627
991
  role: xm.ExperimentUnitRole,
628
992
  identity: str,
629
- ) -> Awaitable[SlurmWorkUnit]:
630
- del identity
631
-
632
- def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
993
+ ) -> tp.Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
994
+ def _create_work_unit(role: xm.WorkUnitRole) -> tp.Awaitable[SlurmWorkUnit]:
633
995
  work_unit = SlurmWorkUnit(
634
996
  self,
635
997
  self._create_task,
636
998
  args,
637
999
  role,
638
1000
  self._work_unit_id_predictor,
1001
+ identity=identity,
639
1002
  )
640
1003
  self._experiment_units.append(work_unit)
641
1004
  self._work_unit_count += 1
642
1005
 
643
- future = asyncio.Future()
1006
+ api.client().insert_work_unit(
1007
+ self.experiment_id,
1008
+ api.WorkUnitModel(
1009
+ wid=work_unit.work_unit_id,
1010
+ identity=work_unit.identity,
1011
+ args=json.dumps(args),
1012
+ ),
1013
+ )
1014
+
1015
+ future = asyncio.Future[SlurmWorkUnit]()
644
1016
  future.set_result(work_unit)
645
1017
  return future
646
1018
 
1019
+ def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> tp.Awaitable[SlurmAuxiliaryUnit]:
1020
+ auxiliary_unit = SlurmAuxiliaryUnit(
1021
+ self,
1022
+ self._create_task,
1023
+ args,
1024
+ role,
1025
+ identity=identity,
1026
+ )
1027
+ self._experiment_units.append(auxiliary_unit)
1028
+ future = asyncio.Future[SlurmAuxiliaryUnit]()
1029
+ future.set_result(auxiliary_unit)
1030
+ return future
1031
+
647
1032
  match role:
648
1033
  case xm.WorkUnitRole():
1034
+ if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
1035
+ future = asyncio.Future[SlurmWorkUnit]()
1036
+ future.set_result(existing_unit)
1037
+ return future
649
1038
  return _create_work_unit(role)
1039
+ case xm.AuxiliaryUnitRole():
1040
+ return _create_auxiliary_unit(role)
650
1041
  case _:
651
1042
  raise ValueError(f"Unsupported role {role}")
652
1043
 
653
- def _get_experiment_unit(
1044
+ def _get_experiment_unit( # type: ignore
654
1045
  self,
655
1046
  experiment_id: int,
656
1047
  identity: str,
657
1048
  role: xm.ExperimentUnitRole,
658
- args: Mapping[str, Any] | None = None,
659
- ) -> Awaitable[xm.ExperimentUnit]:
1049
+ args: JobArgs | tp.Mapping[str, JobArgs] | None = None,
1050
+ ) -> tp.Awaitable[SlurmExperimentUnit]:
660
1051
  del experiment_id, identity, role, args
661
1052
  raise NotImplementedError
662
1053
 
@@ -672,7 +1063,7 @@ class SlurmExperiment(xm.Experiment):
672
1063
  # If no work units were added, delete this experiment
673
1064
  # This is to prevent empty experiments from being persisted
674
1065
  # and cluttering the database.
675
- if self.work_unit_count == 0:
1066
+ if len(self._experiment_units) == 0:
676
1067
  console.print(
677
1068
  f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
678
1069
  )
@@ -689,15 +1080,14 @@ class SlurmExperiment(xm.Experiment):
689
1080
  return self.context.annotations.title
690
1081
 
691
1082
  @property
692
- def context(self) -> SlurmExperimentMetadataContext:
1083
+ def context(self) -> SlurmExperimentMetadataContext: # type: ignore
693
1084
  return self._experiment_context
694
1085
 
695
1086
  @property
696
1087
  def work_unit_count(self) -> int:
697
1088
  return self._work_unit_count
698
1089
 
699
- @property
700
- def work_units(self) -> Mapping[int, SlurmWorkUnit]:
1090
+ def work_units(self) -> dict[int, SlurmWorkUnit]:
701
1091
  """Gets work units created via self.add()."""
702
1092
  return {
703
1093
  wu.work_unit_id: wu for wu in self._experiment_units if isinstance(wu, SlurmWorkUnit)
@@ -716,8 +1106,51 @@ def create_experiment(experiment_title: str) -> SlurmExperiment:
716
1106
  def get_experiment(experiment_id: int) -> SlurmExperiment:
717
1107
  """Get Experiment."""
718
1108
  experiment_model = api.client().get_experiment(experiment_id)
719
- # TODO(jfarebro): Fill in jobs and work units and annotations
720
- return SlurmExperiment(experiment_title=experiment_model.title, experiment_id=experiment_id)
1109
+ experiment = SlurmExperiment(
1110
+ experiment_title=experiment_model.title, experiment_id=experiment_id
1111
+ )
1112
+ experiment._work_unit_id_predictor = id_predictor.Predictor(1)
1113
+
1114
+ # Populate annotations
1115
+ experiment.context.annotations.description = experiment_model.description or ""
1116
+ experiment.context.annotations.note = experiment_model.note or ""
1117
+ experiment.context.annotations.tags = set(experiment_model.tags or [])
1118
+
1119
+ # Populate artifacts
1120
+ for artifact in experiment_model.artifacts:
1121
+ experiment.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
1122
+
1123
+ # Populate work units
1124
+ for wu_model in experiment_model.work_units:
1125
+ work_unit = SlurmWorkUnit(
1126
+ experiment=experiment,
1127
+ create_task=experiment._create_task,
1128
+ args=json.loads(wu_model.args) if wu_model.args else {},
1129
+ role=xm.WorkUnitRole(),
1130
+ identity=wu_model.identity or "",
1131
+ work_unit_id_predictor=experiment._work_unit_id_predictor,
1132
+ )
1133
+ work_unit._work_unit_id = wu_model.wid
1134
+
1135
+ # Populate jobs for each work unit
1136
+ for job_model in wu_model.jobs:
1137
+ slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
1138
+ handle = execution.SlurmHandle(
1139
+ experiment_id=experiment_id,
1140
+ ssh=slurm_ssh_config,
1141
+ slurm_job=str(job_model.slurm_job_id),
1142
+ job_name=job_model.name,
1143
+ )
1144
+ work_unit._execution_handles.append(handle)
1145
+
1146
+ # Populate artifacts for each work unit
1147
+ for artifact in wu_model.artifacts:
1148
+ work_unit.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
1149
+
1150
+ experiment._experiment_units.append(work_unit)
1151
+ experiment._work_unit_count += 1
1152
+
1153
+ return experiment
721
1154
 
722
1155
 
723
1156
  @functools.cache
@@ -733,5 +1166,5 @@ def get_current_work_unit() -> SlurmWorkUnit | None:
733
1166
  wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
734
1167
  ):
735
1168
  experiment = get_experiment(int(xid))
736
- return experiment.work_units[int(wid)]
1169
+ return experiment.work_units()[int(wid)]
737
1170
  return None