nshtrainer 0.42.0__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from . import _experimental as _experimental
2
4
  from . import callbacks as callbacks
3
5
  from . import config as config
nshtrainer/_callback.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from pathlib import Path
2
4
  from typing import TYPE_CHECKING
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Iterable, Sequence
3
5
  from dataclasses import dataclass
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import datetime
3
5
  import logging
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  import shutil
nshtrainer/_directory.py CHANGED
@@ -1,9 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from pathlib import Path
3
5
 
4
6
  import nshconfig as C
5
7
 
6
- from .callbacks.directory_setup import DirectorySetupConfig
8
+ from .callbacks.directory_setup import DirectorySetupCallbackConfig
7
9
  from .loggers import LoggerConfig
8
10
 
9
11
  log = logging.getLogger(__name__)
@@ -32,7 +34,7 @@ class DirectoryConfig(C.Config):
32
34
  profile: Path | None = None
33
35
  """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
34
36
 
35
- setup_callback: DirectorySetupConfig = DirectorySetupConfig()
37
+ setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
36
38
  """Configuration for the directory setup PyTorch Lightning callback."""
37
39
 
38
40
  def resolve_run_root_directory(self, run_id: str) -> Path:
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from lightning.fabric.utilities.throughput import measure_flops as measure_flops
nshtrainer/_hf_hub.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import logging
3
5
  import os
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated
2
4
 
3
5
  import nshconfig as C
@@ -6,60 +8,74 @@ from . import checkpoint as checkpoint
6
8
  from .base import CallbackConfigBase as CallbackConfigBase
7
9
  from .checkpoint import BestCheckpoint as BestCheckpoint
8
10
  from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
- from .checkpoint import LastCheckpoint as LastCheckpoint
11
+ from .checkpoint import LastCheckpointCallback as LastCheckpointCallback
10
12
  from .checkpoint import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
11
- from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
13
+ from .checkpoint import OnExceptionCheckpointCallback as OnExceptionCheckpointCallback
12
14
  from .checkpoint import (
13
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
16
  )
15
17
  from .debug_flag import DebugFlagCallback as DebugFlagCallback
16
18
  from .debug_flag import DebugFlagCallbackConfig as DebugFlagCallbackConfig
17
19
  from .directory_setup import DirectorySetupCallback as DirectorySetupCallback
18
- from .directory_setup import DirectorySetupConfig as DirectorySetupConfig
19
- from .early_stopping import EarlyStopping as EarlyStopping
20
- from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
21
- from .ema import EMA as EMA
22
- from .ema import EMAConfig as EMAConfig
20
+ from .directory_setup import (
21
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
22
+ )
23
+ from .early_stopping import EarlyStoppingCallback as EarlyStoppingCallback
24
+ from .early_stopping import EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig
25
+ from .ema import EMACallback as EMACallback
26
+ from .ema import EMACallbackConfig as EMACallbackConfig
23
27
  from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
24
- from .finite_checks import FiniteChecksConfig as FiniteChecksConfig
25
- from .gradient_skipping import GradientSkipping as GradientSkipping
26
- from .gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
28
+ from .finite_checks import FiniteChecksCallbackConfig as FiniteChecksCallbackConfig
29
+ from .gradient_skipping import GradientSkippingCallback as GradientSkippingCallback
30
+ from .gradient_skipping import (
31
+ GradientSkippingCallbackConfig as GradientSkippingCallbackConfig,
32
+ )
27
33
  from .interval import EpochIntervalCallback as EpochIntervalCallback
28
34
  from .interval import IntervalCallback as IntervalCallback
29
35
  from .interval import StepIntervalCallback as StepIntervalCallback
30
36
  from .log_epoch import LogEpochCallback as LogEpochCallback
37
+ from .log_epoch import LogEpochCallbackConfig as LogEpochCallbackConfig
31
38
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
32
- from .norm_logging import NormLoggingConfig as NormLoggingConfig
39
+ from .norm_logging import NormLoggingCallbackConfig as NormLoggingCallbackConfig
33
40
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
34
- from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
41
+ from .print_table import (
42
+ PrintTableMetricsCallbackConfig as PrintTableMetricsCallbackConfig,
43
+ )
35
44
  from .rlp_sanity_checks import RLPSanityChecksCallback as RLPSanityChecksCallback
36
- from .rlp_sanity_checks import RLPSanityChecksConfig as RLPSanityChecksConfig
45
+ from .rlp_sanity_checks import (
46
+ RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
47
+ )
37
48
  from .shared_parameters import SharedParametersCallback as SharedParametersCallback
38
- from .shared_parameters import SharedParametersConfig as SharedParametersConfig
49
+ from .shared_parameters import (
50
+ SharedParametersCallbackConfig as SharedParametersCallbackConfig,
51
+ )
39
52
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
40
- from .timer import EpochTimer as EpochTimer
41
- from .timer import EpochTimerConfig as EpochTimerConfig
53
+ from .timer import EpochTimerCallback as EpochTimerCallback
54
+ from .timer import EpochTimerCallbackConfig as EpochTimerCallbackConfig
42
55
  from .wandb_upload_code import WandbUploadCodeCallback as WandbUploadCodeCallback
43
- from .wandb_upload_code import WandbUploadCodeConfig as WandbUploadCodeConfig
56
+ from .wandb_upload_code import (
57
+ WandbUploadCodeCallbackConfig as WandbUploadCodeCallbackConfig,
58
+ )
44
59
  from .wandb_watch import WandbWatchCallback as WandbWatchCallback
45
- from .wandb_watch import WandbWatchConfig as WandbWatchConfig
60
+ from .wandb_watch import WandbWatchCallbackConfig as WandbWatchCallbackConfig
46
61
 
47
62
  CallbackConfig = Annotated[
48
63
  DebugFlagCallbackConfig
49
- | EarlyStoppingConfig
64
+ | EarlyStoppingCallbackConfig
50
65
  | ThroughputMonitorConfig
51
- | EpochTimerConfig
52
- | PrintTableMetricsConfig
53
- | FiniteChecksConfig
54
- | NormLoggingConfig
55
- | GradientSkippingConfig
56
- | EMAConfig
66
+ | EpochTimerCallbackConfig
67
+ | PrintTableMetricsCallbackConfig
68
+ | FiniteChecksCallbackConfig
69
+ | NormLoggingCallbackConfig
70
+ | GradientSkippingCallbackConfig
71
+ | LogEpochCallbackConfig
72
+ | EMACallbackConfig
57
73
  | BestCheckpointCallbackConfig
58
74
  | LastCheckpointCallbackConfig
59
75
  | OnExceptionCheckpointCallbackConfig
60
- | SharedParametersConfig
61
- | RLPSanityChecksConfig
62
- | WandbWatchConfig
63
- | WandbUploadCodeConfig,
76
+ | SharedParametersCallbackConfig
77
+ | RLPSanityChecksCallbackConfig
78
+ | WandbWatchCallbackConfig
79
+ | WandbUploadCodeCallbackConfig,
64
80
  C.Field(discriminator="name"),
65
81
  ]
@@ -12,6 +12,8 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ from __future__ import annotations
16
+
15
17
  import time
16
18
  from collections import deque
17
19
  from typing import (
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  from pathlib import Path
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from collections import Counter
3
5
  from collections.abc import Iterable
@@ -1,12 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
4
  from .best_checkpoint import (
3
5
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
6
  )
5
- from .last_checkpoint import LastCheckpoint as LastCheckpoint
7
+ from .last_checkpoint import LastCheckpointCallback as LastCheckpointCallback
6
8
  from .last_checkpoint import (
7
9
  LastCheckpointCallbackConfig as LastCheckpointCallbackConfig,
8
10
  )
9
- from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
11
+ from .on_exception_checkpoint import (
12
+ OnExceptionCheckpointCallback as OnExceptionCheckpointCallback,
13
+ )
10
14
  from .on_exception_checkpoint import (
11
15
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
12
16
  )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import ABC, abstractmethod
3
5
  from pathlib import Path
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from pathlib import Path
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -17,11 +19,11 @@ class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
17
19
 
18
20
  @override
19
21
  def create_checkpoint(self, root_config, dirpath):
20
- return LastCheckpoint(self, dirpath)
22
+ return LastCheckpointCallback(self, dirpath)
21
23
 
22
24
 
23
25
  @final
24
- class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
26
+ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
25
27
  @override
26
28
  def name(self):
27
29
  return "last"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import datetime
3
5
  import logging
@@ -59,10 +61,12 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
59
61
 
60
62
  if not (filename := self.filename):
61
63
  filename = f"on_exception_{root_config.id}"
62
- yield OnExceptionCheckpoint(self, dirpath=Path(dirpath), filename=filename)
64
+ yield OnExceptionCheckpointCallback(
65
+ self, dirpath=Path(dirpath), filename=filename
66
+ )
63
67
 
64
68
 
65
- class OnExceptionCheckpoint(_OnExceptionCheckpoint):
69
+ class OnExceptionCheckpointCallback(_OnExceptionCheckpoint):
66
70
  @override
67
71
  def __init__(
68
72
  self,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import TYPE_CHECKING, Literal, cast
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from pathlib import Path
@@ -41,7 +43,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
41
43
  symlink_path.symlink_to(session_dir)
42
44
 
43
45
 
44
- class DirectorySetupConfig(CallbackConfigBase):
46
+ class DirectorySetupCallbackConfig(CallbackConfigBase):
45
47
  name: Literal["directory_setup"] = "directory_setup"
46
48
 
47
49
  enabled: bool = True
@@ -62,7 +64,7 @@ class DirectorySetupConfig(CallbackConfigBase):
62
64
 
63
65
  class DirectorySetupCallback(Callback):
64
66
  @override
65
- def __init__(self, config: DirectorySetupConfig):
67
+ def __init__(self, config: DirectorySetupCallbackConfig):
66
68
  super().__init__()
67
69
 
68
70
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import math
3
5
  from typing import Literal
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
17
- class EarlyStoppingConfig(CallbackConfigBase):
19
+ class EarlyStoppingCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["early_stopping"] = "early_stopping"
19
21
 
20
22
  metric: MetricConfig | None = None
@@ -54,11 +56,11 @@ class EarlyStoppingConfig(CallbackConfigBase):
54
56
  "Either `metric` or `root_config.primary_metric` must be set to use EarlyStopping."
55
57
  )
56
58
 
57
- yield EarlyStopping(self, metric)
59
+ yield EarlyStoppingCallback(self, metric)
58
60
 
59
61
 
60
- class EarlyStopping(_EarlyStopping):
61
- def __init__(self, config: EarlyStoppingConfig, metric: MetricConfig):
62
+ class EarlyStoppingCallback(_EarlyStopping):
63
+ def __init__(self, config: EarlyStoppingCallbackConfig, metric: MetricConfig):
62
64
  self.config = config
63
65
  self.metric = metric
64
66
  del config, metric
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import copy
3
5
  import threading
@@ -13,7 +15,7 @@ from typing_extensions import override
13
15
  from .base import CallbackConfigBase
14
16
 
15
17
 
16
- class EMA(Callback):
18
+ class EMACallback(Callback):
17
19
  """
