torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Literal, Any
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from ...core import Chainable, TensorwiseTransform
|
|
7
|
+
|
|
8
|
+
def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
|
|
9
|
+
M = torch.stack(tuple(history), dim=1)# / len(history)
|
|
10
|
+
MTM = M.T @ M
|
|
11
|
+
if damping != 0:
|
|
12
|
+
MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
L, Q = torch.linalg.eigh(MTM) # pylint:disable=not-callable
|
|
16
|
+
|
|
17
|
+
tol = torch.finfo(M.dtype).eps * L.amax() # remove small eigenvalues
|
|
18
|
+
indices = L > tol
|
|
19
|
+
L = L[indices]
|
|
20
|
+
Q = Q[:, indices]
|
|
21
|
+
|
|
22
|
+
U = (M @ Q) * L.rsqrt()
|
|
23
|
+
|
|
24
|
+
if rdamping != 0:
|
|
25
|
+
rdamping *= torch.linalg.vector_norm(L) # pylint:disable=not-callable
|
|
26
|
+
L.add_(rdamping)
|
|
27
|
+
|
|
28
|
+
return U, L
|
|
29
|
+
|
|
30
|
+
except torch.linalg.LinAlgError:
|
|
31
|
+
return None, None
|
|
32
|
+
|
|
33
|
+
def lm_adagrad_apply(g: torch.Tensor, U: torch.Tensor, L: torch.Tensor):
|
|
34
|
+
Z = U.T @ g
|
|
35
|
+
return (U * L.rsqrt()) @ Z
|
|
36
|
+
|
|
37
|
+
def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
|
|
38
|
+
if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
|
|
39
|
+
else:
|
|
40
|
+
if state_[key] is None or state_[key].shape != value.shape: state_[key] = value
|
|
41
|
+
else: state_[key].lerp_(value, 1-beta)
|
|
42
|
+
|
|
43
|
+
class LMAdagrad(TensorwiseTransform):
|
|
44
|
+
"""
|
|
45
|
+
Limited-memory full matrix Adagrad.
|
|
46
|
+
|
|
47
|
+
The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate update as U S^-1 Uᵀg.
|
|
48
|
+
But it uses eigendecomposition on MᵀM to get U and S^2 because that is faster when you don't neeed V.
|
|
49
|
+
|
|
50
|
+
This is equivalent to full-matrix Adagrad on recent gradients.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
history_size (int, optional): number of past gradients to store. Defaults to 10.
|
|
54
|
+
update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
|
|
55
|
+
damping (float, optional): damping value. Defaults to 1e-4.
|
|
56
|
+
rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
|
|
57
|
+
order (int, optional):
|
|
58
|
+
order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
|
|
59
|
+
true_damping (bool, optional):
|
|
60
|
+
If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
|
|
61
|
+
eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
|
|
62
|
+
U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
|
|
63
|
+
S_beta (float | None, optional): momentum for S (too unstable, don't use). Defaults to None.
|
|
64
|
+
interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
|
|
65
|
+
concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
|
|
66
|
+
inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
|
|
67
|
+
|
|
68
|
+
Examples:
|
|
69
|
+
Limited-memory Adagrad
|
|
70
|
+
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
optimizer = tz.Modular(
|
|
74
|
+
model.parameters(),
|
|
75
|
+
tz.m.LMAdagrad(),
|
|
76
|
+
tz.m.LR(0.1)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
|
|
80
|
+
|
|
81
|
+
.. code-block:: python
|
|
82
|
+
|
|
83
|
+
optimizer = tz.Modular(
|
|
84
|
+
model.parameters(),
|
|
85
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
86
|
+
tz.m.Debias(0.9, 0.999),
|
|
87
|
+
tz.m.LR(0.01)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
|
|
91
|
+
|
|
92
|
+
.. code-block:: python
|
|
93
|
+
|
|
94
|
+
optimizer = tz.Modular(
|
|
95
|
+
model.parameters(),
|
|
96
|
+
tz.m.LMAdagrad(inner=tz.m.EMA()),
|
|
97
|
+
tz.m.Debias(0.9, 0.999),
|
|
98
|
+
tz.m.ClipNormByEMA(max_ema_growth=1.2),
|
|
99
|
+
tz.m.LR(0.01)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
Reference:
|
|
103
|
+
Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
history_size: int = 100,
|
|
109
|
+
update_freq: int = 1,
|
|
110
|
+
damping: float = 1e-4,
|
|
111
|
+
rdamping: float = 0,
|
|
112
|
+
order: int = 1,
|
|
113
|
+
true_damping: bool = True,
|
|
114
|
+
U_beta: float | None = None,
|
|
115
|
+
L_beta: float | None = None,
|
|
116
|
+
interval: int = 1,
|
|
117
|
+
concat_params: bool = True,
|
|
118
|
+
inner: Chainable | None = None,
|
|
119
|
+
):
|
|
120
|
+
# history is still updated each step so Precondition's update_freq has different meaning
|
|
121
|
+
defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, L_beta=L_beta)
|
|
122
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
|
|
123
|
+
|
|
124
|
+
@torch.no_grad
|
|
125
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
126
|
+
order = setting['order']
|
|
127
|
+
history_size = setting['history_size']
|
|
128
|
+
update_freq = setting['update_freq']
|
|
129
|
+
damping = setting['damping']
|
|
130
|
+
rdamping = setting['rdamping']
|
|
131
|
+
U_beta = setting['U_beta']
|
|
132
|
+
L_beta = setting['L_beta']
|
|
133
|
+
|
|
134
|
+
if 'history' not in state: state['history'] = deque(maxlen=history_size)
|
|
135
|
+
history = state['history']
|
|
136
|
+
|
|
137
|
+
if order == 1:
|
|
138
|
+
t = tensor.clone().view(-1)
|
|
139
|
+
history.append(t)
|
|
140
|
+
else:
|
|
141
|
+
|
|
142
|
+
# if order=2, history is of gradient differences, order 3 is differences between differences, etc
|
|
143
|
+
# scaled by parameter differences
|
|
144
|
+
cur_p = param.clone()
|
|
145
|
+
cur_g = tensor.clone()
|
|
146
|
+
for i in range(1, order):
|
|
147
|
+
if f'prev_g_{i}' not in state:
|
|
148
|
+
state[f'prev_p_{i}'] = cur_p
|
|
149
|
+
state[f'prev_g_{i}'] = cur_g
|
|
150
|
+
break
|
|
151
|
+
|
|
152
|
+
s = cur_p - state[f'prev_p_{i}']
|
|
153
|
+
y = cur_g - state[f'prev_g_{i}']
|
|
154
|
+
state[f'prev_p_{i}'] = cur_p
|
|
155
|
+
state[f'prev_g_{i}'] = cur_g
|
|
156
|
+
cur_p = s
|
|
157
|
+
cur_g = y
|
|
158
|
+
|
|
159
|
+
if i == order - 1:
|
|
160
|
+
cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
|
|
161
|
+
history.append(cur_g.view(-1))
|
|
162
|
+
|
|
163
|
+
step = state.get('step', 0)
|
|
164
|
+
if step % update_freq == 0 and len(history) != 0:
|
|
165
|
+
U, L = lm_adagrad_update(history, damping=damping, rdamping=rdamping)
|
|
166
|
+
maybe_lerp_(state, U_beta, 'U', U)
|
|
167
|
+
maybe_lerp_(state, L_beta, 'L', L)
|
|
168
|
+
|
|
169
|
+
if len(history) != 0:
|
|
170
|
+
state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
|
|
171
|
+
|
|
172
|
+
@torch.no_grad
|
|
173
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
174
|
+
U = state.get('U', None)
|
|
175
|
+
if U is None:
|
|
176
|
+
# make a conservative step to avoid issues due to different GD scaling
|
|
177
|
+
return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
|
|
178
|
+
|
|
179
|
+
L = state['L']
|
|
180
|
+
update = lm_adagrad_apply(tensor.view(-1), U, L).view_as(tensor)
|
|
181
|
+
|
|
182
|
+
return update
|
|
183
|
+
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from ...core import Module, Target, Transform
|
|
4
|
-
from ...utils import NumberList, TensorList
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
|
|
@@ -28,8 +28,8 @@ class Lion(Transform):
|
|
|
28
28
|
super().__init__(defaults, uses_grad=False)
|
|
29
29
|
|
|
30
30
|
@torch.no_grad
|
|
31
|
-
def
|
|
32
|
-
beta1, beta2 =
|
|
33
|
-
exp_avg =
|
|
31
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
32
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
34
34
|
return lion_(TensorList(tensors),exp_avg,beta1,beta2)
|
|
35
35
|
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def mars_correction_(
|
|
19
|
+
tensors_: TensorList,
|
|
20
|
+
prev_: TensorList,
|
|
21
|
+
beta: float | NumberList,
|
|
22
|
+
scaling: float | NumberList,
|
|
23
|
+
max_norm: float | NumberList | None,
|
|
24
|
+
):
|
|
25
|
+
dg = (tensors_ - prev_).mul_(scaling * beta / (1-beta))
|
|
26
|
+
prev_.copy_(tensors_)
|
|
27
|
+
|
|
28
|
+
c = tensors_.add_(dg)
|
|
29
|
+
if max_norm is not None:
|
|
30
|
+
c.clip_norm_(max=max_norm, tensorwise=False)
|
|
31
|
+
|
|
32
|
+
return c
|
|
33
|
+
|
|
34
|
+
class MARSCorrection(Transform):
|
|
35
|
+
"""MARS variance reduction correction.
|
|
36
|
+
|
|
37
|
+
Place any other momentum-based optimizer after this,
|
|
38
|
+
make sure :code:`beta` parameter matches with momentum in the optimizer.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
|
|
42
|
+
scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
|
|
43
|
+
max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
Mars-AdamW
|
|
47
|
+
|
|
48
|
+
.. code-block:: python
|
|
49
|
+
|
|
50
|
+
optimizer = tz.Modular(
|
|
51
|
+
model.parameters(),
|
|
52
|
+
tz.m.MARSCorrection(beta=0.95),
|
|
53
|
+
tz.m.Adam(beta1=0.95, beta2=0.99),
|
|
54
|
+
tz.m.WeightDecay(1e-3),
|
|
55
|
+
tz.m.LR(0.1)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
Mars-Lion
|
|
59
|
+
|
|
60
|
+
.. code-block:: python
|
|
61
|
+
|
|
62
|
+
optimizer = tz.Modular(
|
|
63
|
+
model.parameters(),
|
|
64
|
+
tz.m.MARSCorrection(beta=0.9),
|
|
65
|
+
tz.m.Lion(beta1=0.9),
|
|
66
|
+
tz.m.LR(0.1)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
"""
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
beta: float = 0.9,
|
|
73
|
+
scaling: float = 0.025,
|
|
74
|
+
max_norm: float | None = 1,
|
|
75
|
+
):
|
|
76
|
+
defaults=dict(beta=beta, scaling=scaling, max_norm=max_norm)
|
|
77
|
+
super().__init__(defaults, uses_grad=False)
|
|
78
|
+
|
|
79
|
+
@torch.no_grad
|
|
80
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
81
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
82
|
+
beta, scaling = unpack_dicts(settings, 'beta', 'scaling', cls=NumberList)
|
|
83
|
+
max_norm = settings[0]['max_norm']
|
|
84
|
+
|
|
85
|
+
return mars_correction_(
|
|
86
|
+
tensors_=TensorList(tensors),
|
|
87
|
+
prev_=prev,
|
|
88
|
+
beta=beta,
|
|
89
|
+
scaling=scaling,
|
|
90
|
+
max_norm=max_norm,
|
|
91
|
+
)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states, generic_ne
|
|
7
|
+
from ..functional import ema_
|
|
8
|
+
from ..momentum.momentum import nag_
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def msam_(
|
|
12
|
+
tensors: TensorList,
|
|
13
|
+
params: TensorList,
|
|
14
|
+
velocity_: TensorList,
|
|
15
|
+
momentum: float | NumberList,
|
|
16
|
+
lr: NumberList | None,
|
|
17
|
+
rho: float | NumberList,
|
|
18
|
+
weight_decay: float | NumberList,
|
|
19
|
+
nesterov: bool = False,
|
|
20
|
+
lerp: bool = False,
|
|
21
|
+
|
|
22
|
+
# inner args
|
|
23
|
+
inner: Module | None = None,
|
|
24
|
+
grads: list[torch.Tensor] | None = None,
|
|
25
|
+
):
|
|
26
|
+
# weights w and wh, momentum μ, perturbation strength ρ
|
|
27
|
+
# w = wh + rho * v / ||v||
|
|
28
|
+
# v1 = μv + g
|
|
29
|
+
# w1 = w - lr*v1
|
|
30
|
+
# wh1 = w1 - rho * v1 / ||v1||
|
|
31
|
+
|
|
32
|
+
# w1 = wh + rho * v / ||v|| - lr*v1
|
|
33
|
+
# vn = rho * v / ||v||
|
|
34
|
+
# v1n = rho * v1 / ||v1||
|
|
35
|
+
# wh1 = wh + vn - lr*v1 - v1n
|
|
36
|
+
|
|
37
|
+
# the update is
|
|
38
|
+
# vn - lr*v1 - v1n
|
|
39
|
+
|
|
40
|
+
# we track ascent direction so it becomes lr*v1 + v1n - vn
|
|
41
|
+
|
|
42
|
+
# can't really decouple it from lr
|
|
43
|
+
# but at least it is now expressed as function of g
|
|
44
|
+
|
|
45
|
+
denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
|
|
46
|
+
vn = velocity_ / denom
|
|
47
|
+
|
|
48
|
+
mom_ = nag_ if nesterov else ema_
|
|
49
|
+
velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
|
|
50
|
+
|
|
51
|
+
denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
|
|
52
|
+
v1n = velocity_ / denom
|
|
53
|
+
|
|
54
|
+
if inner is not None:
|
|
55
|
+
assert params is not None
|
|
56
|
+
inner_update = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
57
|
+
|
|
58
|
+
else:
|
|
59
|
+
assert lr is not None
|
|
60
|
+
inner_update = velocity_ * lr
|
|
61
|
+
|
|
62
|
+
update = inner_update.add_(v1n).sub_(vn)
|
|
63
|
+
|
|
64
|
+
if generic_ne(weight_decay, 0):
|
|
65
|
+
wd = (params + vn).mul_(weight_decay)
|
|
66
|
+
update.add_(wd)
|
|
67
|
+
|
|
68
|
+
return update
|
|
69
|
+
|
|
70
|
+
class MSAM(Transform):
|
|
71
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
72
|
+
|
|
73
|
+
This implementation expresses the update rule as function of gradient. This way it can be used as a drop-in
|
|
74
|
+
replacement for momentum strategies in other optimizers.
|
|
75
|
+
|
|
76
|
+
To combine MSAM with other optimizers in the way done in the official implementation,
|
|
77
|
+
e.g. to make Adam_MSAM, use :code:`tz.m.MSAMObjective` module.
|
|
78
|
+
|
|
79
|
+
.. note::
|
|
80
|
+
MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
|
|
81
|
+
To avoid compounding learning rate mofications, remove the :code:`tz.m.LR` module if you had it.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
lr (float): learning rate. Adding this module adds support for learning rate schedulers.
|
|
85
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
86
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
87
|
+
weight_decay (float, optional):
|
|
88
|
+
weight decay. It is applied to perturbed parameters, so it is differnet
|
|
89
|
+
from applying :code:`tz.m.WeightDecay` after MSAM. Defaults to 0.
|
|
90
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
91
|
+
lerp (bool, optional):
|
|
92
|
+
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
MSAM
|
|
96
|
+
|
|
97
|
+
.. code-block:: python
|
|
98
|
+
|
|
99
|
+
opt = tz.Modular(
|
|
100
|
+
model.parameters(),
|
|
101
|
+
tz.m.MSAM(1e-3)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
Adam with MSAM instead of exponential average. Note that this is different from Adam_MSAM.
|
|
105
|
+
To make Adam_MSAM and such, use the :code:`tz.m.MSAMObjective` module.
|
|
106
|
+
|
|
107
|
+
.. code-block:: python
|
|
108
|
+
|
|
109
|
+
opt = tz.Modular(
|
|
110
|
+
model.parameters(),
|
|
111
|
+
tz.m.RMSprop(0.999, inner=tz.m.MSAM(1e-3)),
|
|
112
|
+
tz.m.Debias(0.9, 0.999),
|
|
113
|
+
)
|
|
114
|
+
"""
|
|
115
|
+
USES_LR = True
|
|
116
|
+
def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
|
|
117
|
+
defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
|
|
118
|
+
if self.USES_LR: defaults['lr'] = lr
|
|
119
|
+
super().__init__(defaults, uses_grad=False)
|
|
120
|
+
|
|
121
|
+
@torch.no_grad
|
|
122
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
123
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
124
|
+
s = self.settings[params[0]]
|
|
125
|
+
lerp = s['lerp']
|
|
126
|
+
nesterov = s['nesterov']
|
|
127
|
+
|
|
128
|
+
if self.USES_LR:
|
|
129
|
+
lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
|
|
130
|
+
|
|
131
|
+
else:
|
|
132
|
+
lr=None
|
|
133
|
+
momentum,rho,weight_decay = unpack_dicts(settings, 'momentum','rho','weight_decay', cls=NumberList)
|
|
134
|
+
|
|
135
|
+
return msam_(
|
|
136
|
+
TensorList(tensors),
|
|
137
|
+
params=TensorList(params),
|
|
138
|
+
velocity_=velocity,
|
|
139
|
+
momentum=momentum,
|
|
140
|
+
lr=lr,
|
|
141
|
+
rho=rho,
|
|
142
|
+
weight_decay=weight_decay,
|
|
143
|
+
nesterov=nesterov,
|
|
144
|
+
lerp=lerp,
|
|
145
|
+
|
|
146
|
+
# inner args
|
|
147
|
+
inner=self.children.get("modules", None),
|
|
148
|
+
grads=grads,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class MSAMObjective(MSAM):
|
|
153
|
+
"""Momentum-SAM from https://arxiv.org/pdf/2401.12033.
|
|
154
|
+
|
|
155
|
+
.. note::
|
|
156
|
+
Please make sure to place :code:`tz.m.LR` inside the :code:`modules` argument. For example,
|
|
157
|
+
:code:`tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])`. Putting LR after MSAM will lead
|
|
158
|
+
to an incorrect update rule.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
modules (Chainable): modules that will optimizer the MSAM objective. Make sure :code:`tz.m.LR` is one of them.
|
|
162
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
163
|
+
rho (float, optional): perturbation strength. Defaults to 0.3.
|
|
164
|
+
nesterov (bool, optional): whether to use nesterov momentum formula. Defaults to False.
|
|
165
|
+
lerp (bool, optional):
|
|
166
|
+
whether to use linear interpolation, if True, MSAM momentum becomes similar to exponential moving average.
|
|
167
|
+
Defaults to False.
|
|
168
|
+
|
|
169
|
+
Examples:
|
|
170
|
+
AdamW-MSAM
|
|
171
|
+
|
|
172
|
+
.. code-block:: python
|
|
173
|
+
|
|
174
|
+
opt = tz.Modular(
|
|
175
|
+
bench.parameters(),
|
|
176
|
+
tz.m.MSAMObjective(
|
|
177
|
+
[tz.m.Adam(), tz.m.WeightDecay(1e-3), tz.m.LR(1e-3)],
|
|
178
|
+
rho=1.
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
"""
|
|
182
|
+
USES_LR = False
|
|
183
|
+
def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
|
|
184
|
+
super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
|
|
185
|
+
self.set_child('modules', modules)
|
|
186
|
+
|
|
@@ -19,6 +19,7 @@ def _is_at_least_2d(p: torch.Tensor):
|
|
|
19
19
|
|
|
20
20
|
# stolen from:
|
|
21
21
|
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
|
22
|
+
# actually at this stage its a frankenstein
|
|
22
23
|
@enable_compilation
|
|
23
24
|
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
|
|
24
25
|
"""
|
|
@@ -152,7 +153,7 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
152
153
|
The Muon page says that embeddings and classifier heads should not be orthogonalized.
|
|
153
154
|
Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
|
|
154
155
|
|
|
155
|
-
To make Muon, use Split with Adam on 1d params
|
|
156
|
+
To make Muon, use Split with Adam on 1d params
|
|
156
157
|
|
|
157
158
|
Args:
|
|
158
159
|
ns_steps (int, optional):
|
|
@@ -164,7 +165,31 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
164
165
|
method (str, optional):
|
|
165
166
|
Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
|
|
166
167
|
target (str, optional):
|
|
167
|
-
what to set on
|
|
168
|
+
what to set on var.
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
Examples:
|
|
172
|
+
standard Muon with Adam fallback
|
|
173
|
+
|
|
174
|
+
.. code-block:: python
|
|
175
|
+
|
|
176
|
+
opt = tz.Modular(
|
|
177
|
+
model.head.parameters(),
|
|
178
|
+
tz.m.Split(
|
|
179
|
+
# apply muon only to 2D+ parameters
|
|
180
|
+
filter = lambda t: t.ndim >= 2,
|
|
181
|
+
true = [
|
|
182
|
+
tz.m.HeavyBall(),
|
|
183
|
+
tz.m.Orthogonalize(),
|
|
184
|
+
tz.m.LR(1e-2),
|
|
185
|
+
],
|
|
186
|
+
false = tz.m.Adam()
|
|
187
|
+
),
|
|
188
|
+
tz.m.LR(1e-2)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
Reference:
|
|
192
|
+
Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
|
|
168
193
|
"""
|
|
169
194
|
def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
|
|
170
195
|
method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
|
|
@@ -172,9 +197,9 @@ class Orthogonalize(TensorwiseTransform):
|
|
|
172
197
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
173
198
|
|
|
174
199
|
@torch.no_grad
|
|
175
|
-
def
|
|
200
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
176
201
|
orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
|
|
177
|
-
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(
|
|
202
|
+
'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(setting)
|
|
178
203
|
|
|
179
204
|
if not orthogonalize: return tensor
|
|
180
205
|
|
|
@@ -199,7 +224,7 @@ class DualNormCorrection(TensorwiseTransform):
|
|
|
199
224
|
def __init__(self, target: Target='update'):
|
|
200
225
|
super().__init__({}, uses_grad=True, target=target)
|
|
201
226
|
|
|
202
|
-
def
|
|
227
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
203
228
|
assert grad is not None
|
|
204
229
|
if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
|
|
205
230
|
return _dual_norm_correction(tensor, grad, batch_first=False)
|
|
@@ -213,8 +238,8 @@ class MuonAdjustLR(Transform):
|
|
|
213
238
|
defaults = dict(alpha=alpha)
|
|
214
239
|
super().__init__(defaults=defaults, uses_grad=False, target=target)
|
|
215
240
|
|
|
216
|
-
def
|
|
217
|
-
alphas =
|
|
241
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
242
|
+
alphas = [s['alpha'] for s in settings]
|
|
218
243
|
tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
|
|
219
244
|
tensors = [i[0] for i in tensors_alphas]
|
|
220
245
|
a = [i[1] for i in alphas]
|
|
@@ -30,16 +30,15 @@ class OrthoGrad(Transform):
|
|
|
30
30
|
Args:
|
|
31
31
|
eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
|
|
32
32
|
renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
|
|
33
|
-
target (Target, optional): what to set on
|
|
33
|
+
target (Target, optional): what to set on var. Defaults to 'update'.
|
|
34
34
|
"""
|
|
35
35
|
def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
|
|
36
36
|
defaults = dict(eps=eps, renormalize=renormalize)
|
|
37
37
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
renormalize = settings['renormalize']
|
|
39
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
|
+
eps = settings[0]['eps']
|
|
41
|
+
renormalize = settings[0]['renormalize']
|
|
43
42
|
|
|
44
43
|
params = as_tensorlist(params)
|
|
45
44
|
target = as_tensorlist(tensors)
|
|
@@ -3,8 +3,8 @@ from typing import Literal
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Module, Target, Transform, Chainable,
|
|
7
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...core import Module, Target, Transform, Chainable, Var, apply_transform
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
8
|
from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
|
|
9
9
|
|
|
10
10
|
|
|
@@ -23,7 +23,6 @@ def rmsprop_(
|
|
|
23
23
|
inner: Module | None = None,
|
|
24
24
|
params: list[torch.Tensor] | None = None,
|
|
25
25
|
grads: list[torch.Tensor] | None = None,
|
|
26
|
-
vars: Vars | None = None,
|
|
27
26
|
):
|
|
28
27
|
"""returns `tensors_`"""
|
|
29
28
|
if exp_avg_ is not None:
|
|
@@ -36,12 +35,14 @@ def rmsprop_(
|
|
|
36
35
|
|
|
37
36
|
if inner is not None:
|
|
38
37
|
assert params is not None
|
|
39
|
-
tensors_ = TensorList(
|
|
38
|
+
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
40
39
|
|
|
41
40
|
return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
|
|
42
41
|
|
|
43
42
|
class RMSprop(Transform):
|
|
44
|
-
"""Divides graient by EMA of gradient squares.
|
|
43
|
+
"""Divides graient by EMA of gradient squares.
|
|
44
|
+
|
|
45
|
+
This implementation is identical to :code:`torch.optim.RMSprop`.
|
|
45
46
|
|
|
46
47
|
Args:
|
|
47
48
|
smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
|
|
@@ -51,7 +52,8 @@ class RMSprop(Transform):
|
|
|
51
52
|
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
52
53
|
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
53
54
|
init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
|
|
54
|
-
inner (Chainable | None, optional):
|
|
55
|
+
inner (Chainable | None, optional):
|
|
56
|
+
Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
|
|
55
57
|
"""
|
|
56
58
|
def __init__(
|
|
57
59
|
self,
|
|
@@ -61,26 +63,25 @@ class RMSprop(Transform):
|
|
|
61
63
|
debiased: bool = False,
|
|
62
64
|
amsgrad: bool = False,
|
|
63
65
|
pow: float = 2,
|
|
64
|
-
init: Literal["zeros", "update"] = "
|
|
66
|
+
init: Literal["zeros", "update"] = "zeros",
|
|
65
67
|
inner: Chainable | None = None,
|
|
66
68
|
):
|
|
67
69
|
defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
|
|
68
70
|
super().__init__(defaults=defaults, uses_grad=False)
|
|
69
|
-
|
|
71
|
+
|
|
70
72
|
if inner is not None:
|
|
71
73
|
self.set_child('inner', inner)
|
|
72
74
|
|
|
73
|
-
def
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
|
|
75
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
76
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
77
|
+
smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
|
|
78
|
+
centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
|
|
78
79
|
|
|
79
|
-
exp_avg_sq =
|
|
80
|
-
exp_avg =
|
|
81
|
-
max_exp_avg_sq =
|
|
80
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
81
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
|
|
82
|
+
max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
|
|
82
83
|
|
|
83
|
-
if init == 'update' and
|
|
84
|
+
if init == 'update' and step == 1:
|
|
84
85
|
exp_avg_sq.set_([t**2 for t in tensors])
|
|
85
86
|
if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
|
|
86
87
|
|
|
@@ -90,7 +91,7 @@ class RMSprop(Transform):
|
|
|
90
91
|
smoothing=smoothing,
|
|
91
92
|
eps=eps,
|
|
92
93
|
debiased=debiased,
|
|
93
|
-
step=
|
|
94
|
+
step=step,
|
|
94
95
|
exp_avg_=exp_avg,
|
|
95
96
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
96
97
|
pow=pow,
|
|
@@ -99,5 +100,4 @@ class RMSprop(Transform):
|
|
|
99
100
|
inner=self.children.get("inner", None),
|
|
100
101
|
params=params,
|
|
101
102
|
grads=grads,
|
|
102
|
-
vars=vars,
|
|
103
103
|
)
|