nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +12 -1
  4. nshtrainer/callbacks/debug_flag.py +72 -0
  5. nshtrainer/callbacks/directory_setup.py +85 -0
  6. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  7. nshtrainer/callbacks/shared_parameters.py +87 -0
  8. nshtrainer/config.py +67 -0
  9. nshtrainer/ll/__init__.py +5 -4
  10. nshtrainer/ll/model.py +7 -0
  11. nshtrainer/loggers/wandb.py +1 -1
  12. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  13. nshtrainer/model/__init__.py +0 -21
  14. nshtrainer/model/base.py +124 -67
  15. nshtrainer/model/config.py +7 -1025
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +787 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/callback.py +0 -206
  28. nshtrainer/model/modules/debug.py +0 -42
  29. nshtrainer/model/modules/distributed.py +0 -70
  30. nshtrainer/model/modules/profiler.py +0 -24
  31. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  32. nshtrainer/model/modules/shared_parameters.py +0 -72
  33. /nshtrainer/{config → util/config}/duration.py +0 -0
  34. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
nshtrainer/ll/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
+ from typing import TypeAlias
2
+
1
3
  from . import _experimental as _experimental
2
4
  from . import actsave as actsave
3
5
  from . import callbacks as callbacks
@@ -21,12 +23,9 @@ from .log import init_python_logging as init_python_logging
21
23
  from .log import lovely as lovely
22
24
  from .log import pretty as pretty
23
25
  from .lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
24
- from .model import Base as Base
25
26
  from .model import BaseConfig as BaseConfig
26
- from .model import BaseProfilerConfig as BaseProfilerConfig
27
27
  from .model import CheckpointLoadingConfig as CheckpointLoadingConfig
28
28
  from .model import CheckpointSavingConfig as CheckpointSavingConfig
29
- from .model import ConfigList as ConfigList
30
29
  from .model import DirectoryConfig as DirectoryConfig
31
30
  from .model import (
32
31
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
@@ -43,7 +42,6 @@ from .model import LightningModuleBase as LightningModuleBase
43
42
  from .model import LoggingConfig as LoggingConfig
44
43
  from .model import MetricConfig as MetricConfig
45
44
  from .model import OptimizationConfig as OptimizationConfig
46
- from .model import PrimaryMetricConfig as PrimaryMetricConfig
47
45
  from .model import ReproducibilityConfig as ReproducibilityConfig
48
46
  from .model import SanityCheckingConfig as SanityCheckingConfig
49
47
  from .model import TrainerConfig as TrainerConfig
@@ -54,3 +52,6 @@ from .runner import Runner as Runner
54
52
  from .runner import SnapshotConfig as SnapshotConfig
55
53
  from .snoop import snoop as snoop
56
54
  from .trainer import Trainer as Trainer
55
+
56
+ PrimaryMetricConfig: TypeAlias = MetricConfig
57
+ ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
nshtrainer/ll/model.py CHANGED
@@ -1,5 +1,12 @@
1
1
  from nshtrainer.model import * # noqa: F403
2
2
 
3
+ from ..trainer._config import CheckpointLoadingConfig as CheckpointLoadingConfig
4
+ from ..trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
5
+ from ..trainer._config import GradientClippingConfig as GradientClippingConfig
6
+ from ..trainer._config import LoggingConfig as LoggingConfig
7
+ from ..trainer._config import OptimizationConfig as OptimizationConfig
8
+ from ..trainer._config import ReproducibilityConfig as ReproducibilityConfig
9
+ from ..trainer._config import SanityCheckingConfig as SanityCheckingConfig
3
10
  from ..util._environment_info import (
4
11
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
5
12
  )
@@ -129,7 +129,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
129
129
  "Please either upgrade to a newer version of WandB or disable the `use_wandb_core` option."
130
130
  )
131
131
  else:
132
- wandb.require("core")
132
+ wandb.require("core") # type: ignore
133
133
  log.critical("Using the `wandb-core` backend for WandB.")
134
134
  except ImportError:
135
135
  pass
@@ -6,7 +6,7 @@ from torch.optim import Optimizer
6
6
  from torch.optim.lr_scheduler import LRScheduler
7
7
  from typing_extensions import override
8
8
 
9
- from ..config import Duration
9
+ from ..util.config import Duration
10
10
  from ._base import LRSchedulerConfigBase, LRSchedulerMetadata
11
11
 
12
12
 
@@ -1,26 +1,5 @@
1
- from typing_extensions import TypeAlias
2
-
3
- from .base import Base as Base
4
1
  from .base import LightningModuleBase as LightningModuleBase
