nshtrainer 1.0.0b25__py3-none-any.whl → 1.0.0b27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. nshtrainer/.nshconfig.generated.json +6 -0
  2. nshtrainer/_checkpoint/metadata.py +1 -1
  3. nshtrainer/callbacks/__init__.py +3 -0
  4. nshtrainer/callbacks/actsave.py +2 -2
  5. nshtrainer/callbacks/base.py +5 -3
  6. nshtrainer/callbacks/checkpoint/__init__.py +4 -0
  7. nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
  8. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
  9. nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
  10. nshtrainer/callbacks/print_table.py +2 -2
  11. nshtrainer/callbacks/shared_parameters.py +5 -3
  12. nshtrainer/configs/__init__.py +99 -10
  13. nshtrainer/configs/_checkpoint/__init__.py +6 -0
  14. nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
  15. nshtrainer/configs/_directory/__init__.py +5 -1
  16. nshtrainer/configs/_hf_hub/__init__.py +6 -0
  17. nshtrainer/configs/callbacks/__init__.py +48 -1
  18. nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
  19. nshtrainer/configs/callbacks/base/__init__.py +4 -0
  20. nshtrainer/configs/callbacks/checkpoint/__init__.py +20 -0
  21. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
  22. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
  23. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
  24. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
  25. nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +19 -0
  26. nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
  27. nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
  28. nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
  29. nshtrainer/configs/callbacks/ema/__init__.py +5 -0
  30. nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
  31. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
  32. nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
  33. nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
  34. nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
  35. nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
  36. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
  37. nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
  38. nshtrainer/configs/callbacks/timer/__init__.py +5 -0
  39. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
  40. nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
  41. nshtrainer/configs/loggers/__init__.py +16 -1
  42. nshtrainer/configs/loggers/_base/__init__.py +4 -0
  43. nshtrainer/configs/loggers/actsave/__init__.py +5 -0
  44. nshtrainer/configs/loggers/csv/__init__.py +5 -0
  45. nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
  46. nshtrainer/configs/loggers/wandb/__init__.py +8 -0
  47. nshtrainer/configs/lr_scheduler/__init__.py +10 -4
  48. nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
  49. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
  50. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
  51. nshtrainer/configs/metrics/__init__.py +5 -0
  52. nshtrainer/configs/metrics/_config/__init__.py +4 -0
  53. nshtrainer/configs/nn/__init__.py +21 -1
  54. nshtrainer/configs/nn/mlp/__init__.py +5 -1
  55. nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
  56. nshtrainer/configs/optimizer/__init__.py +5 -1
  57. nshtrainer/configs/profiler/__init__.py +11 -1
  58. nshtrainer/configs/profiler/_base/__init__.py +4 -0
  59. nshtrainer/configs/profiler/advanced/__init__.py +5 -0
  60. nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
  61. nshtrainer/configs/profiler/simple/__init__.py +5 -0
  62. nshtrainer/configs/trainer/__init__.py +39 -6
  63. nshtrainer/configs/trainer/_config/__init__.py +37 -6
  64. nshtrainer/configs/trainer/trainer/__init__.py +9 -0
  65. nshtrainer/configs/util/__init__.py +19 -1
  66. nshtrainer/configs/util/_environment_info/__init__.py +14 -0
  67. nshtrainer/configs/util/config/__init__.py +8 -1
  68. nshtrainer/configs/util/config/dtype/__init__.py +4 -0
  69. nshtrainer/configs/util/config/duration/__init__.py +5 -1
  70. nshtrainer/loggers/__init__.py +12 -5
  71. nshtrainer/lr_scheduler/__init__.py +9 -5
  72. nshtrainer/model/mixins/callback.py +6 -4
  73. nshtrainer/optimizer.py +5 -3
  74. nshtrainer/profiler/__init__.py +9 -5
  75. nshtrainer/trainer/_config.py +85 -61
  76. nshtrainer/trainer/_runtime_callback.py +3 -3
  77. nshtrainer/trainer/signal_connector.py +6 -4
  78. nshtrainer/trainer/trainer.py +4 -4
  79. nshtrainer/util/_useful_types.py +11 -2
  80. nshtrainer/util/config/dtype.py +46 -43
  81. nshtrainer/util/path.py +3 -2
  82. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/METADATA +2 -1
  83. nshtrainer-1.0.0b27.dist-info/RECORD +143 -0
  84. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b27.dist-info}/WHEEL +1 -1
  85. nshtrainer-1.0.0b25.dist-info/RECORD +0 -140
