torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +115 -68
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +62 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +185 -53
- torchzero/core/transform.py +327 -159
- torchzero/modules/__init__.py +3 -1
- torchzero/modules/clipping/clipping.py +120 -23
- torchzero/modules/clipping/ema_clipping.py +37 -22
- torchzero/modules/clipping/growth_clipping.py +20 -21
- torchzero/modules/experimental/__init__.py +30 -4
- torchzero/modules/experimental/absoap.py +53 -156
- torchzero/modules/experimental/adadam.py +22 -15
- torchzero/modules/experimental/adamY.py +21 -25
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
- torchzero/modules/experimental/adasoap.py +24 -129
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +120 -0
- torchzero/modules/experimental/etf.py +195 -0
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +92 -0
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +10 -7
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +20 -10
- torchzero/modules/experimental/tensor_adagrad.py +42 -0
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +31 -4
- torchzero/modules/grad_approximation/forward_gradient.py +17 -7
- torchzero/modules/grad_approximation/grad_approximator.py +69 -24
- torchzero/modules/grad_approximation/rfdm.py +310 -50
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +319 -0
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +75 -31
- torchzero/modules/line_search/line_search.py +107 -49
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +20 -5
- torchzero/modules/line_search/strong_wolfe.py +52 -36
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/misc/debug.py +48 -0
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/misc/split.py +103 -0
- torchzero/modules/{ops → misc}/switch.py +48 -7
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +115 -40
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +145 -76
- torchzero/modules/momentum/momentum.py +25 -4
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +51 -25
- torchzero/modules/ops/binary.py +108 -62
- torchzero/modules/ops/multi.py +95 -34
- torchzero/modules/ops/reduce.py +31 -23
- torchzero/modules/ops/unary.py +37 -21
- torchzero/modules/ops/utility.py +53 -45
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +48 -29
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +35 -37
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/optimizers/ladagrad.py +183 -0
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +32 -7
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +19 -19
- torchzero/modules/optimizers/rprop.py +89 -52
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +55 -27
- torchzero/modules/optimizers/soap.py +40 -37
- torchzero/modules/optimizers/sophia_h.py +82 -25
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +4 -2
- torchzero/modules/projections/projection.py +212 -118
- torchzero/modules/quasi_newton/__init__.py +44 -5
- torchzero/modules/quasi_newton/cg.py +190 -39
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +102 -58
- torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +245 -54
- torchzero/modules/second_order/newton_cg.py +311 -21
- torchzero/modules/second_order/nystrom.py +124 -21
- torchzero/modules/smoothing/gaussian.py +55 -21
- torchzero/modules/smoothing/laplacian.py +20 -12
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +126 -10
- torchzero/modules/wrappers/optim_wrapper.py +40 -12
- torchzero/optim/wrappers/directsearch.py +281 -0
- torchzero/optim/wrappers/fcmaes.py +105 -0
- torchzero/optim/wrappers/mads.py +89 -0
- torchzero/optim/wrappers/nevergrad.py +20 -5
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +167 -16
- torchzero/utils/__init__.py +3 -7
- torchzero/utils/derivatives.py +5 -4
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +27 -4
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
- torchzero-0.3.11.dist-info/RECORD +159 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/soapy.py +0 -290
- torchzero/modules/experimental/spectral.py +0 -288
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/lr.py +0 -59
- torchzero/modules/lr/step_size.py +0 -97
- torchzero/modules/ops/debug.py +0 -25
- torchzero/modules/ops/misc.py +0 -419
- torchzero/modules/ops/split.py +0 -75
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
"""Cautioning related modules"""
|
|
1
2
|
from collections import deque
|
|
2
3
|
from operator import itemgetter
|
|
3
4
|
from typing import Literal
|
|
@@ -5,7 +6,7 @@ from typing import Literal
|
|
|
5
6
|
import torch
|
|
6
7
|
|
|
7
8
|
from ...core import Target, Transform, Module, Chainable
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_dicts
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def cautious_(
|
|
@@ -54,9 +55,20 @@ class Cautious(Transform):
|
|
|
54
55
|
|
|
55
56
|
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
58
|
+
Examples:
|
|
59
|
+
Cautious Adam
|
|
60
|
+
|
|
61
|
+
.. code-block:: python
|
|
62
|
+
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
bench.parameters(),
|
|
65
|
+
tz.m.Adam(),
|
|
66
|
+
tz.m.Cautious(),
|
|
67
|
+
tz.m.LR(1e-2)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
References:
|
|
71
|
+
Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
|
|
60
72
|
"""
|
|
61
73
|
|
|
62
74
|
def __init__(
|
|
@@ -64,27 +76,33 @@ class Cautious(Transform):
|
|
|
64
76
|
normalize=False,
|
|
65
77
|
eps=1e-6,
|
|
66
78
|
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
67
|
-
target: Target = "update",
|
|
68
79
|
):
|
|
69
80
|
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
70
|
-
super().__init__(defaults, uses_grad=True
|
|
81
|
+
super().__init__(defaults, uses_grad=True)
|
|
71
82
|
|
|
72
83
|
@torch.no_grad
|
|
73
|
-
def
|
|
84
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
74
85
|
assert grads is not None
|
|
75
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(
|
|
86
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
|
|
76
87
|
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
77
88
|
|
|
78
89
|
class UpdateGradientSignConsistency(Transform):
|
|
79
|
-
"""
|
|
80
|
-
|
|
90
|
+
"""Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
normalize (bool, optional):
|
|
94
|
+
renormalize update after masking. Defaults to False.
|
|
95
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
96
|
+
"""
|
|
97
|
+
def __init__(self, normalize = False, eps=1e-6):
|
|
98
|
+
|
|
81
99
|
defaults = dict(normalize=normalize, eps=eps)
|
|
82
|
-
super().__init__(defaults, uses_grad=True
|
|
100
|
+
super().__init__(defaults, uses_grad=True)
|
|
83
101
|
|
|
84
102
|
@torch.no_grad
|
|
85
|
-
def
|
|
103
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
86
104
|
assert grads is not None
|
|
87
|
-
normalize, eps = itemgetter('normalize', 'eps')(
|
|
105
|
+
normalize, eps = itemgetter('normalize', 'eps')(settings[0])
|
|
88
106
|
|
|
89
107
|
mask = (TensorList(tensors).mul_(grads)).gt_(0)
|
|
90
108
|
if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
|
|
@@ -92,6 +110,23 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
92
110
|
return mask
|
|
93
111
|
|
|
94
112
|
class IntermoduleCautious(Module):
|
|
113
|
+
"""Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
main (Chainable): main module or sequence of modules whose update will be cautioned.
|
|
117
|
+
compare (Chainable): modules or sequence of modules to compare the sign to.
|
|
118
|
+
normalize (bool, optional):
|
|
119
|
+
renormalize update after masking. Defaults to False.
|
|
120
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
121
|
+
mode (str, optional):
|
|
122
|
+
what to do with updates with inconsistent signs.
|
|
123
|
+
|
|
124
|
+
"zero" - set them to zero (as in paper)
|
|
125
|
+
|
|
126
|
+
"grad" - set them to the gradient
|
|
127
|
+
|
|
128
|
+
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
129
|
+
"""
|
|
95
130
|
def __init__(
|
|
96
131
|
self,
|
|
97
132
|
main: Chainable,
|
|
@@ -100,6 +135,7 @@ class IntermoduleCautious(Module):
|
|
|
100
135
|
eps=1e-6,
|
|
101
136
|
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
102
137
|
):
|
|
138
|
+
|
|
103
139
|
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
104
140
|
super().__init__(defaults)
|
|
105
141
|
|
|
@@ -107,47 +143,86 @@ class IntermoduleCautious(Module):
|
|
|
107
143
|
self.set_child('compare', compare)
|
|
108
144
|
|
|
109
145
|
@torch.no_grad
|
|
110
|
-
def step(self,
|
|
146
|
+
def step(self, var):
|
|
111
147
|
main = self.children['main']
|
|
112
148
|
compare = self.children['compare']
|
|
113
149
|
|
|
114
|
-
|
|
115
|
-
|
|
150
|
+
main_var = main.step(var.clone(clone_update=True))
|
|
151
|
+
var.update_attrs_from_clone_(main_var)
|
|
116
152
|
|
|
117
|
-
|
|
118
|
-
|
|
153
|
+
compare_var = compare.step(var.clone(clone_update=True))
|
|
154
|
+
var.update_attrs_from_clone_(compare_var)
|
|
119
155
|
|
|
120
|
-
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[
|
|
121
|
-
|
|
122
|
-
TensorList(
|
|
123
|
-
TensorList(
|
|
156
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
|
|
157
|
+
var.update = cautious_(
|
|
158
|
+
TensorList(main_var.get_update()),
|
|
159
|
+
TensorList(compare_var.get_update()),
|
|
124
160
|
normalize=normalize,
|
|
125
161
|
mode=mode,
|
|
126
162
|
eps=eps,
|
|
127
163
|
)
|
|
128
164
|
|
|
129
|
-
return
|
|
165
|
+
return var
|
|
130
166
|
|
|
131
167
|
class ScaleByGradCosineSimilarity(Transform):
|
|
168
|
+
"""Multiplies the update by cosine similarity with gradient.
|
|
169
|
+
If cosine similarity is negative, naturally the update will be negated as well.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
173
|
+
|
|
174
|
+
Examples:
|
|
175
|
+
Scaled Adam
|
|
176
|
+
|
|
177
|
+
.. code-block:: python
|
|
178
|
+
|
|
179
|
+
opt = tz.Modular(
|
|
180
|
+
bench.parameters(),
|
|
181
|
+
tz.m.Adam(),
|
|
182
|
+
tz.m.ScaleByGradCosineSimilarity(),
|
|
183
|
+
tz.m.LR(1e-2)
|
|
184
|
+
)
|
|
185
|
+
"""
|
|
132
186
|
def __init__(
|
|
133
187
|
self,
|
|
134
|
-
eps=1e-6,
|
|
135
|
-
target: Target = "update",
|
|
188
|
+
eps: float = 1e-6,
|
|
136
189
|
):
|
|
137
190
|
defaults = dict(eps=eps)
|
|
138
|
-
super().__init__(defaults, uses_grad=True
|
|
191
|
+
super().__init__(defaults, uses_grad=True)
|
|
139
192
|
|
|
140
193
|
@torch.no_grad
|
|
141
|
-
def
|
|
194
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
142
195
|
assert grads is not None
|
|
143
|
-
eps =
|
|
196
|
+
eps = settings[0]['eps']
|
|
144
197
|
tensors = TensorList(tensors)
|
|
145
198
|
grads = TensorList(grads)
|
|
146
|
-
cos_sim =
|
|
199
|
+
cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
|
|
147
200
|
|
|
148
201
|
return tensors.mul_(cos_sim)
|
|
149
202
|
|
|
150
203
|
class ScaleModulesByCosineSimilarity(Module):
|
|
204
|
+
"""Scales the output of :code:`main` module by it's cosine similarity to the output
|
|
205
|
+
of :code:`compare` module.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
main (Chainable): main module or sequence of modules whose update will be scaled.
|
|
209
|
+
compare (Chainable): module or sequence of modules to compare to
|
|
210
|
+
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
211
|
+
|
|
212
|
+
Example:
|
|
213
|
+
Adam scaled by similarity to RMSprop
|
|
214
|
+
|
|
215
|
+
.. code-block:: python
|
|
216
|
+
|
|
217
|
+
opt = tz.Modular(
|
|
218
|
+
bench.parameters(),
|
|
219
|
+
tz.m.ScaleModulesByCosineSimilarity(
|
|
220
|
+
main = tz.m.Adam(),
|
|
221
|
+
compare = tz.m.RMSprop(0.999, debiased=True),
|
|
222
|
+
),
|
|
223
|
+
tz.m.LR(1e-2)
|
|
224
|
+
)
|
|
225
|
+
"""
|
|
151
226
|
def __init__(
|
|
152
227
|
self,
|
|
153
228
|
main: Chainable,
|
|
@@ -161,21 +236,21 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
161
236
|
self.set_child('compare', compare)
|
|
162
237
|
|
|
163
238
|
@torch.no_grad
|
|
164
|
-
def step(self,
|
|
239
|
+
def step(self, var):
|
|
165
240
|
main = self.children['main']
|
|
166
241
|
compare = self.children['compare']
|
|
167
242
|
|
|
168
|
-
|
|
169
|
-
|
|
243
|
+
main_var = main.step(var.clone(clone_update=True))
|
|
244
|
+
var.update_attrs_from_clone_(main_var)
|
|
170
245
|
|
|
171
|
-
|
|
172
|
-
|
|
246
|
+
compare_var = compare.step(var.clone(clone_update=True))
|
|
247
|
+
var.update_attrs_from_clone_(compare_var)
|
|
173
248
|
|
|
174
|
-
m = TensorList(
|
|
175
|
-
c = TensorList(
|
|
176
|
-
eps = self.settings[
|
|
249
|
+
m = TensorList(main_var.get_update())
|
|
250
|
+
c = TensorList(compare_var.get_update())
|
|
251
|
+
eps = self.settings[var.params[0]]['eps']
|
|
177
252
|
|
|
178
|
-
cos_sim =
|
|
253
|
+
cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
179
254
|
|
|
180
|
-
|
|
181
|
-
return
|
|
255
|
+
var.update = m.mul_(cos_sim)
|
|
256
|
+
return var
|
|
@@ -5,18 +5,19 @@ from typing import Literal
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList, NumberList
|
|
8
|
+
from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
|
|
9
9
|
from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class EMA(Transform):
|
|
13
|
-
"""Maintains
|
|
13
|
+
"""Maintains an exponential moving average of update.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
17
|
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
18
|
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
19
|
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
+
ema_init (str, optional): initial values for the EMA, "zeros" or "update".
|
|
20
21
|
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
21
22
|
"""
|
|
22
23
|
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
@@ -24,13 +25,14 @@ class EMA(Transform):
|
|
|
24
25
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
26
|
|
|
26
27
|
@torch.no_grad
|
|
27
|
-
def
|
|
28
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
28
29
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
30
|
|
|
30
|
-
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(
|
|
31
|
+
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
31
32
|
|
|
32
|
-
exp_avg =
|
|
33
|
-
|
|
33
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg',
|
|
34
|
+
init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
35
|
+
momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
34
36
|
|
|
35
37
|
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
36
38
|
|
|
@@ -39,44 +41,58 @@ class EMA(Transform):
|
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class EMASquared(Transform):
|
|
44
|
+
"""Maintains an exponential moving average of squared updates.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
48
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
49
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
50
|
+
"""
|
|
42
51
|
EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
|
|
43
52
|
|
|
44
|
-
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2
|
|
53
|
+
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
|
|
45
54
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
46
|
-
super().__init__(defaults, uses_grad=False
|
|
55
|
+
super().__init__(defaults, uses_grad=False)
|
|
47
56
|
|
|
48
57
|
@torch.no_grad
|
|
49
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
50
59
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
51
|
-
beta =
|
|
60
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
52
61
|
|
|
53
62
|
if amsgrad:
|
|
54
|
-
exp_avg_sq, max_exp_avg_sq =
|
|
63
|
+
exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
55
64
|
else:
|
|
56
|
-
exp_avg_sq =
|
|
65
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
57
66
|
max_exp_avg_sq = None
|
|
58
67
|
|
|
59
68
|
return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
|
|
60
69
|
|
|
61
70
|
class SqrtEMASquared(Transform):
|
|
62
|
-
|
|
71
|
+
"""Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
|
|
63
72
|
|
|
64
|
-
|
|
73
|
+
Args:
|
|
74
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
75
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
76
|
+
debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
|
|
77
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
78
|
+
"""
|
|
79
|
+
SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
|
|
80
|
+
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
|
|
65
81
|
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
66
|
-
super().__init__(defaults, uses_grad=False
|
|
82
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
83
|
|
|
68
84
|
|
|
69
85
|
@torch.no_grad
|
|
70
|
-
def
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
71
87
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
88
|
|
|
73
|
-
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(
|
|
74
|
-
beta =
|
|
89
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
90
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
75
91
|
|
|
76
92
|
if amsgrad:
|
|
77
|
-
exp_avg_sq, max_exp_avg_sq =
|
|
93
|
+
exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
78
94
|
else:
|
|
79
|
-
exp_avg_sq =
|
|
95
|
+
exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
|
|
80
96
|
max_exp_avg_sq = None
|
|
81
97
|
|
|
82
98
|
return self.SQRT_EMA_SQ_FN(
|
|
@@ -91,47 +107,73 @@ class SqrtEMASquared(Transform):
|
|
|
91
107
|
|
|
92
108
|
|
|
93
109
|
class Debias(Transform):
|
|
110
|
+
"""Multiplies the update by an Adam debiasing term based first and/or second momentum.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
beta1 (float | None, optional):
|
|
114
|
+
first momentum, should be the same as first momentum used in modules before. Defaults to None.
|
|
115
|
+
beta2 (float | None, optional):
|
|
116
|
+
second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
|
|
117
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
118
|
+
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
119
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
120
|
+
"""
|
|
94
121
|
def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
|
|
95
122
|
defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
|
|
96
123
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
97
124
|
|
|
98
125
|
@torch.no_grad
|
|
99
|
-
def
|
|
126
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
100
127
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
101
128
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
|
|
129
|
+
pow = settings[0]['pow']
|
|
130
|
+
alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)
|
|
105
131
|
|
|
106
132
|
return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
|
|
107
133
|
|
|
108
134
|
class Debias2(Transform):
|
|
135
|
+
"""Multiplies the update by an Adam debiasing term based on the second momentum.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
beta (float | None, optional):
|
|
139
|
+
second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
|
|
140
|
+
pow (float, optional): power, assumes absolute value is used. Defaults to 2.
|
|
141
|
+
target (Target, optional): target. Defaults to 'update'.
|
|
142
|
+
"""
|
|
109
143
|
def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
|
|
110
144
|
defaults = dict(beta=beta, pow=pow)
|
|
111
145
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
112
146
|
|
|
113
147
|
@torch.no_grad
|
|
114
|
-
def
|
|
148
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
115
149
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
116
150
|
|
|
117
|
-
pow =
|
|
118
|
-
beta =
|
|
151
|
+
pow = settings[0]['pow']
|
|
152
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
119
153
|
return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
|
|
120
154
|
|
|
121
155
|
class CenteredEMASquared(Transform):
|
|
122
|
-
|
|
156
|
+
"""Maintains a centered exponential moving average of squared updates. This also maintains an additional
|
|
157
|
+
exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
161
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
162
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
163
|
+
"""
|
|
164
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
|
|
123
165
|
defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
|
|
124
|
-
super().__init__(defaults, uses_grad=False
|
|
166
|
+
super().__init__(defaults, uses_grad=False)
|
|
125
167
|
|
|
126
168
|
@torch.no_grad
|
|
127
|
-
def
|
|
128
|
-
amsgrad, pow = itemgetter('amsgrad', 'pow')(
|
|
129
|
-
beta =
|
|
169
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
170
|
+
amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
|
|
171
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
130
172
|
|
|
131
173
|
if amsgrad:
|
|
132
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
174
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
133
175
|
else:
|
|
134
|
-
exp_avg, exp_avg_sq =
|
|
176
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
135
177
|
max_exp_avg_sq = None
|
|
136
178
|
|
|
137
179
|
return centered_ema_sq_(
|
|
@@ -144,21 +186,30 @@ class CenteredEMASquared(Transform):
|
|
|
144
186
|
).clone()
|
|
145
187
|
|
|
146
188
|
class CenteredSqrtEMASquared(Transform):
|
|
147
|
-
|
|
189
|
+
"""Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
|
|
190
|
+
This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
beta (float, optional): momentum value. Defaults to 0.999.
|
|
194
|
+
amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
|
|
195
|
+
debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
|
|
196
|
+
pow (float, optional): power, absolute value is always used. Defaults to 2.
|
|
197
|
+
"""
|
|
198
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
|
|
148
199
|
defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
|
|
149
|
-
super().__init__(defaults, uses_grad=False
|
|
200
|
+
super().__init__(defaults, uses_grad=False)
|
|
150
201
|
|
|
151
202
|
@torch.no_grad
|
|
152
|
-
def
|
|
203
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
153
204
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
154
205
|
|
|
155
|
-
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(
|
|
156
|
-
beta =
|
|
206
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
207
|
+
beta = NumberList(s['beta'] for s in settings)
|
|
157
208
|
|
|
158
209
|
if amsgrad:
|
|
159
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
210
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
160
211
|
else:
|
|
161
|
-
exp_avg, exp_avg_sq =
|
|
212
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
162
213
|
max_exp_avg_sq = None
|
|
163
214
|
|
|
164
215
|
return sqrt_centered_ema_sq_(
|
|
@@ -6,7 +6,7 @@ from typing import Literal
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from ...core import Target, Transform
|
|
9
|
-
from ...utils import NumberList, TensorList
|
|
9
|
+
from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
|
|
10
10
|
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
11
|
from .ema import EMASquared, SqrtEMASquared
|
|
12
12
|
from .momentum import nag_
|
|
@@ -43,22 +43,22 @@ def precentered_ema_sq_(
|
|
|
43
43
|
return exp_avg_sq_
|
|
44
44
|
|
|
45
45
|
class PrecenteredEMASquared(Transform):
|
|
46
|
+
"""Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
|
|
46
47
|
def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
47
48
|
defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
|
|
48
49
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
49
|
-
self.current_step = 0
|
|
50
50
|
|
|
51
51
|
@torch.no_grad
|
|
52
|
-
def
|
|
53
|
-
self.
|
|
52
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
53
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
54
54
|
|
|
55
|
-
beta1, beta2 =
|
|
56
|
-
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(
|
|
55
|
+
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
56
|
+
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
|
|
57
57
|
|
|
58
58
|
if amsgrad:
|
|
59
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq =
|
|
59
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
60
60
|
else:
|
|
61
|
-
exp_avg, exp_avg_sq =
|
|
61
|
+
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
62
62
|
max_exp_avg_sq = None
|
|
63
63
|
|
|
64
64
|
return precentered_ema_sq_(
|
|
@@ -67,7 +67,7 @@ class PrecenteredEMASquared(Transform):
|
|
|
67
67
|
exp_avg_sq_=exp_avg_sq,
|
|
68
68
|
beta1=beta1,
|
|
69
69
|
beta2=beta2,
|
|
70
|
-
step =
|
|
70
|
+
step = step,
|
|
71
71
|
min_step=min_step,
|
|
72
72
|
pow=pow,
|
|
73
73
|
max_exp_avg_sq_=max_exp_avg_sq,
|
|
@@ -119,9 +119,11 @@ def sqrt_nag_ema_sq_(
|
|
|
119
119
|
pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
|
|
120
120
|
|
|
121
121
|
class NesterovEMASquared(EMASquared):
|
|
122
|
+
"""squared momentum with nesterov momentum rule"""
|
|
122
123
|
EMA_SQ_FN = staticmethod(nag_ema_sq_)
|
|
123
124
|
|
|
124
125
|
class SqrtNesterovEMASquared(SqrtEMASquared):
|
|
126
|
+
"""square root of squared momentum with nesterov momentum rule"""
|
|
125
127
|
SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
|
|
126
128
|
|
|
127
129
|
|
|
@@ -141,14 +143,20 @@ def coordinate_momentum_(
|
|
|
141
143
|
|
|
142
144
|
|
|
143
145
|
class CoordinateMomentum(Transform):
|
|
146
|
+
"""Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
p (float, optional): _description_. Defaults to 0.1.
|
|
150
|
+
target (Target, optional): _description_. Defaults to 'update'.
|
|
151
|
+
"""
|
|
144
152
|
def __init__(self, p: float = 0.1, target: Target = 'update'):
|
|
145
153
|
defaults = dict(p=p)
|
|
146
154
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
147
155
|
|
|
148
156
|
@torch.no_grad
|
|
149
|
-
def
|
|
150
|
-
p =
|
|
151
|
-
velocity =
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
|
+
p = NumberList(s['p'] for s in settings)
|
|
159
|
+
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
152
160
|
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
153
161
|
|
|
154
162
|
|
|
@@ -180,7 +188,7 @@ class CoordinateMomentum(Transform):
|
|
|
180
188
|
# super().__init__(defaults, uses_grad=False)
|
|
181
189
|
|
|
182
190
|
# @torch.no_grad
|
|
183
|
-
# def
|
|
191
|
+
# def apply(self, tensors, params, grads, loss, states, settings):
|
|
184
192
|
# momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
185
193
|
# abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
|
|
186
194
|
# velocity = self.get_state('velocity', params=params, cls=TensorList)
|