nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
7
7
  # Config/alias imports
8
8
 
9
9
  if TYPE_CHECKING:
10
+ from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
10
11
  from nshtrainer.trainer._config import (
11
12
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
12
13
  )
@@ -25,9 +26,11 @@ if TYPE_CHECKING:
25
26
  from nshtrainer.trainer._config import (
26
27
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
27
28
  )
29
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
28
30
  from nshtrainer.trainer._config import (
29
31
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
30
32
  )
33
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
31
34
  from nshtrainer.trainer._config import (
32
35
  GradientClippingConfig as GradientClippingConfig,
33
36
  )
@@ -35,8 +38,12 @@ if TYPE_CHECKING:
35
38
  from nshtrainer.trainer._config import (
36
39
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
37
40
  )
41
+ from nshtrainer.trainer._config import (
42
+ LogEpochCallbackConfig as LogEpochCallbackConfig,
43
+ )
38
44
  from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
39
45
  from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
46
+ from nshtrainer.trainer._config import MetricConfig as MetricConfig
40
47
  from nshtrainer.trainer._config import (
41
48
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
42
49
  )
@@ -64,80 +71,96 @@ else:
64
71
 
65
72
  if name in globals():
66
73
  return globals()[name]
67
- if name == "SanityCheckingConfig":
74
+ if name == "ActSaveLoggerConfig":
68
75
  return importlib.import_module(
69
76
  "nshtrainer.trainer._config"
70
- ).SanityCheckingConfig
71
- if name == "TrainerConfig":
72
- return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
73
- if name == "OnExceptionCheckpointCallbackConfig":
77
+ ).ActSaveLoggerConfig
78
+ if name == "BestCheckpointCallbackConfig":
74
79
  return importlib.import_module(
75
80
  "nshtrainer.trainer._config"
76
- ).OnExceptionCheckpointCallbackConfig
77
- if name == "GradientClippingConfig":
81
+ ).BestCheckpointCallbackConfig
82
+ if name == "CSVLoggerConfig":
83
+ return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
84
+ if name == "CallbackConfigBase":
78
85
  return importlib.import_module(
79
86
  "nshtrainer.trainer._config"
80
- ).GradientClippingConfig
81
- if name == "WandbLoggerConfig":
87
+ ).CallbackConfigBase
88
+ if name == "CheckpointLoadingConfig":
82
89
  return importlib.import_module(
83
90
  "nshtrainer.trainer._config"
84
- ).WandbLoggerConfig
85
- if name == "LoggingConfig":
86
- return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
87
- if name == "TensorboardLoggerConfig":
91
+ ).CheckpointLoadingConfig
92
+ if name == "CheckpointSavingConfig":
88
93
  return importlib.import_module(
89
94
  "nshtrainer.trainer._config"
90
- ).TensorboardLoggerConfig
91
- if name == "RLPSanityChecksCallbackConfig":
95
+ ).CheckpointSavingConfig
96
+ if name == "DebugFlagCallbackConfig":
92
97
  return importlib.import_module(
93
98
  "nshtrainer.trainer._config"
94
- ).RLPSanityChecksCallbackConfig
95
- if name == "CheckpointSavingConfig":
99
+ ).DebugFlagCallbackConfig
100
+ if name == "DirectoryConfig":
101
+ return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
102
+ if name == "EarlyStoppingCallbackConfig":
96
103
  return importlib.import_module(
97
104
  "nshtrainer.trainer._config"
98
- ).CheckpointSavingConfig
99
- if name == "CSVLoggerConfig":
100
- return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
105
+ ).EarlyStoppingCallbackConfig
106
+ if name == "EnvironmentConfig":
107
+ return importlib.import_module(
108
+ "nshtrainer.trainer._config"
109
+ ).EnvironmentConfig
110
+ if name == "GradientClippingConfig":
111
+ return importlib.import_module(
112
+ "nshtrainer.trainer._config"
113
+ ).GradientClippingConfig
101
114
  if name == "HuggingFaceHubConfig":
