nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,573 +0,0 @@
1
- import copy
2
- import logging
3
- import math
4
- import os
5
- import signal
6
- from collections.abc import Callable, Mapping, Sequence
7
- from datetime import timedelta
8
- from pathlib import Path
9
- from typing import Any, Literal, cast
10
-
11
- from deepmerge import always_merger
12
- from typing_extensions import TypeAlias, TypedDict, TypeVarTuple, Unpack
13
-
14
- from ._output import SubmitOutput
15
- from ._script import helper_script_to_command, write_helper_script
16
-
17
- log = logging.getLogger(__name__)
18
-
19
- TArgs = TypeVarTuple("TArgs")
20
-
21
- _Path: TypeAlias = str | Path | os.PathLike
22
- MailType: TypeAlias = Literal[
23
- "NONE",
24
- "BEGIN",
25
- "END",
26
- "FAIL",
27
- "REQUEUE",
28
- "ALL",
29
- "INVALID_DEPEND",
30
- "STAGE_OUT",
31
- "TIME_LIMIT",
32
- "TIME_LIMIT_90",
33
- "TIME_LIMIT_80",
34
- "TIME_LIMIT_50",
35
- "ARRAY_TASKS",
36
- ]
37
-
38
-
39
- class SlurmJobKwargs(TypedDict, total=False):
40
- name: str
41
- """
42
- The name of the job.
43
-
44
- This corresponds to the "-J" option in sbatch.
45
- """
46
-
47
- account: str
48
- """
49
- The account to charge resources used by this job to.
50
-
51
- This corresponds to the "-A" option in sbatch.
52
- """
53
-
54
- partition: str | Sequence[str]
55
- """
56
- The partition to submit the job to.
57
-
58
- This corresponds to the "-p" option in sbatch. If not specified, the default partition will be used.
59
- Multiple partitions can be specified, and they will be combined using logical OR.
60
- """
61
-
62
- qos: str
63
- """
64
- The quality of service to submit the job to.
65
-
66
- This corresponds to the "--qos" option in sbatch.
67
- """
68
-
69
- output_file: _Path
70
- """
71
- The file to write the job output to.
72
-
73
- This corresponds to the "-o" option in sbatch. If not specified, the output will be written to the default output file.
74
- """
75
-
76
- error_file: _Path
77
- """
78
- The file to write the job errors to.
79
-
80
- This corresponds to the "-e" option in sbatch. If not specified, the errors will be written to the default error file.
81
- """
82
-
83
- time: timedelta | Literal[0]
84
- """
85
- The maximum time for the job.
86
-
87
- This corresponds to the "-t" option in sbatch. A value of 0 means no time limit.
88
- """
89
-
90
- memory_mb: int
91
- """
92
- The maximum memory for the job in MB.
93
-
94
- This corresponds to the "--mem" option in sbatch. If not specified, the default memory limit will be used.
95
- """
96
-
97
- memory_per_cpu_mb: int
98
- """
99
- The minimum memory required per usable allocated CPU.
100
-
101
- This corresponds to the "--mem-per-cpu" option in sbatch. If not specified, the default memory limit will be used.
102
- """
103
-
104
- memory_per_gpu_mb: int
105
- """
106
- The minimum memory required per allocated GPU.
107
-
108
- This corresponds to the "--mem-per-gpu" option in sbatch. If not specified, the default memory limit will be used.
109
- """
110
-
111
- cpus_per_task: int
112
- """
113
- Advise the Slurm controller that ensuing job steps will require _ncpus_ number of processors per task.
114
-
115
- This corresponds to the "-c" option in sbatch.
116
- """
117
-
118
- nodes: int
119
- """
120
- The number of nodes to use for the job.
121
-
122
- This corresponds to the "-N" option in sbatch. The default is 1 node.
123
- """
124
-
125
- ntasks: int
126
- """
127
- The number of tasks to use for the job.
128
-
129
- This corresponds to the "-n" option in sbatch. The default is 1 task.
130
- """
131
-
132
- ntasks_per_node: int
133
- """
134
- The number of tasks for each node.
135
-
136
- This corresponds to the "--ntasks-per-node" option in sbatch.
137
- """
138
-
139
- constraint: str | Sequence[str]
140
- """
141
- Nodes can have features assigned to them by the Slurm administrator. Users can specify which of these features are required by their job using the constraint option.
142
-
143
- This corresponds to the "-C" option in sbatch.
144
- """
145
-
146
- gres: str | Sequence[str]
147
- """
148
- Specifies a comma-delimited list of generic consumable resources.
149
-
150
- This corresponds to the "--gres" option in sbatch.
151
- """
152
-
153
- gpus: int | str
154
- """
155
- Specify the total number of GPUs required for the job. An optional GPU type specification can be supplied.
156
-
157
- This corresponds to the "-G" option in sbatch.
158
- """
159
-
160
- gpus_per_node: int | str
161
- """
162
- Specify the number of GPUs required for the job on each node included in the job's resource allocation. An optional GPU type specification can be supplied.
163
-
164
- This corresponds to the "--gpus-per-node" option in sbatch.
165
- """
166
-
167
- gpus_per_task: int
168
- """
169
- Specify the number of GPUs required for the job on each task to be spawned in the job's resource allocation. An optional GPU type specification can be supplied.
170
-
171
- This corresponds to the "--gpus-per-task" option in sbatch.
172
- """
173
-
174
- mail_user: str
175
- """
176
- User to receive email notification of state changes as defined by mail_type.
177
-
178
- This corresponds to the "--mail-user" option in sbatch.
179
- """
180
-
181
- mail_type: MailType | Sequence[MailType]
182
- """
183
- Notify user by email when certain event types occur.
184
-
185
- This corresponds to the "--mail-type" option in sbatch.
186
- """
187
-
188
- dependency: str
189
- """
190
- Defer the start of this job until the specified dependencies have been satisfied.
191
-
192
- This corresponds to the "-d" option in sbatch.
193
- """
194
-
195
- exclusive: bool
196
- """
197
- The job allocation can not share nodes with other running jobs.
198
-
199
- This corresponds to the "--exclusive" option in sbatch.
200
- """
201
-
202
- signal: signal.Signals
203
- """
204
- The signal to send to the job when the job is being terminated.
205
-
206
- This corresponds to the "--signal" option in sbatch.
207
- """
208
-
209
- signal_delay: timedelta
210
- """
211
- The delay before sending the signal to the job.
212
-
213
- This corresponds to the "--signal ...@[delay]" option in sbatch.
214
- """
215
-
216
- open_mode: str
217
- """
218
- The open mode for the output and error files.
219
-
220
- This corresponds to the "--open-mode" option in sbatch.
221
-
222
- Valid values are "append" and "truncate".
223
- """
224
-
225
- requeue: bool
226
- """
227
- Requeues the job if it's pre-empted.
228
-
229
- This corresponds to the "--requeue" option in sbatch.
230
- """
231
-
232
- setup_commands: Sequence[str]
233
- """
234
- The setup commands to run before the job.
235
-
236
- These commands will be executed prior to everything else in the job script.
237
- """
238
-
239
- environment: Mapping[str, str]
240
- """
241
- The environment variables to set for the job.
242
-
243
- These variables will be set prior to executing any commands in the job script.
244
- """
245
-
246
- command_prefix: str
247
- """
248
- A command to prefix the job command with.
249
-
250
- This is used to add commands like `srun` to the job command.
251
- """
252
-
253
- command_template: str
254
- """
255
- The template for the command to execute the helper script.
256
-
257
- Default: `bash {/path/to/helper.sh}`.
258
- """
259
-
260
- srun_flags: str | Sequence[str]
261
- """
262
- The flags to pass to the `srun` command.
263
- """
264
-
265
-
266
- DEFAULT_KWARGS: SlurmJobKwargs = {
267
- "name": "ll",
268
- # "nodes": 1,
269
- # "time": timedelta(hours=2),
270
- "signal": signal.SIGURG,
271
- "signal_delay": timedelta(seconds=90),
272
- "open_mode": "append",
273
- # "requeue": True,
274
- }
275
-
276
-
277
- def _determine_gres(kwargs: SlurmJobKwargs) -> Sequence[str] | None:
278
- """
279
- There are many different ways to specify GPU resources, but some are buggy.
280
-
281
- This function normalizes all other ways to specify GPU resources to the `gres` option.
282
- """
283
-
284
- # If `--gres` is set, just return it
285
- if (gres := kwargs.get("gres")) is not None:
286
- if isinstance(gres, str):
287
- gres = [gres]
288
- return gres
289
-
290
- # We will only support `--gpus` if `--nodes` is set to 1
291
- if (gpus := kwargs.get("gpus")) is not None:
292
- if kwargs.get("nodes") != 1:
293
- raise ValueError("Cannot specify `gpus` without `nodes` set to 1.")
294
- if isinstance(gpus, int):
295
- gpus = [f"gpu:{gpus}"]
296
- return gpus
297
-
298
- # `--gpus-per-task` is only supported if `--ntasks-per-node` is set (or can be inferred).
299
- if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
300
- if (ntasks_per_node := _determine_ntasks_per_node(kwargs)) is None:
301
- raise ValueError(
302
- "Cannot specify `gpus_per_task` without `ntasks_per_node`."
303
- )
304
-
305
- gpus_per_node = ntasks_per_node * gpus_per_task
306
- return [f"gpu:{gpus_per_node}"]
307
-
308
- # `--gpus-per-node` has no restrictions
309
- if (gpus_per_node := kwargs.get("gpus_per_node")) is not None:
310
- if isinstance(gpus_per_node, int):
311
- gpus_per_node = [f"gpu:{gpus_per_node}"]
312
- return gpus_per_node
313
-
314
- return None
315
-
316
-
317
- def _determine_ntasks_per_node(kwargs: SlurmJobKwargs) -> int | None:
318
- # If `--ntasks-per-node` is set, just return it
319
- if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
320
- return ntasks_per_node
321
-
322
- # If `--ntasks` is set, we can infer `--ntasks-per-node`
323
- if (ntasks := kwargs.get("ntasks")) is not None:
324
- if (nodes := kwargs.get("nodes")) is None:
325
- raise ValueError("Cannot infer `ntasks_per_node` without `nodes`.")
326
-
327
- # If nnodes is not divisible by ntasks, raise an error
328
- if nodes % ntasks != 0:
329
- raise ValueError(
330
- "The number of nodes must be divisible by the number of tasks."
331
- )
332
-
333
- return ntasks // nodes
334
-
335
- return None
336
-
337
-
338
- def _write_batch_script_to_file(
339
- path: Path,
340
- kwargs: SlurmJobKwargs,
341
- command: str,
342
- job_array_n_jobs: int | None = None,
343
- ):
344
- with path.open("w") as f:
345
- f.write("#!/bin/bash\n")
346
-
347
- if kwargs.get("requeue"):
348
- f.write("#SBATCH --requeue\n")
349
-
350
- if job_array_n_jobs is not None:
351
- f.write(f"#SBATCH --array=1-{job_array_n_jobs}\n")
352
-
353
- if (name := kwargs.get("name")) is not None:
354
- f.write(f"#SBATCH -J {name}\n")
355
-
356
- if (account := kwargs.get("account")) is not None:
357
- f.write(f"#SBATCH --account={account}\n")
358
-
359
- if (time := kwargs.get("time")) is not None:
360
- # A time limit of zero requests that no time limit be imposed. Acceptable time formats include "minutes", "minutes:seconds", "hours:minutes:seconds", "days-hours", "days-hours:minutes" and "days-hours:minutes:seconds".
361
- if time == 0:
362
- time_str = "0"
363
- else:
364
- total_seconds = time.total_seconds()
365
- hours, remainder = divmod(total_seconds, 3600)
366
- minutes, seconds = divmod(remainder, 60)
367
- if hours > 24:
368
- days, hours = divmod(hours, 24)
369
- time_str = f"{int(days)}-{int(hours):02d}:{int(minutes):02d}"
370
- else:
371
- time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
372
- f.write(f"#SBATCH --time={time_str}\n")
373
-
374
- if (nodes := kwargs.get("nodes")) is not None:
375
- f.write(f"#SBATCH --nodes={nodes}\n")
376
-
377
- if (ntasks := kwargs.get("ntasks")) is not None:
378
- f.write(f"#SBATCH --ntasks={ntasks}\n")
379
-
380
- if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
381
- f.write(f"#SBATCH --ntasks-per-node={ntasks_per_node}\n")
382
-
383
- if (output_file := kwargs.get("output_file")) is not None:
384
- output_file = str(Path(output_file).absolute())
385
- f.write(f"#SBATCH --output={output_file}\n")
386
-
387
- if (error_file := kwargs.get("error_file")) is not None:
388
- error_file = str(Path(error_file).absolute())
389
- f.write(f"#SBATCH --error={error_file}\n")
390
-
391
- if (partition := kwargs.get("partition")) is not None:
392
- if isinstance(partition, str):
393
- partition = [partition]
394
- f.write(f"#SBATCH --partition={','.join(partition)}\n")
395
-
396
- if (qos := kwargs.get("qos")) is not None:
397
- f.write(f"#SBATCH --qos={qos}\n")
398
-
399
- if (memory_mb := kwargs.get("memory_mb")) is not None:
400
- f.write(f"#SBATCH --mem={memory_mb}\n")
401
-
402
- if (memory_per_cpu_mb := kwargs.get("memory_per_cpu_mb")) is not None:
403
- f.write(f"#SBATCH --mem-per-cpu={memory_per_cpu_mb}\n")
404
-
405
- if (memory_per_gpu_mb := kwargs.get("memory_per_gpu_mb")) is not None:
406
- f.write(f"#SBATCH --mem-per-gpu={memory_per_gpu_mb}\n")
407
-
408
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
409
- f.write(f"#SBATCH --cpus-per-task={cpus_per_task}\n")
410
-
411
- if gres := _determine_gres(kwargs):
412
- f.write(f"#SBATCH --gres={','.join(gres)}\n")
413
-
414
- if (mail_user := kwargs.get("mail_user")) is not None:
415
- f.write(f"#SBATCH --mail-user={mail_user}\n")
416
-
417
- if (mail_type := kwargs.get("mail_type")) is not None:
418
- if isinstance(mail_type, str):
419
- mail_type = [mail_type]
420
- f.write(f"#SBATCH --mail-type={','.join(mail_type)}\n")
421
-
422
- if (dependency := kwargs.get("dependency")) is not None:
423
- f.write(f"#SBATCH --dependency={dependency}\n")
424
-
425
- if kwargs.get("exclusive"):
426
- f.write("#SBATCH --exclusive\n")
427
-
428
- if (open_mode := kwargs.get("open_mode")) is not None:
429
- f.write(f"#SBATCH --open-mode={open_mode}\n")
430
-
431
- if (constraint := kwargs.get("constraint")) is not None:
432
- if isinstance(constraint, str):
433
- constraint = [constraint]
434
- f.write(f"#SBATCH --constraint={','.join(constraint)}\n")
435
-
436
- if (signal := kwargs.get("signal")) is not None:
437
- signal_str = signal.name
438
- if (signal_delay := kwargs.get("signal_delay")) is not None:
439
- signal_str += f"@{math.ceil(signal_delay.total_seconds())}"
440
- f.write(f"#SBATCH --signal={signal_str}\n")
441
-
442
- f.write("\n")
443
-
444
- if (command_prefix := kwargs.get("command_prefix")) is not None:
445
- command = " ".join(
446
- x_stripped
447
- for x in (command_prefix, command)
448
- if (x_stripped := x.strip())
449
- )
450
- f.write(f"{command}\n")
451
-
452
- return path
453
-
454
-
455
- def _update_kwargs(kwargs_in: SlurmJobKwargs, base_path: Path):
456
- # Update the kwargs with the default values
457
- kwargs = copy.deepcopy(DEFAULT_KWARGS)
458
-
459
- # Merge the kwargs
460
- kwargs = cast(SlurmJobKwargs, always_merger.merge(kwargs, kwargs_in))
461
- del kwargs_in
462
-
463
- # If out/err files are not specified, set them
464
- logs_base = base_path.parent / "logs"
465
- logs_base.mkdir(exist_ok=True)
466
-
467
- if kwargs.get("output_file") is None:
468
- kwargs["output_file"] = logs_base / "output_%j_%a.out"
469
-
470
- if kwargs.get("error_file") is None:
471
- kwargs["error_file"] = logs_base / "error_%j_%a.err"
472
-
473
- # Update the command_prefix to add srun:
474
- command_parts: list[str] = ["srun"]
475
- if (srun_flags := kwargs.get("srun_flags")) is not None:
476
- if isinstance(srun_flags, str):
477
- srun_flags = [srun_flags]
478
- command_parts.extend(srun_flags)
479
-
480
- # Add ntasks/cpus/gpus
481
- if (ntasks := kwargs.get("ntasks")) is not None:
482
- command_parts.append(f"--ntasks={ntasks}")
483
-
484
- if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
485
- command_parts.append(f"--ntasks-per-node={ntasks_per_node}")
486
-
487
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
488
- command_parts.append(f"--cpus-per-task={cpus_per_task}")
489
-
490
- if gres := _determine_gres(kwargs):
491
- command_parts.append(f"--gres={','.join(gres)}")
492
-
493
- command_parts.append("--unbuffered")
494
-
495
- # Add the task id to the output filenames
496
- if (f := kwargs.get("output_file")) is not None:
497
- f = Path(f).absolute()
498
- command_parts.extend(
499
- [
500
- "--output",
501
- str(f.with_name(f"{f.stem}-%t{f.suffix}").absolute()),
502
- ]
503
- )
504
- if (f := kwargs.get("error_file")) is not None:
505
- f = Path(f).absolute()
506
- command_parts.extend(
507
- [
508
- "--error",
509
- str(f.with_name(f"{f.stem}-%t{f.suffix}").absolute()),
510
- ]
511
- )
512
-
513
- # If there is already a command prefix, combine them.
514
- if (existing_command_prefix := kwargs.get("command_prefix")) is not None:
515
- command_parts.extend(existing_command_prefix.split())
516
- # Add the command prefix to the kwargs.
517
- kwargs["command_prefix"] = " ".join(command_parts)
518
-
519
- return kwargs
520
-
521
-
522
- def to_array_batch_script(
523
- dest: Path,
524
- callable: Callable[[Unpack[TArgs]], Any],
525
- args_list: Sequence[tuple[Unpack[TArgs]]],
526
- /,
527
- job_index_variable: str = "SLURM_ARRAY_TASK_ID",
528
- print_environment_info: bool = False,
529
- python_command_prefix: str | None = None,
530
- **kwargs: Unpack[SlurmJobKwargs],
531
- ) -> SubmitOutput:
532
- """
533
- Create the batch script for the job.
534
- """
535
-
536
- from ...picklerunner import serialize_many
537
-
538
- kwargs = _update_kwargs(kwargs, dest)
539
-
540
- # Convert the command/callable to a string for the command
541
- num_jobs = len(args_list)
542
-
543
- destdir = dest / "fns"
544
- destdir.mkdir(exist_ok=True)
545
-
546
- serialized_command = serialize_many(
547
- destdir,
548
- callable,
549
- [(args, {}) for args in args_list],
550
- start_idx=1, # Slurm job indices are 1-based
551
- )
552
- helper_path = write_helper_script(
553
- destdir,
554
- serialized_command.to_bash_command(
555
- job_index_variable, print_environment_info=print_environment_info
556
- ),
557
- kwargs.get("environment", {}),
558
- kwargs.get("setup_commands", []),
559
- command_prefix=python_command_prefix,
560
- )
561
- command = helper_script_to_command(helper_path, kwargs.get("command_template"))
562
-
563
- script_path = _write_batch_script_to_file(
564
- dest / "launch.sh",
565
- kwargs,
566
- command,
567
- job_array_n_jobs=num_jobs,
568
- )
569
- script_path = script_path.resolve().absolute()
570
- return SubmitOutput(
571
- command_parts=["sbatch", f"{script_path}"],
572
- script_path=script_path,
573
- )