torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -3,8 +3,10 @@ import warnings
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import
|
|
6
|
+
from ...core import TensorTransform, Chainable
|
|
7
|
+
from ...utils import unpack_dicts, unpack_states, TensorList, NumberList
|
|
7
8
|
from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
|
|
9
|
+
from ...linalg import torch_linalg
|
|
8
10
|
|
|
9
11
|
@torch.no_grad
|
|
10
12
|
def update_soap_covariances_(
|
|
@@ -20,52 +22,48 @@ def update_soap_covariances_(
|
|
|
20
22
|
else: GG.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
|
|
21
23
|
|
|
22
24
|
@torch.no_grad
|
|
23
|
-
def project(
|
|
25
|
+
def project(tensor: torch.Tensor, Q: list[torch.Tensor | None]):
|
|
24
26
|
"""
|
|
25
27
|
Projects the gradient to the eigenbases of the preconditioner.
|
|
26
28
|
"""
|
|
27
|
-
for
|
|
28
|
-
if
|
|
29
|
-
|
|
29
|
+
for M in Q:
|
|
30
|
+
if M is not None:
|
|
31
|
+
tensor = torch.tensordot(tensor, M, dims=[[0], [0]]) # pyright:ignore[reportArgumentType]
|
|
30
32
|
else:
|
|
31
|
-
permute_order = list(range(1, len(
|
|
32
|
-
|
|
33
|
+
permute_order = list(range(1, len(tensor.shape))) + [0]
|
|
34
|
+
tensor = tensor.permute(permute_order)
|
|
33
35
|
|
|
34
|
-
return
|
|
36
|
+
return tensor
|
|
35
37
|
|
|
36
38
|
@torch.no_grad
|
|
37
|
-
def project_back(
|
|
39
|
+
def project_back(tensor: torch.Tensor, Q: list[torch.Tensor| None]):
|
|
38
40
|
"""
|
|
39
41
|
Projects the gradient back to the original space.
|
|
40
42
|
"""
|
|
41
|
-
for
|
|
42
|
-
if
|
|
43
|
-
|
|
43
|
+
for M in Q:
|
|
44
|
+
if M is not None:
|
|
45
|
+
tensor = torch.tensordot(tensor, M, dims=[[0], [1]]) # pyright:ignore[reportArgumentType]
|
|
44
46
|
else:
|
|
45
|
-
permute_order = list(range(1, len(
|
|
46
|
-
|
|
47
|
+
permute_order = list(range(1, len(tensor.shape))) + [0]
|
|
48
|
+
tensor = tensor.permute(permute_order)
|
|
47
49
|
|
|
48
|
-
return
|
|
50
|
+
return tensor
|
|
49
51
|
|
|
50
52
|
# function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py
|
|
51
53
|
@torch.no_grad
|
|
52
|
-
def get_orthogonal_matrix(
|
|
54
|
+
def get_orthogonal_matrix(mats: list[torch.Tensor | None]):
|
|
53
55
|
"""
|
|
54
56
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
final = []
|
|
58
|
-
for
|
|
60
|
+
for M in mats:
|
|
59
61
|
|
|
60
|
-
if
|
|
61
|
-
final.append(
|
|
62
|
+
if M is None:
|
|
63
|
+
final.append(None)
|
|
62
64
|
continue
|
|
63
65
|
|
|
64
|
-
|
|
65
|
-
_, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
66
|
-
except torch.linalg.LinAlgError:
|
|
67
|
-
_, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
|
|
68
|
-
Q = Q.to(m.dtype)
|
|
66
|
+
_, Q = torch_linalg.eigh(M + 1e-30 * torch.eye(M.shape[0], device=M.device), retry_float64=True)
|
|
69
67
|
|
|
70
68
|
Q = torch.flip(Q, [1])
|
|
71
69
|
final.append(Q)
|
|
@@ -78,30 +76,33 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
|
|
|
78
76
|
"""
|
|
79
77
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
80
78
|
followed by torch.linalg.qr decomposition.
|
|
81
|
-
|
|
79
|
+
|
|
80
|
+
Approximately modifies ``exp_avg_sq`` to be in the new eigenbases.
|
|
81
|
+
"""
|
|
82
82
|
final = []
|
|
83
83
|
|
|
84
|
-
for ind, (
|
|
84
|
+
for ind, (M, O) in enumerate(zip(GG, Q_list)):
|
|
85
85
|
|
|
86
86
|
# skip 1d or large dims
|
|
87
|
-
if
|
|
88
|
-
final.append(
|
|
87
|
+
if M is None:
|
|
88
|
+
final.append(None)
|
|
89
89
|
continue
|
|
90
|
-
assert o is not None
|
|
91
90
|
|
|
92
|
-
|
|
91
|
+
assert O is not None
|
|
92
|
+
|
|
93
|
+
est_eig = torch.diagonal(O.T @ M @ O)
|
|
93
94
|
sort_idx = torch.argsort(est_eig, descending=True)
|
|
94
95
|
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
95
96
|
|
|
96
|
-
power_iter =
|
|
97
|
-
Q, _ =
|
|
97
|
+
power_iter = M @ O[:, sort_idx]
|
|
98
|
+
Q, _ = torch_linalg.qr(power_iter.to(torch.float32), retry_float64=True)
|
|
98
99
|
Q = Q.to(power_iter.dtype)
|
|
99
100
|
|
|
100
101
|
final.append(Q)
|
|
101
102
|
|
|
102
103
|
return final, exp_avg_sq
|
|
103
104
|
|
|
104
|
-
class SOAP(
|
|
105
|
+
class SOAP(TensorTransform):
|
|
105
106
|
"""SOAP (ShampoO with Adam in the Preconditioner's eigenbasis from https://arxiv.org/abs/2409.11321).
|
|
106
107
|
|
|
107
108
|
Args:
|
|
@@ -111,35 +112,42 @@ class SOAP(Transform):
|
|
|
111
112
|
beta for covariance matrices accumulators. Can be None, then it just sums them like Adagrad (which works worse). Defaults to 0.95.
|
|
112
113
|
precond_freq (int, optional): How often to update the preconditioner. Defaults to 10.
|
|
113
114
|
merge_small (bool, optional): Whether to merge small dims. Defaults to True.
|
|
114
|
-
max_dim (int, optional): Won't precondition dims larger than this. Defaults to
|
|
115
|
+
max_dim (int, optional): Won't precondition dims larger than this. Defaults to 10_000.
|
|
115
116
|
precondition_1d (bool, optional):
|
|
116
117
|
Whether to precondition 1d params (SOAP paper sets this to False). Defaults to True.
|
|
117
118
|
eps (float, optional):
|
|
118
119
|
epsilon for dividing first momentum by second. Defaults to 1e-8.
|
|
119
|
-
|
|
120
|
-
|
|
120
|
+
debias (bool, optional):
|
|
121
|
+
enables adam bias correction. Defaults to True.
|
|
122
|
+
proj_exp_avg (bool, optional):
|
|
123
|
+
if True, maintains exponential average of gradients (momentum) in projected space.
|
|
124
|
+
If False - in original space Defaults to True.
|
|
121
125
|
alpha (float, optional):
|
|
122
126
|
learning rate. Defaults to 1.
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
127
|
+
inner (Chainable | None, optional):
|
|
128
|
+
output of this module is projected and Adam will run on it, but preconditioners are updated
|
|
129
|
+
from original gradients.
|
|
130
|
+
|
|
131
|
+
### Examples:
|
|
132
|
+
SOAP:
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
opt = tz.Modular(
|
|
136
|
+
model.parameters(),
|
|
137
|
+
tz.m.SOAP(),
|
|
138
|
+
tz.m.LR(1e-3)
|
|
139
|
+
)
|
|
140
|
+
```
|
|
141
|
+
Stabilized SOAP:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
opt = tz.Modular(
|
|
145
|
+
model.parameters(),
|
|
146
|
+
tz.m.SOAP(),
|
|
147
|
+
tz.m.NormalizeByEMA(max_ema_growth=1.2),
|
|
148
|
+
tz.m.LR(1e-2)
|
|
149
|
+
)
|
|
150
|
+
```
|
|
143
151
|
"""
|
|
144
152
|
def __init__(
|
|
145
153
|
self,
|
|
@@ -148,118 +156,174 @@ class SOAP(Transform):
|
|
|
148
156
|
shampoo_beta: float | None = 0.95,
|
|
149
157
|
precond_freq: int = 10,
|
|
150
158
|
merge_small: bool = True,
|
|
151
|
-
max_dim: int =
|
|
159
|
+
max_dim: int = 10_000,
|
|
152
160
|
precondition_1d: bool = True,
|
|
153
161
|
eps: float = 1e-8,
|
|
154
|
-
|
|
162
|
+
debias: bool = True,
|
|
163
|
+
proj_exp_avg: bool = True,
|
|
155
164
|
alpha: float = 1,
|
|
156
|
-
|
|
165
|
+
|
|
166
|
+
inner: Chainable | None = None,
|
|
157
167
|
):
|
|
158
|
-
defaults =
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
merge_small=merge_small,
|
|
164
|
-
max_dim=max_dim,
|
|
165
|
-
precondition_1d=precondition_1d,
|
|
166
|
-
eps=eps,
|
|
167
|
-
decay=decay,
|
|
168
|
-
bias_correction=bias_correction,
|
|
169
|
-
alpha=alpha,
|
|
170
|
-
)
|
|
171
|
-
super().__init__(defaults, uses_grad=False)
|
|
168
|
+
defaults = locals().copy()
|
|
169
|
+
del defaults['self'], defaults["inner"]
|
|
170
|
+
|
|
171
|
+
super().__init__(defaults)
|
|
172
|
+
self.set_child("inner", inner)
|
|
172
173
|
|
|
173
174
|
@torch.no_grad
|
|
174
|
-
def
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
175
|
+
def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
|
|
176
|
+
if setting["merge_small"]:
|
|
177
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
178
|
+
|
|
179
|
+
state["exp_avg_proj"] = torch.zeros_like(tensor)
|
|
180
|
+
state["exp_avg_sq_proj"] = torch.zeros_like(tensor)
|
|
180
181
|
|
|
181
|
-
|
|
182
|
-
|
|
182
|
+
if tensor.ndim <= 1 and not setting["precondition_1d"]:
|
|
183
|
+
state['GG'] = []
|
|
183
184
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
185
|
+
else:
|
|
186
|
+
max_dim = setting["max_dim"]
|
|
187
|
+
state['GG'] = [
|
|
188
|
+
torch.zeros(s, s, dtype=tensor.dtype, device=tensor.device) if 1<s<max_dim else None for s in tensor.shape
|
|
189
|
+
]
|
|
188
190
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
+
# either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
|
|
192
|
+
if len([i is not None for i in state['GG']]) == 0:
|
|
193
|
+
state['GG'] = None
|
|
191
194
|
|
|
192
|
-
|
|
193
|
-
|
|
195
|
+
# first covariance accumulation
|
|
196
|
+
if state['GG'] is not None:
|
|
197
|
+
update_soap_covariances_(tensor, GGs_=state['GG'], beta=setting["shampoo_beta"])
|
|
194
198
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
199
|
+
# get projection matrix with first gradients with eigh
|
|
200
|
+
try: state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
201
|
+
except torch.linalg.LinAlgError as e:
|
|
202
|
+
warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
|
|
203
|
+
state["GG"] = None
|
|
204
|
+
|
|
205
|
+
state['step'] = 0
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# no update to avoid running merge_dims twice
|
|
209
|
+
|
|
210
|
+
@torch.no_grad
|
|
211
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
212
|
+
# note
|
|
213
|
+
# do not modify tensors in-place
|
|
214
|
+
# because they are used to update preconditioner at the end
|
|
215
|
+
|
|
216
|
+
steps = [s["step"] for s in states]
|
|
217
|
+
if any(s == 0 for s in steps):
|
|
218
|
+
# skip 1st update so to avoid using current gradient in the projection
|
|
219
|
+
# I scale it instead to avoid issues with further modules
|
|
220
|
+
for s in states: s["step"] += 1
|
|
221
|
+
return TensorList(tensors).clamp(-0.1, 0.1)
|
|
222
|
+
# return TensorList(tensors).zero_()
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
fs = settings[0]
|
|
226
|
+
merged = []
|
|
227
|
+
projected = []
|
|
228
|
+
# ---------------------------------- project --------------------------------- #
|
|
229
|
+
|
|
230
|
+
for tensor, state, setting in zip(tensors, states, settings):
|
|
231
|
+
if setting["merge_small"]:
|
|
232
|
+
tensor, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(tensor, setting["max_dim"])
|
|
233
|
+
|
|
234
|
+
merged.append(tensor)
|
|
198
235
|
|
|
199
|
-
if state['GG'] is not None:
|
|
200
|
-
update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
|
|
201
|
-
try: state['Q'] = get_orthogonal_matrix(state['GG'])
|
|
202
|
-
except torch.linalg.LinAlgError as e:
|
|
203
|
-
warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
|
|
204
|
-
state["GG"] = None
|
|
205
|
-
|
|
206
|
-
state['step'] = 0
|
|
207
|
-
updates.append(tensors[i].clip(-0.1, 0.1))
|
|
208
|
-
continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
|
|
209
|
-
# I use scaled update instead as to not mess up with next modules.
|
|
210
|
-
|
|
211
|
-
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
212
|
-
# i.e. projecting to the eigenbases of matrices in state['GG']
|
|
213
|
-
t_projected = None
|
|
214
236
|
if state['GG'] is not None:
|
|
215
|
-
|
|
237
|
+
tensor = project(tensor, state['Q'])
|
|
216
238
|
|
|
217
|
-
|
|
218
|
-
# this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
|
|
219
|
-
exp_avg: torch.Tensor = state["exp_avg"]
|
|
220
|
-
exp_avg_sq_projected: torch.Tensor = state["exp_avg_sq_projected"]
|
|
239
|
+
projected.append(tensor)
|
|
221
240
|
|
|
222
|
-
|
|
241
|
+
# ------------------------ run adam in projected space ----------------------- #
|
|
242
|
+
exp_avg_proj, exp_avg_sq_proj = unpack_states(states, tensors, "exp_avg_proj", "exp_avg_sq_proj", must_exist=True, cls=TensorList)
|
|
243
|
+
alpha, beta1, beta2, eps = unpack_dicts(settings, "alpha", "beta1", "beta2", "eps", cls=NumberList)
|
|
223
244
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
exp_avg_sq_projected.mul_(beta2).addcmul_(t_projected, t_projected, value=1-beta2)
|
|
245
|
+
# lerp exp_avg in projected space
|
|
246
|
+
if fs["proj_exp_avg"]:
|
|
247
|
+
exp_avg_proj.lerp_(projected, weight=1-beta1)
|
|
228
248
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
249
|
+
# or lerp in original space and project
|
|
250
|
+
else:
|
|
251
|
+
exp_avg = exp_avg_proj
|
|
252
|
+
exp_avg.lerp_(merged, weight=1-beta1)
|
|
253
|
+
exp_avg_proj = []
|
|
254
|
+
for t, state, setting in zip(exp_avg, states, settings):
|
|
255
|
+
if state['GG'] is not None:
|
|
256
|
+
t = project(t, state["Q"])
|
|
257
|
+
exp_avg_proj.append(t)
|
|
233
258
|
|
|
234
|
-
|
|
235
|
-
# print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
|
|
259
|
+
exp_avg_sq_proj.mul_(beta2).addcmul_(projected, projected, value=1-beta2)
|
|
236
260
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
update = exp_avg_projected / denom
|
|
261
|
+
denom = exp_avg_sq_proj.sqrt().add_(eps)
|
|
262
|
+
dirs_proj = exp_avg_proj / denom
|
|
240
263
|
|
|
241
|
-
|
|
242
|
-
|
|
264
|
+
# ------------------------------- project back ------------------------------- #
|
|
265
|
+
dirs: list[torch.Tensor] = []
|
|
266
|
+
for dir, state, setting in zip(dirs_proj, states, settings):
|
|
267
|
+
if state['GG'] is not None:
|
|
268
|
+
dir = project_back(dir, state['Q'])
|
|
243
269
|
|
|
244
|
-
if setting[
|
|
245
|
-
|
|
246
|
-
bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
|
|
247
|
-
update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
|
|
248
|
-
elif alpha is not None:
|
|
249
|
-
update *= alpha
|
|
270
|
+
if setting["merge_small"]:
|
|
271
|
+
dir = _unmerge_small_dims(dir, state['flat_sizes'], state['sort_idxs'])
|
|
250
272
|
|
|
251
|
-
|
|
252
|
-
update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
|
|
273
|
+
dirs.append(dir)
|
|
253
274
|
|
|
254
|
-
updates.append(update)
|
|
255
|
-
state["step"] += 1
|
|
256
275
|
|
|
257
|
-
|
|
276
|
+
# -------------------------------- inner step -------------------------------- #
|
|
277
|
+
if "inner" in self.children:
|
|
278
|
+
tensors = self.inner_step_tensors("inner", tensors, clone=False,
|
|
279
|
+
params=params, grads=grads,loss=loss)
|
|
280
|
+
|
|
281
|
+
# we now have to re-merge small dims on updated tensors
|
|
282
|
+
merged = []
|
|
283
|
+
for tensor, state, setting in zip(tensors, states, settings):
|
|
284
|
+
if setting["merge_small"]:
|
|
285
|
+
tensor, _, _ = _merge_small_dims(tensor, setting["max_dim"])
|
|
286
|
+
merged.append(tensor)
|
|
287
|
+
|
|
288
|
+
# -------------------------- update preconditioners -------------------------- #
|
|
289
|
+
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
290
|
+
|
|
291
|
+
for tensor, state, setting in zip(merged, states, settings):
|
|
258
292
|
if state['GG'] is not None:
|
|
259
|
-
|
|
260
|
-
|
|
293
|
+
|
|
294
|
+
# lerp covariances
|
|
295
|
+
update_soap_covariances_(tensor, state['GG'], beta=setting["shampoo_beta"])
|
|
296
|
+
|
|
297
|
+
# (state['step'] - 1) since we start updating on 2nd step
|
|
298
|
+
if (state['step'] - 1) % setting['precond_freq'] == 0:
|
|
299
|
+
|
|
300
|
+
# unproject exp_avg before updating if it is maintained projected
|
|
301
|
+
exp_avg = None
|
|
302
|
+
if fs["proj_exp_avg"]:
|
|
303
|
+
exp_avg = project_back(state["exp_avg_proj"], state["Q"])
|
|
304
|
+
|
|
305
|
+
# update projection matrix and exp_avg_sq_proj
|
|
261
306
|
try:
|
|
262
|
-
state['Q'], state['
|
|
307
|
+
state['Q'], state['exp_avg_sq_proj'] = get_orthogonal_matrix_QR(
|
|
308
|
+
state["exp_avg_sq_proj"], state['GG'], state['Q'])
|
|
309
|
+
|
|
310
|
+
# re-project exp_avg if it is maintained projected
|
|
311
|
+
if fs["proj_exp_avg"]:
|
|
312
|
+
assert exp_avg is not None
|
|
313
|
+
state["exp_avg_proj"] = project(exp_avg, state["Q"])
|
|
314
|
+
|
|
263
315
|
except torch.linalg.LinAlgError:
|
|
264
316
|
pass
|
|
265
|
-
|
|
317
|
+
|
|
318
|
+
state["step"] += 1
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# ------------------------- bias-corrected step size ------------------------- #
|
|
322
|
+
if fs["debias"]:
|
|
323
|
+
steps1 = [s+1 for s in steps]
|
|
324
|
+
bias_correction1 = 1.0 - beta1 ** steps1
|
|
325
|
+
bias_correction2 = 1.0 - beta2 ** steps1
|
|
326
|
+
alpha = alpha * (bias_correction2 ** .5) / bias_correction1
|
|
327
|
+
|
|
328
|
+
torch._foreach_mul_(dirs, alpha)
|
|
329
|
+
return dirs
|