nshtrainer 0.44.0__py3-none-any.whl → 1.0.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +51 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/_base.py +2 -1
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b9.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.0.dist-info/RECORD +0 -162
- {nshtrainer-0.44.0.dist-info → nshtrainer-1.0.0b9.dist-info}/WHEEL +0 -0
@@ -19,12 +19,12 @@ else:
|
|
19
19
|
|
20
20
|
if name in globals():
|
21
21
|
return globals()[name]
|
22
|
+
if name == "DirectoryConfig":
|
23
|
+
return importlib.import_module("nshtrainer._directory").DirectoryConfig
|
22
24
|
if name == "DirectorySetupCallbackConfig":
|
23
25
|
return importlib.import_module(
|
24
26
|
"nshtrainer._directory"
|
25
27
|
).DirectorySetupCallbackConfig
|
26
|
-
if name == "DirectoryConfig":
|
27
|
-
return importlib.import_module("nshtrainer._directory").DirectoryConfig
|
28
28
|
if name == "LoggerConfig":
|
29
29
|
return importlib.import_module("nshtrainer._directory").LoggerConfig
|
30
30
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
@@ -19,14 +19,14 @@ else:
|
|
19
19
|
|
20
20
|
if name in globals():
|
21
21
|
return globals()[name]
|
22
|
+
if name == "CallbackConfigBase":
|
23
|
+
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
22
24
|
if name == "HuggingFaceHubAutoCreateConfig":
|
23
25
|
return importlib.import_module(
|
24
26
|
"nshtrainer._hf_hub"
|
25
27
|
).HuggingFaceHubAutoCreateConfig
|
26
28
|
if name == "HuggingFaceHubConfig":
|
27
29
|
return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
|
28
|
-
if name == "CallbackConfigBase":
|
29
|
-
return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
|
30
30
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
31
31
|
|
32
32
|
# Submodule exports
|
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
|
32
32
|
from nshtrainer.callbacks import (
|
33
33
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
34
34
|
)
|
35
|
+
from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
|
35
36
|
from nshtrainer.callbacks import (
|
36
37
|
NormLoggingCallbackConfig as NormLoggingCallbackConfig,
|
37
38
|
)
|
@@ -47,7 +48,6 @@ if TYPE_CHECKING:
|
|
47
48
|
from nshtrainer.callbacks import (
|
48
49
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
49
50
|
)
|
50
|
-
from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
|
51
51
|
from nshtrainer.callbacks import (
|
52
52
|
WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
|
53
53
|
)
|
@@ -69,88 +69,88 @@ else:
|
|
69
69
|
|
70
70
|
if name in globals():
|
71
71
|
return globals()[name]
|
72
|
-
if name == "
|
72
|
+
if name == "ActSaveConfig":
|
73
|
+
return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
|
74
|
+
if name == "BaseCheckpointCallbackConfig":
|
75
|
+
return importlib.import_module(
|
76
|
+
"nshtrainer.callbacks.checkpoint._base"
|
77
|
+
).BaseCheckpointCallbackConfig
|
78
|
+
if name == "BestCheckpointCallbackConfig":
|
73
79
|
return importlib.import_module(
|
74
80
|
"nshtrainer.callbacks"
|
75
|
-
).
|
81
|
+
).BestCheckpointCallbackConfig
|
76
82
|
if name == "CallbackConfigBase":
|
77
83
|
return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
|
84
|
+
if name == "CheckpointMetadata":
|
85
|
+
return importlib.import_module(
|
86
|
+
"nshtrainer.callbacks.checkpoint._base"
|
87
|
+
).CheckpointMetadata
|
78
88
|
if name == "DebugFlagCallbackConfig":
|
79
89
|
return importlib.import_module(
|
80
90
|
"nshtrainer.callbacks"
|
81
91
|
).DebugFlagCallbackConfig
|
82
|
-
if name == "
|
92
|
+
if name == "DirectorySetupCallbackConfig":
|
83
93
|
return importlib.import_module(
|
84
94
|
"nshtrainer.callbacks"
|
85
|
-
).
|
86
|
-
if name == "
|
95
|
+
).DirectorySetupCallbackConfig
|
96
|
+
if name == "EMACallbackConfig":
|
97
|
+
return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
|
98
|
+
if name == "EarlyStoppingCallbackConfig":
|
87
99
|
return importlib.import_module(
|
88
100
|
"nshtrainer.callbacks"
|
89
|
-
).
|
90
|
-
if name == "
|
101
|
+
).EarlyStoppingCallbackConfig
|
102
|
+
if name == "EpochTimerCallbackConfig":
|
91
103
|
return importlib.import_module(
|
92
104
|
"nshtrainer.callbacks"
|
93
|
-
).
|
94
|
-
if name == "
|
105
|
+
).EpochTimerCallbackConfig
|
106
|
+
if name == "FiniteChecksCallbackConfig":
|
95
107
|
return importlib.import_module(
|
96
108
|
"nshtrainer.callbacks"
|
97
|
-
).
|
98
|
-
if name == "
|
99
|
-
return importlib.import_module(
|
100
|
-
"nshtrainer.callbacks.early_stopping"
|
101
|
-
).MetricConfig
|
102
|
-
if name == "EarlyStoppingCallbackConfig":
|
109
|
+
).FiniteChecksCallbackConfig
|
110
|
+
if name == "GradientSkippingCallbackConfig":
|
103
111
|
return importlib.import_module(
|
104
112
|
"nshtrainer.callbacks"
|
105
|
-
).
|
106
|
-
if name == "
|
113
|
+
).GradientSkippingCallbackConfig
|
114
|
+
if name == "LastCheckpointCallbackConfig":
|
107
115
|
return importlib.import_module(
|
108
116
|
"nshtrainer.callbacks"
|
109
|
-
).
|
110
|
-
if name == "
|
111
|
-
return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
|
112
|
-
if name == "DirectorySetupCallbackConfig":
|
117
|
+
).LastCheckpointCallbackConfig
|
118
|
+
if name == "LogEpochCallbackConfig":
|
113
119
|
return importlib.import_module(
|
114
120
|
"nshtrainer.callbacks"
|
115
|
-
).
|
116
|
-
if name == "
|
117
|
-
return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
|
118
|
-
if name == "FiniteChecksCallbackConfig":
|
121
|
+
).LogEpochCallbackConfig
|
122
|
+
if name == "MetricConfig":
|
119
123
|
return importlib.import_module(
|
120
|
-
"nshtrainer.callbacks"
|
121
|
-
).
|
124
|
+
"nshtrainer.callbacks.early_stopping"
|
125
|
+
).MetricConfig
|
122
126
|
if name == "NormLoggingCallbackConfig":
|
123
127
|
return importlib.import_module(
|
124
128
|
"nshtrainer.callbacks"
|
125
129
|
).NormLoggingCallbackConfig
|
126
|
-
if name == "EpochTimerCallbackConfig":
|
127
|
-
return importlib.import_module(
|
128
|
-
"nshtrainer.callbacks"
|
129
|
-
).EpochTimerCallbackConfig
|
130
130
|
if name == "OnExceptionCheckpointCallbackConfig":
|
131
131
|
return importlib.import_module(
|
132
132
|
"nshtrainer.callbacks"
|
133
133
|
).OnExceptionCheckpointCallbackConfig
|
134
|
-
if name == "
|
134
|
+
if name == "PrintTableMetricsCallbackConfig":
|
135
135
|
return importlib.import_module(
|
136
136
|
"nshtrainer.callbacks"
|
137
|
-
).
|
137
|
+
).PrintTableMetricsCallbackConfig
|
138
|
+
if name == "RLPSanityChecksCallbackConfig":
|
139
|
+
return importlib.import_module(
|
140
|
+
"nshtrainer.callbacks"
|
141
|
+
).RLPSanityChecksCallbackConfig
|
138
142
|
if name == "SharedParametersCallbackConfig":
|
139
143
|
return importlib.import_module(
|
140
144
|
"nshtrainer.callbacks"
|
141
145
|
).SharedParametersCallbackConfig
|
142
|
-
if name == "
|
146
|
+
if name == "WandbUploadCodeCallbackConfig":
|
143
147
|
return importlib.import_module(
|
144
148
|
"nshtrainer.callbacks"
|
145
|
-
).
|
146
|
-
if name == "
|
147
|
-
return importlib.import_module(
|
148
|
-
"nshtrainer.callbacks.checkpoint._base"
|
149
|
-
).CheckpointMetadata
|
150
|
-
if name == "BaseCheckpointCallbackConfig":
|
149
|
+
).WandbUploadCodeCallbackConfig
|
150
|
+
if name == "WandbWatchCallbackConfig":
|
151
151
|
return importlib.import_module(
|
152
|
-
"nshtrainer.callbacks
|
153
|
-
).
|
152
|
+
"nshtrainer.callbacks"
|
153
|
+
).WandbWatchCallbackConfig
|
154
154
|
if name == "CallbackConfig":
|
155
155
|
return importlib.import_module("nshtrainer.callbacks").CallbackConfig
|
156
156
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
@@ -166,11 +166,11 @@ from . import early_stopping as early_stopping
|
|
166
166
|
from . import ema as ema
|
167
167
|
from . import finite_checks as finite_checks
|
168
168
|
from . import gradient_skipping as gradient_skipping
|
169
|
+
from . import log_epoch as log_epoch
|
169
170
|
from . import norm_logging as norm_logging
|
170
171
|
from . import print_table as print_table
|
171
172
|
from . import rlp_sanity_checks as rlp_sanity_checks
|
172
173
|
from . import shared_parameters as shared_parameters
|
173
|
-
from . import throughput_monitor as throughput_monitor
|
174
174
|
from . import timer as timer
|
175
175
|
from . import wandb_upload_code as wandb_upload_code
|
176
176
|
from . import wandb_watch as wandb_watch
|
@@ -35,34 +35,34 @@ else:
|
|
35
35
|
|
36
36
|
if name in globals():
|
37
37
|
return globals()[name]
|
38
|
-
if name == "CheckpointMetadata":
|
39
|
-
return importlib.import_module(
|
40
|
-
"nshtrainer.callbacks.checkpoint._base"
|
41
|
-
).CheckpointMetadata
|
42
38
|
if name == "BaseCheckpointCallbackConfig":
|
43
39
|
return importlib.import_module(
|
44
40
|
"nshtrainer.callbacks.checkpoint._base"
|
45
41
|
).BaseCheckpointCallbackConfig
|
46
|
-
if name == "
|
42
|
+
if name == "BestCheckpointCallbackConfig":
|
47
43
|
return importlib.import_module(
|
48
44
|
"nshtrainer.callbacks.checkpoint"
|
49
|
-
).
|
45
|
+
).BestCheckpointCallbackConfig
|
50
46
|
if name == "CallbackConfigBase":
|
51
47
|
return importlib.import_module(
|
52
48
|
"nshtrainer.callbacks.checkpoint._base"
|
53
49
|
).CallbackConfigBase
|
54
|
-
if name == "
|
50
|
+
if name == "CheckpointMetadata":
|
55
51
|
return importlib.import_module(
|
56
|
-
"nshtrainer.callbacks.checkpoint"
|
57
|
-
).
|
58
|
-
if name == "
|
52
|
+
"nshtrainer.callbacks.checkpoint._base"
|
53
|
+
).CheckpointMetadata
|
54
|
+
if name == "LastCheckpointCallbackConfig":
|
59
55
|
return importlib.import_module(
|
60
56
|
"nshtrainer.callbacks.checkpoint"
|
61
|
-
).
|
57
|
+
).LastCheckpointCallbackConfig
|
62
58
|
if name == "MetricConfig":
|
63
59
|
return importlib.import_module(
|
64
60
|
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
65
61
|
).MetricConfig
|
62
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
63
|
+
return importlib.import_module(
|
64
|
+
"nshtrainer.callbacks.checkpoint"
|
65
|
+
).OnExceptionCheckpointCallbackConfig
|
66
66
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
67
67
|
|
68
68
|
|
@@ -23,10 +23,6 @@ else:
|
|
23
23
|
|
24
24
|
if name in globals():
|
25
25
|
return globals()[name]
|
26
|
-
if name == "CheckpointMetadata":
|
27
|
-
return importlib.import_module(
|
28
|
-
"nshtrainer.callbacks.checkpoint._base"
|
29
|
-
).CheckpointMetadata
|
30
26
|
if name == "BaseCheckpointCallbackConfig":
|
31
27
|
return importlib.import_module(
|
32
28
|
"nshtrainer.callbacks.checkpoint._base"
|
@@ -35,6 +31,10 @@ else:
|
|
35
31
|
return importlib.import_module(
|
36
32
|
"nshtrainer.callbacks.checkpoint._base"
|
37
33
|
).CallbackConfigBase
|
34
|
+
if name == "CheckpointMetadata":
|
35
|
+
return importlib.import_module(
|
36
|
+
"nshtrainer.callbacks.checkpoint._base"
|
37
|
+
).CheckpointMetadata
|
38
38
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
39
39
|
|
40
40
|
# Submodule exports
|
@@ -26,22 +26,22 @@ else:
|
|
26
26
|
|
27
27
|
if name in globals():
|
28
28
|
return globals()[name]
|
29
|
-
if name == "CheckpointMetadata":
|
30
|
-
return importlib.import_module(
|
31
|
-
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
32
|
-
).CheckpointMetadata
|
33
29
|
if name == "BaseCheckpointCallbackConfig":
|
34
30
|
return importlib.import_module(
|
35
31
|
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
36
32
|
).BaseCheckpointCallbackConfig
|
37
|
-
if name == "MetricConfig":
|
38
|
-
return importlib.import_module(
|
39
|
-
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
40
|
-
).MetricConfig
|
41
33
|
if name == "BestCheckpointCallbackConfig":
|
42
34
|
return importlib.import_module(
|
43
35
|
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
44
36
|
).BestCheckpointCallbackConfig
|
37
|
+
if name == "CheckpointMetadata":
|
38
|
+
return importlib.import_module(
|
39
|
+
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
40
|
+
).CheckpointMetadata
|
41
|
+
if name == "MetricConfig":
|
42
|
+
return importlib.import_module(
|
43
|
+
"nshtrainer.callbacks.checkpoint.best_checkpoint"
|
44
|
+
).MetricConfig
|
45
45
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
46
46
|
|
47
47
|
# Submodule exports
|
@@ -23,14 +23,14 @@ else:
|
|
23
23
|
|
24
24
|
if name in globals():
|
25
25
|
return globals()[name]
|
26
|
-
if name == "CheckpointMetadata":
|
27
|
-
return importlib.import_module(
|
28
|
-
"nshtrainer.callbacks.checkpoint.last_checkpoint"
|
29
|
-
).CheckpointMetadata
|
30
26
|
if name == "BaseCheckpointCallbackConfig":
|
31
27
|
return importlib.import_module(
|
32
28
|
"nshtrainer.callbacks.checkpoint.last_checkpoint"
|
33
29
|
).BaseCheckpointCallbackConfig
|
30
|
+
if name == "CheckpointMetadata":
|
31
|
+
return importlib.import_module(
|
32
|
+
"nshtrainer.callbacks.checkpoint.last_checkpoint"
|
33
|
+
).CheckpointMetadata
|
34
34
|
if name == "LastCheckpointCallbackConfig":
|
35
35
|
return importlib.import_module(
|
36
36
|
"nshtrainer.callbacks.checkpoint.last_checkpoint"
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "OnExceptionCheckpointCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
|
26
|
-
).OnExceptionCheckpointCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
|
30
|
+
).OnExceptionCheckpointCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -18,14 +18,14 @@ else:
|
|
18
18
|
|
19
19
|
if name in globals():
|
20
20
|
return globals()[name]
|
21
|
-
if name == "DebugFlagCallbackConfig":
|
22
|
-
return importlib.import_module(
|
23
|
-
"nshtrainer.callbacks.debug_flag"
|
24
|
-
).DebugFlagCallbackConfig
|
25
21
|
if name == "CallbackConfigBase":
|
26
22
|
return importlib.import_module(
|
27
23
|
"nshtrainer.callbacks.debug_flag"
|
28
24
|
).CallbackConfigBase
|
25
|
+
if name == "DebugFlagCallbackConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.debug_flag"
|
28
|
+
).DebugFlagCallbackConfig
|
29
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
30
|
|
31
31
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "DirectorySetupCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.directory_setup"
|
26
|
-
).DirectorySetupCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.directory_setup"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "DirectorySetupCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.directory_setup"
|
30
|
+
).DirectorySetupCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -21,18 +21,18 @@ else:
|
|
21
21
|
|
22
22
|
if name in globals():
|
23
23
|
return globals()[name]
|
24
|
-
if name == "
|
24
|
+
if name == "CallbackConfigBase":
|
25
25
|
return importlib.import_module(
|
26
26
|
"nshtrainer.callbacks.early_stopping"
|
27
|
-
).
|
27
|
+
).CallbackConfigBase
|
28
28
|
if name == "EarlyStoppingCallbackConfig":
|
29
29
|
return importlib.import_module(
|
30
30
|
"nshtrainer.callbacks.early_stopping"
|
31
31
|
).EarlyStoppingCallbackConfig
|
32
|
-
if name == "
|
32
|
+
if name == "MetricConfig":
|
33
33
|
return importlib.import_module(
|
34
34
|
"nshtrainer.callbacks.early_stopping"
|
35
|
-
).
|
35
|
+
).MetricConfig
|
36
36
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
37
37
|
|
38
38
|
# Submodule exports
|
@@ -16,12 +16,12 @@ else:
|
|
16
16
|
|
17
17
|
if name in globals():
|
18
18
|
return globals()[name]
|
19
|
-
if name == "EMACallbackConfig":
|
20
|
-
return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
|
21
19
|
if name == "CallbackConfigBase":
|
22
20
|
return importlib.import_module(
|
23
21
|
"nshtrainer.callbacks.ema"
|
24
22
|
).CallbackConfigBase
|
23
|
+
if name == "EMACallbackConfig":
|
24
|
+
return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
|
25
25
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
26
26
|
|
27
27
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "FiniteChecksCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.finite_checks"
|
26
|
-
).FiniteChecksCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.finite_checks"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "FiniteChecksCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.finite_checks"
|
30
|
+
).FiniteChecksCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "GradientSkippingCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.gradient_skipping"
|
26
|
-
).GradientSkippingCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.gradient_skipping"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "GradientSkippingCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.gradient_skipping"
|
30
|
+
).GradientSkippingCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -7,11 +7,9 @@ from typing import TYPE_CHECKING
|
|
7
7
|
# Config/alias imports
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
-
from nshtrainer.callbacks.
|
11
|
-
|
12
|
-
|
13
|
-
from nshtrainer.callbacks.throughput_monitor import (
|
14
|
-
ThroughputMonitorConfig as ThroughputMonitorConfig,
|
10
|
+
from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigBase
|
11
|
+
from nshtrainer.callbacks.log_epoch import (
|
12
|
+
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
15
13
|
)
|
16
14
|
else:
|
17
15
|
|
@@ -20,14 +18,14 @@ else:
|
|
20
18
|
|
21
19
|
if name in globals():
|
22
20
|
return globals()[name]
|
23
|
-
if name == "ThroughputMonitorConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.throughput_monitor"
|
26
|
-
).ThroughputMonitorConfig
|
27
21
|
if name == "CallbackConfigBase":
|
28
22
|
return importlib.import_module(
|
29
|
-
"nshtrainer.callbacks.
|
23
|
+
"nshtrainer.callbacks.log_epoch"
|
30
24
|
).CallbackConfigBase
|
25
|
+
if name == "LogEpochCallbackConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.log_epoch"
|
28
|
+
).LogEpochCallbackConfig
|
31
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
30
|
|
33
31
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "NormLoggingCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.norm_logging"
|
26
|
-
).NormLoggingCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.norm_logging"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "NormLoggingCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.norm_logging"
|
30
|
+
).NormLoggingCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "PrintTableMetricsCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.print_table"
|
26
|
-
).PrintTableMetricsCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.print_table"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "PrintTableMetricsCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.print_table"
|
30
|
+
).PrintTableMetricsCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "RLPSanityChecksCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.rlp_sanity_checks"
|
26
|
-
).RLPSanityChecksCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.rlp_sanity_checks"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "RLPSanityChecksCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.rlp_sanity_checks"
|
30
|
+
).RLPSanityChecksCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "SharedParametersCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.shared_parameters"
|
26
|
-
).SharedParametersCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.shared_parameters"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "SharedParametersCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.shared_parameters"
|
30
|
+
).SharedParametersCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -18,14 +18,14 @@ else:
|
|
18
18
|
|
19
19
|
if name in globals():
|
20
20
|
return globals()[name]
|
21
|
-
if name == "EpochTimerCallbackConfig":
|
22
|
-
return importlib.import_module(
|
23
|
-
"nshtrainer.callbacks.timer"
|
24
|
-
).EpochTimerCallbackConfig
|
25
21
|
if name == "CallbackConfigBase":
|
26
22
|
return importlib.import_module(
|
27
23
|
"nshtrainer.callbacks.timer"
|
28
24
|
).CallbackConfigBase
|
25
|
+
if name == "EpochTimerCallbackConfig":
|
26
|
+
return importlib.import_module(
|
27
|
+
"nshtrainer.callbacks.timer"
|
28
|
+
).EpochTimerCallbackConfig
|
29
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
30
30
|
|
31
31
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "WandbUploadCodeCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.wandb_upload_code"
|
26
|
-
).WandbUploadCodeCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.wandb_upload_code"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "WandbUploadCodeCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.wandb_upload_code"
|
30
|
+
).WandbUploadCodeCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|
@@ -20,14 +20,14 @@ else:
|
|
20
20
|
|
21
21
|
if name in globals():
|
22
22
|
return globals()[name]
|
23
|
-
if name == "WandbWatchCallbackConfig":
|
24
|
-
return importlib.import_module(
|
25
|
-
"nshtrainer.callbacks.wandb_watch"
|
26
|
-
).WandbWatchCallbackConfig
|
27
23
|
if name == "CallbackConfigBase":
|
28
24
|
return importlib.import_module(
|
29
25
|
"nshtrainer.callbacks.wandb_watch"
|
30
26
|
).CallbackConfigBase
|
27
|
+
if name == "WandbWatchCallbackConfig":
|
28
|
+
return importlib.import_module(
|
29
|
+
"nshtrainer.callbacks.wandb_watch"
|
30
|
+
).WandbWatchCallbackConfig
|
31
31
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
32
32
|
|
33
33
|
# Submodule exports
|