18
20
  Implements Exponential Moving Averaging (EMA).
19
21
 
@@ -358,7 +360,7 @@ class EMAOptimizer(torch.optim.Optimizer):
358
360
  self.rebuild_ema_params = True
359
361
 
360
362
 
361
- class EMAConfig(CallbackConfigBase):
363
+ class EMACallbackConfig(CallbackConfigBase):
362
364
  name: Literal["ema"] = "ema"
363
365
 
364
366
  decay: float
@@ -375,7 +377,7 @@ class EMAConfig(CallbackConfigBase):
375
377
 
376
378
  @override
377
379
  def create_callbacks(self, root_config):
378
- yield EMA(
380
+ yield EMACallback(
379
381
  decay=self.decay,
380
382
  validate_original_weights=self.validate_original_weights,
381
383
  every_n_steps=self.every_n_steps,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -58,7 +60,7 @@ class FiniteChecksCallback(Callback):
58
60
  )
59
61
 
60
62
 
61
- class FiniteChecksConfig(CallbackConfigBase):
63
+ class FiniteChecksCallbackConfig(CallbackConfigBase):
62
64
  name: Literal["finite_checks"] = "finite_checks"
63
65
 
64
66
  nonfinite_grads: bool = True
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal, Protocol, runtime_checkable
3
5
 
@@ -18,8 +20,8 @@ class HasGradSkippedSteps(Protocol):
18
20
  grad_skipped_steps: Any
19
21
 
20
22
 
21
- class GradientSkipping(Callback):
22
- def __init__(self, config: "GradientSkippingConfig"):
23
+ class GradientSkippingCallback(Callback):
24
+ def __init__(self, config: "GradientSkippingCallbackConfig"):
23
25
  super().__init__()
24
26
  self.config = config
25
27
 
@@ -73,7 +75,7 @@ class GradientSkipping(Callback):
73
75
  )