@@ -0,0 +1,6 @@
1
+ {
2
+ "module": "nshtrainer",
3
+ "output": "configs",
4
+ "typed_dicts": null,
5
+ "json_schemas": null
6
+ }
@@ -55,7 +55,7 @@ class CheckpointMetadata(C.Config):
55
55
  metrics: dict[str, Any]
56
56
  environment: EnvironmentConfig
57
57
 
58
- hparams: Any
58
+ hparams: Any | None
59
59
 
60
60
  @classmethod
61
61
  def from_file(cls, path: Path):
@@ -14,6 +14,8 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
14
14
  from .checkpoint import (
15
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
16
  )
17
+ from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
+ from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
17
19
  from .debug_flag import DebugFlagCallback as DebugFlagCallback
18
20
  from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
21
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
@@ -71,6 +73,7 @@ CallbackConfig = Annotated[
71
73
  | BestCheckpointCallbackConfig
72
74
  | LastCheckpointCallbackConfig
73
75
  | OnExceptionCheckpointCallbackConfig
76
+ | TimeCheckpointCallbackConfig
74
77
  | SharedParametersCallbackConfig
75
78
  | RLPSanityChecksCallbackConfig
76
79
  | WandbWatchCallbackConfig
@@ -4,12 +4,12 @@ import contextlib
4
4
  from pathlib import Path
5
5
  from typing import Literal
6
6
 
7
- from typing_extensions import TypeAlias, override
7
+ from typing_extensions import TypeAliasType, override
8
8
 
9
9
  from .._callback import NTCallbackBase
10
10
  from .base import CallbackConfigBase
11
11
 
12
- Stage: TypeAlias = Literal["train", "validation", "test", "predict"]
12
+ Stage = TypeAliasType("Stage", Literal["train", "validation", "test", "predict"])
13
13
 
14
14
 
15
15
  class ActSaveConfig(CallbackConfigBase):
@@ -4,11 +4,11 @@ from abc import ABC, abstractmethod
4
4
  from collections import Counter
5
5
  from collections.abc import Iterable
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, ClassVar, TypeAlias
7
+ from typing import TYPE_CHECKING, ClassVar
8
8
 
9
9
  import nshconfig as C
10
10
  from lightning.pytorch import Callback
11
- from typing_extensions import TypedDict, Unpack
11
+ from typing_extensions import TypeAliasType, TypedDict, Unpack
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ..trainer._config import TrainerConfig
@@ -30,7 +30,9 @@ class CallbackWithMetadata:
30
30
  metadata: CallbackMetadataConfig
31
31
 
32
32
 
33
- ConstructedCallback: TypeAlias = Callback | CallbackWithMetadata
33
+ ConstructedCallback = TypeAliasType(
34
+ "ConstructedCallback", Callback | CallbackWithMetadata
35
+ )
34
36
 
35
37
 
36
38
  class CallbackConfigBase(C.Config, ABC):
@@ -14,3 +14,7 @@ from .on_exception_checkpoint import (
14
14
  from .on_exception_checkpoint import (
15
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
16
  )
17
+ from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
+ from .time_checkpoint import (
19
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
20
+ )
@@ -7,8 +7,7 @@ from typing import Literal
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from typing_extensions import final, override
9
9
 
10
- from nshtrainer._checkpoint.metadata import CheckpointMetadata
11
-
10
+ from ..._checkpoint.metadata import CheckpointMetadata
12
11
  from ...metrics._config import MetricConfig
13
12
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
14
13
 
@@ -6,8 +6,7 @@ from typing import Literal
6
6
  from lightning.pytorch import LightningModule, Trainer
7
7
  from typing_extensions import final, override
8
8
 
9
- from nshtrainer._checkpoint.metadata import CheckpointMetadata
10
-
9
+ from ..._checkpoint.metadata import CheckpointMetadata
11
10
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
12
11
 
13
12
  log = logging.getLogger(__name__)
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from datetime import timedelta
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ from lightning.pytorch import LightningModule, Trainer
10
+ from typing_extensions import final, override
11
+
12
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
13
+
14
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ @final
20
+ class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
21
+ name: Literal["time_checkpoint"] = "time_checkpoint"
22
+
23
+ interval: timedelta = timedelta(hours=12)
24
+ """Time interval between checkpoints."""
25
+
26
+ @override
27
+ def create_checkpoint(self, trainer_config, dirpath):
28
+ return TimeCheckpointCallback(self, dirpath)
29
+
30
+
31
+ @final
32
+ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
33
+ def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
34
+ super().__init__(config, dirpath)
35
+ self.start_time = time.time()
36
+ self.last_checkpoint_time = self.start_time
37
+ self.interval_seconds = config.interval.total_seconds()
38
+
39
+ @override
40
+ def name(self):
41
+ return "time"
42
+
43
+ @override
44
+ def default_filename(self):
45
+ return "epoch{epoch}-step{step}-duration{train_duration}"
46
+
47
+ @override
48
+ def topk_sort_key(self, metadata: CheckpointMetadata):
49
+ return metadata.checkpoint_timestamp
50
+
51
+ @override
52
+ def topk_sort_reverse(self):
53
+ return True
54
+
55
+ def _should_checkpoint(self) -> bool:
56
+ current_time = time.time()
57
+ elapsed_time = current_time - self.last_checkpoint_time
58
+ return elapsed_time >= self.interval_seconds
59
+
60
+ def _format_duration(self, seconds: float) -> str:
61
+ """Format duration in seconds to a human-readable string."""
62
+ td = timedelta(seconds=int(seconds))
63
+ days = td.days
64
+ hours, remainder = divmod(td.seconds, 3600)
65
+ minutes, seconds = divmod(remainder, 60)
66
+
67
+ parts = []
68
+ if days > 0:
69
+ parts.append(f"{days}d")
70
+ if hours > 0:
71
+ parts.append(f"{hours}h")
72
+ if minutes > 0:
73
+ parts.append(f"{minutes}m")
74
+ if seconds > 0 or not parts:
75
+ parts.append(f"{seconds}s")
76
+
77
+ return "_".join(parts)
78
+
79
+ @override
80
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
81
+ metrics = super().current_metrics(trainer)
82
+ train_duration = time.time() - self.start_time
83
+ metrics["train_duration"] = self._format_duration(train_duration)
84
+ return metrics
85
+
86
+ @override
87
+ def on_train_batch_end(
88
+ self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
89
+ ):
90
+ if self._should_checkpoint():
91
+ self.save_checkpoints(trainer)
92
+ self.last_checkpoint_time = time.time()
93
+
94
+ @override
95
+ def state_dict(self) -> dict[str, Any]:
96
+ """Save the timer state for checkpoint resumption.
97
+
98
+ Returns:
99
+ Dictionary containing the start time and last checkpoint time.
100
+ """
101
+ return {
102
+ "start_time": self.start_time,
103
+ "last_checkpoint_time": self.last_checkpoint_time,
104
+ }
105
+
106
+ @override
107
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
108
+ """Restore the timer state when resuming from a checkpoint.
109
+
110
+ Args:
111
+ state_dict: Dictionary containing the previously saved timer state.
112
+ """
113
+ self.start_time = state_dict["start_time"]
114
+ self.last_checkpoint_time = state_dict["last_checkpoint_time"]
@@ -49,14 +49,14 @@ class PrintTableMetricsCallback(Callback):
49
49
  }
50
50
  self.metrics.append(metrics_dict)
51
51
 
52
- from rich.console import Console
52
+ from rich.console import Console # type: ignore[reportMissingImports] # noqa
53
53
 
54
54
  console = Console()
55
55
  table = self.create_metrics_table()
56
56
  console.print(table)
57
57
 
58
58
  def create_metrics_table(self):
59
- from rich.table import Table
59
+ from rich.table import Table # type: ignore[reportMissingImports] # noqa
60
60
 
61
61
  table = Table(show_header=True, header_style="bold magenta")
62
62
 
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from collections.abc import Iterable
5
- from typing import Literal, Protocol, TypeAlias, runtime_checkable
5
+ from typing import Literal, Protocol, runtime_checkable
6
6
 
7
7
  import torch.nn as nn
8
8
  from lightning.pytorch import LightningModule, Trainer
9
9
  from lightning.pytorch.callbacks import Callback
10
- from typing_extensions import override
10
+ from typing_extensions import TypeAliasType, override
11
11
 
12
12
  from .base import CallbackConfigBase
13
13
 
@@ -34,7 +34,9 @@ class SharedParametersCallbackConfig(CallbackConfigBase):
34
34
  yield SharedParametersCallback(self)
35
35
 
36
36
 
37
- SharedParametersList: TypeAlias = list[tuple[nn.Parameter, int | float]]
37
+ SharedParametersList = TypeAliasType(
38
+ "SharedParametersList", list[tuple[nn.Parameter, int | float]]
39
+ )
38
40
 
39
41
 
40
42
  @runtime_checkable
@@ -14,7 +14,6 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
14
14
  from nshtrainer.callbacks import (
15
15
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
16
16
  )
17
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
18
17
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
18
  from nshtrainer.callbacks import (
20
19
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
@@ -47,6 +46,9 @@ from nshtrainer.callbacks import (
47
46
  from nshtrainer.callbacks import (
48
47
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
49
48
  )
49
+ from nshtrainer.callbacks import (
50
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
51
+ )
50
52
  from nshtrainer.callbacks import (
51
53
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
52
54
  )
@@ -58,13 +60,11 @@ from nshtrainer.callbacks.checkpoint._base import (
58
60
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
59
61
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
60
62
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
61
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
62
63
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
63
64
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
64
65
  from nshtrainer.lr_scheduler import (
65
66
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
66
67
  )
67
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
68
68
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
69
69
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
70
70
  from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
@@ -73,7 +73,6 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
73
73
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
74
74
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
75
75
  from nshtrainer.nn import MLPConfig as MLPConfig
76
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
77
76
  from nshtrainer.nn import PReLUConfig as PReLUConfig
78
77
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
79
78
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -87,23 +86,21 @@ from nshtrainer.nn.nonlinearity import (
87
86
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
88
87
  )
89
88
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
90
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
91
89
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
92
90
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
93
91
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
94
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
95
92
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
96
93
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
97
- from nshtrainer.trainer._config import (
98
- CheckpointCallbackConfig as CheckpointCallbackConfig,
99
- )
94
+ from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
100
95
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
101
96
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
102
97
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
103
98
  from nshtrainer.trainer._config import (
104
99
  LearningRateMonitorConfig as LearningRateMonitorConfig,
105
100
  )
101
+ from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
106
102
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
103
+ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
107
104
  from nshtrainer.util._environment_info import (
108
105
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
109
106
  )
@@ -133,7 +130,6 @@ from nshtrainer.util._environment_info import (
133
130
  )
134
131
  from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
135
132
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
136
- from nshtrainer.util.config import DurationConfig as DurationConfig
137
133
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
138
134
  from nshtrainer.util.config import StepsConfig as StepsConfig
139
135
 
@@ -149,3 +145,96 @@ from . import optimizer as optimizer
149
145
  from . import profiler as profiler
150
146
  from . import trainer as trainer
151
147
  from . import util as util
148
+
149
+ __all__ = [
150
+ "AcceleratorConfigBase",
151
+ "ActSaveConfig",
152
+ "ActSaveLoggerConfig",
153
+ "AdamWConfig",
154
+ "AdvancedProfilerConfig",
155
+ "BaseCheckpointCallbackConfig",
156
+ "BaseLoggerConfig",
157
+ "BaseNonlinearityConfig",
158
+ "BaseProfilerConfig",
159
+ "BestCheckpointCallbackConfig",
160
+ "CSVLoggerConfig",
161
+ "CallbackConfigBase",
162
+ "CheckpointMetadata",
163
+ "CheckpointSavingConfig",
164
+ "DTypeConfig",
165
+ "DebugFlagCallbackConfig",
166
+ "DirectoryConfig",
167
+ "DirectorySetupCallbackConfig",
168
+ "ELUNonlinearityConfig",
169
+ "EMACallbackConfig",
170
+ "EarlyStoppingCallbackConfig",
171
+ "EnvironmentCUDAConfig",
172
+ "EnvironmentClassInformationConfig",
173
+ "EnvironmentConfig",
174
+ "EnvironmentGPUConfig",
175
+ "EnvironmentHardwareConfig",
176
+ "EnvironmentLSFInformationConfig",
177
+ "EnvironmentLinuxEnvironmentConfig",
178
+ "EnvironmentPackageConfig",
179
+ "EnvironmentSLURMInformationConfig",
180
+ "EnvironmentSnapshotConfig",
181
+ "EpochTimerCallbackConfig",
182
+ "EpochsConfig",
183
+ "FiniteChecksCallbackConfig",
184
+ "GELUNonlinearityConfig",
185
+ "GitRepositoryConfig",
186
+ "GradientClippingConfig",
187
+ "GradientSkippingCallbackConfig",
188
+ "HuggingFaceHubAutoCreateConfig",
189
+ "HuggingFaceHubConfig",
190
+ "LRSchedulerConfigBase",
191
+ "LastCheckpointCallbackConfig",
192
+ "LeakyReLUNonlinearityConfig",
193
+ "LearningRateMonitorConfig",
194
+ "LinearWarmupCosineDecayLRSchedulerConfig",
195
+ "LogEpochCallbackConfig",
196
+ "MLPConfig",
197
+ "MetricConfig",
198
+ "MishNonlinearityConfig",
199
+ "NormLoggingCallbackConfig",
200
+ "OnExceptionCheckpointCallbackConfig",
201
+ "OptimizerConfigBase",
202
+ "PReLUConfig",
203
+ "PluginConfigBase",
204
+ "PrintTableMetricsCallbackConfig",
205
+ "PyTorchProfilerConfig",
206
+ "RLPSanityChecksCallbackConfig",
207
+ "ReLUNonlinearityConfig",
208
+ "ReduceLROnPlateauConfig",
209
+ "SanityCheckingConfig",
210
+ "SharedParametersCallbackConfig",
211
+ "SiLUNonlinearityConfig",
212
+ "SigmoidNonlinearityConfig",
213
+ "SimpleProfilerConfig",
214
+ "SoftmaxNonlinearityConfig",
215
+ "SoftplusNonlinearityConfig",
216
+ "SoftsignNonlinearityConfig",
217
+ "StepsConfig",
218
+ "StrategyConfigBase",
219
+ "SwiGLUNonlinearityConfig",
220
+ "SwishNonlinearityConfig",
221
+ "TanhNonlinearityConfig",
222
+ "TensorboardLoggerConfig",
223
+ "TimeCheckpointCallbackConfig",
224
+ "TrainerConfig",
225
+ "WandbLoggerConfig",
226
+ "WandbUploadCodeCallbackConfig",
227
+ "WandbWatchCallbackConfig",
228
+ "_checkpoint",
229
+ "_directory",
230
+ "_hf_hub",
231
+ "callbacks",
232
+ "loggers",
233
+ "lr_scheduler",
234
+ "metrics",
235
+ "nn",
236
+ "optimizer",
237
+ "profiler",
238
+ "trainer",
239
+ "util",
240
+ ]
@@ -6,3 +6,9 @@ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMeta
6
6
  from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
7
7
 
8
8
  from . import metadata as metadata
9
+
10
+ __all__ = [
11
+ "CheckpointMetadata",
12
+ "EnvironmentConfig",
13
+ "metadata",
14
+ ]
@@ -4,3 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
6
6
  from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
7
+
8
+ __all__ = [
9
+ "CheckpointMetadata",
10
+ "EnvironmentConfig",
11
+ ]
@@ -6,4 +6,8 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
6
  from nshtrainer._directory import (
7
7
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
8
8
  )
9
- from nshtrainer._directory import LoggerConfig as LoggerConfig
9
+
10
+ __all__ = [
11
+ "DirectoryConfig",
12
+ "DirectorySetupCallbackConfig",
13
+ ]
@@ -7,3 +7,9 @@ from nshtrainer._hf_hub import (
7
7
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
8
8
  )
9
9
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
10
+
11
+ __all__ = [
12
+ "CallbackConfigBase",
13
+ "HuggingFaceHubAutoCreateConfig",
14
+ "HuggingFaceHubConfig",
15
+ ]
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer.callbacks import (
6
6
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
7
7
  )
8
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
9
8
  from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
10
9
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
11
10
  from nshtrainer.callbacks import (
@@ -39,6 +38,9 @@ from nshtrainer.callbacks import (
39
38
  from nshtrainer.callbacks import (
40
39
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
41
40
  )
41
+ from nshtrainer.callbacks import (
42
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
43
+ )
42
44
  from nshtrainer.callbacks import (
43
45
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
44
46
  )
@@ -73,3 +75,48 @@ from . import shared_parameters as shared_parameters
73
75
  from . import timer as timer
74
76
  from . import wandb_upload_code as wandb_upload_code
75
77
  from . import wandb_watch as wandb_watch
78
+
79
+ __all__ = [
80
+ "ActSaveConfig",
81
+ "BaseCheckpointCallbackConfig",
82
+ "BestCheckpointCallbackConfig",
83
+ "CallbackConfigBase",
84
+ "CheckpointMetadata",
85
+ "DebugFlagCallbackConfig",
86
+ "DirectorySetupCallbackConfig",
87
+ "EMACallbackConfig",
88
+ "EarlyStoppingCallbackConfig",
89
+ "EpochTimerCallbackConfig",
90
+ "FiniteChecksCallbackConfig",
91
+ "GradientSkippingCallbackConfig",
92
+ "LastCheckpointCallbackConfig",
93
+ "LearningRateMonitorConfig",
94
+ "LogEpochCallbackConfig",
95
+ "MetricConfig",
96
+ "NormLoggingCallbackConfig",
97
+ "OnExceptionCheckpointCallbackConfig",
98
+ "PrintTableMetricsCallbackConfig",
99
+ "RLPSanityChecksCallbackConfig",
100
+ "SharedParametersCallbackConfig",
101
+ "TimeCheckpointCallbackConfig",
102
+ "WandbUploadCodeCallbackConfig",
103
+ "WandbWatchCallbackConfig",
104
+ "actsave",
105
+ "base",
106
+ "checkpoint",
107
+ "debug_flag",
108
+ "directory_setup",
109
+ "early_stopping",
110
+ "ema",
111
+ "finite_checks",
112
+ "gradient_skipping",
113
+ "log_epoch",
114
+ "lr_monitor",
115
+ "norm_logging",
116
+ "print_table",
117
+ "rlp_sanity_checks",
118
+ "shared_parameters",
119
+ "timer",
120
+ "wandb_upload_code",
121
+ "wandb_watch",
122
+ ]
@@ -4,3 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
6
6
  from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
7
+
8
+ __all__ = [
9
+ "ActSaveConfig",
10
+ "CallbackConfigBase",
11
+ ]
@@ -3,3 +3,7 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
6
+
7
+ __all__ = [
8
+ "CallbackConfigBase",
9
+ ]
@@ -11,6 +11,9 @@ from nshtrainer.callbacks.checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint import (
12
12
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
13
13
  )
14
+ from nshtrainer.callbacks.checkpoint import (
15
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
16
+ )
14
17
  from nshtrainer.callbacks.checkpoint._base import (
15
18
  BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
16
19
  )
@@ -26,3 +29,20 @@ from . import _base as _base
26
29
  from . import best_checkpoint as best_checkpoint
27
30
  from . import last_checkpoint as last_checkpoint
28
31
  from . import on_exception_checkpoint as on_exception_checkpoint
32
+ from . import time_checkpoint as time_checkpoint
33
+
34
+ __all__ = [
35
+ "BaseCheckpointCallbackConfig",
36
+ "BestCheckpointCallbackConfig",
37
+ "CallbackConfigBase",
38
+ "CheckpointMetadata",
39
+ "LastCheckpointCallbackConfig",
40
+ "MetricConfig",
41
+ "OnExceptionCheckpointCallbackConfig",
42
+ "TimeCheckpointCallbackConfig",
43
+ "_base",
44
+ "best_checkpoint",
45
+ "last_checkpoint",
46
+ "on_exception_checkpoint",
47
+ "time_checkpoint",
48
+ ]
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint._base import (
11
11
  from nshtrainer.callbacks.checkpoint._base import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CallbackConfigBase",
18
+ "CheckpointMetadata",
19
+ ]
@@ -12,3 +12,10 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
14
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
15
+
16
+ __all__ = [
17
+ "BaseCheckpointCallbackConfig",
18
+ "BestCheckpointCallbackConfig",
19
+ "CheckpointMetadata",
20
+ "MetricConfig",
21
+ ]
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint.last_checkpoint import (
12
12
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
13
13
  )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CheckpointMetadata",
18
+ "LastCheckpointCallbackConfig",
19
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
8
8
  from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
9
9
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "OnExceptionCheckpointCallbackConfig",
15
+ ]
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
6
+ BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
7
+ )
8
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
9
+ CheckpointMetadata as CheckpointMetadata,
10
+ )
11
+ from nshtrainer.callbacks.checkpoint.time_checkpoint import (
12
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
13
+ )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CheckpointMetadata",
18
+ "TimeCheckpointCallbackConfig",
19
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.debug_flag import (
7
7
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "DebugFlagCallbackConfig",
13
+ ]