nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +6 -3
- nshtrainer/_callback.py +297 -2
- nshtrainer/_checkpoint/loader.py +23 -30
- nshtrainer/_checkpoint/metadata.py +22 -18
- nshtrainer/_experimental/__init__.py +0 -2
- nshtrainer/_hf_hub.py +25 -26
- nshtrainer/callbacks/__init__.py +1 -3
- nshtrainer/callbacks/actsave.py +22 -20
- nshtrainer/callbacks/base.py +7 -7
- nshtrainer/callbacks/checkpoint/__init__.py +1 -1
- nshtrainer/callbacks/checkpoint/_base.py +8 -5
- nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
- nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
- nshtrainer/callbacks/debug_flag.py +14 -19
- nshtrainer/callbacks/directory_setup.py +6 -11
- nshtrainer/callbacks/early_stopping.py +3 -3
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/log_epoch.py +1 -1
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
- nshtrainer/callbacks/shared_parameters.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_upload_code.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/config/__init__.py +189 -189
- nshtrainer/config/_checkpoint/__init__.py +70 -0
- nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
- nshtrainer/config/_directory/__init__.py +2 -2
- nshtrainer/config/_hf_hub/__init__.py +2 -2
- nshtrainer/config/callbacks/__init__.py +44 -44
- nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
- nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
- nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
- nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
- nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
- nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
- nshtrainer/config/callbacks/ema/__init__.py +2 -2
- nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
- nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
- nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
- nshtrainer/config/callbacks/print_table/__init__.py +4 -4
- nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
- nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
- nshtrainer/config/callbacks/timer/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
- nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
- nshtrainer/config/loggers/__init__.py +10 -6
- nshtrainer/config/loggers/actsave/__init__.py +29 -0
- nshtrainer/config/loggers/csv/__init__.py +2 -2
- nshtrainer/config/loggers/wandb/__init__.py +6 -6
- nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
- nshtrainer/config/nn/__init__.py +18 -18
- nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
- nshtrainer/config/optimizer/__init__.py +2 -2
- nshtrainer/config/profiler/__init__.py +2 -2
- nshtrainer/config/profiler/pytorch/__init__.py +4 -4
- nshtrainer/config/profiler/simple/__init__.py +4 -4
- nshtrainer/config/trainer/__init__.py +180 -0
- nshtrainer/config/trainer/_config/__init__.py +59 -36
- nshtrainer/config/trainer/trainer/__init__.py +27 -0
- nshtrainer/config/util/__init__.py +109 -0
- nshtrainer/config/util/_environment_info/__init__.py +20 -20
- nshtrainer/config/util/config/__init__.py +2 -2
- nshtrainer/data/datamodule.py +52 -2
- nshtrainer/loggers/__init__.py +2 -1
- nshtrainer/loggers/_base.py +5 -2
- nshtrainer/loggers/actsave.py +59 -0
- nshtrainer/loggers/csv.py +5 -5
- nshtrainer/loggers/tensorboard.py +5 -5
- nshtrainer/loggers/wandb.py +17 -16
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +64 -347
- nshtrainer/model/mixins/callback.py +24 -5
- nshtrainer/model/mixins/debug.py +86 -0
- nshtrainer/model/mixins/logger.py +142 -145
- nshtrainer/profiler/_base.py +2 -2
- nshtrainer/profiler/advanced.py +4 -4
- nshtrainer/profiler/pytorch.py +4 -4
- nshtrainer/profiler/simple.py +4 -4
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/_config.py +164 -17
- nshtrainer/trainer/checkpoint_connector.py +23 -8
- nshtrainer/trainer/trainer.py +194 -76
- nshtrainer/util/_environment_info.py +21 -13
- nshtrainer/util/config/dtype.py +4 -4
- nshtrainer/util/typing_utils.py +1 -1
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
- nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
- nshtrainer/callbacks/throughput_monitor.py +0 -58
- nshtrainer/config/model/__init__.py +0 -41
- nshtrainer/config/model/base/__init__.py +0 -25
- nshtrainer/config/model/config/__init__.py +0 -37
- nshtrainer/config/model/mixins/logger/__init__.py +0 -22
- nshtrainer/config/runner/__init__.py +0 -22
- nshtrainer/ll/__init__.py +0 -59
- nshtrainer/ll/_experimental.py +0 -3
- nshtrainer/ll/actsave.py +0 -6
- nshtrainer/ll/callbacks.py +0 -3
- nshtrainer/ll/config.py +0 -6
- nshtrainer/ll/data.py +0 -3
- nshtrainer/ll/log.py +0 -5
- nshtrainer/ll/lr_scheduler.py +0 -3
- nshtrainer/ll/model.py +0 -21
- nshtrainer/ll/nn.py +0 -3
- nshtrainer/ll/optimizer.py +0 -3
- nshtrainer/ll/runner.py +0 -5
- nshtrainer/ll/snapshot.py +0 -3
- nshtrainer/ll/snoop.py +0 -3
- nshtrainer/ll/trainer.py +0 -3
- nshtrainer/ll/typecheck.py +0 -3
- nshtrainer/ll/util.py +0 -3
- nshtrainer/model/config.py +0 -218
- nshtrainer/runner.py +0 -101
- nshtrainer-0.44.1.dist-info/RECORD +0 -162
- {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
|
7
7
|
# Config/alias imports
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
|
10
11
|
from nshtrainer.trainer._config import (
|
11
12
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
12
13
|
)
|
@@ -25,9 +26,11 @@ if TYPE_CHECKING:
|
|
25
26
|
from nshtrainer.trainer._config import (
|
26
27
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
27
28
|
)
|
29
|
+
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
28
30
|
from nshtrainer.trainer._config import (
|
29
31
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
30
32
|
)
|
33
|
+
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
31
34
|
from nshtrainer.trainer._config import (
|
32
35
|
GradientClippingConfig as GradientClippingConfig,
|
33
36
|
)
|
@@ -35,8 +38,12 @@ if TYPE_CHECKING:
|
|
35
38
|
from nshtrainer.trainer._config import (
|
36
39
|
LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
|
37
40
|
)
|
41
|
+
from nshtrainer.trainer._config import (
|
42
|
+
LogEpochCallbackConfig as LogEpochCallbackConfig,
|
43
|
+
)
|
38
44
|
from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
|
39
45
|
from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
|
46
|
+
from nshtrainer.trainer._config import MetricConfig as MetricConfig
|
40
47
|
from nshtrainer.trainer._config import (
|
41
48
|
OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
|
42
49
|
)
|
@@ -64,80 +71,96 @@ else:
|
|
64
71
|
|
65
72
|
if name in globals():
|
66
73
|
return globals()[name]
|
67
|
-
if name == "
|
74
|
+
if name == "ActSaveLoggerConfig":
|
68
75
|
return importlib.import_module(
|
69
76
|
"nshtrainer.trainer._config"
|
70
|
-
).
|
71
|
-
if name == "
|
72
|
-
return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
|
73
|
-
if name == "OnExceptionCheckpointCallbackConfig":
|
77
|
+
).ActSaveLoggerConfig
|
78
|
+
if name == "BestCheckpointCallbackConfig":
|
74
79
|
return importlib.import_module(
|
75
80
|
"nshtrainer.trainer._config"
|
76
|
-
).
|
77
|
-
if name == "
|
81
|
+
).BestCheckpointCallbackConfig
|
82
|
+
if name == "CSVLoggerConfig":
|
83
|
+
return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
|
84
|
+
if name == "CallbackConfigBase":
|
78
85
|
return importlib.import_module(
|
79
86
|
"nshtrainer.trainer._config"
|
80
|
-
).
|
81
|
-
if name == "
|
87
|
+
).CallbackConfigBase
|
88
|
+
if name == "CheckpointLoadingConfig":
|
82
89
|
return importlib.import_module(
|
83
90
|
"nshtrainer.trainer._config"
|
84
|
-
).
|
85
|
-
if name == "
|
86
|
-
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
87
|
-
if name == "TensorboardLoggerConfig":
|
91
|
+
).CheckpointLoadingConfig
|
92
|
+
if name == "CheckpointSavingConfig":
|
88
93
|
return importlib.import_module(
|
89
94
|
"nshtrainer.trainer._config"
|
90
|
-
).
|
91
|
-
if name == "
|
95
|
+
).CheckpointSavingConfig
|
96
|
+
if name == "DebugFlagCallbackConfig":
|
92
97
|
return importlib.import_module(
|
93
98
|
"nshtrainer.trainer._config"
|
94
|
-
).
|
95
|
-
if name == "
|
99
|
+
).DebugFlagCallbackConfig
|
100
|
+
if name == "DirectoryConfig":
|
101
|
+
return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
|
102
|
+
if name == "EarlyStoppingCallbackConfig":
|
96
103
|
return importlib.import_module(
|
97
104
|
"nshtrainer.trainer._config"
|
98
|
-
).
|
99
|
-
if name == "
|
100
|
-
return importlib.import_module(
|
105
|
+
).EarlyStoppingCallbackConfig
|
106
|
+
if name == "EnvironmentConfig":
|
107
|
+
return importlib.import_module(
|
108
|
+
"nshtrainer.trainer._config"
|
109
|
+
).EnvironmentConfig
|
110
|
+
if name == "GradientClippingConfig":
|
111
|
+
return importlib.import_module(
|
112
|
+
"nshtrainer.trainer._config"
|
113
|
+
).GradientClippingConfig
|
101
114
|
if name == "HuggingFaceHubConfig":
|
102
115
|
return importlib.import_module(
|
103
116
|
"nshtrainer.trainer._config"
|
104
117
|
).HuggingFaceHubConfig
|
105
|
-
if name == "
|
118
|
+
if name == "LastCheckpointCallbackConfig":
|
106
119
|
return importlib.import_module(
|
107
120
|
"nshtrainer.trainer._config"
|
108
|
-
).
|
109
|
-
if name == "
|
121
|
+
).LastCheckpointCallbackConfig
|
122
|
+
if name == "LogEpochCallbackConfig":
|
110
123
|
return importlib.import_module(
|
111
124
|
"nshtrainer.trainer._config"
|
112
|
-
).
|
113
|
-
if name == "
|
125
|
+
).LogEpochCallbackConfig
|
126
|
+
if name == "LoggingConfig":
|
127
|
+
return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
|
128
|
+
if name == "MetricConfig":
|
129
|
+
return importlib.import_module("nshtrainer.trainer._config").MetricConfig
|
130
|
+
if name == "OnExceptionCheckpointCallbackConfig":
|
114
131
|
return importlib.import_module(
|
115
132
|
"nshtrainer.trainer._config"
|
116
|
-
).
|
117
|
-
if name == "
|
133
|
+
).OnExceptionCheckpointCallbackConfig
|
134
|
+
if name == "OptimizationConfig":
|
118
135
|
return importlib.import_module(
|
119
136
|
"nshtrainer.trainer._config"
|
120
|
-
).
|
121
|
-
if name == "
|
137
|
+
).OptimizationConfig
|
138
|
+
if name == "RLPSanityChecksCallbackConfig":
|
122
139
|
return importlib.import_module(
|
123
140
|
"nshtrainer.trainer._config"
|
124
|
-
).
|
141
|
+
).RLPSanityChecksCallbackConfig
|
125
142
|
if name == "ReproducibilityConfig":
|
126
143
|
return importlib.import_module(
|
127
144
|
"nshtrainer.trainer._config"
|
128
145
|
).ReproducibilityConfig
|
129
|
-
if name == "
|
146
|
+
if name == "SanityCheckingConfig":
|
130
147
|
return importlib.import_module(
|
131
148
|
"nshtrainer.trainer._config"
|
132
|
-
).
|
133
|
-
if name == "
|
149
|
+
).SanityCheckingConfig
|
150
|
+
if name == "SharedParametersCallbackConfig":
|
134
151
|
return importlib.import_module(
|
135
152
|
"nshtrainer.trainer._config"
|
136
|
-
).
|
137
|
-
if name == "
|
153
|
+
).SharedParametersCallbackConfig
|
154
|
+
if name == "TensorboardLoggerConfig":
|
138
155
|
return importlib.import_module(
|
139
156
|
"nshtrainer.trainer._config"
|
140
|
-
).
|
157
|
+
).TensorboardLoggerConfig
|
158
|
+
if name == "TrainerConfig":
|
159
|
+
return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
|
160
|
+
if name == "WandbLoggerConfig":
|
161
|
+
return importlib.import_module(
|
162
|
+
"nshtrainer.trainer._config"
|
163
|
+
).WandbLoggerConfig
|
141
164
|
if name == "CallbackConfig":
|
142
165
|
return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
|
143
166
|
if name == "CheckpointCallbackConfig":
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
|
11
|
+
from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
|
12
|
+
else:
|
13
|
+
|
14
|
+
def __getattr__(name):
|
15
|
+
import importlib
|
16
|
+
|
17
|
+
if name in globals():
|
18
|
+
return globals()[name]
|
19
|
+
if name == "EnvironmentConfig":
|
20
|
+
return importlib.import_module(
|
21
|
+
"nshtrainer.trainer.trainer"
|
22
|
+
).EnvironmentConfig
|
23
|
+
if name == "TrainerConfig":
|
24
|
+
return importlib.import_module("nshtrainer.trainer.trainer").TrainerConfig
|
25
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
26
|
+
|
27
|
+
# Submodule exports
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__codegen__ = True
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
# Config/alias imports
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from nshtrainer.util._environment_info import (
|
11
|
+
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
12
|
+
)
|
13
|
+
from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
|
14
|
+
from nshtrainer.util._environment_info import (
|
15
|
+
EnvironmentCUDAConfig as EnvironmentCUDAConfig,
|
16
|
+
)
|
17
|
+
from nshtrainer.util._environment_info import (
|
18
|
+
EnvironmentGPUConfig as EnvironmentGPUConfig,
|
19
|
+
)
|
20
|
+
from nshtrainer.util._environment_info import (
|
21
|
+
EnvironmentHardwareConfig as EnvironmentHardwareConfig,
|
22
|
+
)
|
23
|
+
from nshtrainer.util._environment_info import (
|
24
|
+
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
25
|
+
)
|
26
|
+
from nshtrainer.util._environment_info import (
|
27
|
+
EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
|
28
|
+
)
|
29
|
+
from nshtrainer.util._environment_info import (
|
30
|
+
EnvironmentPackageConfig as EnvironmentPackageConfig,
|
31
|
+
)
|
32
|
+
from nshtrainer.util._environment_info import (
|
33
|
+
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
34
|
+
)
|
35
|
+
from nshtrainer.util._environment_info import (
|
36
|
+
EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
|
37
|
+
)
|
38
|
+
from nshtrainer.util._environment_info import (
|
39
|
+
GitRepositoryConfig as GitRepositoryConfig,
|
40
|
+
)
|
41
|
+
from nshtrainer.util.config import DTypeConfig as DTypeConfig
|
42
|
+
from nshtrainer.util.config import DurationConfig as DurationConfig
|
43
|
+
from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
44
|
+
from nshtrainer.util.config import StepsConfig as StepsConfig
|
45
|
+
else:
|
46
|
+
|
47
|
+
def __getattr__(name):
|
48
|
+
import importlib
|
49
|
+
|
50
|
+
if name in globals():
|
51
|
+
return globals()[name]
|
52
|
+
if name == "DTypeConfig":
|
53
|
+
return importlib.import_module("nshtrainer.util.config").DTypeConfig
|
54
|
+
if name == "EnvironmentCUDAConfig":
|
55
|
+
return importlib.import_module(
|
56
|
+
"nshtrainer.util._environment_info"
|
57
|
+
).EnvironmentCUDAConfig
|
58
|
+
if name == "EnvironmentClassInformationConfig":
|
59
|
+
return importlib.import_module(
|
60
|
+
"nshtrainer.util._environment_info"
|
61
|
+
).EnvironmentClassInformationConfig
|
62
|
+
if name == "EnvironmentConfig":
|
63
|
+
return importlib.import_module(
|
64
|
+
"nshtrainer.util._environment_info"
|
65
|
+
).EnvironmentConfig
|
66
|
+
if name == "EnvironmentGPUConfig":
|
67
|
+
return importlib.import_module(
|
68
|
+
"nshtrainer.util._environment_info"
|
69
|
+
).EnvironmentGPUConfig
|
70
|
+
if name == "EnvironmentHardwareConfig":
|
71
|
+
return importlib.import_module(
|
72
|
+
"nshtrainer.util._environment_info"
|
73
|
+
).EnvironmentHardwareConfig
|
74
|
+
if name == "EnvironmentLSFInformationConfig":
|
75
|
+
return importlib.import_module(
|
76
|
+
"nshtrainer.util._environment_info"
|
77
|
+
).EnvironmentLSFInformationConfig
|
78
|
+
if name == "EnvironmentLinuxEnvironmentConfig":
|
79
|
+
return importlib.import_module(
|
80
|
+
"nshtrainer.util._environment_info"
|
81
|
+
).EnvironmentLinuxEnvironmentConfig
|
82
|
+
if name == "EnvironmentPackageConfig":
|
83
|
+
return importlib.import_module(
|
84
|
+
"nshtrainer.util._environment_info"
|
85
|
+
).EnvironmentPackageConfig
|
86
|
+
if name == "EnvironmentSLURMInformationConfig":
|
87
|
+
return importlib.import_module(
|
88
|
+
"nshtrainer.util._environment_info"
|
89
|
+
).EnvironmentSLURMInformationConfig
|
90
|
+
if name == "EnvironmentSnapshotConfig":
|
91
|
+
return importlib.import_module(
|
92
|
+
"nshtrainer.util._environment_info"
|
93
|
+
).EnvironmentSnapshotConfig
|
94
|
+
if name == "EpochsConfig":
|
95
|
+
return importlib.import_module("nshtrainer.util.config").EpochsConfig
|
96
|
+
if name == "GitRepositoryConfig":
|
97
|
+
return importlib.import_module(
|
98
|
+
"nshtrainer.util._environment_info"
|
99
|
+
).GitRepositoryConfig
|
100
|
+
if name == "StepsConfig":
|
101
|
+
return importlib.import_module("nshtrainer.util.config").StepsConfig
|
102
|
+
if name == "DurationConfig":
|
103
|
+
return importlib.import_module("nshtrainer.util.config").DurationConfig
|
104
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
105
|
+
|
106
|
+
|
107
|
+
# Submodule exports
|
108
|
+
from . import _environment_info as _environment_info
|
109
|
+
from . import config as config
|
@@ -45,50 +45,50 @@ else:
|
|
45
45
|
|
46
46
|
if name in globals():
|
47
47
|
return globals()[name]
|
48
|
-
if name == "
|
48
|
+
if name == "EnvironmentCUDAConfig":
|
49
49
|
return importlib.import_module(
|
50
50
|
"nshtrainer.util._environment_info"
|
51
|
-
).
|
52
|
-
if name == "
|
51
|
+
).EnvironmentCUDAConfig
|
52
|
+
if name == "EnvironmentClassInformationConfig":
|
53
53
|
return importlib.import_module(
|
54
54
|
"nshtrainer.util._environment_info"
|
55
|
-
).
|
56
|
-
if name == "
|
55
|
+
).EnvironmentClassInformationConfig
|
56
|
+
if name == "EnvironmentConfig":
|
57
57
|
return importlib.import_module(
|
58
58
|
"nshtrainer.util._environment_info"
|
59
|
-
).
|
60
|
-
if name == "
|
59
|
+
).EnvironmentConfig
|
60
|
+
if name == "EnvironmentGPUConfig":
|
61
61
|
return importlib.import_module(
|
62
62
|
"nshtrainer.util._environment_info"
|
63
|
-
).
|
63
|
+
).EnvironmentGPUConfig
|
64
64
|
if name == "EnvironmentHardwareConfig":
|
65
65
|
return importlib.import_module(
|
66
66
|
"nshtrainer.util._environment_info"
|
67
67
|
).EnvironmentHardwareConfig
|
68
|
-
if name == "
|
68
|
+
if name == "EnvironmentLSFInformationConfig":
|
69
69
|
return importlib.import_module(
|
70
70
|
"nshtrainer.util._environment_info"
|
71
|
-
).
|
72
|
-
if name == "
|
71
|
+
).EnvironmentLSFInformationConfig
|
72
|
+
if name == "EnvironmentLinuxEnvironmentConfig":
|
73
73
|
return importlib.import_module(
|
74
74
|
"nshtrainer.util._environment_info"
|
75
|
-
).
|
76
|
-
if name == "
|
75
|
+
).EnvironmentLinuxEnvironmentConfig
|
76
|
+
if name == "EnvironmentPackageConfig":
|
77
77
|
return importlib.import_module(
|
78
78
|
"nshtrainer.util._environment_info"
|
79
|
-
).
|
80
|
-
if name == "
|
79
|
+
).EnvironmentPackageConfig
|
80
|
+
if name == "EnvironmentSLURMInformationConfig":
|
81
81
|
return importlib.import_module(
|
82
82
|
"nshtrainer.util._environment_info"
|
83
|
-
).
|
84
|
-
if name == "
|
83
|
+
).EnvironmentSLURMInformationConfig
|
84
|
+
if name == "EnvironmentSnapshotConfig":
|
85
85
|
return importlib.import_module(
|
86
86
|
"nshtrainer.util._environment_info"
|
87
|
-
).
|
88
|
-
if name == "
|
87
|
+
).EnvironmentSnapshotConfig
|
88
|
+
if name == "GitRepositoryConfig":
|
89
89
|
return importlib.import_module(
|
90
90
|
"nshtrainer.util._environment_info"
|
91
|
-
).
|
91
|
+
).GitRepositoryConfig
|
92
92
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
93
93
|
|
94
94
|
# Submodule exports
|
@@ -18,12 +18,12 @@ else:
|
|
18
18
|
|
19
19
|
if name in globals():
|
20
20
|
return globals()[name]
|
21
|
+
if name == "DTypeConfig":
|
22
|
+
return importlib.import_module("nshtrainer.util.config").DTypeConfig
|
21
23
|
if name == "EpochsConfig":
|
22
24
|
return importlib.import_module("nshtrainer.util.config").EpochsConfig
|
23
25
|
if name == "StepsConfig":
|
24
26
|
return importlib.import_module("nshtrainer.util.config").StepsConfig
|
25
|
-
if name == "DTypeConfig":
|
26
|
-
return importlib.import_module("nshtrainer.util.config").DTypeConfig
|
27
27
|
if name == "DurationConfig":
|
28
28
|
return importlib.import_module("nshtrainer.util.config").DurationConfig
|
29
29
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
nshtrainer/data/datamodule.py
CHANGED
@@ -1,7 +1,57 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from collections.abc import Mapping
|
5
|
+
from typing import Any, Generic, cast
|
6
|
+
|
7
|
+
import nshconfig as C
|
3
8
|
from lightning.pytorch import LightningDataModule
|
9
|
+
from typing_extensions import Never, TypeVar, deprecated, override
|
10
|
+
|
11
|
+
from ..model.mixins.callback import CallbackRegistrarModuleMixin
|
12
|
+
from ..model.mixins.debug import _DebugModuleMixin
|
13
|
+
|
14
|
+
THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
|
15
|
+
|
16
|
+
|
17
|
+
class LightningDataModuleBase(
|
18
|
+
_DebugModuleMixin,
|
19
|
+
CallbackRegistrarModuleMixin,
|
20
|
+
LightningDataModule,
|
21
|
+
ABC,
|
22
|
+
Generic[THparams],
|
23
|
+
):
|
24
|
+
@property
|
25
|
+
@override
|
26
|
+
def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
27
|
+
return cast(THparams, super().hparams)
|
28
|
+
|
29
|
+
@property
|
30
|
+
@override
|
31
|
+
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
32
|
+
hparams = cast(THparams, super().hparams_initial)
|
33
|
+
return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
|
34
|
+
|
35
|
+
@property
|
36
|
+
@deprecated("Use `hparams` instead")
|
37
|
+
def config(self):
|
38
|
+
return cast(Never, self.hparams)
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
@abstractmethod
|
42
|
+
def hparams_cls(cls) -> type[THparams]: ...
|
4
43
|
|
44
|
+
@override
|
45
|
+
def __init__(self, hparams: THparams | Mapping[str, Any]):
|
46
|
+
super().__init__()
|
5
47
|
|
6
|
-
|
7
|
-
|
48
|
+
# Validate and save hyperparameters
|
49
|
+
hparams_cls = self.hparams_cls()
|
50
|
+
if isinstance(hparams, Mapping):
|
51
|
+
hparams = hparams_cls.model_validate(hparams)
|
52
|
+
elif not isinstance(hparams, hparams_cls):
|
53
|
+
raise TypeError(
|
54
|
+
f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
|
55
|
+
)
|
56
|
+
hparams = hparams.model_deep_validate()
|
57
|
+
self.save_hyperparameters(hparams)
|
nshtrainer/loggers/__init__.py
CHANGED
@@ -5,11 +5,12 @@ from typing import Annotated, TypeAlias
|
|
5
5
|
import nshconfig as C
|
6
6
|
|
7
7
|
from ._base import BaseLoggerConfig as BaseLoggerConfig
|
8
|
+
from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
|
8
9
|
from .csv import CSVLoggerConfig as CSVLoggerConfig
|
9
10
|
from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
|
10
11
|
from .wandb import WandbLoggerConfig as WandbLoggerConfig
|
11
12
|
|
12
13
|
LoggerConfig: TypeAlias = Annotated[
|
13
|
-
CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig,
|
14
|
+
CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig | ActSaveLoggerConfig,
|
14
15
|
C.Field(discriminator="name"),
|
15
16
|
]
|
nshtrainer/loggers/_base.py
CHANGED
@@ -7,7 +7,7 @@ import nshconfig as C
|
|
7
7
|
from lightning.pytorch.loggers import Logger
|
8
8
|
|
9
9
|
if TYPE_CHECKING:
|
10
|
-
from ..
|
10
|
+
from ..trainer._config import TrainerConfig
|
11
11
|
|
12
12
|
|
13
13
|
class BaseLoggerConfig(C.Config, ABC):
|
@@ -21,8 +21,11 @@ class BaseLoggerConfig(C.Config, ABC):
|
|
21
21
|
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
22
22
|
|
23
23
|
@abstractmethod
|
24
|
-
def create_logger(self,
|
24
|
+
def create_logger(self, trainer_config: TrainerConfig) -> Logger | None: ...
|
25
25
|
|
26
26
|
def disable_(self):
|
27
27
|
self.enabled = False
|
28
28
|
return self
|
29
|
+
|
30
|
+
def __bool__(self):
|
31
|
+
return self.enabled
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from argparse import Namespace
|
4
|
+
from typing import Any, Literal
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from lightning.pytorch.loggers import Logger
|
8
|
+
|
9
|
+
from ._base import BaseLoggerConfig
|
10
|
+
|
11
|
+
|
12
|
+
class ActSaveLoggerConfig(BaseLoggerConfig):
|
13
|
+
name: Literal["actsave"] = "actsave"
|
14
|
+
|
15
|
+
def create_logger(self, trainer_config):
|
16
|
+
if not self.enabled:
|
17
|
+
return None
|
18
|
+
|
19
|
+
return ActSaveLogger()
|
20
|
+
|
21
|
+
|
22
|
+
class ActSaveLogger(Logger):
|
23
|
+
@property
|
24
|
+
def name(self):
|
25
|
+
return None
|
26
|
+
|
27
|
+
@property
|
28
|
+
def version(self):
|
29
|
+
from nshutils import ActSave
|
30
|
+
|
31
|
+
if ActSave._saver is None:
|
32
|
+
return None
|
33
|
+
|
34
|
+
return ActSave._saver._id
|
35
|
+
|
36
|
+
@property
|
37
|
+
def save_dir(self):
|
38
|
+
from nshutils import ActSave
|
39
|
+
|
40
|
+
if ActSave._saver is None:
|
41
|
+
return None
|
42
|
+
|
43
|
+
return str(ActSave._saver._save_dir)
|
44
|
+
|
45
|
+
def log_hyperparams(
|
46
|
+
self,
|
47
|
+
params: dict[str, Any] | Namespace,
|
48
|
+
*args: Any,
|
49
|
+
**kwargs: Any,
|
50
|
+
):
|
51
|
+
from nshutils import ActSave
|
52
|
+
|
53
|
+
# Wrap the hparams as a object-dtype np array
|
54
|
+
return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
|
55
|
+
|
56
|
+
def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
|
57
|
+
from nshutils import ActSave
|
58
|
+
|
59
|
+
ActSave.save({**metrics})
|
nshtrainer/loggers/csv.py
CHANGED
@@ -23,20 +23,20 @@ class CSVLoggerConfig(BaseLoggerConfig):
|
|
23
23
|
"""How often to flush logs to disk."""
|
24
24
|
|
25
25
|
@override
|
26
|
-
def create_logger(self,
|
26
|
+
def create_logger(self, trainer_config):
|
27
27
|
if not self.enabled:
|
28
28
|
return None
|
29
29
|
|
30
30
|
from lightning.pytorch.loggers.csv_logs import CSVLogger
|
31
31
|
|
32
|
-
save_dir =
|
33
|
-
|
32
|
+
save_dir = trainer_config.directory._resolve_log_directory_for_logger(
|
33
|
+
trainer_config.id,
|
34
34
|
self,
|
35
35
|
)
|
36
36
|
return CSVLogger(
|
37
37
|
save_dir=save_dir,
|
38
|
-
name=
|
39
|
-
version=
|
38
|
+
name=trainer_config.full_name,
|
39
|
+
version=trainer_config.id,
|
40
40
|
prefix=self.prefix,
|
41
41
|
flush_logs_every_n_steps=self.flush_logs_every_n_steps,
|
42
42
|
)
|
@@ -56,20 +56,20 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
56
56
|
"""A string to put at the beginning of metric keys."""
|
57
57
|
|
58
58
|
@override
|
59
|
-
def create_logger(self,
|
59
|
+
def create_logger(self, trainer_config):
|
60
60
|
if not self.enabled:
|
61
61
|
return None
|
62
62
|
|
63
63
|
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
64
64
|
|
65
|
-
save_dir =
|
66
|
-
|
65
|
+
save_dir = trainer_config.directory._resolve_log_directory_for_logger(
|
66
|
+
trainer_config.id,
|
67
67
|
self,
|
68
68
|
)
|
69
69
|
return TensorBoardLogger(
|
70
70
|
save_dir=save_dir,
|
71
|
-
name=
|
72
|
-
version=
|
71
|
+
name=trainer_config.full_name,
|
72
|
+
version=trainer_config.id,
|
73
73
|
log_graph=self.log_graph,
|
74
74
|
default_hp_metric=self.default_hp_metric,
|
75
75
|
)
|