nshtrainer 0.42.0__py3-none-any.whl → 0.43.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. nshtrainer/__init__.py +2 -0
  2. nshtrainer/_callback.py +2 -0
  3. nshtrainer/_checkpoint/loader.py +2 -0
  4. nshtrainer/_checkpoint/metadata.py +2 -0
  5. nshtrainer/_checkpoint/saver.py +2 -0
  6. nshtrainer/_directory.py +4 -2
  7. nshtrainer/_experimental/__init__.py +2 -0
  8. nshtrainer/_hf_hub.py +2 -0
  9. nshtrainer/callbacks/__init__.py +45 -29
  10. nshtrainer/callbacks/_throughput_monitor_callback.py +2 -0
  11. nshtrainer/callbacks/actsave.py +2 -0
  12. nshtrainer/callbacks/base.py +2 -0
  13. nshtrainer/callbacks/checkpoint/__init__.py +6 -2
  14. nshtrainer/callbacks/checkpoint/_base.py +2 -0
  15. nshtrainer/callbacks/checkpoint/best_checkpoint.py +2 -0
  16. nshtrainer/callbacks/checkpoint/last_checkpoint.py +4 -2
  17. nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +6 -2
  18. nshtrainer/callbacks/debug_flag.py +2 -0
  19. nshtrainer/callbacks/directory_setup.py +4 -2
  20. nshtrainer/callbacks/early_stopping.py +6 -4
  21. nshtrainer/callbacks/ema.py +5 -3
  22. nshtrainer/callbacks/finite_checks.py +3 -1
  23. nshtrainer/callbacks/gradient_skipping.py +6 -4
  24. nshtrainer/callbacks/interval.py +2 -0
  25. nshtrainer/callbacks/log_epoch.py +13 -1
  26. nshtrainer/callbacks/norm_logging.py +4 -2
  27. nshtrainer/callbacks/print_table.py +3 -1
  28. nshtrainer/callbacks/rlp_sanity_checks.py +4 -2
  29. nshtrainer/callbacks/shared_parameters.py +4 -2
  30. nshtrainer/callbacks/throughput_monitor.py +2 -0
  31. nshtrainer/callbacks/timer.py +5 -3
  32. nshtrainer/callbacks/wandb_upload_code.py +4 -2
  33. nshtrainer/callbacks/wandb_watch.py +4 -2
  34. nshtrainer/config/__init__.py +130 -90
  35. nshtrainer/config/_checkpoint/loader/__init__.py +10 -8
  36. nshtrainer/config/_checkpoint/metadata/__init__.py +6 -4
  37. nshtrainer/config/_directory/__init__.py +9 -3
  38. nshtrainer/config/_hf_hub/__init__.py +6 -4
  39. nshtrainer/config/callbacks/__init__.py +82 -42
  40. nshtrainer/config/callbacks/actsave/__init__.py +4 -2
  41. nshtrainer/config/callbacks/base/__init__.py +2 -0
  42. nshtrainer/config/callbacks/checkpoint/__init__.py +6 -4
  43. nshtrainer/config/callbacks/checkpoint/_base/__init__.py +6 -4
  44. nshtrainer/config/callbacks/checkpoint/best_checkpoint/__init__.py +2 -0
  45. nshtrainer/config/callbacks/checkpoint/last_checkpoint/__init__.py +6 -4
  46. nshtrainer/config/callbacks/checkpoint/on_exception_checkpoint/__init__.py +6 -4
  47. nshtrainer/config/callbacks/debug_flag/__init__.py +6 -4
  48. nshtrainer/config/callbacks/directory_setup/__init__.py +7 -5
  49. nshtrainer/config/callbacks/early_stopping/__init__.py +9 -7
  50. nshtrainer/config/callbacks/ema/__init__.py +5 -3
  51. nshtrainer/config/callbacks/finite_checks/__init__.py +7 -5
  52. nshtrainer/config/callbacks/gradient_skipping/__init__.py +7 -5
  53. nshtrainer/config/callbacks/norm_logging/__init__.py +9 -5
  54. nshtrainer/config/callbacks/print_table/__init__.py +7 -5
  55. nshtrainer/config/callbacks/rlp_sanity_checks/__init__.py +7 -5
  56. nshtrainer/config/callbacks/shared_parameters/__init__.py +7 -5
  57. nshtrainer/config/callbacks/throughput_monitor/__init__.py +6 -4
  58. nshtrainer/config/callbacks/timer/__init__.py +9 -5
  59. nshtrainer/config/callbacks/wandb_upload_code/__init__.py +7 -5
  60. nshtrainer/config/callbacks/wandb_watch/__init__.py +9 -5
  61. nshtrainer/config/loggers/__init__.py +18 -10
  62. nshtrainer/config/loggers/_base/__init__.py +2 -0
  63. nshtrainer/config/loggers/csv/__init__.py +2 -0
  64. nshtrainer/config/loggers/tensorboard/__init__.py +2 -0
  65. nshtrainer/config/loggers/wandb/__init__.py +18 -10
  66. nshtrainer/config/lr_scheduler/__init__.py +2 -0
  67. nshtrainer/config/lr_scheduler/_base/__init__.py +2 -0
  68. nshtrainer/config/lr_scheduler/linear_warmup_cosine/__init__.py +2 -0
  69. nshtrainer/config/lr_scheduler/reduce_lr_on_plateau/__init__.py +6 -4
  70. nshtrainer/config/metrics/__init__.py +2 -0
  71. nshtrainer/config/metrics/_config/__init__.py +2 -0
  72. nshtrainer/config/model/__init__.py +8 -6
  73. nshtrainer/config/model/base/__init__.py +4 -2
  74. nshtrainer/config/model/config/__init__.py +8 -6
  75. nshtrainer/config/model/mixins/logger/__init__.py +2 -0
  76. nshtrainer/config/nn/__init__.py +16 -14
  77. nshtrainer/config/nn/mlp/__init__.py +2 -0
  78. nshtrainer/config/nn/nonlinearity/__init__.py +26 -24
  79. nshtrainer/config/optimizer/__init__.py +2 -0
  80. nshtrainer/config/profiler/__init__.py +2 -0
  81. nshtrainer/config/profiler/_base/__init__.py +2 -0
  82. nshtrainer/config/profiler/advanced/__init__.py +6 -4
  83. nshtrainer/config/profiler/pytorch/__init__.py +6 -4
  84. nshtrainer/config/profiler/simple/__init__.py +6 -4
  85. nshtrainer/config/runner/__init__.py +2 -0
  86. nshtrainer/config/trainer/_config/__init__.py +43 -39
  87. nshtrainer/config/trainer/checkpoint_connector/__init__.py +2 -0
  88. nshtrainer/config/util/_environment_info/__init__.py +20 -18
  89. nshtrainer/config/util/config/__init__.py +2 -0
  90. nshtrainer/config/util/config/dtype/__init__.py +2 -0
  91. nshtrainer/config/util/config/duration/__init__.py +2 -0
  92. nshtrainer/data/__init__.py +2 -0
  93. nshtrainer/data/balanced_batch_sampler.py +2 -0
  94. nshtrainer/data/datamodule.py +2 -0
  95. nshtrainer/data/transform.py +2 -0
  96. nshtrainer/ll/__init__.py +2 -0
  97. nshtrainer/ll/_experimental.py +2 -0
  98. nshtrainer/ll/actsave.py +2 -0
  99. nshtrainer/ll/callbacks.py +2 -0
  100. nshtrainer/ll/config.py +2 -0
  101. nshtrainer/ll/data.py +2 -0
  102. nshtrainer/ll/log.py +2 -0
  103. nshtrainer/ll/lr_scheduler.py +2 -0
  104. nshtrainer/ll/model.py +2 -0
  105. nshtrainer/ll/nn.py +2 -0
  106. nshtrainer/ll/optimizer.py +2 -0
  107. nshtrainer/ll/runner.py +2 -0
  108. nshtrainer/ll/snapshot.py +2 -0
  109. nshtrainer/ll/snoop.py +2 -0
  110. nshtrainer/ll/trainer.py +2 -0
  111. nshtrainer/ll/typecheck.py +2 -0
  112. nshtrainer/ll/util.py +2 -0
  113. nshtrainer/loggers/__init__.py +2 -0
  114. nshtrainer/loggers/_base.py +2 -0
  115. nshtrainer/loggers/csv.py +2 -0
  116. nshtrainer/loggers/tensorboard.py +2 -0
  117. nshtrainer/loggers/wandb.py +6 -4
  118. nshtrainer/lr_scheduler/__init__.py +2 -0
  119. nshtrainer/lr_scheduler/_base.py +2 -0
  120. nshtrainer/lr_scheduler/linear_warmup_cosine.py +2 -0
  121. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +2 -0
  122. nshtrainer/metrics/__init__.py +2 -0
  123. nshtrainer/metrics/_config.py +2 -0
  124. nshtrainer/model/__init__.py +2 -0
  125. nshtrainer/model/base.py +2 -0
  126. nshtrainer/model/config.py +2 -0
  127. nshtrainer/model/mixins/callback.py +2 -0
  128. nshtrainer/model/mixins/logger.py +2 -0
  129. nshtrainer/nn/__init__.py +2 -0
  130. nshtrainer/nn/mlp.py +2 -0
  131. nshtrainer/nn/module_dict.py +2 -0
  132. nshtrainer/nn/module_list.py +2 -0
  133. nshtrainer/nn/nonlinearity.py +2 -0
  134. nshtrainer/optimizer.py +2 -0
  135. nshtrainer/profiler/__init__.py +2 -0
  136. nshtrainer/profiler/_base.py +2 -0
  137. nshtrainer/profiler/advanced.py +2 -0
  138. nshtrainer/profiler/pytorch.py +2 -0
  139. nshtrainer/profiler/simple.py +2 -0
  140. nshtrainer/runner.py +2 -0
  141. nshtrainer/scripts/find_packages.py +2 -0
  142. nshtrainer/trainer/__init__.py +2 -0
  143. nshtrainer/trainer/_config.py +16 -13
  144. nshtrainer/trainer/_runtime_callback.py +2 -0
  145. nshtrainer/trainer/checkpoint_connector.py +2 -0
  146. nshtrainer/trainer/signal_connector.py +2 -0
  147. nshtrainer/trainer/trainer.py +2 -0
  148. nshtrainer/util/_environment_info.py +2 -0
  149. nshtrainer/util/bf16.py +2 -0
  150. nshtrainer/util/config/__init__.py +2 -0
  151. nshtrainer/util/config/dtype.py +2 -0
  152. nshtrainer/util/config/duration.py +2 -0
  153. nshtrainer/util/environment.py +2 -0
  154. nshtrainer/util/path.py +2 -0
  155. nshtrainer/util/seed.py +2 -0
  156. nshtrainer/util/slurm.py +3 -0
  157. nshtrainer/util/typed.py +2 -0
  158. nshtrainer/util/typing_utils.py +2 -0
  159. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/METADATA +1 -1
  160. nshtrainer-0.43.0.dist-info/RECORD +162 -0
  161. nshtrainer-0.42.0.dist-info/RECORD +0 -162
  162. {nshtrainer-0.42.0.dist-info → nshtrainer-0.43.0.dist-info}/WHEEL +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -43,30 +45,30 @@ else:
