nshtrainer 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -7,10 +7,9 @@ from . import metrics as metrics
7
7
  from . import model as model
8
8
  from . import nn as nn
9
9
  from . import optimizer as optimizer
10
+ from . import profiler as profiler
10
11
  from .metrics import MetricConfig as MetricConfig
11
- from .model import Base as Base
12
12
  from .model import BaseConfig as BaseConfig
13
- from .model import ConfigList as ConfigList
14
13
  from .model import LightningModuleBase as LightningModuleBase
15
14
  from .runner import Runner as Runner
16
15
  from .trainer import Trainer as Trainer
@@ -0,0 +1,85 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import nshconfig as C
5
+
6
+ from .callbacks.directory_setup import DirectorySetupConfig
7
+ from .loggers import LoggerConfig
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class DirectoryConfig(C.Config):
13
+ project_root: Path | None = None
14
+ """
15
+ Root directory for this project.
16
+
17
+ This isn't specific to the run; it is the parent directory of all runs.
18
+ """
19
+
20
+ log: Path | None = None
21
+ """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
22
+
23
+ stdio: Path | None = None
24
+ """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
25
+
26
+ checkpoint: Path | None = None
27
+ """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
28
+
29
+ activation: Path | None = None
30
+ """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
31
+
32
+ profile: Path | None = None
33
+ """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
34
+
35
+ setup_callback: DirectorySetupConfig = DirectorySetupConfig()
36
+ """Configuration for the directory setup PyTorch Lightning callback."""
37
+
38
+ def resolve_run_root_directory(self, run_id: str) -> Path:
39
+ if (project_root_dir := self.project_root) is None:
40
+ project_root_dir = Path.cwd()
41
+
42
+ # The default base dir is $CWD/nshtrainer/{id}/
43
+ base_dir = project_root_dir / "nshtrainer"
44
+ base_dir.mkdir(exist_ok=True)
45
+
46
+ # Add a .gitignore file to the nshtrainer directory
47
+ # which will ignore all files except for the .gitignore file itself
48
+ gitignore_path = base_dir / ".gitignore"
49
+ if not gitignore_path.exists():
50
+ gitignore_path.touch()
51
+ gitignore_path.write_text("*\n")
52
+
53
+ base_dir = base_dir / run_id
54
+ base_dir.mkdir(exist_ok=True)
55
+
56
+ return base_dir
57
+
58
+ def resolve_subdirectory(
59
+ self,
60
+ run_id: str,
61
+ # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
62
+ subdirectory: str,
63
+ ) -> Path:
64
+ # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
65
+ if (subdir := getattr(self, subdirectory, None)) is not None:
66
+ assert isinstance(
67
+ subdir, Path
68
+ ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
69
+ return subdir
70
+
71
+ dir = self.resolve_run_root_directory(run_id)
72
+ dir = dir / subdirectory
73
+ dir.mkdir(exist_ok=True)
74
+ return dir
75
+
76
+ def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
77
+ if (log_dir := logger.log_dir) is not None:
78
+ return log_dir
79
+
80
+ # Save to nshtrainer/{id}/log/{logger name}
81
+ log_dir = self.resolve_subdirectory(run_id, "log")
82
+ log_dir = log_dir / logger.name
83
+ log_dir.mkdir(exist_ok=True)
84
+
85
+ return log_dir
@@ -12,6 +12,8 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
12
12
  from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
+ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
16
+ from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
17
  from .early_stopping import EarlyStopping as EarlyStopping
16
18
  from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
17
19
  from .ema import EMA as EMA
@@ -28,6 +30,10 @@ from .norm_logging import NormLoggingCallback as NormLoggingCallback
28
30
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
29
31
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
30
32
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
33
+ from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
34
+ from .rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
35
+ from .shared_parameters import SharedParametersCallback as SharedParametersCallback
36
+ from .shared_parameters import SharedParametersConfig as SharedParametersConfig
31
37
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
32
38
  from .timer import EpochTimer as EpochTimer
33
39
  from .timer import EpochTimerConfig as EpochTimerConfig
@@ -46,6 +52,8 @@ CallbackConfig = Annotated[
46
52
  | BestCheckpointCallbackConfig
47
53
  | LastCheckpointCallbackConfig
48
54
  | OnExceptionCheckpointCallbackConfig
55
+ | SharedParametersConfig
56
+ | RLPSanityChecksConfig
49
57
  | WandbWatchConfig,
50
58
  C.Field(discriminator="name"),
51
59
  ]
@@ -0,0 +1,85 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch import Callback
7
+ from typing_extensions import override
8
+
9
+ from .base import CallbackConfigBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def _create_symlink_to_nshrunner(base_dir: Path):
15
+ # Resolve the current nshrunner session directory
16
+ if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
17
+ log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
18
+ return
19
+ session_dir = Path(session_dir)
20
+ if not session_dir.exists() or not session_dir.is_dir():
21
+ log.warning(
22
+ f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
23
+ "Skipping symlink creation."
24
+ )
25
+ return
26
+
27
+ # Create the symlink
28
+ symlink_path = base_dir / "nshrunner"
29
+ if symlink_path.exists():
30
+ # If it already points to the correct directory, we're done
31
+ if symlink_path.resolve() == session_dir.resolve():
32
+ return
33
+
34
+ # Otherwise, we should log a warning and remove the existing symlink
35
+ log.warning(
36
+ f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
37
+ "Removing the existing symlink."
38
+ )
39
+ symlink_path.unlink()
40
+
41
+ symlink_path.symlink_to(session_dir)
42
+
43
+
44
+ class DirectorySetupConfig(CallbackConfigBase):
45
+ name: Literal["directory_setup"] = "directory_setup"
46
+
47
+ enabled: bool = True
48
+ """Whether to enable the directory setup callback."""
49
+
50
+ create_symlink_to_nshrunner_root: bool = True
51
+ """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
52
+
53
+ def __bool__(self):
54
+ return self.enabled
55
+
56
+ def create_callbacks(self, root_config):
57
+ if not self:
58
+ return
59
+
60
+ yield DirectorySetupCallback(self)
61
+
62
+
63
+ class DirectorySetupCallback(Callback):
64
+ @override
65
+ def __init__(self, config: DirectorySetupConfig):
66
+ super().__init__()
67
+
68
+ self.config = config
69
+ del config
70
+
71
+ @override
72
+ def setup(self, trainer, pl_module, stage):
73
+ super().setup(trainer, pl_module, stage)
74
+
75
+ # Create a symlink to the root folder for the Runner
76
+ if self.config.create_symlink_to_nshrunner_root:
77
+ # Resolve the base dir
78
+ from ..model.config import BaseConfig
79
+
80
+ assert isinstance(
81
+ config := pl_module.hparams, BaseConfig
82
+ ), f"Expected a BaseConfig, got {type(config)}"
83
+
84
+ base_dir = config.directory.resolve_run_root_directory(config.id)
85
+ _create_symlink_to_nshrunner(base_dir)
@@ -0,0 +1,230 @@
1
+ import logging
2
+ from collections.abc import Mapping
3
+ from typing import Literal, cast
4
+
5
+ import torch
6
+ from lightning.pytorch import LightningModule
7
+ from lightning.pytorch.callbacks import Callback
8
+ from lightning.pytorch.utilities.types import (
9
+ LRSchedulerConfigType,
10
+ LRSchedulerTypeUnion,
11
+ )
12
+ from typing_extensions import Protocol, override, runtime_checkable
13
+
14
+ from .base import CallbackConfigBase
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class RLPSanityChecksConfig(CallbackConfigBase):
20
+ """
21
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
22
+ - If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
23
+ - If the ``interval`` is epoch, it makes sure that validation is called every ``frequency`` epochs.
24
+ """
25
+
26
+ name: Literal["rlp_sanity_checks"] = "rlp_sanity_checks"
27
+
28
+ enabled: bool = True
29
+ """Whether to enable ReduceLRPlateau sanity checks."""
30
+
31
+ on_error: Literal["warn", "error"] = "error"
32
+ """What to do when a sanity check fails."""
33
+
34
+ def __bool__(self):
35
+ return self.enabled
36
+
37
+ def create_callbacks(self, root_config):
38
+ if not self:
39
+ return
40
+
41
+ yield RLPSanityChecksCallback(self)
42
+
43
+
44
+ class RLPSanityChecksCallback(Callback):
45
+ @override
46
+ def __init__(self, config: RLPSanityChecksConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ @override
53
+ def on_train_start(self, trainer, pl_module):
54
+ # If we're in PL's "sanity check" mode, we don't need to run this check
55
+ if trainer.sanity_checking:
56
+ return
57
+
58
+ # If the sanity check is disabled, return.
59
+ if not self.config:
60
+ return
61
+
62
+ # If no lr schedulers, return.
63
+ if not trainer.lr_scheduler_configs:
64
+ return
65
+
66
+ errors: list[str] = []
67
+ disable_message = (
68
+ "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau = None` "
69
+ "to disable this sanity check."
70
+ )
71
+
72
+ for lr_scheduler_config in trainer.lr_scheduler_configs:
73
+ if not lr_scheduler_config.reduce_on_plateau:
74
+ continue
75
+
76
+ match lr_scheduler_config.interval:
77
+ case "epoch":
78
+ # we need to make sure that the trainer runs val every `frequency` epochs
79
+
80
+ # If `trainer.check_val_every_n_epoch` is None, then Lightning
81
+ # will run val every `int(trainer.val_check_interval)` steps.
82
+ # So, first we need to make sure that `trainer.val_check_interval` is not None first.
83
+ if trainer.check_val_every_n_epoch is None:
84
+ errors.append(
85
+ "Trainer is not running validation at epoch intervals "
86
+ "(i.e., `trainer.check_val_every_n_epoch` is None) but "
87
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
88
+ f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
89
+ + disable_message
90
+ )
91
+
92
+ # Second, we make sure that the trainer runs val at least every `frequency` epochs
93
+ if (
94
+ trainer.check_val_every_n_epoch is not None
95
+ and lr_scheduler_config.frequency
96
+ % trainer.check_val_every_n_epoch
97
+ != 0
98
+ ):
99
+ errors.append(
100
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
101
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
102
+ f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
103
+ + disable_message
104
+ )
105
+
106
+ case "step":
107
+ # In this case, we need to make sure that the trainer runs val at step intervals
108
+ # that are multiples of `frequency`.
109
+
110
+ # First, we make sure that validation is run at step intervals
111
+ if trainer.check_val_every_n_epoch is not None:
112
+ errors.append(
113
+ "Trainer is running validation at epoch intervals "
114
+ "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
115
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
116
+ "Please set `config.trainer.check_val_every_n_epoch=None` "
117
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
118
+ + disable_message
119
+ )
120
+
121
+ # Second, we make sure `trainer.val_check_interval` is an integer
122
+ if not isinstance(trainer.val_check_interval, int):
123
+ errors.append(
124
+ f"Trainer is not running validation at step intervals "
125
+ f"(i.e., `trainer.val_check_interval` is not an integer) but "
126
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
127
+ "Please set `config.trainer.val_check_interval=None` "
128
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
129
+ + disable_message
130
+ )
131
+
132
+ # Third, we make sure that the trainer runs val at least every `frequency` steps
133
+ if (
134
+ isinstance(trainer.val_check_interval, int)
135
+ and trainer.val_check_interval % lr_scheduler_config.frequency
136
+ != 0
137
+ ):
138
+ errors.append(
139
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
140
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
141
+ "Please set `config.trainer.val_check_interval` "
142
+ f"to a multiple of {lr_scheduler_config.frequency}. "
143
+ + disable_message
144
+ )
145
+
146
+ case _:
147
+ pass
148
+
149
+ if not errors:
150
+ return
151
+
152
+ message = (
153
+ "ReduceLRPlateau sanity checks failed with the following errors:\n"
154
+ + "\n".join(errors)
155
+ )
156
+ match self.config.on_error:
157
+ case "warn":
158
+ log.warning(message)
159
+ case "error":
160
+ raise ValueError(message)
161
+ case _:
162
+ pass
163
+
164
+
165
+ @runtime_checkable
166
+ class CustomRLPImplementation(Protocol):
167
+ __reduce_lr_on_plateau__: bool
168
+
169
+
170
+ class _RLPSanityCheckModuleMixin(LightningModule):
171
+ def reduce_lr_on_plateau_config(
172
+ self,
173
+ lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
174
+ ) -> LRSchedulerConfigType:
175
+ if (trainer := self._trainer) is None:
176
+ raise RuntimeError(
177
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
178
+ "because `self.trainer` is None."
179
+ )
180
+
181
+ # First, resolve the LR scheduler from the provided config.
182
+ lr_scheduler_config: LRSchedulerConfigType
183
+ match lr_scheduler:
184
+ case Mapping():
185
+ lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
186
+ case _:
187
+ lr_scheduler_config = {"scheduler": lr_scheduler}
188
+
189
+ # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
190
+ if (
191
+ not isinstance(
192
+ lr_scheduler_config["scheduler"],
193
+ torch.optim.lr_scheduler.ReduceLROnPlateau,
194
+ )
195
+ ) and (
196
+ not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
197
+ or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
198
+ ):
199
+ log.warning(
200
+ "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
201
+ f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
202
+ "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
203
+ "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
204
+ "If you are using a custom ReduceLRPlateau scheduler implementation, "
205
+ "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
206
+ "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
207
+ )
208
+
209
+ # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
210
+ if trainer.check_val_every_n_epoch is not None:
211
+ return {
212
+ "reduce_on_plateau": True,
213
+ "interval": "epoch",
214
+ "frequency": trainer.check_val_every_n_epoch,
215
+ **lr_scheduler_config,
216
+ }
217
+
218
+ # Otherwise, we run val at step intervals.
219
+ if not isinstance(trainer.val_check_batch, int):
220
+ raise ValueError(
221
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
222
+ f"because {trainer.val_check_batch=} is not an integer."
223
+ )
224
+
225
+ return {
226
+ "reduce_on_plateau": True,
227
+ "interval": "step",
228
+ "frequency": trainer.val_check_batch,
229
+ **lr_scheduler_config,
230
+ }
@@ -0,0 +1,87 @@
1
+ import logging
2
+ from collections.abc import Iterable
3
+ from typing import Literal, Protocol, TypeAlias, runtime_checkable
4
+
5
+ import torch.nn as nn
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.callbacks import Callback
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
16
+ mapping = {id(p): n for n, p in model.named_parameters()}
17
+ return [mapping[id(p)] for p in parameters]
18
+
19
+
20
+ class SharedParametersConfig(CallbackConfigBase):
21
+ """A callback that allows scaling the gradients of shared parameters that
22
+ are registered in the ``self.shared_parameters`` list of the root module.
23
+
24
+ This is useful for models that share parameters across multiple modules and
25
+ want to downscale the gradients of these parameters to avoid overfitting.
26
+ """
27
+
28
+ name: Literal["shared_parameters"] = "shared_parameters"
29
+
30
+ @override
31
+ def create_callbacks(self, root_config):
32
+ yield SharedParametersCallback(self)
33
+
34
+
35
+ SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
36
+
37
+
38
+ @runtime_checkable
39
+ class ModuleWithSharedParameters(Protocol):
40
+ @property
41
+ def shared_parameters(self) -> SharedParametersList: ...
42
+
43
+
44
+ class SharedParametersCallback(Callback):
45
+ @override
46
+ def __init__(self, config: SharedParametersConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ self._warned_shared_parameters = False
53
+
54
+ def _shared_parameters(self, pl_module: LightningModule) -> SharedParametersList:
55
+ if not isinstance(pl_module, ModuleWithSharedParameters):
56
+ return []
57
+
58
+ return pl_module.shared_parameters
59
+
60
+ @override
61
+ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
62
+ if not (shared_parameters := self._shared_parameters(pl_module)):
63
+ log.debug(
64
+ "No shared parameters to scale, skipping SharedParametersCallback"
65
+ )
66
+ return
67
+
68
+ log.debug(f"Scaling {len(shared_parameters)} shared parameters...")
69
+ no_grad_parameters: list[nn.Parameter] = []
70
+ for p, factor in shared_parameters:
71
+ if not hasattr(p, "grad") or p.grad is None:
72
+ no_grad_parameters.append(p)
73
+ continue
74
+
75
+ _ = p.grad.data.div_(factor)
76
+
77
+ if no_grad_parameters and not self._warned_shared_parameters:
78
+ no_grad_parameters_str = ", ".join(
79
+ _parameters_to_names(no_grad_parameters, pl_module)
80
+ )
81
+ log.warning(
82
+ "The following parameters were marked as shared, but had no gradients: "
83
+ f"{no_grad_parameters_str}"
84
+ )
85
+ self._warned_shared_parameters = True
86
+
87
+ log.debug(f"Done scaling shared parameters. (len={len(shared_parameters)})")
nshtrainer/config.py ADDED
@@ -0,0 +1,67 @@
1
+ from nshconfig._config import Config as Config
2
+ from nshsnap._config import SnapshotConfig as SnapshotConfig
3
+ from nshtrainer._checkpoint.loader import CheckpointLoadingConfig as CheckpointLoadingConfig
4
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
5
+ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
+ from nshtrainer._hf_hub import HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig
7
+ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
8
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
9
+ from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
10
+ from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig
11
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
12
+ from nshtrainer.callbacks.checkpoint.last_checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
13
+ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
14
+ from nshtrainer.callbacks.directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
+ from nshtrainer.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
+ from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
17
+ from nshtrainer.callbacks.finite_checks import FiniteChecksConfig as FiniteChecksConfig
18
+ from nshtrainer.callbacks.gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
19
+ from nshtrainer.callbacks.norm_logging import NormLoggingConfig as NormLoggingConfig
20
+ from nshtrainer.callbacks.print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
21
+ from nshtrainer.callbacks.rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
22
+ from nshtrainer.callbacks.shared_parameters import SharedParametersConfig as SharedParametersConfig
23
+ from nshtrainer.callbacks.throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
24
+ from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
25
+ from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
26
+ from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
27
+ from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
28
+ from nshtrainer.loggers.tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
29
+ from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
30
+ from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
31
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
32
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
33
+ from nshtrainer.metrics._config import MetricConfig as MetricConfig
34
+ from nshtrainer.model.config import BaseConfig as BaseConfig
35
+ from nshtrainer.nn.mlp import MLPConfig as MLPConfig
36
+ from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
37
+ from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
38
+ from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
39
+ from nshtrainer.nn.nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
40
+ from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
41
+ from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
42
+ from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
43
+ from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
44
+ from nshtrainer.nn.nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
45
+ from nshtrainer.nn.nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
46
+ from nshtrainer.nn.nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
47
+ from nshtrainer.nn.nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
48
+ from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
49
+ from nshtrainer.nn.nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
50
+ from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
51
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
52
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
53
+ from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
54
+ from nshtrainer.profiler.advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
55
+ from nshtrainer.profiler.pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
56
+ from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
57
+ from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
58
+ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
59
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
60
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
61
+ from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
62
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
63
+ from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
64
+ from nshtrainer.util._environment_info import EnvironmentClassInformationConfig as EnvironmentClassInformationConfig
65
+ from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
66
+ from nshtrainer.util._environment_info import EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig
67
+ from nshtrainer.util._environment_info import EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig
nshtrainer/ll/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import TypeAlias
2
+
1
3
  from . import _experimental as _experimental
2
4
  from . import actsave as actsave
3
5
  from . import callbacks as callbacks
@@ -21,12 +23,9 @@ from .log import init_python_logging as init_python_logging
21
23
  from .log import lovely as lovely
22
24
  from .log import pretty as pretty
23
25
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
- from .model import Base as Base
25
26
  from .model import BaseConfig as BaseConfig
26
- from .model import BaseProfilerConfig as BaseProfilerConfig
27
27
  from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
28
28
  from .model import CheckpointSavingConfig as CheckpointSavingConfig
29
- from .model import ConfigList as ConfigList
30
29
  from .model import DirectoryConfig as DirectoryConfig
31
30
  from .model import (
32
31
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
@@ -43,7 +42,6 @@ from .model import LightningModuleBase as LightningModuleBase
43
42
  from .model import LoggingConfig as LoggingConfig
44
43
  from .model import MetricConfig as MetricConfig
45
44
  from .model import OptimizationConfig as OptimizationConfig
46
- from .model import PrimaryMetricConfig as PrimaryMetricConfig
47
45
  from .model import ReproducibilityConfig as ReproducibilityConfig
48
46
  from .model import SanityCheckingConfig as SanityCheckingConfig
49
47
  from .model import TrainerConfig as TrainerConfig
@@ -54,3 +52,6 @@ from .runner import Runner as Runner
54
52
  from .runner import SnapshotConfig as SnapshotConfig
55
53
  from .snoop import snoop as snoop
56
54
  from .trainer import Trainer as Trainer
55
+
56
+ PrimaryMetricConfig: TypeAlias = MetricConfig
57
+ ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
nshtrainer/ll/model.py CHANGED
@@ -1,5 +1,12 @@
1
1
  from nshtrainer.model import * # noqa: F403
2
2
 
3
+ from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
4
+ from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
5
+ from ..trainer._config import GradientClippingConfig as GradientClippingConfig
6
+ from ..trainer._config import LoggingConfig as LoggingConfig
7
+ from ..trainer._config import OptimizationConfig as OptimizationConfig
8
+ from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
9
+ from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
3
10
  from ..util._environment_info import (
4
11
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
12
  )
@@ -129,7 +129,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
129
129
  "Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
130
130
  )
131
131
  else:
132
- wandb.require("core")
132
+ wandb.require("core") # type: ignore
133
133
  log.critical("Using the `wandb-core` backend for WandB.")
134
134
  except ImportError:
135
135
  pass
@@ -6,7 +6,7 @@ from torch.optim import Optimizer
6
6
  from torch.optim.lr_scheduler import LRScheduler
7
7
  from typing_extensions import override
8
8
 
9
- from ..config import Duration
9
+ from ..util.config import Duration
10
10
  from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12