74
76
 
75
77
 
76
- class GradientSkippingConfig(CallbackConfigBase):
78
+ class GradientSkippingCallbackConfig(CallbackConfigBase):
77
79
  name: Literal["gradient_skipping"] = "gradient_skipping"
78
80
 
79
81
  threshold: float
@@ -94,4 +96,4 @@ class GradientSkippingConfig(CallbackConfigBase):
94
96
 
95
97
  @override
96
98
  def create_callbacks(self, root_config):
97
- yield GradientSkipping(self)
99
+ yield GradientSkippingCallback(self)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Callable
2
4
  from typing import Literal
3
5
 
@@ -1,14 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import math
3
- from typing import Any
5
+ from typing import Any, Literal
4
6
 
5
7
  from lightning.pytorch import LightningModule, Trainer
6
8
  from lightning.pytorch.callbacks import Callback
7
9
  from typing_extensions import override
8
10
 
11
+ from .base import CallbackConfigBase
12
+
9
13
  log = logging.getLogger(__name__)
10
14
 
11
15
 
16
+ class LogEpochCallbackConfig(CallbackConfigBase):
17
+ name: Literal["log_epoch"] = "log_epoch"
18
+
19
+ @override
20
+ def create_callbacks(self, root_config):
21
+ yield LogEpochCallback()
22
+
23
+
12
24
  class LogEpochCallback(Callback):
