nshtrainer 1.0.0b24__py3-none-any.whl → 1.0.0b26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/.nshconfig.generated.json +6 -0
- nshtrainer/_checkpoint/metadata.py +1 -1
- nshtrainer/callbacks/__init__.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +4 -0
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
- nshtrainer/callbacks/print_table.py +2 -2
- nshtrainer/configs/__init__.py +95 -10
- nshtrainer/configs/_checkpoint/__init__.py +6 -0
- nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
- nshtrainer/configs/_directory/__init__.py +5 -1
- nshtrainer/configs/_hf_hub/__init__.py +6 -0
- nshtrainer/configs/callbacks/__init__.py +44 -1
- nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
- nshtrainer/configs/callbacks/base/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +14 -0
- nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
- nshtrainer/configs/callbacks/ema/__init__.py +5 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
- nshtrainer/configs/callbacks/timer/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
- nshtrainer/configs/loggers/__init__.py +16 -1
- nshtrainer/configs/loggers/_base/__init__.py +4 -0
- nshtrainer/configs/loggers/actsave/__init__.py +5 -0
- nshtrainer/configs/loggers/csv/__init__.py +5 -0
- nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
- nshtrainer/configs/loggers/wandb/__init__.py +8 -0
- nshtrainer/configs/lr_scheduler/__init__.py +10 -4
- nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
- nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
- nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
- nshtrainer/configs/metrics/__init__.py +5 -0
- nshtrainer/configs/metrics/_config/__init__.py +4 -0
- nshtrainer/configs/nn/__init__.py +21 -1
- nshtrainer/configs/nn/mlp/__init__.py +5 -1
- nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
- nshtrainer/configs/optimizer/__init__.py +5 -1
- nshtrainer/configs/profiler/__init__.py +11 -1
- nshtrainer/configs/profiler/_base/__init__.py +4 -0
- nshtrainer/configs/profiler/advanced/__init__.py +5 -0
- nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
- nshtrainer/configs/profiler/simple/__init__.py +5 -0
- nshtrainer/configs/trainer/__init__.py +35 -6
- nshtrainer/configs/trainer/_config/__init__.py +33 -6
- nshtrainer/configs/trainer/trainer/__init__.py +9 -0
- nshtrainer/configs/util/__init__.py +19 -1
- nshtrainer/configs/util/_environment_info/__init__.py +14 -0
- nshtrainer/configs/util/config/__init__.py +8 -1
- nshtrainer/configs/util/config/dtype/__init__.py +4 -0
- nshtrainer/configs/util/config/duration/__init__.py +5 -1
- nshtrainer/model/mixins/logger.py +30 -12
- nshtrainer/trainer/_config.py +40 -21
- nshtrainer/trainer/trainer.py +4 -4
- {nshtrainer-1.0.0b24.dist-info → nshtrainer-1.0.0b26.dist-info}/METADATA +2 -1
- {nshtrainer-1.0.0b24.dist-info → nshtrainer-1.0.0b26.dist-info}/RECORD +71 -69
- {nshtrainer-1.0.0b24.dist-info → nshtrainer-1.0.0b26.dist-info}/WHEEL +1 -1
nshtrainer/callbacks/__init__.py
CHANGED
@@ -14,6 +14,8 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
|
|
14
14
|
from .checkpoint import (
|
15
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
+
from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
|
17
19
|
from .debug_flag import DebugFlagCallback as DebugFlagCallback
|
18
20
|
from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
19
21
|
from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
|
@@ -71,6 +73,7 @@ CallbackConfig = Annotated[
|
|
71
73
|
| BestCheckpointCallbackConfig
|
72
74
|
| LastCheckpointCallbackConfig
|
73
75
|
| OnExceptionCheckpointCallbackConfig
|
76
|
+
| TimeCheckpointCallbackConfig
|
74
77
|
| SharedParametersCallbackConfig
|
75
78
|
| RLPSanityChecksCallbackConfig
|
76
79
|
| WandbWatchCallbackConfig
|
@@ -14,3 +14,7 @@ from .on_exception_checkpoint import (
|
|
14
14
|
from .on_exception_checkpoint import (
|
15
15
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
+
from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
|
18
|
+
from .time_checkpoint import (
|
19
|
+
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
20
|
+
)
|
@@ -7,8 +7,7 @@ from typing import Literal
|
|
7
7
|
from lightning.pytorch import LightningModule, Trainer
|
8
8
|
from typing_extensions import final, override
|
9
9
|
|
10
|
-
from
|
11
|
-
|
10
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
12
11
|
from ...metrics._config import MetricConfig
|
13
12
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
14
13
|
|
@@ -6,8 +6,7 @@ from typing import Literal
|
|
6
6
|
from lightning.pytorch import LightningModule, Trainer
|
7
7
|
from typing_extensions import final, override
|
8
8
|
|
9
|
-
from
|
10
|
-
|
9
|
+
from ..._checkpoint.metadata import CheckpointMetadata
|
11
10
|
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
12
11
|
|
13
12
|
log = logging.getLogger(__name__)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
from datetime import timedelta
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
from lightning.pytorch import LightningModule, Trainer
|
10
|
+
from typing_extensions import final, override
|
11
|
+
|
12
|
+
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
13
|
+
|
14
|
+
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
15
|
+
|
16
|
+
log = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
@final
|
20
|
+
class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
21
|
+
name: Literal["time_checkpoint"] = "time_checkpoint"
|
22
|
+
|
23
|
+
interval: timedelta = timedelta(hours=12)
|
24
|
+
"""Time interval between checkpoints."""
|
25
|
+
|
26
|
+
@override
|
27
|
+
def create_checkpoint(self, trainer_config, dirpath):
|
28
|
+
return TimeCheckpointCallback(self, dirpath)
|
29
|
+
|
30
|
+
|
31
|
+
@final
|
32
|
+
class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
33
|
+
def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
|
34
|
+
super().__init__(config, dirpath)
|
35
|
+
self.start_time = time.time()
|
36
|
+
self.last_checkpoint_time = self.start_time
|
37
|
+
self.interval_seconds = config.interval.total_seconds()
|
38
|
+
|
39
|
+
@override
|
40
|
+
def name(self):
|
41
|
+
return "time"
|
42
|
+
|
43
|
+
@override
|
44
|
+
def default_filename(self):
|
45
|
+
return "epoch{epoch}-step{step}-duration{train_duration}"
|
46
|
+
|
47
|
+
@override
|
48
|
+
def topk_sort_key(self, metadata: CheckpointMetadata):
|
49
|
+
return metadata.checkpoint_timestamp
|
50
|
+
|
51
|
+
@override
|
52
|
+
def topk_sort_reverse(self):
|
53
|
+
return True
|
54
|
+
|
55
|
+
def _should_checkpoint(self) -> bool:
|
56
|
+
current_time = time.time()
|
57
|
+
elapsed_time = current_time - self.last_checkpoint_time
|
58
|
+
return elapsed_time >= self.interval_seconds
|
59
|
+
|
60
|
+
def _format_duration(self, seconds: float) -> str:
|
61
|
+
"""Format duration in seconds to a human-readable string."""
|
62
|
+
td = timedelta(seconds=int(seconds))
|
63
|
+
days = td.days
|
64
|
+
hours, remainder = divmod(td.seconds, 3600)
|
65
|
+
minutes, seconds = divmod(remainder, 60)
|
66
|
+
|
67
|
+
parts = []
|
68
|
+
if days > 0:
|
69
|
+
parts.append(f"{days}d")
|
70
|
+
if hours > 0:
|
71
|
+
parts.append(f"{hours}h")
|
72
|
+
if minutes > 0:
|
73
|
+
parts.append(f"{minutes}m")
|
74
|
+
if seconds > 0 or not parts:
|
75
|
+
parts.append(f"{seconds}s")
|
76
|
+
|
77
|
+
return "_".join(parts)
|
78
|
+
|
79
|
+
@override
|
80
|
+
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
81
|
+
metrics = super().current_metrics(trainer)
|
82
|
+
train_duration = time.time() - self.start_time
|
83
|
+
metrics["train_duration"] = self._format_duration(train_duration)
|
84
|
+
return metrics
|
85
|
+
|
86
|
+
@override
|
87
|
+
def on_train_batch_end(
|
88
|
+
self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
|
89
|
+
):
|
90
|
+
if self._should_checkpoint():
|
91
|
+
self.save_checkpoints(trainer)
|
92
|
+
self.last_checkpoint_time = time.time()
|
93
|
+
|
94
|
+
@override
|
95
|
+
def state_dict(self) -> dict[str, Any]:
|
96
|
+
"""Save the timer state for checkpoint resumption.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Dictionary containing the start time and last checkpoint time.
|
100
|
+
"""
|
101
|
+
return {
|
102
|
+
"start_time": self.start_time,
|
103
|
+
"last_checkpoint_time": self.last_checkpoint_time,
|
104
|
+
}
|
105
|
+
|
106
|
+
@override
|
107
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
108
|
+
"""Restore the timer state when resuming from a checkpoint.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
state_dict: Dictionary containing the previously saved timer state.
|
112
|
+
"""
|
113
|
+
self.start_time = state_dict["start_time"]
|
114
|
+
self.last_checkpoint_time = state_dict["last_checkpoint_time"]
|
@@ -49,14 +49,14 @@ class PrintTableMetricsCallback(Callback):
|
|
49
49
|
}
|
50
50
|
self.metrics.append(metrics_dict)
|
51
51
|
|
52
|
-
from rich.console import Console
|
52
|
+
from rich.console import Console # type: ignore[reportMissingImports] # noqa
|
53
53
|
|
54
54
|
console = Console()
|
55
55
|
table = self.create_metrics_table()
|
56
56
|
console.print(table)
|
57
57
|
|
58
58
|
def create_metrics_table(self):
|
59
|
-
from rich.table import Table
|
59
|
+
from rich.table import Table # type: ignore[reportMissingImports] # noqa
|
60
60
|
|
61
61
|
table = Table(show_header=True, header_style="bold magenta")
|
62
62
|
|
nshtrainer/configs/__init__.py
CHANGED
@@ -14,7 +14,6 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
|
14
14
|
from nshtrainer.callbacks import (
|
15
15
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
16
16
|
)
|
17
|
-
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
18
17
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
19
18
|
from nshtrainer.callbacks import (
|
20
19
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
@@ -58,13 +57,11 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
58
57
|
from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
|
59
58
|
from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
|
60
59
|
from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
|
61
|
-
from nshtrainer.loggers import LoggerConfig as LoggerConfig
|
62
60
|
from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
|
63
61
|
from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
|
64
62
|
from nshtrainer.lr_scheduler import (
|
65
63
|
LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
|
66
64
|
)
|
67
|
-
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
68
65
|
from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
|
69
66
|
from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
|
70
67
|
from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
|
@@ -73,7 +70,6 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
|
|
73
70
|
from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
|
74
71
|
from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
|
75
72
|
from nshtrainer.nn import MLPConfig as MLPConfig
|
76
|
-
from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
|
77
73
|
from nshtrainer.nn import PReLUConfig as PReLUConfig
|
78
74
|
from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
|
79
75
|
from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
|
@@ -87,23 +83,21 @@ from nshtrainer.nn.nonlinearity import (
|
|
87
83
|
SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
|
88
84
|
)
|
89
85
|
from nshtrainer.optimizer import AdamWConfig as AdamWConfig
|
90
|
-
from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
|
91
86
|
from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
|
92
87
|
from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
|
93
88
|
from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
94
|
-
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
95
89
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
96
90
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
97
|
-
from nshtrainer.trainer._config import
|
98
|
-
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
99
|
-
)
|
91
|
+
from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
|
100
92
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
101
93
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
102
94
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
103
95
|
from nshtrainer.trainer._config import (
|
104
96
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
105
97
|
)
|
98
|
+
from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
|
106
99
|
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
100
|
+
from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
|
107
101
|
from nshtrainer.util._environment_info import (
|
108
102
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
109
103
|
)
|
@@ -133,7 +127,6 @@ from nshtrainer.util._environment_info import (
|
|
133
127
|
)
|
134
128
|
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
135
129
|
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
136
|
-
from nshtrainer.util.config import DurationConfig as DurationConfig
|
137
130
|
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
138
131
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
139
132
|
|
@@ -149,3 +142,95 @@ from . import optimizer as optimizer
|
|
149
142
|
from . import profiler as profiler
|
150
143
|
from . import trainer as trainer
|
151
144
|
from . import util as util
|
145
|
+
|
146
|
+
__all__ = [
|
147
|
+
"AcceleratorConfigBase",
|
148
|
+
"ActSaveConfig",
|
149
|
+
"ActSaveLoggerConfig",
|
150
|
+
"AdamWConfig",
|
151
|
+
"AdvancedProfilerConfig",
|
152
|
+
"BaseCheckpointCallbackConfig",
|
153
|
+
"BaseLoggerConfig",
|
154
|
+
"BaseNonlinearityConfig",
|
155
|
+
"BaseProfilerConfig",
|
156
|
+
"BestCheckpointCallbackConfig",
|
157
|
+
"CSVLoggerConfig",
|
158
|
+
"CallbackConfigBase",
|
159
|
+
"CheckpointMetadata",
|
160
|
+
"CheckpointSavingConfig",
|
161
|
+
"DTypeConfig",
|
162
|
+
"DebugFlagCallbackConfig",
|
163
|
+
"DirectoryConfig",
|
164
|
+
"DirectorySetupCallbackConfig",
|
165
|
+
"ELUNonlinearityConfig",
|
166
|
+
"EMACallbackConfig",
|
167
|
+
"EarlyStoppingCallbackConfig",
|
168
|
+
"EnvironmentCUDAConfig",
|
169
|
+
"EnvironmentClassInformationConfig",
|
170
|
+
"EnvironmentConfig",
|
171
|
+
"EnvironmentGPUConfig",
|
172
|
+
"EnvironmentHardwareConfig",
|
173
|
+
"EnvironmentLSFInformationConfig",
|
174
|
+
"EnvironmentLinuxEnvironmentConfig",
|
175
|
+
"EnvironmentPackageConfig",
|
176
|
+
"EnvironmentSLURMInformationConfig",
|
177
|
+
"EnvironmentSnapshotConfig",
|
178
|
+
"EpochTimerCallbackConfig",
|
179
|
+
"EpochsConfig",
|
180
|
+
"FiniteChecksCallbackConfig",
|
181
|
+
"GELUNonlinearityConfig",
|
182
|
+
"GitRepositoryConfig",
|
183
|
+
"GradientClippingConfig",
|
184
|
+
"GradientSkippingCallbackConfig",
|
185
|
+
"HuggingFaceHubAutoCreateConfig",
|
186
|
+
"HuggingFaceHubConfig",
|
187
|
+
"LRSchedulerConfigBase",
|
188
|
+
"LastCheckpointCallbackConfig",
|
189
|
+
"LeakyReLUNonlinearityConfig",
|
190
|
+
"LearningRateMonitorConfig",
|
191
|
+
"LinearWarmupCosineDecayLRSchedulerConfig",
|
192
|
+
"LogEpochCallbackConfig",
|
193
|
+
"MLPConfig",
|
194
|
+
"MetricConfig",
|
195
|
+
"MishNonlinearityConfig",
|
196
|
+
"NormLoggingCallbackConfig",
|
197
|
+
"OnExceptionCheckpointCallbackConfig",
|
198
|
+
"OptimizerConfigBase",
|
199
|
+
"PReLUConfig",
|
200
|
+
"PluginConfigBase",
|
201
|
+
"PrintTableMetricsCallbackConfig",
|
202
|
+
"PyTorchProfilerConfig",
|
203
|
+
"RLPSanityChecksCallbackConfig",
|
204
|
+
"ReLUNonlinearityConfig",
|
205
|
+
"ReduceLROnPlateauConfig",
|
206
|
+
"SanityCheckingConfig",
|
207
|
+
"SharedParametersCallbackConfig",
|
208
|
+
"SiLUNonlinearityConfig",
|
209
|
+
"SigmoidNonlinearityConfig",
|
210
|
+
"SimpleProfilerConfig",
|
211
|
+
"SoftmaxNonlinearityConfig",
|
212
|
+
"SoftplusNonlinearityConfig",
|
213
|
+
"SoftsignNonlinearityConfig",
|
214
|
+
"StepsConfig",
|
215
|
+
"StrategyConfigBase",
|
216
|
+
"SwiGLUNonlinearityConfig",
|
217
|
+
"SwishNonlinearityConfig",
|
218
|
+
"TanhNonlinearityConfig",
|
219
|
+
"TensorboardLoggerConfig",
|
220
|
+
"TrainerConfig",
|
221
|
+
"WandbLoggerConfig",
|
222
|
+
"WandbUploadCodeCallbackConfig",
|
223
|
+
"WandbWatchCallbackConfig",
|
224
|
+
"_checkpoint",
|
225
|
+
"_directory",
|
226
|
+
"_hf_hub",
|
227
|
+
"callbacks",
|
228
|
+
"loggers",
|
229
|
+
"lr_scheduler",
|
230
|
+
"metrics",
|
231
|
+
"nn",
|
232
|
+
"optimizer",
|
233
|
+
"profiler",
|
234
|
+
"trainer",
|
235
|
+
"util",
|
236
|
+
]
|
@@ -6,3 +6,9 @@ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMeta
|
|
6
6
|
from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
|
7
7
|
|
8
8
|
from . import metadata as metadata
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CheckpointMetadata",
|
12
|
+
"EnvironmentConfig",
|
13
|
+
"metadata",
|
14
|
+
]
|
@@ -6,4 +6,8 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
6
6
|
from nshtrainer._directory import (
|
7
7
|
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
8
|
)
|
9
|
-
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"DirectoryConfig",
|
12
|
+
"DirectorySetupCallbackConfig",
|
13
|
+
]
|
@@ -7,3 +7,9 @@ from nshtrainer._hf_hub import (
|
|
7
7
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"CallbackConfigBase",
|
13
|
+
"HuggingFaceHubAutoCreateConfig",
|
14
|
+
"HuggingFaceHubConfig",
|
15
|
+
]
|
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer.callbacks import (
|
6
6
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
7
7
|
)
|
8
|
-
from nshtrainer.callbacks import CallbackConfig as CallbackConfig
|
9
8
|
from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
|
10
9
|
from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
|
11
10
|
from nshtrainer.callbacks import (
|
@@ -73,3 +72,47 @@ from . import shared_parameters as shared_parameters
|
|
73
72
|
from . import timer as timer
|
74
73
|
from . import wandb_upload_code as wandb_upload_code
|
75
74
|
from . import wandb_watch as wandb_watch
|
75
|
+
|
76
|
+
__all__ = [
|
77
|
+
"ActSaveConfig",
|
78
|
+
"BaseCheckpointCallbackConfig",
|
79
|
+
"BestCheckpointCallbackConfig",
|
80
|
+
"CallbackConfigBase",
|
81
|
+
"CheckpointMetadata",
|
82
|
+
"DebugFlagCallbackConfig",
|
83
|
+
"DirectorySetupCallbackConfig",
|
84
|
+
"EMACallbackConfig",
|
85
|
+
"EarlyStoppingCallbackConfig",
|
86
|
+
"EpochTimerCallbackConfig",
|
87
|
+
"FiniteChecksCallbackConfig",
|
88
|
+
"GradientSkippingCallbackConfig",
|
89
|
+
"LastCheckpointCallbackConfig",
|
90
|
+
"LearningRateMonitorConfig",
|
91
|
+
"LogEpochCallbackConfig",
|
92
|
+
"MetricConfig",
|
93
|
+
"NormLoggingCallbackConfig",
|
94
|
+
"OnExceptionCheckpointCallbackConfig",
|
95
|
+
"PrintTableMetricsCallbackConfig",
|
96
|
+
"RLPSanityChecksCallbackConfig",
|
97
|
+
"SharedParametersCallbackConfig",
|
98
|
+
"WandbUploadCodeCallbackConfig",
|
99
|
+
"WandbWatchCallbackConfig",
|
100
|
+
"actsave",
|
101
|
+
"base",
|
102
|
+
"checkpoint",
|
103
|
+
"debug_flag",
|
104
|
+
"directory_setup",
|
105
|
+
"early_stopping",
|
106
|
+
"ema",
|
107
|
+
"finite_checks",
|
108
|
+
"gradient_skipping",
|
109
|
+
"log_epoch",
|
110
|
+
"lr_monitor",
|
111
|
+
"norm_logging",
|
112
|
+
"print_table",
|
113
|
+
"rlp_sanity_checks",
|
114
|
+
"shared_parameters",
|
115
|
+
"timer",
|
116
|
+
"wandb_upload_code",
|
117
|
+
"wandb_watch",
|
118
|
+
]
|
@@ -26,3 +26,17 @@ from . import _base as _base
|
|
26
26
|
from . import best_checkpoint as best_checkpoint
|
27
27
|
from . import last_checkpoint as last_checkpoint
|
28
28
|
from . import on_exception_checkpoint as on_exception_checkpoint
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"BaseCheckpointCallbackConfig",
|
32
|
+
"BestCheckpointCallbackConfig",
|
33
|
+
"CallbackConfigBase",
|
34
|
+
"CheckpointMetadata",
|
35
|
+
"LastCheckpointCallbackConfig",
|
36
|
+
"MetricConfig",
|
37
|
+
"OnExceptionCheckpointCallbackConfig",
|
38
|
+
"_base",
|
39
|
+
"best_checkpoint",
|
40
|
+
"last_checkpoint",
|
41
|
+
"on_exception_checkpoint",
|
42
|
+
]
|
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint._base import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint._base import (
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CallbackConfigBase",
|
18
|
+
"CheckpointMetadata",
|
19
|
+
]
|
@@ -12,3 +12,10 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
|
|
12
12
|
CheckpointMetadata as CheckpointMetadata,
|
13
13
|
)
|
14
14
|
from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"BaseCheckpointCallbackConfig",
|
18
|
+
"BestCheckpointCallbackConfig",
|
19
|
+
"CheckpointMetadata",
|
20
|
+
"MetricConfig",
|
21
|
+
]
|
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
|
11
11
|
from nshtrainer.callbacks.checkpoint.last_checkpoint import (
|
12
12
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
13
13
|
)
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"BaseCheckpointCallbackConfig",
|
17
|
+
"CheckpointMetadata",
|
18
|
+
"LastCheckpointCallbackConfig",
|
19
|
+
]
|
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
|
8
8
|
from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
|
9
9
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
10
10
|
)
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"CallbackConfigBase",
|
14
|
+
"OnExceptionCheckpointCallbackConfig",
|
15
|
+
]
|
@@ -7,3 +7,9 @@ from nshtrainer.callbacks.early_stopping import (
|
|
7
7
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
8
8
|
)
|
9
9
|
from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"CallbackConfigBase",
|
13
|
+
"EarlyStoppingCallbackConfig",
|
14
|
+
"MetricConfig",
|
15
|
+
]
|
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
|
|
6
6
|
from nshtrainer.callbacks.finite_checks import (
|
7
7
|
FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
|
8
8
|
)
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CallbackConfigBase",
|
12
|
+
"FiniteChecksCallbackConfig",
|
13
|
+
]
|
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfig
|
|
6
6
|
from nshtrainer.callbacks.lr_monitor import (
|
7
7
|
LearningRateMonitorConfig as LearningRateMonitorConfig,
|
8
8
|
)
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CallbackConfigBase",
|
12
|
+
"LearningRateMonitorConfig",
|
13
|
+
]
|
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConf
|
|
6
6
|
from nshtrainer.callbacks.norm_logging import (
|
7
7
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
8
8
|
)
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CallbackConfigBase",
|
12
|
+
"NormLoggingCallbackConfig",
|
13
|
+
]
|
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfi
|
|
6
6
|
from nshtrainer.callbacks.print_table import (
|
7
7
|
PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
|
8
8
|
)
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"CallbackConfigBase",
|
12
|
+
"PrintTableMetricsCallbackConfig",
|
13
|
+
]
|