nshtrainer 0.30.1__tar.gz → 0.32.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/PKG-INFO +1 -1
  2. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/pyproject.toml +1 -1
  3. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/__init__.py +1 -2
  4. nshtrainer-0.32.0/src/nshtrainer/_directory.py +85 -0
  5. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/__init__.py +12 -1
  6. nshtrainer-0.32.0/src/nshtrainer/callbacks/debug_flag.py +72 -0
  7. nshtrainer-0.32.0/src/nshtrainer/callbacks/directory_setup.py +85 -0
  8. nshtrainer-0.32.0/src/nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  9. nshtrainer-0.32.0/src/nshtrainer/callbacks/shared_parameters.py +87 -0
  10. nshtrainer-0.32.0/src/nshtrainer/config.py +67 -0
  11. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/__init__.py +5 -4
  12. nshtrainer-0.32.0/src/nshtrainer/ll/model.py +19 -0
  13. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/loggers/wandb.py +1 -1
  14. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  15. nshtrainer-0.32.0/src/nshtrainer/model/__init__.py +5 -0
  16. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/model/base.py +124 -67
  17. nshtrainer-0.32.0/src/nshtrainer/model/config.py +206 -0
  18. {nshtrainer-0.30.1/src/nshtrainer/model/modules → nshtrainer-0.32.0/src/nshtrainer/model/mixins}/logger.py +13 -16
  19. nshtrainer-0.32.0/src/nshtrainer/profiler/__init__.py +13 -0
  20. nshtrainer-0.32.0/src/nshtrainer/profiler/_base.py +29 -0
  21. nshtrainer-0.32.0/src/nshtrainer/profiler/advanced.py +37 -0
  22. nshtrainer-0.32.0/src/nshtrainer/profiler/pytorch.py +83 -0
  23. nshtrainer-0.32.0/src/nshtrainer/profiler/simple.py +36 -0
  24. nshtrainer-0.30.1/src/nshtrainer/model/config.py → nshtrainer-0.32.0/src/nshtrainer/trainer/_config.py +38 -475
  25. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/trainer/trainer.py +16 -17
  26. {nshtrainer-0.30.1/src/nshtrainer → nshtrainer-0.32.0/src/nshtrainer/util}/config/__init__.py +1 -0
  27. nshtrainer-0.30.1/src/nshtrainer/ll/model.py +0 -12
  28. nshtrainer-0.30.1/src/nshtrainer/model/__init__.py +0 -26
  29. nshtrainer-0.30.1/src/nshtrainer/model/modules/callback.py +0 -206
  30. nshtrainer-0.30.1/src/nshtrainer/model/modules/debug.py +0 -42
  31. nshtrainer-0.30.1/src/nshtrainer/model/modules/distributed.py +0 -70
  32. nshtrainer-0.30.1/src/nshtrainer/model/modules/profiler.py +0 -24
  33. nshtrainer-0.30.1/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  34. nshtrainer-0.30.1/src/nshtrainer/model/modules/shared_parameters.py +0 -72
  35. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/README.md +0 -0
  36. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_callback.py +0 -0
  37. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
  38. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  39. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
  40. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_experimental/__init__.py +0 -0
  41. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/_hf_hub.py +0 -0
  42. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  43. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/actsave.py +0 -0
  44. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/base.py +0 -0
  45. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  46. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  47. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  48. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  49. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  50. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  51. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/ema.py +0 -0
  52. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  53. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  54. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/interval.py +0 -0
  55. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  56. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  57. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/print_table.py +0 -0
  58. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  59. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/timer.py +0 -0
  60. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  61. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/data/__init__.py +0 -0
  62. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  63. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/data/transform.py +0 -0
  64. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/_experimental.py +0 -0
  65. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/actsave.py +0 -0
  66. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/callbacks.py +0 -0
  67. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/config.py +0 -0
  68. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/data.py +0 -0
  69. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/log.py +0 -0
  70. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  71. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/nn.py +0 -0
  72. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/optimizer.py +0 -0
  73. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/runner.py +0 -0
  74. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/snapshot.py +0 -0
  75. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/snoop.py +0 -0
  76. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/trainer.py +0 -0
  77. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/typecheck.py +0 -0
  78. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/ll/util.py +0 -0
  79. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/loggers/__init__.py +0 -0
  80. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/loggers/_base.py +0 -0
  81. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/loggers/csv.py +0 -0
  82. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
  83. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  84. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  85. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  86. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/metrics/__init__.py +0 -0
  87. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/metrics/_config.py +0 -0
  88. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/nn/__init__.py +0 -0
  89. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/nn/mlp.py +0 -0
  90. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/nn/module_dict.py +0 -0
  91. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/nn/module_list.py +0 -0
  92. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
  93. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/optimizer.py +0 -0
  94. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/runner.py +0 -0
  95. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/scripts/find_packages.py +0 -0
  96. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/trainer/__init__.py +0 -0
  97. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  98. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  99. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
  100. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/_environment_info.py +0 -0
  101. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/_useful_types.py +0 -0
  102. {nshtrainer-0.30.1/src/nshtrainer → nshtrainer-0.32.0/src/nshtrainer/util}/config/duration.py +0 -0
  103. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/environment.py +0 -0
  104. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/path.py +0 -0
  105. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/seed.py +0 -0
  106. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/slurm.py +0 -0
  107. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/typed.py +0 -0
  108. {nshtrainer-0.30.1 → nshtrainer-0.32.0}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.30.1
