nshtrainer 0.44.1__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. nshtrainer/__init__.py +6 -3
  2. nshtrainer/_callback.py +297 -2
  3. nshtrainer/_checkpoint/loader.py +23 -30
  4. nshtrainer/_checkpoint/metadata.py +22 -18
  5. nshtrainer/_experimental/__init__.py +0 -2
  6. nshtrainer/_hf_hub.py +25 -26
  7. nshtrainer/callbacks/__init__.py +1 -3
  8. nshtrainer/callbacks/actsave.py +22 -20
  9. nshtrainer/callbacks/base.py +7 -7
  10. nshtrainer/callbacks/checkpoint/__init__.py +1 -1
  11. nshtrainer/callbacks/checkpoint/_base.py +8 -5
  12. nshtrainer/callbacks/checkpoint/best_checkpoint.py +4 -4
  13. nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  14. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +4 -4
  15. nshtrainer/callbacks/debug_flag.py +14 -19
  16. nshtrainer/callbacks/directory_setup.py +6 -11
  17. nshtrainer/callbacks/early_stopping.py +3 -3
  18. nshtrainer/callbacks/ema.py +1 -1
  19. nshtrainer/callbacks/finite_checks.py +1 -1
  20. nshtrainer/callbacks/gradient_skipping.py +1 -1
  21. nshtrainer/callbacks/log_epoch.py +1 -1
  22. nshtrainer/callbacks/norm_logging.py +1 -1
  23. nshtrainer/callbacks/print_table.py +1 -1
  24. nshtrainer/callbacks/rlp_sanity_checks.py +1 -1
  25. nshtrainer/callbacks/shared_parameters.py +1 -1
  26. nshtrainer/callbacks/timer.py +1 -1
  27. nshtrainer/callbacks/wandb_upload_code.py +1 -1
  28. nshtrainer/callbacks/wandb_watch.py +1 -1
  29. nshtrainer/config/__init__.py +189 -189
  30. nshtrainer/config/_checkpoint/__init__.py +70 -0
  31. nshtrainer/config/_checkpoint/loader/__init__.py +6 -6
  32. nshtrainer/config/_directory/__init__.py +2 -2
  33. nshtrainer/config/_hf_hub/__init__.py +2 -2
  34. nshtrainer/config/callbacks/__init__.py +44 -44
  35. nshtrainer/config/callbacks/checkpoint/__init__.py +11 -11
  36. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +4 -4
  37. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +8 -8
  38. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +4 -4
  39. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +4 -4
  40. nshtrainer/config/callbacks/debug_flag/__init__.py +4 -4
  41. nshtrainer/config/callbacks/directory_setup/__init__.py +4 -4
  42. nshtrainer/config/callbacks/early_stopping/__init__.py +4 -4
  43. nshtrainer/config/callbacks/ema/__init__.py +2 -2
  44. nshtrainer/config/callbacks/finite_checks/__init__.py +4 -4
  45. nshtrainer/config/callbacks/gradient_skipping/__init__.py +4 -4
  46. nshtrainer/config/callbacks/{throughput_monitor → log_epoch}/__init__.py +8 -10
  47. nshtrainer/config/callbacks/norm_logging/__init__.py +4 -4
  48. nshtrainer/config/callbacks/print_table/__init__.py +4 -4
  49. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +4 -4
  50. nshtrainer/config/callbacks/shared_parameters/__init__.py +4 -4
  51. nshtrainer/config/callbacks/timer/__init__.py +4 -4
  52. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +4 -4
  53. nshtrainer/config/callbacks/wandb_watch/__init__.py +4 -4
  54. nshtrainer/config/loggers/__init__.py +10 -6
  55. nshtrainer/config/loggers/actsave/__init__.py +29 -0
  56. nshtrainer/config/loggers/csv/__init__.py +2 -2
  57. nshtrainer/config/loggers/wandb/__init__.py +6 -6
  58. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +4 -4
  59. nshtrainer/config/nn/__init__.py +18 -18
  60. nshtrainer/config/nn/nonlinearity/__init__.py +26 -26
  61. nshtrainer/config/optimizer/__init__.py +2 -2
  62. nshtrainer/config/profiler/__init__.py +2 -2
  63. nshtrainer/config/profiler/pytorch/__init__.py +4 -4
  64. nshtrainer/config/profiler/simple/__init__.py +4 -4
  65. nshtrainer/config/trainer/__init__.py +180 -0
  66. nshtrainer/config/trainer/_config/__init__.py +59 -36
  67. nshtrainer/config/trainer/trainer/__init__.py +27 -0
  68. nshtrainer/config/util/__init__.py +109 -0
  69. nshtrainer/config/util/_environment_info/__init__.py +20 -20
  70. nshtrainer/config/util/config/__init__.py +2 -2
  71. nshtrainer/data/datamodule.py +52 -2
  72. nshtrainer/loggers/__init__.py +2 -1
  73. nshtrainer/loggers/_base.py +5 -2
  74. nshtrainer/loggers/actsave.py +59 -0
  75. nshtrainer/loggers/csv.py +5 -5
  76. nshtrainer/loggers/tensorboard.py +5 -5
  77. nshtrainer/loggers/wandb.py +17 -16
  78. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +9 -7
  79. nshtrainer/model/__init__.py +0 -4
  80. nshtrainer/model/base.py +64 -347
  81. nshtrainer/model/mixins/callback.py +24 -5
  82. nshtrainer/model/mixins/debug.py +86 -0
  83. nshtrainer/model/mixins/logger.py +142 -145
  84. nshtrainer/profiler/_base.py +2 -2
  85. nshtrainer/profiler/advanced.py +4 -4
  86. nshtrainer/profiler/pytorch.py +4 -4
  87. nshtrainer/profiler/simple.py +4 -4
  88. nshtrainer/trainer/__init__.py +1 -0
  89. nshtrainer/trainer/_config.py +164 -17
  90. nshtrainer/trainer/checkpoint_connector.py +23 -8
  91. nshtrainer/trainer/trainer.py +194 -76
  92. nshtrainer/util/_environment_info.py +21 -13
  93. nshtrainer/util/config/dtype.py +4 -4
  94. nshtrainer/util/typing_utils.py +1 -1
  95. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/METADATA +2 -2
  96. nshtrainer-1.0.0b10.dist-info/RECORD +143 -0
  97. nshtrainer/callbacks/_throughput_monitor_callback.py +0 -551
  98. nshtrainer/callbacks/throughput_monitor.py +0 -58
  99. nshtrainer/config/model/__init__.py +0 -41
  100. nshtrainer/config/model/base/__init__.py +0 -25
  101. nshtrainer/config/model/config/__init__.py +0 -37
  102. nshtrainer/config/model/mixins/logger/__init__.py +0 -22
  103. nshtrainer/config/runner/__init__.py +0 -22
  104. nshtrainer/ll/__init__.py +0 -59
  105. nshtrainer/ll/_experimental.py +0 -3
  106. nshtrainer/ll/actsave.py +0 -6
  107. nshtrainer/ll/callbacks.py +0 -3
  108. nshtrainer/ll/config.py +0 -6
  109. nshtrainer/ll/data.py +0 -3
  110. nshtrainer/ll/log.py +0 -5
  111. nshtrainer/ll/lr_scheduler.py +0 -3
  112. nshtrainer/ll/model.py +0 -21
  113. nshtrainer/ll/nn.py +0 -3
  114. nshtrainer/ll/optimizer.py +0 -3
  115. nshtrainer/ll/runner.py +0 -5
  116. nshtrainer/ll/snapshot.py +0 -3
  117. nshtrainer/ll/snoop.py +0 -3
  118. nshtrainer/ll/trainer.py +0 -3
  119. nshtrainer/ll/typecheck.py +0 -3
  120. nshtrainer/ll/util.py +0 -3
  121. nshtrainer/model/config.py +0 -218
  122. nshtrainer/runner.py +0 -101
  123. nshtrainer-0.44.1.dist-info/RECORD +0 -162
  124. {nshtrainer-0.44.1.dist-info → nshtrainer-1.0.0b10.dist-info}/WHEEL +0 -0
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
7
7
  # Config/alias imports
