torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -2,9 +2,9 @@ from typing import Literal
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Module,
|
|
5
|
+
from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
|
|
6
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
|
-
from ..
|
|
7
|
+
from ..opt_utils import ema_
|
|
8
8
|
from ..momentum.momentum import nag_
|
|
9
9
|
|
|
10
10
|
|
|
@@ -21,7 +21,7 @@ def msam_(
|
|
|
21
21
|
|
|
22
22
|
# inner args
|
|
23
23
|
inner: Module | None = None,
|
|
24
|
-
|
|
24
|
+
objective: Objective | None = None,
|
|
25
25
|
):
|
|
26
26
|
# weights w and wh, momentum μ, perturbation strength ρ
|
|
27
27
|
# w = wh + rho * v / ||v||
|
|
@@ -54,8 +54,8 @@ def msam_(
|
|
|
54
54
|
v1n = velocity_ / denom
|
|
55
55
|
|
|
56
56
|
if inner is not None:
|
|
57
|
-
assert
|
|
58
|
-
inner_update = TensorList(
|
|
57
|
+
assert objective is not None and inner is not None
|
|
58
|
+
inner_update = TensorList(step(objective, inner).get_updates())
|
|
59
59
|
|
|
60
60
|
else:
|
|
61
61
|
assert lr is not None
|
|
@@ -69,7 +69,7 @@ def msam_(
|
|
|
69
69
|
|
|
70
70
|
return update
|
|
71
71
|
|
|
72
|
-
class
|
|
72
|
+
class MSAMMomentum(TensorTransform):
|
|
73
73
|
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
74
74
|
|
|
75
75
|
This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
|
|
@@ -93,46 +93,40 @@ class MSAM(Transform):
|
|
|
93
93
|
lerp (bool, optional):
|
|
94
94
|
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
95
95
|
|
|
96
|
-
Examples:
|
|
97
|
-
MSAM
|
|
96
|
+
### Examples:
|
|
98
97
|
|
|
99
|
-
|
|
98
|
+
MSAM
|
|
100
99
|
|
|
101
|
-
|
|
102
|
-
model.parameters(),
|
|
103
|
-
tz.m.MSAM(1e-3)
|
|
104
|
-
)
|
|
100
|
+
```python
|
|
105
101
|
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
opt = tz.Optimizer(
|
|
103
|
+
model.parameters(),
|
|
104
|
+
tz.m.MSAM(1e-3)
|
|
105
|
+
)
|
|
106
|
+
```
|
|
108
107
|
|
|
109
|
-
|
|
108
|
+
Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
|
|
109
|
+
To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
111
|
+
```python
|
|
112
|
+
opt = tz.Optimizer(
|
|
113
|
+
model.parameters(),
|
|
114
|
+
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
115
|
+
tz.m.Debias(0.9, 0.999),
|
|
116
|
+
)
|
|
117
|
+
```
|
|
116
118
|
"""
|
|
117
|
-
|
|
119
|
+
|
|
118
120
|
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
119
|
-
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
120
|
-
if self._USES_LR: defaults['lr'] = lr
|
|
121
|
+
defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
121
122
|
super().__init__(defaults, uses_grad=False)
|
|
122
123
|
|
|
123
124
|
@torch.no_grad
|
|
124
|
-
def
|
|
125
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
125
126
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
126
|
-
|
|
127
|
-
lerp = s['lerp']
|
|
128
|
-
nesterov = s['nesterov']
|
|
127
|
+
fs = settings[0]
|
|
129
128
|
|
|
130
|
-
|
|
131
|
-
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
132
|
-
|
|
133
|
-
else:
|
|
134
|
-
lr=None
|
|
135
|
-
momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
|
|
129
|
+
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
136
130
|
|
|
137
131
|
return msam_(
|
|
138
132
|
TensorList(tensors),
|
|
@@ -142,16 +136,16 @@ class MSAM(Transform):
|
|
|
142
136
|
lr=lr,
|
|
143
137
|
rho=rho,
|
|
144
138
|
weight_decay=weight_decay,
|
|
145
|
-
nesterov=nesterov,
|
|
146
|
-
lerp=lerp,
|
|
139
|
+
nesterov=fs['nesterov'],
|
|
140
|
+
lerp=fs['lerp'],
|
|
147
141
|
|
|
148
142
|
# inner args
|
|
149
|
-
inner=
|
|
150
|
-
|
|
143
|
+
inner=None,
|
|
144
|
+
objective=None,
|
|
151
145
|
)
|
|
152
146
|
|
|
153
147
|
|
|
154
|
-
class
|
|
148
|
+
class MSAM(Transform):
|
|
155
149
|
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
156
150
|
|
|
157
151
|
Note:
|
|
@@ -160,7 +154,7 @@ class MSAMObjective(MSAM):
|
|
|
160
154
|
to an incorrect update rule.
|
|
161
155
|
|
|
162
156
|
Args:
|
|
163
|
-
modules (Chainable): modules that will
|
|
157
|
+
modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
|
|
164
158
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
165
159
|
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
166
160
|
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
@@ -169,20 +163,44 @@ class MSAMObjective(MSAM):
|
|
|
169
163
|
Defaults to False.
|
|
170
164
|
|
|
171
165
|
Examples:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
166
|
+
AdamW-MSAM
|
|
167
|
+
|
|
168
|
+
```py
|
|
169
|
+
opt = tz.Optimizer(
|
|
170
|
+
bench.parameters(),
|
|
171
|
+
tz.m.MSAMObjective(
|
|
172
|
+
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|
|
173
|
+
rho=1.
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
```
|
|
183
177
|
"""
|
|
184
|
-
_USES_LR = False
|
|
185
178
|
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
186
|
-
|
|
179
|
+
defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
180
|
+
super().__init__(defaults)
|
|
181
|
+
|
|
187
182
|
self.set_child('modules', modules)
|
|
188
183
|
|
|
184
|
+
|
|
185
|
+
@torch.no_grad
|
|
186
|
+
def apply_states(self, objective, states, settings):
|
|
187
|
+
velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
|
|
188
|
+
fs = settings[0]
|
|
189
|
+
|
|
190
|
+
momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)
|
|
191
|
+
|
|
192
|
+
return msam_(
|
|
193
|
+
TensorList(objective.get_updates()),
|
|
194
|
+
params=TensorList(objective.params),
|
|
195
|
+
velocity_=velocity,
|
|
196
|
+
momentum=momentum,
|
|
197
|
+
lr=None,
|
|
198
|
+
rho=rho,
|
|
199
|
+
weight_decay=weight_decay,
|
|
200
|
+
nesterov=fs['nesterov'],
|
|
201
|
+
lerp=fs['lerp'],
|
|
202
|
+
|
|
203
|
+
# inner args
|
|
204
|
+
inner=self.children["modules"],
|
|
205
|
+
objective=objective,
|
|
206
|
+
)
|
|
@@ -1,152 +1,85 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
2
|
import math
|
|
3
|
-
import
|
|
4
|
-
from collections.abc import Iterable, Sequence
|
|
5
|
-
from typing import Literal
|
|
3
|
+
from collections.abc import Iterable
|
|
6
4
|
|
|
7
5
|
import torch
|
|
8
6
|
|
|
9
|
-
from ...core import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
7
|
+
from ...core import TensorTransform, Transform
|
|
8
|
+
from ...linalg.orthogonalize import orthogonalize as _orthogonalize, OrthogonalizeMethod
|
|
12
9
|
|
|
13
10
|
def reverse_dims(t:torch.Tensor):
|
|
14
11
|
return t.permute(*reversed(range(t.ndim)))
|
|
15
12
|
|
|
16
|
-
def _is_at_least_2d(p: torch.Tensor):
|
|
17
|
-
if
|
|
13
|
+
def _is_at_least_2d(p: torch.Tensor, channel_first:bool):
|
|
14
|
+
if p.ndim < 2: return False
|
|
15
|
+
if channel_first and (p.size(0) > 1) and (p.size(1) > 1): return True
|
|
16
|
+
if (not channel_first) and (p.size(-2) > 1) and (p.size(-1) > 1): return True
|
|
18
17
|
return False
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
29
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
30
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
31
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
32
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
33
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
34
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
35
|
-
"""
|
|
36
|
-
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
37
|
-
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
38
|
-
X = G.bfloat16()
|
|
39
|
-
if G.size(-2) > G.size(-1):
|
|
40
|
-
X = X.mT
|
|
41
|
-
|
|
42
|
-
# Ensure spectral norm is at most 1
|
|
43
|
-
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
44
|
-
# Perform the NS iterations
|
|
45
|
-
for _ in range(steps):
|
|
46
|
-
A = X @ X.mT
|
|
47
|
-
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
48
|
-
X = a * X + B @ X
|
|
49
|
-
|
|
50
|
-
if G.size(-2) > G.size(-1):
|
|
51
|
-
X = X.mT
|
|
52
|
-
return X
|
|
53
|
-
|
|
54
|
-
# stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
55
|
-
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
56
|
-
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
57
|
-
@torch.no_grad
|
|
58
|
-
def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
|
|
59
|
-
"""
|
|
60
|
-
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
61
|
-
"""
|
|
62
|
-
X = G.view(G.shape[0], -1)
|
|
63
|
-
|
|
64
|
-
t = False
|
|
65
|
-
if X.size(0) > X.size(1):
|
|
66
|
-
X = X.T
|
|
67
|
-
t = True
|
|
68
|
-
|
|
69
|
-
orth_X: torch.Tensor | None = None
|
|
70
|
-
try:
|
|
71
|
-
u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
|
|
72
|
-
orth_X = u @ vt
|
|
73
|
-
except RuntimeError:
|
|
74
|
-
# if warn: logging.warning('Failed to perform SVD, adding some noise.')
|
|
75
|
-
try:
|
|
76
|
-
u, s, v = torch.svd_lowrank(
|
|
77
|
-
X,
|
|
78
|
-
q=1, # assume rank is at least 1
|
|
79
|
-
M=1e-4 * X.mean() * torch.randn_like(X))
|
|
80
|
-
orth_X = u @ v.T
|
|
81
|
-
except RuntimeError:
|
|
82
|
-
if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
|
|
83
|
-
' skipping gradient orthogonalisation'))
|
|
84
|
-
if orth_X is not None:
|
|
85
|
-
if t: orth_X = orth_X.T
|
|
86
|
-
return orth_X.view_as(G)
|
|
87
|
-
|
|
88
|
-
return G # fail
|
|
19
|
+
def _orthogonalize_format(
|
|
20
|
+
tensor: torch.Tensor,
|
|
21
|
+
method: OrthogonalizeMethod,
|
|
22
|
+
channel_first: bool,
|
|
23
|
+
):
|
|
24
|
+
"""orthogonalize either 1st two dims if channel first or last two otherwise"""
|
|
25
|
+
if channel_first:
|
|
26
|
+
return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
|
|
89
27
|
|
|
28
|
+
return _orthogonalize(tensor, method=method)
|
|
90
29
|
|
|
91
30
|
@torch.no_grad
|
|
92
|
-
def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor,
|
|
93
|
-
"""
|
|
31
|
+
def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, channel_first: bool):
|
|
32
|
+
"""``channel_first`` means it applies to first two dims, otherwise to last two dims"""
|
|
94
33
|
# this is from https://github.com/leloykun/adaptive-muon
|
|
95
34
|
# Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
96
|
-
if
|
|
97
|
-
else: X = torch.einsum('ij
|
|
35
|
+
if channel_first: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
|
|
36
|
+
else: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
|
|
98
37
|
return X
|
|
99
38
|
|
|
100
39
|
|
|
101
40
|
# code from
|
|
102
41
|
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
|
103
|
-
def adjust_lr_for_muon(lr, param_shape):
|
|
104
|
-
A, B = param_shape[:2]
|
|
42
|
+
def adjust_lr_for_muon(lr, param_shape, channel_first:bool):
|
|
43
|
+
if channel_first: A, B = param_shape[:2]
|
|
44
|
+
else: A, B = param_shape[-2:]
|
|
45
|
+
|
|
105
46
|
# We adjust the learning rate and weight decay based on the size of the parameter matrix
|
|
106
47
|
# as describted in the paper
|
|
107
48
|
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
|
108
49
|
adjusted_lr = lr * adjusted_ratio
|
|
109
50
|
return adjusted_lr
|
|
110
51
|
|
|
111
|
-
def _orthogonalize_tensor(
|
|
112
|
-
tensor: torch.Tensor,
|
|
113
|
-
steps: int = 5,
|
|
114
|
-
method: Literal["newton-schulz", "svd"] = "newton-schulz",
|
|
115
|
-
):
|
|
116
|
-
if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
|
|
117
|
-
if method == 'svd': return _svd_orthogonalize(tensor, False)
|
|
118
|
-
raise ValueError(method)
|
|
119
|
-
|
|
120
52
|
|
|
121
53
|
def orthogonalize_grads_(
|
|
122
54
|
params: Iterable[torch.Tensor],
|
|
123
|
-
steps: int = 5,
|
|
124
55
|
dual_norm_correction=False,
|
|
125
|
-
method:
|
|
56
|
+
method: OrthogonalizeMethod = "newtonschulz",
|
|
57
|
+
channel_first:bool=True,
|
|
126
58
|
):
|
|
127
|
-
"""
|
|
59
|
+
"""Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
128
60
|
|
|
129
61
|
This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
|
|
130
62
|
|
|
131
63
|
Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
132
64
|
Args:
|
|
133
65
|
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
134
|
-
steps (int, optional):
|
|
135
|
-
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
136
66
|
dual_norm_correction (bool, optional):
|
|
137
67
|
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
138
68
|
method (str, optional):
|
|
139
69
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
70
|
+
channel_first (bool, optional):
|
|
71
|
+
if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
|
|
72
|
+
are considered batch dimensions.
|
|
140
73
|
"""
|
|
141
74
|
for p in params:
|
|
142
|
-
if (p.grad is not None) and _is_at_least_2d(p.grad):
|
|
143
|
-
X =
|
|
144
|
-
if dual_norm_correction: X = _dual_norm_correction(X, p.grad,
|
|
75
|
+
if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
|
|
76
|
+
X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
|
|
77
|
+
if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
|
|
145
78
|
p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
|
|
146
79
|
|
|
147
80
|
|
|
148
81
|
|
|
149
|
-
class Orthogonalize(
|
|
82
|
+
class Orthogonalize(TensorTransform):
|
|
150
83
|
"""Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
|
|
151
84
|
|
|
152
85
|
To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
|
|
@@ -156,22 +89,21 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
156
89
|
To make Muon, use Split with Adam on 1d params
|
|
157
90
|
|
|
158
91
|
Args:
|
|
159
|
-
ns_steps (int, optional):
|
|
160
|
-
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
161
92
|
adjust_lr (bool, optional):
|
|
162
93
|
Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
|
|
163
94
|
dual_norm_correction (bool, optional):
|
|
164
95
|
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
165
96
|
method (str, optional):
|
|
166
|
-
Newton-Schulz is very fast, SVD is
|
|
167
|
-
|
|
168
|
-
|
|
97
|
+
Newton-Schulz is very fast, SVD is slow but can be more precise.
|
|
98
|
+
channel_first (bool, optional):
|
|
99
|
+
if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
|
|
100
|
+
are considered batch dimensions.
|
|
169
101
|
|
|
170
102
|
## Examples:
|
|
171
103
|
|
|
172
104
|
standard Muon with Adam fallback
|
|
173
105
|
```py
|
|
174
|
-
opt = tz.
|
|
106
|
+
opt = tz.Optimizer(
|
|
175
107
|
model.head.parameters(),
|
|
176
108
|
tz.m.Split(
|
|
177
109
|
# apply muon only to 2D+ parameters
|
|
@@ -190,56 +122,62 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
190
122
|
Reference:
|
|
191
123
|
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
192
124
|
"""
|
|
193
|
-
def __init__(self,
|
|
194
|
-
method:
|
|
195
|
-
defaults = dict(orthogonalize=True,
|
|
196
|
-
super().__init__(
|
|
125
|
+
def __init__(self, adjust_lr=False, dual_norm_correction=False,
|
|
126
|
+
method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
|
|
127
|
+
defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
|
|
128
|
+
super().__init__(defaults=defaults)
|
|
197
129
|
|
|
198
130
|
@torch.no_grad
|
|
199
|
-
def
|
|
200
|
-
orthogonalize,
|
|
201
|
-
'orthogonalize', '
|
|
131
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
132
|
+
orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
|
|
133
|
+
'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)
|
|
202
134
|
|
|
203
135
|
if not orthogonalize: return tensor
|
|
204
136
|
|
|
205
|
-
if _is_at_least_2d(tensor):
|
|
137
|
+
if _is_at_least_2d(tensor, channel_first=channel_first):
|
|
206
138
|
|
|
207
|
-
X =
|
|
139
|
+
X = _orthogonalize_format(tensor, method, channel_first=channel_first)
|
|
208
140
|
|
|
209
141
|
if dual_norm_correction:
|
|
210
|
-
X = _dual_norm_correction(X, tensor,
|
|
142
|
+
X = _dual_norm_correction(X, tensor, channel_first=channel_first)
|
|
211
143
|
|
|
212
144
|
if adjust_lr:
|
|
213
|
-
X.mul_(adjust_lr_for_muon(1, param.shape))
|
|
145
|
+
X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))
|
|
214
146
|
|
|
215
147
|
return X.view_as(param)
|
|
216
148
|
|
|
217
149
|
return tensor
|
|
218
150
|
|
|
219
151
|
|
|
220
|
-
class DualNormCorrection(
|
|
152
|
+
class DualNormCorrection(TensorTransform):
|
|
221
153
|
"""Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
|
|
222
154
|
Orthogonalize already has this built in with the `dual_norm_correction` setting."""
|
|
223
|
-
def __init__(self,
|
|
224
|
-
|
|
155
|
+
def __init__(self, channel_first: bool = True):
|
|
156
|
+
defaults = dict(channel_first=channel_first)
|
|
157
|
+
super().__init__(defaults)
|
|
225
158
|
|
|
226
|
-
|
|
159
|
+
@torch.no_grad
|
|
160
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
227
161
|
assert grad is not None
|
|
228
162
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
229
|
-
return _dual_norm_correction(tensor, grad,
|
|
163
|
+
return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
|
|
230
164
|
return tensor
|
|
231
165
|
|
|
232
166
|
|
|
233
167
|
class MuonAdjustLR(Transform):
|
|
234
168
|
"""LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
|
|
235
|
-
Orthogonalize already has this built in with the
|
|
236
|
-
def __init__(self,
|
|
237
|
-
defaults = dict(alpha=alpha)
|
|
238
|
-
super().__init__(defaults=defaults
|
|
169
|
+
Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
|
|
170
|
+
def __init__(self, channel_first: bool = True, alpha: float = 1):
|
|
171
|
+
defaults = dict(channel_first=channel_first, alpha=alpha)
|
|
172
|
+
super().__init__(defaults=defaults)
|
|
239
173
|
|
|
240
|
-
|
|
174
|
+
@torch.no_grad
|
|
175
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
241
176
|
alphas = [s['alpha'] for s in settings]
|
|
242
|
-
|
|
177
|
+
channel_first = [s["channel_first=channel_first"] for s in settings]
|
|
178
|
+
tensors_alphas = [
|
|
179
|
+
(t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
|
|
180
|
+
]
|
|
243
181
|
tensors = [i[0] for i in tensors_alphas]
|
|
244
182
|
a = [i[1] for i in alphas]
|
|
245
183
|
torch._foreach_mul_(tensors, a)
|