nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import signal
|
|
6
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Literal, cast
|
|
10
|
+
|
|
11
|
+
from deepmerge import always_merger
|
|
12
|
+
from typing_extensions import TypeAlias, TypedDict, TypeVarTuple, Unpack
|
|
13
|
+
|
|
14
|
+
from ._output import SubmitOutput
|
|
15
|
+
from ._script import helper_script_to_command, write_helper_script
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
TArgs = TypeVarTuple("TArgs")
|
|
20
|
+
|
|
21
|
+
_Path: TypeAlias = str | Path | os.PathLike
|
|
22
|
+
MailType: TypeAlias = Literal[
|
|
23
|
+
"NONE",
|
|
24
|
+
"BEGIN",
|
|
25
|
+
"END",
|
|
26
|
+
"FAIL",
|
|
27
|
+
"REQUEUE",
|
|
28
|
+
"ALL",
|
|
29
|
+
"INVALID_DEPEND",
|
|
30
|
+
"STAGE_OUT",
|
|
31
|
+
"TIME_LIMIT",
|
|
32
|
+
"TIME_LIMIT_90",
|
|
33
|
+
"TIME_LIMIT_80",
|
|
34
|
+
"TIME_LIMIT_50",
|
|
35
|
+
"ARRAY_TASKS",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SlurmJobKwargs(TypedDict, total=False):
|
|
40
|
+
name: str
|
|
41
|
+
"""
|
|
42
|
+
The name of the job.
|
|
43
|
+
|
|
44
|
+
This corresponds to the "-J" option in sbatch.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
account: str
|
|
48
|
+
"""
|
|
49
|
+
The account to charge resources used by this job to.
|
|
50
|
+
|
|
51
|
+
This corresponds to the "-A" option in sbatch.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
partition: str | Sequence[str]
|
|
55
|
+
"""
|
|
56
|
+
The partition to submit the job to.
|
|
57
|
+
|
|
58
|
+
This corresponds to the "-p" option in sbatch. If not specified, the default partition will be used.
|
|
59
|
+
Multiple partitions can be specified, and they will be combined using logical OR.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
qos: str
|
|
63
|
+
"""
|
|
64
|
+
The quality of service to submit the job to.
|
|
65
|
+
|
|
66
|
+
This corresponds to the "--qos" option in sbatch.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
output_file: _Path
|
|
70
|
+
"""
|
|
71
|
+
The file to write the job output to.
|
|
72
|
+
|
|
73
|
+
This corresponds to the "-o" option in sbatch. If not specified, the output will be written to the default output file.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
error_file: _Path
|
|
77
|
+
"""
|
|
78
|
+
The file to write the job errors to.
|
|
79
|
+
|
|
80
|
+
This corresponds to the "-e" option in sbatch. If not specified, the errors will be written to the default error file.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
time: timedelta | Literal[0]
|
|
84
|
+
"""
|
|
85
|
+
The maximum time for the job.
|
|
86
|
+
|
|
87
|
+
This corresponds to the "-t" option in sbatch. A value of 0 means no time limit.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
memory_mb: int
|
|
91
|
+
"""
|
|
92
|
+
The maximum memory for the job in MB.
|
|
93
|
+
|
|
94
|
+
This corresponds to the "--mem" option in sbatch. If not specified, the default memory limit will be used.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
memory_per_cpu_mb: int
|
|
98
|
+
"""
|
|
99
|
+
The minimum memory required per usable allocated CPU.
|
|
100
|
+
|
|
101
|
+
This corresponds to the "--mem-per-cpu" option in sbatch. If not specified, the default memory limit will be used.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
memory_per_gpu_mb: int
|
|
105
|
+
"""
|
|
106
|
+
The minimum memory required per allocated GPU.
|
|
107
|
+
|
|
108
|
+
This corresponds to the "--mem-per-gpu" option in sbatch. If not specified, the default memory limit will be used.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
cpus_per_task: int
|
|
112
|
+
"""
|
|
113
|
+
Advise the Slurm controller that ensuing job steps will require _ncpus_ number of processors per task.
|
|
114
|
+
|
|
115
|
+
This corresponds to the "-c" option in sbatch.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
nodes: int
|
|
119
|
+
"""
|
|
120
|
+
The number of nodes to use for the job.
|
|
121
|
+
|
|
122
|
+
This corresponds to the "-N" option in sbatch. The default is 1 node.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
ntasks: int
|
|
126
|
+
"""
|
|
127
|
+
The number of tasks to use for the job.
|
|
128
|
+
|
|
129
|
+
This corresponds to the "-n" option in sbatch. The default is 1 task.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
ntasks_per_node: int
|
|
133
|
+
"""
|
|
134
|
+
The number of tasks for each node.
|
|
135
|
+
|
|
136
|
+
This corresponds to the "--ntasks-per-node" option in sbatch.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
constraint: str | Sequence[str]
|
|
140
|
+
"""
|
|
141
|
+
Nodes can have features assigned to them by the Slurm administrator. Users can specify which of these features are required by their job using the constraint option.
|
|
142
|
+
|
|
143
|
+
This corresponds to the "-C" option in sbatch.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
gres: str | Sequence[str]
|
|
147
|
+
"""
|
|
148
|
+
Specifies a comma-delimited list of generic consumable resources.
|
|
149
|
+
|
|
150
|
+
This corresponds to the "--gres" option in sbatch.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
gpus: int | str
|
|
154
|
+
"""
|
|
155
|
+
Specify the total number of GPUs required for the job. An optional GPU type specification can be supplied.
|
|
156
|
+
|
|
157
|
+
This corresponds to the "-G" option in sbatch.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
gpus_per_node: int | str
|
|
161
|
+
"""
|
|
162
|
+
Specify the number of GPUs required for the job on each node included in the job's resource allocation. An optional GPU type specification can be supplied.
|
|
163
|
+
|
|
164
|
+
This corresponds to the "--gpus-per-node" option in sbatch.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
gpus_per_task: int
|
|
168
|
+
"""
|
|
169
|
+
Specify the number of GPUs required for the job on each task to be spawned in the job's resource allocation. An optional GPU type specification can be supplied.
|
|
170
|
+
|
|
171
|
+
This corresponds to the "--gpus-per-task" option in sbatch.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
mail_user: str
|
|
175
|
+
"""
|
|
176
|
+
User to receive email notification of state changes as defined by mail_type.
|
|
177
|
+
|
|
178
|
+
This corresponds to the "--mail-user" option in sbatch.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
mail_type: MailType | Sequence[MailType]
|
|
182
|
+
"""
|
|
183
|
+
Notify user by email when certain event types occur.
|
|
184
|
+
|
|
185
|
+
This corresponds to the "--mail-type" option in sbatch.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
dependency: str
|
|
189
|
+
"""
|
|
190
|
+
Defer the start of this job until the specified dependencies have been satisfied.
|
|
191
|
+
|
|
192
|
+
This corresponds to the "-d" option in sbatch.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
exclusive: bool
|
|
196
|
+
"""
|
|
197
|
+
The job allocation can not share nodes with other running jobs.
|
|
198
|
+
|
|
199
|
+
This corresponds to the "--exclusive" option in sbatch.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
signal: signal.Signals
|
|
203
|
+
"""
|
|
204
|
+
The signal to send to the job when the job is being terminated.
|
|
205
|
+
|
|
206
|
+
This corresponds to the "--signal" option in sbatch.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
signal_delay: timedelta
|
|
210
|
+
"""
|
|
211
|
+
The delay before sending the signal to the job.
|
|
212
|
+
|
|
213
|
+
This corresponds to the "--signal ...@[delay]" option in sbatch.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
open_mode: str
|
|
217
|
+
"""
|
|
218
|
+
The open mode for the output and error files.
|
|
219
|
+
|
|
220
|
+
This corresponds to the "--open-mode" option in sbatch.
|
|
221
|
+
|
|
222
|
+
Valid values are "append" and "truncate".
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
requeue: bool
|
|
226
|
+
"""
|
|
227
|
+
Requeues the job if it's pre-empted.
|
|
228
|
+
|
|
229
|
+
This corresponds to the "--requeue" option in sbatch.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
setup_commands: Sequence[str]
|
|
233
|
+
"""
|
|
234
|
+
The setup commands to run before the job.
|
|
235
|
+
|
|
236
|
+
These commands will be executed prior to everything else in the job script.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
environment: Mapping[str, str]
|
|
240
|
+
"""
|
|
241
|
+
The environment variables to set for the job.
|
|
242
|
+
|
|
243
|
+
These variables will be set prior to executing any commands in the job script.
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
command_prefix: str
|
|
247
|
+
"""
|
|
248
|
+
A command to prefix the job command with.
|
|
249
|
+
|
|
250
|
+
This is used to add commands like `srun` to the job command.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
command_template: str
|
|
254
|
+
"""
|
|
255
|
+
The template for the command to execute the helper script.
|
|
256
|
+
|
|
257
|
+
Default: `bash {/path/to/helper.sh}`.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
srun_flags: str | Sequence[str]
|
|
261
|
+
"""
|
|
262
|
+
The flags to pass to the `srun` command.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
DEFAULT_KWARGS: SlurmJobKwargs = {
|
|
267
|
+
"name": "ll",
|
|
268
|
+
# "nodes": 1,
|
|
269
|
+
# "time": timedelta(hours=2),
|
|
270
|
+
"signal": signal.SIGURG,
|
|
271
|
+
"signal_delay": timedelta(seconds=90),
|
|
272
|
+
"open_mode": "append",
|
|
273
|
+
# "requeue": True,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _determine_gres(kwargs: SlurmJobKwargs) -> Sequence[str] | None:
|
|
278
|
+
"""
|
|
279
|
+
There are many different ways to specify GPU resources, but some are buggy.
|
|
280
|
+
|
|
281
|
+
This function normalizes all other ways to specify GPU resources to the `gres` option.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
# If `--gres` is set, just return it
|
|
285
|
+
if (gres := kwargs.get("gres")) is not None:
|
|
286
|
+
if isinstance(gres, str):
|
|
287
|
+
gres = [gres]
|
|
288
|
+
return gres
|
|
289
|
+
|
|
290
|
+
# We will only support `--gpus` if `--nodes` is set to 1
|
|
291
|
+
if (gpus := kwargs.get("gpus")) is not None:
|
|
292
|
+
if kwargs.get("nodes") != 1:
|
|
293
|
+
raise ValueError("Cannot specify `gpus` without `nodes` set to 1.")
|
|
294
|
+
if isinstance(gpus, int):
|
|
295
|
+
gpus = [f"gpu:{gpus}"]
|
|
296
|
+
return gpus
|
|
297
|
+
|
|
298
|
+
# `--gpus-per-task` is only supported if `--ntasks-per-node` is set (or can be inferred).
|
|
299
|
+
if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
|
|
300
|
+
if (ntasks_per_node := _determine_ntasks_per_node(kwargs)) is None:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
"Cannot specify `gpus_per_task` without `ntasks_per_node`."
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
gpus_per_node = ntasks_per_node * gpus_per_task
|
|
306
|
+
return [f"gpu:{gpus_per_node}"]
|
|
307
|
+
|
|
308
|
+
# `--gpus-per-node` has no restrictions
|
|
309
|
+
if (gpus_per_node := kwargs.get("gpus_per_node")) is not None:
|
|
310
|
+
if isinstance(gpus_per_node, int):
|
|
311
|
+
gpus_per_node = [f"gpu:{gpus_per_node}"]
|
|
312
|
+
return gpus_per_node
|
|
313
|
+
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _determine_ntasks_per_node(kwargs: SlurmJobKwargs) -> int | None:
|
|
318
|
+
# If `--ntasks-per-node` is set, just return it
|
|
319
|
+
if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
|
|
320
|
+
return ntasks_per_node
|
|
321
|
+
|
|
322
|
+
# If `--ntasks` is set, we can infer `--ntasks-per-node`
|
|
323
|
+
if (ntasks := kwargs.get("ntasks")) is not None:
|
|
324
|
+
if (nodes := kwargs.get("nodes")) is None:
|
|
325
|
+
raise ValueError("Cannot infer `ntasks_per_node` without `nodes`.")
|
|
326
|
+
|
|
327
|
+
# If nnodes is not divisible by ntasks, raise an error
|
|
328
|
+
if nodes % ntasks != 0:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
"The number of nodes must be divisible by the number of tasks."
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return ntasks // nodes
|
|
334
|
+
|
|
335
|
+
return None
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _write_batch_script_to_file(
|
|
339
|
+
path: Path,
|
|
340
|
+
kwargs: SlurmJobKwargs,
|
|
341
|
+
command: str,
|
|
342
|
+
job_array_n_jobs: int | None = None,
|
|
343
|
+
):
|
|
344
|
+
with path.open("w") as f:
|
|
345
|
+
f.write("#!/bin/bash\n")
|
|
346
|
+
|
|
347
|
+
if kwargs.get("requeue"):
|
|
348
|
+
f.write("#SBATCH --requeue\n")
|
|
349
|
+
|
|
350
|
+
if job_array_n_jobs is not None:
|
|
351
|
+
f.write(f"#SBATCH --array=1-{job_array_n_jobs}\n")
|
|
352
|
+
|
|
353
|
+
if (name := kwargs.get("name")) is not None:
|
|
354
|
+
f.write(f"#SBATCH -J {name}\n")
|
|
355
|
+
|
|
356
|
+
if (account := kwargs.get("account")) is not None:
|
|
357
|
+
f.write(f"#SBATCH --account={account}\n")
|
|
358
|
+
|
|
359
|
+
if (time := kwargs.get("time")) is not None:
|
|
360
|
+
# A time limit of zero requests that no time limit be imposed. Acceptable time formats include "minutes", "minutes:seconds", "hours:minutes:seconds", "days-hours", "days-hours:minutes" and "days-hours:minutes:seconds".
|
|
361
|
+
if time == 0:
|
|
362
|
+
time_str = "0"
|
|
363
|
+
else:
|
|
364
|
+
total_seconds = time.total_seconds()
|
|
365
|
+
hours, remainder = divmod(total_seconds, 3600)
|
|
366
|
+
minutes, seconds = divmod(remainder, 60)
|
|
367
|
+
if hours > 24:
|
|
368
|
+
days, hours = divmod(hours, 24)
|
|
369
|
+
time_str = f"{int(days)}-{int(hours):02d}:{int(minutes):02d}"
|
|
370
|
+
else:
|
|
371
|
+
time_str = f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"
|
|
372
|
+
f.write(f"#SBATCH --time={time_str}\n")
|
|
373
|
+
|
|
374
|
+
if (nodes := kwargs.get("nodes")) is not None:
|
|
375
|
+
f.write(f"#SBATCH --nodes={nodes}\n")
|
|
376
|
+
|
|
377
|
+
if (ntasks := kwargs.get("ntasks")) is not None:
|
|
378
|
+
f.write(f"#SBATCH --ntasks={ntasks}\n")
|
|
379
|
+
|
|
380
|
+
if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
|
|
381
|
+
f.write(f"#SBATCH --ntasks-per-node={ntasks_per_node}\n")
|
|
382
|
+
|
|
383
|
+
if (output_file := kwargs.get("output_file")) is not None:
|
|
384
|
+
output_file = str(Path(output_file).absolute())
|
|
385
|
+
f.write(f"#SBATCH --output={output_file}\n")
|
|
386
|
+
|
|
387
|
+
if (error_file := kwargs.get("error_file")) is not None:
|
|
388
|
+
error_file = str(Path(error_file).absolute())
|
|
389
|
+
f.write(f"#SBATCH --error={error_file}\n")
|
|
390
|
+
|
|
391
|
+
if (partition := kwargs.get("partition")) is not None:
|
|
392
|
+
if isinstance(partition, str):
|
|
393
|
+
partition = [partition]
|
|
394
|
+
f.write(f"#SBATCH --partition={','.join(partition)}\n")
|
|
395
|
+
|
|
396
|
+
if (qos := kwargs.get("qos")) is not None:
|
|
397
|
+
f.write(f"#SBATCH --qos={qos}\n")
|
|
398
|
+
|
|
399
|
+
if (memory_mb := kwargs.get("memory_mb")) is not None:
|
|
400
|
+
f.write(f"#SBATCH --mem={memory_mb}\n")
|
|
401
|
+
|
|
402
|
+
if (memory_per_cpu_mb := kwargs.get("memory_per_cpu_mb")) is not None:
|
|
403
|
+
f.write(f"#SBATCH --mem-per-cpu={memory_per_cpu_mb}\n")
|
|
404
|
+
|
|
405
|
+
if (memory_per_gpu_mb := kwargs.get("memory_per_gpu_mb")) is not None:
|
|
406
|
+
f.write(f"#SBATCH --mem-per-gpu={memory_per_gpu_mb}\n")
|
|
407
|
+
|
|
408
|
+
if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
|
|
409
|
+
f.write(f"#SBATCH --cpus-per-task={cpus_per_task}\n")
|
|
410
|
+
|
|
411
|
+
if gres := _determine_gres(kwargs):
|
|
412
|
+
f.write(f"#SBATCH --gres={','.join(gres)}\n")
|
|
413
|
+
|
|
414
|
+
if (mail_user := kwargs.get("mail_user")) is not None:
|
|
415
|
+
f.write(f"#SBATCH --mail-user={mail_user}\n")
|
|
416
|
+
|
|
417
|
+
if (mail_type := kwargs.get("mail_type")) is not None:
|
|
418
|
+
if isinstance(mail_type, str):
|
|
419
|
+
mail_type = [mail_type]
|
|
420
|
+
f.write(f"#SBATCH --mail-type={','.join(mail_type)}\n")
|
|
421
|
+
|
|
422
|
+
if (dependency := kwargs.get("dependency")) is not None:
|
|
423
|
+
f.write(f"#SBATCH --dependency={dependency}\n")
|
|
424
|
+
|
|
425
|
+
if kwargs.get("exclusive"):
|
|
426
|
+
f.write("#SBATCH --exclusive\n")
|
|
427
|
+
|
|
428
|
+
if (open_mode := kwargs.get("open_mode")) is not None:
|
|
429
|
+
f.write(f"#SBATCH --open-mode={open_mode}\n")
|
|
430
|
+
|
|
431
|
+
if (constraint := kwargs.get("constraint")) is not None:
|
|
432
|
+
if isinstance(constraint, str):
|
|
433
|
+
constraint = [constraint]
|
|
434
|
+
f.write(f"#SBATCH --constraint={','.join(constraint)}\n")
|
|
435
|
+
|
|
436
|
+
if (signal := kwargs.get("signal")) is not None:
|
|
437
|
+
signal_str = signal.name
|
|
438
|
+
if (signal_delay := kwargs.get("signal_delay")) is not None:
|
|
439
|
+
signal_str += f"@{math.ceil(signal_delay.total_seconds())}"
|
|
440
|
+
f.write(f"#SBATCH --signal={signal_str}\n")
|
|
441
|
+
|
|
442
|
+
f.write("\n")
|
|
443
|
+
|
|
444
|
+
if (command_prefix := kwargs.get("command_prefix")) is not None:
|
|
445
|
+
command = " ".join(
|
|
446
|
+
x_stripped
|
|
447
|
+
for x in (command_prefix, command)
|
|
448
|
+
if (x_stripped := x.strip())
|
|
449
|
+
)
|
|
450
|
+
f.write(f"{command}\n")
|
|
451
|
+
|
|
452
|
+
return path
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def _update_kwargs(kwargs_in: SlurmJobKwargs, base_path: Path):
|
|
456
|
+
# Update the kwargs with the default values
|
|
457
|
+
kwargs = copy.deepcopy(DEFAULT_KWARGS)
|
|
458
|
+
|
|
459
|
+
# Merge the kwargs
|
|
460
|
+
kwargs = cast(SlurmJobKwargs, always_merger.merge(kwargs, kwargs_in))
|
|
461
|
+
del kwargs_in
|
|
462
|
+
|
|
463
|
+
# If out/err files are not specified, set them
|
|
464
|
+
logs_base = base_path.parent / "logs"
|
|
465
|
+
logs_base.mkdir(exist_ok=True)
|
|
466
|
+
|
|
467
|
+
if kwargs.get("output_file") is None:
|
|
468
|
+
kwargs["output_file"] = logs_base / "output_%j_%a.out"
|
|
469
|
+
|
|
470
|
+
if kwargs.get("error_file") is None:
|
|
471
|
+
kwargs["error_file"] = logs_base / "error_%j_%a.err"
|
|
472
|
+
|
|
473
|
+
# Update the command_prefix to add srun:
|
|
474
|
+
command_parts: list[str] = ["srun"]
|
|
475
|
+
if (srun_flags := kwargs.get("srun_flags")) is not None:
|
|
476
|
+
if isinstance(srun_flags, str):
|
|
477
|
+
srun_flags = [srun_flags]
|
|
478
|
+
command_parts.extend(srun_flags)
|
|
479
|
+
|
|
480
|
+
# Add ntasks/cpus/gpus
|
|
481
|
+
if (ntasks := kwargs.get("ntasks")) is not None:
|
|
482
|
+
command_parts.append(f"--ntasks={ntasks}")
|
|
483
|
+
|
|
484
|
+
if (ntasks_per_node := kwargs.get("ntasks_per_node")) is not None:
|
|
485
|
+
command_parts.append(f"--ntasks-per-node={ntasks_per_node}")
|
|
486
|
+
|
|
487
|
+
if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
|
|
488
|
+
command_parts.append(f"--cpus-per-task={cpus_per_task}")
|
|
489
|
+
|
|
490
|
+
if gres := _determine_gres(kwargs):
|
|
491
|
+
command_parts.append(f"--gres={','.join(gres)}")
|
|
492
|
+
|
|
493
|
+
command_parts.append("--unbuffered")
|
|
494
|
+
|
|
495
|
+
# Add the task id to the output filenames
|
|
496
|
+
if (f := kwargs.get("output_file")) is not None:
|
|
497
|
+
f = Path(f).absolute()
|
|
498
|
+
command_parts.extend(
|
|
499
|
+
[
|
|
500
|
+
"--output",
|
|
501
|
+
str(f.with_name(f"{f.stem}-%t{f.suffix}").absolute()),
|
|
502
|
+
]
|
|
503
|
+
)
|
|
504
|
+
if (f := kwargs.get("error_file")) is not None:
|
|
505
|
+
f = Path(f).absolute()
|
|
506
|
+
command_parts.extend(
|
|
507
|
+
[
|
|
508
|
+
"--error",
|
|
509
|
+
str(f.with_name(f"{f.stem}-%t{f.suffix}").absolute()),
|
|
510
|
+
]
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# If there is already a command prefix, combine them.
|
|
514
|
+
if (existing_command_prefix := kwargs.get("command_prefix")) is not None:
|
|
515
|
+
command_parts.extend(existing_command_prefix.split())
|
|
516
|
+
# Add the command prefix to the kwargs.
|
|
517
|
+
kwargs["command_prefix"] = " ".join(command_parts)
|
|
518
|
+
|
|
519
|
+
return kwargs
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def to_array_batch_script(
|
|
523
|
+
dest: Path,
|
|
524
|
+
callable: Callable[[Unpack[TArgs]], Any],
|
|
525
|
+
args_list: Sequence[tuple[Unpack[TArgs]]],
|
|
526
|
+
/,
|
|
527
|
+
job_index_variable: str = "SLURM_ARRAY_TASK_ID",
|
|
528
|
+
print_environment_info: bool = False,
|
|
529
|
+
python_command_prefix: str | None = None,
|
|
530
|
+
**kwargs: Unpack[SlurmJobKwargs],
|
|
531
|
+
) -> SubmitOutput:
|
|
532
|
+
"""
|
|
533
|
+
Create the batch script for the job.
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
from ...picklerunner import serialize_many
|
|
537
|
+
|
|
538
|
+
kwargs = _update_kwargs(kwargs, dest)
|
|
539
|
+
|
|
540
|
+
# Convert the command/callable to a string for the command
|
|
541
|
+
num_jobs = len(args_list)
|
|
542
|
+
|
|
543
|
+
destdir = dest / "fns"
|
|
544
|
+
destdir.mkdir(exist_ok=True)
|
|
545
|
+
|
|
546
|
+
serialized_command = serialize_many(
|
|
547
|
+
destdir,
|
|
548
|
+
callable,
|
|
549
|
+
[(args, {}) for args in args_list],
|
|
550
|
+
start_idx=1, # Slurm job indices are 1-based
|
|
551
|
+
)
|
|
552
|
+
helper_path = write_helper_script(
|
|
553
|
+
destdir,
|
|
554
|
+
serialized_command.to_bash_command(
|
|
555
|
+
job_index_variable, print_environment_info=print_environment_info
|
|
556
|
+
),
|
|
557
|
+
kwargs.get("environment", {}),
|
|
558
|
+
kwargs.get("setup_commands", []),
|
|
559
|
+
command_prefix=python_command_prefix,
|
|
560
|
+
)
|
|
561
|
+
command = helper_script_to_command(helper_path, kwargs.get("command_template"))
|
|
562
|
+
|
|
563
|
+
script_path = _write_batch_script_to_file(
|
|
564
|
+
dest / "launch.sh",
|
|
565
|
+
kwargs,
|
|
566
|
+
command,
|
|
567
|
+
job_array_n_jobs=num_jobs,
|
|
568
|
+
)
|
|
569
|
+
script_path = script_path.resolve().absolute()
|
|
570
|
+
return SubmitOutput(
|
|
571
|
+
command_parts=["sbatch", f"{script_path}"],
|
|
572
|
+
script_path=script_path,
|
|
573
|
+
)
|