43
45
 
44
46
  if name in globals():
45
47
  return globals()[name]
46
- if name == "EnvironmentPackageConfig":
47
- return importlib.import_module(
48
- "nshtrainer.util._environment_info"
49
- ).EnvironmentPackageConfig
50
- if name == "EnvironmentSnapshotConfig":
48
+ if name == "EnvironmentLinuxEnvironmentConfig":
51
49
  return importlib.import_module(
52
50
  "nshtrainer.util._environment_info"
53
- ).EnvironmentSnapshotConfig
51
+ ).EnvironmentLinuxEnvironmentConfig
54
52
  if name == "EnvironmentLSFInformationConfig":
55
53
  return importlib.import_module(
56
54
  "nshtrainer.util._environment_info"
57
55
  ).EnvironmentLSFInformationConfig
58
- if name == "EnvironmentLinuxEnvironmentConfig":
56
+ if name == "EnvironmentGPUConfig":
59
57
  return importlib.import_module(
60
58
  "nshtrainer.util._environment_info"
61
- ).EnvironmentLinuxEnvironmentConfig
62
- if name == "EnvironmentSLURMInformationConfig":
59
+ ).EnvironmentGPUConfig
60
+ if name == "EnvironmentPackageConfig":
63
61
  return importlib.import_module(
64
62
  "nshtrainer.util._environment_info"
65
- ).EnvironmentSLURMInformationConfig
66
- if name == "EnvironmentConfig":
63
+ ).EnvironmentPackageConfig
64
+ if name == "EnvironmentHardwareConfig":
67
65
  return importlib.import_module(
68
66
  "nshtrainer.util._environment_info"
69
- ).EnvironmentConfig
67
+ ).EnvironmentHardwareConfig
68
+ if name == "EnvironmentSnapshotConfig":
69
+ return importlib.import_module(
70
+ "nshtrainer.util._environment_info"
71
+ ).EnvironmentSnapshotConfig
70
72
  if name == "EnvironmentClassInformationConfig":