102
115
  return importlib.import_module(
103
116
  "nshtrainer.trainer._config"
104
117
  ).HuggingFaceHubConfig
105
- if name == "CheckpointLoadingConfig":
118
+ if name == "LastCheckpointCallbackConfig":
106
119
  return importlib.import_module(
107
120
  "nshtrainer.trainer._config"
108
- ).CheckpointLoadingConfig
109
- if name == "DebugFlagCallbackConfig":
121
+ ).LastCheckpointCallbackConfig
122
+ if name == "LogEpochCallbackConfig":
110
123
  return importlib.import_module(
111
124
  "nshtrainer.trainer._config"
112
- ).DebugFlagCallbackConfig
113
- if name == "CallbackConfigBase":
125
+ ).LogEpochCallbackConfig
126
+ if name == "LoggingConfig":
127
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
+ if name == "MetricConfig":
129
+ return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
+ if name == "OnExceptionCheckpointCallbackConfig":
114
131
  return importlib.import_module(
115
132
  "nshtrainer.trainer._config"
116
- ).CallbackConfigBase
117
- if name == "LastCheckpointCallbackConfig":
133
+ ).OnExceptionCheckpointCallbackConfig
134
+ if name == "OptimizationConfig":
118
135
  return importlib.import_module(
119
136
  "nshtrainer.trainer._config"
120
- ).LastCheckpointCallbackConfig
121
- if name == "SharedParametersCallbackConfig":
137
+ ).OptimizationConfig
138
+ if name == "RLPSanityChecksCallbackConfig":
122
139
  return importlib.import_module(
123
140
  "nshtrainer.trainer._config"
124
- ).SharedParametersCallbackConfig
141
+ ).RLPSanityChecksCallbackConfig
125
142
  if name == "ReproducibilityConfig":
126
143
  return importlib.import_module(
127
144
  "nshtrainer.trainer._config"
128
145
  ).ReproducibilityConfig
129
- if name == "EarlyStoppingCallbackConfig":
146
+ if name == "SanityCheckingConfig":
130
147
  return importlib.import_module(
131
148
  "nshtrainer.trainer._config"
132
- ).EarlyStoppingCallbackConfig
133
- if name == "OptimizationConfig":
149
+ ).SanityCheckingConfig
150
+ if name == "SharedParametersCallbackConfig":
134
151
  return importlib.import_module(
135
152
  "nshtrainer.trainer._config"
136
- ).OptimizationConfig
137
- if name == "BestCheckpointCallbackConfig":
153
+ ).SharedParametersCallbackConfig
154
+ if name == "TensorboardLoggerConfig":
138
155
  return importlib.import_module(
139
156
  "nshtrainer.trainer._config"
140
- ).BestCheckpointCallbackConfig
157
+ ).TensorboardLoggerConfig
158
+ if name == "TrainerConfig":
159
+ return importlib.import_module("nshtrainer.trainer._config").TrainerConfig
160
+ if name == "WandbLoggerConfig":
161
+ return importlib.import_module(
162
+ "nshtrainer.trainer._config"
163
+ ).WandbLoggerConfig
141
164
  if name == "CallbackConfig":
142
165
  return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