8
8
 
9
9
  if TYPE_CHECKING:
10
+ from nshtrainer.loggers import ActSaveLoggerConfig as ActSaveLoggerConfig
10
11
  from nshtrainer.loggers import BaseLoggerConfig as BaseLoggerConfig
11
12
  from nshtrainer.loggers import CSVLoggerConfig as CSVLoggerConfig
12
13
  from nshtrainer.loggers import LoggerConfig as LoggerConfig
@@ -26,8 +27,16 @@ else:
26
27
 
27
28
  if name in globals():
28
29
  return globals()[name]
30
+ if name == "ActSaveLoggerConfig":
31
+ return importlib.import_module("nshtrainer.loggers").ActSaveLoggerConfig
29
32
  if name == "BaseLoggerConfig":
30
33
  return importlib.import_module("nshtrainer.loggers").BaseLoggerConfig
34
+ if name == "CSVLoggerConfig":
35
+ return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
36
+ if name == "CallbackConfigBase":
37
+ return importlib.import_module(
38
+ "nshtrainer.loggers.wandb"
39
+ ).CallbackConfigBase
31
40
  if name == "TensorboardLoggerConfig":
32
41
  return importlib.import_module("nshtrainer.loggers").TensorboardLoggerConfig
33
42
  if name == "WandbLoggerConfig":
