torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,225 +0,0 @@
|
|
|
1
|
-
import itertools
|
|
2
|
-
import math
|
|
3
|
-
import warnings
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
from contextlib import nullcontext
|
|
6
|
-
from functools import partial
|
|
7
|
-
from typing import Any, Literal
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
import scipy.optimize
|
|
11
|
-
import torch
|
|
12
|
-
|
|
13
|
-
from ...core import Chainable, Module, apply_transform
|
|
14
|
-
from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
|
|
15
|
-
from ...utils.derivatives import (
|
|
16
|
-
hessian_list_to_mat,
|
|
17
|
-
jacobian_wrt,
|
|
18
|
-
hvp,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
def _poly_eval_diag(s: np.ndarray, c, derivatives):
|
|
22
|
-
val = float(c) + (derivatives[0] * s).sum(-1)
|
|
23
|
-
|
|
24
|
-
if len(derivatives) > 1:
|
|
25
|
-
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
26
|
-
val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
|
|
27
|
-
|
|
28
|
-
return val
|
|
29
|
-
|
|
30
|
-
def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
31
|
-
"""Computes the value of the proximal polynomial approximation."""
|
|
32
|
-
if x.ndim == 2: x = x.T
|
|
33
|
-
s = x - x0
|
|
34
|
-
|
|
35
|
-
val = _poly_eval_diag(s, c, derivatives)
|
|
36
|
-
|
|
37
|
-
penalty = 0
|
|
38
|
-
if prox != 0:
|
|
39
|
-
penalty = (prox / 2) * (s**2).sum(-1)
|
|
40
|
-
|
|
41
|
-
return val + penalty
|
|
42
|
-
|
|
43
|
-
def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
44
|
-
"""Computes the gradient of the proximal polynomial approximation."""
|
|
45
|
-
s = x - x0
|
|
46
|
-
|
|
47
|
-
g = derivatives[0].copy()
|
|
48
|
-
|
|
49
|
-
if len(derivatives) > 1:
|
|
50
|
-
for i, d_diag in enumerate(derivatives[1:], 2):
|
|
51
|
-
g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
|
|
52
|
-
|
|
53
|
-
if prox != 0:
|
|
54
|
-
g += prox * s
|
|
55
|
-
|
|
56
|
-
return g
|
|
57
|
-
|
|
58
|
-
def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
|
|
59
|
-
"""Computes the Hessian of the proximal polynomial approximation."""
|
|
60
|
-
s = x - x0
|
|
61
|
-
n = x.shape[0]
|
|
62
|
-
|
|
63
|
-
if len(derivatives) < 2:
|
|
64
|
-
H_diag = np.zeros(n, dtype=s.dtype)
|
|
65
|
-
else:
|
|
66
|
-
H_diag = derivatives[1].copy()
|
|
67
|
-
|
|
68
|
-
if len(derivatives) > 2:
|
|
69
|
-
for i, d_diag in enumerate(derivatives[2:], 3):
|
|
70
|
-
H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
|
|
71
|
-
|
|
72
|
-
if prox != 0:
|
|
73
|
-
H_diag += prox
|
|
74
|
-
|
|
75
|
-
return np.diag(H_diag)
|
|
76
|
-
|
|
77
|
-
def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
|
|
78
|
-
derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
|
|
79
|
-
x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
|
|
80
|
-
bounds = None
|
|
81
|
-
if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
|
|
82
|
-
|
|
83
|
-
# if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
|
|
84
|
-
if bounds is None:
|
|
85
|
-
if len(derivatives) == 1: method = 'bfgs'
|
|
86
|
-
else: method = 'trust-exact'
|
|
87
|
-
else:
|
|
88
|
-
if len(derivatives) == 1: method = 'l-bfgs-b'
|
|
89
|
-
else: method = 'trust-constr'
|
|
90
|
-
|
|
91
|
-
x_init = x0.copy()
|
|
92
|
-
v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
|
|
93
|
-
if de_iters is not None and de_iters != 0:
|
|
94
|
-
if de_iters == -1: de_iters = None # let scipy decide
|
|
95
|
-
res = scipy.optimize.differential_evolution(
|
|
96
|
-
_proximal_poly_v_diag,
|
|
97
|
-
bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
|
|
98
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
99
|
-
maxiter=de_iters,
|
|
100
|
-
vectorized=True,
|
|
101
|
-
)
|
|
102
|
-
if res.fun < v0: x_init = res.x
|
|
103
|
-
|
|
104
|
-
res = scipy.optimize.minimize(
|
|
105
|
-
_proximal_poly_v_diag,
|
|
106
|
-
x_init,
|
|
107
|
-
method=method,
|
|
108
|
-
args=(c, prox, x0.copy(), derivatives),
|
|
109
|
-
jac=_proximal_poly_g_diag,
|
|
110
|
-
hess=_proximal_poly_H_diag,
|
|
111
|
-
bounds=bounds
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
return torch.from_numpy(res.x).to(x), res.fun
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
class DiagonalHigherOrderNewton(Module):
|
|
119
|
-
"""
|
|
120
|
-
Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
|
|
121
|
-
except it doesn't work in all cases except ones where it works.
|
|
122
|
-
"""
|
|
123
|
-
def __init__(
|
|
124
|
-
self,
|
|
125
|
-
order: int = 4,
|
|
126
|
-
trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
|
|
127
|
-
increase: float = 1.5,
|
|
128
|
-
decrease: float = 0.75,
|
|
129
|
-
trust_init: float | None = None,
|
|
130
|
-
trust_tol: float = 1,
|
|
131
|
-
de_iters: int | None = None,
|
|
132
|
-
vectorize: bool = True,
|
|
133
|
-
):
|
|
134
|
-
if trust_init is None:
|
|
135
|
-
if trust_method == 'bounds': trust_init = 1
|
|
136
|
-
else: trust_init = 0.1
|
|
137
|
-
|
|
138
|
-
defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
|
|
139
|
-
super().__init__(defaults)
|
|
140
|
-
|
|
141
|
-
@torch.no_grad
|
|
142
|
-
def step(self, var):
|
|
143
|
-
params = TensorList(var.params)
|
|
144
|
-
closure = var.closure
|
|
145
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
146
|
-
|
|
147
|
-
settings = self.settings[params[0]]
|
|
148
|
-
order = settings['order']
|
|
149
|
-
increase = settings['increase']
|
|
150
|
-
decrease = settings['decrease']
|
|
151
|
-
trust_tol = settings['trust_tol']
|
|
152
|
-
trust_init = settings['trust_init']
|
|
153
|
-
trust_method = settings['trust_method']
|
|
154
|
-
de_iters = settings['de_iters']
|
|
155
|
-
|
|
156
|
-
trust_value = self.global_state.get('trust_value', trust_init)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
160
|
-
with torch.enable_grad():
|
|
161
|
-
loss = var.loss = var.loss_approx = closure(False)
|
|
162
|
-
|
|
163
|
-
g = torch.autograd.grad(loss, params, create_graph=True)
|
|
164
|
-
var.grad = list(g)
|
|
165
|
-
|
|
166
|
-
derivatives = [g]
|
|
167
|
-
T = g # current derivatives tensor diagonal
|
|
168
|
-
ones = [torch.ones_like(t) for t in g]
|
|
169
|
-
|
|
170
|
-
# get all derivatives up to order
|
|
171
|
-
for o in range(2, order + 1):
|
|
172
|
-
T = hvp(params, T, ones, create_graph=o != order)
|
|
173
|
-
derivatives.append(T)
|
|
174
|
-
|
|
175
|
-
x0 = torch.cat([p.ravel() for p in params])
|
|
176
|
-
|
|
177
|
-
if trust_method is None: trust_method = 'none'
|
|
178
|
-
else: trust_method = trust_method.lower()
|
|
179
|
-
|
|
180
|
-
if trust_method == 'none':
|
|
181
|
-
trust_region = None
|
|
182
|
-
prox = 0
|
|
183
|
-
|
|
184
|
-
elif trust_method == 'bounds':
|
|
185
|
-
trust_region = trust_value
|
|
186
|
-
prox = 0
|
|
187
|
-
|
|
188
|
-
elif trust_method == 'proximal':
|
|
189
|
-
trust_region = None
|
|
190
|
-
prox = 1 / trust_value
|
|
191
|
-
|
|
192
|
-
else:
|
|
193
|
-
raise ValueError(trust_method)
|
|
194
|
-
|
|
195
|
-
x_star, expected_loss = _poly_minimize(
|
|
196
|
-
trust_region=trust_region,
|
|
197
|
-
prox=prox,
|
|
198
|
-
de_iters=de_iters,
|
|
199
|
-
c=loss.item(),
|
|
200
|
-
x=x0,
|
|
201
|
-
derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# trust region
|
|
205
|
-
if trust_method != 'none':
|
|
206
|
-
expected_reduction = loss - expected_loss
|
|
207
|
-
|
|
208
|
-
vec_to_tensors_(x_star, params)
|
|
209
|
-
loss_star = closure(False)
|
|
210
|
-
vec_to_tensors_(x0, params)
|
|
211
|
-
reduction = loss - loss_star
|
|
212
|
-
|
|
213
|
-
# failed step
|
|
214
|
-
if reduction <= 0:
|
|
215
|
-
x_star = x0
|
|
216
|
-
self.global_state['trust_value'] = trust_value * decrease
|
|
217
|
-
|
|
218
|
-
# very good step
|
|
219
|
-
elif expected_reduction / reduction <= trust_tol:
|
|
220
|
-
self.global_state['trust_value'] = trust_value * increase
|
|
221
|
-
|
|
222
|
-
difference = vec_to_tensors(x0 - x_star, params)
|
|
223
|
-
var.update = list(difference)
|
|
224
|
-
return var
|
|
225
|
-
|
|
@@ -1,163 +0,0 @@
|
|
|
1
|
-
from operator import itemgetter
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
6
|
-
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
-
from ..optimizers.soap import (
|
|
8
|
-
update_soap_covariances_,
|
|
9
|
-
get_orthogonal_matrix,
|
|
10
|
-
get_orthogonal_matrix_QR,
|
|
11
|
-
project,
|
|
12
|
-
project_back,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
class SOAPY(Transform):
|
|
16
|
-
"""Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
|
|
17
|
-
|
|
18
|
-
New args:
|
|
19
|
-
scale_by_s - whether to scale gradient differences by parameter differences
|
|
20
|
-
y_to_ema2 - whether to use gradient differences for exponential moving average too
|
|
21
|
-
"""
|
|
22
|
-
def __init__(
|
|
23
|
-
self,
|
|
24
|
-
beta1: float = 0.95,
|
|
25
|
-
beta2: float = 0.95,
|
|
26
|
-
shampoo_beta: float | None = 0.95,
|
|
27
|
-
precond_freq: int = 10,
|
|
28
|
-
merge_small: bool = True,
|
|
29
|
-
max_dim: int = 2_000,
|
|
30
|
-
precondition_1d: bool = True,
|
|
31
|
-
eps: float = 1e-8,
|
|
32
|
-
decay: float | None = None,
|
|
33
|
-
alpha: float = 1,
|
|
34
|
-
bias_correction: bool = True,
|
|
35
|
-
scale_by_s: bool = True,
|
|
36
|
-
y_to_ema2: bool = False,
|
|
37
|
-
):
|
|
38
|
-
defaults = dict(
|
|
39
|
-
beta1=beta1,
|
|
40
|
-
beta2=beta2,
|
|
41
|
-
shampoo_beta=shampoo_beta,
|
|
42
|
-
precond_freq=precond_freq,
|
|
43
|
-
merge_small=merge_small,
|
|
44
|
-
max_dim=max_dim,
|
|
45
|
-
precondition_1d=precondition_1d,
|
|
46
|
-
eps=eps,
|
|
47
|
-
decay=decay,
|
|
48
|
-
bias_correction=bias_correction,
|
|
49
|
-
alpha=alpha,
|
|
50
|
-
scale_by_s=scale_by_s,
|
|
51
|
-
y_to_ema2=y_to_ema2,
|
|
52
|
-
)
|
|
53
|
-
super().__init__(defaults, uses_grad=False)
|
|
54
|
-
|
|
55
|
-
@torch.no_grad
|
|
56
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
57
|
-
updates = []
|
|
58
|
-
# update preconditioners
|
|
59
|
-
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
60
|
-
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
61
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
62
|
-
scale_by_s = setting['scale_by_s']
|
|
63
|
-
y_to_ema2 = setting['y_to_ema2']
|
|
64
|
-
|
|
65
|
-
if merge_small:
|
|
66
|
-
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
67
|
-
|
|
68
|
-
if 'g_prev' not in state:
|
|
69
|
-
state['p_prev'] = p.clone()
|
|
70
|
-
state['g_prev'] = t.clone()
|
|
71
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
72
|
-
continue
|
|
73
|
-
|
|
74
|
-
p_prev = state['p_prev']
|
|
75
|
-
g_prev = state['g_prev']
|
|
76
|
-
s = p - p_prev
|
|
77
|
-
y = t - g_prev
|
|
78
|
-
if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
|
|
79
|
-
|
|
80
|
-
state['p_prev'].copy_(p)
|
|
81
|
-
state['g_prev'].copy_(t)
|
|
82
|
-
|
|
83
|
-
# initialize state on 1st step
|
|
84
|
-
if 'GG' not in state:
|
|
85
|
-
state["exp_avg"] = torch.zeros_like(t)
|
|
86
|
-
if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
|
|
87
|
-
else: state["exp_avg_sq"] = torch.zeros_like(t)
|
|
88
|
-
|
|
89
|
-
if not precondition_1d and t.ndim <= 1:
|
|
90
|
-
state['GG'] = []
|
|
91
|
-
|
|
92
|
-
else:
|
|
93
|
-
state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
|
|
94
|
-
|
|
95
|
-
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
96
|
-
if len([i is not None for i in state['GG']]) == 0:
|
|
97
|
-
state['GG'] = None
|
|
98
|
-
|
|
99
|
-
if state['GG'] is not None:
|
|
100
|
-
update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
|
|
101
|
-
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
102
|
-
|
|
103
|
-
state['step'] = 0
|
|
104
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
105
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
106
|
-
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
107
|
-
|
|
108
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
109
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
110
|
-
z_projected = None
|
|
111
|
-
if state['GG'] is not None:
|
|
112
|
-
if y_to_ema2: z_projected = project(y, state['Q'])
|
|
113
|
-
else: z_projected = project(t, state['Q'])
|
|
114
|
-
|
|
115
|
-
# exponential moving averages
|
|
116
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
117
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
118
|
-
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
119
|
-
|
|
120
|
-
exp_avg.lerp_(t, 1-beta1)
|
|
121
|
-
|
|
122
|
-
if z_projected is None:
|
|
123
|
-
if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
|
|
124
|
-
else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
125
|
-
else:
|
|
126
|
-
exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
|
|
127
|
-
|
|
128
|
-
# project exponential moving averages if they are accumulated unprojected
|
|
129
|
-
exp_avg_projected = exp_avg
|
|
130
|
-
if z_projected is not None:
|
|
131
|
-
exp_avg_projected = project(exp_avg, state['Q'])
|
|
132
|
-
|
|
133
|
-
exp_avg_sq_projected = exp_avg_sq
|
|
134
|
-
|
|
135
|
-
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
136
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
137
|
-
|
|
138
|
-
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
139
|
-
# to the original space
|
|
140
|
-
update = exp_avg_projected / denom
|
|
141
|
-
if z_projected is not None:
|
|
142
|
-
update = project_back(update, state["Q"])
|
|
143
|
-
|
|
144
|
-
if setting['bias_correction']:
|
|
145
|
-
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
146
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
147
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
148
|
-
elif alpha is not None:
|
|
149
|
-
update *= alpha
|
|
150
|
-
|
|
151
|
-
if merge_small:
|
|
152
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
153
|
-
|
|
154
|
-
updates.append(update)
|
|
155
|
-
state["step"] += 1
|
|
156
|
-
|
|
157
|
-
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
158
|
-
if state['GG'] is not None:
|
|
159
|
-
update_soap_covariances_(y, state['GG'], shampoo_beta)
|
|
160
|
-
if state['step'] % setting['precond_freq'] == 0:
|
|
161
|
-
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
162
|
-
|
|
163
|
-
return updates
|
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
# idea https://arxiv.org/pdf/2212.09841
|
|
2
|
-
import warnings
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from functools import partial
|
|
5
|
-
from typing import Literal
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
|
|
9
|
-
from ...core import Chainable, Module, apply_transform
|
|
10
|
-
from ...utils import TensorList, vec_to_tensors
|
|
11
|
-
from ...utils.derivatives import (
|
|
12
|
-
hessian_list_to_mat,
|
|
13
|
-
hessian_mat,
|
|
14
|
-
hvp,
|
|
15
|
-
hvp_fd_central,
|
|
16
|
-
hvp_fd_forward,
|
|
17
|
-
jacobian_and_hessian_wrt,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class StructuredNewton(Module):
|
|
22
|
-
"""TODO. Please note that this is experimental and isn't guaranteed to work.
|
|
23
|
-
Args:
|
|
24
|
-
structure (str, optional): structure.
|
|
25
|
-
reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
|
|
26
|
-
hvp_method (str):
|
|
27
|
-
how to calculate hvp_method. Defaults to "autograd".
|
|
28
|
-
inner (Chainable | None, optional): inner modules. Defaults to None.
|
|
29
|
-
|
|
30
|
-
"""
|
|
31
|
-
def __init__(
|
|
32
|
-
self,
|
|
33
|
-
structure: Literal[
|
|
34
|
-
"diagonal",
|
|
35
|
-
"diagonal1",
|
|
36
|
-
"diagonal_abs",
|
|
37
|
-
"tridiagonal",
|
|
38
|
-
"circulant",
|
|
39
|
-
"toeplitz",
|
|
40
|
-
"toeplitz_like",
|
|
41
|
-
"hankel",
|
|
42
|
-
"rank1",
|
|
43
|
-
"rank2", # any rank
|
|
44
|
-
]
|
|
45
|
-
| str = "diagonal",
|
|
46
|
-
reg: float = 1e-6,
|
|
47
|
-
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
48
|
-
h: float = 1e-3,
|
|
49
|
-
inner: Chainable | None = None,
|
|
50
|
-
):
|
|
51
|
-
defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
|
|
52
|
-
super().__init__(defaults)
|
|
53
|
-
|
|
54
|
-
if inner is not None:
|
|
55
|
-
self.set_child('inner', inner)
|
|
56
|
-
|
|
57
|
-
@torch.no_grad
|
|
58
|
-
def step(self, var):
|
|
59
|
-
params = TensorList(var.params)
|
|
60
|
-
closure = var.closure
|
|
61
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
62
|
-
|
|
63
|
-
settings = self.settings[params[0]]
|
|
64
|
-
reg = settings['reg']
|
|
65
|
-
hvp_method = settings['hvp_method']
|
|
66
|
-
structure = settings['structure']
|
|
67
|
-
h = settings['h']
|
|
68
|
-
|
|
69
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
70
|
-
if hvp_method == 'autograd':
|
|
71
|
-
grad = var.get_grad(create_graph=True)
|
|
72
|
-
def Hvp_fn1(x):
|
|
73
|
-
return hvp(params, grad, x, retain_graph=True)
|
|
74
|
-
Hvp_fn = Hvp_fn1
|
|
75
|
-
|
|
76
|
-
elif hvp_method == 'forward':
|
|
77
|
-
grad = var.get_grad()
|
|
78
|
-
def Hvp_fn2(x):
|
|
79
|
-
return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
|
|
80
|
-
Hvp_fn = Hvp_fn2
|
|
81
|
-
|
|
82
|
-
elif hvp_method == 'central':
|
|
83
|
-
grad = var.get_grad()
|
|
84
|
-
def Hvp_fn3(x):
|
|
85
|
-
return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
|
|
86
|
-
Hvp_fn = Hvp_fn3
|
|
87
|
-
|
|
88
|
-
else: raise ValueError(hvp_method)
|
|
89
|
-
|
|
90
|
-
# -------------------------------- inner step -------------------------------- #
|
|
91
|
-
update = var.get_update()
|
|
92
|
-
if 'inner' in self.children:
|
|
93
|
-
update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
|
|
94
|
-
|
|
95
|
-
# hessian
|
|
96
|
-
if structure.startswith('diagonal'):
|
|
97
|
-
H = Hvp_fn([torch.ones_like(p) for p in params])
|
|
98
|
-
if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
|
|
99
|
-
if structure == 'diagonal_abs': torch._foreach_abs_(H)
|
|
100
|
-
torch._foreach_add_(H, reg)
|
|
101
|
-
torch._foreach_div_(update, H)
|
|
102
|
-
var.update = update
|
|
103
|
-
return var
|
|
104
|
-
|
|
105
|
-
# hessian
|
|
106
|
-
raise NotImplementedError(structure)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
torchzero/modules/lr/__init__.py
DELETED
torchzero/modules/lr/adaptive.py
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
"""Various step size strategies"""
|
|
2
|
-
import random
|
|
3
|
-
from typing import Any
|
|
4
|
-
from operator import itemgetter
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Transform
|
|
8
|
-
from ...utils import TensorList, NumberList, unpack_dicts
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class PolyakStepSize(Transform):
|
|
12
|
-
"""Polyak's step-size method.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
16
|
-
min_obj_value (int, optional):
|
|
17
|
-
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
18
|
-
use_grad (bool, optional):
|
|
19
|
-
if True, uses dot product of update and gradient to compute the step size.
|
|
20
|
-
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
21
|
-
Defaults to True.
|
|
22
|
-
parameterwise (bool, optional):
|
|
23
|
-
if True, calculate Polyak step-size for each parameter separately,
|
|
24
|
-
if False calculate one global step size for all parameters. Defaults to False.
|
|
25
|
-
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
26
|
-
"""
|
|
27
|
-
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
28
|
-
|
|
29
|
-
defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
|
|
30
|
-
super().__init__(defaults, uses_grad=use_grad)
|
|
31
|
-
|
|
32
|
-
@torch.no_grad
|
|
33
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
34
|
-
assert grads is not None
|
|
35
|
-
tensors = TensorList(tensors)
|
|
36
|
-
grads = TensorList(grads)
|
|
37
|
-
alpha = NumberList(s['alpha'] for s in settings)
|
|
38
|
-
|
|
39
|
-
parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
|
|
40
|
-
|
|
41
|
-
if use_grad: denom = tensors.dot(grads)
|
|
42
|
-
else: denom = tensors.dot(tensors)
|
|
43
|
-
|
|
44
|
-
if parameterwise:
|
|
45
|
-
polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
|
|
46
|
-
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
47
|
-
if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
|
|
48
|
-
|
|
49
|
-
else:
|
|
50
|
-
if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
|
|
51
|
-
else: polyak_step_size = (loss - min_obj_value) / denom
|
|
52
|
-
|
|
53
|
-
if max is not None:
|
|
54
|
-
if polyak_step_size > max: polyak_step_size = max
|
|
55
|
-
|
|
56
|
-
tensors.mul_(alpha * polyak_step_size)
|
|
57
|
-
return tensors
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class RandomStepSize(Transform):
|
|
61
|
-
"""Uses random global or layer-wise step size from `low` to `high`.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
low (float, optional): minimum learning rate. Defaults to 0.
|
|
65
|
-
high (float, optional): maximum learning rate. Defaults to 1.
|
|
66
|
-
parameterwise (bool, optional):
|
|
67
|
-
if True, generate random step size for each parameter separately,
|
|
68
|
-
if False generate one global random step size. Defaults to False.
|
|
69
|
-
"""
|
|
70
|
-
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
71
|
-
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
72
|
-
super().__init__(defaults, uses_grad=False)
|
|
73
|
-
|
|
74
|
-
@torch.no_grad
|
|
75
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
76
|
-
s = settings[0]
|
|
77
|
-
parameterwise = s['parameterwise']
|
|
78
|
-
|
|
79
|
-
seed = s['seed']
|
|
80
|
-
if 'generator' not in self.global_state:
|
|
81
|
-
self.global_state['generator'] = random.Random(seed)
|
|
82
|
-
generator: random.Random = self.global_state['generator']
|
|
83
|
-
|
|
84
|
-
if parameterwise:
|
|
85
|
-
low, high = unpack_dicts(settings, 'low', 'high')
|
|
86
|
-
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
87
|
-
else:
|
|
88
|
-
low = s['low']
|
|
89
|
-
high = s['high']
|
|
90
|
-
lr = generator.uniform(low, high)
|
|
91
|
-
|
|
92
|
-
torch._foreach_mul_(tensors, lr)
|
|
93
|
-
return tensors
|
torchzero/modules/lr/lr.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
"""Learning rate"""
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from ...core import Transform
|
|
5
|
-
from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
|
|
6
|
-
|
|
7
|
-
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
8
|
-
"""multiplies by lr if lr is not 1"""
|
|
9
|
-
if generic_eq(lr, 1): return tensors
|
|
10
|
-
if inplace: return tensors.mul_(lr)
|
|
11
|
-
return tensors * lr
|
|
12
|
-
|
|
13
|
-
class LR(Transform):
|
|
14
|
-
"""Learning rate. Adding this module also adds support for LR schedulers."""
|
|
15
|
-
def __init__(self, lr: float):
|
|
16
|
-
defaults=dict(lr=lr)
|
|
17
|
-
super().__init__(defaults, uses_grad=False)
|
|
18
|
-
|
|
19
|
-
@torch.no_grad
|
|
20
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
21
|
-
return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
|
|
22
|
-
|
|
23
|
-
class StepSize(Transform):
|
|
24
|
-
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
|
|
25
|
-
def __init__(self, step_size: float, key = 'step_size'):
|
|
26
|
-
defaults={"key": key, key: step_size}
|
|
27
|
-
super().__init__(defaults, uses_grad=False)
|
|
28
|
-
|
|
29
|
-
@torch.no_grad
|
|
30
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
31
|
-
return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
35
|
-
"""returns warm up lr scalar"""
|
|
36
|
-
if step > steps: return end_lr
|
|
37
|
-
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
38
|
-
|
|
39
|
-
class Warmup(Transform):
|
|
40
|
-
"""Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
|
|
44
|
-
end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
|
|
45
|
-
steps (int, optional): number of steps to perform warmup for. Defaults to 100.
|
|
46
|
-
"""
|
|
47
|
-
def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
|
|
48
|
-
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
49
|
-
super().__init__(defaults, uses_grad=False)
|
|
50
|
-
|
|
51
|
-
@torch.no_grad
|
|
52
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
53
|
-
start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
|
|
54
|
-
num_steps = settings[0]['steps']
|
|
55
|
-
step = self.global_state.get('step', 0)
|
|
56
|
-
|
|
57
|
-
target = lazy_lr(
|
|
58
|
-
TensorList(tensors),
|
|
59
|
-
lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
60
|
-
inplace=True
|
|
61
|
-
)
|
|
62
|
-
self.global_state['step'] = step + 1
|
|
63
|
-
return target
|