torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
torchzero/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from . import tensorlist as tl # this needs to be imported first to avoid circular imports
2
+ from .tensorlist import TensorList
3
+ from . import optim, modules as m, core, random
4
+ from .optim import Modular
@@ -0,0 +1,13 @@
1
+ import sys
2
+
3
+ from .module import (
4
+ OptimizationState,
5
+ OptimizerModule,
6
+ _Chain,
7
+ _Chainable,
8
+ _get_loss,
9
+ _ScalarLoss,
10
+ _Targets,
11
+ )
12
+
13
+ from .tensorlist_optimizer import TensorListOptimizer, ParamsT, _ClosureType, _maybe_pass_backward
@@ -0,0 +1,471 @@
1
+ import sys
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable, Iterable, Sequence
5
+ from typing import Any, Literal
6
+ from typing_extensions import Self, TypeAlias
7
+
8
+ import torch
9
+ from torch.optim.optimizer import ParamsT
10
+
11
+ from ..tensorlist import TensorList
12
+ from ..utils.python_tools import _ScalarLoss, flatten
13
+
14
+ from .tensorlist_optimizer import (
15
+ TensorListOptimizer,
16
+ _ClosureType,
17
+ _maybe_pass_backward,
18
+ )
19
+
20
+ def _get_loss(fx0, fx0_approx):
21
+ """Returns fx0 if it is not None otherwise fx0_approx"""
22
+ if fx0 is None: return fx0_approx
23
+ return fx0
24
+
25
+
26
+ class OptimizationState:
27
+ """Holds optimization state. This is usually automatically created by :any:`torchzero.optim.Modular`."""
28
+ def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
29
+
30
+ self.closure: _ClosureType | None = closure
31
+ """A closure that reevaluates the model and returns the loss.
32
+ The closure should accept `backward` boolean argument that is True by default, which,
33
+ if True, sets `.grad` attributes of all learnable params, for example via `loss.backward()`.
34
+ Closure can be None for most first order optimizers."""
35
+
36
+ self.ascent: TensorList | None = None
37
+ """Ascent direction, for example the gradients.
38
+ Will be None on the first module in the chain.
39
+ May remain none for modules that create a new closure."""
40
+
41
+ self.fx0: _ScalarLoss | None = None
42
+ """Loss value strictly with initial parameters of the current step.
43
+ If initial parameters have not been evaluated, this should be None."""
44
+
45
+ self.fx0_approx: _ScalarLoss | None = None
46
+ """Loss value, could be sampled nearby the initial parameters.
47
+ This is mainly used as the return value of the step method when fx0 is None."""
48
+
49
+ self.grad: TensorList | None = None
50
+ """Gradient if it has been computed, otherwise None.
51
+
52
+ Gradient must be evaluated strictly with initial parameters of the current step"""
53
+
54
+ self.model = model
55
+ """model itself (torch.nn.Module) if it was passed, otherwise None."""
56
+
57
+ self.post_step_hooks = []
58
+ """callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA.
59
+
60
+ Signature:
61
+
62
+ .. code:: py
63
+
64
+ def hook(optimizer: ModularOptimizer, state: OptimizationState) -> None:
65
+ ...
66
+ """
67
+
68
+ def maybe_compute_grad_(self, params: TensorList | None) -> TensorList:
69
+ """Computes gradient if it hasn't been computed already, and returns it"""
70
+ if self.grad is None:
71
+ if params is None: raise ValueError()
72
+ if self.closure is not None:
73
+ with torch.enable_grad(): self.fx0 = self.closure() # pylint:disable = not-callable (???)
74
+ self.grad = params.ensure_grad_().grad
75
+
76
+ return self.grad
77
+
78
+ def maybe_use_grad_(self, params: TensorList | None) -> TensorList:
79
+ """If ascent direction is None, use cloned gradient as ascent direction and returns it.
80
+ Otherwise does nothing and returns existing ascent direction.
81
+ If gradient hasn't been computed, this also sets `fx0`."""
82
+ if self.ascent is None:
83
+ self.ascent = self.maybe_compute_grad_(params).clone()
84
+
85
+ return self.ascent
86
+
87
+ def set_grad_(self, grad: TensorList, params: TensorList):
88
+ """Sets gradient to this state and to params"""
89
+ self.grad = grad
90
+ params.set_grad_(grad)
91
+
92
+ def evaluate_fx0_(self, backward=True) -> _ScalarLoss:
93
+ """if fx0 is None or if backward is True and self.grad is None, evaluates closure and sets them. Returns fx0"""
94
+ if self.fx0 is not None:
95
+ if backward and self.grad is None:
96
+ warnings.warn('evaluating fx0 with backward=True after it has already been evaluated with backward=False. Something may be inefficient.')
97
+ with torch.enable_grad(): self.closure() # set grad #type:ignore
98
+ return self.fx0
99
+
100
+ if self.closure is None: raise ValueError("Closure is None")
101
+ loss = self.fx0 = _maybe_pass_backward(self.closure, backward)
102
+ return loss
103
+
104
+ def evaluate_fx0_approx_(self, backward=True) -> _ScalarLoss:
105
+ """evaluates closure, sets self.fx0_approx and returns it"""
106
+ if self.closure is None: raise ValueError("Closure is None")
107
+ loss = self.fx0_approx = _maybe_pass_backward(self.closure, backward)
108
+ return loss
109
+
110
+ def get_loss(self):
111
+ """Returns fx0 if it is not None otherwise fx0_approx"""
112
+ if self.fx0 is None: return self.fx0_approx
113
+ return self.fx0
114
+
115
+ def copy(self, clone_ascent = False):
116
+ """Copy this optimization state. This will not clone anything other than optionally ascent_direction.
117
+
118
+ Args:
119
+ clone_ascent (bool, optional): Whether to clone ascent direction. Defaults to False.
120
+
121
+ Returns:
122
+ A copy of this OptimizationState.
123
+ """
124
+ state = OptimizationState(self.closure, self.model)
125
+ state.fx0 = self.fx0
126
+ state.fx0_approx = self.fx0_approx
127
+ state.grad = self.grad
128
+
129
+ if clone_ascent and self.ascent is not None: state.ascent = self.ascent.clone()
130
+ else: state.ascent = self.ascent
131
+
132
+ return state
133
+
134
+ def update_attrs_(self, state: "OptimizationState"):
135
+ """Updates attributes of this state with attributes of another state.
136
+
137
+ This updates `grad`, `fx0` and `fx0_approx`."""
138
+ if state.grad is not None: self.grad = state.grad
139
+ if state.fx0 is not None: self.fx0 = state.fx0
140
+ if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx
141
+
142
+
143
+ def add_post_step_hook(self, hook: Callable):
144
+ """add a hook that runs after each step. The hook should look like this:
145
+ .. code:: py
146
+ def hook(optimizer: tz.optim.Modular, state: tz.core.OptimizationState): ...
147
+ """
148
+ self.post_step_hooks.append(hook)
149
+
150
+ _Targets = Literal['ascent', 'grad', 'closure',]
151
+ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
152
+ r"""Base class for all modules.
153
+
154
+ Args:
155
+ defaults (dict): dictionary with default parameters for the module.
156
+ target (str, optional):
157
+ determines how _update method is used in the default step method.
158
+
159
+ "ascent" - it updates the ascent
160
+
161
+ "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
162
+
163
+ "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
164
+ """
165
+ IS_LR_MODULE = False
166
+ def __init__(self, defaults: dict[str, Any], target: Literal['ascent', 'grad', 'closure',] = 'ascent'): # pylint:disable = super-init-not-called
167
+ # there can only be 1 LR module, which is placed in the appropriate location among other modules.
168
+ # scheduling and per-parameter "lr" options will be routed to that module.
169
+ # otherwise, since many update rules like Adam have baked in lr, if multiple such modules are used,
170
+ # any lr modification gets applied multiple times.
171
+ # Some optimizers will automatically be fused if followed an LR() module (only LR specifically is supported).
172
+ if not self.IS_LR_MODULE:
173
+ if 'lr' in defaults:
174
+ warnings.warn(
175
+ f'{self.__class__.__name__} got an "lr" default, but it is not an LR module.\
176
+ To support lr scheduling and per-parameter options, rename "lr" to "alpha" and set the default value to 1.\
177
+ If this is a learning rate module, set a class attribute `IS_LR_MODULE=True`.'
178
+ )
179
+
180
+ self._defaults = defaults
181
+ self.next_module: OptimizerModule | None = None
182
+ """next module that takes this module's state and continues working on it."""
183
+ self.children: dict[Any, OptimizerModule] = {}
184
+ """children modules."""
185
+ self._initialized = False
186
+ """True if torch.optim.Optimzer.__init__ was called on this meaning this optimizer has parameters."""
187
+ self._default_step_target: Literal['ascent', 'grad', 'closure'] = target
188
+ """'ascent', 'grad' or 'closure'"""
189
+
190
+ self._has_custom_params = False
191
+ """Signifies that :any:`self.set_params` was called on this to set custom params.
192
+ When this is True, when parent calls :any:`_update_child_params_` with this module as child,
193
+ nothing will happen, as this module already has parameters set."""
194
+
195
+ self._passed_params: list[torch.Tensor] | list[dict[str, Any]] | None = None
196
+ """list of parameters or parameter groups that were passed to this module and will get passed to child modules."""
197
+
198
+ self.post_init_hooks: list[Callable[[Any, Self], Any]] = []
199
+ """Hooks that run once after a ModularOptimizer is initialized with this module.
200
+
201
+ Signature:
202
+
203
+ .. code:: py
204
+
205
+ def hook(optimizer: ModularOptimizer, module: OptimizerModule) -> None:
206
+ ...
207
+
208
+ where `module` is this module.
209
+ """
210
+
211
+ def __repr__(self):
212
+ if self._initialized: return super().__repr__()
213
+ return f"uninitialized {self.__class__.__name__}()"
214
+
215
+ def set_params(self, params: ParamsT):
216
+ """
217
+ Set parameters to this module. Use this to set per-parameter group settings.
218
+ """
219
+ self._initialize_(params, set_passed_params = False)
220
+ self._has_custom_params = True
221
+ return self
222
+
223
+ def _initialize_(self, params: ParamsT, set_passed_params: bool):
224
+ """Initializes this optimizer and all children with the given parameters."""
225
+ if isinstance(params, torch.Tensor): raise ValueError("Params must be an iterable of tensors, not torch.Tensor")
226
+ params_list = list(params)
227
+ if set_passed_params: self._passed_params = params_list.copy() # type:ignore
228
+
229
+ # super().__init__, which is torch.optim.Optimizer.__init__,
230
+ # calls self.add_param_group on each param group,
231
+ # which in turn calls _update_child_params_,
232
+ # which calls add_param_group on each child.
233
+ super().__init__(params_list.copy(), self._defaults) # type:ignore
234
+ self._initialized = True
235
+
236
+ def _set_child_(self, name, child: "_Chainable"):
237
+ """Set a child and initialize it's params."""
238
+ if not isinstance(child, OptimizerModule): child = _Chain(child)
239
+ self.children[name] = child
240
+ if self._initialized:
241
+ self._update_child_params_(child)
242
+
243
+ def _update_child_params_(self, child: "OptimizerModule"):
244
+ """Initializes or updates child params with parameters of this module."""
245
+ return self._update_next_module_params_(child)
246
+
247
+ def _set_next_module(self, next_module: "OptimizerModule"):
248
+ """Set next module and initialize it's params."""
249
+ self.next_module = next_module
250
+ if self._initialized:
251
+ self._update_next_module_params_(next_module)
252
+
253
+ def _update_next_module_params_(self, next_module: "OptimizerModule"):
254
+ """Initializes or updates next module params with parameters of this module."""
255
+ # Shouldn't forget that this method is overwritten by some modules
256
+ # So if I update it I need to keep that in mind
257
+ if self._passed_params is None:
258
+ raise RuntimeError(
259
+ f"{self.__class__.__name__} is not initialized, but _update_next_module_params_\
260
+ was called with next_module = {next_module.__class__.__name__}"
261
+ )
262
+
263
+ # if child is not initialized, torch.optim.Optimizer.__init__ is called on it by _initialize_ method
264
+ if not next_module._initialized:
265
+ next_module._initialize_(self._passed_params, set_passed_params=True)
266
+
267
+ # otherwise to avoid calling __init__ multiple twice, we erase the param groups and readd them
268
+ elif not next_module._has_custom_params:
269
+ next_module.param_groups = []
270
+ for group in self._passed_params:
271
+ if isinstance(group, torch.Tensor): group = {"params": group}
272
+ next_module.add_param_group(group)
273
+
274
+ else:
275
+ # still pass per-parameter settings so that they propagate to further modules
276
+ next_module._passed_params = self._passed_params.copy()
277
+
278
+
279
+ def add_param_group(self, param_group: dict[str, Any]) -> None:
280
+ super().add_param_group(param_group)
281
+
282
+ if self.next_module is not None: self._update_next_module_params_(self.next_module)
283
+ for c in self.children.values():
284
+ self._update_child_params_(c)
285
+
286
+ def _update_params_or_step_with_next(self, state: OptimizationState, params: TensorList | None = None) -> _ScalarLoss | None:
287
+ """If this has no children, update params and return loss. Otherwise step with the next module.
288
+
289
+ Optionally pass params to not recreate them if you've already made them.
290
+
291
+ Returns:
292
+ Loss (fx0 or fx0_approx)
293
+ """
294
+ # if this has no children, update params and return loss.
295
+ if self.next_module is None:
296
+ if state.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
297
+ if params is None: params = self.get_params()
298
+ params -= state.ascent # type:ignore
299
+ return state.get_loss()
300
+
301
+ # otherwise pass the updated ascent direction to the child
302
+ return self.next_module.step(state)
303
+
304
+ @torch.no_grad
305
+ def _step_update_closure(self, state: OptimizationState) -> _ScalarLoss | None:
306
+ """Create a new closure which applies the `_update` method and passes it to the next module."""
307
+ if state.closure is None: raise ValueError('If target == "closure", closure must be provided')
308
+
309
+ params = self.get_params()
310
+ closure = state.closure # closure shouldn't reference state attribute because it can be changed
311
+ ascent_direction = state.ascent
312
+
313
+ def update_closure(backward = True):
314
+ loss = _maybe_pass_backward(closure, backward)
315
+
316
+ # on backward, update the ascent direction
317
+ if backward:
318
+ grad = self._update(state, ascent_direction) # type:ignore
319
+ # set new ascent direction as gradients
320
+ # (accumulation doesn't make sense here as closure always calls zero_grad)
321
+ for p, g in zip(params,grad):
322
+ p.grad = g
323
+
324
+ return loss
325
+
326
+ # pass new closure to the child.
327
+ # if self.next_module is None:
328
+ # raise ValueError(f'{self.__class__.__name__} has no child to step with (maybe set "target" from "closure" to something else??).')
329
+
330
+ state.closure = update_closure
331
+ return self._update_params_or_step_with_next(state)
332
+
333
+
334
+ @torch.no_grad
335
+ def _step_update_target(self, state: OptimizationState) -> _ScalarLoss | None:
336
+ """Apply _update method to the ascent direction and pass it to the child, or make a step if child is None."""
337
+ # the following code by default uses `_update` method which simple modules can override.
338
+ # But you can also just override the entire `step`.
339
+
340
+ params = None
341
+
342
+ # update ascent direction
343
+ if self._default_step_target == 'ascent':
344
+ # if this is the first module, it uses the gradients
345
+ if state.grad is None: params = self.get_params()
346
+ t = state.maybe_use_grad_(params)
347
+ state.ascent = self._update(state, t)
348
+
349
+ # update gradients
350
+ elif self._default_step_target == 'grad':
351
+ if params is None: params = self.get_params()
352
+ g = state.maybe_compute_grad_(params)
353
+ g = self._update(state, g)
354
+ state.set_grad_(g, params)
355
+ else:
356
+ raise ValueError(f"Invalid {self._default_step_target = }")
357
+
358
+ # peform an update with the new state, or pass it to the child.
359
+ return self._update_params_or_step_with_next(state, params=params)
360
+
361
+ @torch.no_grad
362
+ def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
363
+ self,
364
+ state: OptimizationState
365
+ ) -> _ScalarLoss | None:
366
+ """Perform a single optimization step to update parameter."""
367
+
368
+ if self._default_step_target == 'closure': return self._step_update_closure(state)
369
+ return self._step_update_target(state)
370
+
371
+ @torch.no_grad
372
+ def _update(self, state: OptimizationState, ascent: TensorList) -> TensorList:
373
+ """Update `ascent_direction` and return the new ascent direction (but it may update it in place).
374
+ Make sure it doesn't return anything from `state` to avoid future modules modifying that in-place.
375
+
376
+ Before calling `_update`, if ascent direction was not provided to `step`, it will be set to the gradients.
377
+
378
+ After generating a new ascent direction with this `_update` method,
379
+ if this module has no child, ascent direction will be subtracted from params.
380
+ Otherwise everything is passed to the child."""
381
+ raise NotImplementedError()
382
+
383
+ def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
384
+ """step with this module and return the ascent as tensorlist"""
385
+ if params is None: params = self.get_params()
386
+ true_next = self.next_module
387
+ self.next_module = _ReturnAscent(params) # type:ignore
388
+ ascent: TensorList = self.step(state) # type:ignore
389
+ self.next_module = true_next
390
+ return ascent
391
+
392
+ def reset_stats(self):
393
+ """Resets running stats of this optimizer such as momentum. This is meant to be used stop all
394
+ momentum when significantly changing model parameters, for example when setting model parameters
395
+ to weighted average every once in a while, like periodic SWA does. Pediodic resetting
396
+ may also be beneficial for some optimizers.
397
+ By default this method completely clears per-parameter state.
398
+ Modules may override this to provide different functionality."""
399
+ for g in self.param_groups:
400
+ for p in g['params']:
401
+ state = self.state[p]
402
+ for k in state.copy().keys(): del state[k]
403
+
404
+
405
+ class _ReturnAscent:
406
+ __slots__ = ('IS_LR_MODULE', 'params', 'children', 'next_module', )
407
+ def __init__(self, params):
408
+ self.params = params
409
+ self.IS_LR_MODULE = False
410
+
411
+ self.children = {}
412
+ self.next_module = None
413
+
414
+ @torch.no_grad
415
+ def step(self, state: OptimizationState) -> TensorList: # type:ignore
416
+ update = state.maybe_use_grad_(self.params) # this will execute the closure which might be modified
417
+ return update
418
+
419
+
420
+ class _MaybeReturnAscent(OptimizerModule):
421
+ """utility module that either returns ascent or updates the parameters depending on `_return_ascent`, used in Chain."""
422
+ def __init__(self):
423
+ super().__init__({})
424
+ self._return_ascent = False
425
+
426
+ @torch.no_grad
427
+ def step(self, state: OptimizationState):
428
+ assert self.next_module is None, self.next_module
429
+
430
+ if self._return_ascent:
431
+ return state.ascent
432
+
433
+ return self._update_params_or_step_with_next(state)
434
+
435
+ _Chainable = OptimizerModule | Iterable[OptimizerModule]
436
+
437
+ class _Chain(OptimizerModule):
438
+ """
439
+ Utility module that chains multiple modules together, usually used by other modules.
440
+ """
441
+ def __init__(self, *modules: _Chainable):
442
+ super().__init__({})
443
+ flat_modules: list[OptimizerModule] = flatten(modules)
444
+
445
+ if any(not hasattr(i, "step") for i in flat_modules):
446
+ raise TypeError(f"One of the modules is not an OptimizerModule, got {[i.__class__.__name__ for i in flat_modules]}")
447
+
448
+ # first module is chain's child, second module is first module's child, etc
449
+ self._set_child_('first', flat_modules[0])
450
+ if len(flat_modules) > 1:
451
+ for i, m in enumerate(flat_modules[:-1]):
452
+ m._set_next_module(flat_modules[i+1])
453
+
454
+ self._last_module = flat_modules[-1]
455
+
456
+ self._chain_modules = flat_modules
457
+
458
+ @torch.no_grad
459
+ def step(self, state: OptimizationState):
460
+ # no next module, step with the child
461
+ if self.next_module is None:
462
+ return self.children['first'].step(state)
463
+
464
+ # return ascent and pass it to next module
465
+ # we do this because updating parameters directly is often more efficient
466
+ params = self.get_params()
467
+ self._last_module.next_module = _ReturnAscent(params) # type:ignore
468
+ state.ascent: TensorList = self.children['first'].step(state) # type:ignore
469
+ self._last_module.next_module = None
470
+
471
+ return self._update_params_or_step_with_next(state)