torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Alternate(Module):
|
|
10
|
-
"""
|
|
10
|
+
"""Alternates between stepping with :code:`modules`.
|
|
11
|
+
|
|
12
|
+
That is, first step is performed with 1st module, second step with second module, etc.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
Alternate between Adam, SignSGD and RMSprop
|
|
19
|
+
|
|
20
|
+
.. code-block:: python
|
|
21
|
+
|
|
22
|
+
opt = tz.Modular(
|
|
23
|
+
model.parameters(),
|
|
24
|
+
tz.m.Alternate(
|
|
25
|
+
tz.m.Adam(),
|
|
26
|
+
[tz.m.SignSGD(), tz.m.Mul(0.5)],
|
|
27
|
+
tz.m.RMSprop(),
|
|
28
|
+
),
|
|
29
|
+
tz.m.LR(1e-3),
|
|
30
|
+
)
|
|
31
|
+
"""
|
|
11
32
|
LOOP = True
|
|
12
33
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
|
|
13
34
|
if isinstance(steps, Iterable):
|
|
@@ -54,14 +75,34 @@ class Alternate(Module):
|
|
|
54
75
|
return var
|
|
55
76
|
|
|
56
77
|
class Switch(Alternate):
|
|
57
|
-
"""
|
|
78
|
+
"""After :code:`steps` steps switches to the next module.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
steps (int | Iterable[int]): Number of steps to perform with each module.
|
|
82
|
+
|
|
83
|
+
Examples:
|
|
84
|
+
Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
|
|
85
|
+
|
|
86
|
+
.. code-block:: python
|
|
87
|
+
|
|
88
|
+
opt = tz.Modular(
|
|
89
|
+
model.parameters(),
|
|
90
|
+
tz.m.Switch(
|
|
91
|
+
[tz.m.Adam(), tz.m.LR(1e-3)],
|
|
92
|
+
[tz.m.LBFGS(), tz.m.Backtracking()],
|
|
93
|
+
[tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
|
|
94
|
+
steps = (1000, 2000)
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
"""
|
|
98
|
+
|
|
58
99
|
LOOP = False
|
|
59
100
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
|
|
60
101
|
|
|
61
102
|
if isinstance(steps, Iterable):
|
|
62
103
|
steps = list(steps)
|
|
63
104
|
if len(steps) != len(modules) - 1:
|
|
64
|
-
raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
|
|
105
|
+
raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
|
|
65
106
|
|
|
66
107
|
steps.append(1)
|
|
67
108
|
|
|
@@ -11,4 +11,4 @@ from .experimental import CoordinateMomentum
|
|
|
11
11
|
# from .matrix_momentum import MatrixMomentum
|
|
12
12
|
|
|
13
13
|
from .momentum import NAG, HeavyBall
|
|
14
|
-
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
14
|
+
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
@@ -21,8 +21,8 @@ class Averaging(TensorwiseTransform):
|
|
|
21
21
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
22
22
|
|
|
23
23
|
@torch.no_grad
|
|
24
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
25
|
-
history_size =
|
|
24
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
25
|
+
history_size = setting['history_size']
|
|
26
26
|
if 'history' not in state:
|
|
27
27
|
state['history'] = deque(maxlen=history_size)
|
|
28
28
|
state['average'] = torch.zeros_like(tensor)
|
|
@@ -46,8 +46,8 @@ class WeightedAveraging(TensorwiseTransform):
|
|
|
46
46
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
47
47
|
|
|
48
48
|
@torch.no_grad
|
|
49
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
50
|
-
weights =
|
|
49
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
50
|
+
weights = setting['weights']
|
|
51
51
|
|
|
52
52
|
if 'history' not in state:
|
|
53
53
|
state['history'] = deque(maxlen=len(weights))
|
|
@@ -80,8 +80,8 @@ class MedianAveraging(TensorwiseTransform):
|
|
|
80
80
|
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
81
81
|
|
|
82
82
|
@torch.no_grad
|
|
83
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
84
|
-
history_size =
|
|
83
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
84
|
+
history_size = setting['history_size']
|
|
85
85
|
|
|
86
86
|
if 'history' not in state:
|
|
87
87
|
state['history'] = deque(maxlen=history_size)
|
|
@@ -55,9 +55,20 @@ class Cautious(Transform):
|
|
|
55
55
|
|
|
56
56
|
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
57
57
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
Examples:
|
|
59
|
+
Cautious Adam
|
|
60
|
+
|
|
61
|
+
.. code-block:: python
|
|
62
|
+
|
|
63
|
+
opt = tz.Modular(
|
|
64
|
+
bench.parameters(),
|
|
65
|
+
tz.m.Adam(),
|
|
66
|
+
tz.m.Cautious(),
|
|
67
|
+
tz.m.LR(1e-2)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
References:
|
|
71
|
+
Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
|
|
61
72
|
"""
|
|
62
73
|
|
|
63
74
|
def __init__(
|
|
@@ -70,7 +81,7 @@ class Cautious(Transform):
|
|
|
70
81
|
super().__init__(defaults, uses_grad=True)
|
|
71
82
|
|
|
72
83
|
@torch.no_grad
|
|
73
|
-
def
|
|
84
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
74
85
|
assert grads is not None
|
|
75
86
|
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
|
|
76
87
|
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
@@ -89,7 +100,7 @@ class UpdateGradientSignConsistency(Transform):
|
|
|
89
100
|
super().__init__(defaults, uses_grad=True)
|
|
90
101
|
|
|
91
102
|
@torch.no_grad
|
|
92
|
-
def
|
|
103
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
93
104
|
assert grads is not None
|
|
94
105
|
normalize, eps = itemgetter('normalize', 'eps')(settings[0])
|
|
95
106
|
|
|
@@ -159,6 +170,18 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
159
170
|
|
|
160
171
|
Args:
|
|
161
172
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
173
|
+
|
|
174
|
+
Examples:
|
|
175
|
+
Scaled Adam
|
|
176
|
+
|
|
177
|
+
.. code-block:: python
|
|
178
|
+
|
|
179
|
+
opt = tz.Modular(
|
|
180
|
+
bench.parameters(),
|
|
181
|
+
tz.m.Adam(),
|
|
182
|
+
tz.m.ScaleByGradCosineSimilarity(),
|
|
183
|
+
tz.m.LR(1e-2)
|
|
184
|
+
)
|
|
162
185
|
"""
|
|
163
186
|
def __init__(
|
|
164
187
|
self,
|
|
@@ -168,12 +191,12 @@ class ScaleByGradCosineSimilarity(Transform):
|
|
|
168
191
|
super().__init__(defaults, uses_grad=True)
|
|
169
192
|
|
|
170
193
|
@torch.no_grad
|
|
171
|
-
def
|
|
194
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
172
195
|
assert grads is not None
|
|
173
196
|
eps = settings[0]['eps']
|
|
174
197
|
tensors = TensorList(tensors)
|
|
175
198
|
grads = TensorList(grads)
|
|
176
|
-
cos_sim =
|
|
199
|
+
cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
|
|
177
200
|
|
|
178
201
|
return tensors.mul_(cos_sim)
|
|
179
202
|
|
|
@@ -185,6 +208,20 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
185
208
|
main (Chainable): main module or sequence of modules whose update will be scaled.
|
|
186
209
|
compare (Chainable): module or sequence of modules to compare to
|
|
187
210
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
211
|
+
|
|
212
|
+
Example:
|
|
213
|
+
Adam scaled by similarity to RMSprop
|
|
214
|
+
|
|
215
|
+
.. code-block:: python
|
|
216
|
+
|
|
217
|
+
opt = tz.Modular(
|
|
218
|
+
bench.parameters(),
|
|
219
|
+
tz.m.ScaleModulesByCosineSimilarity(
|
|
220
|
+
main = tz.m.Adam(),
|
|
221
|
+
compare = tz.m.RMSprop(0.999, debiased=True),
|
|
222
|
+
),
|
|
223
|
+
tz.m.LR(1e-2)
|
|
224
|
+
)
|
|
188
225
|
"""
|
|
189
226
|
def __init__(
|
|
190
227
|
self,
|
|
@@ -213,7 +250,7 @@ class ScaleModulesByCosineSimilarity(Module):
|
|
|
213
250
|
c = TensorList(compare_var.get_update())
|
|
214
251
|
eps = self.settings[var.params[0]]['eps']
|
|
215
252
|
|
|
216
|
-
cos_sim =
|
|
253
|
+
cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
217
254
|
|
|
218
255
|
var.update = m.mul_(cos_sim)
|
|
219
256
|
return var
|
|
@@ -25,7 +25,7 @@ class EMA(Transform):
|
|
|
25
25
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
26
26
|
|
|
27
27
|
@torch.no_grad
|
|
28
|
-
def
|
|
28
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
29
29
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
30
30
|
|
|
31
31
|
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
|
|
@@ -55,7 +55,7 @@ class EMASquared(Transform):
|
|
|
55
55
|
super().__init__(defaults, uses_grad=False)
|
|
56
56
|
|
|
57
57
|
@torch.no_grad
|
|
58
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
59
59
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
60
60
|
beta = NumberList(s['beta'] for s in settings)
|
|
61
61
|
|
|
@@ -83,7 +83,7 @@ class SqrtEMASquared(Transform):
|
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
@torch.no_grad
|
|
86
|
-
def
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
87
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
88
88
|
|
|
89
89
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
@@ -123,7 +123,7 @@ class Debias(Transform):
|
|
|
123
123
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
124
124
|
|
|
125
125
|
@torch.no_grad
|
|
126
|
-
def
|
|
126
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
127
127
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
128
128
|
|
|
129
129
|
pow = settings[0]['pow']
|
|
@@ -145,7 +145,7 @@ class Debias2(Transform):
|
|
|
145
145
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
146
146
|
|
|
147
147
|
@torch.no_grad
|
|
148
|
-
def
|
|
148
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
149
149
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
150
150
|
|
|
151
151
|
pow = settings[0]['pow']
|
|
@@ -166,7 +166,7 @@ class CenteredEMASquared(Transform):
|
|
|
166
166
|
super().__init__(defaults, uses_grad=False)
|
|
167
167
|
|
|
168
168
|
@torch.no_grad
|
|
169
|
-
def
|
|
169
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
170
170
|
amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
|
|
171
171
|
beta = NumberList(s['beta'] for s in settings)
|
|
172
172
|
|
|
@@ -200,7 +200,7 @@ class CenteredSqrtEMASquared(Transform):
|
|
|
200
200
|
super().__init__(defaults, uses_grad=False)
|
|
201
201
|
|
|
202
202
|
@torch.no_grad
|
|
203
|
-
def
|
|
203
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
204
204
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
205
205
|
|
|
206
206
|
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
|
|
@@ -49,7 +49,7 @@ class PrecenteredEMASquared(Transform):
|
|
|
49
49
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
50
50
|
|
|
51
51
|
@torch.no_grad
|
|
52
|
-
def
|
|
52
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
53
53
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
54
54
|
|
|
55
55
|
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
@@ -154,7 +154,7 @@ class CoordinateMomentum(Transform):
|
|
|
154
154
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
155
155
|
|
|
156
156
|
@torch.no_grad
|
|
157
|
-
def
|
|
157
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
158
|
p = NumberList(s['p'] for s in settings)
|
|
159
159
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
160
160
|
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
@@ -7,18 +7,39 @@ from ...utils import NumberList, TensorList, as_tensorlist
|
|
|
7
7
|
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
8
8
|
|
|
9
9
|
class MatrixMomentum(Module):
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
"""Second order momentum method.
|
|
11
|
+
|
|
12
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
|
|
13
|
+
|
|
14
|
+
.. note::
|
|
15
|
+
:code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
|
|
19
|
+
without having to tune :code:`mu`.
|
|
13
20
|
|
|
14
|
-
|
|
21
|
+
.. note::
|
|
22
|
+
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
This module requires the a closure passed to the optimizer step,
|
|
26
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
27
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
15
28
|
|
|
16
29
|
Args:
|
|
17
30
|
mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
|
|
18
31
|
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
19
32
|
hvp_method (str, optional):
|
|
20
|
-
|
|
21
|
-
|
|
33
|
+
Determines how Hessian-vector products are evaluated.
|
|
34
|
+
|
|
35
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
36
|
+
This requires creating a graph for the gradient.
|
|
37
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
38
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
39
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
40
|
+
more accurate HVP approximation. This requires two extra
|
|
41
|
+
gradient evaluations.
|
|
42
|
+
Defaults to "autograd".
|
|
22
43
|
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
23
44
|
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
24
45
|
|
|
@@ -30,7 +51,7 @@ class MatrixMomentum(Module):
|
|
|
30
51
|
self,
|
|
31
52
|
mu=0.1,
|
|
32
53
|
beta: float = 1,
|
|
33
|
-
hvp_method: Literal["autograd", "forward", "central"] = "
|
|
54
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
34
55
|
h: float = 1e-3,
|
|
35
56
|
hvp_tfm: Chainable | None = None,
|
|
36
57
|
):
|
|
@@ -40,57 +61,66 @@ class MatrixMomentum(Module):
|
|
|
40
61
|
if hvp_tfm is not None:
|
|
41
62
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
42
63
|
|
|
64
|
+
def reset_for_online(self):
|
|
65
|
+
super().reset_for_online()
|
|
66
|
+
self.clear_state_keys('prev_update')
|
|
67
|
+
|
|
43
68
|
@torch.no_grad
|
|
44
|
-
def
|
|
69
|
+
def update(self, var):
|
|
45
70
|
assert var.closure is not None
|
|
46
|
-
prev_update = self.get_state(var.params, 'prev_update'
|
|
71
|
+
prev_update = self.get_state(var.params, 'prev_update')
|
|
47
72
|
hvp_method = self.settings[var.params[0]]['hvp_method']
|
|
48
73
|
h = self.settings[var.params[0]]['h']
|
|
49
74
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
if hvp_method == 'autograd':
|
|
53
|
-
with torch.enable_grad():
|
|
54
|
-
grad = var.get_grad(create_graph=True)
|
|
55
|
-
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
56
|
-
|
|
57
|
-
elif hvp_method == 'forward':
|
|
58
|
-
var.get_grad()
|
|
59
|
-
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
60
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
75
|
+
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
76
|
+
Hvp = [t.detach() for t in Hvp]
|
|
61
77
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
78
|
+
if 'hvp_tfm' in self.children:
|
|
79
|
+
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
65
80
|
|
|
66
|
-
|
|
67
|
-
raise ValueError(hvp_method)
|
|
81
|
+
self.store(var.params, "Hvp", Hvp)
|
|
68
82
|
|
|
69
|
-
if 'hvp_tfm' in self.children:
|
|
70
|
-
hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
|
|
71
83
|
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def apply(self, var):
|
|
72
86
|
update = TensorList(var.get_update())
|
|
87
|
+
Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
|
|
88
|
+
mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
|
|
73
89
|
|
|
74
|
-
|
|
75
|
-
update.add_(prev_update - hvp_*mu)
|
|
90
|
+
update.add_(prev_update - Hvp*mu)
|
|
76
91
|
prev_update.set_(update * beta)
|
|
77
92
|
var.update = update
|
|
78
93
|
return var
|
|
79
94
|
|
|
80
95
|
|
|
81
96
|
class AdaptiveMatrixMomentum(Module):
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
|
|
97
|
+
"""Second order momentum method.
|
|
98
|
+
|
|
99
|
+
Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
|
|
100
|
+
|
|
101
|
+
.. note::
|
|
102
|
+
In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
|
|
103
|
+
|
|
104
|
+
.. note::
|
|
105
|
+
This module requires the a closure passed to the optimizer step,
|
|
106
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
107
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
85
108
|
|
|
86
|
-
This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
|
|
87
109
|
|
|
88
110
|
Args:
|
|
89
111
|
mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
|
|
90
112
|
beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
|
|
91
113
|
hvp_method (str, optional):
|
|
92
|
-
|
|
93
|
-
|
|
114
|
+
Determines how Hessian-vector products are evaluated.
|
|
115
|
+
|
|
116
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
117
|
+
This requires creating a graph for the gradient.
|
|
118
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
119
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
120
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
121
|
+
more accurate HVP approximation. This requires two extra
|
|
122
|
+
gradient evaluations.
|
|
123
|
+
Defaults to "autograd".
|
|
94
124
|
h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
|
|
95
125
|
hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
|
|
96
126
|
|
|
@@ -103,7 +133,7 @@ class AdaptiveMatrixMomentum(Module):
|
|
|
103
133
|
mu_mul: float = 1,
|
|
104
134
|
beta: float = 1,
|
|
105
135
|
eps=1e-4,
|
|
106
|
-
hvp_method: Literal["autograd", "forward", "central"] = "
|
|
136
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
107
137
|
h: float = 1e-3,
|
|
108
138
|
hvp_tfm: Chainable | None = None,
|
|
109
139
|
):
|
|
@@ -113,8 +143,12 @@ class AdaptiveMatrixMomentum(Module):
|
|
|
113
143
|
if hvp_tfm is not None:
|
|
114
144
|
self.set_child('hvp_tfm', hvp_tfm)
|
|
115
145
|
|
|
146
|
+
def reset_for_online(self):
|
|
147
|
+
super().reset_for_online()
|
|
148
|
+
self.clear_state_keys('prev_params', 'prev_grad')
|
|
149
|
+
|
|
116
150
|
@torch.no_grad
|
|
117
|
-
def
|
|
151
|
+
def update(self, var):
|
|
118
152
|
assert var.closure is not None
|
|
119
153
|
prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
|
|
120
154
|
|
|
@@ -123,43 +157,36 @@ class AdaptiveMatrixMomentum(Module):
|
|
|
123
157
|
h = settings['h']
|
|
124
158
|
eps = settings['eps']
|
|
125
159
|
|
|
126
|
-
mu_mul
|
|
127
|
-
|
|
128
|
-
if hvp_method == 'autograd':
|
|
129
|
-
with torch.enable_grad():
|
|
130
|
-
grad = var.get_grad(create_graph=True)
|
|
131
|
-
hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
|
|
132
|
-
|
|
133
|
-
elif hvp_method == 'forward':
|
|
134
|
-
var.get_grad()
|
|
135
|
-
l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
|
|
136
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
137
|
-
|
|
138
|
-
elif hvp_method == 'central':
|
|
139
|
-
l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
|
|
140
|
-
if var.loss_approx is None: var.loss_approx = l
|
|
160
|
+
mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
|
|
141
161
|
|
|
142
|
-
|
|
143
|
-
|
|
162
|
+
Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
|
|
163
|
+
Hvp = [t.detach() for t in Hvp]
|
|
144
164
|
|
|
145
165
|
if 'hvp_tfm' in self.children:
|
|
146
|
-
|
|
166
|
+
Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
|
|
147
167
|
|
|
148
168
|
# adaptive part
|
|
149
|
-
update = TensorList(var.get_update())
|
|
150
|
-
|
|
151
169
|
s_k = var.params - prev_params
|
|
152
170
|
prev_params.copy_(var.params)
|
|
153
171
|
|
|
154
|
-
assert var.grad is not None
|
|
155
|
-
|
|
156
|
-
prev_grad
|
|
172
|
+
if hvp_method != 'central': assert var.grad is not None
|
|
173
|
+
grad = var.get_grad()
|
|
174
|
+
y_k = grad - prev_grad
|
|
175
|
+
prev_grad.copy_(grad)
|
|
157
176
|
|
|
158
177
|
ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
|
|
159
178
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
179
|
+
self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
|
|
180
|
+
|
|
181
|
+
@torch.no_grad
|
|
182
|
+
def apply(self, var):
|
|
183
|
+
Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
|
|
184
|
+
Hvp = as_tensorlist(Hvp)
|
|
185
|
+
beta = NumberList(self.settings[p]['beta'] for p in var.params)
|
|
186
|
+
update = TensorList(var.get_update())
|
|
187
|
+
prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
|
|
188
|
+
|
|
189
|
+
update.add_(prev_update - Hvp*ada_mu)
|
|
163
190
|
prev_update.set_(update * beta)
|
|
164
191
|
var.update = update
|
|
165
192
|
return var
|
|
@@ -55,9 +55,10 @@ class NAG(Transform):
|
|
|
55
55
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
56
56
|
|
|
57
57
|
@torch.no_grad
|
|
58
|
-
def
|
|
58
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
59
59
|
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
60
60
|
lerp = self.settings[params[0]]['lerp']
|
|
61
61
|
|
|
62
62
|
momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
|
|
63
63
|
return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
|
|
64
|
+
|
|
@@ -7,7 +7,7 @@ from .accumulate import (
|
|
|
7
7
|
)
|
|
8
8
|
from .binary import (
|
|
9
9
|
Add,
|
|
10
|
-
|
|
10
|
+
BinaryOperationBase,
|
|
11
11
|
Clip,
|
|
12
12
|
CopyMagnitude,
|
|
13
13
|
CopySign,
|
|
@@ -27,37 +27,12 @@ from .binary import (
|
|
|
27
27
|
Sub,
|
|
28
28
|
Threshold,
|
|
29
29
|
)
|
|
30
|
-
from .debug import PrintShape, PrintUpdate
|
|
31
|
-
from .misc import (
|
|
32
|
-
DivByLoss,
|
|
33
|
-
Dropout,
|
|
34
|
-
FillLoss,
|
|
35
|
-
GradientAccumulation,
|
|
36
|
-
GradSign,
|
|
37
|
-
GraftGradToUpdate,
|
|
38
|
-
GraftToGrad,
|
|
39
|
-
GraftToParams,
|
|
40
|
-
LastAbsoluteRatio,
|
|
41
|
-
LastDifference,
|
|
42
|
-
LastGradDifference,
|
|
43
|
-
LastProduct,
|
|
44
|
-
LastRatio,
|
|
45
|
-
MulByLoss,
|
|
46
|
-
Multistep,
|
|
47
|
-
NegateOnLossIncrease,
|
|
48
|
-
NoiseSign,
|
|
49
|
-
Previous,
|
|
50
|
-
Relative,
|
|
51
|
-
Sequential,
|
|
52
|
-
UpdateSign,
|
|
53
|
-
WeightDropout,
|
|
54
|
-
)
|
|
55
30
|
from .multi import (
|
|
56
31
|
ClipModules,
|
|
57
32
|
DivModules,
|
|
58
33
|
GraftModules,
|
|
59
34
|
LerpModules,
|
|
60
|
-
|
|
35
|
+
MultiOperationBase,
|
|
61
36
|
PowModules,
|
|
62
37
|
SubModules,
|
|
63
38
|
)
|
|
@@ -66,13 +41,11 @@ from .reduce import (
|
|
|
66
41
|
Mean,
|
|
67
42
|
MinimumModules,
|
|
68
43
|
Prod,
|
|
69
|
-
|
|
44
|
+
ReduceOperationBase,
|
|
70
45
|
Sum,
|
|
71
46
|
WeightedMean,
|
|
72
47
|
WeightedSum,
|
|
73
48
|
)
|
|
74
|
-
from .split import Split
|
|
75
|
-
from .switch import Alternate, Switch
|
|
76
49
|
from .unary import (
|
|
77
50
|
Abs,
|
|
78
51
|
CustomUnaryOperation,
|
|
@@ -97,7 +70,6 @@ from .utility import (
|
|
|
97
70
|
Randn,
|
|
98
71
|
RandomSample,
|
|
99
72
|
Uniform,
|
|
100
|
-
Update,
|
|
101
73
|
UpdateToNone,
|
|
102
74
|
Zeros,
|
|
103
75
|
)
|
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from operator import itemgetter
|
|
3
|
-
from typing import Literal
|
|
4
|
-
|
|
5
1
|
import torch
|
|
6
2
|
|
|
7
3
|
from ...core import Target, Transform
|
|
8
|
-
from ...utils import TensorList,
|
|
4
|
+
from ...utils import TensorList, unpack_states
|
|
9
5
|
|
|
10
6
|
class AccumulateSum(Transform):
|
|
11
7
|
"""Accumulates sum of all past updates.
|
|
@@ -19,7 +15,7 @@ class AccumulateSum(Transform):
|
|
|
19
15
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
20
16
|
|
|
21
17
|
@torch.no_grad
|
|
22
|
-
def
|
|
18
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
23
19
|
sum = unpack_states(states, tensors, 'sum', cls=TensorList)
|
|
24
20
|
decay = [1-s['decay'] for s in settings]
|
|
25
21
|
return sum.add_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -36,7 +32,7 @@ class AccumulateMean(Transform):
|
|
|
36
32
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
37
33
|
|
|
38
34
|
@torch.no_grad
|
|
39
|
-
def
|
|
35
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
40
36
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
41
37
|
mean = unpack_states(states, tensors, 'mean', cls=TensorList)
|
|
42
38
|
decay = [1-s['decay'] for s in settings]
|
|
@@ -54,7 +50,7 @@ class AccumulateProduct(Transform):
|
|
|
54
50
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
55
51
|
|
|
56
52
|
@torch.no_grad
|
|
57
|
-
def
|
|
53
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
58
54
|
prod = unpack_states(states, tensors, 'prod', cls=TensorList)
|
|
59
55
|
decay = [1-s['decay'] for s in settings]
|
|
60
56
|
return prod.mul_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -71,7 +67,7 @@ class AccumulateMaximum(Transform):
|
|
|
71
67
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
72
68
|
|
|
73
69
|
@torch.no_grad
|
|
74
|
-
def
|
|
70
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
75
71
|
maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
|
|
76
72
|
decay = [1-s['decay'] for s in settings]
|
|
77
73
|
return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
|
|
@@ -88,7 +84,7 @@ class AccumulateMinimum(Transform):
|
|
|
88
84
|
super().__init__(defaults, uses_grad=False, target=target)
|
|
89
85
|
|
|
90
86
|
@torch.no_grad
|
|
91
|
-
def
|
|
87
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
92
88
|
minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
|
|
93
89
|
decay = [1-s['decay'] for s in settings]
|
|
94
90
|
return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
|