torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Cautioning related modules"""
|
|
1
2
|
from collections import deque
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
from typing import Literal
|
|
@@ -5,7 +6,7 @@ from typing import Literal
|
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
8
|
from ...core import Target, Transform, Module, Chainable
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_dicts
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def cautious_(
|
|
@@ -64,27 +65,33 @@ class Cautious(Transform):
|
|
|
64
65
|
normalize=False,
|
|
65
66
|
eps=1e-6,
|
|
66
67
|
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
67
|
-
target: Target = "update",
|
|
68
68
|
):
|
|
69
69
|
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
70
|
-
super().__init__(defaults, uses_grad=True
|
|
70
|
+
super().__init__(defaults, uses_grad=True)
|
|
71
71
|
|
|
72
72
|
@torch.no_grad
|
|
73
|
-
def
|
|
73
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
74
74
|
assert grads is not None
|
|
75
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(
|
|
75
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
|
|
76
76
|
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
77
77
|
|
|
78
78
|
class UpdateGradientSignConsistency(Transform):
|
|
79
|
-
"""
|
|
80
|
-
|
|
79
|
+
"""Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
normalize (bool, optional):
|
|
83
|
+
renormalize update after masking. Defaults to False.
|
|
84
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
85
|
+
"""
|
|
86
|
+
def __init__(self, normalize = False, eps=1e-6):
|
|
87
|
+
|
|
81
88
|
defaults = dict(normalize=normalize, eps=eps)
|
|
82
|
-
super().__init__(defaults, uses_grad=True
|
|
89
|
+
super().__init__(defaults, uses_grad=True)
|
|
83
90
|
|
|
84
91
|
@torch.no_grad
|
|
85
|
-
def
|
|
92
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
86
93
|
assert grads is not None
|
|
87
|
-
normalize, eps = itemgetter('normalize', 'eps')(
|
|
94
|
+
normalize, eps = itemgetter('normalize', 'eps')(settings[0])
|
|
88
95
|
|
|
89
96
|
mask = (TensorList(tensors).mul_(grads)).gt_(0)
|
|
90
97
|
if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
|
|
@@ -92,6 +99,23 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
92
99
|
return mask
|
|
93
100
|
|
|
94
101
|
class IntermoduleCautious(Module):
|
|
102
|
+
"""Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
main (Chainable): main module or sequence of modules whose update will be cautioned.
|
|
106
|
+
compare (Chainable): modules or sequence of modules to compare the sign to.
|
|
107
|
+
normalize (bool, optional):
|
|
108
|
+
renormalize update after masking. Defaults to False.
|
|
109
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
110
|
+
mode (str, optional):
|
|
111
|
+
what to do with updates with inconsistent signs.
|
|
112
|
+
|
|
113
|
+
"zero" - set them to zero (as in paper)
|
|
114
|
+
|
|
115
|
+
"grad" - set them to the gradient
|
|
116
|
+
|
|
117
|
+
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
118
|
+
"""
|
|
95
119
|
def __init__(
|
|
96
120
|
self,
|
|
97
121
|
main: Chainable,
|
|
@@ -100,6 +124,7 @@ class IntermoduleCautious(Module):
|
|
|
100
124
|
eps=1e-6,
|
|
101
125
|
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
102
126
|
):
|
|
127
|
+
|
|
103
128
|
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
104
129
|
super().__init__(defaults)
|
|
105
130
|
|
|
@@ -107,40 +132,45 @@ class IntermoduleCautious(Module):
|
|
|
107
132
|
self.set_child('compare', compare)
|
|
108
133
|
|
|
109
134
|
@torch.no_grad
|
|
110
|
-
def step(self,
|
|
135
|
+
def step(self, var):
|
|
111
136
|
main = self.children['main']
|
|
112
137
|
compare = self.children['compare']
|
|
113
138
|
|
|
114
|
-
|
|
115
|
-
|
|
139
|
+
main_var = main.step(var.clone(clone_update=True))
|
|
140
|
+
var.update_attrs_from_clone_(main_var)
|
|
116
141
|
|
|
117
|
-
|
|
118
|
-
|
|
142
|
+
compare_var = compare.step(var.clone(clone_update=True))
|
|
143
|
+
var.update_attrs_from_clone_(compare_var)
|
|
119
144
|
|
|
120
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[
|
|
121
|
-
|
|
122
|
-
TensorList(
|
|
123
|
-
TensorList(
|
|
145
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
|
|
146
|
+
var.update = cautious_(
|
|
147
|
+
TensorList(main_var.get_update()),
|
|
148
|
+
TensorList(compare_var.get_update()),
|
|
124
149
|
normalize=normalize,
|
|
125
150
|
mode=mode,
|
|
126
151
|
eps=eps,
|
|
127
152
|
)
|
|
128
153
|
|
|
129
|
-
return
|
|
154
|
+
return var
|
|
130
155
|
|
|
131
156
|
class ScaleByGradCosineSimilarity(Transform):
|
|
157
|
+
"""Multiplies the update by cosine similarity with gradient.
|
|
158
|
+
If cosine similarity is negative, naturally the update will be negated as well.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
162
|
+
"""
|
|
132
163
|
def __init__(
|
|
133
164
|
self,
|
|
134
|
-
eps=1e-6,
|
|
135
|
-
target: Target = "update",
|
|
165
|
+
eps: float = 1e-6,
|
|
136
166
|
):
|
|
137
167
|
defaults = dict(eps=eps)
|
|
138
|
-
super().__init__(defaults, uses_grad=True
|
|
168
|
+
super().__init__(defaults, uses_grad=True)
|
|
139
169
|
|
|
140
170
|
@torch.no_grad
|
|
141
|
-
def
|
|
171
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
142
172
|
assert grads is not None
|
|
143
|
-
eps =
|
|
173
|
+
eps = settings[0]['eps']
|
|
144
174
|
tensors = TensorList(tensors)
|
|
145
175
|
grads = TensorList(grads)
|
|
146
176
|
cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
|
|
@@ -148,6 +178,14 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
148
178
|
return tensors.mul_(cos_sim)
|
|
149
179
|
|
|
150
180
|
class ScaleModulesByCosineSimilarity(Module):
|
|
181
|
+
"""Scales the output of :code:`main` module by it's cosine similarity to the output
|
|
182
|
+
of :code:`compare` module.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
main (Chainable): main module or sequence of modules whose update will be scaled.
|
|
186
|
+
compare (Chainable): module or sequence of modules to compare to
|
|
187
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
188
|
+
"""
|
|
151
189
|
def __init__(
|
|
152
190
|
self,
|
|
153
191
|
main: Chainable,
|
|
@@ -161,21 +199,21 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
161
199
|
self.set_child('compare', compare)
|
|
162
200
|
|
|
163
201
|
@torch.no_grad
|
|
164
|
-
def step(self,
|
|
202
|
+
def step(self, var):
|
|
165
203
|
main = self.children['main']
|
|
166
204
|
compare = self.children['compare']
|
|
167
205
|
|
|
168
|
-
|
|
169
|
-
|
|
206
|
+
main_var = main.step(var.clone(clone_update=True))
|
|
207
|
+
var.update_attrs_from_clone_(main_var)
|
|
170
208
|
|
|
171
|
-
|
|
172
|
-
|
|
209
|
+
compare_var = compare.step(var.clone(clone_update=True))
|
|
210
|
+
var.update_attrs_from_clone_(compare_var)
|
|
173
211
|
|
|
174
|
-
m = TensorList(
|
|
175
|
-
c = TensorList(
|
|
176
|
-
eps = self.settings[
|
|
212
|
+
m = TensorList(main_var.get_update())
|
|
213
|
+
c = TensorList(compare_var.get_update())
|
|
214
|
+
eps = self.settings[var.params[0]]['eps']
|
|
177
215
|
|
|
178
216
|
cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
179
217
|
|
|
180
|
-
|
|
181
|
-
return
|
|
218
|
+
var.update = m.mul_(cos_sim)
|
|
219
|
+
return var
|
|
@@ -5,18 +5,19 @@ from typing import Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList, NumberList
|
|
8
|
+
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
9
9
|
from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class EMA(Transform):
|
|
13
|
-
"""Maintains
|
|
13
|
+
"""Maintains an exponential moving average of update.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
17
|
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
18
|
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
19
|
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
20
21
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
21
22
|
"""
|
|
22
23
|
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
@@ -24,13 +25,14 @@ class EMA(Transform):
|
|
|
24
25
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
26
|
|
|
26
27
|
@torch.no_grad
|
|
27
|
-
def
|
|
28
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
28
29
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
30
|
|
|
30
|
-
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(
|
|
31
|
+
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
31
32
|
|
|
32
|
-
exp_avg =
|
|
33
|
-
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
+
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
+
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
34
36
|
|
|
35
37
|
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
36
38
|
|
|
@@ -39,44 +41,58 @@ class EMA(Transform):
|
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class EMASquared(Transform):
|
|
44
|
+
"""Maintains an exponential moving average of squared updates.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
48
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
49
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
50
|
+
"""
|
|
42
51
|
EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
|
|
43
52
|
|
|
44
|
-
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2
|
|
53
|
+
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
|
|
45
54
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
46
|
-
super().__init__(defaults, uses_grad=False
|
|
55
|
+
super().__init__(defaults, uses_grad=False)
|
|
47
56
|
|
|
48
57
|
@torch.no_grad
|
|
49
|
-
def
|
|
58
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
50
59
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
51
|
-
beta =
|
|
60
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
52
61
|
|
|
53
62
|
if amsgrad:
|
|
54
|
-
exp_avg_sq, max_exp_avg_sq =
|
|
63
|
+
exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
55
64
|
else:
|
|
56
|
-
exp_avg_sq =
|
|
65
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
57
66
|
max_exp_avg_sq = None
|
|
58
67
|
|
|
59
68
|
return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
|
|
60
69
|
|
|
61
70
|
class SqrtEMASquared(Transform):
|
|
62
|
-
|
|
71
|
+
"""Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
|
|
63
72
|
|
|
64
|
-
|
|
73
|
+
Args:
|
|
74
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
75
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
76
|
+
debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
|
|
77
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
78
|
+
"""
|
|
79
|
+
SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
|
|
80
|
+
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
|
|
65
81
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
66
|
-
super().__init__(defaults, uses_grad=False
|
|
82
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
83
|
|
|
68
84
|
|
|
69
85
|
@torch.no_grad
|
|
70
|
-
def
|
|
86
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
71
87
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
88
|
|
|
73
|
-
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(
|
|
74
|
-
beta =
|
|
89
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
90
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
75
91
|
|
|
76
92
|
if amsgrad:
|
|
77
|
-
exp_avg_sq, max_exp_avg_sq =
|
|
93
|
+
exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
78
94
|
else:
|
|
79
|
-
exp_avg_sq =
|
|
95
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
80
96
|
max_exp_avg_sq = None
|
|
81
97
|
|
|
82
98
|
return self.SQRT_EMA_SQ_FN(
|
|
@@ -91,47 +107,73 @@ class SqrtEMASquared(Transform):
|
|
|
91
107
|
|
|
92
108
|
|
|
93
109
|
class Debias(Transform):
|
|
110
|
+
"""Multiplies the update by an Adam debiasing term based first and/or second momentum.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
beta1 (float | None, optional):
|
|
114
|
+
first momentum, should be the same as first momentum used in modules before. Defaults to None.
|
|
115
|
+
beta2 (float | None, optional):
|
|
116
|
+
second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
|
|
117
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
118
|
+
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
119
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
120
|
+
"""
|
|
94
121
|
def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
|
|
95
122
|
defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
|
|
96
123
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
97
124
|
|
|
98
125
|
@torch.no_grad
|
|
99
|
-
def
|
|
126
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
100
127
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
101
128
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
|
|
129
|
+
pow = settings[0]['pow']
|
|
130
|
+
alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)
|
|
105
131
|
|
|
106
132
|
return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
|
|
107
133
|
|
|
108
134
|
class Debias2(Transform):
|
|
135
|
+
"""Multiplies the update by an Adam debiasing term based on the second momentum.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
beta (float | None, optional):
|
|
139
|
+
second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
|
|
140
|
+
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
141
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
142
|
+
"""
|
|
109
143
|
def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
|
|
110
144
|
defaults = dict(beta=beta, pow=pow)
|
|
111
145
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
112
146
|
|
|
113
147
|
@torch.no_grad
|
|
114
|
-
def
|
|
148
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
115
149
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
116
150
|
|
|
117
|
-
pow =
|
|
118
|
-
beta =
|
|
151
|
+
pow = settings[0]['pow']
|
|
152
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
119
153
|
return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
|
|
120
154
|
|
|
121
155
|
class CenteredEMASquared(Transform):
|
|
122
|
-
|
|
156
|
+
"""Maintains a centered exponential moving average of squared updates. This also maintains an additional
|
|
157
|
+
exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
161
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
162
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
163
|
+
"""
|
|
164
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
|
|
123
165
|
defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
|
|
124
|
-
super().__init__(defaults, uses_grad=False
|
|
166
|
+
super().__init__(defaults, uses_grad=False)
|
|
125
167
|
|
|
126
168
|
@torch.no_grad
|
|
127
|
-
def
|
|
128
|
-
amsgrad, pow = itemgetter('amsgrad', 'pow')(
|
|
129
|
-
beta =
|
|
169
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
170
|
+
amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
|
|
171
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
130
172
|
|
|
131
173
|
if amsgrad:
|
|
132
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
174
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
133
175
|
else:
|
|
134
|
-
exp_avg, exp_avg_sq =
|
|
176
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
135
177
|
max_exp_avg_sq = None
|
|
136
178
|
|
|
137
179
|
return centered_ema_sq_(
|
|
@@ -144,21 +186,30 @@ class CenteredEMASquared(Transform):
|
|
|
144
186
|
).clone()
|
|
145
187
|
|
|
146
188
|
class CenteredSqrtEMASquared(Transform):
|
|
147
|
-
|
|
189
|
+
"""Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
|
|
190
|
+
This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
194
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
195
|
+
debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
|
|
196
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
197
|
+
"""
|
|
198
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
|
|
148
199
|
defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
|
|
149
|
-
super().__init__(defaults, uses_grad=False
|
|
200
|
+
super().__init__(defaults, uses_grad=False)
|
|
150
201
|
|
|
151
202
|
@torch.no_grad
|
|
152
|
-
def
|
|
203
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
153
204
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
154
205
|
|
|
155
|
-
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(
|
|
156
|
-
beta =
|
|
206
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
207
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
157
208
|
|
|
158
209
|
if amsgrad:
|
|
159
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
210
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
160
211
|
else:
|
|
161
|
-
exp_avg, exp_avg_sq =
|
|
212
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
162
213
|
max_exp_avg_sq = None
|
|
163
214
|
|
|
164
215
|
return sqrt_centered_ema_sq_(
|
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Target, Transform
|
|
9
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
|
|
10
10
|
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
11
|
from .ema import EMASquared, SqrtEMASquared
|
|
12
12
|
from .momentum import nag_
|
|
@@ -43,22 +43,22 @@ def precentered_ema_sq_(
|
|
|
43
43
|
return exp_avg_sq_
|
|
44
44
|
|
|
45
45
|
class PrecenteredEMASquared(Transform):
|
|
46
|
+
"""Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
|
|
46
47
|
def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
47
48
|
defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
|
|
48
49
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
49
|
-
self.current_step = 0
|
|
50
50
|
|
|
51
51
|
@torch.no_grad
|
|
52
|
-
def
|
|
53
|
-
self.
|
|
52
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
53
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
54
54
|
|
|
55
|
-
beta1, beta2 =
|
|
56
|
-
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(
|
|
55
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
56
|
+
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
|
|
57
57
|
|
|
58
58
|
if amsgrad:
|
|
59
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
59
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
60
60
|
else:
|
|
61
|
-
exp_avg, exp_avg_sq =
|
|
61
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
62
62
|
max_exp_avg_sq = None
|
|
63
63
|
|
|
64
64
|
return precentered_ema_sq_(
|
|
@@ -67,7 +67,7 @@ class PrecenteredEMASquared(Transform):
|
|
|
67
67
|
exp_avg_sq_=exp_avg_sq,
|
|
68
68
|
beta1=beta1,
|
|
69
69
|
beta2=beta2,
|
|
70
|
-
step =
|
|
70
|
+
step = step,
|
|
71
71
|
min_step=min_step,
|
|
72
72
|
pow=pow,
|
|
73
73
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
@@ -119,9 +119,11 @@ def sqrt_nag_ema_sq_(
|
|
|
119
119
|
pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
|
|
120
120
|
|
|
121
121
|
class NesterovEMASquared(EMASquared):
|
|
122
|
+
"""squared momentum with nesterov momentum rule"""
|
|
122
123
|
EMA_SQ_FN = staticmethod(nag_ema_sq_)
|
|
123
124
|
|
|
124
125
|
class SqrtNesterovEMASquared(SqrtEMASquared):
|
|
126
|
+
"""square root of squared momentum with nesterov momentum rule"""
|
|
125
127
|
SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
|
|
126
128
|
|
|
127
129
|
|
|
@@ -141,14 +143,20 @@ def coordinate_momentum_(
|
|
|
141
143
|
|
|
142
144
|
|
|
143
145
|
class CoordinateMomentum(Transform):
|
|
146
|
+
"""Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
p (float, optional): _description_. Defaults to 0.1.
|
|
150
|
+
target (Target, optional): _description_. Defaults to 'update'.
|
|
151
|
+
"""
|
|
144
152
|
def __init__(self, p: float = 0.1, target: Target = 'update'):
|
|
145
153
|
defaults = dict(p=p)
|
|
146
154
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
147
155
|
|
|
148
156
|
@torch.no_grad
|
|
149
|
-
def
|
|
150
|
-
p =
|
|
151
|
-
velocity =
|
|
157
|
+
def apply(self, tensors, params, grads, loss, states, settings):
|
|
158
|
+
p = NumberList(s['p'] for s in settings)
|
|
159
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
152
160
|
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
153
161
|
|
|
154
162
|
|
|
@@ -180,7 +188,7 @@ class CoordinateMomentum(Transform):
|
|
|
180
188
|
# super().__init__(defaults, uses_grad=False)
|
|
181
189
|
|
|
182
190
|
# @torch.no_grad
|
|
183
|
-
# def
|
|
191
|
+
# def apply(self, tensors, params, grads, loss, states, settings):
|
|
184
192
|
# momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
185
193
|
# abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
|
|
186
194
|
# velocity = self.get_state('velocity', params=params, cls=TensorList)
|