torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -1,4 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
from .averaging import Averaging, MedianAveraging, WeightedAveraging
|
|
2
|
+
from .cautious import (
|
|
3
|
+
Cautious,
|
|
4
|
+
IntermoduleCautious,
|
|
5
|
+
ScaleByGradCosineSimilarity,
|
|
6
|
+
ScaleModulesByCosineSimilarity,
|
|
7
|
+
UpdateGradientSignConsistency,
|
|
8
|
+
)
|
|
9
|
+
from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
|
|
10
|
+
from .experimental import CoordinateMomentum
|
|
11
|
+
# from .matrix_momentum import MatrixMomentum
|
|
12
|
+
|
|
13
|
+
from .momentum import NAG, HeavyBall
|
|
14
|
+
from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Any, Literal, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import TensorwiseTransform, Target
|
|
8
|
+
from ...utils import tolist
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Averaging(TensorwiseTransform):
|
|
12
|
+
def __init__(self, history_size: int, target: Target = 'update'):
|
|
13
|
+
defaults = dict(history_size=history_size)
|
|
14
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
15
|
+
|
|
16
|
+
@torch.no_grad
|
|
17
|
+
def transform(self, tensor, param, grad, vars):
|
|
18
|
+
history_size = self.settings[param]['history_size']
|
|
19
|
+
state = self.state[param]
|
|
20
|
+
if 'history' not in state:
|
|
21
|
+
state['history'] = deque(maxlen=history_size)
|
|
22
|
+
state['average'] = torch.zeros_like(tensor)
|
|
23
|
+
|
|
24
|
+
history = state['history']; average = state['average']
|
|
25
|
+
if len(history) == history_size: average -= history[0]
|
|
26
|
+
history.append(tensor)
|
|
27
|
+
average += tensor
|
|
28
|
+
|
|
29
|
+
return average / len(history)
|
|
30
|
+
|
|
31
|
+
class WeightedAveraging(TensorwiseTransform):
|
|
32
|
+
"""weights are oldest to newest"""
|
|
33
|
+
def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
|
|
34
|
+
defaults = dict(weights = tolist(weights))
|
|
35
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def transform(self, tensor, param, grad, vars):
|
|
39
|
+
weights = self.settings[param]['weights']
|
|
40
|
+
state = self.state[param]
|
|
41
|
+
|
|
42
|
+
if 'history' not in state:
|
|
43
|
+
state['history'] = deque(maxlen=len(weights))
|
|
44
|
+
|
|
45
|
+
history = state['history']
|
|
46
|
+
history.append(tensor)
|
|
47
|
+
if len(history) != len(weights):
|
|
48
|
+
weights = weights[-len(history):]
|
|
49
|
+
|
|
50
|
+
average = None
|
|
51
|
+
for i, (h, w) in enumerate(zip(history, weights)):
|
|
52
|
+
if average is None: average = h * (w / len(history))
|
|
53
|
+
else:
|
|
54
|
+
if w == 0: continue
|
|
55
|
+
average += h * (w / len(history))
|
|
56
|
+
|
|
57
|
+
assert average is not None
|
|
58
|
+
return average
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MedianAveraging(TensorwiseTransform):
|
|
62
|
+
def __init__(self, history_size: int, target: Target = 'update'):
|
|
63
|
+
defaults = dict(history_size = history_size)
|
|
64
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
65
|
+
|
|
66
|
+
@torch.no_grad
|
|
67
|
+
def transform(self, tensor, param, grad, vars):
|
|
68
|
+
history_size = self.settings[param]['history_size']
|
|
69
|
+
state = self.state[param]
|
|
70
|
+
|
|
71
|
+
if 'history' not in state:
|
|
72
|
+
state['history'] = deque(maxlen=history_size)
|
|
73
|
+
|
|
74
|
+
history = state['history']
|
|
75
|
+
history.append(tensor)
|
|
76
|
+
|
|
77
|
+
stacked = torch.stack(tuple(history), 0)
|
|
78
|
+
return torch.quantile(stacked, 0.5, dim = 0)
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Target, Transform, Module, Chainable
|
|
8
|
+
from ...utils import NumberList, TensorList
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def cautious_(
|
|
12
|
+
tensors_: TensorList,
|
|
13
|
+
grads: TensorList,
|
|
14
|
+
normalize: bool,
|
|
15
|
+
eps: float,
|
|
16
|
+
mode: Literal['zero', 'grad', 'backtrack']
|
|
17
|
+
):
|
|
18
|
+
# mask will be > 0 for parameters where both signs are the same
|
|
19
|
+
mask = (tensors_ * grads) > 0
|
|
20
|
+
if mode in ('zero', 'grad'):
|
|
21
|
+
if normalize and mode == 'zero':
|
|
22
|
+
fmask = mask.to(tensors_[0].dtype)
|
|
23
|
+
fmask /= fmask.global_mean().clip(min=eps) # type:ignore
|
|
24
|
+
else:
|
|
25
|
+
fmask = mask
|
|
26
|
+
|
|
27
|
+
tensors_ *= fmask
|
|
28
|
+
|
|
29
|
+
if mode == 'grad':
|
|
30
|
+
tensors_ += grads * mask.logical_not_()
|
|
31
|
+
|
|
32
|
+
return tensors_
|
|
33
|
+
|
|
34
|
+
# mode = 'backtrack'
|
|
35
|
+
tensors_ -= tensors_.mul(2).mul_(mask.logical_not_())
|
|
36
|
+
return tensors_
|
|
37
|
+
|
|
38
|
+
class Cautious(Transform):
|
|
39
|
+
"""Negates update for parameters where update and gradient sign is inconsistent.
|
|
40
|
+
Optionally normalizes the update by the number of parameters that are not masked.
|
|
41
|
+
This is meant to be used after any momentum-based modules.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
normalize (bool, optional):
|
|
45
|
+
renormalize update after masking.
|
|
46
|
+
only has effect when mode is 'zero'. Defaults to False.
|
|
47
|
+
eps (float, optional): epsilon for normalization. Defaults to 1e-6.
|
|
48
|
+
mode (str, optional):
|
|
49
|
+
what to do with updates with inconsistent signs.
|
|
50
|
+
|
|
51
|
+
"zero" - set them to zero (as in paper)
|
|
52
|
+
|
|
53
|
+
"grad" - set them to the gradient
|
|
54
|
+
|
|
55
|
+
"backtrack" - negate them (same as using update magnitude and gradient sign)
|
|
56
|
+
|
|
57
|
+
reference
|
|
58
|
+
*Cautious Optimizers: Improving Training with One Line of Code.
|
|
59
|
+
Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
normalize=False,
|
|
65
|
+
eps=1e-6,
|
|
66
|
+
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
67
|
+
target: Target = "update",
|
|
68
|
+
):
|
|
69
|
+
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
70
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
71
|
+
|
|
72
|
+
@torch.no_grad
|
|
73
|
+
def transform(self, tensors, params, grads, vars):
|
|
74
|
+
assert grads is not None
|
|
75
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[params[0]])
|
|
76
|
+
return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
|
|
77
|
+
|
|
78
|
+
class UpdateGradientSignConsistency(Transform):
|
|
79
|
+
"""1 where signs match 0 otherwise"""
|
|
80
|
+
def __init__(self, normalize = False, eps=1e-6, target: Target = 'update'):
|
|
81
|
+
defaults = dict(normalize=normalize, eps=eps)
|
|
82
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
83
|
+
|
|
84
|
+
@torch.no_grad
|
|
85
|
+
def transform(self, tensors, params, grads, vars):
|
|
86
|
+
assert grads is not None
|
|
87
|
+
normalize, eps = itemgetter('normalize', 'eps')(self.settings[params[0]])
|
|
88
|
+
|
|
89
|
+
mask = (TensorList(tensors).mul_(grads)).gt_(0)
|
|
90
|
+
if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
|
|
91
|
+
|
|
92
|
+
return mask
|
|
93
|
+
|
|
94
|
+
class IntermoduleCautious(Module):
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
main: Chainable,
|
|
98
|
+
compare: Chainable,
|
|
99
|
+
normalize=False,
|
|
100
|
+
eps=1e-6,
|
|
101
|
+
mode: Literal["zero", "grad", "backtrack"] = "zero",
|
|
102
|
+
):
|
|
103
|
+
defaults = dict(normalize=normalize, eps=eps, mode=mode)
|
|
104
|
+
super().__init__(defaults)
|
|
105
|
+
|
|
106
|
+
self.set_child('main', main)
|
|
107
|
+
self.set_child('compare', compare)
|
|
108
|
+
|
|
109
|
+
@torch.no_grad
|
|
110
|
+
def step(self, vars):
|
|
111
|
+
main = self.children['main']
|
|
112
|
+
compare = self.children['compare']
|
|
113
|
+
|
|
114
|
+
main_vars = main.step(vars.clone(clone_update=True))
|
|
115
|
+
vars.update_attrs_from_clone_(main_vars)
|
|
116
|
+
|
|
117
|
+
compare_vars = compare.step(vars.clone(clone_update=True))
|
|
118
|
+
vars.update_attrs_from_clone_(compare_vars)
|
|
119
|
+
|
|
120
|
+
mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[vars.params[0]])
|
|
121
|
+
vars.update = cautious_(
|
|
122
|
+
TensorList(main_vars.get_update()),
|
|
123
|
+
TensorList(compare_vars.get_update()),
|
|
124
|
+
normalize=normalize,
|
|
125
|
+
mode=mode,
|
|
126
|
+
eps=eps,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return vars
|
|
130
|
+
|
|
131
|
+
class ScaleByGradCosineSimilarity(Transform):
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
eps=1e-6,
|
|
135
|
+
target: Target = "update",
|
|
136
|
+
):
|
|
137
|
+
defaults = dict(eps=eps)
|
|
138
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
139
|
+
|
|
140
|
+
@torch.no_grad
|
|
141
|
+
def transform(self, tensors, params, grads, vars):
|
|
142
|
+
assert grads is not None
|
|
143
|
+
eps = self.settings[params[0]]['eps']
|
|
144
|
+
tensors = TensorList(tensors)
|
|
145
|
+
grads = TensorList(grads)
|
|
146
|
+
cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
|
|
147
|
+
|
|
148
|
+
return tensors.mul_(cos_sim)
|
|
149
|
+
|
|
150
|
+
class ScaleModulesByCosineSimilarity(Module):
|
|
151
|
+
def __init__(
|
|
152
|
+
self,
|
|
153
|
+
main: Chainable,
|
|
154
|
+
compare: Chainable,
|
|
155
|
+
eps=1e-6,
|
|
156
|
+
):
|
|
157
|
+
defaults = dict(eps=eps)
|
|
158
|
+
super().__init__(defaults)
|
|
159
|
+
|
|
160
|
+
self.set_child('main', main)
|
|
161
|
+
self.set_child('compare', compare)
|
|
162
|
+
|
|
163
|
+
@torch.no_grad
|
|
164
|
+
def step(self, vars):
|
|
165
|
+
main = self.children['main']
|
|
166
|
+
compare = self.children['compare']
|
|
167
|
+
|
|
168
|
+
main_vars = main.step(vars.clone(clone_update=True))
|
|
169
|
+
vars.update_attrs_from_clone_(main_vars)
|
|
170
|
+
|
|
171
|
+
compare_vars = compare.step(vars.clone(clone_update=True))
|
|
172
|
+
vars.update_attrs_from_clone_(compare_vars)
|
|
173
|
+
|
|
174
|
+
m = TensorList(main_vars.get_update())
|
|
175
|
+
c = TensorList(compare_vars.get_update())
|
|
176
|
+
eps = self.settings[vars.params[0]]['eps']
|
|
177
|
+
|
|
178
|
+
cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
|
|
179
|
+
|
|
180
|
+
vars.update = m.mul_(cos_sim)
|
|
181
|
+
return vars
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Target, Transform
|
|
8
|
+
from ...utils import TensorList, NumberList
|
|
9
|
+
from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EMA(Transform):
|
|
13
|
+
"""Maintains EMA of update.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
momentum (float, optional): momentum (beta). Defaults to 0.9.
|
|
17
|
+
dampening (float, optional): momentum dampening. Defaults to 0.
|
|
18
|
+
debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
|
|
19
|
+
lerp (bool, optional): whether to use linear interpolation. Defaults to True.
|
|
20
|
+
target (Target, optional): target to apply EMA to. Defaults to 'update'.
|
|
21
|
+
"""
|
|
22
|
+
def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
|
|
23
|
+
defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
|
|
24
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def transform(self, tensors, params, grads, vars):
|
|
28
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
29
|
+
|
|
30
|
+
debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(self.settings[params[0]])
|
|
31
|
+
|
|
32
|
+
exp_avg = self.get_state('exp_avg', params=params, init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
|
|
33
|
+
momentum, dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
34
|
+
|
|
35
|
+
exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
|
|
36
|
+
|
|
37
|
+
if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
|
|
38
|
+
else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class EMASquared(Transform):
|
|
42
|
+
EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
|
|
43
|
+
|
|
44
|
+
def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
45
|
+
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
|
|
46
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def transform(self, tensors, params, grads, vars):
|
|
50
|
+
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
51
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
52
|
+
|
|
53
|
+
if amsgrad:
|
|
54
|
+
exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
|
|
55
|
+
else:
|
|
56
|
+
exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
|
|
57
|
+
max_exp_avg_sq = None
|
|
58
|
+
|
|
59
|
+
return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
|
|
60
|
+
|
|
61
|
+
class SqrtEMASquared(Transform):
|
|
62
|
+
SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
|
|
63
|
+
|
|
64
|
+
def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update',):
|
|
65
|
+
defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
|
|
66
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def transform(self, tensors, params, grads, vars):
|
|
71
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
+
|
|
73
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
|
|
74
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
75
|
+
|
|
76
|
+
if amsgrad:
|
|
77
|
+
exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
|
|
78
|
+
else:
|
|
79
|
+
exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
|
|
80
|
+
max_exp_avg_sq = None
|
|
81
|
+
|
|
82
|
+
return self.SQRT_EMA_SQ_FN(
|
|
83
|
+
TensorList(tensors),
|
|
84
|
+
exp_avg_sq_=exp_avg_sq,
|
|
85
|
+
beta=beta,
|
|
86
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
87
|
+
debiased=debiased,
|
|
88
|
+
step=step,
|
|
89
|
+
pow=pow,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Debias(Transform):
|
|
94
|
+
def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
|
|
95
|
+
defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
|
|
96
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
97
|
+
|
|
98
|
+
@torch.no_grad
|
|
99
|
+
def transform(self, tensors, params, grads, vars):
|
|
100
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
101
|
+
|
|
102
|
+
settings = self.settings[params[0]]
|
|
103
|
+
pow = settings['pow']
|
|
104
|
+
alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
|
|
105
|
+
|
|
106
|
+
return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
|
|
107
|
+
|
|
108
|
+
class Debias2(Transform):
|
|
109
|
+
def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
|
|
110
|
+
defaults = dict(beta=beta, pow=pow)
|
|
111
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
112
|
+
|
|
113
|
+
@torch.no_grad
|
|
114
|
+
def transform(self, tensors, params, grads, vars):
|
|
115
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
116
|
+
|
|
117
|
+
pow = self.settings[params[0]]['pow']
|
|
118
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
119
|
+
return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
|
|
120
|
+
|
|
121
|
+
class CenteredEMASquared(Transform):
|
|
122
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
123
|
+
defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
|
|
124
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
125
|
+
|
|
126
|
+
@torch.no_grad
|
|
127
|
+
def transform(self, tensors, params, grads, vars):
|
|
128
|
+
amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
|
|
129
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
130
|
+
|
|
131
|
+
if amsgrad:
|
|
132
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
|
|
133
|
+
else:
|
|
134
|
+
exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
|
|
135
|
+
max_exp_avg_sq = None
|
|
136
|
+
|
|
137
|
+
return centered_ema_sq_(
|
|
138
|
+
TensorList(tensors),
|
|
139
|
+
exp_avg_=exp_avg,
|
|
140
|
+
exp_avg_sq_=exp_avg_sq,
|
|
141
|
+
beta=beta,
|
|
142
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
143
|
+
pow=pow,
|
|
144
|
+
).clone()
|
|
145
|
+
|
|
146
|
+
class CenteredSqrtEMASquared(Transform):
|
|
147
|
+
def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update'):
|
|
148
|
+
defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
|
|
149
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
150
|
+
|
|
151
|
+
@torch.no_grad
|
|
152
|
+
def transform(self, tensors, params, grads, vars):
|
|
153
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
154
|
+
|
|
155
|
+
amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
|
|
156
|
+
beta = self.get_settings('beta', params=params, cls=NumberList)
|
|
157
|
+
|
|
158
|
+
if amsgrad:
|
|
159
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
|
|
160
|
+
else:
|
|
161
|
+
exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
|
|
162
|
+
max_exp_avg_sq = None
|
|
163
|
+
|
|
164
|
+
return sqrt_centered_ema_sq_(
|
|
165
|
+
TensorList(tensors),
|
|
166
|
+
exp_avg_=exp_avg,
|
|
167
|
+
exp_avg_sq_=exp_avg_sq,
|
|
168
|
+
beta=beta,
|
|
169
|
+
debiased=debiased,
|
|
170
|
+
step=step,
|
|
171
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
172
|
+
pow=pow,
|
|
173
|
+
)
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from functools import partial
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Target, Transform
|
|
9
|
+
from ...utils import NumberList, TensorList
|
|
10
|
+
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
|
+
from .ema import EMASquared, SqrtEMASquared
|
|
12
|
+
from .momentum import nag_
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def precentered_ema_sq_(
|
|
16
|
+
tensors: TensorList,
|
|
17
|
+
exp_avg_: TensorList,
|
|
18
|
+
exp_avg_sq_: TensorList,
|
|
19
|
+
beta1: float | NumberList,
|
|
20
|
+
beta2: float | NumberList,
|
|
21
|
+
step: int,
|
|
22
|
+
min_step: int,
|
|
23
|
+
pow: float,
|
|
24
|
+
max_exp_avg_sq_: TensorList | None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Squared EMA of (update - 1st EMA). Starts taking effect after `min_step` to avoid division by epsilon.
|
|
28
|
+
|
|
29
|
+
returns `exp_avg_sq_` or `max_exp_avg_sq_`.
|
|
30
|
+
"""
|
|
31
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0, lerp=False)
|
|
32
|
+
|
|
33
|
+
if step < min_step: centered_update = tensors
|
|
34
|
+
else: centered_update = tensors - exp_avg_
|
|
35
|
+
|
|
36
|
+
exp_avg_sq_=ema_sq_(
|
|
37
|
+
centered_update,
|
|
38
|
+
exp_avg_sq_=exp_avg_sq_,
|
|
39
|
+
beta=beta2,
|
|
40
|
+
pow=pow,
|
|
41
|
+
max_exp_avg_sq_=max_exp_avg_sq_,
|
|
42
|
+
)
|
|
43
|
+
return exp_avg_sq_
|
|
44
|
+
|
|
45
|
+
class PrecenteredEMASquared(Transform):
|
|
46
|
+
def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
47
|
+
defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
|
|
48
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
49
|
+
self.current_step = 0
|
|
50
|
+
|
|
51
|
+
@torch.no_grad
|
|
52
|
+
def transform(self, tensors, params, grads, vars):
|
|
53
|
+
self.current_step += 1
|
|
54
|
+
|
|
55
|
+
beta1, beta2 = self.get_settings('beta1','beta2', params=params, cls=NumberList)
|
|
56
|
+
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(self.settings[params[0]])
|
|
57
|
+
|
|
58
|
+
if amsgrad:
|
|
59
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
|
|
60
|
+
else:
|
|
61
|
+
exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
|
|
62
|
+
max_exp_avg_sq = None
|
|
63
|
+
|
|
64
|
+
return precentered_ema_sq_(
|
|
65
|
+
TensorList(tensors),
|
|
66
|
+
exp_avg_ = exp_avg,
|
|
67
|
+
exp_avg_sq_=exp_avg_sq,
|
|
68
|
+
beta1=beta1,
|
|
69
|
+
beta2=beta2,
|
|
70
|
+
step = self.current_step,
|
|
71
|
+
min_step=min_step,
|
|
72
|
+
pow=pow,
|
|
73
|
+
max_exp_avg_sq_=max_exp_avg_sq,
|
|
74
|
+
).clone()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def nag_ema_sq_(
|
|
78
|
+
tensors: TensorList,
|
|
79
|
+
exp_avg_sq_: TensorList,
|
|
80
|
+
beta: float | NumberList,
|
|
81
|
+
max_exp_avg_sq_: TensorList | None,
|
|
82
|
+
pow: float,
|
|
83
|
+
lerp:bool=True,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Nesterov EMA of squared tensors.
|
|
87
|
+
|
|
88
|
+
Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
|
|
89
|
+
"""
|
|
90
|
+
if pow == 1: tensors = tensors.abs()
|
|
91
|
+
elif pow%2 == 0: tensors = tensors.pow(pow)
|
|
92
|
+
else: tensors = tensors.pow(pow).abs()
|
|
93
|
+
|
|
94
|
+
exp_avg_sq_=nag_(tensors,velocity_=exp_avg_sq_,momentum=beta,dampening=0,lerp=lerp,)
|
|
95
|
+
|
|
96
|
+
# AMSGrad
|
|
97
|
+
if max_exp_avg_sq_ is not None:
|
|
98
|
+
max_exp_avg_sq_.maximum_(exp_avg_sq_)
|
|
99
|
+
exp_avg_sq_ = max_exp_avg_sq_
|
|
100
|
+
|
|
101
|
+
return exp_avg_sq_
|
|
102
|
+
|
|
103
|
+
def sqrt_nag_ema_sq_(
|
|
104
|
+
tensors: TensorList,
|
|
105
|
+
exp_avg_sq_: TensorList,
|
|
106
|
+
beta: float | NumberList,
|
|
107
|
+
max_exp_avg_sq_: TensorList | None,
|
|
108
|
+
debiased: bool,
|
|
109
|
+
step: int,
|
|
110
|
+
pow: float,
|
|
111
|
+
lerp:bool=False,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Square root of nesterov EMA of squared tensors.
|
|
115
|
+
|
|
116
|
+
Returns new tensors.
|
|
117
|
+
"""
|
|
118
|
+
return sqrt_ema_sq_(tensors=tensors,exp_avg_sq_=exp_avg_sq_,beta=beta,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
119
|
+
pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
|
|
120
|
+
|
|
121
|
+
class NesterovEMASquared(EMASquared):
|
|
122
|
+
EMA_SQ_FN = staticmethod(nag_ema_sq_)
|
|
123
|
+
|
|
124
|
+
class SqrtNesterovEMASquared(SqrtEMASquared):
|
|
125
|
+
SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def coordinate_momentum_(
|
|
129
|
+
tensors: TensorList,
|
|
130
|
+
velocity_: TensorList,
|
|
131
|
+
p: float | NumberList,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
sets `velocity_` to p% random values from `tensors`.
|
|
135
|
+
|
|
136
|
+
Returns `velocity_`
|
|
137
|
+
"""
|
|
138
|
+
mask = tensors.bernoulli_like(p).as_bool()
|
|
139
|
+
velocity_.masked_set_(mask, tensors)
|
|
140
|
+
return velocity_
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class CoordinateMomentum(Transform):
|
|
144
|
+
def __init__(self, p: float = 0.1, target: Target = 'update'):
|
|
145
|
+
defaults = dict(p=p)
|
|
146
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
147
|
+
|
|
148
|
+
@torch.no_grad
|
|
149
|
+
def transform(self, tensors, params, grads, vars):
|
|
150
|
+
p = self.get_settings('p', params=params, cls=NumberList)
|
|
151
|
+
velocity = self.get_state('velocity', params=params, cls=TensorList)
|
|
152
|
+
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# def multiplicative_momentum_(
|
|
156
|
+
# tensors_: TensorList,
|
|
157
|
+
# velocity_: TensorList,
|
|
158
|
+
# momentum: float | NumberList,
|
|
159
|
+
# dampening: float | NumberList,
|
|
160
|
+
# normalize_velocity: bool = True,
|
|
161
|
+
# abs: bool = False,
|
|
162
|
+
# lerp: bool = False,
|
|
163
|
+
# ):
|
|
164
|
+
# """
|
|
165
|
+
# abs: if True, tracks momentum of absolute magnitudes.
|
|
166
|
+
|
|
167
|
+
# returns `tensors_`.
|
|
168
|
+
# """
|
|
169
|
+
# tensors_into_velocity = tensors_.abs() if abs else tensors_
|
|
170
|
+
# ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
|
|
171
|
+
|
|
172
|
+
# if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
|
|
173
|
+
# return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# class MultiplicativeMomentum(Transform):
|
|
177
|
+
# """sucks"""
|
|
178
|
+
# def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
|
|
179
|
+
# defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
|
|
180
|
+
# super().__init__(defaults, uses_grad=False)
|
|
181
|
+
|
|
182
|
+
# @torch.no_grad
|
|
183
|
+
# def transform(self, tensors, params, grads, vars):
|
|
184
|
+
# momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
|
|
185
|
+
# abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
|
|
186
|
+
# velocity = self.get_state('velocity', params=params, cls=TensorList)
|
|
187
|
+
# return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
|
|
188
|
+
# normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
|
|
189
|
+
|