torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -3
- tests/test_opts.py +140 -100
- tests/test_tensorlist.py +8 -7
- tests/test_vars.py +1 -0
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +335 -50
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +197 -70
- torchzero/modules/__init__.py +13 -4
- torchzero/modules/adaptive/__init__.py +30 -0
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/adaptive/adahessian.py +224 -0
- torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
- torchzero/modules/adaptive/adan.py +96 -0
- torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/adaptive/esgd.py +171 -0
- torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
- torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
- torchzero/modules/adaptive/mars.py +79 -0
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/adaptive/msam.py +188 -0
- torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
- torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
- torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
- torchzero/modules/adaptive/sam.py +163 -0
- torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
- torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
- torchzero/modules/adaptive/sophia_h.py +185 -0
- torchzero/modules/clipping/clipping.py +115 -25
- torchzero/modules/clipping/ema_clipping.py +31 -17
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/conjugate_gradient/cg.py +355 -0
- torchzero/modules/experimental/__init__.py +13 -19
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +32 -15
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
- torchzero/modules/functional.py +52 -6
- torchzero/modules/grad_approximation/fdm.py +30 -4
- torchzero/modules/grad_approximation/forward_gradient.py +16 -4
- torchzero/modules/grad_approximation/grad_approximator.py +51 -10
- torchzero/modules/grad_approximation/rfdm.py +321 -52
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +164 -93
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +124 -0
- torchzero/modules/line_search/backtracking.py +95 -57
- torchzero/modules/line_search/line_search.py +171 -22
- torchzero/modules/line_search/scipy.py +3 -3
- torchzero/modules/line_search/strong_wolfe.py +327 -199
- torchzero/modules/misc/__init__.py +35 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +62 -0
- torchzero/modules/misc/gradient_accumulation.py +136 -0
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +383 -0
- torchzero/modules/misc/multistep.py +194 -0
- torchzero/modules/misc/regularization.py +167 -0
- torchzero/modules/misc/split.py +123 -0
- torchzero/modules/{ops → misc}/switch.py +45 -4
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +9 -9
- torchzero/modules/momentum/cautious.py +51 -19
- torchzero/modules/momentum/momentum.py +37 -2
- torchzero/modules/ops/__init__.py +11 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +81 -34
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
- torchzero/modules/ops/multi.py +82 -21
- torchzero/modules/ops/reduce.py +16 -8
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +30 -18
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +190 -96
- torchzero/modules/quasi_newton/__init__.py +9 -14
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
- torchzero/modules/quasi_newton/lbfgs.py +286 -173
- torchzero/modules/quasi_newton/lsr1.py +185 -106
- torchzero/modules/quasi_newton/quasi_newton.py +816 -268
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +3 -2
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +292 -68
- torchzero/modules/second_order/newton_cg.py +365 -15
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +387 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +94 -11
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +39 -3
- torchzero/optim/wrappers/fcmaes.py +24 -15
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +3 -3
- torchzero/optim/wrappers/scipy.py +86 -25
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +126 -114
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +369 -58
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +16 -0
- torchzero/utils/tensorlist.py +134 -51
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -57
- torchzero/modules/experimental/absoap.py +0 -250
- torchzero/modules/experimental/adadam.py +0 -112
- torchzero/modules/experimental/adamY.py +0 -125
- torchzero/modules/experimental/adasoap.py +0 -172
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/eigendescent.py +0 -117
- torchzero/modules/experimental/etf.py +0 -172
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/subspace_preconditioners.py +0 -138
- torchzero/modules/experimental/tada.py +0 -38
- torchzero/modules/line_search/trust_region.py +0 -73
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/momentum/matrix_momentum.py +0 -166
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/optimizers/__init__.py +0 -18
- torchzero/modules/optimizers/adagrad.py +0 -155
- torchzero/modules/optimizers/sophia_h.py +0 -129
- torchzero/modules/quasi_newton/cg.py +0 -268
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero/modules/smoothing/gaussian.py +0 -164
- torchzero-0.3.10.dist-info/METADATA +0 -379
- torchzero-0.3.10.dist-info/RECORD +0 -139
- torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
- {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from ...core import (
|
|
6
|
+
Chainable,
|
|
7
|
+
Module,
|
|
8
|
+
Target,
|
|
9
|
+
TensorwiseTransform,
|
|
10
|
+
Transform,
|
|
11
|
+
Var,
|
|
12
|
+
apply_transform,
|
|
13
|
+
)
|
|
14
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
15
|
+
from ...utils.linalg import matrix_power_eigh
|
|
16
|
+
from ..functional import add_power_, lerp_power_, root, epsilon_step_size
|
|
17
|
+
from ...utils.linalg.linear_operator import Dense
|
|
18
|
+
|
|
19
|
+
def adagrad_(
|
|
20
|
+
tensors_: TensorList,
|
|
21
|
+
sq_sum_: TensorList,
|
|
22
|
+
alpha: float | NumberList,
|
|
23
|
+
lr_decay: float | NumberList,
|
|
24
|
+
eps: float | NumberList,
|
|
25
|
+
step: int,
|
|
26
|
+
pow: float = 2,
|
|
27
|
+
use_sqrt: bool = True,
|
|
28
|
+
divide: bool = False,
|
|
29
|
+
|
|
30
|
+
decay: float | None = None,
|
|
31
|
+
beta: float | None = None,
|
|
32
|
+
|
|
33
|
+
# inner args
|
|
34
|
+
inner: Module | None = None,
|
|
35
|
+
params: list[torch.Tensor] | None = None,
|
|
36
|
+
grads: list[torch.Tensor] | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""returns `tensors_`"""
|
|
39
|
+
clr = alpha / (1 + step * lr_decay)
|
|
40
|
+
|
|
41
|
+
if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
|
|
42
|
+
else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
|
|
43
|
+
if decay is not None:
|
|
44
|
+
sq_sum_.mul_(1-decay)
|
|
45
|
+
|
|
46
|
+
if inner is not None:
|
|
47
|
+
assert params is not None
|
|
48
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
49
|
+
|
|
50
|
+
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
51
|
+
|
|
52
|
+
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
53
|
+
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
54
|
+
|
|
55
|
+
return tensors_
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Adagrad(Transform):
|
|
60
|
+
"""Adagrad, divides by sum of past squares of gradients.
|
|
61
|
+
|
|
62
|
+
This implementation is identical to ``torch.optim.Adagrad``.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
66
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
67
|
+
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
68
|
+
alpha (float, optional): step size. Defaults to 1.
|
|
69
|
+
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
70
|
+
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
71
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
72
|
+
"""
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
lr_decay: float = 0,
|
|
76
|
+
initial_accumulator_value: float = 0,
|
|
77
|
+
eps: float = 1e-10,
|
|
78
|
+
alpha: float = 1,
|
|
79
|
+
pow: float = 2,
|
|
80
|
+
use_sqrt: bool = True,
|
|
81
|
+
divide: bool=False,
|
|
82
|
+
beta:float | None = None,
|
|
83
|
+
decay: float | None = None,
|
|
84
|
+
inner: Chainable | None = None,
|
|
85
|
+
):
|
|
86
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
87
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
|
|
88
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
89
|
+
|
|
90
|
+
if inner is not None:
|
|
91
|
+
self.set_child('inner', inner)
|
|
92
|
+
|
|
93
|
+
@torch.no_grad
|
|
94
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
95
|
+
tensors = TensorList(tensors)
|
|
96
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
97
|
+
|
|
98
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
99
|
+
|
|
100
|
+
pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
|
|
101
|
+
|
|
102
|
+
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
103
|
+
|
|
104
|
+
# initialize accumulator on 1st step
|
|
105
|
+
if step == 1:
|
|
106
|
+
sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
|
|
107
|
+
|
|
108
|
+
return adagrad_(
|
|
109
|
+
tensors,
|
|
110
|
+
sq_sum_=sq_sum,
|
|
111
|
+
alpha=alpha,
|
|
112
|
+
lr_decay=lr_decay,
|
|
113
|
+
eps=eps,
|
|
114
|
+
step=step,
|
|
115
|
+
pow=pow,
|
|
116
|
+
use_sqrt=use_sqrt,
|
|
117
|
+
divide=divide,
|
|
118
|
+
|
|
119
|
+
beta = self.defaults["beta"],
|
|
120
|
+
decay = self.defaults["decay"],
|
|
121
|
+
# inner args
|
|
122
|
+
inner=self.children.get("inner", None),
|
|
123
|
+
params=params,
|
|
124
|
+
grads=grads,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def lerp(start, end, weight):
|
|
129
|
+
return start + weight * (end - start)
|
|
130
|
+
|
|
131
|
+
def adagrad_norm_(
|
|
132
|
+
tensors_: TensorList,
|
|
133
|
+
accumulator: float | torch.Tensor,
|
|
134
|
+
alpha: float | NumberList,
|
|
135
|
+
lr_decay: float | NumberList,
|
|
136
|
+
eps: float | NumberList,
|
|
137
|
+
step: int,
|
|
138
|
+
use_sqrt: bool = True,
|
|
139
|
+
divide: bool = False,
|
|
140
|
+
|
|
141
|
+
decay: float | None = None,
|
|
142
|
+
beta: float | None = None,
|
|
143
|
+
|
|
144
|
+
# inner args
|
|
145
|
+
inner: Module | None = None,
|
|
146
|
+
params: list[torch.Tensor] | None = None,
|
|
147
|
+
grads: list[torch.Tensor] | None = None,
|
|
148
|
+
):
|
|
149
|
+
"""returns `tensors_`"""
|
|
150
|
+
clr = alpha / (1 + step * lr_decay)
|
|
151
|
+
|
|
152
|
+
gg = tensors_.dot(tensors_)
|
|
153
|
+
|
|
154
|
+
if beta is None or step == 1: accumulator += gg
|
|
155
|
+
else: accumulator = lerp(accumulator, gg, 1-beta)
|
|
156
|
+
|
|
157
|
+
if decay is not None:
|
|
158
|
+
accumulator *= 1-decay
|
|
159
|
+
|
|
160
|
+
if inner is not None:
|
|
161
|
+
assert params is not None
|
|
162
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
163
|
+
|
|
164
|
+
if divide: accumulator = accumulator / max(step, 1)
|
|
165
|
+
|
|
166
|
+
if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
|
|
167
|
+
else: tensors_.div_(eps + accumulator).mul_(clr)
|
|
168
|
+
|
|
169
|
+
return tensors_, accumulator
|
|
170
|
+
|
|
171
|
+
class AdagradNorm(Transform):
|
|
172
|
+
"""Adagrad-Norm, divides by sum of past means of squares of gradients.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
176
|
+
initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
|
|
177
|
+
eps (float, optional): division epsilon. Defaults to 1e-10.
|
|
178
|
+
alpha (float, optional): step size. Defaults to 1.
|
|
179
|
+
pow (float, optional): power for gradients and accumulator root. Defaults to 2.
|
|
180
|
+
use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
|
|
181
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
|
|
182
|
+
"""
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
lr_decay: float = 0,
|
|
186
|
+
initial_accumulator_value: float = 0,
|
|
187
|
+
eps: float = 1e-10,
|
|
188
|
+
alpha: float = 1,
|
|
189
|
+
pow: float = 2,
|
|
190
|
+
use_sqrt: bool = True,
|
|
191
|
+
divide: bool=False,
|
|
192
|
+
beta:float | None = None,
|
|
193
|
+
decay: float | None = None,
|
|
194
|
+
inner: Chainable | None = None,
|
|
195
|
+
):
|
|
196
|
+
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
197
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
|
|
198
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
199
|
+
|
|
200
|
+
if inner is not None:
|
|
201
|
+
self.set_child('inner', inner)
|
|
202
|
+
|
|
203
|
+
@torch.no_grad
|
|
204
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
205
|
+
tensors = TensorList(tensors)
|
|
206
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
207
|
+
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
208
|
+
|
|
209
|
+
use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])
|
|
210
|
+
|
|
211
|
+
accumulator = self.global_state.get("accumulator", initial_accumulator_value)
|
|
212
|
+
|
|
213
|
+
d, self.global_state["accumulator"] = adagrad_norm_(
|
|
214
|
+
tensors,
|
|
215
|
+
accumulator=accumulator,
|
|
216
|
+
alpha=alpha,
|
|
217
|
+
lr_decay=lr_decay,
|
|
218
|
+
eps=eps,
|
|
219
|
+
step=step,
|
|
220
|
+
use_sqrt=use_sqrt,
|
|
221
|
+
divide=divide,
|
|
222
|
+
|
|
223
|
+
beta = self.defaults["beta"],
|
|
224
|
+
decay = self.defaults["decay"],
|
|
225
|
+
# inner args
|
|
226
|
+
inner=self.children.get("inner", None),
|
|
227
|
+
params=params,
|
|
228
|
+
grads=grads,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return d
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class FullMatrixAdagrad(TensorwiseTransform):
|
|
235
|
+
"""Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
|
|
236
|
+
|
|
237
|
+
Note:
|
|
238
|
+
A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
|
|
242
|
+
decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
|
|
243
|
+
sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
|
|
244
|
+
concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
|
|
245
|
+
precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
|
|
246
|
+
init (Literal[str], optional):
|
|
247
|
+
how to initialize the accumulator.
|
|
248
|
+
- "identity" - with identity matrix (default).
|
|
249
|
+
- "zeros" - with zero matrix.
|
|
250
|
+
- "ones" - with matrix of ones.
|
|
251
|
+
-"GGT" - with the first outer product
|
|
252
|
+
divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
|
|
253
|
+
inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
|
|
254
|
+
|
|
255
|
+
## Examples:
|
|
256
|
+
|
|
257
|
+
Plain full-matrix adagrad
|
|
258
|
+
```python
|
|
259
|
+
opt = tz.Modular(
|
|
260
|
+
model.parameters(),
|
|
261
|
+
tz.m.FullMatrixAdagrd(),
|
|
262
|
+
tz.m.LR(1e-2),
|
|
263
|
+
)
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
Full-matrix RMSprop
|
|
267
|
+
```python
|
|
268
|
+
opt = tz.Modular(
|
|
269
|
+
model.parameters(),
|
|
270
|
+
tz.m.FullMatrixAdagrad(beta=0.99),
|
|
271
|
+
tz.m.LR(1e-2),
|
|
272
|
+
)
|
|
273
|
+
```
|
|
274
|
+
|
|
275
|
+
Full-matrix Adam
|
|
276
|
+
```python
|
|
277
|
+
opt = tz.Modular(
|
|
278
|
+
model.parameters(),
|
|
279
|
+
tz.m.FullMatrixAdagrad(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
280
|
+
tz.m.Debias(0.9, 0.999),
|
|
281
|
+
tz.m.LR(1e-2),
|
|
282
|
+
)
|
|
283
|
+
```
|
|
284
|
+
"""
|
|
285
|
+
def __init__(
|
|
286
|
+
self,
|
|
287
|
+
beta: float | None = None,
|
|
288
|
+
decay: float | None = None,
|
|
289
|
+
sqrt: bool = True,
|
|
290
|
+
concat_params=True,
|
|
291
|
+
precond_freq: int = 1,
|
|
292
|
+
init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
|
|
293
|
+
reg: float = 1e-12,
|
|
294
|
+
divide: bool = False,
|
|
295
|
+
inner: Chainable | None = None,
|
|
296
|
+
):
|
|
297
|
+
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
|
|
298
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)
|
|
299
|
+
|
|
300
|
+
@torch.no_grad
|
|
301
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
302
|
+
G = tensor.ravel()
|
|
303
|
+
GG = torch.outer(G, G)
|
|
304
|
+
decay = setting['decay']
|
|
305
|
+
beta = setting['beta']
|
|
306
|
+
init = setting['init']
|
|
307
|
+
|
|
308
|
+
if 'GG' not in state:
|
|
309
|
+
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
310
|
+
elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
|
|
311
|
+
elif init == 'ones': state['GG'] = torch.ones_like(GG)
|
|
312
|
+
elif init == 'GGT': state['GG'] = GG.clone()
|
|
313
|
+
else: raise ValueError(init)
|
|
314
|
+
if decay is not None: state['GG'].mul_(decay)
|
|
315
|
+
|
|
316
|
+
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
317
|
+
else: state['GG'].add_(GG)
|
|
318
|
+
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
319
|
+
|
|
320
|
+
@torch.no_grad
|
|
321
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
322
|
+
step = state.get('step', 0)
|
|
323
|
+
state['step'] = step + 1
|
|
324
|
+
|
|
325
|
+
GG: torch.Tensor = state['GG']
|
|
326
|
+
sqrt = setting['sqrt']
|
|
327
|
+
divide = setting['divide']
|
|
328
|
+
precond_freq = setting['precond_freq']
|
|
329
|
+
reg = setting['reg']
|
|
330
|
+
|
|
331
|
+
if divide: GG = GG/state.get('i', 1)
|
|
332
|
+
|
|
333
|
+
if reg != 0:
|
|
334
|
+
GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)
|
|
335
|
+
|
|
336
|
+
if tensor.numel() == 1:
|
|
337
|
+
GG = GG.squeeze()
|
|
338
|
+
if sqrt: return tensor / GG.sqrt()
|
|
339
|
+
return tensor / GG
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
if sqrt:
|
|
343
|
+
if "B" not in state or step % precond_freq == 0:
|
|
344
|
+
B = state["B"] = matrix_power_eigh(GG, -1/2)
|
|
345
|
+
else:
|
|
346
|
+
B = state["B"]
|
|
347
|
+
|
|
348
|
+
else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
|
|
349
|
+
|
|
350
|
+
except torch.linalg.LinAlgError:
|
|
351
|
+
# fallback to diagonal AdaGrad
|
|
352
|
+
denom = GG.diagonal()
|
|
353
|
+
if sqrt: denom = denom.sqrt()
|
|
354
|
+
return tensor.div_(denom + max(reg, 1e-12))
|
|
355
|
+
|
|
356
|
+
return (B @ tensor.ravel()).view_as(tensor)
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
8
|
+
from ..functional import debiased_step_size
|
|
9
|
+
|
|
10
|
+
def _full_average(hvp: torch.Tensor):
|
|
11
|
+
if hvp.ndim >= 3: # Conv kernel
|
|
12
|
+
return torch.mean(hvp.abs(), dim=[2, *range(3,hvp.ndim)], keepdim=True)
|
|
13
|
+
return hvp
|
|
14
|
+
|
|
15
|
+
def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
|
|
16
|
+
"""averages x over first dimension in blocks"""
|
|
17
|
+
if enable and x.ndim >= 2:
|
|
18
|
+
if math.prod(x.shape[1:]) <= 1: return x
|
|
19
|
+
if block_size is None: return _full_average(x)
|
|
20
|
+
size = x.size(0)
|
|
21
|
+
|
|
22
|
+
n_blocks = size // block_size
|
|
23
|
+
if n_blocks <= 1: return x.abs().mean(0, keepdim = True)
|
|
24
|
+
|
|
25
|
+
n_remaining = size - n_blocks * block_size
|
|
26
|
+
remaining = None
|
|
27
|
+
if n_remaining > 0:
|
|
28
|
+
remaining = x[-n_remaining:].abs().mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
|
|
29
|
+
x = x[:-n_remaining]
|
|
30
|
+
|
|
31
|
+
x = x.view(block_size, n_blocks, *x.shape[1:])
|
|
32
|
+
x_mean = x.abs().mean(0).repeat_interleave(block_size, 0)
|
|
33
|
+
|
|
34
|
+
if remaining is None: return x_mean
|
|
35
|
+
return torch.cat([x_mean, remaining], 0)
|
|
36
|
+
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _rademacher_like(tensor, p = 0.5, generator = None):
|
|
41
|
+
"""p is probability of a 1, other values will be -1."""
|
|
42
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
|
|
43
|
+
|
|
44
|
+
def adahessian(
|
|
45
|
+
tensors: TensorList,
|
|
46
|
+
D: TensorList | None,
|
|
47
|
+
exp_avg_: TensorList,
|
|
48
|
+
D_exp_avg_sq_: TensorList,
|
|
49
|
+
beta1: float | NumberList,
|
|
50
|
+
beta2: float | NumberList,
|
|
51
|
+
update_freq: int,
|
|
52
|
+
eps: float | NumberList,
|
|
53
|
+
hessian_power: float | NumberList,
|
|
54
|
+
step: int,
|
|
55
|
+
):
|
|
56
|
+
# momentum
|
|
57
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
58
|
+
|
|
59
|
+
# update preconditioner
|
|
60
|
+
if step % update_freq == 0:
|
|
61
|
+
assert D is not None
|
|
62
|
+
D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
assert D is None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
|
|
69
|
+
num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
|
|
70
|
+
|
|
71
|
+
return num.div_(denom)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AdaHessian(Module):
|
|
75
|
+
"""AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
|
|
76
|
+
|
|
77
|
+
This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
|
|
78
|
+
|
|
79
|
+
Notes:
|
|
80
|
+
- In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
|
|
81
|
+
|
|
82
|
+
- If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
|
|
83
|
+
|
|
84
|
+
- This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
beta1 (float, optional): first momentum. Defaults to 0.9.
|
|
88
|
+
beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
|
|
89
|
+
averaging (bool, optional):
|
|
90
|
+
whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
|
|
91
|
+
This can be set per-parameter in param groups.
|
|
92
|
+
block_size (int, optional):
|
|
93
|
+
size of block in the block-diagonal averaging.
|
|
94
|
+
update_freq (int, optional):
|
|
95
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
96
|
+
This value can be increased to reduce computational cost. Defaults to 1.
|
|
97
|
+
eps (float, optional):
|
|
98
|
+
division stability epsilon. Defaults to 1e-8.
|
|
99
|
+
hvp_method (str, optional):
|
|
100
|
+
Determines how Hessian-vector products are evaluated.
|
|
101
|
+
|
|
102
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
103
|
+
This requires creating a graph for the gradient.
|
|
104
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
105
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
106
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
107
|
+
more accurate HVP approximation. This requires two extra
|
|
108
|
+
gradient evaluations.
|
|
109
|
+
Defaults to "autograd".
|
|
110
|
+
fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
|
|
111
|
+
n_samples (int, optional):
|
|
112
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
113
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
114
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
115
|
+
inner (Chainable | None, optional):
|
|
116
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
117
|
+
1. compute hessian diagonal estimate.
|
|
118
|
+
2. pass inputs to ``inner``.
|
|
119
|
+
3. momentum and preconditioning are applied to the ouputs of ``inner``.
|
|
120
|
+
|
|
121
|
+
## Examples:
|
|
122
|
+
|
|
123
|
+
Using AdaHessian:
|
|
124
|
+
|
|
125
|
+
```python
|
|
126
|
+
opt = tz.Modular(
|
|
127
|
+
model.parameters(),
|
|
128
|
+
tz.m.AdaHessian(),
|
|
129
|
+
tz.m.LR(0.1)
|
|
130
|
+
)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
|
|
134
|
+
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
135
|
+
AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
|
|
136
|
+
```python
|
|
137
|
+
opt = tz.Modular(
|
|
138
|
+
model.parameters(),
|
|
139
|
+
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
140
|
+
tz.m.LR(0.1)
|
|
141
|
+
)
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
beta1: float = 0.9,
|
|
148
|
+
beta2: float = 0.999,
|
|
149
|
+
averaging: bool = True,
|
|
150
|
+
block_size: int | None = None,
|
|
151
|
+
update_freq: int = 1,
|
|
152
|
+
eps: float = 1e-8,
|
|
153
|
+
hessian_power: float = 1,
|
|
154
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
155
|
+
fd_h: float = 1e-3,
|
|
156
|
+
n_samples = 1,
|
|
157
|
+
seed: int | None = None,
|
|
158
|
+
inner: Chainable | None = None
|
|
159
|
+
):
|
|
160
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
161
|
+
super().__init__(defaults)
|
|
162
|
+
|
|
163
|
+
if inner is not None:
|
|
164
|
+
self.set_child('inner', inner)
|
|
165
|
+
|
|
166
|
+
@torch.no_grad
|
|
167
|
+
def step(self, var):
|
|
168
|
+
params = var.params
|
|
169
|
+
settings = self.settings[params[0]]
|
|
170
|
+
hvp_method = settings['hvp_method']
|
|
171
|
+
fd_h = settings['fd_h']
|
|
172
|
+
update_freq = settings['update_freq']
|
|
173
|
+
n_samples = settings['n_samples']
|
|
174
|
+
|
|
175
|
+
seed = settings['seed']
|
|
176
|
+
generator = self.get_generator(params[0].device, seed)
|
|
177
|
+
|
|
178
|
+
beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
|
|
179
|
+
'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
|
|
180
|
+
|
|
181
|
+
exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
182
|
+
|
|
183
|
+
step = self.global_state.get('step', 0)
|
|
184
|
+
self.global_state['step'] = step + 1
|
|
185
|
+
|
|
186
|
+
closure = var.closure
|
|
187
|
+
assert closure is not None
|
|
188
|
+
|
|
189
|
+
D = None
|
|
190
|
+
if step % update_freq == 0:
|
|
191
|
+
|
|
192
|
+
rgrad=None
|
|
193
|
+
for i in range(n_samples):
|
|
194
|
+
u = [_rademacher_like(p, generator=generator) for p in params]
|
|
195
|
+
|
|
196
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
197
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
198
|
+
Hvp = tuple(Hvp)
|
|
199
|
+
|
|
200
|
+
if D is None: D = Hvp
|
|
201
|
+
else: torch._foreach_add_(D, Hvp)
|
|
202
|
+
|
|
203
|
+
assert D is not None
|
|
204
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
205
|
+
|
|
206
|
+
D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
|
|
207
|
+
|
|
208
|
+
update = var.get_update()
|
|
209
|
+
if 'inner' in self.children:
|
|
210
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
211
|
+
|
|
212
|
+
var.update = adahessian(
|
|
213
|
+
tensors=TensorList(update),
|
|
214
|
+
D=TensorList(D) if D is not None else None,
|
|
215
|
+
exp_avg_=exp_avg,
|
|
216
|
+
D_exp_avg_sq_=D_exp_avg_sq,
|
|
217
|
+
beta1=beta1,
|
|
218
|
+
beta2=beta2,
|
|
219
|
+
update_freq=update_freq,
|
|
220
|
+
eps=eps,
|
|
221
|
+
hessian_power=hessian_power,
|
|
222
|
+
step=step,
|
|
223
|
+
)
|
|
224
|
+
return var
|
|
@@ -10,9 +10,6 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..lr.lr import lazy_lr
|
|
14
|
-
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
-
from ..momentum.momentum import nag_
|
|
16
13
|
|
|
17
14
|
|
|
18
15
|
def adam_(
|
|
@@ -33,7 +30,7 @@ def adam_(
|
|
|
33
30
|
params: list[torch.Tensor] | None = None,
|
|
34
31
|
grads: list[torch.Tensor] | None = None,
|
|
35
32
|
):
|
|
36
|
-
"""Returns new tensors
|
|
33
|
+
"""Returns new tensors."""
|
|
37
34
|
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
38
35
|
debiased=False,step=step,pow=pow)
|
|
39
36
|
|
|
@@ -43,11 +40,12 @@ def adam_(
|
|
|
43
40
|
|
|
44
41
|
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
45
42
|
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
46
|
-
return (exp_avg_ / sqrt_exp_avg_sq.add_(eps))
|
|
43
|
+
return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
|
|
47
44
|
|
|
48
45
|
class Adam(Transform):
|
|
49
|
-
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
50
|
-
|
|
46
|
+
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
47
|
+
|
|
48
|
+
This implementation is identical to :code:`torch.optim.Adam`.
|
|
51
49
|
|
|
52
50
|
Args:
|
|
53
51
|
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
@@ -75,7 +73,7 @@ class Adam(Transform):
|
|
|
75
73
|
if inner is not None: self.set_child('inner', inner)
|
|
76
74
|
|
|
77
75
|
@torch.no_grad
|
|
78
|
-
def
|
|
76
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
79
77
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
80
78
|
|
|
81
79
|
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|