nshtrainer 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,350 +0,0 @@
1
- import copy
2
- import logging
3
- import os
4
- import signal
5
- import subprocess
6
- from collections.abc import Callable, Mapping, Sequence
7
- from datetime import timedelta
8
- from pathlib import Path
9
- from typing import Any, Literal
10
-
11
- from typing_extensions import (
12
- TypeAlias,
13
- TypedDict,
14
- TypeVar,
15
- TypeVarTuple,
16
- Unpack,
17
- assert_never,
18
- )
19
-
20
- from . import lsf, slurm
21
- from ._output import SubmitOutput
22
-
23
- TArgs = TypeVarTuple("TArgs")
24
- _Path: TypeAlias = str | Path | os.PathLike
25
-
26
- log = logging.getLogger(__name__)
27
-
28
-
29
- class GenericJobKwargs(TypedDict, total=False):
30
- name: str
31
- """The name of the job."""
32
-
33
- partition: str | Sequence[str]
34
- """The partition or queue to submit the job to. Same as `queue`."""
35
-
36
- queue: str | Sequence[str]
37
- """The queue to submit the job to. Same as `partition`."""
38
-
39
- qos: str
40
- """
41
- The quality of service to submit the job to.
42
-
43
- This corresponds to the "--qos" option in sbatch (only for Slurm).
44
- """
45
-
46
- account: str
47
- """The account (or project) to charge the job to. Same as `project`."""
48
-
49
- project: str
50
- """The project (or account) to charge the job to. Same as `account`."""
51
-
52
- output_file: _Path
53
- """
54
- The file to write the job output to.
55
-
56
- This corresponds to the "-o" option in bsub. If not specified, the output will be written to the default output file.
57
- """
58
-
59
- error_file: _Path
60
- """
61
- The file to write the job errors to.
62
-
63
- This corresponds to the "-e" option in bsub. If not specified, the errors will be written to the default error file.
64
- """
65
-
66
- nodes: int
67
- """The number of nodes to request."""
68
-
69
- tasks_per_node: int
70
- """The number of tasks to request per node."""
71
-
72
- cpus_per_task: int
73
- """The number of CPUs to request per task."""
74
-
75
- gpus_per_task: int
76
- """The number of GPUs to request per task."""
77
-
78
- memory_mb: int
79
- """The maximum memory for the job in MB."""
80
-
81
- walltime: timedelta
82
- """The maximum walltime for the job."""
83
-
84
- email: str
85
- """The email address to send notifications to."""
86
-
87
- notifications: set[Literal["begin", "end"]]
88
- """The notifications to send via email."""
89
-
90
- setup_commands: Sequence[str]
91
- """
92
- The setup commands to run before the job.
93
-
94
- These commands will be executed prior to everything else in the job script.
95
- """
96
-
97
- environment: Mapping[str, str]
98
- """
99
- The environment variables to set for the job.
100
-
101
- These variables will be set prior to executing any commands in the job script.
102
- """
103
-
104
- command_prefix: str
105
- """
106
- A command to prefix the job command with.
107
-
108
- This is used to add commands like `srun` or `jsrun` to the job command.
109
- """
110
-
111
- constraint: str | Sequence[str]
112
- """
113
- The constraint to request for the job. For SLRUM, this corresponds to the `--constraint` option. For LSF, this is unused.
114
- """
115
-
116
- signal: signal.Signals
117
- """The signal that will be sent to the job when it is time to stop it."""
118
-
119
- command_template: str
120
- """
121
- The template for the command to execute the helper script.
122
-
123
- Default: `bash {script}`.
124
- """
125
-
126
- requeue_on_preempt: bool
127
- """
128
- Whether to requeue the job if it is preempted.
129
-
130
- This corresponds to the "--requeue" option in sbatch (only for Slurm).
131
- """
132
-
133
- slurm_options: slurm.SlurmJobKwargs
134
- """Additional keyword arguments for Slurm jobs."""
135
-
136
- lsf_options: lsf.LSFJobKwargs
137
- """Additional keyword arguments for LSF jobs."""
138
-
139
-
140
- Scheduler: TypeAlias = Literal["slurm", "lsf"]
141
-
142
-
143
- T = TypeVar("T", infer_variance=True)
144
-
145
-
146
- def _one_of(*fns: Callable[[], T | None]) -> T | None:
147
- values = [value for fn in fns if (value := fn()) is not None]
148
-
149
- # Only one (or zero) value should be set. If not, raise an error.
150
- if len(set(values)) > 1:
151
- raise ValueError(f"Multiple values set: {values}")
152
-
153
- return next((value for value in values if value is not None), None)
154
-
155
-
156
- def _to_slurm(kwargs: GenericJobKwargs) -> slurm.SlurmJobKwargs:
157
- slurm_kwargs: slurm.SlurmJobKwargs = {}
158
- if (name := kwargs.get("name")) is not None:
159
- slurm_kwargs["name"] = name
160
- if (
161
- account := _one_of(
162
- lambda: kwargs.get("account"),
163
- lambda: kwargs.get("project"),
164
- )
165
- ) is not None:
166
- slurm_kwargs["account"] = account
167
- if (
168
- partition := _one_of(
169
- lambda: kwargs.get("partition"),
170
- lambda: kwargs.get("queue"),
171
- )
172
- ) is not None:
173
- slurm_kwargs["partition"] = partition
174
- if (qos := kwargs.get("qos")) is not None:
175
- slurm_kwargs["qos"] = qos
176
- if (output_file := kwargs.get("output_file")) is not None:
177
- slurm_kwargs["output_file"] = output_file
178
- if (error_file := kwargs.get("error_file")) is not None:
179
- slurm_kwargs["error_file"] = error_file
180
- if (walltime := kwargs.get("walltime")) is not None:
181
- slurm_kwargs["time"] = walltime
182
- if (memory_mb := kwargs.get("memory_mb")) is not None:
183
- slurm_kwargs["memory_mb"] = memory_mb
184
- if (nodes := kwargs.get("nodes")) is not None:
185
- slurm_kwargs["nodes"] = nodes
186
- if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
187
- slurm_kwargs["ntasks_per_node"] = tasks_per_node
188
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
189
- slurm_kwargs["cpus_per_task"] = cpus_per_task
190
- if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
191
- slurm_kwargs["gpus_per_task"] = gpus_per_task
192
- if (constraint := kwargs.get("constraint")) is not None:
193
- slurm_kwargs["constraint"] = constraint
194
- if (signal := kwargs.get("signal")) is not None:
195
- slurm_kwargs["signal"] = signal
196
- if (email := kwargs.get("email")) is not None:
197
- slurm_kwargs["mail_user"] = email
198
- if (notifications := kwargs.get("notifications")) is not None:
199
- mail_type: list[slurm.MailType] = []
200
- for notification in notifications:
201
- match notification:
202
- case "begin":
203
- mail_type.append("BEGIN")
204
- case "end":
205
- mail_type.append("END")
206
- case _:
207
- raise ValueError(f"Unknown notification type: {notification}")
208
- slurm_kwargs["mail_type"] = mail_type
209
- if (setup_commands := kwargs.get("setup_commands")) is not None:
210
- slurm_kwargs["setup_commands"] = setup_commands
211
- if (environment := kwargs.get("environment")) is not None:
212
- slurm_kwargs["environment"] = environment
213
- if (command_prefix := kwargs.get("command_prefix")) is not None:
214
- slurm_kwargs["command_prefix"] = command_prefix
215
- if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
216
- slurm_kwargs["requeue"] = requeue_on_preempt
217
- if (additional_kwargs := kwargs.get("slurm_options")) is not None:
218
- slurm_kwargs.update(additional_kwargs)
219
-
220
- return slurm_kwargs
221
-
222
-
223
- def _to_lsf(kwargs: GenericJobKwargs) -> lsf.LSFJobKwargs:
224
- lsf_kwargs: lsf.LSFJobKwargs = {}
225
- if (name := kwargs.get("name")) is not None:
226
- lsf_kwargs["name"] = name
227
- if (
228
- account := _one_of(
229
- lambda: kwargs.get("account"),
230
- lambda: kwargs.get("project"),
231
- )
232
- ) is not None:
233
- lsf_kwargs["project"] = account
234
- if (
235
- partition := _one_of(
236
- lambda: kwargs.get("partition"),
237
- lambda: kwargs.get("queue"),
238
- )
239
- ) is not None:
240
- lsf_kwargs["queue"] = partition
241
- if (output_file := kwargs.get("output_file")) is not None:
242
- lsf_kwargs["output_file"] = output_file
243
- if (error_file := kwargs.get("error_file")) is not None:
244
- lsf_kwargs["error_file"] = error_file
245
- if (walltime := kwargs.get("walltime")) is not None:
246
- lsf_kwargs["walltime"] = walltime
247
- if (memory_mb := kwargs.get("memory_mb")) is not None:
248
- lsf_kwargs["memory_mb"] = memory_mb
249
- if (nodes := kwargs.get("nodes")) is not None:
250
- lsf_kwargs["nodes"] = nodes
251
- if (tasks_per_node := kwargs.get("tasks_per_node")) is not None:
252
- lsf_kwargs["rs_per_node"] = tasks_per_node
253
- if (cpus_per_task := kwargs.get("cpus_per_task")) is not None:
254
- lsf_kwargs["cpus_per_rs"] = cpus_per_task
255
- if (gpus_per_task := kwargs.get("gpus_per_task")) is not None:
256
- lsf_kwargs["gpus_per_rs"] = gpus_per_task
257
- if (constraint := kwargs.get("constraint")) is not None:
258
- log.warning(f'LSF does not support constraints, ignoring "{constraint=}".')
259
- if (email := kwargs.get("email")) is not None:
260
- lsf_kwargs["email"] = email
261
- if (notifications := kwargs.get("notifications")) is not None:
262
- if "begin" in notifications:
263
- lsf_kwargs["notify_begin"] = True
264
- if "end" in notifications:
265
- lsf_kwargs["notify_end"] = True
266
- if (setup_commands := kwargs.get("setup_commands")) is not None:
267
- lsf_kwargs["setup_commands"] = setup_commands
268
- if (environment := kwargs.get("environment")) is not None:
269
- lsf_kwargs["environment"] = environment
270
- if (command_prefix := kwargs.get("command_prefix")) is not None:
271
- lsf_kwargs["command_prefix"] = command_prefix
272
- if (signal := kwargs.get("signal")) is not None:
273
- lsf_kwargs["signal"] = signal
274
- if (requeue_on_preempt := kwargs.get("requeue_on_preempt")) is not None:
275
- log.warning(
276
- f'LSF does not support requeueing, ignoring "{requeue_on_preempt=}".'
277
- )
278
- if (additional_kwargs := kwargs.get("lsf_options")) is not None:
279
- lsf_kwargs.update(additional_kwargs)
280
-
281
- return lsf_kwargs
282
-
283
-
284
- def validate_kwargs(scheduler: Scheduler, kwargs: GenericJobKwargs) -> None:
285
- match scheduler:
286
- case "slurm":
287
- _to_slurm(copy.deepcopy(kwargs))
288
- case "lsf":
289
- _to_lsf(copy.deepcopy(kwargs))
290
- case _:
291
- assert_never(scheduler)
292
-
293
-
294
- def to_array_batch_script(
295
- scheduler: Scheduler,
296
- dest: Path,
297
- callable: Callable[[Unpack[TArgs]], Any],
298
- args_list: Sequence[tuple[Unpack[TArgs]]],
299
- /,
300
- job_index_variable: str | None = None,
301
- print_environment_info: bool = False,
302
- python_command_prefix: str | None = None,
303
- **kwargs: Unpack[GenericJobKwargs],
304
- ) -> SubmitOutput:
305
- job_index_variable_kwargs = {}
306
- if job_index_variable is not None:
307
- job_index_variable_kwargs["job_index_variable"] = job_index_variable
308
- match scheduler:
309
- case "slurm":
310
- slurm_kwargs = _to_slurm(kwargs)
311
- return slurm.to_array_batch_script(
312
- dest,
313
- callable,
314
- args_list,
315
- **job_index_variable_kwargs,
316
- print_environment_info=print_environment_info,
317
- python_command_prefix=python_command_prefix,
318
- **slurm_kwargs,
319
- )
320
- case "lsf":
321
- lsf_kwargs = _to_lsf(kwargs)
322
- return lsf.to_array_batch_script(
323
- dest,
324
- callable,
325
- args_list,
326
- **job_index_variable_kwargs,
327
- print_environment_info=print_environment_info,
328
- python_command_prefix=python_command_prefix,
329
- **lsf_kwargs,
330
- )
331
- case _:
332
- assert_never(scheduler)
333
-
334
-
335
- def infer_current_scheduler() -> Scheduler:
336
- # First, we check for `bsub` as it's much less common than `sbatch`.
337
- try:
338
- subprocess.check_output(["bsub", "-V"])
339
- return "lsf"
340
- except BaseException:
341
- pass
342
-
343
- # Next, we check for `sbatch` as it's the most common scheduler.
344
- try:
345
- subprocess.check_output(["sbatch", "--version"])
346
- return "slurm"
347
- except BaseException:
348
- pass
349
-
350
- raise RuntimeError("Could not determine the current scheduler.")
nshtrainer/config.py DELETED
@@ -1,289 +0,0 @@
1
- from collections.abc import Mapping, MutableMapping
2
- from typing import TYPE_CHECKING, Any, ClassVar
3
-
4
- from pydantic import BaseModel, ConfigDict
5
- from pydantic import Field as Field
6
- from pydantic import PrivateAttr as PrivateAttr
7
- from typing_extensions import deprecated, override
8
-
9
- from ._config.missing import MISSING, validate_no_missing_values
10
- from ._config.missing import AllowMissing as AllowMissing
11
- from ._config.missing import MissingField as MissingField
12
-
13
- _MutableMappingBase = MutableMapping[str, Any]
14
- if TYPE_CHECKING:
15
- _MutableMappingBase = object
16
-
17
-
18
- _DraftConfigContextSentinel = object()
19
-
20
-
21
- class TypedConfig(BaseModel, _MutableMappingBase):
22
- _is_draft_config: bool = PrivateAttr(default=False)
23
- """
24
- Whether this config is a draft config or not.
25
-
26
- Draft configs are configs that are not yet fully validated.
27
- They allow for a nicer API when creating configs, e.g.:
28
-
29
- ```python
30
- config = MyConfig.draft()
31
-
32
- # Set some values
33
- config.a = 10
34
- config.b = "hello"
35
-
36
- # Finalize the config
37
- config = config.finalize()
38
- ```
39
- """
40
-
41
- repr_diff_only: ClassVar[bool] = True
42
- """
43
- If `True`, the repr methods will only show values for fields that are different from the default.
44
- """
45
-
46
- MISSING: ClassVar[Any] = MISSING
47
- """
48
- Alias for the `MISSING` constant.
49
- """
50
-
51
- model_config: ClassVar[ConfigDict] = ConfigDict(
52
- # By default, Pydantic will throw a warning if a field starts with "model_",
53
- # so we need to disable that warning (beacuse "model_" is a popular prefix for ML).
54
- protected_namespaces=(),
55
- validate_assignment=True,
56
- validate_return=True,
57
- validate_default=True,
58
- strict=True,
59
- revalidate_instances="always",
60
- arbitrary_types_allowed=True,
61
- extra="ignore",
62
- validation_error_cause=True,
63
- use_attribute_docstrings=True,
64
- )
65
-
66
- def __draft_pre_init__(self):
67
- """Called right before a draft config is finalized."""
68
- pass
69
-
70
- def __post_init__(self):
71
- """Called after the final config is validated."""
72
- pass
73
-
74
- @classmethod
75
- @deprecated("Use `model_validate` instead.")
76
- def from_dict(cls, model_dict: Mapping[str, Any]):
77
- return cls.model_validate(model_dict)
78
-
79
- def model_deep_validate(self, strict: bool = True):
80
- """
81
- Validate the config and all of its sub-configs.
82
-
83
- Args:
84
- config: The config to validate.
85
- strict: Whether to validate the config strictly.
86
- """
87
- config_dict = self.model_dump(round_trip=True)
88
- config = self.model_validate(config_dict, strict=strict)
89
-
90
- # Make sure that this is not a draft config
91
- if config._is_draft_config:
92
- raise ValueError("Draft configs are not valid. Call `finalize` first.")
93
-
94
- return config
95
-
96
- @classmethod
97
- def draft(cls, **kwargs):
98
- config = cls.model_construct_draft(**kwargs)
99
- return config
100
-
101
- def finalize(self, strict: bool = True):
102
- # This must be a draft config, otherwise we raise an error
103
- if not self._is_draft_config:
104
- raise ValueError("Finalize can only be called on drafts.")
105
-
106
- # First, we call `__draft_pre_init__` to allow the config to modify itself a final time
107
- self.__draft_pre_init__()
108
-
109
- # Then, we dump the config to a dict and then re-validate it
110
- return self.model_deep_validate(strict=strict)
111
-
112
- @override
113
- def model_post_init(self, __context: Any) -> None:
114
- super().model_post_init(__context)
115
-
116
- # Call the `__post_init__` method if this is not a draft config
117
- if __context is _DraftConfigContextSentinel:
118
- return
119
-
120
- self.__post_init__()
121
-
122
- # After `_post_init__` is called, we perform the final round of validation
123
- self.model_post_init_validate()
124
-
125
- def model_post_init_validate(self):
126
- validate_no_missing_values(self)
127
-
128
- @classmethod
129
- def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any):
130
- """
131
- NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class,
132
- with the following changes:
133
- - The `model_post_init` method is called with the `_DraftConfigContext` context.
134
- - The `_is_draft_config` attribute is set to `True` in the `values` dict.
135
-
136
- Creates a new instance of the `Model` class with validated data.
137
-
138
- Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
139
- Default values are respected, but no other validation is performed.
140
-
141
- !!! note
142
- `model_construct()` generally respects the `model_config.extra` setting on the provided model.
143
- That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
144
- and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
145
- Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
146
- an error if extra values are passed, but they will be ignored.
147
-
148
- Args:
149
- _fields_set: The set of field names accepted for the Model instance.
150
- values: Trusted or pre-validated data dictionary.
151
-
152
- Returns:
153
- A new instance of the `Model` class with validated data.
154
- """
155
-
156
- values["_is_draft_config"] = True
157
-
158
- m = cls.__new__(cls)
159
- fields_values: dict[str, Any] = {}
160
- fields_set = set()
161
-
162
- for name, field in cls.model_fields.items():
163
- if field.alias and field.alias in values:
164
- fields_values[name] = values.pop(field.alias)
165
- fields_set.add(name)
166
- elif name in values:
167
- fields_values[name] = values.pop(name)
168
- fields_set.add(name)
169
- elif not field.is_required():
170
- fields_values[name] = field.get_default(call_default_factory=True)
171
- if _fields_set is None:
172
- _fields_set = fields_set
173
-
174
- _extra: dict[str, Any] | None = None
175
- if cls.model_config.get("extra") == "allow":
176
- _extra = {}
177
- for k, v in values.items():
178
- _extra[k] = v
179
- object.__setattr__(m, "__dict__", fields_values)
180
- object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
181
- if not cls.__pydantic_root_model__:
182
- object.__setattr__(m, "__pydantic_extra__", _extra)
183
-
184
- if cls.__pydantic_post_init__:
185
- m.model_post_init(_DraftConfigContextSentinel)
186
- # update private attributes with values set
187
- if (
188
- hasattr(m, "__pydantic_private__")
189
- and m.__pydantic_private__ is not None
190
- ):
191
- for k, v in values.items():
192
- if k in m.__private_attributes__:
193
- m.__pydantic_private__[k] = v
194
-
195
- elif not cls.__pydantic_root_model__:
196
- # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
197
- # Since it doesn't, that means that `__pydantic_private__` should be set to None
198
- object.__setattr__(m, "__pydantic_private__", None)
199
-
200
- return m
201
-
202
- @override
203
- def __repr_args__(self):
204
- # If `repr_diff_only` is `True`, we only show the fields that are different from the default.
205
- if not self.repr_diff_only:
206
- yield from super().__repr_args__()
207
- return
208
-
209
- # First, we get the default values for all fields.
210
- default_values = self.model_construct_draft()
211
-
212
- # Then, we compare the default values with the current values.
213
- for k, v in super().__repr_args__():
214
- if k is None:
215
- yield k, v
216
- continue
217
-
218
- # If there is no default value or the value is different from the default, we yield it.
219
- if not hasattr(default_values, k) or getattr(default_values, k) != v:
220
- yield k, v
221
- continue
222
-
223
- # Otherwise, we can skip this field.
224
-
225
- # region MutableMapping implementation
226
- if not TYPE_CHECKING:
227
- # This is mainly so the config can be used with lightning's hparams
228
- # transparently and without any issues.
229
-
230
- @property
231
- def _ll_dict(self):
232
- return self.model_dump()
233
-
234
- # We need to make sure every config class
235
- # is a MutableMapping[str, Any] so that it can be used
236
- # with lightning's hparams.
237
- @override
238
- def __getitem__(self, key: str):
239
- # Key can be of the format "a.b.c"
240
- # so we need to split it into a list of keys.
241
- [first_key, *rest_keys] = key.split(".")
242
- value = self._ll_dict[first_key]
243
-
244
- for key in rest_keys:
245
- if isinstance(value, Mapping):
246
- value = value[key]
247
- else:
248
- value = getattr(value, key)
249
-
250
- return value
251
-
252
- @override
253
- def __setitem__(self, key: str, value: Any):
254
- # Key can be of the format "a.b.c"
255
- # so we need to split it into a list of keys.
256
- [first_key, *rest_keys] = key.split(".")
257
- if len(rest_keys) == 0:
258
- self._ll_dict[first_key] = value
259
- return
260
-
261
- # We need to traverse the keys until we reach the last key
262
- # and then set the value
263
- current_value = self._ll_dict[first_key]
264
- for key in rest_keys[:-1]:
265
- if isinstance(current_value, Mapping):
266
- current_value = current_value[key]
267
- else:
268
- current_value = getattr(current_value, key)
269
-
270
- # Set the value
271
- if isinstance(current_value, MutableMapping):
272
- current_value[rest_keys[-1]] = value
273
- else:
274
- setattr(current_value, rest_keys[-1], value)
275
-
276
- @override
277
- def __delitem__(self, key: str):
278
- # This is unsupported for this class
279
- raise NotImplementedError
280
-
281
- @override
282
- def __iter__(self):
283
- return iter(self._ll_dict)
284
-
285
- @override
286
- def __len__(self):
287
- return len(self._ll_dict)
288
-
289
- # endregion