nshtrainer 0.30.1__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.1.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
@@ -1,26 +1,5 @@
1
- from typing_extensions import TypeAlias
2
-
3
- from .base import Base as Base
4
1
  from .base import LightningModuleBase as LightningModuleBase
5
2
  from .config import BaseConfig as BaseConfig
6
- from .config import BaseProfilerConfig as BaseProfilerConfig
7
- from .config import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
8
- from .config import CheckpointLoadingConfig as CheckpointLoadingConfig
9
- from .config import CheckpointSavingConfig as CheckpointSavingConfig
10
3
  from .config import DirectoryConfig as DirectoryConfig
11
- from .config import EarlyStoppingConfig as EarlyStoppingConfig
12
- from .config import GradientClippingConfig as GradientClippingConfig
13
- from .config import HuggingFaceHubConfig as HuggingFaceHubConfig
14
- from .config import LastCheckpointCallbackConfig as LastCheckpointCallbackConfig
15
- from .config import LoggingConfig as LoggingConfig
16
4
  from .config import MetricConfig as MetricConfig
17
- from .config import (
18
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
19
- )
20
- from .config import OptimizationConfig as OptimizationConfig
21
- from .config import PrimaryMetricConfig as PrimaryMetricConfig
22
- from .config import ReproducibilityConfig as ReproducibilityConfig
23
- from .config import SanityCheckingConfig as SanityCheckingConfig
24
5
  from .config import TrainerConfig as TrainerConfig
25
-
26
- ConfigList: TypeAlias = list[tuple[BaseConfig, type[LightningModuleBase]]]
nshtrainer/model/base.py CHANGED
@@ -2,61 +2,28 @@ import inspect
2
2
  import logging
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import MutableMapping
5
- from typing import IO, TYPE_CHECKING, Any, Generic, cast
5
+ from typing import IO, TYPE_CHECKING, Any, Generic, Literal, cast
6
6
 
7
7
  import torch
8
+ import torch.distributed
8
9
  from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
9
10
  from lightning.pytorch import LightningModule, Trainer
10
11
  from lightning.pytorch.callbacks import Callback
12
+ from lightning.pytorch.profilers import PassThroughProfiler, Profiler
11
13
  from lightning.pytorch.utilities.types import STEP_OUTPUT
12
14
  from typing_extensions import Self, TypeVar, override
13
15
 
16
+ from ..callbacks.rlp_sanity_checks import _RLPSanityCheckModuleMixin
14
17
  from ..util._environment_info import EnvironmentConfig
15
18
  from .config import BaseConfig
16
- from .modules.callback import CallbackModuleMixin
17
- from .modules.debug import DebugModuleMixin
18
- from .modules.distributed import DistributedMixin
19
- from .modules.logger import LoggerLightningModuleMixin
20
- from .modules.profiler import ProfilerMixin
21
- from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin
22
- from .modules.shared_parameters import SharedParametersModuleMixin
19
+ from .mixins.callback import CallbackModuleMixin
20
+ from .mixins.logger import LoggerLightningModuleMixin
23
21
 
24
22
  log = logging.getLogger(__name__)
25
23
 
26
24
  THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True)
27
25
 
28
26
 
29
- class Base(DebugModuleMixin, Generic[THparams]):
30
- @torch.jit.unused
31
- @property
32
- def config(self) -> THparams:
33
- return self.hparams
34
-
35
- @torch.jit.unused
36
- @property
37
- def C(self) -> THparams:
38
- return self.hparams
39
-
40
- @property
41
- def debug(self) -> bool:
42
- if torch.jit.is_scripting():
43
- return False
44
- return self.config.debug
45
-
46
- @property
47
- def dev(self) -> bool:
48
- if torch.jit.is_scripting():
49
- return False
50
- return self.config.debug
51
-
52
- @override
53
- def __init__(self, hparams: THparams):
54
- super().__init__()
55
-
56
- if not hasattr(self, "hparams"):
57
- self.hparams = hparams
58
-
59
-
60
27
  class DebugFlagCallback(Callback):
