nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -1,25 +1,16 @@
1
1
  from . import _experimental as _experimental
2
2
  from . import actsave as actsave
3
3
  from . import callbacks as callbacks
4
+ from . import config as config
4
5
  from . import lr_scheduler as lr_scheduler
5
6
  from . import model as model
6
7
  from . import nn as nn
7
8
  from . import optimizer as optimizer
8
- from . import snapshot as snapshot
9
9
  from . import typecheck as typecheck
10
10
  from ._snoop import snoop as snoop
11
11
  from .actsave import ActLoad as ActLoad
12
12
  from .actsave import ActSave as ActSave
13
- from .config import MISSING as MISSING
14
- from .config import AllowMissing as AllowMissing
15
- from .config import Field as Field
16
- from .config import MissingField as MissingField
17
- from .config import PrivateAttr as PrivateAttr
18
- from .config import TypedConfig as TypedConfig
19
13
  from .data import dataset_transform as dataset_transform
20
- from .log import init_python_logging as init_python_logging
21
- from .log import lovely as lovely
22
- from .log import pretty as pretty
23
14
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
15
  from .model import ActSaveConfig as ActSaveConfig
25
16
  from .model import Base as Base
@@ -41,24 +32,17 @@ from .model import (
41
32
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
42
33
  )
43
34
  from .model import GradientClippingConfig as GradientClippingConfig
44
- from .model import LightningDataModuleBase as LightningDataModuleBase
45
35
  from .model import LightningModuleBase as LightningModuleBase
46
36
  from .model import LoggingConfig as LoggingConfig
47
37
  from .model import MetricConfig as MetricConfig
48
38
  from .model import OptimizationConfig as OptimizationConfig
49
39
  from .model import PrimaryMetricConfig as PrimaryMetricConfig
50
- from .model import PythonLogging as PythonLogging
51
40
  from .model import ReproducibilityConfig as ReproducibilityConfig
52
- from .model import RunnerConfig as RunnerConfig
53
41
  from .model import SanityCheckingConfig as SanityCheckingConfig
54
- from .model import SeedConfig as SeedConfig
55
42
  from .model import TrainerConfig as TrainerConfig
56
43
  from .model import WandbWatchConfig as WandbWatchConfig
57
44
  from .nn import TypedModuleDict as TypedModuleDict
58
45
  from .nn import TypedModuleList as TypedModuleList
59
46
  from .optimizer import OptimizerConfig as OptimizerConfig
60
47
  from .runner import Runner as Runner
61
- from .runner import SnapshotConfig as SnapshotConfig
62
48
  from .trainer import Trainer as Trainer
63
- from .util.singleton import Registry as Registry
64
- from .util.singleton import Singleton as Singleton
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from .base import CallbackConfigBase as CallbackConfigBase
5
6
  from .early_stopping import EarlyStopping as EarlyStopping
6
7
  from .ema import EMA as EMA
@@ -31,5 +32,5 @@ CallbackConfig = Annotated[
31
32
  | NormLoggingConfig
32
33
  | GradientSkippingConfig
33
34
  | EMAConfig,
34
- Field(discriminator="name"),
35
+ C.Field(discriminator="name"),
35
36
  ]
@@ -4,10 +4,9 @@ from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
5
  from typing import TYPE_CHECKING, TypeAlias, TypedDict
6
6
 
7
+ import nshconfig as C
7
8
  from lightning.pytorch import Callback
8
9
 
9
- from ..config import TypedConfig
10
-
11
10
  if TYPE_CHECKING:
12
11
  from ..model.config import BaseConfig
13
12
 
