nshtrainer 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. nshtrainer/__init__.py +1 -2
  2. nshtrainer/_directory.py +85 -0
  3. nshtrainer/callbacks/__init__.py +12 -1
  4. nshtrainer/callbacks/debug_flag.py +72 -0
  5. nshtrainer/callbacks/directory_setup.py +85 -0
  6. nshtrainer/callbacks/rlp_sanity_checks.py +230 -0
  7. nshtrainer/callbacks/shared_parameters.py +87 -0
  8. nshtrainer/config.py +67 -0
  9. nshtrainer/ll/__init__.py +5 -4
  10. nshtrainer/ll/model.py +7 -0
  11. nshtrainer/loggers/wandb.py +1 -1
  12. nshtrainer/lr_scheduler/linear_warmup_cosine.py +1 -1
  13. nshtrainer/model/__init__.py +0 -21
  14. nshtrainer/model/base.py +124 -67
  15. nshtrainer/model/config.py +7 -1025
  16. nshtrainer/model/{modules → mixins}/logger.py +13 -16
  17. nshtrainer/profiler/__init__.py +13 -0
  18. nshtrainer/profiler/_base.py +29 -0
  19. nshtrainer/profiler/advanced.py +37 -0
  20. nshtrainer/profiler/pytorch.py +83 -0
  21. nshtrainer/profiler/simple.py +36 -0
  22. nshtrainer/trainer/_config.py +787 -0
  23. nshtrainer/trainer/trainer.py +16 -17
  24. nshtrainer/{config → util/config}/__init__.py +1 -0
  25. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/METADATA +1 -1
  26. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/RECORD +28 -22
  27. nshtrainer/model/modules/callback.py +0 -206
  28. nshtrainer/model/modules/debug.py +0 -42
  29. nshtrainer/model/modules/distributed.py +0 -70
  30. nshtrainer/model/modules/profiler.py +0 -24
  31. nshtrainer/model/modules/rlp_sanity_checks.py +0 -202
  32. nshtrainer/model/modules/shared_parameters.py +0 -72
  33. /nshtrainer/{config → util/config}/duration.py +0 -0
  34. {nshtrainer-0.30.1.dist-info → nshtrainer-0.32.0.dist-info}/WHEEL +0 -0
