torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/core/var.py
DELETED
|
@@ -1,376 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
import warnings
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
from collections import ChainMap, defaultdict
|
|
5
|
-
from collections.abc import Callable, Iterable, MutableMapping, Sequence
|
|
6
|
-
from operator import itemgetter
|
|
7
|
-
from typing import Any, final, overload, Literal, cast, TYPE_CHECKING
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
from ..utils import (
|
|
12
|
-
Init,
|
|
13
|
-
ListLike,
|
|
14
|
-
Params,
|
|
15
|
-
_make_param_groups,
|
|
16
|
-
get_state_vals,
|
|
17
|
-
vec_to_tensors
|
|
18
|
-
)
|
|
19
|
-
from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward, flatten_jacobian
|
|
20
|
-
from ..utils.python_tools import flatten
|
|
21
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
22
|
-
|
|
23
|
-
if TYPE_CHECKING:
|
|
24
|
-
from .modular import Modular
|
|
25
|
-
|
|
26
|
-
def _closure_backward(closure, params, retain_graph, create_graph):
|
|
27
|
-
with torch.enable_grad():
|
|
28
|
-
if not (retain_graph or create_graph):
|
|
29
|
-
return closure()
|
|
30
|
-
|
|
31
|
-
for p in params: p.grad = None
|
|
32
|
-
loss = closure(False)
|
|
33
|
-
grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
|
|
34
|
-
for p,g in zip(params,grad): p.grad = g
|
|
35
|
-
return loss
|
|
36
|
-
|
|
37
|
-
# region Vars
|
|
38
|
-
# ----------------------------------- var ----------------------------------- #
|
|
39
|
-
class Var:
|
|
40
|
-
"""
|
|
41
|
-
Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
|
|
42
|
-
Modules take in a ``Var`` object, modify and it is passed to the next module.
|
|
43
|
-
|
|
44
|
-
"""
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
params: list[torch.Tensor],
|
|
48
|
-
closure: Callable | None,
|
|
49
|
-
model: torch.nn.Module | None,
|
|
50
|
-
current_step: int,
|
|
51
|
-
parent: "Var | None" = None,
|
|
52
|
-
modular: "Modular | None" = None,
|
|
53
|
-
loss: torch.Tensor | None = None,
|
|
54
|
-
storage: dict | None = None,
|
|
55
|
-
):
|
|
56
|
-
self.params: list[torch.Tensor] = params
|
|
57
|
-
"""List of all parameters with requires_grad = True."""
|
|
58
|
-
|
|
59
|
-
self.closure = closure
|
|
60
|
-
"""A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
|
|
61
|
-
|
|
62
|
-
self.model = model
|
|
63
|
-
"""torch.nn.Module object of the model, None if it wasn't specified."""
|
|
64
|
-
|
|
65
|
-
self.current_step: int = current_step
|
|
66
|
-
"""global current step, starts at 0. This may not correspond to module current step,
|
|
67
|
-
for example a module may step every 10 global steps."""
|
|
68
|
-
|
|
69
|
-
self.parent: "Var | None" = parent
|
|
70
|
-
"""parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
|
|
71
|
-
Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
|
|
72
|
-
e.g. when projecting."""
|
|
73
|
-
|
|
74
|
-
self.modular: "Modular | None" = modular
|
|
75
|
-
"""Modular optimizer object that created this ``Var``."""
|
|
76
|
-
|
|
77
|
-
self.update: list[torch.Tensor] | None = None
|
|
78
|
-
"""
|
|
79
|
-
current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
|
|
80
|
-
|
|
81
|
-
If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
|
|
82
|
-
|
|
83
|
-
At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
|
|
84
|
-
gradient will be used and calculated if needed.
|
|
85
|
-
"""
|
|
86
|
-
|
|
87
|
-
self.grad: list[torch.Tensor] | None = None
|
|
88
|
-
"""gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
|
|
89
|
-
|
|
90
|
-
self.loss: torch.Tensor | Any | None = loss
|
|
91
|
-
"""loss with current parameters."""
|
|
92
|
-
|
|
93
|
-
self.loss_approx: torch.Tensor | Any | None = None
|
|
94
|
-
"""loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
|
|
95
|
-
whereas some other modules require loss strictly at current point."""
|
|
96
|
-
|
|
97
|
-
self.post_step_hooks: list[Callable[[Modular, Var]]] = []
|
|
98
|
-
"""list of functions to be called after optimizer step.
|
|
99
|
-
|
|
100
|
-
This attribute should always be modified in-place (using ``append`` or ``extend``).
|
|
101
|
-
|
|
102
|
-
The signature is:
|
|
103
|
-
|
|
104
|
-
```python
|
|
105
|
-
def hook(optimizer: Modular, var: Vars): ...
|
|
106
|
-
```
|
|
107
|
-
"""
|
|
108
|
-
|
|
109
|
-
self.stop: bool = False
|
|
110
|
-
"""if True, all following modules will be skipped.
|
|
111
|
-
If this module is a child, it only affects modules at the same level (in the same Chain)."""
|
|
112
|
-
|
|
113
|
-
self.skip_update: bool = False
|
|
114
|
-
"""if True, the parameters will not be updated."""
|
|
115
|
-
|
|
116
|
-
# self.storage: dict = {}
|
|
117
|
-
# """Storage for any other data, such as hessian estimates, etc."""
|
|
118
|
-
|
|
119
|
-
self.attrs: dict = {}
|
|
120
|
-
"""attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
|
|
121
|
-
|
|
122
|
-
if storage is None: storage = {}
|
|
123
|
-
self.storage: dict = storage
|
|
124
|
-
"""additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
|
|
125
|
-
|
|
126
|
-
self.should_terminate: bool | None = None
|
|
127
|
-
"""termination criteria, Modular.should_terminate is set to this after each step if not None"""
|
|
128
|
-
|
|
129
|
-
def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
|
|
130
|
-
"""Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
|
|
131
|
-
Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
|
|
132
|
-
if self.loss is None:
|
|
133
|
-
|
|
134
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
135
|
-
if backward:
|
|
136
|
-
with torch.enable_grad():
|
|
137
|
-
self.loss = self.loss_approx = _closure_backward(
|
|
138
|
-
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
|
|
142
|
-
# it is technically a more correct approach for when some parameters conditionally receive gradients
|
|
143
|
-
# and in this case it shouldn't be slower.
|
|
144
|
-
|
|
145
|
-
# next time closure() is called, it will set grad to None.
|
|
146
|
-
# zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
|
|
147
|
-
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
148
|
-
else:
|
|
149
|
-
self.loss = self.loss_approx = self.closure(False)
|
|
150
|
-
|
|
151
|
-
# if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
|
|
152
|
-
# and now it is requested to be evaluated with backward.
|
|
153
|
-
if backward and self.grad is None:
|
|
154
|
-
warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
|
|
155
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
156
|
-
|
|
157
|
-
with torch.enable_grad():
|
|
158
|
-
self.loss = self.loss_approx = _closure_backward(
|
|
159
|
-
closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
|
|
160
|
-
)
|
|
161
|
-
self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
|
|
162
|
-
|
|
163
|
-
# set parent grad
|
|
164
|
-
if self.parent is not None:
|
|
165
|
-
# the way projections/split work, they make a new closure which evaluates original
|
|
166
|
-
# closure and projects the gradient, and set it as their var.closure.
|
|
167
|
-
# then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
|
|
168
|
-
# and we set it to parent var here.
|
|
169
|
-
if self.parent.loss is None: self.parent.loss = self.loss
|
|
170
|
-
if self.parent.grad is None and backward:
|
|
171
|
-
if all(p.grad is None for p in self.parent.params):
|
|
172
|
-
warnings.warn("Parent grad is None after backward.")
|
|
173
|
-
self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
|
|
174
|
-
|
|
175
|
-
return self.loss # type:ignore
|
|
176
|
-
|
|
177
|
-
def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
|
|
178
|
-
"""Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
|
|
179
|
-
``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
|
|
180
|
-
if self.grad is None:
|
|
181
|
-
if self.closure is None: raise RuntimeError("closure is None")
|
|
182
|
-
self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
|
|
183
|
-
|
|
184
|
-
assert self.grad is not None
|
|
185
|
-
return self.grad
|
|
186
|
-
|
|
187
|
-
def get_update(self) -> list[torch.Tensor]:
|
|
188
|
-
"""Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
|
|
189
|
-
Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
|
|
190
|
-
Do not call this at perturbed parameters."""
|
|
191
|
-
if self.update is None: self.update = [g.clone() for g in self.get_grad()]
|
|
192
|
-
return self.update
|
|
193
|
-
|
|
194
|
-
def clone(self, clone_update: bool, parent: "Var | None" = None):
|
|
195
|
-
"""Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
|
|
196
|
-
|
|
197
|
-
Setting ``parent`` is only if clone's parameters are something different,
|
|
198
|
-
while clone's closure referes to the same objective but with a "view" on parameters.
|
|
199
|
-
"""
|
|
200
|
-
copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
|
|
201
|
-
|
|
202
|
-
if clone_update and self.update is not None:
|
|
203
|
-
copy.update = [u.clone() for u in self.update]
|
|
204
|
-
else:
|
|
205
|
-
copy.update = self.update
|
|
206
|
-
|
|
207
|
-
copy.grad = self.grad
|
|
208
|
-
copy.loss = self.loss
|
|
209
|
-
copy.loss_approx = self.loss_approx
|
|
210
|
-
copy.closure = self.closure
|
|
211
|
-
copy.post_step_hooks = self.post_step_hooks
|
|
212
|
-
copy.stop = self.stop
|
|
213
|
-
copy.skip_update = self.skip_update
|
|
214
|
-
|
|
215
|
-
copy.modular = self.modular
|
|
216
|
-
copy.attrs = self.attrs
|
|
217
|
-
copy.storage = self.storage
|
|
218
|
-
copy.should_terminate = self.should_terminate
|
|
219
|
-
|
|
220
|
-
return copy
|
|
221
|
-
|
|
222
|
-
def update_attrs_from_clone_(self, var: "Var"):
|
|
223
|
-
"""Updates attributes of this `Vars` instance from a cloned instance.
|
|
224
|
-
Typically called after a child module has processed a cloned `Vars`
|
|
225
|
-
object. This propagates any newly computed loss or gradient values
|
|
226
|
-
from the child's context back to the parent `Vars` if the parent
|
|
227
|
-
didn't have them computed already.
|
|
228
|
-
|
|
229
|
-
Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
|
|
230
|
-
if the child updates them, the update will affect the parent too.
|
|
231
|
-
"""
|
|
232
|
-
if self.loss is None: self.loss = var.loss
|
|
233
|
-
if self.loss_approx is None: self.loss_approx = var.loss_approx
|
|
234
|
-
if self.grad is None: self.grad = var.grad
|
|
235
|
-
|
|
236
|
-
if var.should_terminate is not None: self.should_terminate = var.should_terminate
|
|
237
|
-
|
|
238
|
-
def zero_grad(self, set_to_none=True):
|
|
239
|
-
if set_to_none:
|
|
240
|
-
for p in self.params: p.grad = None
|
|
241
|
-
else:
|
|
242
|
-
grads = [p.grad for p in self.params if p.grad is not None]
|
|
243
|
-
if len(grads) != 0: torch._foreach_zero_(grads)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
# ------------------------------ HELPER METHODS ------------------------------ #
|
|
247
|
-
@torch.no_grad
|
|
248
|
-
def hessian_vector_product(
|
|
249
|
-
self,
|
|
250
|
-
v: Sequence[torch.Tensor],
|
|
251
|
-
at_x0: bool,
|
|
252
|
-
rgrad: Sequence[torch.Tensor] | None,
|
|
253
|
-
hvp_method: Literal['autograd', 'forward', 'central'],
|
|
254
|
-
h: float,
|
|
255
|
-
normalize: bool,
|
|
256
|
-
retain_graph: bool,
|
|
257
|
-
) -> tuple[list[torch.Tensor], Sequence[torch.Tensor] | None]:
|
|
258
|
-
"""
|
|
259
|
-
Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
|
|
260
|
-
possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
|
|
261
|
-
Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
|
|
262
|
-
|
|
263
|
-
Single sample example:
|
|
264
|
-
|
|
265
|
-
```python
|
|
266
|
-
Hvp, _ = self.hessian_vector_product(v, at_x0=True, rgrad=None, ..., retain_graph=False)
|
|
267
|
-
```
|
|
268
|
-
|
|
269
|
-
Multiple samples example:
|
|
270
|
-
|
|
271
|
-
```python
|
|
272
|
-
D = None
|
|
273
|
-
rgrad = None
|
|
274
|
-
for i in range(n_samples):
|
|
275
|
-
v = [torch.randn_like(p) for p in params]
|
|
276
|
-
Hvp, rgrad = self.hessian_vector_product(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
|
|
277
|
-
|
|
278
|
-
if D is None: D = Hvp
|
|
279
|
-
else: torch._foreach_add_(D, Hvp)
|
|
280
|
-
|
|
281
|
-
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
282
|
-
```
|
|
283
|
-
|
|
284
|
-
Args:
|
|
285
|
-
v (Sequence[torch.Tensor]): vector in hessian-vector product
|
|
286
|
-
at_x0 (bool): whether this is being called at original or perturbed parameters.
|
|
287
|
-
var (Var): Var
|
|
288
|
-
rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
|
|
289
|
-
hvp_method (str): hvp method.
|
|
290
|
-
h (float): finite difference step size
|
|
291
|
-
normalize (bool): whether to normalize v for finite difference
|
|
292
|
-
retain_grad (bool): retain grad
|
|
293
|
-
"""
|
|
294
|
-
# get grad
|
|
295
|
-
if rgrad is None and hvp_method in ('autograd', 'forward'):
|
|
296
|
-
if at_x0: rgrad = self.get_grad(create_graph = hvp_method=='autograd')
|
|
297
|
-
else:
|
|
298
|
-
if self.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
299
|
-
with torch.enable_grad():
|
|
300
|
-
loss = self.closure()
|
|
301
|
-
rgrad = torch.autograd.grad(loss, self.params, create_graph = hvp_method=='autograd')
|
|
302
|
-
|
|
303
|
-
if hvp_method == 'autograd':
|
|
304
|
-
assert rgrad is not None
|
|
305
|
-
Hvp = hvp(self.params, rgrad, v, retain_graph=retain_graph)
|
|
306
|
-
|
|
307
|
-
elif hvp_method == 'forward':
|
|
308
|
-
assert rgrad is not None
|
|
309
|
-
loss, Hvp = hvp_fd_forward(self.closure, self.params, v, h=h, g_0=rgrad, normalize=normalize)
|
|
310
|
-
|
|
311
|
-
elif hvp_method == 'central':
|
|
312
|
-
loss, Hvp = hvp_fd_central(self.closure, self.params, v, h=h, normalize=normalize)
|
|
313
|
-
|
|
314
|
-
else:
|
|
315
|
-
raise ValueError(hvp_method)
|
|
316
|
-
|
|
317
|
-
return list(Hvp), rgrad
|
|
318
|
-
|
|
319
|
-
@torch.no_grad
|
|
320
|
-
def hessian_matrix_product(
|
|
321
|
-
self,
|
|
322
|
-
M: torch.Tensor,
|
|
323
|
-
at_x0: bool,
|
|
324
|
-
rgrad: Sequence[torch.Tensor] | None,
|
|
325
|
-
hvp_method: Literal["batched", 'autograd', 'forward', 'central'],
|
|
326
|
-
h: float,
|
|
327
|
-
normalize: bool,
|
|
328
|
-
retain_graph: bool,
|
|
329
|
-
) -> tuple[torch.Tensor, Sequence[torch.Tensor] | None]:
|
|
330
|
-
"""M is (n_dim, n_hvps), computes H @ M - (n_dim, n_hvps)."""
|
|
331
|
-
|
|
332
|
-
# get grad
|
|
333
|
-
if rgrad is None and hvp_method in ('autograd', 'forward', "batched"):
|
|
334
|
-
if at_x0: rgrad = self.get_grad(create_graph = hvp_method in ('autograd', "batched"))
|
|
335
|
-
else:
|
|
336
|
-
if self.closure is None: raise RuntimeError("Closure is required to calculate HVp")
|
|
337
|
-
with torch.enable_grad():
|
|
338
|
-
loss = self.closure()
|
|
339
|
-
create_graph = hvp_method in ('autograd', "batched")
|
|
340
|
-
rgrad = torch.autograd.grad(loss, self.params, create_graph=create_graph)
|
|
341
|
-
|
|
342
|
-
if hvp_method == "batched":
|
|
343
|
-
assert rgrad is not None
|
|
344
|
-
with torch.enable_grad():
|
|
345
|
-
flat_inputs = torch.cat([g.ravel() for g in rgrad])
|
|
346
|
-
HM_list = torch.autograd.grad(flat_inputs, self.params, grad_outputs=M.T, is_grads_batched=True, retain_graph=retain_graph)
|
|
347
|
-
HM = flatten_jacobian(HM_list).T
|
|
348
|
-
|
|
349
|
-
elif hvp_method == 'autograd':
|
|
350
|
-
assert rgrad is not None
|
|
351
|
-
with torch.enable_grad():
|
|
352
|
-
flat_inputs = torch.cat([g.ravel() for g in rgrad])
|
|
353
|
-
HV_tensors = [torch.autograd.grad(
|
|
354
|
-
flat_inputs, self.params, grad_outputs=col,
|
|
355
|
-
retain_graph = retain_graph or (i < M.size(1) - 1)
|
|
356
|
-
) for i,col in enumerate(M.unbind(1))]
|
|
357
|
-
HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
|
|
358
|
-
HM = torch.stack(HM_list, 1)
|
|
359
|
-
|
|
360
|
-
elif hvp_method == 'forward':
|
|
361
|
-
assert rgrad is not None
|
|
362
|
-
HV_tensors = [hvp_fd_forward(self.closure, self.params, vec_to_tensors(col, self.params), h=h, g_0=rgrad, normalize=normalize)[1] for col in M.unbind(1)]
|
|
363
|
-
HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
|
|
364
|
-
HM = flatten_jacobian(HM_list)
|
|
365
|
-
|
|
366
|
-
elif hvp_method == 'central':
|
|
367
|
-
HV_tensors = [hvp_fd_central(self.closure, self.params, vec_to_tensors(col, self.params), h=h, normalize=normalize)[1] for col in M.unbind(1)]
|
|
368
|
-
HM_list = [torch.cat([t.ravel() for t in tensors]) for tensors in HV_tensors]
|
|
369
|
-
HM = flatten_jacobian(HM_list)
|
|
370
|
-
|
|
371
|
-
else:
|
|
372
|
-
raise ValueError(hvp_method)
|
|
373
|
-
|
|
374
|
-
return HM, rgrad
|
|
375
|
-
|
|
376
|
-
# endregion
|
|
@@ -1,186 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from typing import Literal, Any
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from ...core import Chainable, TensorwiseTransform
|
|
7
|
-
|
|
8
|
-
def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
|
|
9
|
-
if isinstance(history, torch.Tensor):
|
|
10
|
-
M = history
|
|
11
|
-
else:
|
|
12
|
-
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
13
|
-
|
|
14
|
-
MTM = M.T @ M
|
|
15
|
-
if damping != 0:
|
|
16
|
-
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
17
|
-
|
|
18
|
-
try:
|
|
19
|
-
L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
|
|
20
|
-
|
|
21
|
-
tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
|
|
22
|
-
indices = L > tol
|
|
23
|
-
L = L[indices]
|
|
24
|
-
Q = Q[:, indices]
|
|
25
|
-
|
|
26
|
-
U = (M @ Q) * L.rsqrt()
|
|
27
|
-
|
|
28
|
-
if rdamping != 0:
|
|
29
|
-
rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
|
|
30
|
-
L.add_(rdamping)
|
|
31
|
-
|
|
32
|
-
return U, L
|
|
33
|
-
|
|
34
|
-
except torch.linalg.LinAlgError:
|
|
35
|
-
return None, None
|
|
36
|
-
|
|
37
|
-
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
|
|
38
|
-
Z = U.T @ g
|
|
39
|
-
return (U * L.rsqrt()) @ Z
|
|
40
|
-
|
|
41
|
-
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
42
|
-
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
43
|
-
else:
|
|
44
|
-
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
45
|
-
else: state_[key].lerp_(value, 1-beta)
|
|
46
|
-
|
|
47
|
-
class LMAdagrad(TensorwiseTransform):
|
|
48
|
-
"""
|
|
49
|
-
Limited-memory full matrix Adagrad.
|
|
50
|
-
|
|
51
|
-
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
52
|
-
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
53
|
-
|
|
54
|
-
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
58
|
-
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
59
|
-
damping (float, optional): damping value. Defaults to 1e-4.
|
|
60
|
-
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
61
|
-
order (int, optional):
|
|
62
|
-
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
63
|
-
true_damping (bool, optional):
|
|
64
|
-
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
65
|
-
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
66
|
-
L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
|
|
67
|
-
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
68
|
-
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
69
|
-
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
70
|
-
|
|
71
|
-
## Examples:
|
|
72
|
-
|
|
73
|
-
Limited-memory Adagrad
|
|
74
|
-
|
|
75
|
-
```python
|
|
76
|
-
optimizer = tz.Modular(
|
|
77
|
-
model.parameters(),
|
|
78
|
-
tz.m.LMAdagrad(),
|
|
79
|
-
tz.m.LR(0.1)
|
|
80
|
-
)
|
|
81
|
-
```
|
|
82
|
-
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
83
|
-
|
|
84
|
-
```python
|
|
85
|
-
optimizer = tz.Modular(
|
|
86
|
-
model.parameters(),
|
|
87
|
-
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
88
|
-
tz.m.Debias(0.9, 0.999),
|
|
89
|
-
tz.m.LR(0.01)
|
|
90
|
-
)
|
|
91
|
-
```
|
|
92
|
-
|
|
93
|
-
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
94
|
-
|
|
95
|
-
```python
|
|
96
|
-
optimizer = tz.Modular(
|
|
97
|
-
model.parameters(),
|
|
98
|
-
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
99
|
-
tz.m.Debias(0.9, 0.999),
|
|
100
|
-
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
101
|
-
tz.m.LR(0.01)
|
|
102
|
-
)
|
|
103
|
-
```
|
|
104
|
-
Reference:
|
|
105
|
-
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
def __init__(
|
|
109
|
-
self,
|
|
110
|
-
history_size: int = 100,
|
|
111
|
-
update_freq: int = 1,
|
|
112
|
-
damping: float = 1e-4,
|
|
113
|
-
rdamping: float = 0,
|
|
114
|
-
order: int = 1,
|
|
115
|
-
true_damping: bool = True,
|
|
116
|
-
U_beta: float | None = None,
|
|
117
|
-
L_beta: float | None = None,
|
|
118
|
-
interval: int = 1,
|
|
119
|
-
concat_params: bool = True,
|
|
120
|
-
inner: Chainable | None = None,
|
|
121
|
-
):
|
|
122
|
-
# history is still updated each step so Precondition's update_freq has different meaning
|
|
123
|
-
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
|
|
124
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
|
|
125
|
-
|
|
126
|
-
@torch.no_grad
|
|
127
|
-
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
128
|
-
order = setting['order']
|
|
129
|
-
history_size = setting['history_size']
|
|
130
|
-
update_freq = setting['update_freq']
|
|
131
|
-
damping = setting['damping']
|
|
132
|
-
rdamping = setting['rdamping']
|
|
133
|
-
U_beta = setting['U_beta']
|
|
134
|
-
L_beta = setting['L_beta']
|
|
135
|
-
|
|
136
|
-
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
137
|
-
history = state['history']
|
|
138
|
-
|
|
139
|
-
if order == 1:
|
|
140
|
-
t = tensor.clone().view(-1)
|
|
141
|
-
history.append(t)
|
|
142
|
-
else:
|
|
143
|
-
|
|
144
|
-
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
145
|
-
# scaled by parameter differences
|
|
146
|
-
cur_p = param.clone()
|
|
147
|
-
cur_g = tensor.clone()
|
|
148
|
-
eps = torch.finfo(cur_p.dtype).tiny * 2
|
|
149
|
-
for i in range(1, order):
|
|
150
|
-
if f'prev_g_{i}' not in state:
|
|
151
|
-
state[f'prev_p_{i}'] = cur_p
|
|
152
|
-
state[f'prev_g_{i}'] = cur_g
|
|
153
|
-
break
|
|
154
|
-
|
|
155
|
-
s = cur_p - state[f'prev_p_{i}']
|
|
156
|
-
y = cur_g - state[f'prev_g_{i}']
|
|
157
|
-
state[f'prev_p_{i}'] = cur_p
|
|
158
|
-
state[f'prev_g_{i}'] = cur_g
|
|
159
|
-
cur_p = s
|
|
160
|
-
cur_g = y
|
|
161
|
-
|
|
162
|
-
if i == order - 1:
|
|
163
|
-
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
|
|
164
|
-
history.append(cur_g.view(-1))
|
|
165
|
-
|
|
166
|
-
step = state.get('step', 0)
|
|
167
|
-
if step % update_freq == 0 and len(history) != 0:
|
|
168
|
-
U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
|
|
169
|
-
maybe_lerp_(state, U_beta, 'U', U)
|
|
170
|
-
maybe_lerp_(state, L_beta, 'L', L)
|
|
171
|
-
|
|
172
|
-
if len(history) != 0:
|
|
173
|
-
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
174
|
-
|
|
175
|
-
@torch.no_grad
|
|
176
|
-
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
177
|
-
U = state.get('U', None)
|
|
178
|
-
if U is None:
|
|
179
|
-
# make a conservative step to avoid issues due to different GD scaling
|
|
180
|
-
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
181
|
-
|
|
182
|
-
L = state['L']
|
|
183
|
-
update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
|
|
184
|
-
|
|
185
|
-
return update
|
|
186
|
-
|