nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -19,12 +19,12 @@ else:
19
19
 
20
20
  if name in globals():
21
21
  return globals()[name]
22
+ if name == "DirectoryConfig":
23
+ return importlib.import_module("nshtrainer._directory").DirectoryConfig
22
24
  if name == "DirectorySetupCallbackConfig":
23
25
  return importlib.import_module(
24
26
  "nshtrainer._directory"
25
27
  ).DirectorySetupCallbackConfig
26
- if name == "DirectoryConfig":
27
- return importlib.import_module("nshtrainer._directory").DirectoryConfig
28
28
  if name == "LoggerConfig":
29
29
  return importlib.import_module("nshtrainer._directory").LoggerConfig
30
30
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -19,14 +19,14 @@ else:
19
19
 
20
20
  if name in globals():
21
21
  return globals()[name]
22
+ if name == "CallbackConfigBase":
23
+ return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
22
24
  if name == "HuggingFaceHubAutoCreateConfig":
23
25
  return importlib.import_module(
24
26
  "nshtrainer._hf_hub"
25
27
  ).HuggingFaceHubAutoCreateConfig
26
28
  if name == "HuggingFaceHubConfig":
27
29
  return importlib.import_module("nshtrainer._hf_hub").HuggingFaceHubConfig
28
- if name == "CallbackConfigBase":
29
- return importlib.import_module("nshtrainer._hf_hub").CallbackConfigBase
30
30
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
31
31
 
32
32
  # Submodule exports
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
32
32
  from nshtrainer.callbacks import (
33
33
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
34
34
  )
35
+ from nshtrainer.callbacks import LogEpochCallbackConfig as LogEpochCallbackConfig
35
36
  from nshtrainer.callbacks import (
36
37
  NormLoggingCallbackConfig as NormLoggingCallbackConfig,
37
38
  )
@@ -47,7 +48,6 @@ if TYPE_CHECKING:
47
48
  from nshtrainer.callbacks import (
48
49
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
49
50
  )
50
- from nshtrainer.callbacks import ThroughputMonitorConfig as ThroughputMonitorConfig
51
51
  from nshtrainer.callbacks import (
52
52
  WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
53
53
  )
@@ -69,88 +69,88 @@ else:
69
69
 
70
70
  if name in globals():
71
71
  return globals()[name]
72
- if name == "PrintTableMetricsCallbackConfig":
72
+ if name == "ActSaveConfig":
73
+ return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
74
+ if name == "BaseCheckpointCallbackConfig":
75
+ return importlib.import_module(
76
+ "nshtrainer.callbacks.checkpoint._base"
77
+ ).BaseCheckpointCallbackConfig
78
+ if name == "BestCheckpointCallbackConfig":
73
79
  return importlib.import_module(
74
80
  "nshtrainer.callbacks"
75
- ).PrintTableMetricsCallbackConfig
81
+ ).BestCheckpointCallbackConfig
76
82
  if name == "CallbackConfigBase":
77
83
  return importlib.import_module("nshtrainer.callbacks").CallbackConfigBase
84
+ if name == "CheckpointMetadata":
85
+ return importlib.import_module(
86
+ "nshtrainer.callbacks.checkpoint._base"
87
+ ).CheckpointMetadata
78
88
  if name == "DebugFlagCallbackConfig":
79
89
  return importlib.import_module(
80
90
  "nshtrainer.callbacks"
81
91
  ).DebugFlagCallbackConfig