13
25
  def __init__(self, metric_name: str = "computed_epoch"):
14
26
  super().__init__()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal, cast
3
5
 
@@ -96,7 +98,7 @@ def compute_norm(
96
98
 
97
99
 
98
100
  class NormLoggingCallback(Callback):
99
- def __init__(self, config: "NormLoggingConfig"):
101
+ def __init__(self, config: "NormLoggingCallbackConfig"):
100
102
  super().__init__()
101
103
 
102
104
  self.config = config
@@ -155,7 +157,7 @@ class NormLoggingCallback(Callback):
155
157
  )
156
158
 
157
159
 
158
- class NormLoggingConfig(CallbackConfigBase):
160
+ class NormLoggingCallbackConfig(CallbackConfigBase):
159
161
  name: Literal["norm_logging"] = "norm_logging"
160
162
 
161
163
  log_grad_norm: bool | str | float = False
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import fnmatch
3
5
  import importlib.util
@@ -74,7 +76,7 @@ class PrintTableMetricsCallback(Callback):
74
76
  return table
75
77
 
76
78
 
77
- class PrintTableMetricsConfig(CallbackConfigBase):
79
+ class PrintTableMetricsCallbackConfig(CallbackConfigBase):
78
80
  """Configuration class for PrintTableMetricsCallback."""
79
81
 
80
82
  name: Literal["print_table_metrics"] = "print_table_metrics"
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Mapping
3
5
  from typing import Literal, cast
@@ -16,7 +18,7 @@ from .base import CallbackConfigBase
16
18
  log = logging.getLogger(__name__)
17
19
 
18
20
 