143
166
  if name == "CheckpointCallbackConfig":
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.trainer.trainer import EnvironmentConfig as EnvironmentConfig
11
+ from nshtrainer.trainer.trainer import TrainerConfig as TrainerConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "EnvironmentConfig":
20
+ return importlib.import_module(
21
+ "nshtrainer.trainer.trainer"
22
+ ).EnvironmentConfig
23
+ if name == "TrainerConfig":
24
+ return importlib.import_module("nshtrainer.trainer.trainer").TrainerConfig
25
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
+
27
+ # Submodule exports
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.util._environment_info import (
11
+ EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
12
+ )
13
+ from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
14
+ from nshtrainer.util._environment_info import (
15
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
16
+ )
17
+ from nshtrainer.util._environment_info import (
18
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
19
+ )
20
+ from nshtrainer.util._environment_info import (
21
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
22
+ )
23
+ from nshtrainer.util._environment_info import (
24
+ EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
25
+ )
26
+ from nshtrainer.util._environment_info import (
27
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
28
+ )
29
+ from nshtrainer.util._environment_info import (
30
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
31
+ )
32
+ from nshtrainer.util._environment_info import (
33
+ EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
34
+ )
35
+ from nshtrainer.util._environment_info import (
36
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
37
+ )
38
+ from nshtrainer.util._environment_info import (
39
+ GitRepositoryConfig as GitRepositoryConfig,
40
+ )
41
+ from nshtrainer.util.config import DTypeConfig as DTypeConfig
42
+ from nshtrainer.util.config import DurationConfig as DurationConfig
43
+ from nshtrainer.util.config import EpochsConfig as EpochsConfig
44
+ from nshtrainer.util.config import StepsConfig as StepsConfig
45
+ else:
46
+
47
+ def __getattr__(name):
48
+ import importlib
49
+
50
+ if name in globals():
51
+ return globals()[name]
52
+ if name == "DTypeConfig":
53
+ return importlib.import_module("nshtrainer.util.config").DTypeConfig
54
+ if name == "EnvironmentCUDAConfig":
55
+ return importlib.import_module(
56
+ "nshtrainer.util._environment_info"
57
+ ).EnvironmentCUDAConfig
58
+ if name == "EnvironmentClassInformationConfig":
59
+ return importlib.import_module(
60
+ "nshtrainer.util._environment_info"
61
+ ).EnvironmentClassInformationConfig
62
+ if name == "EnvironmentConfig":
63
+ return importlib.import_module(
64
+ "nshtrainer.util._environment_info"
65
+ ).EnvironmentConfig
66
+ if name == "EnvironmentGPUConfig":
67
+ return importlib.import_module(
68
+ "nshtrainer.util._environment_info"
69
+ ).EnvironmentGPUConfig
70
+ if name == "EnvironmentHardwareConfig":
71
+ return importlib.import_module(
72
+ "nshtrainer.util._environment_info"
73
+ ).EnvironmentHardwareConfig
74
+ if name == "EnvironmentLSFInformationConfig":
75
+ return importlib.import_module(
76
+ "nshtrainer.util._environment_info"
77
+ ).EnvironmentLSFInformationConfig
78
+ if name == "EnvironmentLinuxEnvironmentConfig":
79
+ return importlib.import_module(
80
+ "nshtrainer.util._environment_info"
81
+ ).EnvironmentLinuxEnvironmentConfig
82
+ if name == "EnvironmentPackageConfig":
83
+ return importlib.import_module(
84
+ "nshtrainer.util._environment_info"
85
+ ).EnvironmentPackageConfig
86
+ if name == "EnvironmentSLURMInformationConfig":
87
+ return importlib.import_module(
88
+ "nshtrainer.util._environment_info"
89
+ ).EnvironmentSLURMInformationConfig
90
+ if name == "EnvironmentSnapshotConfig":
91
+ return importlib.import_module(
92
+ "nshtrainer.util._environment_info"
93
+ ).EnvironmentSnapshotConfig
94
+ if name == "EpochsConfig":
95
+ return importlib.import_module("nshtrainer.util.config").EpochsConfig
96
+ if name == "GitRepositoryConfig":
97
+ return importlib.import_module(
98
+ "nshtrainer.util._environment_info"
99
+ ).GitRepositoryConfig
100
+ if name == "StepsConfig":
101
+ return importlib.import_module("nshtrainer.util.config").StepsConfig
102
+ if name == "DurationConfig":
103
+ return importlib.import_module("nshtrainer.util.config").DurationConfig
104
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
105
+
106
+
107
+ # Submodule exports
108
+ from . import _environment_info as _environment_info
109
+ from . import config as config
@@ -45,50 +45,50 @@ else:
45
45
 