@@ -40,12 +49,6 @@ else:
40
49
  return importlib.import_module(
41
50
  "nshtrainer.loggers.wandb"
42
51
  ).WandbWatchCallbackConfig
43
- if name == "CallbackConfigBase":
44
- return importlib.import_module(
45
- "nshtrainer.loggers.wandb"
46
- ).CallbackConfigBase
47
- if name == "CSVLoggerConfig":
48
- return importlib.import_module("nshtrainer.loggers").CSVLoggerConfig
49
52
  if name == "LoggerConfig":
50
53
  return importlib.import_module("nshtrainer.loggers").LoggerConfig
51
54
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -53,6 +56,7 @@ else:
53
56
 
54
57
  # Submodule exports
55
58
  from . import _base as _base
59
+ from . import actsave as actsave
56
60
  from . import csv as csv
57
61
  from . import tensorboard as tensorboard
58
62
  from . import wandb as wandb
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.loggers.actsave import ActSaveLoggerConfig as ActSaveLoggerConfig
11
+ from nshtrainer.loggers.actsave import BaseLoggerConfig as BaseLoggerConfig
12
+ else:
13
+
14
+ def __getattr__(name):
15
+ import importlib
16
+
17
+ if name in globals():
18
+ return globals()[name]
19
+ if name == "ActSaveLoggerConfig":
20
+ return importlib.import_module(
21
+ "nshtrainer.loggers.actsave"
22
+ ).ActSaveLoggerConfig
23
+ if name == "BaseLoggerConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.loggers.actsave"
26
+ ).BaseLoggerConfig
27
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
+
29
+ # Submodule exports
@@ -16,10 +16,10 @@ else:
16
16
 
17
17
  if name in globals():
18
18
  return globals()[name]
19
- if name == "CSVLoggerConfig":
20
- return importlib.import_module("nshtrainer.loggers.csv").CSVLoggerConfig
21
19
  if name == "BaseLoggerConfig":
22
20
  return importlib.import_module("nshtrainer.loggers.csv").BaseLoggerConfig
21
+ if name == "CSVLoggerConfig":
22
+ return importlib.import_module("nshtrainer.loggers.csv").CSVLoggerConfig
23
23
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
24
24
 
25
25
  # Submodule exports
@@ -23,6 +23,12 @@ else:
23
23
 
24
24
  if name in globals():
25
25
  return globals()[name]
26
+ if name == "BaseLoggerConfig":
27
+ return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
28
+ if name == "CallbackConfigBase":
29
+ return importlib.import_module(
30
+ "nshtrainer.loggers.wandb"
31
+ ).CallbackConfigBase
26
32
  if name == "WandbLoggerConfig":
27
33
  return importlib.import_module("nshtrainer.loggers.wandb").WandbLoggerConfig
