torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
6
|
-
from ...utils import NumberList, TensorList,
|
|
7
|
-
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
5
|
+
from ...core import Chainable, Transform, HVPMethod
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
|
|
8
7
|
from ..functional import initial_step_size
|
|
9
8
|
|
|
10
9
|
|
|
11
|
-
class MatrixMomentum(
|
|
10
|
+
class MatrixMomentum(Transform):
|
|
12
11
|
"""Second order momentum method.
|
|
13
12
|
|
|
14
13
|
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
@@ -23,17 +22,17 @@ class MatrixMomentum(Module):
|
|
|
23
22
|
Args:
|
|
24
23
|
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
25
24
|
hvp_method (str, optional):
|
|
26
|
-
Determines how
|
|
27
|
-
|
|
28
|
-
- ``"
|
|
29
|
-
|
|
30
|
-
- ``"
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
25
|
+
Determines how hessian-vector products are computed.
|
|
26
|
+
|
|
27
|
+
- ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
|
|
28
|
+
- ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
|
|
29
|
+
- ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
|
|
30
|
+
- ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
|
|
31
|
+
|
|
32
|
+
Defaults to ``"autograd"``.
|
|
33
|
+
h (float, optional):
|
|
34
|
+
The step size for finite difference if ``hvp_method`` is
|
|
35
|
+
``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
|
|
37
36
|
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
38
37
|
|
|
39
38
|
Reference:
|
|
@@ -44,51 +43,45 @@ class MatrixMomentum(Module):
|
|
|
44
43
|
self,
|
|
45
44
|
lr:float,
|
|
46
45
|
mu=0.1,
|
|
47
|
-
hvp_method:
|
|
46
|
+
hvp_method: HVPMethod = "autograd",
|
|
48
47
|
h: float = 1e-3,
|
|
49
48
|
adaptive:bool = False,
|
|
50
49
|
adapt_freq: int | None = None,
|
|
51
|
-
|
|
50
|
+
|
|
51
|
+
inner: Chainable | None = None,
|
|
52
52
|
):
|
|
53
53
|
defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
|
|
54
|
-
super().__init__(defaults)
|
|
55
|
-
|
|
56
|
-
if hvp_tfm is not None:
|
|
57
|
-
self.set_child('hvp_tfm', hvp_tfm)
|
|
54
|
+
super().__init__(defaults, inner=inner)
|
|
58
55
|
|
|
59
56
|
def reset_for_online(self):
|
|
60
57
|
super().reset_for_online()
|
|
61
58
|
self.clear_state_keys('p_prev')
|
|
62
59
|
|
|
63
60
|
@torch.no_grad
|
|
64
|
-
def
|
|
65
|
-
|
|
66
|
-
p = TensorList(
|
|
67
|
-
p_prev =
|
|
61
|
+
def update_states(self, objective, states, settings):
|
|
62
|
+
step = self.increment_counter("step", 0)
|
|
63
|
+
p = TensorList(objective.params)
|
|
64
|
+
p_prev = unpack_states(states, p, 'p_prev', init=p)
|
|
68
65
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
self.global_state["step"] = step + 1
|
|
66
|
+
fs = settings[0]
|
|
67
|
+
hvp_method = fs['hvp_method']
|
|
68
|
+
h = fs['h']
|
|
73
69
|
|
|
74
70
|
if step > 0:
|
|
75
71
|
s = p - p_prev
|
|
76
72
|
|
|
77
|
-
Hs, _ =
|
|
73
|
+
Hs, _ = objective.hessian_vector_product(s, at_x0=True, rgrad=None, hvp_method=hvp_method, h=h, retain_graph=False)
|
|
78
74
|
Hs = [t.detach() for t in Hs]
|
|
79
75
|
|
|
80
|
-
if 'hvp_tfm' in self.children:
|
|
81
|
-
Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
|
|
82
|
-
|
|
83
76
|
self.store(p, ("Hs", "s"), (Hs, s))
|
|
84
77
|
|
|
85
78
|
# -------------------------------- adaptive mu ------------------------------- #
|
|
86
|
-
if
|
|
87
|
-
g = TensorList(
|
|
79
|
+
if fs["adaptive"]:
|
|
80
|
+
g = TensorList(objective.get_grads())
|
|
88
81
|
|
|
89
|
-
if
|
|
82
|
+
if fs["adapt_freq"] is None:
|
|
90
83
|
# ---------------------------- deterministic case ---------------------------- #
|
|
91
|
-
g_prev =
|
|
84
|
+
g_prev = unpack_states(states, p, "g_prev", cls=TensorList)
|
|
92
85
|
y = g - g_prev
|
|
93
86
|
g_prev.copy_(g)
|
|
94
87
|
denom = y.global_vector_norm()
|
|
@@ -101,14 +94,14 @@ class MatrixMomentum(Module):
|
|
|
101
94
|
|
|
102
95
|
# we start on 1nd step, and want to adapt when we start, so use (step - 1)
|
|
103
96
|
if (step - 1) % adapt_freq == 0:
|
|
104
|
-
assert
|
|
105
|
-
params = TensorList(
|
|
97
|
+
assert objective.closure is not None
|
|
98
|
+
params = TensorList(objective.params)
|
|
106
99
|
p_cur = params.clone()
|
|
107
100
|
|
|
108
101
|
# move to previous params and evaluate p_prev with current mini-batch
|
|
109
|
-
params.copy_(
|
|
102
|
+
params.copy_(unpack_states(states, p, 'p_prev'))
|
|
110
103
|
with torch.enable_grad():
|
|
111
|
-
|
|
104
|
+
objective.closure()
|
|
112
105
|
g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
113
106
|
y = g - g_prev
|
|
114
107
|
|
|
@@ -119,12 +112,12 @@ class MatrixMomentum(Module):
|
|
|
119
112
|
denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
|
|
120
113
|
self.global_state["mu_mul"] = s.global_vector_norm() / denom
|
|
121
114
|
|
|
122
|
-
torch._foreach_copy_(p_prev,
|
|
115
|
+
torch._foreach_copy_(p_prev, objective.params)
|
|
123
116
|
|
|
124
117
|
@torch.no_grad
|
|
125
|
-
def
|
|
126
|
-
update = TensorList(
|
|
127
|
-
lr,mu =
|
|
118
|
+
def apply_states(self, objective, states, settings):
|
|
119
|
+
update = TensorList(objective.get_updates())
|
|
120
|
+
lr, mu = unpack_dicts(settings, "lr", 'mu', cls=NumberList)
|
|
128
121
|
|
|
129
122
|
if "mu_mul" in self.global_state:
|
|
130
123
|
mu = mu * self.global_state["mu_mul"]
|
|
@@ -133,14 +126,17 @@ class MatrixMomentum(Module):
|
|
|
133
126
|
# p_prev is not available so make a small step
|
|
134
127
|
step = self.global_state["step"]
|
|
135
128
|
if step == 1:
|
|
136
|
-
if self.defaults["adaptive"]:
|
|
129
|
+
if self.defaults["adaptive"]:
|
|
130
|
+
# initialize
|
|
131
|
+
unpack_states(states, objective.params, "g_prev", init=objective.get_grads())
|
|
132
|
+
|
|
137
133
|
update.mul_(lr) # separate so that initial_step_size can clip correctly
|
|
138
134
|
update.mul_(initial_step_size(update, 1e-7))
|
|
139
|
-
return
|
|
135
|
+
return objective
|
|
140
136
|
|
|
141
137
|
# -------------------------- matrix momentum update -------------------------- #
|
|
142
|
-
s, Hs =
|
|
138
|
+
s, Hs = unpack_states(states, objective.params, 's', 'Hs', cls=TensorList)
|
|
143
139
|
|
|
144
140
|
update.mul_(lr).sub_(s).add_(Hs*mu)
|
|
145
|
-
|
|
146
|
-
return
|
|
141
|
+
objective.updates = update
|
|
142
|
+
return objective
|
|
@@ -2,7 +2,7 @@ from typing import Literal
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Module,
|
|
5
|
+
from ...core import Chainable, Module, Transform, TensorTransform, step, Objective
|
|
6
6
|
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
7
|
from ..functional import ema_
|
|
8
8
|
from ..momentum.momentum import nag_
|
|
@@ -21,7 +21,7 @@ def msam_(
|
|
|
21
21
|
|
|
22
22
|
# inner args
|
|
23
23
|
inner: Module | None = None,
|
|
24
|
-
|
|
24
|
+
objective: Objective | None = None,
|
|
25
25
|
):
|
|
26
26
|
# weights w and wh, momentum μ, perturbation strength ρ
|
|
27
27
|
# w = wh + rho * v / ||v||
|
|
@@ -54,8 +54,8 @@ def msam_(
|
|
|
54
54
|
v1n = velocity_ / denom
|
|
55
55
|
|
|
56
56
|
if inner is not None:
|
|
57
|
-
assert
|
|
58
|
-
inner_update = TensorList(
|
|
57
|
+
assert objective is not None and inner is not None
|
|
58
|
+
inner_update = TensorList(step(objective, inner).get_updates())
|
|
59
59
|
|
|
60
60
|
else:
|
|
61
61
|
assert lr is not None
|
|
@@ -69,7 +69,7 @@ def msam_(
|
|
|
69
69
|
|
|
70
70
|
return update
|
|
71
71
|
|
|
72
|
-
class
|
|
72
|
+
class MSAMMomentum(TensorTransform):
|
|
73
73
|
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
74
74
|
|
|
75
75
|
This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
|
|
@@ -93,46 +93,40 @@ class MSAM(Transform):
|
|
|
93
93
|
lerp (bool, optional):
|
|
94
94
|
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
95
95
|
|
|
96
|
-
Examples:
|
|
97
|
-
MSAM
|
|
96
|
+
### Examples:
|
|
98
97
|
|
|
99
|
-
|
|
98
|
+
MSAM
|
|
100
99
|
|
|
101
|
-
|
|
102
|
-
model.parameters(),
|
|
103
|
-
tz.m.MSAM(1e-3)
|
|
104
|
-
)
|
|
100
|
+
```python
|
|
105
101
|
|
|
106
|
-
|
|
107
|
-
|
|
102
|
+
opt = tz.Modular(
|
|
103
|
+
model.parameters(),
|
|
104
|
+
tz.m.MSAM(1e-3)
|
|
105
|
+
)
|
|
106
|
+
```
|
|
108
107
|
|
|
109
|
-
|
|
108
|
+
Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
|
|
109
|
+
To make Adam_MSAM and such, use the ``tz.m.MSAMObjective`` module.
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
111
|
+
```python
|
|
112
|
+
opt = tz.Modular(
|
|
113
|
+
model.parameters(),
|
|
114
|
+
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
115
|
+
tz.m.Debias(0.9, 0.999),
|
|
116
|
+
)
|
|
117
|
+
```
|
|
116
118
|
"""
|
|
117
|
-
|
|
119
|
+
|
|
118
120
|
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
119
|
-
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
120
|
-
if self._USES_LR: defaults['lr'] = lr
|
|
121
|
+
defaults = dict(lr = lr, momentum=momentum, rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
121
122
|
super().__init__(defaults, uses_grad=False)
|
|
122
123
|
|
|
123
124
|
@torch.no_grad
|
|
124
|
-
def
|
|
125
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
125
126
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
126
|
-
|
|
127
|
-
lerp = s['lerp']
|
|
128
|
-
nesterov = s['nesterov']
|
|
127
|
+
fs = settings[0]
|
|
129
128
|
|
|
130
|
-
|
|
131
|
-
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
132
|
-
|
|
133
|
-
else:
|
|
134
|
-
lr=None
|
|
135
|
-
momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
|
|
129
|
+
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
136
130
|
|
|
137
131
|
return msam_(
|
|
138
132
|
TensorList(tensors),
|
|
@@ -142,16 +136,16 @@ class MSAM(Transform):
|
|
|
142
136
|
lr=lr,
|
|
143
137
|
rho=rho,
|
|
144
138
|
weight_decay=weight_decay,
|
|
145
|
-
nesterov=nesterov,
|
|
146
|
-
lerp=lerp,
|
|
139
|
+
nesterov=fs['nesterov'],
|
|
140
|
+
lerp=fs['lerp'],
|
|
147
141
|
|
|
148
142
|
# inner args
|
|
149
|
-
inner=
|
|
150
|
-
|
|
143
|
+
inner=None,
|
|
144
|
+
objective=None,
|
|
151
145
|
)
|
|
152
146
|
|
|
153
147
|
|
|
154
|
-
class
|
|
148
|
+
class MSAM(Transform):
|
|
155
149
|
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
156
150
|
|
|
157
151
|
Note:
|
|
@@ -160,7 +154,7 @@ class MSAMObjective(MSAM):
|
|
|
160
154
|
to an incorrect update rule.
|
|
161
155
|
|
|
162
156
|
Args:
|
|
163
|
-
modules (Chainable): modules that will
|
|
157
|
+
modules (Chainable): modules that will optimize the MSAM objective. Make sure ``tz.m.LR`` is one of them.
|
|
164
158
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
165
159
|
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
166
160
|
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
@@ -169,20 +163,44 @@ class MSAMObjective(MSAM):
|
|
|
169
163
|
Defaults to False.
|
|
170
164
|
|
|
171
165
|
Examples:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
166
|
+
AdamW-MSAM
|
|
167
|
+
|
|
168
|
+
```py
|
|
169
|
+
opt = tz.Modular(
|
|
170
|
+
bench.parameters(),
|
|
171
|
+
tz.m.MSAMObjective(
|
|
172
|
+
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|
|
173
|
+
rho=1.
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
```
|
|
183
177
|
"""
|
|
184
|
-
_USES_LR = False
|
|
185
178
|
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
186
|
-
|
|
179
|
+
defaults = dict(momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
180
|
+
super().__init__(defaults)
|
|
181
|
+
|
|
187
182
|
self.set_child('modules', modules)
|
|
188
183
|
|
|
184
|
+
|
|
185
|
+
@torch.no_grad
|
|
186
|
+
def apply_states(self, objective, states, settings):
|
|
187
|
+
velocity = unpack_states(states, objective.params, 'velocity', cls=TensorList)
|
|
188
|
+
fs = settings[0]
|
|
189
|
+
|
|
190
|
+
momentum, rho, weight_decay = unpack_dicts(settings, 'momentum', 'rho', 'weight_decay', cls=NumberList)
|
|
191
|
+
|
|
192
|
+
return msam_(
|
|
193
|
+
TensorList(objective.get_updates()),
|
|
194
|
+
params=TensorList(objective.params),
|
|
195
|
+
velocity_=velocity,
|
|
196
|
+
momentum=momentum,
|
|
197
|
+
lr=None,
|
|
198
|
+
rho=rho,
|
|
199
|
+
weight_decay=weight_decay,
|
|
200
|
+
nesterov=fs['nesterov'],
|
|
201
|
+
lerp=fs['lerp'],
|
|
202
|
+
|
|
203
|
+
# inner args
|
|
204
|
+
inner=self.children["modules"],
|
|
205
|
+
objective=objective,
|
|
206
|
+
)
|
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
from operator import itemgetter
|
|
2
2
|
import math
|
|
3
|
-
import
|
|
4
|
-
from collections.abc import Iterable, Sequence
|
|
5
|
-
from typing import Literal
|
|
3
|
+
from collections.abc import Iterable
|
|
6
4
|
|
|
7
5
|
import torch
|
|
8
6
|
|
|
9
|
-
from ...core import
|
|
10
|
-
from ...
|
|
11
|
-
|
|
7
|
+
from ...core import TensorTransform, Transform
|
|
8
|
+
from ...linalg.orthogonalize import orthogonalize as _orthogonalize, OrthogonalizeMethod
|
|
12
9
|
|
|
13
10
|
def reverse_dims(t:torch.Tensor):
|
|
14
11
|
return t.permute(*reversed(range(t.ndim)))
|
|
@@ -17,136 +14,69 @@ def _is_at_least_2d(p: torch.Tensor):
|
|
|
17
14
|
if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
|
|
18
15
|
return False
|
|
19
16
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
29
|
-
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
30
|
-
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
31
|
-
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
32
|
-
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
33
|
-
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
34
|
-
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
35
|
-
"""
|
|
36
|
-
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
37
|
-
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
38
|
-
X = G.bfloat16()
|
|
39
|
-
if G.size(-2) > G.size(-1):
|
|
40
|
-
X = X.mT
|
|
41
|
-
|
|
42
|
-
# Ensure spectral norm is at most 1
|
|
43
|
-
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
44
|
-
# Perform the NS iterations
|
|
45
|
-
for _ in range(steps):
|
|
46
|
-
A = X @ X.mT
|
|
47
|
-
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
48
|
-
X = a * X + B @ X
|
|
49
|
-
|
|
50
|
-
if G.size(-2) > G.size(-1):
|
|
51
|
-
X = X.mT
|
|
52
|
-
return X
|
|
53
|
-
|
|
54
|
-
# stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
|
|
55
|
-
# Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
|
|
56
|
-
# Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
|
|
57
|
-
@torch.no_grad
|
|
58
|
-
def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
|
|
59
|
-
"""
|
|
60
|
-
Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
|
|
61
|
-
"""
|
|
62
|
-
X = G.view(G.shape[0], -1)
|
|
63
|
-
|
|
64
|
-
t = False
|
|
65
|
-
if X.size(0) > X.size(1):
|
|
66
|
-
X = X.T
|
|
67
|
-
t = True
|
|
68
|
-
|
|
69
|
-
orth_X: torch.Tensor | None = None
|
|
70
|
-
try:
|
|
71
|
-
u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
|
|
72
|
-
orth_X = u @ vt
|
|
73
|
-
except RuntimeError:
|
|
74
|
-
# if warn: logging.warning('Failed to perform SVD, adding some noise.')
|
|
75
|
-
try:
|
|
76
|
-
u, s, v = torch.svd_lowrank(
|
|
77
|
-
X,
|
|
78
|
-
q=1, # assume rank is at least 1
|
|
79
|
-
M=1e-4 * X.mean() * torch.randn_like(X))
|
|
80
|
-
orth_X = u @ v.T
|
|
81
|
-
except RuntimeError:
|
|
82
|
-
if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
|
|
83
|
-
' skipping gradient orthogonalisation'))
|
|
84
|
-
if orth_X is not None:
|
|
85
|
-
if t: orth_X = orth_X.T
|
|
86
|
-
return orth_X.view_as(G)
|
|
87
|
-
|
|
88
|
-
return G # fail
|
|
17
|
+
def _orthogonalize_format(
|
|
18
|
+
tensor: torch.Tensor,
|
|
19
|
+
method: OrthogonalizeMethod,
|
|
20
|
+
channel_first: bool,
|
|
21
|
+
):
|
|
22
|
+
if channel_first:
|
|
23
|
+
return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
|
|
89
24
|
|
|
25
|
+
return _orthogonalize(tensor, method=method)
|
|
90
26
|
|
|
91
27
|
@torch.no_grad
|
|
92
|
-
def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor,
|
|
93
|
-
"""
|
|
28
|
+
def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, channel_first: bool):
|
|
29
|
+
"""``channel_first`` means it applies to first two dims, otherwise to last two dims"""
|
|
94
30
|
# this is from https://github.com/leloykun/adaptive-muon
|
|
95
31
|
# Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
|
|
96
|
-
if
|
|
97
|
-
else: X = torch.einsum('ij
|
|
32
|
+
if channel_first: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
|
|
33
|
+
else: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
|
|
98
34
|
return X
|
|
99
35
|
|
|
100
36
|
|
|
101
37
|
# code from
|
|
102
38
|
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
|
103
|
-
def adjust_lr_for_muon(lr, param_shape):
|
|
104
|
-
A, B = param_shape[:2]
|
|
39
|
+
def adjust_lr_for_muon(lr, param_shape, channel_first:bool):
|
|
40
|
+
if channel_first: A, B = param_shape[:2]
|
|
41
|
+
else: A, B = param_shape[-2:]
|
|
42
|
+
|
|
105
43
|
# We adjust the learning rate and weight decay based on the size of the parameter matrix
|
|
106
44
|
# as describted in the paper
|
|
107
45
|
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
|
108
46
|
adjusted_lr = lr * adjusted_ratio
|
|
109
47
|
return adjusted_lr
|
|
110
48
|
|
|
111
|
-
def _orthogonalize_tensor(
|
|
112
|
-
tensor: torch.Tensor,
|
|
113
|
-
steps: int = 5,
|
|
114
|
-
method: Literal["newton-schulz", "svd"] = "newton-schulz",
|
|
115
|
-
):
|
|
116
|
-
if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
|
|
117
|
-
if method == 'svd': return _svd_orthogonalize(tensor, False)
|
|
118
|
-
raise ValueError(method)
|
|
119
|
-
|
|
120
49
|
|
|
121
50
|
def orthogonalize_grads_(
|
|
122
51
|
params: Iterable[torch.Tensor],
|
|
123
|
-
steps: int = 5,
|
|
124
52
|
dual_norm_correction=False,
|
|
125
|
-
method:
|
|
53
|
+
method: OrthogonalizeMethod = "newtonschulz",
|
|
54
|
+
channel_first:bool=True,
|
|
126
55
|
):
|
|
127
|
-
"""
|
|
56
|
+
"""Computes the zeroth power / orthogonalization of gradients of an iterable of parameters.
|
|
128
57
|
|
|
129
58
|
This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
|
|
130
59
|
|
|
131
60
|
Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
132
61
|
Args:
|
|
133
62
|
params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
|
|
134
|
-
steps (int, optional):
|
|
135
|
-
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
136
63
|
dual_norm_correction (bool, optional):
|
|
137
64
|
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
138
65
|
method (str, optional):
|
|
139
66
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
67
|
+
channel_first (bool, optional):
|
|
68
|
+
if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
|
|
69
|
+
are considered batch dimensions.
|
|
140
70
|
"""
|
|
141
71
|
for p in params:
|
|
142
72
|
if (p.grad is not None) and _is_at_least_2d(p.grad):
|
|
143
|
-
X =
|
|
144
|
-
if dual_norm_correction: X = _dual_norm_correction(X, p.grad,
|
|
73
|
+
X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
|
|
74
|
+
if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
|
|
145
75
|
p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
|
|
146
76
|
|
|
147
77
|
|
|
148
78
|
|
|
149
|
-
class Orthogonalize(
|
|
79
|
+
class Orthogonalize(TensorTransform):
|
|
150
80
|
"""Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
|
|
151
81
|
|
|
152
82
|
To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
|
|
@@ -156,16 +86,15 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
156
86
|
To make Muon, use Split with Adam on 1d params
|
|
157
87
|
|
|
158
88
|
Args:
|
|
159
|
-
ns_steps (int, optional):
|
|
160
|
-
The number of Newton-Schulz iterations to run. Defaults to 5.
|
|
161
89
|
adjust_lr (bool, optional):
|
|
162
90
|
Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
|
|
163
91
|
dual_norm_correction (bool, optional):
|
|
164
92
|
enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
|
|
165
93
|
method (str, optional):
|
|
166
|
-
Newton-Schulz is very fast, SVD is
|
|
167
|
-
|
|
168
|
-
|
|
94
|
+
Newton-Schulz is very fast, SVD is slow but can be more precise.
|
|
95
|
+
channel_first (bool, optional):
|
|
96
|
+
if True, orthogonalizes along 1st two dimensions, otherwise along last 2. Other dimensions
|
|
97
|
+
are considered batch dimensions.
|
|
169
98
|
|
|
170
99
|
## Examples:
|
|
171
100
|
|
|
@@ -190,56 +119,62 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
190
119
|
Reference:
|
|
191
120
|
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
192
121
|
"""
|
|
193
|
-
def __init__(self,
|
|
194
|
-
method:
|
|
195
|
-
defaults = dict(orthogonalize=True,
|
|
196
|
-
super().__init__(
|
|
122
|
+
def __init__(self, adjust_lr=False, dual_norm_correction=False,
|
|
123
|
+
method: OrthogonalizeMethod = 'newtonschulz', channel_first:bool=True):
|
|
124
|
+
defaults = dict(orthogonalize=True, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower(), channel_first=channel_first)
|
|
125
|
+
super().__init__(defaults=defaults)
|
|
197
126
|
|
|
198
127
|
@torch.no_grad
|
|
199
|
-
def
|
|
200
|
-
orthogonalize,
|
|
201
|
-
'orthogonalize', '
|
|
128
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
129
|
+
orthogonalize, dual_norm_correction, adjust_lr, method, channel_first = itemgetter(
|
|
130
|
+
'orthogonalize', 'dual_norm_correction', 'adjust_lr', 'method', 'channel_first')(setting)
|
|
202
131
|
|
|
203
132
|
if not orthogonalize: return tensor
|
|
204
133
|
|
|
205
134
|
if _is_at_least_2d(tensor):
|
|
206
135
|
|
|
207
|
-
X =
|
|
136
|
+
X = _orthogonalize_format(tensor, method, channel_first=channel_first)
|
|
208
137
|
|
|
209
138
|
if dual_norm_correction:
|
|
210
|
-
X = _dual_norm_correction(X, tensor,
|
|
139
|
+
X = _dual_norm_correction(X, tensor, channel_first=channel_first)
|
|
211
140
|
|
|
212
141
|
if adjust_lr:
|
|
213
|
-
X.mul_(adjust_lr_for_muon(1, param.shape))
|
|
142
|
+
X.mul_(adjust_lr_for_muon(1, param.shape, channel_first=channel_first))
|
|
214
143
|
|
|
215
144
|
return X.view_as(param)
|
|
216
145
|
|
|
217
146
|
return tensor
|
|
218
147
|
|
|
219
148
|
|
|
220
|
-
class DualNormCorrection(
|
|
149
|
+
class DualNormCorrection(TensorTransform):
|
|
221
150
|
"""Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
|
|
222
151
|
Orthogonalize already has this built in with the `dual_norm_correction` setting."""
|
|
223
|
-
def __init__(self,
|
|
224
|
-
|
|
152
|
+
def __init__(self, channel_first: bool = True):
|
|
153
|
+
defaults = dict(channel_first=channel_first)
|
|
154
|
+
super().__init__(defaults)
|
|
225
155
|
|
|
226
|
-
|
|
156
|
+
@torch.no_grad
|
|
157
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
227
158
|
assert grad is not None
|
|
228
159
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
229
|
-
return _dual_norm_correction(tensor, grad,
|
|
160
|
+
return _dual_norm_correction(tensor, grad, channel_first=setting["channel_first"])
|
|
230
161
|
return tensor
|
|
231
162
|
|
|
232
163
|
|
|
233
164
|
class MuonAdjustLR(Transform):
|
|
234
165
|
"""LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
|
|
235
|
-
Orthogonalize already has this built in with the
|
|
236
|
-
def __init__(self,
|
|
237
|
-
defaults = dict(alpha=alpha)
|
|
238
|
-
super().__init__(defaults=defaults
|
|
166
|
+
Orthogonalize already has this built in with the ``adjust_lr`` setting, however you might want to move this to be later in the chain."""
|
|
167
|
+
def __init__(self, channel_first: bool = True, alpha: float = 1):
|
|
168
|
+
defaults = dict(channel_first=channel_first, alpha=alpha)
|
|
169
|
+
super().__init__(defaults=defaults)
|
|
239
170
|
|
|
240
|
-
|
|
171
|
+
@torch.no_grad
|
|
172
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
241
173
|
alphas = [s['alpha'] for s in settings]
|
|
242
|
-
|
|
174
|
+
channel_first = [s["channel_first=channel_first"] for s in settings]
|
|
175
|
+
tensors_alphas = [
|
|
176
|
+
(t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t)
|
|
177
|
+
]
|
|
243
178
|
tensors = [i[0] for i in tensors_alphas]
|
|
244
179
|
a = [i[1] for i in alphas]
|
|
245
180
|
torch._foreach_mul_(tensors, a)
|