@@ -1,202 +0,0 @@
1
- import logging
2
- from collections.abc import Mapping
3
- from typing import cast
4
-
5
- import torch
6
- from lightning.pytorch import LightningModule, Trainer
7
- from lightning.pytorch.utilities.types import (
8
- LRSchedulerConfigType,
9
- LRSchedulerTypeUnion,
10
- )
11
- from typing_extensions import Protocol, override, runtime_checkable
12
-
13
- from ...util.typing_utils import mixin_base_type
14
- from ..config import BaseConfig
15
- from .callback import CallbackModuleMixin
16
-
17
- log = logging.getLogger(__name__)
18
-
19
-
20
- def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule):
21
- # If we're in PL's "sanity check" mode, we don't need to run this check
22
- if trainer.sanity_checking:
23
- return
24
-
25
- config = cast(BaseConfig, pl_module.hparams)
26
- if config.trainer.sanity_checking.reduce_lr_on_plateau == "disable":
27
- return
28
-
29
- # if no lr schedulers, return
30
- if not trainer.lr_scheduler_configs:
31
- return
32
-
33
- errors: list[str] = []
34
- disable_message = (
35
- "Otherwise, set `config.trainer.sanity_checking.reduce_lr_on_plateau='disable'` "
36
- "to disable this sanity check."
37
- )
38
-
39
- for lr_scheduler_config in trainer.lr_scheduler_configs:
40
- if not lr_scheduler_config.reduce_on_plateau:
41
- continue
42
-
43
- match lr_scheduler_config.interval:
44
- case "epoch":
45
- # we need to make sure that the trainer runs val every `frequency` epochs
46
-
47
- # If `trainer.check_val_every_n_epoch` is None, then Lightning
48
- # will run val every `int(trainer.val_check_interval)` steps.
49
- # So, first we need to make sure that `trainer.val_check_interval` is not None first.
50
- if trainer.check_val_every_n_epoch is None:
51
- errors.append(
52
- "Trainer is not running validation at epoch intervals "
53
- "(i.e., `trainer.check_val_every_n_epoch` is None) but "
54
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
55
- f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. "
56
- + disable_message
57
- )
58
-
59
- # Second, we make sure that the trainer runs val at least every `frequency` epochs
60
- if (
61
- trainer.check_val_every_n_epoch is not None
62
- and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch
63
- != 0
64
- ):
65
- errors.append(
66
- f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but "
67
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
68
- f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. "
69
- + disable_message
70
- )
71
-
72
- case "step":
73
- # In this case, we need to make sure that the trainer runs val at step intervals
74
- # that are multiples of `frequency`.
75
-
76
- # First, we make sure that validation is run at step intervals
77
- if trainer.check_val_every_n_epoch is not None:
78
- errors.append(
79
- "Trainer is running validation at epoch intervals "
80
- "(i.e., `trainer.check_val_every_n_epoch` is not None) but "
81
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
82
- "Please set `config.trainer.check_val_every_n_epoch=None` "
83
- f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
84
- + disable_message
85
- )
86
-
87
- # Second, we make sure `trainer.val_check_interval` is an integer
88
- if not isinstance(trainer.val_check_interval, int):
89
- errors.append(
90
- f"Trainer is not running validation at step intervals "
91
- f"(i.e., `trainer.val_check_interval` is not an integer) but "
92
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used."
93
- "Please set `config.trainer.val_check_interval=None` "
94
- f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. "
95
- + disable_message
96
- )
97
-
98
- # Third, we make sure that the trainer runs val at least every `frequency` steps
99
- if (
100
- isinstance(trainer.val_check_interval, int)
101
- and trainer.val_check_interval % lr_scheduler_config.frequency != 0
102
- ):
103
- errors.append(
104
- f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but "
105
- f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used."
106
- "Please set `config.trainer.val_check_interval` "
107
- f"to a multiple of {lr_scheduler_config.frequency}. "
108
- + disable_message
109
- )
110
-
111
- case _:
112
- pass
113
-
114
- if not errors:
115
- return
116
-
117
- message = (
118
- "ReduceLRPlateau sanity checks failed with the following errors:\n"
119
- + "\n".join(errors)
120
- )
121
- match config.trainer.sanity_checking.reduce_lr_on_plateau:
122
- case "warn":
123
- log.warning(message)
124
- case "error":
125
- raise ValueError(message)
126
- case _:
127
- pass
128
-
129
-
130
- @runtime_checkable
131
- class CustomRLPImplementation(Protocol):
132
- __reduce_lr_on_plateau__: bool
133
-
134
-
135
- class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)):
136
- @override
137
- def __init__(self, *args, **kwargs):
138
- super().__init__(*args, **kwargs)
139
-
140
- global _on_train_start_callback
141
- self.register_callback(on_train_start=_on_train_start_callback)
142
-
143
- def reduce_lr_on_plateau_config(
144
- self,
145
- lr_scheduler: LRSchedulerTypeUnion | LRSchedulerConfigType,
146
- ) -> LRSchedulerConfigType:
147
- if (trainer := self._trainer) is None:
148
- raise RuntimeError(
149
- "Could not determine the frequency of ReduceLRPlateau scheduler "
150
- "because `self.trainer` is None."
151
- )
152
-
153
- # First, resolve the LR scheduler from the provided config.
154
- lr_scheduler_config: LRSchedulerConfigType
155
- match lr_scheduler:
156
- case Mapping():
157
- lr_scheduler_config = cast(LRSchedulerConfigType, lr_scheduler)
158
- case _:
159
- lr_scheduler_config = {"scheduler": lr_scheduler}
160
-
161
- # Make sure the scheduler is a ReduceLRPlateau scheduler. Otherwise, warn the user.
162
- if (
163
- not isinstance(
164
- lr_scheduler_config["scheduler"],
165
- torch.optim.lr_scheduler.ReduceLROnPlateau,
166
- )
167
- ) and (
168
- not isinstance(lr_scheduler_config["scheduler"], CustomRLPImplementation)
169
- or not lr_scheduler_config["scheduler"].__reduce_lr_on_plateau__
170
- ):
171
- log.warning(
172
- "`reduce_lr_on_plateau_config` should only be used with a ReduceLRPlateau scheduler. "
173
- f"The provided scheduler, {lr_scheduler_config['scheduler']}, does not subclass "
174
- "`torch.optim.lr_scheduler.ReduceLROnPlateau`. "
175
- "Please ensure that the scheduler is a ReduceLRPlateau scheduler. "
176
- "If you are using a custom ReduceLRPlateau scheduler implementation, "
177
- "please either (1) make sure that it subclasses `torch.optim.lr_scheduler.ReduceLROnPlateau`, "
178
- "or (2) set the scheduler's `__reduce_lr_on_plateau__` attribute to `True`."
179
- )
180
-
181
- # If trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals.
182
- if trainer.check_val_every_n_epoch is not None:
183
- return {
184
- "reduce_on_plateau": True,
185
- "interval": "epoch",
186
- "frequency": trainer.check_val_every_n_epoch,
187
- **lr_scheduler_config,
188
- }
189
-
190
- # Otherwise, we run val at step intervals.
191
- if not isinstance(trainer.val_check_batch, int):
192
- raise ValueError(
193
- "Could not determine the frequency of ReduceLRPlateau scheduler "
194
- f"because {trainer.val_check_batch=} is not an integer."
195
- )
196
-
197
- return {
198
- "reduce_on_plateau": True,
199
- "interval": "step",
200
- "frequency": trainer.val_check_batch,
201
- **lr_scheduler_config,
202
- }
@@ -1,72 +0,0 @@
1
- import logging
2
- from collections.abc import Sequence
3
- from typing import cast
4
-
5
- import torch.nn as nn
6
- from lightning.pytorch import LightningModule, Trainer
7
- from typing_extensions import override
8
-
9
- from ...util.typing_utils import mixin_base_type
10
- from ..config import BaseConfig
11
- from .callback import CallbackRegistrarModuleMixin
12
-
13
- log = logging.getLogger(__name__)
14
-
15
-
16
- def _parameters_to_names(parameters: Sequence[nn.Parameter], model: nn.Module):
17
- mapping = {id(p): n for n, p in model.named_parameters()}
18
- return [mapping[id(p)] for p in parameters]
19
-
20
-
21
- class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)):
22
- @override
23
- def __init__(self, *args, **kwargs):
24
- super().__init__(*args, **kwargs)
25
-
26
- self.shared_parameters: list[tuple[nn.Parameter, int | float]] = []
27
- self._warned_shared_parameters = False
28
-
29
- def on_after_backward(_trainer: Trainer, pl_module: LightningModule):
30
- nonlocal self
31
-
32
- config = cast(BaseConfig, pl_module.hparams)
33
- if not config.trainer.supports_shared_parameters:
34
- return
35
-
36
- log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...")
37
- no_grad_parameters: list[nn.Parameter] = []
38
- for p, factor in self.shared_parameters:
39
- if not hasattr(p, "grad") or p.grad is None:
40
- no_grad_parameters.append(p)
41
- continue
42
-
43
- _ = p.grad.data.div_(factor)
44
-
45
- if no_grad_parameters and not self._warned_shared_parameters:
46
- no_grad_parameters_str = ", ".join(
47
- _parameters_to_names(no_grad_parameters, pl_module)
48
- )
49
- log.warning(
50
- "The following parameters were marked as shared, but had no gradients: "
51
- f"{no_grad_parameters_str}"
52
- )
53
- self._warned_shared_parameters = True
54
-
55
- log.debug(
56
- f"Done scaling shared parameters. (len={len(self.shared_parameters)})"
57
- )
58
-
59
- self.register_callback(on_after_backward=on_after_backward)
60
-
61
- def register_shared_parameters(
62
- self, parameters: list[tuple[nn.Parameter, int | float]]
63
- ):
64
- for parameter, factor in parameters:
65
- if not isinstance(parameter, nn.Parameter):
66
- raise ValueError("Shared parameters must be PyTorch parameters")
67
- if not isinstance(factor, (int, float)):
68
- raise ValueError("Factor must be an integer or float")
69
-
70
- self.shared_parameters.append((parameter, factor))
71
-
72
- log.info(f"Registered {len(parameters)} shared parameters")
File without changes