nshtrainer 1.0.0b25__py3-none-any.whl → 1.0.0b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. nshtrainer/.nshconfig.generated.json +6 -0
  2. nshtrainer/_checkpoint/metadata.py +1 -1
  3. nshtrainer/callbacks/__init__.py +3 -0
  4. nshtrainer/callbacks/checkpoint/__init__.py +4 -0
  5. nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -2
  6. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -2
  7. nshtrainer/callbacks/checkpoint/time_checkpoint.py +114 -0
  8. nshtrainer/callbacks/print_table.py +2 -2
  9. nshtrainer/configs/__init__.py +95 -10
  10. nshtrainer/configs/_checkpoint/__init__.py +6 -0
  11. nshtrainer/configs/_checkpoint/metadata/__init__.py +5 -0
  12. nshtrainer/configs/_directory/__init__.py +5 -1
  13. nshtrainer/configs/_hf_hub/__init__.py +6 -0
  14. nshtrainer/configs/callbacks/__init__.py +44 -1
  15. nshtrainer/configs/callbacks/actsave/__init__.py +5 -0
  16. nshtrainer/configs/callbacks/base/__init__.py +4 -0
  17. nshtrainer/configs/callbacks/checkpoint/__init__.py +14 -0
  18. nshtrainer/configs/callbacks/checkpoint/_base/__init__.py +6 -0
  19. nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +7 -0
  20. nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +6 -0
  21. nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +5 -0
  22. nshtrainer/configs/callbacks/debug_flag/__init__.py +5 -0
  23. nshtrainer/configs/callbacks/directory_setup/__init__.py +5 -0
  24. nshtrainer/configs/callbacks/early_stopping/__init__.py +6 -0
  25. nshtrainer/configs/callbacks/ema/__init__.py +5 -0
  26. nshtrainer/configs/callbacks/finite_checks/__init__.py +5 -0
  27. nshtrainer/configs/callbacks/gradient_skipping/__init__.py +5 -0
  28. nshtrainer/configs/callbacks/log_epoch/__init__.py +5 -0
  29. nshtrainer/configs/callbacks/lr_monitor/__init__.py +5 -0
  30. nshtrainer/configs/callbacks/norm_logging/__init__.py +5 -0
  31. nshtrainer/configs/callbacks/print_table/__init__.py +5 -0
  32. nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +5 -0
  33. nshtrainer/configs/callbacks/shared_parameters/__init__.py +5 -0
  34. nshtrainer/configs/callbacks/timer/__init__.py +5 -0
  35. nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +5 -0
  36. nshtrainer/configs/callbacks/wandb_watch/__init__.py +5 -0
  37. nshtrainer/configs/loggers/__init__.py +16 -1
  38. nshtrainer/configs/loggers/_base/__init__.py +4 -0
  39. nshtrainer/configs/loggers/actsave/__init__.py +5 -0
  40. nshtrainer/configs/loggers/csv/__init__.py +5 -0
  41. nshtrainer/configs/loggers/tensorboard/__init__.py +5 -0
  42. nshtrainer/configs/loggers/wandb/__init__.py +8 -0
  43. nshtrainer/configs/lr_scheduler/__init__.py +10 -4
  44. nshtrainer/configs/lr_scheduler/_base/__init__.py +4 -0
  45. nshtrainer/configs/lr_scheduler/linear_warmup_cosine/__init__.py +5 -3
  46. nshtrainer/configs/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -0
  47. nshtrainer/configs/metrics/__init__.py +5 -0
  48. nshtrainer/configs/metrics/_config/__init__.py +4 -0
  49. nshtrainer/configs/nn/__init__.py +21 -1
  50. nshtrainer/configs/nn/mlp/__init__.py +5 -1
  51. nshtrainer/configs/nn/nonlinearity/__init__.py +18 -1
  52. nshtrainer/configs/optimizer/__init__.py +5 -1
  53. nshtrainer/configs/profiler/__init__.py +11 -1
  54. nshtrainer/configs/profiler/_base/__init__.py +4 -0
  55. nshtrainer/configs/profiler/advanced/__init__.py +5 -0
  56. nshtrainer/configs/profiler/pytorch/__init__.py +5 -0
  57. nshtrainer/configs/profiler/simple/__init__.py +5 -0
  58. nshtrainer/configs/trainer/__init__.py +35 -6
  59. nshtrainer/configs/trainer/_config/__init__.py +33 -6
  60. nshtrainer/configs/trainer/trainer/__init__.py +9 -0
  61. nshtrainer/configs/util/__init__.py +19 -1
  62. nshtrainer/configs/util/_environment_info/__init__.py +14 -0
  63. nshtrainer/configs/util/config/__init__.py +8 -1
  64. nshtrainer/configs/util/config/dtype/__init__.py +4 -0
  65. nshtrainer/configs/util/config/duration/__init__.py +5 -1
  66. nshtrainer/trainer/_config.py +40 -21
  67. nshtrainer/trainer/trainer.py +4 -4
  68. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b26.dist-info}/METADATA +2 -1
  69. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b26.dist-info}/RECORD +70 -68
  70. {nshtrainer-1.0.0b25.dist-info → nshtrainer-1.0.0b26.dist-info}/WHEEL +1 -1