19
- class RLPSanityChecksConfig(CallbackConfigBase):
21
+ class RLPSanityChecksCallbackConfig(CallbackConfigBase):
20
22
  """
21
23
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
22
24
  - If the ``interval`` is step, it makes sure that validation is called every ``frequency`` steps.
@@ -43,7 +45,7 @@ class RLPSanityChecksConfig(CallbackConfigBase):
43
45
 
44
46
  class RLPSanityChecksCallback(Callback):
45
47
  @override
46
- def __init__(self, config: RLPSanityChecksConfig):
48
+ def __init__(self, config: RLPSanityChecksCallbackConfig):
47
49
  super().__init__()
48
50
 
49
51
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Iterable
3
5
  from typing import Literal, Protocol, TypeAlias, runtime_checkable
@@ -17,7 +19,7 @@ def _parameters_to_names(parameters: Iterable[nn.Parameter], model: nn.Module):
17
19
  return [mapping[id(p)] for p in parameters]
18
20
 
19
21
 
20
- class SharedParametersConfig(CallbackConfigBase):
22
+ class SharedParametersCallbackConfig(CallbackConfigBase):
21
23
  """A callback that allows scaling the gradients of shared parameters that
22
24
  are registered in the ``self.shared_parameters`` list of the root module.
23
25
 
@@ -43,7 +45,7 @@ class ModuleWithSharedParameters(Protocol):
43
45
 
44
46
  class SharedParametersCallback(Callback):
45
47
  @override
46
- def __init__(self, config: SharedParametersConfig):
48
+ def __init__(self, config: SharedParametersCallbackConfig):
47
49
  super().__init__()
48
50
 
49
51
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
5
  from typing import Any, Literal
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
12
14
  log = logging.getLogger(__name__)
13
15
 
14
16
 
15
- class EpochTimer(Callback):
17
+ class EpochTimerCallback(Callback):
16
18
  def __init__(self):
17
19
  super().__init__()
18
20
 
@@ -149,9 +151,9 @@ class EpochTimer(Callback):
149
151
  self._total_batches = state_dict["total_batches"]
150
152
 
151
153
 
152
- class EpochTimerConfig(CallbackConfigBase):
154
+ class EpochTimerCallbackConfig(CallbackConfigBase):
153
155
  name: Literal["epoch_timer"] = "epoch_timer"
154
156
 
155
157
  @override
156
158
  def create_callbacks(self, root_config):
157
- yield EpochTimer()
159
+ yield EpochTimerCallback()
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from pathlib import Path
@@ -14,7 +16,7 @@ from .base import CallbackConfigBase
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
17
- class WandbUploadCodeConfig(CallbackConfigBase):
19
+ class WandbUploadCodeCallbackConfig(CallbackConfigBase):
18
20
  name: Literal["wandb_upload_code"] = "wandb_upload_code"
19
21
 
20
22
  enabled: bool = True
@@ -32,7 +34,7 @@ class WandbUploadCodeConfig(CallbackConfigBase):
32
34
 
33
35
 
34
36
  class WandbUploadCodeCallback(Callback):
35
- def __init__(self, config: WandbUploadCodeConfig):
37
+ def __init__(self, config: WandbUploadCodeCallbackConfig):
36
38
  super().__init__()
37
39
 
38
40
  self.config = config
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal, Protocol, cast, runtime_checkable
3
5
 
@@ -12,7 +14,7 @@ from .base import CallbackConfigBase
12
14
  log = logging.getLogger(__name__)
13
15
 
14
16
 
15
- class WandbWatchConfig(CallbackConfigBase):
17
+ class WandbWatchCallbackConfig(CallbackConfigBase):
16
18
  name: Literal["wandb_watch"] = "wandb_watch"
17
19
 
18
20
  enabled: bool = True
@@ -41,7 +43,7 @@ class _HasWandbLogModuleProtocol(Protocol):
41
43
 
42
44
 
43
45
  class WandbWatchCallback(Callback):
44
- def __init__(self, config: WandbWatchConfig):
46
+ def __init__(self, config: WandbWatchCallbackConfig):
45
47
  super().__init__()
46
48
 
47
49
  self.config = config