@@ -20,7 +19,7 @@ class CallbackMetadataDict(TypedDict, total=False):
20
19
  """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
20
 
22
21
 
23
- class CallbackMetadataConfig(TypedConfig):
22
+ class CallbackMetadataConfig(C.Config):
24
23
  ignore_if_exists: bool = False
25
24
  """If `True`, the callback will not be added if another callback with the same class already exists."""
26
25
 
@@ -37,7 +36,7 @@ class CallbackWithMetadata:
37
36
  ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
38
37
 
39
38
 
40
- class CallbackConfigBase(TypedConfig, ABC):
39
+ class CallbackConfigBase(C.Config, ABC):
41
40
  metadata: CallbackMetadataConfig = CallbackMetadataConfig()
42
41
  """Metadata for the callback."""
43
42
 
nshtrainer/config.py CHANGED
@@ -1,289 +1,4 @@
1
- from collections.abc import Mapping, MutableMapping
2
- from typing import TYPE_CHECKING, Any, ClassVar
1
+ from nshconfig import * # type: ignore # noqa: F403
2
+ from nshconfig import Config
3
3
 
4
- from pydantic import BaseModel, ConfigDict
5
- from pydantic import Field as Field
6
- from pydantic import PrivateAttr as PrivateAttr
7
- from typing_extensions import deprecated, override
8
-
9
- from ._config.missing import MISSING, validate_no_missing_values
10
- from ._config.missing import AllowMissing as AllowMissing
11
- from ._config.missing import MissingField as MissingField
12
-
13
- _MutableMappingBase = MutableMapping[str, Any]
14
- if TYPE_CHECKING:
15
- _MutableMappingBase = object
16
-
17
-
18
- _DraftConfigContextSentinel = object()
19
-
20
-
21
- class TypedConfig(BaseModel, _MutableMappingBase):
22
- _is_draft_config: bool = PrivateAttr(default=False)
23
- """
24
- Whether this config is a draft config or not.
25
-
26
- Draft configs are configs that are not yet fully validated.
27
- They allow for a nicer API when creating configs, e.g.:
28
-
29
- ```python
30
- config = MyConfig.draft()
31
-
32
- # Set some values
33
- config.a = 10
34
- config.b = "hello"
35
-
36
- # Finalize the config
37
- config = config.finalize()
38
- ```
39
- """
40
-
41
- repr_diff_only: ClassVar[bool] = True
42
- """
43
- If `True`, the repr methods will only show values for fields that are different from the default.
44
- """
45
-
46
- MISSING: ClassVar[Any] = MISSING
47
- """
48
- Alias for the `MISSING` constant.
49
- """
50
-
51
- model_config: ClassVar[ConfigDict] = ConfigDict(
52
- # By default, Pydantic will throw a warning if a field starts with "model_",
53
- # so we need to disable that warning (beacuse "model_" is a popular prefix for ML).
54
- protected_namespaces=(),
55
- validate_assignment=True,
56
- validate_return=True,
57
- validate_default=True,
58
- strict=True,
59
- revalidate_instances="always",
60
- arbitrary_types_allowed=True,
61
- extra="ignore",
62
- validation_error_cause=True,
63
- use_attribute_docstrings=True,
64
- )
65
-
66
- def __draft_pre_init__(self):
67
- """Called right before a draft config is finalized."""
68
- pass
69
-
70
- def __post_init__(self):
71
- """Called after the final config is validated."""
72
- pass
73
-
74
- @classmethod
75
- @deprecated("Use `model_validate` instead.")
76
- def from_dict(cls, model_dict: Mapping[str, Any]):
77
- return cls.model_validate(model_dict)
78
-
79
- def model_deep_validate(self, strict: bool = True):
80
- """
81
- Validate the config and all of its sub-configs.
82
-
83
- Args:
84
- config: The config to validate.
85
- strict: Whether to validate the config strictly.
86
- """
87
- config_dict = self.model_dump(round_trip=True)
88
- config = self.model_validate(config_dict, strict=strict)
89
-
90
- # Make sure that this is not a draft config
91
- if config._is_draft_config:
92
- raise ValueError("Draft configs are not valid. Call `finalize` first.")
93
-
94
- return config
95
-
96
- @classmethod
97
- def draft(cls, **kwargs):
98
- config = cls.model_construct_draft(**kwargs)
99
- return config
100
-
101
- def finalize(self, strict: bool = True):
102
- # This must be a draft config, otherwise we raise an error
103
- if not self._is_draft_config:
104
- raise ValueError("Finalize can only be called on drafts.")
105
-
106
- # First, we call `__draft_pre_init__` to allow the config to modify itself a final time
107
- self.__draft_pre_init__()
108
-
109
- # Then, we dump the config to a dict and then re-validate it
110
- return self.model_deep_validate(strict=strict)
111
-
112
- @override
113
- def model_post_init(self, __context: Any) -> None:
114
- super().model_post_init(__context)
115
-
116
- # Call the `__post_init__` method if this is not a draft config
117
- if __context is _DraftConfigContextSentinel:
118
- return
119
-
120
- self.__post_init__()
121
-
122
- # After `_post_init__` is called, we perform the final round of validation
123
- self.model_post_init_validate()
124
-
125
- def model_post_init_validate(self):
126
- validate_no_missing_values(self)
127
-
128
- @classmethod
129
- def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any):
130
- """
131
- NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class,
132
- with the following changes:
133
- - The `model_post_init` method is called with the `_DraftConfigContext` context.
134
- - The `_is_draft_config` attribute is set to `True` in the `values` dict.
135
-
136
- Creates a new instance of the `Model` class with validated data.
137
-
138
- Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
139
- Default values are respected, but no other validation is performed.
140
-
141
- !!! note
142
- `model_construct()` generally respects the `model_config.extra` setting on the provided model.
143
- That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
144
- and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
145
- Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
146
- an error if extra values are passed, but they will be ignored.
147
-
148
- Args:
149
- _fields_set: The set of field names accepted for the Model instance.
150
- values: Trusted or pre-validated data dictionary.
151
-
152
- Returns:
153
- A new instance of the `Model` class with validated data.
154
- """
155
-
156
- values["_is_draft_config"] = True
157
-
158
- m = cls.__new__(cls)
159
- fields_values: dict[str, Any] = {}
160
- fields_set = set()
161
-
162
- for name, field in cls.model_fields.items():
163
- if field.alias and field.alias in values:
164
- fields_values[name] = values.pop(field.alias)
165
- fields_set.add(name)
166
- elif name in values:
167
- fields_values[name] = values.pop(name)
168
- fields_set.add(name)
169
- elif not field.is_required():
170
- fields_values[name] = field.get_default(call_default_factory=True)
171
- if _fields_set is None:
172
- _fields_set = fields_set
173
-
174
- _extra: dict[str, Any] | None = None
175
- if cls.model_config.get("extra") == "allow":
176
- _extra = {}
177
- for k, v in values.items():
178
- _extra[k] = v
179
- object.__setattr__(m, "__dict__", fields_values)
180
- object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
181
- if not cls.__pydantic_root_model__:
182
- object.__setattr__(m, "__pydantic_extra__", _extra)
183
-
184
- if cls.__pydantic_post_init__:
185
- m.model_post_init(_DraftConfigContextSentinel)
186
- # update private attributes with values set
187
- if (
188
- hasattr(m, "__pydantic_private__")
189
- and m.__pydantic_private__ is not None
190
- ):
191
- for k, v in values.items():
192
- if k in m.__private_attributes__:
193
- m.__pydantic_private__[k] = v
194
-
195
- elif not cls.__pydantic_root_model__:
196
- # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
197
- # Since it doesn't, that means that `__pydantic_private__` should be set to None
198
- object.__setattr__(m, "__pydantic_private__", None)
199
-
200
- return m
201
-
202
- @override
203
- def __repr_args__(self):
204
- # If `repr_diff_only` is `True`, we only show the fields that are different from the default.
205
- if not self.repr_diff_only:
206
- yield from super().__repr_args__()
207
- return
208
-
209
- # First, we get the default values for all fields.
210
- default_values = self.model_construct_draft()
211
-
212
- # Then, we compare the default values with the current values.
213
- for k, v in super().__repr_args__():
214
- if k is None:
215
- yield k, v
216
- continue
217
-
218
- # If there is no default value or the value is different from the default, we yield it.
219
- if not hasattr(default_values, k) or getattr(default_values, k) != v:
220
- yield k, v
221
- continue
222
-
223
- # Otherwise, we can skip this field.
224
-
225
- # region MutableMapping implementation
226
- if not TYPE_CHECKING:
227
- # This is mainly so the config can be used with lightning's hparams
228
- # transparently and without any issues.
229
-
230
- @property
231
- def _ll_dict(self):
232
- return self.model_dump()
233
-
234
- # We need to make sure every config class
235
- # is a MutableMapping[str, Any] so that it can be used
236
- # with lightning's hparams.
237
- @override
238
- def __getitem__(self, key: str):
239
- # Key can be of the format "a.b.c"
240
- # so we need to split it into a list of keys.
241
- [first_key, *rest_keys] = key.split(".")
242
- value = self._ll_dict[first_key]
243
-
244
- for key in rest_keys:
245
- if isinstance(value, Mapping):
246
- value = value[key]
247
- else:
248
- value = getattr(value, key)
249
-
250
- return value
251
-
252
- @override
253
- def __setitem__(self, key: str, value: Any):
254
- # Key can be of the format "a.b.c"
255
- # so we need to split it into a list of keys.
256
- [first_key, *rest_keys] = key.split(".")
257
- if len(rest_keys) == 0:
258
- self._ll_dict[first_key] = value
259
- return
260
-
261
- # We need to traverse the keys until we reach the last key
262
- # and then set the value
263
- current_value = self._ll_dict[first_key]
264
- for key in rest_keys[:-1]:
265
- if isinstance(current_value, Mapping):
266
- current_value = current_value[key]
267
- else:
268
- current_value = getattr(current_value, key)
269
-
270
- # Set the value
271
- if isinstance(current_value, MutableMapping):
272
- current_value[rest_keys[-1]] = value
273
- else:
274
- setattr(current_value, rest_keys[-1], value)
275
-
276
- @override
277
- def __delitem__(self, key: str):
278
- # This is unsupported for this class
279
- raise NotImplementedError
280
-
281
- @override
282
- def __iter__(self):
283
- return iter(self._ll_dict)
284
-
285
- @override
286
- def __len__(self):
287
- return len(self._ll_dict)
288
-
289
- # endregion
4
+ TypedConfig = Config
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated, TypeAlias
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
5
6
  from ._base import LRSchedulerMetadata as LRSchedulerMetadata
6
7
  from .linear_warmup_cosine import (
@@ -14,5 +15,5 @@ from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauCo
14
15
 
15
16
  LRSchedulerConfig: TypeAlias = Annotated[
16
17
  LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
17
- Field(discriminator="name"),
18
+ C.Field(discriminator="name"),
18
19
  ]
@@ -1,8 +1,9 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping
4
- from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias
4
+ from typing import TYPE_CHECKING, Literal
5
5
 
6
+ import nshconfig as C
6
7
  from lightning.pytorch.utilities.types import (
7
8
  LRSchedulerConfigType,
8
9
  LRSchedulerTypeUnion,
@@ -10,8 +11,6 @@ from lightning.pytorch.utilities.types import (
10
11
  from torch.optim import Optimizer
11
12
  from typing_extensions import NotRequired, TypedDict
12
13
 
13
- from ..config import TypedConfig
14
-
15
14
  if TYPE_CHECKING:
16
15
  from ..model.base import LightningModuleBase
17
16
 
@@ -37,9 +36,7 @@ class LRSchedulerMetadata(TypedDict):
37
36
  """Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