3
+ Version: 0.32.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.30.1"
3
+ version = "0.32.0"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -7,10 +7,9 @@ from . import metrics as metrics
7
7
  from . import model as model
8
8
  from . import nn as nn
9
9
  from . import optimizer as optimizer
10
+ from . import profiler as profiler
10
11
  from .metrics import MetricConfig as MetricConfig
11
- from .model import Base as Base
12
12
  from .model import BaseConfig as BaseConfig
13
- from .model import ConfigList as ConfigList
14
13
  from .model import LightningModuleBase as LightningModuleBase
15
14
  from .runner import Runner as Runner
16
15
  from .trainer import Trainer as Trainer
@@ -0,0 +1,85 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import nshconfig as C
5
+
6
+ from .callbacks.directory_setup import DirectorySetupConfig
7
+ from .loggers import LoggerConfig
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class DirectoryConfig(C.Config):
13
+ project_root: Path | None = None
14
+ """
15
+ Root directory for this project.
16
+
17
+ This isn't specific to the run; it is the parent directory of all runs.
18
+ """
19
+
20
+ log: Path | None = None
21
+ """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
22
+
23
+ stdio: Path | None = None
24
+ """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
25
+
26
+ checkpoint: Path | None = None
27
+ """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
28
+
29
+ activation: Path | None = None
30
+ """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
31
+
32
+ profile: Path | None = None
33
+ """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
34
+
35
+ setup_callback: DirectorySetupConfig = DirectorySetupConfig()
36
+ """Configuration for the directory setup PyTorch Lightning callback."""
37
+
38
+ def resolve_run_root_directory(self, run_id: str) -> Path:
39
+ if (project_root_dir := self.project_root) is None:
40
+ project_root_dir = Path.cwd()
41
+
42
+ # The default base dir is $CWD/nshtrainer/{id}/
43
+ base_dir = project_root_dir / "nshtrainer"
44
+ base_dir.mkdir(exist_ok=True)
45
+
46
+ # Add a .gitignore file to the nshtrainer directory
47
+ # which will ignore all files except for the .gitignore file itself
48
+ gitignore_path = base_dir / ".gitignore"
49
+ if not gitignore_path.exists():
50
+ gitignore_path.touch()
51
+ gitignore_path.write_text("*\n")
52
+
53
+ base_dir = base_dir / run_id
54
+ base_dir.mkdir(exist_ok=True)
55
+
56
+ return base_dir
57
+
58
+ def resolve_subdirectory(
59
+ self,
60
+ run_id: str,
61
+ # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
62
+ subdirectory: str,
63
+ ) -> Path:
64
+ # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
65
+ if (subdir := getattr(self, subdirectory, None)) is not None:
66
+ assert isinstance(
67
+ subdir, Path
68
+ ), f"Expected a Path for {subdirectory}, got {type(subdir)}"
69
+ return subdir
70
+
71
+ dir = self.resolve_run_root_directory(run_id)
72
+ dir = dir / subdirectory
73
+ dir.mkdir(exist_ok=True)
74
+ return dir
75
+
76
+ def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
77
+ if (log_dir := logger.log_dir) is not None:
78
+ return log_dir
79
+
80
+ # Save to nshtrainer/{id}/log/{logger name}
81
+ log_dir = self.resolve_subdirectory(run_id, "log")
82
+ log_dir = log_dir / logger.name
83
+ log_dir.mkdir(exist_ok=True)
84
+
85
+ return log_dir
@@ -12,6 +12,10 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
12
12
  from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
+ from .debug_flag import DebugFlagCallback as DebugFlagCallback
16
+ from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
17
+ from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
18
+ from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
15
19
  from .early_stopping import EarlyStopping as EarlyStopping
16
20
  from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
17
21
  from .ema import EMA as EMA
@@ -28,6 +32,10 @@ from .norm_logging import NormLoggingCallback as NormLoggingCallback
28
32
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
29
33
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
30
34
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
35
+ from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
36
+ from .rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
37
+ from .shared_parameters import SharedParametersCallback as SharedParametersCallback
38
+ from .shared_parameters import SharedParametersConfig as SharedParametersConfig
31
39
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
32
40
  from .timer import EpochTimer as EpochTimer
33
41
  from .timer import EpochTimerConfig as EpochTimerConfig
@@ -35,7 +43,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
35
43
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
36
44
 
37
45
  CallbackConfig = Annotated[
38
- EarlyStoppingConfig
46
+ DebugFlagCallbackConfig
47
+ | EarlyStoppingConfig
39
48
  | ThroughputMonitorConfig
40
49
  | EpochTimerConfig
41
50
  | PrintTableMetricsConfig
@@ -46,6 +55,8 @@ CallbackConfig = Annotated[
46
55
  | BestCheckpointCallbackConfig
47
56
  | LastCheckpointCallbackConfig
48
57
  | OnExceptionCheckpointCallbackConfig
58
+ | SharedParametersConfig
59
+ | RLPSanityChecksConfig
49
60
  | WandbWatchConfig,
50
61
  C.Field(discriminator="name"),
51
62
  ]
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Literal, cast
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks import Callback
6
+ from typing_extensions import override
7
+
8
+ from nshtrainer.model.config import BaseConfig
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ if TYPE_CHECKING:
13
+ from ..model.config import BaseConfig
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class DebugFlagCallbackConfig(CallbackConfigBase):
19
+ name: Literal["debug_flag"] = "debug_flag"
20
+
21
+ enabled: bool = True
22
+ """Whether to enable the callback."""
23
+
24
+ def __bool__(self):
25
+ return self.enabled
26
+
27
+ @override
28
+ def create_callbacks(self, root_config):
29
+ if not self:
30
+ return
31
+
32
+ yield DebugFlagCallback(self)
33
+
34
+
35
+ class DebugFlagCallback(Callback):
36
+ """
37
+ Sets the debug flag to true in the following circumstances:
38
+ - fast_dev_run is enabled
39
+ - sanity check is running
40
+ """
41
+
42
+ @override
43
+ def __init__(self, config: DebugFlagCallbackConfig):
44
+ super().__init__()
45
+
46
+ self.config = config
47
+ del config
48
+
49
+ @override
50
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
51
+ if not getattr(trainer, "fast_dev_run", False):
52
+ return
53
+
54
+ hparams = cast("BaseConfig", pl_module.hparams)
55
+ if not hparams.debug:
56
+ log.critical("Fast dev run detected, setting debug flag to True.")
57
+ hparams.debug = True
58
+
59
+ @override
60
+ def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
61
+ hparams = cast("BaseConfig", pl_module.hparams)
62
+ self._debug = hparams.debug
63
+ if not self._debug:
64
+ log.critical("Enabling debug flag during sanity check routine.")
65
+ hparams.debug = True
66
+
67
+ @override
68
+ def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
69
+ hparams = cast("BaseConfig", pl_module.hparams)
70
+ if not self._debug:
71
+ log.critical("Sanity check routine complete, disabling debug flag.")
72
+ hparams.debug = self._debug
@@ -0,0 +1,85 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ from lightning.pytorch import Callback
7
+ from typing_extensions import override
8
+
9
+ from .base import CallbackConfigBase
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def _create_symlink_to_nshrunner(base_dir: Path):
15
+ # Resolve the current nshrunner session directory
16
+ if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
17
+ log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
18
+ return
19
+ session_dir = Path(session_dir)
20
+ if not session_dir.exists() or not session_dir.is_dir():
21
+ log.warning(
22
+ f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
23
+ "Skipping symlink creation."
24
+ )
25
+ return
26
+
27
+ # Create the symlink
28
+ symlink_path = base_dir / "nshrunner"
29
+ if symlink_path.exists():
30
+ # If it already points to the correct directory, we're done
31
+ if symlink_path.resolve() == session_dir.resolve():
32
+ return
33
+
34
+ # Otherwise, we should log a warning and remove the existing symlink
35
+ log.warning(
36
+ f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
37
+ "Removing the existing symlink."
38
+ )
39
+ symlink_path.unlink()
40
+
41
+ symlink_path.symlink_to(session_dir)
42
+
43
+
44
+ class DirectorySetupConfig(CallbackConfigBase):
45
+ name: Literal["directory_setup"] = "directory_setup"
46
+
47
+ enabled: bool = True
48
+ """Whether to enable the directory setup callback."""
49
+
50
+ create_symlink_to_nshrunner_root: bool = True
51
+ """Should we create a symlink to the root folder for the Runner (if we're in one)?"""
52
+
53
+ def __bool__(self):
54
+ return self.enabled
55
+
56
+ def create_callbacks(self, root_config):
57
+ if not self:
58
+ return
59
+
60
+ yield DirectorySetupCallback(self)
61
+
62
+
63
+ class DirectorySetupCallback(Callback):
64
+ @override
65
+ def __init__(self, config: DirectorySetupConfig):
66
+ super().__init__()
67
+
68
+ self.config = config
69
+ del config
70
+
71
+ @override
72
+ def setup(self, trainer, pl_module, stage):
73
+ super().setup(trainer, pl_module, stage)
74
+
75
+ # Create a symlink to the root folder for the Runner
76
+ if self.config.create_symlink_to_nshrunner_root:
77
+ # Resolve the base dir
78
+ from ..model.config import BaseConfig
79
+
80
+ assert isinstance(
81
+ config := pl_module.hparams, BaseConfig
82
+ ), f"Expected a BaseConfig, got {type(config)}"
83
+
84
+ base_dir = config.directory.resolve_run_root_directory(config.id)
85
+ _create_symlink_to_nshrunner(base_dir)
@@ -0,0 +1,230 @@
1
+ import logging
2
+ from collections.abc import Mapping
3
+ from typing import Literal, cast
4
+
5
+ import torch
6
+ from lightning.pytorch import LightningModule
7
+ from lightning.pytorch.callbacks import Callback
8
+ from lightning.pytorch.utilities.types import (
9
+ LRSchedulerConfigType,
10
+ LRSchedulerTypeUnion,
11
+ )
12
+ from typing_extensions import Protocol, override, runtime_checkable
13
+
14
+ from .base import CallbackConfigBase
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class RLPSanityChecksConfig(CallbackConfigBase):
20
+ """
21
+ If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
22
+ - If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
23
+ - If the ``interval`` is epoch, it makes sure that validation is called every ``frequency`` epochs.
24
+ """
25
+
26
+ name: Literal["rlp_sanity_checks"] = "rlp_sanity_checks"
27
+
28
+ enabled: bool = True
29
+ """Whether to enable ReduceLRPlateau sanity checks."""
30
+
31
+ on_error: Literal["warn", "error"] = "error"
32
+ """What to do when a sanity check fails."""
33
+
34
+ def __bool__(self):
35
+ return self.enabled
36
+
37
+ def create_callbacks(self, root_config):
38
+ if not self:
39
+ return
40
+
41
+ yield RLPSanityChecksCallback(self)
42
+
43
+
44
+ class RLPSanityChecksCallback(Callback):
45
+ @override
46
+ def __init__(self, config: RLPSanityChecksConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ @override
53
+ def on_train_start(self, trainer, pl_module):
54
+ # If we're in PL's "sanity check" mode, we don't need to run this check
55
+ if trainer.sanity_checking:
56
+ return
57
+
58
+ # If the sanity check is disabled, return.
59
+ if not self.config:
60
+ return
61
+
62
+ # If no lr schedulers, return.
63
+ if not trainer.lr_scheduler_configs:
64
+ return
65
+
66
+ errors: list[str] = []
67
+ disable_message = (
68
+ "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau = None` "
69
+ "to disable this sanity check."
70
+ )
71
+
72
+ for lr_scheduler_config in trainer.lr_scheduler_configs:
73
+ if not lr_scheduler_config.reduce_on_plateau:
74
+ continue
75
+
76
+ match lr_scheduler_config.interval:
77
+ case "epoch":
78
+ # we need to make sure that the trainer runs val every `frequency` epochs
79
+
80
+ # If `trainer.check_val_every_n_epoch` is None, then Lightning
81
+ # will run val every `int(trainer.val_check_interval)` steps.
82
+ # So, first we need to make sure that `trainer.val_check_interval` is not None first.
83
+ if trainer.check_val_every_n_epoch is None:
84
+ errors.append(
85
+ "Trainer is not running validation at epoch intervals "
86
+ "(i.e., `trainer.check_val_every_n_epoch` is None) but "
87
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
88
+ f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
89
+ + disable_message
90
+ )
91
+
92
+ # Second, we make sure that the trainer runs val at least every `frequency` epochs
93
+ if (
94
+ trainer.check_val_every_n_epoch is not None
95
+ and lr_scheduler_config.frequency
96
+ % trainer.check_val_every_n_epoch
97
+ != 0
98
+ ):
99
+ errors.append(
100
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
101
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
102
+ f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
103
+ + disable_message
104
+ )
105
+
106
+ case "step":
107
+ # In this case, we need to make sure that the trainer runs val at step intervals
108
+ # that are multiples of `frequency`.
109
+
110
+ # First, we make sure that validation is run at step intervals
111
+ if trainer.check_val_every_n_epoch is not None:
112
+ errors.append(
113
+ "Trainer is running validation at epoch intervals "
114
+ "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
115
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
116
+ "Please set `config.trainer.check_val_every_n_epoch=None` "
117
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
118
+ + disable_message
119
+ )
120
+
121
+ # Second, we make sure `trainer.val_check_interval` is an integer
122
+ if not isinstance(trainer.val_check_interval, int):
123
+ errors.append(
124
+ f"Trainer is not running validation at step intervals "
125
+ f"(i.e., `trainer.val_check_interval` is not an integer) but "
126
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
127
+ "Please set `config.trainer.val_check_interval=None` "
128
+ f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
129
+ + disable_message
130
+ )
131
+
132
+ # Third, we make sure that the trainer runs val at least every `frequency` steps
133
+ if (
134
+ isinstance(trainer.val_check_interval, int)
135
+ and trainer.val_check_interval % lr_scheduler_config.frequency
136
+ != 0
137
+ ):
138
+ errors.append(
139
+ f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
140
+ f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
141
+ "Please set `config.trainer.val_check_interval` "
142
+ f"to a multiple of {lr_scheduler_config.frequency}. "
143
+ + disable_message
144
+ )
145
+
146
+ case _:
147
+ pass
148
+
149
+ if not errors:
150
+ return
151
+
152
+ message = (
153
+ "ReduceLRPlateau sanity checks failed with the following errors:\n"
154
+ + "\n".join(errors)
155
+ )
156
+ match self.config.on_error:
157
+ case "warn":
158
+ log.warning(message)
159
+ case "error":
160
+ raise ValueError(message)
161
+ case _:
162
+ pass
163
+
164
+
165
+ @runtime_checkable
166
+ class CustomRLPImplementation(Protocol):
167
+ __reduce_lr_on_plateau__: bool
168
+
169
+
170
+ class _RLPSanityCheckModuleMixin(LightningModule):
171
+ def reduce_lr_on_plateau_config(
172
+ self,
173
+ lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
174
+ ) -> LRSchedulerConfigType:
175
+ if (trainer := self._trainer) is None:
176
+ raise RuntimeError(
177
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
178
+ "because `self.trainer` is None."
179
+ )
180
+
181
+ # First, resolve the LR scheduler from the provided config.
182
+ lr_scheduler_config: LRSchedulerConfigType
183
+ match lr_scheduler:
184
+ case Mapping():
185
+ lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
186
+ case _:
187
+ lr_scheduler_config = {"scheduler": lr_scheduler}
188
+
189
+ # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
190
+ if (
191
+ not isinstance(
192
+ lr_scheduler_config["scheduler"],
193
+ torch.optim.lr_scheduler.ReduceLROnPlateau,
194
+ )
195
+ ) and (
196
+ not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
197
+ or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
198
+ ):
199
+ log.warning(
200
+ "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
201
+ f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
202
+ "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
203
+ "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
204
+ "If you are using a custom ReduceLRPlateau scheduler implementation, "
205
+ "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
206
+ "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
207
+ )
208
+
209
+ # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
210
+ if trainer.check_val_every_n_epoch is not None:
211
+ return {
212
+ "reduce_on_plateau": True,
213
+ "interval": "epoch",
214
+ "frequency": trainer.check_val_every_n_epoch,
215
+ **lr_scheduler_config,
216
+ }
217
+
218
+ # Otherwise, we run val at step intervals.
219
+ if not isinstance(trainer.val_check_batch, int):
220
+ raise ValueError(
221
+ "Could not determine the frequency of ReduceLRPlateau scheduler "
222
+ f"because {trainer.val_check_batch=} is not an integer."
223
+ )
224
+
225
+ return {
226
+ "reduce_on_plateau": True,
227
+ "interval": "step",
228
+ "frequency": trainer.val_check_batch,
229
+ **lr_scheduler_config,
230
+ }
@@ -0,0 +1,87 @@
1
+ import logging
2
+ from collections.abc import Iterable
3
+ from typing import Literal, Protocol, TypeAlias, runtime_checkable
4
+
5
+ import torch.nn as nn
6
+ from lightning.pytorch import LightningModule, Trainer
7
+ from lightning.pytorch.callbacks import Callback
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
16
+ mapping = {id(p): n for n, p in model.named_parameters()}
17
+ return [mapping[id(p)] for p in parameters]
18
+
19
+
20
+ class SharedParametersConfig(CallbackConfigBase):
21
+ """A callback that allows scaling the gradients of shared parameters that
22
+ are registered in the ``self.shared_parameters`` list of the root module.
23
+
24
+ This is useful for models that share parameters across multiple modules and
25
+ want to downscale the gradients of these parameters to avoid overfitting.
26
+ """
27
+
28
+ name: Literal["shared_parameters"] = "shared_parameters"
29
+
30
+ @override
31
+ def create_callbacks(self, root_config):
32
+ yield SharedParametersCallback(self)
33
+
34
+
35
+ SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
36
+
37
+
38
+ @runtime_checkable
39
+ class ModuleWithSharedParameters(Protocol):
40
+ @property
41
+ def shared_parameters(self) -> SharedParametersList: ...
42
+
43
+
44
+ class SharedParametersCallback(Callback):
45
+ @override
46
+ def __init__(self, config: SharedParametersConfig):
47
+ super().__init__()
48
+
49
+ self.config = config
50
+ del config
51
+
52
+ self._warned_shared_parameters = False
53
+
54
+ def _shared_parameters(self, pl_module: LightningModule) -> SharedParametersList:
55
+ if not isinstance(pl_module, ModuleWithSharedParameters):
56
+ return []
57
+
58
+ return pl_module.shared_parameters
59
+
60
+ @override
61
+ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
62
+ if not (shared_parameters := self._shared_parameters(pl_module)):
63
+ log.debug(
64
+ "No shared parameters to scale, skipping SharedParametersCallback"
65
+ )
66
+ return
67
+
68
+ log.debug(f"Scaling {len(shared_parameters)} shared parameters...")
69
+ no_grad_parameters: list[nn.Parameter] = []
70
+ for p, factor in shared_parameters:
71
+ if not hasattr(p, "grad") or p.grad is None:
72
+ no_grad_parameters.append(p)
73
+ continue
74
+
75
+ _ = p.grad.data.div_(factor)
76
+
77
+ if no_grad_parameters and not self._warned_shared_parameters:
78
+ no_grad_parameters_str = ", ".join(
79
+ _parameters_to_names(no_grad_parameters, pl_module)
80
+ )
81
+ log.warning(
82
+ "The following parameters were marked as shared, but had no gradients: "
83
+ f"{no_grad_parameters_str}"
84
+ )
85
+ self._warned_shared_parameters = True
86
+
87
+ log.debug(f"Done scaling shared parameters. (len={len(shared_parameters)})")