nshtrainer 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/PKG-INFO +13 -2
  2. nshtrainer-0.1.1/pyproject.toml +47 -0
  3. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/__init__.py +0 -16
  4. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/__init__.py +3 -2
  5. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/base.py +3 -4
  6. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/lr_scheduler/__init__.py +3 -2
  7. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/lr_scheduler/_base.py +3 -6
  8. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +5 -5
  9. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +5 -4
  10. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/__init__.py +0 -4
  11. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/base.py +9 -71
  12. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/config.py +39 -141
  13. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/nn/nonlinearity.py +3 -4
  14. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/optimizer.py +3 -7
  15. nshtrainer-0.1.1/src/nshtrainer/runner.py +31 -0
  16. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/trainer/signal_connector.py +22 -11
  17. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/trainer/trainer.py +1 -1
  18. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/typecheck.py +1 -0
  19. nshtrainer-0.1.0/pyproject.toml +0 -18
  20. nshtrainer-0.1.0/src/nshtrainer/_submit/print_environment_info.py +0 -31
  21. nshtrainer-0.1.0/src/nshtrainer/_submit/session/_output.py +0 -12
  22. nshtrainer-0.1.0/src/nshtrainer/_submit/session/_script.py +0 -109
  23. nshtrainer-0.1.0/src/nshtrainer/_submit/session/lsf.py +0 -467
  24. nshtrainer-0.1.0/src/nshtrainer/_submit/session/slurm.py +0 -573
  25. nshtrainer-0.1.0/src/nshtrainer/_submit/session/unified.py +0 -350
  26. nshtrainer-0.1.0/src/nshtrainer/config.py +0 -289
  27. nshtrainer-0.1.0/src/nshtrainer/runner.py +0 -21
  28. nshtrainer-0.1.0/src/nshtrainer/util/singleton.py +0 -89
  29. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/README.md +0 -0
  30. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  31. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  32. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  33. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  34. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/_snoop.py +0 -0
  35. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/actsave/__init__.py +0 -0
  36. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/actsave/_callback.py +0 -0
  37. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/actsave/_loader.py +0 -0
  38. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/actsave/_saver.py +0 -0
  39. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  40. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  41. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/ema.py +0 -0
  42. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  43. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  44. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/interval.py +0 -0
  45. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  46. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  47. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  48. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  49. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  50. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  51. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/timer.py +0 -0
  52. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  53. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/data/__init__.py +0 -0
  54. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  55. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/data/transform.py +0 -0
  56. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/callback.py +0 -0
  57. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/debug.py +0 -0
  58. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  59. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/logger.py +0 -0
  60. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  61. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  62. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  63. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/nn/__init__.py +0 -0
  64. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/nn/mlp.py +0 -0
  65. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/nn/module_dict.py +0 -0
  66. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/nn/module_list.py +0 -0
  67. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/scripts/check_env.py +0 -0
  68. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  69. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/trainer/__init__.py +0 -0
  70. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/util/environment.py +0 -0
  71. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/util/seed.py +0 -0
  72. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/util/slurm.py +0 -0
  73. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/util/typed.py +0 -0
  74. {nshtrainer-0.1.0 → nshtrainer-0.1.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,10 +9,21 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
+ Requires-Dist: lightning
15
+ Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
+ Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
12
17
  Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
13
- Requires-Dist: nshrunner (>=0.1.0,<0.2.0)
18
+ Requires-Dist: nshrunner (>=0.5.3,<0.6.0)
19
+ Requires-Dist: numpy
20
+ Requires-Dist: pysnooper
21
+ Requires-Dist: pytorch-lightning
22
+ Requires-Dist: rich
14
23
  Requires-Dist: torch
24
+ Requires-Dist: torchmetrics
15
25
  Requires-Dist: typing-extensions
26
+ Requires-Dist: wrapt
16
27
  Description-Content-Type: text/markdown
17
28
 
18
29
 
@@ -0,0 +1,47 @@
1
+ [tool.poetry]
2
+ name = "nshtrainer"
3
+ version = "0.1.1"
4
+ description = ""
5
+ authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.10"
10
+ nshrunner = "^0.5.3"
11
+ nshconfig = "^0.2.0"
12
+ torch = "*"
13
+ typing-extensions = "*"
14
+ lightning = "*"
15
+ pytorch-lightning = "*"
16
+ torchmetrics = "*"
17
+ numpy = "*"
18
+ jaxtyping = "^0.2.33"
19
+ beartype = "^0.18.5"
20
+ lovely-numpy = "^0.2.13"
21
+ lovely-tensors = "^0.1.16"
22
+ pysnooper = "*"
23
+ wrapt = "*"
24
+ rich = "*"
25
+
26
+
27
+ [tool.poetry.group.dev.dependencies]
28
+ pyright = "^1.1.372"
29
+ ruff = "^0.5.4"
30
+ ipykernel = "^6.29.5"
31
+ ipywidgets = "^8.1.3"
32
+
33
+ [build-system]
34
+ requires = ["poetry-core"]
35
+ build-backend = "poetry.core.masonry.api"
36
+
37
+ [tool.pyright]
38
+ typeCheckingMode = "standard"
39
+ deprecateTypingAliases = true
40
+ strictListInference = true
41
+ strictDictionaryInference = true
42
+ strictSetInference = true
43
+ reportPrivateImportUsage = false
44
+ ignore = ["./build/"]
45
+
46
+ [tool.ruff.lint]
47
+ ignore = ["F722", "F821", "E731", "E741"]
@@ -10,16 +10,7 @@ from . import typecheck as typecheck
10
10
  from ._snoop import snoop as snoop
11
11
  from .actsave import ActLoad as ActLoad
12
12
  from .actsave import ActSave as ActSave
13
- from .config import MISSING as MISSING
14
- from .config import AllowMissing as AllowMissing
15
- from .config import Field as Field
16
- from .config import MissingField as MissingField
17
- from .config import PrivateAttr as PrivateAttr
18
- from .config import TypedConfig as TypedConfig
19
13
  from .data import dataset_transform as dataset_transform
20
- from .log import init_python_logging as init_python_logging
21
- from .log import lovely as lovely
22
- from .log import pretty as pretty
23
14
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
15
  from .model import ActSaveConfig as ActSaveConfig
25
16
  from .model import Base as Base
@@ -41,24 +32,17 @@ from .model import (
41
32
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
42
33
  )
43
34
  from .model import GradientClippingConfig as GradientClippingConfig
44
- from .model import LightningDataModuleBase as LightningDataModuleBase
45
35
  from .model import LightningModuleBase as LightningModuleBase
46
36
  from .model import LoggingConfig as LoggingConfig
47
37
  from .model import MetricConfig as MetricConfig
48
38
  from .model import OptimizationConfig as OptimizationConfig
49
39
  from .model import PrimaryMetricConfig as PrimaryMetricConfig
50
- from .model import PythonLogging as PythonLogging
51
40
  from .model import ReproducibilityConfig as ReproducibilityConfig
52
- from .model import RunnerConfig as RunnerConfig
53
41
  from .model import SanityCheckingConfig as SanityCheckingConfig
54
- from .model import SeedConfig as SeedConfig
55
42
  from .model import TrainerConfig as TrainerConfig
56
43
  from .model import WandbWatchConfig as WandbWatchConfig
57
44
  from .nn import TypedModuleDict as TypedModuleDict
58
45
  from .nn import TypedModuleList as TypedModuleList
59
46
  from .optimizer import OptimizerConfig as OptimizerConfig
60
47
  from .runner import Runner as Runner
61
- from .runner import SnapshotConfig as SnapshotConfig
62
48
  from .trainer import Trainer as Trainer
63
- from .util.singleton import Registry as Registry
64
- from .util.singleton import Singleton as Singleton
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from .base import CallbackConfigBase as CallbackConfigBase
5
6
  from .early_stopping import EarlyStopping as EarlyStopping
6
7
  from .ema import EMA as EMA
@@ -31,5 +32,5 @@ CallbackConfig = Annotated[
31
32
  | NormLoggingConfig
32
33
  | GradientSkippingConfig
33
34
  | EMAConfig,
34
- Field(discriminator="name"),
35
+ C.Field(discriminator="name"),
35
36
  ]
@@ -4,10 +4,9 @@ from collections.abc import Iterable
4
4
  from dataclasses import dataclass
5
5
  from typing import TYPE_CHECKING, TypeAlias, TypedDict
6
6
 
7
+ import nshconfig as C
7
8
  from lightning.pytorch import Callback
8
9
 
9
- from ..config import TypedConfig
10
-
11
10
  if TYPE_CHECKING:
12
11
  from ..model.config import BaseConfig
13
12
 
@@ -20,7 +19,7 @@ class CallbackMetadataDict(TypedDict, total=False):
20
19
  """Priority of the callback. Callbacks with higher priority will be loaded first."""
21
20
 
22
21
 
23
- class CallbackMetadataConfig(TypedConfig):
22
+ class CallbackMetadataConfig(C.Config):
24
23
  ignore_if_exists: bool = False
25
24
  """If `True`, the callback will not be added if another callback with the same class already exists."""
26
25
 
@@ -37,7 +36,7 @@ class CallbackWithMetadata:
37
36
  ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
38
37
 
39
38
 
40
- class CallbackConfigBase(TypedConfig, ABC):
39
+ class CallbackConfigBase(C.Config, ABC):
41
40
  metadata: CallbackMetadataConfig = CallbackMetadataConfig()
42
41
  """Metadata for the callback."""
43
42
 
@@ -1,6 +1,7 @@
1
1
  from typing import Annotated, TypeAlias
2
2
 
3
- from ..config import Field
3
+ import nshconfig as C
4
+
4
5
  from ._base import LRSchedulerConfigBase as LRSchedulerConfigBase
5
6
  from ._base import LRSchedulerMetadata as LRSchedulerMetadata
6
7
  from .linear_warmup_cosine import (
@@ -14,5 +15,5 @@ from .reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauCo
14
15
 
15
16
  LRSchedulerConfig: TypeAlias = Annotated[
16
17
  LinearWarmupCosineDecayLRSchedulerConfig | ReduceLROnPlateauConfig,
17
- Field(discriminator="name"),
18
+ C.Field(discriminator="name"),
18
19
  ]
@@ -1,8 +1,9 @@
1
1
  import math
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Mapping
4
- from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias
4
+ from typing import TYPE_CHECKING, Literal
5
5
 
6
+ import nshconfig as C
6
7
  from lightning.pytorch.utilities.types import (
7
8
  LRSchedulerConfigType,
8
9
  LRSchedulerTypeUnion,
@@ -10,8 +11,6 @@ from lightning.pytorch.utilities.types import (
10
11
  from torch.optim import Optimizer
11
12
  from typing_extensions import NotRequired, TypedDict
12
13
 
13
- from ..config import TypedConfig
14
-
15
14
  if TYPE_CHECKING:
16
15
  from ..model.base import LightningModuleBase
17
16
 
@@ -37,9 +36,7 @@ class LRSchedulerMetadata(TypedDict):
37
36
  """Whether to enforce that the monitor exists for reducing the learning rate on plateau. Default is `True`."""
38
37
 
39
38
 
40
- class LRSchedulerConfigBase(TypedConfig, ABC):
41
- Metadata: ClassVar[TypeAlias] = LRSchedulerMetadata
42
-
39
+ class LRSchedulerConfigBase(C.Config, ABC):
43
40
  @abstractmethod
44
41
  def metadata(self) -> LRSchedulerMetadata: ...
45
42
 
@@ -2,12 +2,12 @@ import math
2
2
  import warnings
3
3
  from typing import Literal
4
4
 
5
+ import nshconfig as C
5
6
  from torch.optim import Optimizer
6
7
  from torch.optim.lr_scheduler import LRScheduler
7
8
  from typing_extensions import override
8
9
 
9
- from ..config import Field
10
- from ._base import LRSchedulerConfigBase
10
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
13
13
  class LinearWarmupCosineAnnealingLR(LRScheduler):
@@ -91,11 +91,11 @@ class LinearWarmupCosineAnnealingLR(LRScheduler):
91
91
  class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
92
92
  name: Literal["linear_warmup_cosine_decay"] = "linear_warmup_cosine_decay"
93
93
 
94
- warmup_epochs: int = Field(ge=0)
94
+ warmup_epochs: int = C.Field(ge=0)
95
95
  r"""The number of epochs for the linear warmup phase.
96
96
  The learning rate is linearly increased from `warmup_start_lr` to the initial learning rate over this number of epochs."""
97
97
 
98
- max_epochs: int = Field(gt=0)
98
+ max_epochs: int = C.Field(gt=0)
99
99
  r"""The total number of epochs.
100
100
  The learning rate is decayed to `min_lr` over this number of epochs."""
101
101
 
@@ -113,7 +113,7 @@ class LinearWarmupCosineDecayLRSchedulerConfig(LRSchedulerConfigBase):
113
113
  If `True`, the learning rate will be decayed to `min_lr` over `max_epochs` epochs, and then the learning rate will be increased back to the initial learning rate over `max_epochs` epochs, and so on (this is called a cosine annealing schedule)."""
114
114
 
115
115
  @override
116
- def metadata(self) -> LRSchedulerConfigBase.Metadata:
116
+ def metadata(self) -> LRSchedulerMetadata:
117
117
  return {
118
118
  "interval": "step",
119
119
  }
@@ -1,12 +1,11 @@
1
1
  from typing import TYPE_CHECKING, Literal, cast
2
2
 
3
+ from lightning.pytorch.utilities.types import LRSchedulerConfigType
3
4
  from torch.optim.lr_scheduler import ReduceLROnPlateau
4
5
  from typing_extensions import override
5
6
 
6
- from ll.lr_scheduler._base import LRSchedulerMetadata
7
-
8
7
  from ..model.config import MetricConfig
9
- from ._base import LRSchedulerConfigBase
8
+ from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
10
9
 
11
10
  if TYPE_CHECKING:
12
11
  from ..model.base import BaseConfig
@@ -43,7 +42,9 @@ class ReduceLROnPlateauConfig(LRSchedulerConfigBase):
43
42
  r"""One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * (1 + threshold) in 'max' mode or best * (1 - threshold) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'."""
44
43
 
45
44
  @override
46
- def create_scheduler_impl(self, optimizer, lightning_module, lr):
45
+ def create_scheduler_impl(
46
+ self, optimizer, lightning_module, lr
47
+ ) -> LRSchedulerConfigType:
47
48
  if (metric := self.metric) is None:
48
49
  lm_config = cast("BaseConfig", lightning_module.config)
49
50
  assert (
@@ -1,7 +1,6 @@
1
1
  from typing_extensions import TypeAlias
2
2
 
3
3
  from .base import Base as Base
4
- from .base import LightningDataModuleBase as LightningDataModuleBase
5
4
  from .base import LightningModuleBase as LightningModuleBase
6
5
  from .config import ActSaveConfig as ActSaveConfig
7
6
  from .config import BaseConfig as BaseConfig
@@ -33,11 +32,8 @@ from .config import (
33
32
  )
34
33
  from .config import OptimizationConfig as OptimizationConfig
35
34
  from .config import PrimaryMetricConfig as PrimaryMetricConfig
36
- from .config import PythonLogging as PythonLogging
37
35
  from .config import ReproducibilityConfig as ReproducibilityConfig
38
- from .config import RunnerConfig as RunnerConfig
39
36
  from .config import SanityCheckingConfig as SanityCheckingConfig
40
- from .config import SeedConfig as SeedConfig
41
37
  from .config import TrainerConfig as TrainerConfig
42
38
  from .config import WandbWatchConfig as WandbWatchConfig
43
39
 
@@ -23,11 +23,12 @@ from .config import (
23
23
  EnvironmentLinuxEnvironmentConfig,
24
24
  EnvironmentLSFInformationConfig,
25
25
  EnvironmentSLURMInformationConfig,
26
+ EnvironmentSnapshotConfig,
26
27
  )
27
- from .modules.callback import CallbackModuleMixin, CallbackRegistrarModuleMixin
28
+ from .modules.callback import CallbackModuleMixin
28
29
  from .modules.debug import DebugModuleMixin
29
30
  from .modules.distributed import DistributedMixin
30
- from .modules.logger import LoggerLightningModuleMixin, LoggerModuleMixin
31
+ from .modules.logger import LoggerLightningModuleMixin
31
32
  from .modules.profiler import ProfilerMixin
32
33
  from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
33
34
  from .modules.shared_parameters import SharedParametersModuleMixin
@@ -265,6 +266,9 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
265
266
  boot_time=_try_get(lambda: _psutil().boot_time()),
266
267
  load_avg=_try_get(lambda: os.getloadavg()),
267
268
  )
269
+ hparams.environment.snapshot = (
270
+ EnvironmentSnapshotConfig.from_current_environment()
271
+ )
268
272
 
269
273
  def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
270
274
  """
@@ -309,15 +313,12 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
309
313
  @property
310
314
  def datamodule(self):
311
315
  datamodule = getattr(self.trainer, "datamodule", None)
312
- if datamodule is None:
316
+ if (datamodule := getattr(self.trainer, "datamodule", None)) is None:
313
317
  return None
314
-
315
- if not isinstance(datamodule, LightningDataModuleBase):
318
+ if not isinstance(datamodule, LightningDataModule):
316
319
  raise TypeError(
317
- f"datamodule must be a LightningDataModuleBase: {type(datamodule)}"
320
+ f"datamodule must be a LightningDataModule: {type(datamodule)}"
318
321
  )
319
-
320
- datamodule = cast(LightningDataModuleBase[THparams], datamodule)
321
322
  return datamodule
322
323
 
323
324
  if TYPE_CHECKING:
@@ -576,66 +577,3 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
576
577
  "batch": batch,
577
578
  "batch_idx": batch_idx,
578
579
  }
579
-
580
-
581
- class LightningDataModuleBase(
582
- LoggerModuleMixin,
583
- CallbackRegistrarModuleMixin,
584
- Base[THparams],
585
- LightningDataModule,
586
- ABC,
587
- Generic[THparams],
588
- ):
589
- hparams: THparams # pyright: ignore[reportIncompatibleMethodOverride]
590
- hparams_initial: THparams # pyright: ignore[reportIncompatibleMethodOverride]
591
-
592
- def pre_init_update_hparams_dict(self, hparams: MutableMapping[str, Any]):
593
- """
594
- Override this method to update the hparams dictionary before it is used to create the hparams object.
595
- Mapping-based parameters are passed to the constructor of the hparams object when we're loading the model from a checkpoint.
596
- """
597
- return hparams
598
-
599
- def pre_init_update_hparams(self, hparams: THparams):
600
- """
601
- Override this method to update the hparams object before it is used to create the hparams_initial object.
602
- """
603
- return hparams
604
-
605
- @classmethod
606
- def _update_environment(cls, hparams: THparams):
607
- hparams.environment.data = _cls_info(cls)
608
-
609
- @override
610
- def __init__(self, hparams: THparams):
611
- if not isinstance(hparams, BaseConfig):
612
- if not isinstance(hparams, MutableMapping):
613
- raise TypeError(
614
- f"hparams must be a BaseConfig or a MutableMapping: {type(hparams)}"
615
- )
616
-
617
- hparams = self.pre_init_update_hparams_dict(hparams)
618
- hparams = self.config_cls().from_dict(hparams)
619
- self._update_environment(hparams)
620
- hparams = self.pre_init_update_hparams(hparams)
621
- super().__init__(hparams)
622
-
623
- self.save_hyperparameters(hparams)
624
-
625
- @property
626
- def lightning_module(self):
627
- if not self.trainer:
628
- raise ValueError("Trainer has not been set.")
629
-
630
- module = self.trainer.lightning_module
631
- if not isinstance(module, LightningModuleBase):
632
- raise ValueError(
633
- f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}"
634
- )
635
-
636
- module = cast(LightningModuleBase[THparams], module)
637
- return module
638
-
639
- @property
640
- def device(self):
641
- return self.lightning_module.device