torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
torchzero/__init__.py
ADDED
torchzero/core/module.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import warnings
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
from typing_extensions import Self, TypeAlias
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.optim.optimizer import ParamsT
|
|
10
|
+
|
|
11
|
+
from ..tensorlist import TensorList
|
|
12
|
+
from ..utils.python_tools import _ScalarLoss, flatten
|
|
13
|
+
|
|
14
|
+
from .tensorlist_optimizer import (
|
|
15
|
+
TensorListOptimizer,
|
|
16
|
+
_ClosureType,
|
|
17
|
+
_maybe_pass_backward,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def _get_loss(fx0, fx0_approx):
|
|
21
|
+
"""Returns fx0 if it is not None otherwise fx0_approx"""
|
|
22
|
+
if fx0 is None: return fx0_approx
|
|
23
|
+
return fx0
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OptimizationState:
|
|
27
|
+
"""Holds optimization state. This is usually automatically created by :any:`torchzero.optim.Modular`."""
|
|
28
|
+
def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
|
|
29
|
+
|
|
30
|
+
self.closure: _ClosureType | None = closure
|
|
31
|
+
"""A closure that reevaluates the model and returns the loss.
|
|
32
|
+
The closure should accept `backward` boolean argument that is True by default, which,
|
|
33
|
+
if True, sets `.grad` attributes of all learnable params, for example via `loss.backward()`.
|
|
34
|
+
Closure can be None for most first order optimizers."""
|
|
35
|
+
|
|
36
|
+
self.ascent: TensorList | None = None
|
|
37
|
+
"""Ascent direction, for example the gradients.
|
|
38
|
+
Will be None on the first module in the chain.
|
|
39
|
+
May remain none for modules that create a new closure."""
|
|
40
|
+
|
|
41
|
+
self.fx0: _ScalarLoss | None = None
|
|
42
|
+
"""Loss value strictly with initial parameters of the current step.
|
|
43
|
+
If initial parameters have not been evaluated, this should be None."""
|
|
44
|
+
|
|
45
|
+
self.fx0_approx: _ScalarLoss | None = None
|
|
46
|
+
"""Loss value, could be sampled nearby the initial parameters.
|
|
47
|
+
This is mainly used as the return value of the step method when fx0 is None."""
|
|
48
|
+
|
|
49
|
+
self.grad: TensorList | None = None
|
|
50
|
+
"""Gradient if it has been computed, otherwise None.
|
|
51
|
+
|
|
52
|
+
Gradient must be evaluated strictly with initial parameters of the current step"""
|
|
53
|
+
|
|
54
|
+
self.model = model
|
|
55
|
+
"""model itself (torch.nn.Module) if it was passed, otherwise None."""
|
|
56
|
+
|
|
57
|
+
self.post_step_hooks = []
|
|
58
|
+
"""callables that get executed after each step. Used by periodic SWA to reset momentum when setting model parameters to SWA.
|
|
59
|
+
|
|
60
|
+
Signature:
|
|
61
|
+
|
|
62
|
+
.. code:: py
|
|
63
|
+
|
|
64
|
+
def hook(optimizer: ModularOptimizer, state: OptimizationState) -> None:
|
|
65
|
+
...
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def maybe_compute_grad_(self, params: TensorList | None) -> TensorList:
|
|
69
|
+
"""Computes gradient if it hasn't been computed already, and returns it"""
|
|
70
|
+
if self.grad is None:
|
|
71
|
+
if params is None: raise ValueError()
|
|
72
|
+
if self.closure is not None:
|
|
73
|
+
with torch.enable_grad(): self.fx0 = self.closure() # pylint:disable = not-callable (???)
|
|
74
|
+
self.grad = params.ensure_grad_().grad
|
|
75
|
+
|
|
76
|
+
return self.grad
|
|
77
|
+
|
|
78
|
+
def maybe_use_grad_(self, params: TensorList | None) -> TensorList:
|
|
79
|
+
"""If ascent direction is None, use cloned gradient as ascent direction and returns it.
|
|
80
|
+
Otherwise does nothing and returns existing ascent direction.
|
|
81
|
+
If gradient hasn't been computed, this also sets `fx0`."""
|
|
82
|
+
if self.ascent is None:
|
|
83
|
+
self.ascent = self.maybe_compute_grad_(params).clone()
|
|
84
|
+
|
|
85
|
+
return self.ascent
|
|
86
|
+
|
|
87
|
+
def set_grad_(self, grad: TensorList, params: TensorList):
|
|
88
|
+
"""Sets gradient to this state and to params"""
|
|
89
|
+
self.grad = grad
|
|
90
|
+
params.set_grad_(grad)
|
|
91
|
+
|
|
92
|
+
def evaluate_fx0_(self, backward=True) -> _ScalarLoss:
|
|
93
|
+
"""if fx0 is None or if backward is True and self.grad is None, evaluates closure and sets them. Returns fx0"""
|
|
94
|
+
if self.fx0 is not None:
|
|
95
|
+
if backward and self.grad is None:
|
|
96
|
+
warnings.warn('evaluating fx0 with backward=True after it has already been evaluated with backward=False. Something may be inefficient.')
|
|
97
|
+
with torch.enable_grad(): self.closure() # set grad #type:ignore
|
|
98
|
+
return self.fx0
|
|
99
|
+
|
|
100
|
+
if self.closure is None: raise ValueError("Closure is None")
|
|
101
|
+
loss = self.fx0 = _maybe_pass_backward(self.closure, backward)
|
|
102
|
+
return loss
|
|
103
|
+
|
|
104
|
+
def evaluate_fx0_approx_(self, backward=True) -> _ScalarLoss:
|
|
105
|
+
"""evaluates closure, sets self.fx0_approx and returns it"""
|
|
106
|
+
if self.closure is None: raise ValueError("Closure is None")
|
|
107
|
+
loss = self.fx0_approx = _maybe_pass_backward(self.closure, backward)
|
|
108
|
+
return loss
|
|
109
|
+
|
|
110
|
+
def get_loss(self):
|
|
111
|
+
"""Returns fx0 if it is not None otherwise fx0_approx"""
|
|
112
|
+
if self.fx0 is None: return self.fx0_approx
|
|
113
|
+
return self.fx0
|
|
114
|
+
|
|
115
|
+
def copy(self, clone_ascent = False):
|
|
116
|
+
"""Copy this optimization state. This will not clone anything other than optionally ascent_direction.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
clone_ascent (bool, optional): Whether to clone ascent direction. Defaults to False.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
A copy of this OptimizationState.
|
|
123
|
+
"""
|
|
124
|
+
state = OptimizationState(self.closure, self.model)
|
|
125
|
+
state.fx0 = self.fx0
|
|
126
|
+
state.fx0_approx = self.fx0_approx
|
|
127
|
+
state.grad = self.grad
|
|
128
|
+
|
|
129
|
+
if clone_ascent and self.ascent is not None: state.ascent = self.ascent.clone()
|
|
130
|
+
else: state.ascent = self.ascent
|
|
131
|
+
|
|
132
|
+
return state
|
|
133
|
+
|
|
134
|
+
def update_attrs_(self, state: "OptimizationState"):
|
|
135
|
+
"""Updates attributes of this state with attributes of another state.
|
|
136
|
+
|
|
137
|
+
This updates `grad`, `fx0` and `fx0_approx`."""
|
|
138
|
+
if state.grad is not None: self.grad = state.grad
|
|
139
|
+
if state.fx0 is not None: self.fx0 = state.fx0
|
|
140
|
+
if state.fx0_approx is not None: self.fx0_approx = state.fx0_approx
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def add_post_step_hook(self, hook: Callable):
|
|
144
|
+
"""add a hook that runs after each step. The hook should look like this:
|
|
145
|
+
.. code:: py
|
|
146
|
+
def hook(optimizer: tz.optim.Modular, state: tz.core.OptimizationState): ...
|
|
147
|
+
"""
|
|
148
|
+
self.post_step_hooks.append(hook)
|
|
149
|
+
|
|
150
|
+
_Targets = Literal['ascent', 'grad', 'closure',]
|
|
151
|
+
class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
|
|
152
|
+
r"""Base class for all modules.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
defaults (dict): dictionary with default parameters for the module.
|
|
156
|
+
target (str, optional):
|
|
157
|
+
determines how _update method is used in the default step method.
|
|
158
|
+
|
|
159
|
+
"ascent" - it updates the ascent
|
|
160
|
+
|
|
161
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
162
|
+
|
|
163
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
164
|
+
"""
|
|
165
|
+
IS_LR_MODULE = False
|
|
166
|
+
def __init__(self, defaults: dict[str, Any], target: Literal['ascent', 'grad', 'closure',] = 'ascent'): # pylint:disable = super-init-not-called
|
|
167
|
+
# there can only be 1 LR module, which is placed in the appropriate location among other modules.
|
|
168
|
+
# scheduling and per-parameter "lr" options will be routed to that module.
|
|
169
|
+
# otherwise, since many update rules like Adam have baked in lr, if multiple such modules are used,
|
|
170
|
+
# any lr modification gets applied multiple times.
|
|
171
|
+
# Some optimizers will automatically be fused if followed an LR() module (only LR specifically is supported).
|
|
172
|
+
if not self.IS_LR_MODULE:
|
|
173
|
+
if 'lr' in defaults:
|
|
174
|
+
warnings.warn(
|
|
175
|
+
f'{self.__class__.__name__} got an "lr" default, but it is not an LR module.\
|
|
176
|
+
To support lr scheduling and per-parameter options, rename "lr" to "alpha" and set the default value to 1.\
|
|
177
|
+
If this is a learning rate module, set a class attribute `IS_LR_MODULE=True`.'
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
self._defaults = defaults
|
|
181
|
+
self.next_module: OptimizerModule | None = None
|
|
182
|
+
"""next module that takes this module's state and continues working on it."""
|
|
183
|
+
self.children: dict[Any, OptimizerModule] = {}
|
|
184
|
+
"""children modules."""
|
|
185
|
+
self._initialized = False
|
|
186
|
+
"""True if torch.optim.Optimzer.__init__ was called on this meaning this optimizer has parameters."""
|
|
187
|
+
self._default_step_target: Literal['ascent', 'grad', 'closure'] = target
|
|
188
|
+
"""'ascent', 'grad' or 'closure'"""
|
|
189
|
+
|
|
190
|
+
self._has_custom_params = False
|
|
191
|
+
"""Signifies that :any:`self.set_params` was called on this to set custom params.
|
|
192
|
+
When this is True, when parent calls :any:`_update_child_params_` with this module as child,
|
|
193
|
+
nothing will happen, as this module already has parameters set."""
|
|
194
|
+
|
|
195
|
+
self._passed_params: list[torch.Tensor] | list[dict[str, Any]] | None = None
|
|
196
|
+
"""list of parameters or parameter groups that were passed to this module and will get passed to child modules."""
|
|
197
|
+
|
|
198
|
+
self.post_init_hooks: list[Callable[[Any, Self], Any]] = []
|
|
199
|
+
"""Hooks that run once after a ModularOptimizer is initialized with this module.
|
|
200
|
+
|
|
201
|
+
Signature:
|
|
202
|
+
|
|
203
|
+
.. code:: py
|
|
204
|
+
|
|
205
|
+
def hook(optimizer: ModularOptimizer, module: OptimizerModule) -> None:
|
|
206
|
+
...
|
|
207
|
+
|
|
208
|
+
where `module` is this module.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __repr__(self):
|
|
212
|
+
if self._initialized: return super().__repr__()
|
|
213
|
+
return f"uninitialized {self.__class__.__name__}()"
|
|
214
|
+
|
|
215
|
+
def set_params(self, params: ParamsT):
|
|
216
|
+
"""
|
|
217
|
+
Set parameters to this module. Use this to set per-parameter group settings.
|
|
218
|
+
"""
|
|
219
|
+
self._initialize_(params, set_passed_params = False)
|
|
220
|
+
self._has_custom_params = True
|
|
221
|
+
return self
|
|
222
|
+
|
|
223
|
+
def _initialize_(self, params: ParamsT, set_passed_params: bool):
|
|
224
|
+
"""Initializes this optimizer and all children with the given parameters."""
|
|
225
|
+
if isinstance(params, torch.Tensor): raise ValueError("Params must be an iterable of tensors, not torch.Tensor")
|
|
226
|
+
params_list = list(params)
|
|
227
|
+
if set_passed_params: self._passed_params = params_list.copy() # type:ignore
|
|
228
|
+
|
|
229
|
+
# super().__init__, which is torch.optim.Optimizer.__init__,
|
|
230
|
+
# calls self.add_param_group on each param group,
|
|
231
|
+
# which in turn calls _update_child_params_,
|
|
232
|
+
# which calls add_param_group on each child.
|
|
233
|
+
super().__init__(params_list.copy(), self._defaults) # type:ignore
|
|
234
|
+
self._initialized = True
|
|
235
|
+
|
|
236
|
+
def _set_child_(self, name, child: "_Chainable"):
|
|
237
|
+
"""Set a child and initialize it's params."""
|
|
238
|
+
if not isinstance(child, OptimizerModule): child = _Chain(child)
|
|
239
|
+
self.children[name] = child
|
|
240
|
+
if self._initialized:
|
|
241
|
+
self._update_child_params_(child)
|
|
242
|
+
|
|
243
|
+
def _update_child_params_(self, child: "OptimizerModule"):
|
|
244
|
+
"""Initializes or updates child params with parameters of this module."""
|
|
245
|
+
return self._update_next_module_params_(child)
|
|
246
|
+
|
|
247
|
+
def _set_next_module(self, next_module: "OptimizerModule"):
|
|
248
|
+
"""Set next module and initialize it's params."""
|
|
249
|
+
self.next_module = next_module
|
|
250
|
+
if self._initialized:
|
|
251
|
+
self._update_next_module_params_(next_module)
|
|
252
|
+
|
|
253
|
+
def _update_next_module_params_(self, next_module: "OptimizerModule"):
|
|
254
|
+
"""Initializes or updates next module params with parameters of this module."""
|
|
255
|
+
# Shouldn't forget that this method is overwritten by some modules
|
|
256
|
+
# So if I update it I need to keep that in mind
|
|
257
|
+
if self._passed_params is None:
|
|
258
|
+
raise RuntimeError(
|
|
259
|
+
f"{self.__class__.__name__} is not initialized, but _update_next_module_params_\
|
|
260
|
+
was called with next_module = {next_module.__class__.__name__}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# if child is not initialized, torch.optim.Optimizer.__init__ is called on it by _initialize_ method
|
|
264
|
+
if not next_module._initialized:
|
|
265
|
+
next_module._initialize_(self._passed_params, set_passed_params=True)
|
|
266
|
+
|
|
267
|
+
# otherwise to avoid calling __init__ multiple twice, we erase the param groups and readd them
|
|
268
|
+
elif not next_module._has_custom_params:
|
|
269
|
+
next_module.param_groups = []
|
|
270
|
+
for group in self._passed_params:
|
|
271
|
+
if isinstance(group, torch.Tensor): group = {"params": group}
|
|
272
|
+
next_module.add_param_group(group)
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
# still pass per-parameter settings so that they propagate to further modules
|
|
276
|
+
next_module._passed_params = self._passed_params.copy()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
|
280
|
+
super().add_param_group(param_group)
|
|
281
|
+
|
|
282
|
+
if self.next_module is not None: self._update_next_module_params_(self.next_module)
|
|
283
|
+
for c in self.children.values():
|
|
284
|
+
self._update_child_params_(c)
|
|
285
|
+
|
|
286
|
+
def _update_params_or_step_with_next(self, state: OptimizationState, params: TensorList | None = None) -> _ScalarLoss | None:
|
|
287
|
+
"""If this has no children, update params and return loss. Otherwise step with the next module.
|
|
288
|
+
|
|
289
|
+
Optionally pass params to not recreate them if you've already made them.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Loss (fx0 or fx0_approx)
|
|
293
|
+
"""
|
|
294
|
+
# if this has no children, update params and return loss.
|
|
295
|
+
if self.next_module is None:
|
|
296
|
+
if state.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
|
|
297
|
+
if params is None: params = self.get_params()
|
|
298
|
+
params -= state.ascent # type:ignore
|
|
299
|
+
return state.get_loss()
|
|
300
|
+
|
|
301
|
+
# otherwise pass the updated ascent direction to the child
|
|
302
|
+
return self.next_module.step(state)
|
|
303
|
+
|
|
304
|
+
@torch.no_grad
|
|
305
|
+
def _step_update_closure(self, state: OptimizationState) -> _ScalarLoss | None:
|
|
306
|
+
"""Create a new closure which applies the `_update` method and passes it to the next module."""
|
|
307
|
+
if state.closure is None: raise ValueError('If target == "closure", closure must be provided')
|
|
308
|
+
|
|
309
|
+
params = self.get_params()
|
|
310
|
+
closure = state.closure # closure shouldn't reference state attribute because it can be changed
|
|
311
|
+
ascent_direction = state.ascent
|
|
312
|
+
|
|
313
|
+
def update_closure(backward = True):
|
|
314
|
+
loss = _maybe_pass_backward(closure, backward)
|
|
315
|
+
|
|
316
|
+
# on backward, update the ascent direction
|
|
317
|
+
if backward:
|
|
318
|
+
grad = self._update(state, ascent_direction) # type:ignore
|
|
319
|
+
# set new ascent direction as gradients
|
|
320
|
+
# (accumulation doesn't make sense here as closure always calls zero_grad)
|
|
321
|
+
for p, g in zip(params,grad):
|
|
322
|
+
p.grad = g
|
|
323
|
+
|
|
324
|
+
return loss
|
|
325
|
+
|
|
326
|
+
# pass new closure to the child.
|
|
327
|
+
# if self.next_module is None:
|
|
328
|
+
# raise ValueError(f'{self.__class__.__name__} has no child to step with (maybe set "target" from "closure" to something else??).')
|
|
329
|
+
|
|
330
|
+
state.closure = update_closure
|
|
331
|
+
return self._update_params_or_step_with_next(state)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@torch.no_grad
|
|
335
|
+
def _step_update_target(self, state: OptimizationState) -> _ScalarLoss | None:
|
|
336
|
+
"""Apply _update method to the ascent direction and pass it to the child, or make a step if child is None."""
|
|
337
|
+
# the following code by default uses `_update` method which simple modules can override.
|
|
338
|
+
# But you can also just override the entire `step`.
|
|
339
|
+
|
|
340
|
+
params = None
|
|
341
|
+
|
|
342
|
+
# update ascent direction
|
|
343
|
+
if self._default_step_target == 'ascent':
|
|
344
|
+
# if this is the first module, it uses the gradients
|
|
345
|
+
if state.grad is None: params = self.get_params()
|
|
346
|
+
t = state.maybe_use_grad_(params)
|
|
347
|
+
state.ascent = self._update(state, t)
|
|
348
|
+
|
|
349
|
+
# update gradients
|
|
350
|
+
elif self._default_step_target == 'grad':
|
|
351
|
+
if params is None: params = self.get_params()
|
|
352
|
+
g = state.maybe_compute_grad_(params)
|
|
353
|
+
g = self._update(state, g)
|
|
354
|
+
state.set_grad_(g, params)
|
|
355
|
+
else:
|
|
356
|
+
raise ValueError(f"Invalid {self._default_step_target = }")
|
|
357
|
+
|
|
358
|
+
# peform an update with the new state, or pass it to the child.
|
|
359
|
+
return self._update_params_or_step_with_next(state, params=params)
|
|
360
|
+
|
|
361
|
+
@torch.no_grad
|
|
362
|
+
def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
|
|
363
|
+
self,
|
|
364
|
+
state: OptimizationState
|
|
365
|
+
) -> _ScalarLoss | None:
|
|
366
|
+
"""Perform a single optimization step to update parameter."""
|
|
367
|
+
|
|
368
|
+
if self._default_step_target == 'closure': return self._step_update_closure(state)
|
|
369
|
+
return self._step_update_target(state)
|
|
370
|
+
|
|
371
|
+
@torch.no_grad
|
|
372
|
+
def _update(self, state: OptimizationState, ascent: TensorList) -> TensorList:
|
|
373
|
+
"""Update `ascent_direction` and return the new ascent direction (but it may update it in place).
|
|
374
|
+
Make sure it doesn't return anything from `state` to avoid future modules modifying that in-place.
|
|
375
|
+
|
|
376
|
+
Before calling `_update`, if ascent direction was not provided to `step`, it will be set to the gradients.
|
|
377
|
+
|
|
378
|
+
After generating a new ascent direction with this `_update` method,
|
|
379
|
+
if this module has no child, ascent direction will be subtracted from params.
|
|
380
|
+
Otherwise everything is passed to the child."""
|
|
381
|
+
raise NotImplementedError()
|
|
382
|
+
|
|
383
|
+
def return_ascent(self, state: OptimizationState, params=None) -> TensorList:
|
|
384
|
+
"""step with this module and return the ascent as tensorlist"""
|
|
385
|
+
if params is None: params = self.get_params()
|
|
386
|
+
true_next = self.next_module
|
|
387
|
+
self.next_module = _ReturnAscent(params) # type:ignore
|
|
388
|
+
ascent: TensorList = self.step(state) # type:ignore
|
|
389
|
+
self.next_module = true_next
|
|
390
|
+
return ascent
|
|
391
|
+
|
|
392
|
+
def reset_stats(self):
|
|
393
|
+
"""Resets running stats of this optimizer such as momentum. This is meant to be used stop all
|
|
394
|
+
momentum when significantly changing model parameters, for example when setting model parameters
|
|
395
|
+
to weighted average every once in a while, like periodic SWA does. Pediodic resetting
|
|
396
|
+
may also be beneficial for some optimizers.
|
|
397
|
+
By default this method completely clears per-parameter state.
|
|
398
|
+
Modules may override this to provide different functionality."""
|
|
399
|
+
for g in self.param_groups:
|
|
400
|
+
for p in g['params']:
|
|
401
|
+
state = self.state[p]
|
|
402
|
+
for k in state.copy().keys(): del state[k]
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class _ReturnAscent:
|
|
406
|
+
__slots__ = ('IS_LR_MODULE', 'params', 'children', 'next_module', )
|
|
407
|
+
def __init__(self, params):
|
|
408
|
+
self.params = params
|
|
409
|
+
self.IS_LR_MODULE = False
|
|
410
|
+
|
|
411
|
+
self.children = {}
|
|
412
|
+
self.next_module = None
|
|
413
|
+
|
|
414
|
+
@torch.no_grad
|
|
415
|
+
def step(self, state: OptimizationState) -> TensorList: # type:ignore
|
|
416
|
+
update = state.maybe_use_grad_(self.params) # this will execute the closure which might be modified
|
|
417
|
+
return update
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
class _MaybeReturnAscent(OptimizerModule):
|
|
421
|
+
"""utility module that either returns ascent or updates the parameters depending on `_return_ascent`, used in Chain."""
|
|
422
|
+
def __init__(self):
|
|
423
|
+
super().__init__({})
|
|
424
|
+
self._return_ascent = False
|
|
425
|
+
|
|
426
|
+
@torch.no_grad
|
|
427
|
+
def step(self, state: OptimizationState):
|
|
428
|
+
assert self.next_module is None, self.next_module
|
|
429
|
+
|
|
430
|
+
if self._return_ascent:
|
|
431
|
+
return state.ascent
|
|
432
|
+
|
|
433
|
+
return self._update_params_or_step_with_next(state)
|
|
434
|
+
|
|
435
|
+
_Chainable = OptimizerModule | Iterable[OptimizerModule]
|
|
436
|
+
|
|
437
|
+
class _Chain(OptimizerModule):
|
|
438
|
+
"""
|
|
439
|
+
Utility module that chains multiple modules together, usually used by other modules.
|
|
440
|
+
"""
|
|
441
|
+
def __init__(self, *modules: _Chainable):
|
|
442
|
+
super().__init__({})
|
|
443
|
+
flat_modules: list[OptimizerModule] = flatten(modules)
|
|
444
|
+
|
|
445
|
+
if any(not hasattr(i, "step") for i in flat_modules):
|
|
446
|
+
raise TypeError(f"One of the modules is not an OptimizerModule, got {[i.__class__.__name__ for i in flat_modules]}")
|
|
447
|
+
|
|
448
|
+
# first module is chain's child, second module is first module's child, etc
|
|
449
|
+
self._set_child_('first', flat_modules[0])
|
|
450
|
+
if len(flat_modules) > 1:
|
|
451
|
+
for i, m in enumerate(flat_modules[:-1]):
|
|
452
|
+
m._set_next_module(flat_modules[i+1])
|
|
453
|
+
|
|
454
|
+
self._last_module = flat_modules[-1]
|
|
455
|
+
|
|
456
|
+
self._chain_modules = flat_modules
|
|
457
|
+
|
|
458
|
+
@torch.no_grad
|
|
459
|
+
def step(self, state: OptimizationState):
|
|
460
|
+
# no next module, step with the child
|
|
461
|
+
if self.next_module is None:
|
|
462
|
+
return self.children['first'].step(state)
|
|
463
|
+
|
|
464
|
+
# return ascent and pass it to next module
|
|
465
|
+
# we do this because updating parameters directly is often more efficient
|
|
466
|
+
params = self.get_params()
|
|
467
|
+
self._last_module.next_module = _ReturnAscent(params) # type:ignore
|
|
468
|
+
state.ascent: TensorList = self.children['first'].step(state) # type:ignore
|
|
469
|
+
self._last_module.next_module = None
|
|
470
|
+
|
|
471
|
+
return self._update_params_or_step_with_next(state)
|