nshtrainer 0.30.0__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +8 -0
  4. nshtrainer/callbacks/directory_setup.py +85 -0
  5. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  6. nshtrainer/callbacks/shared_parameters.py +87 -0
  7. nshtrainer/config.py +67 -0
  8. nshtrainer/ll/__init__.py +5 -4
  9. nshtrainer/ll/model.py +7 -0
  10. nshtrainer/loggers/wandb.py +1 -1
  11. nshtrainer/lr_scheduler/linear_warmup_cosine.py +3 -8
  12. nshtrainer/model/__init__.py +0 -21
  13. nshtrainer/model/base.py +139 -44
  14. nshtrainer/model/config.py +7 -1025
  15. nshtrainer/model/{modules → mixins}/callback.py +2 -2
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +778 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/debug.py +0 -42
  28. nshtrainer/model/modules/distributed.py +0 -70
  29. nshtrainer/model/modules/profiler.py +0 -24
  30. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  31. nshtrainer/model/modules/shared_parameters.py +0 -72
  32. /nshtrainer/{config → util/config}/duration.py +0 -0
  33. {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
@@ -165,7 +165,7 @@ class CallbackModuleMixin(
165
165
  CallbackRegistrarModuleMixin,
166
166
  mixin_base_type(LightningModule),
167
167
  ):
168
- def _gather_all_callbacks(self):
168
+ def _nshtrainer_gather_all_callbacks(self):
169
169
  modules: list[Any] = []
170
170
  if isinstance(self, CallbackRegistrarModuleMixin):
171
171
  modules.append(self)
@@ -189,7 +189,7 @@ class CallbackModuleMixin(
189
189
  callbacks = [callbacks]
190
190
 
191
191
  callbacks = list(callbacks)
192
- for callback_fn in self._gather_all_callbacks():
192
+ for callback_fn in self._nshtrainer_gather_all_callbacks():
193
193
  callback_result = callback_fn()
194
194
  if callback_result is None:
195
195
  continue
@@ -6,7 +6,7 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any, cast
7
7
 
8
8
  import torchmetrics
9
- from lightning.pytorch import LightningDataModule, LightningModule
9
+ from lightning.pytorch import LightningModule
10
10
  from lightning.pytorch.utilities.types import _METRIC
11
11
  from lightning_utilities.core.rank_zero import rank_zero_warn
12
12
  from nshutils import ActSave
@@ -23,15 +23,13 @@ class _LogContext:
23
23
  kwargs: dict[str, Any] = field(default_factory=dict)
24
24
 
25
25
 
26
- class LoggerModuleMixin:
26
+ class LoggerModuleMixin(mixin_base_type(LightningModule)):
27
27
  @property
28
28
  def log_dir(self):
29
- if not isinstance(self, (LightningModule, LightningDataModule)):
30
- raise TypeError(
31
- "log_dir can only be used on LightningModule or LightningDataModule"
32
- )
33
-
34
- if (trainer := self.trainer) is None:
29
+ """
30
+ The directory where logs are saved.
31
+ """
32
+ if (trainer := self._trainer) is None:
35
33
  raise RuntimeError("trainer is not defined")
36
34
 
37
35
  if (logger := trainer.logger) is None:
@@ -44,16 +42,15 @@ class LoggerModuleMixin:
44
42
 
45
43
  @property
46
44
  def should_update_logs(self):
47
- if not isinstance(self, (LightningModule, LightningDataModule)):
48
- raise TypeError(
49
- "should_update_logs can only be used on LightningModule or LightningDataModule"
45
+ """
46
+ Whether logs should be updated. This is true once every `log_every_n_steps` steps.
47
+ """
48
+ if self._trainer is None:
49
+ raise RuntimeError(
50
+ "`should_update_logs` can only be used after the module is attached to a trainer"
50
51
  )
51
52
 
52
- trainer = self._trainer if isinstance(self, LightningModule) else self.trainer
53
- if trainer is None:
54
- return True
55
-
56
- return trainer._logger_connector.should_update_logs
53
+ return self._trainer._logger_connector.should_update_logs
57
54
 
58
55
 
59
56
  class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningModule)):
@@ -0,0 +1,13 @@
1
+ from typing import Annotated, TypeAlias
2
+
3
+ import nshconfig as C
4
+
5
+ from ._base import BaseProfilerConfig as BaseProfilerConfig
6
+ from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
7
+ from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
8
+ from .simple import SimpleProfilerConfig as SimpleProfilerConfig
9
+
10
+ ProfilerConfig: TypeAlias = Annotated[
11
+ SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
12
+ C.Discriminator("name"),
13
+ ]
@@ -0,0 +1,29 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import nshconfig as C
7
+ from lightning.pytorch.profilers import Profiler
8
+
9
+ if TYPE_CHECKING:
10
+ from ..model import BaseConfig
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class BaseProfilerConfig(C.Config, ABC):
16
+ dirpath: str | Path | None = None
17
+ """
18
+ Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
19
+ ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
20
+ will be used.
21
+ """
22
+ filename: str | None = None
23
+ """
24
+ If present, filename where the profiler results will be saved instead of printing to stdout.
25
+ The ``.txt`` extension will be used automatically.
26
+ """
27
+
28
+ @abstractmethod
29
+ def create_profiler(self, root_config: "BaseConfig") -> Profiler | None: ...
@@ -0,0 +1,37 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from typing_extensions import override
5
+
6
+ from ._base import BaseProfilerConfig
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class AdvancedProfilerConfig(BaseProfilerConfig):
12
+ name: Literal["advanced"] = "advanced"
13
+
14
+ line_count_restriction: float = 1.0
15
+ """
16
+ This can be used to limit the number of functions
17
+ reported for each action. either an integer (to select a count of lines),
18
+ or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
19
+ """
20
+
21
+ @override
22
+ def create_profiler(self, root_config):
23
+ from lightning.pytorch.profilers.advanced import AdvancedProfiler
24
+
25
+ if (dirpath := self.dirpath) is None:
26
+ dirpath = root_config.directory.resolve_subdirectory(
27
+ root_config.id, "profile"
28
+ )
29
+
30
+ if (filename := self.filename) is None:
31
+ filename = f"{root_config.id}_profile.txt"
32
+
33
+ return AdvancedProfiler(
34
+ line_count_restriction=self.line_count_restriction,
35
+ dirpath=dirpath,
36
+ filename=filename,
37
+ )
@@ -0,0 +1,83 @@
1
+ import logging
2
+ from typing import Any, Literal
3
+
4
+ from typing_extensions import override
5
+
6
+ from ._base import BaseProfilerConfig
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class PyTorchProfilerConfig(BaseProfilerConfig):
12
+ name: Literal["pytorch"] = "pytorch"
13
+
14
+ group_by_input_shapes: bool = False
15
+ """Include operator input shapes and group calls by shape."""
16
+
17
+ emit_nvtx: bool = False
18
+ """
19
+ Context manager that makes every autograd operation emit an NVTX range
20
+ Run::
21
+
22
+ nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
23
+
24
+ To visualize, you can either use::
25
+
26
+ nvvp trace_name.prof
27
+ torch.autograd.profiler.load_nvprof(path)
28
+ """
29
+
30
+ export_to_chrome: bool = True
31
+ """
32
+ Whether to export the sequence of profiled operators for Chrome.
33
+ It will generate a ``.json`` file which can be read by Chrome.
34
+ """
35
+
36
+ row_limit: int = 20
37
+ """
38
+ Limit the number of rows in a table, ``-1`` is a special value that
39
+ removes the limit completely.
40
+ """
41
+
42
+ sort_by_key: str | None = None
43
+ """
44
+ Attribute used to sort entries. By default
45
+ they are printed in the same order as they were registered.
46
+ Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
47
+ ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
48
+ ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
49
+ """
50
+
51
+ record_module_names: bool = True
52
+ """Whether to add module names while recording autograd operation."""
53
+
54
+ table_kwargs: dict[str, Any] | None = None
55
+ """Dictionary with keyword arguments for the summary table."""
56
+
57
+ additional_profiler_kwargs: dict[str, Any] = {}
58
+ """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
59
+
60
+ @override
61
+ def create_profiler(self, root_config):
62
+ from lightning.pytorch.profilers.pytorch import PyTorchProfiler
63
+
64
+ if (dirpath := self.dirpath) is None:
65
+ dirpath = root_config.directory.resolve_subdirectory(
66
+ root_config.id, "profile"
67
+ )
68
+
69
+ if (filename := self.filename) is None:
70
+ filename = f"{root_config.id}_profile.txt"
71
+
72
+ return PyTorchProfiler(
73
+ group_by_input_shapes=self.group_by_input_shapes,
74
+ emit_nvtx=self.emit_nvtx,
75
+ export_to_chrome=self.export_to_chrome,
76
+ row_limit=self.row_limit,
77
+ sort_by_key=self.sort_by_key,
78
+ record_module_names=self.record_module_names,
79
+ table_kwargs=self.table_kwargs,
80
+ dirpath=dirpath,
81
+ filename=filename,
82
+ **self.additional_profiler_kwargs,
83
+ )
@@ -0,0 +1,36 @@
1
+ import logging
2
+ from typing import Literal
3
+
4
+ from typing_extensions import override
5
+
6
+ from ._base import BaseProfilerConfig
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class SimpleProfilerConfig(BaseProfilerConfig):
12
+ name: Literal["simple"] = "simple"
13
+
14
+ extended: bool = True
15
+ """
16
+ If ``True``, adds extra columns representing number of calls and percentage of
17
+ total time spent onrespective action.
18
+ """
19
+
20
+ @override
21
+ def create_profiler(self, root_config):
22
+ from lightning.pytorch.profilers.simple import SimpleProfiler
23
+
24
+ if (dirpath := self.dirpath) is None:
25
+ dirpath = root_config.directory.resolve_subdirectory(
26
+ root_config.id, "profile"
27
+ )
28
+
29
+ if (filename := self.filename) is None:
30
+ filename = f"{root_config.id}_profile.txt"
31
+
32
+ return SimpleProfiler(
33
+ extended=self.extended,
34
+ dirpath=dirpath,
35
+ filename=filename,
36
+ )