torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -76
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +229 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/spsa1.py +93 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/__init__.py +1 -1
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +6 -7
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +114 -175
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +16 -4
- torchzero/modules/line_search/strong_wolfe.py +319 -220
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +253 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +207 -170
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +99 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +122 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/optimizer.py +2 -2
- torchzero/utils/python_tools.py +7 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.14.dist-info/METADATA +14 -0
- torchzero-0.3.14.dist-info/RECORD +167 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
torchzero/core/module.py
CHANGED
|
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
|
|
|
3
3
|
from collections import ChainMap, defaultdict
|
|
4
4
|
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
5
5
|
from operator import itemgetter
|
|
6
|
-
from typing import Any, final, overload, Literal
|
|
6
|
+
from typing import Any, final, overload, Literal, cast
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
@@ -16,6 +16,7 @@ from ..utils import (
|
|
|
16
16
|
)
|
|
17
17
|
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
18
18
|
from ..utils.python_tools import flatten
|
|
19
|
+
from ..utils.linalg.linear_operator import LinearOperator
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def _closure_backward(closure, params, retain_graph, create_graph):
|
|
@@ -33,11 +34,9 @@ def _closure_backward(closure, params, retain_graph, create_graph):
|
|
|
33
34
|
# ----------------------------------- var ----------------------------------- #
|
|
34
35
|
class Var:
|
|
35
36
|
"""
|
|
36
|
-
Holds
|
|
37
|
+
Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
|
|
38
|
+
Modules take in a ``Var`` object, modify and it is passed to the next module.
|
|
37
39
|
|
|
38
|
-
This class acts as a mutable container for information relevant to the current
|
|
39
|
-
optimization step, such as parameters, gradients, loss, and the computed update.
|
|
40
|
-
Modules read from and write to this object to coordinate their actions.
|
|
41
40
|
"""
|
|
42
41
|
def __init__(
|
|
43
42
|
self,
|
|
@@ -45,6 +44,10 @@ class Var:
|
|
|
45
44
|
closure: Callable | None,
|
|
46
45
|
model: torch.nn.Module | None,
|
|
47
46
|
current_step: int,
|
|
47
|
+
parent: "Var | None" = None,
|
|
48
|
+
modular: "Modular | None" = None,
|
|
49
|
+
loss: torch.Tensor | None = None,
|
|
50
|
+
storage: dict | None = None,
|
|
48
51
|
):
|
|
49
52
|
self.params: list[torch.Tensor] = params
|
|
50
53
|
"""List of all parameters with requires_grad = True."""
|
|
@@ -56,19 +59,31 @@ class Var:
|
|
|
56
59
|
"""torch.nn.Module object of the model, None if it wasn't specified."""
|
|
57
60
|
|
|
58
61
|
self.current_step: int = current_step
|
|
59
|
-
"""global current step, starts at 0
|
|
62
|
+
"""global current step, starts at 0. This may not correspond to module current step,
|
|
63
|
+
for example a module may step every 10 global steps."""
|
|
64
|
+
|
|
65
|
+
self.parent: "Var | None" = parent
|
|
66
|
+
"""parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
67
|
+
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
68
|
+
e.g. when projecting."""
|
|
69
|
+
|
|
70
|
+
self.modular: "Modular" = cast(Modular, modular)
|
|
71
|
+
"""Modular optimizer object that created this ``Var``."""
|
|
60
72
|
|
|
61
73
|
self.update: list[torch.Tensor] | None = None
|
|
62
74
|
"""
|
|
63
|
-
current update
|
|
75
|
+
current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
|
|
64
76
|
|
|
65
77
|
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
78
|
+
|
|
79
|
+
At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
|
|
80
|
+
gradient will be used and calculated if needed.
|
|
66
81
|
"""
|
|
67
82
|
|
|
68
83
|
self.grad: list[torch.Tensor] | None = None
|
|
69
|
-
"""gradient with current parameters. If closure is not None
|
|
84
|
+
"""gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
|
|
70
85
|
|
|
71
|
-
self.loss: torch.Tensor | Any | None =
|
|
86
|
+
self.loss: torch.Tensor | Any | None = loss
|
|
72
87
|
"""loss with current parameters."""
|
|
73
88
|
|
|
74
89
|
self.loss_approx: torch.Tensor | Any | None = None
|
|
@@ -77,24 +92,28 @@ class Var:
|
|
|
77
92
|
|
|
78
93
|
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
79
94
|
"""list of functions to be called after optimizer step.
|
|
80
|
-
The signature is:
|
|
81
95
|
|
|
82
|
-
|
|
96
|
+
This attribute should always be modified in-place (using ``append`` or ``extend``).
|
|
83
97
|
|
|
84
|
-
|
|
98
|
+
The signature is:
|
|
85
99
|
|
|
100
|
+
```python
|
|
101
|
+
def hook(optimizer: Modular, var: Vars): ...
|
|
102
|
+
```
|
|
86
103
|
"""
|
|
87
104
|
|
|
88
105
|
self.is_last: bool = False
|
|
89
106
|
"""
|
|
90
107
|
Indicates that current module is either last or next-to-last before a learning rate module.
|
|
91
108
|
This is always False if current module has children or is a child.
|
|
109
|
+
This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
|
|
92
110
|
"""
|
|
93
111
|
|
|
94
112
|
self.nested_is_last: bool = False
|
|
95
113
|
"""
|
|
96
114
|
Indicates that current module is either last or next-to-last before a learning rate module, for modules
|
|
97
|
-
that have children.
|
|
115
|
+
that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
|
|
116
|
+
a child of a last module may also receive ``var.nested_is_last=True``.
|
|
98
117
|
"""
|
|
99
118
|
|
|
100
119
|
self.last_module_lrs: list[float] | None = None
|
|
@@ -105,19 +124,30 @@ class Var:
|
|
|
105
124
|
"""
|
|
106
125
|
|
|
107
126
|
self.stop: bool = False
|
|
108
|
-
"""if True, all following modules will be skipped.
|
|
127
|
+
"""if True, all following modules will be skipped.
|
|
128
|
+
If this module is a child, it only affects modules at the same level (in the same Chain)."""
|
|
109
129
|
|
|
110
130
|
self.skip_update: bool = False
|
|
111
|
-
"""if True, the parameters will not be updated"""
|
|
131
|
+
"""if True, the parameters will not be updated."""
|
|
112
132
|
|
|
113
|
-
self.storage: dict = {}
|
|
114
|
-
"""Storage for any other data, such as hessian estimates, etc"""
|
|
133
|
+
# self.storage: dict = {}
|
|
134
|
+
# """Storage for any other data, such as hessian estimates, etc."""
|
|
115
135
|
|
|
116
|
-
|
|
117
|
-
"""
|
|
118
|
-
|
|
136
|
+
self.attrs: dict = {}
|
|
137
|
+
"""attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
|
|
138
|
+
|
|
139
|
+
if storage is None: storage = {}
|
|
140
|
+
self.storage: dict = storage
|
|
141
|
+
"""additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
|
|
142
|
+
|
|
143
|
+
self.should_terminate: bool | None = None
|
|
144
|
+
"""termination criteria, Modular.should_terminate is set to this after each step if not None"""
|
|
119
145
|
|
|
146
|
+
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
147
|
+
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
|
|
148
|
+
Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
|
|
120
149
|
if self.loss is None:
|
|
150
|
+
|
|
121
151
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
122
152
|
if backward:
|
|
123
153
|
with torch.enable_grad():
|
|
@@ -128,7 +158,10 @@ class Var:
|
|
|
128
158
|
# initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
|
|
129
159
|
# it is technically a more correct approach for when some parameters conditionally receive gradients
|
|
130
160
|
# and in this case it shouldn't be slower.
|
|
131
|
-
|
|
161
|
+
|
|
162
|
+
# next time closure() is called, it will set grad to None.
|
|
163
|
+
# zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
|
|
164
|
+
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
132
165
|
else:
|
|
133
166
|
self.loss = self.loss_approx = self.closure(False)
|
|
134
167
|
|
|
@@ -143,11 +176,24 @@ class Var:
|
|
|
143
176
|
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
144
177
|
)
|
|
145
178
|
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
179
|
+
|
|
180
|
+
# set parent grad
|
|
181
|
+
if self.parent is not None:
|
|
182
|
+
# the way projections/split work, they make a new closure which evaluates original
|
|
183
|
+
# closure and projects the gradient, and set it as their var.closure.
|
|
184
|
+
# then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
|
|
185
|
+
# and we set it to parent var here.
|
|
186
|
+
if self.parent.loss is None: self.parent.loss = self.loss
|
|
187
|
+
if self.parent.grad is None and backward:
|
|
188
|
+
if all(p.grad is None for p in self.parent.params):
|
|
189
|
+
warnings.warn("Parent grad is None after backward.")
|
|
190
|
+
self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
|
|
191
|
+
|
|
146
192
|
return self.loss # type:ignore
|
|
147
193
|
|
|
148
194
|
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
149
195
|
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
150
|
-
|
|
196
|
+
``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
|
|
151
197
|
if self.grad is None:
|
|
152
198
|
if self.closure is None: raise RuntimeError("closure is None")
|
|
153
199
|
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
@@ -156,15 +202,21 @@ class Var:
|
|
|
156
202
|
return self.grad
|
|
157
203
|
|
|
158
204
|
def get_update(self) -> list[torch.Tensor]:
|
|
159
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to
|
|
160
|
-
Computing the gradients may assign
|
|
205
|
+
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
|
|
206
|
+
Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
|
|
161
207
|
Do not call this at perturbed parameters."""
|
|
162
208
|
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
163
209
|
return self.update
|
|
164
210
|
|
|
165
|
-
def clone(self, clone_update: bool):
|
|
166
|
-
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via
|
|
167
|
-
|
|
211
|
+
def clone(self, clone_update: bool, parent: "Var | None" = None):
|
|
212
|
+
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
|
|
213
|
+
|
|
214
|
+
Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
|
|
215
|
+
|
|
216
|
+
Setting ``parent`` is only if clone's parameters are something different,
|
|
217
|
+
while clone's closure referes to the same objective but with a "view" on parameters.
|
|
218
|
+
"""
|
|
219
|
+
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
|
|
168
220
|
|
|
169
221
|
if clone_update and self.update is not None:
|
|
170
222
|
copy.update = [u.clone() for u in self.update]
|
|
@@ -174,10 +226,16 @@ class Var:
|
|
|
174
226
|
copy.grad = self.grad
|
|
175
227
|
copy.loss = self.loss
|
|
176
228
|
copy.loss_approx = self.loss_approx
|
|
229
|
+
copy.closure = self.closure
|
|
177
230
|
copy.post_step_hooks = self.post_step_hooks
|
|
178
231
|
copy.stop = self.stop
|
|
179
232
|
copy.skip_update = self.skip_update
|
|
180
233
|
|
|
234
|
+
copy.modular = self.modular
|
|
235
|
+
copy.attrs = self.attrs
|
|
236
|
+
copy.storage = self.storage
|
|
237
|
+
copy.should_terminate = self.should_terminate
|
|
238
|
+
|
|
181
239
|
return copy
|
|
182
240
|
|
|
183
241
|
def update_attrs_from_clone_(self, var: "Var"):
|
|
@@ -186,11 +244,15 @@ class Var:
|
|
|
186
244
|
object. This propagates any newly computed loss or gradient values
|
|
187
245
|
from the child's context back to the parent `Vars` if the parent
|
|
188
246
|
didn't have them computed already.
|
|
247
|
+
|
|
248
|
+
Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
|
|
249
|
+
if the child updates them, the update will affect the parent too.
|
|
189
250
|
"""
|
|
190
251
|
if self.loss is None: self.loss = var.loss
|
|
191
252
|
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
192
253
|
if self.grad is None: self.grad = var.grad
|
|
193
|
-
|
|
254
|
+
|
|
255
|
+
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
194
256
|
|
|
195
257
|
def zero_grad(self, set_to_none=True):
|
|
196
258
|
if set_to_none:
|
|
@@ -201,6 +263,7 @@ class Var:
|
|
|
201
263
|
|
|
202
264
|
# endregion
|
|
203
265
|
|
|
266
|
+
|
|
204
267
|
# region Module
|
|
205
268
|
# ---------------------------------- module ---------------------------------- #
|
|
206
269
|
class Module(ABC):
|
|
@@ -313,17 +376,16 @@ class Module(ABC):
|
|
|
313
376
|
|
|
314
377
|
If you want to force it to return a tuple even with a single key, pass a list/tuple of 1 or more keys.
|
|
315
378
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
# returns cls (by default TensorList)
|
|
320
|
-
|
|
321
|
-
exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
|
|
322
|
-
# returns list of cls
|
|
379
|
+
```python
|
|
380
|
+
exp_avg = self.state_vals("exp_avg")
|
|
381
|
+
# returns cls (by default TensorList)
|
|
323
382
|
|
|
324
|
-
|
|
325
|
-
|
|
383
|
+
exp_avg, exp_avg_sq = self.state_vals("exp_avg", "exp_avg_sq")
|
|
384
|
+
# returns list of cls
|
|
326
385
|
|
|
386
|
+
exp_avg = self.state_vals(["exp_avg"])
|
|
387
|
+
# always returns a list of cls, even if got a single key
|
|
388
|
+
```
|
|
327
389
|
|
|
328
390
|
Args:
|
|
329
391
|
*keys (str):
|
|
@@ -402,7 +464,8 @@ class Module(ABC):
|
|
|
402
464
|
}
|
|
403
465
|
return state_dict
|
|
404
466
|
|
|
405
|
-
def
|
|
467
|
+
def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
|
|
468
|
+
"""loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
|
|
406
469
|
# load state
|
|
407
470
|
state = state_dict['state']
|
|
408
471
|
self.state.clear()
|
|
@@ -421,7 +484,7 @@ class Module(ABC):
|
|
|
421
484
|
|
|
422
485
|
# children
|
|
423
486
|
for k, v in state_dict['children']:
|
|
424
|
-
if k in self.children: self.children[k].
|
|
487
|
+
if k in self.children: self.children[k]._load_state_dict(v, id_to_tensor)
|
|
425
488
|
else: warnings.warn(f'State dict for {self} has child {k}, which is missing in {self}')
|
|
426
489
|
|
|
427
490
|
# extra info
|
|
@@ -429,37 +492,72 @@ class Module(ABC):
|
|
|
429
492
|
|
|
430
493
|
# ---------------------------- OVERRIDABLE METHODS --------------------------- #
|
|
431
494
|
def step(self, var: Var) -> Var:
|
|
432
|
-
"""performs a step, returns new var but may update it in-place."""
|
|
495
|
+
"""performs a step, returns new ``var`` but may update it in-place."""
|
|
433
496
|
self.update(var)
|
|
434
497
|
return self.apply(var)
|
|
435
498
|
|
|
436
499
|
def update(self, var:Var) -> Any:
|
|
437
|
-
"""Updates the internal state of this module. This should not modify
|
|
500
|
+
"""Updates the internal state of this module. This should not modify ``var.update``.
|
|
438
501
|
|
|
439
502
|
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
440
|
-
such as
|
|
503
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
|
|
441
504
|
"""
|
|
442
505
|
|
|
443
506
|
def apply(self, var: Var) -> Var:
|
|
444
|
-
"""Applies this module to ``var.get_update()``.
|
|
445
|
-
|
|
507
|
+
"""Applies this module to ``var.get_update()``.
|
|
508
|
+
This should not modify the internal state of this module if possible.
|
|
509
|
+
|
|
510
|
+
Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
|
|
511
|
+
such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
|
|
512
|
+
"""
|
|
513
|
+
return self.step(var)
|
|
514
|
+
|
|
515
|
+
def get_H(self, var: Var) -> LinearOperator | None:
|
|
516
|
+
"""returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
|
|
517
|
+
The hessian approximation is assumed to be for all parameters concatenated to a vector."""
|
|
518
|
+
# if this method is not defined it searches in children
|
|
519
|
+
# this should be overwritten to return None if child params are different from this modules params
|
|
520
|
+
H = None
|
|
521
|
+
for k,v in self.children.items():
|
|
522
|
+
H_v = v.get_H(var)
|
|
523
|
+
|
|
524
|
+
if (H is not None) and (H_v is not None):
|
|
525
|
+
raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
|
|
526
|
+
|
|
527
|
+
if H_v is not None: H = H_v
|
|
528
|
+
|
|
529
|
+
return H
|
|
446
530
|
|
|
447
531
|
def reset(self):
|
|
448
|
-
"""Resets the internal state of the module (e.g. momentum). By default clears state and global state."""
|
|
449
|
-
# no complex logic is allowed there because this is overridden by many modules
|
|
450
|
-
# where super().reset() shouldn't be called
|
|
532
|
+
"""Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
|
|
451
533
|
self.state.clear()
|
|
534
|
+
|
|
535
|
+
generator = self.global_state.get("generator", None)
|
|
452
536
|
self.global_state.clear()
|
|
537
|
+
if generator is not None: self.global_state["generator"] = generator
|
|
538
|
+
|
|
539
|
+
for c in self.children.values(): c.reset()
|
|
453
540
|
|
|
454
541
|
def reset_for_online(self):
|
|
455
|
-
"""
|
|
542
|
+
"""Resets buffers that depend on previous evaluation, such as previous gradient and loss,
|
|
543
|
+
which may become inaccurate due to mini-batching.
|
|
544
|
+
|
|
545
|
+
``Online`` module calls ``reset_for_online``,
|
|
546
|
+
then it calls ``update`` with previous parameters,
|
|
547
|
+
then it calls ``update`` with current parameters,
|
|
548
|
+
and then ``apply``.
|
|
549
|
+
"""
|
|
456
550
|
for c in self.children.values(): c.reset_for_online()
|
|
457
551
|
|
|
458
552
|
def _extra_pack(self):
|
|
553
|
+
"""extra information to store in state_dict of this optimizer.
|
|
554
|
+
Will be passed to ``_extra_unpack`` when loading the state_dict."""
|
|
459
555
|
return {}
|
|
460
556
|
|
|
461
557
|
def _extra_unpack(self, x):
|
|
462
|
-
|
|
558
|
+
"""``_extra_pack`` return will be passed to this method when loading state_dict.
|
|
559
|
+
This method is called after loading the rest of the state dict"""
|
|
560
|
+
|
|
463
561
|
|
|
464
562
|
|
|
465
563
|
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
@@ -474,30 +572,33 @@ class Module(ABC):
|
|
|
474
572
|
h: float,
|
|
475
573
|
normalize: bool,
|
|
476
574
|
retain_grad: bool,
|
|
477
|
-
):
|
|
575
|
+
) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
478
576
|
"""
|
|
479
|
-
Returns ``(Hvp, rgrad)
|
|
577
|
+
Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
|
|
578
|
+
possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
|
|
579
|
+
Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
480
580
|
|
|
481
581
|
Single sample example:
|
|
482
582
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
583
|
+
```python
|
|
584
|
+
Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
585
|
+
```
|
|
486
586
|
|
|
487
587
|
Multiple samples example:
|
|
488
588
|
|
|
489
|
-
|
|
589
|
+
```python
|
|
590
|
+
D = None
|
|
591
|
+
rgrad = None
|
|
592
|
+
for i in range(n_samples):
|
|
593
|
+
v = [torch.randn_like(p) for p in params]
|
|
594
|
+
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
490
595
|
|
|
491
|
-
D =
|
|
492
|
-
|
|
493
|
-
for i in range(n_samples):
|
|
494
|
-
v = [torch.randn_like(p) for p in params]
|
|
495
|
-
Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
596
|
+
if D is None: D = Hvp
|
|
597
|
+
else: torch._foreach_add_(D, Hvp)
|
|
496
598
|
|
|
497
|
-
|
|
498
|
-
|
|
599
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
600
|
+
```
|
|
499
601
|
|
|
500
|
-
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
501
602
|
Args:
|
|
502
603
|
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
503
604
|
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
@@ -533,6 +634,14 @@ class Module(ABC):
|
|
|
533
634
|
|
|
534
635
|
return Hvp, rgrad
|
|
535
636
|
|
|
637
|
+
def get_generator(self, device: torch.types.Device, seed: int | None):
|
|
638
|
+
if seed is None: return None
|
|
639
|
+
|
|
640
|
+
if 'generator' not in self.global_state:
|
|
641
|
+
self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
|
|
642
|
+
|
|
643
|
+
return self.global_state['generator']
|
|
644
|
+
|
|
536
645
|
# endregion
|
|
537
646
|
|
|
538
647
|
Chainable = Module | Sequence[Module]
|
|
@@ -555,7 +664,7 @@ def unroll_modules(*modules: Chainable) -> list[Module]:
|
|
|
555
664
|
# ---------------------------------- Modular --------------------------------- #
|
|
556
665
|
|
|
557
666
|
class _EvalCounterClosure:
|
|
558
|
-
"""keeps track of how many times closure has been evaluated"""
|
|
667
|
+
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
559
668
|
__slots__ = ("modular", "closure")
|
|
560
669
|
def __init__(self, modular: "Modular", closure):
|
|
561
670
|
self.modular = modular
|
|
@@ -565,8 +674,14 @@ class _EvalCounterClosure:
|
|
|
565
674
|
if self.closure is None:
|
|
566
675
|
raise RuntimeError("One of the modules requires closure to be passed to the step method")
|
|
567
676
|
|
|
677
|
+
v = self.closure(*args, **kwargs)
|
|
678
|
+
|
|
679
|
+
# set closure return on 1st evaluation
|
|
680
|
+
if self.modular._closure_return is None:
|
|
681
|
+
self.modular._closure_return = v
|
|
682
|
+
|
|
568
683
|
self.modular.num_evaluations += 1
|
|
569
|
-
return
|
|
684
|
+
return v
|
|
570
685
|
|
|
571
686
|
# have to inherit from Modular to support lr schedulers
|
|
572
687
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
@@ -584,6 +699,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
584
699
|
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
585
700
|
|
|
586
701
|
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
702
|
+
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
|
|
587
703
|
self.model: torch.nn.Module | None = None
|
|
588
704
|
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
589
705
|
if isinstance(params, torch.nn.Module):
|
|
@@ -617,18 +733,34 @@ class Modular(torch.optim.Optimizer):
|
|
|
617
733
|
for m in self.unrolled_modules: defaults.update(m.defaults)
|
|
618
734
|
super().__init__(param_groups, defaults=defaults)
|
|
619
735
|
|
|
620
|
-
# note - this is what super
|
|
736
|
+
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
621
737
|
|
|
622
738
|
# self.defaults = defaults
|
|
623
739
|
# for param_group in param_groups:
|
|
624
740
|
# self.add_param_group(param_group)
|
|
625
741
|
|
|
742
|
+
# add_param_group adds a ChainMap where defaults are lowest priority,
|
|
743
|
+
# and entries specifed in param_groups or scheduler are higher priority.
|
|
744
|
+
# pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
|
|
745
|
+
# in each module, settings passed to that module by calling set_param_groups are highest priority
|
|
746
|
+
|
|
626
747
|
self.current_step = 0
|
|
627
748
|
"""global step counter for the optimizer."""
|
|
628
749
|
|
|
629
750
|
self.num_evaluations = 0
|
|
630
751
|
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
631
752
|
|
|
753
|
+
# reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
|
|
754
|
+
# we want to return original loss so this attribute is used
|
|
755
|
+
self._closure_return = None
|
|
756
|
+
"""on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
|
|
757
|
+
|
|
758
|
+
self.attrs = {}
|
|
759
|
+
"""custom attributes that can be set by modules, for example EMA of weights or best so far"""
|
|
760
|
+
|
|
761
|
+
self.should_terminate = False
|
|
762
|
+
"""is set to True by termination criteria modules."""
|
|
763
|
+
|
|
632
764
|
def add_param_group(self, param_group: dict[str, Any]):
|
|
633
765
|
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
634
766
|
self.param_groups.append(ChainMap(proc_param_group, self.defaults))
|
|
@@ -673,10 +805,13 @@ class Modular(torch.optim.Optimizer):
|
|
|
673
805
|
|
|
674
806
|
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
675
807
|
for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
|
|
676
|
-
m.
|
|
808
|
+
m._load_state_dict(sd, id_to_tensor)
|
|
677
809
|
|
|
678
810
|
|
|
679
|
-
def step(self, closure=None): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
811
|
+
def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
812
|
+
# clear closure return from previous step
|
|
813
|
+
self._closure_return = None
|
|
814
|
+
|
|
680
815
|
# propagate global per-parameter setting overrides
|
|
681
816
|
for g in self.param_groups:
|
|
682
817
|
settings = dict(g.maps[0]) # ignore defaults
|
|
@@ -689,16 +824,17 @@ class Modular(torch.optim.Optimizer):
|
|
|
689
824
|
|
|
690
825
|
# create var
|
|
691
826
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
692
|
-
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step)
|
|
827
|
+
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
|
|
693
828
|
|
|
694
829
|
# if closure is None, assume backward has been called and gather grads
|
|
695
830
|
if closure is None:
|
|
696
831
|
var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
697
832
|
self.num_evaluations += 1
|
|
698
833
|
|
|
834
|
+
n_modules = len(self.modules)
|
|
835
|
+
if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
|
|
699
836
|
last_module = self.modules[-1]
|
|
700
837
|
last_lr = last_module.defaults.get('lr', None)
|
|
701
|
-
n_modules = len(self.modules)
|
|
702
838
|
|
|
703
839
|
# step
|
|
704
840
|
for i, module in enumerate(self.modules):
|
|
@@ -718,11 +854,17 @@ class Modular(torch.optim.Optimizer):
|
|
|
718
854
|
with torch.no_grad():
|
|
719
855
|
torch._foreach_sub_(params, var.get_update())
|
|
720
856
|
|
|
857
|
+
# update attributes
|
|
858
|
+
self.attrs.update(var.attrs)
|
|
859
|
+
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
860
|
+
|
|
861
|
+
# hooks
|
|
721
862
|
for hook in var.post_step_hooks:
|
|
722
863
|
hook(self, var)
|
|
723
864
|
|
|
724
865
|
self.current_step += 1
|
|
725
|
-
return var.loss if var.loss is not None else var.loss_approx
|
|
866
|
+
#return var.loss if var.loss is not None else var.loss_approx
|
|
867
|
+
return self._closure_return
|
|
726
868
|
|
|
727
869
|
def __repr__(self):
|
|
728
870
|
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
@@ -738,6 +880,21 @@ class Chain(Module):
|
|
|
738
880
|
for i, module in enumerate(flat_modules):
|
|
739
881
|
self.set_child(f'module_{i}', module)
|
|
740
882
|
|
|
883
|
+
def update(self, var):
|
|
884
|
+
# note here that `update` and `apply` shouldn't be used directly
|
|
885
|
+
# as it will update all modules, and then apply all modules
|
|
886
|
+
# it is used in specific cases like Chain as trust region hessian module
|
|
887
|
+
for i in range(len(self.children)):
|
|
888
|
+
self.children[f'module_{i}'].update(var)
|
|
889
|
+
if var.stop: break
|
|
890
|
+
return var
|
|
891
|
+
|
|
892
|
+
def apply(self, var):
|
|
893
|
+
for i in range(len(self.children)):
|
|
894
|
+
var = self.children[f'module_{i}'].apply(var)
|
|
895
|
+
if var.stop: break
|
|
896
|
+
return var
|
|
897
|
+
|
|
741
898
|
def step(self, var):
|
|
742
899
|
for i in range(len(self.children)):
|
|
743
900
|
var = self.children[f'module_{i}'].step(var)
|
|
@@ -748,7 +905,7 @@ class Chain(Module):
|
|
|
748
905
|
s = self.__class__.__name__
|
|
749
906
|
if self.children:
|
|
750
907
|
if s == 'Chain': s = 'C' # to shorten it
|
|
751
|
-
s = f'{s}({", ".join(str(m) for m in self.children.values())}'
|
|
908
|
+
s = f'{s}({", ".join(str(m) for m in self.children.values())})'
|
|
752
909
|
return s
|
|
753
910
|
|
|
754
911
|
def maybe_chain(*modules: Chainable) -> Module:
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .module import Chainable, Modular, Module, Var
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Reformulation(Module, ABC):
|
|
10
|
+
def __init__(self, defaults: dict | None, modules: Chainable | None):
|
|
11
|
+
super().__init__(defaults)
|
|
12
|
+
|
|
13
|
+
if modules is not None:
|
|
14
|
+
self.set_child("modules", modules)
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
18
|
+
"""
|
|
19
|
+
returns (loss, gradient), if backward is False then gradient can be None.
|
|
20
|
+
|
|
21
|
+
If evaluating original loss/gradient at x_0, set them to ``var``.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def pre_step(self, var: Var) -> Var | None:
|
|
25
|
+
"""This runs once before each step, whereas `closure` may run multiple times per step if further modules
|
|
26
|
+
evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
|
|
27
|
+
|
|
28
|
+
def step(self, var):
|
|
29
|
+
ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
|
|
30
|
+
if isinstance(ret, Var): var = ret
|
|
31
|
+
|
|
32
|
+
if var.closure is None: raise RuntimeError("Reformulation requires closure")
|
|
33
|
+
params, closure = var.params, var.closure
|
|
34
|
+
|
|
35
|
+
# step with children
|
|
36
|
+
if 'modules' in self.children:
|
|
37
|
+
|
|
38
|
+
# make a reformulated closure
|
|
39
|
+
def modified_closure(backward=True):
|
|
40
|
+
loss, grad = self.closure(backward, closure, params, var)
|
|
41
|
+
|
|
42
|
+
if grad is not None:
|
|
43
|
+
for p,g in zip(params, grad):
|
|
44
|
+
p.grad = g
|
|
45
|
+
|
|
46
|
+
return loss
|
|
47
|
+
|
|
48
|
+
# set it to a new Var object
|
|
49
|
+
modified_var = var.clone(clone_update=False)
|
|
50
|
+
modified_var.closure = modified_closure
|
|
51
|
+
|
|
52
|
+
# step with child
|
|
53
|
+
modules = self.children['modules']
|
|
54
|
+
modified_var = modules.step(modified_var)
|
|
55
|
+
|
|
56
|
+
# modified_var.loss and grad refers to loss and grad of a modified objective
|
|
57
|
+
# so we only take the update
|
|
58
|
+
var.update = modified_var.update
|
|
59
|
+
|
|
60
|
+
# or just evaluate new closure and set to update
|
|
61
|
+
else:
|
|
62
|
+
loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
|
|
63
|
+
if grad is not None: var.update = list(grad)
|
|
64
|
+
|
|
65
|
+
return var
|