46
46
  if name in globals():
47
47
  return globals()[name]
48
- if name == "EnvironmentLinuxEnvironmentConfig":
48
+ if name == "EnvironmentCUDAConfig":
49
49
  return importlib.import_module(
50
50
  "nshtrainer.util._environment_info"
51
- ).EnvironmentLinuxEnvironmentConfig
52
- if name == "EnvironmentLSFInformationConfig":
51
+ ).EnvironmentCUDAConfig
52
+ if name == "EnvironmentClassInformationConfig":
53
53
  return importlib.import_module(
54
54
  "nshtrainer.util._environment_info"
55
- ).EnvironmentLSFInformationConfig
56
- if name == "EnvironmentGPUConfig":
55
+ ).EnvironmentClassInformationConfig
56
+ if name == "EnvironmentConfig":
57
57
  return importlib.import_module(
58
58
  "nshtrainer.util._environment_info"
59
- ).EnvironmentGPUConfig
60
- if name == "EnvironmentPackageConfig":
59
+ ).EnvironmentConfig
60
+ if name == "EnvironmentGPUConfig":
61
61
  return importlib.import_module(
62
62
  "nshtrainer.util._environment_info"
63
- ).EnvironmentPackageConfig
63
+ ).EnvironmentGPUConfig
64
64
  if name == "EnvironmentHardwareConfig":
65
65
  return importlib.import_module(
66
66
  "nshtrainer.util._environment_info"
67
67
  ).EnvironmentHardwareConfig
68
- if name == "EnvironmentSnapshotConfig":
68
+ if name == "EnvironmentLSFInformationConfig":
69
69
  return importlib.import_module(
70
70
  "nshtrainer.util._environment_info"
71
- ).EnvironmentSnapshotConfig
72
- if name == "EnvironmentClassInformationConfig":
71
+ ).EnvironmentLSFInformationConfig
72
+ if name == "EnvironmentLinuxEnvironmentConfig":
73
73
  return importlib.import_module(
74
74
  "nshtrainer.util._environment_info"
75
- ).EnvironmentClassInformationConfig
76
- if name == "GitRepositoryConfig":
75
+ ).EnvironmentLinuxEnvironmentConfig
76
+ if name == "EnvironmentPackageConfig":
77
77
  return importlib.import_module(
78
78
  "nshtrainer.util._environment_info"
79
- ).GitRepositoryConfig
80
- if name == "EnvironmentConfig":
79
+ ).EnvironmentPackageConfig
80
+ if name == "EnvironmentSLURMInformationConfig":
81
81
  return importlib.import_module(
82
82
  "nshtrainer.util._environment_info"
83
- ).EnvironmentConfig
84
- if name == "EnvironmentCUDAConfig":
83
+ ).EnvironmentSLURMInformationConfig
84
+ if name == "EnvironmentSnapshotConfig":
85
85
  return importlib.import_module(
86
86
  "nshtrainer.util._environment_info"
87
- ).EnvironmentCUDAConfig
88
- if name == "EnvironmentSLURMInformationConfig":
87
+ ).EnvironmentSnapshotConfig
88
+ if name == "GitRepositoryConfig":
89
89
  return importlib.import_module(
90
90
  "nshtrainer.util._environment_info"
91
- ).EnvironmentSLURMInformationConfig
91
+ ).GitRepositoryConfig
92
92
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
93
93
 
94
94
  # Submodule exports
@@ -18,12 +18,12 @@ else:
18
18
 
19
19
  if name in globals():
20
20
  return globals()[name]
21
+ if name == "DTypeConfig":
22
+ return importlib.import_module("nshtrainer.util.config").DTypeConfig
21
23
  if name == "EpochsConfig":
22
24
  return importlib.import_module("nshtrainer.util.config").EpochsConfig
23
25
  if name == "StepsConfig":