28
34
  if name == "WandbUploadCodeCallbackConfig":
@@ -33,12 +39,6 @@ else:
33
39
  return importlib.import_module(
34
40
  "nshtrainer.loggers.wandb"
35
41
  ).WandbWatchCallbackConfig
36
- if name == "BaseLoggerConfig":
37
- return importlib.import_module("nshtrainer.loggers.wandb").BaseLoggerConfig
38
- if name == "CallbackConfigBase":
39
- return importlib.import_module(
40
- "nshtrainer.loggers.wandb"
41
- ).CallbackConfigBase
42
42
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
43
43
 
44
44
  # Submodule exports
@@ -23,14 +23,14 @@ else:
23
23
 
24
24
  if name in globals():
25
25
  return globals()[name]
26
- if name == "LinearWarmupCosineDecayLRSchedulerConfig":
27
- return importlib.import_module(
28
- "nshtrainer.lr_scheduler.linear_warmup_cosine"
29
- ).LinearWarmupCosineDecayLRSchedulerConfig
30
26
  if name == "LRSchedulerConfigBase":
31
27
  return importlib.import_module(
32
28
  "nshtrainer.lr_scheduler.linear_warmup_cosine"
33
29
  ).LRSchedulerConfigBase
30
+ if name == "LinearWarmupCosineDecayLRSchedulerConfig":
31
+ return importlib.import_module(
32
+ "nshtrainer.lr_scheduler.linear_warmup_cosine"
33
+ ).LinearWarmupCosineDecayLRSchedulerConfig
34
34
  if name == "DurationConfig":
