nshtrainer 0.30.0__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +8 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +3 -8
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +139 -44
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/callback.py +2 -2
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +778 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.0.dist-info → nshtrainer-0.31.0.dist-info}/WHEEL +0 -0
|
@@ -165,7 +165,7 @@ class CallbackModuleMixin(
|
|
|
165
165
|
CallbackRegistrarModuleMixin,
|
|
166
166
|
mixin_base_type(LightningModule),
|
|
167
167
|
):
|
|
168
|
-
def
|
|
168
|
+
def _nshtrainer_gather_all_callbacks(self):
|
|
169
169
|
modules: list[Any] = []
|
|
170
170
|
if isinstance(self, CallbackRegistrarModuleMixin):
|
|
171
171
|
modules.append(self)
|
|
@@ -189,7 +189,7 @@ class CallbackModuleMixin(
|
|
|
189
189
|
callbacks = [callbacks]
|
|
190
190
|
|
|
191
191
|
callbacks = list(callbacks)
|
|
192
|
-
for callback_fn in self.
|
|
192
|
+
for callback_fn in self._nshtrainer_gather_all_callbacks():
|
|
193
193
|
callback_result = callback_fn()
|
|
194
194
|
if callback_result is None:
|
|
195
195
|
continue
|
|
@@ -6,7 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
7
|
|
|
8
8
|
import torchmetrics
|
|
9
|
-
from lightning.pytorch import
|
|
9
|
+
from lightning.pytorch import LightningModule
|
|
10
10
|
from lightning.pytorch.utilities.types import _METRIC
|
|
11
11
|
from lightning_utilities.core.rank_zero import rank_zero_warn
|
|
12
12
|
from nshutils import ActSave
|
|
@@ -23,15 +23,13 @@ class _LogContext:
|
|
|
23
23
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class LoggerModuleMixin:
|
|
26
|
+
class LoggerModuleMixin(mixin_base_type(LightningModule)):
|
|
27
27
|
@property
|
|
28
28
|
def log_dir(self):
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
if (trainer := self.trainer) is None:
|
|
29
|
+
"""
|
|
30
|
+
The directory where logs are saved.
|
|
31
|
+
"""
|
|
32
|
+
if (trainer := self._trainer) is None:
|
|
35
33
|
raise RuntimeError("trainer is not defined")
|
|
36
34
|
|
|
37
35
|
if (logger := trainer.logger) is None:
|
|
@@ -44,16 +42,15 @@ class LoggerModuleMixin:
|
|
|
44
42
|
|
|
45
43
|
@property
|
|
46
44
|
def should_update_logs(self):
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
45
|
+
"""
|
|
46
|
+
Whether logs should be updated. This is true once every `log_every_n_steps` steps.
|
|
47
|
+
"""
|
|
48
|
+
if self._trainer is None:
|
|
49
|
+
raise RuntimeError(
|
|
50
|
+
"`should_update_logs` can only be used after the module is attached to a trainer"
|
|
50
51
|
)
|
|
51
52
|
|
|
52
|
-
|
|
53
|
-
if trainer is None:
|
|
54
|
-
return True
|
|
55
|
-
|
|
56
|
-
return trainer._logger_connector.should_update_logs
|
|
53
|
+
return self._trainer._logger_connector.should_update_logs
|
|
57
54
|
|
|
58
55
|
|
|
59
56
|
class LoggerLightningModuleMixin(LoggerModuleMixin, mixin_base_type(LightningModule)):
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Annotated, TypeAlias
|
|
2
|
+
|
|
3
|
+
import nshconfig as C
|
|
4
|
+
|
|
5
|
+
from ._base import BaseProfilerConfig as BaseProfilerConfig
|
|
6
|
+
from .advanced import AdvancedProfilerConfig as AdvancedProfilerConfig
|
|
7
|
+
from .pytorch import PyTorchProfilerConfig as PyTorchProfilerConfig
|
|
8
|
+
from .simple import SimpleProfilerConfig as SimpleProfilerConfig
|
|
9
|
+
|
|
10
|
+
ProfilerConfig: TypeAlias = Annotated[
|
|
11
|
+
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
|
12
|
+
C.Discriminator("name"),
|
|
13
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import nshconfig as C
|
|
7
|
+
from lightning.pytorch.profilers import Profiler
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ..model import BaseConfig
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseProfilerConfig(C.Config, ABC):
|
|
16
|
+
dirpath: str | Path | None = None
|
|
17
|
+
"""
|
|
18
|
+
Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
|
|
19
|
+
``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
|
|
20
|
+
will be used.
|
|
21
|
+
"""
|
|
22
|
+
filename: str | None = None
|
|
23
|
+
"""
|
|
24
|
+
If present, filename where the profiler results will be saved instead of printing to stdout.
|
|
25
|
+
The ``.txt`` extension will be used automatically.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def create_profiler(self, root_config: "BaseConfig") -> Profiler | None: ...
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from ._base import BaseProfilerConfig
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AdvancedProfilerConfig(BaseProfilerConfig):
|
|
12
|
+
name: Literal["advanced"] = "advanced"
|
|
13
|
+
|
|
14
|
+
line_count_restriction: float = 1.0
|
|
15
|
+
"""
|
|
16
|
+
This can be used to limit the number of functions
|
|
17
|
+
reported for each action. either an integer (to select a count of lines),
|
|
18
|
+
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@override
|
|
22
|
+
def create_profiler(self, root_config):
|
|
23
|
+
from lightning.pytorch.profilers.advanced import AdvancedProfiler
|
|
24
|
+
|
|
25
|
+
if (dirpath := self.dirpath) is None:
|
|
26
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
27
|
+
root_config.id, "profile"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
if (filename := self.filename) is None:
|
|
31
|
+
filename = f"{root_config.id}_profile.txt"
|
|
32
|
+
|
|
33
|
+
return AdvancedProfiler(
|
|
34
|
+
line_count_restriction=self.line_count_restriction,
|
|
35
|
+
dirpath=dirpath,
|
|
36
|
+
filename=filename,
|
|
37
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from ._base import BaseProfilerConfig
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
12
|
+
name: Literal["pytorch"] = "pytorch"
|
|
13
|
+
|
|
14
|
+
group_by_input_shapes: bool = False
|
|
15
|
+
"""Include operator input shapes and group calls by shape."""
|
|
16
|
+
|
|
17
|
+
emit_nvtx: bool = False
|
|
18
|
+
"""
|
|
19
|
+
Context manager that makes every autograd operation emit an NVTX range
|
|
20
|
+
Run::
|
|
21
|
+
|
|
22
|
+
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
|
|
23
|
+
|
|
24
|
+
To visualize, you can either use::
|
|
25
|
+
|
|
26
|
+
nvvp trace_name.prof
|
|
27
|
+
torch.autograd.profiler.load_nvprof(path)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
export_to_chrome: bool = True
|
|
31
|
+
"""
|
|
32
|
+
Whether to export the sequence of profiled operators for Chrome.
|
|
33
|
+
It will generate a ``.json`` file which can be read by Chrome.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
row_limit: int = 20
|
|
37
|
+
"""
|
|
38
|
+
Limit the number of rows in a table, ``-1`` is a special value that
|
|
39
|
+
removes the limit completely.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
sort_by_key: str | None = None
|
|
43
|
+
"""
|
|
44
|
+
Attribute used to sort entries. By default
|
|
45
|
+
they are printed in the same order as they were registered.
|
|
46
|
+
Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
|
|
47
|
+
``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
|
|
48
|
+
``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
record_module_names: bool = True
|
|
52
|
+
"""Whether to add module names while recording autograd operation."""
|
|
53
|
+
|
|
54
|
+
table_kwargs: dict[str, Any] | None = None
|
|
55
|
+
"""Dictionary with keyword arguments for the summary table."""
|
|
56
|
+
|
|
57
|
+
additional_profiler_kwargs: dict[str, Any] = {}
|
|
58
|
+
"""Keyword arguments for the PyTorch profiler. This depends on your PyTorch version"""
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
def create_profiler(self, root_config):
|
|
62
|
+
from lightning.pytorch.profilers.pytorch import PyTorchProfiler
|
|
63
|
+
|
|
64
|
+
if (dirpath := self.dirpath) is None:
|
|
65
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
66
|
+
root_config.id, "profile"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if (filename := self.filename) is None:
|
|
70
|
+
filename = f"{root_config.id}_profile.txt"
|
|
71
|
+
|
|
72
|
+
return PyTorchProfiler(
|
|
73
|
+
group_by_input_shapes=self.group_by_input_shapes,
|
|
74
|
+
emit_nvtx=self.emit_nvtx,
|
|
75
|
+
export_to_chrome=self.export_to_chrome,
|
|
76
|
+
row_limit=self.row_limit,
|
|
77
|
+
sort_by_key=self.sort_by_key,
|
|
78
|
+
record_module_names=self.record_module_names,
|
|
79
|
+
table_kwargs=self.table_kwargs,
|
|
80
|
+
dirpath=dirpath,
|
|
81
|
+
filename=filename,
|
|
82
|
+
**self.additional_profiler_kwargs,
|
|
83
|
+
)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from typing_extensions import override
|
|
5
|
+
|
|
6
|
+
from ._base import BaseProfilerConfig
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SimpleProfilerConfig(BaseProfilerConfig):
|
|
12
|
+
name: Literal["simple"] = "simple"
|
|
13
|
+
|
|
14
|
+
extended: bool = True
|
|
15
|
+
"""
|
|
16
|
+
If ``True``, adds extra columns representing number of calls and percentage of
|
|
17
|
+
total time spent onrespective action.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def create_profiler(self, root_config):
|
|
22
|
+
from lightning.pytorch.profilers.simple import SimpleProfiler
|
|
23
|
+
|
|
24
|
+
if (dirpath := self.dirpath) is None:
|
|
25
|
+
dirpath = root_config.directory.resolve_subdirectory(
|
|
26
|
+
root_config.id, "profile"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if (filename := self.filename) is None:
|
|
30
|
+
filename = f"{root_config.id}_profile.txt"
|
|
31
|
+
|
|
32
|
+
return SimpleProfiler(
|
|
33
|
+
extended=self.extended,
|
|
34
|
+
dirpath=dirpath,
|
|
35
|
+
filename=filename,
|
|
36
|
+
)
|