nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +12 -1
  4. nshtrainer/callbacks/debug_flag.py +72 -0
  5. nshtrainer/callbacks/directory_setup.py +85 -0
  6. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  7. nshtrainer/callbacks/shared_parameters.py +87 -0
  8. nshtrainer/config.py +67 -0
  9. nshtrainer/ll/__init__.py +5 -4
  10. nshtrainer/ll/model.py +7 -0
  11. nshtrainer/loggers/wandb.py +1 -1
  12. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  13. nshtrainer/model/__init__.py +0 -21
  14. nshtrainer/model/base.py +124 -67
  15. nshtrainer/model/config.py +7 -1025
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +787 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/callback.py +0 -206
  28. nshtrainer/model/modules/debug.py +0 -42
  29. nshtrainer/model/modules/distributed.py +0 -70
  30. nshtrainer/model/modules/profiler.py +0 -24
  31. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  32. nshtrainer/model/modules/shared_parameters.py +0 -72
  33. /nshtrainer/{config → util/config}/duration.py +0 -0
  34. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -7,10 +7,9 @@ from . import metrics as metrics
7
7
  from . import model as model
8
8
  from . import nn as nn
9
9
  from . import optimizer as optimizer
10
+ from . import profiler as profiler
10
11
  from .metrics import MetricConfig as MetricConfig
11
- from .model import Base as Base
12
12
  from .model import BaseConfig as BaseConfig
13
- from .model import ConfigList as ConfigList
14
13
  from .model import LightningModuleBase as LightningModuleBase
15
14
  from .runner import Runner as Runner
16
15
  from .trainer import Trainer as Trainer
