torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
"""A bunch of useless modules that I hate and that didn't work"""
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from ...core import Chainable, Transform, apply_transform
|
|
5
|
+
from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CosineStepSize(Transform):
|
|
9
|
+
"""Adaptive step size based on cosine similarity
|
|
10
|
+
|
|
11
|
+
VERDICT: Useless. This is too unstable.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
15
|
+
init (float, optional): initial step size. Defaults to 1.
|
|
16
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
17
|
+
target_cossim (float, optional): cosine similarity needs to be above this to increase step size. Defaults to 1e-8.
|
|
18
|
+
inner (Chainable | None, optional):
|
|
19
|
+
inner modules applied after calculating cosine similarity and before step size correction. Defaults to None.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, scale:float = 0.95, init:float=1, eps:float=1e-12, inner:Chainable | None = None):
|
|
22
|
+
defaults = dict(scale=scale, init=init, eps=eps)
|
|
23
|
+
super().__init__(defaults, uses_grad=False)
|
|
24
|
+
if inner is not None: self.set_child('inner', inner)
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
28
|
+
scale, init = unpack_dicts(settings, 'scale', 'init', cls=NumberList)
|
|
29
|
+
unpack_states(states, tensors, 'alpha', init=init, cls=NumberList) # initializes alpha to init
|
|
30
|
+
eps = settings[0]['eps']
|
|
31
|
+
|
|
32
|
+
tensors = as_tensorlist(tensors)
|
|
33
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
34
|
+
|
|
35
|
+
tensors_norm = tensors.global_vector_norm()
|
|
36
|
+
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
37
|
+
|
|
38
|
+
if 'inner' in self.children:
|
|
39
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
40
|
+
|
|
41
|
+
new_alpha = []
|
|
42
|
+
for s, sc in zip(states, scale):
|
|
43
|
+
s['alpha'] *= 1 + cos_sim * sc
|
|
44
|
+
new_alpha.append(s['alpha'])
|
|
45
|
+
|
|
46
|
+
tensors.mul_(new_alpha)
|
|
47
|
+
prev.copy_(tensors)
|
|
48
|
+
|
|
49
|
+
return tensors
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CosineDebounce(Transform):
|
|
54
|
+
"""Debouncing when cosine similarity is less than 0.
|
|
55
|
+
|
|
56
|
+
VERDICT: Useless. This doesn't help at all.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 0.95.
|
|
60
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
61
|
+
inner (Chainable | None, optional):
|
|
62
|
+
inner modules applied after calculating cosine similarity and before debouncing correction. Defaults to None.
|
|
63
|
+
"""
|
|
64
|
+
def __init__(self, scale:float = 0.95, eps:float=1e-12, damping:float=0.95, inner:Chainable | None = None):
|
|
65
|
+
defaults = dict(scale=scale, eps=eps, damping=damping)
|
|
66
|
+
super().__init__(defaults, uses_grad=False)
|
|
67
|
+
if inner is not None: self.set_child('inner', inner)
|
|
68
|
+
|
|
69
|
+
@torch.no_grad
|
|
70
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
71
|
+
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
72
|
+
eps = settings[0]['eps']
|
|
73
|
+
|
|
74
|
+
tensors = as_tensorlist(tensors)
|
|
75
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList).mul_(damping)
|
|
76
|
+
|
|
77
|
+
tensors_norm = tensors.global_vector_norm()
|
|
78
|
+
cos_sim = (tensors.dot(prev) / (tensors_norm * prev.global_vector_norm()).clip(min=eps)).item()
|
|
79
|
+
|
|
80
|
+
if 'inner' in self.children:
|
|
81
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
82
|
+
|
|
83
|
+
if cos_sim < -eps:
|
|
84
|
+
undo = prev.neg().mul_(-cos_sim * scale)
|
|
85
|
+
comb = prev.graft(tensors).add_(tensors).graft_(prev).mul_(-cos_sim*scale)
|
|
86
|
+
tensors = undo.add_(comb)
|
|
87
|
+
|
|
88
|
+
prev.copy_(tensors)
|
|
89
|
+
return tensors
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class CosineMomentum(Transform):
|
|
94
|
+
"""Beta depends on cosine similarity. At cossim=1, beta is 0. At cossim=-1, beta is 2^power. This basically removes oscillations.
|
|
95
|
+
|
|
96
|
+
VERDICT: Useless. Worse than all other momentums.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
scale (float, optional): cosine similarity multiplier. Defaults to 1.
|
|
100
|
+
nesterov (float, optional): whether to use nesterov momentum. Defaults to False.
|
|
101
|
+
power (float, optional): power for beta. Defaults to 1.
|
|
102
|
+
eps (float, optional): epsilon for division stability. Defaults to 1e-12.
|
|
103
|
+
inner (Chainable | None, optional):
|
|
104
|
+
inner modules applied after calculating cosine similarity and before updating exponential moving average. Defaults to None.
|
|
105
|
+
"""
|
|
106
|
+
def __init__(self, scale:float = 1, nesterov: bool = False, power: float = 1, eps:float=1e-12, inner:Chainable | None = None):
|
|
107
|
+
defaults = dict(scale=scale, eps=eps, nesterov=nesterov, power=power)
|
|
108
|
+
super().__init__(defaults, uses_grad=False)
|
|
109
|
+
if inner is not None: self.set_child('inner', inner)
|
|
110
|
+
|
|
111
|
+
@torch.no_grad
|
|
112
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
113
|
+
scale, power = unpack_dicts(settings, 'scale', 'power', cls=NumberList)
|
|
114
|
+
eps = settings[0]['eps']
|
|
115
|
+
nesterov = settings[0]['nesterov']
|
|
116
|
+
exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList)
|
|
117
|
+
|
|
118
|
+
tensors = as_tensorlist(tensors)
|
|
119
|
+
|
|
120
|
+
tensors_norm = tensors.global_vector_norm()
|
|
121
|
+
cos_sim = (tensors.dot(exp_avg) / (tensors_norm * exp_avg.global_vector_norm()).clip(min=eps)).item()
|
|
122
|
+
|
|
123
|
+
if 'inner' in self.children:
|
|
124
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
125
|
+
|
|
126
|
+
beta = (1 - (cos_sim*scale)) ** power
|
|
127
|
+
if nesterov:
|
|
128
|
+
exp_avg.add_(tensors.mul(beta))
|
|
129
|
+
return tensors.add_(exp_avg)
|
|
130
|
+
else:
|
|
131
|
+
exp_avg.add_(tensors.mul_(beta))
|
|
132
|
+
return exp_avg.clone()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class AdaptiveDifference(Transform):
|
|
136
|
+
"""VERDICT: Useless. Doesn't help (sort of to be expected)."""
|
|
137
|
+
def __init__(self, inner:Chainable | None = None):
|
|
138
|
+
defaults = dict()
|
|
139
|
+
super().__init__(defaults, uses_grad=False)
|
|
140
|
+
if inner is not None: self.set_child('inner', inner)
|
|
141
|
+
|
|
142
|
+
@torch.no_grad
|
|
143
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
144
|
+
tensors = as_tensorlist(tensors)
|
|
145
|
+
prev = unpack_states(states, tensors, 'prev', init=tensors, cls=TensorList)
|
|
146
|
+
|
|
147
|
+
diff = tensors - prev.graft_(tensors)
|
|
148
|
+
prev.copy_(tensors)
|
|
149
|
+
|
|
150
|
+
if 'inner' in self.children:
|
|
151
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
152
|
+
|
|
153
|
+
tensors.add_(diff.graft_(tensors))
|
|
154
|
+
|
|
155
|
+
return tensors
|
|
156
|
+
|
|
157
|
+
class AdaptiveDifferenceEMA(Transform):
|
|
158
|
+
"""VERDICT: better than non-EMA but still useless."""
|
|
159
|
+
def __init__(self, beta=0.99, inner:Chainable | None = None):
|
|
160
|
+
defaults = dict(beta=beta)
|
|
161
|
+
super().__init__(defaults, uses_grad=False)
|
|
162
|
+
if inner is not None: self.set_child('inner', inner)
|
|
163
|
+
|
|
164
|
+
@torch.no_grad
|
|
165
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
166
|
+
tensors = as_tensorlist(tensors)
|
|
167
|
+
beta = unpack_dicts(settings, 'beta', cls=NumberList)
|
|
168
|
+
prev, diff_exp_avg = unpack_states(states, tensors, 'prev', 'diff_exp_avg', init=[tensors,torch.zeros_like], cls=TensorList)
|
|
169
|
+
|
|
170
|
+
diff = (tensors - prev.graft_(tensors)).graft_(tensors)
|
|
171
|
+
diff_exp_avg.lerp_(diff, 1-beta)
|
|
172
|
+
prev.copy_(tensors)
|
|
173
|
+
|
|
174
|
+
if 'inner' in self.children:
|
|
175
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
176
|
+
|
|
177
|
+
tensors.add_(diff_exp_avg.graft(tensors))
|
|
178
|
+
|
|
179
|
+
return tensors
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class ScaledAdaptiveDifference(Transform):
|
|
183
|
+
"""VERDICT: Useless and doesn't help."""
|
|
184
|
+
def __init__(self, scale=0.95, damping:float=0.99, inner:Chainable | None = None):
|
|
185
|
+
defaults = dict(scale=scale, damping=damping)
|
|
186
|
+
super().__init__(defaults, uses_grad=False)
|
|
187
|
+
if inner is not None: self.set_child('inner', inner)
|
|
188
|
+
|
|
189
|
+
@torch.no_grad
|
|
190
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
191
|
+
tensors = as_tensorlist(tensors)
|
|
192
|
+
scale, damping = unpack_dicts(settings, 'scale', 'damping', cls=NumberList)
|
|
193
|
+
prev_tensors, prev_update = unpack_states(states, tensors, 'prev', 'prev_update', init=[tensors,tensors], cls=TensorList)
|
|
194
|
+
|
|
195
|
+
cos_sim = (tensors.dot(prev_update) / (tensors.global_vector_norm() * prev_update.global_vector_norm()).clip(min=1e-10)).item()
|
|
196
|
+
|
|
197
|
+
if 'inner' in self.children:
|
|
198
|
+
tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads, loss))
|
|
199
|
+
|
|
200
|
+
if cos_sim > 0:
|
|
201
|
+
tensors.add_(prev_tensors*(cos_sim*scale))
|
|
202
|
+
|
|
203
|
+
else:
|
|
204
|
+
undo = prev_tensors.neg().mul_(-cos_sim*scale)
|
|
205
|
+
comb = prev_tensors.graft(tensors).add_(tensors).graft_(prev_tensors).mul_(-cos_sim*scale)
|
|
206
|
+
tensors = undo.add_(comb).graft_((tensors-prev_tensors).mul_(damping))
|
|
207
|
+
|
|
208
|
+
diff = tensors - prev_tensors.graft_(tensors)
|
|
209
|
+
prev_tensors.copy_(tensors)
|
|
210
|
+
diff.graft_(tensors)
|
|
211
|
+
tensors.add_(diff)
|
|
212
|
+
prev_update.copy_(tensors)
|
|
213
|
+
|
|
214
|
+
return tensors
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def signed_cbrt(x: TensorList) -> TensorList:
|
|
8
|
+
return x.sign() * x.abs().pow(1/3)
|
|
9
|
+
|
|
10
|
+
def cubic_adam_(
|
|
11
|
+
tensors: TensorList,
|
|
12
|
+
exp_avg_: TensorList,
|
|
13
|
+
exp_avg_sq_: TensorList,
|
|
14
|
+
exp_avg_cu_: TensorList,
|
|
15
|
+
alpha: float | NumberList,
|
|
16
|
+
beta1: float | NumberList,
|
|
17
|
+
beta2: float | NumberList,
|
|
18
|
+
beta3: float | NumberList,
|
|
19
|
+
eps: float | NumberList,
|
|
20
|
+
debiased: bool,
|
|
21
|
+
step: int,
|
|
22
|
+
):
|
|
23
|
+
exp_avg_.lerp_(tensors, 1-beta1)
|
|
24
|
+
exp_avg_sq_.lerp_(tensors**2, 1-beta2)
|
|
25
|
+
exp_avg_cu_.lerp_(tensors**3, 1-beta3)
|
|
26
|
+
|
|
27
|
+
if debiased:
|
|
28
|
+
m1 = exp_avg_ / (1 - beta1 ** step)
|
|
29
|
+
m2 = exp_avg_sq_ / (1 - beta2 ** step)
|
|
30
|
+
m3 = exp_avg_cu_ / (1 - beta3 ** step)
|
|
31
|
+
else:
|
|
32
|
+
m1, m2, m3 = exp_avg_, exp_avg_sq_, exp_avg_cu_
|
|
33
|
+
|
|
34
|
+
# adam minimizes ax^2 + bx
|
|
35
|
+
# we are going to minimize ax^3 + bx^2 + cx
|
|
36
|
+
A = signed_cbrt(m3)
|
|
37
|
+
B = m2.sqrt()
|
|
38
|
+
C = m1
|
|
39
|
+
discriminant = B.pow(2) - 4 * A * C
|
|
40
|
+
|
|
41
|
+
denom = 2 * A
|
|
42
|
+
root = discriminant.clamp(min=0).sqrt_()
|
|
43
|
+
|
|
44
|
+
x0 = (-B + root) / (denom + eps)
|
|
45
|
+
x1 = (-B - root) / (denom + eps)
|
|
46
|
+
|
|
47
|
+
f0 = (A/3)*x0**3 + (B/2)*x0**2 + C*x0
|
|
48
|
+
f1 = (A/3)*x1**3 + (B/2)*x1**2 + C*x1
|
|
49
|
+
|
|
50
|
+
x_star = x0.where(f0 < f1, x1)
|
|
51
|
+
|
|
52
|
+
adam = -C / (B + eps)
|
|
53
|
+
x_star = adam.where(discriminant < 0, x_star)
|
|
54
|
+
|
|
55
|
+
return x_star.mul_(-alpha)
|
|
56
|
+
|
|
57
|
+
class CubicAdam(Transform):
|
|
58
|
+
"""Adam which has 3rd momentum and minimizes a cubic polynomial.
|
|
59
|
+
|
|
60
|
+
VERDICT: can outperform Adam very slightly. Usually very similar performance.
|
|
61
|
+
|
|
62
|
+
.. warning::
|
|
63
|
+
Experimental.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
beta1: float = 0.9,
|
|
69
|
+
beta2: float = 0.99,
|
|
70
|
+
beta3: float = 0.99,
|
|
71
|
+
eps: float = 1e-8,
|
|
72
|
+
debiased:bool=True,
|
|
73
|
+
alpha: float = 1.,
|
|
74
|
+
):
|
|
75
|
+
defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,debiased=debiased,alpha=alpha)
|
|
76
|
+
super().__init__(defaults, uses_grad=False)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
81
|
+
|
|
82
|
+
beta1,beta2,beta3,eps,alpha=unpack_dicts(settings, 'beta1','beta2','beta3','eps','alpha', cls=NumberList)
|
|
83
|
+
exp_avg, exp_avg_sq, exp_avg_cu = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'exp_avg_cu', cls=TensorList)
|
|
84
|
+
|
|
85
|
+
return cubic_adam_(
|
|
86
|
+
tensors=TensorList(tensors),
|
|
87
|
+
exp_avg_=exp_avg,
|
|
88
|
+
exp_avg_sq_=exp_avg_sq,
|
|
89
|
+
exp_avg_cu_=exp_avg_cu,
|
|
90
|
+
alpha=alpha,
|
|
91
|
+
beta1=beta1,
|
|
92
|
+
beta2=beta2,
|
|
93
|
+
beta3=beta3,
|
|
94
|
+
eps=eps,
|
|
95
|
+
debiased=settings[0]['debiased'],
|
|
96
|
+
step=step,
|
|
97
|
+
)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
2
|
import torch
|
|
3
3
|
import torch_dct
|
|
4
|
-
from
|
|
4
|
+
from ..projections import ProjectionBase
|
|
5
5
|
from ...core import Chainable
|
|
6
6
|
|
|
7
7
|
def reverse_dims(t:torch.Tensor):
|
|
8
8
|
return t.permute(*reversed(range(t.ndim)))
|
|
9
9
|
|
|
10
|
-
class DCTProjection(
|
|
10
|
+
class DCTProjection(ProjectionBase):
|
|
11
11
|
# norm description copied from pytorch docstring
|
|
12
12
|
"""Project update into Discrete Cosine Transform space, requires `torch_dct` library.
|
|
13
13
|
|
|
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
|
|
|
34
34
|
super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
|
|
35
35
|
|
|
36
36
|
@torch.no_grad
|
|
37
|
-
def project(self, tensors,
|
|
38
|
-
settings =
|
|
37
|
+
def project(self, tensors, params, grads, loss, states, settings, current):
|
|
38
|
+
settings = settings[0]
|
|
39
39
|
dims = settings['dims']
|
|
40
40
|
norm = settings['norm']
|
|
41
41
|
|
|
@@ -54,18 +54,18 @@ class DCTProjection(Projection):
|
|
|
54
54
|
return projected
|
|
55
55
|
|
|
56
56
|
@torch.no_grad
|
|
57
|
-
def unproject(self,
|
|
58
|
-
settings =
|
|
57
|
+
def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
|
|
58
|
+
settings = projected_settings[0]
|
|
59
59
|
dims = settings['dims']
|
|
60
60
|
norm = settings['norm']
|
|
61
61
|
|
|
62
62
|
unprojected = []
|
|
63
|
-
for
|
|
64
|
-
dim = min(
|
|
63
|
+
for t in projected_tensors:
|
|
64
|
+
dim = min(t.ndim, dims)
|
|
65
65
|
|
|
66
|
-
if dim == 1: idct = torch_dct.idct(
|
|
67
|
-
elif dim == 2: idct = torch_dct.idct_2d(
|
|
68
|
-
elif dim == 3: idct = torch_dct.idct_3d(
|
|
66
|
+
if dim == 1: idct = torch_dct.idct(t, norm = norm)
|
|
67
|
+
elif dim == 2: idct = torch_dct.idct_2d(t, norm=norm)
|
|
68
|
+
elif dim == 3: idct = torch_dct.idct_3d(t, norm=norm)
|
|
69
69
|
else: raise ValueError(f"Unsupported number of dimensions {dim}")
|
|
70
70
|
|
|
71
71
|
unprojected.append(reverse_dims(idct))
|
|
@@ -23,7 +23,10 @@ def _cosine_similarity(x, y):
|
|
|
23
23
|
|
|
24
24
|
class EigenDescent(Module):
|
|
25
25
|
"""
|
|
26
|
-
Uses eigenvectors corresponding to certain eigenvalues.
|
|
26
|
+
Uses eigenvectors corresponding to certain eigenvalues. For now they are just extracted from hessian.
|
|
27
|
+
|
|
28
|
+
.. warning::
|
|
29
|
+
Experimental.
|
|
27
30
|
|
|
28
31
|
Args:
|
|
29
32
|
mode (str, optional):
|
|
@@ -4,13 +4,17 @@ import warnings
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...core import Module
|
|
7
|
-
from ...utils import vec_to_tensors, vec_to_tensors_
|
|
7
|
+
from ...utils import vec_to_tensors, vec_to_tensors_, as_tensorlist
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class ExponentialTrajectoryFit(Module):
|
|
11
|
-
"""A method.
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
"""A method.
|
|
12
|
+
|
|
13
|
+
.. warning::
|
|
14
|
+
Experimental.
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, step_size=1e-2, adaptive:bool=True):
|
|
17
|
+
defaults = dict(step_size = step_size,adaptive=adaptive)
|
|
14
18
|
super().__init__(defaults)
|
|
15
19
|
|
|
16
20
|
@torch.no_grad
|
|
@@ -18,11 +22,17 @@ class ExponentialTrajectoryFit(Module):
|
|
|
18
22
|
closure = var.closure
|
|
19
23
|
assert closure is not None
|
|
20
24
|
step_size = self.settings[var.params[0]]['step_size']
|
|
25
|
+
adaptive = self.settings[var.params[0]]['adaptive']
|
|
26
|
+
|
|
21
27
|
|
|
22
28
|
# 1. perform 3 GD steps to obtain 4 points
|
|
23
29
|
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
24
30
|
for i in range(3):
|
|
25
|
-
if i == 0:
|
|
31
|
+
if i == 0:
|
|
32
|
+
grad = var.get_grad()
|
|
33
|
+
if adaptive:
|
|
34
|
+
step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
|
|
35
|
+
|
|
26
36
|
else:
|
|
27
37
|
with torch.enable_grad(): closure()
|
|
28
38
|
grad = [cast(torch.Tensor, p.grad) for p in var.params]
|
|
@@ -67,9 +77,14 @@ class ExponentialTrajectoryFit(Module):
|
|
|
67
77
|
|
|
68
78
|
|
|
69
79
|
class ExponentialTrajectoryFitV2(Module):
|
|
70
|
-
"""Should be better than one above, except it isn't.
|
|
71
|
-
|
|
72
|
-
|
|
80
|
+
"""Should be better than one above, except it isn't.
|
|
81
|
+
|
|
82
|
+
.. warning::
|
|
83
|
+
Experimental.
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
def __init__(self, step_size=1e-3, num_steps: int= 4, adaptive:bool=True):
|
|
87
|
+
defaults = dict(step_size = step_size, num_steps=num_steps, adaptive=adaptive)
|
|
73
88
|
super().__init__(defaults)
|
|
74
89
|
|
|
75
90
|
@torch.no_grad
|
|
@@ -78,9 +93,13 @@ class ExponentialTrajectoryFitV2(Module):
|
|
|
78
93
|
assert closure is not None
|
|
79
94
|
step_size = self.settings[var.params[0]]['step_size']
|
|
80
95
|
num_steps = self.settings[var.params[0]]['num_steps']
|
|
96
|
+
adaptive = self.settings[var.params[0]]['adaptive']
|
|
81
97
|
|
|
82
98
|
# 1. perform 3 GD steps to obtain 4 points (or more)
|
|
83
99
|
grad = var.get_grad()
|
|
100
|
+
if adaptive:
|
|
101
|
+
step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
|
|
102
|
+
|
|
84
103
|
points = [torch.cat([p.view(-1) for p in var.params])]
|
|
85
104
|
point_grads = [torch.cat([g.view(-1) for g in grad])]
|
|
86
105
|
|
|
@@ -132,7 +151,11 @@ def _fit_exponential(y0, y1, y2):
|
|
|
132
151
|
return A, B, r
|
|
133
152
|
|
|
134
153
|
class PointwiseExponential(Module):
|
|
135
|
-
"""A stupid method (for my youtube channel).
|
|
154
|
+
"""A stupid method (for my youtube channel).
|
|
155
|
+
|
|
156
|
+
.. warning::
|
|
157
|
+
Experimental.
|
|
158
|
+
"""
|
|
136
159
|
def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
|
|
137
160
|
defaults = dict(reg=reg, steps=steps, step_size=step_size)
|
|
138
161
|
super().__init__(defaults)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
from functools import partial
|
|
3
|
+
import math
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Module, Target, Transform, apply_transform, Chainable
|
|
7
|
+
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
8
|
+
from ..functional import (
|
|
9
|
+
debias, debiased_step_size,
|
|
10
|
+
ema_,
|
|
11
|
+
sqrt_ema_sq_,
|
|
12
|
+
)
|
|
13
|
+
from ..step_size.lr import lazy_lr
|
|
14
|
+
from ..momentum.experimental import sqrt_nag_ema_sq_
|
|
15
|
+
from ..momentum.momentum import nag_
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def exp_adam_(
|
|
19
|
+
tensors: TensorList,
|
|
20
|
+
exp_avg_: TensorList,
|
|
21
|
+
exp_avg_exp_: TensorList,
|
|
22
|
+
alpha: float | NumberList,
|
|
23
|
+
beta1: float | NumberList,
|
|
24
|
+
beta2: float | NumberList,
|
|
25
|
+
eps: float | NumberList,
|
|
26
|
+
step: int,
|
|
27
|
+
pow: float = 2,
|
|
28
|
+
debiased: bool = True,
|
|
29
|
+
max_exp_avg_exp_: TensorList | None = None,
|
|
30
|
+
|
|
31
|
+
# inner args
|
|
32
|
+
inner: Module | None = None,
|
|
33
|
+
params: list[torch.Tensor] | None = None,
|
|
34
|
+
grads: list[torch.Tensor] | None = None,
|
|
35
|
+
):
|
|
36
|
+
"""Returns new tensors."""
|
|
37
|
+
tensors_exp = tensors.abs().clip_(max=math.log(torch.finfo(tensors[0].dtype).max) / 2).exp_()
|
|
38
|
+
exp_avg_exp_.lerp_(tensors_exp, 1-beta2)
|
|
39
|
+
|
|
40
|
+
if max_exp_avg_exp_ is not None:
|
|
41
|
+
max_exp_avg_exp_.maximum_(exp_avg_exp_)
|
|
42
|
+
exp_avg_exp_ = max_exp_avg_exp_
|
|
43
|
+
|
|
44
|
+
if inner is not None:
|
|
45
|
+
assert params is not None
|
|
46
|
+
tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
|
|
47
|
+
|
|
48
|
+
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
|
|
49
|
+
if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
|
|
50
|
+
return (exp_avg_.lazy_mul(alpha) / exp_avg_exp_.log().add_(eps))
|
|
51
|
+
|
|
52
|
+
class ExpAdam(Transform):
|
|
53
|
+
"""Adam but uses abs exp and log instead of square and sqrt.
|
|
54
|
+
The gradient will be clipped to half the maximum value representable by its dtype (around 50 for float32)
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
beta1 (float, optional): momentum. Defaults to 0.9.
|
|
58
|
+
beta2 (float, optional): second momentum. Defaults to 0.999.
|
|
59
|
+
eps (float, optional): epsilon. Defaults to 1e-8.
|
|
60
|
+
alpha (float, optional): learning rate. Defaults to 1.
|
|
61
|
+
amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
|
|
62
|
+
pow (float, optional): power used in second momentum power and root. Defaults to 2.
|
|
63
|
+
debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
|
|
64
|
+
"""
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
beta1: float = 0.9,
|
|
68
|
+
beta2: float = 0.999,
|
|
69
|
+
eps: float = 1e-8,
|
|
70
|
+
amsgrad: bool = False,
|
|
71
|
+
alpha: float = 1.,
|
|
72
|
+
pow: float = 2,
|
|
73
|
+
debiased: bool = True,
|
|
74
|
+
inner: Chainable | None = None
|
|
75
|
+
):
|
|
76
|
+
defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
|
|
77
|
+
super().__init__(defaults, uses_grad=False)
|
|
78
|
+
|
|
79
|
+
if inner is not None: self.set_child('inner', inner)
|
|
80
|
+
|
|
81
|
+
@torch.no_grad
|
|
82
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
83
|
+
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
84
|
+
|
|
85
|
+
beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
|
|
86
|
+
amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
|
|
87
|
+
|
|
88
|
+
if amsgrad:
|
|
89
|
+
exp_avg, exp_avg_exp, max_exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', 'max_exp_avg_exp', cls=TensorList)
|
|
90
|
+
else:
|
|
91
|
+
exp_avg, exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', cls=TensorList)
|
|
92
|
+
max_exp_avg_exp = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
return exp_adam_(
|
|
96
|
+
tensors=TensorList(tensors),
|
|
97
|
+
exp_avg_=exp_avg,
|
|
98
|
+
exp_avg_exp_=exp_avg_exp,
|
|
99
|
+
alpha=alpha,
|
|
100
|
+
beta1=beta1,
|
|
101
|
+
beta2=beta2,
|
|
102
|
+
eps=eps,
|
|
103
|
+
step=step,
|
|
104
|
+
pow=pow,
|
|
105
|
+
debiased=debiased,
|
|
106
|
+
max_exp_avg_exp_=max_exp_avg_exp,
|
|
107
|
+
|
|
108
|
+
# inner args
|
|
109
|
+
inner=self.children.get("inner", None),
|
|
110
|
+
params=params,
|
|
111
|
+
grads=grads,
|
|
112
|
+
|
|
113
|
+
)
|