torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
torchzero/core/module.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
4
|
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any, final, overload
|
|
6
|
+
from typing import Any, final, overload, Literal, cast
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
@@ -14,7 +14,9 @@ from ..utils import (
|
|
|
14
14
|
_make_param_groups,
|
|
15
15
|
get_state_vals,
|
|
16
16
|
)
|
|
17
|
+
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
17
18
|
from ..utils.python_tools import flatten
|
|
19
|
+
from ..utils.linalg.linear_operator import LinearOperator
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
def _closure_backward(closure, params, retain_graph, create_graph):
|
|
@@ -32,11 +34,9 @@ def _closure_backward(closure, params, retain_graph, create_graph):
|
|
|
32
34
|
# ----------------------------------- var ----------------------------------- #
|
|
33
35
|
class Var:
|
|
34
36
|
"""
|
|
35
|
-
Holds
|
|
37
|
+
Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
|
|
38
|
+
Modules take in a ``Var`` object, modify and it is passed to the next module.
|
|
36
39
|
|
|
37
|
-
This class acts as a mutable container for information relevant to the current
|
|
38
|
-
optimization step, such as parameters, gradients, loss, and the computed update.
|
|
39
|
-
Modules read from and write to this object to coordinate their actions.
|
|
40
40
|
"""
|
|
41
41
|
def __init__(
|
|
42
42
|
self,
|
|
@@ -44,6 +44,10 @@ class Var:
|
|
|
44
44
|
closure: Callable | None,
|
|
45
45
|
model: torch.nn.Module | None,
|
|
46
46
|
current_step: int,
|
|
47
|
+
parent: "Var | None" = None,
|
|
48
|
+
modular: "Modular | None" = None,
|
|
49
|
+
loss: torch.Tensor | None = None,
|
|
50
|
+
storage: dict | None = None,
|
|
47
51
|
):
|
|
48
52
|
self.params: list[torch.Tensor] = params
|
|
49
53
|
"""List of all parameters with requires_grad = True."""
|
|
@@ -55,19 +59,31 @@ class Var:
|
|
|
55
59
|
"""torch.nn.Module object of the model, None if it wasn't specified."""
|
|
56
60
|
|
|
57
61
|
self.current_step: int = current_step
|
|
58
|
-
"""global current step, starts at 0
|
|
62
|
+
"""global current step, starts at 0. This may not correspond to module current step,
|
|
63
|
+
for example a module may step every 10 global steps."""
|
|
64
|
+
|
|
65
|
+
self.parent: "Var | None" = parent
|
|
66
|
+
"""parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
67
|
+
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
68
|
+
e.g. when projecting."""
|
|
69
|
+
|
|
70
|
+
self.modular: "Modular" = cast(Modular, modular)
|
|
71
|
+
"""Modular optimizer object that created this ``Var``."""
|
|
59
72
|
|
|
60
73
|
self.update: list[torch.Tensor] | None = None
|
|
61
74
|
"""
|
|
62
|
-
current update
|
|
75
|
+
current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
|
|
63
76
|
|
|
64
77
|
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
78
|
+
|
|
79
|
+
At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
|
|
80
|
+
gradient will be used and calculated if needed.
|
|
65
81
|
"""
|
|
66
82
|
|
|
67
83
|
self.grad: list[torch.Tensor] | None = None
|
|
68
|
-
"""gradient with current parameters. If closure is not None
|
|
84
|
+
"""gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
|
|
69
85
|
|
|
70
|
-
self.loss: torch.Tensor | Any | None =
|
|
86
|
+
self.loss: torch.Tensor | Any | None = loss
|
|
71
87
|
"""loss with current parameters."""
|
|
72
88
|
|
|
73
89
|
self.loss_approx: torch.Tensor | Any | None = None
|
|
@@ -76,24 +92,28 @@ class Var:
|
|
|
76
92
|
|
|
77
93
|
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
78
94
|
"""list of functions to be called after optimizer step.
|
|
79
|
-
The signature is:
|
|
80
95
|
|
|
81
|
-
|
|
96
|
+
This attribute should always be modified in-place (using ``append`` or ``extend``).
|
|
82
97
|
|
|
83
|
-
|
|
98
|
+
The signature is:
|
|
84
99
|
|
|
100
|
+
```python
|
|
101
|
+
def hook(optimizer: Modular, var: Vars): ...
|
|
102
|
+
```
|
|
85
103
|
"""
|
|
86
104
|
|
|
87
105
|
self.is_last: bool = False
|
|
88
106
|
"""
|
|
89
107
|
Indicates that current module is either last or next-to-last before a learning rate module.
|
|
90
108
|
This is always False if current module has children or is a child.
|
|
109
|
+
This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
|
|
91
110
|
"""
|
|
92
111
|
|
|
93
112
|
self.nested_is_last: bool = False
|
|
94
113
|
"""
|
|
95
114
|
Indicates that current module is either last or next-to-last before a learning rate module, for modules
|
|
96
|
-
that have children.
|
|
115
|
+
that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
|
|
116
|
+
a child of a last module may also receive ``var.nested_is_last=True``.
|
|
97
117
|
"""
|
|
98
118
|
|
|
99
119
|
self.last_module_lrs: list[float] | None = None
|
|
@@ -104,16 +124,30 @@ class Var:
|
|
|
104
124
|
"""
|
|
105
125
|
|
|
106
126
|
self.stop: bool = False
|
|
107
|
-
"""if True, all following modules will be skipped.
|
|
127
|
+
"""if True, all following modules will be skipped.
|
|
128
|
+
If this module is a child, it only affects modules at the same level (in the same Chain)."""
|
|
108
129
|
|
|
109
130
|
self.skip_update: bool = False
|
|
110
|
-
"""if True, the parameters will not be updated"""
|
|
131
|
+
"""if True, the parameters will not be updated."""
|
|
111
132
|
|
|
112
|
-
|
|
113
|
-
"""
|
|
114
|
-
Do not call this at perturbed parameters. Backward always zeroes grads before recomputing."""
|
|
133
|
+
# self.storage: dict = {}
|
|
134
|
+
# """Storage for any other data, such as hessian estimates, etc."""
|
|
115
135
|
|
|
136
|
+
self.attrs: dict = {}
|
|
137
|
+
"""attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
|
|
138
|
+
|
|
139
|
+
if storage is None: storage = {}
|
|
140
|
+
self.storage: dict = storage
|
|
141
|
+
"""additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
|
|
142
|
+
|
|
143
|
+
self.should_terminate: bool | None = None
|
|
144
|
+
"""termination criteria, Modular.should_terminate is set to this after each step if not None"""
|
|
145
|
+
|
|
146
|
+
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
147
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
|
|
148
|
+
Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
|
|
116
149
|
if self.loss is None:
|
|
150
|
+
|
|
117
151
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
118
152
|
if backward:
|
|
119
153
|
with torch.enable_grad():
|
|
@@ -124,7 +158,10 @@ class Var:
|
|
|
124
158
|
# initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
|
|
125
159
|
# it is technically a more correct approach for when some parameters conditionally receive gradients
|
|
126
160
|
# and in this case it shouldn't be slower.
|
|
127
|
-
|
|
161
|
+
|
|
162
|
+
# next time closure() is called, it will set grad to None.
|
|
163
|
+
# zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
|
|
164
|
+
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
128
165
|
else:
|
|
129
166
|
self.loss = self.loss_approx = self.closure(False)
|
|
130
167
|
|
|
@@ -139,11 +176,24 @@ class Var:
|
|
|
139
176
|
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
140
177
|
)
|
|
141
178
|
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
179
|
+
|
|
180
|
+
# set parent grad
|
|
181
|
+
if self.parent is not None:
|
|
182
|
+
# the way projections/split work, they make a new closure which evaluates original
|
|
183
|
+
# closure and projects the gradient, and set it as their var.closure.
|
|
184
|
+
# then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
|
|
185
|
+
# and we set it to parent var here.
|
|
186
|
+
if self.parent.loss is None: self.parent.loss = self.loss
|
|
187
|
+
if self.parent.grad is None and backward:
|
|
188
|
+
if all(p.grad is None for p in self.parent.params):
|
|
189
|
+
warnings.warn("Parent grad is None after backward.")
|
|
190
|
+
self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
|
|
191
|
+
|
|
142
192
|
return self.loss # type:ignore
|
|
143
193
|
|
|
144
194
|
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
145
195
|
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
146
|
-
|
|
196
|
+
``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
|
|
147
197
|
if self.grad is None:
|
|
148
198
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
149
199
|
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
@@ -152,15 +202,21 @@ class Var:
|
|
|
152
202
|
return self.grad
|
|
153
203
|
|
|
154
204
|
def get_update(self) -> list[torch.Tensor]:
|
|
155
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to
|
|
156
|
-
Computing the gradients may assign
|
|
205
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
|
|
206
|
+
Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
|
|
157
207
|
Do not call this at perturbed parameters."""
|
|
158
208
|
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
159
209
|
return self.update
|
|
160
210
|
|
|
161
|
-
def clone(self, clone_update: bool):
|
|
162
|
-
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via
|
|
163
|
-
|
|
211
|
+
def clone(self, clone_update: bool, parent: "Var | None" = None):
|
|
212
|
+
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
|
|
213
|
+
|
|
214
|
+
Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
|
|
215
|
+
|
|
216
|
+
Setting ``parent`` is only if clone's parameters are something different,
|
|
217
|
+
while clone's closure referes to the same objective but with a "view" on parameters.
|
|
218
|
+
"""
|
|
219
|
+
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
|
|
164
220
|
|
|
165
221
|
if clone_update and self.update is not None:
|
|
166
222
|
copy.update = [u.clone() for u in self.update]
|
|
@@ -170,10 +226,16 @@ class Var:
|
|
|
170
226
|
copy.grad = self.grad
|
|
171
227
|
copy.loss = self.loss
|
|
172
228
|
copy.loss_approx = self.loss_approx
|
|
229
|
+
copy.closure = self.closure
|
|
173
230
|
copy.post_step_hooks = self.post_step_hooks
|
|
174
231
|
copy.stop = self.stop
|
|
175
232
|
copy.skip_update = self.skip_update
|
|
176
233
|
|
|
234
|
+
copy.modular = self.modular
|
|
235
|
+
copy.attrs = self.attrs
|
|
236
|
+
copy.storage = self.storage
|
|
237
|
+
copy.should_terminate = self.should_terminate
|
|
238
|
+
|
|
177
239
|
return copy
|
|
178
240
|
|
|
179
241
|
def update_attrs_from_clone_(self, var: "Var"):
|
|
@@ -182,11 +244,16 @@ class Var:
|
|
|
182
244
|
object. This propagates any newly computed loss or gradient values
|
|
183
245
|
from the child's context back to the parent `Vars` if the parent
|
|
184
246
|
didn't have them computed already.
|
|
247
|
+
|
|
248
|
+
Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
|
|
249
|
+
if the child updates them, the update will affect the parent too.
|
|
185
250
|
"""
|
|
186
251
|
if self.loss is None: self.loss = var.loss
|
|
187
252
|
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
188
253
|
if self.grad is None: self.grad = var.grad
|
|
189
254
|
|
|
255
|
+
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
256
|
+
|
|
190
257
|
def zero_grad(self, set_to_none=True):
|
|
191
258
|
if set_to_none:
|
|
192
259
|
for p in self.params: p.grad = None
|
|
@@ -196,6 +263,7 @@ class Var:
|
|
|
196
263
|
|
|
197
264
|
# endregion
|
|
198
265
|
|
|
266
|
+
|
|
199
267
|
# region Module
|
|
200
268
|
# ---------------------------------- module ---------------------------------- #
|
|
201
269
|
class Module(ABC):
|
|
@@ -308,17 +376,16 @@ class Module(ABC):
|
|
|
308
376
|
|
|
309
377
|
If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
|
|
310
378
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
# returns cls (by default TensorList)
|
|
315
|
-
|
|
316
|
-
exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
|
|
317
|
-
# returns list of cls
|
|
379
|
+
```python
|
|
380
|
+
exp_avg = self.state_vals("exp_avg")
|
|
381
|
+
# returns cls (by default TensorList)
|
|
318
382
|
|
|
319
|
-
|
|
320
|
-
|
|
383
|
+
exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
|
|
384
|
+
# returns list of cls
|
|
321
385
|
|
|
386
|
+
exp_avg = self.state_vals(["exp_avg"])
|
|
387
|
+
# always returns a list of cls, even if got a single key
|
|
388
|
+
```
|
|
322
389
|
|
|
323
390
|
Args:
|
|
324
391
|
*keys (str):
|
|
@@ -358,6 +425,26 @@ class Module(ABC):
|
|
|
358
425
|
# # if isinstance(params, Vars): params = params.params
|
|
359
426
|
# return itemgetter(*keys)(self.settings[params[0]])
|
|
360
427
|
|
|
428
|
+
def clear_state_keys(self, *keys:str):
|
|
429
|
+
for s in self.state.values():
|
|
430
|
+
for k in keys:
|
|
431
|
+
if k in s: del s[k]
|
|
432
|
+
|
|
433
|
+
@overload
|
|
434
|
+
def store(self, params: Sequence[torch.Tensor], keys: str, values: Sequence): ...
|
|
435
|
+
@overload
|
|
436
|
+
def store(self, params: Sequence[torch.Tensor], keys: Sequence[str], values: Sequence[Sequence]): ...
|
|
437
|
+
def store(self, params: Sequence[torch.Tensor], keys: str | Sequence[str], values: Sequence):
|
|
438
|
+
if isinstance(keys, str):
|
|
439
|
+
for p,v in zip(params, values):
|
|
440
|
+
state = self.state[p]
|
|
441
|
+
state[keys] = v
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
for p, *p_v in zip(params, *values):
|
|
445
|
+
state = self.state[p]
|
|
446
|
+
for k,v in zip(keys, p_v): state[k] = v
|
|
447
|
+
|
|
361
448
|
def state_dict(self):
|
|
362
449
|
"""state dict"""
|
|
363
450
|
packed_state = {id(k):v for k,v in self.state.items()}
|
|
@@ -377,7 +464,8 @@ class Module(ABC):
|
|
|
377
464
|
}
|
|
378
465
|
return state_dict
|
|
379
466
|
|
|
380
|
-
def
|
|
467
|
+
def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
|
|
468
|
+
"""loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
|
|
381
469
|
# load state
|
|
382
470
|
state = state_dict['state']
|
|
383
471
|
self.state.clear()
|
|
@@ -396,29 +484,159 @@ class Module(ABC):
|
|
|
396
484
|
|
|
397
485
|
# children
|
|
398
486
|
for k, v in state_dict['children']:
|
|
399
|
-
if k in self.children: self.children[k].
|
|
487
|
+
if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
|
|
400
488
|
else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
|
|
401
489
|
|
|
402
490
|
# extra info
|
|
403
491
|
self._extra_unpack(state_dict['extra'])
|
|
404
492
|
|
|
405
493
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
406
|
-
@abstractmethod
|
|
407
494
|
def step(self, var: Var) -> Var:
|
|
408
|
-
"""performs a step, returns new var but may update
|
|
495
|
+
"""performs a step, returns new ``var`` but may update it in-place."""
|
|
496
|
+
self.update(var)
|
|
497
|
+
return self.apply(var)
|
|
498
|
+
|
|
499
|
+
def update(self, var:Var) -> Any:
|
|
500
|
+
"""Updates the internal state of this module. This should not modify ``var.update``.
|
|
501
|
+
|
|
502
|
+
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
503
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
def apply(self, var: Var) -> Var:
|
|
507
|
+
"""Applies this module to ``var.get_update()``.
|
|
508
|
+
This should not modify the internal state of this module if possible.
|
|
509
|
+
|
|
510
|
+
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
511
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
|
|
512
|
+
"""
|
|
513
|
+
return self.step(var)
|
|
514
|
+
|
|
515
|
+
def get_H(self, var: Var) -> LinearOperator | None:
|
|
516
|
+
"""returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
|
|
517
|
+
The hessian approximation is assumed to be for all parameters concatenated to a vector."""
|
|
518
|
+
# if this method is not defined it searches in children
|
|
519
|
+
# this should be overwritten to return None if child params are different from this modules params
|
|
520
|
+
H = None
|
|
521
|
+
for k,v in self.children.items():
|
|
522
|
+
H_v = v.get_H(var)
|
|
523
|
+
|
|
524
|
+
if (H is not None) and (H_v is not None):
|
|
525
|
+
raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
|
|
526
|
+
|
|
527
|
+
if H_v is not None: H = H_v
|
|
528
|
+
|
|
529
|
+
return H
|
|
409
530
|
|
|
410
531
|
def reset(self):
|
|
411
|
-
"""Resets the internal state of the module (e.g. momentum)."""
|
|
412
|
-
# no complex logic is allowed there because this is overridden by many modules
|
|
413
|
-
# where super().reset() shouldn't be called
|
|
532
|
+
"""Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
|
|
414
533
|
self.state.clear()
|
|
415
534
|
self.global_state.clear()
|
|
535
|
+
for c in self.children.values(): c.reset()
|
|
536
|
+
|
|
537
|
+
def reset_for_online(self):
|
|
538
|
+
"""Resets buffers that depend on previous evaluation, such as previous gradient and loss,
|
|
539
|
+
which may become inaccurate due to mini-batching.
|
|
540
|
+
|
|
541
|
+
``Online`` module calls ``reset_for_online``,
|
|
542
|
+
then it calls ``update`` with previous parameters,
|
|
543
|
+
then it calls ``update`` with current parameters,
|
|
544
|
+
and then ``apply``.
|
|
545
|
+
"""
|
|
546
|
+
for c in self.children.values(): c.reset_for_online()
|
|
416
547
|
|
|
417
548
|
def _extra_pack(self):
|
|
549
|
+
"""extra information to store in state_dict of this optimizer.
|
|
550
|
+
Will be passed to ``_extra_unpack`` when loading the state_dict."""
|
|
418
551
|
return {}
|
|
419
552
|
|
|
420
553
|
def _extra_unpack(self, x):
|
|
421
|
-
|
|
554
|
+
"""``_extra_pack`` return will be passed to this method when loading state_dict.
|
|
555
|
+
This method is called after loading the rest of the state dict"""
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
560
|
+
@torch.no_grad
|
|
561
|
+
def Hvp(
|
|
562
|
+
self,
|
|
563
|
+
v: Sequence[torch.Tensor],
|
|
564
|
+
at_x0: bool,
|
|
565
|
+
var: Var,
|
|
566
|
+
rgrad: Sequence[torch.Tensor] | None,
|
|
567
|
+
hvp_method: Literal['autograd', 'forward', 'central'],
|
|
568
|
+
h: float,
|
|
569
|
+
normalize: bool,
|
|
570
|
+
retain_grad: bool,
|
|
571
|
+
) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
572
|
+
"""
|
|
573
|
+
Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
|
|
574
|
+
possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
|
|
575
|
+
Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
576
|
+
|
|
577
|
+
Single sample example:
|
|
578
|
+
|
|
579
|
+
```python
|
|
580
|
+
Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
581
|
+
```
|
|
582
|
+
|
|
583
|
+
Multiple samples example:
|
|
584
|
+
|
|
585
|
+
```python
|
|
586
|
+
D = None
|
|
587
|
+
rgrad = None
|
|
588
|
+
for i in range(n_samples):
|
|
589
|
+
v = [torch.randn_like(p) for p in params]
|
|
590
|
+
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
591
|
+
|
|
592
|
+
if D is None: D = Hvp
|
|
593
|
+
else: torch._foreach_add_(D, Hvp)
|
|
594
|
+
|
|
595
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
596
|
+
```
|
|
597
|
+
|
|
598
|
+
Args:
|
|
599
|
+
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
600
|
+
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
601
|
+
var (Var): Var
|
|
602
|
+
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
603
|
+
hvp_method (str): hvp method.
|
|
604
|
+
h (float): finite difference step size
|
|
605
|
+
normalize (bool): whether to normalize v for finite difference
|
|
606
|
+
retain_grad (bool): retain grad
|
|
607
|
+
"""
|
|
608
|
+
# get grad
|
|
609
|
+
if rgrad is None and hvp_method in ('autograd', 'forward'):
|
|
610
|
+
if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
|
|
611
|
+
else:
|
|
612
|
+
if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
613
|
+
with torch.enable_grad():
|
|
614
|
+
loss = var.closure()
|
|
615
|
+
rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
|
|
616
|
+
|
|
617
|
+
if hvp_method == 'autograd':
|
|
618
|
+
assert rgrad is not None
|
|
619
|
+
Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
|
|
620
|
+
|
|
621
|
+
elif hvp_method == 'forward':
|
|
622
|
+
assert rgrad is not None
|
|
623
|
+
loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
|
|
624
|
+
|
|
625
|
+
elif hvp_method == 'central':
|
|
626
|
+
loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
|
|
627
|
+
|
|
628
|
+
else:
|
|
629
|
+
raise ValueError(hvp_method)
|
|
630
|
+
|
|
631
|
+
return Hvp, rgrad
|
|
632
|
+
|
|
633
|
+
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
634
|
+
if seed is None: return None
|
|
635
|
+
|
|
636
|
+
if 'generator' not in self.global_state:
|
|
637
|
+
self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
|
|
638
|
+
|
|
639
|
+
return self.global_state['generator']
|
|
422
640
|
|
|
423
641
|
# endregion
|
|
424
642
|
|
|
@@ -440,6 +658,27 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
|
440
658
|
|
|
441
659
|
# region Modular
|
|
442
660
|
# ---------------------------------- Modular --------------------------------- #
|
|
661
|
+
|
|
662
|
+
class _EvalCounterClosure:
|
|
663
|
+
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
664
|
+
__slots__ = ("modular", "closure")
|
|
665
|
+
def __init__(self, modular: "Modular", closure):
|
|
666
|
+
self.modular = modular
|
|
667
|
+
self.closure = closure
|
|
668
|
+
|
|
669
|
+
def __call__(self, *args, **kwargs):
|
|
670
|
+
if self.closure is None:
|
|
671
|
+
raise RuntimeError("One of the modules requires closure to be passed to the step method")
|
|
672
|
+
|
|
673
|
+
v = self.closure(*args, **kwargs)
|
|
674
|
+
|
|
675
|
+
# set closure return on 1st evaluation
|
|
676
|
+
if self.modular._closure_return is None:
|
|
677
|
+
self.modular._closure_return = v
|
|
678
|
+
|
|
679
|
+
self.modular.num_evaluations += 1
|
|
680
|
+
return v
|
|
681
|
+
|
|
443
682
|
# have to inherit from Modular to support lr schedulers
|
|
444
683
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
445
684
|
class Modular(torch.optim.Optimizer):
|
|
@@ -456,6 +695,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
456
695
|
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
457
696
|
|
|
458
697
|
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
698
|
+
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
|
|
459
699
|
self.model: torch.nn.Module | None = None
|
|
460
700
|
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
461
701
|
if isinstance(params, torch.nn.Module):
|
|
@@ -489,14 +729,33 @@ class Modular(torch.optim.Optimizer):
|
|
|
489
729
|
for m in self.unrolled_modules: defaults.update(m.defaults)
|
|
490
730
|
super().__init__(param_groups, defaults=defaults)
|
|
491
731
|
|
|
492
|
-
# note - this is what super
|
|
732
|
+
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
493
733
|
|
|
494
734
|
# self.defaults = defaults
|
|
495
735
|
# for param_group in param_groups:
|
|
496
736
|
# self.add_param_group(param_group)
|
|
497
737
|
|
|
738
|
+
# add_param_group adds a ChainMap where defaults are lowest priority,
|
|
739
|
+
# and entries specifed in param_groups or scheduler are higher priority.
|
|
740
|
+
# pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
|
|
741
|
+
# in each module, settings passed to that module by calling set_param_groups are highest priority
|
|
742
|
+
|
|
498
743
|
self.current_step = 0
|
|
499
|
-
"""
|
|
744
|
+
"""global step counter for the optimizer."""
|
|
745
|
+
|
|
746
|
+
self.num_evaluations = 0
|
|
747
|
+
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
748
|
+
|
|
749
|
+
# reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
|
|
750
|
+
# we want to return original loss so this attribute is used
|
|
751
|
+
self._closure_return = None
|
|
752
|
+
"""on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
|
|
753
|
+
|
|
754
|
+
self.attrs = {}
|
|
755
|
+
"""custom attributes that can be set by modules, for example EMA of weights or best so far"""
|
|
756
|
+
|
|
757
|
+
self.should_terminate = False
|
|
758
|
+
"""is set to True by termination criteria modules."""
|
|
500
759
|
|
|
501
760
|
def add_param_group(self, param_group: dict[str, Any]):
|
|
502
761
|
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
@@ -542,10 +801,13 @@ class Modular(torch.optim.Optimizer):
|
|
|
542
801
|
|
|
543
802
|
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
544
803
|
for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
|
|
545
|
-
m.
|
|
804
|
+
m._load_state_dict(sd, id_to_tensor)
|
|
805
|
+
|
|
546
806
|
|
|
807
|
+
def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
808
|
+
# clear closure return from previous step
|
|
809
|
+
self._closure_return = None
|
|
547
810
|
|
|
548
|
-
def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
549
811
|
# propagate global per-parameter setting overrides
|
|
550
812
|
for g in self.param_groups:
|
|
551
813
|
settings = dict(g.maps[0]) # ignore defaults
|
|
@@ -558,15 +820,17 @@ class Modular(torch.optim.Optimizer):
|
|
|
558
820
|
|
|
559
821
|
# create var
|
|
560
822
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
561
|
-
var = Var(params=params, closure=closure, model=self.model, current_step=self.current_step)
|
|
823
|
+
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
|
|
562
824
|
|
|
563
825
|
# if closure is None, assume backward has been called and gather grads
|
|
564
826
|
if closure is None:
|
|
565
827
|
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
828
|
+
self.num_evaluations += 1
|
|
566
829
|
|
|
830
|
+
n_modules = len(self.modules)
|
|
831
|
+
if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
|
|
567
832
|
last_module = self.modules[-1]
|
|
568
833
|
last_lr = last_module.defaults.get('lr', None)
|
|
569
|
-
n_modules = len(self.modules)
|
|
570
834
|
|
|
571
835
|
# step
|
|
572
836
|
for i, module in enumerate(self.modules):
|
|
@@ -586,11 +850,17 @@ class Modular(torch.optim.Optimizer):
|
|
|
586
850
|
with torch.no_grad():
|
|
587
851
|
torch._foreach_sub_(params, var.get_update())
|
|
588
852
|
|
|
853
|
+
# update attributes
|
|
854
|
+
self.attrs.update(var.attrs)
|
|
855
|
+
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
856
|
+
|
|
857
|
+
# hooks
|
|
589
858
|
for hook in var.post_step_hooks:
|
|
590
859
|
hook(self, var)
|
|
591
860
|
|
|
592
861
|
self.current_step += 1
|
|
593
|
-
return var.loss if var.loss is not None else var.loss_approx
|
|
862
|
+
#return var.loss if var.loss is not None else var.loss_approx
|
|
863
|
+
return self._closure_return
|
|
594
864
|
|
|
595
865
|
def __repr__(self):
|
|
596
866
|
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
@@ -606,6 +876,21 @@ class Chain(Module):
|
|
|
606
876
|
for i, module in enumerate(flat_modules):
|
|
607
877
|
self.set_child(f'module_{i}', module)
|
|
608
878
|
|
|
879
|
+
def update(self, var):
|
|
880
|
+
# note here that `update` and `apply` shouldn't be used directly
|
|
881
|
+
# as it will update all modules, and then apply all modules
|
|
882
|
+
# it is used in specific cases like Chain as trust region hessian module
|
|
883
|
+
for i in range(len(self.children)):
|
|
884
|
+
self.children[f'module_{i}'].update(var)
|
|
885
|
+
if var.stop: break
|
|
886
|
+
return var
|
|
887
|
+
|
|
888
|
+
def apply(self, var):
|
|
889
|
+
for i in range(len(self.children)):
|
|
890
|
+
var = self.children[f'module_{i}'].apply(var)
|
|
891
|
+
if var.stop: break
|
|
892
|
+
return var
|
|
893
|
+
|
|
609
894
|
def step(self, var):
|
|
610
895
|
for i in range(len(self.children)):
|
|
611
896
|
var = self.children[f'module_{i}'].step(var)
|
|
@@ -616,7 +901,7 @@ class Chain(Module):
|
|
|
616
901
|
s = self.__class__.__name__
|
|
617
902
|
if self.children:
|
|
618
903
|
if s == 'Chain': s = 'C' # to shorten it
|
|
619
|
-
s = f'{s}({", ".join(str(m) for m in self.children.values())}'
|
|
904
|
+
s = f'{s}({", ".join(str(m) for m in self.children.values())})'
|
|
620
905
|
return s
|
|
621
906
|
|
|
622
907
|
def maybe_chain(*modules: Chainable) -> Module:
|