82
- if name == "ThroughputMonitorConfig":
92
+ if name == "DirectorySetupCallbackConfig":
83
93
  return importlib.import_module(
84
94
  "nshtrainer.callbacks"
85
- ).ThroughputMonitorConfig
86
- if name == "GradientSkippingCallbackConfig":
95
+ ).DirectorySetupCallbackConfig
96
+ if name == "EMACallbackConfig":
97
+ return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
98
+ if name == "EarlyStoppingCallbackConfig":
87
99
  return importlib.import_module(
88
100
  "nshtrainer.callbacks"
89
- ).GradientSkippingCallbackConfig
90
- if name == "RLPSanityChecksCallbackConfig":
101
+ ).EarlyStoppingCallbackConfig
102
+ if name == "EpochTimerCallbackConfig":
91
103
  return importlib.import_module(
92
104
  "nshtrainer.callbacks"
93
- ).RLPSanityChecksCallbackConfig
94
- if name == "WandbUploadCodeCallbackConfig":
105
+ ).EpochTimerCallbackConfig
106
+ if name == "FiniteChecksCallbackConfig":
95
107
  return importlib.import_module(
96
108
  "nshtrainer.callbacks"
97
- ).WandbUploadCodeCallbackConfig
98
- if name == "MetricConfig":
99
- return importlib.import_module(
100
- "nshtrainer.callbacks.early_stopping"
101
- ).MetricConfig
102
- if name == "EarlyStoppingCallbackConfig":
109
+ ).FiniteChecksCallbackConfig
110
+ if name == "GradientSkippingCallbackConfig":
103
111
  return importlib.import_module(
104
112
  "nshtrainer.callbacks"
105
- ).EarlyStoppingCallbackConfig
106
- if name == "WandbWatchCallbackConfig":
113
+ ).GradientSkippingCallbackConfig
114
+ if name == "LastCheckpointCallbackConfig":
107
115
  return importlib.import_module(
108
116
  "nshtrainer.callbacks"
109
- ).WandbWatchCallbackConfig
110
- if name == "EMACallbackConfig":
111
- return importlib.import_module("nshtrainer.callbacks").EMACallbackConfig
112
- if name == "DirectorySetupCallbackConfig":
117
+ ).LastCheckpointCallbackConfig
118
+ if name == "LogEpochCallbackConfig":
113
119
  return importlib.import_module(
114
120
  "nshtrainer.callbacks"
115
- ).DirectorySetupCallbackConfig
116
- if name == "ActSaveConfig":
117
- return importlib.import_module("nshtrainer.callbacks.actsave").ActSaveConfig
118
- if name == "FiniteChecksCallbackConfig":
121
+ ).LogEpochCallbackConfig
122
+ if name == "MetricConfig":
119
123
  return importlib.import_module(
120
- "nshtrainer.callbacks"
121
- ).FiniteChecksCallbackConfig
124
+ "nshtrainer.callbacks.early_stopping"
125
+ ).MetricConfig
122
126
  if name == "NormLoggingCallbackConfig":
123
127
  return importlib.import_module(
124
128
  "nshtrainer.callbacks"
125
129
  ).NormLoggingCallbackConfig
126
- if name == "EpochTimerCallbackConfig":
127
- return importlib.import_module(
128
- "nshtrainer.callbacks"
129
- ).EpochTimerCallbackConfig
130
130
  if name == "OnExceptionCheckpointCallbackConfig":
131
131
  return importlib.import_module(
132
132
  "nshtrainer.callbacks"
133
133
  ).OnExceptionCheckpointCallbackConfig
134
- if name == "LastCheckpointCallbackConfig":
134
+ if name == "PrintTableMetricsCallbackConfig":
135
135
  return importlib.import_module(
136
136
  "nshtrainer.callbacks"
137
- ).LastCheckpointCallbackConfig
137
+ ).PrintTableMetricsCallbackConfig
138
+ if name == "RLPSanityChecksCallbackConfig":
139
+ return importlib.import_module(
140
+ "nshtrainer.callbacks"
141
+ ).RLPSanityChecksCallbackConfig
138
142
  if name == "SharedParametersCallbackConfig":
139
143
  return importlib.import_module(
140
144
  "nshtrainer.callbacks"
141
145
  ).SharedParametersCallbackConfig
142
- if name == "BestCheckpointCallbackConfig":
146
+ if name == "WandbUploadCodeCallbackConfig":
143
147
  return importlib.import_module(
144
148
  "nshtrainer.callbacks"
145
- ).BestCheckpointCallbackConfig
146
- if name == "CheckpointMetadata":
147
- return importlib.import_module(
148
- "nshtrainer.callbacks.checkpoint._base"
149
- ).CheckpointMetadata
150
- if name == "BaseCheckpointCallbackConfig":
149
+ ).WandbUploadCodeCallbackConfig
150
+ if name == "WandbWatchCallbackConfig":
151
151
  return importlib.import_module(
152
- "nshtrainer.callbacks.checkpoint._base"
153
- ).BaseCheckpointCallbackConfig
152
+ "nshtrainer.callbacks"
153
+ ).WandbWatchCallbackConfig
154
154
  if name == "CallbackConfig":
155
155
  return importlib.import_module("nshtrainer.callbacks").CallbackConfig
156
156
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -166,11 +166,11 @@ from . import early_stopping as early_stopping
166
166
  from . import ema as ema
167
167
  from . import finite_checks as finite_checks
168
168
  from . import gradient_skipping as gradient_skipping
