nshtrainer 0.30.0__tar.gz → 0.31.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/PKG-INFO +1 -1
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/pyproject.toml +1 -1
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/__init__.py +1 -2
- nshtrainer-0.31.0/src/nshtrainer/_directory.py +85 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/__init__.py +8 -0
- nshtrainer-0.31.0/src/nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer-0.31.0/src/nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer-0.31.0/src/nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer-0.31.0/src/nshtrainer/config.py +67 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/__init__.py +5 -4
- nshtrainer-0.31.0/src/nshtrainer/ll/model.py +19 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/loggers/wandb.py +1 -1
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +3 -8
- nshtrainer-0.31.0/src/nshtrainer/model/__init__.py +5 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/model/base.py +139 -44
- nshtrainer-0.31.0/src/nshtrainer/model/config.py +206 -0
- {nshtrainer-0.30.0/src/nshtrainer/model/modules → nshtrainer-0.31.0/src/nshtrainer/model/mixins}/callback.py +2 -2
- {nshtrainer-0.30.0/src/nshtrainer/model/modules → nshtrainer-0.31.0/src/nshtrainer/model/mixins}/logger.py +13 -16
- nshtrainer-0.31.0/src/nshtrainer/profiler/__init__.py +13 -0
- nshtrainer-0.31.0/src/nshtrainer/profiler/_base.py +29 -0
- nshtrainer-0.31.0/src/nshtrainer/profiler/advanced.py +37 -0
- nshtrainer-0.31.0/src/nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer-0.31.0/src/nshtrainer/profiler/simple.py +36 -0
- nshtrainer-0.30.0/src/nshtrainer/model/config.py → nshtrainer-0.31.0/src/nshtrainer/trainer/_config.py +29 -475
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/trainer/trainer.py +16 -17
- {nshtrainer-0.30.0/src/nshtrainer → nshtrainer-0.31.0/src/nshtrainer/util}/config/__init__.py +1 -0
- nshtrainer-0.30.0/src/nshtrainer/ll/model.py +0 -12
- nshtrainer-0.30.0/src/nshtrainer/model/__init__.py +0 -26
- nshtrainer-0.30.0/src/nshtrainer/model/modules/debug.py +0 -42
- nshtrainer-0.30.0/src/nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer-0.30.0/src/nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer-0.30.0/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer-0.30.0/src/nshtrainer/model/modules/shared_parameters.py +0 -72
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/README.md +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.30.0/src/nshtrainer → nshtrainer-0.31.0/src/nshtrainer/util}/config/duration.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.30.0 → nshtrainer-0.31.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -7,10 +7,9 @@ from . import metrics as metrics
|
|
|
7
7
|
from . import model as model
|
|
8
8
|
from . import nn as nn
|
|
9
9
|
from . import optimizer as optimizer
|
|
10
|
+
from . import profiler as profiler
|
|
10
11
|
from .metrics import MetricConfig as MetricConfig
|
|
11
|
-
from .model import Base as Base
|
|
12
12
|
from .model import BaseConfig as BaseConfig
|
|
13
|
-
from .model import ConfigList as ConfigList
|
|
14
13
|
from .model import LightningModuleBase as LightningModuleBase
|
|
15
14
|
from .runner import Runner as Runner
|
|
16
15
|
from .trainer import Trainer as Trainer
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import nshconfig as C
|
|
5
|
+
|
|
6
|
+
from .callbacks.directory_setup import DirectorySetupConfig
|
|
7
|
+
from .loggers import LoggerConfig
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DirectoryConfig(C.Config):
|
|
13
|
+
project_root: Path | None = None
|
|
14
|
+
"""
|
|
15
|
+
Root directory for this project.
|
|
16
|
+
|
|
17
|
+
This isn't specific to the run; it is the parent directory of all runs.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
log: Path | None = None
|
|
21
|
+
"""Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
|
|
22
|
+
|
|
23
|
+
stdio: Path | None = None
|
|
24
|
+
"""stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
|
|
25
|
+
|
|
26
|
+
checkpoint: Path | None = None
|
|
27
|
+
"""Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
|
|
28
|
+
|
|
29
|
+
activation: Path | None = None
|
|
30
|
+
"""Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
|
|
31
|
+
|
|
32
|
+
profile: Path | None = None
|
|
33
|
+
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
|
34
|
+
|
|
35
|
+
setup_callback: DirectorySetupConfig = DirectorySetupConfig()
|
|
36
|
+
"""Configuration for the directory setup PyTorch Lightning callback."""
|
|
37
|
+
|
|
38
|
+
def resolve_run_root_directory(self, run_id: str) -> Path:
|
|
39
|
+
if (project_root_dir := self.project_root) is None:
|
|
40
|
+
project_root_dir = Path.cwd()
|
|
41
|
+
|
|
42
|
+
# The default base dir is $CWD/nshtrainer/{id}/
|
|
43
|
+
base_dir = project_root_dir / "nshtrainer"
|
|
44
|
+
base_dir.mkdir(exist_ok=True)
|
|
45
|
+
|
|
46
|
+
# Add a .gitignore file to the nshtrainer directory
|
|
47
|
+
# which will ignore all files except for the .gitignore file itself
|
|
48
|
+
gitignore_path = base_dir / ".gitignore"
|
|
49
|
+
if not gitignore_path.exists():
|
|
50
|
+
gitignore_path.touch()
|
|
51
|
+
gitignore_path.write_text("*\n")
|
|
52
|
+
|
|
53
|
+
base_dir = base_dir / run_id
|
|
54
|
+
base_dir.mkdir(exist_ok=True)
|
|
55
|
+
|
|
56
|
+
return base_dir
|
|
57
|
+
|
|
58
|
+
def resolve_subdirectory(
|
|
59
|
+
self,
|
|
60
|
+
run_id: str,
|
|
61
|
+
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
|
62
|
+
subdirectory: str,
|
|
63
|
+
) -> Path:
|
|
64
|
+
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
|
65
|
+
if (subdir := getattr(self, subdirectory, None)) is not None:
|
|
66
|
+
assert isinstance(
|
|
67
|
+
subdir, Path
|
|
68
|
+
), f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
|
69
|
+
return subdir
|
|
70
|
+
|
|
71
|
+
dir = self.resolve_run_root_directory(run_id)
|
|
72
|
+
dir = dir / subdirectory
|
|
73
|
+
dir.mkdir(exist_ok=True)
|
|
74
|
+
return dir
|
|
75
|
+
|
|
76
|
+
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
|
77
|
+
if (log_dir := logger.log_dir) is not None:
|
|
78
|
+
return log_dir
|
|
79
|
+
|
|
80
|
+
# Save to nshtrainer/{id}/log/{logger name}
|
|
81
|
+
log_dir = self.resolve_subdirectory(run_id, "log")
|
|
82
|
+
log_dir = log_dir / logger.name
|
|
83
|
+
log_dir.mkdir(exist_ok=True)
|
|
84
|
+
|
|
85
|
+
return log_dir
|
|
@@ -12,6 +12,8 @@ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
|
|
|
12
12
|
from .checkpoint import (
|
|
13
13
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
|
14
14
|
)
|
|
15
|
+
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
|
16
|
+
from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
|
|
15
17
|
from .early_stopping import EarlyStopping as EarlyStopping
|
|
16
18
|
from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
|
|
17
19
|
from .ema import EMA as EMA
|
|
@@ -28,6 +30,10 @@ from .norm_logging import NormLoggingCallback as NormLoggingCallback
|
|
|
28
30
|
from .norm_logging import NormLoggingConfig as NormLoggingConfig
|
|
29
31
|
from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
|
|
30
32
|
from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
|
|
33
|
+
from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
|
|
34
|
+
from .rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
|
|
35
|
+
from .shared_parameters import SharedParametersCallback as SharedParametersCallback
|
|
36
|
+
from .shared_parameters import SharedParametersConfig as SharedParametersConfig
|
|
31
37
|
from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
32
38
|
from .timer import EpochTimer as EpochTimer
|
|
33
39
|
from .timer import EpochTimerConfig as EpochTimerConfig
|
|
@@ -46,6 +52,8 @@ CallbackConfig = Annotated[
|
|
|
46
52
|
| BestCheckpointCallbackConfig
|
|
47
53
|
| LastCheckpointCallbackConfig
|
|
48
54
|
| OnExceptionCheckpointCallbackConfig
|
|
55
|
+
| SharedParametersConfig
|
|
56
|
+
| RLPSanityChecksConfig
|
|
49
57
|
| WandbWatchConfig,
|
|
50
58
|
C.Field(discriminator="name"),
|
|
51
59
|
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch import Callback
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from .base import CallbackConfigBase
|
|
10
|
+
|
|
11
|
+
log = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _create_symlink_to_nshrunner(base_dir: Path):
|
|
15
|
+
# Resolve the current nshrunner session directory
|
|
16
|
+
if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
|
|
17
|
+
log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
|
|
18
|
+
return
|
|
19
|
+
session_dir = Path(session_dir)
|
|
20
|
+
if not session_dir.exists() or not session_dir.is_dir():
|
|
21
|
+
log.warning(
|
|
22
|
+
f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
|
|
23
|
+
"Skipping symlink creation."
|
|
24
|
+
)
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
# Create the symlink
|
|
28
|
+
symlink_path = base_dir / "nshrunner"
|
|
29
|
+
if symlink_path.exists():
|
|
30
|
+
# If it already points to the correct directory, we're done
|
|
31
|
+
if symlink_path.resolve() == session_dir.resolve():
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
# Otherwise, we should log a warning and remove the existing symlink
|
|
35
|
+
log.warning(
|
|
36
|
+
f"A symlink pointing to {symlink_path.resolve()} already exists at {symlink_path}. "
|
|
37
|
+
"Removing the existing symlink."
|
|
38
|
+
)
|
|
39
|
+
symlink_path.unlink()
|
|
40
|
+
|
|
41
|
+
symlink_path.symlink_to(session_dir)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DirectorySetupConfig(CallbackConfigBase):
|
|
45
|
+
name: Literal["directory_setup"] = "directory_setup"
|
|
46
|
+
|
|
47
|
+
enabled: bool = True
|
|
48
|
+
"""Whether to enable the directory setup callback."""
|
|
49
|
+
|
|
50
|
+
create_symlink_to_nshrunner_root: bool = True
|
|
51
|
+
"""Should we create a symlink to the root folder for the Runner (if we're in one)?"""
|
|
52
|
+
|
|
53
|
+
def __bool__(self):
|
|
54
|
+
return self.enabled
|
|
55
|
+
|
|
56
|
+
def create_callbacks(self, root_config):
|
|
57
|
+
if not self:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
yield DirectorySetupCallback(self)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class DirectorySetupCallback(Callback):
|
|
64
|
+
@override
|
|
65
|
+
def __init__(self, config: DirectorySetupConfig):
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
self.config = config
|
|
69
|
+
del config
|
|
70
|
+
|
|
71
|
+
@override
|
|
72
|
+
def setup(self, trainer, pl_module, stage):
|
|
73
|
+
super().setup(trainer, pl_module, stage)
|
|
74
|
+
|
|
75
|
+
# Create a symlink to the root folder for the Runner
|
|
76
|
+
if self.config.create_symlink_to_nshrunner_root:
|
|
77
|
+
# Resolve the base dir
|
|
78
|
+
from ..model.config import BaseConfig
|
|
79
|
+
|
|
80
|
+
assert isinstance(
|
|
81
|
+
config := pl_module.hparams, BaseConfig
|
|
82
|
+
), f"Expected a BaseConfig, got {type(config)}"
|
|
83
|
+
|
|
84
|
+
base_dir = config.directory.resolve_run_root_directory(config.id)
|
|
85
|
+
_create_symlink_to_nshrunner(base_dir)
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from lightning.pytorch import LightningModule
|
|
7
|
+
from lightning.pytorch.callbacks import Callback
|
|
8
|
+
from lightning.pytorch.utilities.types import (
|
|
9
|
+
LRSchedulerConfigType,
|
|
10
|
+
LRSchedulerTypeUnion,
|
|
11
|
+
)
|
|
12
|
+
from typing_extensions import Protocol, override, runtime_checkable
|
|
13
|
+
|
|
14
|
+
from .base import CallbackConfigBase
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RLPSanityChecksConfig(CallbackConfigBase):
|
|
20
|
+
"""
|
|
21
|
+
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
|
22
|
+
- If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
|
|
23
|
+
- If the ``interval`` is epoch, it makes sure that validation is called every ``frequency`` epochs.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
name: Literal["rlp_sanity_checks"] = "rlp_sanity_checks"
|
|
27
|
+
|
|
28
|
+
enabled: bool = True
|
|
29
|
+
"""Whether to enable ReduceLRPlateau sanity checks."""
|
|
30
|
+
|
|
31
|
+
on_error: Literal["warn", "error"] = "error"
|
|
32
|
+
"""What to do when a sanity check fails."""
|
|
33
|
+
|
|
34
|
+
def __bool__(self):
|
|
35
|
+
return self.enabled
|
|
36
|
+
|
|
37
|
+
def create_callbacks(self, root_config):
|
|
38
|
+
if not self:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
yield RLPSanityChecksCallback(self)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RLPSanityChecksCallback(Callback):
|
|
45
|
+
@override
|
|
46
|
+
def __init__(self, config: RLPSanityChecksConfig):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self.config = config
|
|
50
|
+
del config
|
|
51
|
+
|
|
52
|
+
@override
|
|
53
|
+
def on_train_start(self, trainer, pl_module):
|
|
54
|
+
# If we're in PL's "sanity check" mode, we don't need to run this check
|
|
55
|
+
if trainer.sanity_checking:
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
# If the sanity check is disabled, return.
|
|
59
|
+
if not self.config:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
# If no lr schedulers, return.
|
|
63
|
+
if not trainer.lr_scheduler_configs:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
errors: list[str] = []
|
|
67
|
+
disable_message = (
|
|
68
|
+
"Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau = None` "
|
|
69
|
+
"to disable this sanity check."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
for lr_scheduler_config in trainer.lr_scheduler_configs:
|
|
73
|
+
if not lr_scheduler_config.reduce_on_plateau:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
match lr_scheduler_config.interval:
|
|
77
|
+
case "epoch":
|
|
78
|
+
# we need to make sure that the trainer runs val every `frequency` epochs
|
|
79
|
+
|
|
80
|
+
# If `trainer.check_val_every_n_epoch` is None, then Lightning
|
|
81
|
+
# will run val every `int(trainer.val_check_interval)` steps.
|
|
82
|
+
# So, first we need to make sure that `trainer.val_check_interval` is not None first.
|
|
83
|
+
if trainer.check_val_every_n_epoch is None:
|
|
84
|
+
errors.append(
|
|
85
|
+
"Trainer is not running validation at epoch intervals "
|
|
86
|
+
"(i.e., `trainer.check_val_every_n_epoch` is None) but "
|
|
87
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
88
|
+
f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
|
|
89
|
+
+ disable_message
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Second, we make sure that the trainer runs val at least every `frequency` epochs
|
|
93
|
+
if (
|
|
94
|
+
trainer.check_val_every_n_epoch is not None
|
|
95
|
+
and lr_scheduler_config.frequency
|
|
96
|
+
% trainer.check_val_every_n_epoch
|
|
97
|
+
!= 0
|
|
98
|
+
):
|
|
99
|
+
errors.append(
|
|
100
|
+
f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
|
|
101
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
102
|
+
f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
|
|
103
|
+
+ disable_message
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
case "step":
|
|
107
|
+
# In this case, we need to make sure that the trainer runs val at step intervals
|
|
108
|
+
# that are multiples of `frequency`.
|
|
109
|
+
|
|
110
|
+
# First, we make sure that validation is run at step intervals
|
|
111
|
+
if trainer.check_val_every_n_epoch is not None:
|
|
112
|
+
errors.append(
|
|
113
|
+
"Trainer is running validation at epoch intervals "
|
|
114
|
+
"(i.e., `trainer.check_val_every_n_epoch` is not None) but "
|
|
115
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
116
|
+
"Please set `config.trainer.check_val_every_n_epoch=None` "
|
|
117
|
+
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
118
|
+
+ disable_message
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Second, we make sure `trainer.val_check_interval` is an integer
|
|
122
|
+
if not isinstance(trainer.val_check_interval, int):
|
|
123
|
+
errors.append(
|
|
124
|
+
f"Trainer is not running validation at step intervals "
|
|
125
|
+
f"(i.e., `trainer.val_check_interval` is not an integer) but "
|
|
126
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
127
|
+
"Please set `config.trainer.val_check_interval=None` "
|
|
128
|
+
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
129
|
+
+ disable_message
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Third, we make sure that the trainer runs val at least every `frequency` steps
|
|
133
|
+
if (
|
|
134
|
+
isinstance(trainer.val_check_interval, int)
|
|
135
|
+
and trainer.val_check_interval % lr_scheduler_config.frequency
|
|
136
|
+
!= 0
|
|
137
|
+
):
|
|
138
|
+
errors.append(
|
|
139
|
+
f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
|
|
140
|
+
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
141
|
+
"Please set `config.trainer.val_check_interval` "
|
|
142
|
+
f"to a multiple of {lr_scheduler_config.frequency}. "
|
|
143
|
+
+ disable_message
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
case _:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
if not errors:
|
|
150
|
+
return
|
|
151
|
+
|
|
152
|
+
message = (
|
|
153
|
+
"ReduceLRPlateau sanity checks failed with the following errors:\n"
|
|
154
|
+
+ "\n".join(errors)
|
|
155
|
+
)
|
|
156
|
+
match self.config.on_error:
|
|
157
|
+
case "warn":
|
|
158
|
+
log.warning(message)
|
|
159
|
+
case "error":
|
|
160
|
+
raise ValueError(message)
|
|
161
|
+
case _:
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@runtime_checkable
|
|
166
|
+
class CustomRLPImplementation(Protocol):
|
|
167
|
+
__reduce_lr_on_plateau__: bool
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class _RLPSanityCheckModuleMixin(LightningModule):
|
|
171
|
+
def reduce_lr_on_plateau_config(
|
|
172
|
+
self,
|
|
173
|
+
lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
|
|
174
|
+
) -> LRSchedulerConfigType:
|
|
175
|
+
if (trainer := self._trainer) is None:
|
|
176
|
+
raise RuntimeError(
|
|
177
|
+
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
178
|
+
"because `self.trainer` is None."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# First, resolve the LR scheduler from the provided config.
|
|
182
|
+
lr_scheduler_config: LRSchedulerConfigType
|
|
183
|
+
match lr_scheduler:
|
|
184
|
+
case Mapping():
|
|
185
|
+
lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
|
|
186
|
+
case _:
|
|
187
|
+
lr_scheduler_config = {"scheduler": lr_scheduler}
|
|
188
|
+
|
|
189
|
+
# Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
|
|
190
|
+
if (
|
|
191
|
+
not isinstance(
|
|
192
|
+
lr_scheduler_config["scheduler"],
|
|
193
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
194
|
+
)
|
|
195
|
+
) and (
|
|
196
|
+
not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
|
|
197
|
+
or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
|
|
198
|
+
):
|
|
199
|
+
log.warning(
|
|
200
|
+
"`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
|
|
201
|
+
f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
|
|
202
|
+
"`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
|
|
203
|
+
"Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
|
|
204
|
+
"If you are using a custom ReduceLRPlateau scheduler implementation, "
|
|
205
|
+
"please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
|
|
206
|
+
"or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
|
|
210
|
+
if trainer.check_val_every_n_epoch is not None:
|
|
211
|
+
return {
|
|
212
|
+
"reduce_on_plateau": True,
|
|
213
|
+
"interval": "epoch",
|
|
214
|
+
"frequency": trainer.check_val_every_n_epoch,
|
|
215
|
+
**lr_scheduler_config,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
# Otherwise, we run val at step intervals.
|
|
219
|
+
if not isinstance(trainer.val_check_batch, int):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
222
|
+
f"because {trainer.val_check_batch=} is not an integer."
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
"reduce_on_plateau": True,
|
|
227
|
+
"interval": "step",
|
|
228
|
+
"frequency": trainer.val_check_batch,
|
|
229
|
+
**lr_scheduler_config,
|
|
230
|
+
}
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Literal, Protocol, TypeAlias, runtime_checkable
|
|
4
|
+
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
+
from lightning.pytorch.callbacks import Callback
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
|
|
16
|
+
mapping = {id(p): n for n, p in model.named_parameters()}
|
|
17
|
+
return [mapping[id(p)] for p in parameters]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SharedParametersConfig(CallbackConfigBase):
|
|
21
|
+
"""A callback that allows scaling the gradients of shared parameters that
|
|
22
|
+
are registered in the ``self.shared_parameters`` list of the root module.
|
|
23
|
+
|
|
24
|
+
This is useful for models that share parameters across multiple modules and
|
|
25
|
+
want to downscale the gradients of these parameters to avoid overfitting.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
name: Literal["shared_parameters"] = "shared_parameters"
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def create_callbacks(self, root_config):
|
|
32
|
+
yield SharedParametersCallback(self)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@runtime_checkable
|
|
39
|
+
class ModuleWithSharedParameters(Protocol):
|
|
40
|
+
@property
|
|
41
|
+
def shared_parameters(self) -> SharedParametersList: ...
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SharedParametersCallback(Callback):
|
|
45
|
+
@override
|
|
46
|
+
def __init__(self, config: SharedParametersConfig):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
self.config = config
|
|
50
|
+
del config
|
|
51
|
+
|
|
52
|
+
self._warned_shared_parameters = False
|
|
53
|
+
|
|
54
|
+
def _shared_parameters(self, pl_module: LightningModule) -> SharedParametersList:
|
|
55
|
+
if not isinstance(pl_module, ModuleWithSharedParameters):
|
|
56
|
+
return []
|
|
57
|
+
|
|
58
|
+
return pl_module.shared_parameters
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
|
|
62
|
+
if not (shared_parameters := self._shared_parameters(pl_module)):
|
|
63
|
+
log.debug(
|
|
64
|
+
"No shared parameters to scale, skipping SharedParametersCallback"
|
|
65
|
+
)
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
log.debug(f"Scaling {len(shared_parameters)} shared parameters...")
|
|
69
|
+
no_grad_parameters: list[nn.Parameter] = []
|
|
70
|
+
for p, factor in shared_parameters:
|
|
71
|
+
if not hasattr(p, "grad") or p.grad is None:
|
|
72
|
+
no_grad_parameters.append(p)
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
_ = p.grad.data.div_(factor)
|
|
76
|
+
|
|
77
|
+
if no_grad_parameters and not self._warned_shared_parameters:
|
|
78
|
+
no_grad_parameters_str = ", ".join(
|
|
79
|
+
_parameters_to_names(no_grad_parameters, pl_module)
|
|
80
|
+
)
|
|
81
|
+
log.warning(
|
|
82
|
+
"The following parameters were marked as shared, but had no gradients: "
|
|
83
|
+
f"{no_grad_parameters_str}"
|
|
84
|
+
)
|
|
85
|
+
self._warned_shared_parameters = True
|
|
86
|
+
|
|
87
|
+
log.debug(f"Done scaling shared parameters. (len={len(shared_parameters)})")
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from nshconfig._config import Config as Config
|
|
2
|
+
from nshsnap._config import SnapshotConfig as SnapshotConfig
|
|
3
|
+
from nshtrainer._checkpoint.loader import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
4
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
|
5
|
+
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
6
|
+
from nshtrainer._hf_hub import HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig
|
|
7
|
+
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
8
|
+
from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
|
|
9
|
+
from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
|
|
10
|
+
from nshtrainer.callbacks.checkpoint._base import BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig
|
|
11
|
+
from nshtrainer.callbacks.checkpoint.best_checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
|
|
12
|
+
from nshtrainer.callbacks.checkpoint.last_checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
|
|
13
|
+
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig
|
|
14
|
+
from nshtrainer.callbacks.directory_setup import DirectorySetupConfig as DirectorySetupConfig
|
|
15
|
+
from nshtrainer.callbacks.early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
|
|
16
|
+
from nshtrainer.callbacks.ema import EMAConfig as EMAConfig
|
|
17
|
+
from nshtrainer.callbacks.finite_checks import FiniteChecksConfig as FiniteChecksConfig
|
|
18
|
+
from nshtrainer.callbacks.gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
|
|
19
|
+
from nshtrainer.callbacks.norm_logging import NormLoggingConfig as NormLoggingConfig
|
|
20
|
+
from nshtrainer.callbacks.print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
|
|
21
|
+
from nshtrainer.callbacks.rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
|
|
22
|
+
from nshtrainer.callbacks.shared_parameters import SharedParametersConfig as SharedParametersConfig
|
|
23
|
+
from nshtrainer.callbacks.throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
|
|
24
|
+
from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
|
|
25
|
+
from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
26
|
+
from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
|
|
27
|
+
from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
|
|
28
|
+
from nshtrainer.loggers.tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
|
29
|
+
from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
30
|
+
from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
31
|
+
from nshtrainer.lr_scheduler.linear_warmup_cosine import LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig
|
|
32
|
+
from nshtrainer.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
|
33
|
+
from nshtrainer.metrics._config import MetricConfig as MetricConfig
|
|
34
|
+
from nshtrainer.model.config import BaseConfig as BaseConfig
|
|
35
|
+
from nshtrainer.nn.mlp import MLPConfig as MLPConfig
|
|
36
|
+
from nshtrainer.nn.nonlinearity import BaseNonlinearityConfig as BaseNonlinearityConfig
|
|
37
|
+
from nshtrainer.nn.nonlinearity import ELUNonlinearityConfig as ELUNonlinearityConfig
|
|
38
|
+
from nshtrainer.nn.nonlinearity import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
39
|
+
from nshtrainer.nn.nonlinearity import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
|
40
|
+
from nshtrainer.nn.nonlinearity import MishNonlinearityConfig as MishNonlinearityConfig
|
|
41
|
+
from nshtrainer.nn.nonlinearity import PReLUConfig as PReLUConfig
|
|
42
|
+
from nshtrainer.nn.nonlinearity import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
|
43
|
+
from nshtrainer.nn.nonlinearity import SiLUNonlinearityConfig as SiLUNonlinearityConfig
|
|
44
|
+
from nshtrainer.nn.nonlinearity import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
|
45
|
+
from nshtrainer.nn.nonlinearity import SoftmaxNonlinearityConfig as SoftmaxNonlinearityConfig
|
|
46
|
+
from nshtrainer.nn.nonlinearity import SoftplusNonlinearityConfig as SoftplusNonlinearityConfig
|
|
47
|
+
from nshtrainer.nn.nonlinearity import SoftsignNonlinearityConfig as SoftsignNonlinearityConfig
|
|
48
|
+
from nshtrainer.nn.nonlinearity import SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig
|
|
49
|
+
from nshtrainer.nn.nonlinearity import SwishNonlinearityConfig as SwishNonlinearityConfig
|
|
50
|
+
from nshtrainer.nn.nonlinearity import TanhNonlinearityConfig as TanhNonlinearityConfig
|
|
51
|
+
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
|
52
|
+
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
|
53
|
+
from nshtrainer.profiler._base import BaseProfilerConfig as BaseProfilerConfig
|
|
54
|
+
from nshtrainer.profiler.advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
|
|
55
|
+
from nshtrainer.profiler.pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
|
|
56
|
+
from nshtrainer.profiler.simple import SimpleProfilerConfig as SimpleProfilerConfig
|
|
57
|
+
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
58
|
+
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
|
59
|
+
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
|
60
|
+
from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
|
|
61
|
+
from nshtrainer.trainer._config import ReproducibilityConfig as ReproducibilityConfig
|
|
62
|
+
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
63
|
+
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
|
64
|
+
from nshtrainer.util._environment_info import EnvironmentClassInformationConfig as EnvironmentClassInformationConfig
|
|
65
|
+
from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
|
|
66
|
+
from nshtrainer.util._environment_info import EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig
|
|
67
|
+
from nshtrainer.util._environment_info import EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import TypeAlias
|
|
2
|
+
|
|
1
3
|
from . import _experimental as _experimental
|
|
2
4
|
from . import actsave as actsave
|
|
3
5
|
from . import callbacks as callbacks
|
|
@@ -21,12 +23,9 @@ from .log import init_python_logging as init_python_logging
|
|
|
21
23
|
from .log import lovely as lovely
|
|
22
24
|
from .log import pretty as pretty
|
|
23
25
|
from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
24
|
-
from .model import Base as Base
|
|
25
26
|
from .model import BaseConfig as BaseConfig
|
|
26
|
-
from .model import BaseProfilerConfig as BaseProfilerConfig
|
|
27
27
|
from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
28
28
|
from .model import CheckpointSavingConfig as CheckpointSavingConfig
|
|
29
|
-
from .model import ConfigList as ConfigList
|
|
30
29
|
from .model import DirectoryConfig as DirectoryConfig
|
|
31
30
|
from .model import (
|
|
32
31
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
@@ -43,7 +42,6 @@ from .model import LightningModuleBase as LightningModuleBase
|
|
|
43
42
|
from .model import LoggingConfig as LoggingConfig
|
|
44
43
|
from .model import MetricConfig as MetricConfig
|
|
45
44
|
from .model import OptimizationConfig as OptimizationConfig
|
|
46
|
-
from .model import PrimaryMetricConfig as PrimaryMetricConfig
|
|
47
45
|
from .model import ReproducibilityConfig as ReproducibilityConfig
|
|
48
46
|
from .model import SanityCheckingConfig as SanityCheckingConfig
|
|
49
47
|
from .model import TrainerConfig as TrainerConfig
|
|
@@ -54,3 +52,6 @@ from .runner import Runner as Runner
|
|
|
54
52
|
from .runner import SnapshotConfig as SnapshotConfig
|
|
55
53
|
from .snoop import snoop as snoop
|
|
56
54
|
from .trainer import Trainer as Trainer
|
|
55
|
+
|
|
56
|
+
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
57
|
+
ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from nshtrainer.model import * # noqa: F403
|
|
2
|
+
|
|
3
|
+
from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
|
|
4
|
+
from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
|
5
|
+
from ..trainer._config import GradientClippingConfig as GradientClippingConfig
|
|
6
|
+
from ..trainer._config import LoggingConfig as LoggingConfig
|
|
7
|
+
from ..trainer._config import OptimizationConfig as OptimizationConfig
|
|
8
|
+
from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
|
|
9
|
+
from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
|
10
|
+
from ..util._environment_info import (
|
|
11
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
12
|
+
)
|
|
13
|
+
from ..util._environment_info import EnvironmentConfig as EnvironmentConfig
|
|
14
|
+
from ..util._environment_info import (
|
|
15
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
16
|
+
)
|
|
17
|
+
from ..util._environment_info import (
|
|
18
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
19
|
+
)
|