torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
|
-
|
|
5
|
-
from ...core import Chainable, Transform
|
|
5
|
+
|
|
6
|
+
from ...core import Chainable, Transform
|
|
6
7
|
from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
8
|
+
from ..optimizers.soap import project, project_back, get_orthogonal_matrix, get_orthogonal_matrix_QR
|
|
7
9
|
|
|
8
10
|
@torch.no_grad
|
|
9
|
-
def
|
|
11
|
+
def update_absoap_covariances_(
|
|
10
12
|
g1: torch.Tensor,
|
|
11
13
|
g2: torch.Tensor,
|
|
12
14
|
GGs_: list[torch.Tensor | None],
|
|
@@ -19,138 +21,36 @@ def update_soap_covariances_(
|
|
|
19
21
|
if beta is None: GG.add_(torch.tensordot(g1, g2, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
20
22
|
else: GG.lerp_(torch.tensordot(g1, g2, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
21
23
|
|
|
22
|
-
@torch.no_grad
|
|
23
|
-
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
24
|
-
"""
|
|
25
|
-
Projects the gradient to the eigenbases of the preconditioner.
|
|
26
|
-
"""
|
|
27
|
-
for mat in Q:
|
|
28
|
-
if mat is None: continue
|
|
29
|
-
if len(mat) > 0:
|
|
30
|
-
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
31
|
-
else:
|
|
32
|
-
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
33
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
34
|
-
tensors = tensors.permute(permute_order)
|
|
35
|
-
|
|
36
|
-
return tensors
|
|
37
24
|
|
|
38
|
-
|
|
39
|
-
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
40
|
-
"""
|
|
41
|
-
Projects the gradient back to the original space.
|
|
42
|
-
"""
|
|
43
|
-
for mat in Q:
|
|
44
|
-
if mat is None: continue
|
|
45
|
-
if len(mat) > 0:
|
|
46
|
-
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
47
|
-
else:
|
|
48
|
-
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
49
|
-
tensors = tensors.permute(permute_order)
|
|
50
|
-
|
|
51
|
-
return tensors
|
|
52
|
-
|
|
53
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
54
|
-
@torch.no_grad
|
|
55
|
-
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
56
|
-
"""
|
|
57
|
-
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
58
|
-
"""
|
|
59
|
-
matrix = []
|
|
60
|
-
float_data = False
|
|
61
|
-
original_type = original_device = None
|
|
62
|
-
for m in mat:
|
|
63
|
-
if m is None: continue
|
|
64
|
-
if len(m) == 0:
|
|
65
|
-
matrix.append([])
|
|
66
|
-
continue
|
|
67
|
-
if m.dtype != torch.float:
|
|
68
|
-
original_type = m.dtype
|
|
69
|
-
original_device = m.device
|
|
70
|
-
matrix.append(m.float())
|
|
71
|
-
else:
|
|
72
|
-
float_data = True
|
|
73
|
-
matrix.append(m)
|
|
74
|
-
|
|
75
|
-
final = []
|
|
76
|
-
for m in matrix:
|
|
77
|
-
if len(m) == 0:
|
|
78
|
-
final.append([])
|
|
79
|
-
continue
|
|
80
|
-
try:
|
|
81
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
82
|
-
except Exception:
|
|
83
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
84
|
-
Q = Q.to(m.dtype)
|
|
85
|
-
Q = torch.flip(Q, [1])
|
|
86
|
-
|
|
87
|
-
if not float_data:
|
|
88
|
-
Q = Q.to(original_device).type(original_type)
|
|
89
|
-
final.append(Q)
|
|
90
|
-
return final
|
|
91
|
-
|
|
92
|
-
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
93
|
-
@torch.no_grad
|
|
94
|
-
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
95
|
-
"""
|
|
96
|
-
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
97
|
-
followed by torch.linalg.qr decomposition.
|
|
98
|
-
"""
|
|
99
|
-
matrix = []
|
|
100
|
-
orth_matrix = []
|
|
101
|
-
float_data = False
|
|
102
|
-
original_type = original_device = None
|
|
103
|
-
for m,o in zip(GG, Q_list):
|
|
104
|
-
if m is None: continue
|
|
105
|
-
assert o is not None
|
|
106
|
-
|
|
107
|
-
if len(m) == 0:
|
|
108
|
-
matrix.append([])
|
|
109
|
-
orth_matrix.append([])
|
|
110
|
-
continue
|
|
111
|
-
if m.data.dtype != torch.float:
|
|
112
|
-
original_type = m.data.dtype
|
|
113
|
-
original_device = m.data.device
|
|
114
|
-
matrix.append(m.data.float())
|
|
115
|
-
orth_matrix.append(o.data.float())
|
|
116
|
-
else:
|
|
117
|
-
float_data = True
|
|
118
|
-
matrix.append(m.data.float())
|
|
119
|
-
orth_matrix.append(o.data.float())
|
|
120
|
-
|
|
121
|
-
final = []
|
|
122
|
-
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
123
|
-
if len(m)==0:
|
|
124
|
-
final.append([])
|
|
125
|
-
continue
|
|
126
|
-
est_eig = torch.diag(o.T @ m @ o)
|
|
127
|
-
sort_idx = torch.argsort(est_eig, descending=True)
|
|
128
|
-
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
129
|
-
o = o[:,sort_idx]
|
|
130
|
-
power_iter = m @ o
|
|
131
|
-
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
132
|
-
|
|
133
|
-
if not float_data:
|
|
134
|
-
Q = Q.to(original_device).type(original_type)
|
|
135
|
-
final.append(Q)
|
|
136
|
-
|
|
137
|
-
return final, exp_avg_sq
|
|
138
|
-
|
|
139
|
-
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys','sn', 'yn']
|
|
25
|
+
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
|
|
140
26
|
class ABSOAP(Transform):
|
|
141
|
-
"""SOAP but with
|
|
142
|
-
|
|
143
|
-
|
|
27
|
+
"""SOAP but with some extra options for testing.
|
|
28
|
+
|
|
29
|
+
.. warning::
|
|
30
|
+
This module is just for testing my stupid ideas.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
scale_by_s - whether to scale y by s
|
|
34
|
+
gg1 - 1st vector into GGᵀ
|
|
35
|
+
gg2 - 2nd vector into GGᵀ
|
|
36
|
+
ema1 - vector into 1st momentum
|
|
37
|
+
ema2 - 2 vectors into 2nd momentum
|
|
38
|
+
rel1 - if True, multiplies gg1 by params
|
|
39
|
+
rel2 - same but for gg2
|
|
40
|
+
norm - if True, gg1 a and gg2 are normalized, and I need to make that into a letter
|
|
41
|
+
|
|
42
|
+
letters:
|
|
43
|
+
p - params
|
|
44
|
+
g - grad
|
|
45
|
+
s - param difference
|
|
46
|
+
y - grad difference
|
|
47
|
+
gy - g+y
|
|
48
|
+
sy - s+y
|
|
49
|
+
sn - s normalized
|
|
50
|
+
yn - y normalized
|
|
51
|
+
gys - g + y#g
|
|
52
|
+
sys - s + y#s
|
|
144
53
|
|
|
145
|
-
new args
|
|
146
|
-
|
|
147
|
-
scale by s whether to scale gradient differences by parameter differences
|
|
148
|
-
|
|
149
|
-
y_to_ema2 whether to use gradient differences for exponential moving average too
|
|
150
|
-
|
|
151
|
-
okay I changed these args into another ones
|
|
152
|
-
|
|
153
|
-
BASICALLY THIS IS FOR MY EXPERIMENTS
|
|
154
54
|
"""
|
|
155
55
|
def __init__(
|
|
156
56
|
self,
|
|
@@ -166,8 +66,8 @@ class ABSOAP(Transform):
|
|
|
166
66
|
alpha: float = 1,
|
|
167
67
|
bias_correction: bool = True,
|
|
168
68
|
scale_by_s: bool = True,
|
|
169
|
-
|
|
170
|
-
|
|
69
|
+
gg1: Source='g',
|
|
70
|
+
gg2: Source='g',
|
|
171
71
|
ema1: Source='g',
|
|
172
72
|
ema2: tuple[Source, Source] = ('g','g'),
|
|
173
73
|
rel1: bool=False,
|
|
@@ -189,29 +89,27 @@ class ABSOAP(Transform):
|
|
|
189
89
|
scale_by_s=scale_by_s,
|
|
190
90
|
ema1=ema1,
|
|
191
91
|
ema2=ema2,
|
|
192
|
-
first=
|
|
193
|
-
second=
|
|
92
|
+
first=gg1,
|
|
93
|
+
second=gg2,
|
|
194
94
|
rel1=rel1, rel2=rel2,
|
|
195
95
|
norm=norm,
|
|
196
96
|
)
|
|
197
97
|
super().__init__(defaults, uses_grad=False)
|
|
198
98
|
|
|
199
99
|
@torch.no_grad
|
|
200
|
-
def
|
|
100
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
201
101
|
updates = []
|
|
202
102
|
# update preconditioners
|
|
203
|
-
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
204
|
-
state = self.state[p]
|
|
205
|
-
settings = self.settings[p]
|
|
103
|
+
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
206
104
|
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
|
|
207
|
-
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(
|
|
208
|
-
scale_by_s =
|
|
209
|
-
ema1 =
|
|
210
|
-
ema2 =
|
|
211
|
-
first=
|
|
212
|
-
second=
|
|
213
|
-
rel1 =
|
|
214
|
-
norm=
|
|
105
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
|
|
106
|
+
scale_by_s = setting['scale_by_s']
|
|
107
|
+
ema1 = setting['ema1']
|
|
108
|
+
ema2 = setting['ema2']
|
|
109
|
+
first=setting['first']
|
|
110
|
+
second=setting['second']
|
|
111
|
+
rel1 = setting['rel1']; rel2 = setting['rel2']
|
|
112
|
+
norm=setting['norm']
|
|
215
113
|
|
|
216
114
|
if merge_small:
|
|
217
115
|
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
@@ -219,8 +117,8 @@ class ABSOAP(Transform):
|
|
|
219
117
|
if 'g_prev' not in state:
|
|
220
118
|
state['p_prev'] = p.clone()
|
|
221
119
|
state['g_prev'] = t.clone()
|
|
222
|
-
updates.append(tensors[i].clip(-0.1,0.1))
|
|
223
|
-
continue
|
|
120
|
+
# updates.append(tensors[i].clip(-0.1,0.1))
|
|
121
|
+
# continue
|
|
224
122
|
|
|
225
123
|
p_prev = state['p_prev']
|
|
226
124
|
g_prev = state['g_prev']
|
|
@@ -270,11 +168,10 @@ class ABSOAP(Transform):
|
|
|
270
168
|
t1 = t1/torch.linalg.vector_norm(t1).clip(min=1e-8) # pylint:disable=not-callable
|
|
271
169
|
t2 = t2/torch.linalg.vector_norm(t2).clip(min=1e-8) # pylint:disable=not-callable
|
|
272
170
|
|
|
273
|
-
|
|
274
171
|
# initialize state on 1st step
|
|
275
172
|
if 'GG' not in state:
|
|
276
173
|
state["exp_avg"] = torch.zeros_like(t)
|
|
277
|
-
state["exp_avg_sq"] = torch.
|
|
174
|
+
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
278
175
|
|
|
279
176
|
if not precondition_1d and t.ndim <= 1:
|
|
280
177
|
state['GG'] = []
|
|
@@ -287,7 +184,7 @@ class ABSOAP(Transform):
|
|
|
287
184
|
state['GG'] = None
|
|
288
185
|
|
|
289
186
|
if state['GG'] is not None:
|
|
290
|
-
|
|
187
|
+
update_absoap_covariances_(t1, t2, GGs_=state['GG'], beta=shampoo_beta)
|
|
291
188
|
state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
292
189
|
|
|
293
190
|
state['step'] = 0
|
|
@@ -334,7 +231,7 @@ class ABSOAP(Transform):
|
|
|
334
231
|
if z1_projected is not None:
|
|
335
232
|
update = project_back(update, state["Q"])
|
|
336
233
|
|
|
337
|
-
if
|
|
234
|
+
if setting['bias_correction']:
|
|
338
235
|
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
339
236
|
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
340
237
|
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
@@ -349,8 +246,8 @@ class ABSOAP(Transform):
|
|
|
349
246
|
|
|
350
247
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
351
248
|
if state['GG'] is not None:
|
|
352
|
-
|
|
353
|
-
if state['step'] %
|
|
249
|
+
update_absoap_covariances_(t1, t2, state['GG'], shampoo_beta)
|
|
250
|
+
if state['step'] % setting['precond_freq'] == 0:
|
|
354
251
|
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
|
|
355
252
|
|
|
356
253
|
return updates
|
|
@@ -10,7 +10,7 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -50,7 +50,13 @@ def adadam_(
|
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
52
|
class Adadam(Module):
|
|
53
|
-
"""Adam with a diagonally preconditioned preconditioner.
|
|
53
|
+
"""Adam with a diagonally preconditioned preconditioner.
|
|
54
|
+
|
|
55
|
+
Verdict: I haven't tested this yet.
|
|
56
|
+
|
|
57
|
+
.. warning::
|
|
58
|
+
Experimental.
|
|
59
|
+
"""
|
|
54
60
|
def __init__(
|
|
55
61
|
self,
|
|
56
62
|
beta1: float = 0.9,
|
|
@@ -67,31 +73,32 @@ class Adadam(Module):
|
|
|
67
73
|
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
68
74
|
|
|
69
75
|
@torch.no_grad
|
|
70
|
-
def step(self,
|
|
76
|
+
def step(self, var):
|
|
71
77
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
78
|
+
params = var.params
|
|
72
79
|
|
|
73
|
-
beta1,beta2,precond_beta,eps,alpha=self.get_settings('beta1','beta2','precond_beta','eps','alpha',
|
|
74
|
-
amsgrad,pow,debiased = self.getter(self.settings[
|
|
80
|
+
beta1,beta2,precond_beta,eps,alpha=self.get_settings(params, 'beta1','beta2','precond_beta','eps','alpha', cls=NumberList)
|
|
81
|
+
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
75
82
|
|
|
76
83
|
if amsgrad:
|
|
77
|
-
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu',
|
|
84
|
+
exp_avg, exp_avg_sq, exp_avg_qu, max_exp_avg_sq, max_exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', 'max_exp_avg_sq', 'max_exp_avg_qu', cls=TensorList)
|
|
78
85
|
else:
|
|
79
|
-
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state('exp_avg','exp_avg_sq', 'exp_avg_qu',
|
|
86
|
+
exp_avg, exp_avg_sq, exp_avg_qu = self.get_state(params, 'exp_avg','exp_avg_sq', 'exp_avg_qu', cls=TensorList)
|
|
80
87
|
max_exp_avg_sq = None
|
|
81
88
|
max_exp_avg_qu = None
|
|
82
89
|
|
|
83
90
|
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
84
|
-
if
|
|
85
|
-
if
|
|
86
|
-
passed_params = TensorList(
|
|
87
|
-
|
|
88
|
-
|
|
91
|
+
if var.is_last:
|
|
92
|
+
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
93
|
+
passed_params = TensorList(var.params)
|
|
94
|
+
var.stop = True
|
|
95
|
+
var.skip_update = True
|
|
89
96
|
|
|
90
97
|
else:
|
|
91
98
|
passed_params = None
|
|
92
99
|
|
|
93
|
-
|
|
94
|
-
tensors=TensorList(
|
|
100
|
+
var.update = adadam_(
|
|
101
|
+
tensors=TensorList(var.get_update()),
|
|
95
102
|
exp_avg_=exp_avg,
|
|
96
103
|
exp_avg_sq_=exp_avg_sq,
|
|
97
104
|
exp_avg_qu_=exp_avg_qu,
|
|
@@ -108,4 +115,4 @@ class Adadam(Module):
|
|
|
108
115
|
params_=passed_params,
|
|
109
116
|
)
|
|
110
117
|
|
|
111
|
-
return
|
|
118
|
+
return var
|
|
@@ -10,7 +10,7 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -64,14 +64,10 @@ def adamy_(
|
|
|
64
64
|
class AdamY(Module):
|
|
65
65
|
"""Adam but uses scaled gradient differences for second momentum.
|
|
66
66
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
alpha (float, optional): learning rate. Defaults to 1.
|
|
72
|
-
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
73
|
-
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
74
|
-
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
67
|
+
Verdict: I haven't tested this yet.
|
|
68
|
+
|
|
69
|
+
.. warning::
|
|
70
|
+
Experimental.
|
|
75
71
|
"""
|
|
76
72
|
def __init__(
|
|
77
73
|
self,
|
|
@@ -88,36 +84,36 @@ class AdamY(Module):
|
|
|
88
84
|
self.getter = itemgetter('amsgrad','pow','debiased')
|
|
89
85
|
|
|
90
86
|
@torch.no_grad
|
|
91
|
-
def step(self,
|
|
87
|
+
def step(self, var):
|
|
92
88
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
93
89
|
|
|
94
|
-
beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha',
|
|
95
|
-
amsgrad,pow,debiased = self.getter(self.settings[
|
|
90
|
+
beta1,beta2,eps,alpha=self.get_settings(var.params, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
91
|
+
amsgrad,pow,debiased = self.getter(self.settings[var.params[0]])
|
|
96
92
|
|
|
97
93
|
if amsgrad:
|
|
98
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq',
|
|
94
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state(var.params,'exp_avg','exp_avg_sq','max_exp_avg_sq', cls=TensorList)
|
|
99
95
|
else:
|
|
100
|
-
exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq',
|
|
96
|
+
exp_avg, exp_avg_sq = self.get_state(var.params, 'exp_avg','exp_avg_sq', cls=TensorList)
|
|
101
97
|
max_exp_avg_sq = None
|
|
102
98
|
|
|
103
99
|
# if this is last module, update parameters in-place with slightly more efficient addcdiv_
|
|
104
|
-
if
|
|
105
|
-
if
|
|
106
|
-
passed_params = TensorList(
|
|
107
|
-
|
|
108
|
-
|
|
100
|
+
if var.is_last:
|
|
101
|
+
if var.last_module_lrs is not None: alpha = alpha * var.last_module_lrs
|
|
102
|
+
passed_params = TensorList(var.params)
|
|
103
|
+
var.stop = True
|
|
104
|
+
var.skip_update = True
|
|
109
105
|
|
|
110
106
|
else:
|
|
111
107
|
passed_params = None
|
|
112
108
|
|
|
113
|
-
p_prev = self.get_state('p_prev',
|
|
114
|
-
g_prev = self.get_state('g_prev',
|
|
109
|
+
p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
|
|
110
|
+
g_prev = self.get_state(var.params, 'g_prev', cls=TensorList)
|
|
115
111
|
|
|
116
112
|
|
|
117
|
-
|
|
118
|
-
p=TensorList(
|
|
113
|
+
var.update = adamy_(
|
|
114
|
+
p=TensorList(var.params),
|
|
119
115
|
p_prev=p_prev,
|
|
120
|
-
g=TensorList(
|
|
116
|
+
g=TensorList(var.get_update()),
|
|
121
117
|
g_prev=g_prev,
|
|
122
118
|
exp_avg_=exp_avg,
|
|
123
119
|
exp_avg_sq_=exp_avg_sq,
|
|
@@ -132,4 +128,4 @@ class AdamY(Module):
|
|
|
132
128
|
params_=passed_params,
|
|
133
129
|
)
|
|
134
130
|
|
|
135
|
-
return
|
|
131
|
+
return var
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _lambertw_newton_raphson(x: TensorList, iterations=5):
|
|
19
|
+
# z = torch.zeros_like(x)
|
|
20
|
+
# mask_neg = x < 0
|
|
21
|
+
# mask_pos = ~mask_neg
|
|
22
|
+
|
|
23
|
+
# z[mask_pos] = torch.log(x[mask_pos] + 1.0)
|
|
24
|
+
|
|
25
|
+
# x_neg = x[mask_neg]
|
|
26
|
+
# z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
|
|
27
|
+
# z[mask_neg] = z_neg
|
|
28
|
+
|
|
29
|
+
# x is always positive
|
|
30
|
+
z = (x+1).log_()
|
|
31
|
+
for _ in range(iterations):
|
|
32
|
+
exp_z = z.exp()
|
|
33
|
+
numerator = z * exp_z - x
|
|
34
|
+
denominator = exp_z * (z + 1.0) + 1e-8
|
|
35
|
+
delta = numerator / denominator
|
|
36
|
+
z -= delta
|
|
37
|
+
return z
|
|
38
|
+
|
|
39
|
+
# https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
|
|
40
|
+
def _lambertw_winitzki(x: TensorList):
|
|
41
|
+
x_log1p = x.log1p()
|
|
42
|
+
return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def adam_lambertw_(
|
|
46
|
+
tensors: TensorList,
|
|
47
|
+
exp_avg_: TensorList,
|
|
48
|
+
exp_avg_xpx_: TensorList,
|
|
49
|
+
alpha: float | NumberList,
|
|
50
|
+
beta1: float | NumberList,
|
|
51
|
+
beta2: float | NumberList,
|
|
52
|
+
eps: float | NumberList,
|
|
53
|
+
step: int,
|
|
54
|
+
pow: float = 2,
|
|
55
|
+
debiased: bool = True,
|
|
56
|
+
max_exp_avg_xpx_: TensorList | None = None,
|
|
57
|
+
iterations: int | None = 5,
|
|
58
|
+
|
|
59
|
+
# inner args
|
|
60
|
+
inner: Module | None = None,
|
|
61
|
+
params: list[torch.Tensor] | None = None,
|
|
62
|
+
grads: list[torch.Tensor] | None = None,
|
|
63
|
+
):
|
|
64
|
+
"""Returns new tensors."""
|
|
65
|
+
tensors_abs = tensors.abs().clip_(max=20)
|
|
66
|
+
tensors_xpx = tensors_abs.pow_(tensors_abs)
|
|
67
|
+
exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
|
|
68
|
+
|
|
69
|
+
if max_exp_avg_xpx_ is not None:
|
|
70
|
+
max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
|
|
71
|
+
exp_avg_xpx_ = max_exp_avg_xpx_
|
|
72
|
+
|
|
73
|
+
if inner is not None:
|
|
74
|
+
assert params is not None
|
|
75
|
+
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
76
|
+
|
|
77
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
78
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
79
|
+
|
|
80
|
+
if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
|
|
81
|
+
else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
|
|
82
|
+
|
|
83
|
+
return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
|
|
84
|
+
|
|
85
|
+
class AdamLambertW(Transform):
|
|
86
|
+
"""Adam but uses abs x^x and LambertW instead of square and sqrt.
|
|
87
|
+
The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
91
|
+
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
92
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
93
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
94
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
95
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
96
|
+
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
97
|
+
iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
|
|
98
|
+
"""
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
beta1: float = 0.9,
|
|
102
|
+
beta2: float = 0.999,
|
|
103
|
+
eps: float = 1e-8,
|
|
104
|
+
amsgrad: bool = False,
|
|
105
|
+
alpha: float = 1.,
|
|
106
|
+
pow: float = 2,
|
|
107
|
+
debiased: bool = True,
|
|
108
|
+
iterations: int | None = 5,
|
|
109
|
+
inner: Chainable | None = None
|
|
110
|
+
):
|
|
111
|
+
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
|
|
112
|
+
super().__init__(defaults, uses_grad=False)
|
|
113
|
+
|
|
114
|
+
if inner is not None: self.set_child('inner', inner)
|
|
115
|
+
|
|
116
|
+
@torch.no_grad
|
|
117
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
118
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
119
|
+
|
|
120
|
+
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
121
|
+
amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
|
|
122
|
+
|
|
123
|
+
if amsgrad:
|
|
124
|
+
exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
|
|
125
|
+
else:
|
|
126
|
+
exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
|
|
127
|
+
max_exp_avg_xpx = None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
return adam_lambertw_(
|
|
131
|
+
tensors=TensorList(tensors),
|
|
132
|
+
exp_avg_=exp_avg,
|
|
133
|
+
exp_avg_xpx_=exp_avg_xpx,
|
|
134
|
+
alpha=alpha,
|
|
135
|
+
beta1=beta1,
|
|
136
|
+
beta2=beta2,
|
|
137
|
+
eps=eps,
|
|
138
|
+
step=step,
|
|
139
|
+
pow=pow,
|
|
140
|
+
debiased=debiased,
|
|
141
|
+
max_exp_avg_xpx_=max_exp_avg_xpx,
|
|
142
|
+
iterations=iterations,
|
|
143
|
+
|
|
144
|
+
# inner args
|
|
145
|
+
inner=self.children.get("inner", None),
|
|
146
|
+
params=params,
|
|
147
|
+
grads=grads,
|
|
148
|
+
|
|
149
|
+
)
|
|
@@ -2,35 +2,64 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from ..line_search import LineSearchBase
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
9
|
-
"""Basic first order
|
|
8
|
+
class AdaptiveStepSize(LineSearchBase):
|
|
9
|
+
"""Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
|
|
10
|
+
step size is increased. If value increased, step size is decreased.
|
|
11
|
+
|
|
12
|
+
.. note::
|
|
13
|
+
This works well in some cases, but it is often prone to collapsing.
|
|
14
|
+
For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
|
|
18
|
+
nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
|
|
19
|
+
c (float, optional): descent condition. Defaults to 1e-4.
|
|
20
|
+
init (float, optional): initial step size. Defaults to 1.
|
|
21
|
+
backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
|
|
22
|
+
adaptive (bool, optional):
|
|
23
|
+
If enabled, when multiple consecutive steps have been successful or unsuccessful,
|
|
24
|
+
the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
Adagrad with trust region:
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
opt = tz.Modular(
|
|
33
|
+
model.parameters(),
|
|
34
|
+
tz.m.Adagrad(),
|
|
35
|
+
tz.m.TrustRegion()
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
"""
|
|
10
39
|
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
11
40
|
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
12
41
|
super().__init__(defaults)
|
|
13
42
|
|
|
14
43
|
@torch.no_grad
|
|
15
|
-
def search(self, update,
|
|
44
|
+
def search(self, update, var):
|
|
16
45
|
|
|
17
|
-
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[
|
|
46
|
+
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
|
|
18
47
|
step_size = self.global_state.setdefault('step_size', init)
|
|
19
48
|
previous_success = self.global_state.setdefault('previous_success', False)
|
|
20
49
|
nplus_mul = self.global_state.setdefault('nplus_mul', 1)
|
|
21
50
|
nminus_mul = self.global_state.setdefault('nminus_mul', 1)
|
|
22
51
|
|
|
23
52
|
|
|
24
|
-
f_0 = self.evaluate_step_size(0,
|
|
53
|
+
f_0 = self.evaluate_step_size(0, var, backward=False)
|
|
25
54
|
|
|
26
55
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
27
56
|
if c == 0: d = 0
|
|
28
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(
|
|
57
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
|
|
29
58
|
|
|
30
59
|
# test step size
|
|
31
60
|
sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
|
|
32
61
|
|
|
33
|
-
f_1 = self.evaluate_step_size(step_size,
|
|
62
|
+
f_1 = self.evaluate_step_size(step_size, var, backward=False)
|
|
34
63
|
|
|
35
64
|
proposed = step_size
|
|
36
65
|
|