nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -17
- nshtrainer/callbacks/__init__.py +3 -2
- nshtrainer/callbacks/base.py +3 -4
- nshtrainer/config.py +3 -288
- nshtrainer/lr_scheduler/__init__.py +3 -2
- nshtrainer/lr_scheduler/_base.py +3 -6
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +5 -5
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +5 -4
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +9 -71
- nshtrainer/model/config.py +39 -141
- nshtrainer/nn/nonlinearity.py +3 -4
- nshtrainer/optimizer.py +3 -7
- nshtrainer/runner.py +18 -8
- nshtrainer/trainer/signal_connector.py +22 -11
- nshtrainer/trainer/trainer.py +1 -1
- nshtrainer/typecheck.py +1 -0
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/METADATA +13 -2
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/RECORD +20 -27
- nshtrainer/_submit/print_environment_info.py +0 -31
- nshtrainer/_submit/session/_output.py +0 -12
- nshtrainer/_submit/session/_script.py +0 -109
- nshtrainer/_submit/session/lsf.py +0 -467
- nshtrainer/_submit/session/slurm.py +0 -573
- nshtrainer/_submit/session/unified.py +0 -350
- nshtrainer/util/singleton.py +0 -89
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
|
@@ -1,25 +1,16 @@
|
|
|
1
1
|
from . import _experimental as _experimental
|
|
2
2
|
from . import actsave as actsave
|
|
3
3
|
from . import callbacks as callbacks
|
|
4
|
+
from . import config as config
|
|
4
5
|
from . import lr_scheduler as lr_scheduler
|
|
5
6
|
from . import model as model
|
|
6
7
|
from . import nn as nn
|
|
7
8
|
from . import optimizer as optimizer
|
|
8
|
-
from . import snapshot as snapshot
|
|
9
9
|
from . import typecheck as typecheck
|
|
10
10
|
from ._snoop import snoop as snoop
|
|
11
11
|
from .actsave import ActLoad as ActLoad
|
|
12
12
|
from .actsave import ActSave as ActSave
|
|
13
|
-
from .config import MISSING as MISSING
|
|
14
|
-
from .config import AllowMissing as AllowMissing
|
|
15
|
-
from .config import Field as Field
|
|
16
|
-
from .config import MissingField as MissingField
|
|
17
|
-
from .config import PrivateAttr as PrivateAttr
|
|
18
|
-
from .config import TypedConfig as TypedConfig
|
|
19
13
|
from .data import dataset_transform as dataset_transform
|
|
20
|
-
from .log import init_python_logging as init_python_logging
|
|
21
|
-
from .log import lovely as lovely
|
|
22
|
-
from .log import pretty as pretty
|
|
23
14
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
15
|
from .model import ActSaveConfig as ActSaveConfig
|
|
25
16
|
from .model import Base as Base
|
|
@@ -41,24 +32,17 @@ from .model import (
|
|
|
41
32
|
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
42
33
|
)
|
|
43
34
|
from .model import GradientClippingConfig as GradientClippingConfig
|
|
44
|
-
from .model import LightningDataModuleBase as LightningDataModuleBase
|
|
45
35
|
from .model import LightningModuleBase as LightningModuleBase
|
|
46
36
|
from .model import LoggingConfig as LoggingConfig
|
|
47
37
|
from .model import MetricConfig as MetricConfig
|
|
48
38
|
from .model import OptimizationConfig as OptimizationConfig
|
|
49
39
|
from .model import PrimaryMetricConfig as PrimaryMetricConfig
|
|
50
|
-
from .model import PythonLogging as PythonLogging
|
|
51
40
|
from .model import ReproducibilityConfig as ReproducibilityConfig
|
|
52
|
-
from .model import RunnerConfig as RunnerConfig
|
|
53
41
|
from .model import SanityCheckingConfig as SanityCheckingConfig
|
|
54
|
-
from .model import SeedConfig as SeedConfig
|
|
55
42
|
from .model import TrainerConfig as TrainerConfig
|
|
56
43
|
from .model import WandbWatchConfig as WandbWatchConfig
|
|
57
44
|
from .nn import TypedModuleDict as TypedModuleDict
|
|
58
45
|
from .nn import TypedModuleList as TypedModuleList
|
|
59
46
|
from .optimizer import OptimizerConfig as OptimizerConfig
|
|
60
47
|
from .runner import Runner as Runner
|
|
61
|
-
from .runner import SnapshotConfig as SnapshotConfig
|
|
62
48
|
from .trainer import Trainer as Trainer
|
|
63
|
-
from .util.singleton import Registry as Registry
|
|
64
|
-
from .util.singleton import Singleton as Singleton
|
nshtrainer/callbacks/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Annotated
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import nshconfig as C
|
|
4
|
+
|
|
4
5
|
from .base import CallbackConfigBase as CallbackConfigBase
|
|
5
6
|
from .early_stopping import EarlyStopping as EarlyStopping
|
|
6
7
|
from .ema import EMA as EMA
|
|
@@ -31,5 +32,5 @@ CallbackConfig = Annotated[
|
|
|
31
32
|
| NormLoggingConfig
|
|
32
33
|
| GradientSkippingConfig
|
|
33
34
|
| EMAConfig,
|
|
34
|
-
Field(discriminator="name"),
|
|
35
|
+
C.Field(discriminator="name"),
|
|
35
36
|
]
|
nshtrainer/callbacks/base.py
CHANGED
|
@@ -4,10 +4,9 @@ from collections.abc import Iterable
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from typing import TYPE_CHECKING, TypeAlias, TypedDict
|
|
6
6
|
|
|
7
|
+
import nshconfig as C
|
|
7
8
|
from lightning.pytorch import Callback
|
|
8
9
|
|
|
9
|
-
from ..config import TypedConfig
|
|
10
|
-
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from ..model.config import BaseConfig
|
|
13
12
|
|
|
@@ -20,7 +19,7 @@ class CallbackMetadataDict(TypedDict, total=False):
|
|
|
20
19
|
"""Priority of the callback. Callbacks with higher priority will be loaded first."""
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class CallbackMetadataConfig(
|
|
22
|
+
class CallbackMetadataConfig(C.Config):
|
|
24
23
|
ignore_if_exists: bool = False
|
|
25
24
|
"""If `True`, the callback will not be added if another callback with the same class already exists."""
|
|
26
25
|
|
|
@@ -37,7 +36,7 @@ class CallbackWithMetadata:
|
|
|
37
36
|
ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
class CallbackConfigBase(
|
|
39
|
+
class CallbackConfigBase(C.Config, ABC):
|
|
41
40
|
metadata: CallbackMetadataConfig = CallbackMetadataConfig()
|
|
42
41
|
"""Metadata for the callback."""
|
|
43
42
|
|
nshtrainer/config.py
CHANGED
|
@@ -1,289 +1,4 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from nshconfig import * # type: ignore # noqa: F403
|
|
2
|
+
from nshconfig import Config
|
|
3
3
|
|
|
4
|
-
|
|
5
|
-
from pydantic import Field as Field
|
|
6
|
-
from pydantic import PrivateAttr as PrivateAttr
|
|
7
|
-
from typing_extensions import deprecated, override
|
|
8
|
-
|
|
9
|
-
from ._config.missing import MISSING, validate_no_missing_values
|
|
10
|
-
from ._config.missing import AllowMissing as AllowMissing
|
|
11
|
-
from ._config.missing import MissingField as MissingField
|
|
12
|
-
|
|
13
|
-
_MutableMappingBase = MutableMapping[str, Any]
|
|
14
|
-
if TYPE_CHECKING:
|
|
15
|
-
_MutableMappingBase = object
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
_DraftConfigContextSentinel = object()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class TypedConfig(BaseModel, _MutableMappingBase):
|
|
22
|
-
_is_draft_config: bool = PrivateAttr(default=False)
|
|
23
|
-
"""
|
|
24
|
-
Whether this config is a draft config or not.
|
|
25
|
-
|
|
26
|
-
Draft configs are configs that are not yet fully validated.
|
|
27
|
-
They allow for a nicer API when creating configs, e.g.:
|
|
28
|
-
|
|
29
|
-
```python
|
|
30
|
-
config = MyConfig.draft()
|
|
31
|
-
|
|
32
|
-
# Set some values
|
|
33
|
-
config.a = 10
|
|
34
|
-
config.b = "hello"
|
|
35
|
-
|
|
36
|
-
# Finalize the config
|
|
37
|
-
config = config.finalize()
|
|
38
|
-
```
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
repr_diff_only: ClassVar[bool] = True
|
|
42
|
-
"""
|
|
43
|
-
If `True`, the repr methods will only show values for fields that are different from the default.
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
MISSING: ClassVar[Any] = MISSING
|
|
47
|
-
"""
|
|
48
|
-
Alias for the `MISSING` constant.
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
52
|
-
# By default, Pydantic will throw a warning if a field starts with "model_",
|
|
53
|
-
# so we need to disable that warning (beacuse "model_" is a popular prefix for ML).
|
|
54
|
-
protected_namespaces=(),
|
|
55
|
-
validate_assignment=True,
|
|
56
|
-
validate_return=True,
|
|
57
|
-
validate_default=True,
|
|
58
|
-
strict=True,
|
|
59
|
-
revalidate_instances="always",
|
|
60
|
-
arbitrary_types_allowed=True,
|
|
61
|
-
extra="ignore",
|
|
62
|
-
validation_error_cause=True,
|
|
63
|
-
use_attribute_docstrings=True,
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
def __draft_pre_init__(self):
|
|
67
|
-
"""Called right before a draft config is finalized."""
|
|
68
|
-
pass
|
|
69
|
-
|
|
70
|
-
def __post_init__(self):
|
|
71
|
-
"""Called after the final config is validated."""
|
|
72
|
-
pass
|
|
73
|
-
|
|
74
|
-
@classmethod
|
|
75
|
-
@deprecated("Use `model_validate` instead.")
|
|
76
|
-
def from_dict(cls, model_dict: Mapping[str, Any]):
|
|
77
|
-
return cls.model_validate(model_dict)
|
|
78
|
-
|
|
79
|
-
def model_deep_validate(self, strict: bool = True):
|
|
80
|
-
"""
|
|
81
|
-
Validate the config and all of its sub-configs.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
config: The config to validate.
|
|
85
|
-
strict: Whether to validate the config strictly.
|
|
86
|
-
"""
|
|
87
|
-
config_dict = self.model_dump(round_trip=True)
|
|
88
|
-
config = self.model_validate(config_dict, strict=strict)
|
|
89
|
-
|
|
90
|
-
# Make sure that this is not a draft config
|
|
91
|
-
if config._is_draft_config:
|
|
92
|
-
raise ValueError("Draft configs are not valid. Call `finalize` first.")
|
|
93
|
-
|
|
94
|
-
return config
|
|
95
|
-
|
|
96
|
-
@classmethod
|
|
97
|
-
def draft(cls, **kwargs):
|
|
98
|
-
config = cls.model_construct_draft(**kwargs)
|
|
99
|
-
return config
|
|
100
|
-
|
|
101
|
-
def finalize(self, strict: bool = True):
|
|
102
|
-
# This must be a draft config, otherwise we raise an error
|
|
103
|
-
if not self._is_draft_config:
|
|
104
|
-
raise ValueError("Finalize can only be called on drafts.")
|
|
105
|
-
|
|
106
|
-
# First, we call `__draft_pre_init__` to allow the config to modify itself a final time
|
|
107
|
-
self.__draft_pre_init__()
|
|
108
|
-
|
|
109
|
-
# Then, we dump the config to a dict and then re-validate it
|
|
110
|
-
return self.model_deep_validate(strict=strict)
|
|
111
|
-
|
|
112
|
-
@override
|
|
113
|
-
def model_post_init(self, __context: Any) -> None:
|
|
114
|
-
super().model_post_init(__context)
|
|
115
|
-
|
|
116
|
-
# Call the `__post_init__` method if this is not a draft config
|
|
117
|
-
if __context is _DraftConfigContextSentinel:
|
|
118
|
-
return
|
|
119
|
-
|
|
120
|
-
self.__post_init__()
|
|
121
|
-
|
|
122
|
-
# After `_post_init__` is called, we perform the final round of validation
|
|
123
|
-
self.model_post_init_validate()
|
|
124
|
-
|
|
125
|
-
def model_post_init_validate(self):
|
|
126
|
-
validate_no_missing_values(self)
|
|
127
|
-
|
|
128
|
-
@classmethod
|
|
129
|
-
def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any):
|
|
130
|
-
"""
|
|
131
|
-
NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class,
|
|
132
|
-
with the following changes:
|
|
133
|
-
- The `model_post_init` method is called with the `_DraftConfigContext` context.
|
|
134
|
-
- The `_is_draft_config` attribute is set to `True` in the `values` dict.
|
|
135
|
-
|
|
136
|
-
Creates a new instance of the `Model` class with validated data.
|
|
137
|
-
|
|
138
|
-
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
|
|
139
|
-
Default values are respected, but no other validation is performed.
|
|
140
|
-
|
|
141
|
-
!!! note
|
|
142
|
-
`model_construct()` generally respects the `model_config.extra` setting on the provided model.
|
|
143
|
-
That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
|
|
144
|
-
and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
|
|
145
|
-
Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
|
|
146
|
-
an error if extra values are passed, but they will be ignored.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
_fields_set: The set of field names accepted for the Model instance.
|
|
150
|
-
values: Trusted or pre-validated data dictionary.
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
A new instance of the `Model` class with validated data.
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
values["_is_draft_config"] = True
|
|
157
|
-
|
|
158
|
-
m = cls.__new__(cls)
|
|
159
|
-
fields_values: dict[str, Any] = {}
|
|
160
|
-
fields_set = set()
|
|
161
|
-
|
|
162
|
-
for name, field in cls.model_fields.items():
|
|
163
|
-
if field.alias and field.alias in values:
|
|
164
|
-
fields_values[name] = values.pop(field.alias)
|
|
165
|
-
fields_set.add(name)
|
|
166
|
-
elif name in values:
|
|
167
|
-
fields_values[name] = values.pop(name)
|
|
168
|
-
fields_set.add(name)
|
|
169
|
-
elif not field.is_required():
|
|
170
|
-
fields_values[name] = field.get_default(call_default_factory=True)
|
|
171
|
-
if _fields_set is None:
|
|
172
|
-
_fields_set = fields_set
|
|
173
|
-
|
|
174
|
-
_extra: dict[str, Any] | None = None
|
|
175
|
-
if cls.model_config.get("extra") == "allow":
|
|
176
|
-
_extra = {}
|
|
177
|
-
for k, v in values.items():
|
|
178
|
-
_extra[k] = v
|
|
179
|
-
object.__setattr__(m, "__dict__", fields_values)
|
|
180
|
-
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
|
|
181
|
-
if not cls.__pydantic_root_model__:
|
|
182
|
-
object.__setattr__(m, "__pydantic_extra__", _extra)
|
|
183
|
-
|
|
184
|
-
if cls.__pydantic_post_init__:
|
|
185
|
-
m.model_post_init(_DraftConfigContextSentinel)
|
|
186
|
-
# update private attributes with values set
|
|
187
|
-
if (
|
|
188
|
-
hasattr(m, "__pydantic_private__")
|
|
189
|
-
and m.__pydantic_private__ is not None
|
|
190
|
-
):
|
|
191
|
-
for k, v in values.items():
|
|
192
|
-
if k in m.__private_attributes__:
|
|
193
|
-
m.__pydantic_private__[k] = v
|
|
194
|
-
|
|
195
|
-
elif not cls.__pydantic_root_model__:
|
|
196
|
-
# Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
|
|
197
|
-
# Since it doesn't, that means that `__pydantic_private__` should be set to None
|
|
198
|
-
object.__setattr__(m, "__pydantic_private__", None)
|
|
199
|
-
|
|
200
|
-
return m
|
|
201
|
-
|
|
202
|
-
@override
|
|
203
|
-
def __repr_args__(self):
|
|
204
|
-
# If `repr_diff_only` is `True`, we only show the fields that are different from the default.
|
|
205
|
-
if not self.repr_diff_only:
|
|
206
|
-
yield from super().__repr_args__()
|
|
207
|
-
return
|
|
208
|
-
|
|
209
|
-
# First, we get the default values for all fields.
|
|
210
|
-
default_values = self.model_construct_draft()
|
|
211
|
-
|
|
212
|
-
# Then, we compare the default values with the current values.
|
|
213
|
-
for k, v in super().__repr_args__():
|
|
214
|
-
if k is None:
|
|
215
|
-
yield k, v
|
|
216
|
-
continue
|
|
217
|
-
|
|
218
|
-
# If there is no default value or the value is different from the default, we yield it.
|
|
219
|
-
if not hasattr(default_values, k) or getattr(default_values, k) != v:
|
|
220
|
-
yield k, v
|
|
221
|
-
continue
|
|
222
|
-
|
|
223
|
-
# Otherwise, we can skip this field.
|
|
224
|
-
|
|
225
|
-
# region MutableMapping implementation
|
|
226
|
-
if not TYPE_CHECKING:
|
|
227
|
-
# This is mainly so the config can be used with lightning's hparams
|
|
228
|
-
# transparently and without any issues.
|
|
229
|
-
|
|
230
|
-
@property
|
|
231
|
-
def _ll_dict(self):
|
|
232
|
-
return self.model_dump()
|
|
233
|
-
|
|
234
|
-
# We need to make sure every config class
|
|
235
|
-
# is a MutableMapping[str, Any] so that it can be used
|
|
236
|
-
# with lightning's hparams.
|
|
237
|
-
@override
|
|
238
|
-
def __getitem__(self, key: str):
|
|
239
|
-
# Key can be of the format "a.b.c"
|
|
240
|
-
# so we need to split it into a list of keys.
|
|
241
|
-
[first_key, *rest_keys] = key.split(".")
|
|
242
|
-
value = self._ll_dict[first_key]
|
|
243
|
-
|
|
244
|
-
for key in rest_keys:
|
|
245
|
-
if isinstance(value, Mapping):
|
|
246
|
-
value = value[key]
|
|
247
|
-
else:
|
|
248
|
-
value = getattr(value, key)
|
|
249
|
-
|
|
250
|
-
return value
|
|
251
|
-
|
|
252
|
-
@override
|
|
253
|
-
def __setitem__(self, key: str, value: Any):
|
|
254
|
-
# Key can be of the format "a.b.c"
|
|
255
|
-
# so we need to split it into a list of keys.
|
|
256
|
-
[first_key, *rest_keys] = key.split(".")
|
|
257
|
-
if len(rest_keys) == 0:
|
|
258
|
-
self._ll_dict[first_key] = value
|
|
259
|
-
return
|
|
260
|
-
|
|
261
|
-
# We need to traverse the keys until we reach the last key
|
|
262
|
-
# and then set the value
|
|
263
|
-
current_value = self._ll_dict[first_key]
|
|
264
|
-
for key in rest_keys[:-1]:
|
|
265
|
-
if isinstance(current_value, Mapping):
|
|
266
|
-
current_value = current_value[key]
|
|
267
|
-
else:
|
|
268
|
-
current_value = getattr(current_value, key)
|
|
269
|
-
|
|
270
|
-
# Set the value
|
|
271
|
-
if isinstance(current_value, MutableMapping):
|
|
272
|
-
current_value[rest_keys[-1]] = value
|
|
273
|
-
else:
|
|
274
|
-
setattr(current_value, rest_keys[-1], value)
|
|
275
|
-
|
|
276
|
-
@override
|
|
277
|
-
def __delitem__(self, key: str):
|
|
278
|
-
# This is unsupported for this class
|
|
279
|
-
raise NotImplementedError
|
|
280
|
-
|
|
281
|
-
@override
|
|
282
|
-
def __iter__(self):
|
|
283
|
-
return iter(self._ll_dict)
|
|
284
|
-
|
|
285
|
-
@override
|
|
286
|
-
def __len__(self):
|
|
287
|
-
return len(self._ll_dict)
|
|
288
|
-
|
|
289
|
-
# endregion
|
|
4
|
+
TypedConfig = Config
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Annotated, TypeAlias
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import nshconfig as C
|
|
4
|
+
|
|
4
5
|
from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
5
6
|
from ._base import LRSchedulerMetadata as LRSchedulerMetadata
|
|
6
7
|
from .linear_warmup_cosine import (
|
|
@@ -14,5 +15,5 @@ from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauCo
|
|
|
14
15
|
|
|
15
16
|
LRSchedulerConfig: TypeAlias = Annotated[
|
|
16
17
|
LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
|
|
17
|
-
Field(discriminator="name"),
|
|
18
|
+
C.Field(discriminator="name"),
|
|
18
19
|
]
|
nshtrainer/lr_scheduler/_base.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections.abc import Mapping
|
|
4
|
-
from typing import TYPE_CHECKING,
|
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
|
5
5
|
|
|
6
|
+
import nshconfig as C
|
|
6
7
|
from lightning.pytorch.utilities.types import (
|
|
7
8
|
LRSchedulerConfigType,
|
|
8
9
|
LRSchedulerTypeUnion,
|
|
@@ -10,8 +11,6 @@ from lightning.pytorch.utilities.types import (
|
|
|
10
11
|
from torch.optim import Optimizer
|
|
11
12
|
from typing_extensions import NotRequired, TypedDict
|
|
12
13
|
|
|
13
|
-
from ..config import TypedConfig
|
|
14
|
-
|
|
15
14
|
if TYPE_CHECKING:
|
|
16
15
|
from ..model.base import LightningModuleBase
|
|
17
16
|
|
|
@@ -37,9 +36,7 @@ class LRSchedulerMetadata(TypedDict):
|
|
|
37
36
|
"""Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
|
|
38
37
|
|
|
39
38
|
|
|
40
|
-
class LRSchedulerConfigBase(
|
|
41
|
-
Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
|
|
42
|
-
|
|
39
|
+
class LRSchedulerConfigBase(C.Config, ABC):
|
|
43
40
|
@abstractmethod
|
|
44
41
|
def metadata(self) -> LRSchedulerMetadata: ...
|
|
45
42
|
|
|
@@ -2,12 +2,12 @@ import math
|
|
|
2
2
|
import warnings
|
|
3
3
|
from typing import Literal
|
|
4
4
|
|
|
5
|
+
import nshconfig as C
|
|
5
6
|
from torch.optim import Optimizer
|
|
6
7
|
from torch.optim.lr_scheduler import LRScheduler
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
|
-
from
|
|
10
|
-
from ._base import LRSchedulerConfigBase
|
|
10
|
+
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
@@ -91,11 +91,11 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
|
91
91
|
class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
92
92
|
name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
|
|
93
93
|
|
|
94
|
-
warmup_epochs: int = Field(ge=0)
|
|
94
|
+
warmup_epochs: int = C.Field(ge=0)
|
|
95
95
|
r"""The number of epochs for the linear warmup phase.
|
|
96
96
|
The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
|
|
97
97
|
|
|
98
|
-
max_epochs: int = Field(gt=0)
|
|
98
|
+
max_epochs: int = C.Field(gt=0)
|
|
99
99
|
r"""The total number of epochs.
|
|
100
100
|
The learning rate is decayed to `min_lr` over this number of epochs."""
|
|
101
101
|
|
|
@@ -113,7 +113,7 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
|
|
|
113
113
|
If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
|
|
114
114
|
|
|
115
115
|
@override
|
|
116
|
-
def metadata(self) ->
|
|
116
|
+
def metadata(self) -> LRSchedulerMetadata:
|
|
117
117
|
return {
|
|
118
118
|
"interval": "step",
|
|
119
119
|
}
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Literal, cast
|
|
2
2
|
|
|
3
|
+
from lightning.pytorch.utilities.types import LRSchedulerConfigType
|
|
3
4
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from ll.lr_scheduler._base import LRSchedulerMetadata
|
|
7
|
-
|
|
8
7
|
from ..model.config import MetricConfig
|
|
9
|
-
from ._base import LRSchedulerConfigBase
|
|
8
|
+
from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
|
|
10
9
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from ..model.base import BaseConfig
|
|
@@ -43,7 +42,9 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
|
|
|
43
42
|
r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
|
|
44
43
|
|
|
45
44
|
@override
|
|
46
|
-
def create_scheduler_impl(
|
|
45
|
+
def create_scheduler_impl(
|
|
46
|
+
self, optimizer, lightning_module, lr
|
|
47
|
+
) -> LRSchedulerConfigType:
|
|
47
48
|
if (metric := self.metric) is None:
|
|
48
49
|
lm_config = cast("BaseConfig", lightning_module.config)
|
|
49
50
|
assert (
|
nshtrainer/model/__init__.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing_extensions import TypeAlias
|
|
2
2
|
|
|
3
3
|
from .base import Base as Base
|
|
4
|
-
from .base import LightningDataModuleBase as LightningDataModuleBase
|
|
5
4
|
from .base import LightningModuleBase as LightningModuleBase
|
|
6
5
|
from .config import ActSaveConfig as ActSaveConfig
|
|
7
6
|
from .config import BaseConfig as BaseConfig
|
|
@@ -33,11 +32,8 @@ from .config import (
|
|
|
33
32
|
)
|
|
34
33
|
from .config import OptimizationConfig as OptimizationConfig
|
|
35
34
|
from .config import PrimaryMetricConfig as PrimaryMetricConfig
|
|
36
|
-
from .config import PythonLogging as PythonLogging
|
|
37
35
|
from .config import ReproducibilityConfig as ReproducibilityConfig
|
|
38
|
-
from .config import RunnerConfig as RunnerConfig
|
|
39
36
|
from .config import SanityCheckingConfig as SanityCheckingConfig
|
|
40
|
-
from .config import SeedConfig as SeedConfig
|
|
41
37
|
from .config import TrainerConfig as TrainerConfig
|
|
42
38
|
from .config import WandbWatchConfig as WandbWatchConfig
|
|
43
39
|
|
nshtrainer/model/base.py
CHANGED
|
@@ -23,11 +23,12 @@ from .config import (
|
|
|
23
23
|
EnvironmentLinuxEnvironmentConfig,
|
|
24
24
|
EnvironmentLSFInformationConfig,
|
|
25
25
|
EnvironmentSLURMInformationConfig,
|
|
26
|
+
EnvironmentSnapshotConfig,
|
|
26
27
|
)
|
|
27
|
-
from .modules.callback import CallbackModuleMixin
|
|
28
|
+
from .modules.callback import CallbackModuleMixin
|
|
28
29
|
from .modules.debug import DebugModuleMixin
|
|
29
30
|
from .modules.distributed import DistributedMixin
|
|
30
|
-
from .modules.logger import LoggerLightningModuleMixin
|
|
31
|
+
from .modules.logger import LoggerLightningModuleMixin
|
|
31
32
|
from .modules.profiler import ProfilerMixin
|
|
32
33
|
from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
|
|
33
34
|
from .modules.shared_parameters import SharedParametersModuleMixin
|
|
@@ -265,6 +266,9 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
265
266
|
boot_time=_try_get(lambda: _psutil().boot_time()),
|
|
266
267
|
load_avg=_try_get(lambda: os.getloadavg()),
|
|
267
268
|
)
|
|
269
|
+
hparams.environment.snapshot = (
|
|
270
|
+
EnvironmentSnapshotConfig.from_current_environment()
|
|
271
|
+
)
|
|
268
272
|
|
|
269
273
|
def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
|
|
270
274
|
"""
|
|
@@ -309,15 +313,12 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
309
313
|
@property
|
|
310
314
|
def datamodule(self):
|
|
311
315
|
datamodule = getattr(self.trainer, "datamodule", None)
|
|
312
|
-
if datamodule is None:
|
|
316
|
+
if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
|
|
313
317
|
return None
|
|
314
|
-
|
|
315
|
-
if not isinstance(datamodule, LightningDataModuleBase):
|
|
318
|
+
if not isinstance(datamodule, LightningDataModule):
|
|
316
319
|
raise TypeError(
|
|
317
|
-
f"datamodule must be a
|
|
320
|
+
f"datamodule must be a LightningDataModule: {type(datamodule)}"
|
|
318
321
|
)
|
|
319
|
-
|
|
320
|
-
datamodule = cast(LightningDataModuleBase[THparams], datamodule)
|
|
321
322
|
return datamodule
|
|
322
323
|
|
|
323
324
|
if TYPE_CHECKING:
|
|
@@ -576,66 +577,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
576
577
|
"batch": batch,
|
|
577
578
|
"batch_idx": batch_idx,
|
|
578
579
|
}
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
class LightningDataModuleBase(
|
|
582
|
-
LoggerModuleMixin,
|
|
583
|
-
CallbackRegistrarModuleMixin,
|
|
584
|
-
Base[THparams],
|
|
585
|
-
LightningDataModule,
|
|
586
|
-
ABC,
|
|
587
|
-
Generic[THparams],
|
|
588
|
-
):
|
|
589
|
-
hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
|
|
590
|
-
hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
|
|
591
|
-
|
|
592
|
-
def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
|
|
593
|
-
"""
|
|
594
|
-
Override this method to update the hparams dictionary before it is used to create the hparams object.
|
|
595
|
-
Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
|
|
596
|
-
"""
|
|
597
|
-
return hparams
|
|
598
|
-
|
|
599
|
-
def pre_init_update_hparams(self, hparams: THparams):
|
|
600
|
-
"""
|
|
601
|
-
Override this method to update the hparams object before it is used to create the hparams_initial object.
|
|
602
|
-
"""
|
|
603
|
-
return hparams
|
|
604
|
-
|
|
605
|
-
@classmethod
|
|
606
|
-
def _update_environment(cls, hparams: THparams):
|
|
607
|
-
hparams.environment.data = _cls_info(cls)
|
|
608
|
-
|
|
609
|
-
@override
|
|
610
|
-
def __init__(self, hparams: THparams):
|
|
611
|
-
if not isinstance(hparams, BaseConfig):
|
|
612
|
-
if not isinstance(hparams, MutableMapping):
|
|
613
|
-
raise TypeError(
|
|
614
|
-
f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
|
|
615
|
-
)
|
|
616
|
-
|
|
617
|
-
hparams = self.pre_init_update_hparams_dict(hparams)
|
|
618
|
-
hparams = self.config_cls().from_dict(hparams)
|
|
619
|
-
self._update_environment(hparams)
|
|
620
|
-
hparams = self.pre_init_update_hparams(hparams)
|
|
621
|
-
super().__init__(hparams)
|
|
622
|
-
|
|
623
|
-
self.save_hyperparameters(hparams)
|
|
624
|
-
|
|
625
|
-
@property
|
|
626
|
-
def lightning_module(self):
|
|
627
|
-
if not self.trainer:
|
|
628
|
-
raise ValueError("Trainer has not been set.")
|
|
629
|
-
|
|
630
|
-
module = self.trainer.lightning_module
|
|
631
|
-
if not isinstance(module, LightningModuleBase):
|
|
632
|
-
raise ValueError(
|
|
633
|
-
f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}"
|
|
634
|
-
)
|
|
635
|
-
|
|
636
|
-
module = cast(LightningModuleBase[THparams], module)
|
|
637
|
-
return module
|
|
638
|
-
|
|
639
|
-
@property
|
|
640
|
-
def device(self):
|
|
641
|
-
return self.lightning_module.device
|