nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -2
- nshtrainer/_directory.py +85 -0
- nshtrainer/callbacks/__init__.py +12 -1
- nshtrainer/callbacks/debug_flag.py +72 -0
- nshtrainer/callbacks/directory_setup.py +85 -0
- nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
- nshtrainer/callbacks/shared_parameters.py +87 -0
- nshtrainer/config.py +67 -0
- nshtrainer/ll/__init__.py +5 -4
- nshtrainer/ll/model.py +7 -0
- nshtrainer/loggers/wandb.py +1 -1
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
- nshtrainer/model/__init__.py +0 -21
- nshtrainer/model/base.py +124 -67
- nshtrainer/model/config.py +7 -1025
- nshtrainer/model/{modules → mixins}/logger.py +13 -16
- nshtrainer/profiler/__init__.py +13 -0
- nshtrainer/profiler/_base.py +29 -0
- nshtrainer/profiler/advanced.py +37 -0
- nshtrainer/profiler/pytorch.py +83 -0
- nshtrainer/profiler/simple.py +36 -0
- nshtrainer/trainer/_config.py +787 -0
- nshtrainer/trainer/trainer.py +16 -17
- nshtrainer/{config → util/config}/__init__.py +1 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
- nshtrainer/model/modules/callback.py +0 -206
- nshtrainer/model/modules/debug.py +0 -42
- nshtrainer/model/modules/distributed.py +0 -70
- nshtrainer/model/modules/profiler.py +0 -24
- nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
- nshtrainer/model/modules/shared_parameters.py +0 -72
- /nshtrainer/{config → util/config}/duration.py +0 -0
- {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
|
@@ -1,202 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Mapping
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
-
from lightning.pytorch.utilities.types import (
|
|
8
|
-
LRSchedulerConfigType,
|
|
9
|
-
LRSchedulerTypeUnion,
|
|
10
|
-
)
|
|
11
|
-
from typing_extensions import Protocol, override, runtime_checkable
|
|
12
|
-
|
|
13
|
-
from ...util.typing_utils import mixin_base_type
|
|
14
|
-
from ..config import BaseConfig
|
|
15
|
-
from .callback import CallbackModuleMixin
|
|
16
|
-
|
|
17
|
-
log = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
|
|
21
|
-
# If we're in PL's "sanity check" mode, we don't need to run this check
|
|
22
|
-
if trainer.sanity_checking:
|
|
23
|
-
return
|
|
24
|
-
|
|
25
|
-
config = cast(BaseConfig, pl_module.hparams)
|
|
26
|
-
if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
|
|
27
|
-
return
|
|
28
|
-
|
|
29
|
-
# if no lr schedulers, return
|
|
30
|
-
if not trainer.lr_scheduler_configs:
|
|
31
|
-
return
|
|
32
|
-
|
|
33
|
-
errors: list[str] = []
|
|
34
|
-
disable_message = (
|
|
35
|
-
"Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
|
|
36
|
-
"to disable this sanity check."
|
|
37
|
-
)
|
|
38
|
-
|
|
39
|
-
for lr_scheduler_config in trainer.lr_scheduler_configs:
|
|
40
|
-
if not lr_scheduler_config.reduce_on_plateau:
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
match lr_scheduler_config.interval:
|
|
44
|
-
case "epoch":
|
|
45
|
-
# we need to make sure that the trainer runs val every `frequency` epochs
|
|
46
|
-
|
|
47
|
-
# If `trainer.check_val_every_n_epoch` is None, then Lightning
|
|
48
|
-
# will run val every `int(trainer.val_check_interval)` steps.
|
|
49
|
-
# So, first we need to make sure that `trainer.val_check_interval` is not None first.
|
|
50
|
-
if trainer.check_val_every_n_epoch is None:
|
|
51
|
-
errors.append(
|
|
52
|
-
"Trainer is not running validation at epoch intervals "
|
|
53
|
-
"(i.e., `trainer.check_val_every_n_epoch` is None) but "
|
|
54
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
55
|
-
f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
|
|
56
|
-
+ disable_message
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
# Second, we make sure that the trainer runs val at least every `frequency` epochs
|
|
60
|
-
if (
|
|
61
|
-
trainer.check_val_every_n_epoch is not None
|
|
62
|
-
and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
|
|
63
|
-
!= 0
|
|
64
|
-
):
|
|
65
|
-
errors.append(
|
|
66
|
-
f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
|
|
67
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
68
|
-
f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
|
|
69
|
-
+ disable_message
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
case "step":
|
|
73
|
-
# In this case, we need to make sure that the trainer runs val at step intervals
|
|
74
|
-
# that are multiples of `frequency`.
|
|
75
|
-
|
|
76
|
-
# First, we make sure that validation is run at step intervals
|
|
77
|
-
if trainer.check_val_every_n_epoch is not None:
|
|
78
|
-
errors.append(
|
|
79
|
-
"Trainer is running validation at epoch intervals "
|
|
80
|
-
"(i.e., `trainer.check_val_every_n_epoch` is not None) but "
|
|
81
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
82
|
-
"Please set `config.trainer.check_val_every_n_epoch=None` "
|
|
83
|
-
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
84
|
-
+ disable_message
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
# Second, we make sure `trainer.val_check_interval` is an integer
|
|
88
|
-
if not isinstance(trainer.val_check_interval, int):
|
|
89
|
-
errors.append(
|
|
90
|
-
f"Trainer is not running validation at step intervals "
|
|
91
|
-
f"(i.e., `trainer.val_check_interval` is not an integer) but "
|
|
92
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
|
|
93
|
-
"Please set `config.trainer.val_check_interval=None` "
|
|
94
|
-
f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
|
|
95
|
-
+ disable_message
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
# Third, we make sure that the trainer runs val at least every `frequency` steps
|
|
99
|
-
if (
|
|
100
|
-
isinstance(trainer.val_check_interval, int)
|
|
101
|
-
and trainer.val_check_interval % lr_scheduler_config.frequency != 0
|
|
102
|
-
):
|
|
103
|
-
errors.append(
|
|
104
|
-
f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
|
|
105
|
-
f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
|
|
106
|
-
"Please set `config.trainer.val_check_interval` "
|
|
107
|
-
f"to a multiple of {lr_scheduler_config.frequency}. "
|
|
108
|
-
+ disable_message
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
case _:
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
|
-
if not errors:
|
|
115
|
-
return
|
|
116
|
-
|
|
117
|
-
message = (
|
|
118
|
-
"ReduceLRPlateau sanity checks failed with the following errors:\n"
|
|
119
|
-
+ "\n".join(errors)
|
|
120
|
-
)
|
|
121
|
-
match config.trainer.sanity_checking.reduce_lr_on_plateau:
|
|
122
|
-
case "warn":
|
|
123
|
-
log.warning(message)
|
|
124
|
-
case "error":
|
|
125
|
-
raise ValueError(message)
|
|
126
|
-
case _:
|
|
127
|
-
pass
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
@runtime_checkable
|
|
131
|
-
class CustomRLPImplementation(Protocol):
|
|
132
|
-
__reduce_lr_on_plateau__: bool
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
|
|
136
|
-
@override
|
|
137
|
-
def __init__(self, *args, **kwargs):
|
|
138
|
-
super().__init__(*args, **kwargs)
|
|
139
|
-
|
|
140
|
-
global _on_train_start_callback
|
|
141
|
-
self.register_callback(on_train_start=_on_train_start_callback)
|
|
142
|
-
|
|
143
|
-
def reduce_lr_on_plateau_config(
|
|
144
|
-
self,
|
|
145
|
-
lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
|
|
146
|
-
) -> LRSchedulerConfigType:
|
|
147
|
-
if (trainer := self._trainer) is None:
|
|
148
|
-
raise RuntimeError(
|
|
149
|
-
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
150
|
-
"because `self.trainer` is None."
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
# First, resolve the LR scheduler from the provided config.
|
|
154
|
-
lr_scheduler_config: LRSchedulerConfigType
|
|
155
|
-
match lr_scheduler:
|
|
156
|
-
case Mapping():
|
|
157
|
-
lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
|
|
158
|
-
case _:
|
|
159
|
-
lr_scheduler_config = {"scheduler": lr_scheduler}
|
|
160
|
-
|
|
161
|
-
# Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
|
|
162
|
-
if (
|
|
163
|
-
not isinstance(
|
|
164
|
-
lr_scheduler_config["scheduler"],
|
|
165
|
-
torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
166
|
-
)
|
|
167
|
-
) and (
|
|
168
|
-
not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
|
|
169
|
-
or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
|
|
170
|
-
):
|
|
171
|
-
log.warning(
|
|
172
|
-
"`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
|
|
173
|
-
f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
|
|
174
|
-
"`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
|
|
175
|
-
"Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
|
|
176
|
-
"If you are using a custom ReduceLRPlateau scheduler implementation, "
|
|
177
|
-
"please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
|
|
178
|
-
"or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
# If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
|
|
182
|
-
if trainer.check_val_every_n_epoch is not None:
|
|
183
|
-
return {
|
|
184
|
-
"reduce_on_plateau": True,
|
|
185
|
-
"interval": "epoch",
|
|
186
|
-
"frequency": trainer.check_val_every_n_epoch,
|
|
187
|
-
**lr_scheduler_config,
|
|
188
|
-
}
|
|
189
|
-
|
|
190
|
-
# Otherwise, we run val at step intervals.
|
|
191
|
-
if not isinstance(trainer.val_check_batch, int):
|
|
192
|
-
raise ValueError(
|
|
193
|
-
"Could not determine the frequency of ReduceLRPlateau scheduler "
|
|
194
|
-
f"because {trainer.val_check_batch=} is not an integer."
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
return {
|
|
198
|
-
"reduce_on_plateau": True,
|
|
199
|
-
"interval": "step",
|
|
200
|
-
"frequency": trainer.val_check_batch,
|
|
201
|
-
**lr_scheduler_config,
|
|
202
|
-
}
|
|
@@ -1,72 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
from typing import cast
|
|
4
|
-
|
|
5
|
-
import torch.nn as nn
|
|
6
|
-
from lightning.pytorch import LightningModule, Trainer
|
|
7
|
-
from typing_extensions import override
|
|
8
|
-
|
|
9
|
-
from ...util.typing_utils import mixin_base_type
|
|
10
|
-
from ..config import BaseConfig
|
|
11
|
-
from .callback import CallbackRegistrarModuleMixin
|
|
12
|
-
|
|
13
|
-
log = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
|
|
17
|
-
mapping = {id(p): n for n, p in model.named_parameters()}
|
|
18
|
-
return [mapping[id(p)] for p in parameters]
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
|
|
22
|
-
@override
|
|
23
|
-
def __init__(self, *args, **kwargs):
|
|
24
|
-
super().__init__(*args, **kwargs)
|
|
25
|
-
|
|
26
|
-
self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
|
|
27
|
-
self._warned_shared_parameters = False
|
|
28
|
-
|
|
29
|
-
def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
|
|
30
|
-
nonlocal self
|
|
31
|
-
|
|
32
|
-
config = cast(BaseConfig, pl_module.hparams)
|
|
33
|
-
if not config.trainer.supports_shared_parameters:
|
|
34
|
-
return
|
|
35
|
-
|
|
36
|
-
log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
|
|
37
|
-
no_grad_parameters: list[nn.Parameter] = []
|
|
38
|
-
for p, factor in self.shared_parameters:
|
|
39
|
-
if not hasattr(p, "grad") or p.grad is None:
|
|
40
|
-
no_grad_parameters.append(p)
|
|
41
|
-
continue
|
|
42
|
-
|
|
43
|
-
_ = p.grad.data.div_(factor)
|
|
44
|
-
|
|
45
|
-
if no_grad_parameters and not self._warned_shared_parameters:
|
|
46
|
-
no_grad_parameters_str = ", ".join(
|
|
47
|
-
_parameters_to_names(no_grad_parameters, pl_module)
|
|
48
|
-
)
|
|
49
|
-
log.warning(
|
|
50
|
-
"The following parameters were marked as shared, but had no gradients: "
|
|
51
|
-
f"{no_grad_parameters_str}"
|
|
52
|
-
)
|
|
53
|
-
self._warned_shared_parameters = True
|
|
54
|
-
|
|
55
|
-
log.debug(
|
|
56
|
-
f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
self.register_callback(on_after_backward=on_after_backward)
|
|
60
|
-
|
|
61
|
-
def register_shared_parameters(
|
|
62
|
-
self, parameters: list[tuple[nn.Parameter, int | float]]
|
|
63
|
-
):
|
|
64
|
-
for parameter, factor in parameters:
|
|
65
|
-
if not isinstance(parameter, nn.Parameter):
|
|
66
|
-
raise ValueError("Shared parameters must be PyTorch parameters")
|
|
67
|
-
if not isinstance(factor, (int, float)):
|
|
68
|
-
raise ValueError("Factor must be an integer or float")
|
|
69
|
-
|
|
70
|
-
self.shared_parameters.append((parameter, factor))
|
|
71
|
-
|
|
72
|
-
log.info(f"Registered {len(parameters)} shared parameters")
|
|
File without changes
|
|
File without changes
|