169
+ from . import log_epoch as log_epoch
169
170
  from . import norm_logging as norm_logging
170
171
  from . import print_table as print_table
171
172
  from . import rlp_sanity_checks as rlp_sanity_checks
172
173
  from . import shared_parameters as shared_parameters
173
- from . import throughput_monitor as throughput_monitor
174
174
  from . import timer as timer
175
175
  from . import wandb_upload_code as wandb_upload_code
176
176
  from . import wandb_watch as wandb_watch
@@ -35,34 +35,34 @@ else:
35
35
 
36
36
  if name in globals():
37
37
  return globals()[name]
38
- if name == "CheckpointMetadata":
39
- return importlib.import_module(
40
- "nshtrainer.callbacks.checkpoint._base"
41
- ).CheckpointMetadata
42
38
  if name == "BaseCheckpointCallbackConfig":
43
39
  return importlib.import_module(
44
40
  "nshtrainer.callbacks.checkpoint._base"
45
41
  ).BaseCheckpointCallbackConfig
46
- if name == "LastCheckpointCallbackConfig":
42
+ if name == "BestCheckpointCallbackConfig":
47
43
  return importlib.import_module(
48
44
  "nshtrainer.callbacks.checkpoint"
49
- ).LastCheckpointCallbackConfig
45
+ ).BestCheckpointCallbackConfig
50
46
  if name == "CallbackConfigBase":
51
47
  return importlib.import_module(
52
48
  "nshtrainer.callbacks.checkpoint._base"
53
49
  ).CallbackConfigBase
54
- if name == "OnExceptionCheckpointCallbackConfig":
50
+ if name == "CheckpointMetadata":
55
51
  return importlib.import_module(
56
- "nshtrainer.callbacks.checkpoint"
57
- ).OnExceptionCheckpointCallbackConfig
58
- if name == "BestCheckpointCallbackConfig":
52
+ "nshtrainer.callbacks.checkpoint._base"
53
+ ).CheckpointMetadata
54
+ if name == "LastCheckpointCallbackConfig":
59
55
  return importlib.import_module(
60
56
  "nshtrainer.callbacks.checkpoint"
61
- ).BestCheckpointCallbackConfig
57
+ ).LastCheckpointCallbackConfig
62
58
  if name == "MetricConfig":
63
59
  return importlib.import_module(
64
60
  "nshtrainer.callbacks.checkpoint.best_checkpoint"
65
61
  ).MetricConfig
62
+ if name == "OnExceptionCheckpointCallbackConfig":
63
+ return importlib.import_module(
64
+ "nshtrainer.callbacks.checkpoint"
65
+ ).OnExceptionCheckpointCallbackConfig
66
66
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
67
67
 
68
68
 
@@ -23,10 +23,6 @@ else:
23
23
 
24
24
  if name in globals():
25
25
  return globals()[name]
26
- if name == "CheckpointMetadata":
27
- return importlib.import_module(
28
- "nshtrainer.callbacks.checkpoint._base"
29
- ).CheckpointMetadata
30
26
  if name == "BaseCheckpointCallbackConfig":
