xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xmanager-slurm might be problematic. Click here for more details.
- xm_slurm/__init__.py +4 -2
- xm_slurm/api.py +301 -34
- xm_slurm/batching.py +4 -4
- xm_slurm/config.py +99 -54
- xm_slurm/constants.py +15 -0
- xm_slurm/contrib/__init__.py +0 -0
- xm_slurm/contrib/clusters/__init__.py +22 -13
- xm_slurm/contrib/clusters/drac.py +34 -16
- xm_slurm/executables.py +15 -8
- xm_slurm/execution.py +86 -38
- xm_slurm/experiment.py +273 -131
- xm_slurm/experimental/parameter_controller.py +200 -0
- xm_slurm/job_blocks.py +7 -0
- xm_slurm/packageables.py +42 -20
- xm_slurm/packaging/docker/__init__.py +5 -11
- xm_slurm/packaging/docker/local.py +13 -12
- xm_slurm/packaging/utils.py +7 -55
- xm_slurm/resources.py +28 -4
- xm_slurm/scripts/_cloudpickle.py +28 -0
- xm_slurm/status.py +9 -0
- xm_slurm/templates/docker/mamba.Dockerfile +3 -1
- xm_slurm/templates/docker/python.Dockerfile +18 -10
- xm_slurm/templates/docker/uv.Dockerfile +35 -0
- xm_slurm/utils.py +18 -10
- xmanager_slurm-0.4.0.dist-info/METADATA +26 -0
- xmanager_slurm-0.4.0.dist-info/RECORD +42 -0
- {xmanager_slurm-0.3.2.dist-info → xmanager_slurm-0.4.0.dist-info}/WHEEL +1 -1
- xmanager_slurm-0.4.0.dist-info/licenses/LICENSE.md +227 -0
- xm_slurm/packaging/docker/cloud.py +0 -503
- xm_slurm/templates/docker/pdm.Dockerfile +0 -31
- xmanager_slurm-0.3.2.dist-info/METADATA +0 -25
- xmanager_slurm-0.3.2.dist-info/RECORD +0 -38
xm_slurm/experiment.py
CHANGED
|
@@ -10,11 +10,13 @@ import typing
|
|
|
10
10
|
from concurrent import futures
|
|
11
11
|
from typing import Any, Awaitable, Callable, Mapping, MutableSet, Sequence
|
|
12
12
|
|
|
13
|
+
import more_itertools as mit
|
|
13
14
|
from xmanager import xm
|
|
14
15
|
from xmanager.xm import async_packager, id_predictor
|
|
15
16
|
|
|
16
|
-
from xm_slurm import api, execution, executors
|
|
17
|
+
from xm_slurm import api, config, execution, executors
|
|
17
18
|
from xm_slurm.console import console
|
|
19
|
+
from xm_slurm.job_blocks import JobArgs
|
|
18
20
|
from xm_slurm.packaging import router
|
|
19
21
|
from xm_slurm.status import SlurmWorkUnitStatus
|
|
20
22
|
from xm_slurm.utils import UserSet
|
|
@@ -26,7 +28,7 @@ _current_job_array_queue = contextvars.ContextVar[
|
|
|
26
28
|
|
|
27
29
|
def _validate_job(
|
|
28
30
|
job: xm.JobType,
|
|
29
|
-
args_view: Mapping[str,
|
|
31
|
+
args_view: JobArgs | Mapping[str, JobArgs],
|
|
30
32
|
) -> None:
|
|
31
33
|
if not args_view:
|
|
32
34
|
return
|
|
@@ -49,7 +51,7 @@ def _validate_job(
|
|
|
49
51
|
)
|
|
50
52
|
|
|
51
53
|
if isinstance(job, xm.JobGroup) and key in job.jobs:
|
|
52
|
-
_validate_job(job.jobs[key], expanded)
|
|
54
|
+
_validate_job(job.jobs[key], typing.cast(JobArgs, expanded))
|
|
53
55
|
elif key not in allowed_keys:
|
|
54
56
|
raise ValueError(f"Only `args` and `env_vars` are supported for args on job {job!r}.")
|
|
55
57
|
|
|
@@ -117,28 +119,40 @@ class SlurmExperimentUnitMetadataContext:
|
|
|
117
119
|
class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
118
120
|
"""ExperimentUnit is a collection of semantically associated `Job`s."""
|
|
119
121
|
|
|
120
|
-
experiment: "SlurmExperiment"
|
|
122
|
+
experiment: "SlurmExperiment" # type: ignore
|
|
121
123
|
|
|
122
124
|
def __init__(
|
|
123
125
|
self,
|
|
124
126
|
experiment: xm.Experiment,
|
|
125
127
|
create_task: Callable[[Awaitable[Any]], futures.Future[Any]],
|
|
126
|
-
args:
|
|
128
|
+
args: JobArgs | None,
|
|
127
129
|
role: xm.ExperimentUnitRole,
|
|
130
|
+
identity: str = "",
|
|
128
131
|
) -> None:
|
|
129
|
-
super().__init__(experiment, create_task, args, role)
|
|
132
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
130
133
|
self._launched_jobs: list[xm.LaunchedJob] = []
|
|
131
134
|
self._execution_handles: list[execution.SlurmHandle] = []
|
|
132
135
|
self._context = SlurmExperimentUnitMetadataContext(
|
|
133
136
|
artifacts=ContextArtifacts(owner=self, artifacts=[]),
|
|
134
137
|
)
|
|
135
138
|
|
|
139
|
+
@typing.overload
|
|
140
|
+
async def _submit_jobs_for_execution(
|
|
141
|
+
self,
|
|
142
|
+
job: xm.Job,
|
|
143
|
+
args_view: JobArgs,
|
|
144
|
+
identity: str | None = ...,
|
|
145
|
+
) -> execution.SlurmHandle: ...
|
|
146
|
+
|
|
147
|
+
@typing.overload
|
|
136
148
|
async def _submit_jobs_for_execution(
|
|
137
149
|
self,
|
|
138
|
-
job: xm.
|
|
139
|
-
args_view: Mapping[str,
|
|
140
|
-
identity: str | None =
|
|
141
|
-
) -> execution.SlurmHandle:
|
|
150
|
+
job: xm.JobGroup,
|
|
151
|
+
args_view: Mapping[str, JobArgs],
|
|
152
|
+
identity: str | None = ...,
|
|
153
|
+
) -> execution.SlurmHandle: ...
|
|
154
|
+
|
|
155
|
+
async def _submit_jobs_for_execution(self, job, args_view, identity=None):
|
|
142
156
|
return await execution.launch(
|
|
143
157
|
job=job,
|
|
144
158
|
args=args_view,
|
|
@@ -147,9 +161,28 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
147
161
|
)
|
|
148
162
|
|
|
149
163
|
def _ingest_launched_jobs(self, job: xm.JobType, handle: execution.SlurmHandle) -> None:
|
|
164
|
+
self._execution_handles.append(handle)
|
|
165
|
+
|
|
166
|
+
def _ingest_job(job: xm.Job) -> None:
|
|
167
|
+
if not isinstance(self._role, xm.WorkUnitRole):
|
|
168
|
+
return
|
|
169
|
+
assert isinstance(self, SlurmWorkUnit)
|
|
170
|
+
assert job.name is not None
|
|
171
|
+
api.client().insert_job(
|
|
172
|
+
self.experiment_id,
|
|
173
|
+
self.work_unit_id,
|
|
174
|
+
api.SlurmJobModel(
|
|
175
|
+
name=job.name,
|
|
176
|
+
slurm_job_id=handle.job_id, # type: ignore
|
|
177
|
+
slurm_ssh_config=handle.ssh.serialize(),
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
|
|
150
181
|
match job:
|
|
151
182
|
case xm.JobGroup() as job_group:
|
|
152
183
|
for job in job_group.jobs.values():
|
|
184
|
+
assert isinstance(job, xm.Job)
|
|
185
|
+
_ingest_job(job)
|
|
153
186
|
self._launched_jobs.append(
|
|
154
187
|
xm.LaunchedJob(
|
|
155
188
|
name=job.name, # type: ignore
|
|
@@ -157,6 +190,7 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
157
190
|
)
|
|
158
191
|
)
|
|
159
192
|
case xm.Job():
|
|
193
|
+
_ingest_job(job)
|
|
160
194
|
self._launched_jobs.append(
|
|
161
195
|
xm.LaunchedJob(
|
|
162
196
|
name=handle.job.name, # type: ignore
|
|
@@ -191,11 +225,12 @@ class SlurmExperimentUnit(xm.ExperimentUnit):
|
|
|
191
225
|
states = await asyncio.gather(*[handle.get_state() for handle in self._execution_handles])
|
|
192
226
|
return SlurmWorkUnitStatus.aggregate(states)
|
|
193
227
|
|
|
228
|
+
@property
|
|
194
229
|
def launched_jobs(self) -> list[xm.LaunchedJob]:
|
|
195
230
|
return self._launched_jobs
|
|
196
231
|
|
|
197
232
|
@property
|
|
198
|
-
def context(self) -> SlurmExperimentUnitMetadataContext:
|
|
233
|
+
def context(self) -> SlurmExperimentUnitMetadataContext: # type: ignore
|
|
199
234
|
return self._context
|
|
200
235
|
|
|
201
236
|
|
|
@@ -204,68 +239,65 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
|
204
239
|
self,
|
|
205
240
|
experiment: "SlurmExperiment",
|
|
206
241
|
create_task: Callable[[Awaitable[Any]], futures.Future],
|
|
207
|
-
args:
|
|
242
|
+
args: JobArgs,
|
|
208
243
|
role: xm.ExperimentUnitRole,
|
|
209
244
|
work_unit_id_predictor: id_predictor.Predictor,
|
|
245
|
+
identity: str = "",
|
|
210
246
|
) -> None:
|
|
211
|
-
super().__init__(experiment, create_task, args, role)
|
|
247
|
+
super().__init__(experiment, create_task, args, role, identity=identity)
|
|
212
248
|
self._work_unit_id_predictor = work_unit_id_predictor
|
|
213
249
|
self._work_unit_id = self._work_unit_id_predictor.reserve_id()
|
|
214
|
-
api.client().insert_work_unit(
|
|
215
|
-
self.experiment_id,
|
|
216
|
-
api.WorkUnitPatchModel(
|
|
217
|
-
wid=self.work_unit_id,
|
|
218
|
-
identity=self.identity,
|
|
219
|
-
args=json.dumps(args),
|
|
220
|
-
),
|
|
221
|
-
)
|
|
222
250
|
|
|
223
|
-
def
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
name=self.experiment_unit_name,
|
|
230
|
-
slurm_job_id=handle.job_id, # type: ignore
|
|
231
|
-
slurm_cluster=json.dumps({
|
|
232
|
-
"host": handle.ssh_connection_options.host,
|
|
233
|
-
"username": handle.ssh_connection_options.username,
|
|
234
|
-
"port": handle.ssh_connection_options.port,
|
|
235
|
-
"config": handle.ssh_connection_options.config.get_options(False),
|
|
236
|
-
}),
|
|
237
|
-
),
|
|
238
|
-
)
|
|
251
|
+
def _get_existing_handle(self, job: xm.JobGroup) -> execution.SlurmHandle | None:
|
|
252
|
+
job_name = mit.one(job.jobs.keys())
|
|
253
|
+
for handle in self._execution_handles:
|
|
254
|
+
if handle.job_name == job_name:
|
|
255
|
+
return handle
|
|
256
|
+
return None
|
|
239
257
|
|
|
240
|
-
async def _launch_job_group(
|
|
258
|
+
async def _launch_job_group( # type: ignore
|
|
241
259
|
self,
|
|
242
260
|
job: xm.JobGroup,
|
|
243
|
-
args_view: Mapping[str,
|
|
261
|
+
args_view: Mapping[str, JobArgs],
|
|
244
262
|
identity: str,
|
|
245
263
|
) -> None:
|
|
246
264
|
global _current_job_array_queue
|
|
247
265
|
_validate_job(job, args_view)
|
|
266
|
+
future = asyncio.Future[execution.SlurmHandle]()
|
|
267
|
+
|
|
268
|
+
# If we already have a handle for this job, we don't need to submit it again.
|
|
269
|
+
# We'll just resolve the future with the existing handle.
|
|
270
|
+
# Otherwise we'll add callbacks to ingest the handle and the launched jobs.
|
|
271
|
+
if existing_handle := self._get_existing_handle(job):
|
|
272
|
+
future.set_result(existing_handle)
|
|
273
|
+
else:
|
|
274
|
+
future.add_done_callback(
|
|
275
|
+
lambda handle: self._ingest_launched_jobs(job, handle.result())
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
api.client().update_work_unit(
|
|
279
|
+
self.experiment_id,
|
|
280
|
+
self.work_unit_id,
|
|
281
|
+
api.ExperimentUnitPatchModel(args=json.dumps(args_view), identity=None),
|
|
282
|
+
)
|
|
248
283
|
|
|
249
|
-
future = asyncio.Future()
|
|
250
284
|
async with self._work_unit_id_predictor.submit_id(self.work_unit_id): # type: ignore
|
|
251
285
|
# If we're scheduling as part of a job queue (i.e., the queue is set on the context)
|
|
252
|
-
# then we'll insert the job and
|
|
253
|
-
#
|
|
286
|
+
# then we'll insert the job and future that'll get resolved to the proper handle
|
|
287
|
+
# when the Slurm job array is scheduled.
|
|
254
288
|
if job_array_queue := _current_job_array_queue.get():
|
|
255
289
|
job_array_queue.put_nowait((job, future))
|
|
256
|
-
# Otherwise we'
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
)
|
|
290
|
+
# Otherwise, we're scheduling a single job and we'll submit it for execution.
|
|
291
|
+
# If the future is already done, i.e., the handle is already resolved, we don't need
|
|
292
|
+
# to submit the job again.
|
|
293
|
+
elif not future.done():
|
|
294
|
+
handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
|
|
295
|
+
future.set_result(handle)
|
|
263
296
|
|
|
264
297
|
# Wait for the job handle, this is either coming from scheduling the job array
|
|
265
|
-
# or from
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
self._ingest_launched_jobs(job, handle)
|
|
298
|
+
# or from a single job submission. If an existing handle was found, this will be
|
|
299
|
+
# a no-op.
|
|
300
|
+
await future
|
|
269
301
|
|
|
270
302
|
@property
|
|
271
303
|
def experiment_unit_name(self) -> str:
|
|
@@ -282,20 +314,15 @@ class SlurmWorkUnit(xm.WorkUnit, SlurmExperimentUnit):
|
|
|
282
314
|
class SlurmAuxiliaryUnit(SlurmExperimentUnit):
|
|
283
315
|
"""An auxiliary unit operated by the Slurm backend."""
|
|
284
316
|
|
|
285
|
-
def
|
|
286
|
-
del handle
|
|
287
|
-
console.print("[red]Auxiliary units do not currently support ingestion.[/red]")
|
|
288
|
-
|
|
289
|
-
async def _launch_job_group(
|
|
317
|
+
async def _launch_job_group( # type: ignore
|
|
290
318
|
self,
|
|
291
|
-
job: xm.
|
|
292
|
-
args_view: Mapping[str,
|
|
319
|
+
job: xm.JobGroup,
|
|
320
|
+
args_view: Mapping[str, JobArgs],
|
|
293
321
|
identity: str,
|
|
294
322
|
) -> None:
|
|
295
323
|
_validate_job(job, args_view)
|
|
296
324
|
|
|
297
325
|
slurm_handle = await self._submit_jobs_for_execution(job, args_view, identity=identity)
|
|
298
|
-
self._ingest_handle(slurm_handle)
|
|
299
326
|
self._ingest_launched_jobs(job, slurm_handle)
|
|
300
327
|
|
|
301
328
|
@property
|
|
@@ -438,37 +465,47 @@ class SlurmExperiment(xm.Experiment):
|
|
|
438
465
|
job: xm.AuxiliaryUnitJob,
|
|
439
466
|
args: Mapping[str, Any] | None = ...,
|
|
440
467
|
*,
|
|
441
|
-
identity: str =
|
|
442
|
-
) -> asyncio.Future[
|
|
468
|
+
identity: str = ...,
|
|
469
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
443
470
|
|
|
444
471
|
@typing.overload
|
|
445
472
|
def add(
|
|
446
473
|
self,
|
|
447
|
-
job: xm.
|
|
448
|
-
args: Mapping[str, Any] | None
|
|
474
|
+
job: xm.JobGroup,
|
|
475
|
+
args: Mapping[str, Mapping[str, Any]] | None,
|
|
449
476
|
*,
|
|
450
|
-
role: xm.WorkUnitRole =
|
|
451
|
-
identity: str =
|
|
477
|
+
role: xm.WorkUnitRole = xm.WorkUnitRole(),
|
|
478
|
+
identity: str = ...,
|
|
452
479
|
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
453
480
|
|
|
454
481
|
@typing.overload
|
|
455
482
|
def add(
|
|
456
483
|
self,
|
|
457
|
-
job: xm.
|
|
458
|
-
args: Mapping[str, Any] | None,
|
|
484
|
+
job: xm.JobGroup,
|
|
485
|
+
args: Mapping[str, Mapping[str, Any]] | None,
|
|
459
486
|
*,
|
|
460
487
|
role: xm.ExperimentUnitRole,
|
|
461
|
-
identity: str =
|
|
488
|
+
identity: str = ...,
|
|
462
489
|
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
463
490
|
|
|
464
491
|
@typing.overload
|
|
465
492
|
def add(
|
|
466
493
|
self,
|
|
467
|
-
job: xm.
|
|
468
|
-
args: Mapping[str, Any] | None
|
|
494
|
+
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
495
|
+
args: Mapping[str, Any] | None,
|
|
496
|
+
*,
|
|
497
|
+
role: xm.WorkUnitRole = xm.WorkUnitRole(),
|
|
498
|
+
identity: str = ...,
|
|
499
|
+
) -> asyncio.Future[SlurmWorkUnit]: ...
|
|
500
|
+
|
|
501
|
+
@typing.overload
|
|
502
|
+
def add(
|
|
503
|
+
self,
|
|
504
|
+
job: xm.Job | xm.JobGeneratorType | xm.JobConfig,
|
|
505
|
+
args: Mapping[str, Any] | None,
|
|
469
506
|
*,
|
|
470
507
|
role: xm.ExperimentUnitRole,
|
|
471
|
-
identity: str =
|
|
508
|
+
identity: str = ...,
|
|
472
509
|
) -> asyncio.Future[SlurmExperimentUnit]: ...
|
|
473
510
|
|
|
474
511
|
@typing.overload
|
|
@@ -477,11 +514,20 @@ class SlurmExperiment(xm.Experiment):
|
|
|
477
514
|
job: xm.Job | xm.JobGeneratorType,
|
|
478
515
|
args: Sequence[Mapping[str, Any]],
|
|
479
516
|
*,
|
|
480
|
-
role: xm.WorkUnitRole =
|
|
481
|
-
identity: str =
|
|
517
|
+
role: xm.WorkUnitRole = xm.WorkUnitRole(),
|
|
518
|
+
identity: str = ...,
|
|
482
519
|
) -> asyncio.Future[Sequence[SlurmWorkUnit]]: ...
|
|
483
520
|
|
|
521
|
+
@typing.overload
|
|
484
522
|
def add(
|
|
523
|
+
self,
|
|
524
|
+
job: xm.JobType,
|
|
525
|
+
*,
|
|
526
|
+
role: xm.AuxiliaryUnitRole = ...,
|
|
527
|
+
identity: str = ...,
|
|
528
|
+
) -> asyncio.Future[SlurmAuxiliaryUnit]: ...
|
|
529
|
+
|
|
530
|
+
def add( # type: ignore
|
|
485
531
|
self,
|
|
486
532
|
job: xm.JobType,
|
|
487
533
|
args: Mapping[str, Any] | Sequence[Mapping[str, Any]] | None = None,
|
|
@@ -489,18 +535,14 @@ class SlurmExperiment(xm.Experiment):
|
|
|
489
535
|
role: xm.ExperimentUnitRole = xm.WorkUnitRole(),
|
|
490
536
|
identity: str = "",
|
|
491
537
|
) -> (
|
|
492
|
-
asyncio.Future[
|
|
538
|
+
asyncio.Future[SlurmAuxiliaryUnit]
|
|
539
|
+
| asyncio.Future[SlurmExperimentUnit]
|
|
493
540
|
| asyncio.Future[SlurmWorkUnit]
|
|
494
541
|
| asyncio.Future[Sequence[SlurmWorkUnit]]
|
|
495
542
|
):
|
|
496
543
|
if isinstance(args, collections.abc.Sequence):
|
|
497
544
|
if not isinstance(role, xm.WorkUnitRole):
|
|
498
545
|
raise ValueError("Only `xm.WorkUnit`s are supported for job arrays.")
|
|
499
|
-
if identity:
|
|
500
|
-
raise ValueError(
|
|
501
|
-
"Cannot set an identity on the root add call. "
|
|
502
|
-
"Please use a job generator and set the identity within."
|
|
503
|
-
)
|
|
504
546
|
if isinstance(job, xm.JobGroup):
|
|
505
547
|
raise ValueError(
|
|
506
548
|
"Job arrays over `xm.JobGroup`s aren't supported. "
|
|
@@ -509,6 +551,11 @@ class SlurmExperiment(xm.Experiment):
|
|
|
509
551
|
)
|
|
510
552
|
assert isinstance(job, xm.Job) or inspect.iscoroutinefunction(job), "Invalid job type"
|
|
511
553
|
|
|
554
|
+
# Validate job & args
|
|
555
|
+
for trial in args:
|
|
556
|
+
_validate_job(job, trial)
|
|
557
|
+
args = typing.cast(Sequence[JobArgs], args)
|
|
558
|
+
|
|
512
559
|
return asyncio.wrap_future(
|
|
513
560
|
self._create_task(self._launch_job_array(job, args, role, identity))
|
|
514
561
|
)
|
|
@@ -518,7 +565,7 @@ class SlurmExperiment(xm.Experiment):
|
|
|
518
565
|
async def _launch_job_array(
|
|
519
566
|
self,
|
|
520
567
|
job: xm.Job | xm.JobGeneratorType,
|
|
521
|
-
args: Sequence[
|
|
568
|
+
args: Sequence[JobArgs],
|
|
522
569
|
role: xm.WorkUnitRole,
|
|
523
570
|
identity: str = "",
|
|
524
571
|
) -> Sequence[SlurmWorkUnit]:
|
|
@@ -531,10 +578,9 @@ class SlurmExperiment(xm.Experiment):
|
|
|
531
578
|
# For each trial we'll schedule the job
|
|
532
579
|
# and collect the futures
|
|
533
580
|
wu_futures = []
|
|
534
|
-
for trial in args:
|
|
535
|
-
wu_futures.append(super().add(job, args=trial, role=role, identity=identity))
|
|
581
|
+
for idx, trial in enumerate(args):
|
|
582
|
+
wu_futures.append(super().add(job, args=trial, role=role, identity=f"{identity}_{idx}"))
|
|
536
583
|
|
|
537
|
-
# TODO(jfarebro): Set a timeout here
|
|
538
584
|
# We'll wait until XManager has filled the queue.
|
|
539
585
|
# There are two cases here, either we were given an xm.Job
|
|
540
586
|
# in which case this will be trivial and filled immediately.
|
|
@@ -551,9 +597,13 @@ class SlurmExperiment(xm.Experiment):
|
|
|
551
597
|
# that there's only a single job in this job group
|
|
552
598
|
job_group_view, future = job_array_queue.get_nowait()
|
|
553
599
|
assert isinstance(job_group_view, xm.JobGroup), "Expected a job group from xm"
|
|
554
|
-
|
|
600
|
+
job_view = mit.one(
|
|
601
|
+
job_group_view.jobs.values(),
|
|
602
|
+
too_short=ValueError("Expected a single `xm.Job` in job group."),
|
|
603
|
+
too_long=ValueError("Only one `xm.Job` is supported for job arrays."),
|
|
604
|
+
)
|
|
555
605
|
|
|
556
|
-
if
|
|
606
|
+
if not isinstance(job_view, xm.Job):
|
|
557
607
|
raise ValueError("Only `xm.Job` is supported for job arrays. ")
|
|
558
608
|
|
|
559
609
|
if executable is None:
|
|
@@ -571,64 +621,86 @@ class SlurmExperiment(xm.Experiment):
|
|
|
571
621
|
if job_view.name != name:
|
|
572
622
|
raise RuntimeError("Found multiple names in job array")
|
|
573
623
|
|
|
574
|
-
resolved_args.append(
|
|
575
|
-
set(xm.SequentialArgs.from_collection(job_view.args).to_dict().items())
|
|
576
|
-
)
|
|
624
|
+
resolved_args.append(xm.SequentialArgs.from_collection(job_view.args).to_list())
|
|
577
625
|
resolved_env_vars.append(set(job_view.env_vars.items()))
|
|
578
626
|
resolved_futures.append(future)
|
|
579
627
|
assert executable is not None, "No executable found?"
|
|
580
628
|
assert executor is not None, "No executor found?"
|
|
581
629
|
assert isinstance(executor, executors.Slurm), "Only Slurm executors are supported."
|
|
582
|
-
assert
|
|
583
|
-
|
|
584
|
-
|
|
630
|
+
assert (
|
|
631
|
+
executor.requirements.cluster is not None
|
|
632
|
+
), "Cluster must be specified on requirements."
|
|
633
|
+
|
|
634
|
+
# XManager merges job arguments with keyword arguments with job arguments
|
|
635
|
+
# coming first. These are the arguments that may be common across all jobs
|
|
636
|
+
# so we can find the largest common prefix and remove them from each job.
|
|
637
|
+
common_args: list[str] = list(mit.longest_common_prefix(resolved_args))
|
|
585
638
|
common_env_vars: set = functools.reduce(lambda a, b: a & b, resolved_env_vars, set())
|
|
586
639
|
|
|
587
640
|
sweep_args = [
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
641
|
+
JobArgs(
|
|
642
|
+
args=functools.reduce(
|
|
643
|
+
# Remove the common arguments from each job
|
|
644
|
+
lambda args, to_remove: args.remove_args(to_remove),
|
|
645
|
+
common_args,
|
|
646
|
+
xm.SequentialArgs.from_collection(a),
|
|
647
|
+
),
|
|
648
|
+
env_vars=dict(e.difference(common_env_vars)),
|
|
649
|
+
)
|
|
592
650
|
for a, e in zip(resolved_args, resolved_env_vars)
|
|
593
651
|
]
|
|
594
652
|
|
|
595
653
|
# No support for sweep_env_vars right now.
|
|
596
654
|
# We schedule the job array and then we'll resolve all the work units with
|
|
597
655
|
# the handles Slurm gives back to us.
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
656
|
+
# If we already have handles for all the work units, we don't need to submit the
|
|
657
|
+
# job array to SLURM.
|
|
658
|
+
num_resolved_handles = sum(future.done() for future in resolved_futures)
|
|
659
|
+
if num_resolved_handles == 0:
|
|
660
|
+
try:
|
|
661
|
+
handles = await execution.launch(
|
|
662
|
+
job=xm.Job(
|
|
663
|
+
executable=executable,
|
|
664
|
+
executor=executor,
|
|
665
|
+
name=name,
|
|
666
|
+
args=xm.SequentialArgs.from_collection(common_args),
|
|
667
|
+
env_vars=dict(common_env_vars),
|
|
668
|
+
),
|
|
669
|
+
args=sweep_args,
|
|
670
|
+
experiment_id=self.experiment_id,
|
|
671
|
+
identity=identity,
|
|
672
|
+
)
|
|
673
|
+
except Exception as e:
|
|
674
|
+
for future in resolved_futures:
|
|
675
|
+
future.set_exception(e)
|
|
676
|
+
raise
|
|
677
|
+
else:
|
|
678
|
+
for handle, future in zip(handles, resolved_futures):
|
|
679
|
+
future.set_result(handle)
|
|
680
|
+
elif 0 < num_resolved_handles < len(resolved_futures):
|
|
681
|
+
raise RuntimeError(
|
|
682
|
+
"Some array job elements have handles, but some don't. This shouldn't happen."
|
|
611
683
|
)
|
|
612
|
-
except Exception as e:
|
|
613
|
-
for future in resolved_futures:
|
|
614
|
-
future.set_exception(e)
|
|
615
|
-
raise
|
|
616
|
-
else:
|
|
617
|
-
for handle, future in zip(handles, resolved_futures):
|
|
618
|
-
future.set_result(handle)
|
|
619
684
|
|
|
620
685
|
wus = await asyncio.gather(*wu_futures)
|
|
686
|
+
|
|
621
687
|
_current_job_array_queue.set(None)
|
|
622
688
|
return wus
|
|
623
689
|
|
|
624
|
-
def
|
|
690
|
+
def _get_work_unit_by_identity(self, identity: str) -> SlurmWorkUnit | None:
|
|
691
|
+
if identity == "":
|
|
692
|
+
return None
|
|
693
|
+
for unit in self._experiment_units:
|
|
694
|
+
if isinstance(unit, SlurmWorkUnit) and unit.identity == identity:
|
|
695
|
+
return unit
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
def _create_experiment_unit( # type: ignore
|
|
625
699
|
self,
|
|
626
|
-
args:
|
|
700
|
+
args: JobArgs,
|
|
627
701
|
role: xm.ExperimentUnitRole,
|
|
628
702
|
identity: str,
|
|
629
|
-
) -> Awaitable[SlurmWorkUnit]:
|
|
630
|
-
del identity
|
|
631
|
-
|
|
703
|
+
) -> Awaitable[SlurmWorkUnit | SlurmAuxiliaryUnit]:
|
|
632
704
|
def _create_work_unit(role: xm.WorkUnitRole) -> Awaitable[SlurmWorkUnit]:
|
|
633
705
|
work_unit = SlurmWorkUnit(
|
|
634
706
|
self,
|
|
@@ -636,26 +708,55 @@ class SlurmExperiment(xm.Experiment):
|
|
|
636
708
|
args,
|
|
637
709
|
role,
|
|
638
710
|
self._work_unit_id_predictor,
|
|
711
|
+
identity=identity,
|
|
639
712
|
)
|
|
640
713
|
self._experiment_units.append(work_unit)
|
|
641
714
|
self._work_unit_count += 1
|
|
642
715
|
|
|
643
|
-
|
|
716
|
+
api.client().insert_work_unit(
|
|
717
|
+
self.experiment_id,
|
|
718
|
+
api.WorkUnitModel(
|
|
719
|
+
wid=work_unit.work_unit_id,
|
|
720
|
+
identity=work_unit.identity,
|
|
721
|
+
args=json.dumps(args),
|
|
722
|
+
),
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
644
726
|
future.set_result(work_unit)
|
|
645
727
|
return future
|
|
646
728
|
|
|
729
|
+
def _create_auxiliary_unit(role: xm.AuxiliaryUnitRole) -> Awaitable[SlurmAuxiliaryUnit]:
|
|
730
|
+
auxiliary_unit = SlurmAuxiliaryUnit(
|
|
731
|
+
self,
|
|
732
|
+
self._create_task,
|
|
733
|
+
args,
|
|
734
|
+
role,
|
|
735
|
+
identity=identity,
|
|
736
|
+
)
|
|
737
|
+
self._experiment_units.append(auxiliary_unit)
|
|
738
|
+
future = asyncio.Future[SlurmAuxiliaryUnit]()
|
|
739
|
+
future.set_result(auxiliary_unit)
|
|
740
|
+
return future
|
|
741
|
+
|
|
647
742
|
match role:
|
|
648
743
|
case xm.WorkUnitRole():
|
|
744
|
+
if (existing_unit := self._get_work_unit_by_identity(identity)) is not None:
|
|
745
|
+
future = asyncio.Future[SlurmWorkUnit]()
|
|
746
|
+
future.set_result(existing_unit)
|
|
747
|
+
return future
|
|
649
748
|
return _create_work_unit(role)
|
|
749
|
+
case xm.AuxiliaryUnitRole():
|
|
750
|
+
return _create_auxiliary_unit(role)
|
|
650
751
|
case _:
|
|
651
752
|
raise ValueError(f"Unsupported role {role}")
|
|
652
753
|
|
|
653
|
-
def _get_experiment_unit(
|
|
754
|
+
def _get_experiment_unit( # type: ignore
|
|
654
755
|
self,
|
|
655
756
|
experiment_id: int,
|
|
656
757
|
identity: str,
|
|
657
758
|
role: xm.ExperimentUnitRole,
|
|
658
|
-
args:
|
|
759
|
+
args: JobArgs | None = None,
|
|
659
760
|
) -> Awaitable[xm.ExperimentUnit]:
|
|
660
761
|
del experiment_id, identity, role, args
|
|
661
762
|
raise NotImplementedError
|
|
@@ -672,7 +773,7 @@ class SlurmExperiment(xm.Experiment):
|
|
|
672
773
|
# If no work units were added, delete this experiment
|
|
673
774
|
# This is to prevent empty experiments from being persisted
|
|
674
775
|
# and cluttering the database.
|
|
675
|
-
if self.
|
|
776
|
+
if len(self._experiment_units) == 0:
|
|
676
777
|
console.print(
|
|
677
778
|
f"[red]No work units were added to experiment `{self.experiment_title}`... deleting.[/red]"
|
|
678
779
|
)
|
|
@@ -689,14 +790,13 @@ class SlurmExperiment(xm.Experiment):
|
|
|
689
790
|
return self.context.annotations.title
|
|
690
791
|
|
|
691
792
|
@property
|
|
692
|
-
def context(self) -> SlurmExperimentMetadataContext:
|
|
793
|
+
def context(self) -> SlurmExperimentMetadataContext: # type: ignore
|
|
693
794
|
return self._experiment_context
|
|
694
795
|
|
|
695
796
|
@property
|
|
696
797
|
def work_unit_count(self) -> int:
|
|
697
798
|
return self._work_unit_count
|
|
698
799
|
|
|
699
|
-
@property
|
|
700
800
|
def work_units(self) -> Mapping[int, SlurmWorkUnit]:
|
|
701
801
|
"""Gets work units created via self.add()."""
|
|
702
802
|
return {
|
|
@@ -716,8 +816,50 @@ def create_experiment(experiment_title: str) -> SlurmExperiment:
|
|
|
716
816
|
def get_experiment(experiment_id: int) -> SlurmExperiment:
|
|
717
817
|
"""Get Experiment."""
|
|
718
818
|
experiment_model = api.client().get_experiment(experiment_id)
|
|
719
|
-
|
|
720
|
-
|
|
819
|
+
experiment = SlurmExperiment(
|
|
820
|
+
experiment_title=experiment_model.title, experiment_id=experiment_id
|
|
821
|
+
)
|
|
822
|
+
experiment._work_unit_id_predictor = id_predictor.Predictor(1)
|
|
823
|
+
|
|
824
|
+
# Populate annotations
|
|
825
|
+
experiment.context.annotations.description = experiment_model.description
|
|
826
|
+
experiment.context.annotations.note = experiment_model.note
|
|
827
|
+
experiment.context.annotations.tags = set(experiment_model.tags or [])
|
|
828
|
+
|
|
829
|
+
# Populate artifacts
|
|
830
|
+
for artifact in experiment_model.artifacts:
|
|
831
|
+
experiment.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
|
|
832
|
+
|
|
833
|
+
# Populate work units
|
|
834
|
+
for wu_model in experiment_model.work_units:
|
|
835
|
+
work_unit = SlurmWorkUnit(
|
|
836
|
+
experiment=experiment,
|
|
837
|
+
create_task=experiment._create_task,
|
|
838
|
+
args=json.loads(wu_model.args) if wu_model.args else {},
|
|
839
|
+
role=xm.WorkUnitRole(),
|
|
840
|
+
identity=wu_model.identity or "",
|
|
841
|
+
work_unit_id_predictor=experiment._work_unit_id_predictor,
|
|
842
|
+
)
|
|
843
|
+
work_unit._work_unit_id = wu_model.wid
|
|
844
|
+
|
|
845
|
+
# Populate jobs for each work unit
|
|
846
|
+
for job_model in wu_model.jobs:
|
|
847
|
+
slurm_ssh_config = config.SlurmSSHConfig.deserialize(job_model.slurm_ssh_config)
|
|
848
|
+
handle = execution.SlurmHandle(
|
|
849
|
+
ssh=slurm_ssh_config,
|
|
850
|
+
job_id=str(job_model.slurm_job_id),
|
|
851
|
+
job_name=job_model.name,
|
|
852
|
+
)
|
|
853
|
+
work_unit._execution_handles.append(handle)
|
|
854
|
+
|
|
855
|
+
# Populate artifacts for each work unit
|
|
856
|
+
for artifact in wu_model.artifacts:
|
|
857
|
+
work_unit.context.artifacts.add(Artifact(name=artifact.name, uri=artifact.uri))
|
|
858
|
+
|
|
859
|
+
experiment._experiment_units.append(work_unit)
|
|
860
|
+
experiment._work_unit_count += 1
|
|
861
|
+
|
|
862
|
+
return experiment
|
|
721
863
|
|
|
722
864
|
|
|
723
865
|
@functools.cache
|
|
@@ -733,5 +875,5 @@ def get_current_work_unit() -> SlurmWorkUnit | None:
|
|
|
733
875
|
wid := os.environ.get("XM_SLURM_WORK_UNIT_ID")
|
|
734
876
|
):
|
|
735
877
|
experiment = get_experiment(int(xid))
|
|
736
|
-
return experiment.work_units[int(wid)]
|
|
878
|
+
return experiment.work_units()[int(wid)]
|
|
737
879
|
return None
|