torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ....core import Chainable, Module, Transform, Vars, apply, maybe_chain
|
|
8
|
+
from ....utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _adaptive_damping(
|
|
12
|
+
s_k: TensorList,
|
|
13
|
+
y_k: TensorList,
|
|
14
|
+
ys_k: torch.Tensor,
|
|
15
|
+
init_damping = 0.99,
|
|
16
|
+
eigval_bounds = (0.01, 1.5)
|
|
17
|
+
):
|
|
18
|
+
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
19
|
+
sigma_l, sigma_h = eigval_bounds
|
|
20
|
+
u = ys_k / s_k.dot(s_k)
|
|
21
|
+
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
22
|
+
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
23
|
+
else: tau = init_damping
|
|
24
|
+
y_k = tau * y_k + (1-tau) * s_k
|
|
25
|
+
ys_k = s_k.dot(y_k)
|
|
26
|
+
|
|
27
|
+
return s_k, y_k, ys_k
|
|
28
|
+
|
|
29
|
+
def lbfgs(
|
|
30
|
+
tensors_: TensorList,
|
|
31
|
+
vars: Vars,
|
|
32
|
+
s_history: deque[TensorList],
|
|
33
|
+
y_history: deque[TensorList],
|
|
34
|
+
sy_history: deque[torch.Tensor],
|
|
35
|
+
y_k: TensorList | None,
|
|
36
|
+
ys_k: torch.Tensor | None,
|
|
37
|
+
z_tfm: Any,
|
|
38
|
+
):
|
|
39
|
+
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
40
|
+
# dir = params.grad.sign() # may work fine
|
|
41
|
+
|
|
42
|
+
# initial step size guess taken from pytorch L-BFGS
|
|
43
|
+
return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
|
|
44
|
+
|
|
45
|
+
else:
|
|
46
|
+
# 1st loop
|
|
47
|
+
alpha_list = []
|
|
48
|
+
q = tensors_.clone()
|
|
49
|
+
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
50
|
+
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
51
|
+
alpha = p_i * s_i.dot(q)
|
|
52
|
+
alpha_list.append(alpha)
|
|
53
|
+
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
54
|
+
|
|
55
|
+
# calculate z
|
|
56
|
+
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
57
|
+
# z is it times q
|
|
58
|
+
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
59
|
+
z = q * (ys_k / (y_k.dot(y_k)))
|
|
60
|
+
|
|
61
|
+
if z_tfm is not None:
|
|
62
|
+
z = TensorList(apply(z_tfm, tensors=z, params=vars.params, grads=vars.grad, vars=vars))
|
|
63
|
+
|
|
64
|
+
# 2nd loop
|
|
65
|
+
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
66
|
+
p_i = 1 / ys_i
|
|
67
|
+
beta_i = p_i * y_i.dot(z)
|
|
68
|
+
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
69
|
+
|
|
70
|
+
return z
|
|
71
|
+
|
|
72
|
+
def _apply_tfms_into_history(
|
|
73
|
+
self: Module,
|
|
74
|
+
params: list[torch.Tensor],
|
|
75
|
+
vars: Vars,
|
|
76
|
+
update: list[torch.Tensor],
|
|
77
|
+
):
|
|
78
|
+
if 'params_history_tfm' in self.children:
|
|
79
|
+
params = apply(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
|
|
80
|
+
|
|
81
|
+
if 'grad_history_tfm' in self.children:
|
|
82
|
+
update = apply(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=vars.grad, vars=vars)
|
|
83
|
+
|
|
84
|
+
return params, update
|
|
85
|
+
|
|
86
|
+
def _apply_tfms_into_precond(
|
|
87
|
+
self: Module,
|
|
88
|
+
params: list[torch.Tensor],
|
|
89
|
+
vars: Vars,
|
|
90
|
+
update: list[torch.Tensor],
|
|
91
|
+
):
|
|
92
|
+
if 'params_precond_tfm' in self.children:
|
|
93
|
+
params = apply(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
|
|
94
|
+
|
|
95
|
+
if 'grad_precond_tfm' in self.children:
|
|
96
|
+
update = apply(self.children['grad_precond_tfm'], tensors=update, params=params, grads=vars.grad, vars=vars)
|
|
97
|
+
|
|
98
|
+
return params, update
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ModularLBFGS(Module):
|
|
102
|
+
"""L-BFGS with ability to apply transforms to many inner variables.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
106
|
+
tol (float | None, optional):
|
|
107
|
+
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
108
|
+
damping (bool, optional):
|
|
109
|
+
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
110
|
+
init_damping (float, optional):
|
|
111
|
+
initial damping for adaptive dampening. Defaults to 0.9.
|
|
112
|
+
eigval_bounds (tuple, optional):
|
|
113
|
+
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
114
|
+
update_freq (int, optional):
|
|
115
|
+
how often to update L-BFGS history. Defaults to 1.
|
|
116
|
+
z_tfm (float | None, optional):
|
|
117
|
+
transform module applied to initial H^-1 @ q guess. Defaults to None.
|
|
118
|
+
params_history_tfm (AnyTransform | None, optional):
|
|
119
|
+
transform module applied to params before adding s_k to history. Defaults to None.
|
|
120
|
+
grad_history_tfm (AnyTransform | None, optional):
|
|
121
|
+
transform module applied to grads before adding y_k to history. Defaults to None.
|
|
122
|
+
params_precond_tfm (AnyTransform | None, optional):
|
|
123
|
+
transform module applied to params to calculate s_k before preconditioning. Defaults to None.
|
|
124
|
+
grad_precond_tfm (AnyTransform | None, optional):
|
|
125
|
+
transform module applied to grads to calculate y_k before preconditioning. Defaults to None.
|
|
126
|
+
update_precond_tfm (Chainable | None, optional):
|
|
127
|
+
transform module applied to grads that are being preconditioned. Defaults to None.
|
|
128
|
+
"""
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
history_size=10,
|
|
132
|
+
tol: float | None = 1e-10,
|
|
133
|
+
damping: bool = False,
|
|
134
|
+
init_damping=0.9,
|
|
135
|
+
eigval_bounds=(0.5, 50),
|
|
136
|
+
update_freq = 1,
|
|
137
|
+
params_history_tfm: Chainable | None = None,
|
|
138
|
+
grad_history_tfm: Chainable | None = None,
|
|
139
|
+
params_precond_tfm: Chainable | None = None,
|
|
140
|
+
grad_precond_tfm: Chainable | None = None,
|
|
141
|
+
update_precond_tfm: Chainable | None = None,
|
|
142
|
+
z_tfm: Chainable | None = None,
|
|
143
|
+
):
|
|
144
|
+
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, update_freq=update_freq)
|
|
145
|
+
super().__init__(defaults)
|
|
146
|
+
|
|
147
|
+
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
148
|
+
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
149
|
+
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
150
|
+
|
|
151
|
+
loc = locals().copy()
|
|
152
|
+
for k in ('update_precond_tfm', 'params_history_tfm', 'grad_history_tfm', 'params_precond_tfm', 'grad_precond_tfm','z_tfm'):
|
|
153
|
+
v = loc[k]
|
|
154
|
+
if v is not None:
|
|
155
|
+
self.set_child(k,v)
|
|
156
|
+
|
|
157
|
+
def reset(self):
|
|
158
|
+
"""Resets the internal state of the L-SR1 module."""
|
|
159
|
+
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
160
|
+
self.state.clear()
|
|
161
|
+
self.global_state['step'] = 0
|
|
162
|
+
self.global_state['s_history'].clear()
|
|
163
|
+
self.global_state['y_history'].clear()
|
|
164
|
+
self.global_state['sy_history'].clear()
|
|
165
|
+
|
|
166
|
+
@torch.no_grad
|
|
167
|
+
def step(self, vars):
|
|
168
|
+
params = as_tensorlist(vars.params)
|
|
169
|
+
update = as_tensorlist(vars.get_update())
|
|
170
|
+
step = self.global_state.get('step', 0)
|
|
171
|
+
self.global_state['step'] = step + 1
|
|
172
|
+
|
|
173
|
+
# history of s and k
|
|
174
|
+
s_history: deque[TensorList] = self.global_state['s_history']
|
|
175
|
+
y_history: deque[TensorList] = self.global_state['y_history']
|
|
176
|
+
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
177
|
+
|
|
178
|
+
tol, damping, init_damping, eigval_bounds, update_freq = itemgetter(
|
|
179
|
+
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq')(self.settings[params[0]])
|
|
180
|
+
|
|
181
|
+
# params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params, cls=NumberList)
|
|
182
|
+
# l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
183
|
+
|
|
184
|
+
# params and update that go into history
|
|
185
|
+
params_h, update_h = _apply_tfms_into_history(
|
|
186
|
+
self,
|
|
187
|
+
params=params,
|
|
188
|
+
vars=vars,
|
|
189
|
+
update=update,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h', params=params, cls=TensorList)
|
|
193
|
+
|
|
194
|
+
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
195
|
+
if step == 0:
|
|
196
|
+
s_k_h = None; y_k_h = None; ys_k_h = None
|
|
197
|
+
else:
|
|
198
|
+
s_k_h = params_h - prev_params_h
|
|
199
|
+
y_k_h = update_h - prev_grad_h
|
|
200
|
+
ys_k_h = s_k_h.dot(y_k_h)
|
|
201
|
+
|
|
202
|
+
if damping:
|
|
203
|
+
s_k_h, y_k_h, ys_k_h = _adaptive_damping(s_k_h, y_k_h, ys_k_h, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
204
|
+
|
|
205
|
+
prev_params_h.copy_(params_h)
|
|
206
|
+
prev_grad_h.copy_(update_h)
|
|
207
|
+
|
|
208
|
+
# update effective preconditioning state
|
|
209
|
+
if step % update_freq == 0:
|
|
210
|
+
if ys_k_h is not None and ys_k_h > 1e-10:
|
|
211
|
+
assert s_k_h is not None and y_k_h is not None
|
|
212
|
+
s_history.append(s_k_h)
|
|
213
|
+
y_history.append(y_k_h)
|
|
214
|
+
sy_history.append(ys_k_h)
|
|
215
|
+
|
|
216
|
+
# step with inner module before applying preconditioner
|
|
217
|
+
if 'update_precond_tfm' in self.children:
|
|
218
|
+
update_precond_tfm = self.children['update_precond_tfm']
|
|
219
|
+
inner_vars = update_precond_tfm.step(vars.clone(clone_update=True))
|
|
220
|
+
vars.update_attrs_from_clone_(inner_vars)
|
|
221
|
+
tensors = inner_vars.update
|
|
222
|
+
assert tensors is not None
|
|
223
|
+
else:
|
|
224
|
+
tensors = update.clone()
|
|
225
|
+
|
|
226
|
+
# transforms into preconditioner
|
|
227
|
+
params_p, update_p = _apply_tfms_into_precond(self, params=params, vars=vars, update=update)
|
|
228
|
+
prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p', params=params, cls=TensorList)
|
|
229
|
+
|
|
230
|
+
if step == 0:
|
|
231
|
+
s_k_p = None; y_k_p = None; ys_k_p = None
|
|
232
|
+
|
|
233
|
+
else:
|
|
234
|
+
s_k_p = params_p - prev_params_p
|
|
235
|
+
y_k_p = update_p - prev_grad_p
|
|
236
|
+
ys_k_p = s_k_p.dot(y_k_p)
|
|
237
|
+
|
|
238
|
+
if damping:
|
|
239
|
+
s_k_p, y_k_p, ys_k_p = _adaptive_damping(s_k_p, y_k_p, ys_k_p, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
240
|
+
|
|
241
|
+
prev_params_p.copy_(params_p)
|
|
242
|
+
prev_grad_p.copy_(update_p)
|
|
243
|
+
|
|
244
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
245
|
+
if tol is not None:
|
|
246
|
+
if y_k_p is not None and y_k_p.abs().global_max() <= tol:
|
|
247
|
+
vars.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
248
|
+
return vars
|
|
249
|
+
|
|
250
|
+
# precondition
|
|
251
|
+
dir = lbfgs(
|
|
252
|
+
tensors_=as_tensorlist(tensors),
|
|
253
|
+
vars=vars,
|
|
254
|
+
s_history=s_history,
|
|
255
|
+
y_history=y_history,
|
|
256
|
+
sy_history=sy_history,
|
|
257
|
+
y_k=y_k_p,
|
|
258
|
+
ys_k=ys_k_p,
|
|
259
|
+
z_tfm=self.children.get('z_tfm', None),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
vars.update = dir
|
|
263
|
+
|
|
264
|
+
return vars
|
|
265
|
+
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Transform, Chainable, Module, Vars, apply
|
|
6
|
+
from ...utils import TensorList, as_tensorlist, NumberList
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _adaptive_damping(
|
|
10
|
+
s_k: TensorList,
|
|
11
|
+
y_k: TensorList,
|
|
12
|
+
ys_k: torch.Tensor,
|
|
13
|
+
init_damping = 0.99,
|
|
14
|
+
eigval_bounds = (0.01, 1.5)
|
|
15
|
+
):
|
|
16
|
+
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
17
|
+
sigma_l, sigma_h = eigval_bounds
|
|
18
|
+
u = ys_k / s_k.dot(s_k)
|
|
19
|
+
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
20
|
+
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
21
|
+
else: tau = init_damping
|
|
22
|
+
y_k = tau * y_k + (1-tau) * s_k
|
|
23
|
+
ys_k = s_k.dot(y_k)
|
|
24
|
+
|
|
25
|
+
return s_k, y_k, ys_k
|
|
26
|
+
|
|
27
|
+
def lbfgs(
|
|
28
|
+
tensors_: TensorList,
|
|
29
|
+
s_history: deque[TensorList],
|
|
30
|
+
y_history: deque[TensorList],
|
|
31
|
+
sy_history: deque[torch.Tensor],
|
|
32
|
+
y_k: TensorList | None,
|
|
33
|
+
ys_k: torch.Tensor | None,
|
|
34
|
+
z_beta: float | None,
|
|
35
|
+
z_ema: TensorList | None,
|
|
36
|
+
step: int,
|
|
37
|
+
):
|
|
38
|
+
if len(s_history) == 0 or y_k is None or ys_k is None:
|
|
39
|
+
# dir = params.grad.sign() # may work fine
|
|
40
|
+
|
|
41
|
+
# initial step size guess taken from pytorch L-BFGS
|
|
42
|
+
return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
|
|
43
|
+
|
|
44
|
+
else:
|
|
45
|
+
# 1st loop
|
|
46
|
+
alpha_list = []
|
|
47
|
+
q = tensors_.clone()
|
|
48
|
+
for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
49
|
+
p_i = 1 / ys_i # this is also denoted as ρ (rho)
|
|
50
|
+
alpha = p_i * s_i.dot(q)
|
|
51
|
+
alpha_list.append(alpha)
|
|
52
|
+
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
53
|
+
|
|
54
|
+
# calculate z
|
|
55
|
+
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
56
|
+
# z is it times q
|
|
57
|
+
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
58
|
+
z = q * (ys_k / (y_k.dot(y_k)))
|
|
59
|
+
|
|
60
|
+
# an attempt into adding momentum, lerping initial z seems stable compared to other variables
|
|
61
|
+
if z_beta is not None:
|
|
62
|
+
assert z_ema is not None
|
|
63
|
+
if step == 0: z_ema.copy_(z)
|
|
64
|
+
else: z_ema.lerp(z, 1-z_beta)
|
|
65
|
+
z = z_ema
|
|
66
|
+
|
|
67
|
+
# 2nd loop
|
|
68
|
+
for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
69
|
+
p_i = 1 / ys_i
|
|
70
|
+
beta_i = p_i * y_i.dot(z)
|
|
71
|
+
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
72
|
+
|
|
73
|
+
return z
|
|
74
|
+
|
|
75
|
+
def _lerp_params_update_(
|
|
76
|
+
self_: Module,
|
|
77
|
+
params: list[torch.Tensor],
|
|
78
|
+
update: list[torch.Tensor],
|
|
79
|
+
params_beta: list[float | None],
|
|
80
|
+
grads_beta: list[float | None],
|
|
81
|
+
):
|
|
82
|
+
for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
|
|
83
|
+
if p_beta is not None or u_beta is not None:
|
|
84
|
+
state = self_.state[p]
|
|
85
|
+
|
|
86
|
+
if p_beta is not None:
|
|
87
|
+
if 'param_ema' not in state: state['param_ema'] = p.clone()
|
|
88
|
+
else: state['param_ema'].lerp_(p, 1-p_beta)
|
|
89
|
+
params[i] = state['param_ema']
|
|
90
|
+
|
|
91
|
+
if u_beta is not None:
|
|
92
|
+
if 'grad_ema' not in state: state['grad_ema'] = u.clone()
|
|
93
|
+
else: state['grad_ema'].lerp_(u, 1-u_beta)
|
|
94
|
+
update[i] = state['grad_ema']
|
|
95
|
+
|
|
96
|
+
return TensorList(params), TensorList(update)
|
|
97
|
+
|
|
98
|
+
class LBFGS(Module):
|
|
99
|
+
"""L-BFGS
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
103
|
+
tol (float | None, optional):
|
|
104
|
+
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
105
|
+
damping (bool, optional):
|
|
106
|
+
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
107
|
+
init_damping (float, optional):
|
|
108
|
+
initial damping for adaptive dampening. Defaults to 0.9.
|
|
109
|
+
eigval_bounds (tuple, optional):
|
|
110
|
+
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
111
|
+
params_beta (float | None, optional):
|
|
112
|
+
if not None, EMA of parameters is used for preconditioner update. Defaults to None.
|
|
113
|
+
grads_beta (float | None, optional):
|
|
114
|
+
if not None, EMA of gradients is used for preconditioner update. Defaults to None.
|
|
115
|
+
update_freq (int, optional):
|
|
116
|
+
how often to update L-BFGS history. Defaults to 1.
|
|
117
|
+
z_beta (float | None, optional):
|
|
118
|
+
optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
|
|
119
|
+
tol_reset (bool, optional):
|
|
120
|
+
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
121
|
+
inner (Chainable | None, optional):
|
|
122
|
+
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
123
|
+
"""
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
history_size=10,
|
|
127
|
+
tol: float | None = 1e-10,
|
|
128
|
+
damping: bool = False,
|
|
129
|
+
init_damping=0.9,
|
|
130
|
+
eigval_bounds=(0.5, 50),
|
|
131
|
+
params_beta: float | None = None,
|
|
132
|
+
grads_beta: float | None = None,
|
|
133
|
+
update_freq = 1,
|
|
134
|
+
z_beta: float | None = None,
|
|
135
|
+
tol_reset: bool = False,
|
|
136
|
+
inner: Chainable | None = None,
|
|
137
|
+
):
|
|
138
|
+
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
|
|
139
|
+
super().__init__(defaults)
|
|
140
|
+
|
|
141
|
+
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
142
|
+
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
143
|
+
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
144
|
+
|
|
145
|
+
if inner is not None:
|
|
146
|
+
self.set_child('inner', inner)
|
|
147
|
+
|
|
148
|
+
def reset(self):
|
|
149
|
+
self.state.clear()
|
|
150
|
+
self.global_state['step'] = 0
|
|
151
|
+
self.global_state['s_history'].clear()
|
|
152
|
+
self.global_state['y_history'].clear()
|
|
153
|
+
self.global_state['sy_history'].clear()
|
|
154
|
+
|
|
155
|
+
@torch.no_grad
|
|
156
|
+
def step(self, vars):
|
|
157
|
+
params = as_tensorlist(vars.params)
|
|
158
|
+
update = as_tensorlist(vars.get_update())
|
|
159
|
+
step = self.global_state.get('step', 0)
|
|
160
|
+
self.global_state['step'] = step + 1
|
|
161
|
+
|
|
162
|
+
# history of s and k
|
|
163
|
+
s_history: deque[TensorList] = self.global_state['s_history']
|
|
164
|
+
y_history: deque[TensorList] = self.global_state['y_history']
|
|
165
|
+
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
166
|
+
|
|
167
|
+
tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
|
|
168
|
+
'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
|
|
169
|
+
params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params)
|
|
170
|
+
|
|
171
|
+
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
172
|
+
prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
|
|
173
|
+
|
|
174
|
+
# 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
|
|
175
|
+
if step == 0:
|
|
176
|
+
s_k = None; y_k = None; ys_k = None
|
|
177
|
+
else:
|
|
178
|
+
s_k = l_params - prev_l_params
|
|
179
|
+
y_k = l_update - prev_l_grad
|
|
180
|
+
ys_k = s_k.dot(y_k)
|
|
181
|
+
|
|
182
|
+
if damping:
|
|
183
|
+
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
184
|
+
|
|
185
|
+
prev_l_params.copy_(l_params)
|
|
186
|
+
prev_l_grad.copy_(l_update)
|
|
187
|
+
|
|
188
|
+
# update effective preconditioning state
|
|
189
|
+
if step % update_freq == 0:
|
|
190
|
+
if ys_k is not None and ys_k > 1e-10:
|
|
191
|
+
assert s_k is not None and y_k is not None
|
|
192
|
+
s_history.append(s_k)
|
|
193
|
+
y_history.append(y_k)
|
|
194
|
+
sy_history.append(ys_k)
|
|
195
|
+
|
|
196
|
+
# step with inner module before applying preconditioner
|
|
197
|
+
if self.children:
|
|
198
|
+
update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
|
|
199
|
+
|
|
200
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
201
|
+
if tol is not None:
|
|
202
|
+
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
203
|
+
vars.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
204
|
+
if tol_reset: self.reset()
|
|
205
|
+
return vars
|
|
206
|
+
|
|
207
|
+
# lerp initial H^-1 @ q guess
|
|
208
|
+
z_ema = None
|
|
209
|
+
if z_beta is not None:
|
|
210
|
+
z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
|
|
211
|
+
|
|
212
|
+
# precondition
|
|
213
|
+
dir = lbfgs(
|
|
214
|
+
tensors_=as_tensorlist(update),
|
|
215
|
+
s_history=s_history,
|
|
216
|
+
y_history=y_history,
|
|
217
|
+
sy_history=sy_history,
|
|
218
|
+
y_k=y_k,
|
|
219
|
+
ys_k=ys_k,
|
|
220
|
+
z_beta = z_beta,
|
|
221
|
+
z_ema = z_ema,
|
|
222
|
+
step=step
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
vars.update = dir
|
|
226
|
+
|
|
227
|
+
return vars
|
|
228
|
+
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Module, Transform, Vars, apply
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
8
|
+
|
|
9
|
+
from .lbfgs import _lerp_params_update_
|
|
10
|
+
|
|
11
|
+
def lsr1_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
s_history: deque[TensorList],
|
|
14
|
+
y_history: deque[TensorList],
|
|
15
|
+
step: int,
|
|
16
|
+
scale_second: bool,
|
|
17
|
+
):
|
|
18
|
+
if step == 0 or not s_history:
|
|
19
|
+
# initial step size guess from pytorch
|
|
20
|
+
tensors_.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
|
|
21
|
+
return tensors_
|
|
22
|
+
|
|
23
|
+
m = len(s_history)
|
|
24
|
+
|
|
25
|
+
w_list: list[TensorList] = []
|
|
26
|
+
ww_list: list = [None for _ in range(m)]
|
|
27
|
+
wy_list: list = [None for _ in range(m)]
|
|
28
|
+
|
|
29
|
+
# 1st loop - all w_k = s_k - H_k_prev y_k
|
|
30
|
+
for k in range(m):
|
|
31
|
+
s_k = s_history[k]
|
|
32
|
+
y_k = y_history[k]
|
|
33
|
+
|
|
34
|
+
H_k = y_k.clone()
|
|
35
|
+
for j in range(k):
|
|
36
|
+
w_j = w_list[j]
|
|
37
|
+
y_j = y_history[j]
|
|
38
|
+
|
|
39
|
+
wy = wy_list[j]
|
|
40
|
+
if wy is None: wy = wy_list[j] = w_j.dot(y_j)
|
|
41
|
+
|
|
42
|
+
ww = ww_list[j]
|
|
43
|
+
if ww is None: ww = ww_list[j] = w_j.dot(w_j)
|
|
44
|
+
|
|
45
|
+
if wy == 0: continue
|
|
46
|
+
|
|
47
|
+
H_k.add_(w_j, alpha=w_j.dot(y_k) / wy) # pyright:ignore[reportArgumentType]
|
|
48
|
+
|
|
49
|
+
w_k = s_k - H_k
|
|
50
|
+
w_list.append(w_k)
|
|
51
|
+
|
|
52
|
+
Hx = tensors_.clone()
|
|
53
|
+
for k in range(m):
|
|
54
|
+
w_k = w_list[k]
|
|
55
|
+
y_k = y_history[k]
|
|
56
|
+
wy = wy_list[k]
|
|
57
|
+
ww = ww_list[k]
|
|
58
|
+
|
|
59
|
+
if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
|
|
60
|
+
if ww is None: ww = w_k.dot(w_k)
|
|
61
|
+
|
|
62
|
+
if wy == 0: continue
|
|
63
|
+
|
|
64
|
+
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
65
|
+
|
|
66
|
+
if scale_second and step == 1:
|
|
67
|
+
Hx.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
|
|
68
|
+
return Hx
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LSR1(Module):
|
|
72
|
+
"""Limited Memory SR1 (L-SR1)
|
|
73
|
+
Args:
|
|
74
|
+
history_size (int, optional): Number of past parameter differences (s)
|
|
75
|
+
and gradient differences (y) to store. Defaults to 10.
|
|
76
|
+
skip_R_val (float, optional): Tolerance R for the SR1 update skip condition
|
|
77
|
+
|w_k^T y_k| >= R * ||w_k|| * ||y_k||. Defaults to 1e-8.
|
|
78
|
+
Updates where this condition is not met are skipped during history accumulation
|
|
79
|
+
and matrix-vector products.
|
|
80
|
+
params_beta (float | None, optional): If not None, EMA of parameters is used for
|
|
81
|
+
preconditioner update (s_k vector). Defaults to None.
|
|
82
|
+
grads_beta (float | None, optional): If not None, EMA of gradients is used for
|
|
83
|
+
preconditioner update (y_k vector). Defaults to None.
|
|
84
|
+
update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
|
|
85
|
+
conv_tol (float | None, optional): Tolerance for y_k norm. If max abs value of y_k
|
|
86
|
+
is below this, the preconditioning step might be skipped, assuming convergence.
|
|
87
|
+
Defaults to 1e-10.
|
|
88
|
+
inner (Chainable | None, optional): Optional inner modules applied after updating
|
|
89
|
+
L-SR1 history and before preconditioning. Defaults to None.
|
|
90
|
+
"""
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
history_size: int = 10,
|
|
94
|
+
tol: float = 1e-8,
|
|
95
|
+
params_beta: float | None = None,
|
|
96
|
+
grads_beta: float | None = None,
|
|
97
|
+
update_freq: int = 1,
|
|
98
|
+
scale_second: bool = True,
|
|
99
|
+
inner: Chainable | None = None,
|
|
100
|
+
):
|
|
101
|
+
defaults = dict(
|
|
102
|
+
history_size=history_size, tol=tol,
|
|
103
|
+
params_beta=params_beta, grads_beta=grads_beta,
|
|
104
|
+
update_freq=update_freq, scale_second=scale_second
|
|
105
|
+
)
|
|
106
|
+
super().__init__(defaults)
|
|
107
|
+
|
|
108
|
+
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
109
|
+
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
110
|
+
|
|
111
|
+
if inner is not None:
|
|
112
|
+
self.set_child('inner', inner)
|
|
113
|
+
|
|
114
|
+
def reset(self):
|
|
115
|
+
self.state.clear()
|
|
116
|
+
self.global_state['step'] = 0
|
|
117
|
+
self.global_state['s_history'].clear()
|
|
118
|
+
self.global_state['y_history'].clear()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@torch.no_grad
|
|
122
|
+
def step(self, vars: Vars):
|
|
123
|
+
params = as_tensorlist(vars.params)
|
|
124
|
+
update = as_tensorlist(vars.get_update())
|
|
125
|
+
step = self.global_state.get('step', 0)
|
|
126
|
+
self.global_state['step'] = step + 1
|
|
127
|
+
|
|
128
|
+
s_history: deque[TensorList] = self.global_state['s_history']
|
|
129
|
+
y_history: deque[TensorList] = self.global_state['y_history']
|
|
130
|
+
|
|
131
|
+
settings = self.settings[params[0]]
|
|
132
|
+
tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
|
|
133
|
+
|
|
134
|
+
params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta', params=params) # type: ignore
|
|
135
|
+
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
|
|
136
|
+
|
|
137
|
+
prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
|
|
138
|
+
|
|
139
|
+
y_k = None
|
|
140
|
+
if step != 0:
|
|
141
|
+
if step % update_freq == 0:
|
|
142
|
+
s_k = l_params - prev_l_params
|
|
143
|
+
y_k = l_update - prev_l_grad
|
|
144
|
+
|
|
145
|
+
s_history.append(s_k)
|
|
146
|
+
y_history.append(y_k)
|
|
147
|
+
|
|
148
|
+
prev_l_params.copy_(l_params)
|
|
149
|
+
prev_l_grad.copy_(l_update)
|
|
150
|
+
|
|
151
|
+
if 'inner' in self.children:
|
|
152
|
+
update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
|
|
153
|
+
|
|
154
|
+
# tolerance on gradient difference to avoid exploding after converging
|
|
155
|
+
if tol is not None:
|
|
156
|
+
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
157
|
+
vars.update = update
|
|
158
|
+
return vars
|
|
159
|
+
|
|
160
|
+
dir = lsr1_(
|
|
161
|
+
tensors_=update,
|
|
162
|
+
s_history=s_history,
|
|
163
|
+
y_history=y_history,
|
|
164
|
+
step=step,
|
|
165
|
+
scale_second=scale_second,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
vars.update = dir
|
|
169
|
+
|
|
170
|
+
return vars
|