24
26
  return importlib.import_module("nshtrainer.util.config").StepsConfig
25
- if name == "DTypeConfig":
26
- return importlib.import_module("nshtrainer.util.config").DTypeConfig
27
27
  if name == "DurationConfig":
28
28
  return importlib.import_module("nshtrainer.util.config").DurationConfig
29
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -1,7 +1,57 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+ from typing import Any, Generic, cast
6
+
7
+ import nshconfig as C
3
8
  from lightning.pytorch import LightningDataModule
9
+ from typing_extensions import Never, TypeVar, deprecated, override
10
+
11
+ from ..model.mixins.callback import CallbackRegistrarModuleMixin
12
+ from ..model.mixins.debug import _DebugModuleMixin
13
+
14
+ THparams = TypeVar("THparams", bound=C.Config, infer_variance=True)
15
+
16
+
17
+ class LightningDataModuleBase(
18
+ _DebugModuleMixin,
19
+ CallbackRegistrarModuleMixin,
20
+ LightningDataModule,
21
+ ABC,
22
+ Generic[THparams],
23
+ ):
24
+ @property
25
+ @override
26
+ def hparams(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
27
+ return cast(THparams, super().hparams)
28
+
29
+ @property
30
+ @override
31
+ def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
32
+ hparams = cast(THparams, super().hparams_initial)
33
+ return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
34
+
35
+ @property
36
+ @deprecated("Use `hparams` instead")
37
+ def config(self):
38
+ return cast(Never, self.hparams)
39
+
40
+ @classmethod
41
+ @abstractmethod
42
+ def hparams_cls(cls) -> type[THparams]: ...
4
43
 
44
+ @override
45
+ def __init__(self, hparams: THparams | Mapping[str, Any]):
46
+ super().__init__()
5
47
 
6
- class LightningDataModuleBase(LightningDataModule):
7
- pass
48
+ # Validate and save hyperparameters
49
+ hparams_cls = self.hparams_cls()
50
+ if isinstance(hparams, Mapping):
51
+ hparams = hparams_cls.model_validate(hparams)
52
+ elif not isinstance(hparams, hparams_cls):
53
+ raise TypeError(
54
+ f"Expected hparams to be either a Mapping or an instance of {hparams_cls}, got {type(hparams)}"
55
+ )
56
+ hparams = hparams.model_deep_validate()
57
+ self.save_hyperparameters(hparams)
@@ -5,11 +5,12 @@ from typing import Annotated, TypeAlias
5
5
  import nshconfig as C
6
6
 
7
7
  from ._base import BaseLoggerConfig as BaseLoggerConfig
8
+ from .actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
8
9
  from .csv import CSVLoggerConfig as CSVLoggerConfig
9
10
  from .tensorboard import TensorboardLoggerConfig as TensorboardLoggerConfig
10
11
  from .wandb import WandbLoggerConfig as WandbLoggerConfig
11
12
 
12
13
  LoggerConfig: TypeAlias = Annotated[
13
- CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig,
14
+ CSVLoggerConfig | TensorboardLoggerConfig | WandbLoggerConfig | ActSaveLoggerConfig,
14
15
  C.Field(discriminator="name"),
15
16
  ]
@@ -7,7 +7,7 @@ import nshconfig as C
7
7
  from lightning.pytorch.loggers import Logger
8
8
 
9
9
  if TYPE_CHECKING:
10
- from ..model import BaseConfig
10
+ from ..trainer._config import TrainerConfig
11
11
 
12
12
 
13
13
  class BaseLoggerConfig(C.Config, ABC):
@@ -21,8 +21,11 @@ class BaseLoggerConfig(C.Config, ABC):
21
21
  """Directory to save the logs to. If None, will use the default log directory for the trainer."""
22
22
 
23
23
  @abstractmethod
24
- def create_logger(self, root_config: "BaseConfig") -> Logger | None: ...
24
+ def create_logger(self, trainer_config: TrainerConfig) -> Logger | None: ...
25
25
 
26
26
  def disable_(self):
27
27
  self.enabled = False
28
28
  return self
29
+
30
+ def __bool__(self):
31
+ return self.enabled
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from argparse import Namespace
4
+ from typing import Any, Literal
5
+
6
+ import numpy as np
7
+ from lightning.pytorch.loggers import Logger
8
+
9
+ from ._base import BaseLoggerConfig
10
+
11
+
12
+ class ActSaveLoggerConfig(BaseLoggerConfig):
13
+ name: Literal["actsave"] = "actsave"
14
+
15
+ def create_logger(self, trainer_config):
16
+ if not self.enabled:
17
+ return None
18
+
19
+ return ActSaveLogger()
20
+
21
+
22
+ class ActSaveLogger(Logger):
23
+ @property
24
+ def name(self):
25
+ return None
26
+
27
+ @property
28
+ def version(self):
29
+ from nshutils import ActSave
30
+
31
+ if ActSave._saver is None:
32
+ return None
33
+
34
+ return ActSave._saver._id
35
+
36
+ @property
37
+ def save_dir(self):
38
+ from nshutils import ActSave
39
+
40
+ if ActSave._saver is None:
41
+ return None
42
+
43
+ return str(ActSave._saver._save_dir)
44
+
45
+ def log_hyperparams(
46
+ self,
47
+ params: dict[str, Any] | Namespace,
48
+ *args: Any,
49
+ **kwargs: Any,
50
+ ):
51
+ from nshutils import ActSave
52
+
53
+ # Wrap the hparams as a object-dtype np array
54
+ return ActSave.save({"hyperparameters": np.array(params, dtype=object)})
55
+
56
+ def log_metrics(self, metrics: dict[str, float], step: int | None = None) -> None:
57
+ from nshutils import ActSave
58
+
59
+ ActSave.save({**metrics})
nshtrainer/loggers/csv.py CHANGED
@@ -23,20 +23,20 @@ class CSVLoggerConfig(BaseLoggerConfig):
23
23
  """How often to flush logs to disk."""
24
24
 
25
25
  @override
26
- def create_logger(self, root_config):
26
+ def create_logger(self, trainer_config):
27
27
  if not self.enabled:
28
28
  return None
29
29
 
30
30
  from lightning.pytorch.loggers.csv_logs import CSVLogger
31
31
 
32
- save_dir = root_config.directory._resolve_log_directory_for_logger(
33
- root_config.id,
32
+ save_dir = trainer_config.directory._resolve_log_directory_for_logger(
33
+ trainer_config.id,
34
34
  self,
35
35
  )
36
36
  return CSVLogger(
37
37
  save_dir=save_dir,
38
- name=root_config.run_name,
39
- version=root_config.id,
38
+ name=trainer_config.full_name,
39
+ version=trainer_config.id,
40
40
  prefix=self.prefix,
41
41
  flush_logs_every_n_steps=self.flush_logs_every_n_steps,
42
42
  )
@@ -56,20 +56,20 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
56
56
  """A string to put at the beginning of metric keys."""
57
57
 
58
58
  @override
59
- def create_logger(self, root_config):
59
+ def create_logger(self, trainer_config):
60
60
  if not self.enabled:
61
61
  return None
62
62
 
63
63
  from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
64
64
 
65
- save_dir = root_config.directory._resolve_log_directory_for_logger(
66
- root_config.id,
65
+ save_dir = trainer_config.directory._resolve_log_directory_for_logger(
66
+ trainer_config.id,
67
67
  self,
68
68
  )
69
69
  return TensorBoardLogger(
70
70
  save_dir=save_dir,
71
- name=root_config.run_name,
72
- version=root_config.id,
71
+ name=trainer_config.full_name,
72
+ version=trainer_config.id,
73
73
  log_graph=self.log_graph,
74
74
  default_hp_metric=self.default_hp_metric,
75
75
  )