31
27
  return importlib.import_module(
32
28
  "nshtrainer.callbacks.checkpoint._base"
@@ -35,6 +31,10 @@ else:
35
31
  return importlib.import_module(
36
32
  "nshtrainer.callbacks.checkpoint._base"
37
33
  ).CallbackConfigBase
34
+ if name == "CheckpointMetadata":
35
+ return importlib.import_module(
36
+ "nshtrainer.callbacks.checkpoint._base"
37
+ ).CheckpointMetadata
38
38
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
39
39
 
40
40
  # Submodule exports
@@ -26,22 +26,22 @@ else:
26
26
 
27
27
  if name in globals():
28
28
  return globals()[name]
29
- if name == "CheckpointMetadata":
30
- return importlib.import_module(
31
- "nshtrainer.callbacks.checkpoint.best_checkpoint"
32
- ).CheckpointMetadata
33
29
  if name == "BaseCheckpointCallbackConfig":
34
30
  return importlib.import_module(
35
31
  "nshtrainer.callbacks.checkpoint.best_checkpoint"
36
32
  ).BaseCheckpointCallbackConfig
37
- if name == "MetricConfig":
38
- return importlib.import_module(
39
- "nshtrainer.callbacks.checkpoint.best_checkpoint"
40
- ).MetricConfig
41
33
  if name == "BestCheckpointCallbackConfig":
42
34
  return importlib.import_module(
43
35
  "nshtrainer.callbacks.checkpoint.best_checkpoint"
44
36
  ).BestCheckpointCallbackConfig
37
+ if name == "CheckpointMetadata":
38
+ return importlib.import_module(
39
+ "nshtrainer.callbacks.checkpoint.best_checkpoint"
40
+ ).CheckpointMetadata
41
+ if name == "MetricConfig":
42
+ return importlib.import_module(
43
+ "nshtrainer.callbacks.checkpoint.best_checkpoint"
44
+ ).MetricConfig
45
45
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
46
46
 
47
47
  # Submodule exports
@@ -23,14 +23,14 @@ else:
23
23
 
24
24
  if name in globals():
25
25
  return globals()[name]
26
- if name == "CheckpointMetadata":
27
- return importlib.import_module(
28
- "nshtrainer.callbacks.checkpoint.last_checkpoint"
29
- ).CheckpointMetadata
30
26
  if name == "BaseCheckpointCallbackConfig":
31
27
  return importlib.import_module(
32
28
  "nshtrainer.callbacks.checkpoint.last_checkpoint"
33
29
  ).BaseCheckpointCallbackConfig
30
+ if name == "CheckpointMetadata":
31
+ return importlib.import_module(
32
+ "nshtrainer.callbacks.checkpoint.last_checkpoint"
33
+ ).CheckpointMetadata
34
34
  if name == "LastCheckpointCallbackConfig":
35
35
  return importlib.import_module(
36
36
  "nshtrainer.callbacks.checkpoint.last_checkpoint"
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "OnExceptionCheckpointCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
26
- ).OnExceptionCheckpointCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
30
26
  ).CallbackConfigBase
27
+ if name == "OnExceptionCheckpointCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.checkpoint.on_exception_checkpoint"
30
+ ).OnExceptionCheckpointCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -18,14 +18,14 @@ else:
18
18
 
19
19
  if name in globals():
20
20
  return globals()[name]
21
- if name == "DebugFlagCallbackConfig":
22
- return importlib.import_module(
23
- "nshtrainer.callbacks.debug_flag"
24
- ).DebugFlagCallbackConfig
25
21
  if name == "CallbackConfigBase":
26
22
  return importlib.import_module(
27
23
  "nshtrainer.callbacks.debug_flag"
28
24
  ).CallbackConfigBase
25
+ if name == "DebugFlagCallbackConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.debug_flag"
28
+ ).DebugFlagCallbackConfig
29
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
30
 
31
31
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "DirectorySetupCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.directory_setup"
26
- ).DirectorySetupCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.directory_setup"
30
26
  ).CallbackConfigBase
27
+ if name == "DirectorySetupCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.directory_setup"
30
+ ).DirectorySetupCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -21,18 +21,18 @@ else:
21
21
 
22
22
  if name in globals():
23
23
  return globals()[name]
24
- if name == "MetricConfig":
24
+ if name == "CallbackConfigBase":
25
25
  return importlib.import_module(
26
26
  "nshtrainer.callbacks.early_stopping"
27
- ).MetricConfig
27
+ ).CallbackConfigBase
28
28
  if name == "EarlyStoppingCallbackConfig":
29
29
  return importlib.import_module(
30
30
  "nshtrainer.callbacks.early_stopping"
31
31
  ).EarlyStoppingCallbackConfig
32
- if name == "CallbackConfigBase":
32
+ if name == "MetricConfig":
33
33
  return importlib.import_module(
34
34
  "nshtrainer.callbacks.early_stopping"
35
- ).CallbackConfigBase
35
+ ).MetricConfig
36
36
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
37
37
 
38
38
  # Submodule exports
@@ -16,12 +16,12 @@ else:
16
16
 
17
17
  if name in globals():
18
18
  return globals()[name]
19
- if name == "EMACallbackConfig":
20
- return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
21
19
  if name == "CallbackConfigBase":
22
20
  return importlib.import_module(
23
21
  "nshtrainer.callbacks.ema"
24
22
  ).CallbackConfigBase
23
+ if name == "EMACallbackConfig":
24
+ return importlib.import_module("nshtrainer.callbacks.ema").EMACallbackConfig
25
25
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
26
26
 
27
27
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "FiniteChecksCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.finite_checks"
26
- ).FiniteChecksCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.finite_checks"
30
26
  ).CallbackConfigBase
