torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
|
+
from ..functional import ema_
|
|
8
|
+
from ..momentum.momentum import nag_
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def msam_(
|
|
12
|
+
tensors: TensorList,
|
|
13
|
+
params: TensorList,
|
|
14
|
+
velocity_: TensorList,
|
|
15
|
+
momentum: float | NumberList,
|
|
16
|
+
lr: NumberList | None,
|
|
17
|
+
rho: float | NumberList,
|
|
18
|
+
weight_decay: float | NumberList,
|
|
19
|
+
nesterov: bool = False,
|
|
20
|
+
lerp: bool = False,
|
|
21
|
+
|
|
22
|
+
# inner args
|
|
23
|
+
inner: Module | None = None,
|
|
24
|
+
grads: list[torch.Tensor] | None = None,
|
|
25
|
+
):
|
|
26
|
+
# weights w and wh, momentum μ, perturbation strength ρ
|
|
27
|
+
# w = wh + rho * v / ||v||
|
|
28
|
+
# v1 = μv + g
|
|
29
|
+
# w1 = w - lr*v1
|
|
30
|
+
# wh1 = w1 - rho * v1 / ||v1||
|
|
31
|
+
|
|
32
|
+
# w1 = wh + rho * v / ||v|| - lr*v1
|
|
33
|
+
# vn = rho * v / ||v||
|
|
34
|
+
# v1n = rho * v1 / ||v1||
|
|
35
|
+
# wh1 = wh + vn - lr*v1 - v1n
|
|
36
|
+
|
|
37
|
+
# the update is
|
|
38
|
+
# vn - lr*v1 - v1n
|
|
39
|
+
|
|
40
|
+
# we track ascent direction so it becomes lr*v1 + v1n - vn
|
|
41
|
+
|
|
42
|
+
# can't really decouple it from lr
|
|
43
|
+
# but at least it is now expressed as function of g
|
|
44
|
+
|
|
45
|
+
denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
|
|
46
|
+
vn = velocity_ / denom
|
|
47
|
+
|
|
48
|
+
mom_ = nag_ if nesterov else ema_
|
|
49
|
+
velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
|
|
50
|
+
|
|
51
|
+
denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
|
|
52
|
+
v1n = velocity_ / denom
|
|
53
|
+
|
|
54
|
+
if inner is not None:
|
|
55
|
+
assert params is not None
|
|
56
|
+
inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
assert lr is not None
|
|
60
|
+
inner_update = velocity_ * lr
|
|
61
|
+
|
|
62
|
+
update = inner_update.add_(v1n).sub_(vn)
|
|
63
|
+
|
|
64
|
+
if generic_ne(weight_decay, 0):
|
|
65
|
+
wd = (params + vn).mul_(weight_decay)
|
|
66
|
+
update.add_(wd)
|
|
67
|
+
|
|
68
|
+
return update
|
|
69
|
+
|
|
70
|
+
class MSAM(Transform):
|
|
71
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
72
|
+
|
|
73
|
+
This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
|
|
74
|
+
replacement for momentum strategies in other optimizers.
|
|
75
|
+
|
|
76
|
+
To combine MSAM with other optimizers in the way done in the official implementation,
|
|
77
|
+
e.g. to make Adam_MSAM, use :code:`tz.m.MSAMObjective` module.
|
|
78
|
+
|
|
79
|
+
.. note::
|
|
80
|
+
MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
81
|
+
To avoid compounding learning rate mofications, remove the :code:`tz.m.LR` module if you had it.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
lr (float): learning rate. Adding this module adds support for learning rate schedulers.
|
|
85
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
86
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
87
|
+
weight_decay (float, optional):
|
|
88
|
+
weight decay. It is applied to perturbed parameters, so it is differnet
|
|
89
|
+
from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
|
|
90
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
91
|
+
lerp (bool, optional):
|
|
92
|
+
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
MSAM
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
opt = tz.Modular(
|
|
100
|
+
model.parameters(),
|
|
101
|
+
tz.m.MSAM(1e-3)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
|
|
105
|
+
To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
|
|
106
|
+
|
|
107
|
+
.. code-block:: python
|
|
108
|
+
|
|
109
|
+
opt = tz.Modular(
|
|
110
|
+
model.parameters(),
|
|
111
|
+
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
112
|
+
tz.m.Debias(0.9, 0.999),
|
|
113
|
+
)
|
|
114
|
+
"""
|
|
115
|
+
USES_LR = True
|
|
116
|
+
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
117
|
+
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
118
|
+
if self.USES_LR: defaults['lr'] = lr
|
|
119
|
+
super().__init__(defaults, uses_grad=False)
|
|
120
|
+
|
|
121
|
+
@torch.no_grad
|
|
122
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
123
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
124
|
+
s = self.settings[params[0]]
|
|
125
|
+
lerp = s['lerp']
|
|
126
|
+
nesterov = s['nesterov']
|
|
127
|
+
|
|
128
|
+
if self.USES_LR:
|
|
129
|
+
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
130
|
+
|
|
131
|
+
else:
|
|
132
|
+
lr=None
|
|
133
|
+
momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
|
|
134
|
+
|
|
135
|
+
return msam_(
|
|
136
|
+
TensorList(tensors),
|
|
137
|
+
params=TensorList(params),
|
|
138
|
+
velocity_=velocity,
|
|
139
|
+
momentum=momentum,
|
|
140
|
+
lr=lr,
|
|
141
|
+
rho=rho,
|
|
142
|
+
weight_decay=weight_decay,
|
|
143
|
+
nesterov=nesterov,
|
|
144
|
+
lerp=lerp,
|
|
145
|
+
|
|
146
|
+
# inner args
|
|
147
|
+
inner=self.children.get("modules", None),
|
|
148
|
+
grads=grads,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class MSAMObjective(MSAM):
|
|
153
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
154
|
+
|
|
155
|
+
.. note::
|
|
156
|
+
Please make sure to place :code:`tz.m.LR` inside the :code:`modules` argument. For example,
|
|
157
|
+
:code:`tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])`. Putting LR after MSAM will lead
|
|
158
|
+
to an incorrect update rule.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
|
|
162
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
163
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
164
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
165
|
+
lerp (bool, optional):
|
|
166
|
+
whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
|
|
167
|
+
Defaults to False.
|
|
168
|
+
|
|
169
|
+
Examples:
|
|
170
|
+
AdamW-MSAM
|
|
171
|
+
|
|
172
|
+
.. code-block:: python
|
|
173
|
+
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
bench.parameters(),
|
|
176
|
+
tz.m.MSAMObjective(
|
|
177
|
+
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|
|
178
|
+
rho=1.
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
"""
|
|
182
|
+
USES_LR = False
|
|
183
|
+
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
184
|
+
super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
185
|
+
self.set_child('modules', modules)
|
|
186
|
+
|
|
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
|
|
|
19
19
|
|
|
20
20
|
# stolen from:
|
|
21
21
|
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
22
|
+
# actually at this stage its a frankenstein
|
|
22
23
|
@enable_compilation
|
|
23
24
|
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
|
|
24
25
|
"""
|
|
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
152
153
|
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
153
154
|
Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
|
|
154
155
|
|
|
155
|
-
To make Muon, use Split with Adam on 1d params
|
|
156
|
+
To make Muon, use Split with Adam on 1d params
|
|
156
157
|
|
|
157
158
|
Args:
|
|
158
159
|
ns_steps (int, optional):
|
|
@@ -165,6 +166,30 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
165
166
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
166
167
|
target (str, optional):
|
|
167
168
|
what to set on var.
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
Examples:
|
|
172
|
+
standard Muon with Adam fallback
|
|
173
|
+
|
|
174
|
+
.. code-block:: python
|
|
175
|
+
|
|
176
|
+
opt = tz.Modular(
|
|
177
|
+
model.head.parameters(),
|
|
178
|
+
tz.m.Split(
|
|
179
|
+
# apply muon only to 2D+ parameters
|
|
180
|
+
filter = lambda t: t.ndim >= 2,
|
|
181
|
+
true = [
|
|
182
|
+
tz.m.HeavyBall(),
|
|
183
|
+
tz.m.Orthogonalize(),
|
|
184
|
+
tz.m.LR(1e-2),
|
|
185
|
+
],
|
|
186
|
+
false = tz.m.Adam()
|
|
187
|
+
),
|
|
188
|
+
tz.m.LR(1e-2)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
Reference:
|
|
192
|
+
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
168
193
|
"""
|
|
169
194
|
def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
|
|
170
195
|
method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
|
|
@@ -172,9 +197,9 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
172
197
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
173
198
|
|
|
174
199
|
@torch.no_grad
|
|
175
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
200
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
176
201
|
orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
|
|
177
|
-
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(
|
|
202
|
+
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
|
|
178
203
|
|
|
179
204
|
if not orthogonalize: return tensor
|
|
180
205
|
|
|
@@ -199,7 +224,7 @@ class DualNormCorrection(TensorwiseTransform):
|
|
|
199
224
|
def __init__(self, target: Target='update'):
|
|
200
225
|
super().__init__({}, uses_grad=True, target=target)
|
|
201
226
|
|
|
202
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
227
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
203
228
|
assert grad is not None
|
|
204
229
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
205
230
|
return _dual_norm_correction(tensor, grad, batch_first=False)
|
|
@@ -213,7 +238,7 @@ class MuonAdjustLR(Transform):
|
|
|
213
238
|
defaults = dict(alpha=alpha)
|
|
214
239
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
215
240
|
|
|
216
|
-
def
|
|
241
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
217
242
|
alphas = [s['alpha'] for s in settings]
|
|
218
243
|
tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
|
|
219
244
|
tensors = [i[0] for i in tensors_alphas]
|
|
@@ -36,7 +36,7 @@ class OrthoGrad(Transform):
|
|
|
36
36
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
37
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
40
|
eps = settings[0]['eps']
|
|
41
41
|
renormalize = settings[0]['renormalize']
|
|
42
42
|
|
|
@@ -40,7 +40,9 @@ def rmsprop_(
|
|
|
40
40
|
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
41
41
|
|
|
42
42
|
class RMSprop(Transform):
|
|
43
|
-
"""Divides graient by EMA of gradient squares.
|
|
43
|
+
"""Divides graient by EMA of gradient squares.
|
|
44
|
+
|
|
45
|
+
This implementation is identical to :code:`torch.optim.RMSprop`.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
48
|
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
@@ -50,7 +52,8 @@ class RMSprop(Transform):
|
|
|
50
52
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
51
53
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
52
54
|
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
53
|
-
inner (Chainable | None, optional):
|
|
55
|
+
inner (Chainable | None, optional):
|
|
56
|
+
Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
|
|
54
57
|
"""
|
|
55
58
|
def __init__(
|
|
56
59
|
self,
|
|
@@ -60,7 +63,7 @@ class RMSprop(Transform):
|
|
|
60
63
|
debiased: bool = False,
|
|
61
64
|
amsgrad: bool = False,
|
|
62
65
|
pow: float = 2,
|
|
63
|
-
init: Literal["zeros", "update"] = "
|
|
66
|
+
init: Literal["zeros", "update"] = "zeros",
|
|
64
67
|
inner: Chainable | None = None,
|
|
65
68
|
):
|
|
66
69
|
defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
|
|
@@ -69,7 +72,7 @@ class RMSprop(Transform):
|
|
|
69
72
|
if inner is not None:
|
|
70
73
|
self.set_child('inner', inner)
|
|
71
74
|
|
|
72
|
-
def
|
|
75
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
73
76
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
74
77
|
smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
|
|
75
78
|
centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
|
|
@@ -135,7 +135,8 @@ class Rprop(Transform):
|
|
|
135
135
|
Next step, magnitude for that weight won't change.
|
|
136
136
|
|
|
137
137
|
Compared to pytorch this also implements backtracking update when sign changes.
|
|
138
|
-
|
|
138
|
+
|
|
139
|
+
This implementation is identical to :code:`torch.optim.Rprop` if :code:`backtrack` is set to False.
|
|
139
140
|
|
|
140
141
|
Args:
|
|
141
142
|
nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
|
|
@@ -164,7 +165,7 @@ class Rprop(Transform):
|
|
|
164
165
|
super().__init__(defaults, uses_grad=False)
|
|
165
166
|
|
|
166
167
|
@torch.no_grad
|
|
167
|
-
def
|
|
168
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
168
169
|
step = self.global_state.get('step', 0)
|
|
169
170
|
self.global_state['step'] = step + 1
|
|
170
171
|
|
|
@@ -223,7 +224,7 @@ class ScaleLRBySignChange(Transform):
|
|
|
223
224
|
super().__init__(defaults, uses_grad=use_grad, target=target)
|
|
224
225
|
|
|
225
226
|
@torch.no_grad
|
|
226
|
-
def
|
|
227
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
227
228
|
step = self.global_state.get('step', 0)
|
|
228
229
|
self.global_state['step'] = step + 1
|
|
229
230
|
|
|
@@ -272,7 +273,7 @@ class BacktrackOnSignChange(Transform):
|
|
|
272
273
|
super().__init__(defaults, uses_grad=use_grad)
|
|
273
274
|
|
|
274
275
|
@torch.no_grad
|
|
275
|
-
def
|
|
276
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
276
277
|
step = self.global_state.get('step', 0)
|
|
277
278
|
self.global_state['step'] = step + 1
|
|
278
279
|
|
|
@@ -294,12 +295,29 @@ class BacktrackOnSignChange(Transform):
|
|
|
294
295
|
return tensors
|
|
295
296
|
|
|
296
297
|
class SignConsistencyMask(Transform):
|
|
297
|
-
"""
|
|
298
|
+
"""
|
|
299
|
+
Outputs a mask of sign consistency of current and previous inputs.
|
|
300
|
+
|
|
301
|
+
The output is 0 for weights where input sign changed compared to previous input, 1 otherwise.
|
|
302
|
+
|
|
303
|
+
Examples:
|
|
304
|
+
|
|
305
|
+
GD that skips update for weights where gradient sign changed compared to previous gradient.
|
|
306
|
+
|
|
307
|
+
.. code-block:: python
|
|
308
|
+
|
|
309
|
+
opt = tz.Modular(
|
|
310
|
+
model.parameters(),
|
|
311
|
+
tz.m.Mul(tz.m.SignConsistencyMask()),
|
|
312
|
+
tz.m.LR(1e-2)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
"""
|
|
298
316
|
def __init__(self,target: Target = 'update'):
|
|
299
317
|
super().__init__({}, uses_grad=False, target = target)
|
|
300
318
|
|
|
301
319
|
@torch.no_grad
|
|
302
|
-
def
|
|
320
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
303
321
|
prev = unpack_states(states, tensors, 'prev', cls=TensorList)
|
|
304
322
|
mask = prev.mul_(tensors).gt_(0)
|
|
305
323
|
prev.copy_(tensors)
|
|
@@ -307,7 +325,23 @@ class SignConsistencyMask(Transform):
|
|
|
307
325
|
|
|
308
326
|
|
|
309
327
|
class SignConsistencyLRs(Transform):
|
|
310
|
-
"""
|
|
328
|
+
"""Outputs per-weight learning rates based on consecutive sign consistency.
|
|
329
|
+
|
|
330
|
+
The learning rate for a weight is multiplied by :code:`nplus` when two consecutive update signs are the same, otherwise it is multiplied by :code:`nplus`. The learning rates are bounded to be in :code:`(lb, ub)` range.
|
|
331
|
+
|
|
332
|
+
Examples:
|
|
333
|
+
|
|
334
|
+
GD scaled by consecutive gradient sign consistency
|
|
335
|
+
|
|
336
|
+
.. code-block:: python
|
|
337
|
+
|
|
338
|
+
opt = tz.Modular(
|
|
339
|
+
model.parameters(),
|
|
340
|
+
tz.m.Mul(tz.m.SignConsistencyLRs()),
|
|
341
|
+
tz.m.LR(1e-2)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
"""
|
|
311
345
|
def __init__(
|
|
312
346
|
self,
|
|
313
347
|
nplus: float = 1.2,
|
|
@@ -321,7 +355,7 @@ class SignConsistencyLRs(Transform):
|
|
|
321
355
|
super().__init__(defaults, uses_grad=False, target = target)
|
|
322
356
|
|
|
323
357
|
@torch.no_grad
|
|
324
|
-
def
|
|
358
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
325
359
|
step = self.global_state.get('step', 0)
|
|
326
360
|
self.global_state['step'] = step + 1
|
|
327
361
|
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
import torch
|
|
3
|
+
from ...utils import TensorList, NumberList
|
|
4
|
+
from ...core import Module
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SAM(Module):
|
|
8
|
+
"""Sharpness-Aware Minimization from https://arxiv.org/pdf/2010.01412
|
|
9
|
+
|
|
10
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
11
|
+
It performs two forward and backward passes per step.
|
|
12
|
+
|
|
13
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
14
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
15
|
+
|
|
16
|
+
.. note::
|
|
17
|
+
This module requires a closure passed to the optimizer step,
|
|
18
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
22
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
23
|
+
asam (bool, optional):
|
|
24
|
+
enables ASAM variant which makes perturbation relative to weight magnitudes.
|
|
25
|
+
ASAM requires a much larger :code:`rho`, like 0.5 or 1.
|
|
26
|
+
The :code:`tz.m.ASAM` class is idential to setting this argument to True, but
|
|
27
|
+
it has larger :code:`rho` by default.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
SAM-SGD:
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
opt = tz.Modular(
|
|
35
|
+
model.parameters(),
|
|
36
|
+
tz.m.SAM(),
|
|
37
|
+
tz.m.LR(1e-2)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SAM-Adam:
|
|
41
|
+
|
|
42
|
+
.. code-block:: python
|
|
43
|
+
|
|
44
|
+
opt = tz.Modular(
|
|
45
|
+
model.parameters(),
|
|
46
|
+
tz.m.SAM(),
|
|
47
|
+
tz.m.Adam(),
|
|
48
|
+
tz.m.LR(1e-2)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
References:
|
|
52
|
+
Foret, P., Kleiner, A., Mobahi, H., & Neyshabur, B. (2020). Sharpness-aware minimization for efficiently improving generalization. arXiv preprint arXiv:2010.01412. https://arxiv.org/abs/2010.01412#page=3.16
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self, rho: float = 0.05, p: float = 2, eps=1e-10, asam=False):
|
|
55
|
+
defaults = dict(rho=rho, p=p, eps=eps, asam=asam)
|
|
56
|
+
super().__init__(defaults)
|
|
57
|
+
|
|
58
|
+
@torch.no_grad
|
|
59
|
+
def step(self, var):
|
|
60
|
+
|
|
61
|
+
params = var.params
|
|
62
|
+
closure = var.closure
|
|
63
|
+
zero_grad = var.zero_grad
|
|
64
|
+
if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
|
|
65
|
+
p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
|
|
66
|
+
s = self.settings[var.params[0]]
|
|
67
|
+
eps = s['eps']
|
|
68
|
+
asam = s['asam']
|
|
69
|
+
|
|
70
|
+
# 1/p + 1/q = 1
|
|
71
|
+
# okay, authors of SAM paper, I will manually solve your equation
|
|
72
|
+
# so q = -p/(1-p)
|
|
73
|
+
q = -p / (1-p)
|
|
74
|
+
# as a validation for 2 it is -2 / -1 = 2
|
|
75
|
+
|
|
76
|
+
@torch.no_grad
|
|
77
|
+
def sam_closure(backward=True):
|
|
78
|
+
orig_grads = None
|
|
79
|
+
if not backward:
|
|
80
|
+
# if backward is False, make sure this doesn't modify gradients
|
|
81
|
+
# to avoid issues
|
|
82
|
+
orig_grads = [p.grad for p in params]
|
|
83
|
+
|
|
84
|
+
# gradient at initial parameters
|
|
85
|
+
zero_grad()
|
|
86
|
+
with torch.enable_grad():
|
|
87
|
+
closure()
|
|
88
|
+
|
|
89
|
+
grad = TensorList(p.grad if p.grad is not None else torch.zeros_like(p) for p in params)
|
|
90
|
+
grad_abs = grad.abs()
|
|
91
|
+
|
|
92
|
+
# compute e
|
|
93
|
+
term1 = grad.sign().mul_(rho)
|
|
94
|
+
term2 = grad_abs.pow(q-1)
|
|
95
|
+
|
|
96
|
+
if asam:
|
|
97
|
+
grad_abs.mul_(torch._foreach_abs(params))
|
|
98
|
+
|
|
99
|
+
denom = grad_abs.pow_(q).sum().pow(1/p)
|
|
100
|
+
|
|
101
|
+
e = term1.mul_(term2).div_(denom.clip(min=eps))
|
|
102
|
+
|
|
103
|
+
if asam:
|
|
104
|
+
e.mul_(torch._foreach_pow(params, 2))
|
|
105
|
+
|
|
106
|
+
# calculate loss and gradient approximation of inner problem
|
|
107
|
+
torch._foreach_add_(params, e)
|
|
108
|
+
if backward:
|
|
109
|
+
zero_grad()
|
|
110
|
+
with torch.enable_grad():
|
|
111
|
+
# this sets .grad attributes
|
|
112
|
+
sam_loss = closure()
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
sam_loss = closure(False)
|
|
116
|
+
|
|
117
|
+
# and restore initial parameters
|
|
118
|
+
torch._foreach_sub_(params, e)
|
|
119
|
+
|
|
120
|
+
if orig_grads is not None:
|
|
121
|
+
for param,orig_grad in zip(params, orig_grads):
|
|
122
|
+
param.grad = orig_grad
|
|
123
|
+
|
|
124
|
+
return sam_loss
|
|
125
|
+
|
|
126
|
+
var.closure = sam_closure
|
|
127
|
+
return var
|
|
128
|
+
|
|
129
|
+
# different class because defaults for SAM are bad for ASAM
|
|
130
|
+
class ASAM(SAM):
|
|
131
|
+
"""Adaptive Sharpness-Aware Minimization from https://arxiv.org/pdf/2102.11600#page=6.52
|
|
132
|
+
|
|
133
|
+
SAM functions by seeking parameters that lie in neighborhoods having uniformly low loss value.
|
|
134
|
+
It performs two forward and backward passes per step.
|
|
135
|
+
|
|
136
|
+
This implementation modifies the closure to return loss and calculate gradients
|
|
137
|
+
of the SAM objective. All modules after this will use the modified objective.
|
|
138
|
+
|
|
139
|
+
.. note::
|
|
140
|
+
This module requires a closure passed to the optimizer step,
|
|
141
|
+
as it needs to re-evaluate the loss and gradients at two points on each step.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
rho (float, optional): Neighborhood size. Defaults to 0.05.
|
|
145
|
+
p (float, optional): norm of the SAM objective. Defaults to 2.
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
ASAM-Adam:
|
|
149
|
+
|
|
150
|
+
.. code-block:: python
|
|
151
|
+
|
|
152
|
+
opt = tz.Modular(
|
|
153
|
+
model.parameters(),
|
|
154
|
+
tz.m.ASAM(),
|
|
155
|
+
tz.m.Adam(),
|
|
156
|
+
tz.m.LR(1e-2)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
References:
|
|
160
|
+
Kwon, J., Kim, J., Park, H., & Choi, I. K. (2021, July). Asam: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks. In International Conference on Machine Learning (pp. 5905-5914). PMLR. https://arxiv.org/abs/2102.11600
|
|
161
|
+
"""
|
|
162
|
+
def __init__(self, rho: float = 0.5, p: float = 2, eps=1e-10):
|
|
163
|
+
super().__init__(rho=rho, p=p, eps=eps, asam=True)
|
|
@@ -59,7 +59,7 @@ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
|
|
|
59
59
|
if tensor.shape[sort_idxs[0]] > max_dim:
|
|
60
60
|
return tensor, None, None
|
|
61
61
|
|
|
62
|
-
tensor = tensor.permute(*sort_idxs)
|
|
62
|
+
tensor = tensor.permute(*sort_idxs.tolist())
|
|
63
63
|
flatten_end_idx = 0
|
|
64
64
|
flat_sizes = []
|
|
65
65
|
flat_numel = 1
|
|
@@ -80,19 +80,28 @@ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None,
|
|
|
80
80
|
if flat_sizes is None: return tensor
|
|
81
81
|
assert sort_idxs is not None
|
|
82
82
|
tensor = tensor.unflatten(0, flat_sizes)
|
|
83
|
-
return tensor.permute(*np.argsort(sort_idxs))
|
|
83
|
+
return tensor.permute(*np.argsort(sort_idxs).tolist())
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
class Shampoo(Transform):
|
|
87
87
|
"""Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
|
|
88
88
|
|
|
89
|
+
.. note::
|
|
90
|
+
Shampoo is usually grafted to another optimizer like Adam, otherwise it can be unstable. An example of how to do grafting is given below in the Examples section.
|
|
91
|
+
|
|
92
|
+
.. note::
|
|
93
|
+
Shampoo is a very computationally expensive optimizer, increase :code:`update_freq` if it is too slow.
|
|
94
|
+
|
|
95
|
+
.. note::
|
|
96
|
+
SOAP optimizer usually outperforms Shampoo and is also not as computationally expensive. SOAP implementation is available as :code:`tz.m.SOAP`.
|
|
97
|
+
|
|
89
98
|
Args:
|
|
90
99
|
decay (float | None, optional): slowly decays preconditioners. Defaults to None.
|
|
91
100
|
beta (float | None, optional):
|
|
92
101
|
if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
|
|
93
102
|
matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
|
|
94
103
|
update_freq (int, optional): preconditioner update frequency. Defaults to 10.
|
|
95
|
-
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to
|
|
104
|
+
exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
|
|
96
105
|
merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
|
|
97
106
|
max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
|
|
98
107
|
precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
|
|
@@ -101,13 +110,38 @@ class Shampoo(Transform):
|
|
|
101
110
|
module applied after updating preconditioners and before applying preconditioning.
|
|
102
111
|
For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
|
|
103
112
|
Defaults to None.
|
|
113
|
+
|
|
114
|
+
Examples:
|
|
115
|
+
Shampoo grafted to Adam
|
|
116
|
+
|
|
117
|
+
.. code-block:: python
|
|
118
|
+
|
|
119
|
+
opt = tz.Modular(
|
|
120
|
+
model.parameters(),
|
|
121
|
+
tz.m.GraftModules(
|
|
122
|
+
direction = tz.m.Shampoo(),
|
|
123
|
+
magnitude = tz.m.Adam(),
|
|
124
|
+
),
|
|
125
|
+
tz.m.LR(1e-3)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
Adam with Shampoo preconditioner
|
|
129
|
+
|
|
130
|
+
.. code-block:: python
|
|
131
|
+
|
|
132
|
+
opt = tz.Modular(
|
|
133
|
+
model.parameters(),
|
|
134
|
+
tz.m.Shampoo(beta=0.999, inner=tz.m.EMA(0.9)),
|
|
135
|
+
tz.m.Debias(0.9, 0.999),
|
|
136
|
+
tz.m.LR(1e-3)
|
|
137
|
+
)
|
|
104
138
|
"""
|
|
105
139
|
def __init__(
|
|
106
140
|
self,
|
|
107
141
|
decay: float | None = None,
|
|
108
142
|
beta: float | None = None,
|
|
109
143
|
update_freq: int = 10,
|
|
110
|
-
exp_override: int | None =
|
|
144
|
+
exp_override: int | None = 2,
|
|
111
145
|
merge_small: bool = True,
|
|
112
146
|
max_dim: int = 2_000,
|
|
113
147
|
precondition_1d: bool = True,
|
|
@@ -120,7 +154,7 @@ class Shampoo(Transform):
|
|
|
120
154
|
if inner is not None:
|
|
121
155
|
self.set_child('inner', inner)
|
|
122
156
|
|
|
123
|
-
def
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
124
158
|
merged_tensors = [] # target with merged dims
|
|
125
159
|
|
|
126
160
|
# update preconditioners
|