nshtrainer 1.0.0b32__py3-none-any.whl → 1.0.0b36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/_hf_hub.py +8 -1
- nshtrainer/callbacks/__init__.py +10 -23
- nshtrainer/callbacks/actsave.py +6 -2
- nshtrainer/callbacks/base.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +0 -4
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
- nshtrainer/callbacks/debug_flag.py +4 -2
- nshtrainer/callbacks/directory_setup.py +23 -21
- nshtrainer/callbacks/early_stopping.py +4 -2
- nshtrainer/callbacks/ema.py +29 -27
- nshtrainer/callbacks/finite_checks.py +21 -19
- nshtrainer/callbacks/gradient_skipping.py +29 -27
- nshtrainer/callbacks/log_epoch.py +4 -2
- nshtrainer/callbacks/lr_monitor.py +6 -1
- nshtrainer/callbacks/norm_logging.py +36 -34
- nshtrainer/callbacks/print_table.py +20 -18
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +9 -7
- nshtrainer/callbacks/timer.py +12 -10
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/configs/__init__.py +4 -8
- nshtrainer/configs/_hf_hub/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +4 -8
- nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
- nshtrainer/configs/callbacks/base/__init__.py +2 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
- nshtrainer/configs/callbacks/ema/__init__.py +2 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
- nshtrainer/configs/callbacks/timer/__init__.py +2 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +2 -4
- nshtrainer/configs/trainer/_config/__init__.py +0 -8
- nshtrainer/data/datamodule.py +0 -2
- nshtrainer/model/base.py +0 -2
- nshtrainer/trainer/__init__.py +3 -2
- nshtrainer/trainer/_config.py +4 -42
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/RECORD +57 -60
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
- nshtrainer/util/hparams.py +0 -18
- {nshtrainer-1.0.0b32.dist-info → nshtrainer-1.0.0b36.dist-info}/WHEEL +0 -0
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.wandb_upload_code import (
|
|
8
8
|
from nshtrainer.callbacks.wandb_upload_code import (
|
9
9
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.wandb_upload_code import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"WandbUploadCodeCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.wandb_watch import CallbackConfigBase as CallbackConfi
|
|
6
6
|
from nshtrainer.callbacks.wandb_watch import (
|
7
7
|
WandbWatchCallbackConfig as WandbWatchCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.wandb_watch import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"WandbWatchCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -4,6 +4,7 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
6
6
|
from nshtrainer.trainer import accelerator_registry as accelerator_registry
|
7
|
+
from nshtrainer.trainer import callback_registry as callback_registry
|
7
8
|
from nshtrainer.trainer import plugin_registry as plugin_registry
|
8
9
|
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
9
10
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
@@ -55,9 +56,6 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
|
55
56
|
from nshtrainer.trainer._config import (
|
56
57
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
57
58
|
)
|
58
|
-
from nshtrainer.trainer._config import (
|
59
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
60
|
-
)
|
61
59
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
62
60
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
63
61
|
from nshtrainer.trainer.accelerator import (
|
@@ -179,7 +177,6 @@ __all__ = [
|
|
179
177
|
"StrategyConfig",
|
180
178
|
"StrategyConfigBase",
|
181
179
|
"TensorboardLoggerConfig",
|
182
|
-
"TimeCheckpointCallbackConfig",
|
183
180
|
"TorchCheckpointIOPlugin",
|
184
181
|
"TorchElasticEnvironmentPlugin",
|
185
182
|
"TorchSyncBatchNormPlugin",
|
@@ -193,6 +190,7 @@ __all__ = [
|
|
193
190
|
"_config",
|
194
191
|
"accelerator",
|
195
192
|
"accelerator_registry",
|
193
|
+
"callback_registry",
|
196
194
|
"plugin",
|
197
195
|
"plugin_registry",
|
198
196
|
"strategy",
|
@@ -53,13 +53,8 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
|
53
53
|
from nshtrainer.trainer._config import (
|
54
54
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
55
55
|
)
|
56
|
-
from nshtrainer.trainer._config import (
|
57
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
58
|
-
)
|
59
56
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
60
57
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
61
|
-
from nshtrainer.trainer._config import accelerator_registry as accelerator_registry
|
62
|
-
from nshtrainer.trainer._config import plugin_registry as plugin_registry
|
63
58
|
|
64
59
|
__all__ = [
|
65
60
|
"AcceleratorConfig",
|
@@ -91,9 +86,6 @@ __all__ = [
|
|
91
86
|
"SharedParametersCallbackConfig",
|
92
87
|
"StrategyConfig",
|
93
88
|
"TensorboardLoggerConfig",
|
94
|
-
"TimeCheckpointCallbackConfig",
|
95
89
|
"TrainerConfig",
|
96
90
|
"WandbLoggerConfig",
|
97
|
-
"accelerator_registry",
|
98
|
-
"plugin_registry",
|
99
91
|
]
|
nshtrainer/data/datamodule.py
CHANGED
@@ -12,13 +12,11 @@ from typing_extensions import Never, TypeVar, deprecated, override
|
|
12
12
|
|
13
13
|
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
14
14
|
from ..model.mixins.debug import _DebugModuleMixin
|
15
|
-
from ..util.hparams import HyperparamsMixin
|
16
15
|
|
17
16
|
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
18
17
|
|
19
18
|
|
20
19
|
class LightningDataModuleBase(
|
21
|
-
HyperparamsMixin,
|
22
20
|
_DebugModuleMixin,
|
23
21
|
CallbackRegistrarModuleMixin,
|
24
22
|
LightningDataModule,
|
nshtrainer/model/base.py
CHANGED
@@ -16,7 +16,6 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
|
16
16
|
from typing_extensions import Never, TypeVar, deprecated, override
|
17
17
|
|
18
18
|
from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
|
19
|
-
from ..util.hparams import HyperparamsMixin
|
20
19
|
from .mixins.callback import CallbackModuleMixin
|
21
20
|
from .mixins.debug import _DebugModuleMixin
|
22
21
|
from .mixins.logger import LoggerLightningModuleMixin
|
@@ -55,7 +54,6 @@ VALID_REDUCE_OPS = (
|
|
55
54
|
|
56
55
|
|
57
56
|
class LightningModuleBase(
|
58
|
-
HyperparamsMixin,
|
59
57
|
_DebugModuleMixin,
|
60
58
|
_RLPSanityCheckModuleMixin,
|
61
59
|
LoggerLightningModuleMixin,
|
nshtrainer/trainer/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from ..callbacks import callback_registry as callback_registry
|
3
4
|
from ._config import TrainerConfig as TrainerConfig
|
4
|
-
from .
|
5
|
-
from .
|
5
|
+
from .accelerator import accelerator_registry as accelerator_registry
|
6
|
+
from .plugin import plugin_registry as plugin_registry
|
6
7
|
from .trainer import Trainer as Trainer
|
nshtrainer/trainer/_config.py
CHANGED
@@ -37,7 +37,6 @@ from ..callbacks import (
|
|
37
37
|
OnExceptionCheckpointCallbackConfig,
|
38
38
|
)
|
39
39
|
from ..callbacks.base import CallbackConfigBase
|
40
|
-
from ..callbacks.checkpoint.time_checkpoint import TimeCheckpointCallbackConfig
|
41
40
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
42
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
43
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
@@ -54,9 +53,9 @@ from ..loggers.actsave import ActSaveLoggerConfig
|
|
54
53
|
from ..metrics._config import MetricConfig
|
55
54
|
from ..profiler import ProfilerConfig
|
56
55
|
from ..util._environment_info import EnvironmentConfig
|
57
|
-
from .accelerator import AcceleratorConfig, AcceleratorLiteral
|
58
|
-
from .plugin import PluginConfig
|
59
|
-
from .strategy import StrategyConfig
|
56
|
+
from .accelerator import AcceleratorConfig, AcceleratorLiteral
|
57
|
+
from .plugin import PluginConfig
|
58
|
+
from .strategy import StrategyConfig, StrategyLiteral
|
60
59
|
|
61
60
|
log = logging.getLogger(__name__)
|
62
61
|
|
@@ -70,46 +69,12 @@ class GradientClippingConfig(C.Config):
|
|
70
69
|
"""Norm type to use for gradient clipping."""
|
71
70
|
|
72
71
|
|
73
|
-
StrategyLiteral = TypeAliasType(
|
74
|
-
"StrategyLiteral",
|
75
|
-
Literal[
|
76
|
-
"auto",
|
77
|
-
"ddp",
|
78
|
-
"ddp_find_unused_parameters_false",
|
79
|
-
"ddp_find_unused_parameters_true",
|
80
|
-
"ddp_spawn",
|
81
|
-
"ddp_spawn_find_unused_parameters_false",
|
82
|
-
"ddp_spawn_find_unused_parameters_true",
|
83
|
-
"ddp_fork",
|
84
|
-
"ddp_fork_find_unused_parameters_false",
|
85
|
-
"ddp_fork_find_unused_parameters_true",
|
86
|
-
"ddp_notebook",
|
87
|
-
"dp",
|
88
|
-
"deepspeed",
|
89
|
-
"deepspeed_stage_1",
|
90
|
-
"deepspeed_stage_1_offload",
|
91
|
-
"deepspeed_stage_2",
|
92
|
-
"deepspeed_stage_2_offload",
|
93
|
-
"deepspeed_stage_3",
|
94
|
-
"deepspeed_stage_3_offload",
|
95
|
-
"deepspeed_stage_3_offload_nvme",
|
96
|
-
"fsdp",
|
97
|
-
"fsdp_cpu_offload",
|
98
|
-
"single_xla",
|
99
|
-
"xla_fsdp",
|
100
|
-
"xla",
|
101
|
-
"single_tpu",
|
102
|
-
],
|
103
|
-
)
|
104
|
-
|
105
|
-
|
106
72
|
CheckpointCallbackConfig = TypeAliasType(
|
107
73
|
"CheckpointCallbackConfig",
|
108
74
|
Annotated[
|
109
75
|
BestCheckpointCallbackConfig
|
110
76
|
| LastCheckpointCallbackConfig
|
111
|
-
| OnExceptionCheckpointCallbackConfig
|
112
|
-
| TimeCheckpointCallbackConfig,
|
77
|
+
| OnExceptionCheckpointCallbackConfig,
|
113
78
|
C.Field(discriminator="name"),
|
114
79
|
],
|
115
80
|
)
|
@@ -123,7 +88,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
123
88
|
BestCheckpointCallbackConfig(throw_on_no_metric=False),
|
124
89
|
LastCheckpointCallbackConfig(),
|
125
90
|
OnExceptionCheckpointCallbackConfig(),
|
126
|
-
TimeCheckpointCallbackConfig(interval=timedelta(hours=12)),
|
127
91
|
]
|
128
92
|
"""Checkpoint callback configurations."""
|
129
93
|
|
@@ -397,8 +361,6 @@ class SanityCheckingConfig(C.Config):
|
|
397
361
|
"""
|
398
362
|
|
399
363
|
|
400
|
-
@plugin_registry.rebuild_on_registers
|
401
|
-
@accelerator_registry.rebuild_on_registers
|
402
364
|
class TrainerConfig(C.Config):
|
403
365
|
# region Active Run Configuration
|
404
366
|
id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
|
@@ -1,65 +1,63 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
6
6
|
nshtrainer/_directory.py,sha256=p2uk1FnISFEpMqlDevKhoWhQsCEtvHUPg459K-86QA8,3053
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
|
-
nshtrainer/_hf_hub.py,sha256=
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
10
|
-
nshtrainer/callbacks/actsave.py,sha256=
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
12
|
-
nshtrainer/callbacks/checkpoint/__init__.py,sha256=
|
8
|
+
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
|
10
|
+
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
|
+
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
|
+
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
16
|
-
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=
|
17
|
-
nshtrainer/callbacks/
|
18
|
-
nshtrainer/callbacks/
|
19
|
-
nshtrainer/callbacks/
|
20
|
-
nshtrainer/callbacks/
|
21
|
-
nshtrainer/callbacks/
|
22
|
-
nshtrainer/callbacks/
|
23
|
-
nshtrainer/callbacks/gradient_skipping.py,sha256=k5qNaNeileZ_5YFad4ssfLplMxMKeKFhPcY8-QVmLek,3464
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=MJcNB0biOebx2si2IBFaSUiVOSLSCZTzxB-RcEgO2gY,3482
|
16
|
+
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
|
+
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
|
+
nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
|
19
|
+
nshtrainer/callbacks/early_stopping.py,sha256=EjzN-gD_Xd4YHZLkXsbi00g_4ti3RTMJEdHJ8GMeaFM,4776
|
20
|
+
nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
|
21
|
+
nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
|
22
|
+
nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
|
24
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
25
|
-
nshtrainer/callbacks/log_epoch.py,sha256=
|
26
|
-
nshtrainer/callbacks/lr_monitor.py,sha256=
|
27
|
-
nshtrainer/callbacks/norm_logging.py,sha256=
|
28
|
-
nshtrainer/callbacks/print_table.py,sha256=
|
29
|
-
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=
|
30
|
-
nshtrainer/callbacks/shared_parameters.py,sha256=
|
31
|
-
nshtrainer/callbacks/timer.py,sha256=
|
32
|
-
nshtrainer/callbacks/wandb_upload_code.py,sha256=
|
33
|
-
nshtrainer/callbacks/wandb_watch.py,sha256=
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
24
|
+
nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
|
25
|
+
nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
|
26
|
+
nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
|
27
|
+
nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
|
28
|
+
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
|
29
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
|
30
|
+
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
31
|
+
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
32
|
+
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
33
|
+
nshtrainer/configs/__init__.py,sha256=OevZEZxb4H8imadSQXK9huqdYUF4SrJPfNU_2fpMBvI,14084
|
35
34
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
35
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
36
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
38
|
-
nshtrainer/configs/_hf_hub/__init__.py,sha256=
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
40
|
-
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=
|
41
|
-
nshtrainer/configs/callbacks/base/__init__.py,sha256=
|
42
|
-
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=
|
37
|
+
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
38
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=jSWkbsdiu9vdGWTzqkDf-Bo9dXr9RengeNZLzWUhi7Y,4283
|
39
|
+
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
40
|
+
nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
|
41
|
+
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
|
43
42
|
nshtrainer/configs/callbacks/checkpoint/_base/__init__.py,sha256=5jl6A5Gv6arZXmHV6lz5dQ8DL6PdJIfJqHLP4acClKQ,479
|
44
|
-
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=
|
45
|
-
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=
|
46
|
-
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=
|
47
|
-
nshtrainer/configs/callbacks/
|
48
|
-
nshtrainer/configs/callbacks/
|
49
|
-
nshtrainer/configs/callbacks/
|
50
|
-
nshtrainer/configs/callbacks/
|
51
|
-
nshtrainer/configs/callbacks/
|
52
|
-
nshtrainer/configs/callbacks/
|
53
|
-
nshtrainer/configs/callbacks/
|
54
|
-
nshtrainer/configs/callbacks/
|
55
|
-
nshtrainer/configs/callbacks/
|
56
|
-
nshtrainer/configs/callbacks/
|
57
|
-
nshtrainer/configs/callbacks/
|
58
|
-
nshtrainer/configs/callbacks/
|
59
|
-
nshtrainer/configs/callbacks/
|
60
|
-
nshtrainer/configs/callbacks/
|
61
|
-
nshtrainer/configs/callbacks/
|
62
|
-
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=MW-ANrF529DxBhopovPjYEQ7nANX9ttd1K4_bJnKXks,322
|
43
|
+
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=hL4WGBdo5_gtQuEGcRa3cWYMOSFdlNzkW-2Y3X3ZGTI,781
|
44
|
+
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=SIRfz5QP30K4zzKw_1LZSSFr-3x-S3vc0vWL4ndyvjc,672
|
45
|
+
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=VSkO0TYCAYy_9mQuOBoAND7D3Cg6w6nMCpqivQZLPcE,551
|
46
|
+
nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=s_ifB-DbZjar0w11pr2oVAlcMTWWMnK_tCNilfswL04,425
|
47
|
+
nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=e8GCRy2Alds3AXLwp4ieSGtn8S0YjmKJ5khOaQ0zKGs,464
|
48
|
+
nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=m8N6H11PjqcWqXP5ZxWC8L4PHMUI6avYyN5rUNprjuQ,546
|
49
|
+
nshtrainer/configs/callbacks/ema/__init__.py,sha256=DUJrbDD8wWX_s0_4dwKpT_IWKSVpBmhe4-1aELq7G6w,377
|
50
|
+
nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw7fDMfb9Tsa6Duk0TIa8ZIgIIE,443
|
51
|
+
nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=T3eVxxJfnYBrO9WfLiycn4TyWP4vaqJ57yp7Epkg7B4,485
|
52
|
+
nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=IQ5owYYvyk7fiQP1QXYtncRRJrESuq3rRFhab-II2uE,419
|
53
|
+
nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=qejy1AnXNDHmsFuXRAXQQ5B0TcbKzvpaw-I4dv2AXIs,431
|
54
|
+
nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=j2LrnYEbDGLwJR2lk-jmh-4J_iLEs2HNEoepvJSFLAg,437
|
55
|
+
nshtrainer/configs/callbacks/print_table/__init__.py,sha256=t6fA_dBkUCszUXDJKEdnlBH4oEpfAQqcmAlatTFYIyQ,452
|
56
|
+
nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=dlP14Wh-w8zG_B4EtNmCIFzVMhf6bXCJ1O9cJWmEFnA,482
|
57
|
+
nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=AU7_bSnSRSlj16UTaFBe9QVUf0T0zKUmKCOBSL4xYmg,485
|
58
|
+
nshtrainer/configs/callbacks/timer/__init__.py,sha256=cOUtbsl0_OhCO0fIcBfLuIF6FEGBHQu7AvQFzwVznWQ,413
|
59
|
+
nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=CJeCc9OCu5F39lWiY5aIc4WxQlgBvB-8cga6cQtw0GQ,482
|
60
|
+
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=dzz1oavL1BwELE33xus45_avBEAZDeB6xtcb6CsOEos,431
|
63
61
|
nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
|
64
62
|
nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
|
65
63
|
nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
|
@@ -81,8 +79,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
81
79
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
82
80
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
83
81
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
84
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
85
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
82
|
+
nshtrainer/configs/trainer/__init__.py,sha256=8Z4E1IeJHtDW8fpDxJkiC9CgDqKrTBIR5VMK1q4DYy4,7729
|
83
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=t72kmUn60UtjpD6H38XzKbEs50gU2dS1IH0u-RnHZ04,3666
|
86
84
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
87
85
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
88
86
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
@@ -99,7 +97,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEq
|
|
99
97
|
nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
|
100
98
|
nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
|
101
99
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
102
|
-
nshtrainer/data/datamodule.py,sha256=
|
100
|
+
nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk,4138
|
103
101
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
104
102
|
nshtrainer/loggers/__init__.py,sha256=-y8B-9TF6vJdZUQewJNDcZ2aOv04FEUFtKwaiDobIO0,670
|
105
103
|
nshtrainer/loggers/_base.py,sha256=nw4AZzJP3Z-fljgQlgq7FkuMkPmYKTsXj7OfJJSmtXI,811
|
@@ -114,7 +112,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=vXH5S26ESHO_LPPqW8aDC3S5N
|
|
114
112
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
115
113
|
nshtrainer/metrics/_config.py,sha256=XIRokFM8PHrhBa3w2R6BM6a4es3ncsoBqE_LqXQFsFE,1223
|
116
114
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
117
|
-
nshtrainer/model/base.py,sha256=
|
115
|
+
nshtrainer/model/base.py,sha256=JL3AmH17GQjQIoMrZl3O0vUI7dj5ZsO5iEJgoLPyzHw,10356
|
118
116
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
119
117
|
nshtrainer/model/mixins/debug.py,sha256=1LX9KzeFX9JDPs_a6YCdYDZXLhEk_5rBO2aCqlfBy7w,2087
|
120
118
|
nshtrainer/model/mixins/logger.py,sha256=27H99FuLaxc6_dDLG2pid4E_5E0-eLGnc2Ifpt0HYIM,6066
|
@@ -129,8 +127,8 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
|
|
129
127
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
130
128
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
131
129
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
132
|
-
nshtrainer/trainer/__init__.py,sha256=
|
133
|
-
nshtrainer/trainer/_config.py,sha256=
|
130
|
+
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
131
|
+
nshtrainer/trainer/_config.py,sha256=SPg3WXjF3ufhnr27sTHQLq23hdebnW6CTWa8AJkRG0A,32982
|
134
132
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
135
133
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
136
134
|
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
@@ -148,12 +146,11 @@ nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDW
|
|
148
146
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
149
147
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
150
148
|
nshtrainer/util/environment.py,sha256=s-B5nY0cKYXdFMdNYumvC_xxacMATiI4DvV2gUDu20k,4195
|
151
|
-
nshtrainer/util/hparams.py,sha256=4i9czN6JQfDke2wuZzaOTNvwqHJvAvmoVD-PeL5c4r4,475
|
152
149
|
nshtrainer/util/path.py,sha256=L-Nh9tlXSUfoP19TFbQq8I0AfS5ugCfGYTYFeddDHcs,3516
|
153
150
|
nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
154
151
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
155
152
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
156
153
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
159
|
-
nshtrainer-1.0.
|
154
|
+
nshtrainer-1.0.0b36.dist-info/METADATA,sha256=R9O2SnflaNiDkxtoOPD_YFCXIgnEl8YjkhbEU5CbWHQ,988
|
155
|
+
nshtrainer-1.0.0b36.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
156
|
+
nshtrainer-1.0.0b36.dist-info/RECORD,,
|
@@ -1,114 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
import time
|
5
|
-
from datetime import timedelta
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import Any, Literal
|
8
|
-
|
9
|
-
from lightning.pytorch import LightningModule, Trainer
|
10
|
-
from typing_extensions import final, override
|
11
|
-
|
12
|
-
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
13
|
-
|
14
|
-
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
15
|
-
|
16
|
-
log = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
@final
|
20
|
-
class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
21
|
-
name: Literal["time_checkpoint"] = "time_checkpoint"
|
22
|
-
|
23
|
-
interval: timedelta = timedelta(hours=12)
|
24
|
-
"""Time interval between checkpoints."""
|
25
|
-
|
26
|
-
@override
|
27
|
-
def create_checkpoint(self, trainer_config, dirpath):
|
28
|
-
return TimeCheckpointCallback(self, dirpath)
|
29
|
-
|
30
|
-
|
31
|
-
@final
|
32
|
-
class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
33
|
-
def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
|
34
|
-
super().__init__(config, dirpath)
|
35
|
-
self.start_time = time.time()
|
36
|
-
self.last_checkpoint_time = self.start_time
|
37
|
-
self.interval_seconds = config.interval.total_seconds()
|
38
|
-
|
39
|
-
@override
|
40
|
-
def name(self):
|
41
|
-
return "time"
|
42
|
-
|
43
|
-
@override
|
44
|
-
def default_filename(self):
|
45
|
-
return "epoch{epoch}-step{step}-duration{train_duration}"
|
46
|
-
|
47
|
-
@override
|
48
|
-
def topk_sort_key(self, metadata: CheckpointMetadata):
|
49
|
-
return metadata.checkpoint_timestamp
|
50
|
-
|
51
|
-
@override
|
52
|
-
def topk_sort_reverse(self):
|
53
|
-
return True
|
54
|
-
|
55
|
-
def _should_checkpoint(self) -> bool:
|
56
|
-
current_time = time.time()
|
57
|
-
elapsed_time = current_time - self.last_checkpoint_time
|
58
|
-
return elapsed_time >= self.interval_seconds
|
59
|
-
|
60
|
-
def _format_duration(self, seconds: float) -> str:
|
61
|
-
"""Format duration in seconds to a human-readable string."""
|
62
|
-
td = timedelta(seconds=int(seconds))
|
63
|
-
days = td.days
|
64
|
-
hours, remainder = divmod(td.seconds, 3600)
|
65
|
-
minutes, seconds = divmod(remainder, 60)
|
66
|
-
|
67
|
-
parts = []
|
68
|
-
if days > 0:
|
69
|
-
parts.append(f"{days}d")
|
70
|
-
if hours > 0:
|
71
|
-
parts.append(f"{hours}h")
|
72
|
-
if minutes > 0:
|
73
|
-
parts.append(f"{minutes}m")
|
74
|
-
if seconds > 0 or not parts:
|
75
|
-
parts.append(f"{seconds}s")
|
76
|
-
|
77
|
-
return "_".join(parts)
|
78
|
-
|
79
|
-
@override
|
80
|
-
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
81
|
-
metrics = super().current_metrics(trainer)
|
82
|
-
train_duration = time.time() - self.start_time
|
83
|
-
metrics["train_duration"] = self._format_duration(train_duration)
|
84
|
-
return metrics
|
85
|
-
|
86
|
-
@override
|
87
|
-
def on_train_batch_end(
|
88
|
-
self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
|
89
|
-
):
|
90
|
-
if self._should_checkpoint():
|
91
|
-
self.save_checkpoints(trainer)
|
92
|
-
self.last_checkpoint_time = time.time()
|
93
|
-
|
94
|
-
@override
|
95
|
-
def state_dict(self) -> dict[str, Any]:
|
96
|
-
"""Save the timer state for checkpoint resumption.
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
Dictionary containing the start time and last checkpoint time.
|
100
|
-
"""
|
101
|
-
return {
|
102
|
-
"start_time": self.start_time,
|
103
|
-
"last_checkpoint_time": self.last_checkpoint_time,
|
104
|
-
}
|
105
|
-
|
106
|
-
@override
|
107
|
-
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
108
|
-
"""Restore the timer state when resuming from a checkpoint.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
state_dict: Dictionary containing the previously saved timer state.
|
112
|
-
"""
|
113
|
-
self.start_time = state_dict["start_time"]
|
114
|
-
self.last_checkpoint_time = state_dict["last_checkpoint_time"]
|
@@ -1,19 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
6
|
-
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
7
|
-
)
|
8
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
9
|
-
CheckpointMetadata as CheckpointMetadata,
|
10
|
-
)
|
11
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
12
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
13
|
-
)
|
14
|
-
|
15
|
-
__all__ = [
|
16
|
-
"BaseCheckpointCallbackConfig",
|
17
|
-
"CheckpointMetadata",
|
18
|
-
"TimeCheckpointCallbackConfig",
|
19
|
-
]
|
nshtrainer/util/hparams.py
DELETED
@@ -1,18 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
|
-
import nshconfig as C
|
6
|
-
from lightning.pytorch.core.mixins.hparams_mixin import (
|
7
|
-
HyperparametersMixin as _LightningHyperparametersMixin,
|
8
|
-
)
|
9
|
-
|
10
|
-
|
11
|
-
class HyperparamsMixin(_LightningHyperparametersMixin):
|
12
|
-
if not TYPE_CHECKING:
|
13
|
-
|
14
|
-
def _to_hparams_dict(self, hp):
|
15
|
-
if isinstance(hp, C.Config):
|
16
|
-
return hp.model_dump(mode="python")
|
17
|
-
|
18
|
-
return super()._set_hparams(hp)
|
File without changes
|