torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Modular, TensorwiseTransform, Target, Transform
|
|
10
|
+
from ...utils import enable_compilation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def reverse_dims(t:torch.Tensor):
|
|
14
|
+
return t.permute(*reversed(range(t.ndim)))
|
|
15
|
+
|
|
16
|
+
def _is_at_least_2d(p: torch.Tensor):
|
|
17
|
+
if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
# stolen from:
|
|
21
|
+
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
22
|
+
@enable_compilation
|
|
23
|
+
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
|
|
26
|
+
|
|
27
|
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
28
|
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
29
|
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
30
|
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
31
|
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
32
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
33
|
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
34
|
+
"""
|
|
35
|
+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
36
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
37
|
+
X = G.bfloat16()
|
|
38
|
+
if G.size(-2) > G.size(-1):
|
|
39
|
+
X = X.mT
|
|
40
|
+
|
|
41
|
+
# Ensure spectral norm is at most 1
|
|
42
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
43
|
+
# Perform the NS iterations
|
|
44
|
+
for _ in range(steps):
|
|
45
|
+
A = X @ X.mT
|
|
46
|
+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
47
|
+
X = a * X + B @ X
|
|
48
|
+
|
|
49
|
+
if G.size(-2) > G.size(-1):
|
|
50
|
+
X = X.mT
|
|
51
|
+
return X
|
|
52
|
+
|
|
53
|
+
# stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
54
|
+
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
55
|
+
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
56
|
+
@torch.no_grad
|
|
57
|
+
def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
|
|
58
|
+
"""
|
|
59
|
+
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
60
|
+
"""
|
|
61
|
+
X = G.view(G.shape[0], -1)
|
|
62
|
+
|
|
63
|
+
t = False
|
|
64
|
+
if X.size(0) > X.size(1):
|
|
65
|
+
X = X.T
|
|
66
|
+
t = True
|
|
67
|
+
|
|
68
|
+
orth_X: torch.Tensor | None = None
|
|
69
|
+
try:
|
|
70
|
+
u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
|
|
71
|
+
orth_X = u @ vt
|
|
72
|
+
except RuntimeError:
|
|
73
|
+
# if warn: logging.warning('Failed to perform SVD, adding some noise.')
|
|
74
|
+
try:
|
|
75
|
+
u, s, v = torch.svd_lowrank(
|
|
76
|
+
X,
|
|
77
|
+
q=1, # assume rank is at least 1
|
|
78
|
+
M=1e-4 * X.mean() * torch.randn_like(X))
|
|
79
|
+
orth_X = u @ v.T
|
|
80
|
+
except RuntimeError:
|
|
81
|
+
if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
|
|
82
|
+
' skipping gradient orthogonalisation'))
|
|
83
|
+
if orth_X is not None:
|
|
84
|
+
if t: orth_X = orth_X.T
|
|
85
|
+
return orth_X.view_as(G)
|
|
86
|
+
|
|
87
|
+
return G # fail
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@torch.no_grad
|
|
91
|
+
def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, batch_first):
|
|
92
|
+
"""batch first means it applies to last 2 dims, otherwise to 1st two dims"""
|
|
93
|
+
# this is from https://github.com/leloykun/adaptive-muon
|
|
94
|
+
# Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
95
|
+
if batch_first: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
|
|
96
|
+
else: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
|
|
97
|
+
return X
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
# code from
|
|
101
|
+
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
|
102
|
+
def adjust_lr_for_muon(lr, param_shape):
|
|
103
|
+
A, B = param_shape[:2]
|
|
104
|
+
# We adjust the learning rate and weight decay based on the size of the parameter matrix
|
|
105
|
+
# as describted in the paper
|
|
106
|
+
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
|
107
|
+
adjusted_lr = lr * adjusted_ratio
|
|
108
|
+
return adjusted_lr
|
|
109
|
+
|
|
110
|
+
def _orthogonalize_tensor(
|
|
111
|
+
tensor: torch.Tensor,
|
|
112
|
+
steps: int = 5,
|
|
113
|
+
method: Literal["newton-schulz", "svd"] = "newton-schulz",
|
|
114
|
+
):
|
|
115
|
+
if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
|
|
116
|
+
if method == 'svd': return _svd_orthogonalize(tensor, False)
|
|
117
|
+
raise ValueError(method)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def orthogonalize_grads_(
|
|
121
|
+
params: Iterable[torch.Tensor],
|
|
122
|
+
steps: int = 5,
|
|
123
|
+
dual_norm_correction=False,
|
|
124
|
+
method: Literal["newton-schulz", "svd"] = "newton-schulz",
|
|
125
|
+
):
|
|
126
|
+
"""Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
127
|
+
|
|
128
|
+
This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
|
|
129
|
+
|
|
130
|
+
Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
131
|
+
Args:
|
|
132
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
133
|
+
steps (int, optional):
|
|
134
|
+
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
135
|
+
dual_norm_correction (bool, optional):
|
|
136
|
+
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
137
|
+
method (str, optional):
|
|
138
|
+
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
139
|
+
"""
|
|
140
|
+
for p in params:
|
|
141
|
+
if (p.grad is not None) and _is_at_least_2d(p.grad):
|
|
142
|
+
X = _orthogonalize_tensor(p.grad, steps, method)
|
|
143
|
+
if dual_norm_correction: X = _dual_norm_correction(X, p.grad, batch_first=False)
|
|
144
|
+
p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Orthogonalize(TensorwiseTransform):
|
|
149
|
+
"""Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
|
|
150
|
+
|
|
151
|
+
To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
|
|
152
|
+
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
153
|
+
Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
|
|
154
|
+
|
|
155
|
+
To make Muon, use Split with Adam on 1d params: TODO code example.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
ns_steps (int, optional):
|
|
159
|
+
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
160
|
+
adjust_lr (bool, optional):
|
|
161
|
+
Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
|
|
162
|
+
dual_norm_correction (bool, optional):
|
|
163
|
+
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
164
|
+
method (str, optional):
|
|
165
|
+
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
166
|
+
target (str, optional):
|
|
167
|
+
what to set on vars.
|
|
168
|
+
"""
|
|
169
|
+
def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
|
|
170
|
+
method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
|
|
171
|
+
defaults = dict(orthogonalize=True, ns_steps=ns_steps, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower())
|
|
172
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
173
|
+
|
|
174
|
+
@torch.no_grad
|
|
175
|
+
def transform(self, tensor, param, grad, vars):
|
|
176
|
+
orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
|
|
177
|
+
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(self.settings[param])
|
|
178
|
+
|
|
179
|
+
if not orthogonalize: return tensor
|
|
180
|
+
|
|
181
|
+
if _is_at_least_2d(tensor):
|
|
182
|
+
|
|
183
|
+
X = _orthogonalize_tensor(tensor, ns_steps, method)
|
|
184
|
+
|
|
185
|
+
if dual_norm_correction:
|
|
186
|
+
X = _dual_norm_correction(X, tensor, batch_first=False)
|
|
187
|
+
|
|
188
|
+
if adjust_lr:
|
|
189
|
+
X.mul_(adjust_lr_for_muon(1, param.shape))
|
|
190
|
+
|
|
191
|
+
return X.view_as(param)
|
|
192
|
+
|
|
193
|
+
return tensor
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class DualNormCorrection(TensorwiseTransform):
|
|
197
|
+
"""Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
|
|
198
|
+
Orthogonalize already has this built in with the `dual_norm_correction` setting."""
|
|
199
|
+
def __init__(self, target: Target='update'):
|
|
200
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
201
|
+
|
|
202
|
+
def transform(self, tensor, param, grad, vars):
|
|
203
|
+
assert grad is not None
|
|
204
|
+
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
205
|
+
return _dual_norm_correction(tensor, grad, batch_first=False)
|
|
206
|
+
return tensor
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class MuonAdjustLR(Transform):
|
|
210
|
+
"""LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
|
|
211
|
+
Orthogonalize already has this built in with the `adjust_lr` setting, however you might want to move this to be later in the chain."""
|
|
212
|
+
def __init__(self, alpha: float = 1, target: Target='update'):
|
|
213
|
+
defaults = dict(alpha=alpha)
|
|
214
|
+
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
215
|
+
|
|
216
|
+
def transform(self, tensors, params, grads, vars):
|
|
217
|
+
alphas = self.get_settings('alpha', params=params)
|
|
218
|
+
tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
|
|
219
|
+
tensors = [i[0] for i in tensors_alphas]
|
|
220
|
+
a = [i[1] for i in alphas]
|
|
221
|
+
torch._foreach_mul_(tensors, a)
|
|
222
|
+
return tensors
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
import math
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ...core import Target, Transform
|
|
10
|
+
from ...utils import as_tensorlist
|
|
11
|
+
|
|
12
|
+
def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
|
|
13
|
+
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
|
|
17
|
+
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
18
|
+
|
|
19
|
+
reference
|
|
20
|
+
https://arxiv.org/abs/2501.04697
|
|
21
|
+
"""
|
|
22
|
+
params = as_tensorlist(params).with_grad()
|
|
23
|
+
grad = params.grad
|
|
24
|
+
grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OrthoGrad(Transform):
|
|
28
|
+
"""Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
32
|
+
renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
|
|
33
|
+
target (Target, optional): what to set on vars. Defaults to 'update'.
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
|
|
36
|
+
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
38
|
+
|
|
39
|
+
def transform(self, tensors, params, grads, vars):
|
|
40
|
+
settings = self.settings[params[0]]
|
|
41
|
+
eps = settings['eps']
|
|
42
|
+
renormalize = settings['renormalize']
|
|
43
|
+
|
|
44
|
+
params = as_tensorlist(params)
|
|
45
|
+
target = as_tensorlist(tensors)
|
|
46
|
+
|
|
47
|
+
scale = params.dot(target)/(params.dot(params) + eps)
|
|
48
|
+
if renormalize:
|
|
49
|
+
norm = target.global_vector_norm()
|
|
50
|
+
target -= params * scale
|
|
51
|
+
target *= (norm / target.global_vector_norm())
|
|
52
|
+
return target
|
|
53
|
+
|
|
54
|
+
target -= params * scale
|
|
55
|
+
return target
|
|
@@ -1,51 +1,103 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
from ...core import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, Chainable, Vars, apply
|
|
7
|
+
from ...utils import NumberList, TensorList
|
|
8
|
+
from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def rmsprop_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
exp_avg_sq_: TensorList,
|
|
14
|
+
smoothing: float | NumberList,
|
|
15
|
+
eps: float | NumberList,
|
|
16
|
+
debiased: bool,
|
|
17
|
+
step: int,
|
|
18
|
+
exp_avg_: TensorList | None = None,
|
|
19
|
+
max_exp_avg_sq_: TensorList | None = None,
|
|
20
|
+
pow: float = 2,
|
|
21
|
+
|
|
22
|
+
# inner args
|
|
23
|
+
inner: Module | None = None,
|
|
24
|
+
params: list[torch.Tensor] | None = None,
|
|
25
|
+
grads: list[torch.Tensor] | None = None,
|
|
26
|
+
vars: Vars | None = None,
|
|
27
|
+
):
|
|
28
|
+
"""returns `tensors_`"""
|
|
29
|
+
if exp_avg_ is not None:
|
|
30
|
+
sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
|
|
31
|
+
exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
32
|
+
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
33
|
+
else:
|
|
34
|
+
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
35
|
+
beta=smoothing,debiased=debiased,step=step,pow=pow)
|
|
36
|
+
|
|
37
|
+
if inner is not None:
|
|
38
|
+
assert params is not None
|
|
39
|
+
tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
|
|
40
|
+
|
|
41
|
+
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
42
|
+
|
|
43
|
+
class RMSprop(Transform):
|
|
44
|
+
"""Divides graient by EMA of gradient squares. Matches pytorch RMSprop if "init" is set to "zeros".
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
48
|
+
eps (float, optional): epsilon for division. Defaults to 1e-8.
|
|
49
|
+
centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
|
|
50
|
+
debiased (bool, optional): applies Adam debiasing. Defaults to False.
|
|
51
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
52
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
53
|
+
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
54
|
+
inner (Chainable | None, optional): Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
|
|
55
|
+
"""
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
smoothing: float = 0.99,
|
|
59
|
+
eps: float = 1e-8,
|
|
60
|
+
centered: bool = False,
|
|
61
|
+
debiased: bool = False,
|
|
62
|
+
amsgrad: bool = False,
|
|
63
|
+
pow: float = 2,
|
|
64
|
+
init: Literal["zeros", "update"] = "update",
|
|
65
|
+
inner: Chainable | None = None,
|
|
66
|
+
):
|
|
67
|
+
defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
|
|
68
|
+
super().__init__(defaults=defaults, uses_grad=False)
|
|
69
|
+
self.current_step = 0
|
|
70
|
+
if inner is not None:
|
|
71
|
+
self.set_child('inner', inner)
|
|
72
|
+
|
|
73
|
+
def transform(self, tensors, params, grads, vars):
|
|
74
|
+
self.current_step += 1
|
|
75
|
+
|
|
76
|
+
smoothing,eps = self.get_settings('smoothing', 'eps', params=params, cls=NumberList)
|
|
77
|
+
centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
|
|
78
|
+
|
|
79
|
+
exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
|
|
80
|
+
exp_avg = self.get_state('exp_avg', params=params, cls=TensorList) if centered else None
|
|
81
|
+
max_exp_avg_sq = self.get_state('max_exp_avg_sq', params=params, cls=TensorList) if amsgrad else None
|
|
82
|
+
|
|
83
|
+
if init == 'update' and self.current_step == 1:
|
|
84
|
+
exp_avg_sq.set_([t**2 for t in tensors])
|
|
85
|
+
if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
|
|
86
|
+
|
|
87
|
+
return rmsprop_(
|
|
88
|
+
TensorList(tensors),
|
|
89
|
+
exp_avg_sq_=exp_avg_sq,
|
|
90
|
+
smoothing=smoothing,
|
|
91
|
+
eps=eps,
|
|
92
|
+
debiased=debiased,
|
|
93
|
+
step=self.current_step,
|
|
94
|
+
exp_avg_=exp_avg,
|
|
95
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
96
|
+
pow=pow,
|
|
97
|
+
|
|
98
|
+
# inner args
|
|
99
|
+
inner=self.children.get("inner", None),
|
|
100
|
+
params=params,
|
|
101
|
+
grads=grads,
|
|
102
|
+
vars=vars,
|
|
103
|
+
)
|