nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,350 +0,0 @@
1
- import copy
2
- import logging
3
- import os
4
- import signal
5
- import subprocess
6
- from collections.abc import Callable, Mapping, Sequence
7
- from datetime import timedelta
8
- from pathlib import Path
9
- from typing import Any, Literal
10
-
11
- from typing_extensions import (
12
- TypeAlias,
13
- TypedDict,
14
- TypeVar,
15
- TypeVarTuple,
16
- Unpack,
17
- assert_never,
18
- )
19
-
20
- from . import lsf, slurm
21
- from ._output import SubmitOutput
22
-
23
- TArgs = TypeVarTuple("TArgs")
24
- _Path: TypeAlias = str | Path | os.PathLike
25
-
26
- log = logging.getLogger(__name__)
27
-
28
-
29
- class GenericJobKwargs(TypedDict, total=False):
30
- name: str
31
- """The name of the job."""
32
-
33
- partition: str | Sequence[str]
34
- """The partition or queue to submit the job to. Same as `queue`."""
35
-
36
- queue: str | Sequence[str]
37
- """The queue to submit the job to. Same as `partition`."""
38
-
39
- qos: str
40
- """
41
- The quality of service to submit the job to.
42
-
43
- This corresponds to the "--qos" option in sbatch (only for Slurm).
44
- """
45
-
46
- account: str
47
- """The account (or project) to charge the job to. Same as `project`."""
48
-
49
- project: str
50
- """The project (or account) to charge the job to. Same as `account`."""
51
-
52
- output_file: _Path
53
- """
54
- The file to write the job output to.
55
-
56
- This corresponds to the "-o" option in bsub. If not specified, the output will be written to the default output file.
57
- """
58
-
59
- error_file: _Path
60
- """
61
- The file to write the job errors to.
62
-
63
- This corresponds to the "-e" option in bsub. If not specified, the errors will be written to the default error file.
64
- """
65
-
66
- nodes: int
67
- """The number of nodes to request."""
68
-
69
- tasks_per_node: int
70
- """The number of tasks to request per node."""
71
-
72
- cpus_per_task: int
73
- """The number of CPUs to request per task."""
74
-
75
- gpus_per_task: int
76
- """The number of GPUs to request per task."""
77
-
78
- memory_mb: int
79
- """The maximum memory for the job in MB."""
80
-
81
- walltime: timedelta
82
- """The maximum walltime for the job."""
83
-
84
- email: str
85
- """The email address to send notifications to."""
86
-
87
- notifications: set[Literal["begin", "end"]]
88
- """The notifications to send via email."""
89
-
90
- setup_commands: Sequence[str]
91
- """
92
- The setup commands to run before the job.
93
-
94
- These commands will be executed prior to everything else in the job script.
95
- """
96
-
97
- environment: Mapping[str, str]
98
- """
99
- The environment variables to set for the job.
100
-
101
- These variables will be set prior to executing any commands in the job script.
102
- """
103
-
104
- command_prefix: str
105
- """
106
- A command to prefix the job command with.
107
-
108
- This is used to add commands like `srun` or `jsrun` to the job command.
109
- """
110
-
111
- constraint: str | Sequence[str]
112
- """
113
- The constraint to request for the job. For SLRUM, this corresponds to the `--constraint` option. For LSF, this is unused.
114
- """
115
-
116
- signal: signal.Signals
117
- """The signal that will be sent to the job when it is time to stop it."""
118
-
119
- command_template: str
120
- """
121
- The template for the command to execute the helper script.
122
-
123
- Default: `bash {script}`.
124
- """
125
-
126
- requeue_on_preempt: bool
127
- """
128
- Whether to requeue the job if it is preempted.
129
-
130
- This corresponds to the "--requeue" option in sbatch (only for Slurm).
131
- """
132
-
133
- slurm_options: slurm.SlurmJobKwargs
134
- """Additional keyword arguments for Slurm jobs."""
135
-
136
- lsf_options: lsf.LSFJobKwargs
137
- """Additional keyword arguments for LSF jobs."""
138
-
139
-
140
- Scheduler: TypeAlias = Literal["slurm", "lsf"]
141
-
142
-
143
- T = TypeVar("T", infer_variance=True)
144
-
145
-
146
- def _one_of(*fns: Callable[[], T | None]) -> T | None:
147
- values = [value for fn in fns if (value := fn()) is not None]
148
-
149
- # Only one (or zero) value should be set. If not, raise an error.
150
- if len(set(values)) > 1:
151
- raise ValueError(f"Multiple values set: {values}")
152
-
153
- return next((value for value in values if value is not None), None)
154
-
155
-
156
- def _to_slurm(kwargs: GenericJobKwargs) -> slurm.SlurmJobKwargs:
157
- slurm_kwargs: slurm.SlurmJobKwargs = {}
158
- if (name := kwargs.get("name")) is not None:
159
- slurm_kwargs["name"] = name
160
- if (
161
- account := _one_of(
162
- lambda: kwargs.get("account"),
163
- lambda: kwargs.get("project"),
164
- )
165
- ) is not None:
166
- slurm_kwargs["account"] = account
167
- if (
168
- partition := _one_of(
169
- lambda: kwargs.get("partition"),
170
- lambda: kwargs.get("queue"),
171
- )
172
- ) is not None:
173
- slurm_kwargs["partition"] = partition
174
- if (qos := kwargs.get("qos")) is not None:
175
- slurm_kwargs["qos"] = qos
176
- if (output_file := kwargs.get("output_file")) is not None:
177
- slurm_kwargs["output_file"] = output_file
178
- if (error_file := kwargs.get("error_file")) is not None:
179
- slurm_kwargs["error_file"] = error_file
180
- if (walltime := kwargs.get("walltime")) is not None:
181
- slurm_kwargs["time"] = walltime
182
- if (memory_mb := kwargs.get("memory_mb")) is not None:
183
- slurm_kwargs["memory_mb"] = memory_mb
184
- if (nodes := kwargs.get("nodes")) is not None:
185
- slurm_kwargs["nodes"] = nodes
186
- if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
187
- slurm_kwargs["ntasks_per_node"] = tasks_per_node
188
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
189
- slurm_kwargs["cpus_per_task"] = cpus_per_task
190
- if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
191
- slurm_kwargs["gpus_per_task"] = gpus_per_task
192
- if (constraint := kwargs.get("constraint")) is not None:
193
- slurm_kwargs["constraint"] = constraint
194
- if (signal := kwargs.get("signal")) is not None:
195
- slurm_kwargs["signal"] = signal
196
- if (email := kwargs.get("email")) is not None:
197
- slurm_kwargs["mail_user"] = email
198
- if (notifications := kwargs.get("notifications")) is not None:
199
- mail_type: list[slurm.MailType] = []
200
- for notification in notifications:
201
- match notification:
202
- case "begin":
203
- mail_type.append("BEGIN")
204
- case "end":
205
- mail_type.append("END")
206
- case _:
207
- raise ValueError(f"Unknown notification type: {notification}")
208
- slurm_kwargs["mail_type"] = mail_type
209
- if (setup_commands := kwargs.get("setup_commands")) is not None:
210
- slurm_kwargs["setup_commands"] = setup_commands
211
- if (environment := kwargs.get("environment")) is not None:
212
- slurm_kwargs["environment"] = environment
213
- if (command_prefix := kwargs.get("command_prefix")) is not None:
214
- slurm_kwargs["command_prefix"] = command_prefix
215
- if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
216
- slurm_kwargs["requeue"] = requeue_on_preempt
217
- if (additional_kwargs := kwargs.get("slurm_options")) is not None:
218
- slurm_kwargs.update(additional_kwargs)
219
-
220
- return slurm_kwargs
221
-
222
-
223
- def _to_lsf(kwargs: GenericJobKwargs) -> lsf.LSFJobKwargs:
224
- lsf_kwargs: lsf.LSFJobKwargs = {}
225
- if (name := kwargs.get("name")) is not None:
226
- lsf_kwargs["name"] = name
227
- if (
228
- account := _one_of(
229
- lambda: kwargs.get("account"),
230
- lambda: kwargs.get("project"),
231
- )
232
- ) is not None:
233
- lsf_kwargs["project"] = account
234
- if (
235
- partition := _one_of(
236
- lambda: kwargs.get("partition"),
237
- lambda: kwargs.get("queue"),
238
- )
239
- ) is not None:
240
- lsf_kwargs["queue"] = partition
241
- if (output_file := kwargs.get("output_file")) is not None:
242
- lsf_kwargs["output_file"] = output_file
243
- if (error_file := kwargs.get("error_file")) is not None:
244
- lsf_kwargs["error_file"] = error_file
245
- if (walltime := kwargs.get("walltime")) is not None:
246
- lsf_kwargs["walltime"] = walltime
247
- if (memory_mb := kwargs.get("memory_mb")) is not None:
248
- lsf_kwargs["memory_mb"] = memory_mb
249
- if (nodes := kwargs.get("nodes")) is not None:
250
- lsf_kwargs["nodes"] = nodes
251
- if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
252
- lsf_kwargs["rs_per_node"] = tasks_per_node
253
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
254
- lsf_kwargs["cpus_per_rs"] = cpus_per_task
255
- if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
256
- lsf_kwargs["gpus_per_rs"] = gpus_per_task
257
- if (constraint := kwargs.get("constraint")) is not None:
258
- log.warning(f'LSF does not support constraints, ignoring "{constraint=}".')
259
- if (email := kwargs.get("email")) is not None:
260
- lsf_kwargs["email"] = email
261
- if (notifications := kwargs.get("notifications")) is not None:
262
- if "begin" in notifications:
263
- lsf_kwargs["notify_begin"] = True
264
- if "end" in notifications:
265
- lsf_kwargs["notify_end"] = True
266
- if (setup_commands := kwargs.get("setup_commands")) is not None:
267
- lsf_kwargs["setup_commands"] = setup_commands
268
- if (environment := kwargs.get("environment")) is not None:
269
- lsf_kwargs["environment"] = environment
270
- if (command_prefix := kwargs.get("command_prefix")) is not None:
271
- lsf_kwargs["command_prefix"] = command_prefix
272
- if (signal := kwargs.get("signal")) is not None:
273
- lsf_kwargs["signal"] = signal
274
- if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
275
- log.warning(
276
- f'LSF does not support requeueing, ignoring "{requeue_on_preempt=}".'
277
- )
278
- if (additional_kwargs := kwargs.get("lsf_options")) is not None:
279
- lsf_kwargs.update(additional_kwargs)
280
-
281
- return lsf_kwargs
282
-
283
-
284
- def validate_kwargs(scheduler: Scheduler, kwargs: GenericJobKwargs) -> None:
285
- match scheduler:
286
- case "slurm":
287
- _to_slurm(copy.deepcopy(kwargs))
288
- case "lsf":
289
- _to_lsf(copy.deepcopy(kwargs))
290
- case _:
291
- assert_never(scheduler)
292
-
293
-
294
- def to_array_batch_script(
295
- scheduler: Scheduler,
296
- dest: Path,
297
- callable: Callable[[Unpack[TArgs]], Any],
298
- args_list: Sequence[tuple[Unpack[TArgs]]],
299
- /,
300
- job_index_variable: str | None = None,
301
- print_environment_info: bool = False,
302
- python_command_prefix: str | None = None,
303
- **kwargs: Unpack[GenericJobKwargs],
304
- ) -> SubmitOutput:
305
- job_index_variable_kwargs = {}
306
- if job_index_variable is not None:
307
- job_index_variable_kwargs["job_index_variable"] = job_index_variable
308
- match scheduler:
309
- case "slurm":
310
- slurm_kwargs = _to_slurm(kwargs)
311
- return slurm.to_array_batch_script(
312
- dest,
313
- callable,
314
- args_list,
315
- **job_index_variable_kwargs,
316
- print_environment_info=print_environment_info,
317
- python_command_prefix=python_command_prefix,
318
- **slurm_kwargs,
319
- )
320
- case "lsf":
321
- lsf_kwargs = _to_lsf(kwargs)
322
- return lsf.to_array_batch_script(
323
- dest,
324
- callable,
325
- args_list,
326
- **job_index_variable_kwargs,
327
- print_environment_info=print_environment_info,
328
- python_command_prefix=python_command_prefix,
329
- **lsf_kwargs,
330
- )
331
- case _:
332
- assert_never(scheduler)
333
-
334
-
335
- def infer_current_scheduler() -> Scheduler:
336
- # First, we check for `bsub` as it's much less common than `sbatch`.
337
- try:
338
- subprocess.check_output(["bsub", "-V"])
339
- return "lsf"
340
- except BaseException:
341
- pass
342
-
343
- # Next, we check for `sbatch` as it's the most common scheduler.
344
- try:
345
- subprocess.check_output(["sbatch", "--version"])
346
- return "slurm"
347
- except BaseException:
348
- pass
349
-
350
- raise RuntimeError("Could not determine the current scheduler.")
@@ -1,89 +0,0 @@
1
- from logging import getLogger
2
- from typing import Any
3
-
4
- from typing_extensions import Self, TypeVar, override
5
-
6
- log = getLogger(__name__)
7
-
8
-
9
- class Singleton:
10
- singleton_key = "_singleton_instance"
11
-
12
- @classmethod
13
- def get(cls) -> Self | None:
14
- return getattr(cls, cls.singleton_key, None)
15
-
16
- @classmethod
17
- def set(cls, instance: Self) -> None:
18
- if cls.get() is not None:
19
- log.warning(f"{cls.__qualname__} instance is already set")
20
-
21
- setattr(cls, cls.singleton_key, instance)
22
-
23
- @classmethod
24
- def reset(cls) -> None:
25
- if cls.get() is not None:
26
- delattr(cls, cls.singleton_key)
27
-
28
- @classmethod
29
- def register(cls, instance: Self) -> None:
30
- cls.set(instance)
31
-
32
- def register_self(self):
33
- self.register(self)
34
-
35
- @classmethod
36
- def instance(cls) -> Self:
37
- instance = cls.get()
38
- if instance is None:
39
- raise RuntimeError(f"{cls.__qualname__} instance is not set")
40
-
41
- return instance
42
-
43
- @override
44
- def __init_subclass__(cls, *args, **kwargs) -> None:
45
- super().__init_subclass__(*args, **kwargs)
46
-
47
- cls.reset()
48
-
49
-
50
- T = TypeVar("T", infer_variance=True)
51
-
52
-
53
- class Registry:
54
- _registry: dict[type, Any] = {}
55
-
56
- @staticmethod
57
- def register(cls_: type[T], instance: T):
58
- if not isinstance(instance, cls_):
59
- raise ValueError(f"{instance} is not an instance of {cls_.__qualname__}")
60
-
61
- if cls_ in Registry._registry:
62
- raise ValueError(f"{cls_.__qualname__} is already registered")
63
-
64
- Registry._registry[cls_] = instance
65
-
66
- @staticmethod
67
- def try_get(cls_: type[T]) -> T | None:
68
- return Registry._registry.get(cls_)
69
-
70
- @staticmethod
71
- def get(cls_: type[T]) -> T:
72
- instance = Registry.try_get(cls_)
73
- if instance is None:
74
- raise ValueError(f"{cls_.__qualname__} is not registered")
75
-
76
- return instance
77
-
78
- @staticmethod
79
- def instance(cls_: type[T]) -> T:
80
- return Registry.get(cls_)
81
-
82
- @staticmethod
83
- def reset(cls_: type[T]):
84
- if cls_ in Registry._registry:
85
- del Registry._registry[cls_]
86
-
87
- @staticmethod
88
- def reset_all():
89
- Registry._registry.clear()