38
37
 
39
38
 
40
- class LRSchedulerConfigBase(TypedConfig, ABC):
41
- Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
42
-
39
+ class LRSchedulerConfigBase(C.Config, ABC):
43
40
  @abstractmethod
44
41
  def metadata(self) -> LRSchedulerMetadata: ...
45
42
 
@@ -2,12 +2,12 @@ import math
2
2
  import warnings
3
3
  from typing import Literal
4
4
 
5
+ import nshconfig as C
5
6
  from torch.optim import Optimizer
6
7
  from torch.optim.lr_scheduler import LRScheduler
7
8
  from typing_extensions import override
8
9
 
9
- from ..config import Field
10
- from ._base import LRSchedulerConfigBase
10
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
13
13
  class LinearWarmupCosineAnnealingLR(LRScheduler):
@@ -91,11 +91,11 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
91
91
  class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
92
  name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
93
 
94
- warmup_epochs: int = Field(ge=0)
94
+ warmup_epochs: int = C.Field(ge=0)
95
95
  r"""The number of epochs for the linear warmup phase.
96
96
  The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
97
97
 
98
- max_epochs: int = Field(gt=0)
98
+ max_epochs: int = C.Field(gt=0)
99
99
  r"""The total number of epochs.
100
100
  The learning rate is decayed to `min_lr` over this number of epochs."""
101
101
 
@@ -113,7 +113,7 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
113
113
  If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
114
114
 
115
115
  @override
116
- def metadata(self) -> LRSchedulerConfigBase.Metadata:
116
+ def metadata(self) -> LRSchedulerMetadata:
117
117
  return {
118
118
  "interval": "step",
119
119
  }
@@ -1,12 +1,11 @@
1
1
  from typing import TYPE_CHECKING, Literal, cast
2
2
 
3
+ from lightning.pytorch.utilities.types import LRSchedulerConfigType
3
4
  from torch.optim.lr_scheduler import ReduceLROnPlateau
4
5
  from typing_extensions import override
5
6
 
6
- from ll.lr_scheduler._base import LRSchedulerMetadata
7
-
8
7
  from ..model.config import MetricConfig
9
- from ._base import LRSchedulerConfigBase
8
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
10
9
 
11
10
  if TYPE_CHECKING:
12
11
  from ..model.base import BaseConfig
@@ -43,7 +42,9 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
43
42
  r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
44
43
 
45
44
  @override
46
- def create_scheduler_impl(self, optimizer, lightning_module, lr):
45
+ def create_scheduler_impl(
46
+ self, optimizer, lightning_module, lr
47
+ ) -> LRSchedulerConfigType:
47
48
  if (metric := self.metric) is None:
48
49
  lm_config = cast("BaseConfig", lightning_module.config)
49
50
  assert (
@@ -1,7 +1,6 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
3
  from .base import Base as Base
4
- from .base import LightningDataModuleBase as LightningDataModuleBase
5
4
  from .base import LightningModuleBase as LightningModuleBase
6
5
  from .config import ActSaveConfig as ActSaveConfig
7
6
  from .config import BaseConfig as BaseConfig
@@ -33,11 +32,8 @@ from .config import (
33
32
  )
34
33
  from .config import OptimizationConfig as OptimizationConfig
35
34
  from .config import PrimaryMetricConfig as PrimaryMetricConfig
36
- from .config import PythonLogging as PythonLogging
37
35
  from .config import ReproducibilityConfig as ReproducibilityConfig
38
- from .config import RunnerConfig as RunnerConfig
39
36
  from .config import SanityCheckingConfig as SanityCheckingConfig
40
- from .config import SeedConfig as SeedConfig
41
37
  from .config import TrainerConfig as TrainerConfig
42
38
  from .config import WandbWatchConfig as WandbWatchConfig
43
39
 
nshtrainer/model/base.py CHANGED
@@ -23,11 +23,12 @@ from .config import (
23
23
  EnvironmentLinuxEnvironmentConfig,
24
24
  EnvironmentLSFInformationConfig,
25
25
  EnvironmentSLURMInformationConfig,
26
+ EnvironmentSnapshotConfig,
26
27
  )
27
- from .modules.callback import CallbackModuleMixin, CallbackRegistrarModuleMixin
28
+ from .modules.callback import CallbackModuleMixin
28
29
  from .modules.debug import DebugModuleMixin
29
30
  from .modules.distributed import DistributedMixin
30
- from .modules.logger import LoggerLightningModuleMixin, LoggerModuleMixin
31
+ from .modules.logger import LoggerLightningModuleMixin
31
32
  from .modules.profiler import ProfilerMixin
32
33
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
33
34
  from .modules.shared_parameters import SharedParametersModuleMixin
@@ -265,6 +266,9 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
265
266
  boot_time=_try_get(lambda: _psutil().boot_time()),
266
267
  load_avg=_try_get(lambda: os.getloadavg()),
267
268
  )
269
+ hparams.environment.snapshot = (
270
+ EnvironmentSnapshotConfig.from_current_environment()
271
+ )
268
272
 
269
273
  def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
270
274
  """