5
2
  from .config import BaseConfig as BaseConfig
6
- from .config import BaseProfilerConfig as BaseProfilerConfig
7
- from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
8
- from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
9
- from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
3
  from .config import DirectoryConfig as DirectoryConfig
11
- from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
- from .config import GradientClippingConfig as GradientClippingConfig
13
- from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
14
- from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
15
- from .config import LoggingConfig as LoggingConfig
16
4
  from .config import MetricConfig as MetricConfig
17
- from .config import (
18
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
19
- )
20
- from .config import OptimizationConfig as OptimizationConfig
21
- from .config import PrimaryMetricConfig as PrimaryMetricConfig
22
- from .config import ReproducibilityConfig as ReproducibilityConfig
23
- from .config import SanityCheckingConfig as SanityCheckingConfig
24
5
  from .config import TrainerConfig as TrainerConfig
25
-
26
- ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
nshtrainer/model/base.py CHANGED
@@ -2,39 +2,65 @@ import inspect
2
2
  import logging
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import MutableMapping
5
- from typing import IO, TYPE_CHECKING, Any, Generic, cast
5
+ from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
6
6
 
7
7
  import torch
8
+ import torch.distributed
8
9
  from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
9
- from lightning.pytorch import LightningModule, Trainer
10
- from lightning.pytorch.callbacks import Callback
10
+ from lightning.pytorch import LightningModule
11
+ from lightning.pytorch.profilers import PassThroughProfiler, Profiler
11
12
  from lightning.pytorch.utilities.types import STEP_OUTPUT
12
13
  from typing_extensions import Self, TypeVar, override
13
14
 
15
+ from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
14
16
  from ..util._environment_info import EnvironmentConfig
15
17
  from .config import BaseConfig
16
- from .modules.callback import CallbackModuleMixin
17
- from .modules.debug import DebugModuleMixin
18
- from .modules.distributed import DistributedMixin
19
- from .modules.logger import LoggerLightningModuleMixin
20
- from .modules.profiler import ProfilerMixin
21
- from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
22
- from .modules.shared_parameters import SharedParametersModuleMixin
18
+ from .mixins.logger import LoggerLightningModuleMixin
23
19
 
24
20
  log = logging.getLogger(__name__)
25
21
 
26
22
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
27
23
 
28
24
 
29
- class Base(DebugModuleMixin, Generic[THparams]):
30
- @torch.jit.unused
31
- @property
32
- def config(self) -> THparams:
33
- return self.hparams
25
+ T = TypeVar("T", infer_variance=True)
26
+
27
+ ReduceOpStr = Literal[
28
+ "avg",
29
+ "mean",
30
+ "band",
31
+ "bor",
32
+ "bxor",
33
+ "max",
34
+ "min",
35
+ "premul_sum",
36
+ "product",
37
+ "sum",
38
+ ]
39
+ VALID_REDUCE_OPS = (
40
+ "avg",
41
+ "mean",
42
+ "band",
43
+ "bor",
44
+ "bxor",
45
+ "max",
46
+ "min",
47
+ "premul_sum",
48
+ "product",
49
+ "sum",
50
+ )
51
+
34
52
 
53
+ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
54
+ _RLPSanityCheckModuleMixin,
55
+ LoggerLightningModuleMixin,
56
+ LightningModule,
57
+ ABC,
58
+ Generic[THparams],
59
+ ):
60
+ # region Config
35
61
  @torch.jit.unused
36
62
  @property
37
- def C(self) -> THparams:
63
+ def config(self) -> THparams:
38
64
  return self.hparams
39
65
 
40
66
  @property
@@ -43,65 +69,98 @@ class Base(DebugModuleMixin, Generic[THparams]):
43
69
  return False
44
70
  return self.config.debug
45
71
 
46
- @property
47
- def dev(self) -> bool:
48
- if torch.jit.is_scripting():
72
+ # endregion
73
+
74
+ # region Debug
75
+
76
+ @torch.jit.unused
77
+ def breakpoint(self, rank_zero_only: bool = True):
78
+ if (
79
+ not rank_zero_only
80
+ or not torch.distributed.is_initialized()
81
+ or torch.distributed.get_rank() == 0
82
+ ):
83
+ breakpoint()
84
+
85
+ if rank_zero_only and torch.distributed.is_initialized():
86
+ _ = torch.distributed.barrier()
87
+
88
+ @torch.jit.unused
89
+ def ensure_finite(
90
+ self,
91
+ tensor: torch.Tensor,
92
+ name: str | None = None,
93
+ throw: bool = False,
94
+ ):
95
+ name_parts: list[str] = ["Tensor"]
96
+ if name is not None:
97
+ name_parts.append(name)
98
+ name = " ".join(name_parts)
99
+
100
+ not_finite = ~torch.isfinite(tensor)
101
+ if not_finite.any():
102
+ msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
103
+ if throw:
104
+ raise RuntimeError(msg)
105
+ else:
106
+ log.warning(msg)
49
107
  return False
