torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -2,123 +2,192 @@ from typing import Literal
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Module,
|
|
5
|
+
from ...core import Module, apply_transform, Chainable
|
|
6
6
|
from ...utils import NumberList, TensorList, as_tensorlist
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
9
|
class MatrixMomentum(Module):
|
|
10
|
+
"""Second order momentum method.
|
|
11
|
+
|
|
12
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
13
|
+
|
|
14
|
+
.. note::
|
|
15
|
+
:code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
|
|
19
|
+
without having to tune :code:`mu`.
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
This module requires the a closure passed to the optimizer step,
|
|
26
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
27
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
31
|
+
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
32
|
+
hvp_method (str, optional):
|
|
33
|
+
Determines how Hessian-vector products are evaluated.
|
|
34
|
+
|
|
35
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
36
|
+
This requires creating a graph for the gradient.
|
|
37
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
38
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
39
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
40
|
+
more accurate HVP approximation. This requires two extra
|
|
41
|
+
gradient evaluations.
|
|
42
|
+
Defaults to "autograd".
|
|
43
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
44
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
45
|
+
|
|
46
|
+
Reference:
|
|
47
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
10
48
|
"""
|
|
11
|
-
May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
|
|
12
|
-
Evaluates hessian vector product on each step (via finite difference or autograd).
|
|
13
49
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
mu=0.1,
|
|
53
|
+
beta: float = 1,
|
|
54
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
55
|
+
h: float = 1e-3,
|
|
56
|
+
hvp_tfm: Chainable | None = None,
|
|
57
|
+
):
|
|
58
|
+
defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
|
|
20
59
|
super().__init__(defaults)
|
|
21
60
|
|
|
22
61
|
if hvp_tfm is not None:
|
|
23
62
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
24
63
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
prev_update = self.get_state('prev_update', params=vars.params, cls=TensorList)
|
|
29
|
-
hvp_mode = self.settings[vars.params[0]]['hvp_mode']
|
|
30
|
-
h = self.settings[vars.params[0]]['h']
|
|
64
|
+
def reset_for_online(self):
|
|
65
|
+
super().reset_for_online()
|
|
66
|
+
self.clear_state_keys('prev_update')
|
|
31
67
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def update(self, var):
|
|
70
|
+
assert var.closure is not None
|
|
71
|
+
prev_update = self.get_state(var.params, 'prev_update')
|
|
72
|
+
hvp_method = self.settings[var.params[0]]['hvp_method']
|
|
73
|
+
h = self.settings[var.params[0]]['h']
|
|
38
74
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
|
|
42
|
-
if vars.loss_approx is None: vars.loss_approx = l
|
|
75
|
+
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
76
|
+
Hvp = [t.detach() for t in Hvp]
|
|
43
77
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if vars.loss_approx is None: vars.loss_approx = l
|
|
78
|
+
if 'hvp_tfm' in self.children:
|
|
79
|
+
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
47
80
|
|
|
48
|
-
|
|
49
|
-
raise ValueError(hvp_mode)
|
|
81
|
+
self.store(var.params, "Hvp", Hvp)
|
|
50
82
|
|
|
51
|
-
if 'hvp_tfm' in self.children:
|
|
52
|
-
hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
|
|
53
83
|
|
|
54
|
-
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def apply(self, var):
|
|
86
|
+
update = TensorList(var.get_update())
|
|
87
|
+
Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
|
|
88
|
+
mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
|
|
55
89
|
|
|
56
|
-
|
|
57
|
-
update.add_(prev_update - hvp_*mu)
|
|
90
|
+
update.add_(prev_update - Hvp*mu)
|
|
58
91
|
prev_update.set_(update * beta)
|
|
59
|
-
|
|
60
|
-
return
|
|
92
|
+
var.update = update
|
|
93
|
+
return var
|
|
61
94
|
|
|
62
95
|
|
|
63
96
|
class AdaptiveMatrixMomentum(Module):
|
|
97
|
+
"""Second order momentum method.
|
|
98
|
+
|
|
99
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
|
|
100
|
+
|
|
101
|
+
.. note::
|
|
102
|
+
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
103
|
+
|
|
104
|
+
.. note::
|
|
105
|
+
This module requires the a closure passed to the optimizer step,
|
|
106
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
107
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
|
|
112
|
+
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
113
|
+
hvp_method (str, optional):
|
|
114
|
+
Determines how Hessian-vector products are evaluated.
|
|
115
|
+
|
|
116
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
117
|
+
This requires creating a graph for the gradient.
|
|
118
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
119
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
120
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
121
|
+
more accurate HVP approximation. This requires two extra
|
|
122
|
+
gradient evaluations.
|
|
123
|
+
Defaults to "autograd".
|
|
124
|
+
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
125
|
+
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
126
|
+
|
|
127
|
+
Reference:
|
|
128
|
+
Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
|
|
64
129
|
"""
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
mu_mul: float = 1,
|
|
134
|
+
beta: float = 1,
|
|
135
|
+
eps=1e-4,
|
|
136
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
137
|
+
h: float = 1e-3,
|
|
138
|
+
hvp_tfm: Chainable | None = None,
|
|
139
|
+
):
|
|
140
|
+
defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
|
|
69
141
|
super().__init__(defaults)
|
|
70
142
|
|
|
71
143
|
if hvp_tfm is not None:
|
|
72
144
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
73
145
|
|
|
146
|
+
def reset_for_online(self):
|
|
147
|
+
super().reset_for_online()
|
|
148
|
+
self.clear_state_keys('prev_params', 'prev_grad')
|
|
149
|
+
|
|
74
150
|
@torch.no_grad
|
|
75
|
-
def
|
|
76
|
-
assert
|
|
77
|
-
prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad',
|
|
151
|
+
def update(self, var):
|
|
152
|
+
assert var.closure is not None
|
|
153
|
+
prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
|
|
78
154
|
|
|
79
|
-
settings = self.settings[
|
|
80
|
-
|
|
155
|
+
settings = self.settings[var.params[0]]
|
|
156
|
+
hvp_method = settings['hvp_method']
|
|
81
157
|
h = settings['h']
|
|
82
158
|
eps = settings['eps']
|
|
83
159
|
|
|
84
|
-
mu_mul
|
|
85
|
-
|
|
86
|
-
if hvp_mode == 'autograd':
|
|
87
|
-
with torch.enable_grad():
|
|
88
|
-
grad = vars.get_grad(create_graph=True)
|
|
89
|
-
hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
90
|
-
|
|
91
|
-
elif hvp_mode == 'forward':
|
|
92
|
-
vars.get_grad()
|
|
93
|
-
l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
|
|
94
|
-
if vars.loss_approx is None: vars.loss_approx = l
|
|
160
|
+
mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
|
|
95
161
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if vars.loss_approx is None: vars.loss_approx = l
|
|
99
|
-
|
|
100
|
-
else:
|
|
101
|
-
raise ValueError(hvp_mode)
|
|
162
|
+
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
163
|
+
Hvp = [t.detach() for t in Hvp]
|
|
102
164
|
|
|
103
165
|
if 'hvp_tfm' in self.children:
|
|
104
|
-
|
|
166
|
+
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
105
167
|
|
|
106
168
|
# adaptive part
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
s_k = vars.params - prev_params
|
|
110
|
-
prev_params.copy_(vars.params)
|
|
169
|
+
s_k = var.params - prev_params
|
|
170
|
+
prev_params.copy_(var.params)
|
|
111
171
|
|
|
112
|
-
assert
|
|
113
|
-
|
|
114
|
-
prev_grad
|
|
172
|
+
if hvp_method != 'central': assert var.grad is not None
|
|
173
|
+
grad = var.get_grad()
|
|
174
|
+
y_k = grad - prev_grad
|
|
175
|
+
prev_grad.copy_(grad)
|
|
115
176
|
|
|
116
177
|
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
117
178
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
179
|
+
self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def apply(self, var):
|
|
183
|
+
Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
|
|
184
|
+
Hvp = as_tensorlist(Hvp)
|
|
185
|
+
beta = NumberList(self.settings[p]['beta'] for p in var.params)
|
|
186
|
+
update = TensorList(var.get_update())
|
|
187
|
+
prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
|
|
188
|
+
|
|
189
|
+
update.add_(prev_update - Hvp*ada_mu)
|
|
121
190
|
prev_update.set_(update * beta)
|
|
122
|
-
|
|
123
|
-
return
|
|
191
|
+
var.update = update
|
|
192
|
+
return var
|
|
124
193
|
|
|
@@ -3,11 +3,22 @@ from typing import Literal
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
5
|
from ...core import Target, Transform
|
|
6
|
-
from ...utils import NumberList, TensorList
|
|
6
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
7
7
|
from .ema import EMA
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class HeavyBall(EMA):
|
|
11
|
+
"""Polyak's momentum (heavy-ball method).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
15
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
16
|
+
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
17
|
+
lerp (bool, optional):
|
|
18
|
+
whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
|
|
19
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
20
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
21
|
+
"""
|
|
11
22
|
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
|
|
12
23
|
super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
|
|
13
24
|
|
|
@@ -30,14 +41,24 @@ def nag_(
|
|
|
30
41
|
|
|
31
42
|
|
|
32
43
|
class NAG(Transform):
|
|
44
|
+
"""Nesterov accelerated gradient method (nesterov momentum).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
48
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
49
|
+
lerp (bool, optional):
|
|
50
|
+
whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
|
|
51
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
52
|
+
"""
|
|
33
53
|
def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
|
|
34
54
|
defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
|
|
35
55
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
36
56
|
|
|
37
57
|
@torch.no_grad
|
|
38
|
-
def
|
|
39
|
-
velocity =
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
59
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
40
60
|
lerp = self.settings[params[0]]['lerp']
|
|
41
61
|
|
|
42
|
-
momentum,dampening =
|
|
62
|
+
momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
43
63
|
return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
|
|
64
|
+
|
|
@@ -7,7 +7,7 @@ from .accumulate import (
|
|
|
7
7
|
)
|
|
8
8
|
from .binary import (
|
|
9
9
|
Add,
|
|
10
|
-
|
|
10
|
+
BinaryOperationBase,
|
|
11
11
|
Clip,
|
|
12
12
|
CopyMagnitude,
|
|
13
13
|
CopySign,
|
|
@@ -27,37 +27,12 @@ from .binary import (
|
|
|
27
27
|
Sub,
|
|
28
28
|
Threshold,
|
|
29
29
|
)
|
|
30
|
-
from .debug import PrintShape, PrintUpdate
|
|
31
|
-
from .misc import (
|
|
32
|
-
DivByLoss,
|
|
33
|
-
Dropout,
|
|
34
|
-
FillLoss,
|
|
35
|
-
GradientAccumulation,
|
|
36
|
-
GradSign,
|
|
37
|
-
GraftGradToUpdate,
|
|
38
|
-
GraftToGrad,
|
|
39
|
-
GraftToParams,
|
|
40
|
-
LastAbsoluteRatio,
|
|
41
|
-
LastDifference,
|
|
42
|
-
LastGradDifference,
|
|
43
|
-
LastProduct,
|
|
44
|
-
LastRatio,
|
|
45
|
-
MulByLoss,
|
|
46
|
-
Multistep,
|
|
47
|
-
NegateOnLossIncrease,
|
|
48
|
-
NoiseSign,
|
|
49
|
-
Previous,
|
|
50
|
-
Relative,
|
|
51
|
-
Sequential,
|
|
52
|
-
UpdateSign,
|
|
53
|
-
WeightDropout,
|
|
54
|
-
)
|
|
55
30
|
from .multi import (
|
|
56
31
|
ClipModules,
|
|
57
32
|
DivModules,
|
|
58
33
|
GraftModules,
|
|
59
34
|
LerpModules,
|
|
60
|
-
|
|
35
|
+
MultiOperationBase,
|
|
61
36
|
PowModules,
|
|
62
37
|
SubModules,
|
|
63
38
|
)
|
|
@@ -66,13 +41,11 @@ from .reduce import (
|
|
|
66
41
|
Mean,
|
|
67
42
|
MinimumModules,
|
|
68
43
|
Prod,
|
|
69
|
-
|
|
44
|
+
ReduceOperationBase,
|
|
70
45
|
Sum,
|
|
71
46
|
WeightedMean,
|
|
72
47
|
WeightedSum,
|
|
73
48
|
)
|
|
74
|
-
from .split import Split
|
|
75
|
-
from .switch import Alternate, Switch
|
|
76
49
|
from .unary import (
|
|
77
50
|
Abs,
|
|
78
51
|
CustomUnaryOperation,
|
|
@@ -97,7 +70,6 @@ from .utility import (
|
|
|
97
70
|
Randn,
|
|
98
71
|
RandomSample,
|
|
99
72
|
Uniform,
|
|
100
|
-
Update,
|
|
101
73
|
UpdateToNone,
|
|
102
74
|
Zeros,
|
|
103
75
|
)
|
|
@@ -1,65 +1,91 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
1
|
import torch
|
|
6
2
|
|
|
7
3
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList,
|
|
4
|
+
from ...utils import TensorList, unpack_states
|
|
9
5
|
|
|
10
6
|
class AccumulateSum(Transform):
|
|
7
|
+
"""Accumulates sum of all past updates.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
11
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
12
|
+
"""
|
|
11
13
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
12
14
|
defaults = dict(decay=decay)
|
|
13
15
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
14
16
|
|
|
15
17
|
@torch.no_grad
|
|
16
|
-
def
|
|
17
|
-
sum =
|
|
18
|
-
decay =
|
|
19
|
-
return sum.add_(tensors).lazy_mul(
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
19
|
+
sum = unpack_states(states, tensors, 'sum', cls=TensorList)
|
|
20
|
+
decay = [1-s['decay'] for s in settings]
|
|
21
|
+
return sum.add_(tensors).lazy_mul(decay, clone=True)
|
|
20
22
|
|
|
21
23
|
class AccumulateMean(Transform):
|
|
24
|
+
"""Accumulates mean of all past updates.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
28
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
29
|
+
"""
|
|
22
30
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
23
31
|
defaults = dict(decay=decay)
|
|
24
32
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
33
|
|
|
26
34
|
@torch.no_grad
|
|
27
|
-
def
|
|
35
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
28
36
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
|
-
mean =
|
|
30
|
-
decay =
|
|
31
|
-
return mean.add_(tensors).lazy_mul(
|
|
37
|
+
mean = unpack_states(states, tensors, 'mean', cls=TensorList)
|
|
38
|
+
decay = [1-s['decay'] for s in settings]
|
|
39
|
+
return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
|
|
32
40
|
|
|
33
41
|
class AccumulateProduct(Transform):
|
|
42
|
+
"""Accumulates product of all past updates.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
46
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
47
|
+
"""
|
|
34
48
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
35
49
|
defaults = dict(decay=decay)
|
|
36
50
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
37
51
|
|
|
38
52
|
@torch.no_grad
|
|
39
|
-
def
|
|
40
|
-
prod =
|
|
41
|
-
decay =
|
|
42
|
-
return prod.mul_(tensors).lazy_mul(
|
|
53
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
54
|
+
prod = unpack_states(states, tensors, 'prod', cls=TensorList)
|
|
55
|
+
decay = [1-s['decay'] for s in settings]
|
|
56
|
+
return prod.mul_(tensors).lazy_mul(decay, clone=True)
|
|
43
57
|
|
|
44
58
|
class AccumulateMaximum(Transform):
|
|
59
|
+
"""Accumulates maximum of all past updates.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
63
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
64
|
+
"""
|
|
45
65
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
46
66
|
defaults = dict(decay=decay)
|
|
47
67
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
48
68
|
|
|
49
69
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
-
maximum =
|
|
52
|
-
decay =
|
|
53
|
-
return maximum.maximum_(tensors).lazy_mul(
|
|
70
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
71
|
+
maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
|
|
72
|
+
decay = [1-s['decay'] for s in settings]
|
|
73
|
+
return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
|
|
54
74
|
|
|
55
75
|
class AccumulateMinimum(Transform):
|
|
76
|
+
"""Accumulates minimum of all past updates.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
decay (float, optional): decays the accumulator. Defaults to 0.
|
|
80
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
81
|
+
"""
|
|
56
82
|
def __init__(self, decay: float = 0, target: Target = 'update',):
|
|
57
83
|
defaults = dict(decay=decay)
|
|
58
84
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
59
85
|
|
|
60
86
|
@torch.no_grad
|
|
61
|
-
def
|
|
62
|
-
minimum =
|
|
63
|
-
decay =
|
|
64
|
-
return minimum.minimum_(tensors).lazy_mul(
|
|
87
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
88
|
+
minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
|
|
89
|
+
decay = [1-s['decay'] for s in settings]
|
|
90
|
+
return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
|
|
65
91
|
|