torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Transform, apply
|
|
6
|
+
from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
7
|
+
|
|
8
|
+
@torch.no_grad
|
|
9
|
+
def update_soap_covariances_(
|
|
10
|
+
grad: torch.Tensor,
|
|
11
|
+
GGs_: list[torch.Tensor | None],
|
|
12
|
+
GG_sqs: list[torch.Tensor | None],
|
|
13
|
+
beta: float | None,
|
|
14
|
+
precond_beta: float | None,
|
|
15
|
+
):
|
|
16
|
+
for i, (GG, GG_sq) in enumerate(zip(GGs_, GG_sqs)):
|
|
17
|
+
if GG is None: continue
|
|
18
|
+
assert GG_sq is not None
|
|
19
|
+
|
|
20
|
+
if precond_beta is None: GG_sq.addcmul_(GG, GG)
|
|
21
|
+
else: GG_sq.mul_(precond_beta).addcmul_(GG, GG, value=1-precond_beta)
|
|
22
|
+
|
|
23
|
+
axes = list(range(i)) + list(range(i + 1, grad.ndim)) # this works fine with 1d params
|
|
24
|
+
if beta is None: GG.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
|
|
25
|
+
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
26
|
+
|
|
27
|
+
@torch.no_grad
|
|
28
|
+
def project(tensors: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
29
|
+
"""
|
|
30
|
+
Projects the gradient to the eigenbases of the preconditioner.
|
|
31
|
+
"""
|
|
32
|
+
for mat in Q:
|
|
33
|
+
if mat is None: continue
|
|
34
|
+
if len(mat) > 0:
|
|
35
|
+
tensors = torch.tensordot(tensors, mat, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
36
|
+
else:
|
|
37
|
+
# I don't understand this part but it is in https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
38
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
39
|
+
tensors = tensors.permute(permute_order)
|
|
40
|
+
|
|
41
|
+
return tensors
|
|
42
|
+
|
|
43
|
+
@torch.no_grad
|
|
44
|
+
def project_back(tensors: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
45
|
+
"""
|
|
46
|
+
Projects the gradient back to the original space.
|
|
47
|
+
"""
|
|
48
|
+
for mat in Q:
|
|
49
|
+
if mat is None: continue
|
|
50
|
+
if len(mat) > 0:
|
|
51
|
+
tensors = torch.tensordot(tensors, mat,dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
52
|
+
else:
|
|
53
|
+
permute_order = list(range(1, len(tensors.shape))) + [0]
|
|
54
|
+
tensors = tensors.permute(permute_order)
|
|
55
|
+
|
|
56
|
+
return tensors
|
|
57
|
+
|
|
58
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
|
|
61
|
+
"""
|
|
62
|
+
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
63
|
+
"""
|
|
64
|
+
matrix = []
|
|
65
|
+
float_data = False
|
|
66
|
+
original_type = original_device = None
|
|
67
|
+
for m in mat:
|
|
68
|
+
if m is None: continue
|
|
69
|
+
if len(m) == 0:
|
|
70
|
+
matrix.append([])
|
|
71
|
+
continue
|
|
72
|
+
if m.dtype != torch.float:
|
|
73
|
+
original_type = m.dtype
|
|
74
|
+
original_device = m.device
|
|
75
|
+
matrix.append(m.float())
|
|
76
|
+
else:
|
|
77
|
+
float_data = True
|
|
78
|
+
matrix.append(m)
|
|
79
|
+
|
|
80
|
+
final = []
|
|
81
|
+
for m in matrix:
|
|
82
|
+
if len(m) == 0:
|
|
83
|
+
final.append([])
|
|
84
|
+
continue
|
|
85
|
+
try:
|
|
86
|
+
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
87
|
+
except Exception:
|
|
88
|
+
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
89
|
+
Q = Q.to(m.dtype)
|
|
90
|
+
Q = torch.flip(Q, [1])
|
|
91
|
+
|
|
92
|
+
if not float_data:
|
|
93
|
+
Q = Q.to(original_device).type(original_type)
|
|
94
|
+
final.append(Q)
|
|
95
|
+
return final
|
|
96
|
+
|
|
97
|
+
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
|
|
98
|
+
@torch.no_grad
|
|
99
|
+
def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | None], Q_list: list[torch.Tensor | None]):
|
|
100
|
+
"""
|
|
101
|
+
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
102
|
+
followed by torch.linalg.qr decomposition.
|
|
103
|
+
"""
|
|
104
|
+
matrix = []
|
|
105
|
+
orth_matrix = []
|
|
106
|
+
float_data = False
|
|
107
|
+
original_type = original_device = None
|
|
108
|
+
for m,o in zip(GG, Q_list):
|
|
109
|
+
if m is None: continue
|
|
110
|
+
assert o is not None
|
|
111
|
+
|
|
112
|
+
if len(m) == 0:
|
|
113
|
+
matrix.append([])
|
|
114
|
+
orth_matrix.append([])
|
|
115
|
+
continue
|
|
116
|
+
if m.data.dtype != torch.float:
|
|
117
|
+
original_type = m.data.dtype
|
|
118
|
+
original_device = m.data.device
|
|
119
|
+
matrix.append(m.data.float())
|
|
120
|
+
orth_matrix.append(o.data.float())
|
|
121
|
+
else:
|
|
122
|
+
float_data = True
|
|
123
|
+
matrix.append(m.data.float())
|
|
124
|
+
orth_matrix.append(o.data.float())
|
|
125
|
+
|
|
126
|
+
final = []
|
|
127
|
+
for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
|
|
128
|
+
if len(m)==0:
|
|
129
|
+
final.append([])
|
|
130
|
+
continue
|
|
131
|
+
est_eig = torch.diag(o.T @ m @ o)
|
|
132
|
+
sort_idx = torch.argsort(est_eig, descending=True)
|
|
133
|
+
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
134
|
+
o = o[:,sort_idx]
|
|
135
|
+
power_iter = m @ o
|
|
136
|
+
Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
|
|
137
|
+
|
|
138
|
+
if not float_data:
|
|
139
|
+
Q = Q.to(original_device).type(original_type)
|
|
140
|
+
final.append(Q)
|
|
141
|
+
|
|
142
|
+
return final, exp_avg_sq
|
|
143
|
+
|
|
144
|
+
class AdaSOAP(Transform):
|
|
145
|
+
"""SOAP with diagonally preconditioned GG^Ts
|
|
146
|
+
|
|
147
|
+
precond_beta - beta for GG^T squares
|
|
148
|
+
"""
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
beta1: float = 0.95,
|
|
152
|
+
beta2: float = 0.95,
|
|
153
|
+
shampoo_beta: float | None = 0.95,
|
|
154
|
+
precond_beta: float | None = 0.95,
|
|
155
|
+
precond_freq: int = 10,
|
|
156
|
+
merge_small: bool = True,
|
|
157
|
+
max_dim: int = 2_000,
|
|
158
|
+
precondition_1d: bool = True,
|
|
159
|
+
eps: float = 1e-8,
|
|
160
|
+
decay: float | None = None,
|
|
161
|
+
alpha: float = 1,
|
|
162
|
+
unprojected_exp_avg: bool = True,
|
|
163
|
+
bias_correction: bool = True,
|
|
164
|
+
):
|
|
165
|
+
defaults = dict(
|
|
166
|
+
beta1=beta1,
|
|
167
|
+
beta2=beta2,
|
|
168
|
+
shampoo_beta=shampoo_beta,
|
|
169
|
+
precond_beta=precond_beta,
|
|
170
|
+
precond_freq=precond_freq,
|
|
171
|
+
merge_small=merge_small,
|
|
172
|
+
max_dim=max_dim,
|
|
173
|
+
precondition_1d=precondition_1d,
|
|
174
|
+
eps=eps,
|
|
175
|
+
decay=decay,
|
|
176
|
+
unprojected_exp_avg=unprojected_exp_avg,
|
|
177
|
+
bias_correction=bias_correction,
|
|
178
|
+
alpha=alpha,
|
|
179
|
+
)
|
|
180
|
+
super().__init__(defaults, uses_grad=False)
|
|
181
|
+
|
|
182
|
+
@torch.no_grad
|
|
183
|
+
def transform(self, tensors, params, grads, vars):
|
|
184
|
+
updates = []
|
|
185
|
+
# update preconditioners
|
|
186
|
+
for i,(p,t) in enumerate(zip(params, tensors)):
|
|
187
|
+
state = self.state[p]
|
|
188
|
+
settings = self.settings[p]
|
|
189
|
+
beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
|
|
190
|
+
'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
|
|
191
|
+
precond_beta = settings['precond_beta']
|
|
192
|
+
|
|
193
|
+
if merge_small:
|
|
194
|
+
t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
|
|
195
|
+
|
|
196
|
+
# initialize state on 1st step
|
|
197
|
+
if 'GG' not in state:
|
|
198
|
+
state["exp_avg"] = torch.zeros_like(t)
|
|
199
|
+
state["exp_avg_sq"] = torch.zeros_like(t)
|
|
200
|
+
|
|
201
|
+
if not precondition_1d and t.ndim <= 1:
|
|
202
|
+
state['GG'] = []
|
|
203
|
+
state['GG_sq'] = []
|
|
204
|
+
|
|
205
|
+
else:
|
|
206
|
+
state['GG'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
207
|
+
state['GG_sq'] = [torch.zeros(s, s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
|
|
208
|
+
|
|
209
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
210
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
211
|
+
state['GG'] = None
|
|
212
|
+
state['GG_sq'] = None
|
|
213
|
+
|
|
214
|
+
if state['GG'] is not None:
|
|
215
|
+
assert state['GG_sq'] is not None
|
|
216
|
+
update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
217
|
+
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
218
|
+
state['Q'] = get_orthogonal_matrix(GG_precond)
|
|
219
|
+
|
|
220
|
+
state['step'] = 0
|
|
221
|
+
updates.append(tensors[i].sign())
|
|
222
|
+
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
223
|
+
# I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
|
|
224
|
+
|
|
225
|
+
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
226
|
+
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
227
|
+
t_projected = None
|
|
228
|
+
if state['GG'] is not None:
|
|
229
|
+
t_projected = project(t, state['Q'])
|
|
230
|
+
|
|
231
|
+
# exponential moving averages
|
|
232
|
+
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
233
|
+
exp_avg: torch.Tensor = state["exp_avg"]
|
|
234
|
+
exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
|
|
235
|
+
|
|
236
|
+
if unprojected_exp_avg or t_projected is None:
|
|
237
|
+
exp_avg.lerp_(t, 1-beta1)
|
|
238
|
+
else:
|
|
239
|
+
exp_avg.lerp_(t_projected, 1-beta1)
|
|
240
|
+
|
|
241
|
+
if t_projected is None:
|
|
242
|
+
exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
|
|
243
|
+
else:
|
|
244
|
+
exp_avg_sq.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
245
|
+
|
|
246
|
+
# project exponential moving averages if they are accumulated unprojected
|
|
247
|
+
exp_avg_projected = exp_avg
|
|
248
|
+
if unprojected_exp_avg and t_projected is not None:
|
|
249
|
+
exp_avg_projected = project(exp_avg, state['Q'])
|
|
250
|
+
|
|
251
|
+
exp_avg_sq_projected = exp_avg_sq
|
|
252
|
+
|
|
253
|
+
denom = exp_avg_sq_projected.sqrt().add_(eps)
|
|
254
|
+
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
255
|
+
|
|
256
|
+
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
257
|
+
# to the original space
|
|
258
|
+
update = exp_avg_projected / denom
|
|
259
|
+
if t_projected is not None:
|
|
260
|
+
update = project_back(update, state["Q"])
|
|
261
|
+
|
|
262
|
+
if settings['bias_correction']:
|
|
263
|
+
bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
|
|
264
|
+
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
265
|
+
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
266
|
+
elif alpha is not None:
|
|
267
|
+
update *= alpha
|
|
268
|
+
|
|
269
|
+
if merge_small:
|
|
270
|
+
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
271
|
+
|
|
272
|
+
updates.append(update)
|
|
273
|
+
state["step"] += 1
|
|
274
|
+
|
|
275
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
276
|
+
if state['GG'] is not None:
|
|
277
|
+
update_soap_covariances_(t, GGs_=state['GG'], GG_sqs=state['GG_sq'], beta=shampoo_beta, precond_beta=precond_beta)
|
|
278
|
+
GG_precond = [GG / (GG_sq+1e-8) if GG is not None and GG_sq is not None else None for GG, GG_sq in zip(state['GG'], state['GG_sq'])]
|
|
279
|
+
if state['step'] % settings['precond_freq'] == 0:
|
|
280
|
+
state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, GG_precond, state['Q'])
|
|
281
|
+
|
|
282
|
+
return updates
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
import torch
|
|
6
|
+
import torchalgebras as ta
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, apply, Module
|
|
9
|
+
from ...utils import vec_to_tensors, TensorList
|
|
10
|
+
from ...utils.derivatives import (
|
|
11
|
+
hessian_list_to_mat,
|
|
12
|
+
hessian_mat,
|
|
13
|
+
jacobian_and_hessian_wrt,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
class MaxItersReached(Exception): pass
|
|
17
|
+
def tropical_lstsq(
|
|
18
|
+
H: torch.Tensor,
|
|
19
|
+
g: torch.Tensor,
|
|
20
|
+
solver,
|
|
21
|
+
maxiter,
|
|
22
|
+
tol,
|
|
23
|
+
algebra,
|
|
24
|
+
verbose,
|
|
25
|
+
):
|
|
26
|
+
"""it can run on any algebra with add despite it saying tropical"""
|
|
27
|
+
algebra = ta.get_algebra(algebra)
|
|
28
|
+
|
|
29
|
+
x = torch.zeros_like(g, requires_grad=True)
|
|
30
|
+
best_x = x.detach().clone()
|
|
31
|
+
best_loss = float('inf')
|
|
32
|
+
opt = solver([x])
|
|
33
|
+
|
|
34
|
+
niter = 0
|
|
35
|
+
def closure(backward=True):
|
|
36
|
+
nonlocal niter, best_x, best_loss
|
|
37
|
+
if niter == maxiter: raise MaxItersReached
|
|
38
|
+
niter += 1
|
|
39
|
+
|
|
40
|
+
g_hat = algebra.mm(H, x)
|
|
41
|
+
loss = torch.nn.functional.mse_loss(g_hat, g)
|
|
42
|
+
if loss < best_loss:
|
|
43
|
+
best_x = x.detach().clone()
|
|
44
|
+
best_loss = loss.detach()
|
|
45
|
+
|
|
46
|
+
if backward:
|
|
47
|
+
opt.zero_grad()
|
|
48
|
+
loss.backward()
|
|
49
|
+
return loss
|
|
50
|
+
|
|
51
|
+
loss = None
|
|
52
|
+
prev_loss = float('inf')
|
|
53
|
+
for i in range(maxiter):
|
|
54
|
+
try:
|
|
55
|
+
loss = opt.step(closure)
|
|
56
|
+
if loss == 0: break
|
|
57
|
+
if tol is not None and prev_loss - loss < tol: break
|
|
58
|
+
prev_loss = loss
|
|
59
|
+
except MaxItersReached:
|
|
60
|
+
break
|
|
61
|
+
|
|
62
|
+
if verbose: print(f'{best_loss = } after {niter} iters')
|
|
63
|
+
return best_x.detach()
|
|
64
|
+
|
|
65
|
+
def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemiring()):
|
|
66
|
+
if reg!=0:
|
|
67
|
+
I = ta.AlgebraicTensor(torch.eye(H.size(-1), dtype=H.dtype, device=H.device), algebra)
|
|
68
|
+
I = I * reg
|
|
69
|
+
H = algebra.add(H, I.data)
|
|
70
|
+
return H
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AlgebraicNewton(Module):
|
|
74
|
+
"""newton in other algebras, not practical because solving linear system is very hard."""
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
reg: float | None = None,
|
|
78
|
+
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
79
|
+
vectorize: bool = True,
|
|
80
|
+
solver=lambda p: torch.optim.LBFGS(p, line_search_fn='strong_wolfe'),
|
|
81
|
+
maxiter=1000,
|
|
82
|
+
tol: float | None = 1e-10,
|
|
83
|
+
algebra: ta.Algebra | str = 'tropical max',
|
|
84
|
+
verbose: bool = False,
|
|
85
|
+
inner: Chainable | None = None,
|
|
86
|
+
):
|
|
87
|
+
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize)
|
|
88
|
+
super().__init__(defaults)
|
|
89
|
+
|
|
90
|
+
self.algebra = ta.get_algebra(algebra)
|
|
91
|
+
self.lstsq_args:dict = dict(solver=solver, maxiter=maxiter, tol=tol, algebra=algebra, verbose=verbose)
|
|
92
|
+
|
|
93
|
+
if inner is not None:
|
|
94
|
+
self.set_child('inner', inner)
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def step(self, vars):
|
|
98
|
+
params = TensorList(vars.params)
|
|
99
|
+
closure = vars.closure
|
|
100
|
+
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
101
|
+
|
|
102
|
+
settings = self.settings[params[0]]
|
|
103
|
+
reg = settings['reg']
|
|
104
|
+
hessian_method = settings['hessian_method']
|
|
105
|
+
vectorize = settings['vectorize']
|
|
106
|
+
|
|
107
|
+
# ------------------------ calculate grad and hessian ------------------------ #
|
|
108
|
+
if hessian_method == 'autograd':
|
|
109
|
+
with torch.enable_grad():
|
|
110
|
+
loss = vars.loss = vars.loss_approx = closure(False)
|
|
111
|
+
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
112
|
+
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
113
|
+
vars.grad = g_list
|
|
114
|
+
H = hessian_list_to_mat(H_list)
|
|
115
|
+
|
|
116
|
+
elif hessian_method in ('func', 'autograd.functional'):
|
|
117
|
+
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
118
|
+
with torch.enable_grad():
|
|
119
|
+
g_list = vars.get_grad(retain_graph=True)
|
|
120
|
+
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
121
|
+
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(hessian_method)
|
|
125
|
+
|
|
126
|
+
# -------------------------------- inner step -------------------------------- #
|
|
127
|
+
if 'inner' in self.children:
|
|
128
|
+
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
129
|
+
g = torch.cat([t.view(-1) for t in g_list])
|
|
130
|
+
|
|
131
|
+
# ------------------------------- regulazition ------------------------------- #
|
|
132
|
+
if reg is not None: H = tikhonov(H, reg)
|
|
133
|
+
|
|
134
|
+
# ----------------------------------- solve ---------------------------------- #
|
|
135
|
+
tropical_update = tropical_lstsq(H, g, **self.lstsq_args)
|
|
136
|
+
# what now? w - u is not defined, it is defined for max version if u < w
|
|
137
|
+
# w = params.to_vec()
|
|
138
|
+
# w_hat = self.algebra.sub(w, tropical_update)
|
|
139
|
+
# update = w_hat - w
|
|
140
|
+
# no
|
|
141
|
+
# it makes sense to solve tropical system and sub normally
|
|
142
|
+
# the only thing is that tropical system can have no solutions
|
|
143
|
+
|
|
144
|
+
vars.update = vec_to_tensors(tropical_update, params)
|
|
145
|
+
return vars
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Module, Target, Transform, Chainable, apply
|
|
6
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
|
+
from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
|
|
8
|
+
|
|
9
|
+
def curveball(
|
|
10
|
+
tensors: TensorList,
|
|
11
|
+
z_: TensorList,
|
|
12
|
+
Hz: TensorList,
|
|
13
|
+
momentum: float | NumberList,
|
|
14
|
+
precond_lr: float | NumberList,
|
|
15
|
+
):
|
|
16
|
+
"""returns z_, clone it!!!"""
|
|
17
|
+
delta = Hz + tensors
|
|
18
|
+
z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
|
|
19
|
+
return z_
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CurveBall(Module):
|
|
23
|
+
"""CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.
|
|
24
|
+
|
|
25
|
+
For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
precond_lr (float, optional): learning rate for updating preconditioned gradients. Defaults to 1e-3.
|
|
29
|
+
momentum (float, optional): decay rate for preconditioned gradients. Defaults to 0.9.
|
|
30
|
+
hvp_method (str, optional): how to calculate hessian vector products. Defaults to "autograd".
|
|
31
|
+
h (float, optional): finite difference step size for when hvp_method is set to finite difference. Defaults to 1e-3.
|
|
32
|
+
reg (float, optional): hessian regularization. Defaults to 1.
|
|
33
|
+
inner (Chainable | None, optional): Inner modules. Defaults to None.
|
|
34
|
+
"""
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
precond_lr: float=1e-3,
|
|
38
|
+
momentum: float=0.9,
|
|
39
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
40
|
+
h: float = 1e-3,
|
|
41
|
+
reg: float = 1,
|
|
42
|
+
inner: Chainable | None = None,
|
|
43
|
+
):
|
|
44
|
+
defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
|
|
45
|
+
super().__init__(defaults)
|
|
46
|
+
|
|
47
|
+
if inner is not None: self.set_child('inner', inner)
|
|
48
|
+
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def step(self, vars):
|
|
51
|
+
|
|
52
|
+
params = vars.params
|
|
53
|
+
settings = self.settings[params[0]]
|
|
54
|
+
hvp_method = settings['hvp_method']
|
|
55
|
+
h = settings['h']
|
|
56
|
+
|
|
57
|
+
precond_lr, momentum, reg = self.get_settings('momentum', 'decay_rate', 'reg', params=params, cls=NumberList)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
closure = vars.closure
|
|
61
|
+
assert closure is not None
|
|
62
|
+
|
|
63
|
+
z, Hz = self.get_state('z', 'Hz', params=params, cls=TensorList)
|
|
64
|
+
|
|
65
|
+
if hvp_method == 'autograd':
|
|
66
|
+
grad = vars.get_grad(create_graph=True)
|
|
67
|
+
Hvp = hvp(params, grad, z)
|
|
68
|
+
|
|
69
|
+
elif hvp_method == 'forward':
|
|
70
|
+
loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=vars.get_grad(), normalize=True)
|
|
71
|
+
|
|
72
|
+
elif hvp_method == 'central':
|
|
73
|
+
loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
|
|
74
|
+
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(hvp_method)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
Hz.set_(Hvp + z*reg)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
update = vars.get_update()
|
|
83
|
+
if 'inner' in self.children:
|
|
84
|
+
update = apply(self.children['inner'], update, params, grads=vars.grad, vars=vars)
|
|
85
|
+
|
|
86
|
+
z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
|
|
87
|
+
vars.update = z.neg()
|
|
88
|
+
|
|
89
|
+
return vars
|