torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -2,11 +2,18 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Transform
|
|
5
|
+
from ...core import Chainable, Transform
|
|
6
6
|
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
from ..optimizers.soap import (
|
|
8
|
+
get_orthogonal_matrix,
|
|
9
|
+
get_orthogonal_matrix_QR,
|
|
10
|
+
project,
|
|
11
|
+
project_back,
|
|
12
|
+
)
|
|
13
|
+
|
|
7
14
|
|
|
8
15
|
@torch.no_grad
|
|
9
|
-
def
|
|
16
|
+
def update_adasoap_covariances_(
|
|
10
17
|
grad: torch.Tensor,
|
|
11
18
|
GGs_: list[torch.Tensor | None],
|
|
12
19
|
GG_sqs: list[torch.Tensor | None],
|
|
@@ -24,127 +31,16 @@ def update_soap_covariances_(
|
|
|
24
31
|
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
25
32
|
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
26
33
|
|
|
27
|
-
@torch.no_grad
|
|
28
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
29
|
-
"""
|
|
30
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
31
|
-
"""
|
|
32
|
-
for mat in Q:
|
|
33
|
-
if mat is None: continue
|
|
34
|
-
if len(mat) > 0:
|
|
35
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
36
|
-
else:
|
|
37
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
38
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
39
|
-
tensors = tensors.permute(permute_order)
|
|
40
|
-
|
|
41
|
-
return tensors
|
|
42
|
-
|
|
43
|
-
@torch.no_grad
|
|
44
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
45
|
-
"""
|
|
46
|
-
Projects the gradient back to the original space.
|
|
47
|
-
"""
|
|
48
|
-
for mat in Q:
|
|
49
|
-
if mat is None: continue
|
|
50
|
-
if len(mat) > 0:
|
|
51
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
52
|
-
else:
|
|
53
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
54
|
-
tensors = tensors.permute(permute_order)
|
|
55
|
-
|
|
56
|
-
return tensors
|
|
57
|
-
|
|
58
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
61
|
-
"""
|
|
62
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
63
|
-
"""
|
|
64
|
-
matrix = []
|
|
65
|
-
float_data = False
|
|
66
|
-
original_type = original_device = None
|
|
67
|
-
for m in mat:
|
|
68
|
-
if m is None: continue
|
|
69
|
-
if len(m) == 0:
|
|
70
|
-
matrix.append([])
|
|
71
|
-
continue
|
|
72
|
-
if m.dtype != torch.float:
|
|
73
|
-
original_type = m.dtype
|
|
74
|
-
original_device = m.device
|
|
75
|
-
matrix.append(m.float())
|
|
76
|
-
else:
|
|
77
|
-
float_data = True
|
|
78
|
-
matrix.append(m)
|
|
79
|
-
|
|
80
|
-
final = []
|
|
81
|
-
for m in matrix:
|
|
82
|
-
if len(m) == 0:
|
|
83
|
-
final.append([])
|
|
84
|
-
continue
|
|
85
|
-
try:
|
|
86
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
87
|
-
except Exception:
|
|
88
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
89
|
-
Q = Q.to(m.dtype)
|
|
90
|
-
Q = torch.flip(Q, [1])
|
|
91
|
-
|
|
92
|
-
if not float_data:
|
|
93
|
-
Q = Q.to(original_device).type(original_type)
|
|
94
|
-
final.append(Q)
|
|
95
|
-
return final
|
|
96
|
-
|
|
97
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
98
|
-
@torch.no_grad
|
|
99
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
100
|
-
"""
|
|
101
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
102
|
-
followed by torch.linalg.qr decomposition.
|
|
103
|
-
"""
|
|
104
|
-
matrix = []
|
|
105
|
-
orth_matrix = []
|
|
106
|
-
float_data = False
|
|
107
|
-
original_type = original_device = None
|
|
108
|
-
for m,o in zip(GG, Q_list):
|
|
109
|
-
if m is None: continue
|
|
110
|
-
assert o is not None
|
|
111
|
-
|
|
112
|
-
if len(m) == 0:
|
|
113
|
-
matrix.append([])
|
|
114
|
-
orth_matrix.append([])
|
|
115
|
-
continue
|
|
116
|
-
if m.data.dtype != torch.float:
|
|
117
|
-
original_type = m.data.dtype
|
|
118
|
-
original_device = m.data.device
|
|
119
|
-
matrix.append(m.data.float())
|
|
120
|
-
orth_matrix.append(o.data.float())
|
|
121
|
-
else:
|
|
122
|
-
float_data = True
|
|
123
|
-
matrix.append(m.data.float())
|
|
124
|
-
orth_matrix.append(o.data.float())
|
|
125
|
-
|
|
126
|
-
final = []
|
|
127
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
128
|
-
if len(m)==0:
|
|
129
|
-
final.append([])
|
|
130
|
-
continue
|
|
131
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
132
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
133
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
134
|
-
o = o[:,sort_idx]
|
|
135
|
-
power_iter = m @ o
|
|
136
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
137
|
-
|
|
138
|
-
if not float_data:
|
|
139
|
-
Q = Q.to(original_device).type(original_type)
|
|
140
|
-
final.append(Q)
|
|
141
|
-
|
|
142
|
-
return final, exp_avg_sq
|
|
143
34
|
|
|
144
35
|
class AdaSOAP(Transform):
|
|
145
|
-
"""SOAP with diagonally preconditioned GG^Ts
|
|
36
|
+
"""SOAP with diagonally preconditioned GG^Ts.
|
|
37
|
+
|
|
38
|
+
.. warning::
|
|
39
|
+
Experimental.
|
|
146
40
|
|
|
147
41
|
precond_beta - beta for GG^T squares
|
|
42
|
+
|
|
43
|
+
Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
|
|
148
44
|
"""
|
|
149
45
|
def __init__(
|
|
150
46
|
self,
|
|
@@ -180,15 +76,14 @@ class AdaSOAP(Transform):
|
|
|
180
76
|
super().__init__(defaults, uses_grad=False)
|
|
181
77
|
|
|
182
78
|
@torch.no_grad
|
|
183
|
-
def
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
184
80
|
updates = []
|
|
185
81
|
# update preconditioners
|
|
186
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
187
|
-
|
|
188
|
-
settings = self.settings[p]
|
|
82
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
83
|
+
|
|
189
84
|
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
190
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(
|
|
191
|
-
precond_beta =
|
|
85
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
|
|
86
|
+
precond_beta = setting['precond_beta']
|
|
192
87
|
|
|
193
88
|
if merge_small:
|
|
194
89
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -213,7 +108,7 @@ class AdaSOAP(Transform):
|
|
|
213
108
|
|
|
214
109
|
if state['GG'] is not None:
|
|
215
110
|
assert state['GG_sq'] is not None
|
|
216
|
-
|
|
111
|
+
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
217
112
|
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
218
113
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
114
|
|
|
@@ -259,7 +154,7 @@ class AdaSOAP(Transform):
|
|
|
259
154
|
if t_projected is not None:
|
|
260
155
|
update = project_back(update, state["Q"])
|
|
261
156
|
|
|
262
|
-
if
|
|
157
|
+
if setting['bias_correction']:
|
|
263
158
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
264
159
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
265
160
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -274,9 +169,9 @@ class AdaSOAP(Transform):
|
|
|
274
169
|
|
|
275
170
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
276
171
|
if state['GG'] is not None:
|
|
277
|
-
|
|
172
|
+
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
278
173
|
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
279
|
-
if state['step'] %
|
|
174
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
280
175
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
|
|
281
176
|
|
|
282
177
|
return updates
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""A bunch of useless modules that I hate and that didn't work"""
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Transform, apply_transform
|
|
5
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CosineStepSize(Transform):
|
|
9
|
+
"""Adaptive step size based on cosine similarity
|
|
10
|
+
|
|
11
|
+
VERDICT: Useless. This is too unstable.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
15
|
+
init (float, optional): initial step size. Defaults to 1.
|
|
16
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
17
|
+
target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
|
|
18
|
+
inner (Chainable | None, optional):
|
|
19
|
+
inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
|
|
22
|
+
defaults = dict(scale=scale, init=init, eps=eps)
|
|
23
|
+
super().__init__(defaults, uses_grad=False)
|
|
24
|
+
if inner is not None: self.set_child('inner', inner)
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
28
|
+
scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
|
|
29
|
+
unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
|
|
30
|
+
eps = settings[0]['eps']
|
|
31
|
+
|
|
32
|
+
tensors = as_tensorlist(tensors)
|
|
33
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
34
|
+
|
|
35
|
+
tensors_norm = tensors.global_vector_norm()
|
|
36
|
+
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
37
|
+
|
|
38
|
+
if 'inner' in self.children:
|
|
39
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
40
|
+
|
|
41
|
+
new_alpha = []
|
|
42
|
+
for s, sc in zip(states, scale):
|
|
43
|
+
s['alpha'] *= 1 + cos_sim * sc
|
|
44
|
+
new_alpha.append(s['alpha'])
|
|
45
|
+
|
|
46
|
+
tensors.mul_(new_alpha)
|
|
47
|
+
prev.copy_(tensors)
|
|
48
|
+
|
|
49
|
+
return tensors
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CosineDebounce(Transform):
|
|
54
|
+
"""Debouncing when cosine similarity is less than 0.
|
|
55
|
+
|
|
56
|
+
VERDICT: Useless. This doesn't help at all.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
60
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
61
|
+
inner (Chainable | None, optional):
|
|
62
|
+
inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
|
|
63
|
+
"""
|
|
64
|
+
def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
|
|
65
|
+
defaults = dict(scale=scale, eps=eps, damping=damping)
|
|
66
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
|
+
if inner is not None: self.set_child('inner', inner)
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
71
|
+
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
72
|
+
eps = settings[0]['eps']
|
|
73
|
+
|
|
74
|
+
tensors = as_tensorlist(tensors)
|
|
75
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
|
|
76
|
+
|
|
77
|
+
tensors_norm = tensors.global_vector_norm()
|
|
78
|
+
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
79
|
+
|
|
80
|
+
if 'inner' in self.children:
|
|
81
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
82
|
+
|
|
83
|
+
if cos_sim < -eps:
|
|
84
|
+
undo = prev.neg().mul_(-cos_sim * scale)
|
|
85
|
+
comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
|
|
86
|
+
tensors = undo.add_(comb)
|
|
87
|
+
|
|
88
|
+
prev.copy_(tensors)
|
|
89
|
+
return tensors
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class CosineMomentum(Transform):
|
|
94
|
+
"""Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
|
|
95
|
+
|
|
96
|
+
VERDICT: Useless. Worse than all other momentums.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 1.
|
|
100
|
+
nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
|
|
101
|
+
power (float, optional): power for beta. Defaults to 1.
|
|
102
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
103
|
+
inner (Chainable | None, optional):
|
|
104
|
+
inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
|
|
105
|
+
"""
|
|
106
|
+
def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
|
|
107
|
+
defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
|
|
108
|
+
super().__init__(defaults, uses_grad=False)
|
|
109
|
+
if inner is not None: self.set_child('inner', inner)
|
|
110
|
+
|
|
111
|
+
@torch.no_grad
|
|
112
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
113
|
+
scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
|
|
114
|
+
eps = settings[0]['eps']
|
|
115
|
+
nesterov = settings[0]['nesterov']
|
|
116
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
|
|
117
|
+
|
|
118
|
+
tensors = as_tensorlist(tensors)
|
|
119
|
+
|
|
120
|
+
tensors_norm = tensors.global_vector_norm()
|
|
121
|
+
cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
|
|
122
|
+
|
|
123
|
+
if 'inner' in self.children:
|
|
124
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
125
|
+
|
|
126
|
+
beta = (1 - (cos_sim*scale)) ** power
|
|
127
|
+
if nesterov:
|
|
128
|
+
exp_avg.add_(tensors.mul(beta))
|
|
129
|
+
return tensors.add_(exp_avg)
|
|
130
|
+
else:
|
|
131
|
+
exp_avg.add_(tensors.mul_(beta))
|
|
132
|
+
return exp_avg.clone()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class AdaptiveDifference(Transform):
|
|
136
|
+
"""VERDICT: Useless. Doesn't help (sort of to be expected)."""
|
|
137
|
+
def __init__(self, inner:Chainable | None = None):
|
|
138
|
+
defaults = dict()
|
|
139
|
+
super().__init__(defaults, uses_grad=False)
|
|
140
|
+
if inner is not None: self.set_child('inner', inner)
|
|
141
|
+
|
|
142
|
+
@torch.no_grad
|
|
143
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
144
|
+
tensors = as_tensorlist(tensors)
|
|
145
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
146
|
+
|
|
147
|
+
diff = tensors - prev.graft_(tensors)
|
|
148
|
+
prev.copy_(tensors)
|
|
149
|
+
|
|
150
|
+
if 'inner' in self.children:
|
|
151
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
152
|
+
|
|
153
|
+
tensors.add_(diff.graft_(tensors))
|
|
154
|
+
|
|
155
|
+
return tensors
|
|
156
|
+
|
|
157
|
+
class AdaptiveDifferenceEMA(Transform):
|
|
158
|
+
"""VERDICT: better than non-EMA but still useless."""
|
|
159
|
+
def __init__(self, beta=0.99, inner:Chainable | None = None):
|
|
160
|
+
defaults = dict(beta=beta)
|
|
161
|
+
super().__init__(defaults, uses_grad=False)
|
|
162
|
+
if inner is not None: self.set_child('inner', inner)
|
|
163
|
+
|
|
164
|
+
@torch.no_grad
|
|
165
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
166
|
+
tensors = as_tensorlist(tensors)
|
|
167
|
+
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
168
|
+
prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
|
|
169
|
+
|
|
170
|
+
diff = (tensors - prev.graft_(tensors)).graft_(tensors)
|
|
171
|
+
diff_exp_avg.lerp_(diff, 1-beta)
|
|
172
|
+
prev.copy_(tensors)
|
|
173
|
+
|
|
174
|
+
if 'inner' in self.children:
|
|
175
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
176
|
+
|
|
177
|
+
tensors.add_(diff_exp_avg.graft(tensors))
|
|
178
|
+
|
|
179
|
+
return tensors
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ScaledAdaptiveDifference(Transform):
|
|
183
|
+
"""VERDICT: Useless and doesn't help."""
|
|
184
|
+
def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
|
|
185
|
+
defaults = dict(scale=scale, damping=damping)
|
|
186
|
+
super().__init__(defaults, uses_grad=False)
|
|
187
|
+
if inner is not None: self.set_child('inner', inner)
|
|
188
|
+
|
|
189
|
+
@torch.no_grad
|
|
190
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
191
|
+
tensors = as_tensorlist(tensors)
|
|
192
|
+
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
193
|
+
prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
|
|
194
|
+
|
|
195
|
+
cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
|
|
196
|
+
|
|
197
|
+
if 'inner' in self.children:
|
|
198
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
199
|
+
|
|
200
|
+
if cos_sim > 0:
|
|
201
|
+
tensors.add_(prev_tensors*(cos_sim*scale))
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
undo = prev_tensors.neg().mul_(-cos_sim*scale)
|
|
205
|
+
comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
|
|
206
|
+
tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
|
|
207
|
+
|
|
208
|
+
diff = tensors - prev_tensors.graft_(tensors)
|
|
209
|
+
prev_tensors.copy_(tensors)
|
|
210
|
+
diff.graft_(tensors)
|
|
211
|
+
tensors.add_(diff)
|
|
212
|
+
prev_update.copy_(tensors)
|
|
213
|
+
|
|
214
|
+
return tensors
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def signed_cbrt(x: TensorList) -> TensorList:
|
|
8
|
+
return x.sign() * x.abs().pow(1/3)
|
|
9
|
+
|
|
10
|
+
def cubic_adam_(
|
|
11
|
+
tensors: TensorList,
|
|
12
|
+
exp_avg_: TensorList,
|
|
13
|
+
exp_avg_sq_: TensorList,
|
|
14
|
+
exp_avg_cu_: TensorList,
|
|
15
|
+
alpha: float | NumberList,
|
|
16
|
+
beta1: float | NumberList,
|
|
17
|
+
beta2: float | NumberList,
|
|
18
|
+
beta3: float | NumberList,
|
|
19
|
+
eps: float | NumberList,
|
|
20
|
+
debiased: bool,
|
|
21
|
+
step: int,
|
|
22
|
+
):
|
|
23
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
24
|
+
exp_avg_sq_.lerp_(tensors**2, 1-beta2)
|
|
25
|
+
exp_avg_cu_.lerp_(tensors**3, 1-beta3)
|
|
26
|
+
|
|
27
|
+
if debiased:
|
|
28
|
+
m1 = exp_avg_ / (1 - beta1 ** step)
|
|
29
|
+
m2 = exp_avg_sq_ / (1 - beta2 ** step)
|
|
30
|
+
m3 = exp_avg_cu_ / (1 - beta3 ** step)
|
|
31
|
+
else:
|
|
32
|
+
m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
|
|
33
|
+
|
|
34
|
+
# adam minimizes ax^2 + bx
|
|
35
|
+
# we are going to minimize ax^3 + bx^2 + cx
|
|
36
|
+
A = signed_cbrt(m3)
|
|
37
|
+
B = m2.sqrt()
|
|
38
|
+
C = m1
|
|
39
|
+
discriminant = B.pow(2) - 4 * A * C
|
|
40
|
+
|
|
41
|
+
denom = 2 * A
|
|
42
|
+
root = discriminant.clamp(min=0).sqrt_()
|
|
43
|
+
|
|
44
|
+
x0 = (-B + root) / (denom + eps)
|
|
45
|
+
x1 = (-B - root) / (denom + eps)
|
|
46
|
+
|
|
47
|
+
f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
|
|
48
|
+
f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
|
|
49
|
+
|
|
50
|
+
x_star = x0.where(f0 < f1, x1)
|
|
51
|
+
|
|
52
|
+
adam = -C / (B + eps)
|
|
53
|
+
x_star = adam.where(discriminant < 0, x_star)
|
|
54
|
+
|
|
55
|
+
return x_star.mul_(-alpha)
|
|
56
|
+
|
|
57
|
+
class CubicAdam(Transform):
|
|
58
|
+
"""Adam which has 3rd momentum and minimizes a cubic polynomial.
|
|
59
|
+
|
|
60
|
+
VERDICT: can outperform Adam very slightly. Usually very similar performance.
|
|
61
|
+
|
|
62
|
+
.. warning::
|
|
63
|
+
Experimental.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
beta1: float = 0.9,
|
|
69
|
+
beta2: float = 0.99,
|
|
70
|
+
beta3: float = 0.99,
|
|
71
|
+
eps: float = 1e-8,
|
|
72
|
+
debiased:bool=True,
|
|
73
|
+
alpha: float = 1.,
|
|
74
|
+
):
|
|
75
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
|
|
76
|
+
super().__init__(defaults, uses_grad=False)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
81
|
+
|
|
82
|
+
beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
|
|
83
|
+
exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
|
|
84
|
+
|
|
85
|
+
return cubic_adam_(
|
|
86
|
+
tensors=TensorList(tensors),
|
|
87
|
+
exp_avg_=exp_avg,
|
|
88
|
+
exp_avg_sq_=exp_avg_sq,
|
|
89
|
+
exp_avg_cu_=exp_avg_cu,
|
|
90
|
+
alpha=alpha,
|
|
91
|
+
beta1=beta1,
|
|
92
|
+
beta2=beta2,
|
|
93
|
+
beta3=beta3,
|
|
94
|
+
eps=eps,
|
|
95
|
+
debiased=settings[0]['debiased'],
|
|
96
|
+
step=step,
|
|
97
|
+
)
|
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply_transform
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
8
|
|
|
@@ -47,27 +47,27 @@ class CurveBall(Module):
|
|
|
47
47
|
if inner is not None: self.set_child('inner', inner)
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def step(self,
|
|
50
|
+
def step(self, var):
|
|
51
51
|
|
|
52
|
-
params =
|
|
52
|
+
params = var.params
|
|
53
53
|
settings = self.settings[params[0]]
|
|
54
54
|
hvp_method = settings['hvp_method']
|
|
55
55
|
h = settings['h']
|
|
56
56
|
|
|
57
|
-
precond_lr, momentum, reg = self.get_settings('
|
|
57
|
+
precond_lr, momentum, reg = self.get_settings(params, 'precond_lr', 'momentum', 'reg', cls=NumberList)
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
closure =
|
|
60
|
+
closure = var.closure
|
|
61
61
|
assert closure is not None
|
|
62
62
|
|
|
63
|
-
z, Hz = self.get_state('z', 'Hz',
|
|
63
|
+
z, Hz = self.get_state(params, 'z', 'Hz', cls=TensorList)
|
|
64
64
|
|
|
65
65
|
if hvp_method == 'autograd':
|
|
66
|
-
grad =
|
|
66
|
+
grad = var.get_grad(create_graph=True)
|
|
67
67
|
Hvp = hvp(params, grad, z)
|
|
68
68
|
|
|
69
69
|
elif hvp_method == 'forward':
|
|
70
|
-
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=
|
|
70
|
+
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
|
|
71
71
|
|
|
72
72
|
elif hvp_method == 'central':
|
|
73
73
|
loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
|
|
@@ -79,11 +79,11 @@ class CurveBall(Module):
|
|
|
79
79
|
Hz.set_(Hvp + z*reg)
|
|
80
80
|
|
|
81
81
|
|
|
82
|
-
update =
|
|
82
|
+
update = var.get_update()
|
|
83
83
|
if 'inner' in self.children:
|
|
84
|
-
update =
|
|
84
|
+
update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
|
|
85
85
|
|
|
86
86
|
z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
|
|
87
|
-
|
|
87
|
+
var.update = z.neg()
|
|
88
88
|
|
|
89
|
-
return
|
|
89
|
+
return var
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
import torch
|
|
3
3
|
import torch_dct
|
|
4
|
-
from
|
|
4
|
+
from ..projections import ProjectionBase
|
|
5
5
|
from ...core import Chainable
|
|
6
6
|
|
|
7
7
|
def reverse_dims(t:torch.Tensor):
|
|
8
8
|
return t.permute(*reversed(range(t.ndim)))
|
|
9
9
|
|
|
10
|
-
class DCTProjection(
|
|
10
|
+
class DCTProjection(ProjectionBase):
|
|
11
11
|
# norm description copied from pytorch docstring
|
|
12
12
|
"""Project update into Discrete Cosine Transform space, requires `torch_dct` library.
|
|
13
13
|
|
|
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
|
|
|
34
34
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def project(self, tensors,
|
|
38
|
-
settings =
|
|
37
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
38
|
+
settings = settings[0]
|
|
39
39
|
dims = settings['dims']
|
|
40
40
|
norm = settings['norm']
|
|
41
41
|
|
|
@@ -54,18 +54,18 @@ class DCTProjection(Projection):
|
|
|
54
54
|
return projected
|
|
55
55
|
|
|
56
56
|
@torch.no_grad
|
|
57
|
-
def unproject(self,
|
|
58
|
-
settings =
|
|
57
|
+
def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
|
|
58
|
+
settings = projected_settings[0]
|
|
59
59
|
dims = settings['dims']
|
|
60
60
|
norm = settings['norm']
|
|
61
61
|
|
|
62
62
|
unprojected = []
|
|
63
|
-
for
|
|
64
|
-
dim = min(
|
|
63
|
+
for t in projected_tensors:
|
|
64
|
+
dim = min(t.ndim, dims)
|
|
65
65
|
|
|
66
|
-
if dim == 1: idct = torch_dct.idct(
|
|
67
|
-
elif dim == 2: idct = torch_dct.idct_2d(
|
|
68
|
-
elif dim == 3: idct = torch_dct.idct_3d(
|
|
66
|
+
if dim == 1: idct = torch_dct.idct(t, norm = norm)
|
|
67
|
+
elif dim == 2: idct = torch_dct.idct_2d(t, norm=norm)
|
|
68
|
+
elif dim == 3: idct = torch_dct.idct_3d(t, norm=norm)
|
|
69
69
|
else: raise ValueError(f"Unsupported number of dimensions {dim}")
|
|
70
70
|
|
|
71
71
|
unprojected.append(reverse_dims(idct))
|