torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +95 -69
- tests/test_tensorlist.py +8 -7
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +2 -2
- torchzero/core/module.py +225 -72
- torchzero/core/reformulation.py +65 -0
- torchzero/core/transform.py +44 -24
- torchzero/modules/__init__.py +13 -5
- torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
- torchzero/modules/adaptive/adagrad.py +356 -0
- torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
- torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
- torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
- torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
- torchzero/modules/adaptive/aegd.py +54 -0
- torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
- torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
- torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
- torchzero/modules/adaptive/matrix_momentum.py +146 -0
- torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
- torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
- torchzero/modules/adaptive/natural_gradient.py +175 -0
- torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
- torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
- torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
- torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
- torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
- torchzero/modules/clipping/clipping.py +85 -92
- torchzero/modules/clipping/ema_clipping.py +5 -5
- torchzero/modules/conjugate_gradient/__init__.py +11 -0
- torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
- torchzero/modules/experimental/__init__.py +9 -32
- torchzero/modules/experimental/dct.py +2 -2
- torchzero/modules/experimental/fft.py +2 -2
- torchzero/modules/experimental/gradmin.py +4 -3
- torchzero/modules/experimental/l_infinity.py +111 -0
- torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
- torchzero/modules/experimental/newton_solver.py +79 -17
- torchzero/modules/experimental/newtonnewton.py +27 -14
- torchzero/modules/experimental/scipy_newton_cg.py +105 -0
- torchzero/modules/experimental/structural_projections.py +1 -1
- torchzero/modules/functional.py +50 -14
- torchzero/modules/grad_approximation/fdm.py +19 -20
- torchzero/modules/grad_approximation/forward_gradient.py +4 -2
- torchzero/modules/grad_approximation/grad_approximator.py +43 -47
- torchzero/modules/grad_approximation/rfdm.py +144 -122
- torchzero/modules/higher_order/__init__.py +1 -1
- torchzero/modules/higher_order/higher_order_newton.py +31 -23
- torchzero/modules/least_squares/__init__.py +1 -0
- torchzero/modules/least_squares/gn.py +161 -0
- torchzero/modules/line_search/__init__.py +2 -2
- torchzero/modules/line_search/_polyinterp.py +289 -0
- torchzero/modules/line_search/adaptive.py +69 -44
- torchzero/modules/line_search/backtracking.py +83 -70
- torchzero/modules/line_search/line_search.py +159 -68
- torchzero/modules/line_search/scipy.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +319 -218
- torchzero/modules/misc/__init__.py +8 -0
- torchzero/modules/misc/debug.py +4 -4
- torchzero/modules/misc/escape.py +9 -7
- torchzero/modules/misc/gradient_accumulation.py +88 -22
- torchzero/modules/misc/homotopy.py +59 -0
- torchzero/modules/misc/misc.py +82 -15
- torchzero/modules/misc/multistep.py +47 -11
- torchzero/modules/misc/regularization.py +5 -9
- torchzero/modules/misc/split.py +55 -35
- torchzero/modules/misc/switch.py +1 -1
- torchzero/modules/momentum/__init__.py +1 -5
- torchzero/modules/momentum/averaging.py +3 -3
- torchzero/modules/momentum/cautious.py +42 -47
- torchzero/modules/momentum/momentum.py +35 -1
- torchzero/modules/ops/__init__.py +9 -1
- torchzero/modules/ops/binary.py +9 -8
- torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
- torchzero/modules/ops/multi.py +15 -15
- torchzero/modules/ops/reduce.py +1 -1
- torchzero/modules/ops/utility.py +12 -8
- torchzero/modules/projections/projection.py +4 -4
- torchzero/modules/quasi_newton/__init__.py +1 -16
- torchzero/modules/quasi_newton/damping.py +105 -0
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
- torchzero/modules/quasi_newton/lbfgs.py +256 -200
- torchzero/modules/quasi_newton/lsr1.py +167 -132
- torchzero/modules/quasi_newton/quasi_newton.py +346 -446
- torchzero/modules/restarts/__init__.py +7 -0
- torchzero/modules/restarts/restars.py +252 -0
- torchzero/modules/second_order/__init__.py +2 -1
- torchzero/modules/second_order/multipoint.py +238 -0
- torchzero/modules/second_order/newton.py +133 -88
- torchzero/modules/second_order/newton_cg.py +141 -80
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/sampling.py +300 -0
- torchzero/modules/step_size/__init__.py +1 -1
- torchzero/modules/step_size/adaptive.py +312 -47
- torchzero/modules/termination/__init__.py +14 -0
- torchzero/modules/termination/termination.py +207 -0
- torchzero/modules/trust_region/__init__.py +5 -0
- torchzero/modules/trust_region/cubic_regularization.py +170 -0
- torchzero/modules/trust_region/dogleg.py +92 -0
- torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
- torchzero/modules/trust_region/trust_cg.py +97 -0
- torchzero/modules/trust_region/trust_region.py +350 -0
- torchzero/modules/variance_reduction/__init__.py +1 -0
- torchzero/modules/variance_reduction/svrg.py +208 -0
- torchzero/modules/weight_decay/weight_decay.py +65 -64
- torchzero/modules/zeroth_order/__init__.py +1 -0
- torchzero/modules/zeroth_order/cd.py +359 -0
- torchzero/optim/root.py +65 -0
- torchzero/optim/utility/split.py +8 -8
- torchzero/optim/wrappers/directsearch.py +0 -1
- torchzero/optim/wrappers/fcmaes.py +3 -2
- torchzero/optim/wrappers/nlopt.py +0 -2
- torchzero/optim/wrappers/optuna.py +2 -2
- torchzero/optim/wrappers/scipy.py +81 -22
- torchzero/utils/__init__.py +40 -4
- torchzero/utils/compile.py +1 -1
- torchzero/utils/derivatives.py +123 -111
- torchzero/utils/linalg/__init__.py +9 -2
- torchzero/utils/linalg/linear_operator.py +329 -0
- torchzero/utils/linalg/matrix_funcs.py +2 -2
- torchzero/utils/linalg/orthogonalize.py +2 -1
- torchzero/utils/linalg/qr.py +2 -2
- torchzero/utils/linalg/solve.py +226 -154
- torchzero/utils/metrics.py +83 -0
- torchzero/utils/python_tools.py +6 -0
- torchzero/utils/tensorlist.py +105 -34
- torchzero/utils/torch_tools.py +9 -4
- torchzero-0.3.13.dist-info/METADATA +14 -0
- torchzero-0.3.13.dist-info/RECORD +166 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
- docs/source/conf.py +0 -59
- docs/source/docstring template.py +0 -46
- torchzero/modules/experimental/absoap.py +0 -253
- torchzero/modules/experimental/adadam.py +0 -118
- torchzero/modules/experimental/adamY.py +0 -131
- torchzero/modules/experimental/adam_lambertw.py +0 -149
- torchzero/modules/experimental/adaptive_step_size.py +0 -90
- torchzero/modules/experimental/adasoap.py +0 -177
- torchzero/modules/experimental/cosine.py +0 -214
- torchzero/modules/experimental/cubic_adam.py +0 -97
- torchzero/modules/experimental/eigendescent.py +0 -120
- torchzero/modules/experimental/etf.py +0 -195
- torchzero/modules/experimental/exp_adam.py +0 -113
- torchzero/modules/experimental/expanded_lbfgs.py +0 -141
- torchzero/modules/experimental/hnewton.py +0 -85
- torchzero/modules/experimental/modular_lbfgs.py +0 -265
- torchzero/modules/experimental/parabolic_search.py +0 -220
- torchzero/modules/experimental/subspace_preconditioners.py +0 -145
- torchzero/modules/experimental/tensor_adagrad.py +0 -42
- torchzero/modules/line_search/polynomial.py +0 -233
- torchzero/modules/momentum/matrix_momentum.py +0 -193
- torchzero/modules/optimizers/adagrad.py +0 -165
- torchzero/modules/quasi_newton/trust_region.py +0 -397
- torchzero/modules/smoothing/gaussian.py +0 -198
- torchzero-0.3.11.dist-info/METADATA +0 -404
- torchzero-0.3.11.dist-info/RECORD +0 -159
- torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
- /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
- /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
- {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
|
@@ -1,265 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
|
|
8
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _adaptive_damping(
|
|
12
|
-
s_k: TensorList,
|
|
13
|
-
y_k: TensorList,
|
|
14
|
-
ys_k: torch.Tensor,
|
|
15
|
-
init_damping = 0.99,
|
|
16
|
-
eigval_bounds = (0.01, 1.5)
|
|
17
|
-
):
|
|
18
|
-
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
19
|
-
sigma_l, sigma_h = eigval_bounds
|
|
20
|
-
u = ys_k / s_k.dot(s_k)
|
|
21
|
-
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
22
|
-
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
23
|
-
else: tau = init_damping
|
|
24
|
-
y_k = tau * y_k + (1-tau) * s_k
|
|
25
|
-
ys_k = s_k.dot(y_k)
|
|
26
|
-
|
|
27
|
-
return s_k, y_k, ys_k
|
|
28
|
-
|
|
29
|
-
def lbfgs(
|
|
30
|
-
tensors_: TensorList,
|
|
31
|
-
var: Var,
|
|
32
|
-
s_history: deque[TensorList],
|
|
33
|
-
y_history: deque[TensorList],
|
|
34
|
-
sy_history: deque[torch.Tensor],
|
|
35
|
-
y_k: TensorList | None,
|
|
36
|
-
ys_k: torch.Tensor | None,
|
|
37
|
-
z_tfm: Any,
|
|
38
|
-
):
|
|
39
|
-
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
-
|
|
41
|
-
# initial step size guess modified from pytorch L-BFGS
|
|
42
|
-
scale = 1 / tensors_.abs().global_sum()
|
|
43
|
-
if scale < 1e-5: scale = 1 / tensors_.abs().mean()
|
|
44
|
-
return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
|
|
45
|
-
|
|
46
|
-
# 1st loop
|
|
47
|
-
alpha_list = []
|
|
48
|
-
q = tensors_.clone()
|
|
49
|
-
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
50
|
-
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
51
|
-
alpha = p_i * s_i.dot(q)
|
|
52
|
-
alpha_list.append(alpha)
|
|
53
|
-
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
54
|
-
|
|
55
|
-
# calculate z
|
|
56
|
-
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
57
|
-
# z is it times q
|
|
58
|
-
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
59
|
-
z = q * (ys_k / (y_k.dot(y_k)))
|
|
60
|
-
|
|
61
|
-
if z_tfm is not None:
|
|
62
|
-
z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
|
|
63
|
-
|
|
64
|
-
# 2nd loop
|
|
65
|
-
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
66
|
-
p_i = 1 / ys_i
|
|
67
|
-
beta_i = p_i * y_i.dot(z)
|
|
68
|
-
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
69
|
-
|
|
70
|
-
return z
|
|
71
|
-
|
|
72
|
-
def _apply_tfms_into_history(
|
|
73
|
-
self: Module,
|
|
74
|
-
params: list[torch.Tensor],
|
|
75
|
-
var: Var,
|
|
76
|
-
update: list[torch.Tensor],
|
|
77
|
-
):
|
|
78
|
-
if 'params_history_tfm' in self.children:
|
|
79
|
-
params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
80
|
-
|
|
81
|
-
if 'grad_history_tfm' in self.children:
|
|
82
|
-
update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
|
|
83
|
-
|
|
84
|
-
return params, update
|
|
85
|
-
|
|
86
|
-
def _apply_tfms_into_precond(
|
|
87
|
-
self: Module,
|
|
88
|
-
params: list[torch.Tensor],
|
|
89
|
-
var: Var,
|
|
90
|
-
update: list[torch.Tensor],
|
|
91
|
-
):
|
|
92
|
-
if 'params_precond_tfm' in self.children:
|
|
93
|
-
params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
|
|
94
|
-
|
|
95
|
-
if 'grad_precond_tfm' in self.children:
|
|
96
|
-
update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
|
|
97
|
-
|
|
98
|
-
return params, update
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class ModularLBFGS(Module):
|
|
102
|
-
"""L-BFGS with ability to apply transforms to many inner variables.
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
106
|
-
tol (float | None, optional):
|
|
107
|
-
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
108
|
-
damping (bool, optional):
|
|
109
|
-
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
110
|
-
init_damping (float, optional):
|
|
111
|
-
initial damping for adaptive dampening. Defaults to 0.9.
|
|
112
|
-
eigval_bounds (tuple, optional):
|
|
113
|
-
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
114
|
-
update_freq (int, optional):
|
|
115
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
116
|
-
z_tfm (float | None, optional):
|
|
117
|
-
transform module applied to initial H^-1 @ q guess. Defaults to None.
|
|
118
|
-
params_history_tfm (AnyTransform | None, optional):
|
|
119
|
-
transform module applied to params before adding s_k to history. Defaults to None.
|
|
120
|
-
grad_history_tfm (AnyTransform | None, optional):
|
|
121
|
-
transform module applied to grads before adding y_k to history. Defaults to None.
|
|
122
|
-
params_precond_tfm (AnyTransform | None, optional):
|
|
123
|
-
transform module applied to params to calculate s_k before preconditioning. Defaults to None.
|
|
124
|
-
grad_precond_tfm (AnyTransform | None, optional):
|
|
125
|
-
transform module applied to grads to calculate y_k before preconditioning. Defaults to None.
|
|
126
|
-
update_precond_tfm (Chainable | None, optional):
|
|
127
|
-
transform module applied to grads that are being preconditioned. Defaults to None.
|
|
128
|
-
"""
|
|
129
|
-
def __init__(
|
|
130
|
-
self,
|
|
131
|
-
history_size=10,
|
|
132
|
-
tol: float | None = 1e-10,
|
|
133
|
-
damping: bool = False,
|
|
134
|
-
init_damping=0.9,
|
|
135
|
-
eigval_bounds=(0.5, 50),
|
|
136
|
-
update_freq = 1,
|
|
137
|
-
params_history_tfm: Chainable | None = None,
|
|
138
|
-
grad_history_tfm: Chainable | None = None,
|
|
139
|
-
params_precond_tfm: Chainable | None = None,
|
|
140
|
-
grad_precond_tfm: Chainable | None = None,
|
|
141
|
-
update_precond_tfm: Chainable | None = None,
|
|
142
|
-
z_tfm: Chainable | None = None,
|
|
143
|
-
):
|
|
144
|
-
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, update_freq=update_freq)
|
|
145
|
-
super().__init__(defaults)
|
|
146
|
-
|
|
147
|
-
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
148
|
-
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
149
|
-
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
150
|
-
|
|
151
|
-
loc = locals().copy()
|
|
152
|
-
for k in ('update_precond_tfm', 'params_history_tfm', 'grad_history_tfm', 'params_precond_tfm', 'grad_precond_tfm','z_tfm'):
|
|
153
|
-
v = loc[k]
|
|
154
|
-
if v is not None:
|
|
155
|
-
self.set_child(k,v)
|
|
156
|
-
|
|
157
|
-
def reset(self):
|
|
158
|
-
"""Resets the internal state of the L-SR1 module."""
|
|
159
|
-
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
160
|
-
self.state.clear()
|
|
161
|
-
self.global_state['step'] = 0
|
|
162
|
-
self.global_state['s_history'].clear()
|
|
163
|
-
self.global_state['y_history'].clear()
|
|
164
|
-
self.global_state['sy_history'].clear()
|
|
165
|
-
|
|
166
|
-
@torch.no_grad
|
|
167
|
-
def step(self, var):
|
|
168
|
-
params = as_tensorlist(var.params)
|
|
169
|
-
update = as_tensorlist(var.get_update())
|
|
170
|
-
step = self.global_state.get('step', 0)
|
|
171
|
-
self.global_state['step'] = step + 1
|
|
172
|
-
|
|
173
|
-
# history of s and k
|
|
174
|
-
s_history: deque[TensorList] = self.global_state['s_history']
|
|
175
|
-
y_history: deque[TensorList] = self.global_state['y_history']
|
|
176
|
-
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
177
|
-
|
|
178
|
-
tol, damping, init_damping, eigval_bounds, update_freq = itemgetter(
|
|
179
|
-
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq')(self.settings[params[0]])
|
|
180
|
-
|
|
181
|
-
# params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params, cls=NumberList)
|
|
182
|
-
# l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
183
|
-
|
|
184
|
-
# params and update that go into history
|
|
185
|
-
params_h, update_h = _apply_tfms_into_history(
|
|
186
|
-
self,
|
|
187
|
-
params=params,
|
|
188
|
-
var=var,
|
|
189
|
-
update=update,
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
|
|
193
|
-
|
|
194
|
-
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
195
|
-
if step == 0:
|
|
196
|
-
s_k_h = None; y_k_h = None; ys_k_h = None
|
|
197
|
-
else:
|
|
198
|
-
s_k_h = params_h - prev_params_h
|
|
199
|
-
y_k_h = update_h - prev_grad_h
|
|
200
|
-
ys_k_h = s_k_h.dot(y_k_h)
|
|
201
|
-
|
|
202
|
-
if damping:
|
|
203
|
-
s_k_h, y_k_h, ys_k_h = _adaptive_damping(s_k_h, y_k_h, ys_k_h, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
204
|
-
|
|
205
|
-
prev_params_h.copy_(params_h)
|
|
206
|
-
prev_grad_h.copy_(update_h)
|
|
207
|
-
|
|
208
|
-
# update effective preconditioning state
|
|
209
|
-
if step % update_freq == 0:
|
|
210
|
-
if ys_k_h is not None and ys_k_h > 1e-10:
|
|
211
|
-
assert s_k_h is not None and y_k_h is not None
|
|
212
|
-
s_history.append(s_k_h)
|
|
213
|
-
y_history.append(y_k_h)
|
|
214
|
-
sy_history.append(ys_k_h)
|
|
215
|
-
|
|
216
|
-
# step with inner module before applying preconditioner
|
|
217
|
-
if 'update_precond_tfm' in self.children:
|
|
218
|
-
update_precond_tfm = self.children['update_precond_tfm']
|
|
219
|
-
inner_var = update_precond_tfm.step(var.clone(clone_update=True))
|
|
220
|
-
var.update_attrs_from_clone_(inner_var)
|
|
221
|
-
tensors = inner_var.update
|
|
222
|
-
assert tensors is not None
|
|
223
|
-
else:
|
|
224
|
-
tensors = update.clone()
|
|
225
|
-
|
|
226
|
-
# transforms into preconditioner
|
|
227
|
-
params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
|
|
228
|
-
prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
|
|
229
|
-
|
|
230
|
-
if step == 0:
|
|
231
|
-
s_k_p = None; y_k_p = None; ys_k_p = None
|
|
232
|
-
|
|
233
|
-
else:
|
|
234
|
-
s_k_p = params_p - prev_params_p
|
|
235
|
-
y_k_p = update_p - prev_grad_p
|
|
236
|
-
ys_k_p = s_k_p.dot(y_k_p)
|
|
237
|
-
|
|
238
|
-
if damping:
|
|
239
|
-
s_k_p, y_k_p, ys_k_p = _adaptive_damping(s_k_p, y_k_p, ys_k_p, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
240
|
-
|
|
241
|
-
prev_params_p.copy_(params_p)
|
|
242
|
-
prev_grad_p.copy_(update_p)
|
|
243
|
-
|
|
244
|
-
# tolerance on gradient difference to avoid exploding after converging
|
|
245
|
-
if tol is not None:
|
|
246
|
-
if y_k_p is not None and y_k_p.abs().global_max() <= tol:
|
|
247
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
248
|
-
return var
|
|
249
|
-
|
|
250
|
-
# precondition
|
|
251
|
-
dir = lbfgs(
|
|
252
|
-
tensors_=as_tensorlist(tensors),
|
|
253
|
-
var=var,
|
|
254
|
-
s_history=s_history,
|
|
255
|
-
y_history=y_history,
|
|
256
|
-
sy_history=sy_history,
|
|
257
|
-
y_k=y_k_p,
|
|
258
|
-
ys_k=ys_k_p,
|
|
259
|
-
z_tfm=self.children.get('z_tfm', None),
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
var.update = dir
|
|
263
|
-
|
|
264
|
-
return var
|
|
265
|
-
|
|
@@ -1,220 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from collections.abc import Mapping
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Module
|
|
8
|
-
from ...utils import TensorList
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def adaptive_tracking(
|
|
13
|
-
f,
|
|
14
|
-
f_0,
|
|
15
|
-
f_1,
|
|
16
|
-
t_0,
|
|
17
|
-
maxiter: int
|
|
18
|
-
):
|
|
19
|
-
|
|
20
|
-
t = t_0
|
|
21
|
-
f_t = f(t)
|
|
22
|
-
|
|
23
|
-
# backtrack
|
|
24
|
-
if f_t > f_0:
|
|
25
|
-
if f_1 > f_0: t = min(0.5, t_0/2)
|
|
26
|
-
while f_t > f_0:
|
|
27
|
-
maxiter -= 1
|
|
28
|
-
if maxiter < 0: return 0, f_0
|
|
29
|
-
t = t/2
|
|
30
|
-
f_t = f(t) if t!=1 else f_1
|
|
31
|
-
return t, f_t
|
|
32
|
-
|
|
33
|
-
# forwardtrack
|
|
34
|
-
f_prev = f_t
|
|
35
|
-
t *= 2
|
|
36
|
-
f_t = f(t)
|
|
37
|
-
if f_prev < f_t: return t/2, f_prev
|
|
38
|
-
while f_prev >= f_t:
|
|
39
|
-
maxiter -= 1
|
|
40
|
-
if maxiter < 0: return t, f_t
|
|
41
|
-
f_prev = f_t
|
|
42
|
-
t *= 2
|
|
43
|
-
f_t = f(t)
|
|
44
|
-
return t/2, f_prev
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class ParabolaSearch(Module):
|
|
49
|
-
""""""
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
step_size: float = 1e-2,
|
|
53
|
-
adaptive: bool=True,
|
|
54
|
-
normalize: bool=False,
|
|
55
|
-
# method: str | None = None,
|
|
56
|
-
maxiter: int | None = 10,
|
|
57
|
-
# bracket=None,
|
|
58
|
-
# bounds=None,
|
|
59
|
-
# tol: float | None = None,
|
|
60
|
-
# options=None,
|
|
61
|
-
):
|
|
62
|
-
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
63
|
-
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
64
|
-
super().__init__(defaults)
|
|
65
|
-
|
|
66
|
-
import scipy.optimize
|
|
67
|
-
self.scopt = scipy.optimize
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
@torch.no_grad
|
|
71
|
-
def step(self, var):
|
|
72
|
-
x_0 = TensorList(var.params)
|
|
73
|
-
closure = var.closure
|
|
74
|
-
assert closure is not None
|
|
75
|
-
settings = self.settings[x_0[0]]
|
|
76
|
-
step_size = settings['step_size']
|
|
77
|
-
adaptive = settings['adaptive']
|
|
78
|
-
normalize = settings['normalize']
|
|
79
|
-
maxiter = settings['maxiter']
|
|
80
|
-
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
81
|
-
|
|
82
|
-
grad = TensorList(var.get_grad())
|
|
83
|
-
f_0 = var.get_loss(False)
|
|
84
|
-
|
|
85
|
-
scale = 1
|
|
86
|
-
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
87
|
-
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
88
|
-
|
|
89
|
-
# make step
|
|
90
|
-
v_0 = grad * (step_size/scale)
|
|
91
|
-
x_0 -= v_0
|
|
92
|
-
with torch.enable_grad():
|
|
93
|
-
f_1 = closure()
|
|
94
|
-
grad = x_0.grad
|
|
95
|
-
|
|
96
|
-
x_0 += v_0
|
|
97
|
-
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
98
|
-
v_1 = grad * (step_size/scale)
|
|
99
|
-
a = v_1 - v_0
|
|
100
|
-
|
|
101
|
-
def parabolic_objective(t: float):
|
|
102
|
-
nonlocal x_0
|
|
103
|
-
|
|
104
|
-
step = v_0*t + 0.5*a*t**2
|
|
105
|
-
x_0 -= step
|
|
106
|
-
value = closure(False)
|
|
107
|
-
x_0 += step
|
|
108
|
-
return value.detach().cpu()
|
|
109
|
-
|
|
110
|
-
prev_t = self.global_state.get('prev_t', 2)
|
|
111
|
-
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
112
|
-
self.global_state['prev_t'] = t
|
|
113
|
-
|
|
114
|
-
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
115
|
-
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
116
|
-
|
|
117
|
-
# if maxiter is not None:
|
|
118
|
-
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
119
|
-
# options['maxiter'] = maxiter
|
|
120
|
-
|
|
121
|
-
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
122
|
-
# t = res.x
|
|
123
|
-
|
|
124
|
-
var.update = v_0*t + 0.5*a*t**2
|
|
125
|
-
return var
|
|
126
|
-
|
|
127
|
-
class CubicParabolaSearch(Module):
|
|
128
|
-
""""""
|
|
129
|
-
def __init__(
|
|
130
|
-
self,
|
|
131
|
-
step_size: float = 1e-2,
|
|
132
|
-
adaptive: bool=True,
|
|
133
|
-
normalize: bool=False,
|
|
134
|
-
# method: str | None = None,
|
|
135
|
-
maxiter: int | None = 10,
|
|
136
|
-
# bracket=None,
|
|
137
|
-
# bounds=None,
|
|
138
|
-
# tol: float | None = None,
|
|
139
|
-
# options=None,
|
|
140
|
-
):
|
|
141
|
-
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
142
|
-
defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
|
|
143
|
-
super().__init__(defaults)
|
|
144
|
-
|
|
145
|
-
import scipy.optimize
|
|
146
|
-
self.scopt = scipy.optimize
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@torch.no_grad
|
|
150
|
-
def step(self, var):
|
|
151
|
-
x_0 = TensorList(var.params)
|
|
152
|
-
closure = var.closure
|
|
153
|
-
assert closure is not None
|
|
154
|
-
settings = self.settings[x_0[0]]
|
|
155
|
-
step_size = settings['step_size']
|
|
156
|
-
adaptive = settings['adaptive']
|
|
157
|
-
maxiter = settings['maxiter']
|
|
158
|
-
normalize = settings['normalize']
|
|
159
|
-
if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
|
|
160
|
-
|
|
161
|
-
grad = TensorList(var.get_grad())
|
|
162
|
-
f_0 = var.get_loss(False)
|
|
163
|
-
|
|
164
|
-
scale = 1
|
|
165
|
-
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
166
|
-
if adaptive: scale = grad.abs().mean().clip(min=1e-8)
|
|
167
|
-
|
|
168
|
-
# make step
|
|
169
|
-
v_0 = grad * (step_size/scale)
|
|
170
|
-
x_0 -= v_0
|
|
171
|
-
with torch.enable_grad():
|
|
172
|
-
f_1 = closure()
|
|
173
|
-
grad = x_0.grad
|
|
174
|
-
|
|
175
|
-
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
176
|
-
v_1 = grad * (step_size/scale)
|
|
177
|
-
a_0 = v_1 - v_0
|
|
178
|
-
|
|
179
|
-
# make another step
|
|
180
|
-
x_0 -= v_1
|
|
181
|
-
with torch.enable_grad():
|
|
182
|
-
f_2 = closure()
|
|
183
|
-
grad = x_0.grad
|
|
184
|
-
|
|
185
|
-
if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
|
|
186
|
-
v_2 = grad * (step_size/scale)
|
|
187
|
-
a_1 = v_2 - v_1
|
|
188
|
-
|
|
189
|
-
j = a_1 - a_0
|
|
190
|
-
|
|
191
|
-
x_0 += v_0
|
|
192
|
-
x_0 += v_1
|
|
193
|
-
|
|
194
|
-
def parabolic_objective(t: float):
|
|
195
|
-
nonlocal x_0
|
|
196
|
-
|
|
197
|
-
step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
198
|
-
x_0 -= step
|
|
199
|
-
value = closure(False)
|
|
200
|
-
x_0 += step
|
|
201
|
-
return value
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
prev_t = self.global_state.get('prev_t', 2)
|
|
205
|
-
t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
|
|
206
|
-
self.global_state['prev_t'] = t
|
|
207
|
-
|
|
208
|
-
# method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
209
|
-
# 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
|
|
210
|
-
|
|
211
|
-
# if maxiter is not None:
|
|
212
|
-
# options = dict(options) if isinstance(options, Mapping) else {}
|
|
213
|
-
# options['maxiter'] = maxiter
|
|
214
|
-
|
|
215
|
-
# res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
216
|
-
# t = res.x
|
|
217
|
-
|
|
218
|
-
var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
|
|
219
|
-
return var
|
|
220
|
-
|
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
# import visualbench as vb
|
|
5
|
-
|
|
6
|
-
# import torchzero as tz
|
|
7
|
-
|
|
8
|
-
from ...core import Transform, Chainable, apply_transform
|
|
9
|
-
from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
|
|
10
|
-
from ...utils import TensorList, vec_to_tensors_
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def inverse_sqrt(M):
|
|
14
|
-
if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
|
|
15
|
-
return matrix_power_eigh(M, -1/2)
|
|
16
|
-
|
|
17
|
-
def update_subspace_preconditioner_(
|
|
18
|
-
grad: torch.Tensor, # store grads and basis as vectors for matmul
|
|
19
|
-
basis: torch.Tensor, # ndim, k
|
|
20
|
-
accumulator_: torch.Tensor, # k, k
|
|
21
|
-
beta: float | None,
|
|
22
|
-
):
|
|
23
|
-
projected = basis.T @ grad # k
|
|
24
|
-
outer = torch.outer(projected, projected)
|
|
25
|
-
|
|
26
|
-
if beta is None: accumulator_.add_(outer)
|
|
27
|
-
else: accumulator_.lerp_(outer, 1-beta)
|
|
28
|
-
|
|
29
|
-
def apply_subspace_preconditioner(
|
|
30
|
-
tensor: torch.Tensor,
|
|
31
|
-
basis: torch.Tensor, # ndim, k
|
|
32
|
-
accumulator: torch.Tensor,
|
|
33
|
-
):
|
|
34
|
-
preconditioner = inverse_sqrt(accumulator) # k,k
|
|
35
|
-
|
|
36
|
-
tensor_projected = basis.T @ tensor # k
|
|
37
|
-
update_projected = preconditioner @ tensor_projected # k
|
|
38
|
-
return basis @ update_projected # d
|
|
39
|
-
|
|
40
|
-
class RandomSubspacePreconditioning(Transform):
|
|
41
|
-
"""Whitens in random slowly changing subspace.
|
|
42
|
-
|
|
43
|
-
.. warning::
|
|
44
|
-
Experimental and this is a barebones implementation.
|
|
45
|
-
|
|
46
|
-
"""
|
|
47
|
-
def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
|
|
48
|
-
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
49
|
-
super().__init__(defaults, uses_grad=False)
|
|
50
|
-
|
|
51
|
-
if inner is not None: self.set_child('inner', inner)
|
|
52
|
-
|
|
53
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
|
-
settings = settings[0]
|
|
55
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
56
|
-
k = settings['k']
|
|
57
|
-
beta = settings['beta']
|
|
58
|
-
basis_beta = settings['basis_beta']
|
|
59
|
-
|
|
60
|
-
if 'basis' not in self.global_state:
|
|
61
|
-
self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
62
|
-
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
63
|
-
|
|
64
|
-
basis = self.global_state['basis']
|
|
65
|
-
accumulator = self.global_state['accumulator']
|
|
66
|
-
|
|
67
|
-
if basis_beta is not None:
|
|
68
|
-
basis.lerp_(torch.randn_like(basis), 1-basis_beta)
|
|
69
|
-
|
|
70
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
71
|
-
|
|
72
|
-
if 'inner' in self.children:
|
|
73
|
-
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
74
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
75
|
-
|
|
76
|
-
try:
|
|
77
|
-
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
78
|
-
except torch.linalg.LinAlgError:
|
|
79
|
-
preconditioned = g.clip(-0.1, 0.1)
|
|
80
|
-
vec_to_tensors_(preconditioned, tensors)
|
|
81
|
-
|
|
82
|
-
return tensors
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class HistorySubspacePreconditioning(Transform):
|
|
86
|
-
"""Whitens in subspace spanned by history of gradient differences.
|
|
87
|
-
|
|
88
|
-
.. warning::
|
|
89
|
-
Experimental and this is a barebones implementation.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
beta - for preconditioner itself in the basis.
|
|
93
|
-
basis_beta - how much basis is allowed to change.
|
|
94
|
-
"""
|
|
95
|
-
def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
|
|
96
|
-
defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
|
|
97
|
-
super().__init__(defaults, uses_grad=False)
|
|
98
|
-
|
|
99
|
-
if inner is not None: self.set_child('inner', inner)
|
|
100
|
-
|
|
101
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
102
|
-
settings = settings[0]
|
|
103
|
-
|
|
104
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
105
|
-
k = settings['k']
|
|
106
|
-
beta = settings['beta']
|
|
107
|
-
basis_beta = settings['basis_beta']
|
|
108
|
-
|
|
109
|
-
if 'history' not in self.global_state:
|
|
110
|
-
self.global_state['history'] = deque(maxlen=k)
|
|
111
|
-
self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
|
|
112
|
-
self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
history: deque = self.global_state['history']
|
|
116
|
-
accumulator = self.global_state['accumulator']
|
|
117
|
-
basis = self.global_state['basis']
|
|
118
|
-
|
|
119
|
-
history.append(g)
|
|
120
|
-
if len(history) < k:
|
|
121
|
-
basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
|
|
122
|
-
history_basis = torch.stack(tuple(history), -1)
|
|
123
|
-
basis_t[:, -len(history):] = history_basis
|
|
124
|
-
|
|
125
|
-
else:
|
|
126
|
-
basis_t = torch.stack(tuple(history), -1)
|
|
127
|
-
|
|
128
|
-
basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
|
|
129
|
-
basis_t = (basis_t - basis_t.mean()) / basis_t.std()
|
|
130
|
-
|
|
131
|
-
basis.lerp_(basis_t, 1-basis_beta)
|
|
132
|
-
update_subspace_preconditioner_(g, basis, accumulator, beta)
|
|
133
|
-
|
|
134
|
-
if 'inner' in self.children:
|
|
135
|
-
tensors = apply_transform(self.children['inner'], tensors, params, grads)
|
|
136
|
-
g = torch.cat([t.view(-1) for t in tensors])
|
|
137
|
-
|
|
138
|
-
try:
|
|
139
|
-
preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
|
|
140
|
-
except torch.linalg.LinAlgError:
|
|
141
|
-
preconditioned = g.clip(-0.1,0.1)
|
|
142
|
-
vec_to_tensors_(preconditioned, tensors)
|
|
143
|
-
|
|
144
|
-
return tensors
|
|
145
|
-
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, TensorwiseTransform
|
|
6
|
-
from ...utils.linalg import matrix_power_eigh
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TensorAdagrad(TensorwiseTransform):
|
|
10
|
-
"""3rd order whitening (maybe normalizes skewness, but don't quote me on it).
|
|
11
|
-
|
|
12
|
-
.. warning::
|
|
13
|
-
Experimental.
|
|
14
|
-
"""
|
|
15
|
-
def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
|
|
16
|
-
defaults = dict(history_size=history_size, reg=reg)
|
|
17
|
-
super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
|
|
18
|
-
|
|
19
|
-
@torch.no_grad
|
|
20
|
-
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
21
|
-
reg = setting['reg']
|
|
22
|
-
if 'history' not in state:
|
|
23
|
-
state['history'] = deque(maxlen=setting['history_size'])
|
|
24
|
-
|
|
25
|
-
g = tensor.view(-1)
|
|
26
|
-
history = state['history']
|
|
27
|
-
history.append(g.clone())
|
|
28
|
-
|
|
29
|
-
I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
|
|
30
|
-
g_k = history[0]
|
|
31
|
-
outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
32
|
-
if len(history) > 1:
|
|
33
|
-
for g_k in list(history)[1:]:
|
|
34
|
-
outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
|
|
35
|
-
|
|
36
|
-
state['outer'] = outer.add_(I)
|
|
37
|
-
|
|
38
|
-
@torch.no_grad
|
|
39
|
-
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
40
|
-
outer = state['outer']
|
|
41
|
-
P = matrix_power_eigh(outer, -1/2)
|
|
42
|
-
return (P @ tensor.ravel()).view_as(tensor)
|