torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/core/module.py CHANGED
@@ -1,510 +1,629 @@
1
- import sys
2
- import warnings
3
- from abc import ABC, abstractmethod
4
- from collections.abc import Callable, Iterable, Sequence
5
- from typing import Any, Literal
6
- from typing_extensions import Self, TypeAlias
7
-
8
- import torch
9
- from torch.optim.optimizer import ParamsT
10
-
11
- from ..tensorlist import TensorList
12
- from ..utils.python_tools import _ScalarLoss, flatten
13
-
14
- from .tensorlist_optimizer import (
15
- TensorListOptimizer,
16
- _ClosureType,
17
- _maybe_pass_backward,
18
- )
19
-
20
- def _get_loss(fx0, fx0_approx):
21
- """Returns fx0 if it is not None otherwise fx0_approx"""
22
- if fx0 is None: return fx0_approx
23
- return fx0
24
-
25
-
26
- class OptimizationVars:
27
- """Holds optimization variables. This is usually automatically created by :any:`torchzero.optim.Modular`."""
28
- def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
29
-
30
- self.closure: _ClosureType | None = closure
31
- """A closure that reevaluates the model and returns the loss.
32
- The closure should accept `backward` boolean argument that is True by default, which,
33
- if True, sets `.grad` attributes of all learnable params, for example via `loss.backward()`.
34
- Closure can be None for most first order optimizers."""
35
-
36
- self.ascent: TensorList | None = None
37
- """Ascent direction, for example the gradients.
38
- Will be None on the first module in the chain.
39
- May remain none for modules that create a new closure."""
40
-
41
- self.fx0: _ScalarLoss | None = None
42
- """Loss value strictly with initial parameters of the current step.
43
- If initial parameters have not been evaluated, this should be None."""
44
-
45
- self.fx0_approx: _ScalarLoss | None = None
46
- """Loss value, could be sampled nearby the initial parameters.
47
- This is mainly used as the return value of the step method when fx0 is None."""
48
-
49
- self.grad: TensorList | None = None
50
- """Gradient if it has been computed, otherwise None.
51
-
52
- Gradient must be evaluated strictly with initial parameters of the current step"""
53
-
54
- self.model = model
55
- """model itself (torch.nn.Module) if it was passed, otherwise None."""
56
-
57
- self.post_step_hooks = []
58
- """callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA.
59
-
60
- Signature:
61
-
62
- .. code:: py
63
-
64
- def hook(optimizer: ModularOptimizer, state: OptimizationState) -> None:
65
- ...
66
- """
67
-
68
- def maybe_compute_grad_(self, params: TensorList | None) -> TensorList:
69
- """Computes gradient if it hasn't been computed already, and returns it"""
70
- if self.grad is None:
71
- if params is None: raise ValueError()
72
- if self.closure is not None:
73
- with torch.enable_grad(): self.fx0 = self.closure() # pylint:disable = not-callable (???)
74
- self.grad = params.ensure_grad_().grad
75
-
76
- return self.grad
77
-
78
- def maybe_use_grad_(self, params: TensorList | None) -> TensorList:
79
- """If ascent direction is None, use cloned gradient as ascent direction and returns it.
80
- Otherwise does nothing and returns existing ascent direction.
81
- If gradient hasn't been computed, this also sets `fx0`."""
82
- if self.ascent is None:
83
- self.ascent = self.maybe_compute_grad_(params).clone()
84
-
85
- return self.ascent
86
-
87
- def set_grad_(self, grad: TensorList, params: TensorList):
88
- """Sets gradient to this state and to params"""
89
- self.grad = grad
90
- params.set_grad_(grad)
91
-
92
- def evaluate_fx0_(self, backward=True) -> _ScalarLoss:
93
- """if fx0 is None or if backward is True and self.grad is None, evaluates closure and sets them. Returns fx0"""
94
- if self.fx0 is not None:
95
- if backward and self.grad is None:
96
- warnings.warn('evaluating fx0 with backward=True after it has already been evaluated with backward=False. Something may be inefficient.')
97
- with torch.enable_grad(): self.closure() # set grad #type:ignore
98
- return self.fx0
99
-
100
- if self.closure is None: raise ValueError("Closure is None")
101
- loss = self.fx0 = _maybe_pass_backward(self.closure, backward)
102
- return loss
103
-
104
- def evaluate_fx0_approx_(self, backward=True) -> _ScalarLoss:
105
- """evaluates closure, sets self.fx0_approx and returns it"""
106
- if self.closure is None: raise ValueError("Closure is None")
107
- loss = self.fx0_approx = _maybe_pass_backward(self.closure, backward)
108
- return loss
109
-
110
- def get_loss(self):
111
- """Returns fx0 if it is not None otherwise fx0_approx"""
112
- if self.fx0 is None: return self.fx0_approx
113
- return self.fx0
114
-
115
- def copy(self, clone_ascent = False):
116
- """Copy this optimization state. This will not clone anything other than optionally ascent_direction.
117
-
118
- Args:
119
- clone_ascent (bool, optional): Whether to clone ascent direction. Defaults to False.
120
-
121
- Returns:
122
- A copy of this OptimizationState.
123
- """
124
- vars = OptimizationVars(self.closure, self.model)
125
- vars.fx0 = self.fx0
126
- vars.fx0_approx = self.fx0_approx
127
- vars.grad = self.grad
128
-
129
- if clone_ascent and self.ascent is not None: vars.ascent = self.ascent.clone()
130
- else: vars.ascent = self.ascent
131
-
132
- return vars
133
-
134
- def update_attrs_(self, vars: "OptimizationVars"):
135
- """Updates attributes of this state with attributes of another state.
136
-
137
- This updates `grad`, `fx0` and `fx0_approx`."""
138
- if vars.grad is not None: self.grad = vars.grad
139
- if vars.fx0 is not None: self.fx0 = vars.fx0
140
- if vars.fx0_approx is not None: self.fx0_approx = vars.fx0_approx
141
-
142
-
143
- def add_post_step_hook(self, hook: Callable):
144
- """add a hook that runs after each step. The hook should look like this:
145
- .. code:: py
146
- def hook(optimizer: tz.optim.Modular, state: tz.core.OptimizationState): ...
147
- """
148
- self.post_step_hooks.append(hook)
149
-
150
- _Targets = Literal['ascent', 'grad', 'closure',]
151
- class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
152
- r"""Base class for all modules.
153
-
154
- Args:
155
- defaults (dict): dictionary with default parameters for the module.
156
- target (str, optional):
157
- determines how _update method is used in the default step method.
158
-
159
- "ascent" - it updates the ascent
160
-
161
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
162
-
163
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
164
- """
165
- IS_LR_MODULE = False
166
- def __init__(self, defaults: dict[str, Any], target: Literal['ascent', 'grad', 'closure',] = 'ascent'): # pylint:disable = super-init-not-called
167
- # there can only be 1 LR module, which is placed in the appropriate location among other modules.
168
- # scheduling and per-parameter "lr" options will be routed to that module.
169
- # otherwise, since many update rules like Adam have baked in lr, if multiple such modules are used,
170
- # any lr modification gets applied multiple times.
171
- # Some optimizers will automatically be fused if followed an LR() module (only LR specifically is supported).
172
- if not self.IS_LR_MODULE:
173
- if 'lr' in defaults:
174
- warnings.warn(
175
- f'{self.__class__.__name__} got an "lr" default, but it is not an LR module.\
176
- To support lr scheduling and per-parameter options, rename "lr" to "alpha" and set the default value to 1.\
177
- If this is a learning rate module, set a class attribute `IS_LR_MODULE=True`.'
178
- )
179
-
180
- self._defaults = defaults
181
- self.next_module: OptimizerModule | None = None
182
- """next module that takes this module's state and continues working on it."""
183
- self.children: dict[Any, OptimizerModule] = {}
184
- """children modules."""
185
- self._initialized = False
186
- """True if torch.optim.Optimzer.__init__ was called on this meaning this optimizer has parameters."""
187
- self._default_step_target: Literal['ascent', 'grad', 'closure'] = target
188
- """'ascent', 'grad' or 'closure'"""
189
-
190
- self._has_custom_params = False
191
- """Signifies that :any:`self.set_params` was called on this to set custom params.
192
- When this is True, when parent calls :any:`_update_child_params_` with this module as child,
193
- nothing will happen, as this module already has parameters set."""
194
-
195
- self._passed_params: list[torch.Tensor] | list[dict[str, Any]] | None = None
196
- """list of parameters or parameter groups that were passed to this module and will get passed to child modules."""
197
-
198
- self.post_init_hooks: list[Callable[[Any, Self], Any]] = []
199
- """Hooks that run once after a ModularOptimizer is initialized with this module.
200
-
201
- Signature:
202
-
203
- .. code:: py
204
-
205
- def hook(optimizer: ModularOptimizer, module: OptimizerModule) -> None:
206
- ...
207
-
208
- where `module` is this module.
209
- """
210
-
211
- def __repr__(self):
212
- if self._initialized: return super().__repr__()
213
- return f"uninitialized {self.__class__.__name__}()"
214
-
215
- def state_dict(self):
216
- state_dict = {}
217
- state_dict['__self__'] = super().state_dict()
218
- for k,v in self.children.items():
219
- state_dict[k] = v.state_dict()
220
- return state_dict
221
-
222
- def load_state_dict(self, state_dict: dict[str, Any]) -> None:
223
- super().load_state_dict(state_dict['__self__'])
224
- for k, v in self.children.items():
225
- if k in state_dict:
226
- v.load_state_dict(state_dict[k])
227
- else:
228
- warnings.warn(f"Tried to load state dict for {k}: {v.__class__.__name__}, but it is not present in state_dict with {list(state_dict.keys()) = }")
229
-
230
-
231
- def set_params(self, params: ParamsT):
232
- """
233
- Set parameters to this module. Use this to set per-parameter group settings.
234
- """
235
- self._initialize_(params, set_passed_params = False)
236
- self._has_custom_params = True
237
- return self
238
-
239
- def _initialize_(self, params: ParamsT, set_passed_params: bool):
240
- """Initializes this optimizer and all children with the given parameters."""
241
- if isinstance(params, torch.Tensor): raise ValueError("Params must be an iterable of tensors, not torch.Tensor")
242
- params_list = list(params)
243
- if set_passed_params: self._passed_params = params_list.copy() # type:ignore
244
-
245
- # super().__init__, which is torch.optim.Optimizer.__init__,
246
- # calls self.add_param_group on each param group,
247
- # which in turn calls _update_child_params_,
248
- # which calls add_param_group on each child.
249
- super().__init__(params_list.copy(), self._defaults) # type:ignore
250
- self._initialized = True
251
-
252
- def _set_child_(self, name, child: "_Chainable"):
253
- """Set a child and initialize it's params."""
254
- if not isinstance(child, OptimizerModule): child = _Chain(child)
255
- self.children[name] = child
256
- if self._initialized:
257
- self._update_child_params_(child)
258
-
259
- def _update_child_params_(self, child: "OptimizerModule"):
260
- """Initializes or updates child params with parameters of this module."""
261
- return self._update_next_module_params_(child)
262
-
263
- def _set_next_module(self, next_module: "OptimizerModule"):
264
- """Set next module and initialize it's params."""
265
- self.next_module = next_module
266
- if self._initialized:
267
- self._update_next_module_params_(next_module)
268
-
269
- def _update_next_module_params_(self, next_module: "OptimizerModule"):
270
- """Initializes or updates next module params with parameters of this module."""
271
- # Shouldn't forget that this method is overwritten by some modules
272
- # So if I update it I need to keep that in mind
273
- if self._passed_params is None:
274
- raise RuntimeError(
275
- f"{self.__class__.__name__} is not initialized, but _update_next_module_params_\
276
- was called with next_module = {next_module.__class__.__name__}"
277
- )
278
-
279
- # if child is not initialized, torch.optim.Optimizer.__init__ is called on it by _initialize_ method
280
- if not next_module._initialized:
281
- next_module._initialize_(self._passed_params, set_passed_params=True)
282
-
283
- # otherwise to avoid calling __init__ multiple twice, we erase the param groups and readd them
284
- elif not next_module._has_custom_params:
285
- next_module.param_groups = []
286
- for group in self._passed_params:
287
- if isinstance(group, torch.Tensor): group = {"params": group}
288
- next_module.add_param_group(group)
289
-
290
- else:
291
- # still pass per-parameter settings so that they propagate to further modules
292
- next_module._passed_params = self._passed_params.copy()
293
-
294
-
295
- def add_param_group(self, param_group: dict[str, Any]) -> None:
296
- super().add_param_group(param_group)
297
-
298
- if self.next_module is not None: self._update_next_module_params_(self.next_module)
299
- for c in self.children.values():
300
- self._update_child_params_(c)
301
-
302
- def _update_params_or_step_with_next(self, vars: OptimizationVars, params: TensorList | None = None) -> _ScalarLoss | None:
303
- """If this has no children, update params and return loss. Otherwise step with the next module.
304
-
305
- Optionally pass params to not recreate them if you've already made them.
306
-
307
- Returns:
308
- Loss (fx0 or fx0_approx)
309
- """
310
- # if this has no children, update params and return loss.
311
- if self.next_module is None:
312
- if vars.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
313
- if params is None: params = self.get_params()
314
- params -= vars.ascent # type:ignore
315
- return vars.get_loss()
316
-
317
- # otherwise pass the updated ascent direction to the child
318
- return self.next_module.step(vars)
319
-
320
- @torch.no_grad
321
- def _step_update_closure(self, vars: OptimizationVars) -> _ScalarLoss | None:
322
- """Create a new closure which applies the `_update` method and passes it to the next module."""
323
- if vars.closure is None: raise ValueError('If target == "closure", closure must be provided')
324
-
325
- params = self.get_params()
326
- closure = vars.closure # closure shouldn't reference state attribute because it can be changed
327
- ascent_direction = vars.ascent
328
-
329
- def update_closure(backward = True):
330
- loss = _maybe_pass_backward(closure, backward)
331
-
332
- # on backward, update the ascent direction
333
- if backward:
334
- grad = self._update(vars, ascent_direction) # type:ignore
335
- # set new ascent direction as gradients
336
- # (accumulation doesn't make sense here as closure always calls zero_grad)
337
- for p, g in zip(params,grad):
338
- p.grad = g
339
-
340
- return loss
341
-
342
- # pass new closure to the child.
343
- # if self.next_module is None:
344
- # raise ValueError(f'{self.__class__.__name__} has no child to step with (maybe set "target" from "closure" to something else??).')
345
-
346
- vars.closure = update_closure
347
- return self._update_params_or_step_with_next(vars)
348
-
349
-
350
- @torch.no_grad
351
- def _step_update_target(self, vars: OptimizationVars) -> _ScalarLoss | None:
352
- """Apply _update method to the ascent direction and pass it to the child, or make a step if child is None."""
353
- # the following code by default uses `_update` method which simple modules can override.
354
- # But you can also just override the entire `step`.
355
-
356
- params = None
357
-
358
- # update ascent direction
359
- if self._default_step_target == 'ascent':
360
- # if this is the first module, it uses the gradients
361
- if vars.grad is None: params = self.get_params()
362
- t = vars.maybe_use_grad_(params)
363
- vars.ascent = self._update(vars, t)
364
-
365
- # update gradients
366
- elif self._default_step_target == 'grad':
367
- if params is None: params = self.get_params()
368
- g = vars.maybe_compute_grad_(params)
369
- g = self._update(vars, g)
370
- vars.set_grad_(g, params)
371
- else:
372
- raise ValueError(f"Invalid {self._default_step_target = }")
373
-
374
- # peform an update with the new state, or pass it to the child.
375
- return self._update_params_or_step_with_next(vars, params=params)
376
-
377
- @torch.no_grad
378
- def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
379
- self,
380
- vars: OptimizationVars
381
- ) -> _ScalarLoss | None:
382
- """Perform a single optimization step to update parameter."""
383
-
384
- if self._default_step_target == 'closure': return self._step_update_closure(vars)
385
- return self._step_update_target(vars)
386
-
387
- @torch.no_grad
388
- def _update(self, vars: OptimizationVars, ascent: TensorList) -> TensorList:
389
- """Update `ascent_direction` and return the new ascent direction (but it may update it in place).
390
- Make sure it doesn't return anything from `self.state` to avoid future modules modifying that in-place.
391
-
392
- Before calling `_update`, if ascent direction was not provided to `step`, it will be set to the gradients.
393
-
394
- After generating a new ascent direction with this `_update` method,
395
- if this module has no child, ascent direction will be subtracted from params.
396
- Otherwise everything is passed to the child."""
397
- params = self.get_params()
398
- gradients = ascent.grad
399
- if gradients is None: gradients = [None] * len(params)
400
- settings = tuple(self.get_all_group_keys(list).items())
401
-
402
- new_ascent = TensorList()
403
- for i, (asc, param, grad) in enumerate(zip(ascent, params, gradients)):
404
- kwargs = {"vars": vars, "ascent": asc, "param": param, "grad": grad}
405
- kwargs.update({k:v[i] for k,v in settings})
406
- new_ascent.append(self._single_tensor_update(**kwargs))
407
- return new_ascent
408
-
409
-
410
- def _single_tensor_update(self, vars: OptimizationVars, ascent: torch.Tensor, param: torch.Tensor, grad: torch.Tensor | None, **kwargs) -> torch.Tensor:
411
- """Update function for a single tensor.
412
-
413
- Args:
414
- vars (OptimizationState): holds loss, gradients, etc.
415
- ascent (torch.Tensor): update tensor.
416
- param (torch.Tensor): parameter tensor.
417
- grad (torch.Tensor | None): gradient tensor (may be None)
418
- **kwargs: all per-parameter settings (stuff that you put in `defaults = dict(beta1=beta1, beta2=beta2, eps=eps)`).
419
- """
420
- raise NotImplementedError()
421
-
422
- def return_ascent(self, vars: OptimizationVars, params=None) -> TensorList:
423
- """step with this module and return the ascent as tensorlist"""
424
- if params is None: params = self.get_params()
425
- true_next = self.next_module
426
- self.next_module = _ReturnAscent(params) # type:ignore
427
- ascent: TensorList = self.step(vars) # type:ignore
428
- self.next_module = true_next
429
- return ascent
430
-
431
- def reset_stats(self):
432
- """Resets running stats of this optimizer such as momentum. This is meant to be used stop all
433
- momentum when significantly changing model parameters, for example when setting model parameters
434
- to weighted average every once in a while, like periodic SWA does. Pediodic resetting
435
- may also be beneficial for some optimizers.
436
- By default this method completely clears per-parameter state.
437
- Modules may override this to provide different functionality."""
438
- for g in self.param_groups:
439
- for p in g['params']:
440
- state = self.state[p]
441
- for k in state.copy().keys(): del state[k]
442
-
443
-
444
- class _ReturnAscent:
445
- __slots__ = ('IS_LR_MODULE', 'params', 'children', 'next_module', )
446
- def __init__(self, params):
447
- self.params = params
448
- self.IS_LR_MODULE = False
449
-
450
- self.children = {}
451
- self.next_module = None
452
-
453
- @torch.no_grad
454
- def step(self, vars: OptimizationVars) -> TensorList: # type:ignore
455
- update = vars.maybe_use_grad_(self.params) # this will execute the closure which might be modified
456
- return update
457
-
458
-
459
- class _MaybeReturnAscent(OptimizerModule):
460
- """utility module that either returns ascent or updates the parameters depending on `_return_ascent`, used in Chain."""
461
- def __init__(self):
462
- super().__init__({})
463
- self._return_ascent = False
464
-
465
- @torch.no_grad
466
- def step(self, vars: OptimizationVars):
467
- assert self.next_module is None, self.next_module
468
-
469
- if self._return_ascent:
470
- return vars.ascent
471
-
472
- return self._update_params_or_step_with_next(vars)
473
-
474
- _Chainable = OptimizerModule | Iterable[OptimizerModule]
475
-
476
- class _Chain(OptimizerModule):
477
- """
478
- Utility module that chains multiple modules together, usually used by other modules.
479
- """
480
- def __init__(self, *modules: _Chainable):
481
- super().__init__({})
482
- flat_modules: list[OptimizerModule] = flatten(modules)
483
-
484
- if any(not hasattr(i, "step") for i in flat_modules):
485
- raise TypeError(f"One of the modules is not an OptimizerModule, got {[i.__class__.__name__ for i in flat_modules]}")
486
-
487
- # first module is chain's child, second module is first module's child, etc
488
- self._set_child_('first', flat_modules[0])
489
- if len(flat_modules) > 1:
490
- for i, m in enumerate(flat_modules[:-1]):
491
- m._set_next_module(flat_modules[i+1])
492
-
493
- self._last_module = flat_modules[-1]
494
-
495
- self._chain_modules = flat_modules
496
-
497
- @torch.no_grad
498
- def step(self, vars: OptimizationVars):
499
- # no next module, step with the child
500
- if self.next_module is None:
501
- return self.children['first'].step(vars)
502
-
503
- # return ascent and pass it to next module
504
- # we do this because updating parameters directly is often more efficient
505
- params = self.get_params()
506
- self._last_module.next_module = _ReturnAscent(params) # type:ignore
507
- vars.ascent: TensorList = self.children['first'].step(vars) # type:ignore
508
- self._last_module.next_module = None
509
-
510
- return self._update_params_or_step_with_next(vars)
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from collections import ChainMap, defaultdict
4
+ from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
+ from operator import itemgetter
6
+ from typing import Any, final, overload
7
+
8
+ import torch
9
+
10
+ from ..utils import (
11
+ Init,
12
+ ListLike,
13
+ Params,
14
+ _make_param_groups,
15
+ get_state_vals,
16
+ )
17
+ from ..utils.python_tools import flatten
18
+
19
+
20
+ def _closure_backward(closure, params, retain_graph, create_graph):
21
+ with torch.enable_grad():
22
+ if not (retain_graph or create_graph):
23
+ return closure()
24
+
25
+ for p in params: p.grad = None
26
+ loss = closure(False)
27
+ grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
28
+ for p,g in zip(params,grad): p.grad = g
29
+ return loss
30
+
31
+ # region Vars
32
+ # ----------------------------------- vars ----------------------------------- #
33
+ class Vars:
34
+ """
35
+ Holds the state and context passed between optimizer modules during a step.
36
+
37
+ This class acts as a mutable container for information relevant to the current
38
+ optimization step, such as parameters, gradients, loss, and the computed update.
39
+ Modules read from and write to this object to coordinate their actions.
40
+ """
41
+ def __init__(
42
+ self,
43
+ params: list[torch.Tensor],
44
+ closure: Callable | None,
45
+ model: torch.nn.Module | None,
46
+ current_step: int,
47
+ ):
48
+ self.params: list[torch.Tensor] = params
49
+ """List of all parameters with requires_grad = True."""
50
+
51
+ self.closure = closure
52
+ """A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
53
+
54
+ self.model = model
55
+ """torch.nn.Module object of the model, None if it wasn't specified."""
56
+
57
+ self.current_step: int = current_step
58
+ """global current step, starts at 0"""
59
+
60
+ self.update: list[torch.Tensor] | None = None
61
+ """
62
+ current update, at the end this is subtracted from model parameters unless it is None.
63
+
64
+ If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
65
+ """
66
+
67
+ self.grad: list[torch.Tensor] | None = None
68
+ """gradient with current parameters. If closure is not None, this is set to None and can be calculated if needed."""
69
+
70
+ self.loss: torch.Tensor | Any | None = None
71
+ """loss with current parameters."""
72
+
73
+ self.loss_approx: torch.Tensor | Any | None = None
74
+ """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
75
+ whereas some other modules require loss strictly at current point."""
76
+
77
+ self.post_step_hooks: list[Callable[[Modular, Vars]]] = []
78
+ """list of functions to be called after optimizer step.
79
+ The signature is:
80
+
81
+ .. code:: py
82
+
83
+ def hook(optimizer: Modular, vars: Vars): ...
84
+
85
+ """
86
+
87
+ self.is_last: bool = False
88
+ """
89
+ Indicates that current module is either last or next-to-last before a learning rate module.
90
+ This is always False if current module has children or is a child.
91
+ """
92
+
93
+ self.nested_is_last: bool = False
94
+ """
95
+ Indicates that current module is either last or next-to-last before a learning rate module, for modules
96
+ that have children.
97
+ """
98
+
99
+ self.last_module_lrs: list[float] | None = None
100
+ """
101
+ List of per-parameter learning rates if current module is next-to-last before a
102
+ learning rate module, otherwise this is set to None. Ignore this unless you are manually applying
103
+ update to parameters.
104
+ """
105
+
106
+ self.stop: bool = False
107
+ """if True, all following modules will be skipped."""
108
+
109
+ self.skip_update: bool = False
110
+ """if True, the parameters will not be updated"""
111
+
112
+ def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
113
+ """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning :code:`vars.loss`.
114
+ Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
115
+
116
+ if self.loss is None:
117
+ if self.closure is None: raise RuntimeError("closure is None")
118
+ if backward:
119
+ with torch.enable_grad():
120
+ self.loss = self.loss_approx = _closure_backward(
121
+ closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
122
+ )
123
+
124
+ # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
125
+ # it is technically a more correct approach for when some parameters conditionally receive gradients
126
+ # and in this case it shouldn't be slower.
127
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
128
+ else:
129
+ self.loss = self.loss_approx = self.closure(False)
130
+
131
+ # if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
132
+ # and now it is requested to be evaluated with backward.
133
+ if backward and self.grad is None:
134
+ warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
135
+ if self.closure is None: raise RuntimeError("closure is None")
136
+
137
+ with torch.enable_grad():
138
+ self.loss = self.loss_approx = _closure_backward(
139
+ closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
140
+ )
141
+ self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
142
+ return self.loss # type:ignore
143
+
144
+ def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
145
+ """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
146
+ :code:`vars.grad` and potentially :code:`vars.loss`. Do not call this at perturbed parameters."""
147
+ if self.grad is None:
148
+ if self.closure is None: raise RuntimeError("closure is None")
149
+ self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
150
+
151
+ assert self.grad is not None
152
+ return self.grad
153
+
154
+ def get_update(self) -> list[torch.Tensor]:
155
+ """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to :code:`vars.update`.
156
+ Computing the gradients may assign :code:`vars.grad` and :code:`vars.loss` if they haven't been computed.
157
+ Do not call this at perturbed parameters."""
158
+ if self.update is None: self.update = [g.clone() for g in self.get_grad()]
159
+ return self.update
160
+
161
+ def clone(self, clone_update: bool):
162
+ """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via :code:`torch.clone`)."""
163
+ copy = Vars(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step)
164
+
165
+ if clone_update and self.update is not None:
166
+ copy.update = [u.clone() for u in self.update]
167
+ else:
168
+ copy.update = self.update
169
+
170
+ copy.grad = self.grad
171
+ copy.loss = self.loss
172
+ copy.loss_approx = self.loss_approx
173
+ copy.post_step_hooks = self.post_step_hooks
174
+ copy.stop = self.stop
175
+ copy.skip_update = self.skip_update
176
+
177
+ return copy
178
+
179
+ def update_attrs_from_clone_(self, vars: "Vars"):
180
+ """Updates attributes of this `Vars` instance from a cloned instance.
181
+ Typically called after a child module has processed a cloned `Vars`
182
+ object. This propagates any newly computed loss or gradient values
183
+ from the child's context back to the parent `Vars` if the parent
184
+ didn't have them computed already.
185
+ """
186
+ if self.loss is None: self.loss = vars.loss
187
+ if self.loss_approx is None: self.loss_approx = vars.loss_approx
188
+ if self.grad is None: self.grad = vars.grad
189
+
190
+ def zero_grad(self, set_to_none=True):
191
+ if set_to_none:
192
+ for p in self.params: p.grad = None
193
+ else:
194
+ grads = [p.grad for p in self.params if p.grad is not None]
195
+ if len(grads) != 0: torch._foreach_zero_(grads)
196
+
197
+ # endregion
198
+
199
+ # region Module
200
+ # ---------------------------------- module ---------------------------------- #
201
+ class Module(ABC):
202
+ """Abstract base class for an optimizer modules.
203
+
204
+ Modules represent distinct steps or transformations within the optimization
205
+ process (e.g., momentum, line search, gradient accumulation).
206
+
207
+ A module does not store parameters, but it maintains per-parameter state and per-parameter settings
208
+ where tensors are used as keys (same as torch.optim.Optimizer state.)
209
+
210
+ Args:
211
+ defaults (dict[str, Any] | None):
212
+ a dict containing default values of optimization options (used when a parameter group doesn't specify them).
213
+ """
214
+ def __init__(self, defaults: dict[str, Any] | None = None):
215
+ if defaults is None: defaults = {}
216
+ self.defaults: dict[str, Any] = defaults
217
+
218
+ # settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
219
+ # 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
220
+ # 1 - global per-parameter setting overrides in param_groups passed to Modular - medium priority
221
+ # 2 - `defaults` - lowest priority
222
+ self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
223
+ """per-parameter settings."""
224
+
225
+ self.state: defaultdict[torch.Tensor, dict[str, Any]] = defaultdict(dict)
226
+ """Per-parameter state (e.g., momentum buffers)."""
227
+
228
+ self.global_state: dict[str, Any] = {}
229
+ """Global state for things that are not per-parameter."""
230
+
231
+ self.children: dict[str, Module] = {}
232
+ """A dictionary of child modules."""
233
+
234
+ self._overridden_keys = set()
235
+ """tracks keys overridden with `set_param_groups`, only used to not give a warning"""
236
+
237
+
238
+ def set_param_groups(self, param_groups: Params):
239
+ """Set custom parameter groups with per-parameter settings that this module will use."""
240
+ param_groups = _make_param_groups(param_groups, differentiable=False)
241
+ for group in param_groups:
242
+ settings = group.copy()
243
+ params = settings.pop('params')
244
+ if not settings: continue
245
+ self._overridden_keys.update(*settings.keys())
246
+
247
+ for param in params:
248
+ self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
249
+ return self
250
+
251
+ def set_child(self, key: str, module: "Module | Sequence[Module]"):
252
+ self.children[key] = maybe_chain(module)
253
+
254
+ def set_children_sequence(self, modules: "Iterable[Module | Sequence[Module]]", prefix = 'module_'):
255
+ modules = list(modules)
256
+ for i, m in enumerate(modules):
257
+ self.set_child(f'{prefix}{i}', maybe_chain(m))
258
+
259
+ def get_children_sequence(self, prefix = 'module_'):
260
+ return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
261
+
262
+ def __repr__(self):
263
+ s = self.__class__.__name__
264
+ if self.children:
265
+ s = f'{s}('
266
+ for k,v in self.children.items():
267
+ s = f'{s}{k}={v}, '
268
+ s = f'{s[:-2]})'
269
+ return s
270
+
271
+ @overload
272
+ def get_settings(self, key: str, *,
273
+ params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike: ...
274
+ @overload
275
+ def get_settings(self, key: list[str] | tuple[str,...], *,
276
+ params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
277
+ @overload
278
+ def get_settings(self, key: str, key2: str, *keys: str,
279
+ params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> list[ListLike]: ...
280
+
281
+ def get_settings(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
282
+ params: Sequence[torch.Tensor], cls: type[ListLike] = list) -> ListLike | list[ListLike]:
283
+ # if isinstance(params, Vars): params = params.params
284
+ return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
285
+
286
+
287
+ @overload
288
+ def get_state(self, key: str, *,
289
+ params: Sequence[torch.Tensor], must_exist: bool = False, init: Init = torch.zeros_like,
290
+ cls: type[ListLike] = list) -> ListLike: ...
291
+ @overload
292
+ def get_state(self, key: list[str] | tuple[str,...], *,
293
+ params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
294
+ cls: type[ListLike] = list) -> list[ListLike]: ...
295
+ @overload
296
+ def get_state(self, key: str, key2: str, *keys: str,
297
+ params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
298
+ cls: type[ListLike] = list) -> list[ListLike]: ...
299
+
300
+ def get_state(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
301
+ params: Sequence[torch.Tensor], must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
302
+ cls: type[ListLike] = list) -> ListLike | list[ListLike]:
303
+ """Returns values of per-parameter state for a given key.
304
+ If key doesn't exist, create it with inits.
305
+
306
+ This functions like `operator.itemgetter`, returning a single value if called with a single key,
307
+ or tuple of called with multiple keys.
308
+
309
+ If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
310
+
311
+ .. code:: py
312
+
313
+ exp_avg = self.state_vals("exp_avg")
314
+ # returns cls (by default TensorList)
315
+
316
+ exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
317
+ # returns list of cls
318
+
319
+ exp_avg = self.state_vals(["exp_avg"])
320
+ # always returns a list of cls, even if got a single key
321
+
322
+
323
+ Args:
324
+ *keys (str):
325
+ the keys to look for in each parameters state.
326
+ if a single key is specified, this returns a single value or cls,
327
+ otherwise this returns a list of values or cls per each key.
328
+ params (Iterable[torch.Tensor]): parameters to return the states for.
329
+ must_exist (bool, optional):
330
+ If a key doesn't exist in state, if True, raises a KeyError, if False, creates the value
331
+ using `init` argument (default = False).
332
+ init (Init | Sequence[Init], optional):
333
+ how to initialize a key if it doesn't exist.
334
+
335
+ can be
336
+ - Callable like torch.zeros_like
337
+ - string - "param" or "grad" to use cloned params or cloned grads.
338
+ - anything else other than list/tuples will be used as-is, tensors will be cloned.
339
+ - list/tuple of values per each parameter, only if got a single key.
340
+ - list/tuple of values per each key, only if got multiple keys.
341
+
342
+ if multiple `keys` are specified, inits is per-key!
343
+
344
+ Defaults to torch.zeros_like.
345
+ cls (type[ListLike], optional):
346
+ MutableSequence class to return, this only has effect when state_keys is a list/tuple. Defaults to list.
347
+
348
+ Returns:
349
+ - if state_keys has a single key and keys has a single key, return a single value.
350
+ - if state_keys has a single key and keys has multiple keys, return a list of values.
351
+ - if state_keys has multiple keys and keys has a single key, return cls.
352
+ - if state_keys has multiple keys and keys has multiple keys, return list of cls.
353
+ """
354
+ # if isinstance(params, Vars): params = params.params
355
+ return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
356
+
357
+ # def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
358
+ # # if isinstance(params, Vars): params = params.params
359
+ # return itemgetter(*keys)(self.settings[params[0]])
360
+
361
+ def state_dict(self):
362
+ """state dict"""
363
+ packed_state = {id(k):v for k,v in self.state.items()}
364
+ packed_settings = {id(k):v for k,v in self.settings.items()}
365
+
366
+ state_dict = {
367
+ "state": packed_state,
368
+ "settings":
369
+ {
370
+ "local": {k:v.maps[0] for k,v in packed_settings.items()},
371
+ "global": {k:v.maps[1] for k,v in packed_settings.items()},
372
+ "defaults": {k:v.maps[2] for k,v in packed_settings.items()},
373
+ },
374
+ "global_state": self.global_state,
375
+ "extra": self._extra_pack(),
376
+ "children": {k: v.state_dict() for k, v in self.children.items()}
377
+ }
378
+ return state_dict
379
+
380
+ def load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
381
+ # load state
382
+ state = state_dict['state']
383
+ self.state.clear()
384
+ self.state.update({id_to_tensor[k]:v for k,v in state.items()})
385
+
386
+ # load settings
387
+ settings = state_dict['settings']
388
+ self.settings.clear()
389
+ for k, v in settings['local'].items(): self.settings[id_to_tensor[k]].maps[0].update(v)
390
+ for k, v in settings['global'].items(): self.settings[id_to_tensor[k]].maps[1].update(v)
391
+ for k, v in settings['defaults'].items(): self.settings[id_to_tensor[k]].maps[2].update(v)
392
+
393
+ # load global state
394
+ self.global_state.clear()
395
+ self.global_state.update(state_dict['global_state'])
396
+
397
+ # children
398
+ for k, v in state_dict['children']:
399
+ if k in self.children: self.children[k].load_state_dict(v, id_to_tensor)
400
+ else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
401
+
402
+ # extra info
403
+ self._extra_unpack(state_dict['extra'])
404
+
405
+ # ---------------------------- OVERRIDABLE METHODS --------------------------- #
406
+ @abstractmethod
407
+ def step(self, vars: Vars) -> Vars:
408
+ """performs a step, returns new vars but may update them in-place."""
409
+
410
+ def reset(self):
411
+ """Resets the internal state of the module (e.g. momentum)."""
412
+ # no complex logic is allowed there because this is overridden by many modules
413
+ # where super().reset() shouldn't be called
414
+ self.state.clear()
415
+ self.global_state.clear()
416
+
417
+ def _extra_pack(self):
418
+ return {}
419
+
420
+ def _extra_unpack(self, x):
421
+ pass
422
+
423
+ # endregion
424
+
425
+ Chainable = Module | Sequence[Module]
426
+
427
+
428
+ def unroll_modules(*modules: Chainable) -> list[Module]:
429
+ unrolled = []
430
+
431
+ for m in modules:
432
+ if isinstance(m, Module):
433
+ unrolled.append(m)
434
+ unrolled.extend(unroll_modules(list(m.children.values())))
435
+ else:
436
+ unrolled.extend(unroll_modules(*m))
437
+
438
+ return unrolled
439
+
440
+
441
+ # region Modular
442
+ # ---------------------------------- Modular --------------------------------- #
443
+ # have to inherit from Modular to support lr schedulers
444
+ # although Accelerate doesn't work due to converting param_groups to a dict
445
+ class Modular(torch.optim.Optimizer):
446
+ """Chains multiple modules into an optimizer.
447
+
448
+ Args:
449
+ params (Params | torch.nn.Module): An iterable of parameters to optimize
450
+ (typically `model.parameters()`), an iterable of parameter group dicts,
451
+ or a `torch.nn.Module` instance.
452
+ *modules (Module): A sequence of `Module` instances that define the
453
+ optimization algorithm steps.
454
+ """
455
+ # this is specifically for lr schedulers
456
+ param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
457
+
458
+ def __init__(self, params: Params | torch.nn.Module, *modules: Module):
459
+ self.model: torch.nn.Module | None = None
460
+ """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
461
+ if isinstance(params, torch.nn.Module):
462
+ self.model = params
463
+ params = params.parameters()
464
+
465
+ self.modules = modules
466
+ """Top-level modules providedduring initialization."""
467
+
468
+ self.unrolled_modules = unroll_modules(self.modules)
469
+ """A flattened list of all modules including all children."""
470
+
471
+ param_groups = _make_param_groups(params, differentiable=False)
472
+ self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
473
+
474
+ # make sure there is no more than a single learning rate module
475
+ lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
476
+ if len(lr_modules) > 1:
477
+ warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
478
+
479
+ # iterate over all per-parameter settings overrides and check if they are applied at most once
480
+ for group in param_groups:
481
+ for k in group:
482
+ if k in ('params', 'lr'): continue
483
+ modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
484
+ if len(modules_with_k) > 1:
485
+ warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
486
+
487
+ # defaults for schedulers
488
+ defaults = {}
489
+ for m in self.unrolled_modules: defaults.update(m.defaults)
490
+ super().__init__(param_groups, defaults=defaults)
491
+
492
+ # note - this is what super init does:
493
+
494
+ # self.defaults = defaults
495
+ # for param_group in param_groups:
496
+ # self.add_param_group(param_group)
497
+
498
+ self.current_step = 0
499
+ """The global step counter for the optimizer."""
500
+
501
+ def add_param_group(self, param_group: dict[str, Any]):
502
+ proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
503
+ self.param_groups.append(ChainMap(proc_param_group, self.defaults))
504
+
505
+ for p in proc_param_group['params']:
506
+ # updates global per-parameter setting overrides (medium priority)
507
+ self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
508
+
509
+ def state_dict(self):
510
+ all_params = [p for g in self.param_groups for p in g['params']]
511
+ id_to_idx = {id(p): i for i,p in enumerate(all_params)}
512
+
513
+ groups = []
514
+ for g in self.param_groups:
515
+ g = g.copy()
516
+ g['params'] = [id_to_idx[id(p)] for p in g['params']]
517
+ groups.append(g)
518
+
519
+ state_dict = {
520
+ "idx_to_id": {v:k for k,v in id_to_idx.items()},
521
+ "params": all_params,
522
+ "groups": groups,
523
+ "defaults": self.defaults,
524
+ "modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
525
+ }
526
+ return state_dict
527
+
528
+ def load_state_dict(self, state_dict: dict):
529
+ self.defaults.clear()
530
+ self.defaults.update(state_dict['defaults'])
531
+
532
+ idx_to_param = dict(enumerate(state_dict['params']))
533
+ groups = []
534
+ for g in state_dict['groups']:
535
+ g = g.copy()
536
+ g['params'] = [idx_to_param[p] for p in g['params']]
537
+ groups.append(g)
538
+
539
+ self.param_groups.clear()
540
+ for group in groups:
541
+ self.add_param_group(group)
542
+
543
+ id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
544
+ for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
545
+ m.load_state_dict(sd, id_to_tensor)
546
+
547
+
548
+ def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
549
+ # propagate global per-parameter setting overrides
550
+ for g in self.param_groups:
551
+ settings = dict(g.maps[0]) # ignore defaults
552
+ params = settings.pop('params')
553
+ if not settings: continue
554
+
555
+ for p in params:
556
+ if not p.requires_grad: continue
557
+ for map in self._per_parameter_global_settings[p]: map.update(settings)
558
+
559
+ # create vars
560
+ params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
561
+ vars = Vars(params=params, closure=closure, model=self.model, current_step=self.current_step)
562
+
563
+ # if closure is None, assume backward has been called and gather grads
564
+ if closure is None:
565
+ vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
566
+
567
+ last_module = self.modules[-1]
568
+ last_lr = last_module.defaults.get('lr', None)
569
+ n_modules = len(self.modules)
570
+
571
+ # step
572
+ for i, module in enumerate(self.modules):
573
+ if i!=0: vars = vars.clone(clone_update=False)
574
+
575
+ # last module, or next to last module before lr
576
+ if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
577
+ if module.children: vars.nested_is_last = True
578
+ else: vars.is_last = True
579
+ if last_lr is not None: vars.last_module_lrs = last_module.get_settings('lr', params=vars.params)
580
+
581
+ vars = module.step(vars)
582
+ if vars.stop: break
583
+
584
+ # apply update
585
+ if not vars.skip_update:
586
+ with torch.no_grad():
587
+ torch._foreach_sub_(params, vars.get_update())
588
+
589
+ for hook in vars.post_step_hooks:
590
+ hook(self, vars)
591
+
592
+ self.current_step += 1
593
+ return vars.loss if vars.loss is not None else vars.loss_approx
594
+
595
+ def __repr__(self):
596
+ return f'Modular({", ".join(str(m) for m in self.modules)})'
597
+ # endregion
598
+
599
+ # region Chain
600
+ # ----------------------------------- Chain ---------------------------------- #
601
+ class Chain(Module):
602
+ """Chain of modules, mostly used internally"""
603
+ def __init__(self, *modules: Module | Iterable[Module]):
604
+ super().__init__()
605
+ flat_modules: list[Module] = flatten(modules)
606
+ for i, module in enumerate(flat_modules):
607
+ self.set_child(f'module_{i}', module)
608
+
609
+ def step(self, vars):
610
+ for i in range(len(self.children)):
611
+ vars = self.children[f'module_{i}'].step(vars)
612
+ if vars.stop: break
613
+ return vars
614
+
615
+ def __repr__(self):
616
+ s = self.__class__.__name__
617
+ if self.children:
618
+ if s == 'Chain': s = 'C' # to shorten it
619
+ s = f'{s}({", ".join(str(m) for m in self.children.values())}'
620
+ return s
621
+
622
+ def maybe_chain(*modules: Chainable) -> Module:
623
+ """Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
624
+ flat_modules: list[Module] = flatten(modules)
625
+ if len(flat_modules) == 1:
626
+ return flat_modules[0]
627
+ return Chain(*flat_modules)
628
+ # endregion
629
+