71
73
  return importlib.import_module(
72
74
  "nshtrainer.util._environment_info"
@@ -75,18 +77,18 @@ else:
75
77
  return importlib.import_module(
76
78
  "nshtrainer.util._environment_info"
77
79
  ).GitRepositoryConfig
78
- if name == "EnvironmentCUDAConfig":
80
+ if name == "EnvironmentConfig":
79
81
  return importlib.import_module(
80
82
  "nshtrainer.util._environment_info"
81
- ).EnvironmentCUDAConfig
82
- if name == "EnvironmentGPUConfig":
83
+ ).EnvironmentConfig
84
+ if name == "EnvironmentCUDAConfig":
83
85
  return importlib.import_module(
84
86
  "nshtrainer.util._environment_info"
85
- ).EnvironmentGPUConfig
86
- if name == "EnvironmentHardwareConfig":
87
+ ).EnvironmentCUDAConfig
88
+ if name == "EnvironmentSLURMInformationConfig":
87
89
  return importlib.import_module(
88
90
  "nshtrainer.util._environment_info"
89
- ).EnvironmentHardwareConfig
91
+ ).EnvironmentSLURMInformationConfig
90
92
  raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
91
93
 
92
94
  # Submodule exports
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  __codegen__ = True
2
4
 
3
5
  from typing import TYPE_CHECKING
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from . import transform as dataset_transform
2
4
  from .balanced_batch_sampler import BalancedBatchSampler as BalancedBatchSampler
