torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/misc.py
DELETED
|
@@ -1,418 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
-
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class Previous(TensorwiseTransform):
|
|
13
|
-
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
14
|
-
def __init__(self, n=1, target: Target = 'update'):
|
|
15
|
-
defaults = dict(n=n)
|
|
16
|
-
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@torch.no_grad
|
|
20
|
-
def apply_tensor(self, tensor, param, grad, loss, state, settings):
|
|
21
|
-
n = settings['n']
|
|
22
|
-
|
|
23
|
-
if 'history' not in state:
|
|
24
|
-
state['history'] = deque(maxlen=n+1)
|
|
25
|
-
|
|
26
|
-
state['history'].append(tensor)
|
|
27
|
-
|
|
28
|
-
return state['history'][0]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
class LastDifference(Transform):
|
|
32
|
-
"""Difference between past two updates."""
|
|
33
|
-
def __init__(self,target: Target = 'update'):
|
|
34
|
-
super().__init__({}, uses_grad=False, target=target)
|
|
35
|
-
|
|
36
|
-
@torch.no_grad
|
|
37
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
38
|
-
prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
|
|
39
|
-
difference = torch._foreach_sub(tensors, prev)
|
|
40
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
41
|
-
return difference
|
|
42
|
-
|
|
43
|
-
class LastGradDifference(Module):
|
|
44
|
-
"""Difference between past two grads."""
|
|
45
|
-
def __init__(self):
|
|
46
|
-
super().__init__({})
|
|
47
|
-
|
|
48
|
-
@torch.no_grad
|
|
49
|
-
def step(self, var):
|
|
50
|
-
grad = var.get_grad()
|
|
51
|
-
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
52
|
-
difference = torch._foreach_sub(grad, prev_grad)
|
|
53
|
-
for p, c in zip(prev_grad, grad): p.set_(c)
|
|
54
|
-
var.update = list(difference)
|
|
55
|
-
return var
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
class LastProduct(Transform):
|
|
59
|
-
"""Difference between past two updates."""
|
|
60
|
-
def __init__(self,target: Target = 'update'):
|
|
61
|
-
super().__init__({}, uses_grad=False, target=target)
|
|
62
|
-
|
|
63
|
-
@torch.no_grad
|
|
64
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
65
|
-
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
66
|
-
prod = torch._foreach_mul(tensors, prev)
|
|
67
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
68
|
-
return prod
|
|
69
|
-
|
|
70
|
-
class LastRatio(Transform):
|
|
71
|
-
"""Ratio between past two updates."""
|
|
72
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
73
|
-
defaults = dict(numerator=numerator)
|
|
74
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
75
|
-
|
|
76
|
-
@torch.no_grad
|
|
77
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
78
|
-
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
79
|
-
numerator = settings[0]['numerator']
|
|
80
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
81
|
-
else: ratio = torch._foreach_div(prev, tensors)
|
|
82
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
83
|
-
return ratio
|
|
84
|
-
|
|
85
|
-
class LastAbsoluteRatio(Transform):
|
|
86
|
-
"""Ratio between absolute values of past two updates."""
|
|
87
|
-
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
88
|
-
defaults = dict(numerator=numerator, eps=eps)
|
|
89
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
90
|
-
|
|
91
|
-
@torch.no_grad
|
|
92
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
93
|
-
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
94
|
-
numerator = settings[0]['numerator']
|
|
95
|
-
eps = NumberList(s['eps'] for s in settings)
|
|
96
|
-
|
|
97
|
-
torch._foreach_abs_(tensors)
|
|
98
|
-
torch._foreach_clamp_min_(prev, eps)
|
|
99
|
-
|
|
100
|
-
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
101
|
-
else: ratio = torch._foreach_div(prev, tensors)
|
|
102
|
-
for p, c in zip(prev, tensors): p.set_(c)
|
|
103
|
-
return ratio
|
|
104
|
-
|
|
105
|
-
class GradSign(Transform):
|
|
106
|
-
"""copy gradient sign to update."""
|
|
107
|
-
def __init__(self, target: Target = 'update'):
|
|
108
|
-
super().__init__({}, uses_grad=True, target=target)
|
|
109
|
-
|
|
110
|
-
@torch.no_grad
|
|
111
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
112
|
-
assert grads is not None
|
|
113
|
-
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
114
|
-
|
|
115
|
-
class UpdateSign(Transform):
|
|
116
|
-
"""use per-weight magnitudes from grad while using sign from update."""
|
|
117
|
-
def __init__(self, target: Target = 'update'):
|
|
118
|
-
super().__init__({}, uses_grad=True, target=target)
|
|
119
|
-
|
|
120
|
-
@torch.no_grad
|
|
121
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
122
|
-
assert grads is not None
|
|
123
|
-
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
124
|
-
|
|
125
|
-
class GraftToGrad(Transform):
|
|
126
|
-
"""use gradient norm and update direction."""
|
|
127
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
128
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
129
|
-
super().__init__(defaults, uses_grad=True, target=target)
|
|
130
|
-
|
|
131
|
-
@torch.no_grad
|
|
132
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
133
|
-
assert grads is not None
|
|
134
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
135
|
-
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
136
|
-
|
|
137
|
-
class GraftGradToUpdate(Transform):
|
|
138
|
-
"""use update norm and gradient direction."""
|
|
139
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
140
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
141
|
-
super().__init__(defaults, uses_grad=True, target=target)
|
|
142
|
-
|
|
143
|
-
@torch.no_grad
|
|
144
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
145
|
-
assert grads is not None
|
|
146
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
147
|
-
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class GraftToParams(Transform):
|
|
151
|
-
"""makes update norm be set to parameter norm, but norm won't go below eps"""
|
|
152
|
-
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
|
|
153
|
-
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
154
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
155
|
-
|
|
156
|
-
@torch.no_grad
|
|
157
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
158
|
-
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
159
|
-
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
160
|
-
|
|
161
|
-
class Relative(Transform):
|
|
162
|
-
"""multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
|
|
163
|
-
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
164
|
-
defaults = dict(min_value=min_value)
|
|
165
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
166
|
-
|
|
167
|
-
@torch.no_grad
|
|
168
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
169
|
-
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
170
|
-
torch._foreach_mul_(tensors, mul)
|
|
171
|
-
return tensors
|
|
172
|
-
|
|
173
|
-
class FillLoss(Module):
|
|
174
|
-
"""makes tensors filled with loss value times alpha"""
|
|
175
|
-
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
176
|
-
defaults = dict(alpha=alpha, backward=backward)
|
|
177
|
-
super().__init__(defaults)
|
|
178
|
-
|
|
179
|
-
@torch.no_grad
|
|
180
|
-
def step(self, var):
|
|
181
|
-
alpha = self.get_settings(var.params, 'alpha')
|
|
182
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
183
|
-
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
184
|
-
return var
|
|
185
|
-
|
|
186
|
-
class MulByLoss(Module):
|
|
187
|
-
"""multiplies update by loss times alpha"""
|
|
188
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
189
|
-
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
190
|
-
super().__init__(defaults)
|
|
191
|
-
|
|
192
|
-
@torch.no_grad
|
|
193
|
-
def step(self, var):
|
|
194
|
-
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
195
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
196
|
-
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
197
|
-
torch._foreach_mul_(var.update, mul)
|
|
198
|
-
return var
|
|
199
|
-
|
|
200
|
-
class DivByLoss(Module):
|
|
201
|
-
"""divides update by loss times alpha"""
|
|
202
|
-
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
203
|
-
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
204
|
-
super().__init__(defaults)
|
|
205
|
-
|
|
206
|
-
@torch.no_grad
|
|
207
|
-
def step(self, var):
|
|
208
|
-
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
209
|
-
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
210
|
-
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
211
|
-
torch._foreach_div_(var.update, mul)
|
|
212
|
-
return var
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
217
|
-
params = var.params
|
|
218
|
-
steps = self.settings[params[0]]['steps']
|
|
219
|
-
|
|
220
|
-
if sequential: modules = self.get_children_sequence()
|
|
221
|
-
else: modules = [self.children['module']] * steps
|
|
222
|
-
|
|
223
|
-
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
224
|
-
|
|
225
|
-
# store original params unless this is last module and can update params directly
|
|
226
|
-
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
227
|
-
|
|
228
|
-
# first step - pass var as usual
|
|
229
|
-
var = modules[0].step(var)
|
|
230
|
-
new_var = var
|
|
231
|
-
|
|
232
|
-
# subsequent steps - update parameters and create new var
|
|
233
|
-
if len(modules) > 1:
|
|
234
|
-
for m in modules[1:]:
|
|
235
|
-
|
|
236
|
-
# update params
|
|
237
|
-
if (not new_var.skip_update):
|
|
238
|
-
if new_var.last_module_lrs is not None:
|
|
239
|
-
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
240
|
-
|
|
241
|
-
torch._foreach_sub_(params, new_var.get_update())
|
|
242
|
-
|
|
243
|
-
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
244
|
-
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
245
|
-
model=new_var.model, current_step=new_var.current_step + 1)
|
|
246
|
-
|
|
247
|
-
# step
|
|
248
|
-
new_var = m.step(new_var)
|
|
249
|
-
|
|
250
|
-
# final parameter update
|
|
251
|
-
if (not new_var.skip_update):
|
|
252
|
-
if new_var.last_module_lrs is not None:
|
|
253
|
-
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
254
|
-
|
|
255
|
-
torch._foreach_sub_(params, new_var.get_update())
|
|
256
|
-
|
|
257
|
-
# if last module, update is applied so return new var
|
|
258
|
-
if params_before_steps is None:
|
|
259
|
-
new_var.stop = True
|
|
260
|
-
new_var.skip_update = True
|
|
261
|
-
return new_var
|
|
262
|
-
|
|
263
|
-
# otherwise use parameter difference as update
|
|
264
|
-
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
265
|
-
for p, bef in zip(params, params_before_steps):
|
|
266
|
-
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
267
|
-
return var
|
|
268
|
-
|
|
269
|
-
class Multistep(Module):
|
|
270
|
-
def __init__(self, module: Chainable, steps: int):
|
|
271
|
-
defaults = dict(steps=steps)
|
|
272
|
-
super().__init__(defaults)
|
|
273
|
-
self.set_child('module', module)
|
|
274
|
-
|
|
275
|
-
@torch.no_grad
|
|
276
|
-
def step(self, var):
|
|
277
|
-
return _sequential_step(self, var, sequential=False)
|
|
278
|
-
|
|
279
|
-
class Sequential(Module):
|
|
280
|
-
def __init__(self, modules: Iterable[Chainable], steps: int):
|
|
281
|
-
defaults = dict(steps=steps)
|
|
282
|
-
super().__init__(defaults)
|
|
283
|
-
self.set_children_sequence(modules)
|
|
284
|
-
|
|
285
|
-
@torch.no_grad
|
|
286
|
-
def step(self, var):
|
|
287
|
-
return _sequential_step(self, var, sequential=True)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
class GradientAccumulation(Module):
|
|
291
|
-
"""gradient accumulation"""
|
|
292
|
-
def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
293
|
-
defaults = dict(n=n, mean=mean, stop=stop)
|
|
294
|
-
super().__init__(defaults)
|
|
295
|
-
self.set_child('modules', modules)
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
@torch.no_grad
|
|
299
|
-
def step(self, var):
|
|
300
|
-
accumulator = self.get_state(var.params, 'accumulator')
|
|
301
|
-
settings = self.settings[var.params[0]]
|
|
302
|
-
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
303
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
304
|
-
|
|
305
|
-
# add update to accumulator
|
|
306
|
-
torch._foreach_add_(accumulator, var.get_update())
|
|
307
|
-
|
|
308
|
-
# step with accumulated updates
|
|
309
|
-
if step % n == 0:
|
|
310
|
-
if mean:
|
|
311
|
-
torch._foreach_div_(accumulator, n)
|
|
312
|
-
|
|
313
|
-
var.update = [a.clone() for a in accumulator]
|
|
314
|
-
var = self.children['modules'].step(var)
|
|
315
|
-
|
|
316
|
-
# zero accumulator
|
|
317
|
-
torch._foreach_zero_(accumulator)
|
|
318
|
-
|
|
319
|
-
else:
|
|
320
|
-
# prevent update
|
|
321
|
-
if stop:
|
|
322
|
-
var.stop=True
|
|
323
|
-
var.skip_update=True
|
|
324
|
-
|
|
325
|
-
return var
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
class Dropout(Transform):
|
|
329
|
-
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
330
|
-
defaults = dict(p=p, graft=graft)
|
|
331
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
332
|
-
|
|
333
|
-
@torch.no_grad
|
|
334
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
335
|
-
tensors = TensorList(tensors)
|
|
336
|
-
p = NumberList(s['p'] for s in settings)
|
|
337
|
-
graft = settings[0]['graft']
|
|
338
|
-
|
|
339
|
-
if graft:
|
|
340
|
-
target_norm = tensors.global_vector_norm()
|
|
341
|
-
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
342
|
-
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
343
|
-
|
|
344
|
-
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
345
|
-
|
|
346
|
-
class WeightDropout(Module):
|
|
347
|
-
"""Applies dropout directly to weights."""
|
|
348
|
-
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
349
|
-
defaults = dict(p=p, graft=graft)
|
|
350
|
-
super().__init__(defaults)
|
|
351
|
-
|
|
352
|
-
@torch.no_grad
|
|
353
|
-
def step(self, var):
|
|
354
|
-
closure = var.closure
|
|
355
|
-
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
356
|
-
params = TensorList(var.params)
|
|
357
|
-
p = NumberList(self.settings[p]['p'] for p in params)
|
|
358
|
-
mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
|
|
359
|
-
|
|
360
|
-
@torch.no_grad
|
|
361
|
-
def dropout_closure(backward=True):
|
|
362
|
-
orig_params = params.clone()
|
|
363
|
-
params.mul_(mask)
|
|
364
|
-
if backward:
|
|
365
|
-
with torch.enable_grad(): loss = closure()
|
|
366
|
-
else:
|
|
367
|
-
loss = closure(False)
|
|
368
|
-
params.copy_(orig_params)
|
|
369
|
-
return loss
|
|
370
|
-
|
|
371
|
-
var.closure = dropout_closure
|
|
372
|
-
return var
|
|
373
|
-
|
|
374
|
-
class NoiseSign(Transform):
|
|
375
|
-
"""uses random vector with update sign"""
|
|
376
|
-
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
377
|
-
defaults = dict(distribution=distribution, alpha=alpha)
|
|
378
|
-
super().__init__(defaults, uses_grad=False)
|
|
379
|
-
|
|
380
|
-
@torch.no_grad
|
|
381
|
-
def apply(self, tensors, params, grads, loss, states, settings):
|
|
382
|
-
alpha = [s['alpha'] for s in settings]
|
|
383
|
-
distribution = self.settings[params[0]]['distribution']
|
|
384
|
-
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
class NegateOnLossIncrease(Module):
|
|
388
|
-
def __init__(self, backtrack=True):
|
|
389
|
-
defaults = dict(backtrack=backtrack)
|
|
390
|
-
super().__init__(defaults=defaults)
|
|
391
|
-
|
|
392
|
-
@torch.no_grad
|
|
393
|
-
def step(self, var):
|
|
394
|
-
closure = var.closure
|
|
395
|
-
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
396
|
-
backtrack = self.settings[var.params[0]]['backtrack']
|
|
397
|
-
|
|
398
|
-
update = var.get_update()
|
|
399
|
-
f_0 = var.get_loss(backward=False)
|
|
400
|
-
|
|
401
|
-
torch._foreach_sub_(var.params, update)
|
|
402
|
-
f_1 = closure(False)
|
|
403
|
-
|
|
404
|
-
if f_1 <= f_0:
|
|
405
|
-
if var.is_last and var.last_module_lrs is None:
|
|
406
|
-
var.stop = True
|
|
407
|
-
var.skip_update = True
|
|
408
|
-
return var
|
|
409
|
-
|
|
410
|
-
torch._foreach_add_(var.params, update)
|
|
411
|
-
return var
|
|
412
|
-
|
|
413
|
-
torch._foreach_add_(var.params, update)
|
|
414
|
-
if backtrack:
|
|
415
|
-
torch._foreach_neg_(var.update)
|
|
416
|
-
else:
|
|
417
|
-
torch._foreach_zero_(var.update)
|
|
418
|
-
return var
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .modular_lbfgs import ModularLBFGS
|
|
@@ -1,196 +0,0 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from functools import partial
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, Module, Transform, Var, apply_transform
|
|
9
|
-
from ...utils import NumberList, TensorList, as_tensorlist
|
|
10
|
-
from .lbfgs import _adaptive_damping, lbfgs
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@torch.no_grad
|
|
14
|
-
def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
|
|
15
|
-
assert var.closure is not None
|
|
16
|
-
with torch.enable_grad(): var.closure()
|
|
17
|
-
grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
|
|
18
|
-
s_k = var.params - prev_params
|
|
19
|
-
y_k = grad - prev_grad
|
|
20
|
-
ys_k = s_k.dot(y_k)
|
|
21
|
-
|
|
22
|
-
if damping:
|
|
23
|
-
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
24
|
-
|
|
25
|
-
if ys_k > 1e-10:
|
|
26
|
-
s_history.append(s_k)
|
|
27
|
-
y_history.append(y_k)
|
|
28
|
-
sy_history.append(ys_k)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class OnlineLBFGS(Module):
|
|
33
|
-
"""Online L-BFGS.
|
|
34
|
-
Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
|
|
35
|
-
However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
|
|
36
|
-
better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
|
|
40
|
-
sample_grads (str, optional):
|
|
41
|
-
- "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
|
|
42
|
-
and adds it to history before stepping.
|
|
43
|
-
- "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
|
|
44
|
-
s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
|
|
45
|
-
|
|
46
|
-
In practice both modes behave very similarly. Defaults to 'before'.
|
|
47
|
-
tol (float | None, optional):
|
|
48
|
-
tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
|
|
49
|
-
damping (bool, optional):
|
|
50
|
-
whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
|
|
51
|
-
init_damping (float, optional):
|
|
52
|
-
initial damping for adaptive dampening. Defaults to 0.9.
|
|
53
|
-
eigval_bounds (tuple, optional):
|
|
54
|
-
eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
|
|
55
|
-
params_beta (float | None, optional):
|
|
56
|
-
if not None, EMA of parameters is used for preconditioner update. Defaults to None.
|
|
57
|
-
grads_beta (float | None, optional):
|
|
58
|
-
if not None, EMA of gradients is used for preconditioner update. Defaults to None.
|
|
59
|
-
update_freq (int, optional):
|
|
60
|
-
how often to update L-BFGS history. Defaults to 1.
|
|
61
|
-
z_beta (float | None, optional):
|
|
62
|
-
optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
|
|
63
|
-
inner (Chainable | None, optional):
|
|
64
|
-
optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
|
|
65
|
-
"""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
history_size=10,
|
|
69
|
-
sample_grads: Literal['before', 'after'] = 'before',
|
|
70
|
-
tol: float | None = 1e-10,
|
|
71
|
-
damping: bool = False,
|
|
72
|
-
init_damping=0.9,
|
|
73
|
-
eigval_bounds=(0.5, 50),
|
|
74
|
-
z_beta: float | None = None,
|
|
75
|
-
inner: Chainable | None = None,
|
|
76
|
-
):
|
|
77
|
-
defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
|
|
78
|
-
super().__init__(defaults)
|
|
79
|
-
|
|
80
|
-
self.global_state['s_history'] = deque(maxlen=history_size)
|
|
81
|
-
self.global_state['y_history'] = deque(maxlen=history_size)
|
|
82
|
-
self.global_state['sy_history'] = deque(maxlen=history_size)
|
|
83
|
-
|
|
84
|
-
if inner is not None:
|
|
85
|
-
self.set_child('inner', inner)
|
|
86
|
-
|
|
87
|
-
def reset(self):
|
|
88
|
-
"""Resets the internal state of the L-SR1 module."""
|
|
89
|
-
# super().reset() # Clears self.state (per-parameter) if any, and "step"
|
|
90
|
-
# Re-initialize L-SR1 specific global state
|
|
91
|
-
self.state.clear()
|
|
92
|
-
self.global_state['step'] = 0
|
|
93
|
-
self.global_state['s_history'].clear()
|
|
94
|
-
self.global_state['y_history'].clear()
|
|
95
|
-
self.global_state['sy_history'].clear()
|
|
96
|
-
|
|
97
|
-
@torch.no_grad
|
|
98
|
-
def step(self, var):
|
|
99
|
-
assert var.closure is not None
|
|
100
|
-
|
|
101
|
-
params = as_tensorlist(var.params)
|
|
102
|
-
update = as_tensorlist(var.get_update())
|
|
103
|
-
step = self.global_state.get('step', 0)
|
|
104
|
-
self.global_state['step'] = step + 1
|
|
105
|
-
|
|
106
|
-
# history of s and k
|
|
107
|
-
s_history: deque[TensorList] = self.global_state['s_history']
|
|
108
|
-
y_history: deque[TensorList] = self.global_state['y_history']
|
|
109
|
-
sy_history: deque[torch.Tensor] = self.global_state['sy_history']
|
|
110
|
-
|
|
111
|
-
tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
|
|
112
|
-
'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
|
|
113
|
-
|
|
114
|
-
# sample gradient at previous params with current mini-batch
|
|
115
|
-
if sample_grads == 'before':
|
|
116
|
-
prev_params = self.get_state(params, 'prev_params', cls=TensorList)
|
|
117
|
-
if step == 0:
|
|
118
|
-
s_k = None; y_k = None; ys_k = None
|
|
119
|
-
else:
|
|
120
|
-
s_k = params - prev_params
|
|
121
|
-
|
|
122
|
-
current_params = params.clone()
|
|
123
|
-
params.set_(prev_params)
|
|
124
|
-
with torch.enable_grad(): var.closure()
|
|
125
|
-
y_k = update - params.grad
|
|
126
|
-
ys_k = s_k.dot(y_k)
|
|
127
|
-
params.set_(current_params)
|
|
128
|
-
|
|
129
|
-
if damping:
|
|
130
|
-
s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
|
|
131
|
-
|
|
132
|
-
if ys_k > 1e-10:
|
|
133
|
-
s_history.append(s_k)
|
|
134
|
-
y_history.append(y_k)
|
|
135
|
-
sy_history.append(ys_k)
|
|
136
|
-
|
|
137
|
-
prev_params.copy_(params)
|
|
138
|
-
|
|
139
|
-
# use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
|
|
140
|
-
elif sample_grads == 'after':
|
|
141
|
-
if len(s_history) == 0:
|
|
142
|
-
s_k = None; y_k = None; ys_k = None
|
|
143
|
-
else:
|
|
144
|
-
s_k = s_history[-1]
|
|
145
|
-
y_k = y_history[-1]
|
|
146
|
-
ys_k = s_k.dot(y_k)
|
|
147
|
-
|
|
148
|
-
# this will run after params are updated by Modular after running all future modules
|
|
149
|
-
var.post_step_hooks.append(
|
|
150
|
-
partial(
|
|
151
|
-
_store_sk_yk_after_step_hook,
|
|
152
|
-
prev_params=params.clone(),
|
|
153
|
-
prev_grad=update.clone(),
|
|
154
|
-
damping=damping,
|
|
155
|
-
init_damping=init_damping,
|
|
156
|
-
eigval_bounds=eigval_bounds,
|
|
157
|
-
s_history=s_history,
|
|
158
|
-
y_history=y_history,
|
|
159
|
-
sy_history=sy_history,
|
|
160
|
-
))
|
|
161
|
-
|
|
162
|
-
else:
|
|
163
|
-
raise ValueError(sample_grads)
|
|
164
|
-
|
|
165
|
-
# step with inner module before applying preconditioner
|
|
166
|
-
if self.children:
|
|
167
|
-
update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
|
|
168
|
-
|
|
169
|
-
# tolerance on gradient difference to avoid exploding after converging
|
|
170
|
-
if tol is not None:
|
|
171
|
-
if y_k is not None and y_k.abs().global_max() <= tol:
|
|
172
|
-
var.update = update # may have been updated by inner module, probably makes sense to use it here?
|
|
173
|
-
return var
|
|
174
|
-
|
|
175
|
-
# lerp initial H^-1 @ q guess
|
|
176
|
-
z_ema = None
|
|
177
|
-
if z_beta is not None:
|
|
178
|
-
z_ema = self.get_state(params, 'z_ema', cls=TensorList)
|
|
179
|
-
|
|
180
|
-
# precondition
|
|
181
|
-
dir = lbfgs(
|
|
182
|
-
tensors_=as_tensorlist(update),
|
|
183
|
-
s_history=s_history,
|
|
184
|
-
y_history=y_history,
|
|
185
|
-
sy_history=sy_history,
|
|
186
|
-
y_k=y_k,
|
|
187
|
-
ys_k=ys_k,
|
|
188
|
-
z_beta = z_beta,
|
|
189
|
-
z_ema = z_ema,
|
|
190
|
-
step=step
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
var.update = dir
|
|
194
|
-
|
|
195
|
-
return var
|
|
196
|
-
|