61
28
  """
62
29
  Sets the debug flag to true in the following circumstances:
@@ -90,18 +57,146 @@ class DebugFlagCallback(Callback):
90
57
  hparams.debug = self._debug
91
58
 
92
59
 
60
+ T = TypeVar("T", infer_variance=True)
61
+
62
+ ReduceOpStr = Literal[
63
+ "avg",
64
+ "mean",
65
+ "band",
66
+ "bor",
67
+ "bxor",
68
+ "max",
69
+ "min",
70
+ "premul_sum",
71
+ "product",
72
+ "sum",
73
+ ]
74
+ VALID_REDUCE_OPS = (
75
+ "avg",
76
+ "mean",
77
+ "band",
78
+ "bor",
79
+ "bxor",
80
+ "max",
81
+ "min",
82
+ "premul_sum",
83
+ "product",
84
+ "sum",
85
+ )
86
+
87
+
93
88
  class LightningModuleBase( # pyright: ignore[reportIncompatibleMethodOverride]
94
- ProfilerMixin,
95
- RLPSanityCheckModuleMixin,
89
+ _RLPSanityCheckModuleMixin,
96
90
  LoggerLightningModuleMixin,
97
- SharedParametersModuleMixin,
98
- DistributedMixin,
99
91
  CallbackModuleMixin,
100
- Base[THparams],
101
92
  LightningModule,
102
93
  ABC,
103
94
  Generic[THparams],
104
95
  ):
96
+ # region Config
97
+ @torch.jit.unused
98
+ @property
99
+ def config(self) -> THparams:
100
+ return self.hparams
101
+
102
+ @property
103
+ def debug(self) -> bool:
104
+ if torch.jit.is_scripting():
105
+ return False
106
+ return self.config.debug
107
+
108
+ # endregion
109
+
110
+ # region Debug
111
+
112
+ @torch.jit.unused
113
+ def breakpoint(self, rank_zero_only: bool = True):
114
+ if (
115
+ not rank_zero_only
116
+ or not torch.distributed.is_initialized()
117
+ or torch.distributed.get_rank() == 0
118
+ ):
119
+ breakpoint()
120
+
121
+ if rank_zero_only and torch.distributed.is_initialized():
122
+ _ = torch.distributed.barrier()
123
+
124
+ @torch.jit.unused
125
+ def ensure_finite(
126
+ self,
127
+ tensor: torch.Tensor,
128
+ name: str | None = None,
129
+ throw: bool = False,
130
+ ):
131
+ name_parts: list[str] = ["Tensor"]
132
+ if name is not None:
133
+ name_parts.append(name)
134
+ name = " ".join(name_parts)
135
+
136
+ not_finite = ~torch.isfinite(tensor)
137
+ if not_finite.any():
138
+ msg = f"{name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values."
139
+ if throw:
140
+ raise RuntimeError(msg)
141
+ else:
142
+ log.warning(msg)
143
+ return False
144
+ return True
145
+
146
+ # endregion
147
+
148
+ # region Profiler
149
+ @property
150
+ def profiler(self) -> Profiler:
151
+ if (trainer := self._trainer) is None:
152
+ raise RuntimeError("trainer is not defined")
153
+
154
+ if not hasattr(trainer, "profiler"):
155
+ raise RuntimeError("trainer does not have profiler")
156
+
157
+ if (profiler := getattr(trainer, "profiler")) is None:
158
+ profiler = PassThroughProfiler()
159
+
160
+ return profiler
161
+
162
+ # endregion
163
+
164
+ # region Distributed
165
+ def all_gather_object(
166
+ self,
167
+ object: T,
168
+ group: torch.distributed.ProcessGroup | None = None,
169
+ ) -> list[T]:
170
+ if (
171
+ not torch.distributed.is_available()
172
+ or not torch.distributed.is_initialized()
173
+ ):
174
+ return [object]
175
+
176
+ object_list = [cast(T, None) for _ in range(self.trainer.world_size)]
177
+ torch.distributed.all_gather_object(object_list, object, group=group)
178
+ return object_list
179
+
180
+ def barrier(self, name: str | None = None):
181
+ self.trainer.strategy.barrier(name=name)
182
+
183
+ def reduce(
184
+ self,
185
+ tensor: torch.Tensor,
186
+ reduce_op: torch.distributed.ReduceOp.RedOpType | ReduceOpStr,
187
+ group: Any | None = None,
188
+ ) -> torch.Tensor:
189
+ if isinstance(reduce_op, str):
190
+ # validate reduce_op
191
+ if reduce_op not in VALID_REDUCE_OPS:
192
+ raise ValueError(
193
+ f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}"
194
+ )
195
+
196
+ return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op)
197
+
198
+ # endregion
199
+
105
200
  # Our own custom __repr__ method.
106
201
  # Torch's __repr__ method is too verbose and doesn't provide any useful information.
107
202
  @override