3
5
  from .datamodule import LightningDataModuleBase as LightningDataModuleBase
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import heapq
2
4
  import logging
3
5
  from typing import Any, Protocol, runtime_checkable
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from lightning.pytorch import LightningDataModule
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from collections.abc import Callable
3
5
  from typing import Any, cast
nshtrainer/ll/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import TypeAlias
2
4
 
3
5
  from . import _experimental as _experimental
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer._experimental import * # noqa: F403
nshtrainer/ll/actsave.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils.actsave import * # type: ignore # noqa: F403
2
4
 
3
5
  from nshtrainer.callbacks.actsave import ActSaveCallback as ActSaveCallback
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.callbacks import * # noqa: F403
nshtrainer/ll/config.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshconfig import * # type: ignore # noqa: F403
2
4
  from nshconfig import Config as TypedConfig # type: ignore # noqa: F401
3
5
 
nshtrainer/ll/data.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.data import * # noqa: F403
nshtrainer/ll/log.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils import init_python_logging as init_python_logging
2
4
  from nshutils import lovely as lovely
3
5
  from nshutils import pretty as pretty
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.lr_scheduler import * # noqa: F403
nshtrainer/ll/model.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.model import * # noqa: F403
2
4
 
3
5
  from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
nshtrainer/ll/nn.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.nn import * # noqa: F403
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.optimizer import * # noqa: F403
nshtrainer/ll/runner.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshrunner import SnapshotConfig as SnapshotConfig
2
4
 
3
5
  from nshtrainer.runner import * # type: ignore # noqa: F403
nshtrainer/ll/snapshot.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshsnap import * # pyright: ignore[reportWildcardImportFromLibrary] # noqa: F403
nshtrainer/ll/snoop.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils import snoop as snoop
nshtrainer/ll/trainer.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.trainer import * # noqa: F403
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshutils.typecheck import * # type: ignore # noqa: F403
nshtrainer/ll/util.py CHANGED
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from nshtrainer.util import * # noqa: F403
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from typing import TYPE_CHECKING
3
5
 
nshtrainer/loggers/csv.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Literal
2
4
 
3
5
  from typing_extensions import override
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib.metadata
2
4
  import logging
3
5
  from typing import TYPE_CHECKING, Literal
@@ -8,8 +10,8 @@ from packaging import version
8
10
  from typing_extensions import assert_never, override
9
11
 
10
12
  from ..callbacks.base import CallbackConfigBase
