torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -3,13 +3,22 @@ from .adadam import Adadam
|
|
|
3
3
|
from .adamY import AdamY
|
|
4
4
|
from .adasoap import AdaSOAP
|
|
5
5
|
from .curveball import CurveBall
|
|
6
|
-
from .
|
|
6
|
+
from .eigendescent import EigenDescent
|
|
7
|
+
from .etf import (
|
|
8
|
+
ExponentialTrajectoryFit,
|
|
9
|
+
ExponentialTrajectoryFitV2,
|
|
10
|
+
PointwiseExponential,
|
|
11
|
+
)
|
|
7
12
|
from .gradmin import GradMin
|
|
13
|
+
from .newton_solver import NewtonSolver
|
|
14
|
+
from .newtonnewton import NewtonNewton
|
|
8
15
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
16
|
+
from .soapy import SOAPY
|
|
9
17
|
from .spectral import SpectralPreconditioner
|
|
18
|
+
from .structured_newton import StructuredNewton
|
|
10
19
|
from .subspace_preconditioners import (
|
|
11
20
|
HistorySubspacePreconditioning,
|
|
12
21
|
RandomSubspacePreconditioning,
|
|
13
22
|
)
|
|
14
|
-
from .
|
|
15
|
-
from .
|
|
23
|
+
from .tada import TAda
|
|
24
|
+
from .diagonal_higher_order_newton import DiagonalHigherOrderNewton
|
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Transform
|
|
6
7
|
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
8
|
+
from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
|
|
7
9
|
|
|
8
10
|
@torch.no_grad
|
|
9
|
-
def
|
|
11
|
+
def update_absoap_covariances_(
|
|
10
12
|
g1: torch.Tensor,
|
|
11
13
|
g2: torch.Tensor,
|
|
12
14
|
GGs_: list[torch.Tensor | None],
|
|
@@ -19,138 +21,33 @@ def update_soap_covariances_(
|
|
|
19
21
|
if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
20
22
|
else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
21
23
|
|
|
22
|
-
@torch.no_grad
|
|
23
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
24
|
-
"""
|
|
25
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
26
|
-
"""
|
|
27
|
-
for mat in Q:
|
|
28
|
-
if mat is None: continue
|
|
29
|
-
if len(mat) > 0:
|
|
30
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
31
|
-
else:
|
|
32
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
33
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
34
|
-
tensors = tensors.permute(permute_order)
|
|
35
|
-
|
|
36
|
-
return tensors
|
|
37
24
|
|
|
38
|
-
|
|
39
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
40
|
-
"""
|
|
41
|
-
Projects the gradient back to the original space.
|
|
42
|
-
"""
|
|
43
|
-
for mat in Q:
|
|
44
|
-
if mat is None: continue
|
|
45
|
-
if len(mat) > 0:
|
|
46
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
47
|
-
else:
|
|
48
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
49
|
-
tensors = tensors.permute(permute_order)
|
|
50
|
-
|
|
51
|
-
return tensors
|
|
52
|
-
|
|
53
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
54
|
-
@torch.no_grad
|
|
55
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
56
|
-
"""
|
|
57
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
58
|
-
"""
|
|
59
|
-
matrix = []
|
|
60
|
-
float_data = False
|
|
61
|
-
original_type = original_device = None
|
|
62
|
-
for m in mat:
|
|
63
|
-
if m is None: continue
|
|
64
|
-
if len(m) == 0:
|
|
65
|
-
matrix.append([])
|
|
66
|
-
continue
|
|
67
|
-
if m.dtype != torch.float:
|
|
68
|
-
original_type = m.dtype
|
|
69
|
-
original_device = m.device
|
|
70
|
-
matrix.append(m.float())
|
|
71
|
-
else:
|
|
72
|
-
float_data = True
|
|
73
|
-
matrix.append(m)
|
|
74
|
-
|
|
75
|
-
final = []
|
|
76
|
-
for m in matrix:
|
|
77
|
-
if len(m) == 0:
|
|
78
|
-
final.append([])
|
|
79
|
-
continue
|
|
80
|
-
try:
|
|
81
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
82
|
-
except Exception:
|
|
83
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
84
|
-
Q = Q.to(m.dtype)
|
|
85
|
-
Q = torch.flip(Q, [1])
|
|
86
|
-
|
|
87
|
-
if not float_data:
|
|
88
|
-
Q = Q.to(original_device).type(original_type)
|
|
89
|
-
final.append(Q)
|
|
90
|
-
return final
|
|
91
|
-
|
|
92
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
93
|
-
@torch.no_grad
|
|
94
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
95
|
-
"""
|
|
96
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
97
|
-
followed by torch.linalg.qr decomposition.
|
|
98
|
-
"""
|
|
99
|
-
matrix = []
|
|
100
|
-
orth_matrix = []
|
|
101
|
-
float_data = False
|
|
102
|
-
original_type = original_device = None
|
|
103
|
-
for m,o in zip(GG, Q_list):
|
|
104
|
-
if m is None: continue
|
|
105
|
-
assert o is not None
|
|
106
|
-
|
|
107
|
-
if len(m) == 0:
|
|
108
|
-
matrix.append([])
|
|
109
|
-
orth_matrix.append([])
|
|
110
|
-
continue
|
|
111
|
-
if m.data.dtype != torch.float:
|
|
112
|
-
original_type = m.data.dtype
|
|
113
|
-
original_device = m.data.device
|
|
114
|
-
matrix.append(m.data.float())
|
|
115
|
-
orth_matrix.append(o.data.float())
|
|
116
|
-
else:
|
|
117
|
-
float_data = True
|
|
118
|
-
matrix.append(m.data.float())
|
|
119
|
-
orth_matrix.append(o.data.float())
|
|
120
|
-
|
|
121
|
-
final = []
|
|
122
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
123
|
-
if len(m)==0:
|
|
124
|
-
final.append([])
|
|
125
|
-
continue
|
|
126
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
127
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
128
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
129
|
-
o = o[:,sort_idx]
|
|
130
|
-
power_iter = m @ o
|
|
131
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
132
|
-
|
|
133
|
-
if not float_data:
|
|
134
|
-
Q = Q.to(original_device).type(original_type)
|
|
135
|
-
final.append(Q)
|
|
136
|
-
|
|
137
|
-
return final, exp_avg_sq
|
|
138
|
-
|
|
139
|
-
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
25
|
+
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
|
|
140
26
|
class ABSOAP(Transform):
|
|
141
|
-
"""SOAP but with
|
|
142
|
-
|
|
143
|
-
|
|
27
|
+
"""SOAP but with some extra options for testing. Please note that this is experimental and isn't guaranteed to work.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
scale_by_s - whether to scale y by s
|
|
31
|
+
gg1 - 1st vector into GGᵀ
|
|
32
|
+
gg2 - 2nd vector into GGᵀ
|
|
33
|
+
ema1 - vector into 1st momentum
|
|
34
|
+
ema2 - 2 vectors into 2nd momentum
|
|
35
|
+
rel1 - if True, multiplies gg1 by params
|
|
36
|
+
rel2 - same but for gg2
|
|
37
|
+
norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
|
|
38
|
+
|
|
39
|
+
letters:
|
|
40
|
+
p - params
|
|
41
|
+
g - grad
|
|
42
|
+
s - param difference
|
|
43
|
+
y - grad difference
|
|
44
|
+
gy - g+y
|
|
45
|
+
sy - s+y
|
|
46
|
+
sn - s normalized
|
|
47
|
+
yn - y normalized
|
|
48
|
+
gys - g + y#g
|
|
49
|
+
sys - s + y#s
|
|
144
50
|
|
|
145
|
-
new args
|
|
146
|
-
|
|
147
|
-
scale by s whether to scale gradient differences by parameter differences
|
|
148
|
-
|
|
149
|
-
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
-
|
|
151
|
-
okay I changed these args into another ones
|
|
152
|
-
|
|
153
|
-
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
154
51
|
"""
|
|
155
52
|
def __init__(
|
|
156
53
|
self,
|
|
@@ -166,8 +63,8 @@ class ABSOAP(Transform):
|
|
|
166
63
|
alpha: float = 1,
|
|
167
64
|
bias_correction: bool = True,
|
|
168
65
|
scale_by_s: bool = True,
|
|
169
|
-
|
|
170
|
-
|
|
66
|
+
gg1: Source='g',
|
|
67
|
+
gg2: Source='g',
|
|
171
68
|
ema1: Source='g',
|
|
172
69
|
ema2: tuple[Source, Source] = ('g','g'),
|
|
173
70
|
rel1: bool=False,
|
|
@@ -189,29 +86,27 @@ class ABSOAP(Transform):
|
|
|
189
86
|
scale_by_s=scale_by_s,
|
|
190
87
|
ema1=ema1,
|
|
191
88
|
ema2=ema2,
|
|
192
|
-
first=
|
|
193
|
-
second=
|
|
89
|
+
first=gg1,
|
|
90
|
+
second=gg2,
|
|
194
91
|
rel1=rel1, rel2=rel2,
|
|
195
92
|
norm=norm,
|
|
196
93
|
)
|
|
197
94
|
super().__init__(defaults, uses_grad=False)
|
|
198
95
|
|
|
199
96
|
@torch.no_grad
|
|
200
|
-
def
|
|
97
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
201
98
|
updates = []
|
|
202
99
|
# update preconditioners
|
|
203
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
204
|
-
state = self.state[p]
|
|
205
|
-
settings = self.settings[p]
|
|
100
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
206
101
|
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
207
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(
|
|
208
|
-
scale_by_s =
|
|
209
|
-
ema1 =
|
|
210
|
-
ema2 =
|
|
211
|
-
first=
|
|
212
|
-
second=
|
|
213
|
-
rel1 =
|
|
214
|
-
norm=
|
|
102
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
103
|
+
scale_by_s = setting['scale_by_s']
|
|
104
|
+
ema1 = setting['ema1']
|
|
105
|
+
ema2 = setting['ema2']
|
|
106
|
+
first=setting['first']
|
|
107
|
+
second=setting['second']
|
|
108
|
+
rel1 = setting['rel1']; rel2 = setting['rel2']
|
|
109
|
+
norm=setting['norm']
|
|
215
110
|
|
|
216
111
|
if merge_small:
|
|
217
112
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -219,8 +114,8 @@ class ABSOAP(Transform):
|
|
|
219
114
|
if 'g_prev' not in state:
|
|
220
115
|
state['p_prev'] = p.clone()
|
|
221
116
|
state['g_prev'] = t.clone()
|
|
222
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
223
|
-
continue
|
|
117
|
+
# updates.append(tensors[i].clip(-0.1,0.1))
|
|
118
|
+
# continue
|
|
224
119
|
|
|
225
120
|
p_prev = state['p_prev']
|
|
226
121
|
g_prev = state['g_prev']
|
|
@@ -270,11 +165,10 @@ class ABSOAP(Transform):
|
|
|
270
165
|
t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
|
|
271
166
|
t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
|
|
272
167
|
|
|
273
|
-
|
|
274
168
|
# initialize state on 1st step
|
|
275
169
|
if 'GG' not in state:
|
|
276
170
|
state["exp_avg"] = torch.zeros_like(t)
|
|
277
|
-
state["exp_avg_sq"] = torch.
|
|
171
|
+
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
278
172
|
|
|
279
173
|
if not precondition_1d and t.ndim <= 1:
|
|
280
174
|
state['GG'] = []
|
|
@@ -287,7 +181,7 @@ class ABSOAP(Transform):
|
|
|
287
181
|
state['GG'] = None
|
|
288
182
|
|
|
289
183
|
if state['GG'] is not None:
|
|
290
|
-
|
|
184
|
+
update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
|
|
291
185
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
292
186
|
|
|
293
187
|
state['step'] = 0
|
|
@@ -334,7 +228,7 @@ class ABSOAP(Transform):
|
|
|
334
228
|
if z1_projected is not None:
|
|
335
229
|
update = project_back(update, state["Q"])
|
|
336
230
|
|
|
337
|
-
if
|
|
231
|
+
if setting['bias_correction']:
|
|
338
232
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
339
233
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
340
234
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -349,8 +243,8 @@ class ABSOAP(Transform):
|
|
|
349
243
|
|
|
350
244
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
351
245
|
if state['GG'] is not None:
|
|
352
|
-
|
|
353
|
-
if state['step'] %
|
|
246
|
+
update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
|
|
247
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
354
248
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
355
249
|
|
|
356
250
|
return updates
|
|
@@ -50,7 +50,7 @@ def adadam_(
|
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
52
|
class Adadam(Module):
|
|
53
|
-
"""Adam with a diagonally preconditioned preconditioner."""
|
|
53
|
+
"""Adam with a diagonally preconditioned preconditioner. Please note that this is experimental and isn't guaranteed to work."""
|
|
54
54
|
def __init__(
|
|
55
55
|
self,
|
|
56
56
|
beta1: float = 0.9,
|
|
@@ -67,31 +67,32 @@ class Adadam(Module):
|
|
|
67
67
|
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
68
68
|
|
|
69
69
|
@torch.no_grad
|
|
70
|
-
def step(self,
|
|
70
|
+
def step(self, var):
|
|
71
71
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
+
params = var.params
|
|
72
73
|
|
|
73
|
-
beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha',
|
|
74
|
-
amsgrad,pow,debiased = self.getter(self.settings[
|
|
74
|
+
beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
|
|
75
|
+
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
75
76
|
|
|
76
77
|
if amsgrad:
|
|
77
|
-
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu',
|
|
78
|
+
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
|
|
78
79
|
else:
|
|
79
|
-
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu',
|
|
80
|
+
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
|
|
80
81
|
max_exp_avg_sq = None
|
|
81
82
|
max_exp_avg_qu = None
|
|
82
83
|
|
|
83
84
|
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
84
|
-
if
|
|
85
|
-
if
|
|
86
|
-
passed_params = TensorList(
|
|
87
|
-
|
|
88
|
-
|
|
85
|
+
if var.is_last:
|
|
86
|
+
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
87
|
+
passed_params = TensorList(var.params)
|
|
88
|
+
var.stop = True
|
|
89
|
+
var.skip_update = True
|
|
89
90
|
|
|
90
91
|
else:
|
|
91
92
|
passed_params = None
|
|
92
93
|
|
|
93
|
-
|
|
94
|
-
tensors=TensorList(
|
|
94
|
+
var.update = adadam_(
|
|
95
|
+
tensors=TensorList(var.get_update()),
|
|
95
96
|
exp_avg_=exp_avg,
|
|
96
97
|
exp_avg_sq_=exp_avg_sq,
|
|
97
98
|
exp_avg_qu_=exp_avg_qu,
|
|
@@ -108,4 +109,4 @@ class Adadam(Module):
|
|
|
108
109
|
params_=passed_params,
|
|
109
110
|
)
|
|
110
111
|
|
|
111
|
-
return
|
|
112
|
+
return var
|
|
@@ -62,17 +62,7 @@ def adamy_(
|
|
|
62
62
|
return None
|
|
63
63
|
|
|
64
64
|
class AdamY(Module):
|
|
65
|
-
"""Adam but uses scaled gradient differences for second momentum.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
69
|
-
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
70
|
-
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
71
|
-
alpha (float, optional): learning rate. Defaults to 1.
|
|
72
|
-
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
73
|
-
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
74
|
-
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
75
|
-
"""
|
|
65
|
+
"""Adam but uses scaled gradient differences for second momentum. Please note that this is experimental and isn't guaranteed to work."""
|
|
76
66
|
def __init__(
|
|
77
67
|
self,
|
|
78
68
|
beta1: float = 0.9,
|
|
@@ -88,36 +78,36 @@ class AdamY(Module):
|
|
|
88
78
|
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
89
79
|
|
|
90
80
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
81
|
+
def step(self, var):
|
|
92
82
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
93
83
|
|
|
94
|
-
beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha',
|
|
95
|
-
amsgrad,pow,debiased = self.getter(self.settings[
|
|
84
|
+
beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
85
|
+
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
96
86
|
|
|
97
87
|
if amsgrad:
|
|
98
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq',
|
|
88
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
|
|
99
89
|
else:
|
|
100
|
-
exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq',
|
|
90
|
+
exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
|
|
101
91
|
max_exp_avg_sq = None
|
|
102
92
|
|
|
103
93
|
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
104
|
-
if
|
|
105
|
-
if
|
|
106
|
-
passed_params = TensorList(
|
|
107
|
-
|
|
108
|
-
|
|
94
|
+
if var.is_last:
|
|
95
|
+
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
96
|
+
passed_params = TensorList(var.params)
|
|
97
|
+
var.stop = True
|
|
98
|
+
var.skip_update = True
|
|
109
99
|
|
|
110
100
|
else:
|
|
111
101
|
passed_params = None
|
|
112
102
|
|
|
113
|
-
p_prev = self.get_state('p_prev',
|
|
114
|
-
g_prev = self.get_state('g_prev',
|
|
103
|
+
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
104
|
+
g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
|
|
115
105
|
|
|
116
106
|
|
|
117
|
-
|
|
118
|
-
p=TensorList(
|
|
107
|
+
var.update = adamy_(
|
|
108
|
+
p=TensorList(var.params),
|
|
119
109
|
p_prev=p_prev,
|
|
120
|
-
g=TensorList(
|
|
110
|
+
g=TensorList(var.get_update()),
|
|
121
111
|
g_prev=g_prev,
|
|
122
112
|
exp_avg_=exp_avg,
|
|
123
113
|
exp_avg_sq_=exp_avg_sq,
|
|
@@ -132,4 +122,4 @@ class AdamY(Module):
|
|
|
132
122
|
params_=passed_params,
|
|
133
123
|
)
|
|
134
124
|
|
|
135
|
-
return
|
|
125
|
+
return var
|
|
@@ -2,11 +2,18 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Transform
|
|
5
|
+
from ...core import Chainable, Transform
|
|
6
6
|
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
from ..optimizers.soap import (
|
|
8
|
+
get_orthogonal_matrix,
|
|
9
|
+
get_orthogonal_matrix_QR,
|
|
10
|
+
project,
|
|
11
|
+
project_back,
|
|
12
|
+
)
|
|
13
|
+
|
|
7
14
|
|
|
8
15
|
@torch.no_grad
|
|
9
|
-
def
|
|
16
|
+
def update_adasoap_covariances_(
|
|
10
17
|
grad: torch.Tensor,
|
|
11
18
|
GGs_: list[torch.Tensor | None],
|
|
12
19
|
GG_sqs: list[torch.Tensor | None],
|
|
@@ -24,125 +31,9 @@ def update_soap_covariances_(
|
|
|
24
31
|
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
25
32
|
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
26
33
|
|
|
27
|
-
@torch.no_grad
|
|
28
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
29
|
-
"""
|
|
30
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
31
|
-
"""
|
|
32
|
-
for mat in Q:
|
|
33
|
-
if mat is None: continue
|
|
34
|
-
if len(mat) > 0:
|
|
35
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
36
|
-
else:
|
|
37
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
38
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
39
|
-
tensors = tensors.permute(permute_order)
|
|
40
|
-
|
|
41
|
-
return tensors
|
|
42
|
-
|
|
43
|
-
@torch.no_grad
|
|
44
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
45
|
-
"""
|
|
46
|
-
Projects the gradient back to the original space.
|
|
47
|
-
"""
|
|
48
|
-
for mat in Q:
|
|
49
|
-
if mat is None: continue
|
|
50
|
-
if len(mat) > 0:
|
|
51
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
52
|
-
else:
|
|
53
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
54
|
-
tensors = tensors.permute(permute_order)
|
|
55
|
-
|
|
56
|
-
return tensors
|
|
57
|
-
|
|
58
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
59
|
-
@torch.no_grad
|
|
60
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
61
|
-
"""
|
|
62
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
63
|
-
"""
|
|
64
|
-
matrix = []
|
|
65
|
-
float_data = False
|
|
66
|
-
original_type = original_device = None
|
|
67
|
-
for m in mat:
|
|
68
|
-
if m is None: continue
|
|
69
|
-
if len(m) == 0:
|
|
70
|
-
matrix.append([])
|
|
71
|
-
continue
|
|
72
|
-
if m.dtype != torch.float:
|
|
73
|
-
original_type = m.dtype
|
|
74
|
-
original_device = m.device
|
|
75
|
-
matrix.append(m.float())
|
|
76
|
-
else:
|
|
77
|
-
float_data = True
|
|
78
|
-
matrix.append(m)
|
|
79
|
-
|
|
80
|
-
final = []
|
|
81
|
-
for m in matrix:
|
|
82
|
-
if len(m) == 0:
|
|
83
|
-
final.append([])
|
|
84
|
-
continue
|
|
85
|
-
try:
|
|
86
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
87
|
-
except Exception:
|
|
88
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
89
|
-
Q = Q.to(m.dtype)
|
|
90
|
-
Q = torch.flip(Q, [1])
|
|
91
|
-
|
|
92
|
-
if not float_data:
|
|
93
|
-
Q = Q.to(original_device).type(original_type)
|
|
94
|
-
final.append(Q)
|
|
95
|
-
return final
|
|
96
|
-
|
|
97
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
98
|
-
@torch.no_grad
|
|
99
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
100
|
-
"""
|
|
101
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
102
|
-
followed by torch.linalg.qr decomposition.
|
|
103
|
-
"""
|
|
104
|
-
matrix = []
|
|
105
|
-
orth_matrix = []
|
|
106
|
-
float_data = False
|
|
107
|
-
original_type = original_device = None
|
|
108
|
-
for m,o in zip(GG, Q_list):
|
|
109
|
-
if m is None: continue
|
|
110
|
-
assert o is not None
|
|
111
|
-
|
|
112
|
-
if len(m) == 0:
|
|
113
|
-
matrix.append([])
|
|
114
|
-
orth_matrix.append([])
|
|
115
|
-
continue
|
|
116
|
-
if m.data.dtype != torch.float:
|
|
117
|
-
original_type = m.data.dtype
|
|
118
|
-
original_device = m.data.device
|
|
119
|
-
matrix.append(m.data.float())
|
|
120
|
-
orth_matrix.append(o.data.float())
|
|
121
|
-
else:
|
|
122
|
-
float_data = True
|
|
123
|
-
matrix.append(m.data.float())
|
|
124
|
-
orth_matrix.append(o.data.float())
|
|
125
|
-
|
|
126
|
-
final = []
|
|
127
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
128
|
-
if len(m)==0:
|
|
129
|
-
final.append([])
|
|
130
|
-
continue
|
|
131
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
132
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
133
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
134
|
-
o = o[:,sort_idx]
|
|
135
|
-
power_iter = m @ o
|
|
136
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
137
|
-
|
|
138
|
-
if not float_data:
|
|
139
|
-
Q = Q.to(original_device).type(original_type)
|
|
140
|
-
final.append(Q)
|
|
141
|
-
|
|
142
|
-
return final, exp_avg_sq
|
|
143
34
|
|
|
144
35
|
class AdaSOAP(Transform):
|
|
145
|
-
"""SOAP with diagonally preconditioned GG^Ts
|
|
36
|
+
"""SOAP with diagonally preconditioned GG^Ts. Please note that this is experimental and isn't guaranteed to work.
|
|
146
37
|
|
|
147
38
|
precond_beta - beta for GG^T squares
|
|
148
39
|
"""
|
|
@@ -180,15 +71,14 @@ class AdaSOAP(Transform):
|
|
|
180
71
|
super().__init__(defaults, uses_grad=False)
|
|
181
72
|
|
|
182
73
|
@torch.no_grad
|
|
183
|
-
def
|
|
74
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
184
75
|
updates = []
|
|
185
76
|
# update preconditioners
|
|
186
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
187
|
-
|
|
188
|
-
settings = self.settings[p]
|
|
77
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
78
|
+
|
|
189
79
|
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
190
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(
|
|
191
|
-
precond_beta =
|
|
80
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(setting)
|
|
81
|
+
precond_beta = setting['precond_beta']
|
|
192
82
|
|
|
193
83
|
if merge_small:
|
|
194
84
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -213,7 +103,7 @@ class AdaSOAP(Transform):
|
|
|
213
103
|
|
|
214
104
|
if state['GG'] is not None:
|
|
215
105
|
assert state['GG_sq'] is not None
|
|
216
|
-
|
|
106
|
+
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
217
107
|
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
218
108
|
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
109
|
|
|
@@ -259,7 +149,7 @@ class AdaSOAP(Transform):
|
|
|
259
149
|
if t_projected is not None:
|
|
260
150
|
update = project_back(update, state["Q"])
|
|
261
151
|
|
|
262
|
-
if
|
|
152
|
+
if setting['bias_correction']:
|
|
263
153
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
264
154
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
265
155
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -274,9 +164,9 @@ class AdaSOAP(Transform):
|
|
|
274
164
|
|
|
275
165
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
276
166
|
if state['GG'] is not None:
|
|
277
|
-
|
|
167
|
+
update_adasoap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
278
168
|
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
279
|
-
if state['step'] %
|
|
169
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
280
170
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
|
|
281
171
|
|
|
282
172
|
return updates
|