torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,77 +1,76 @@
|
|
|
1
1
|
from collections import deque
|
|
2
2
|
from operator import itemgetter
|
|
3
|
+
|
|
3
4
|
import torch
|
|
4
5
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import TensorList, as_tensorlist,
|
|
6
|
+
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import safe_scaling_
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def _adaptive_damping(
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
s: TensorList,
|
|
13
|
+
y: TensorList,
|
|
14
|
+
sy: torch.Tensor,
|
|
13
15
|
init_damping = 0.99,
|
|
14
16
|
eigval_bounds = (0.01, 1.5)
|
|
15
17
|
):
|
|
16
18
|
# adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
|
|
17
19
|
sigma_l, sigma_h = eigval_bounds
|
|
18
|
-
u =
|
|
20
|
+
u = sy / s.dot(s)
|
|
19
21
|
if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
|
|
20
22
|
elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
|
|
21
23
|
else: tau = init_damping
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
y = tau * y + (1-tau) * s
|
|
25
|
+
sy = s.dot(y)
|
|
24
26
|
|
|
25
|
-
return
|
|
27
|
+
return s, y, sy
|
|
26
28
|
|
|
27
29
|
def lbfgs(
|
|
28
30
|
tensors_: TensorList,
|
|
29
31
|
s_history: deque[TensorList],
|
|
30
32
|
y_history: deque[TensorList],
|
|
31
33
|
sy_history: deque[torch.Tensor],
|
|
32
|
-
|
|
33
|
-
|
|
34
|
+
y: TensorList | None,
|
|
35
|
+
sy: torch.Tensor | None,
|
|
34
36
|
z_beta: float | None,
|
|
35
37
|
z_ema: TensorList | None,
|
|
36
38
|
step: int,
|
|
37
39
|
):
|
|
38
|
-
if len(s_history) == 0 or
|
|
40
|
+
if len(s_history) == 0 or y is None or sy is None:
|
|
39
41
|
|
|
40
42
|
# initial step size guess modified from pytorch L-BFGS
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
73
|
-
|
|
74
|
-
return z
|
|
43
|
+
return safe_scaling_(TensorList(tensors_))
|
|
44
|
+
|
|
45
|
+
# 1st loop
|
|
46
|
+
alpha_list = []
|
|
47
|
+
q = tensors_.clone()
|
|
48
|
+
for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
|
|
49
|
+
p_i = 1 / sy_i # this is also denoted as ρ (rho)
|
|
50
|
+
alpha = p_i * s_i.dot(q)
|
|
51
|
+
alpha_list.append(alpha)
|
|
52
|
+
q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
|
|
53
|
+
|
|
54
|
+
# calculate z
|
|
55
|
+
# s.y/y.y is also this weird y-looking symbol I couldn't find
|
|
56
|
+
# z is it times q
|
|
57
|
+
# actually H0 = (s.y/y.y) * I, and z = H0 @ q
|
|
58
|
+
z = q * (sy / (y.dot(y)))
|
|
59
|
+
|
|
60
|
+
# an attempt into adding momentum, lerping initial z seems stable compared to other variables
|
|
61
|
+
if z_beta is not None:
|
|
62
|
+
assert z_ema is not None
|
|
63
|
+
if step == 1: z_ema.copy_(z)
|
|
64
|
+
else: z_ema.lerp(z, 1-z_beta)
|
|
65
|
+
z = z_ema
|
|
66
|
+
|
|
67
|
+
# 2nd loop
|
|
68
|
+
for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
|
|
69
|
+
p_i = 1 / sy_i
|
|
70
|
+
beta_i = p_i * y_i.dot(z)
|
|
71
|
+
z.add_(s_i, alpha = alpha_i - beta_i)
|
|
72
|
+
|
|
73
|
+
return z
|
|
75
74
|
|
|
76
75
|
def _lerp_params_update_(
|
|
77
76
|
self_: Module,
|
|
@@ -96,19 +95,24 @@ def _lerp_params_update_(
|
|
|
96
95
|
|
|
97
96
|
return TensorList(params), TensorList(update)
|
|
98
97
|
|
|
99
|
-
class LBFGS(
|
|
100
|
-
"""L-BFGS
|
|
98
|
+
class LBFGS(Transform):
|
|
99
|
+
"""Limited-memory BFGS algorithm. A line search is recommended, although L-BFGS may be reasonably stable without it.
|
|
101
100
|
|
|
102
101
|
Args:
|
|
103
|
-
history_size (int, optional):
|
|
104
|
-
|
|
105
|
-
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
102
|
+
history_size (int, optional):
|
|
103
|
+
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
106
104
|
damping (bool, optional):
|
|
107
105
|
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
108
106
|
init_damping (float, optional):
|
|
109
107
|
initial damping for adaptive dampening. Defaults to 0.9.
|
|
110
108
|
eigval_bounds (tuple, optional):
|
|
111
109
|
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
110
|
+
tol (float | None, optional):
|
|
111
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
|
|
112
|
+
tol_reset (bool, optional):
|
|
113
|
+
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
114
|
+
gtol (float | None, optional):
|
|
115
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
|
|
112
116
|
params_beta (float | None, optional):
|
|
113
117
|
if not None, EMA of parameters is used for preconditioner update. Defaults to None.
|
|
114
118
|
grads_beta (float | None, optional):
|
|
@@ -117,35 +121,62 @@ class LBFGS(Module):
|
|
|
117
121
|
how often to update L-BFGS history. Defaults to 1.
|
|
118
122
|
z_beta (float | None, optional):
|
|
119
123
|
optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
|
|
120
|
-
tol_reset (bool, optional):
|
|
121
|
-
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
122
124
|
inner (Chainable | None, optional):
|
|
123
125
|
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
126
|
+
|
|
127
|
+
Examples:
|
|
128
|
+
L-BFGS with strong-wolfe line search
|
|
129
|
+
|
|
130
|
+
.. code-block:: python
|
|
131
|
+
|
|
132
|
+
opt = tz.Modular(
|
|
133
|
+
model.parameters(),
|
|
134
|
+
tz.m.LBFGS(100),
|
|
135
|
+
tz.m.StrongWolfe()
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
Dampened L-BFGS
|
|
139
|
+
|
|
140
|
+
.. code-block:: python
|
|
141
|
+
|
|
142
|
+
opt = tz.Modular(
|
|
143
|
+
model.parameters(),
|
|
144
|
+
tz.m.LBFGS(damping=True),
|
|
145
|
+
tz.m.StrongWolfe()
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
L-BFGS preconditioning applied to momentum (may be unstable!)
|
|
149
|
+
|
|
150
|
+
.. code-block:: python
|
|
151
|
+
|
|
152
|
+
opt = tz.Modular(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.LBFGS(inner=tz.m.EMA(0.9)),
|
|
155
|
+
tz.m.LR(1e-2)
|
|
156
|
+
)
|
|
124
157
|
"""
|
|
125
158
|
def __init__(
|
|
126
159
|
self,
|
|
127
160
|
history_size=10,
|
|
128
|
-
tol: float | None = 1e-10,
|
|
129
161
|
damping: bool = False,
|
|
130
162
|
init_damping=0.9,
|
|
131
163
|
eigval_bounds=(0.5, 50),
|
|
164
|
+
tol: float | None = 1e-10,
|
|
165
|
+
tol_reset: bool = False,
|
|
166
|
+
gtol: float | None = 1e-10,
|
|
132
167
|
params_beta: float | None = None,
|
|
133
168
|
grads_beta: float | None = None,
|
|
134
169
|
update_freq = 1,
|
|
135
170
|
z_beta: float | None = None,
|
|
136
|
-
tol_reset: bool = False,
|
|
137
171
|
inner: Chainable | None = None,
|
|
138
172
|
):
|
|
139
|
-
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
|
|
140
|
-
super().__init__(defaults)
|
|
173
|
+
defaults = dict(history_size=history_size, tol=tol, gtol=gtol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
|
|
174
|
+
super().__init__(defaults, uses_grad=False, inner=inner)
|
|
141
175
|
|
|
142
176
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
143
177
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
144
178
|
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
145
179
|
|
|
146
|
-
if inner is not None:
|
|
147
|
-
self.set_child('inner', inner)
|
|
148
|
-
|
|
149
180
|
def reset(self):
|
|
150
181
|
self.state.clear()
|
|
151
182
|
self.global_state['step'] = 0
|
|
@@ -153,10 +184,15 @@ class LBFGS(Module):
|
|
|
153
184
|
self.global_state['y_history'].clear()
|
|
154
185
|
self.global_state['sy_history'].clear()
|
|
155
186
|
|
|
187
|
+
def reset_for_online(self):
|
|
188
|
+
super().reset_for_online()
|
|
189
|
+
self.clear_state_keys('prev_l_params', 'prev_l_grad')
|
|
190
|
+
self.global_state.pop('step', None)
|
|
191
|
+
|
|
156
192
|
@torch.no_grad
|
|
157
|
-
def
|
|
158
|
-
params = as_tensorlist(
|
|
159
|
-
update = as_tensorlist(
|
|
193
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
194
|
+
params = as_tensorlist(params)
|
|
195
|
+
update = as_tensorlist(tensors)
|
|
160
196
|
step = self.global_state.get('step', 0)
|
|
161
197
|
self.global_state['step'] = step + 1
|
|
162
198
|
|
|
@@ -165,65 +201,86 @@ class LBFGS(Module):
|
|
|
165
201
|
y_history: deque[TensorList] = self.global_state['y_history']
|
|
166
202
|
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
167
203
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
|
|
204
|
+
damping,init_damping,eigval_bounds,update_freq = itemgetter('damping','init_damping','eigval_bounds','update_freq')(settings[0])
|
|
205
|
+
params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
|
|
171
206
|
|
|
172
207
|
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
173
|
-
prev_l_params, prev_l_grad =
|
|
208
|
+
prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
174
209
|
|
|
175
|
-
# 1st step - there are no previous params and grads,
|
|
210
|
+
# 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
|
|
176
211
|
if step == 0:
|
|
177
|
-
|
|
212
|
+
s = None; y = None; sy = None
|
|
178
213
|
else:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
214
|
+
s = l_params - prev_l_params
|
|
215
|
+
y = l_update - prev_l_grad
|
|
216
|
+
sy = s.dot(y)
|
|
182
217
|
|
|
183
218
|
if damping:
|
|
184
|
-
|
|
219
|
+
s, y, sy = _adaptive_damping(s, y, sy, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
185
220
|
|
|
186
221
|
prev_l_params.copy_(l_params)
|
|
187
222
|
prev_l_grad.copy_(l_update)
|
|
188
223
|
|
|
189
224
|
# update effective preconditioning state
|
|
190
225
|
if step % update_freq == 0:
|
|
191
|
-
if
|
|
192
|
-
assert
|
|
193
|
-
s_history.append(
|
|
194
|
-
y_history.append(
|
|
195
|
-
sy_history.append(
|
|
226
|
+
if sy is not None and sy > 1e-10:
|
|
227
|
+
assert s is not None and y is not None
|
|
228
|
+
s_history.append(s)
|
|
229
|
+
y_history.append(y)
|
|
230
|
+
sy_history.append(sy)
|
|
231
|
+
|
|
232
|
+
# store for apply
|
|
233
|
+
self.global_state['s'] = s
|
|
234
|
+
self.global_state['y'] = y
|
|
235
|
+
self.global_state['sy'] = sy
|
|
236
|
+
|
|
237
|
+
def make_Hv(self):
|
|
238
|
+
...
|
|
196
239
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
240
|
+
def make_Bv(self):
|
|
241
|
+
...
|
|
200
242
|
|
|
201
|
-
|
|
243
|
+
@torch.no_grad
|
|
244
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
245
|
+
tensors = as_tensorlist(tensors)
|
|
246
|
+
|
|
247
|
+
s = self.global_state.pop('s')
|
|
248
|
+
y = self.global_state.pop('y')
|
|
249
|
+
sy = self.global_state.pop('sy')
|
|
250
|
+
|
|
251
|
+
setting = settings[0]
|
|
252
|
+
tol = setting['tol']
|
|
253
|
+
gtol = setting['gtol']
|
|
254
|
+
tol_reset = setting['tol_reset']
|
|
255
|
+
z_beta = setting['z_beta']
|
|
256
|
+
|
|
257
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
202
258
|
if tol is not None:
|
|
203
|
-
if
|
|
204
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
259
|
+
if s is not None and s.abs().global_max() <= tol:
|
|
205
260
|
if tol_reset: self.reset()
|
|
206
|
-
return
|
|
261
|
+
return safe_scaling_(TensorList(tensors))
|
|
262
|
+
|
|
263
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
264
|
+
if tol is not None:
|
|
265
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
266
|
+
return safe_scaling_(TensorList(tensors))
|
|
207
267
|
|
|
208
268
|
# lerp initial H^-1 @ q guess
|
|
209
269
|
z_ema = None
|
|
210
270
|
if z_beta is not None:
|
|
211
|
-
z_ema =
|
|
271
|
+
z_ema = unpack_states(states, tensors, 'z_ema', cls=TensorList)
|
|
212
272
|
|
|
213
273
|
# precondition
|
|
214
274
|
dir = lbfgs(
|
|
215
|
-
tensors_=
|
|
216
|
-
s_history=s_history,
|
|
217
|
-
y_history=y_history,
|
|
218
|
-
sy_history=sy_history,
|
|
219
|
-
|
|
220
|
-
|
|
275
|
+
tensors_=tensors,
|
|
276
|
+
s_history=self.global_state['s_history'],
|
|
277
|
+
y_history=self.global_state['y_history'],
|
|
278
|
+
sy_history=self.global_state['sy_history'],
|
|
279
|
+
y=y,
|
|
280
|
+
sy=sy,
|
|
221
281
|
z_beta = z_beta,
|
|
222
282
|
z_ema = z_ema,
|
|
223
|
-
step=step
|
|
283
|
+
step=self.global_state.get('step', 1)
|
|
224
284
|
)
|
|
225
285
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
return var
|
|
229
|
-
|
|
286
|
+
return dir
|
|
@@ -4,10 +4,11 @@ from operator import itemgetter
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
7
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
8
|
-
|
|
7
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import safe_scaling_
|
|
9
9
|
from .lbfgs import _lerp_params_update_
|
|
10
10
|
|
|
11
|
+
|
|
11
12
|
def lsr1_(
|
|
12
13
|
tensors_: TensorList,
|
|
13
14
|
s_history: deque[TensorList],
|
|
@@ -15,11 +16,9 @@ def lsr1_(
|
|
|
15
16
|
step: int,
|
|
16
17
|
scale_second: bool,
|
|
17
18
|
):
|
|
18
|
-
if
|
|
19
|
+
if len(s_history) == 0:
|
|
19
20
|
# initial step size guess from pytorch
|
|
20
|
-
|
|
21
|
-
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
22
|
-
return tensors_.mul_(scale_factor)
|
|
21
|
+
return safe_scaling_(TensorList(tensors_))
|
|
23
22
|
|
|
24
23
|
m = len(s_history)
|
|
25
24
|
|
|
@@ -64,7 +63,7 @@ def lsr1_(
|
|
|
64
63
|
|
|
65
64
|
Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
|
|
66
65
|
|
|
67
|
-
if scale_second and step ==
|
|
66
|
+
if scale_second and step == 2:
|
|
68
67
|
scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
|
|
69
68
|
scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
|
|
70
69
|
Hx.mul_(scale_factor)
|
|
@@ -72,103 +71,148 @@ def lsr1_(
|
|
|
72
71
|
return Hx
|
|
73
72
|
|
|
74
73
|
|
|
75
|
-
class LSR1(
|
|
76
|
-
"""Limited Memory SR1
|
|
74
|
+
class LSR1(Transform):
|
|
75
|
+
"""Limited Memory SR1 algorithm. A line search is recommended.
|
|
76
|
+
|
|
77
|
+
.. note::
|
|
78
|
+
L-SR1 provides a better estimate of true hessian, however it is more unstable compared to L-BFGS.
|
|
79
|
+
|
|
80
|
+
.. note::
|
|
81
|
+
L-SR1 update rule uses a nested loop, computationally with history size `n` it is similar to L-BFGS with history size `(n^2)/2`. On small problems (ndim <= 2000) BFGS and SR1 may be faster than limited-memory versions.
|
|
82
|
+
|
|
83
|
+
.. note::
|
|
84
|
+
directions L-SR1 generates are not guaranteed to be descent directions. This can be alleviated in multiple ways,
|
|
85
|
+
for example using :code:`tz.m.StrongWolfe(plus_minus=True)` line search, or modifying the direction with :code:`tz.m.Cautious` or :code:`tz.m.ScaleByGradCosineSimilarity`.
|
|
86
|
+
|
|
77
87
|
Args:
|
|
78
|
-
history_size (int, optional):
|
|
79
|
-
and gradient differences
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
88
|
+
history_size (int, optional):
|
|
89
|
+
number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
90
|
+
tol (float | None, optional):
|
|
91
|
+
tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
|
|
92
|
+
tol_reset (bool, optional):
|
|
93
|
+
If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
|
|
94
|
+
gtol (float | None, optional):
|
|
95
|
+
tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
|
|
96
|
+
params_beta (float | None, optional):
|
|
97
|
+
if not None, EMA of parameters is used for
|
|
85
98
|
preconditioner update (s_k vector). Defaults to None.
|
|
86
|
-
grads_beta (float | None, optional):
|
|
99
|
+
grads_beta (float | None, optional):
|
|
100
|
+
if not None, EMA of gradients is used for
|
|
87
101
|
preconditioner update (y_k vector). Defaults to None.
|
|
88
102
|
update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
inner (Chainable | None, optional): Optional inner modules applied after updating
|
|
103
|
+
scale_second (bool, optional): downscales second update which tends to be large. Defaults to False.
|
|
104
|
+
inner (Chainable | None, optional):
|
|
105
|
+
Optional inner modules applied after updating
|
|
93
106
|
L-SR1 history and before preconditioning. Defaults to None.
|
|
107
|
+
|
|
108
|
+
Examples:
|
|
109
|
+
L-SR1 with Strong-Wolfe+- line search
|
|
110
|
+
|
|
111
|
+
.. code-block:: python
|
|
112
|
+
|
|
113
|
+
opt = tz.Modular(
|
|
114
|
+
model.parameters(),
|
|
115
|
+
tz.m.LSR1(100),
|
|
116
|
+
tz.m.StrongWolfe(plus_minus=True)
|
|
117
|
+
)
|
|
94
118
|
"""
|
|
95
119
|
def __init__(
|
|
96
120
|
self,
|
|
97
121
|
history_size: int = 10,
|
|
98
|
-
tol: float = 1e-
|
|
122
|
+
tol: float | None = 1e-10,
|
|
123
|
+
tol_reset: bool = False,
|
|
124
|
+
gtol: float | None = 1e-10,
|
|
99
125
|
params_beta: float | None = None,
|
|
100
126
|
grads_beta: float | None = None,
|
|
101
127
|
update_freq: int = 1,
|
|
102
|
-
scale_second: bool =
|
|
128
|
+
scale_second: bool = False,
|
|
103
129
|
inner: Chainable | None = None,
|
|
104
130
|
):
|
|
105
131
|
defaults = dict(
|
|
106
|
-
history_size=history_size, tol=tol,
|
|
132
|
+
history_size=history_size, tol=tol, gtol=gtol,
|
|
107
133
|
params_beta=params_beta, grads_beta=grads_beta,
|
|
108
|
-
update_freq=update_freq, scale_second=scale_second
|
|
134
|
+
update_freq=update_freq, scale_second=scale_second,
|
|
135
|
+
tol_reset=tol_reset,
|
|
109
136
|
)
|
|
110
|
-
super().__init__(defaults)
|
|
137
|
+
super().__init__(defaults, uses_grad=False, inner=inner)
|
|
111
138
|
|
|
112
139
|
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
113
140
|
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
114
141
|
|
|
115
|
-
if inner is not None:
|
|
116
|
-
self.set_child('inner', inner)
|
|
117
|
-
|
|
118
142
|
def reset(self):
|
|
119
143
|
self.state.clear()
|
|
120
144
|
self.global_state['step'] = 0
|
|
121
145
|
self.global_state['s_history'].clear()
|
|
122
146
|
self.global_state['y_history'].clear()
|
|
123
147
|
|
|
148
|
+
def reset_for_online(self):
|
|
149
|
+
super().reset_for_online()
|
|
150
|
+
self.clear_state_keys('prev_l_params', 'prev_l_grad')
|
|
151
|
+
self.global_state.pop('step', None)
|
|
124
152
|
|
|
125
153
|
@torch.no_grad
|
|
126
|
-
def
|
|
127
|
-
params = as_tensorlist(
|
|
128
|
-
update = as_tensorlist(
|
|
154
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
155
|
+
params = as_tensorlist(params)
|
|
156
|
+
update = as_tensorlist(tensors)
|
|
129
157
|
step = self.global_state.get('step', 0)
|
|
130
158
|
self.global_state['step'] = step + 1
|
|
131
159
|
|
|
132
160
|
s_history: deque[TensorList] = self.global_state['s_history']
|
|
133
161
|
y_history: deque[TensorList] = self.global_state['y_history']
|
|
134
162
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
params_beta, grads_beta_ = self.get_settings(params, 'params_beta', 'grads_beta') # type: ignore
|
|
139
|
-
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
|
|
163
|
+
setting = settings[0]
|
|
164
|
+
update_freq = itemgetter('update_freq')(setting)
|
|
140
165
|
|
|
141
|
-
|
|
166
|
+
params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
|
|
167
|
+
l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
|
|
168
|
+
prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
|
|
142
169
|
|
|
143
|
-
|
|
170
|
+
s = None
|
|
171
|
+
y = None
|
|
144
172
|
if step != 0:
|
|
145
173
|
if step % update_freq == 0:
|
|
146
|
-
|
|
147
|
-
|
|
174
|
+
s = l_params - prev_l_params
|
|
175
|
+
y = l_update - prev_l_grad
|
|
148
176
|
|
|
149
|
-
s_history.append(
|
|
150
|
-
y_history.append(
|
|
177
|
+
s_history.append(s)
|
|
178
|
+
y_history.append(y)
|
|
151
179
|
|
|
152
180
|
prev_l_params.copy_(l_params)
|
|
153
181
|
prev_l_grad.copy_(l_update)
|
|
154
182
|
|
|
155
|
-
|
|
156
|
-
|
|
183
|
+
# store for apply
|
|
184
|
+
self.global_state['s'] = s
|
|
185
|
+
self.global_state['y'] = y
|
|
186
|
+
|
|
187
|
+
@torch.no_grad
|
|
188
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
189
|
+
tensors = as_tensorlist(tensors)
|
|
190
|
+
s = self.global_state.pop('s')
|
|
191
|
+
y = self.global_state.pop('y')
|
|
157
192
|
|
|
158
|
-
|
|
193
|
+
setting = settings[0]
|
|
194
|
+
tol = setting['tol']
|
|
195
|
+
gtol = setting['gtol']
|
|
196
|
+
tol_reset = setting['tol_reset']
|
|
197
|
+
|
|
198
|
+
# tolerance on parameter difference to avoid exploding after converging
|
|
199
|
+
if tol is not None:
|
|
200
|
+
if s is not None and s.abs().global_max() <= tol:
|
|
201
|
+
if tol_reset: self.reset()
|
|
202
|
+
return safe_scaling_(TensorList(tensors))
|
|
203
|
+
|
|
204
|
+
# tolerance on gradient difference to avoid exploding when there is no curvature
|
|
159
205
|
if tol is not None:
|
|
160
|
-
if
|
|
161
|
-
|
|
162
|
-
return var
|
|
206
|
+
if y is not None and y.abs().global_max() <= gtol:
|
|
207
|
+
return safe_scaling_(TensorList(tensors))
|
|
163
208
|
|
|
209
|
+
# precondition
|
|
164
210
|
dir = lsr1_(
|
|
165
|
-
tensors_=
|
|
166
|
-
s_history=s_history,
|
|
167
|
-
y_history=y_history,
|
|
168
|
-
step=step,
|
|
169
|
-
scale_second=scale_second,
|
|
211
|
+
tensors_=tensors,
|
|
212
|
+
s_history=self.global_state['s_history'],
|
|
213
|
+
y_history=self.global_state['y_history'],
|
|
214
|
+
step=self.global_state.get('step', 1),
|
|
215
|
+
scale_second=setting['scale_second'],
|
|
170
216
|
)
|
|
171
217
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
return var
|
|
218
|
+
return dir
|