nshtrainer 1.0.0b25__py3-none-any.whl → 1.0.0b27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/.nshconfig.generated.json +6 -0
- nshtrainer/_checkpoint/metadata.py +1 -1
- nshtrainer/callbacks/__init__.py +3 -0
- nshtrainer/callbacks/actsave.py +2 -2
- nshtrainer/callbacks/base.py +5 -3
- nshtrainer/callbacks/checkpoint/__init__.py +4 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
- nshtrainer/callbacks/print_table.py +2 -2
- nshtrainer/callbacks/shared_parameters.py +5 -3
- nshtrainer/configs/__init__.py +99 -10
- nshtrainer/configs/_checkpoint/__init__.py +6 -0
- nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
- nshtrainer/configs/_directory/__init__.py +5 -1
- nshtrainer/configs/_hf_hub/__init__.py +6 -0
- nshtrainer/configs/callbacks/__init__.py +48 -1
- nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
- nshtrainer/configs/callbacks/base/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +20 -0
- nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
- nshtrainer/configs/callbacks/ema/__init__.py +5 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
- nshtrainer/configs/callbacks/timer/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
- nshtrainer/configs/loggers/__init__.py +16 -1
- nshtrainer/configs/loggers/_base/__init__.py +4 -0
- nshtrainer/configs/loggers/actsave/__init__.py +5 -0
- nshtrainer/configs/loggers/csv/__init__.py +5 -0
- nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
- nshtrainer/configs/loggers/wandb/__init__.py +8 -0
- nshtrainer/configs/lr_scheduler/__init__.py +10 -4
- nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
- nshtrainer/configs/metrics/__init__.py +5 -0
- nshtrainer/configs/metrics/_config/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +21 -1
- nshtrainer/configs/nn/mlp/__init__.py +5 -1
- nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
- nshtrainer/configs/optimizer/__init__.py +5 -1
- nshtrainer/configs/profiler/__init__.py +11 -1
- nshtrainer/configs/profiler/_base/__init__.py +4 -0
- nshtrainer/configs/profiler/advanced/__init__.py +5 -0
- nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
- nshtrainer/configs/profiler/simple/__init__.py +5 -0
- nshtrainer/configs/trainer/__init__.py +39 -6
- nshtrainer/configs/trainer/_config/__init__.py +37 -6
- nshtrainer/configs/trainer/trainer/__init__.py +9 -0
- nshtrainer/configs/util/__init__.py +19 -1
- nshtrainer/configs/util/_environment_info/__init__.py +14 -0
- nshtrainer/configs/util/config/__init__.py +8 -1
- nshtrainer/configs/util/config/dtype/__init__.py +4 -0
- nshtrainer/configs/util/config/duration/__init__.py +5 -1
- nshtrainer/loggers/__init__.py +12 -5
- nshtrainer/lr_scheduler/__init__.py +9 -5
- nshtrainer/model/mixins/callback.py +6 -4
- nshtrainer/optimizer.py +5 -3
- nshtrainer/profiler/__init__.py +9 -5
- nshtrainer/trainer/_config.py +85 -61
- nshtrainer/trainer/_runtime_callback.py +3 -3
- nshtrainer/trainer/signal_connector.py +6 -4
- nshtrainer/trainer/trainer.py +4 -4
- nshtrainer/util/_useful_types.py +11 -2
- nshtrainer/util/config/dtype.py +46 -43
- nshtrainer/util/path.py +3 -2
- {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/METADATA +2 -1
- nshtrainer-1.0.0b27.dist-info/RECORD +143 -0
- {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/WHEEL +1 -1
- nshtrainer-1.0.0b25.dist-info/RECORD +0 -140
nshtrainer/callbacks/__init__.py
CHANGED
@@ -14,6 +14,8 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
|
|
14
14
|
from .checkpoint import (
|
15
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
+
from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
|
17
19
|
from .debug_flag import DebugFlagCallback as DebugFlagCallback
|
18
20
|
from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
19
21
|
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
@@ -71,6 +73,7 @@ CallbackConfig = Annotated[
|
|
71
73
|
| BestCheckpointCallbackConfig
|
72
74
|
| LastCheckpointCallbackConfig
|
73
75
|
| OnExceptionCheckpointCallbackConfig
|
76
|
+
| TimeCheckpointCallbackConfig
|
74
77
|
| SharedParametersCallbackConfig
|
75
78
|
| RLPSanityChecksCallbackConfig
|
76
79
|
| WandbWatchCallbackConfig
|
nshtrainer/callbacks/actsave.py
CHANGED
@@ -4,12 +4,12 @@ import contextlib
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Literal
|
6
6
|
|
7
|
-
from typing_extensions import
|
7
|
+
from typing_extensions import TypeAliasType, override
|
8
8
|
|
9
9
|
from .._callback import NTCallbackBase
|
10
10
|
from .base import CallbackConfigBase
|
11
11
|
|
12
|
-
Stage
|
12
|
+
Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
|
13
13
|
|
14
14
|
|
15
15
|
class ActSaveConfig(CallbackConfigBase):
|
nshtrainer/callbacks/base.py
CHANGED
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
|
|
4
4
|
from collections import Counter
|
5
5
|
from collections.abc import Iterable
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, ClassVar
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
8
8
|
|
9
9
|
import nshconfig as C
|
10
10
|
from lightning.pytorch import Callback
|
11
|
-
from typing_extensions import TypedDict, Unpack
|
11
|
+
from typing_extensions import TypeAliasType, TypedDict, Unpack
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from ..trainer._config import TrainerConfig
|
@@ -30,7 +30,9 @@ class CallbackWithMetadata:
|
|
30
30
|
metadata: CallbackMetadataConfig
|
31
31
|
|
32
32
|
|
33
|
-
ConstructedCallback
|
33
|
+
ConstructedCallback = TypeAliasType(
|
34
|
+
"ConstructedCallback", Callback | CallbackWithMetadata
|
35
|
+
)
|
34
36
|
|
35
37
|
|
36
38
|
class CallbackConfigBase(C.Config, ABC):
|
@@ -14,3 +14,7 @@ from .on_exception_checkpoint import (
|
|
14
14
|
from .on_exception_checkpoint import (
|
15
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
+
from .time_checkpoint import (
|
19
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
20
|
+
)
|
@@ -7,8 +7,7 @@ from typing import Literal
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from typing_extensions import final, override
|
9
9
|
|
10
|
-
from
|
11
|
-
|
10
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
12
11
|
from ...metrics._config import MetricConfig
|
13
12
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
14
13
|
|
@@ -6,8 +6,7 @@ from typing import Literal
|
|
6
6
|
from lightning.pytorch import LightningModule, Trainer
|
7
7
|
from typing_extensions import final, override
|
8
8
|
|
9
|
-
from
|
10
|
-
|
9
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
11
10
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
12
11
|
|
13
12
|
log = logging.getLogger(__name__)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from datetime import timedelta
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
from lightning.pytorch import LightningModule, Trainer
|
10
|
+
from typing_extensions import final, override
|
11
|
+
|
12
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
13
|
+
|
14
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
@final
|
20
|
+
class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
21
|
+
name: Literal["time_checkpoint"] = "time_checkpoint"
|
22
|
+
|
23
|
+
interval: timedelta = timedelta(hours=12)
|
24
|
+
"""Time interval between checkpoints."""
|
25
|
+
|
26
|
+
@override
|
27
|
+
def create_checkpoint(self, trainer_config, dirpath):
|
28
|
+
return TimeCheckpointCallback(self, dirpath)
|
29
|
+
|
30
|
+
|
31
|
+
@final
|
32
|
+
class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
33
|
+
def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
|
34
|
+
super().__init__(config, dirpath)
|
35
|
+
self.start_time = time.time()
|
36
|
+
self.last_checkpoint_time = self.start_time
|
37
|
+
self.interval_seconds = config.interval.total_seconds()
|
38
|
+
|
39
|
+
@override
|
40
|
+
def name(self):
|
41
|
+
return "time"
|
42
|
+
|
43
|
+
@override
|
44
|
+
def default_filename(self):
|
45
|
+
return "epoch{epoch}-step{step}-duration{train_duration}"
|
46
|
+
|
47
|
+
@override
|
48
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
49
|
+
return metadata.checkpoint_timestamp
|
50
|
+
|
51
|
+
@override
|
52
|
+
def topk_sort_reverse(self):
|
53
|
+
return True
|
54
|
+
|
55
|
+
def _should_checkpoint(self) -> bool:
|
56
|
+
current_time = time.time()
|
57
|
+
elapsed_time = current_time - self.last_checkpoint_time
|
58
|
+
return elapsed_time >= self.interval_seconds
|
59
|
+
|
60
|
+
def _format_duration(self, seconds: float) -> str:
|
61
|
+
"""Format duration in seconds to a human-readable string."""
|
62
|
+
td = timedelta(seconds=int(seconds))
|
63
|
+
days = td.days
|
64
|
+
hours, remainder = divmod(td.seconds, 3600)
|
65
|
+
minutes, seconds = divmod(remainder, 60)
|
66
|
+
|
67
|
+
parts = []
|
68
|
+
if days > 0:
|
69
|
+
parts.append(f"{days}d")
|
70
|
+
if hours > 0:
|
71
|
+
parts.append(f"{hours}h")
|
72
|
+
if minutes > 0:
|
73
|
+
parts.append(f"{minutes}m")
|
74
|
+
if seconds > 0 or not parts:
|
75
|
+
parts.append(f"{seconds}s")
|
76
|
+
|
77
|
+
return "_".join(parts)
|
78
|
+
|
79
|
+
@override
|
80
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
81
|
+
metrics = super().current_metrics(trainer)
|
82
|
+
train_duration = time.time() - self.start_time
|
83
|
+
metrics["train_duration"] = self._format_duration(train_duration)
|
84
|
+
return metrics
|
85
|
+
|
86
|
+
@override
|
87
|
+
def on_train_batch_end(
|
88
|
+
self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
|
89
|
+
):
|
90
|
+
if self._should_checkpoint():
|
91
|
+
self.save_checkpoints(trainer)
|
92
|
+
self.last_checkpoint_time = time.time()
|
93
|
+
|
94
|
+
@override
|
95
|
+
def state_dict(self) -> dict[str, Any]:
|
96
|
+
"""Save the timer state for checkpoint resumption.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Dictionary containing the start time and last checkpoint time.
|
100
|
+
"""
|
101
|
+
return {
|
102
|
+
"start_time": self.start_time,
|
103
|
+
"last_checkpoint_time": self.last_checkpoint_time,
|
104
|
+
}
|
105
|
+
|
106
|
+
@override
|
107
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
108
|
+
"""Restore the timer state when resuming from a checkpoint.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
state_dict: Dictionary containing the previously saved timer state.
|
112
|
+
"""
|
113
|
+
self.start_time = state_dict["start_time"]
|
114
|
+
self.last_checkpoint_time = state_dict["last_checkpoint_time"]
|
@@ -49,14 +49,14 @@ class PrintTableMetricsCallback(Callback):
|
|
49
49
|
}
|
50
50
|
self.metrics.append(metrics_dict)
|
51
51
|
|
52
|
-
from rich.console import Console
|
52
|
+
from rich.console import Console # type: ignore[reportMissingImports] # noqa
|
53
53
|
|
54
54
|
console = Console()
|
55
55
|
table = self.create_metrics_table()
|
56
56
|
console.print(table)
|
57
57
|
|
58
58
|
def create_metrics_table(self):
|
59
|
-
from rich.table import Table
|
59
|
+
from rich.table import Table # type: ignore[reportMissingImports] # noqa
|
60
60
|
|
61
61
|
table = Table(show_header=True, header_style="bold magenta")
|
62
62
|
|
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from collections.abc import Iterable
|
5
|
-
from typing import Literal, Protocol,
|
5
|
+
from typing import Literal, Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch.nn as nn
|
8
8
|
from lightning.pytorch import LightningModule, Trainer
|
9
9
|
from lightning.pytorch.callbacks import Callback
|
10
|
-
from typing_extensions import override
|
10
|
+
from typing_extensions import TypeAliasType, override
|
11
11
|
|
12
12
|
from .base import CallbackConfigBase
|
13
13
|
|
@@ -34,7 +34,9 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
|
|
34
34
|
yield SharedParametersCallback(self)
|
35
35
|
|
36
36
|
|
37
|
-
SharedParametersList
|
37
|
+
SharedParametersList = TypeAliasType(
|
38
|
+
"SharedParametersList", list[tuple[nn.Parameter, int | float]]
|
39
|
+
)
|
38
40
|
|
39
41
|
|
40
42
|
@runtime_checkable
|
nshtrainer/configs/__init__.py
CHANGED
@@ -14,7 +14,6 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
14
|
from nshtrainer.callbacks import (
|
15
15
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
-
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
18
17
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
19
18
|
from nshtrainer.callbacks import (
|
20
19
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
@@ -47,6 +46,9 @@ from nshtrainer.callbacks import (
|
|
47
46
|
from nshtrainer.callbacks import (
|
48
47
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
49
48
|
)
|
49
|
+
from nshtrainer.callbacks import (
|
50
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
51
|
+
)
|
50
52
|
from nshtrainer.callbacks import (
|
51
53
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
52
54
|
)
|
@@ -58,13 +60,11 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
58
60
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
59
61
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
60
62
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
61
|
-
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
62
63
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
63
64
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
64
65
|
from nshtrainer.lr_scheduler import (
|
65
66
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
66
67
|
)
|
67
|
-
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
68
68
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
69
69
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
70
70
|
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
@@ -73,7 +73,6 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
73
73
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
74
74
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
75
75
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
76
|
-
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
77
76
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
78
77
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
79
78
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -87,23 +86,21 @@ from nshtrainer.nn.nonlinearity import (
|
|
87
86
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
88
87
|
)
|
89
88
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
90
|
-
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
91
89
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
92
90
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
93
91
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
94
|
-
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
95
92
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
96
93
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
97
|
-
from nshtrainer.trainer._config import
|
98
|
-
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
99
|
-
)
|
94
|
+
from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
|
100
95
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
101
96
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
102
97
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
103
98
|
from nshtrainer.trainer._config import (
|
104
99
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
105
100
|
)
|
101
|
+
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
106
102
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
103
|
+
from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
107
104
|
from nshtrainer.util._environment_info import (
|
108
105
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
109
106
|
)
|
@@ -133,7 +130,6 @@ from nshtrainer.util._environment_info import (
|
|
133
130
|
)
|
134
131
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
135
132
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
136
|
-
from nshtrainer.util.config import DurationConfig as DurationConfig
|
137
133
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
138
134
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
139
135
|
|
@@ -149,3 +145,96 @@ from . import optimizer as optimizer
|
|
149
145
|
from . import profiler as profiler
|
150
146
|
from . import trainer as trainer
|
151
147
|
from . import util as util
|
148
|
+
|
149
|
+
__all__ = [
|
150
|
+
"AcceleratorConfigBase",
|
151
|
+
"ActSaveConfig",
|
152
|
+
"ActSaveLoggerConfig",
|
153
|
+
"AdamWConfig",
|
154
|
+
"AdvancedProfilerConfig",
|
155
|
+
"BaseCheckpointCallbackConfig",
|
156
|
+
"BaseLoggerConfig",
|
157
|
+
"BaseNonlinearityConfig",
|
158
|
+
"BaseProfilerConfig",
|
159
|
+
"BestCheckpointCallbackConfig",
|
160
|
+
"CSVLoggerConfig",
|
161
|
+
"CallbackConfigBase",
|
162
|
+
"CheckpointMetadata",
|
163
|
+
"CheckpointSavingConfig",
|
164
|
+
"DTypeConfig",
|
165
|
+
"DebugFlagCallbackConfig",
|
166
|
+
"DirectoryConfig",
|
167
|
+
"DirectorySetupCallbackConfig",
|
168
|
+
"ELUNonlinearityConfig",
|
169
|
+
"EMACallbackConfig",
|
170
|
+
"EarlyStoppingCallbackConfig",
|
171
|
+
"EnvironmentCUDAConfig",
|
172
|
+
"EnvironmentClassInformationConfig",
|
173
|
+
"EnvironmentConfig",
|
174
|
+
"EnvironmentGPUConfig",
|
175
|
+
"EnvironmentHardwareConfig",
|
176
|
+
"EnvironmentLSFInformationConfig",
|
177
|
+
"EnvironmentLinuxEnvironmentConfig",
|
178
|
+
"EnvironmentPackageConfig",
|
179
|
+
"EnvironmentSLURMInformationConfig",
|
180
|
+
"EnvironmentSnapshotConfig",
|
181
|
+
"EpochTimerCallbackConfig",
|
182
|
+
"EpochsConfig",
|
183
|
+
"FiniteChecksCallbackConfig",
|
184
|
+
"GELUNonlinearityConfig",
|
185
|
+
"GitRepositoryConfig",
|
186
|
+
"GradientClippingConfig",
|
187
|
+
"GradientSkippingCallbackConfig",
|
188
|
+
"HuggingFaceHubAutoCreateConfig",
|
189
|
+
"HuggingFaceHubConfig",
|
190
|
+
"LRSchedulerConfigBase",
|
191
|
+
"LastCheckpointCallbackConfig",
|
192
|
+
"LeakyReLUNonlinearityConfig",
|
193
|
+
"LearningRateMonitorConfig",
|
194
|
+
"LinearWarmupCosineDecayLRSchedulerConfig",
|
195
|
+
"LogEpochCallbackConfig",
|
196
|
+
"MLPConfig",
|
197
|
+
"MetricConfig",
|
198
|
+
"MishNonlinearityConfig",
|
199
|
+
"NormLoggingCallbackConfig",
|
200
|
+
"OnExceptionCheckpointCallbackConfig",
|
201
|
+
"OptimizerConfigBase",
|
202
|
+
"PReLUConfig",
|
203
|
+
"PluginConfigBase",
|
204
|
+
"PrintTableMetricsCallbackConfig",
|
205
|
+
"PyTorchProfilerConfig",
|
206
|
+
"RLPSanityChecksCallbackConfig",
|
207
|
+
"ReLUNonlinearityConfig",
|
208
|
+
"ReduceLROnPlateauConfig",
|
209
|
+
"SanityCheckingConfig",
|
210
|
+
"SharedParametersCallbackConfig",
|
211
|
+
"SiLUNonlinearityConfig",
|
212
|
+
"SigmoidNonlinearityConfig",
|
213
|
+
"SimpleProfilerConfig",
|
214
|
+
"SoftmaxNonlinearityConfig",
|
215
|
+
"SoftplusNonlinearityConfig",
|
216
|
+
"SoftsignNonlinearityConfig",
|
217
|
+
"StepsConfig",
|
218
|
+
"StrategyConfigBase",
|
219
|
+
"SwiGLUNonlinearityConfig",
|
220
|
+
"SwishNonlinearityConfig",
|
221
|
+
"TanhNonlinearityConfig",
|
222
|
+
"TensorboardLoggerConfig",
|
223
|
+
"TimeCheckpointCallbackConfig",
|
224
|
+
"TrainerConfig",
|
225
|
+
"WandbLoggerConfig",
|
226
|
+
"WandbUploadCodeCallbackConfig",
|
227
|
+
"WandbWatchCallbackConfig",
|
228
|
+
"_checkpoint",
|
229
|
+
"_directory",
|
230
|
+
"_hf_hub",
|
231
|
+
"callbacks",
|
232
|
+
"loggers",
|
233
|
+
"lr_scheduler",
|
234
|
+
"metrics",
|
235
|
+
"nn",
|
236
|
+
"optimizer",
|
237
|
+
"profiler",
|
238
|
+
"trainer",
|
239
|
+
"util",
|
240
|
+
]
|
@@ -6,3 +6,9 @@ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMeta
|
|
6
6
|
from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
|
7
7
|
|
8
8
|
from . import metadata as metadata
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CheckpointMetadata",
|
12
|
+
"EnvironmentConfig",
|
13
|
+
"metadata",
|
14
|
+
]
|
@@ -6,4 +6,8 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
6
6
|
from nshtrainer._directory import (
|
7
7
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
8
|
)
|
9
|
-
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"DirectoryConfig",
|
12
|
+
"DirectorySetupCallbackConfig",
|
13
|
+
]
|
@@ -7,3 +7,9 @@ from nshtrainer._hf_hub import (
|
|
7
7
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"CallbackConfigBase",
|
13
|
+
"HuggingFaceHubAutoCreateConfig",
|
14
|
+
"HuggingFaceHubConfig",
|
15
|
+
]
|
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.callbacks import (
|
6
6
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
7
|
)
|
8
|
-
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
9
8
|
from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
|
10
9
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
11
10
|
from nshtrainer.callbacks import (
|
@@ -39,6 +38,9 @@ from nshtrainer.callbacks import (
|
|
39
38
|
from nshtrainer.callbacks import (
|
40
39
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
41
40
|
)
|
41
|
+
from nshtrainer.callbacks import (
|
42
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
43
|
+
)
|
42
44
|
from nshtrainer.callbacks import (
|
43
45
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
44
46
|
)
|
@@ -73,3 +75,48 @@ from . import shared_parameters as shared_parameters
|
|
73
75
|
from . import timer as timer
|
74
76
|
from . import wandb_upload_code as wandb_upload_code
|
75
77
|
from . import wandb_watch as wandb_watch
|
78
|
+
|
79
|
+
__all__ = [
|
80
|
+
"ActSaveConfig",
|
81
|
+
"BaseCheckpointCallbackConfig",
|
82
|
+
"BestCheckpointCallbackConfig",
|
83
|
+
"CallbackConfigBase",
|
84
|
+
"CheckpointMetadata",
|
85
|
+
"DebugFlagCallbackConfig",
|
86
|
+
"DirectorySetupCallbackConfig",
|
87
|
+
"EMACallbackConfig",
|
88
|
+
"EarlyStoppingCallbackConfig",
|
89
|
+
"EpochTimerCallbackConfig",
|
90
|
+
"FiniteChecksCallbackConfig",
|
91
|
+
"GradientSkippingCallbackConfig",
|
92
|
+
"LastCheckpointCallbackConfig",
|
93
|
+
"LearningRateMonitorConfig",
|
94
|
+
"LogEpochCallbackConfig",
|
95
|
+
"MetricConfig",
|
96
|
+
"NormLoggingCallbackConfig",
|
97
|
+
"OnExceptionCheckpointCallbackConfig",
|
98
|
+
"PrintTableMetricsCallbackConfig",
|
99
|
+
"RLPSanityChecksCallbackConfig",
|
100
|
+
"SharedParametersCallbackConfig",
|
101
|
+
"TimeCheckpointCallbackConfig",
|
102
|
+
"WandbUploadCodeCallbackConfig",
|
103
|
+
"WandbWatchCallbackConfig",
|
104
|
+
"actsave",
|
105
|
+
"base",
|
106
|
+
"checkpoint",
|
107
|
+
"debug_flag",
|
108
|
+
"directory_setup",
|
109
|
+
"early_stopping",
|
110
|
+
"ema",
|
111
|
+
"finite_checks",
|
112
|
+
"gradient_skipping",
|
113
|
+
"log_epoch",
|
114
|
+
"lr_monitor",
|
115
|
+
"norm_logging",
|
116
|
+
"print_table",
|
117
|
+
"rlp_sanity_checks",
|
118
|
+
"shared_parameters",
|
119
|
+
"timer",
|
120
|
+
"wandb_upload_code",
|
121
|
+
"wandb_watch",
|
122
|
+
]
|
@@ -11,6 +11,9 @@ from nshtrainer.callbacks.checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint import (
|
12
12
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
from nshtrainer.callbacks.checkpoint import (
|
15
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
16
|
+
)
|
14
17
|
from nshtrainer.callbacks.checkpoint._base import (
|
15
18
|
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
16
19
|
)
|
@@ -26,3 +29,20 @@ from . import _base as _base
|
|
26
29
|
from . import best_checkpoint as best_checkpoint
|
27
30
|
from . import last_checkpoint as last_checkpoint
|
28
31
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
32
|
+
from . import time_checkpoint as time_checkpoint
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
"BaseCheckpointCallbackConfig",
|
36
|
+
"BestCheckpointCallbackConfig",
|
37
|
+
"CallbackConfigBase",
|
38
|
+
"CheckpointMetadata",
|
39
|
+
"LastCheckpointCallbackConfig",
|
40
|
+
"MetricConfig",
|
41
|
+
"OnExceptionCheckpointCallbackConfig",
|
42
|
+
"TimeCheckpointCallbackConfig",
|
43
|
+
"_base",
|
44
|
+
"best_checkpoint",
|
45
|
+
"last_checkpoint",
|
46
|
+
"on_exception_checkpoint",
|
47
|
+
"time_checkpoint",
|
48
|
+
]
|
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint._base import (
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CallbackConfigBase",
|
18
|
+
"CheckpointMetadata",
|
19
|
+
]
|
@@ -12,3 +12,10 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
14
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"BaseCheckpointCallbackConfig",
|
18
|
+
"BestCheckpointCallbackConfig",
|
19
|
+
"CheckpointMetadata",
|
20
|
+
"MetricConfig",
|
21
|
+
]
|
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
12
12
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CheckpointMetadata",
|
18
|
+
"LastCheckpointCallbackConfig",
|
19
|
+
]
|
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
|
8
8
|
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
9
9
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"CallbackConfigBase",
|
14
|
+
"OnExceptionCheckpointCallbackConfig",
|
15
|
+
]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
6
|
+
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
7
|
+
)
|
8
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
9
|
+
CheckpointMetadata as CheckpointMetadata,
|
10
|
+
)
|
11
|
+
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
12
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
13
|
+
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CheckpointMetadata",
|
18
|
+
"TimeCheckpointCallbackConfig",
|
19
|
+
]
|