27
+ if name == "FiniteChecksCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.finite_checks"
30
+ ).FiniteChecksCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "GradientSkippingCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.gradient_skipping"
26
- ).GradientSkippingCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.gradient_skipping"
30
26
  ).CallbackConfigBase
27
+ if name == "GradientSkippingCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.gradient_skipping"
30
+ ).GradientSkippingCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -7,11 +7,9 @@ from typing import TYPE_CHECKING
7
7
  # Config/alias imports
8
8
 
9
9
  if TYPE_CHECKING:
10
- from nshtrainer.callbacks.throughput_monitor import (
11
- CallbackConfigBase as CallbackConfigBase,
12
- )
13
- from nshtrainer.callbacks.throughput_monitor import (
14
- ThroughputMonitorConfig as ThroughputMonitorConfig,
10
+ from nshtrainer.callbacks.log_epoch import CallbackConfigBase as CallbackConfigBase
11
+ from nshtrainer.callbacks.log_epoch import (
12
+ LogEpochCallbackConfig as LogEpochCallbackConfig,
15
13
  )
16
14
  else:
17
15
 
@@ -20,14 +18,14 @@ else:
20
18
 
21
19
  if name in globals():
22
20
  return globals()[name]
23
- if name == "ThroughputMonitorConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.throughput_monitor"
26
- ).ThroughputMonitorConfig
27
21
  if name == "CallbackConfigBase":
28
22
  return importlib.import_module(
29
- "nshtrainer.callbacks.throughput_monitor"
23
+ "nshtrainer.callbacks.log_epoch"
30
24
  ).CallbackConfigBase
25
+ if name == "LogEpochCallbackConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.log_epoch"
28
+ ).LogEpochCallbackConfig
31
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
30
 
33
31
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "NormLoggingCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.norm_logging"
26
- ).NormLoggingCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.norm_logging"
30
26
  ).CallbackConfigBase
27
+ if name == "NormLoggingCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.norm_logging"
30
+ ).NormLoggingCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "PrintTableMetricsCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.print_table"
26
- ).PrintTableMetricsCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.print_table"
30
26
  ).CallbackConfigBase
27
+ if name == "PrintTableMetricsCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.print_table"
30
+ ).PrintTableMetricsCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "RLPSanityChecksCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.rlp_sanity_checks"
26
- ).RLPSanityChecksCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.rlp_sanity_checks"
30
26
  ).CallbackConfigBase
27
+ if name == "RLPSanityChecksCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.rlp_sanity_checks"
30
+ ).RLPSanityChecksCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "SharedParametersCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.shared_parameters"
26
- ).SharedParametersCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.shared_parameters"
30
26
  ).CallbackConfigBase
27
+ if name == "SharedParametersCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.shared_parameters"
30
+ ).SharedParametersCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -18,14 +18,14 @@ else:
18
18
 
19
19
  if name in globals():
20
20
  return globals()[name]
21
- if name == "EpochTimerCallbackConfig":
22
- return importlib.import_module(
23
- "nshtrainer.callbacks.timer"
24
- ).EpochTimerCallbackConfig
25
21
  if name == "CallbackConfigBase":
26
22
  return importlib.import_module(
27
23
  "nshtrainer.callbacks.timer"
28
24
  ).CallbackConfigBase
25
+ if name == "EpochTimerCallbackConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.callbacks.timer"
28
+ ).EpochTimerCallbackConfig
29
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
30
 
31
31
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "WandbUploadCodeCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.wandb_upload_code"
26
- ).WandbUploadCodeCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.wandb_upload_code"
30
26
  ).CallbackConfigBase
27
+ if name == "WandbUploadCodeCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.wandb_upload_code"
30
+ ).WandbUploadCodeCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports
@@ -20,14 +20,14 @@ else:
20
20
 
21
21
  if name in globals():
22
22
  return globals()[name]
23
- if name == "WandbWatchCallbackConfig":
24
- return importlib.import_module(
25
- "nshtrainer.callbacks.wandb_watch"
26
- ).WandbWatchCallbackConfig
27
23
  if name == "CallbackConfigBase":
28
24
  return importlib.import_module(
29
25
  "nshtrainer.callbacks.wandb_watch"
30
26
  ).CallbackConfigBase
27
+ if name == "WandbWatchCallbackConfig":
28
+ return importlib.import_module(
29
+ "nshtrainer.callbacks.wandb_watch"
30
+ ).WandbWatchCallbackConfig
31
31
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
32
32
 
33
33
  # Submodule exports