11
- from ..callbacks.wandb_upload_code import WandbUploadCodeConfig
12
- from ..callbacks.wandb_watch import WandbWatchConfig
13
+ from ..callbacks.wandb_upload_code import WandbUploadCodeCallbackConfig
14
+ from ..callbacks.wandb_watch import WandbWatchCallbackConfig
13
15
  from ._base import BaseLoggerConfig
14
16
 
15
17
  if TYPE_CHECKING:
@@ -92,10 +94,10 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
92
94
  - "none" or False: Do not log any checkpoints
93
95
  """
94
96
 
95
- log_code: WandbUploadCodeConfig | None = WandbUploadCodeConfig()
97
+ log_code: WandbUploadCodeCallbackConfig | None = WandbUploadCodeCallbackConfig()
96
98
  """WandB code upload configuration. Used to upload code to WandB."""
97
99
 
98
- watch: WandbWatchConfig | None = WandbWatchConfig()
100
+ watch: WandbWatchCallbackConfig | None = WandbWatchCallbackConfig()
99
101
  """WandB model watch configuration. Used to log model architecture, gradients, and parameters."""
100
102
 
101
103
  offline: bool = False
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from abc import ABC, abstractmethod
3
5
  from collections.abc import Mapping
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  import warnings
3
5
  from typing import Literal
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import TYPE_CHECKING, Literal, cast
2
4
 
3
5
  from lightning.pytorch.utilities.types import LRSchedulerConfigType
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from ._config import MetricConfig as MetricConfig
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import builtins
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .base import LightningModuleBase as LightningModuleBase
2
4
  from .config import BaseConfig as BaseConfig
3
5
  from .config import DirectoryConfig as DirectoryConfig
nshtrainer/model/base.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import inspect
2
4
  import logging
3
5
  from abc import ABC, abstractmethod
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import logging
3
5
  import os
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Callable, Iterable, Sequence
3
5
  from typing import Any, TypeAlias, cast, final
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections import deque
2
4
  from collections.abc import Callable, Generator
3
5
  from contextlib import contextmanager
nshtrainer/nn/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from .mlp import MLP as MLP
2
4
  from .mlp import MLPConfig as MLPConfig
3
5
  from .mlp import MLPConfigDict as MLPConfigDict
nshtrainer/nn/mlp.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from collections.abc import Callable, Sequence
3
5
  from typing import Literal, Protocol, runtime_checkable
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Iterable, Mapping
2
4
  from typing import Generic, cast
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from collections.abc import Iterable, Iterator
2
4
  from typing import Generic, TypeVar, overload
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from typing import Annotated, Literal
3
5
 
nshtrainer/optimizer.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
4
  from collections.abc import Iterable
3
5
  from typing import Annotated, Any, Literal, TypeAlias
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Annotated, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import ABC, abstractmethod
3
5
  from pathlib import Path
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Literal
3
5
 
nshtrainer/runner.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import logging
3
5
  from collections.abc import Callable, Iterable, Mapping, Sequence
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import argparse
2
4
  import ast
3
5
  import glob
@@ -1 +1,3 @@
1
+ from __future__ import annotations
2
+
1
3
  from .trainer import Trainer as Trainer
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from collections.abc import Iterable, Sequence
3
5
  from datetime import timedelta
@@ -30,14 +32,15 @@ from .._hf_hub import HuggingFaceHubConfig
30
32
  from ..callbacks import (
31
33
  BestCheckpointCallbackConfig,
32
34
  CallbackConfig,
33
- EarlyStoppingConfig,
35
+ EarlyStoppingCallbackConfig,
34
36
  LastCheckpointCallbackConfig,
35
37
  OnExceptionCheckpointCallbackConfig,
36
38
  )
37
39
  from ..callbacks.base import CallbackConfigBase
38
40
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
39
- from ..callbacks.rlp_sanity_checks import RLPSanityChecksConfig
40
- from ..callbacks.shared_parameters import SharedParametersConfig
41
+ from ..callbacks.log_epoch import LogEpochCallbackConfig
42
+ from ..callbacks.rlp_sanity_checks import RLPSanityChecksCallbackConfig
43
+ from ..callbacks.shared_parameters import SharedParametersCallbackConfig
41
44
  from ..loggers import (
42
45
  CSVLoggerConfig,
43
46
  LoggerConfig,
@@ -65,7 +68,7 @@ class LoggingConfig(CallbackConfigBase):
65
68
 
66
69
  log_lr: bool | Literal["step", "epoch"] = True
67
70
  """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger."""
68
- log_epoch: bool = True
71
+ log_epoch: LogEpochCallbackConfig | None = LogEpochCallbackConfig()
69
72
  """If enabled, will log the fractional epoch number to the logger."""
70
73
 
71
74
  actsave_logged_metrics: bool = False
@@ -136,9 +139,7 @@ class LoggingConfig(CallbackConfigBase):
136
139
  yield LearningRateMonitor(logging_interval=logging_interval)
137
140
 
138
141
  if self.log_epoch:
139
- from ..callbacks.log_epoch import LogEpochCallback
140
-
141
- yield LogEpochCallback()
142
+ yield from self.log_epoch.create_callbacks(root_config)
142
143
 
143
144
  for logger in self.loggers:
144
145
  if not logger or not isinstance(logger, CallbackConfigBase):
@@ -172,9 +173,9 @@ class OptimizationConfig(CallbackConfigBase):
172
173
 
173
174
  @override
174
175
  def create_callbacks(self, root_config):
175
- from ..callbacks.norm_logging import NormLoggingConfig
176
+ from ..callbacks.norm_logging import NormLoggingCallbackConfig
176
177
 
177
- yield from NormLoggingConfig(
178
+ yield from NormLoggingCallbackConfig(
178
179
  log_grad_norm=self.log_grad_norm,
179
180
  log_grad_norm_per_param=self.log_grad_norm_per_param,
180
181
  log_param_norm=self.log_param_norm,
@@ -564,8 +565,8 @@ class TrainerConfig(C.Config):
564
565
  reproducibility: ReproducibilityConfig = ReproducibilityConfig()
565
566
  """Reproducibility configuration options."""
566
567
 
567
- reduce_lr_on_plateau_sanity_checking: RLPSanityChecksConfig | None = (
568
- RLPSanityChecksConfig()
568
+ reduce_lr_on_plateau_sanity_checking: RLPSanityChecksCallbackConfig | None = (
569
+ RLPSanityChecksCallbackConfig()
569
570
  )
570
571
  """