@@ -0,0 +1,85 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import nshconfig as C
5
+
6
+ from .callbacks.directory_setup import DirectorySetupConfig
7
+ from .loggers import LoggerConfig
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class DirectoryConfig(C.Config):
13
+ project_root: Path | None = None
14
+ """
15
+ Root directory for this project.
16
+
17
+ This isn't specific to the run; it is the parent directory of all runs.
18
+ """
19
+
20
+ log: Path | None = None
21
+ """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
22
+
23
+ stdio: Path | None = None
24
+ """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
25
+
26
+ checkpoint: Path | None = None
27
+ """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
28
+
29
+ activation: Path | None = None
30
+ """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
31
+
32
+ profile: Path | None = None
33
+ """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
34
+
35
+ setup_callback: DirectorySetupConfig = DirectorySetupConfig()
36
+ """Configuration for the directory setup PyTorch Lightning callback."""
37
+
38
+ def resolve_run_root_directory(self, run_id: str) -> Path:
39
+ if (project_root_dir := self.project_root) is None:
40
+ project_root_dir = Path.cwd()
41
+
42
+ # The default base dir is $CWD/nshtrainer/{id}/
43
+ base_dir = project_root_dir / "nshtrainer"
44
+ base_dir.mkdir(exist_ok=True)
45
+
46
+ # Add a .gitignore file to the nshtrainer directory
47
+ # which will ignore all files except for the .gitignore file itself
48
+ gitignore_path = base_dir / ".gitignore"
49
+ if not gitignore_path.exists():
50
+ gitignore_path.touch()
51
+ gitignore_path.write_text("*\n")
52
+
53
+ base_dir = base_dir / run_id
54
+ base_dir.mkdir(exist_ok=True)
55
+
56
+ return base_dir
57
+
58
+ def resolve_subdirectory(
59
+ self,
60
+ run_id: str,
61
+ # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
62
+ subdirectory: str,
63
+ ) -> Path:
64
+ # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
65
+ if (subdir := getattr(self, subdirectory, None)) is not None:
66
+ assert isinstance(
67
+ subdir, Path
68
+ ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
69
+ return subdir
70
+
71
+ dir = self.resolve_run_root_directory(run_id)
72
+ dir = dir / subdirectory
73
+ dir.mkdir(exist_ok=True)
74
+ return dir
75
+
76
+ def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
77
+ if (log_dir := logger.log_dir) is not None:
78
+ return log_dir
79
+
80
+ # Save to nshtrainer/{id}/log/{logger name}
81
+ log_dir = self.resolve_subdirectory(run_id, "log")
82
+ log_dir = log_dir / logger.name
83
+ log_dir.mkdir(exist_ok=True)
84
+
85
+ return log_dir
@@ -12,6 +12,10 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
12
12
  from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
+ from .debug_flag import DebugFlagCallback as DebugFlagCallback
16
+ from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
17
+ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
18
+ from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
19
  from .early_stopping import EarlyStopping as EarlyStopping
16
20
  from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
17
21
  from .ema import EMA as EMA
@@ -28,6 +32,10 @@ from .norm_logging import NormLoggingCallback as NormLoggingCallback
28
32
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
29
33
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
30
34
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
35
+ from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
36
+ from .rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
37
+ from .shared_parameters import SharedParametersCallback as SharedParametersCallback
38
+ from .shared_parameters import SharedParametersConfig as SharedParametersConfig
31
39
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
32
40
  from .timer import EpochTimer as EpochTimer
33
41
  from .timer import EpochTimerConfig as EpochTimerConfig
@@ -35,7 +43,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
35
43
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
36
44
 
37
45
  CallbackConfig = Annotated[
38
- EarlyStoppingConfig
46
+ DebugFlagCallbackConfig
47
+ | EarlyStoppingConfig
39
48
  | ThroughputMonitorConfig
40
49
  | EpochTimerConfig
41
50
  | PrintTableMetricsConfig
@@ -46,6 +55,8 @@ CallbackConfig = Annotated[
46
55
  | BestCheckpointCallbackConfig
47
56
  | LastCheckpointCallbackConfig
48
57
  | OnExceptionCheckpointCallbackConfig
58
+ | SharedParametersConfig
59
+ | RLPSanityChecksConfig
49
60
  | WandbWatchConfig,
50
61
  C.Field(discriminator="name"),
51
62
  ]
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Literal, cast
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks import Callback
6
+ from typing_extensions import override
7
+
8
+ from nshtrainer.model.config import BaseConfig
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ if TYPE_CHECKING:
13
+ from ..model.config import BaseConfig
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class DebugFlagCallbackConfig(CallbackConfigBase):
19
+ name: Literal["debug_flag"] = "debug_flag"
20
+
21
+ enabled: bool = True
22
+ """Whether to enable the callback."""
23
+
24
+ def __bool__(self):
25
+ return self.enabled
26
+
27
+ @override
28
+ def create_callbacks(self, root_config):
29
+ if not self:
30
+ return
31
+
32
+ yield DebugFlagCallback(self)
33
+
34
+
35
+ class DebugFlagCallback(Callback):
36
+ """
37
+ Sets the debug flag to true in the following circumstances:
38
+ - fast_dev_run is enabled
39
+ - sanity check is running
40
+ """
41
+
42
+ @override
43
+ def __init__(self, config: DebugFlagCallbackConfig):
44
+ super().__init__()
45
+
46
+ self.config = config
47
+ del config
48
+
49
+ @override
50
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
51
+ if not getattr(trainer, "fast_dev_run", False):
52
+ return
53
+
54
+ hparams = cast("BaseConfig", pl_module.hparams)
55
+ if not hparams.debug:
56
+ log.critical("Fast dev run detected, setting debug flag to True.")
57
+ hparams.debug = True
58
+
59
+ @override
60
+ def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
61
+ hparams = cast("BaseConfig", pl_module.hparams)
62
+ self._debug = hparams.debug
63
+ if not self._debug:
64
+ log.critical("Enabling debug flag during sanity check routine.")
65
+ hparams.debug = True
66
+
67
+ @override
68
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
69
+ hparams = cast("BaseConfig", pl_module.hparams)
70
+ if not self._debug:
71
+ log.critical("Sanity check routine complete, disabling debug flag.")
72
+ hparams.debug = self._debug
@@ -0,0 +1,85 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch import Callback
7
+ from typing_extensions import override
8
+
9
+ from .base import CallbackConfigBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def _create_symlink_to_nshrunner(base_dir: Path):
15
+ # Resolve the current nshrunner session directory
16
+ if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
17
+ log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
18
+ return
19
+ session_dir = Path(session_dir)
20
+ if not session_dir.exists() or not session_dir.is_dir():
21
+ log.warning(
22
+ f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
23
+ "Skipping symlink creation."
24
+ )
25
+ return
26
+
27
+ # Create the symlink
28
+ symlink_path = base_dir / "nshrunner"
29
+ if symlink_path.exists():
30
+ # If it already points to the correct directory, we're done
31
+ if symlink_path.resolve() == session_dir.resolve():
32
+ return
33
+
34
+ # Otherwise, we should log a warning and remove the existing symlink
35
+ log.warning(
36
+ f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
37
+ "Removing the existing symlink."
38
+ )
39
+ symlink_path.unlink()
40
+
41
+ symlink_path.symlink_to(session_dir)
42
+
43
+
44
+ class DirectorySetupConfig(CallbackConfigBase):
45
+ name: Literal["directory_setup"] = "directory_setup"
46
+
47
+ enabled: bool = True
48
+ """Whether to enable the directory setup callback."""
49
+
50
+ create_symlink_to_nshrunner_root: bool = True
51
+ """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
52
+
53
+ def __bool__(self):
54
+ return self.enabled
55
+
56
+ def create_callbacks(self, root_config):
57
+ if not self:
58
+ return
59
+
60
+ yield DirectorySetupCallback(self)
61
+
62
+
63
+ class DirectorySetupCallback(Callback):
64
+ @override
65
+ def __init__(self, config: DirectorySetupConfig):
66
+ super().__init__()
67
+
68
+ self.config = config
69
+ del config
70
+
71
+ @override
72
+ def setup(self, trainer, pl_module, stage):
73
+ super().setup(trainer, pl_module, stage)
74
+
75
+ # Create a symlink to the root folder for the Runner
76
+ if self.config.create_symlink_to_nshrunner_root:
77
+ # Resolve the base dir
78
+ from ..model.config import BaseConfig
79
+
80
+ assert isinstance(
81
+ config := pl_module.hparams, BaseConfig
82
+ ), f"Expected a BaseConfig, got {type(config)}"
83
+
84
+ base_dir = config.directory.resolve_run_root_directory(config.id)
85
+ _create_symlink_to_nshrunner(base_dir)
@@ -0,0 +1,230 @@
1
+ import logging
2
+ from collections.abc import Mapping
3
+ from typing import Literal, cast
4
+
5
+ import torch
6
+ from lightning.pytorch import LightningModule
7
+ from lightning.pytorch.callbacks import Callback
8
+ from lightning.pytorch.utilities.types import (
9
+ LRSchedulerConfigType,
10
+ LRSchedulerTypeUnion,
11
+ )
12
+ from typing_extensions import Protocol, override, runtime_checkable
13
+
14
+ from .base import CallbackConfigBase
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class RLPSanityChecksConfig(CallbackConfigBase):
20
+ """
21
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
22
+ - If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
23
+ - If the ``interval`` is epoch, it makes sure that validation is called every ``frequency`` epochs.
24
+ """
25
+
26
+ name: Literal["rlp_sanity_checks"] = "rlp_sanity_checks"
27
+
28
+ enabled: bool = True
29
+ """Whether to enable ReduceLRPlateau sanity checks."""
30
+
31
+ on_error: Literal["warn", "error"] = "error"
32
+ """What to do when a sanity check fails."""
33
+
34
+ def __bool__(self):
35
+ return self.enabled
36
+
37
+ def create_callbacks(self, root_config):
38
+ if not self:
39
+ return
40
+
41
+ yield RLPSanityChecksCallback(self)
42
+
43
+
44
+ class RLPSanityChecksCallback(Callback):
45
+ @override
46
+ def __init__(self, config: RLPSanityChecksConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ @override
53
+ def on_train_start(self, trainer, pl_module):
54
+ # If we're in PL's "sanity check" mode, we don't need to run this check
55
+ if trainer.sanity_checking:
56
+ return
57
+
58
+ # If the sanity check is disabled, return.
59
+ if not self.config:
60
+ return
61
+
62
+ # If no lr schedulers, return.
63
+ if not trainer.lr_scheduler_configs:
64
+ return
65
+
66
+ errors: list[str] = []
67
+ disable_message = (
68
+ "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau = None` "
69
+ "to disable this sanity check."
70
+ )
71
+
72
+ for lr_scheduler_config in trainer.lr_scheduler_configs:
73
+ if not lr_scheduler_config.reduce_on_plateau:
74
+ continue
75
+
76
+ match lr_scheduler_config.interval:
77
+ case "epoch":
78
+ # we need to make sure that the trainer runs val every `frequency` epochs
79
+
80
+ # If `trainer.check_val_every_n_epoch` is None, then Lightning
81
+ # will run val every `int(trainer.val_check_interval)` steps.
82
+ # So, first we need to make sure that `trainer.val_check_interval` is not None first.
83
+ if trainer.check_val_every_n_epoch is None:
84
+ errors.append(
85
+ "Trainer is not running validation at epoch intervals "
86
+ "(i.e., `trainer.check_val_every_n_epoch` is None) but "
87
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
88
+ f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
89
+ + disable_message
90
+ )
91
+
92
+ # Second, we make sure that the trainer runs val at least every `frequency` epochs
93
+ if (
94
+ trainer.check_val_every_n_epoch is not None
95
+ and lr_scheduler_config.frequency
96
+ % trainer.check_val_every_n_epoch
97
+ != 0
98
+ ):
99
+ errors.append(
100
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
101
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
102
+ f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
103
+ + disable_message
104
+ )
105
+
106
+ case "step":
107
+ # In this case, we need to make sure that the trainer runs val at step intervals
108
+ # that are multiples of `frequency`.
109
+
110
+ # First, we make sure that validation is run at step intervals
111
+ if trainer.check_val_every_n_epoch is not None:
112
+ errors.append(
113
+ "Trainer is running validation at epoch intervals "
114
+ "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
115
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
116
+ "Please set `config.trainer.check_val_every_n_epoch=None` "
117
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
118
+ + disable_message
119
+ )
120
+
121
+ # Second, we make sure `trainer.val_check_interval` is an integer
122
+ if not isinstance(trainer.val_check_interval, int):
123
+ errors.append(
124
+ f"Trainer is not running validation at step intervals "
125
+ f"(i.e., `trainer.val_check_interval` is not an integer) but "
126
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
127
+ "Please set `config.trainer.val_check_interval=None` "
128
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
129
+ + disable_message
130
+ )
131
+
132
+ # Third, we make sure that the trainer runs val at least every `frequency` steps
133
+ if (
134
+ isinstance(trainer.val_check_interval, int)
135
+ and trainer.val_check_interval % lr_scheduler_config.frequency
136
+ != 0
137
+ ):
138
+ errors.append(
139
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
140
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
141
+ "Please set `config.trainer.val_check_interval` "
142
+ f"to a multiple of {lr_scheduler_config.frequency}. "
143
+ + disable_message
144
+ )
145
+
146
+ case _:
147
+ pass
148
+
149
+ if not errors:
150
+ return
151
+
152
+ message = (
153
+ "ReduceLRPlateau sanity checks failed with the following errors:\n"
154
+ + "\n".join(errors)
155
+ )
156
+ match self.config.on_error:
157
+ case "warn":
158
+ log.warning(message)
159
+ case "error":
160
+ raise ValueError(message)
161
+ case _:
162
+ pass
163
+
164
+
165
+ @runtime_checkable
166
+ class CustomRLPImplementation(Protocol):
167
+ __reduce_lr_on_plateau__: bool
168
+
169
+
170
+ class _RLPSanityCheckModuleMixin(LightningModule):
171
+ def reduce_lr_on_plateau_config(
172
+ self,
173
+ lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
174
+ ) -> LRSchedulerConfigType:
175
+ if (trainer := self._trainer) is None:
176
+ raise RuntimeError(
177
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
178
+ "because `self.trainer` is None."
179
+ )
180
+
181
+ # First, resolve the LR scheduler from the provided config.
182
+ lr_scheduler_config: LRSchedulerConfigType
183
+ match lr_scheduler:
184
+ case Mapping():
185
+ lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
186
+ case _:
187
+ lr_scheduler_config = {"scheduler": lr_scheduler}
188
+
189
+ # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
190
+ if (
191
+ not isinstance(
192
+ lr_scheduler_config["scheduler"],
193
+ torch.optim.lr_scheduler.ReduceLROnPlateau,
194
+ )
195
+ ) and (
196
+ not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
197
+ or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
198
+ ):
199
+ log.warning(
200
+ "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
201
+ f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
202
+ "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
203
+ "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
204
+ "If you are using a custom ReduceLRPlateau scheduler implementation, "
205
+ "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
206
+ "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
207
+ )
208
+
209
+ # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
210
+ if trainer.check_val_every_n_epoch is not None:
211
+ return {
212
+ "reduce_on_plateau": True,
213
+ "interval": "epoch",
214
+ "frequency": trainer.check_val_every_n_epoch,
215
+ **lr_scheduler_config,
216
+ }
217
+
218
+ # Otherwise, we run val at step intervals.
219
+ if not isinstance(trainer.val_check_batch, int):
220
+ raise ValueError(
221
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
222
+ f"because {trainer.val_check_batch=} is not an integer."
223
+ )
224
+
225
+ return {
226
+ "reduce_on_plateau": True,
227
+ "interval": "step",
228
+ "frequency": trainer.val_check_batch,
229
+ **lr_scheduler_config,
230
+ }
@@ -0,0 +1,87 @@
1
+ import logging
2
+ from collections.abc import Iterable
3
+ from typing import Literal, Protocol, TypeAlias, runtime_checkable
4
+
5
+ import torch.nn as nn
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.callbacks import Callback
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
16
+ mapping = {id(p): n for n, p in model.named_parameters()}
17
+ return [mapping[id(p)] for p in parameters]
18
+
19
+
20
+ class SharedParametersConfig(CallbackConfigBase):
21
+ """A callback that allows scaling the gradients of shared parameters that
22
+ are registered in the ``self.shared_parameters`` list of the root module.
23
+
24
+ This is useful for models that share parameters across multiple modules and
25
+ want to downscale the gradients of these parameters to avoid overfitting.
26
+ """
27
+
28
+ name: Literal["shared_parameters"] = "shared_parameters"
29
+
30
+ @override
31
+ def create_callbacks(self, root_config):
32
+ yield SharedParametersCallback(self)
33
+
34
+
35
+ SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
36
+
37
+
38
+ @runtime_checkable
39
+ class ModuleWithSharedParameters(Protocol):
40
+ @property
41
+ def shared_parameters(self) -> SharedParametersList: ...
42
+
43
+
44
+ class SharedParametersCallback(Callback):
45
+ @override
46
+ def __init__(self, config: SharedParametersConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ self._warned_shared_parameters = False
53
+
54
+ def _shared_parameters(self, pl_module: LightningModule) -> SharedParametersList:
55
+ if not isinstance(pl_module, ModuleWithSharedParameters):
56
+ return []
57
+
58
+ return pl_module.shared_parameters
59
+
60
+ @override
61
+ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
62
+ if not (shared_parameters := self._shared_parameters(pl_module)):
63
+ log.debug(
64
+ "No shared parameters to scale, skipping SharedParametersCallback"
65
+ )
66
+ return
67
+
68
+ log.debug(f"Scaling {len(shared_parameters)} shared parameters...")
69
+ no_grad_parameters: list[nn.Parameter] = []
70
+ for p, factor in shared_parameters:
71
+ if not hasattr(p, "grad") or p.grad is None:
72
+ no_grad_parameters.append(p)
73
+ continue
74
+
75
+ _ = p.grad.data.div_(factor)
76
+
77
+ if no_grad_parameters and not self._warned_shared_parameters:
78
+ no_grad_parameters_str = ", ".join(
79
+ _parameters_to_names(no_grad_parameters, pl_module)
80
+ )
81
+ log.warning(
82
+ "The following parameters were marked as shared, but had no gradients: "
83
+ f"{no_grad_parameters_str}"
84
+ )
85
+ self._warned_shared_parameters = True
86
+
87
+ log.debug(f"Done scaling shared parameters. (len={len(shared_parameters)})")
nshtrainer/config.py ADDED
@@ -0,0 +1,67 @@
1
+ from nshconfig._config import Config as Config
2
+ from nshsnap._config import SnapshotConfig as SnapshotConfig
3
+ from nshtrainer._checkpoint.loader import CheckpointLoadingConfig as CheckpointLoadingConfig
4
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
5
+ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
+ from nshtrainer._hf_hub import HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig
7
+ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
8
+ from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
9
+ from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
10
+ from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig
11
+ from nshtrainer.callbacks.checkpoint.best_checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
12
+ from nshtrainer.callbacks.checkpoint.last_checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
13
+ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
14
+ from nshtrainer.callbacks.directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
+ from nshtrainer.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
+ from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
17
+ from nshtrainer.callbacks.finite_checks import FiniteChecksConfig as FiniteChecksConfig
18
+ from nshtrainer.callbacks.gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
19
+ from nshtrainer.callbacks.norm_logging import NormLoggingConfig as NormLoggingConfig
20
+ from nshtrainer.callbacks.print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
21
+ from nshtrainer.callbacks.rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
22
+ from nshtrainer.callbacks.shared_parameters import SharedParametersConfig as SharedParametersConfig
23
+ from nshtrainer.callbacks.throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
24
+ from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
25
+ from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
26
+ from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
27
+ from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
28
+ from nshtrainer.loggers.tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
29
+ from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
30
+ from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
31
+ from nshtrainer.lr_scheduler.linear_warmup_cosine import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
32
+ from nshtrainer.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
33
+ from nshtrainer.metrics._config import MetricConfig as MetricConfig
34
+ from nshtrainer.model.config import BaseConfig as BaseConfig
35
+ from nshtrainer.nn.mlp import MLPConfig as MLPConfig
36
+ from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
37
+ from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
38
+ from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
39
+ from nshtrainer.nn.nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
40
+ from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
41
+ from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
42
+ from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
43
+ from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
44
+ from nshtrainer.nn.nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
45
+ from nshtrainer.nn.nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
46
+ from nshtrainer.nn.nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
47
+ from nshtrainer.nn.nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
48
+ from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
49
+ from nshtrainer.nn.nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
50
+ from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
51
+ from nshtrainer.optimizer import AdamWConfig as AdamWConfig
52
+ from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
53
+ from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
54
+ from nshtrainer.profiler.advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
55
+ from nshtrainer.profiler.pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
56
+ from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
57
+ from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
58
+ from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
59
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
60
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
61
+ from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
62
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
63
+ from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
64
+ from nshtrainer.util._environment_info import EnvironmentClassInformationConfig as EnvironmentClassInformationConfig
65
+ from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
66
+ from nshtrainer.util._environment_info import EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig
67
+ from nshtrainer.util._environment_info import EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig