torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/core/module.py
CHANGED
|
@@ -1,271 +1,20 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
|
-
from collections.abc import Callable, Iterable,
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, final, overload, Literal, cast
|
|
4
|
+
from collections.abc import Callable, Iterable, Sequence
|
|
5
|
+
from typing import Any, overload, TYPE_CHECKING
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
|
-
from ..
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
_make_param_groups,
|
|
15
|
-
get_state_vals,
|
|
16
|
-
)
|
|
17
|
-
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
18
|
-
from ..utils.python_tools import flatten
|
|
19
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _closure_backward(closure, params, retain_graph, create_graph):
|
|
23
|
-
with torch.enable_grad():
|
|
24
|
-
if not (retain_graph or create_graph):
|
|
25
|
-
return closure()
|
|
26
|
-
|
|
27
|
-
for p in params: p.grad = None
|
|
28
|
-
loss = closure(False)
|
|
29
|
-
grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
|
|
30
|
-
for p,g in zip(params,grad): p.grad = g
|
|
31
|
-
return loss
|
|
32
|
-
|
|
33
|
-
# region Vars
|
|
34
|
-
# ----------------------------------- var ----------------------------------- #
|
|
35
|
-
class Var:
|
|
36
|
-
"""
|
|
37
|
-
Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
|
|
38
|
-
Modules take in a ``Var`` object, modify and it is passed to the next module.
|
|
39
|
-
|
|
40
|
-
"""
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
params: list[torch.Tensor],
|
|
44
|
-
closure: Callable | None,
|
|
45
|
-
model: torch.nn.Module | None,
|
|
46
|
-
current_step: int,
|
|
47
|
-
parent: "Var | None" = None,
|
|
48
|
-
modular: "Modular | None" = None,
|
|
49
|
-
loss: torch.Tensor | None = None,
|
|
50
|
-
storage: dict | None = None,
|
|
51
|
-
):
|
|
52
|
-
self.params: list[torch.Tensor] = params
|
|
53
|
-
"""List of all parameters with requires_grad = True."""
|
|
54
|
-
|
|
55
|
-
self.closure = closure
|
|
56
|
-
"""A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
|
|
57
|
-
|
|
58
|
-
self.model = model
|
|
59
|
-
"""torch.nn.Module object of the model, None if it wasn't specified."""
|
|
60
|
-
|
|
61
|
-
self.current_step: int = current_step
|
|
62
|
-
"""global current step, starts at 0. This may not correspond to module current step,
|
|
63
|
-
for example a module may step every 10 global steps."""
|
|
64
|
-
|
|
65
|
-
self.parent: "Var | None" = parent
|
|
66
|
-
"""parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
67
|
-
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
68
|
-
e.g. when projecting."""
|
|
9
|
+
from ..linalg.linear_operator import LinearOperator
|
|
10
|
+
from ..utils.optimizer import Init, ListLike, get_state_vals
|
|
11
|
+
from ..utils.params import Params, _make_param_groups
|
|
12
|
+
from .functional import step_tensors
|
|
69
13
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
self.update: list[torch.Tensor] | None = None
|
|
74
|
-
"""
|
|
75
|
-
current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
|
|
76
|
-
|
|
77
|
-
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
78
|
-
|
|
79
|
-
At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
|
|
80
|
-
gradient will be used and calculated if needed.
|
|
81
|
-
"""
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .objective import Objective
|
|
82
16
|
|
|
83
|
-
self.grad: list[torch.Tensor] | None = None
|
|
84
|
-
"""gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
|
|
85
17
|
|
|
86
|
-
self.loss: torch.Tensor | Any | None = loss
|
|
87
|
-
"""loss with current parameters."""
|
|
88
|
-
|
|
89
|
-
self.loss_approx: torch.Tensor | Any | None = None
|
|
90
|
-
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
91
|
-
whereas some other modules require loss strictly at current point."""
|
|
92
|
-
|
|
93
|
-
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
94
|
-
"""list of functions to be called after optimizer step.
|
|
95
|
-
|
|
96
|
-
This attribute should always be modified in-place (using ``append`` or ``extend``).
|
|
97
|
-
|
|
98
|
-
The signature is:
|
|
99
|
-
|
|
100
|
-
```python
|
|
101
|
-
def hook(optimizer: Modular, var: Vars): ...
|
|
102
|
-
```
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
self.is_last: bool = False
|
|
106
|
-
"""
|
|
107
|
-
Indicates that current module is either last or next-to-last before a learning rate module.
|
|
108
|
-
This is always False if current module has children or is a child.
|
|
109
|
-
This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
|
|
110
|
-
"""
|
|
111
|
-
|
|
112
|
-
self.nested_is_last: bool = False
|
|
113
|
-
"""
|
|
114
|
-
Indicates that current module is either last or next-to-last before a learning rate module, for modules
|
|
115
|
-
that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
|
|
116
|
-
a child of a last module may also receive ``var.nested_is_last=True``.
|
|
117
|
-
"""
|
|
118
|
-
|
|
119
|
-
self.last_module_lrs: list[float] | None = None
|
|
120
|
-
"""
|
|
121
|
-
List of per-parameter learning rates if current module is next-to-last before a
|
|
122
|
-
learning rate module, otherwise this is set to None. Ignore this unless you are manually applying
|
|
123
|
-
update to parameters.
|
|
124
|
-
"""
|
|
125
|
-
|
|
126
|
-
self.stop: bool = False
|
|
127
|
-
"""if True, all following modules will be skipped.
|
|
128
|
-
If this module is a child, it only affects modules at the same level (in the same Chain)."""
|
|
129
|
-
|
|
130
|
-
self.skip_update: bool = False
|
|
131
|
-
"""if True, the parameters will not be updated."""
|
|
132
|
-
|
|
133
|
-
# self.storage: dict = {}
|
|
134
|
-
# """Storage for any other data, such as hessian estimates, etc."""
|
|
135
|
-
|
|
136
|
-
self.attrs: dict = {}
|
|
137
|
-
"""attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
|
|
138
|
-
|
|
139
|
-
if storage is None: storage = {}
|
|
140
|
-
self.storage: dict = storage
|
|
141
|
-
"""additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
|
|
142
|
-
|
|
143
|
-
self.should_terminate: bool | None = None
|
|
144
|
-
"""termination criteria, Modular.should_terminate is set to this after each step if not None"""
|
|
145
|
-
|
|
146
|
-
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
147
|
-
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
|
|
148
|
-
Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
|
|
149
|
-
if self.loss is None:
|
|
150
|
-
|
|
151
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
152
|
-
if backward:
|
|
153
|
-
with torch.enable_grad():
|
|
154
|
-
self.loss = self.loss_approx = _closure_backward(
|
|
155
|
-
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
# initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
|
|
159
|
-
# it is technically a more correct approach for when some parameters conditionally receive gradients
|
|
160
|
-
# and in this case it shouldn't be slower.
|
|
161
|
-
|
|
162
|
-
# next time closure() is called, it will set grad to None.
|
|
163
|
-
# zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
|
|
164
|
-
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
165
|
-
else:
|
|
166
|
-
self.loss = self.loss_approx = self.closure(False)
|
|
167
|
-
|
|
168
|
-
# if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
|
|
169
|
-
# and now it is requested to be evaluated with backward.
|
|
170
|
-
if backward and self.grad is None:
|
|
171
|
-
warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
|
|
172
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
173
|
-
|
|
174
|
-
with torch.enable_grad():
|
|
175
|
-
self.loss = self.loss_approx = _closure_backward(
|
|
176
|
-
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
177
|
-
)
|
|
178
|
-
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
179
|
-
|
|
180
|
-
# set parent grad
|
|
181
|
-
if self.parent is not None:
|
|
182
|
-
# the way projections/split work, they make a new closure which evaluates original
|
|
183
|
-
# closure and projects the gradient, and set it as their var.closure.
|
|
184
|
-
# then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
|
|
185
|
-
# and we set it to parent var here.
|
|
186
|
-
if self.parent.loss is None: self.parent.loss = self.loss
|
|
187
|
-
if self.parent.grad is None and backward:
|
|
188
|
-
if all(p.grad is None for p in self.parent.params):
|
|
189
|
-
warnings.warn("Parent grad is None after backward.")
|
|
190
|
-
self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
|
|
191
|
-
|
|
192
|
-
return self.loss # type:ignore
|
|
193
|
-
|
|
194
|
-
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
195
|
-
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
196
|
-
``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
|
|
197
|
-
if self.grad is None:
|
|
198
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
199
|
-
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
200
|
-
|
|
201
|
-
assert self.grad is not None
|
|
202
|
-
return self.grad
|
|
203
|
-
|
|
204
|
-
def get_update(self) -> list[torch.Tensor]:
|
|
205
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
|
|
206
|
-
Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
|
|
207
|
-
Do not call this at perturbed parameters."""
|
|
208
|
-
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
209
|
-
return self.update
|
|
210
|
-
|
|
211
|
-
def clone(self, clone_update: bool, parent: "Var | None" = None):
|
|
212
|
-
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
|
|
213
|
-
|
|
214
|
-
Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
|
|
215
|
-
|
|
216
|
-
Setting ``parent`` is only if clone's parameters are something different,
|
|
217
|
-
while clone's closure referes to the same objective but with a "view" on parameters.
|
|
218
|
-
"""
|
|
219
|
-
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
|
|
220
|
-
|
|
221
|
-
if clone_update and self.update is not None:
|
|
222
|
-
copy.update = [u.clone() for u in self.update]
|
|
223
|
-
else:
|
|
224
|
-
copy.update = self.update
|
|
225
|
-
|
|
226
|
-
copy.grad = self.grad
|
|
227
|
-
copy.loss = self.loss
|
|
228
|
-
copy.loss_approx = self.loss_approx
|
|
229
|
-
copy.closure = self.closure
|
|
230
|
-
copy.post_step_hooks = self.post_step_hooks
|
|
231
|
-
copy.stop = self.stop
|
|
232
|
-
copy.skip_update = self.skip_update
|
|
233
|
-
|
|
234
|
-
copy.modular = self.modular
|
|
235
|
-
copy.attrs = self.attrs
|
|
236
|
-
copy.storage = self.storage
|
|
237
|
-
copy.should_terminate = self.should_terminate
|
|
238
|
-
|
|
239
|
-
return copy
|
|
240
|
-
|
|
241
|
-
def update_attrs_from_clone_(self, var: "Var"):
|
|
242
|
-
"""Updates attributes of this `Vars` instance from a cloned instance.
|
|
243
|
-
Typically called after a child module has processed a cloned `Vars`
|
|
244
|
-
object. This propagates any newly computed loss or gradient values
|
|
245
|
-
from the child's context back to the parent `Vars` if the parent
|
|
246
|
-
didn't have them computed already.
|
|
247
|
-
|
|
248
|
-
Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
|
|
249
|
-
if the child updates them, the update will affect the parent too.
|
|
250
|
-
"""
|
|
251
|
-
if self.loss is None: self.loss = var.loss
|
|
252
|
-
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
253
|
-
if self.grad is None: self.grad = var.grad
|
|
254
|
-
|
|
255
|
-
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
256
|
-
|
|
257
|
-
def zero_grad(self, set_to_none=True):
|
|
258
|
-
if set_to_none:
|
|
259
|
-
for p in self.params: p.grad = None
|
|
260
|
-
else:
|
|
261
|
-
grads = [p.grad for p in self.params if p.grad is not None]
|
|
262
|
-
if len(grads) != 0: torch._foreach_zero_(grads)
|
|
263
|
-
|
|
264
|
-
# endregion
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
# region Module
|
|
268
|
-
# ---------------------------------- module ---------------------------------- #
|
|
269
18
|
class Module(ABC):
|
|
270
19
|
"""Abstract base class for an optimizer modules.
|
|
271
20
|
|
|
@@ -281,6 +30,7 @@ class Module(ABC):
|
|
|
281
30
|
"""
|
|
282
31
|
def __init__(self, defaults: dict[str, Any] | None = None):
|
|
283
32
|
if defaults is None: defaults = {}
|
|
33
|
+
if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
|
|
284
34
|
self.defaults: dict[str, Any] = defaults
|
|
285
35
|
|
|
286
36
|
# settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
|
|
@@ -300,7 +50,7 @@ class Module(ABC):
|
|
|
300
50
|
"""A dictionary of child modules."""
|
|
301
51
|
|
|
302
52
|
self._overridden_keys = set()
|
|
303
|
-
"""tracks keys overridden with
|
|
53
|
+
"""tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
|
|
304
54
|
|
|
305
55
|
|
|
306
56
|
def set_param_groups(self, param_groups: Params):
|
|
@@ -316,10 +66,18 @@ class Module(ABC):
|
|
|
316
66
|
self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
|
|
317
67
|
return self
|
|
318
68
|
|
|
319
|
-
def set_child(self, key: str, module: "Module | Sequence[Module]"):
|
|
69
|
+
def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
|
|
70
|
+
if key in self.children:
|
|
71
|
+
warnings.warn(f"set_child overwriting child `{key}`")
|
|
72
|
+
|
|
73
|
+
if module is None: return
|
|
74
|
+
|
|
75
|
+
from .chain import maybe_chain
|
|
320
76
|
self.children[key] = maybe_chain(module)
|
|
321
77
|
|
|
322
78
|
def set_children_sequence(self, modules: "Iterable[Module | Sequence[Module]]", prefix = 'module_'):
|
|
79
|
+
from .chain import maybe_chain
|
|
80
|
+
|
|
323
81
|
modules = list(modules)
|
|
324
82
|
for i, m in enumerate(modules):
|
|
325
83
|
self.set_child(f'{prefix}{i}', maybe_chain(m))
|
|
@@ -327,6 +85,62 @@ class Module(ABC):
|
|
|
327
85
|
def get_children_sequence(self, prefix = 'module_'):
|
|
328
86
|
return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
|
|
329
87
|
|
|
88
|
+
def inner_step(
|
|
89
|
+
self,
|
|
90
|
+
key: str,
|
|
91
|
+
objective: "Objective",
|
|
92
|
+
must_exist: bool = True,
|
|
93
|
+
) -> "Objective":
|
|
94
|
+
"""Passes ``objective`` to child and returns it."""
|
|
95
|
+
child = self.children.get(key, None)
|
|
96
|
+
|
|
97
|
+
if child is None:
|
|
98
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
99
|
+
return objective
|
|
100
|
+
|
|
101
|
+
return child.step(objective)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def inner_step_tensors(
|
|
105
|
+
self,
|
|
106
|
+
key: str,
|
|
107
|
+
tensors: list[torch.Tensor],
|
|
108
|
+
clone: bool,
|
|
109
|
+
params: Iterable[torch.Tensor] | None = None,
|
|
110
|
+
grads: Sequence[torch.Tensor] | None = None,
|
|
111
|
+
loss: torch.Tensor | None = None,
|
|
112
|
+
closure: Callable | None = None,
|
|
113
|
+
objective: "Objective | None" = None,
|
|
114
|
+
must_exist: bool = True
|
|
115
|
+
) -> list[torch.Tensor]:
|
|
116
|
+
"""Steps with child module. Can be used to apply transforms to any internal buffers.
|
|
117
|
+
|
|
118
|
+
If ``objective`` is specified, other attributes shouldn't to be specified.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
key (str): Child module key.
|
|
122
|
+
tensors (Sequence[torch.Tensor]): tensors to pass to child module.
|
|
123
|
+
clone (bool):
|
|
124
|
+
If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
|
|
125
|
+
If ``key`` doesn't exist, ``tensors`` are always returned without cloning
|
|
126
|
+
params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
|
|
127
|
+
grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
|
|
128
|
+
loss (torch.Tensor | None, optional): loss. Defaults to None.
|
|
129
|
+
closure (Callable | None, optional): closure. Defaults to None.
|
|
130
|
+
must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
child = self.children.get(key, None)
|
|
134
|
+
|
|
135
|
+
if child is None:
|
|
136
|
+
if must_exist: raise KeyError(f"child `{key}` doesn't exist")
|
|
137
|
+
return tensors
|
|
138
|
+
|
|
139
|
+
if clone: tensors = [t.clone() for t in tensors]
|
|
140
|
+
return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
|
|
141
|
+
loss=loss, closure=closure, objective=objective)
|
|
142
|
+
|
|
143
|
+
|
|
330
144
|
def __repr__(self):
|
|
331
145
|
s = self.__class__.__name__
|
|
332
146
|
if self.children:
|
|
@@ -348,7 +162,6 @@ class Module(ABC):
|
|
|
348
162
|
|
|
349
163
|
def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
|
|
350
164
|
*keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
|
|
351
|
-
# if isinstance(params, Vars): params = params.params
|
|
352
165
|
return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
|
|
353
166
|
|
|
354
167
|
|
|
@@ -418,13 +231,8 @@ class Module(ABC):
|
|
|
418
231
|
- if state_keys has multiple keys and keys has a single key, return cls.
|
|
419
232
|
- if state_keys has multiple keys and keys has multiple keys, return list of cls.
|
|
420
233
|
"""
|
|
421
|
-
# if isinstance(params, Vars): params = params.params
|
|
422
234
|
return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
|
|
423
235
|
|
|
424
|
-
# def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
|
|
425
|
-
# # if isinstance(params, Vars): params = params.params
|
|
426
|
-
# return itemgetter(*keys)(self.settings[params[0]])
|
|
427
|
-
|
|
428
236
|
def clear_state_keys(self, *keys:str):
|
|
429
237
|
for s in self.state.values():
|
|
430
238
|
for k in keys:
|
|
@@ -490,36 +298,73 @@ class Module(ABC):
|
|
|
490
298
|
# extra info
|
|
491
299
|
self._extra_unpack(state_dict['extra'])
|
|
492
300
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
301
|
+
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
302
|
+
"""If ``seed=None``, returns ``None``.
|
|
303
|
+
|
|
304
|
+
Otherwise, if generator on this device and with this seed hasn't been created,
|
|
305
|
+
creates it and stores in global state.
|
|
306
|
+
|
|
307
|
+
Returns ``torch.Generator``."""
|
|
308
|
+
if seed is None: return None
|
|
309
|
+
|
|
310
|
+
if device is None: device_obj = torch.get_default_device()
|
|
311
|
+
else: device_obj = torch.device(device)
|
|
312
|
+
key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"
|
|
498
313
|
|
|
499
|
-
|
|
500
|
-
|
|
314
|
+
if key not in self.global_state:
|
|
315
|
+
self.global_state[key] = torch.Generator(device).manual_seed(seed)
|
|
316
|
+
|
|
317
|
+
return self.global_state[key]
|
|
318
|
+
|
|
319
|
+
def increment_counter(self, key: str, start: int):
|
|
320
|
+
"""first value is ``start``"""
|
|
321
|
+
value = self.global_state.get(key, start - 1) + 1
|
|
322
|
+
self.global_state[key] = value
|
|
323
|
+
return value
|
|
324
|
+
|
|
325
|
+
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
326
|
+
def update(self, objective:"Objective") -> None:
|
|
327
|
+
"""Updates internal state of this module. This should not modify ``objective.update``.
|
|
501
328
|
|
|
502
329
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
503
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
330
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
331
|
+
|
|
332
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
objective (Objective): ``Objective`` object
|
|
504
336
|
"""
|
|
505
337
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
338
|
+
@abstractmethod
|
|
339
|
+
def apply(self, objective: "Objective") -> "Objective":
|
|
340
|
+
"""Updates ``objective`` using the internal state of this module.
|
|
341
|
+
|
|
342
|
+
If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.
|
|
509
343
|
|
|
510
344
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
511
|
-
such as ``tz.m.Online`` or trust regions. Alternatively,
|
|
345
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
|
|
346
|
+
|
|
347
|
+
``update`` is guaranteed to be called at least once before ``apply``.
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
objective (Objective): ``Objective`` object
|
|
512
351
|
"""
|
|
513
|
-
|
|
352
|
+
# if apply is empty, it should be defined explicitly.
|
|
353
|
+
raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")
|
|
354
|
+
|
|
355
|
+
def step(self, objective: "Objective") -> "Objective":
|
|
356
|
+
"""Perform a step with this module. Calls ``update``, then ``apply``."""
|
|
357
|
+
self.update(objective)
|
|
358
|
+
return self.apply(objective)
|
|
514
359
|
|
|
515
|
-
def get_H(self,
|
|
360
|
+
def get_H(self, objective: "Objective") -> LinearOperator | None:
|
|
516
361
|
"""returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
|
|
517
362
|
The hessian approximation is assumed to be for all parameters concatenated to a vector."""
|
|
518
363
|
# if this method is not defined it searches in children
|
|
519
364
|
# this should be overwritten to return None if child params are different from this modules params
|
|
520
365
|
H = None
|
|
521
366
|
for k,v in self.children.items():
|
|
522
|
-
H_v = v.get_H(
|
|
367
|
+
H_v = v.get_H(objective)
|
|
523
368
|
|
|
524
369
|
if (H is not None) and (H_v is not None):
|
|
525
370
|
raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
|
|
@@ -549,370 +394,14 @@ class Module(ABC):
|
|
|
549
394
|
"""
|
|
550
395
|
for c in self.children.values(): c.reset_for_online()
|
|
551
396
|
|
|
552
|
-
def _extra_pack(self):
|
|
553
|
-
"""extra information to store in state_dict of this optimizer.
|
|
554
|
-
Will be passed to ``_extra_unpack`` when loading the state_dict
|
|
397
|
+
def _extra_pack(self) -> dict:
|
|
398
|
+
"""extra information to store in ``state_dict`` of this optimizer.
|
|
399
|
+
Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
|
|
555
400
|
return {}
|
|
556
401
|
|
|
557
|
-
def _extra_unpack(self,
|
|
558
|
-
"""``_extra_pack`` return will be passed to this method when loading state_dict
|
|
402
|
+
def _extra_unpack(self, d: dict):
|
|
403
|
+
"""``_extra_pack`` return will be passed to this method when loading ``state_dict``.
|
|
559
404
|
This method is called after loading the rest of the state dict"""
|
|
560
405
|
|
|
561
406
|
|
|
562
|
-
|
|
563
|
-
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
564
|
-
@torch.no_grad
|
|
565
|
-
def Hvp(
|
|
566
|
-
self,
|
|
567
|
-
v: Sequence[torch.Tensor],
|
|
568
|
-
at_x0: bool,
|
|
569
|
-
var: Var,
|
|
570
|
-
rgrad: Sequence[torch.Tensor] | None,
|
|
571
|
-
hvp_method: Literal['autograd', 'forward', 'central'],
|
|
572
|
-
h: float,
|
|
573
|
-
normalize: bool,
|
|
574
|
-
retain_grad: bool,
|
|
575
|
-
) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
576
|
-
"""
|
|
577
|
-
Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
|
|
578
|
-
possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
|
|
579
|
-
Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
580
|
-
|
|
581
|
-
Single sample example:
|
|
582
|
-
|
|
583
|
-
```python
|
|
584
|
-
Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
585
|
-
```
|
|
586
|
-
|
|
587
|
-
Multiple samples example:
|
|
588
|
-
|
|
589
|
-
```python
|
|
590
|
-
D = None
|
|
591
|
-
rgrad = None
|
|
592
|
-
for i in range(n_samples):
|
|
593
|
-
v = [torch.randn_like(p) for p in params]
|
|
594
|
-
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
595
|
-
|
|
596
|
-
if D is None: D = Hvp
|
|
597
|
-
else: torch._foreach_add_(D, Hvp)
|
|
598
|
-
|
|
599
|
-
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
600
|
-
```
|
|
601
|
-
|
|
602
|
-
Args:
|
|
603
|
-
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
604
|
-
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
605
|
-
var (Var): Var
|
|
606
|
-
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
607
|
-
hvp_method (str): hvp method.
|
|
608
|
-
h (float): finite difference step size
|
|
609
|
-
normalize (bool): whether to normalize v for finite difference
|
|
610
|
-
retain_grad (bool): retain grad
|
|
611
|
-
"""
|
|
612
|
-
# get grad
|
|
613
|
-
if rgrad is None and hvp_method in ('autograd', 'forward'):
|
|
614
|
-
if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
|
|
615
|
-
else:
|
|
616
|
-
if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
617
|
-
with torch.enable_grad():
|
|
618
|
-
loss = var.closure()
|
|
619
|
-
rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
|
|
620
|
-
|
|
621
|
-
if hvp_method == 'autograd':
|
|
622
|
-
assert rgrad is not None
|
|
623
|
-
Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
|
|
624
|
-
|
|
625
|
-
elif hvp_method == 'forward':
|
|
626
|
-
assert rgrad is not None
|
|
627
|
-
loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
|
|
628
|
-
|
|
629
|
-
elif hvp_method == 'central':
|
|
630
|
-
loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
|
|
631
|
-
|
|
632
|
-
else:
|
|
633
|
-
raise ValueError(hvp_method)
|
|
634
|
-
|
|
635
|
-
return Hvp, rgrad
|
|
636
|
-
|
|
637
|
-
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
638
|
-
if seed is None: return None
|
|
639
|
-
|
|
640
|
-
if 'generator' not in self.global_state:
|
|
641
|
-
self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
|
|
642
|
-
|
|
643
|
-
return self.global_state['generator']
|
|
644
|
-
|
|
645
|
-
# endregion
|
|
646
|
-
|
|
647
407
|
Chainable = Module | Sequence[Module]
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
651
|
-
unrolled = []
|
|
652
|
-
|
|
653
|
-
for m in modules:
|
|
654
|
-
if isinstance(m, Module):
|
|
655
|
-
unrolled.append(m)
|
|
656
|
-
unrolled.extend(unroll_modules(list(m.children.values())))
|
|
657
|
-
else:
|
|
658
|
-
unrolled.extend(unroll_modules(*m))
|
|
659
|
-
|
|
660
|
-
return unrolled
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
# region Modular
|
|
664
|
-
# ---------------------------------- Modular --------------------------------- #
|
|
665
|
-
|
|
666
|
-
class _EvalCounterClosure:
|
|
667
|
-
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
668
|
-
__slots__ = ("modular", "closure")
|
|
669
|
-
def __init__(self, modular: "Modular", closure):
|
|
670
|
-
self.modular = modular
|
|
671
|
-
self.closure = closure
|
|
672
|
-
|
|
673
|
-
def __call__(self, *args, **kwargs):
|
|
674
|
-
if self.closure is None:
|
|
675
|
-
raise RuntimeError("One of the modules requires closure to be passed to the step method")
|
|
676
|
-
|
|
677
|
-
v = self.closure(*args, **kwargs)
|
|
678
|
-
|
|
679
|
-
# set closure return on 1st evaluation
|
|
680
|
-
if self.modular._closure_return is None:
|
|
681
|
-
self.modular._closure_return = v
|
|
682
|
-
|
|
683
|
-
self.modular.num_evaluations += 1
|
|
684
|
-
return v
|
|
685
|
-
|
|
686
|
-
# have to inherit from Modular to support lr schedulers
|
|
687
|
-
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
688
|
-
class Modular(torch.optim.Optimizer):
|
|
689
|
-
"""Chains multiple modules into an optimizer.
|
|
690
|
-
|
|
691
|
-
Args:
|
|
692
|
-
params (Params | torch.nn.Module): An iterable of parameters to optimize
|
|
693
|
-
(typically `model.parameters()`), an iterable of parameter group dicts,
|
|
694
|
-
or a `torch.nn.Module` instance.
|
|
695
|
-
*modules (Module): A sequence of `Module` instances that define the
|
|
696
|
-
optimization algorithm steps.
|
|
697
|
-
"""
|
|
698
|
-
# this is specifically for lr schedulers
|
|
699
|
-
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
700
|
-
|
|
701
|
-
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
702
|
-
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
|
|
703
|
-
self.model: torch.nn.Module | None = None
|
|
704
|
-
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
705
|
-
if isinstance(params, torch.nn.Module):
|
|
706
|
-
self.model = params
|
|
707
|
-
params = params.parameters()
|
|
708
|
-
|
|
709
|
-
self.modules = modules
|
|
710
|
-
"""Top-level modules providedduring initialization."""
|
|
711
|
-
|
|
712
|
-
self.unrolled_modules = unroll_modules(self.modules)
|
|
713
|
-
"""A flattened list of all modules including all children."""
|
|
714
|
-
|
|
715
|
-
param_groups = _make_param_groups(params, differentiable=False)
|
|
716
|
-
self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
|
|
717
|
-
|
|
718
|
-
# make sure there is no more than a single learning rate module
|
|
719
|
-
lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
|
|
720
|
-
if len(lr_modules) > 1:
|
|
721
|
-
warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
|
|
722
|
-
|
|
723
|
-
# iterate over all per-parameter settings overrides and check if they are applied at most once
|
|
724
|
-
for group in param_groups:
|
|
725
|
-
for k in group:
|
|
726
|
-
if k in ('params', 'lr'): continue
|
|
727
|
-
modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
|
|
728
|
-
if len(modules_with_k) > 1:
|
|
729
|
-
warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
|
|
730
|
-
|
|
731
|
-
# defaults for schedulers
|
|
732
|
-
defaults = {}
|
|
733
|
-
for m in self.unrolled_modules: defaults.update(m.defaults)
|
|
734
|
-
super().__init__(param_groups, defaults=defaults)
|
|
735
|
-
|
|
736
|
-
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
737
|
-
|
|
738
|
-
# self.defaults = defaults
|
|
739
|
-
# for param_group in param_groups:
|
|
740
|
-
# self.add_param_group(param_group)
|
|
741
|
-
|
|
742
|
-
# add_param_group adds a ChainMap where defaults are lowest priority,
|
|
743
|
-
# and entries specifed in param_groups or scheduler are higher priority.
|
|
744
|
-
# pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
|
|
745
|
-
# in each module, settings passed to that module by calling set_param_groups are highest priority
|
|
746
|
-
|
|
747
|
-
self.current_step = 0
|
|
748
|
-
"""global step counter for the optimizer."""
|
|
749
|
-
|
|
750
|
-
self.num_evaluations = 0
|
|
751
|
-
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
752
|
-
|
|
753
|
-
# reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
|
|
754
|
-
# we want to return original loss so this attribute is used
|
|
755
|
-
self._closure_return = None
|
|
756
|
-
"""on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
|
|
757
|
-
|
|
758
|
-
self.attrs = {}
|
|
759
|
-
"""custom attributes that can be set by modules, for example EMA of weights or best so far"""
|
|
760
|
-
|
|
761
|
-
self.should_terminate = False
|
|
762
|
-
"""is set to True by termination criteria modules."""
|
|
763
|
-
|
|
764
|
-
def add_param_group(self, param_group: dict[str, Any]):
|
|
765
|
-
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
766
|
-
self.param_groups.append(ChainMap(proc_param_group, self.defaults))
|
|
767
|
-
|
|
768
|
-
for p in proc_param_group['params']:
|
|
769
|
-
# updates global per-parameter setting overrides (medium priority)
|
|
770
|
-
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
|
|
771
|
-
|
|
772
|
-
def state_dict(self):
|
|
773
|
-
all_params = [p for g in self.param_groups for p in g['params']]
|
|
774
|
-
id_to_idx = {id(p): i for i,p in enumerate(all_params)}
|
|
775
|
-
|
|
776
|
-
groups = []
|
|
777
|
-
for g in self.param_groups:
|
|
778
|
-
g = g.copy()
|
|
779
|
-
g['params'] = [id_to_idx[id(p)] for p in g['params']]
|
|
780
|
-
groups.append(g)
|
|
781
|
-
|
|
782
|
-
state_dict = {
|
|
783
|
-
"idx_to_id": {v:k for k,v in id_to_idx.items()},
|
|
784
|
-
"params": all_params,
|
|
785
|
-
"groups": groups,
|
|
786
|
-
"defaults": self.defaults,
|
|
787
|
-
"modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
|
|
788
|
-
}
|
|
789
|
-
return state_dict
|
|
790
|
-
|
|
791
|
-
def load_state_dict(self, state_dict: dict):
|
|
792
|
-
self.defaults.clear()
|
|
793
|
-
self.defaults.update(state_dict['defaults'])
|
|
794
|
-
|
|
795
|
-
idx_to_param = dict(enumerate(state_dict['params']))
|
|
796
|
-
groups = []
|
|
797
|
-
for g in state_dict['groups']:
|
|
798
|
-
g = g.copy()
|
|
799
|
-
g['params'] = [idx_to_param[p] for p in g['params']]
|
|
800
|
-
groups.append(g)
|
|
801
|
-
|
|
802
|
-
self.param_groups.clear()
|
|
803
|
-
for group in groups:
|
|
804
|
-
self.add_param_group(group)
|
|
805
|
-
|
|
806
|
-
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
807
|
-
for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
|
|
808
|
-
m._load_state_dict(sd, id_to_tensor)
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
812
|
-
# clear closure return from previous step
|
|
813
|
-
self._closure_return = None
|
|
814
|
-
|
|
815
|
-
# propagate global per-parameter setting overrides
|
|
816
|
-
for g in self.param_groups:
|
|
817
|
-
settings = dict(g.maps[0]) # ignore defaults
|
|
818
|
-
params = settings.pop('params')
|
|
819
|
-
if not settings: continue
|
|
820
|
-
|
|
821
|
-
for p in params:
|
|
822
|
-
if not p.requires_grad: continue
|
|
823
|
-
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
824
|
-
|
|
825
|
-
# create var
|
|
826
|
-
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
827
|
-
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
|
|
828
|
-
|
|
829
|
-
# if closure is None, assume backward has been called and gather grads
|
|
830
|
-
if closure is None:
|
|
831
|
-
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
832
|
-
self.num_evaluations += 1
|
|
833
|
-
|
|
834
|
-
n_modules = len(self.modules)
|
|
835
|
-
if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
|
|
836
|
-
last_module = self.modules[-1]
|
|
837
|
-
last_lr = last_module.defaults.get('lr', None)
|
|
838
|
-
|
|
839
|
-
# step
|
|
840
|
-
for i, module in enumerate(self.modules):
|
|
841
|
-
if i!=0: var = var.clone(clone_update=False)
|
|
842
|
-
|
|
843
|
-
# last module, or next to last module before lr
|
|
844
|
-
if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
|
|
845
|
-
if module.children: var.nested_is_last = True
|
|
846
|
-
else: var.is_last = True
|
|
847
|
-
if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
|
|
848
|
-
|
|
849
|
-
var = module.step(var)
|
|
850
|
-
if var.stop: break
|
|
851
|
-
|
|
852
|
-
# apply update
|
|
853
|
-
if not var.skip_update:
|
|
854
|
-
with torch.no_grad():
|
|
855
|
-
torch._foreach_sub_(params, var.get_update())
|
|
856
|
-
|
|
857
|
-
# update attributes
|
|
858
|
-
self.attrs.update(var.attrs)
|
|
859
|
-
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
860
|
-
|
|
861
|
-
# hooks
|
|
862
|
-
for hook in var.post_step_hooks:
|
|
863
|
-
hook(self, var)
|
|
864
|
-
|
|
865
|
-
self.current_step += 1
|
|
866
|
-
#return var.loss if var.loss is not None else var.loss_approx
|
|
867
|
-
return self._closure_return
|
|
868
|
-
|
|
869
|
-
def __repr__(self):
|
|
870
|
-
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
871
|
-
# endregion
|
|
872
|
-
|
|
873
|
-
# region Chain
|
|
874
|
-
# ----------------------------------- Chain ---------------------------------- #
|
|
875
|
-
class Chain(Module):
|
|
876
|
-
"""Chain of modules, mostly used internally"""
|
|
877
|
-
def __init__(self, *modules: Module | Iterable[Module]):
|
|
878
|
-
super().__init__()
|
|
879
|
-
flat_modules: list[Module] = flatten(modules)
|
|
880
|
-
for i, module in enumerate(flat_modules):
|
|
881
|
-
self.set_child(f'module_{i}', module)
|
|
882
|
-
|
|
883
|
-
def update(self, var):
|
|
884
|
-
# note here that `update` and `apply` shouldn't be used directly
|
|
885
|
-
# as it will update all modules, and then apply all modules
|
|
886
|
-
# it is used in specific cases like Chain as trust region hessian module
|
|
887
|
-
for i in range(len(self.children)):
|
|
888
|
-
self.children[f'module_{i}'].update(var)
|
|
889
|
-
if var.stop: break
|
|
890
|
-
return var
|
|
891
|
-
|
|
892
|
-
def apply(self, var):
|
|
893
|
-
for i in range(len(self.children)):
|
|
894
|
-
var = self.children[f'module_{i}'].apply(var)
|
|
895
|
-
if var.stop: break
|
|
896
|
-
return var
|
|
897
|
-
|
|
898
|
-
def step(self, var):
|
|
899
|
-
for i in range(len(self.children)):
|
|
900
|
-
var = self.children[f'module_{i}'].step(var)
|
|
901
|
-
if var.stop: break
|
|
902
|
-
return var
|
|
903
|
-
|
|
904
|
-
def __repr__(self):
|
|
905
|
-
s = self.__class__.__name__
|
|
906
|
-
if self.children:
|
|
907
|
-
if s == 'Chain': s = 'C' # to shorten it
|
|
908
|
-
s = f'{s}({", ".join(str(m) for m in self.children.values())})'
|
|
909
|
-
return s
|
|
910
|
-
|
|
911
|
-
def maybe_chain(*modules: Chainable) -> Module:
|
|
912
|
-
"""Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
|
|
913
|
-
flat_modules: list[Module] = flatten(modules)
|
|
914
|
-
if len(flat_modules) == 1:
|
|
915
|
-
return flat_modules[0]
|
|
916
|
-
return Chain(*flat_modules)
|
|
917
|
-
# endregion
|
|
918
|
-
|