torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/utility.py
CHANGED
|
@@ -6,36 +6,35 @@ from ...core import Module, Target, Transform
|
|
|
6
6
|
from ...utils.tensorlist import Distributions, TensorList
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class Clone(
|
|
10
|
-
|
|
11
|
-
@torch.no_grad
|
|
12
|
-
def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
|
|
13
|
-
|
|
14
|
-
class Grad(Module):
|
|
9
|
+
class Clone(Module):
|
|
10
|
+
"""Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
|
|
15
11
|
def __init__(self):
|
|
16
12
|
super().__init__({})
|
|
17
13
|
@torch.no_grad
|
|
18
14
|
def step(self, var):
|
|
19
|
-
var.update = [
|
|
15
|
+
var.update = [u.clone() for u in var.get_update()]
|
|
20
16
|
return var
|
|
21
17
|
|
|
22
|
-
class
|
|
18
|
+
class Grad(Module):
|
|
19
|
+
"""Outputs the gradient"""
|
|
23
20
|
def __init__(self):
|
|
24
21
|
super().__init__({})
|
|
25
22
|
@torch.no_grad
|
|
26
23
|
def step(self, var):
|
|
27
|
-
var.update = [
|
|
24
|
+
var.update = [g.clone() for g in var.get_grad()]
|
|
28
25
|
return var
|
|
29
26
|
|
|
30
|
-
class
|
|
27
|
+
class Params(Module):
|
|
28
|
+
"""Outputs parameters"""
|
|
31
29
|
def __init__(self):
|
|
32
30
|
super().__init__({})
|
|
33
31
|
@torch.no_grad
|
|
34
32
|
def step(self, var):
|
|
35
|
-
var.update = [
|
|
33
|
+
var.update = [p.clone() for p in var.params]
|
|
36
34
|
return var
|
|
37
35
|
|
|
38
36
|
class Zeros(Module):
|
|
37
|
+
"""Outputs zeros"""
|
|
39
38
|
def __init__(self):
|
|
40
39
|
super().__init__({})
|
|
41
40
|
@torch.no_grad
|
|
@@ -44,6 +43,7 @@ class Zeros(Module):
|
|
|
44
43
|
return var
|
|
45
44
|
|
|
46
45
|
class Ones(Module):
|
|
46
|
+
"""Outputs ones"""
|
|
47
47
|
def __init__(self):
|
|
48
48
|
super().__init__({})
|
|
49
49
|
@torch.no_grad
|
|
@@ -52,6 +52,7 @@ class Ones(Module):
|
|
|
52
52
|
return var
|
|
53
53
|
|
|
54
54
|
class Fill(Module):
|
|
55
|
+
"""Outputs tensors filled with :code:`value`"""
|
|
55
56
|
def __init__(self, value: float):
|
|
56
57
|
defaults = dict(value=value)
|
|
57
58
|
super().__init__(defaults)
|
|
@@ -62,6 +63,7 @@ class Fill(Module):
|
|
|
62
63
|
return var
|
|
63
64
|
|
|
64
65
|
class RandomSample(Module):
|
|
66
|
+
"""Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
|
|
65
67
|
def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
|
|
66
68
|
defaults = dict(eps=eps, distribution=distribution)
|
|
67
69
|
super().__init__(defaults)
|
|
@@ -74,6 +76,7 @@ class RandomSample(Module):
|
|
|
74
76
|
return var
|
|
75
77
|
|
|
76
78
|
class Randn(Module):
|
|
79
|
+
"""Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
|
|
77
80
|
def __init__(self):
|
|
78
81
|
super().__init__({})
|
|
79
82
|
|
|
@@ -83,6 +86,7 @@ class Randn(Module):
|
|
|
83
86
|
return var
|
|
84
87
|
|
|
85
88
|
class Uniform(Module):
|
|
89
|
+
"""Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
|
|
86
90
|
def __init__(self, low: float, high: float):
|
|
87
91
|
defaults = dict(low=low, high=high)
|
|
88
92
|
super().__init__(defaults)
|
|
@@ -94,19 +98,23 @@ class Uniform(Module):
|
|
|
94
98
|
return var
|
|
95
99
|
|
|
96
100
|
class GradToNone(Module):
|
|
101
|
+
"""Sets :code:`grad` attribute to None on :code:`var`."""
|
|
97
102
|
def __init__(self): super().__init__()
|
|
98
103
|
def step(self, var):
|
|
99
104
|
var.grad = None
|
|
100
105
|
return var
|
|
101
106
|
|
|
102
107
|
class UpdateToNone(Module):
|
|
108
|
+
"""Sets :code:`update` attribute to None on :code:`var`."""
|
|
103
109
|
def __init__(self): super().__init__()
|
|
104
110
|
def step(self, var):
|
|
105
111
|
var.update = None
|
|
106
112
|
return var
|
|
107
113
|
|
|
108
114
|
class Identity(Module):
|
|
115
|
+
"""A placeholder identity operator that is argument-insensitive."""
|
|
109
116
|
def __init__(self, *args, **kwargs): super().__init__()
|
|
110
117
|
def step(self, var): return var
|
|
111
118
|
|
|
112
|
-
NoOp = Identity
|
|
119
|
+
NoOp = Identity
|
|
120
|
+
"""A placeholder identity operator that is argument-insensitive."""
|
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
from .adagrad import Adagrad, FullMatrixAdagrad
|
|
2
|
+
|
|
3
|
+
# from .curveball import CurveBall
|
|
4
|
+
# from .spectral import SpectralPreconditioner
|
|
5
|
+
from .adahessian import AdaHessian
|
|
2
6
|
from .adam import Adam
|
|
7
|
+
from .adan import Adan
|
|
8
|
+
from .adaptive_heavyball import AdaptiveHeavyBall
|
|
9
|
+
from .esgd import ESGD
|
|
10
|
+
from .ladagrad import LMAdagrad
|
|
3
11
|
from .lion import Lion
|
|
12
|
+
from .mars import MARSCorrection
|
|
13
|
+
from .msam import MSAM, MSAMObjective
|
|
4
14
|
from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
|
|
15
|
+
from .orthograd import OrthoGrad, orthograd_
|
|
5
16
|
from .rmsprop import RMSprop
|
|
6
17
|
from .rprop import (
|
|
7
18
|
BacktrackOnSignChange,
|
|
@@ -10,9 +21,7 @@ from .rprop import (
|
|
|
10
21
|
SignConsistencyLRs,
|
|
11
22
|
SignConsistencyMask,
|
|
12
23
|
)
|
|
24
|
+
from .sam import ASAM, SAM
|
|
13
25
|
from .shampoo import Shampoo
|
|
14
26
|
from .soap import SOAP
|
|
15
|
-
from .orthograd import OrthoGrad, orthograd_
|
|
16
27
|
from .sophia_h import SophiaH
|
|
17
|
-
# from .curveball import CurveBall
|
|
18
|
-
# from .spectral import SpectralPreconditioner
|
|
@@ -25,6 +25,7 @@ def adagrad_(
|
|
|
25
25
|
step: int,
|
|
26
26
|
pow: float = 2,
|
|
27
27
|
use_sqrt: bool = True,
|
|
28
|
+
divide: bool = False,
|
|
28
29
|
|
|
29
30
|
# inner args
|
|
30
31
|
inner: Module | None = None,
|
|
@@ -40,6 +41,8 @@ def adagrad_(
|
|
|
40
41
|
assert params is not None
|
|
41
42
|
tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
|
|
42
43
|
|
|
44
|
+
if divide: sq_sum_ = sq_sum_ / max(step, 1)
|
|
45
|
+
|
|
43
46
|
if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
|
|
44
47
|
else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
|
|
45
48
|
|
|
@@ -48,7 +51,9 @@ def adagrad_(
|
|
|
48
51
|
|
|
49
52
|
|
|
50
53
|
class Adagrad(Transform):
|
|
51
|
-
"""Adagrad, divides by sum of past squares of gradients
|
|
54
|
+
"""Adagrad, divides by sum of past squares of gradients.
|
|
55
|
+
|
|
56
|
+
This implementation is identical to :code:`torch.optim.Adagrad`.
|
|
52
57
|
|
|
53
58
|
Args:
|
|
54
59
|
lr_decay (float, optional): learning rate decay. Defaults to 0.
|
|
@@ -67,23 +72,24 @@ class Adagrad(Transform):
|
|
|
67
72
|
alpha: float = 1,
|
|
68
73
|
pow: float = 2,
|
|
69
74
|
use_sqrt: bool = True,
|
|
75
|
+
divide: bool=False,
|
|
70
76
|
inner: Chainable | None = None,
|
|
71
77
|
):
|
|
72
78
|
defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
|
|
73
|
-
eps = eps, pow=pow, use_sqrt = use_sqrt)
|
|
79
|
+
eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
|
|
74
80
|
super().__init__(defaults=defaults, uses_grad=False)
|
|
75
81
|
|
|
76
82
|
if inner is not None:
|
|
77
83
|
self.set_child('inner', inner)
|
|
78
84
|
|
|
79
85
|
@torch.no_grad
|
|
80
|
-
def
|
|
86
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
81
87
|
tensors = TensorList(tensors)
|
|
82
88
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
83
89
|
|
|
84
90
|
lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
|
|
85
91
|
|
|
86
|
-
pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
|
|
92
|
+
pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
|
|
87
93
|
|
|
88
94
|
sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
|
|
89
95
|
|
|
@@ -100,6 +106,7 @@ class Adagrad(Transform):
|
|
|
100
106
|
step=self.global_state["step"],
|
|
101
107
|
pow=pow,
|
|
102
108
|
use_sqrt=use_sqrt,
|
|
109
|
+
divide=divide,
|
|
103
110
|
|
|
104
111
|
# inner args
|
|
105
112
|
inner=self.children.get("inner", None),
|
|
@@ -110,17 +117,17 @@ class Adagrad(Transform):
|
|
|
110
117
|
|
|
111
118
|
|
|
112
119
|
class FullMatrixAdagrad(TensorwiseTransform):
|
|
113
|
-
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=
|
|
114
|
-
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
|
|
115
|
-
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
|
|
120
|
+
def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
|
|
121
|
+
defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
|
|
122
|
+
super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
|
|
116
123
|
|
|
117
124
|
@torch.no_grad
|
|
118
|
-
def update_tensor(self, tensor, param, grad, loss, state,
|
|
125
|
+
def update_tensor(self, tensor, param, grad, loss, state, setting):
|
|
119
126
|
G = tensor.ravel()
|
|
120
127
|
GG = torch.outer(G, G)
|
|
121
|
-
decay =
|
|
122
|
-
beta =
|
|
123
|
-
init =
|
|
128
|
+
decay = setting['decay']
|
|
129
|
+
beta = setting['beta']
|
|
130
|
+
init = setting['init']
|
|
124
131
|
|
|
125
132
|
if 'GG' not in state:
|
|
126
133
|
if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
|
|
@@ -132,11 +139,14 @@ class FullMatrixAdagrad(TensorwiseTransform):
|
|
|
132
139
|
|
|
133
140
|
if beta is not None: state['GG'].lerp_(GG, 1-beta)
|
|
134
141
|
else: state['GG'].add_(GG)
|
|
142
|
+
state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
|
|
135
143
|
|
|
136
144
|
@torch.no_grad
|
|
137
|
-
def apply_tensor(self, tensor, param, grad, loss, state,
|
|
145
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
138
146
|
GG = state['GG']
|
|
139
|
-
sqrt =
|
|
147
|
+
sqrt = setting['sqrt']
|
|
148
|
+
divide = setting['divide']
|
|
149
|
+
if divide: GG = GG/state.get('i', 1)
|
|
140
150
|
|
|
141
151
|
if tensor.numel() == 1:
|
|
142
152
|
GG = GG.squeeze()
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ...core import Chainable, Module, Target, Transform, apply_transform
|
|
8
|
+
from ...utils import NumberList, TensorList, as_tensorlist
|
|
9
|
+
from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
|
|
13
|
+
"""averages x over first dimension in blocks"""
|
|
14
|
+
if enable and x.ndim >= 2:
|
|
15
|
+
if math.prod(x.shape[1:]) <= 1: return x
|
|
16
|
+
size = x.size(0)
|
|
17
|
+
if block_size is None: return x.mean(0, keepdim=True)
|
|
18
|
+
|
|
19
|
+
n_blocks = size // block_size
|
|
20
|
+
if n_blocks <= 1: return x.mean(0, keepdim = True)
|
|
21
|
+
|
|
22
|
+
n_remaining = size - n_blocks * block_size
|
|
23
|
+
remaining = None
|
|
24
|
+
if n_remaining > 0:
|
|
25
|
+
remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
|
|
26
|
+
x = x[:-n_remaining]
|
|
27
|
+
|
|
28
|
+
x = x.view(block_size, n_blocks, *x.shape[1:])
|
|
29
|
+
x_mean = x.mean(0).repeat_interleave(block_size, 0)
|
|
30
|
+
|
|
31
|
+
if remaining is None: return x_mean
|
|
32
|
+
return torch.cat([x_mean, remaining], 0)
|
|
33
|
+
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
def _rademacher_like(tensor, p = 0.5, generator = None):
|
|
37
|
+
"""p is probability of a 1, other values will be -1."""
|
|
38
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
|
|
39
|
+
|
|
40
|
+
def adahessian(
|
|
41
|
+
tensors: TensorList,
|
|
42
|
+
D: TensorList | None,
|
|
43
|
+
exp_avg_: TensorList,
|
|
44
|
+
D_exp_avg_sq_: TensorList,
|
|
45
|
+
beta1: float | NumberList,
|
|
46
|
+
beta2: float | NumberList,
|
|
47
|
+
update_freq: int,
|
|
48
|
+
eps: float | NumberList,
|
|
49
|
+
step: int,
|
|
50
|
+
):
|
|
51
|
+
# momentum
|
|
52
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
53
|
+
num = exp_avg_ / (1-beta1)
|
|
54
|
+
|
|
55
|
+
# update preconditioner
|
|
56
|
+
if step % update_freq == 0:
|
|
57
|
+
assert D is not None
|
|
58
|
+
D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
assert D is None
|
|
62
|
+
|
|
63
|
+
denom = (D_exp_avg_sq_ / (1-beta2)).sqrt_().add_(eps)
|
|
64
|
+
|
|
65
|
+
return num.div_(denom)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class AdaHessian(Module):
|
|
69
|
+
"""AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
|
|
70
|
+
|
|
71
|
+
This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
|
|
72
|
+
|
|
73
|
+
.. note::
|
|
74
|
+
In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply AdaHessian preconditioning to another module's output.
|
|
75
|
+
|
|
76
|
+
.. note::
|
|
77
|
+
If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
|
|
78
|
+
|
|
79
|
+
.. note::
|
|
80
|
+
This module requires a closure passed to the optimizer step,
|
|
81
|
+
as it needs to re-evaluate the loss and gradients for calculating HVPs.
|
|
82
|
+
The closure must accept a ``backward`` argument (refer to documentation).
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
beta1 (float, optional): first momentum. Defaults to 0.9.
|
|
86
|
+
beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
|
|
87
|
+
averaging (bool, optional):
|
|
88
|
+
whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
|
|
89
|
+
This can be set per-parameter in param groups.
|
|
90
|
+
block_size (int, optional):
|
|
91
|
+
size of block in the block-diagonal averaging.
|
|
92
|
+
update_freq (int, optional):
|
|
93
|
+
frequency of updating hessian diagonal estimate via a hessian-vector product.
|
|
94
|
+
This value can be increased to reduce computational cost. Defaults to 1.
|
|
95
|
+
eps (float, optional):
|
|
96
|
+
division stability epsilon. Defaults to 1e-8.
|
|
97
|
+
hvp_method (str, optional):
|
|
98
|
+
Determines how Hessian-vector products are evaluated.
|
|
99
|
+
|
|
100
|
+
- ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
|
|
101
|
+
This requires creating a graph for the gradient.
|
|
102
|
+
- ``"forward"``: Use a forward finite difference formula to
|
|
103
|
+
approximate the HVP. This requires one extra gradient evaluation.
|
|
104
|
+
- ``"central"``: Use a central finite difference formula for a
|
|
105
|
+
more accurate HVP approximation. This requires two extra
|
|
106
|
+
gradient evaluations.
|
|
107
|
+
Defaults to "autograd".
|
|
108
|
+
h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
|
|
109
|
+
n_samples (int, optional):
|
|
110
|
+
number of hessian-vector products with random vectors to evaluate each time when updating
|
|
111
|
+
the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
|
|
112
|
+
seed (int | None, optional): seed for random vectors. Defaults to None.
|
|
113
|
+
inner (Chainable | None, optional):
|
|
114
|
+
Inner module. If this is specified, operations are performed in the following order.
|
|
115
|
+
1. compute hessian diagonal estimate.
|
|
116
|
+
2. pass inputs to :code:`inner`.
|
|
117
|
+
3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
Using AdaHessian:
|
|
121
|
+
|
|
122
|
+
.. code-block:: python
|
|
123
|
+
|
|
124
|
+
opt = tz.Modular(
|
|
125
|
+
model.parameters(),
|
|
126
|
+
tz.m.AdaHessian(),
|
|
127
|
+
tz.m.LR(0.1)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
AdaHessian preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
|
|
131
|
+
Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
|
|
132
|
+
AdaHessian preconditioning to nesterov momentum (:code:`tz.m.NAG`):
|
|
133
|
+
|
|
134
|
+
.. code-block:: python
|
|
135
|
+
|
|
136
|
+
opt = tz.Modular(
|
|
137
|
+
model.parameters(),
|
|
138
|
+
tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
|
|
139
|
+
tz.m.LR(0.1)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
"""
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
beta1: float = 0.9,
|
|
146
|
+
beta2: float = 0.999,
|
|
147
|
+
averaging: bool = False,
|
|
148
|
+
block_size: int | None = 9,
|
|
149
|
+
update_freq: int = 1,
|
|
150
|
+
eps: float = 1e-8,
|
|
151
|
+
hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
|
|
152
|
+
fd_h: float = 1e-3,
|
|
153
|
+
n_samples = 1,
|
|
154
|
+
seed: int | None = None,
|
|
155
|
+
inner: Chainable | None = None
|
|
156
|
+
):
|
|
157
|
+
defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
|
|
158
|
+
super().__init__(defaults)
|
|
159
|
+
|
|
160
|
+
if inner is not None:
|
|
161
|
+
self.set_child('inner', inner)
|
|
162
|
+
|
|
163
|
+
@torch.no_grad
|
|
164
|
+
def step(self, var):
|
|
165
|
+
params = var.params
|
|
166
|
+
settings = self.settings[params[0]]
|
|
167
|
+
hvp_method = settings['hvp_method']
|
|
168
|
+
fd_h = settings['fd_h']
|
|
169
|
+
update_freq = settings['update_freq']
|
|
170
|
+
n_samples = settings['n_samples']
|
|
171
|
+
|
|
172
|
+
seed = settings['seed']
|
|
173
|
+
generator = None
|
|
174
|
+
if seed is not None:
|
|
175
|
+
if 'generator' not in self.global_state:
|
|
176
|
+
self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
|
|
177
|
+
generator = self.global_state['generator']
|
|
178
|
+
|
|
179
|
+
beta1, beta2, eps, averaging, block_size = self.get_settings(params,
|
|
180
|
+
'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
|
|
181
|
+
|
|
182
|
+
exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
|
|
183
|
+
|
|
184
|
+
step = self.global_state.get('step', 0)
|
|
185
|
+
self.global_state['step'] = step + 1
|
|
186
|
+
|
|
187
|
+
closure = var.closure
|
|
188
|
+
assert closure is not None
|
|
189
|
+
|
|
190
|
+
D = None
|
|
191
|
+
if step % update_freq == 0:
|
|
192
|
+
|
|
193
|
+
rgrad=None
|
|
194
|
+
for i in range(n_samples):
|
|
195
|
+
u = [_rademacher_like(p, generator=generator) for p in params]
|
|
196
|
+
|
|
197
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
198
|
+
h=fd_h, normalize=True, retain_grad=i < n_samples-1)
|
|
199
|
+
|
|
200
|
+
if D is None: D = Hvp
|
|
201
|
+
else: torch._foreach_add_(D, Hvp)
|
|
202
|
+
|
|
203
|
+
assert D is not None
|
|
204
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
205
|
+
|
|
206
|
+
D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
|
|
207
|
+
|
|
208
|
+
update = var.get_update()
|
|
209
|
+
if 'inner' in self.children:
|
|
210
|
+
update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
|
|
211
|
+
|
|
212
|
+
var.update = adahessian(
|
|
213
|
+
tensors=TensorList(update),
|
|
214
|
+
D=TensorList(D) if D is not None else None,
|
|
215
|
+
exp_avg_=exp_avg,
|
|
216
|
+
D_exp_avg_sq_=D_exp_avg_sq,
|
|
217
|
+
beta1=beta1,
|
|
218
|
+
beta2=beta2,
|
|
219
|
+
update_freq=update_freq,
|
|
220
|
+
eps=eps,
|
|
221
|
+
step=step,
|
|
222
|
+
)
|
|
223
|
+
return var
|
|
@@ -10,7 +10,7 @@ from ..functional import (
|
|
|
10
10
|
ema_,
|
|
11
11
|
sqrt_ema_sq_,
|
|
12
12
|
)
|
|
13
|
-
from ..
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
14
|
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
15
|
from ..momentum.momentum import nag_
|
|
16
16
|
|
|
@@ -33,7 +33,7 @@ def adam_(
|
|
|
33
33
|
params: list[torch.Tensor] | None = None,
|
|
34
34
|
grads: list[torch.Tensor] | None = None,
|
|
35
35
|
):
|
|
36
|
-
"""Returns new tensors
|
|
36
|
+
"""Returns new tensors."""
|
|
37
37
|
sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
|
|
38
38
|
debiased=False,step=step,pow=pow)
|
|
39
39
|
|
|
@@ -43,11 +43,12 @@ def adam_(
|
|
|
43
43
|
|
|
44
44
|
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
45
45
|
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
46
|
-
return (exp_avg_ / sqrt_exp_avg_sq.add_(eps))
|
|
46
|
+
return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
|
|
47
47
|
|
|
48
48
|
class Adam(Transform):
|
|
49
|
-
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
50
|
-
|
|
49
|
+
"""Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
|
|
50
|
+
|
|
51
|
+
This implementation is identical to :code:`torch.optim.Adam`.
|
|
51
52
|
|
|
52
53
|
Args:
|
|
53
54
|
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
@@ -75,7 +76,7 @@ class Adam(Transform):
|
|
|
75
76
|
if inner is not None: self.set_child('inner', inner)
|
|
76
77
|
|
|
77
78
|
@torch.no_grad
|
|
78
|
-
def
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
79
80
|
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
80
81
|
|
|
81
82
|
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
def adan_(
|
|
7
|
+
g: TensorList,
|
|
8
|
+
g_prev_: TensorList,
|
|
9
|
+
m_: TensorList, # exponential moving average
|
|
10
|
+
v_: TensorList, # exponential moving average of gradient differences
|
|
11
|
+
n_: TensorList, # kinda like squared momentum
|
|
12
|
+
n_prev_: TensorList | None,
|
|
13
|
+
beta1: float | NumberList,
|
|
14
|
+
beta2: float | NumberList,
|
|
15
|
+
beta3: float | NumberList,
|
|
16
|
+
eps: float | NumberList,
|
|
17
|
+
use_n_prev: bool,
|
|
18
|
+
):
|
|
19
|
+
"""Returns new tensors."""
|
|
20
|
+
m_.lerp_(g, 1-beta1)
|
|
21
|
+
|
|
22
|
+
y = g - g_prev_
|
|
23
|
+
v_.lerp_(y, 1-beta2)
|
|
24
|
+
|
|
25
|
+
y.mul_(1-beta2).add_(g)
|
|
26
|
+
n_.mul_(beta3).addcmul_(y, y, 1-beta3)
|
|
27
|
+
|
|
28
|
+
if use_n_prev:
|
|
29
|
+
assert n_prev_ is not None
|
|
30
|
+
ns = n_prev_.clone()
|
|
31
|
+
n_prev_.copy_(n_)
|
|
32
|
+
n_ = ns
|
|
33
|
+
|
|
34
|
+
eta = n_.sqrt().add_(eps).reciprocal_()
|
|
35
|
+
term = m_ + (1-beta2)*v_
|
|
36
|
+
update = eta.mul_(term)
|
|
37
|
+
|
|
38
|
+
g_prev_.copy_(g)
|
|
39
|
+
|
|
40
|
+
return update
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Adan(Transform):
|
|
44
|
+
"""Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
beta1 (float, optional): momentum. Defaults to 0.98.
|
|
48
|
+
beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
|
|
49
|
+
beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
|
|
50
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
51
|
+
use_n_prev (bool, optional):
|
|
52
|
+
whether to use previous gradient differences momentum.
|
|
53
|
+
|
|
54
|
+
Reference:
|
|
55
|
+
Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
|
|
56
|
+
"""
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
beta1: float = 0.98,
|
|
60
|
+
beta2: float = 0.92,
|
|
61
|
+
beta3: float = 0.99,
|
|
62
|
+
eps: float = 1e-8,
|
|
63
|
+
use_n_prev: bool = False,
|
|
64
|
+
):
|
|
65
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,use_n_prev=use_n_prev)
|
|
66
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
|
+
|
|
68
|
+
@torch.no_grad
|
|
69
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
70
|
+
tensors = TensorList(tensors)
|
|
71
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
72
|
+
|
|
73
|
+
beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
|
|
74
|
+
s = settings[0]
|
|
75
|
+
use_n_prev = s['use_n_prev']
|
|
76
|
+
|
|
77
|
+
g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if use_n_prev:
|
|
81
|
+
n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
|
|
82
|
+
else:
|
|
83
|
+
n_prev = None
|
|
84
|
+
|
|
85
|
+
if step == 1:
|
|
86
|
+
# initial values, also runs on restarts
|
|
87
|
+
m.copy_(tensors)
|
|
88
|
+
n.set_(tensors ** 2)
|
|
89
|
+
v.zero_()
|
|
90
|
+
g_prev.copy_(tensors)
|
|
91
|
+
if n_prev is not None: n_prev.set_(tensors ** 2)
|
|
92
|
+
|
|
93
|
+
if step == 2:
|
|
94
|
+
v.set_(tensors - g_prev)
|
|
95
|
+
|
|
96
|
+
update = adan_(
|
|
97
|
+
g=tensors,
|
|
98
|
+
g_prev_=g_prev,
|
|
99
|
+
m_=m,
|
|
100
|
+
v_=v,
|
|
101
|
+
n_=n,
|
|
102
|
+
n_prev_=n_prev,
|
|
103
|
+
beta1=beta1,
|
|
104
|
+
beta2=beta2,
|
|
105
|
+
beta3=beta3,
|
|
106
|
+
eps=eps,
|
|
107
|
+
use_n_prev=use_n_prev,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return update
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import Transform
|
|
3
|
+
from ...utils import TensorList, unpack_dicts, unpack_states
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
|
|
7
|
+
if f - f_star <= torch.finfo(p[0].dtype).eps: return g
|
|
8
|
+
|
|
9
|
+
g_g = g.dot(g)
|
|
10
|
+
g_gp = g.dot(g_prev)
|
|
11
|
+
num = -(f - f_star) * g.dot(g_prev)
|
|
12
|
+
denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
|
|
13
|
+
m = num/denom
|
|
14
|
+
|
|
15
|
+
h = 2*(f - f_star) / g_g
|
|
16
|
+
return (1 + m) * h * g - m*(p-p_prev)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AdaptiveHeavyBall(Transform):
|
|
20
|
+
"""Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
|
|
21
|
+
|
|
22
|
+
This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
The step size is determined by the algorithm, so learning rate modules shouldn't be used.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
f_star (int, optional):
|
|
29
|
+
(estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
30
|
+
tol (float, optional):
|
|
31
|
+
tolerance on objective value change.
|
|
32
|
+
"""
|
|
33
|
+
def __init__(self, f_star: float = 0):
|
|
34
|
+
defaults = dict(f_star=f_star)
|
|
35
|
+
super().__init__(defaults, uses_grad=False, uses_loss=True)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad
|
|
38
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
39
|
+
assert loss is not None
|
|
40
|
+
tensors = TensorList(tensors)
|
|
41
|
+
setting = settings[0]
|
|
42
|
+
f_star = setting['f_star']
|
|
43
|
+
|
|
44
|
+
f_prev = self.global_state.get('f_prev', None)
|
|
45
|
+
p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
|
|
46
|
+
|
|
47
|
+
if f_prev is None:
|
|
48
|
+
self.global_state['f_prev'] = loss
|
|
49
|
+
h = 2*(loss - f_star) / tensors.dot(tensors)
|
|
50
|
+
return h * tensors
|
|
51
|
+
|
|
52
|
+
update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
|
|
53
|
+
|
|
54
|
+
self.global_state['f_prev'] = loss
|
|
55
|
+
p_prev.copy_(params)
|
|
56
|
+
g_prev.copy_(tensors)
|
|
57
|
+
return update
|