35
35
  return importlib.import_module(
36
36
  "nshtrainer.lr_scheduler.linear_warmup_cosine"
@@ -35,38 +35,38 @@ else:
35
35
  return globals()[name]
36
36
  if name == "BaseNonlinearityConfig":
37
37
  return importlib.import_module("nshtrainer.nn").BaseNonlinearityConfig
38
+ if name == "ELUNonlinearityConfig":
39
+ return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
40
+ if name == "GELUNonlinearityConfig":
41
+ return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
42
+ if name == "LeakyReLUNonlinearityConfig":
43
+ return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
38
44
  if name == "MLPConfig":
39
45
  return importlib.import_module("nshtrainer.nn").MLPConfig
46
+ if name == "MishNonlinearityConfig":
47
+ return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
40
48
  if name == "PReLUConfig":
41
49
  return importlib.import_module("nshtrainer.nn").PReLUConfig
42
- if name == "LeakyReLUNonlinearityConfig":
43
- return importlib.import_module("nshtrainer.nn").LeakyReLUNonlinearityConfig
44
- if name == "SwiGLUNonlinearityConfig":
45
- return importlib.import_module(
46
- "nshtrainer.nn.nonlinearity"
47
- ).SwiGLUNonlinearityConfig
48
- if name == "SoftsignNonlinearityConfig":
49
- return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
50
+ if name == "ReLUNonlinearityConfig":
51
+ return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
50
52
  if name == "SiLUNonlinearityConfig":
51
53
  return importlib.import_module("nshtrainer.nn").SiLUNonlinearityConfig
52
54
  if name == "SigmoidNonlinearityConfig":
53
55
  return importlib.import_module("nshtrainer.nn").SigmoidNonlinearityConfig
54
- if name == "SoftplusNonlinearityConfig":
55
- return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
56
- if name == "ELUNonlinearityConfig":
57
- return importlib.import_module("nshtrainer.nn").ELUNonlinearityConfig
58
56
  if name == "SoftmaxNonlinearityConfig":
59
57
  return importlib.import_module("nshtrainer.nn").SoftmaxNonlinearityConfig
60
- if name == "GELUNonlinearityConfig":
61
- return importlib.import_module("nshtrainer.nn").GELUNonlinearityConfig
58
+ if name == "SoftplusNonlinearityConfig":
59
+ return importlib.import_module("nshtrainer.nn").SoftplusNonlinearityConfig
60
+ if name == "SoftsignNonlinearityConfig":
61
+ return importlib.import_module("nshtrainer.nn").SoftsignNonlinearityConfig
62
+ if name == "SwiGLUNonlinearityConfig":
63
+ return importlib.import_module(
64
+ "nshtrainer.nn.nonlinearity"
65
+ ).SwiGLUNonlinearityConfig
62
66
  if name == "SwishNonlinearityConfig":
63
67
  return importlib.import_module("nshtrainer.nn").SwishNonlinearityConfig
64
- if name == "MishNonlinearityConfig":
65
- return importlib.import_module("nshtrainer.nn").MishNonlinearityConfig
66
68
  if name == "TanhNonlinearityConfig":
67
69
  return importlib.import_module("nshtrainer.nn").TanhNonlinearityConfig
68
- if name == "ReLUNonlinearityConfig":
69
- return importlib.import_module("nshtrainer.nn").ReLUNonlinearityConfig
70
70
  if name == "NonlinearityConfig":
71
71
  return importlib.import_module("nshtrainer.nn").NonlinearityConfig
72
72
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -58,20 +58,32 @@ else:
58
58
 
59
59
  if name in globals():
60
60
  return globals()[name]
61
- if name == "PReLUConfig":
62
- return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
61
+ if name == "BaseNonlinearityConfig":
62
+ return importlib.import_module(
63
+ "nshtrainer.nn.nonlinearity"
64
+ ).BaseNonlinearityConfig
65
+ if name == "ELUNonlinearityConfig":
66
+ return importlib.import_module(
67
+ "nshtrainer.nn.nonlinearity"
68
+ ).ELUNonlinearityConfig
69
+ if name == "GELUNonlinearityConfig":
70
+ return importlib.import_module(
71
+ "nshtrainer.nn.nonlinearity"
72
+ ).GELUNonlinearityConfig
63
73
  if name == "LeakyReLUNonlinearityConfig":
64
74
  return importlib.import_module(
65
75
  "nshtrainer.nn.nonlinearity"
66
76
  ).LeakyReLUNonlinearityConfig
67
- if name == "SwiGLUNonlinearityConfig":
77
+ if name == "MishNonlinearityConfig":
68
78
  return importlib.import_module(
69
79
  "nshtrainer.nn.nonlinearity"
70
- ).SwiGLUNonlinearityConfig
71
- if name == "SoftsignNonlinearityConfig":
80
+ ).MishNonlinearityConfig
81
+ if name == "PReLUConfig":
82
+ return importlib.import_module("nshtrainer.nn.nonlinearity").PReLUConfig
83
+ if name == "ReLUNonlinearityConfig":
72
84
  return importlib.import_module(
73
85
  "nshtrainer.nn.nonlinearity"
74
- ).SoftsignNonlinearityConfig
86
+ ).ReLUNonlinearityConfig
75
87
  if name == "SiLUNonlinearityConfig":
76
88
  return importlib.import_module(
77
89
  "nshtrainer.nn.nonlinearity"
@@ -80,42 +92,30 @@ else:
80
92
  return importlib.import_module(
81
93
  "nshtrainer.nn.nonlinearity"
82
94
  ).SigmoidNonlinearityConfig
83
- if name == "SoftplusNonlinearityConfig":
95
+ if name == "SoftmaxNonlinearityConfig":
84
96
  return importlib.import_module(
85
97
  "nshtrainer.nn.nonlinearity"
86
- ).SoftplusNonlinearityConfig
87
- if name == "ELUNonlinearityConfig":
98
+ ).SoftmaxNonlinearityConfig
99
+ if name == "SoftplusNonlinearityConfig":
88
100
  return importlib.import_module(
89
101
  "nshtrainer.nn.nonlinearity"
90
- ).ELUNonlinearityConfig
91
- if name == "SoftmaxNonlinearityConfig":
102
+ ).SoftplusNonlinearityConfig
103
+ if name == "SoftsignNonlinearityConfig":
92
104
  return importlib.import_module(
93
105
  "nshtrainer.nn.nonlinearity"
94
- ).SoftmaxNonlinearityConfig
95
- if name == "GELUNonlinearityConfig":
106
+ ).SoftsignNonlinearityConfig
107
+ if name == "SwiGLUNonlinearityConfig":
96
108
  return importlib.import_module(
97
109
  "nshtrainer.nn.nonlinearity"
98
- ).GELUNonlinearityConfig
110
+ ).SwiGLUNonlinearityConfig
99
111
  if name == "SwishNonlinearityConfig":
