nshtrainer 1.0.0b33__py3-none-any.whl → 1.0.0b36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -0
- nshtrainer/_hf_hub.py +8 -1
- nshtrainer/callbacks/__init__.py +10 -23
- nshtrainer/callbacks/actsave.py +6 -2
- nshtrainer/callbacks/base.py +3 -0
- nshtrainer/callbacks/checkpoint/__init__.py +0 -4
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +72 -2
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -2
- nshtrainer/callbacks/debug_flag.py +4 -2
- nshtrainer/callbacks/directory_setup.py +23 -21
- nshtrainer/callbacks/early_stopping.py +4 -2
- nshtrainer/callbacks/ema.py +29 -27
- nshtrainer/callbacks/finite_checks.py +21 -19
- nshtrainer/callbacks/gradient_skipping.py +29 -27
- nshtrainer/callbacks/log_epoch.py +4 -2
- nshtrainer/callbacks/lr_monitor.py +6 -1
- nshtrainer/callbacks/norm_logging.py +36 -34
- nshtrainer/callbacks/print_table.py +20 -18
- nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
- nshtrainer/callbacks/shared_parameters.py +9 -7
- nshtrainer/callbacks/timer.py +12 -10
- nshtrainer/callbacks/wandb_upload_code.py +4 -2
- nshtrainer/callbacks/wandb_watch.py +4 -2
- nshtrainer/configs/__init__.py +4 -8
- nshtrainer/configs/_hf_hub/__init__.py +2 -0
- nshtrainer/configs/callbacks/__init__.py +4 -8
- nshtrainer/configs/callbacks/actsave/__init__.py +2 -0
- nshtrainer/configs/callbacks/base/__init__.py +2 -0
- nshtrainer/configs/callbacks/checkpoint/__init__.py +4 -6
- nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -0
- nshtrainer/configs/callbacks/debug_flag/__init__.py +2 -0
- nshtrainer/configs/callbacks/directory_setup/__init__.py +2 -0
- nshtrainer/configs/callbacks/early_stopping/__init__.py +2 -0
- nshtrainer/configs/callbacks/ema/__init__.py +2 -0
- nshtrainer/configs/callbacks/finite_checks/__init__.py +2 -0
- nshtrainer/configs/callbacks/gradient_skipping/__init__.py +4 -0
- nshtrainer/configs/callbacks/log_epoch/__init__.py +2 -0
- nshtrainer/configs/callbacks/lr_monitor/__init__.py +2 -0
- nshtrainer/configs/callbacks/norm_logging/__init__.py +2 -0
- nshtrainer/configs/callbacks/print_table/__init__.py +2 -0
- nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py +4 -0
- nshtrainer/configs/callbacks/shared_parameters/__init__.py +4 -0
- nshtrainer/configs/callbacks/timer/__init__.py +2 -0
- nshtrainer/configs/callbacks/wandb_upload_code/__init__.py +4 -0
- nshtrainer/configs/callbacks/wandb_watch/__init__.py +2 -0
- nshtrainer/configs/trainer/__init__.py +2 -4
- nshtrainer/configs/trainer/_config/__init__.py +0 -8
- nshtrainer/trainer/__init__.py +3 -2
- nshtrainer/trainer/_config.py +4 -42
- {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b36.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b36.dist-info}/RECORD +55 -57
- nshtrainer/callbacks/checkpoint/time_checkpoint.py +0 -114
- nshtrainer/configs/callbacks/checkpoint/time_checkpoint/__init__.py +0 -19
- {nshtrainer-1.0.0b33.dist-info → nshtrainer-1.0.0b36.dist-info}/WHEEL +0 -0
@@ -8,8 +8,12 @@ from nshtrainer.callbacks.wandb_upload_code import (
|
|
8
8
|
from nshtrainer.callbacks.wandb_upload_code import (
|
9
9
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
10
10
|
)
|
11
|
+
from nshtrainer.callbacks.wandb_upload_code import (
|
12
|
+
callback_registry as callback_registry,
|
13
|
+
)
|
11
14
|
|
12
15
|
__all__ = [
|
13
16
|
"CallbackConfigBase",
|
14
17
|
"WandbUploadCodeCallbackConfig",
|
18
|
+
"callback_registry",
|
15
19
|
]
|
@@ -6,8 +6,10 @@ from nshtrainer.callbacks.wandb_watch import CallbackConfigBase as CallbackConfi
|
|
6
6
|
from nshtrainer.callbacks.wandb_watch import (
|
7
7
|
WandbWatchCallbackConfig as WandbWatchCallbackConfig,
|
8
8
|
)
|
9
|
+
from nshtrainer.callbacks.wandb_watch import callback_registry as callback_registry
|
9
10
|
|
10
11
|
__all__ = [
|
11
12
|
"CallbackConfigBase",
|
12
13
|
"WandbWatchCallbackConfig",
|
14
|
+
"callback_registry",
|
13
15
|
]
|
@@ -4,6 +4,7 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer.trainer import TrainerConfig as TrainerConfig
|
6
6
|
from nshtrainer.trainer import accelerator_registry as accelerator_registry
|
7
|
+
from nshtrainer.trainer import callback_registry as callback_registry
|
7
8
|
from nshtrainer.trainer import plugin_registry as plugin_registry
|
8
9
|
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
9
10
|
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
@@ -55,9 +56,6 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
|
55
56
|
from nshtrainer.trainer._config import (
|
56
57
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
57
58
|
)
|
58
|
-
from nshtrainer.trainer._config import (
|
59
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
60
|
-
)
|
61
59
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
62
60
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
63
61
|
from nshtrainer.trainer.accelerator import (
|
@@ -179,7 +177,6 @@ __all__ = [
|
|
179
177
|
"StrategyConfig",
|
180
178
|
"StrategyConfigBase",
|
181
179
|
"TensorboardLoggerConfig",
|
182
|
-
"TimeCheckpointCallbackConfig",
|
183
180
|
"TorchCheckpointIOPlugin",
|
184
181
|
"TorchElasticEnvironmentPlugin",
|
185
182
|
"TorchSyncBatchNormPlugin",
|
@@ -193,6 +190,7 @@ __all__ = [
|
|
193
190
|
"_config",
|
194
191
|
"accelerator",
|
195
192
|
"accelerator_registry",
|
193
|
+
"callback_registry",
|
196
194
|
"plugin",
|
197
195
|
"plugin_registry",
|
198
196
|
"strategy",
|
@@ -53,13 +53,8 @@ from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
|
53
53
|
from nshtrainer.trainer._config import (
|
54
54
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
55
55
|
)
|
56
|
-
from nshtrainer.trainer._config import (
|
57
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
58
|
-
)
|
59
56
|
from nshtrainer.trainer._config import TrainerConfig as TrainerConfig
|
60
57
|
from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
|
61
|
-
from nshtrainer.trainer._config import accelerator_registry as accelerator_registry
|
62
|
-
from nshtrainer.trainer._config import plugin_registry as plugin_registry
|
63
58
|
|
64
59
|
__all__ = [
|
65
60
|
"AcceleratorConfig",
|
@@ -91,9 +86,6 @@ __all__ = [
|
|
91
86
|
"SharedParametersCallbackConfig",
|
92
87
|
"StrategyConfig",
|
93
88
|
"TensorboardLoggerConfig",
|
94
|
-
"TimeCheckpointCallbackConfig",
|
95
89
|
"TrainerConfig",
|
96
90
|
"WandbLoggerConfig",
|
97
|
-
"accelerator_registry",
|
98
|
-
"plugin_registry",
|
99
91
|
]
|
nshtrainer/trainer/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from ..callbacks import callback_registry as callback_registry
|
3
4
|
from ._config import TrainerConfig as TrainerConfig
|
4
|
-
from .
|
5
|
-
from .
|
5
|
+
from .accelerator import accelerator_registry as accelerator_registry
|
6
|
+
from .plugin import plugin_registry as plugin_registry
|
6
7
|
from .trainer import Trainer as Trainer
|
nshtrainer/trainer/_config.py
CHANGED
@@ -37,7 +37,6 @@ from ..callbacks import (
|
|
37
37
|
OnExceptionCheckpointCallbackConfig,
|
38
38
|
)
|
39
39
|
from ..callbacks.base import CallbackConfigBase
|
40
|
-
from ..callbacks.checkpoint.time_checkpoint import TimeCheckpointCallbackConfig
|
41
40
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
42
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
43
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
@@ -54,9 +53,9 @@ from ..loggers.actsave import ActSaveLoggerConfig
|
|
54
53
|
from ..metrics._config import MetricConfig
|
55
54
|
from ..profiler import ProfilerConfig
|
56
55
|
from ..util._environment_info import EnvironmentConfig
|
57
|
-
from .accelerator import AcceleratorConfig, AcceleratorLiteral
|
58
|
-
from .plugin import PluginConfig
|
59
|
-
from .strategy import StrategyConfig
|
56
|
+
from .accelerator import AcceleratorConfig, AcceleratorLiteral
|
57
|
+
from .plugin import PluginConfig
|
58
|
+
from .strategy import StrategyConfig, StrategyLiteral
|
60
59
|
|
61
60
|
log = logging.getLogger(__name__)
|
62
61
|
|
@@ -70,46 +69,12 @@ class GradientClippingConfig(C.Config):
|
|
70
69
|
"""Norm type to use for gradient clipping."""
|
71
70
|
|
72
71
|
|
73
|
-
StrategyLiteral = TypeAliasType(
|
74
|
-
"StrategyLiteral",
|
75
|
-
Literal[
|
76
|
-
"auto",
|
77
|
-
"ddp",
|
78
|
-
"ddp_find_unused_parameters_false",
|
79
|
-
"ddp_find_unused_parameters_true",
|
80
|
-
"ddp_spawn",
|
81
|
-
"ddp_spawn_find_unused_parameters_false",
|
82
|
-
"ddp_spawn_find_unused_parameters_true",
|
83
|
-
"ddp_fork",
|
84
|
-
"ddp_fork_find_unused_parameters_false",
|
85
|
-
"ddp_fork_find_unused_parameters_true",
|
86
|
-
"ddp_notebook",
|
87
|
-
"dp",
|
88
|
-
"deepspeed",
|
89
|
-
"deepspeed_stage_1",
|
90
|
-
"deepspeed_stage_1_offload",
|
91
|
-
"deepspeed_stage_2",
|
92
|
-
"deepspeed_stage_2_offload",
|
93
|
-
"deepspeed_stage_3",
|
94
|
-
"deepspeed_stage_3_offload",
|
95
|
-
"deepspeed_stage_3_offload_nvme",
|
96
|
-
"fsdp",
|
97
|
-
"fsdp_cpu_offload",
|
98
|
-
"single_xla",
|
99
|
-
"xla_fsdp",
|
100
|
-
"xla",
|
101
|
-
"single_tpu",
|
102
|
-
],
|
103
|
-
)
|
104
|
-
|
105
|
-
|
106
72
|
CheckpointCallbackConfig = TypeAliasType(
|
107
73
|
"CheckpointCallbackConfig",
|
108
74
|
Annotated[
|
109
75
|
BestCheckpointCallbackConfig
|
110
76
|
| LastCheckpointCallbackConfig
|
111
|
-
| OnExceptionCheckpointCallbackConfig
|
112
|
-
| TimeCheckpointCallbackConfig,
|
77
|
+
| OnExceptionCheckpointCallbackConfig,
|
113
78
|
C.Field(discriminator="name"),
|
114
79
|
],
|
115
80
|
)
|
@@ -123,7 +88,6 @@ class CheckpointSavingConfig(CallbackConfigBase):
|
|
123
88
|
BestCheckpointCallbackConfig(throw_on_no_metric=False),
|
124
89
|
LastCheckpointCallbackConfig(),
|
125
90
|
OnExceptionCheckpointCallbackConfig(),
|
126
|
-
TimeCheckpointCallbackConfig(interval=timedelta(hours=12)),
|
127
91
|
]
|
128
92
|
"""Checkpoint callback configurations."""
|
129
93
|
|
@@ -397,8 +361,6 @@ class SanityCheckingConfig(C.Config):
|
|
397
361
|
"""
|
398
362
|
|
399
363
|
|
400
|
-
@plugin_registry.rebuild_on_registers
|
401
|
-
@accelerator_registry.rebuild_on_registers
|
402
364
|
class TrainerConfig(C.Config):
|
403
365
|
# region Active Run Configuration
|
404
366
|
id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
|
@@ -1,65 +1,63 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
6
6
|
nshtrainer/_directory.py,sha256=p2uk1FnISFEpMqlDevKhoWhQsCEtvHUPg459K-86QA8,3053
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
|
-
nshtrainer/_hf_hub.py,sha256=
|
9
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
10
|
-
nshtrainer/callbacks/actsave.py,sha256=
|
11
|
-
nshtrainer/callbacks/base.py,sha256=
|
12
|
-
nshtrainer/callbacks/checkpoint/__init__.py,sha256=
|
8
|
+
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
|
+
nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
|
10
|
+
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
|
+
nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,3683
|
12
|
+
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
|
-
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
16
|
-
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=
|
17
|
-
nshtrainer/callbacks/
|
18
|
-
nshtrainer/callbacks/
|
19
|
-
nshtrainer/callbacks/
|
20
|
-
nshtrainer/callbacks/
|
21
|
-
nshtrainer/callbacks/
|
22
|
-
nshtrainer/callbacks/
|
23
|
-
nshtrainer/callbacks/gradient_skipping.py,sha256=k5qNaNeileZ_5YFad4ssfLplMxMKeKFhPcY8-QVmLek,3464
|
14
|
+
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=MJcNB0biOebx2si2IBFaSUiVOSLSCZTzxB-RcEgO2gY,3482
|
16
|
+
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
|
+
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
|
+
nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
|
19
|
+
nshtrainer/callbacks/early_stopping.py,sha256=EjzN-gD_Xd4YHZLkXsbi00g_4ti3RTMJEdHJ8GMeaFM,4776
|
20
|
+
nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
|
21
|
+
nshtrainer/callbacks/finite_checks.py,sha256=3lZ3kEIjmYQfqTF0DcrgZ9_98ZLQhQj8usH7SgWst3o,2185
|
22
|
+
nshtrainer/callbacks/gradient_skipping.py,sha256=8g7oC7PF0LTAEzwiNoaS5tWOnkjk_EB0QG3JdHkQ8ek,3523
|
24
23
|
nshtrainer/callbacks/interval.py,sha256=UCzUzt3XCFVyQyCWL9lOrStkkxesvduNOYk8yMrGTTk,8116
|
25
|
-
nshtrainer/callbacks/log_epoch.py,sha256=
|
26
|
-
nshtrainer/callbacks/lr_monitor.py,sha256=
|
27
|
-
nshtrainer/callbacks/norm_logging.py,sha256=
|
28
|
-
nshtrainer/callbacks/print_table.py,sha256=
|
29
|
-
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=
|
30
|
-
nshtrainer/callbacks/shared_parameters.py,sha256=
|
31
|
-
nshtrainer/callbacks/timer.py,sha256=
|
32
|
-
nshtrainer/callbacks/wandb_upload_code.py,sha256=
|
33
|
-
nshtrainer/callbacks/wandb_watch.py,sha256=
|
34
|
-
nshtrainer/configs/__init__.py,sha256=
|
24
|
+
nshtrainer/callbacks/log_epoch.py,sha256=B5Dm8XVZwCzKUhUWfT_5PDdDac993191OsbcxxuSVJE,1457
|
25
|
+
nshtrainer/callbacks/lr_monitor.py,sha256=qy_C0R40J0hBAukzBwng5FI2jJUpWuXOi5N6FU6ym3I,1210
|
26
|
+
nshtrainer/callbacks/norm_logging.py,sha256=nVIDWe-ASl5zN830-ODR8QMCqI1ma-QPCIwoy0Wb-Nk,6390
|
27
|
+
nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0zyal_LA,3035
|
28
|
+
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=74BZvV2HLO__ucQXsLXb8eJLUZgRFUNJZ6TL9efMp74,10051
|
29
|
+
nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
|
30
|
+
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
31
|
+
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
32
|
+
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
33
|
+
nshtrainer/configs/__init__.py,sha256=OevZEZxb4H8imadSQXK9huqdYUF4SrJPfNU_2fpMBvI,14084
|
35
34
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
36
35
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
37
36
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
38
|
-
nshtrainer/configs/_hf_hub/__init__.py,sha256=
|
39
|
-
nshtrainer/configs/callbacks/__init__.py,sha256=
|
40
|
-
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=
|
41
|
-
nshtrainer/configs/callbacks/base/__init__.py,sha256=
|
42
|
-
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=
|
37
|
+
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
38
|
+
nshtrainer/configs/callbacks/__init__.py,sha256=jSWkbsdiu9vdGWTzqkDf-Bo9dXr9RengeNZLzWUhi7Y,4283
|
39
|
+
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
40
|
+
nshtrainer/configs/callbacks/base/__init__.py,sha256=wT3RhXttLyf6RFWCIvsoiXcPdfGx5W309WBI18AI5os,278
|
41
|
+
nshtrainer/configs/callbacks/checkpoint/__init__.py,sha256=aGJ7vX14YamkMdwYAdPv6XrRnP0aZd5uZ5X0nSLc6IU,1475
|
43
42
|
nshtrainer/configs/callbacks/checkpoint/_base/__init__.py,sha256=5jl6A5Gv6arZXmHV6lz5dQ8DL6PdJIfJqHLP4acClKQ,479
|
44
|
-
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=
|
45
|
-
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=
|
46
|
-
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=
|
47
|
-
nshtrainer/configs/callbacks/
|
48
|
-
nshtrainer/configs/callbacks/
|
49
|
-
nshtrainer/configs/callbacks/
|
50
|
-
nshtrainer/configs/callbacks/
|
51
|
-
nshtrainer/configs/callbacks/
|
52
|
-
nshtrainer/configs/callbacks/
|
53
|
-
nshtrainer/configs/callbacks/
|
54
|
-
nshtrainer/configs/callbacks/
|
55
|
-
nshtrainer/configs/callbacks/
|
56
|
-
nshtrainer/configs/callbacks/
|
57
|
-
nshtrainer/configs/callbacks/
|
58
|
-
nshtrainer/configs/callbacks/
|
59
|
-
nshtrainer/configs/callbacks/
|
60
|
-
nshtrainer/configs/callbacks/
|
61
|
-
nshtrainer/configs/callbacks/
|
62
|
-
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=MW-ANrF529DxBhopovPjYEQ7nANX9ttd1K4_bJnKXks,322
|
43
|
+
nshtrainer/configs/callbacks/checkpoint/best_checkpoint/__init__.py,sha256=hL4WGBdo5_gtQuEGcRa3cWYMOSFdlNzkW-2Y3X3ZGTI,781
|
44
|
+
nshtrainer/configs/callbacks/checkpoint/last_checkpoint/__init__.py,sha256=SIRfz5QP30K4zzKw_1LZSSFr-3x-S3vc0vWL4ndyvjc,672
|
45
|
+
nshtrainer/configs/callbacks/checkpoint/on_exception_checkpoint/__init__.py,sha256=VSkO0TYCAYy_9mQuOBoAND7D3Cg6w6nMCpqivQZLPcE,551
|
46
|
+
nshtrainer/configs/callbacks/debug_flag/__init__.py,sha256=s_ifB-DbZjar0w11pr2oVAlcMTWWMnK_tCNilfswL04,425
|
47
|
+
nshtrainer/configs/callbacks/directory_setup/__init__.py,sha256=e8GCRy2Alds3AXLwp4ieSGtn8S0YjmKJ5khOaQ0zKGs,464
|
48
|
+
nshtrainer/configs/callbacks/early_stopping/__init__.py,sha256=m8N6H11PjqcWqXP5ZxWC8L4PHMUI6avYyN5rUNprjuQ,546
|
49
|
+
nshtrainer/configs/callbacks/ema/__init__.py,sha256=DUJrbDD8wWX_s0_4dwKpT_IWKSVpBmhe4-1aELq7G6w,377
|
50
|
+
nshtrainer/configs/callbacks/finite_checks/__init__.py,sha256=e-vx9Kn-noqw4wPvZw7fDMfb9Tsa6Duk0TIa8ZIgIIE,443
|
51
|
+
nshtrainer/configs/callbacks/gradient_skipping/__init__.py,sha256=T3eVxxJfnYBrO9WfLiycn4TyWP4vaqJ57yp7Epkg7B4,485
|
52
|
+
nshtrainer/configs/callbacks/log_epoch/__init__.py,sha256=IQ5owYYvyk7fiQP1QXYtncRRJrESuq3rRFhab-II2uE,419
|
53
|
+
nshtrainer/configs/callbacks/lr_monitor/__init__.py,sha256=qejy1AnXNDHmsFuXRAXQQ5B0TcbKzvpaw-I4dv2AXIs,431
|
54
|
+
nshtrainer/configs/callbacks/norm_logging/__init__.py,sha256=j2LrnYEbDGLwJR2lk-jmh-4J_iLEs2HNEoepvJSFLAg,437
|
55
|
+
nshtrainer/configs/callbacks/print_table/__init__.py,sha256=t6fA_dBkUCszUXDJKEdnlBH4oEpfAQqcmAlatTFYIyQ,452
|
56
|
+
nshtrainer/configs/callbacks/rlp_sanity_checks/__init__.py,sha256=dlP14Wh-w8zG_B4EtNmCIFzVMhf6bXCJ1O9cJWmEFnA,482
|
57
|
+
nshtrainer/configs/callbacks/shared_parameters/__init__.py,sha256=AU7_bSnSRSlj16UTaFBe9QVUf0T0zKUmKCOBSL4xYmg,485
|
58
|
+
nshtrainer/configs/callbacks/timer/__init__.py,sha256=cOUtbsl0_OhCO0fIcBfLuIF6FEGBHQu7AvQFzwVznWQ,413
|
59
|
+
nshtrainer/configs/callbacks/wandb_upload_code/__init__.py,sha256=CJeCc9OCu5F39lWiY5aIc4WxQlgBvB-8cga6cQtw0GQ,482
|
60
|
+
nshtrainer/configs/callbacks/wandb_watch/__init__.py,sha256=dzz1oavL1BwELE33xus45_avBEAZDeB6xtcb6CsOEos,431
|
63
61
|
nshtrainer/configs/loggers/__init__.py,sha256=5wTekL79mQxit8f1K3AMllvb0mKertTzOKfC3gpE2Zk,1251
|
64
62
|
nshtrainer/configs/loggers/_base/__init__.py,sha256=HxPPPePsEjlNuhnjsMgYIl0rwj_iqNKKOBTEk_zIOsM,169
|
65
63
|
nshtrainer/configs/loggers/actsave/__init__.py,sha256=2lZQ4bpbjwd4MuUE_Z_PGbmQjjGtWCZUCtXqKO4dTSc,280
|
@@ -81,8 +79,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
81
79
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
82
80
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
83
81
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
84
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
85
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
82
|
+
nshtrainer/configs/trainer/__init__.py,sha256=8Z4E1IeJHtDW8fpDxJkiC9CgDqKrTBIR5VMK1q4DYy4,7729
|
83
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=t72kmUn60UtjpD6H38XzKbEs50gU2dS1IH0u-RnHZ04,3666
|
86
84
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
87
85
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
88
86
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=GuubcKrbXt4VjJRT8VpNUQqBtyuutre_CfkS0EWZ5_E,368
|
@@ -129,8 +127,8 @@ nshtrainer/profiler/_base.py,sha256=kFcSVn9gJuMwgDxbfyHh46CmEAIPZjxw3yjPbKgzvwA,
|
|
129
127
|
nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5NOQ,1160
|
130
128
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
131
129
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
132
|
-
nshtrainer/trainer/__init__.py,sha256=
|
133
|
-
nshtrainer/trainer/_config.py,sha256=
|
130
|
+
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
131
|
+
nshtrainer/trainer/_config.py,sha256=SPg3WXjF3ufhnr27sTHQLq23hdebnW6CTWa8AJkRG0A,32982
|
134
132
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
135
133
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
136
134
|
nshtrainer/trainer/plugin/__init__.py,sha256=UM8f70Ml3RGpsXeQr1Yh1yBcxxjFNGFGgJbniCn_rws,366
|
@@ -153,6 +151,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
153
151
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
154
152
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
155
153
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
156
|
-
nshtrainer-1.0.
|
157
|
-
nshtrainer-1.0.
|
158
|
-
nshtrainer-1.0.
|
154
|
+
nshtrainer-1.0.0b36.dist-info/METADATA,sha256=R9O2SnflaNiDkxtoOPD_YFCXIgnEl8YjkhbEU5CbWHQ,988
|
155
|
+
nshtrainer-1.0.0b36.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
156
|
+
nshtrainer-1.0.0b36.dist-info/RECORD,,
|
@@ -1,114 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
import time
|
5
|
-
from datetime import timedelta
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import Any, Literal
|
8
|
-
|
9
|
-
from lightning.pytorch import LightningModule, Trainer
|
10
|
-
from typing_extensions import final, override
|
11
|
-
|
12
|
-
from nshtrainer._checkpoint.metadata import CheckpointMetadata
|
13
|
-
|
14
|
-
from ._base import BaseCheckpointCallbackConfig, CheckpointBase
|
15
|
-
|
16
|
-
log = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
@final
|
20
|
-
class TimeCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
21
|
-
name: Literal["time_checkpoint"] = "time_checkpoint"
|
22
|
-
|
23
|
-
interval: timedelta = timedelta(hours=12)
|
24
|
-
"""Time interval between checkpoints."""
|
25
|
-
|
26
|
-
@override
|
27
|
-
def create_checkpoint(self, trainer_config, dirpath):
|
28
|
-
return TimeCheckpointCallback(self, dirpath)
|
29
|
-
|
30
|
-
|
31
|
-
@final
|
32
|
-
class TimeCheckpointCallback(CheckpointBase[TimeCheckpointCallbackConfig]):
|
33
|
-
def __init__(self, config: TimeCheckpointCallbackConfig, dirpath: Path):
|
34
|
-
super().__init__(config, dirpath)
|
35
|
-
self.start_time = time.time()
|
36
|
-
self.last_checkpoint_time = self.start_time
|
37
|
-
self.interval_seconds = config.interval.total_seconds()
|
38
|
-
|
39
|
-
@override
|
40
|
-
def name(self):
|
41
|
-
return "time"
|
42
|
-
|
43
|
-
@override
|
44
|
-
def default_filename(self):
|
45
|
-
return "epoch{epoch}-step{step}-duration{train_duration}"
|
46
|
-
|
47
|
-
@override
|
48
|
-
def topk_sort_key(self, metadata: CheckpointMetadata):
|
49
|
-
return metadata.checkpoint_timestamp
|
50
|
-
|
51
|
-
@override
|
52
|
-
def topk_sort_reverse(self):
|
53
|
-
return True
|
54
|
-
|
55
|
-
def _should_checkpoint(self) -> bool:
|
56
|
-
current_time = time.time()
|
57
|
-
elapsed_time = current_time - self.last_checkpoint_time
|
58
|
-
return elapsed_time >= self.interval_seconds
|
59
|
-
|
60
|
-
def _format_duration(self, seconds: float) -> str:
|
61
|
-
"""Format duration in seconds to a human-readable string."""
|
62
|
-
td = timedelta(seconds=int(seconds))
|
63
|
-
days = td.days
|
64
|
-
hours, remainder = divmod(td.seconds, 3600)
|
65
|
-
minutes, seconds = divmod(remainder, 60)
|
66
|
-
|
67
|
-
parts = []
|
68
|
-
if days > 0:
|
69
|
-
parts.append(f"{days}d")
|
70
|
-
if hours > 0:
|
71
|
-
parts.append(f"{hours}h")
|
72
|
-
if minutes > 0:
|
73
|
-
parts.append(f"{minutes}m")
|
74
|
-
if seconds > 0 or not parts:
|
75
|
-
parts.append(f"{seconds}s")
|
76
|
-
|
77
|
-
return "_".join(parts)
|
78
|
-
|
79
|
-
@override
|
80
|
-
def current_metrics(self, trainer: Trainer) -> dict[str, Any]:
|
81
|
-
metrics = super().current_metrics(trainer)
|
82
|
-
train_duration = time.time() - self.start_time
|
83
|
-
metrics["train_duration"] = self._format_duration(train_duration)
|
84
|
-
return metrics
|
85
|
-
|
86
|
-
@override
|
87
|
-
def on_train_batch_end(
|
88
|
-
self, trainer: Trainer, pl_module: LightningModule, *args, **kwargs
|
89
|
-
):
|
90
|
-
if self._should_checkpoint():
|
91
|
-
self.save_checkpoints(trainer)
|
92
|
-
self.last_checkpoint_time = time.time()
|
93
|
-
|
94
|
-
@override
|
95
|
-
def state_dict(self) -> dict[str, Any]:
|
96
|
-
"""Save the timer state for checkpoint resumption.
|
97
|
-
|
98
|
-
Returns:
|
99
|
-
Dictionary containing the start time and last checkpoint time.
|
100
|
-
"""
|
101
|
-
return {
|
102
|
-
"start_time": self.start_time,
|
103
|
-
"last_checkpoint_time": self.last_checkpoint_time,
|
104
|
-
}
|
105
|
-
|
106
|
-
@override
|
107
|
-
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
108
|
-
"""Restore the timer state when resuming from a checkpoint.
|
109
|
-
|
110
|
-
Args:
|
111
|
-
state_dict: Dictionary containing the previously saved timer state.
|
112
|
-
"""
|
113
|
-
self.start_time = state_dict["start_time"]
|
114
|
-
self.last_checkpoint_time = state_dict["last_checkpoint_time"]
|
@@ -1,19 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
6
|
-
BaseCheckpointCallbackConfig as BaseCheckpointCallbackConfig,
|
7
|
-
)
|
8
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
9
|
-
CheckpointMetadata as CheckpointMetadata,
|
10
|
-
)
|
11
|
-
from nshtrainer.callbacks.checkpoint.time_checkpoint import (
|
12
|
-
TimeCheckpointCallbackConfig as TimeCheckpointCallbackConfig,
|
13
|
-
)
|
14
|
-
|
15
|
-
__all__ = [
|
16
|
-
"BaseCheckpointCallbackConfig",
|
17
|
-
"CheckpointMetadata",
|
18
|
-
"TimeCheckpointCallbackConfig",
|
19
|
-
]
|
File without changes
|