50
- return self.config.debug
108
+ return True
51
109
 
52
- @override
53
- def __init__(self, hparams: THparams):
54
- super().__init__()
110
+ # endregion
55
111
 
56
- if not hasattr(self, "hparams"):
57
- self.hparams = hparams
112
+ # region Profiler
113
+ @property
114
+ def profiler(self) -> Profiler:
115
+ if (trainer := self._trainer) is None:
116
+ raise RuntimeError("trainer is not defined")
58
117
 
118
+ if not hasattr(trainer, "profiler"):
119
+ raise RuntimeError("trainer does not have profiler")
59
120
 
60
- class DebugFlagCallback(Callback):
61
- """
62
- Sets the debug flag to true in the following circumstances:
63
- - fast_dev_run is enabled
64
- - sanity check is running
65
- """
121
+ if (profiler := getattr(trainer, "profiler")) is None:
122
+ profiler = PassThroughProfiler()
66
123
 
67
- @override
68
- def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str):
69
- if not getattr(trainer, "fast_dev_run", False):
70
- return
124
+ return profiler
71
125
 
72
- hparams = cast(BaseConfig, pl_module.hparams)
73
- if not hparams.debug:
74
- log.critical("Fast dev run detected, setting debug flag to True.")
75
- hparams.debug = True
126
+ # endregion
76
127
 
77
- @override
78
- def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule):
79
- hparams = cast(BaseConfig, pl_module.hparams)
80
- self._debug = hparams.debug
81
- if not self._debug:
82
- log.critical("Enabling debug flag during sanity check routine.")
83
- hparams.debug = True
128
+ # region Distributed
129
+ def all_gather_object(
130
+ self,
131
+ object: T,
132
+ group: torch.distributed.ProcessGroup | None = None,
133
+ ) -> list[T]:
134
+ if (
135
+ not torch.distributed.is_available()
136
+ or not torch.distributed.is_initialized()
137
+ ):
138
+ return [object]
139
+
140
+ object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
141
+ torch.distributed.all_gather_object(object_list, object, group=group)
142
+ return object_list
143
+
144
+ def barrier(self, name: str | None = None):
145
+ self.trainer.strategy.barrier(name=name)
146
+
147
+ def reduce(
148
+ self,
149
+ tensor: torch.Tensor,
150
+ reduce_op: torch.distributed.ReduceOp.RedOpType | ReduceOpStr,
151
+ group: Any | None = None,
152
+ ) -> torch.Tensor:
153
+ if isinstance(reduce_op, str):
154
+ # validate reduce_op
155
+ if reduce_op not in VALID_REDUCE_OPS:
156
+ raise ValueError(
157
+ f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
158
+ )
84
159
 
85
- @override
86
- def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule):
87
- hparams = cast(BaseConfig, pl_module.hparams)
88
- if not self._debug:
89
- log.critical("Sanity check routine complete, disabling debug flag.")
90
- hparams.debug = self._debug
160
+ return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
91
161
 
162
+ # endregion
92
163
 
93
- class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
94
- ProfilerMixin,
95
- RLPSanityCheckModuleMixin,
96
- LoggerLightningModuleMixin,
97
- SharedParametersModuleMixin,
98
- DistributedMixin,
99
- CallbackModuleMixin,
100
- Base[THparams],
101
- LightningModule,
102
- ABC,
103
- Generic[THparams],
104
- ):
105
164
  # Our own custom __repr__ method.
106
165
  # Torch's __repr__ method is too verbose and doesn't provide any useful information.
107
166
  @override
@@ -193,10 +252,8 @@ class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
193
252
  hparams.environment = EnvironmentConfig.from_current_environment(hparams, self)
194
253
  hparams = self.pre_init_update_hparams(hparams)
195
254
 
196
- super().__init__(hparams)
197
-
255
+ super().__init__()
198
256
  self.save_hyperparameters(hparams)
199
- self.register_callback(lambda: DebugFlagCallback())
200
257
 
201
258
  def zero_loss(self):
202
259
  """