100
112
  return importlib.import_module(
101
113
  "nshtrainer.nn.nonlinearity"
102
114
  ).SwishNonlinearityConfig
103
- if name == "MishNonlinearityConfig":
104
- return importlib.import_module(
105
- "nshtrainer.nn.nonlinearity"
106
- ).MishNonlinearityConfig
107
115
  if name == "TanhNonlinearityConfig":
108
116
  return importlib.import_module(
109
117
  "nshtrainer.nn.nonlinearity"
110
118
  ).TanhNonlinearityConfig
111
- if name == "ReLUNonlinearityConfig":
112
- return importlib.import_module(
113
- "nshtrainer.nn.nonlinearity"
114
- ).ReLUNonlinearityConfig
115
- if name == "BaseNonlinearityConfig":
116
- return importlib.import_module(
117
- "nshtrainer.nn.nonlinearity"
118
- ).BaseNonlinearityConfig
119
119
  if name == "NonlinearityConfig":
120
120
  return importlib.import_module(
121
121
  "nshtrainer.nn.nonlinearity"
@@ -17,10 +17,10 @@ else:
17
17
 
18
18
  if name in globals():
19
19
  return globals()[name]
20
- if name == "OptimizerConfigBase":
21
- return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
22
20
  if name == "AdamWConfig":
23
21
  return importlib.import_module("nshtrainer.optimizer").AdamWConfig
22
+ if name == "OptimizerConfigBase":
23
+ return importlib.import_module("nshtrainer.optimizer").OptimizerConfigBase
24
24
  if name == "OptimizerConfig":
25
25
  return importlib.import_module("nshtrainer.optimizer").OptimizerConfig
26
26
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -19,12 +19,12 @@ else:
19
19
 
20
20
  if name in globals():
21
21
  return globals()[name]
22
+ if name == "AdvancedProfilerConfig":
23
+ return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
22
24
  if name == "BaseProfilerConfig":
23
25
  return importlib.import_module("nshtrainer.profiler").BaseProfilerConfig
24
26
  if name == "PyTorchProfilerConfig":
25
27
  return importlib.import_module("nshtrainer.profiler").PyTorchProfilerConfig
26
- if name == "AdvancedProfilerConfig":
27
- return importlib.import_module("nshtrainer.profiler").AdvancedProfilerConfig
28
28
  if name == "SimpleProfilerConfig":
29
29
  return importlib.import_module("nshtrainer.profiler").SimpleProfilerConfig
30
30
  if name == "ProfilerConfig":
@@ -18,14 +18,14 @@ else:
18
18
 
19
19
  if name in globals():
20
20
  return globals()[name]
21
- if name == "PyTorchProfilerConfig":
22
- return importlib.import_module(
23
- "nshtrainer.profiler.pytorch"
24
- ).PyTorchProfilerConfig
25
21
  if name == "BaseProfilerConfig":
26
22
  return importlib.import_module(
27
23
  "nshtrainer.profiler.pytorch"
28
24
  ).BaseProfilerConfig
25
+ if name == "PyTorchProfilerConfig":
26
+ return importlib.import_module(
27
+ "nshtrainer.profiler.pytorch"
28
+ ).PyTorchProfilerConfig
29
29
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
30
30
 
31
31
  # Submodule exports
@@ -16,14 +16,14 @@ else:
16
16
 
17
17
  if name in globals():
18
18
  return globals()[name]
19
- if name == "SimpleProfilerConfig":
20
- return importlib.import_module(
21
- "nshtrainer.profiler.simple"
22
- ).SimpleProfilerConfig
23
19
  if name == "BaseProfilerConfig":
24
20
  return importlib.import_module(
25
21
  "nshtrainer.profiler.simple"
26
22
  ).BaseProfilerConfig
23
+ if name == "SimpleProfilerConfig":
24
+ return importlib.import_module(
25
+ "nshtrainer.profiler.simple"
26
+ ).SimpleProfilerConfig
27
27
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
28
28
 
29
29
  # Submodule exports
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ __codegen__ = True
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ # Config/alias imports
8
+
9
+ if TYPE_CHECKING:
10
+ from nshtrainer.trainer import TrainerConfig as TrainerConfig
11
+ from nshtrainer.trainer._config import ActSaveLoggerConfig as ActSaveLoggerConfig
12
+ from nshtrainer.trainer._config import (
13
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
14
+ )
15
+ from nshtrainer.trainer._config import CallbackConfig as CallbackConfig
16
+ from nshtrainer.trainer._config import CallbackConfigBase as CallbackConfigBase
17
+ from nshtrainer.trainer._config import (
18
+ CheckpointCallbackConfig as CheckpointCallbackConfig,
19
+ )
20
+ from nshtrainer.trainer._config import (
21
+ CheckpointLoadingConfig as CheckpointLoadingConfig,
22
+ )
23
+ from nshtrainer.trainer._config import (
24
+ CheckpointSavingConfig as CheckpointSavingConfig,
25
+ )
26
+ from nshtrainer.trainer._config import CSVLoggerConfig as CSVLoggerConfig
27
+ from nshtrainer.trainer._config import (
28
+ DebugFlagCallbackConfig as DebugFlagCallbackConfig,
29
+ )
30
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
31
+ from nshtrainer.trainer._config import (
32
+ EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
33
+ )
34
+ from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
35
+ from nshtrainer.trainer._config import (
36
+ GradientClippingConfig as GradientClippingConfig,
37
+ )
38
+ from nshtrainer.trainer._config import HuggingFaceHubConfig as HuggingFaceHubConfig
39
+ from nshtrainer.trainer._config import (
40
+ LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
41
+ )
42
+ from nshtrainer.trainer._config import (
43
+ LogEpochCallbackConfig as LogEpochCallbackConfig,
44
+ )
45
+ from nshtrainer.trainer._config import LoggerConfig as LoggerConfig
46
+ from nshtrainer.trainer._config import LoggingConfig as LoggingConfig
47
+ from nshtrainer.trainer._config import MetricConfig as MetricConfig
48
+ from nshtrainer.trainer._config import (
49
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
50
+ )
51
+ from nshtrainer.trainer._config import OptimizationConfig as OptimizationConfig
52
+ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
53
+ from nshtrainer.trainer._config import (
54
+ ReproducibilityConfig as ReproducibilityConfig,
55
+ )
56
+ from nshtrainer.trainer._config import (
57
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
58
+ )
59
+ from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
60
+ from nshtrainer.trainer._config import (
61
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
62
+ )
63
+ from nshtrainer.trainer._config import (
64
+ TensorboardLoggerConfig as TensorboardLoggerConfig,
65
+ )
66
+ from nshtrainer.trainer._config import WandbLoggerConfig as WandbLoggerConfig
67
+ else:
68
+
69
+ def __getattr__(name):
70
+ import importlib
71
+
72
+ if name in globals():
73
+ return globals()[name]
74
+ if name == "ActSaveLoggerConfig":
75
+ return importlib.import_module(
76
+ "nshtrainer.trainer._config"
77
+ ).ActSaveLoggerConfig
78
+ if name == "BestCheckpointCallbackConfig":
79
+ return importlib.import_module(
80
+ "nshtrainer.trainer._config"
81
+ ).BestCheckpointCallbackConfig
82
+ if name == "CSVLoggerConfig":
83
+ return importlib.import_module("nshtrainer.trainer._config").CSVLoggerConfig
84
+ if name == "CallbackConfigBase":
85
+ return importlib.import_module(
86
+ "nshtrainer.trainer._config"
87
+ ).CallbackConfigBase
88
+ if name == "CheckpointLoadingConfig":
89
+ return importlib.import_module(
90
+ "nshtrainer.trainer._config"
91
+ ).CheckpointLoadingConfig
92
+ if name == "CheckpointSavingConfig":
93
+ return importlib.import_module(
94
+ "nshtrainer.trainer._config"
95
+ ).CheckpointSavingConfig
96
+ if name == "DebugFlagCallbackConfig":
97
+ return importlib.import_module(
98
+ "nshtrainer.trainer._config"
99
+ ).DebugFlagCallbackConfig
100
+ if name == "DirectoryConfig":
101
+ return importlib.import_module("nshtrainer.trainer._config").DirectoryConfig
102
+ if name == "EarlyStoppingCallbackConfig":
103
+ return importlib.import_module(
104
+ "nshtrainer.trainer._config"
105
+ ).EarlyStoppingCallbackConfig
106
+ if name == "EnvironmentConfig":
107
+ return importlib.import_module(
108
+ "nshtrainer.trainer._config"
109
+ ).EnvironmentConfig
110
+ if name == "GradientClippingConfig":
111
+ return importlib.import_module(
112
+ "nshtrainer.trainer._config"
113
+ ).GradientClippingConfig
114
+ if name == "HuggingFaceHubConfig":
115
+ return importlib.import_module(
116
+ "nshtrainer.trainer._config"
117
+ ).HuggingFaceHubConfig
118
+ if name == "LastCheckpointCallbackConfig":
119
+ return importlib.import_module(
120
+ "nshtrainer.trainer._config"
121
+ ).LastCheckpointCallbackConfig
122
+ if name == "LogEpochCallbackConfig":
123
+ return importlib.import_module(
124
+ "nshtrainer.trainer._config"
125
+ ).LogEpochCallbackConfig
126
+ if name == "LoggingConfig":
127
+ return importlib.import_module("nshtrainer.trainer._config").LoggingConfig
128
+ if name == "MetricConfig":
129
+ return importlib.import_module("nshtrainer.trainer._config").MetricConfig
130
+ if name == "OnExceptionCheckpointCallbackConfig":
131
+ return importlib.import_module(
132
+ "nshtrainer.trainer._config"
133
+ ).OnExceptionCheckpointCallbackConfig
134
+ if name == "OptimizationConfig":
135
+ return importlib.import_module(
136
+ "nshtrainer.trainer._config"
137
+ ).OptimizationConfig
138
+ if name == "RLPSanityChecksCallbackConfig":
139
+ return importlib.import_module(
140
+ "nshtrainer.trainer._config"
141
+ ).RLPSanityChecksCallbackConfig
142
+ if name == "ReproducibilityConfig":
143
+ return importlib.import_module(
144
+ "nshtrainer.trainer._config"
145
+ ).ReproducibilityConfig
146
+ if name == "SanityCheckingConfig":
147
+ return importlib.import_module(
148
+ "nshtrainer.trainer._config"
149
+ ).SanityCheckingConfig
150
+ if name == "SharedParametersCallbackConfig":
151
+ return importlib.import_module(
152
+ "nshtrainer.trainer._config"
153
+ ).SharedParametersCallbackConfig
154
+ if name == "TensorboardLoggerConfig":
155
+ return importlib.import_module(
156
+ "nshtrainer.trainer._config"
157
+ ).TensorboardLoggerConfig
158
+ if name == "TrainerConfig":
159
+ return importlib.import_module("nshtrainer.trainer").TrainerConfig
160
+ if name == "WandbLoggerConfig":
161
+ return importlib.import_module(
162
+ "nshtrainer.trainer._config"
163
+ ).WandbLoggerConfig
164
+ if name == "CallbackConfig":
165
+ return importlib.import_module("nshtrainer.trainer._config").CallbackConfig
166
+ if name == "CheckpointCallbackConfig":
167
+ return importlib.import_module(
168
+ "nshtrainer.trainer._config"
169
+ ).CheckpointCallbackConfig
170
+ if name == "LoggerConfig":
171
+ return importlib.import_module("nshtrainer.trainer._config").LoggerConfig
172
+ if name == "ProfilerConfig":
173
+ return importlib.import_module("nshtrainer.trainer._config").ProfilerConfig
174
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
175
+
176
+
177
+ # Submodule exports
178
+ from . import _config as _config
179
+ from . import checkpoint_connector as checkpoint_connector
180
+ from . import trainer as trainer