torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@ from collections.abc import Iterable, Sequence
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
8
|
-
from ...utils import NumberList, TensorList,
|
|
8
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
9
9
|
|
|
10
10
|
class ClipNormByEMA(Transform):
|
|
11
11
|
"""Clips norm to be no larger than the norm of an exponential moving average of past updates.
|
|
@@ -14,9 +14,10 @@ class ClipNormByEMA(Transform):
|
|
|
14
14
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
15
15
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
16
16
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
17
|
-
tensorwise (bool, optional):
|
|
17
|
+
tensorwise (bool, optional):
|
|
18
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
18
19
|
max_ema_growth (float | None, optional):
|
|
19
|
-
if specified, exponential moving average norm can grow
|
|
20
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
20
21
|
ema_init (str, optional):
|
|
21
22
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
22
23
|
"""
|
|
@@ -29,12 +30,13 @@ class ClipNormByEMA(Transform):
|
|
|
29
30
|
tensorwise:bool=True,
|
|
30
31
|
max_ema_growth: float | None = 1.5,
|
|
31
32
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
33
|
+
inner: Chainable | None = None,
|
|
32
34
|
):
|
|
33
35
|
defaults = dict(beta=beta, ord=ord, tensorwise=tensorwise, ema_init=ema_init, eps=eps, max_ema_growth=max_ema_growth)
|
|
34
|
-
super().__init__(defaults,
|
|
36
|
+
super().__init__(defaults, inner=inner)
|
|
35
37
|
|
|
36
38
|
@torch.no_grad
|
|
37
|
-
def
|
|
39
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
38
40
|
tensors = TensorList(tensors)
|
|
39
41
|
ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
|
|
40
42
|
|
|
@@ -78,7 +80,12 @@ class ClipNormByEMA(Transform):
|
|
|
78
80
|
if self.NORMALIZE: denom.clip_(min=eps[0])
|
|
79
81
|
else: denom.clip_(min=1)
|
|
80
82
|
|
|
81
|
-
|
|
83
|
+
self.global_state['denom'] = denom
|
|
84
|
+
|
|
85
|
+
@torch.no_grad
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
87
|
+
denom = self.global_state.pop('denom')
|
|
88
|
+
torch._foreach_div_(tensors, denom)
|
|
82
89
|
return tensors
|
|
83
90
|
|
|
84
91
|
class NormalizeByEMA(ClipNormByEMA):
|
|
@@ -88,9 +95,10 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
88
95
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
89
96
|
ord (float, optional): order of the norm. Defaults to 2.
|
|
90
97
|
eps (float, optional): epsilon for division. Defaults to 1e-6.
|
|
91
|
-
tensorwise (bool, optional):
|
|
98
|
+
tensorwise (bool, optional):
|
|
99
|
+
if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
|
|
92
100
|
max_ema_growth (float | None, optional):
|
|
93
|
-
if specified, exponential moving average norm can grow
|
|
101
|
+
if specified, restricts how quickly exponential moving average norm can grow. The norm is allowed to grow by at most this value per step. Defaults to 1.5.
|
|
94
102
|
ema_init (str, optional):
|
|
95
103
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
96
104
|
"""
|
|
@@ -99,28 +107,30 @@ class NormalizeByEMA(ClipNormByEMA):
|
|
|
99
107
|
# TODO Centralize by EMA?
|
|
100
108
|
|
|
101
109
|
class ClipValueByEMA(Transform):
|
|
102
|
-
"""Clips magnitude of update to be no larger than magnitude of
|
|
110
|
+
"""Clips magnitude of update to be no larger than magnitude of exponential moving average of past (unclipped) updates.
|
|
103
111
|
|
|
104
112
|
Args:
|
|
105
113
|
beta (float, optional): beta for the exponential moving average. Defaults to 0.99.
|
|
106
114
|
ema_init (str, optional):
|
|
107
115
|
How to initialize exponential moving average on first step, "update" to use the first update or "zeros". Defaults to 'zeros'.
|
|
108
|
-
ema_tfm (Chainable | None, optional):
|
|
116
|
+
ema_tfm (Chainable | None, optional):
|
|
117
|
+
optional modules applied to exponential moving average before clipping by it. Defaults to None.
|
|
109
118
|
"""
|
|
110
119
|
def __init__(
|
|
111
120
|
self,
|
|
112
121
|
beta=0.99,
|
|
113
122
|
ema_init: Literal['zeros', 'update'] = 'zeros',
|
|
114
123
|
ema_tfm:Chainable | None=None,
|
|
124
|
+
inner: Chainable | None = None,
|
|
115
125
|
):
|
|
116
126
|
defaults = dict(beta=beta, ema_init=ema_init)
|
|
117
|
-
super().__init__(defaults,
|
|
127
|
+
super().__init__(defaults, inner=inner)
|
|
118
128
|
|
|
119
129
|
if ema_tfm is not None:
|
|
120
130
|
self.set_child('ema_tfm', ema_tfm)
|
|
121
131
|
|
|
122
132
|
@torch.no_grad
|
|
123
|
-
def
|
|
133
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
124
134
|
ema_init = itemgetter('ema_init')(settings[0])
|
|
125
135
|
|
|
126
136
|
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
@@ -129,8 +139,12 @@ class ClipValueByEMA(Transform):
|
|
|
129
139
|
ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
|
|
130
140
|
ema.lerp_(tensors.abs(), 1-beta)
|
|
131
141
|
|
|
142
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
143
|
+
tensors = TensorList(tensors)
|
|
144
|
+
ema = unpack_states(states, tensors, 'ema', cls=TensorList)
|
|
145
|
+
|
|
132
146
|
if 'ema_tfm' in self.children:
|
|
133
|
-
ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
|
|
147
|
+
ema = TensorList(apply_transform(self.children['ema_tfm'], ema.clone(), params, grads, loss))
|
|
134
148
|
|
|
135
149
|
tensors.clip_(-ema, ema)
|
|
136
150
|
return tensors
|
|
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
19
19
|
bounds the tracked multiplicative clipping decay to prevent collapse to 0.
|
|
20
20
|
Next update is at most :code:`max(previous update * mul, max_decay)`.
|
|
21
21
|
Defaults to 2.
|
|
22
|
-
target (Target, optional): what to set on var
|
|
22
|
+
target (Target, optional): what to set on var. Defaults to "update".
|
|
23
23
|
"""
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
@@ -30,11 +30,11 @@ class ClipValueGrowth(TensorwiseTransform):
|
|
|
30
30
|
target: Target = "update",
|
|
31
31
|
):
|
|
32
32
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay)
|
|
33
|
-
super().__init__(defaults,
|
|
33
|
+
super().__init__(defaults, target=target)
|
|
34
34
|
|
|
35
35
|
|
|
36
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
37
|
-
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(
|
|
36
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
37
|
+
add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(setting)
|
|
38
38
|
add: float | None
|
|
39
39
|
|
|
40
40
|
if add is None and mul is None:
|
|
@@ -120,7 +120,8 @@ class ClipNormGrowth(Transform):
|
|
|
120
120
|
|
|
121
121
|
Args:
|
|
122
122
|
add (float | None, optional): additive clipping, next update norm is at most `previous norm + add`. Defaults to None.
|
|
123
|
-
mul (float | None, optional):
|
|
123
|
+
mul (float | None, optional):
|
|
124
|
+
multiplicative clipping, next update norm is at most `previous norm * mul`. Defaults to 1.5.
|
|
124
125
|
min_value (float | None, optional):
|
|
125
126
|
minimum value for multiplicative clipping to prevent collapse to 0.
|
|
126
127
|
Next norm is at most :code:`max(prev_norm, min_value) * mul`. Defaults to 1e-4.
|
|
@@ -144,11 +145,11 @@ class ClipNormGrowth(Transform):
|
|
|
144
145
|
target: Target = "update",
|
|
145
146
|
):
|
|
146
147
|
defaults = dict(add=add, mul=mul, min_value=min_value, max_decay=max_decay, ord=ord, parameterwise=parameterwise)
|
|
147
|
-
super().__init__(defaults,
|
|
148
|
+
super().__init__(defaults, target=target)
|
|
148
149
|
|
|
149
150
|
|
|
150
151
|
|
|
151
|
-
def
|
|
152
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
152
153
|
parameterwise = settings[0]['parameterwise']
|
|
153
154
|
tensors = TensorList(tensors)
|
|
154
155
|
|
|
@@ -1,24 +1,41 @@
|
|
|
1
|
+
"""This submodule contains various untested experimental modules, some of them are to be moved out of experimental when properly tested, some are to remain here forever or to be deleted depending on the degree of their usefulness."""
|
|
1
2
|
from .absoap import ABSOAP
|
|
2
3
|
from .adadam import Adadam
|
|
4
|
+
from .adam_lambertw import AdamLambertW
|
|
3
5
|
from .adamY import AdamY
|
|
6
|
+
from .adaptive_step_size import AdaptiveStepSize
|
|
4
7
|
from .adasoap import AdaSOAP
|
|
8
|
+
from .cosine import (
|
|
9
|
+
AdaptiveDifference,
|
|
10
|
+
AdaptiveDifferenceEMA,
|
|
11
|
+
CosineDebounce,
|
|
12
|
+
CosineMomentum,
|
|
13
|
+
CosineStepSize,
|
|
14
|
+
ScaledAdaptiveDifference,
|
|
15
|
+
)
|
|
16
|
+
from .cubic_adam import CubicAdam
|
|
5
17
|
from .curveball import CurveBall
|
|
18
|
+
|
|
19
|
+
# from dct import DCTProjection
|
|
6
20
|
from .eigendescent import EigenDescent
|
|
7
21
|
from .etf import (
|
|
8
22
|
ExponentialTrajectoryFit,
|
|
9
23
|
ExponentialTrajectoryFitV2,
|
|
10
24
|
PointwiseExponential,
|
|
11
25
|
)
|
|
26
|
+
from .exp_adam import ExpAdam
|
|
27
|
+
from .expanded_lbfgs import ExpandedLBFGS
|
|
28
|
+
from .fft import FFTProjection
|
|
12
29
|
from .gradmin import GradMin
|
|
30
|
+
from .hnewton import HNewton
|
|
31
|
+
from .modular_lbfgs import ModularLBFGS
|
|
13
32
|
from .newton_solver import NewtonSolver
|
|
14
33
|
from .newtonnewton import NewtonNewton
|
|
34
|
+
from .parabolic_search import CubicParabolaSearch, ParabolaSearch
|
|
15
35
|
from .reduce_outward_lr import ReduceOutwardLR
|
|
16
|
-
from .
|
|
17
|
-
from .spectral import SpectralPreconditioner
|
|
18
|
-
from .structured_newton import StructuredNewton
|
|
36
|
+
from .structural_projections import BlockPartition, TensorizeProjection
|
|
19
37
|
from .subspace_preconditioners import (
|
|
20
38
|
HistorySubspacePreconditioning,
|
|
21
39
|
RandomSubspacePreconditioning,
|
|
22
40
|
)
|
|
23
|
-
from .
|
|
24
|
-
from .diagonal_higher_order_newton import DiagonalHigherOrderNewton
|
|
41
|
+
from .tensor_adagrad import TensorAdagrad
|
|
@@ -24,7 +24,10 @@ def update_absoap_covariances_(
|
|
|
24
24
|
|
|
25
25
|
Source=Literal['p','g','s','y', 'gy', 'sy', 'sn', 'yn', 'gys', 'sys']
|
|
26
26
|
class ABSOAP(Transform):
|
|
27
|
-
"""SOAP but with some extra options for testing.
|
|
27
|
+
"""SOAP but with some extra options for testing.
|
|
28
|
+
|
|
29
|
+
.. warning::
|
|
30
|
+
This module is just for testing my stupid ideas.
|
|
28
31
|
|
|
29
32
|
Args:
|
|
30
33
|
scale_by_s - whether to scale y by s
|
|
@@ -94,7 +97,7 @@ class ABSOAP(Transform):
|
|
|
94
97
|
super().__init__(defaults, uses_grad=False)
|
|
95
98
|
|
|
96
99
|
@torch.no_grad
|
|
97
|
-
def
|
|
100
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
98
101
|
updates = []
|
|
99
102
|
# update preconditioners
|
|
100
103
|
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|
|
@@ -10,7 +10,7 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -50,7 +50,13 @@ def adadam_(
|
|
|
50
50
|
return None
|
|
51
51
|
|
|
52
52
|
class Adadam(Module):
|
|
53
|
-
"""Adam with a diagonally preconditioned preconditioner.
|
|
53
|
+
"""Adam with a diagonally preconditioned preconditioner.
|
|
54
|
+
|
|
55
|
+
Verdict: I haven't tested this yet.
|
|
56
|
+
|
|
57
|
+
.. warning::
|
|
58
|
+
Experimental.
|
|
59
|
+
"""
|
|
54
60
|
def __init__(
|
|
55
61
|
self,
|
|
56
62
|
beta1: float = 0.9,
|
|
@@ -10,7 +10,7 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -62,7 +62,13 @@ def adamy_(
|
|
|
62
62
|
return None
|
|
63
63
|
|
|
64
64
|
class AdamY(Module):
|
|
65
|
-
"""Adam but uses scaled gradient differences for second momentum.
|
|
65
|
+
"""Adam but uses scaled gradient differences for second momentum.
|
|
66
|
+
|
|
67
|
+
Verdict: I haven't tested this yet.
|
|
68
|
+
|
|
69
|
+
.. warning::
|
|
70
|
+
Experimental.
|
|
71
|
+
"""
|
|
66
72
|
def __init__(
|
|
67
73
|
self,
|
|
68
74
|
beta1: float = 0.9,
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _lambertw_newton_raphson(x: TensorList, iterations=5):
|
|
19
|
+
# z = torch.zeros_like(x)
|
|
20
|
+
# mask_neg = x < 0
|
|
21
|
+
# mask_pos = ~mask_neg
|
|
22
|
+
|
|
23
|
+
# z[mask_pos] = torch.log(x[mask_pos] + 1.0)
|
|
24
|
+
|
|
25
|
+
# x_neg = x[mask_neg]
|
|
26
|
+
# z_neg = -1.0 + torch.sqrt(2.0 * (1.0 + math.e * x_neg))
|
|
27
|
+
# z[mask_neg] = z_neg
|
|
28
|
+
|
|
29
|
+
# x is always positive
|
|
30
|
+
z = (x+1).log_()
|
|
31
|
+
for _ in range(iterations):
|
|
32
|
+
exp_z = z.exp()
|
|
33
|
+
numerator = z * exp_z - x
|
|
34
|
+
denominator = exp_z * (z + 1.0) + 1e-8
|
|
35
|
+
delta = numerator / denominator
|
|
36
|
+
z -= delta
|
|
37
|
+
return z
|
|
38
|
+
|
|
39
|
+
# https://github.com/gmgeorg/torchlambertw/blob/main/torchlambertw/special.py
|
|
40
|
+
def _lambertw_winitzki(x: TensorList):
|
|
41
|
+
x_log1p = x.log1p()
|
|
42
|
+
return x_log1p * (1.0 - x_log1p.log1p() / (2.0 + x_log1p))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def adam_lambertw_(
|
|
46
|
+
tensors: TensorList,
|
|
47
|
+
exp_avg_: TensorList,
|
|
48
|
+
exp_avg_xpx_: TensorList,
|
|
49
|
+
alpha: float | NumberList,
|
|
50
|
+
beta1: float | NumberList,
|
|
51
|
+
beta2: float | NumberList,
|
|
52
|
+
eps: float | NumberList,
|
|
53
|
+
step: int,
|
|
54
|
+
pow: float = 2,
|
|
55
|
+
debiased: bool = True,
|
|
56
|
+
max_exp_avg_xpx_: TensorList | None = None,
|
|
57
|
+
iterations: int | None = 5,
|
|
58
|
+
|
|
59
|
+
# inner args
|
|
60
|
+
inner: Module | None = None,
|
|
61
|
+
params: list[torch.Tensor] | None = None,
|
|
62
|
+
grads: list[torch.Tensor] | None = None,
|
|
63
|
+
):
|
|
64
|
+
"""Returns new tensors."""
|
|
65
|
+
tensors_abs = tensors.abs().clip_(max=20)
|
|
66
|
+
tensors_xpx = tensors_abs.pow_(tensors_abs)
|
|
67
|
+
exp_avg_xpx_.lerp_(tensors_xpx, 1-beta2)
|
|
68
|
+
|
|
69
|
+
if max_exp_avg_xpx_ is not None:
|
|
70
|
+
max_exp_avg_xpx_.maximum_(exp_avg_xpx_)
|
|
71
|
+
exp_avg_xpx_ = max_exp_avg_xpx_
|
|
72
|
+
|
|
73
|
+
if inner is not None:
|
|
74
|
+
assert params is not None
|
|
75
|
+
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
76
|
+
|
|
77
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
78
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
79
|
+
|
|
80
|
+
if iterations is None or iterations < 1: exp_avg_xpx_ = _lambertw_winitzki(exp_avg_xpx_)
|
|
81
|
+
else: exp_avg_xpx_ = _lambertw_newton_raphson(exp_avg_xpx_, iterations)
|
|
82
|
+
|
|
83
|
+
return (exp_avg_.lazy_mul(alpha) / exp_avg_xpx_.add_(eps))
|
|
84
|
+
|
|
85
|
+
class AdamLambertW(Transform):
|
|
86
|
+
"""Adam but uses abs x^x and LambertW instead of square and sqrt.
|
|
87
|
+
The gradient will be clipped to 20 because float32 which you have to use otherwise you're PC will explode.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
91
|
+
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
92
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
93
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
94
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
95
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
96
|
+
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
97
|
+
iterations (int, optional): 0 or None means Winitzki approximation otherwise number of newton raphson iterations.
|
|
98
|
+
"""
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
beta1: float = 0.9,
|
|
102
|
+
beta2: float = 0.999,
|
|
103
|
+
eps: float = 1e-8,
|
|
104
|
+
amsgrad: bool = False,
|
|
105
|
+
alpha: float = 1.,
|
|
106
|
+
pow: float = 2,
|
|
107
|
+
debiased: bool = True,
|
|
108
|
+
iterations: int | None = 5,
|
|
109
|
+
inner: Chainable | None = None
|
|
110
|
+
):
|
|
111
|
+
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased, iterations=iterations)
|
|
112
|
+
super().__init__(defaults, uses_grad=False)
|
|
113
|
+
|
|
114
|
+
if inner is not None: self.set_child('inner', inner)
|
|
115
|
+
|
|
116
|
+
@torch.no_grad
|
|
117
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
118
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
119
|
+
|
|
120
|
+
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
121
|
+
amsgrad,pow,debiased,iterations = itemgetter('amsgrad','pow','debiased','iterations')(settings[0])
|
|
122
|
+
|
|
123
|
+
if amsgrad:
|
|
124
|
+
exp_avg, exp_avg_xpx, max_exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', 'max_exp_avg_xpx', cls=TensorList)
|
|
125
|
+
else:
|
|
126
|
+
exp_avg, exp_avg_xpx = unpack_states(states, tensors, 'exp_avg', 'exp_avg_xpx', cls=TensorList)
|
|
127
|
+
max_exp_avg_xpx = None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
return adam_lambertw_(
|
|
131
|
+
tensors=TensorList(tensors),
|
|
132
|
+
exp_avg_=exp_avg,
|
|
133
|
+
exp_avg_xpx_=exp_avg_xpx,
|
|
134
|
+
alpha=alpha,
|
|
135
|
+
beta1=beta1,
|
|
136
|
+
beta2=beta2,
|
|
137
|
+
eps=eps,
|
|
138
|
+
step=step,
|
|
139
|
+
pow=pow,
|
|
140
|
+
debiased=debiased,
|
|
141
|
+
max_exp_avg_xpx_=max_exp_avg_xpx,
|
|
142
|
+
iterations=iterations,
|
|
143
|
+
|
|
144
|
+
# inner args
|
|
145
|
+
inner=self.children.get("inner", None),
|
|
146
|
+
params=params,
|
|
147
|
+
grads=grads,
|
|
148
|
+
|
|
149
|
+
)
|
|
@@ -2,12 +2,16 @@ from operator import itemgetter
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from ..line_search import LineSearchBase
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
class
|
|
9
|
-
"""Basic first order
|
|
10
|
-
step size is increased. If value increased, step size is decreased.
|
|
8
|
+
class AdaptiveStepSize(LineSearchBase):
|
|
9
|
+
"""Basic first order step size adaptation method. Re-evaluates the function after stepping, if value decreased sufficiently,
|
|
10
|
+
step size is increased. If value increased, step size is decreased.
|
|
11
|
+
|
|
12
|
+
.. note::
|
|
13
|
+
This works well in some cases, but it is often prone to collapsing.
|
|
14
|
+
For a more robust alternative use :code:`tz.m.AdaptiveBacktracking`.
|
|
11
15
|
|
|
12
16
|
Args:
|
|
13
17
|
nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
|
|
@@ -18,6 +22,19 @@ class TrustRegion(LineSearch):
|
|
|
18
22
|
adaptive (bool, optional):
|
|
19
23
|
If enabled, when multiple consecutive steps have been successful or unsuccessful,
|
|
20
24
|
the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
Adagrad with trust region:
|
|
29
|
+
|
|
30
|
+
.. code-block:: python
|
|
31
|
+
|
|
32
|
+
opt = tz.Modular(
|
|
33
|
+
model.parameters(),
|
|
34
|
+
tz.m.Adagrad(),
|
|
35
|
+
tz.m.TrustRegion()
|
|
36
|
+
)
|
|
37
|
+
|
|
21
38
|
"""
|
|
22
39
|
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
23
40
|
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
@@ -33,9 +33,14 @@ def update_adasoap_covariances_(
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class AdaSOAP(Transform):
|
|
36
|
-
"""SOAP with diagonally preconditioned GG^Ts.
|
|
36
|
+
"""SOAP with diagonally preconditioned GG^Ts.
|
|
37
|
+
|
|
38
|
+
.. warning::
|
|
39
|
+
Experimental.
|
|
37
40
|
|
|
38
41
|
precond_beta - beta for GG^T squares
|
|
42
|
+
|
|
43
|
+
Verdict: It works, but it is about the same performance as Adam, but maybe more tuning potential?
|
|
39
44
|
"""
|
|
40
45
|
def __init__(
|
|
41
46
|
self,
|
|
@@ -71,7 +76,7 @@ class AdaSOAP(Transform):
|
|
|
71
76
|
super().__init__(defaults, uses_grad=False)
|
|
72
77
|
|
|
73
78
|
@torch.no_grad
|
|
74
|
-
def
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
75
80
|
updates = []
|
|
76
81
|
# update preconditioners
|
|
77
82
|
for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
|