@@ -0,0 +1,6 @@
1
+ {
2
+ "module": "nshtrainer",
3
+ "output": "configs",
4
+ "typed_dicts": null,
5
+ "json_schemas": null
6
+ }
@@ -55,7 +55,7 @@ class CheckpointMetadata(C.Config):
55
55
  metrics: dict[str, Any]
56
56
  environment: EnvironmentConfig
57
57
 
58
- hparams: Any
58
+ hparams: Any | None
59
59
 
60
60
  @classmethod
61
61
  def from_file(cls, path: Path):
@@ -14,6 +14,8 @@ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCa
14
14
  from .checkpoint import (
15
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
16
  )
17
+ from .checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
+ from .checkpoint import TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig
17
19
  from .debug_flag import DebugFlagCallback as DebugFlagCallback
18
20
  from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
21
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
@@ -71,6 +73,7 @@ CallbackConfig = Annotated[
71
73
  | BestCheckpointCallbackConfig
72
74
  | LastCheckpointCallbackConfig
73
75
  | OnExceptionCheckpointCallbackConfig
76
+ | TimeCheckpointCallbackConfig
74
77
  | SharedParametersCallbackConfig
75
78
  | RLPSanityChecksCallbackConfig
76
79
  | WandbWatchCallbackConfig
@@ -14,3 +14,7 @@ from .on_exception_checkpoint import (
14
14
  from .on_exception_checkpoint import (
15
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
16
  )
17
+ from .time_checkpoint import TimeCheckpointCallback as TimeCheckpointCallback
18
+ from .time_checkpoint import (
19
+ TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
20
+ )
@@ -7,8 +7,7 @@ from typing import Literal
7
7
  from lightning.pytorch import LightningModule, Trainer
8
8
  from typing_extensions import final, override
9
9
 
10
- from nshtrainer._checkpoint.metadata import CheckpointMetadata
11
-
10
+ from ..._checkpoint.metadata import CheckpointMetadata
12
11
  from ...metrics._config import MetricConfig
13
12
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
14
13
 
@@ -6,8 +6,7 @@ from typing import Literal
6
6
  from lightning.pytorch import LightningModule, Trainer
7
7
  from typing_extensions import final, override
8
8
 
9
- from nshtrainer._checkpoint.metadata import CheckpointMetadata
10
-
9
+ from ..._checkpoint.metadata import CheckpointMetadata
11
10
  from ._base import BaseCheckpointCallbackConfig, CheckpointBase
12
11
 
13
12
  log = logging.getLogger(__name__)
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from datetime import timedelta
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+
9
+ from lightning.pytorch import LightningModule, Trainer
10
+ from typing_extensions import final, override
11
+
12
+ from nshtrainer._checkpoint.metadata import CheckpointMetadata
13
+
14
+ from ._base import BaseCheckpointCallbackConfig, CheckpointBase
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ @final
20
+ class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
21
+ name: Literal["time_checkpoint"] = "time_checkpoint"
22
+
23
+ interval: timedelta = timedelta(hours=12)
24
+ """Time interval between checkpoints."""
25
+
26
+ @override
27
+ def create_checkpoint(self, trainer_config, dirpath):
28
+ return TimeCheckpointCallback(self, dirpath)
29
+
30
+
31
+ @final
32
+ class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
33
+ def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
34
+ super().__init__(config, dirpath)
35
+ self.start_time = time.time()
36
+ self.last_checkpoint_time = self.start_time
37
+ self.interval_seconds = config.interval.total_seconds()
38
+
39
+ @override
40
+ def name(self):
41
+ return "time"
42
+
43
+ @override
44
+ def default_filename(self):
45
+ return "epoch{epoch}-step{step}-duration{train_duration}"
46
+
47
+ @override
48
+ def topk_sort_key(self, metadata: CheckpointMetadata):
49
+ return metadata.checkpoint_timestamp
50
+
51
+ @override
52
+ def topk_sort_reverse(self):
53
+ return True
54
+
55
+ def _should_checkpoint(self) -> bool:
56
+ current_time = time.time()
57
+ elapsed_time = current_time - self.last_checkpoint_time
58
+ return elapsed_time >= self.interval_seconds
59
+
60
+ def _format_duration(self, seconds: float) -> str:
61
+ """Format duration in seconds to a human-readable string."""
62
+ td = timedelta(seconds=int(seconds))
63
+ days = td.days
64
+ hours, remainder = divmod(td.seconds, 3600)
65
+ minutes, seconds = divmod(remainder, 60)
66
+
67
+ parts = []
68
+ if days > 0:
69
+ parts.append(f"{days}d")
70
+ if hours > 0:
71
+ parts.append(f"{hours}h")
72
+ if minutes > 0:
73
+ parts.append(f"{minutes}m")
74
+ if seconds > 0 or not parts:
75
+ parts.append(f"{seconds}s")
76
+
77
+ return "_".join(parts)
78
+
79
+ @override
80
+ def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
81
+ metrics = super().current_metrics(trainer)
82
+ train_duration = time.time() - self.start_time
83
+ metrics["train_duration"] = self._format_duration(train_duration)
84
+ return metrics
85
+
86
+ @override
87
+ def on_train_batch_end(
88
+ self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
89
+ ):
90
+ if self._should_checkpoint():
91
+ self.save_checkpoints(trainer)
92
+ self.last_checkpoint_time = time.time()
93
+
94
+ @override
95
+ def state_dict(self) -> dict[str, Any]:
96
+ """Save the timer state for checkpoint resumption.
97
+
98
+ Returns:
99
+ Dictionary containing the start time and last checkpoint time.
100
+ """
101
+ return {
102
+ "start_time": self.start_time,
103
+ "last_checkpoint_time": self.last_checkpoint_time,
104
+ }
105
+
106
+ @override
107
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
108
+ """Restore the timer state when resuming from a checkpoint.
109
+
110
+ Args:
111
+ state_dict: Dictionary containing the previously saved timer state.
112
+ """
113
+ self.start_time = state_dict["start_time"]
114
+ self.last_checkpoint_time = state_dict["last_checkpoint_time"]
@@ -49,14 +49,14 @@ class PrintTableMetricsCallback(Callback):
49
49
  }
50
50
  self.metrics.append(metrics_dict)
51
51
 
52
- from rich.console import Console
52
+ from rich.console import Console # type: ignore[reportMissingImports] # noqa
53
53
 
54
54
  console = Console()
55
55
  table = self.create_metrics_table()
56
56
  console.print(table)
57
57
 
58
58
  def create_metrics_table(self):
59
- from rich.table import Table
59
+ from rich.table import Table # type: ignore[reportMissingImports] # noqa
60
60
 
61
61
  table = Table(show_header=True, header_style="bold magenta")
62
62
 
@@ -14,7 +14,6 @@ from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
14
14
  from nshtrainer.callbacks import (
15
15
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
16
16
  )
17
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
18
17
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
19
18
  from nshtrainer.callbacks import (
20
19
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
@@ -58,13 +57,11 @@ from nshtrainer.callbacks.checkpoint._base import (
58
57
  from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
59
58
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
60
59
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
61
- from nshtrainer.loggers import LoggerConfig as LoggerConfig
62
60
  from nshtrainer.loggers import TensorboardLoggerConfig as TensorboardLoggerConfig
63
61
  from nshtrainer.loggers import WandbLoggerConfig as WandbLoggerConfig
64
62
  from nshtrainer.lr_scheduler import (
65
63
  LinearWarmupCosineDecayLRSchedulerConfig as LinearWarmupCosineDecayLRSchedulerConfig,
66
64
  )
67
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
68
65
  from nshtrainer.lr_scheduler import LRSchedulerConfigBase as LRSchedulerConfigBase
69
66
  from nshtrainer.lr_scheduler import ReduceLROnPlateauConfig as ReduceLROnPlateauConfig
70
67
  from nshtrainer.nn import BaseNonlinearityConfig as BaseNonlinearityConfig
@@ -73,7 +70,6 @@ from nshtrainer.nn import GELUNonlinearityConfig as GELUNonlinearityConfig
73
70
  from nshtrainer.nn import LeakyReLUNonlinearityConfig as LeakyReLUNonlinearityConfig
74
71
  from nshtrainer.nn import MishNonlinearityConfig as MishNonlinearityConfig
75
72
  from nshtrainer.nn import MLPConfig as MLPConfig
76
- from nshtrainer.nn import NonlinearityConfig as NonlinearityConfig
77
73
  from nshtrainer.nn import PReLUConfig as PReLUConfig
78
74
  from nshtrainer.nn import ReLUNonlinearityConfig as ReLUNonlinearityConfig
79
75
  from nshtrainer.nn import SigmoidNonlinearityConfig as SigmoidNonlinearityConfig
@@ -87,23 +83,21 @@ from nshtrainer.nn.nonlinearity import (
87
83
  SwiGLUNonlinearityConfig as SwiGLUNonlinearityConfig,
88
84
  )
89
85
  from nshtrainer.optimizer import AdamWConfig as AdamWConfig
90
- from nshtrainer.optimizer import OptimizerConfig as OptimizerConfig
91
86
  from nshtrainer.optimizer import OptimizerConfigBase as OptimizerConfigBase
92
87
  from nshtrainer.profiler import AdvancedProfilerConfig as AdvancedProfilerConfig
93
88
  from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
94
- from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
95
89
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
96
90
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
97
- from nshtrainer.trainer._config import (
98
- CheckpointCallbackConfig as CheckpointCallbackConfig,
99
- )
91
+ from nshtrainer.trainer._config import AcceleratorConfigBase as AcceleratorConfigBase
100
92
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
101
93
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
102
94
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
103
95
  from nshtrainer.trainer._config import (
104
96
  LearningRateMonitorConfig as LearningRateMonitorConfig,
105
97
  )
98
+ from nshtrainer.trainer._config import PluginConfigBase as PluginConfigBase
106
99
  from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
100
+ from nshtrainer.trainer._config import StrategyConfigBase as StrategyConfigBase
107
101
  from nshtrainer.util._environment_info import (
108
102
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
109
103
  )
@@ -133,7 +127,6 @@ from nshtrainer.util._environment_info import (
133
127
  )
134
128
  from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
135
129
  from nshtrainer.util.config import DTypeConfig as DTypeConfig
136
- from nshtrainer.util.config import DurationConfig as DurationConfig
137
130
  from nshtrainer.util.config import EpochsConfig as EpochsConfig
138
131
  from nshtrainer.util.config import StepsConfig as StepsConfig
139
132
 
@@ -149,3 +142,95 @@ from . import optimizer as optimizer
149
142
  from . import profiler as profiler
150
143
  from . import trainer as trainer
151
144
  from . import util as util
145
+
146
+ __all__ = [
147
+ "AcceleratorConfigBase",
148
+ "ActSaveConfig",
149
+ "ActSaveLoggerConfig",
150
+ "AdamWConfig",
151
+ "AdvancedProfilerConfig",
152
+ "BaseCheckpointCallbackConfig",
153
+ "BaseLoggerConfig",
154
+ "BaseNonlinearityConfig",
155
+ "BaseProfilerConfig",
156
+ "BestCheckpointCallbackConfig",
157
+ "CSVLoggerConfig",
158
+ "CallbackConfigBase",
159
+ "CheckpointMetadata",
160
+ "CheckpointSavingConfig",
161
+ "DTypeConfig",
162
+ "DebugFlagCallbackConfig",
163
+ "DirectoryConfig",
164
+ "DirectorySetupCallbackConfig",
165
+ "ELUNonlinearityConfig",
166
+ "EMACallbackConfig",
167
+ "EarlyStoppingCallbackConfig",
168
+ "EnvironmentCUDAConfig",
169
+ "EnvironmentClassInformationConfig",
170
+ "EnvironmentConfig",
171
+ "EnvironmentGPUConfig",
172
+ "EnvironmentHardwareConfig",
173
+ "EnvironmentLSFInformationConfig",
174
+ "EnvironmentLinuxEnvironmentConfig",
175
+ "EnvironmentPackageConfig",
176
+ "EnvironmentSLURMInformationConfig",
177
+ "EnvironmentSnapshotConfig",
178
+ "EpochTimerCallbackConfig",
179
+ "EpochsConfig",
180
+ "FiniteChecksCallbackConfig",
181
+ "GELUNonlinearityConfig",
182
+ "GitRepositoryConfig",
183
+ "GradientClippingConfig",
184
+ "GradientSkippingCallbackConfig",
185
+ "HuggingFaceHubAutoCreateConfig",
186
+ "HuggingFaceHubConfig",
187
+ "LRSchedulerConfigBase",
188
+ "LastCheckpointCallbackConfig",
189
+ "LeakyReLUNonlinearityConfig",
190
+ "LearningRateMonitorConfig",
191
+ "LinearWarmupCosineDecayLRSchedulerConfig",
192
+ "LogEpochCallbackConfig",
193
+ "MLPConfig",
194
+ "MetricConfig",
195
+ "MishNonlinearityConfig",
196
+ "NormLoggingCallbackConfig",
197
+ "OnExceptionCheckpointCallbackConfig",
198
+ "OptimizerConfigBase",
199
+ "PReLUConfig",
200
+ "PluginConfigBase",
201
+ "PrintTableMetricsCallbackConfig",
202
+ "PyTorchProfilerConfig",
203
+ "RLPSanityChecksCallbackConfig",
204
+ "ReLUNonlinearityConfig",
205
+ "ReduceLROnPlateauConfig",
206
+ "SanityCheckingConfig",
207
+ "SharedParametersCallbackConfig",
208
+ "SiLUNonlinearityConfig",
209
+ "SigmoidNonlinearityConfig",
210
+ "SimpleProfilerConfig",
211
+ "SoftmaxNonlinearityConfig",
212
+ "SoftplusNonlinearityConfig",
213
+ "SoftsignNonlinearityConfig",
214
+ "StepsConfig",
215
+ "StrategyConfigBase",
216
+ "SwiGLUNonlinearityConfig",
217
+ "SwishNonlinearityConfig",
218
+ "TanhNonlinearityConfig",
219
+ "TensorboardLoggerConfig",
220
+ "TrainerConfig",
221
+ "WandbLoggerConfig",
222
+ "WandbUploadCodeCallbackConfig",
223
+ "WandbWatchCallbackConfig",
224
+ "_checkpoint",
225
+ "_directory",
226
+ "_hf_hub",
227
+ "callbacks",
228
+ "loggers",
229
+ "lr_scheduler",
230
+ "metrics",
231
+ "nn",
232
+ "optimizer",
233
+ "profiler",
234
+ "trainer",
235
+ "util",
236
+ ]
@@ -6,3 +6,9 @@ from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMeta
6
6
  from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
7
7
 
8
8
  from . import metadata as metadata
9
+
10
+ __all__ = [
11
+ "CheckpointMetadata",
12
+ "EnvironmentConfig",
13
+ "metadata",
14
+ ]
@@ -4,3 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
6
6
  from nshtrainer._checkpoint.metadata import EnvironmentConfig as EnvironmentConfig
7
+
8
+ __all__ = [
9
+ "CheckpointMetadata",
10
+ "EnvironmentConfig",
11
+ ]
@@ -6,4 +6,8 @@ from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
6
  from nshtrainer._directory import (
7
7
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
8
8
  )
9
- from nshtrainer._directory import LoggerConfig as LoggerConfig
9
+
10
+ __all__ = [
11
+ "DirectoryConfig",
12
+ "DirectorySetupCallbackConfig",
13
+ ]
@@ -7,3 +7,9 @@ from nshtrainer._hf_hub import (
7
7
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
8
8
  )
9
9
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
10
+
11
+ __all__ = [
12
+ "CallbackConfigBase",
13
+ "HuggingFaceHubAutoCreateConfig",
14
+ "HuggingFaceHubConfig",
15
+ ]
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer.callbacks import (
6
6
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
7
7
  )
8
- from nshtrainer.callbacks import CallbackConfig as CallbackConfig
9
8
  from nshtrainer.callbacks import CallbackConfigBase as CallbackConfigBase
10
9
  from nshtrainer.callbacks import DebugFlagCallbackConfig as DebugFlagCallbackConfig
11
10
  from nshtrainer.callbacks import (
@@ -73,3 +72,47 @@ from . import shared_parameters as shared_parameters
73
72
  from . import timer as timer
74
73
  from . import wandb_upload_code as wandb_upload_code
75
74
  from . import wandb_watch as wandb_watch
75
+
76
+ __all__ = [
77
+ "ActSaveConfig",
78
+ "BaseCheckpointCallbackConfig",
79
+ "BestCheckpointCallbackConfig",
80
+ "CallbackConfigBase",
81
+ "CheckpointMetadata",
82
+ "DebugFlagCallbackConfig",
83
+ "DirectorySetupCallbackConfig",
84
+ "EMACallbackConfig",
85
+ "EarlyStoppingCallbackConfig",
86
+ "EpochTimerCallbackConfig",
87
+ "FiniteChecksCallbackConfig",
88
+ "GradientSkippingCallbackConfig",
89
+ "LastCheckpointCallbackConfig",
90
+ "LearningRateMonitorConfig",
91
+ "LogEpochCallbackConfig",
92
+ "MetricConfig",
93
+ "NormLoggingCallbackConfig",
94
+ "OnExceptionCheckpointCallbackConfig",
95
+ "PrintTableMetricsCallbackConfig",
96
+ "RLPSanityChecksCallbackConfig",
97
+ "SharedParametersCallbackConfig",
98
+ "WandbUploadCodeCallbackConfig",
99
+ "WandbWatchCallbackConfig",
100
+ "actsave",
101
+ "base",
102
+ "checkpoint",
103
+ "debug_flag",
104
+ "directory_setup",
105
+ "early_stopping",
106
+ "ema",
107
+ "finite_checks",
108
+ "gradient_skipping",
109
+ "log_epoch",
110
+ "lr_monitor",
111
+ "norm_logging",
112
+ "print_table",
113
+ "rlp_sanity_checks",
114
+ "shared_parameters",
115
+ "timer",
116
+ "wandb_upload_code",
117
+ "wandb_watch",
118
+ ]
@@ -4,3 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.actsave import ActSaveConfig as ActSaveConfig
6
6
  from nshtrainer.callbacks.actsave import CallbackConfigBase as CallbackConfigBase
7
+
8
+ __all__ = [
9
+ "ActSaveConfig",
10
+ "CallbackConfigBase",
11
+ ]
@@ -3,3 +3,7 @@ from __future__ import annotations
3
3
  __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.base import CallbackConfigBase as CallbackConfigBase
6
+
7
+ __all__ = [
8
+ "CallbackConfigBase",
9
+ ]
@@ -26,3 +26,17 @@ from . import _base as _base
26
26
  from . import best_checkpoint as best_checkpoint
27
27
  from . import last_checkpoint as last_checkpoint
28
28
  from . import on_exception_checkpoint as on_exception_checkpoint
29
+
30
+ __all__ = [
31
+ "BaseCheckpointCallbackConfig",
32
+ "BestCheckpointCallbackConfig",
33
+ "CallbackConfigBase",
34
+ "CheckpointMetadata",
35
+ "LastCheckpointCallbackConfig",
36
+ "MetricConfig",
37
+ "OnExceptionCheckpointCallbackConfig",
38
+ "_base",
39
+ "best_checkpoint",
40
+ "last_checkpoint",
41
+ "on_exception_checkpoint",
42
+ ]
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint._base import (
11
11
  from nshtrainer.callbacks.checkpoint._base import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CallbackConfigBase",
18
+ "CheckpointMetadata",
19
+ ]
@@ -12,3 +12,10 @@ from nshtrainer.callbacks.checkpoint.best_checkpoint import (
12
12
  CheckpointMetadata as CheckpointMetadata,
13
13
  )
14
14
  from nshtrainer.callbacks.checkpoint.best_checkpoint import MetricConfig as MetricConfig
15
+
16
+ __all__ = [
17
+ "BaseCheckpointCallbackConfig",
18
+ "BestCheckpointCallbackConfig",
19
+ "CheckpointMetadata",
20
+ "MetricConfig",
21
+ ]
@@ -11,3 +11,9 @@ from nshtrainer.callbacks.checkpoint.last_checkpoint import (
11
11
  from nshtrainer.callbacks.checkpoint.last_checkpoint import (
12
12
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
13
13
  )
14
+
15
+ __all__ = [
16
+ "BaseCheckpointCallbackConfig",
17
+ "CheckpointMetadata",
18
+ "LastCheckpointCallbackConfig",
19
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
8
8
  from nshtrainer.callbacks.checkpoint.on_exception_checkpoint import (
9
9
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "OnExceptionCheckpointCallbackConfig",
15
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.debug_flag import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.debug_flag import (
7
7
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "DebugFlagCallbackConfig",
13
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.directory_setup import (
8
8
  from nshtrainer.callbacks.directory_setup import (
9
9
  DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "DirectorySetupCallbackConfig",
15
+ ]
@@ -7,3 +7,9 @@ from nshtrainer.callbacks.early_stopping import (
7
7
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
8
8
  )
9
9
  from nshtrainer.callbacks.early_stopping import MetricConfig as MetricConfig
10
+
11
+ __all__ = [
12
+ "CallbackConfigBase",
13
+ "EarlyStoppingCallbackConfig",
14
+ "MetricConfig",
15
+ ]
@@ -4,3 +4,8 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer.callbacks.ema import CallbackConfigBase as CallbackConfigBase
6
6
  from nshtrainer.callbacks.ema import EMACallbackConfig as EMACallbackConfig
7
+
8
+ __all__ = [
9
+ "CallbackConfigBase",
10
+ "EMACallbackConfig",
11
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.finite_checks import CallbackConfigBase as CallbackCon
6
6
  from nshtrainer.callbacks.finite_checks import (
7
7
  FiniteChecksCallbackConfig as FiniteChecksCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "FiniteChecksCallbackConfig",
13
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.gradient_skipping import (
8
8
  from nshtrainer.callbacks.gradient_skipping import (
9
9
  GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "GradientSkippingCallbackConfig",
15
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigB
6
6
  from nshtrainer.callbacks.log_epoch import (
7
7
  LogEpochCallbackConfig as LogEpochCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "LogEpochCallbackConfig",
13
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.lr_monitor import CallbackConfigBase as CallbackConfig
6
6
  from nshtrainer.callbacks.lr_monitor import (
7
7
  LearningRateMonitorConfig as LearningRateMonitorConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "LearningRateMonitorConfig",
13
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.norm_logging import CallbackConfigBase as CallbackConf
6
6
  from nshtrainer.callbacks.norm_logging import (
7
7
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "NormLoggingCallbackConfig",
13
+ ]
@@ -6,3 +6,8 @@ from nshtrainer.callbacks.print_table import CallbackConfigBase as CallbackConfi
6
6
  from nshtrainer.callbacks.print_table import (
7
7
  PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
8
8
  )
9
+
10
+ __all__ = [
11
+ "CallbackConfigBase",
12
+ "PrintTableMetricsCallbackConfig",
13
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.rlp_sanity_checks import (
8
8
  from nshtrainer.callbacks.rlp_sanity_checks import (
9
9
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "RLPSanityChecksCallbackConfig",
15
+ ]
@@ -8,3 +8,8 @@ from nshtrainer.callbacks.shared_parameters import (
8
8
  from nshtrainer.callbacks.shared_parameters import (
9
9
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
10
10
  )
11
+
12
+ __all__ = [
13
+ "CallbackConfigBase",
14
+ "SharedParametersCallbackConfig",
15
+ ]