571
572
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
@@ -573,7 +574,7 @@ class TrainerConfig(C.Config):
573
574
  - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
574
575
  """
575
576
 
576
- early_stopping: EarlyStoppingConfig | None = None
577
+ early_stopping: EarlyStoppingCallbackConfig | None = None
577
578
  """Early stopping configuration options."""
578
579
 
579
580
  profiler: ProfilerConfig | None = None
@@ -741,7 +742,9 @@ class TrainerConfig(C.Config):
741
742
  automatic selection based on the chosen accelerator. Default: ``"auto"``.
742
743
  """
743
744
 
744
- shared_parameters: SharedParametersConfig | None = SharedParametersConfig()
745
+ shared_parameters: SharedParametersCallbackConfig | None = (
746
+ SharedParametersCallbackConfig()
747
+ )
745
748
  """If enabled, the model supports scaling the gradients of shared parameters that
746
749
  are registered in the self.shared_parameters list. This is useful for models that
747
750
  share parameters across multiple modules (e.g., in a GPT model) and want to
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import datetime
2
4
  import logging
3
5
  import time
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from pathlib import Path
3
5
  from typing import TYPE_CHECKING, cast
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  import platform
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from collections.abc import Sequence
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import getpass
2
4
  import importlib.metadata
3
5
  import inspect
nshtrainer/util/bf16.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
2
4
 
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from . import duration as duration
2
4
  from .dtype import DTypeConfig as DTypeConfig
3
5
  from .duration import DurationConfig as DurationConfig
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import TYPE_CHECKING, Literal, TypeAlias
2
4
 
3
5
  import nshconfig as C
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import math
2
4
  from typing import Annotated, Literal
3
5
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from contextlib import contextmanager
nshtrainer/util/path.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import hashlib
2
4
  import logging
3
5
  import os
nshtrainer/util/seed.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
 
3
5
  import lightning.fabric.utilities.seed as LS
nshtrainer/util/slurm.py CHANGED
@@ -1,3 +1,6 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  class SlurmParseException(Exception):
2
5
  pass
3
6
 
nshtrainer/util/typed.py CHANGED
@@ -1,2 +1,4 @@
1
+ from __future__ import annotations
2
+
1
3
  from ..nn.module_dict import TypedModuleDict as TypedModuleDict
2
4
  from ..nn.module_list import TypedModuleList as TypedModuleList