@@ -309,15 +313,12 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
309
313
  @property
310
314
  def datamodule(self):
311
315
  datamodule = getattr(self.trainer, "datamodule", None)
312
- if datamodule is None:
316
+ if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
313
317
  return None
314
-
315
- if not isinstance(datamodule, LightningDataModuleBase):
318
+ if not isinstance(datamodule, LightningDataModule):
316
319
  raise TypeError(
317
- f"datamodule must be a LightningDataModuleBase: {type(datamodule)}"
320
+ f"datamodule must be a LightningDataModule: {type(datamodule)}"
318
321
  )
319
-
320
- datamodule = cast(LightningDataModuleBase[THparams], datamodule)
321
322
  return datamodule
322
323
 
323
324
  if TYPE_CHECKING:
@@ -576,66 +577,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
576
577
  "batch": batch,
577
578
  "batch_idx": batch_idx,
578
579
  }
579
-
580
-
581
- class LightningDataModuleBase(
582
- LoggerModuleMixin,
583
- CallbackRegistrarModuleMixin,
584
- Base[THparams],
585
- LightningDataModule,
586
- ABC,
587
- Generic[THparams],
588
- ):
589
- hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
590
- hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
591
-
592
- def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
593
- """
594
- Override this method to update the hparams dictionary before it is used to create the hparams object.
595
- Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
596
- """
597
- return hparams
598
-
599
- def pre_init_update_hparams(self, hparams: THparams):
600
- """
601
- Override this method to update the hparams object before it is used to create the hparams_initial object.
602
- """
603
- return hparams
604
-
605
- @classmethod
606
- def _update_environment(cls, hparams: THparams):
607
- hparams.environment.data = _cls_info(cls)
608
-
609
- @override
610
- def __init__(self, hparams: THparams):
611
- if not isinstance(hparams, BaseConfig):
612
- if not isinstance(hparams, MutableMapping):
613
- raise TypeError(
614
- f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
615
- )
616
-
617
- hparams = self.pre_init_update_hparams_dict(hparams)
618
- hparams = self.config_cls().from_dict(hparams)
619
- self._update_environment(hparams)
620
- hparams = self.pre_init_update_hparams(hparams)
621
- super().__init__(hparams)
622
-
623
- self.save_hyperparameters(hparams)
624
-
625
- @property
626
- def lightning_module(self):
627
- if not self.trainer:
628
- raise ValueError("Trainer has not been set.")
629
-
630
- module = self.trainer.lightning_module
631
- if not isinstance(module, LightningModuleBase):
632
- raise ValueError(
633
- f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}"
634
- )
635
-
636
- module = cast(LightningModuleBase[THparams], module)
637
- return module
638
-
639
- @property
640
- def device(self):
641
- return self.lightning_module.device